Commit de15e658 authored by Kenton Varda's avatar Kenton Varda

Don't trust trusted messages so much. Just trust that the pointers are valid…

Don't trust trusted messages so much.  Just trust that the pointers are valid and in-bounds, but don't assume the data matches the schema.  This makes it easier to safely create a 'trusted' message: just copy any message into a MessageBuilder, and if it ends up all in one segment, it's safe to use.
parent d7081cca
...@@ -209,6 +209,12 @@ struct WireHelpers { ...@@ -209,6 +209,12 @@ struct WireHelpers {
return (bits + 7 * BITS) / BITS_PER_BYTE; return (bits + 7 * BITS) / BITS_PER_BYTE;
} }
static CAPNPROTO_ALWAYS_INLINE(bool boundsCheck(
SegmentReader* segment, const word* start, const word* end)) {
// If segment is null, this is a trusted message, so all pointers have been checked already.
return segment == nullptr || segment->containsInterval(start, end);
}
static CAPNPROTO_ALWAYS_INLINE(word* allocate( static CAPNPROTO_ALWAYS_INLINE(word* allocate(
WirePointer*& ref, SegmentBuilder*& segment, WordCount amount, WirePointer*& ref, SegmentBuilder*& segment, WordCount amount,
WirePointer::Kind kind)) { WirePointer::Kind kind)) {
...@@ -263,7 +269,8 @@ struct WireHelpers { ...@@ -263,7 +269,8 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE( static CAPNPROTO_ALWAYS_INLINE(
const word* followFars(const WirePointer*& ref, SegmentReader*& segment)) { const word* followFars(const WirePointer*& ref, SegmentReader*& segment)) {
if (ref->kind() == WirePointer::FAR) { // If the segment is null, this is a trusted message, so there are no FAR pointers.
if (segment != nullptr && ref->kind() == WirePointer::FAR) {
// Look up the segment containing the landing pad. // Look up the segment containing the landing pad.
segment = segment->getArena()->tryGetSegment(ref->farRef.segmentId.get()); segment = segment->getArena()->tryGetSegment(ref->farRef.segmentId.get());
VALIDATE_INPUT(segment != nullptr, "Message contains far pointer to unknown segment.") { VALIDATE_INPUT(segment != nullptr, "Message contains far pointer to unknown segment.") {
...@@ -273,7 +280,7 @@ struct WireHelpers { ...@@ -273,7 +280,7 @@ struct WireHelpers {
// Find the landing pad and check that it is within bounds. // Find the landing pad and check that it is within bounds.
const word* ptr = segment->getStartPtr() + ref->farPositionInSegment(); const word* ptr = segment->getStartPtr() + ref->farPositionInSegment();
WordCount padWords = (1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS; WordCount padWords = (1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + padWords), VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + padWords),
"Message contains out-of-bounds far pointer.") { "Message contains out-of-bounds far pointer.") {
return nullptr; return nullptr;
} }
...@@ -433,22 +440,15 @@ struct WireHelpers { ...@@ -433,22 +440,15 @@ struct WireHelpers {
} }
--nestingLimit; --nestingLimit;
const word* ptr; const word* ptr = followFars(ref, segment);
if (segment == nullptr) {
ptr = ref->target();
} else {
ptr = followFars(ref, segment);
}
WordCount64 result = 0 * WORDS; WordCount64 result = 0 * WORDS;
switch (ref->kind()) { switch (ref->kind()) {
case WirePointer::STRUCT: { case WirePointer::STRUCT: {
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()), "Message contained out-of-bounds struct pointer.") {
"Message contained out-of-bounds struct pointer.") { break;
break;
}
} }
result += ref->structRef.wordSize(); result += ref->structRef.wordSize();
...@@ -473,11 +473,9 @@ struct WireHelpers { ...@@ -473,11 +473,9 @@ struct WireHelpers {
WordCount totalWords = roundUpToWords( WordCount totalWords = roundUpToWords(
ElementCount64(ref->listRef.elementCount()) * ElementCount64(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize())); dataBitsPerElement(ref->listRef.elementSize()));
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + totalWords),
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + totalWords), "Message contained out-of-bounds list pointer.") {
"Message contained out-of-bounds list pointer.") { break;
break;
}
} }
result += totalWords; result += totalWords;
break; break;
...@@ -485,11 +483,9 @@ struct WireHelpers { ...@@ -485,11 +483,9 @@ struct WireHelpers {
case FieldSize::POINTER: { case FieldSize::POINTER: {
WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS); WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS);
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + count * WORDS_PER_POINTER),
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + count * WORDS_PER_POINTER), "Message contained out-of-bounds list pointer.") {
"Message contained out-of-bounds list pointer.") { break;
break;
}
} }
result += count * WORDS_PER_POINTER; result += count * WORDS_PER_POINTER;
...@@ -502,12 +498,10 @@ struct WireHelpers { ...@@ -502,12 +498,10 @@ struct WireHelpers {
} }
case FieldSize::INLINE_COMPOSITE: { case FieldSize::INLINE_COMPOSITE: {
WordCount wordCount = ref->listRef.inlineCompositeWordCount(); WordCount wordCount = ref->listRef.inlineCompositeWordCount();
if (segment != nullptr) { VALIDATE_INPUT(
VALIDATE_INPUT( boundsCheck(segment, ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
segment->containsInterval(ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS), "Message contained out-of-bounds list pointer.") {
"Message contained out-of-bounds list pointer.") { break;
break;
}
} }
result += wordCount + POINTER_SIZE_IN_WORDS; result += wordCount + POINTER_SIZE_IN_WORDS;
...@@ -515,16 +509,14 @@ struct WireHelpers { ...@@ -515,16 +509,14 @@ struct WireHelpers {
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount(); ElementCount count = elementTag->inlineCompositeListElementCount();
if (segment != nullptr) { VALIDATE_INPUT(elementTag->kind() == WirePointer::STRUCT,
VALIDATE_INPUT(elementTag->kind() == WirePointer::STRUCT, "Don't know how to handle non-STRUCT inline composite.") {
"Don't know how to handle non-STRUCT inline composite.") { break;
break; }
}
VALIDATE_INPUT(elementTag->structRef.wordSize() / ELEMENTS * count <= wordCount, VALIDATE_INPUT(elementTag->structRef.wordSize() / ELEMENTS * count <= wordCount,
"Struct list pointer's elements overran size.") { "Struct list pointer's elements overran size.") {
break; break;
}
} }
WordCount dataSize = elementTag->structRef.dataSize.get(); WordCount dataSize = elementTag->structRef.dataSize.get();
...@@ -738,70 +730,66 @@ struct WireHelpers { ...@@ -738,70 +730,66 @@ struct WireHelpers {
WirePointer* ref, SegmentBuilder* segment, StructSize size, const word* defaultValue)) { WirePointer* ref, SegmentBuilder* segment, StructSize size, const word* defaultValue)) {
if (ref->isNull()) { if (ref->isNull()) {
useDefault: useDefault:
word* ptr;
if (defaultValue == nullptr || if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) { reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
ptr = allocate(ref, segment, size.total(), WirePointer::STRUCT); return initStructPointer(ref, segment, size);
ref->structRef.set(size);
} else {
ptr = copyMessage(segment, ref, reinterpret_cast<const WirePointer*>(defaultValue));
} }
return StructBuilder(segment, ptr, reinterpret_cast<WirePointer*>(ptr + size.data), copyMessage(segment, ref, reinterpret_cast<const WirePointer*>(defaultValue));
size.data * BITS_PER_WORD, size.pointers, 0 * BITS); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} else { }
WirePointer* oldRef = ref;
SegmentBuilder* oldSegment = segment;
word* oldPtr = followFars(oldRef, oldSegment);
VALIDATE_INPUT(oldRef->kind() == WirePointer::STRUCT, WirePointer* oldRef = ref;
"Message contains non-struct pointer where struct pointer was expected.") { SegmentBuilder* oldSegment = segment;
goto useDefault; word* oldPtr = followFars(oldRef, oldSegment);
}
WordCount oldDataSize = oldRef->structRef.dataSize.get(); VALIDATE_INPUT(oldRef->kind() == WirePointer::STRUCT,
WirePointerCount oldPointerCount = oldRef->structRef.ptrCount.get(); "Message contains non-struct pointer where struct pointer was expected.") {
WirePointer* oldPointerSection = goto useDefault;
reinterpret_cast<WirePointer*>(oldPtr + oldDataSize); }
if (oldDataSize < size.data || oldPointerCount < size.pointers) { WordCount oldDataSize = oldRef->structRef.dataSize.get();
// The space allocated for this struct is too small. Unlike with readers, we can't just WirePointerCount oldPointerCount = oldRef->structRef.ptrCount.get();
// run with it and do bounds checks at access time, because how would we handle writes? WirePointer* oldPointerSection =
// Instead, we have to copy the struct to a new space now. reinterpret_cast<WirePointer*>(oldPtr + oldDataSize);
WordCount newDataSize = std::max<WordCount>(oldDataSize, size.data); if (oldDataSize < size.data || oldPointerCount < size.pointers) {
WirePointerCount newPointerCount = // The space allocated for this struct is too small. Unlike with readers, we can't just
std::max<WirePointerCount>(oldPointerCount, size.pointers); // run with it and do bounds checks at access time, because how would we handle writes?
WordCount totalSize = newDataSize + newPointerCount * WORDS_PER_POINTER; // Instead, we have to copy the struct to a new space now.
// Don't let allocate() zero out the object just yet. WordCount newDataSize = std::max<WordCount>(oldDataSize, size.data);
zeroPointerAndFars(segment, ref); WirePointerCount newPointerCount =
std::max<WirePointerCount>(oldPointerCount, size.pointers);
WordCount totalSize = newDataSize + newPointerCount * WORDS_PER_POINTER;
word* ptr = allocate(ref, segment, totalSize, WirePointer::STRUCT); // Don't let allocate() zero out the object just yet.
ref->structRef.set(newDataSize, newPointerCount); zeroPointerAndFars(segment, ref);
// Copy data section. word* ptr = allocate(ref, segment, totalSize, WirePointer::STRUCT);
memcpy(ptr, oldPtr, oldDataSize * BYTES_PER_WORD / BYTES); ref->structRef.set(newDataSize, newPointerCount);
// Copy pointer section. // Copy data section.
WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(ptr + newDataSize); memcpy(ptr, oldPtr, oldDataSize * BYTES_PER_WORD / BYTES);
for (uint i = 0; i < oldPointerCount / POINTERS; i++) {
transferPointer(segment, newPointerSection + i, oldSegment, oldPointerSection + i);
}
// Zero out old location. This has two purposes:
// 1) We don't want to leak the original contents of the struct when the message is written
// out as it may contain secrets that the caller intends to remove from the new copy.
// 2) Zeros will be deflated by packing, making this dead memory almost-free if it ever
// hits the wire.
memset(oldPtr, 0,
(oldDataSize + oldPointerCount * WORDS_PER_POINTER) * BYTES_PER_WORD / BYTES);
return StructBuilder(segment, ptr, newPointerSection, newDataSize * BITS_PER_WORD, // Copy pointer section.
newPointerCount, 0 * BITS); WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(ptr + newDataSize);
} else { for (uint i = 0; i < oldPointerCount / POINTERS; i++) {
return StructBuilder(oldSegment, oldPtr, oldPointerSection, oldDataSize * BITS_PER_WORD, transferPointer(segment, newPointerSection + i, oldSegment, oldPointerSection + i);
oldPointerCount, 0 * BITS);
} }
// Zero out old location. This has two purposes:
// 1) We don't want to leak the original contents of the struct when the message is written
// out as it may contain secrets that the caller intends to remove from the new copy.
// 2) Zeros will be deflated by packing, making this dead memory almost-free if it ever
// hits the wire.
memset(oldPtr, 0,
(oldDataSize + oldPointerCount * WORDS_PER_POINTER) * BYTES_PER_WORD / BYTES);
return StructBuilder(segment, ptr, newPointerSection, newDataSize * BITS_PER_WORD,
newPointerCount, 0 * BITS);
} else {
return StructBuilder(oldSegment, oldPtr, oldPointerSection, oldDataSize * BITS_PER_WORD,
oldPointerCount, 0 * BITS);
} }
} }
...@@ -869,93 +857,86 @@ struct WireHelpers { ...@@ -869,93 +857,86 @@ struct WireHelpers {
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) { reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return ListBuilder(); return ListBuilder();
} }
word* ptr = copyMessage(origSegment, origRef, copyMessage(origSegment, origRef, reinterpret_cast<const WirePointer*>(defaultValue));
reinterpret_cast<const WirePointer*>(defaultValue)); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
BitCount dataSize = dataBitsPerElement(elementSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(elementSize) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(origSegment, ptr, step, origRef->listRef.elementCount(),
dataSize, pointerCount);
} else { // We must verify that the pointer has the right size. Unlike in
// The pointer is already initialized. We must verify that it has the right size. Unlike // getWritableStructListReference(), we never need to "upgrade" the data, because this
// in getWritableStructListReference(), we never need to "upgrade" the data, because this // method is called only for non-struct lists, and there is no allowed upgrade path *to*
// method is called only for non-struct lists, and there is no allowed upgrade path *to* // a non-struct list, only *from* them.
// a non-struct list, only *from* them.
WirePointer* ref = origRef; WirePointer* ref = origRef;
SegmentBuilder* segment = origSegment; SegmentBuilder* segment = origSegment;
word* ptr = followFars(ref, segment); word* ptr = followFars(ref, segment);
VALIDATE_INPUT(ref->kind() == WirePointer::LIST, VALIDATE_INPUT(ref->kind() == WirePointer::LIST,
"Called getList{Field,Element}() but existing pointer is not a list.") { "Called getList{Field,Element}() but existing pointer is not a list.") {
goto useDefault; goto useDefault;
} }
FieldSize oldSize = ref->listRef.elementSize(); FieldSize oldSize = ref->listRef.elementSize();
if (oldSize == FieldSize::INLINE_COMPOSITE) { if (oldSize == FieldSize::INLINE_COMPOSITE) {
// The existing element size is INLINE_COMPOSITE, which means that it is at least two // The existing element size is INLINE_COMPOSITE, which means that it is at least two
// words, which makes it bigger than the expected element size. Since fields can only // words, which makes it bigger than the expected element size. Since fields can only
// grow when upgraded, the existing data must have been written with a newer version of // grow when upgraded, the existing data must have been written with a newer version of
// the protocol. We therefore never need to upgrade the data in this case, but we do // the protocol. We therefore never need to upgrade the data in this case, but we do
// need to validate that it is a valid upgrade from what we expected. // need to validate that it is a valid upgrade from what we expected.
// Read the tag to get the actual element count. // Read the tag to get the actual element count.
WirePointer* tag = reinterpret_cast<WirePointer*>(ptr); WirePointer* tag = reinterpret_cast<WirePointer*>(ptr);
PRECOND(tag->kind() == WirePointer::STRUCT, PRECOND(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE list with non-STRUCT elements not supported."); "INLINE_COMPOSITE list with non-STRUCT elements not supported.");
ptr += POINTER_SIZE_IN_WORDS; ptr += POINTER_SIZE_IN_WORDS;
WordCount dataSize = tag->structRef.dataSize.get(); WordCount dataSize = tag->structRef.dataSize.get();
WirePointerCount pointerCount = tag->structRef.ptrCount.get(); WirePointerCount pointerCount = tag->structRef.ptrCount.get();
switch (elementSize) { switch (elementSize) {
case FieldSize::VOID: case FieldSize::VOID:
// Anything is a valid upgrade from Void. // Anything is a valid upgrade from Void.
break; break;
case FieldSize::BIT: case FieldSize::BIT:
case FieldSize::BYTE: case FieldSize::BYTE:
case FieldSize::TWO_BYTES: case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES: case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES: case FieldSize::EIGHT_BYTES:
VALIDATE_INPUT(dataSize >= 1 * WORDS, VALIDATE_INPUT(dataSize >= 1 * WORDS,
"Existing list value is incompatible with expected type."); "Existing list value is incompatible with expected type.");
break; break;
case FieldSize::POINTER: case FieldSize::POINTER:
VALIDATE_INPUT(pointerCount >= 1 * POINTERS, VALIDATE_INPUT(pointerCount >= 1 * POINTERS,
"Existing list value is incompatible with expected type."); "Existing list value is incompatible with expected type.");
// Adjust the pointer to point at the reference segment. // Adjust the pointer to point at the reference segment.
ptr += dataSize; ptr += dataSize;
break; break;
case FieldSize::INLINE_COMPOSITE: case FieldSize::INLINE_COMPOSITE:
FAIL_CHECK("Can't get here."); FAIL_CHECK("Can't get here.");
break; break;
} }
// OK, looks valid. // OK, looks valid.
return ListBuilder(segment, ptr, return ListBuilder(segment, ptr,
tag->structRef.wordSize() * BITS_PER_WORD / ELEMENTS, tag->structRef.wordSize() * BITS_PER_WORD / ELEMENTS,
tag->inlineCompositeListElementCount(), tag->inlineCompositeListElementCount(),
dataSize * BITS_PER_WORD, pointerCount); dataSize * BITS_PER_WORD, pointerCount);
} else { } else {
BitCount dataSize = dataBitsPerElement(oldSize) * ELEMENTS; BitCount dataSize = dataBitsPerElement(oldSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(oldSize) * ELEMENTS; WirePointerCount pointerCount = pointersPerElement(oldSize) * ELEMENTS;
VALIDATE_INPUT(dataSize >= dataBitsPerElement(elementSize) * ELEMENTS, VALIDATE_INPUT(dataSize >= dataBitsPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type."); "Existing list value is incompatible with expected type.");
VALIDATE_INPUT(pointerCount >= pointersPerElement(elementSize) * ELEMENTS, VALIDATE_INPUT(pointerCount >= pointersPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type."); "Existing list value is incompatible with expected type.");
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS; auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(segment, ptr, step, ref->listRef.elementCount(), return ListBuilder(segment, ptr, step, ref->listRef.elementCount(),
dataSize, pointerCount); dataSize, pointerCount);
}
} }
} }
...@@ -968,268 +949,249 @@ struct WireHelpers { ...@@ -968,268 +949,249 @@ struct WireHelpers {
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) { reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return ListBuilder(); return ListBuilder();
} }
word* ptr = copyMessage(origSegment, origRef, copyMessage(origSegment, origRef, reinterpret_cast<const WirePointer*>(defaultValue));
reinterpret_cast<const WirePointer*>(defaultValue)); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
// Assume the default value is valid. // We must verify that the pointer has the right size and potentially upgrade it if not.
if (elementSize.preferredListEncoding == FieldSize::INLINE_COMPOSITE) {
WirePointer* tag = reinterpret_cast<WirePointer*>(ptr);
return ListBuilder(origSegment, tag + 1, elementSize.total() * BITS_PER_WORD / ELEMENTS,
tag->inlineCompositeListElementCount(),
elementSize.data * BITS_PER_WORD,
elementSize.pointers);
} else {
BitCount dataSize = dataBitsPerElement(elementSize.preferredListEncoding) * ELEMENTS;
WirePointerCount pointerCount =
pointersPerElement(elementSize.preferredListEncoding) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(origSegment, ptr, step, origRef->listRef.elementCount(), WirePointer* oldRef = origRef;
dataSize, pointerCount); SegmentBuilder* oldSegment = origSegment;
} word* oldPtr = followFars(oldRef, oldSegment);
} else { VALIDATE_INPUT(oldRef->kind() == WirePointer::LIST,
// The pointer is already initialized. We must verify that it has the right size and "Called getList{Field,Element}() but existing pointer is not a list.") {
// potentially upgrade it if not. goto useDefault;
}
FieldSize oldSize = oldRef->listRef.elementSize();
WirePointer* oldRef = origRef; if (oldSize == FieldSize::INLINE_COMPOSITE) {
SegmentBuilder* oldSegment = origSegment; // Existing list is INLINE_COMPOSITE, but we need to verify that the sizes match.
word* oldPtr = followFars(oldRef, oldSegment);
VALIDATE_INPUT(oldRef->kind() == WirePointer::LIST, WirePointer* oldTag = reinterpret_cast<WirePointer*>(oldPtr);
"Called getList{Field,Element}() but existing pointer is not a list.") { oldPtr += POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(oldTag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE list with non-STRUCT elements not supported.") {
goto useDefault; goto useDefault;
} }
FieldSize oldSize = oldRef->listRef.elementSize(); WordCount oldDataSize = oldTag->structRef.dataSize.get();
WirePointerCount oldPointerCount = oldTag->structRef.ptrCount.get();
auto oldStep = (oldDataSize + oldPointerCount * WORDS_PER_POINTER) / ELEMENTS;
ElementCount elementCount = oldTag->inlineCompositeListElementCount();
if (oldSize == FieldSize::INLINE_COMPOSITE) { if (oldDataSize >= elementSize.data && oldPointerCount >= elementSize.pointers) {
// Existing list is INLINE_COMPOSITE, but we need to verify that the sizes match. // Old size is at least as large as we need. Ship it.
return ListBuilder(oldSegment, oldPtr, oldStep * BITS_PER_WORD, elementCount,
oldDataSize * BITS_PER_WORD, oldPointerCount);
}
WirePointer* oldTag = reinterpret_cast<WirePointer*>(oldPtr); // The structs in this list are smaller than expected, probably written using an older
oldPtr += POINTER_SIZE_IN_WORDS; // version of the protocol. We need to make a copy and expand them.
VALIDATE_INPUT(oldTag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE list with non-STRUCT elements not supported.") {
goto useDefault;
}
WordCount oldDataSize = oldTag->structRef.dataSize.get(); WordCount newDataSize = std::max<WordCount>(oldDataSize, elementSize.data);
WirePointerCount oldPointerCount = oldTag->structRef.ptrCount.get(); WirePointerCount newPointerCount =
auto oldStep = (oldDataSize + oldPointerCount * WORDS_PER_POINTER) / ELEMENTS; std::max<WirePointerCount>(oldPointerCount, elementSize.pointers);
ElementCount elementCount = oldTag->inlineCompositeListElementCount(); auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS;
WordCount totalSize = newStep * elementCount;
if (oldDataSize >= elementSize.data && oldPointerCount >= elementSize.pointers) { // Don't let allocate() zero out the object just yet.
// Old size is at least as large as we need. Ship it. zeroPointerAndFars(origSegment, origRef);
return ListBuilder(oldSegment, oldPtr, oldStep * BITS_PER_WORD, elementCount,
oldDataSize * BITS_PER_WORD, oldPointerCount);
}
// The structs in this list are smaller than expected, probably written using an older word* newPtr = allocate(origRef, origSegment, totalSize + POINTER_SIZE_IN_WORDS,
// version of the protocol. We need to make a copy and expand them. WirePointer::LIST);
origRef->listRef.setInlineComposite(totalSize);
WordCount newDataSize = std::max<WordCount>(oldDataSize, elementSize.data); WirePointer* newTag = reinterpret_cast<WirePointer*>(newPtr);
WirePointerCount newPointerCount = newTag->setKindAndInlineCompositeListElementCount(WirePointer::STRUCT, elementCount);
std::max<WirePointerCount>(oldPointerCount, elementSize.pointers); newTag->structRef.set(newDataSize, newPointerCount);
auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS; newPtr += POINTER_SIZE_IN_WORDS;
WordCount totalSize = newStep * elementCount;
// Don't let allocate() zero out the object just yet. word* src = oldPtr;
zeroPointerAndFars(origSegment, origRef); word* dst = newPtr;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
word* newPtr = allocate(origRef, origSegment, totalSize + POINTER_SIZE_IN_WORDS, // Copy data section.
WirePointer::LIST); memcpy(dst, src, oldDataSize * BYTES_PER_WORD / BYTES);
origRef->listRef.setInlineComposite(totalSize);
WirePointer* newTag = reinterpret_cast<WirePointer*>(newPtr);
newTag->setKindAndInlineCompositeListElementCount(WirePointer::STRUCT, elementCount);
newTag->structRef.set(newDataSize, newPointerCount);
newPtr += POINTER_SIZE_IN_WORDS;
word* src = oldPtr;
word* dst = newPtr;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
// Copy data section.
memcpy(dst, src, oldDataSize * BYTES_PER_WORD / BYTES);
// Copy pointer section.
WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(dst + newDataSize);
WirePointer* oldPointerSection = reinterpret_cast<WirePointer*>(src + oldDataSize);
for (uint i = 0; i < oldPointerCount / POINTERS; i++) {
transferPointer(origSegment, newPointerSection + i, oldSegment, oldPointerSection + i);
}
dst += newStep * (1 * ELEMENTS); // Copy pointer section.
src += oldStep * (1 * ELEMENTS); WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(dst + newDataSize);
WirePointer* oldPointerSection = reinterpret_cast<WirePointer*>(src + oldDataSize);
for (uint i = 0; i < oldPointerCount / POINTERS; i++) {
transferPointer(origSegment, newPointerSection + i, oldSegment, oldPointerSection + i);
} }
// Zero out old location. See explanation in getWritableStructPointer(). dst += newStep * (1 * ELEMENTS);
memset(oldPtr, 0, oldStep * elementCount * BYTES_PER_WORD / BYTES); src += oldStep * (1 * ELEMENTS);
}
return ListBuilder(origSegment, newPtr, newStep * BITS_PER_WORD, elementCount, // Zero out old location. See explanation in getWritableStructPointer().
newDataSize * BITS_PER_WORD, newPointerCount); memset(oldPtr, 0, oldStep * elementCount * BYTES_PER_WORD / BYTES);
} else if (oldSize == elementSize.preferredListEncoding) {
// Old size matches exactly.
auto dataSize = dataBitsPerElement(oldSize); return ListBuilder(origSegment, newPtr, newStep * BITS_PER_WORD, elementCount,
auto pointerCount = pointersPerElement(oldSize); newDataSize * BITS_PER_WORD, newPointerCount);
auto step = dataSize + pointerCount * BITS_PER_POINTER; } else if (oldSize == elementSize.preferredListEncoding) {
// Old size matches exactly.
return ListBuilder(oldSegment, oldPtr, step, oldRef->listRef.elementCount(), auto dataSize = dataBitsPerElement(oldSize);
dataSize * (1 * ELEMENTS), pointerCount * (1 * ELEMENTS)); auto pointerCount = pointersPerElement(oldSize);
} else { auto step = dataSize + pointerCount * BITS_PER_POINTER;
switch (elementSize.preferredListEncoding) {
case FieldSize::VOID: return ListBuilder(oldSegment, oldPtr, step, oldRef->listRef.elementCount(),
// No expectations. dataSize * (1 * ELEMENTS), pointerCount * (1 * ELEMENTS));
break; } else {
case FieldSize::POINTER: switch (elementSize.preferredListEncoding) {
VALIDATE_INPUT(oldSize == FieldSize::POINTER || oldSize == FieldSize::VOID, case FieldSize::VOID:
"Struct list has incompatible element size.") { // No expectations.
goto useDefault; break;
} case FieldSize::POINTER:
break; VALIDATE_INPUT(oldSize == FieldSize::POINTER || oldSize == FieldSize::VOID,
case FieldSize::INLINE_COMPOSITE: "Struct list has incompatible element size.") {
// Old size can be anything. goto useDefault;
break; }
case FieldSize::BIT: break;
case FieldSize::BYTE: case FieldSize::INLINE_COMPOSITE:
case FieldSize::TWO_BYTES: // Old size can be anything.
case FieldSize::FOUR_BYTES: break;
case FieldSize::EIGHT_BYTES: case FieldSize::BIT:
// Preferred size is data-only. case FieldSize::BYTE:
VALIDATE_INPUT(oldSize != FieldSize::POINTER, case FieldSize::TWO_BYTES:
"Struct list has incompatible element size.") { case FieldSize::FOUR_BYTES:
goto useDefault; case FieldSize::EIGHT_BYTES:
} // Preferred size is data-only.
break; VALIDATE_INPUT(oldSize != FieldSize::POINTER,
} "Struct list has incompatible element size.") {
goto useDefault;
}
break;
}
// OK, the old size is compatible with the preferred, but is not exactly the same. We may // OK, the old size is compatible with the preferred, but is not exactly the same. We may
// need to upgrade it. // need to upgrade it.
BitCount oldDataSize = dataBitsPerElement(oldSize) * ELEMENTS; BitCount oldDataSize = dataBitsPerElement(oldSize) * ELEMENTS;
WirePointerCount oldPointerCount = pointersPerElement(oldSize) * ELEMENTS; WirePointerCount oldPointerCount = pointersPerElement(oldSize) * ELEMENTS;
auto oldStep = (oldDataSize + oldPointerCount * BITS_PER_POINTER) / ELEMENTS; auto oldStep = (oldDataSize + oldPointerCount * BITS_PER_POINTER) / ELEMENTS;
ElementCount elementCount = oldRef->listRef.elementCount(); ElementCount elementCount = oldRef->listRef.elementCount();
if (oldSize >= elementSize.preferredListEncoding) { if (oldSize >= elementSize.preferredListEncoding) {
// The old size is at least as large as the preferred, so we don't need to upgrade. // The old size is at least as large as the preferred, so we don't need to upgrade.
return ListBuilder(oldSegment, oldPtr, oldStep, elementCount, return ListBuilder(oldSegment, oldPtr, oldStep, elementCount,
oldDataSize, oldPointerCount); oldDataSize, oldPointerCount);
} }
// Upgrade is necessary. // Upgrade is necessary.
if (oldSize == FieldSize::VOID) { if (oldSize == FieldSize::VOID) {
// Nothing to copy, just allocate a new list. // Nothing to copy, just allocate a new list.
return initStructListPointer(origRef, origSegment, elementCount, elementSize); return initStructListPointer(origRef, origSegment, elementCount, elementSize);
} else if (elementSize.preferredListEncoding == FieldSize::INLINE_COMPOSITE) { } else if (elementSize.preferredListEncoding == FieldSize::INLINE_COMPOSITE) {
// Upgrading to an inline composite list. // Upgrading to an inline composite list.
WordCount newDataSize = elementSize.data; WordCount newDataSize = elementSize.data;
WirePointerCount newPointerCount = elementSize.pointers; WirePointerCount newPointerCount = elementSize.pointers;
if (oldSize == FieldSize::POINTER) { if (oldSize == FieldSize::POINTER) {
newPointerCount = std::max(newPointerCount, 1 * POINTERS); newPointerCount = std::max(newPointerCount, 1 * POINTERS);
} else { } else {
// Old list contains data elements, so we need at least 1 word of data. // Old list contains data elements, so we need at least 1 word of data.
newDataSize = std::max(newDataSize, 1 * WORDS); newDataSize = std::max(newDataSize, 1 * WORDS);
} }
auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS; auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS;
WordCount totalWords = elementCount * newStep; WordCount totalWords = elementCount * newStep;
// Don't let allocate() zero out the object just yet. // Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef); zeroPointerAndFars(origSegment, origRef);
word* newPtr = allocate(origRef, origSegment, totalWords + POINTER_SIZE_IN_WORDS, word* newPtr = allocate(origRef, origSegment, totalWords + POINTER_SIZE_IN_WORDS,
WirePointer::LIST); WirePointer::LIST);
origRef->listRef.setInlineComposite(totalWords); origRef->listRef.setInlineComposite(totalWords);
WirePointer* tag = reinterpret_cast<WirePointer*>(newPtr); WirePointer* tag = reinterpret_cast<WirePointer*>(newPtr);
tag->setKindAndInlineCompositeListElementCount(WirePointer::STRUCT, elementCount); tag->setKindAndInlineCompositeListElementCount(WirePointer::STRUCT, elementCount);
tag->structRef.set(newDataSize, newPointerCount); tag->structRef.set(newDataSize, newPointerCount);
newPtr += POINTER_SIZE_IN_WORDS; newPtr += POINTER_SIZE_IN_WORDS;
if (oldSize == FieldSize::POINTER) { if (oldSize == FieldSize::POINTER) {
WirePointer* dst = reinterpret_cast<WirePointer*>(newPtr + newDataSize); WirePointer* dst = reinterpret_cast<WirePointer*>(newPtr + newDataSize);
WirePointer* src = reinterpret_cast<WirePointer*>(oldPtr); WirePointer* src = reinterpret_cast<WirePointer*>(oldPtr);
for (uint i = 0; i < elementCount / ELEMENTS; i++) { for (uint i = 0; i < elementCount / ELEMENTS; i++) {
transferPointer(origSegment, dst, oldSegment, src); transferPointer(origSegment, dst, oldSegment, src);
dst += newStep / WORDS_PER_POINTER * (1 * ELEMENTS); dst += newStep / WORDS_PER_POINTER * (1 * ELEMENTS);
++src; ++src;
}
} else if (oldSize == FieldSize::BIT) {
word* dst = newPtr;
char* src = reinterpret_cast<char*>(oldPtr);
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
*reinterpret_cast<char*>(dst) = (src[i/8] >> (i%8)) & 1;
dst += newStep * (1 * ELEMENTS);
}
} else {
word* dst = newPtr;
char* src = reinterpret_cast<char*>(oldPtr);
ByteCount oldByteStep = oldDataSize / BITS_PER_BYTE;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
memcpy(dst, src, oldByteStep / BYTES);
src += oldByteStep / BYTES;
dst += newStep * (1 * ELEMENTS);
}
} }
} else if (oldSize == FieldSize::BIT) {
word* dst = newPtr;
char* src = reinterpret_cast<char*>(oldPtr);
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
*reinterpret_cast<char*>(dst) = (src[i/8] >> (i%8)) & 1;
dst += newStep * (1 * ELEMENTS);
}
} else {
word* dst = newPtr;
char* src = reinterpret_cast<char*>(oldPtr);
ByteCount oldByteStep = oldDataSize / BITS_PER_BYTE;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
memcpy(dst, src, oldByteStep / BYTES);
src += oldByteStep / BYTES;
dst += newStep * (1 * ELEMENTS);
}
}
// Zero out old location. See explanation in getWritableStructPointer(). // Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, roundUpToBytes(oldStep * elementCount) / BYTES); memset(oldPtr, 0, roundUpToBytes(oldStep * elementCount) / BYTES);
return ListBuilder(origSegment, newPtr, newStep * BITS_PER_WORD, elementCount, return ListBuilder(origSegment, newPtr, newStep * BITS_PER_WORD, elementCount,
newDataSize * BITS_PER_WORD, newPointerCount); newDataSize * BITS_PER_WORD, newPointerCount);
} else { } else {
// If oldSize were POINTER or EIGHT_BYTES then the preferred size must be // If oldSize were POINTER or EIGHT_BYTES then the preferred size must be
// INLINE_COMPOSITE because any other compatible size would not require an upgrade. // INLINE_COMPOSITE because any other compatible size would not require an upgrade.
CHECK(oldSize < FieldSize::EIGHT_BYTES); CHECK(oldSize < FieldSize::EIGHT_BYTES);
// If the preferred size were BIT then oldSize must be VOID, but we handled that case // If the preferred size were BIT then oldSize must be VOID, but we handled that case
// above. // above.
CHECK(elementSize.preferredListEncoding >= FieldSize::BIT); CHECK(elementSize.preferredListEncoding >= FieldSize::BIT);
// OK, so the expected list elements are all data and between 1 byte and 1 word each, // OK, so the expected list elements are all data and between 1 byte and 1 word each,
// and the old element are data between 1 bit and 4 bytes. We're upgrading from one // and the old element are data between 1 bit and 4 bytes. We're upgrading from one
// primitive data type to another, larger one. // primitive data type to another, larger one.
BitCount newDataSize = BitCount newDataSize =
dataBitsPerElement(elementSize.preferredListEncoding) * ELEMENTS; dataBitsPerElement(elementSize.preferredListEncoding) * ELEMENTS;
WordCount totalWords = WordCount totalWords =
roundUpToWords(BitCount64(newDataSize) * (elementCount / ELEMENTS)); roundUpToWords(BitCount64(newDataSize) * (elementCount / ELEMENTS));
// Don't let allocate() zero out the object just yet. // Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef); zeroPointerAndFars(origSegment, origRef);
word* newPtr = allocate(origRef, origSegment, totalWords, WirePointer::LIST); word* newPtr = allocate(origRef, origSegment, totalWords, WirePointer::LIST);
origRef->listRef.set(elementSize.preferredListEncoding, elementCount); origRef->listRef.set(elementSize.preferredListEncoding, elementCount);
char* newBytePtr = reinterpret_cast<char*>(newPtr); char* newBytePtr = reinterpret_cast<char*>(newPtr);
char* oldBytePtr = reinterpret_cast<char*>(oldPtr); char* oldBytePtr = reinterpret_cast<char*>(oldPtr);
ByteCount newDataByteSize = newDataSize / BITS_PER_BYTE; ByteCount newDataByteSize = newDataSize / BITS_PER_BYTE;
if (oldSize == FieldSize::BIT) { if (oldSize == FieldSize::BIT) {
for (uint i = 0; i < elementCount / ELEMENTS; i++) { for (uint i = 0; i < elementCount / ELEMENTS; i++) {
*newBytePtr = (oldBytePtr[i/8] >> (i%8)) & 1; *newBytePtr = (oldBytePtr[i/8] >> (i%8)) & 1;
newBytePtr += newDataByteSize / BYTES; newBytePtr += newDataByteSize / BYTES;
}
} else {
ByteCount oldDataByteSize = oldDataSize / BITS_PER_BYTE;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
memcpy(newBytePtr, oldBytePtr, oldDataByteSize / BYTES);
oldBytePtr += oldDataByteSize / BYTES;
newBytePtr += newDataByteSize / BYTES;
}
} }
} else {
ByteCount oldDataByteSize = oldDataSize / BITS_PER_BYTE;
for (uint i = 0; i < elementCount / ELEMENTS; i++) {
memcpy(newBytePtr, oldBytePtr, oldDataByteSize / BYTES);
oldBytePtr += oldDataByteSize / BYTES;
newBytePtr += newDataByteSize / BYTES;
}
}
// Zero out old location. See explanation in getWritableStructPointer(). // Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, roundUpToBytes(oldStep * elementCount) / BYTES); memset(oldPtr, 0, roundUpToBytes(oldStep * elementCount) / BYTES);
return ListBuilder(origSegment, newPtr, newDataSize / ELEMENTS, elementCount, return ListBuilder(origSegment, newPtr, newDataSize / ELEMENTS, elementCount,
newDataSize, 0 * POINTERS); newDataSize, 0 * POINTERS);
}
} }
} }
} }
...@@ -1459,42 +1421,36 @@ struct WireHelpers { ...@@ -1459,42 +1421,36 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(StructReader readStructPointer( static CAPNPROTO_ALWAYS_INLINE(StructReader readStructPointer(
SegmentReader* segment, const WirePointer* ref, const word* defaultValue, SegmentReader* segment, const WirePointer* ref, const word* defaultValue,
int nestingLimit)) { int nestingLimit)) {
const word* ptr;
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr || if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) { reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return StructReader(nullptr, nullptr, nullptr, 0 * BITS, 0 * POINTERS, 0 * BITS, return StructReader();
std::numeric_limits<int>::max());
} }
segment = nullptr; segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue); ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target(); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} else if (segment != nullptr) { }
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
ptr = followFars(ref, segment); VALIDATE_INPUT(nestingLimit > 0,
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
// Already reported the error. goto useDefault;
goto useDefault; }
}
VALIDATE_INPUT(ref->kind() == WirePointer::STRUCT, const word* ptr = followFars(ref, segment);
"Message contains non-struct pointer where struct pointer was expected.") { if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
goto useDefault; // Already reported the error.
} goto useDefault;
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()), VALIDATE_INPUT(ref->kind() == WirePointer::STRUCT,
"Message contained out-of-bounds struct pointer.") { "Message contains non-struct pointer where struct pointer was expected.") {
goto useDefault; goto useDefault;
} }
} else {
// Trusted messages don't contain far pointers. VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
ptr = ref->target(); "Message contained out-of-bounds struct pointer.") {
goto useDefault;
} }
return StructReader( return StructReader(
...@@ -1507,7 +1463,6 @@ struct WireHelpers { ...@@ -1507,7 +1463,6 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(ListReader readListPointer( static CAPNPROTO_ALWAYS_INLINE(ListReader readListPointer(
SegmentReader* segment, const WirePointer* ref, const word* defaultValue, SegmentReader* segment, const WirePointer* ref, const word* defaultValue,
FieldSize expectedElementSize, int nestingLimit)) { FieldSize expectedElementSize, int nestingLimit)) {
const word* ptr;
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr || if (defaultValue == nullptr ||
...@@ -1516,26 +1471,23 @@ struct WireHelpers { ...@@ -1516,26 +1471,23 @@ struct WireHelpers {
} }
segment = nullptr; segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue); ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target(); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} else if (segment != nullptr) { }
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
ptr = followFars(ref, segment); VALIDATE_INPUT(nestingLimit > 0,
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
// Already reported error. goto useDefault;
goto useDefault; }
}
VALIDATE_INPUT(ref->kind() == WirePointer::LIST, const word* ptr = followFars(ref, segment);
"Message contains non-list pointer where list pointer was expected.") { if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
goto useDefault; // Already reported error.
} goto useDefault;
} else { }
// Trusted messages don't contain far pointers.
ptr = ref->target(); VALIDATE_INPUT(ref->kind() == WirePointer::LIST,
"Message contains non-list pointer where list pointer was expected.") {
goto useDefault;
} }
if (ref->listRef.elementSize() == FieldSize::INLINE_COMPOSITE) { if (ref->listRef.elementSize() == FieldSize::INLINE_COMPOSITE) {
...@@ -1548,75 +1500,63 @@ struct WireHelpers { ...@@ -1548,75 +1500,63 @@ struct WireHelpers {
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS; ptr += POINTER_SIZE_IN_WORDS;
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
VALIDATE_INPUT(segment->containsInterval(ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount), "Message contains out-of-bounds list pointer.") {
"Message contains out-of-bounds list pointer.") { goto useDefault;
goto useDefault; }
}
VALIDATE_INPUT(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault;
}
size = tag->inlineCompositeListElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
VALIDATE_INPUT(size * wordsPerElement <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.") {
goto useDefault;
}
// If a struct list was not expected, then presumably a non-struct list was upgraded to a VALIDATE_INPUT(tag->kind() == WirePointer::STRUCT,
// struct list. We need to manipulate the pointer to point at the first field of the "INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
// struct. Together with the "stepBits", this will allow the struct list to be accessed as goto useDefault;
// if it were a primitive list without branching. }
// Check whether the size is compatible. size = tag->inlineCompositeListElementCount();
switch (expectedElementSize) { wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
case FieldSize::VOID:
break;
case FieldSize::BIT: VALIDATE_INPUT(size * wordsPerElement <= wordCount,
FAIL_VALIDATE_INPUT("Expected a bit list, but got a list of structs.") { "INLINE_COMPOSITE list's elements overrun its word count.") {
goto useDefault; goto useDefault;
} }
break;
case FieldSize::BYTE: // If a struct list was not expected, then presumably a non-struct list was upgraded to a
case FieldSize::TWO_BYTES: // struct list. We need to manipulate the pointer to point at the first field of the
case FieldSize::FOUR_BYTES: // struct. Together with the "stepBits", this will allow the struct list to be accessed as
case FieldSize::EIGHT_BYTES: // if it were a primitive list without branching.
VALIDATE_INPUT(tag->structRef.dataSize.get() > 0 * WORDS,
"Expected a primitive list, but got a list of pointer-only structs.") {
goto useDefault;
}
break;
case FieldSize::POINTER: // Check whether the size is compatible.
// We expected a list of pointers but got a list of structs. Assuming the first field switch (expectedElementSize) {
// in the struct is the pointer we were looking for, we want to munge the pointer to case FieldSize::VOID:
// point at the first element's pointer segment. break;
ptr += tag->structRef.dataSize.get();
VALIDATE_INPUT(tag->structRef.ptrCount.get() > 0 * POINTERS,
"Expected a pointer list, but got a list of data-only structs.") {
goto useDefault;
}
break;
case FieldSize::INLINE_COMPOSITE: case FieldSize::BIT:
break; FAIL_VALIDATE_INPUT("Expected a bit list, but got a list of structs.") {
} goto useDefault;
}
break;
} else { case FieldSize::BYTE:
// Trusted message. case FieldSize::TWO_BYTES:
// This logic is equivalent to the other branch, above, but skipping all the checks. case FieldSize::FOUR_BYTES:
size = tag->inlineCompositeListElementCount(); case FieldSize::EIGHT_BYTES:
wordsPerElement = tag->structRef.wordSize() / ELEMENTS; VALIDATE_INPUT(tag->structRef.dataSize.get() > 0 * WORDS,
"Expected a primitive list, but got a list of pointer-only structs.") {
goto useDefault;
}
break;
if (expectedElementSize == FieldSize::POINTER) { case FieldSize::POINTER:
// We expected a list of pointers but got a list of structs. Assuming the first field
// in the struct is the pointer we were looking for, we want to munge the pointer to
// point at the first element's pointer segment.
ptr += tag->structRef.dataSize.get(); ptr += tag->structRef.dataSize.get();
} VALIDATE_INPUT(tag->structRef.ptrCount.get() > 0 * POINTERS,
"Expected a pointer list, but got a list of data-only structs.") {
goto useDefault;
}
break;
case FieldSize::INLINE_COMPOSITE:
break;
} }
return ListReader( return ListReader(
...@@ -1632,33 +1572,29 @@ struct WireHelpers { ...@@ -1632,33 +1572,29 @@ struct WireHelpers {
pointersPerElement(ref->listRef.elementSize()) * ELEMENTS; pointersPerElement(ref->listRef.elementSize()) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS; auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)),
roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)), "Message contains out-of-bounds list pointer.") {
"Message contains out-of-bounds list pointer.") { goto useDefault;
goto useDefault;
}
} }
if (segment != nullptr) { // Verify that the elements are at least as large as the expected type. Note that if we
// Verify that the elements are at least as large as the expected type. Note that if we // expected INLINE_COMPOSITE, the expected sizes here will be zero, because bounds checking
// expected INLINE_COMPOSITE, the expected sizes here will be zero, because bounds checking // will be performed at field access time. So this check here is for the case where we
// will be performed at field access time. So this check here is for the case where we // expected a list of some primitive or pointer type.
// expected a list of some primitive or pointer type.
BitCount expectedDataBitsPerElement = BitCount expectedDataBitsPerElement =
dataBitsPerElement(expectedElementSize) * ELEMENTS; dataBitsPerElement(expectedElementSize) * ELEMENTS;
WirePointerCount expectedPointersPerElement = WirePointerCount expectedPointersPerElement =
pointersPerElement(expectedElementSize) * ELEMENTS; pointersPerElement(expectedElementSize) * ELEMENTS;
VALIDATE_INPUT(expectedDataBitsPerElement <= dataSize, VALIDATE_INPUT(expectedDataBitsPerElement <= dataSize,
"Message contained list with incompatible element type.") { "Message contained list with incompatible element type.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(expectedPointersPerElement <= pointerCount, VALIDATE_INPUT(expectedPointersPerElement <= pointerCount,
"Message contained list with incompatible element type.") { "Message contained list with incompatible element type.") {
goto useDefault; goto useDefault;
}
} }
return ListReader(segment, ptr, ref->listRef.elementCount(), step, return ListReader(segment, ptr, ref->listRef.elementCount(), step,
...@@ -1671,15 +1607,8 @@ struct WireHelpers { ...@@ -1671,15 +1607,8 @@ struct WireHelpers {
const void* defaultValue, ByteCount defaultSize)) { const void* defaultValue, ByteCount defaultSize)) {
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr || if (defaultValue == nullptr) defaultValue = "";
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
defaultValue = "";
}
return Text::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES); return Text::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES);
} else if (segment == nullptr) {
// Trusted message.
return Text::Reader(reinterpret_cast<const char*>(ref->target()),
ref->listRef.elementCount() / ELEMENTS - 1);
} else { } else {
const word* ptr = followFars(ref, segment); const word* ptr = followFars(ref, segment);
...@@ -1700,7 +1629,7 @@ struct WireHelpers { ...@@ -1700,7 +1629,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))), roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds text pointer.") { "Message contained out-of-bounds text pointer.") {
goto useDefault; goto useDefault;
...@@ -1727,10 +1656,6 @@ struct WireHelpers { ...@@ -1727,10 +1656,6 @@ struct WireHelpers {
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
return Data::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES); return Data::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES);
} else if (segment == nullptr) {
// Trusted message.
return Data::Reader(reinterpret_cast<const char*>(ref->target()),
ref->listRef.elementCount() / ELEMENTS);
} else { } else {
const word* ptr = followFars(ref, segment); const word* ptr = followFars(ref, segment);
...@@ -1751,7 +1676,7 @@ struct WireHelpers { ...@@ -1751,7 +1676,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))), roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds data pointer.") { "Message contained out-of-bounds data pointer.") {
goto useDefault; goto useDefault;
...@@ -1771,7 +1696,6 @@ struct WireHelpers { ...@@ -1771,7 +1696,6 @@ struct WireHelpers {
// Not always-inline because it is called from several places in the copying code, and anyway // Not always-inline because it is called from several places in the copying code, and anyway
// is relatively rarely used. // is relatively rarely used.
const word* ptr;
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr || if (defaultValue == nullptr ||
...@@ -1780,29 +1704,25 @@ struct WireHelpers { ...@@ -1780,29 +1704,25 @@ struct WireHelpers {
} }
segment = nullptr; segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue); ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target(); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} else if (segment != nullptr) { }
ptr = WireHelpers::followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) { const word* ptr = WireHelpers::followFars(ref, segment);
// Already reported the error. if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
goto useDefault; // Already reported the error.
} goto useDefault;
} else {
ptr = ref->target();
} }
switch (ref->kind()) { switch (ref->kind()) {
case WirePointer::STRUCT: case WirePointer::STRUCT:
if (segment != nullptr) { VALIDATE_INPUT(nestingLimit > 0,
VALIDATE_INPUT(nestingLimit > 0, "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { goto useDefault;
goto useDefault; }
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()), VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
goto useDefault; goto useDefault;
}
} }
return ObjectReader( return ObjectReader(
StructReader(segment, ptr, StructReader(segment, ptr,
...@@ -1813,11 +1733,9 @@ struct WireHelpers { ...@@ -1813,11 +1733,9 @@ struct WireHelpers {
case WirePointer::LIST: { case WirePointer::LIST: {
FieldSize elementSize = ref->listRef.elementSize(); FieldSize elementSize = ref->listRef.elementSize();
if (segment != nullptr) { VALIDATE_INPUT(nestingLimit > 0,
VALIDATE_INPUT(nestingLimit > 0, "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { goto useDefault;
goto useDefault;
}
} }
if (elementSize == FieldSize::INLINE_COMPOSITE) { if (elementSize == FieldSize::INLINE_COMPOSITE) {
...@@ -1825,26 +1743,21 @@ struct WireHelpers { ...@@ -1825,26 +1743,21 @@ struct WireHelpers {
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS; ptr += POINTER_SIZE_IN_WORDS;
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
VALIDATE_INPUT(segment->containsInterval(ptr - POINTER_SIZE_IN_WORDS, "Message contains out-of-bounds list pointer.") {
ptr + wordCount), goto useDefault;
"Message contains out-of-bounds list pointer.") { }
goto useDefault;
}
VALIDATE_INPUT(tag->kind() == WirePointer::STRUCT, VALIDATE_INPUT(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") { "INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault; goto useDefault;
}
} }
ElementCount elementCount = tag->inlineCompositeListElementCount(); ElementCount elementCount = tag->inlineCompositeListElementCount();
auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS; auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (segment != nullptr) { VALIDATE_INPUT(wordsPerElement * elementCount <= wordCount,
VALIDATE_INPUT(wordsPerElement * elementCount <= wordCount, "INLINE_COMPOSITE list's elements overrun its word count.");
"INLINE_COMPOSITE list's elements overrun its word count.");
}
return ObjectReader( return ObjectReader(
ListReader(segment, ptr, elementCount, wordsPerElement * BITS_PER_WORD, ListReader(segment, ptr, elementCount, wordsPerElement * BITS_PER_WORD,
...@@ -1857,11 +1770,9 @@ struct WireHelpers { ...@@ -1857,11 +1770,9 @@ struct WireHelpers {
ElementCount elementCount = ref->listRef.elementCount(); ElementCount elementCount = ref->listRef.elementCount();
WordCount wordCount = roundUpToWords(ElementCount64(elementCount) * step); WordCount wordCount = roundUpToWords(ElementCount64(elementCount) * step);
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + wordCount),
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + wordCount), "Message contains out-of-bounds list pointer.") {
"Message contains out-of-bounds list pointer.") { goto useDefault;
goto useDefault;
}
} }
return ObjectReader( return ObjectReader(
...@@ -1992,18 +1903,18 @@ StructReader StructBuilder::asReader() const { ...@@ -1992,18 +1903,18 @@ StructReader StructBuilder::asReader() const {
StructReader StructReader::readRootTrusted(const word* location) { StructReader StructReader::readRootTrusted(const word* location) {
return WireHelpers::readStructPointer(nullptr, reinterpret_cast<const WirePointer*>(location), return WireHelpers::readStructPointer(nullptr, reinterpret_cast<const WirePointer*>(location),
nullptr, std::numeric_limits<int>::max()); nullptr, std::numeric_limits<int>::max());
} }
StructReader StructReader::readRoot( StructReader StructReader::readRoot(
const word* location, SegmentReader* segment, int nestingLimit) { const word* location, SegmentReader* segment, int nestingLimit) {
VALIDATE_INPUT(segment->containsInterval(location, location + POINTER_SIZE_IN_WORDS), VALIDATE_INPUT(WireHelpers::boundsCheck(segment, location, location + POINTER_SIZE_IN_WORDS),
"Root location out-of-bounds.") { "Root location out-of-bounds.") {
location = nullptr; location = nullptr;
} }
return WireHelpers::readStructPointer(segment, reinterpret_cast<const WirePointer*>(location), return WireHelpers::readStructPointer(segment, reinterpret_cast<const WirePointer*>(location),
nullptr, nestingLimit); nullptr, nestingLimit);
} }
StructReader StructReader::getStructField( StructReader StructReader::getStructField(
...@@ -2220,7 +2131,7 @@ Data::Reader ListReader::asData() { ...@@ -2220,7 +2131,7 @@ Data::Reader ListReader::asData() {
} }
StructReader ListReader::getStructElement(ElementCount index) const { StructReader ListReader::getStructElement(ElementCount index) const {
VALIDATE_INPUT((segment == nullptr) | (nestingLimit > 0), VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
return StructReader(); return StructReader();
} }
......
...@@ -416,7 +416,7 @@ class StructReader { ...@@ -416,7 +416,7 @@ class StructReader {
public: public:
inline StructReader() inline StructReader()
: segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0), : segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0),
pointerCount(0), bit0Offset(0), nestingLimit(0) {} pointerCount(0), bit0Offset(0), nestingLimit(0x7fffffff) {}
static StructReader readRootTrusted(const word* location); static StructReader readRootTrusted(const word* location);
static StructReader readRoot(const word* location, SegmentReader* segment, int nestingLimit); static StructReader readRoot(const word* location, SegmentReader* segment, int nestingLimit);
...@@ -609,7 +609,7 @@ class ListReader { ...@@ -609,7 +609,7 @@ class ListReader {
public: public:
inline ListReader() inline ListReader()
: segment(nullptr), ptr(nullptr), elementCount(0), step(0 * BITS / ELEMENTS), : segment(nullptr), ptr(nullptr), elementCount(0), step(0 * BITS / ELEMENTS),
structDataSize(0), structPointerCount(0), nestingLimit(0) {} structDataSize(0), structPointerCount(0), nestingLimit(0x7fffffff) {}
inline ElementCount size() const; inline ElementCount size() const;
// The number of elements in the list. // The number of elements in the list.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment