Commit de15e658 authored by Kenton Varda's avatar Kenton Varda

Don't trust trusted messages so much. Just trust that the pointers are valid…

Don't trust trusted messages so much.  Just trust that the pointers are valid and in-bounds, but don't assume the data matches the schema.  This makes it easier to safely create a 'trusted' message: just copy any message into a MessageBuilder, and if it ends up all in one segment, it's safe to use.
parent d7081cca
......@@ -209,6 +209,12 @@ struct WireHelpers {
return (bits + 7 * BITS) / BITS_PER_BYTE;
}
static CAPNPROTO_ALWAYS_INLINE(bool boundsCheck(
SegmentReader* segment, const word* start, const word* end)) {
// If segment is null, this is a trusted message, so all pointers have been checked already.
return segment == nullptr || segment->containsInterval(start, end);
}
static CAPNPROTO_ALWAYS_INLINE(word* allocate(
WirePointer*& ref, SegmentBuilder*& segment, WordCount amount,
WirePointer::Kind kind)) {
......@@ -263,7 +269,8 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(
const word* followFars(const WirePointer*& ref, SegmentReader*& segment)) {
if (ref->kind() == WirePointer::FAR) {
// If the segment is null, this is a trusted message, so there are no FAR pointers.
if (segment != nullptr && ref->kind() == WirePointer::FAR) {
// Look up the segment containing the landing pad.
segment = segment->getArena()->tryGetSegment(ref->farRef.segmentId.get());
VALIDATE_INPUT(segment != nullptr, "Message contains far pointer to unknown segment.") {
......@@ -273,7 +280,7 @@ struct WireHelpers {
// Find the landing pad and check that it is within bounds.
const word* ptr = segment->getStartPtr() + ref->farPositionInSegment();
WordCount padWords = (1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + padWords),
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + padWords),
"Message contains out-of-bounds far pointer.") {
return nullptr;
}
......@@ -433,22 +440,15 @@ struct WireHelpers {
}
--nestingLimit;
const word* ptr;
if (segment == nullptr) {
ptr = ref->target();
} else {
ptr = followFars(ref, segment);
}
const word* ptr = followFars(ref, segment);
WordCount64 result = 0 * WORDS;
switch (ref->kind()) {
case WirePointer::STRUCT: {
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
break;
}
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
break;
}
result += ref->structRef.wordSize();
......@@ -473,11 +473,9 @@ struct WireHelpers {
WordCount totalWords = roundUpToWords(
ElementCount64(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize()));
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + totalWords),
"Message contained out-of-bounds list pointer.") {
break;
}
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + totalWords),
"Message contained out-of-bounds list pointer.") {
break;
}
result += totalWords;
break;
......@@ -485,11 +483,9 @@ struct WireHelpers {
case FieldSize::POINTER: {
WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS);
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + count * WORDS_PER_POINTER),
"Message contained out-of-bounds list pointer.") {
break;
}
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + count * WORDS_PER_POINTER),
"Message contained out-of-bounds list pointer.") {
break;
}
result += count * WORDS_PER_POINTER;
......@@ -502,12 +498,10 @@ struct WireHelpers {
}
case FieldSize::INLINE_COMPOSITE: {
WordCount wordCount = ref->listRef.inlineCompositeWordCount();
if (segment != nullptr) {
VALIDATE_INPUT(
segment->containsInterval(ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
"Message contained out-of-bounds list pointer.") {
break;
}
VALIDATE_INPUT(
boundsCheck(segment, ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
"Message contained out-of-bounds list pointer.") {
break;
}
result += wordCount + POINTER_SIZE_IN_WORDS;
......@@ -515,16 +509,14 @@ struct WireHelpers {
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount();
if (segment != nullptr) {
VALIDATE_INPUT(elementTag->kind() == WirePointer::STRUCT,
"Don't know how to handle non-STRUCT inline composite.") {
break;
}
VALIDATE_INPUT(elementTag->kind() == WirePointer::STRUCT,
"Don't know how to handle non-STRUCT inline composite.") {
break;
}
VALIDATE_INPUT(elementTag->structRef.wordSize() / ELEMENTS * count <= wordCount,
"Struct list pointer's elements overran size.") {
break;
}
VALIDATE_INPUT(elementTag->structRef.wordSize() / ELEMENTS * count <= wordCount,
"Struct list pointer's elements overran size.") {
break;
}
WordCount dataSize = elementTag->structRef.dataSize.get();
......@@ -738,70 +730,66 @@ struct WireHelpers {
WirePointer* ref, SegmentBuilder* segment, StructSize size, const word* defaultValue)) {
if (ref->isNull()) {
useDefault:
word* ptr;
if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
ptr = allocate(ref, segment, size.total(), WirePointer::STRUCT);
ref->structRef.set(size);
} else {
ptr = copyMessage(segment, ref, reinterpret_cast<const WirePointer*>(defaultValue));
return initStructPointer(ref, segment, size);
}
return StructBuilder(segment, ptr, reinterpret_cast<WirePointer*>(ptr + size.data),
size.data * BITS_PER_WORD, size.pointers, 0 * BITS);
} else {
WirePointer* oldRef = ref;
SegmentBuilder* oldSegment = segment;
word* oldPtr = followFars(oldRef, oldSegment);
copyMessage(segment, ref, reinterpret_cast<const WirePointer*>(defaultValue));
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
VALIDATE_INPUT(oldRef->kind() == WirePointer::STRUCT,
"Message contains non-struct pointer where struct pointer was expected.") {
goto useDefault;
}
WirePointer* oldRef = ref;
SegmentBuilder* oldSegment = segment;
word* oldPtr = followFars(oldRef, oldSegment);
WordCount oldDataSize = oldRef->structRef.dataSize.get();
WirePointerCount oldPointerCount = oldRef->structRef.ptrCount.get();
WirePointer* oldPointerSection =
reinterpret_cast<WirePointer*>(oldPtr + oldDataSize);
VALIDATE_INPUT(oldRef->kind() == WirePointer::STRUCT,
"Message contains non-struct pointer where struct pointer was expected.") {
goto useDefault;
}
if (oldDataSize < size.data || oldPointerCount < size.pointers) {
// The space allocated for this struct is too small. Unlike with readers, we can't just
// run with it and do bounds checks at access time, because how would we handle writes?
// Instead, we have to copy the struct to a new space now.
WordCount oldDataSize = oldRef->structRef.dataSize.get();
WirePointerCount oldPointerCount = oldRef->structRef.ptrCount.get();
WirePointer* oldPointerSection =
reinterpret_cast<WirePointer*>(oldPtr + oldDataSize);
WordCount newDataSize = std::max<WordCount>(oldDataSize, size.data);
WirePointerCount newPointerCount =
std::max<WirePointerCount>(oldPointerCount, size.pointers);
WordCount totalSize = newDataSize + newPointerCount * WORDS_PER_POINTER;
if (oldDataSize < size.data || oldPointerCount < size.pointers) {
// The space allocated for this struct is too small. Unlike with readers, we can't just
// run with it and do bounds checks at access time, because how would we handle writes?
// Instead, we have to copy the struct to a new space now.
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(segment, ref);
WordCount newDataSize = std::max<WordCount>(oldDataSize, size.data);
WirePointerCount newPointerCount =
std::max<WirePointerCount>(oldPointerCount, size.pointers);
WordCount totalSize = newDataSize + newPointerCount * WORDS_PER_POINTER;
word* ptr = allocate(ref, segment, totalSize, WirePointer::STRUCT);
ref->structRef.set(newDataSize, newPointerCount);
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(segment, ref);
// Copy data section.
memcpy(ptr, oldPtr, oldDataSize * BYTES_PER_WORD / BYTES);
word* ptr = allocate(ref, segment, totalSize, WirePointer::STRUCT);
ref->structRef.set(newDataSize, newPointerCount);
// Copy pointer section.
WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(ptr + newDataSize);
for (uint i = 0; i < oldPointerCount / POINTERS; i++) {
transferPointer(segment, newPointerSection + i, oldSegment, oldPointerSection + i);
}
// Zero out old location. This has two purposes:
// 1) We don't want to leak the original contents of the struct when the message is written
// out as it may contain secrets that the caller intends to remove from the new copy.
// 2) Zeros will be deflated by packing, making this dead memory almost-free if it ever
// hits the wire.
memset(oldPtr, 0,
(oldDataSize + oldPointerCount * WORDS_PER_POINTER) * BYTES_PER_WORD / BYTES);
// Copy data section.
memcpy(ptr, oldPtr, oldDataSize * BYTES_PER_WORD / BYTES);
return StructBuilder(segment, ptr, newPointerSection, newDataSize * BITS_PER_WORD,
newPointerCount, 0 * BITS);
} else {
return StructBuilder(oldSegment, oldPtr, oldPointerSection, oldDataSize * BITS_PER_WORD,
oldPointerCount, 0 * BITS);
// Copy pointer section.
WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(ptr + newDataSize);
for (uint i = 0; i < oldPointerCount / POINTERS; i++) {
transferPointer(segment, newPointerSection + i, oldSegment, oldPointerSection + i);
}
// Zero out old location. This has two purposes:
// 1) We don't want to leak the original contents of the struct when the message is written
// out as it may contain secrets that the caller intends to remove from the new copy.
// 2) Zeros will be deflated by packing, making this dead memory almost-free if it ever
// hits the wire.
memset(oldPtr, 0,
(oldDataSize + oldPointerCount * WORDS_PER_POINTER) * BYTES_PER_WORD / BYTES);
return StructBuilder(segment, ptr, newPointerSection, newDataSize * BITS_PER_WORD,
newPointerCount, 0 * BITS);
} else {
return StructBuilder(oldSegment, oldPtr, oldPointerSection, oldDataSize * BITS_PER_WORD,
oldPointerCount, 0 * BITS);
}
}
......@@ -869,93 +857,86 @@ struct WireHelpers {
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return ListBuilder();
}
word* ptr = copyMessage(origSegment, origRef,
reinterpret_cast<const WirePointer*>(defaultValue));
BitCount dataSize = dataBitsPerElement(elementSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(elementSize) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(origSegment, ptr, step, origRef->listRef.elementCount(),
dataSize, pointerCount);
copyMessage(origSegment, origRef, reinterpret_cast<const WirePointer*>(defaultValue));
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
} else {
// The pointer is already initialized. We must verify that it has the right size. Unlike
// in getWritableStructListReference(), we never need to "upgrade" the data, because this
// method is called only for non-struct lists, and there is no allowed upgrade path *to*
// a non-struct list, only *from* them.
// We must verify that the pointer has the right size. Unlike in
// getWritableStructListReference(), we never need to "upgrade" the data, because this
// method is called only for non-struct lists, and there is no allowed upgrade path *to*
// a non-struct list, only *from* them.
WirePointer* ref = origRef;
SegmentBuilder* segment = origSegment;
word* ptr = followFars(ref, segment);
WirePointer* ref = origRef;
SegmentBuilder* segment = origSegment;
word* ptr = followFars(ref, segment);
VALIDATE_INPUT(ref->kind() == WirePointer::LIST,
"Called getList{Field,Element}() but existing pointer is not a list.") {
goto useDefault;
}
VALIDATE_INPUT(ref->kind() == WirePointer::LIST,
"Called getList{Field,Element}() but existing pointer is not a list.") {
goto useDefault;
}
FieldSize oldSize = ref->listRef.elementSize();
FieldSize oldSize = ref->listRef.elementSize();
if (oldSize == FieldSize::INLINE_COMPOSITE) {
// The existing element size is INLINE_COMPOSITE, which means that it is at least two
// words, which makes it bigger than the expected element size. Since fields can only
// grow when upgraded, the existing data must have been written with a newer version of
// the protocol. We therefore never need to upgrade the data in this case, but we do
// need to validate that it is a valid upgrade from what we expected.
if (oldSize == FieldSize::INLINE_COMPOSITE) {
// The existing element size is INLINE_COMPOSITE, which means that it is at least two
// words, which makes it bigger than the expected element size. Since fields can only
// grow when upgraded, the existing data must have been written with a newer version of
// the protocol. We therefore never need to upgrade the data in this case, but we do
// need to validate that it is a valid upgrade from what we expected.
// Read the tag to get the actual element count.
WirePointer* tag = reinterpret_cast<WirePointer*>(ptr);
PRECOND(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE list with non-STRUCT elements not supported.");
ptr += POINTER_SIZE_IN_WORDS;
// Read the tag to get the actual element count.
WirePointer* tag = reinterpret_cast<WirePointer*>(ptr);
PRECOND(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE list with non-STRUCT elements not supported.");
ptr += POINTER_SIZE_IN_WORDS;
WordCount dataSize = tag->structRef.dataSize.get();
WirePointerCount pointerCount = tag->structRef.ptrCount.get();
WordCount dataSize = tag->structRef.dataSize.get();
WirePointerCount pointerCount = tag->structRef.ptrCount.get();
switch (elementSize) {
case FieldSize::VOID:
// Anything is a valid upgrade from Void.
break;
switch (elementSize) {
case FieldSize::VOID:
// Anything is a valid upgrade from Void.
break;
case FieldSize::BIT:
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES:
VALIDATE_INPUT(dataSize >= 1 * WORDS,
"Existing list value is incompatible with expected type.");
break;
case FieldSize::BIT:
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES:
VALIDATE_INPUT(dataSize >= 1 * WORDS,
"Existing list value is incompatible with expected type.");
break;
case FieldSize::POINTER:
VALIDATE_INPUT(pointerCount >= 1 * POINTERS,
"Existing list value is incompatible with expected type.");
// Adjust the pointer to point at the reference segment.
ptr += dataSize;
break;
case FieldSize::POINTER:
VALIDATE_INPUT(pointerCount >= 1 * POINTERS,
"Existing list value is incompatible with expected type.");
// Adjust the pointer to point at the reference segment.
ptr += dataSize;
break;
case FieldSize::INLINE_COMPOSITE:
FAIL_CHECK("Can't get here.");
break;
}
case FieldSize::INLINE_COMPOSITE:
FAIL_CHECK("Can't get here.");
break;
}
// OK, looks valid.
// OK, looks valid.
return ListBuilder(segment, ptr,
tag->structRef.wordSize() * BITS_PER_WORD / ELEMENTS,
tag->inlineCompositeListElementCount(),
dataSize * BITS_PER_WORD, pointerCount);
} else {
BitCount dataSize = dataBitsPerElement(oldSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(oldSize) * ELEMENTS;
return ListBuilder(segment, ptr,
tag->structRef.wordSize() * BITS_PER_WORD / ELEMENTS,
tag->inlineCompositeListElementCount(),
dataSize * BITS_PER_WORD, pointerCount);
} else {
BitCount dataSize = dataBitsPerElement(oldSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(oldSize) * ELEMENTS;
VALIDATE_INPUT(dataSize >= dataBitsPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type.");
VALIDATE_INPUT(pointerCount >= pointersPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type.");
VALIDATE_INPUT(dataSize >= dataBitsPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type.");
VALIDATE_INPUT(pointerCount >= pointersPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type.");
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(segment, ptr, step, ref->listRef.elementCount(),
dataSize, pointerCount);
}
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(segment, ptr, step, ref->listRef.elementCount(),
dataSize, pointerCount);
}
}
......@@ -968,268 +949,249 @@ struct WireHelpers {
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return ListBuilder();
}
word* ptr = copyMessage(origSegment, origRef,
reinterpret_cast<const WirePointer*>(defaultValue));
copyMessage(origSegment, origRef, reinterpret_cast<const WirePointer*>(defaultValue));
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
// Assume the default value is valid.
if (elementSize.preferredListEncoding == FieldSize::INLINE_COMPOSITE) {
WirePointer* tag = reinterpret_cast<WirePointer*>(ptr);
return ListBuilder(origSegment, tag + 1, elementSize.total() * BITS_PER_WORD / ELEMENTS,
tag->inlineCompositeListElementCount(),
elementSize.data * BITS_PER_WORD,
elementSize.pointers);
} else {
BitCount dataSize = dataBitsPerElement(elementSize.preferredListEncoding) * ELEMENTS;
WirePointerCount pointerCount =
pointersPerElement(elementSize.preferredListEncoding) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
// We must verify that the pointer has the right size and potentially upgrade it if not.
return ListBuilder(origSegment, ptr, step, origRef->listRef.elementCount(),
dataSize, pointerCount);
}
WirePointer* oldRef = origRef;
SegmentBuilder* oldSegment = origSegment;
word* oldPtr = followFars(oldRef, oldSegment);
} else {
// The pointer is already initialized. We must verify that it has the right size and
// potentially upgrade it if not.
VALIDATE_INPUT(oldRef->kind() == WirePointer::LIST,
"Called getList{Field,Element}() but existing pointer is not a list.") {
goto useDefault;
}
FieldSize oldSize = oldRef->listRef.elementSize();
WirePointer* oldRef = origRef;
SegmentBuilder* oldSegment = origSegment;
word* oldPtr = followFars(oldRef, oldSegment);
if (oldSize == FieldSize::INLINE_COMPOSITE) {
// Existing list is INLINE_COMPOSITE, but we need to verify that the sizes match.
VALIDATE_INPUT(oldRef->kind() == WirePointer::LIST,
"Called getList{Field,Element}() but existing pointer is not a list.") {
WirePointer* oldTag = reinterpret_cast<WirePointer*>(oldPtr);
oldPtr += POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(oldTag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE list with non-STRUCT elements not supported.") {
goto useDefault;
}
FieldSize oldSize = oldRef->listRef.elementSize();
WordCount oldDataSize = oldTag->structRef.dataSize.get();
WirePointerCount oldPointerCount = oldTag->structRef.ptrCount.get();
auto oldStep = (oldDataSize + oldPointerCount * WORDS_PER_POINTER) / ELEMENTS;
ElementCount elementCount = oldTag->inlineCompositeListElementCount();
if (oldSize == FieldSize::INLINE_COMPOSITE) {
// Existing list is INLINE_COMPOSITE, but we need to verify that the sizes match.
if (oldDataSize >= elementSize.data && oldPointerCount >= elementSize.pointers) {
// Old size is at least as large as we need. Ship it.
return ListBuilder(oldSegment, oldPtr, oldStep * BITS_PER_WORD, elementCount,
oldDataSize * BITS_PER_WORD, oldPointerCount);
}
WirePointer* oldTag = reinterpret_cast<WirePointer*>(oldPtr);
oldPtr += POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(oldTag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE list with non-STRUCT elements not supported.") {
goto useDefault;
}
// The structs in this list are smaller than expected, probably written using an older
// version of the protocol. We need to make a copy and expand them.
WordCount oldDataSize = oldTag->structRef.dataSize.get();
WirePointerCount oldPointerCount = oldTag->structRef.ptrCount.get();
auto oldStep = (oldDataSize + oldPointerCount * WORDS_PER_POINTER) / ELEMENTS;
ElementCount elementCount = oldTag->inlineCompositeListElementCount();
WordCount newDataSize = std::max<WordCount>(oldDataSize, elementSize.data);
WirePointerCount newPointerCount =
std::max<WirePointerCount>(oldPointerCount, elementSize.pointers);
auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS;
WordCount totalSize = newStep * elementCount;
if (oldDataSize >= elementSize.data && oldPointerCount >= elementSize.pointers) {
// Old size is at least as large as we need. Ship it.
return ListBuilder(oldSegment, oldPtr, oldStep * BITS_PER_WORD, elementCount,
oldDataSize * BITS_PER_WORD, oldPointerCount);
}
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef);
// The structs in this list are smaller than expected, probably written using an older
// version of the protocol. We need to make a copy and expand them.
word* newPtr = allocate(origRef, origSegment, totalSize + POINTER_SIZE_IN_WORDS,
WirePointer::LIST);
origRef->listRef.setInlineComposite(totalSize);
WordCount newDataSize = std::max<WordCount>(oldDataSize, elementSize.data);
WirePointerCount newPointerCount =
std::max<WirePointerCount>(oldPointerCount, elementSize.pointers);
auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS;
WordCount totalSize = newStep * elementCount;
WirePointer* newTag = reinterpret_cast<WirePointer*>(newPtr);
newTag->setKindAndInlineCompositeListElementCount(WirePointer::STRUCT, elementCount);
newTag->structRef.set(newDataSize, newPointerCount);
newPtr += POINTER_SIZE_IN_WORDS;
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef);
word* newPtr = allocate(origRef, origSegment, totalSize + POINTER_SIZE_IN_WORDS,
WirePointer::LIST);
origRef->listRef.setInlineComposite(totalSize);
WirePointer* newTag = reinterpret_cast<WirePointer*>(newPtr);
newTag->setKindAndInlineCompositeListElementCount(WirePointer::STRUCT, elementCount);
newTag->structRef.set(newDataSize, newPointerCount);
newPtr += POINTER_SIZE_IN_WORDS;
word* src = oldPtr;
word* dst = newPtr;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
// Copy data section.
memcpy(dst, src, oldDataSize * BYTES_PER_WORD / BYTES);
// Copy pointer section.
WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(dst + newDataSize);
WirePointer* oldPointerSection = reinterpret_cast<WirePointer*>(src + oldDataSize);
for (uint i = 0; i < oldPointerCount / POINTERS; i++) {
transferPointer(origSegment, newPointerSection + i, oldSegment, oldPointerSection + i);
}
word* src = oldPtr;
word* dst = newPtr;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
// Copy data section.
memcpy(dst, src, oldDataSize * BYTES_PER_WORD / BYTES);
dst += newStep * (1 * ELEMENTS);
src += oldStep * (1 * ELEMENTS);
// Copy pointer section.
WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(dst + newDataSize);
WirePointer* oldPointerSection = reinterpret_cast<WirePointer*>(src + oldDataSize);
for (uint i = 0; i < oldPointerCount / POINTERS; i++) {
transferPointer(origSegment, newPointerSection + i, oldSegment, oldPointerSection + i);
}
// Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, oldStep * elementCount * BYTES_PER_WORD / BYTES);
dst += newStep * (1 * ELEMENTS);
src += oldStep * (1 * ELEMENTS);
}
return ListBuilder(origSegment, newPtr, newStep * BITS_PER_WORD, elementCount,
newDataSize * BITS_PER_WORD, newPointerCount);
} else if (oldSize == elementSize.preferredListEncoding) {
// Old size matches exactly.
// Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, oldStep * elementCount * BYTES_PER_WORD / BYTES);
auto dataSize = dataBitsPerElement(oldSize);
auto pointerCount = pointersPerElement(oldSize);
auto step = dataSize + pointerCount * BITS_PER_POINTER;
return ListBuilder(origSegment, newPtr, newStep * BITS_PER_WORD, elementCount,
newDataSize * BITS_PER_WORD, newPointerCount);
} else if (oldSize == elementSize.preferredListEncoding) {
// Old size matches exactly.
return ListBuilder(oldSegment, oldPtr, step, oldRef->listRef.elementCount(),
dataSize * (1 * ELEMENTS), pointerCount * (1 * ELEMENTS));
} else {
switch (elementSize.preferredListEncoding) {
case FieldSize::VOID:
// No expectations.
break;
case FieldSize::POINTER:
VALIDATE_INPUT(oldSize == FieldSize::POINTER || oldSize == FieldSize::VOID,
"Struct list has incompatible element size.") {
goto useDefault;
}
break;
case FieldSize::INLINE_COMPOSITE:
// Old size can be anything.
break;
case FieldSize::BIT:
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES:
// Preferred size is data-only.
VALIDATE_INPUT(oldSize != FieldSize::POINTER,
"Struct list has incompatible element size.") {
goto useDefault;
}
break;
}
auto dataSize = dataBitsPerElement(oldSize);
auto pointerCount = pointersPerElement(oldSize);
auto step = dataSize + pointerCount * BITS_PER_POINTER;
return ListBuilder(oldSegment, oldPtr, step, oldRef->listRef.elementCount(),
dataSize * (1 * ELEMENTS), pointerCount * (1 * ELEMENTS));
} else {
switch (elementSize.preferredListEncoding) {
case FieldSize::VOID:
// No expectations.
break;
case FieldSize::POINTER:
VALIDATE_INPUT(oldSize == FieldSize::POINTER || oldSize == FieldSize::VOID,
"Struct list has incompatible element size.") {
goto useDefault;
}
break;
case FieldSize::INLINE_COMPOSITE:
// Old size can be anything.
break;
case FieldSize::BIT:
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES:
// Preferred size is data-only.
VALIDATE_INPUT(oldSize != FieldSize::POINTER,
"Struct list has incompatible element size.") {
goto useDefault;
}
break;
}
// OK, the old size is compatible with the preferred, but is not exactly the same. We may
// need to upgrade it.
// OK, the old size is compatible with the preferred, but is not exactly the same. We may
// need to upgrade it.
BitCount oldDataSize = dataBitsPerElement(oldSize) * ELEMENTS;
WirePointerCount oldPointerCount = pointersPerElement(oldSize) * ELEMENTS;
auto oldStep = (oldDataSize + oldPointerCount * BITS_PER_POINTER) / ELEMENTS;
ElementCount elementCount = oldRef->listRef.elementCount();
BitCount oldDataSize = dataBitsPerElement(oldSize) * ELEMENTS;
WirePointerCount oldPointerCount = pointersPerElement(oldSize) * ELEMENTS;
auto oldStep = (oldDataSize + oldPointerCount * BITS_PER_POINTER) / ELEMENTS;
ElementCount elementCount = oldRef->listRef.elementCount();
if (oldSize >= elementSize.preferredListEncoding) {
// The old size is at least as large as the preferred, so we don't need to upgrade.
return ListBuilder(oldSegment, oldPtr, oldStep, elementCount,
oldDataSize, oldPointerCount);
}
if (oldSize >= elementSize.preferredListEncoding) {
// The old size is at least as large as the preferred, so we don't need to upgrade.
return ListBuilder(oldSegment, oldPtr, oldStep, elementCount,
oldDataSize, oldPointerCount);
}
// Upgrade is necessary.
// Upgrade is necessary.
if (oldSize == FieldSize::VOID) {
// Nothing to copy, just allocate a new list.
return initStructListPointer(origRef, origSegment, elementCount, elementSize);
} else if (elementSize.preferredListEncoding == FieldSize::INLINE_COMPOSITE) {
// Upgrading to an inline composite list.
if (oldSize == FieldSize::VOID) {
// Nothing to copy, just allocate a new list.
return initStructListPointer(origRef, origSegment, elementCount, elementSize);
} else if (elementSize.preferredListEncoding == FieldSize::INLINE_COMPOSITE) {
// Upgrading to an inline composite list.
WordCount newDataSize = elementSize.data;
WirePointerCount newPointerCount = elementSize.pointers;
WordCount newDataSize = elementSize.data;
WirePointerCount newPointerCount = elementSize.pointers;
if (oldSize == FieldSize::POINTER) {
newPointerCount = std::max(newPointerCount, 1 * POINTERS);
} else {
// Old list contains data elements, so we need at least 1 word of data.
newDataSize = std::max(newDataSize, 1 * WORDS);
}
if (oldSize == FieldSize::POINTER) {
newPointerCount = std::max(newPointerCount, 1 * POINTERS);
} else {
// Old list contains data elements, so we need at least 1 word of data.
newDataSize = std::max(newDataSize, 1 * WORDS);
}
auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS;
WordCount totalWords = elementCount * newStep;
auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS;
WordCount totalWords = elementCount * newStep;
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef);
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef);
word* newPtr = allocate(origRef, origSegment, totalWords + POINTER_SIZE_IN_WORDS,
WirePointer::LIST);
origRef->listRef.setInlineComposite(totalWords);
word* newPtr = allocate(origRef, origSegment, totalWords + POINTER_SIZE_IN_WORDS,
WirePointer::LIST);
origRef->listRef.setInlineComposite(totalWords);
WirePointer* tag = reinterpret_cast<WirePointer*>(newPtr);
tag->setKindAndInlineCompositeListElementCount(WirePointer::STRUCT, elementCount);
tag->structRef.set(newDataSize, newPointerCount);
newPtr += POINTER_SIZE_IN_WORDS;
WirePointer* tag = reinterpret_cast<WirePointer*>(newPtr);
tag->setKindAndInlineCompositeListElementCount(WirePointer::STRUCT, elementCount);
tag->structRef.set(newDataSize, newPointerCount);
newPtr += POINTER_SIZE_IN_WORDS;
if (oldSize == FieldSize::POINTER) {
WirePointer* dst = reinterpret_cast<WirePointer*>(newPtr + newDataSize);
WirePointer* src = reinterpret_cast<WirePointer*>(oldPtr);
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
transferPointer(origSegment, dst, oldSegment, src);
dst += newStep / WORDS_PER_POINTER * (1 * ELEMENTS);
++src;
}
} else if (oldSize == FieldSize::BIT) {
word* dst = newPtr;
char* src = reinterpret_cast<char*>(oldPtr);
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
*reinterpret_cast<char*>(dst) = (src[i/8] >> (i%8)) & 1;
dst += newStep * (1 * ELEMENTS);
}
} else {
word* dst = newPtr;
char* src = reinterpret_cast<char*>(oldPtr);
ByteCount oldByteStep = oldDataSize / BITS_PER_BYTE;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
memcpy(dst, src, oldByteStep / BYTES);
src += oldByteStep / BYTES;
dst += newStep * (1 * ELEMENTS);
}
if (oldSize == FieldSize::POINTER) {
WirePointer* dst = reinterpret_cast<WirePointer*>(newPtr + newDataSize);
WirePointer* src = reinterpret_cast<WirePointer*>(oldPtr);
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
transferPointer(origSegment, dst, oldSegment, src);
dst += newStep / WORDS_PER_POINTER * (1 * ELEMENTS);
++src;
}
} else if (oldSize == FieldSize::BIT) {
word* dst = newPtr;
char* src = reinterpret_cast<char*>(oldPtr);
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
*reinterpret_cast<char*>(dst) = (src[i/8] >> (i%8)) & 1;
dst += newStep * (1 * ELEMENTS);
}
} else {
word* dst = newPtr;
char* src = reinterpret_cast<char*>(oldPtr);
ByteCount oldByteStep = oldDataSize / BITS_PER_BYTE;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
memcpy(dst, src, oldByteStep / BYTES);
src += oldByteStep / BYTES;
dst += newStep * (1 * ELEMENTS);
}
}
// Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, roundUpToBytes(oldStep * elementCount) / BYTES);
// Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, roundUpToBytes(oldStep * elementCount) / BYTES);
return ListBuilder(origSegment, newPtr, newStep * BITS_PER_WORD, elementCount,
newDataSize * BITS_PER_WORD, newPointerCount);
return ListBuilder(origSegment, newPtr, newStep * BITS_PER_WORD, elementCount,
newDataSize * BITS_PER_WORD, newPointerCount);
} else {
// If oldSize were POINTER or EIGHT_BYTES then the preferred size must be
// INLINE_COMPOSITE because any other compatible size would not require an upgrade.
CHECK(oldSize < FieldSize::EIGHT_BYTES);
} else {
// If oldSize were POINTER or EIGHT_BYTES then the preferred size must be
// INLINE_COMPOSITE because any other compatible size would not require an upgrade.
CHECK(oldSize < FieldSize::EIGHT_BYTES);
// If the preferred size were BIT then oldSize must be VOID, but we handled that case
// above.
CHECK(elementSize.preferredListEncoding >= FieldSize::BIT);
// If the preferred size were BIT then oldSize must be VOID, but we handled that case
// above.
CHECK(elementSize.preferredListEncoding >= FieldSize::BIT);
// OK, so the expected list elements are all data and between 1 byte and 1 word each,
// and the old element are data between 1 bit and 4 bytes. We're upgrading from one
// primitive data type to another, larger one.
// OK, so the expected list elements are all data and between 1 byte and 1 word each,
// and the old element are data between 1 bit and 4 bytes. We're upgrading from one
// primitive data type to another, larger one.
BitCount newDataSize =
dataBitsPerElement(elementSize.preferredListEncoding) * ELEMENTS;
BitCount newDataSize =
dataBitsPerElement(elementSize.preferredListEncoding) * ELEMENTS;
WordCount totalWords =
roundUpToWords(BitCount64(newDataSize) * (elementCount / ELEMENTS));
WordCount totalWords =
roundUpToWords(BitCount64(newDataSize) * (elementCount / ELEMENTS));
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef);
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef);
word* newPtr = allocate(origRef, origSegment, totalWords, WirePointer::LIST);
origRef->listRef.set(elementSize.preferredListEncoding, elementCount);
word* newPtr = allocate(origRef, origSegment, totalWords, WirePointer::LIST);
origRef->listRef.set(elementSize.preferredListEncoding, elementCount);
char* newBytePtr = reinterpret_cast<char*>(newPtr);
char* oldBytePtr = reinterpret_cast<char*>(oldPtr);
ByteCount newDataByteSize = newDataSize / BITS_PER_BYTE;
if (oldSize == FieldSize::BIT) {
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
*newBytePtr = (oldBytePtr[i/8] >> (i%8)) & 1;
newBytePtr += newDataByteSize / BYTES;
}
} else {
ByteCount oldDataByteSize = oldDataSize / BITS_PER_BYTE;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
memcpy(newBytePtr, oldBytePtr, oldDataByteSize / BYTES);
oldBytePtr += oldDataByteSize / BYTES;
newBytePtr += newDataByteSize / BYTES;
}
char* newBytePtr = reinterpret_cast<char*>(newPtr);
char* oldBytePtr = reinterpret_cast<char*>(oldPtr);
ByteCount newDataByteSize = newDataSize / BITS_PER_BYTE;
if (oldSize == FieldSize::BIT) {
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
*newBytePtr = (oldBytePtr[i/8] >> (i%8)) & 1;
newBytePtr += newDataByteSize / BYTES;
}
} else {
ByteCount oldDataByteSize = oldDataSize / BITS_PER_BYTE;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
memcpy(newBytePtr, oldBytePtr, oldDataByteSize / BYTES);
oldBytePtr += oldDataByteSize / BYTES;
newBytePtr += newDataByteSize / BYTES;
}
}
// Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, roundUpToBytes(oldStep * elementCount) / BYTES);
// Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, roundUpToBytes(oldStep * elementCount) / BYTES);
return ListBuilder(origSegment, newPtr, newDataSize / ELEMENTS, elementCount,
newDataSize, 0 * POINTERS);
}
return ListBuilder(origSegment, newPtr, newDataSize / ELEMENTS, elementCount,
newDataSize, 0 * POINTERS);
}
}
}
......@@ -1459,42 +1421,36 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(StructReader readStructPointer(
SegmentReader* segment, const WirePointer* ref, const word* defaultValue,
int nestingLimit)) {
const word* ptr;
if (ref == nullptr || ref->isNull()) {
useDefault:
if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return StructReader(nullptr, nullptr, nullptr, 0 * BITS, 0 * POINTERS, 0 * BITS,
std::numeric_limits<int>::max());
return StructReader();
}
segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported the error.
goto useDefault;
}
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
VALIDATE_INPUT(ref->kind() == WirePointer::STRUCT,
"Message contains non-struct pointer where struct pointer was expected.") {
goto useDefault;
}
const word* ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported the error.
goto useDefault;
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
goto useDefault;
}
} else {
// Trusted messages don't contain far pointers.
ptr = ref->target();
VALIDATE_INPUT(ref->kind() == WirePointer::STRUCT,
"Message contains non-struct pointer where struct pointer was expected.") {
goto useDefault;
}
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
goto useDefault;
}
return StructReader(
......@@ -1507,7 +1463,6 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(ListReader readListPointer(
SegmentReader* segment, const WirePointer* ref, const word* defaultValue,
FieldSize expectedElementSize, int nestingLimit)) {
const word* ptr;
if (ref == nullptr || ref->isNull()) {
useDefault:
if (defaultValue == nullptr ||
......@@ -1516,26 +1471,23 @@ struct WireHelpers {
}
segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported error.
goto useDefault;
}
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
VALIDATE_INPUT(ref->kind() == WirePointer::LIST,
"Message contains non-list pointer where list pointer was expected.") {
goto useDefault;
}
} else {
// Trusted messages don't contain far pointers.
ptr = ref->target();
const word* ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported error.
goto useDefault;
}
VALIDATE_INPUT(ref->kind() == WirePointer::LIST,
"Message contains non-list pointer where list pointer was expected.") {
goto useDefault;
}
if (ref->listRef.elementSize() == FieldSize::INLINE_COMPOSITE) {
......@@ -1548,75 +1500,63 @@ struct WireHelpers {
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS;
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
VALIDATE_INPUT(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault;
}
size = tag->inlineCompositeListElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
VALIDATE_INPUT(size * wordsPerElement <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.") {
goto useDefault;
}
VALIDATE_INPUT(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
// If a struct list was not expected, then presumably a non-struct list was upgraded to a
// struct list. We need to manipulate the pointer to point at the first field of the
// struct. Together with the "stepBits", this will allow the struct list to be accessed as
// if it were a primitive list without branching.
VALIDATE_INPUT(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault;
}
// Check whether the size is compatible.
switch (expectedElementSize) {
case FieldSize::VOID:
break;
size = tag->inlineCompositeListElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
case FieldSize::BIT:
FAIL_VALIDATE_INPUT("Expected a bit list, but got a list of structs.") {
goto useDefault;
}
break;
VALIDATE_INPUT(size * wordsPerElement <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.") {
goto useDefault;
}
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES:
VALIDATE_INPUT(tag->structRef.dataSize.get() > 0 * WORDS,
"Expected a primitive list, but got a list of pointer-only structs.") {
goto useDefault;
}
break;
// If a struct list was not expected, then presumably a non-struct list was upgraded to a
// struct list. We need to manipulate the pointer to point at the first field of the
// struct. Together with the "stepBits", this will allow the struct list to be accessed as
// if it were a primitive list without branching.
case FieldSize::POINTER:
// We expected a list of pointers but got a list of structs. Assuming the first field
// in the struct is the pointer we were looking for, we want to munge the pointer to
// point at the first element's pointer segment.
ptr += tag->structRef.dataSize.get();
VALIDATE_INPUT(tag->structRef.ptrCount.get() > 0 * POINTERS,
"Expected a pointer list, but got a list of data-only structs.") {
goto useDefault;
}
break;
// Check whether the size is compatible.
switch (expectedElementSize) {
case FieldSize::VOID:
break;
case FieldSize::INLINE_COMPOSITE:
break;
}
case FieldSize::BIT:
FAIL_VALIDATE_INPUT("Expected a bit list, but got a list of structs.") {
goto useDefault;
}
break;
} else {
// Trusted message.
// This logic is equivalent to the other branch, above, but skipping all the checks.
size = tag->inlineCompositeListElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES:
VALIDATE_INPUT(tag->structRef.dataSize.get() > 0 * WORDS,
"Expected a primitive list, but got a list of pointer-only structs.") {
goto useDefault;
}
break;
if (expectedElementSize == FieldSize::POINTER) {
case FieldSize::POINTER:
// We expected a list of pointers but got a list of structs. Assuming the first field
// in the struct is the pointer we were looking for, we want to munge the pointer to
// point at the first element's pointer segment.
ptr += tag->structRef.dataSize.get();
}
VALIDATE_INPUT(tag->structRef.ptrCount.get() > 0 * POINTERS,
"Expected a pointer list, but got a list of data-only structs.") {
goto useDefault;
}
break;
case FieldSize::INLINE_COMPOSITE:
break;
}
return ListReader(
......@@ -1632,33 +1572,29 @@ struct WireHelpers {
pointersPerElement(ref->listRef.elementSize()) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
if (segment != nullptr) {
// Verify that the elements are at least as large as the expected type. Note that if we
// expected INLINE_COMPOSITE, the expected sizes here will be zero, because bounds checking
// will be performed at field access time. So this check here is for the case where we
// expected a list of some primitive or pointer type.
// Verify that the elements are at least as large as the expected type. Note that if we
// expected INLINE_COMPOSITE, the expected sizes here will be zero, because bounds checking
// will be performed at field access time. So this check here is for the case where we
// expected a list of some primitive or pointer type.
BitCount expectedDataBitsPerElement =
dataBitsPerElement(expectedElementSize) * ELEMENTS;
WirePointerCount expectedPointersPerElement =
pointersPerElement(expectedElementSize) * ELEMENTS;
BitCount expectedDataBitsPerElement =
dataBitsPerElement(expectedElementSize) * ELEMENTS;
WirePointerCount expectedPointersPerElement =
pointersPerElement(expectedElementSize) * ELEMENTS;
VALIDATE_INPUT(expectedDataBitsPerElement <= dataSize,
"Message contained list with incompatible element type.") {
goto useDefault;
}
VALIDATE_INPUT(expectedPointersPerElement <= pointerCount,
"Message contained list with incompatible element type.") {
goto useDefault;
}
VALIDATE_INPUT(expectedDataBitsPerElement <= dataSize,
"Message contained list with incompatible element type.") {
goto useDefault;
}
VALIDATE_INPUT(expectedPointersPerElement <= pointerCount,
"Message contained list with incompatible element type.") {
goto useDefault;
}
return ListReader(segment, ptr, ref->listRef.elementCount(), step,
......@@ -1671,15 +1607,8 @@ struct WireHelpers {
const void* defaultValue, ByteCount defaultSize)) {
if (ref == nullptr || ref->isNull()) {
useDefault:
if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
defaultValue = "";
}
if (defaultValue == nullptr) defaultValue = "";
return Text::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES);
} else if (segment == nullptr) {
// Trusted message.
return Text::Reader(reinterpret_cast<const char*>(ref->target()),
ref->listRef.elementCount() / ELEMENTS - 1);
} else {
const word* ptr = followFars(ref, segment);
......@@ -1700,7 +1629,7 @@ struct WireHelpers {
goto useDefault;
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds text pointer.") {
goto useDefault;
......@@ -1727,10 +1656,6 @@ struct WireHelpers {
if (ref == nullptr || ref->isNull()) {
useDefault:
return Data::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES);
} else if (segment == nullptr) {
// Trusted message.
return Data::Reader(reinterpret_cast<const char*>(ref->target()),
ref->listRef.elementCount() / ELEMENTS);
} else {
const word* ptr = followFars(ref, segment);
......@@ -1751,7 +1676,7 @@ struct WireHelpers {
goto useDefault;
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds data pointer.") {
goto useDefault;
......@@ -1771,7 +1696,6 @@ struct WireHelpers {
// Not always-inline because it is called from several places in the copying code, and anyway
// is relatively rarely used.
const word* ptr;
if (ref == nullptr || ref->isNull()) {
useDefault:
if (defaultValue == nullptr ||
......@@ -1780,29 +1704,25 @@ struct WireHelpers {
}
segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
ptr = WireHelpers::followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported the error.
goto useDefault;
}
} else {
ptr = ref->target();
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
const word* ptr = WireHelpers::followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported the error.
goto useDefault;
}
switch (ref->kind()) {
case WirePointer::STRUCT:
if (segment != nullptr) {
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
goto useDefault;
}
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
goto useDefault;
}
return ObjectReader(
StructReader(segment, ptr,
......@@ -1813,11 +1733,9 @@ struct WireHelpers {
case WirePointer::LIST: {
FieldSize elementSize = ref->listRef.elementSize();
if (segment != nullptr) {
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
if (elementSize == FieldSize::INLINE_COMPOSITE) {
......@@ -1825,26 +1743,21 @@ struct WireHelpers {
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS;
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr - POINTER_SIZE_IN_WORDS,
ptr + wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
VALIDATE_INPUT(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
VALIDATE_INPUT(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault;
}
VALIDATE_INPUT(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault;
}
ElementCount elementCount = tag->inlineCompositeListElementCount();
auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (segment != nullptr) {
VALIDATE_INPUT(wordsPerElement * elementCount <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.");
}
VALIDATE_INPUT(wordsPerElement * elementCount <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.");
return ObjectReader(
ListReader(segment, ptr, elementCount, wordsPerElement * BITS_PER_WORD,
......@@ -1857,11 +1770,9 @@ struct WireHelpers {
ElementCount elementCount = ref->listRef.elementCount();
WordCount wordCount = roundUpToWords(ElementCount64(elementCount) * step);
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
return ObjectReader(
......@@ -1992,18 +1903,18 @@ StructReader StructBuilder::asReader() const {
StructReader StructReader::readRootTrusted(const word* location) {
return WireHelpers::readStructPointer(nullptr, reinterpret_cast<const WirePointer*>(location),
nullptr, std::numeric_limits<int>::max());
nullptr, std::numeric_limits<int>::max());
}
StructReader StructReader::readRoot(
const word* location, SegmentReader* segment, int nestingLimit) {
VALIDATE_INPUT(segment->containsInterval(location, location + POINTER_SIZE_IN_WORDS),
VALIDATE_INPUT(WireHelpers::boundsCheck(segment, location, location + POINTER_SIZE_IN_WORDS),
"Root location out-of-bounds.") {
location = nullptr;
}
return WireHelpers::readStructPointer(segment, reinterpret_cast<const WirePointer*>(location),
nullptr, nestingLimit);
nullptr, nestingLimit);
}
StructReader StructReader::getStructField(
......@@ -2220,7 +2131,7 @@ Data::Reader ListReader::asData() {
}
StructReader ListReader::getStructElement(ElementCount index) const {
VALIDATE_INPUT((segment == nullptr) | (nestingLimit > 0),
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
return StructReader();
}
......
......@@ -416,7 +416,7 @@ class StructReader {
public:
inline StructReader()
: segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0),
pointerCount(0), bit0Offset(0), nestingLimit(0) {}
pointerCount(0), bit0Offset(0), nestingLimit(0x7fffffff) {}
static StructReader readRootTrusted(const word* location);
static StructReader readRoot(const word* location, SegmentReader* segment, int nestingLimit);
......@@ -609,7 +609,7 @@ class ListReader {
public:
inline ListReader()
: segment(nullptr), ptr(nullptr), elementCount(0), step(0 * BITS / ELEMENTS),
structDataSize(0), structPointerCount(0), nestingLimit(0) {}
structDataSize(0), structPointerCount(0), nestingLimit(0x7fffffff) {}
inline ElementCount size() const;
// The number of elements in the list.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment