Commit de15e658 authored by Kenton Varda's avatar Kenton Varda

Don't trust trusted messages so much. Just trust that the pointers are valid…

Don't trust trusted messages so much.  Just trust that the pointers are valid and in-bounds, but don't assume the data matches the schema.  This makes it easier to safely create a 'trusted' message: just copy any message into a MessageBuilder, and if it ends up all in one segment, it's safe to use.
parent d7081cca
...@@ -209,6 +209,12 @@ struct WireHelpers { ...@@ -209,6 +209,12 @@ struct WireHelpers {
return (bits + 7 * BITS) / BITS_PER_BYTE; return (bits + 7 * BITS) / BITS_PER_BYTE;
} }
static CAPNPROTO_ALWAYS_INLINE(bool boundsCheck(
SegmentReader* segment, const word* start, const word* end)) {
// If segment is null, this is a trusted message, so all pointers have been checked already.
return segment == nullptr || segment->containsInterval(start, end);
}
static CAPNPROTO_ALWAYS_INLINE(word* allocate( static CAPNPROTO_ALWAYS_INLINE(word* allocate(
WirePointer*& ref, SegmentBuilder*& segment, WordCount amount, WirePointer*& ref, SegmentBuilder*& segment, WordCount amount,
WirePointer::Kind kind)) { WirePointer::Kind kind)) {
...@@ -263,7 +269,8 @@ struct WireHelpers { ...@@ -263,7 +269,8 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE( static CAPNPROTO_ALWAYS_INLINE(
const word* followFars(const WirePointer*& ref, SegmentReader*& segment)) { const word* followFars(const WirePointer*& ref, SegmentReader*& segment)) {
if (ref->kind() == WirePointer::FAR) { // If the segment is null, this is a trusted message, so there are no FAR pointers.
if (segment != nullptr && ref->kind() == WirePointer::FAR) {
// Look up the segment containing the landing pad. // Look up the segment containing the landing pad.
segment = segment->getArena()->tryGetSegment(ref->farRef.segmentId.get()); segment = segment->getArena()->tryGetSegment(ref->farRef.segmentId.get());
VALIDATE_INPUT(segment != nullptr, "Message contains far pointer to unknown segment.") { VALIDATE_INPUT(segment != nullptr, "Message contains far pointer to unknown segment.") {
...@@ -273,7 +280,7 @@ struct WireHelpers { ...@@ -273,7 +280,7 @@ struct WireHelpers {
// Find the landing pad and check that it is within bounds. // Find the landing pad and check that it is within bounds.
const word* ptr = segment->getStartPtr() + ref->farPositionInSegment(); const word* ptr = segment->getStartPtr() + ref->farPositionInSegment();
WordCount padWords = (1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS; WordCount padWords = (1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + padWords), VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + padWords),
"Message contains out-of-bounds far pointer.") { "Message contains out-of-bounds far pointer.") {
return nullptr; return nullptr;
} }
...@@ -433,23 +440,16 @@ struct WireHelpers { ...@@ -433,23 +440,16 @@ struct WireHelpers {
} }
--nestingLimit; --nestingLimit;
const word* ptr; const word* ptr = followFars(ref, segment);
if (segment == nullptr) {
ptr = ref->target();
} else {
ptr = followFars(ref, segment);
}
WordCount64 result = 0 * WORDS; WordCount64 result = 0 * WORDS;
switch (ref->kind()) { switch (ref->kind()) {
case WirePointer::STRUCT: { case WirePointer::STRUCT: {
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
break; break;
} }
}
result += ref->structRef.wordSize(); result += ref->structRef.wordSize();
const WirePointer* pointerSection = const WirePointer* pointerSection =
...@@ -473,24 +473,20 @@ struct WireHelpers { ...@@ -473,24 +473,20 @@ struct WireHelpers {
WordCount totalWords = roundUpToWords( WordCount totalWords = roundUpToWords(
ElementCount64(ref->listRef.elementCount()) * ElementCount64(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize())); dataBitsPerElement(ref->listRef.elementSize()));
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + totalWords),
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + totalWords),
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
break; break;
} }
}
result += totalWords; result += totalWords;
break; break;
} }
case FieldSize::POINTER: { case FieldSize::POINTER: {
WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS); WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS);
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + count * WORDS_PER_POINTER),
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + count * WORDS_PER_POINTER),
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
break; break;
} }
}
result += count * WORDS_PER_POINTER; result += count * WORDS_PER_POINTER;
...@@ -502,20 +498,17 @@ struct WireHelpers { ...@@ -502,20 +498,17 @@ struct WireHelpers {
} }
case FieldSize::INLINE_COMPOSITE: { case FieldSize::INLINE_COMPOSITE: {
WordCount wordCount = ref->listRef.inlineCompositeWordCount(); WordCount wordCount = ref->listRef.inlineCompositeWordCount();
if (segment != nullptr) {
VALIDATE_INPUT( VALIDATE_INPUT(
segment->containsInterval(ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS), boundsCheck(segment, ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
break; break;
} }
}
result += wordCount + POINTER_SIZE_IN_WORDS; result += wordCount + POINTER_SIZE_IN_WORDS;
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount(); ElementCount count = elementTag->inlineCompositeListElementCount();
if (segment != nullptr) {
VALIDATE_INPUT(elementTag->kind() == WirePointer::STRUCT, VALIDATE_INPUT(elementTag->kind() == WirePointer::STRUCT,
"Don't know how to handle non-STRUCT inline composite.") { "Don't know how to handle non-STRUCT inline composite.") {
break; break;
...@@ -525,7 +518,6 @@ struct WireHelpers { ...@@ -525,7 +518,6 @@ struct WireHelpers {
"Struct list pointer's elements overran size.") { "Struct list pointer's elements overran size.") {
break; break;
} }
}
WordCount dataSize = elementTag->structRef.dataSize.get(); WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get(); WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
...@@ -738,17 +730,14 @@ struct WireHelpers { ...@@ -738,17 +730,14 @@ struct WireHelpers {
WirePointer* ref, SegmentBuilder* segment, StructSize size, const word* defaultValue)) { WirePointer* ref, SegmentBuilder* segment, StructSize size, const word* defaultValue)) {
if (ref->isNull()) { if (ref->isNull()) {
useDefault: useDefault:
word* ptr;
if (defaultValue == nullptr || if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) { reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
ptr = allocate(ref, segment, size.total(), WirePointer::STRUCT); return initStructPointer(ref, segment, size);
ref->structRef.set(size);
} else {
ptr = copyMessage(segment, ref, reinterpret_cast<const WirePointer*>(defaultValue));
} }
return StructBuilder(segment, ptr, reinterpret_cast<WirePointer*>(ptr + size.data), copyMessage(segment, ref, reinterpret_cast<const WirePointer*>(defaultValue));
size.data * BITS_PER_WORD, size.pointers, 0 * BITS); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} else { }
WirePointer* oldRef = ref; WirePointer* oldRef = ref;
SegmentBuilder* oldSegment = segment; SegmentBuilder* oldSegment = segment;
word* oldPtr = followFars(oldRef, oldSegment); word* oldPtr = followFars(oldRef, oldSegment);
...@@ -803,7 +792,6 @@ struct WireHelpers { ...@@ -803,7 +792,6 @@ struct WireHelpers {
oldPointerCount, 0 * BITS); oldPointerCount, 0 * BITS);
} }
} }
}
static CAPNPROTO_ALWAYS_INLINE(ListBuilder initListPointer( static CAPNPROTO_ALWAYS_INLINE(ListBuilder initListPointer(
WirePointer* ref, SegmentBuilder* segment, ElementCount elementCount, WirePointer* ref, SegmentBuilder* segment, ElementCount elementCount,
...@@ -869,18 +857,12 @@ struct WireHelpers { ...@@ -869,18 +857,12 @@ struct WireHelpers {
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) { reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return ListBuilder(); return ListBuilder();
} }
word* ptr = copyMessage(origSegment, origRef, copyMessage(origSegment, origRef, reinterpret_cast<const WirePointer*>(defaultValue));
reinterpret_cast<const WirePointer*>(defaultValue)); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
BitCount dataSize = dataBitsPerElement(elementSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(elementSize) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(origSegment, ptr, step, origRef->listRef.elementCount(),
dataSize, pointerCount);
} else { // We must verify that the pointer has the right size. Unlike in
// The pointer is already initialized. We must verify that it has the right size. Unlike // getWritableStructListReference(), we never need to "upgrade" the data, because this
// in getWritableStructListReference(), we never need to "upgrade" the data, because this
// method is called only for non-struct lists, and there is no allowed upgrade path *to* // method is called only for non-struct lists, and there is no allowed upgrade path *to*
// a non-struct list, only *from* them. // a non-struct list, only *from* them.
...@@ -957,7 +939,6 @@ struct WireHelpers { ...@@ -957,7 +939,6 @@ struct WireHelpers {
dataSize, pointerCount); dataSize, pointerCount);
} }
} }
}
static CAPNPROTO_ALWAYS_INLINE(ListBuilder getWritableStructListPointer( static CAPNPROTO_ALWAYS_INLINE(ListBuilder getWritableStructListPointer(
WirePointer* origRef, SegmentBuilder* origSegment, StructSize elementSize, WirePointer* origRef, SegmentBuilder* origSegment, StructSize elementSize,
...@@ -968,29 +949,11 @@ struct WireHelpers { ...@@ -968,29 +949,11 @@ struct WireHelpers {
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) { reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return ListBuilder(); return ListBuilder();
} }
word* ptr = copyMessage(origSegment, origRef, copyMessage(origSegment, origRef, reinterpret_cast<const WirePointer*>(defaultValue));
reinterpret_cast<const WirePointer*>(defaultValue)); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
// Assume the default value is valid.
if (elementSize.preferredListEncoding == FieldSize::INLINE_COMPOSITE) {
WirePointer* tag = reinterpret_cast<WirePointer*>(ptr);
return ListBuilder(origSegment, tag + 1, elementSize.total() * BITS_PER_WORD / ELEMENTS,
tag->inlineCompositeListElementCount(),
elementSize.data * BITS_PER_WORD,
elementSize.pointers);
} else {
BitCount dataSize = dataBitsPerElement(elementSize.preferredListEncoding) * ELEMENTS;
WirePointerCount pointerCount =
pointersPerElement(elementSize.preferredListEncoding) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(origSegment, ptr, step, origRef->listRef.elementCount(),
dataSize, pointerCount);
} }
} else { // We must verify that the pointer has the right size and potentially upgrade it if not.
// The pointer is already initialized. We must verify that it has the right size and
// potentially upgrade it if not.
WirePointer* oldRef = origRef; WirePointer* oldRef = origRef;
SegmentBuilder* oldSegment = origSegment; SegmentBuilder* oldSegment = origSegment;
...@@ -1232,7 +1195,6 @@ struct WireHelpers { ...@@ -1232,7 +1195,6 @@ struct WireHelpers {
} }
} }
} }
}
static CAPNPROTO_ALWAYS_INLINE(Text::Builder initTextPointer( static CAPNPROTO_ALWAYS_INLINE(Text::Builder initTextPointer(
WirePointer* ref, SegmentBuilder* segment, ByteCount size)) { WirePointer* ref, SegmentBuilder* segment, ByteCount size)) {
...@@ -1459,25 +1421,23 @@ struct WireHelpers { ...@@ -1459,25 +1421,23 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(StructReader readStructPointer( static CAPNPROTO_ALWAYS_INLINE(StructReader readStructPointer(
SegmentReader* segment, const WirePointer* ref, const word* defaultValue, SegmentReader* segment, const WirePointer* ref, const word* defaultValue,
int nestingLimit)) { int nestingLimit)) {
const word* ptr;
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr || if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) { reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return StructReader(nullptr, nullptr, nullptr, 0 * BITS, 0 * POINTERS, 0 * BITS, return StructReader();
std::numeric_limits<int>::max());
} }
segment = nullptr; segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue); ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target(); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} else if (segment != nullptr) { }
VALIDATE_INPUT(nestingLimit > 0, VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault; goto useDefault;
} }
ptr = followFars(ref, segment); const word* ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) { if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported the error. // Already reported the error.
goto useDefault; goto useDefault;
...@@ -1488,14 +1448,10 @@ struct WireHelpers { ...@@ -1488,14 +1448,10 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()), VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
goto useDefault; goto useDefault;
} }
} else {
// Trusted messages don't contain far pointers.
ptr = ref->target();
}
return StructReader( return StructReader(
segment, ptr, reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get()), segment, ptr, reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get()),
...@@ -1507,7 +1463,6 @@ struct WireHelpers { ...@@ -1507,7 +1463,6 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(ListReader readListPointer( static CAPNPROTO_ALWAYS_INLINE(ListReader readListPointer(
SegmentReader* segment, const WirePointer* ref, const word* defaultValue, SegmentReader* segment, const WirePointer* ref, const word* defaultValue,
FieldSize expectedElementSize, int nestingLimit)) { FieldSize expectedElementSize, int nestingLimit)) {
const word* ptr;
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr || if (defaultValue == nullptr ||
...@@ -1516,14 +1471,15 @@ struct WireHelpers { ...@@ -1516,14 +1471,15 @@ struct WireHelpers {
} }
segment = nullptr; segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue); ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target(); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} else if (segment != nullptr) { }
VALIDATE_INPUT(nestingLimit > 0, VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault; goto useDefault;
} }
ptr = followFars(ref, segment); const word* ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) { if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported error. // Already reported error.
goto useDefault; goto useDefault;
...@@ -1533,10 +1489,6 @@ struct WireHelpers { ...@@ -1533,10 +1489,6 @@ struct WireHelpers {
"Message contains non-list pointer where list pointer was expected.") { "Message contains non-list pointer where list pointer was expected.") {
goto useDefault; goto useDefault;
} }
} else {
// Trusted messages don't contain far pointers.
ptr = ref->target();
}
if (ref->listRef.elementSize() == FieldSize::INLINE_COMPOSITE) { if (ref->listRef.elementSize() == FieldSize::INLINE_COMPOSITE) {
decltype(WORDS/ELEMENTS) wordsPerElement; decltype(WORDS/ELEMENTS) wordsPerElement;
...@@ -1548,8 +1500,7 @@ struct WireHelpers { ...@@ -1548,8 +1500,7 @@ struct WireHelpers {
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS; ptr += POINTER_SIZE_IN_WORDS;
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
VALIDATE_INPUT(segment->containsInterval(ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -1608,17 +1559,6 @@ struct WireHelpers { ...@@ -1608,17 +1559,6 @@ struct WireHelpers {
break; break;
} }
} else {
// Trusted message.
// This logic is equivalent to the other branch, above, but skipping all the checks.
size = tag->inlineCompositeListElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (expectedElementSize == FieldSize::POINTER) {
ptr += tag->structRef.dataSize.get();
}
}
return ListReader( return ListReader(
segment, ptr, size, wordsPerElement * BITS_PER_WORD, segment, ptr, size, wordsPerElement * BITS_PER_WORD,
tag->structRef.dataSize.get() * BITS_PER_WORD, tag->structRef.dataSize.get() * BITS_PER_WORD,
...@@ -1632,15 +1572,12 @@ struct WireHelpers { ...@@ -1632,15 +1572,12 @@ struct WireHelpers {
pointersPerElement(ref->listRef.elementSize()) * ELEMENTS; pointersPerElement(ref->listRef.elementSize()) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS; auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)), roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
}
if (segment != nullptr) {
// Verify that the elements are at least as large as the expected type. Note that if we // Verify that the elements are at least as large as the expected type. Note that if we
// expected INLINE_COMPOSITE, the expected sizes here will be zero, because bounds checking // expected INLINE_COMPOSITE, the expected sizes here will be zero, because bounds checking
// will be performed at field access time. So this check here is for the case where we // will be performed at field access time. So this check here is for the case where we
...@@ -1659,7 +1596,6 @@ struct WireHelpers { ...@@ -1659,7 +1596,6 @@ struct WireHelpers {
"Message contained list with incompatible element type.") { "Message contained list with incompatible element type.") {
goto useDefault; goto useDefault;
} }
}
return ListReader(segment, ptr, ref->listRef.elementCount(), step, return ListReader(segment, ptr, ref->listRef.elementCount(), step,
dataSize, pointerCount, nestingLimit - 1); dataSize, pointerCount, nestingLimit - 1);
...@@ -1671,15 +1607,8 @@ struct WireHelpers { ...@@ -1671,15 +1607,8 @@ struct WireHelpers {
const void* defaultValue, ByteCount defaultSize)) { const void* defaultValue, ByteCount defaultSize)) {
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr || if (defaultValue == nullptr) defaultValue = "";
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
defaultValue = "";
}
return Text::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES); return Text::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES);
} else if (segment == nullptr) {
// Trusted message.
return Text::Reader(reinterpret_cast<const char*>(ref->target()),
ref->listRef.elementCount() / ELEMENTS - 1);
} else { } else {
const word* ptr = followFars(ref, segment); const word* ptr = followFars(ref, segment);
...@@ -1700,7 +1629,7 @@ struct WireHelpers { ...@@ -1700,7 +1629,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))), roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds text pointer.") { "Message contained out-of-bounds text pointer.") {
goto useDefault; goto useDefault;
...@@ -1727,10 +1656,6 @@ struct WireHelpers { ...@@ -1727,10 +1656,6 @@ struct WireHelpers {
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
return Data::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES); return Data::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES);
} else if (segment == nullptr) {
// Trusted message.
return Data::Reader(reinterpret_cast<const char*>(ref->target()),
ref->listRef.elementCount() / ELEMENTS);
} else { } else {
const word* ptr = followFars(ref, segment); const word* ptr = followFars(ref, segment);
...@@ -1751,7 +1676,7 @@ struct WireHelpers { ...@@ -1751,7 +1676,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))), roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds data pointer.") { "Message contained out-of-bounds data pointer.") {
goto useDefault; goto useDefault;
...@@ -1771,7 +1696,6 @@ struct WireHelpers { ...@@ -1771,7 +1696,6 @@ struct WireHelpers {
// Not always-inline because it is called from several places in the copying code, and anyway // Not always-inline because it is called from several places in the copying code, and anyway
// is relatively rarely used. // is relatively rarely used.
const word* ptr;
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr || if (defaultValue == nullptr ||
...@@ -1780,30 +1704,26 @@ struct WireHelpers { ...@@ -1780,30 +1704,26 @@ struct WireHelpers {
} }
segment = nullptr; segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue); ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target(); defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} else if (segment != nullptr) { }
ptr = WireHelpers::followFars(ref, segment);
const word* ptr = WireHelpers::followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) { if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported the error. // Already reported the error.
goto useDefault; goto useDefault;
} }
} else {
ptr = ref->target();
}
switch (ref->kind()) { switch (ref->kind()) {
case WirePointer::STRUCT: case WirePointer::STRUCT:
if (segment != nullptr) {
VALIDATE_INPUT(nestingLimit > 0, VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()), VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
goto useDefault; goto useDefault;
} }
}
return ObjectReader( return ObjectReader(
StructReader(segment, ptr, StructReader(segment, ptr,
reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get()), reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get()),
...@@ -1813,21 +1733,17 @@ struct WireHelpers { ...@@ -1813,21 +1733,17 @@ struct WireHelpers {
case WirePointer::LIST: { case WirePointer::LIST: {
FieldSize elementSize = ref->listRef.elementSize(); FieldSize elementSize = ref->listRef.elementSize();
if (segment != nullptr) {
VALIDATE_INPUT(nestingLimit > 0, VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault; goto useDefault;
} }
}
if (elementSize == FieldSize::INLINE_COMPOSITE) { if (elementSize == FieldSize::INLINE_COMPOSITE) {
WordCount wordCount = ref->listRef.inlineCompositeWordCount(); WordCount wordCount = ref->listRef.inlineCompositeWordCount();
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS; ptr += POINTER_SIZE_IN_WORDS;
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
VALIDATE_INPUT(segment->containsInterval(ptr - POINTER_SIZE_IN_WORDS,
ptr + wordCount),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -1836,15 +1752,12 @@ struct WireHelpers { ...@@ -1836,15 +1752,12 @@ struct WireHelpers {
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") { "INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault; goto useDefault;
} }
}
ElementCount elementCount = tag->inlineCompositeListElementCount(); ElementCount elementCount = tag->inlineCompositeListElementCount();
auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS; auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (segment != nullptr) {
VALIDATE_INPUT(wordsPerElement * elementCount <= wordCount, VALIDATE_INPUT(wordsPerElement * elementCount <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count."); "INLINE_COMPOSITE list's elements overrun its word count.");
}
return ObjectReader( return ObjectReader(
ListReader(segment, ptr, elementCount, wordsPerElement * BITS_PER_WORD, ListReader(segment, ptr, elementCount, wordsPerElement * BITS_PER_WORD,
...@@ -1857,12 +1770,10 @@ struct WireHelpers { ...@@ -1857,12 +1770,10 @@ struct WireHelpers {
ElementCount elementCount = ref->listRef.elementCount(); ElementCount elementCount = ref->listRef.elementCount();
WordCount wordCount = roundUpToWords(ElementCount64(elementCount) * step); WordCount wordCount = roundUpToWords(ElementCount64(elementCount) * step);
if (segment != nullptr) { VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + wordCount),
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + wordCount),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
}
return ObjectReader( return ObjectReader(
ListReader(segment, ptr, elementCount, step, dataSize, pointerCount, ListReader(segment, ptr, elementCount, step, dataSize, pointerCount,
...@@ -1997,7 +1908,7 @@ StructReader StructReader::readRootTrusted(const word* location) { ...@@ -1997,7 +1908,7 @@ StructReader StructReader::readRootTrusted(const word* location) {
StructReader StructReader::readRoot( StructReader StructReader::readRoot(
const word* location, SegmentReader* segment, int nestingLimit) { const word* location, SegmentReader* segment, int nestingLimit) {
VALIDATE_INPUT(segment->containsInterval(location, location + POINTER_SIZE_IN_WORDS), VALIDATE_INPUT(WireHelpers::boundsCheck(segment, location, location + POINTER_SIZE_IN_WORDS),
"Root location out-of-bounds.") { "Root location out-of-bounds.") {
location = nullptr; location = nullptr;
} }
...@@ -2220,7 +2131,7 @@ Data::Reader ListReader::asData() { ...@@ -2220,7 +2131,7 @@ Data::Reader ListReader::asData() {
} }
StructReader ListReader::getStructElement(ElementCount index) const { StructReader ListReader::getStructElement(ElementCount index) const {
VALIDATE_INPUT((segment == nullptr) | (nestingLimit > 0), VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
return StructReader(); return StructReader();
} }
......
...@@ -416,7 +416,7 @@ class StructReader { ...@@ -416,7 +416,7 @@ class StructReader {
public: public:
inline StructReader() inline StructReader()
: segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0), : segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0),
pointerCount(0), bit0Offset(0), nestingLimit(0) {} pointerCount(0), bit0Offset(0), nestingLimit(0x7fffffff) {}
static StructReader readRootTrusted(const word* location); static StructReader readRootTrusted(const word* location);
static StructReader readRoot(const word* location, SegmentReader* segment, int nestingLimit); static StructReader readRoot(const word* location, SegmentReader* segment, int nestingLimit);
...@@ -609,7 +609,7 @@ class ListReader { ...@@ -609,7 +609,7 @@ class ListReader {
public: public:
inline ListReader() inline ListReader()
: segment(nullptr), ptr(nullptr), elementCount(0), step(0 * BITS / ELEMENTS), : segment(nullptr), ptr(nullptr), elementCount(0), step(0 * BITS / ELEMENTS),
structDataSize(0), structPointerCount(0), nestingLimit(0) {} structDataSize(0), structPointerCount(0), nestingLimit(0x7fffffff) {}
inline ElementCount size() const; inline ElementCount size() const;
// The number of elements in the list. // The number of elements in the list.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment