Commit de15e658 authored by Kenton Varda's avatar Kenton Varda

Don't trust trusted messages so much. Just trust that the pointers are valid…

Don't trust trusted messages so much.  Just trust that the pointers are valid and in-bounds, but don't assume the data matches the schema.  This makes it easier to safely create a 'trusted' message: just copy any message into a MessageBuilder, and if it ends up all in one segment, it's safe to use.
parent d7081cca
......@@ -209,6 +209,12 @@ struct WireHelpers {
return (bits + 7 * BITS) / BITS_PER_BYTE;
}
static CAPNPROTO_ALWAYS_INLINE(bool boundsCheck(
SegmentReader* segment, const word* start, const word* end)) {
// If segment is null, this is a trusted message, so all pointers have been checked already.
return segment == nullptr || segment->containsInterval(start, end);
}
static CAPNPROTO_ALWAYS_INLINE(word* allocate(
WirePointer*& ref, SegmentBuilder*& segment, WordCount amount,
WirePointer::Kind kind)) {
......@@ -263,7 +269,8 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(
const word* followFars(const WirePointer*& ref, SegmentReader*& segment)) {
if (ref->kind() == WirePointer::FAR) {
// If the segment is null, this is a trusted message, so there are no FAR pointers.
if (segment != nullptr && ref->kind() == WirePointer::FAR) {
// Look up the segment containing the landing pad.
segment = segment->getArena()->tryGetSegment(ref->farRef.segmentId.get());
VALIDATE_INPUT(segment != nullptr, "Message contains far pointer to unknown segment.") {
......@@ -273,7 +280,7 @@ struct WireHelpers {
// Find the landing pad and check that it is within bounds.
const word* ptr = segment->getStartPtr() + ref->farPositionInSegment();
WordCount padWords = (1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + padWords),
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + padWords),
"Message contains out-of-bounds far pointer.") {
return nullptr;
}
......@@ -433,23 +440,16 @@ struct WireHelpers {
}
--nestingLimit;
const word* ptr;
if (segment == nullptr) {
ptr = ref->target();
} else {
ptr = followFars(ref, segment);
}
const word* ptr = followFars(ref, segment);
WordCount64 result = 0 * WORDS;
switch (ref->kind()) {
case WirePointer::STRUCT: {
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
break;
}
}
result += ref->structRef.wordSize();
const WirePointer* pointerSection =
......@@ -473,24 +473,20 @@ struct WireHelpers {
WordCount totalWords = roundUpToWords(
ElementCount64(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize()));
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + totalWords),
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + totalWords),
"Message contained out-of-bounds list pointer.") {
break;
}
}
result += totalWords;
break;
}
case FieldSize::POINTER: {
WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS);
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + count * WORDS_PER_POINTER),
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + count * WORDS_PER_POINTER),
"Message contained out-of-bounds list pointer.") {
break;
}
}
result += count * WORDS_PER_POINTER;
......@@ -502,20 +498,17 @@ struct WireHelpers {
}
case FieldSize::INLINE_COMPOSITE: {
WordCount wordCount = ref->listRef.inlineCompositeWordCount();
if (segment != nullptr) {
VALIDATE_INPUT(
segment->containsInterval(ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
boundsCheck(segment, ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
"Message contained out-of-bounds list pointer.") {
break;
}
}
result += wordCount + POINTER_SIZE_IN_WORDS;
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount();
if (segment != nullptr) {
VALIDATE_INPUT(elementTag->kind() == WirePointer::STRUCT,
"Don't know how to handle non-STRUCT inline composite.") {
break;
......@@ -525,7 +518,6 @@ struct WireHelpers {
"Struct list pointer's elements overran size.") {
break;
}
}
WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
......@@ -738,17 +730,14 @@ struct WireHelpers {
WirePointer* ref, SegmentBuilder* segment, StructSize size, const word* defaultValue)) {
if (ref->isNull()) {
useDefault:
word* ptr;
if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
ptr = allocate(ref, segment, size.total(), WirePointer::STRUCT);
ref->structRef.set(size);
} else {
ptr = copyMessage(segment, ref, reinterpret_cast<const WirePointer*>(defaultValue));
return initStructPointer(ref, segment, size);
}
return StructBuilder(segment, ptr, reinterpret_cast<WirePointer*>(ptr + size.data),
size.data * BITS_PER_WORD, size.pointers, 0 * BITS);
} else {
copyMessage(segment, ref, reinterpret_cast<const WirePointer*>(defaultValue));
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
WirePointer* oldRef = ref;
SegmentBuilder* oldSegment = segment;
word* oldPtr = followFars(oldRef, oldSegment);
......@@ -803,7 +792,6 @@ struct WireHelpers {
oldPointerCount, 0 * BITS);
}
}
}
static CAPNPROTO_ALWAYS_INLINE(ListBuilder initListPointer(
WirePointer* ref, SegmentBuilder* segment, ElementCount elementCount,
......@@ -869,18 +857,12 @@ struct WireHelpers {
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return ListBuilder();
}
word* ptr = copyMessage(origSegment, origRef,
reinterpret_cast<const WirePointer*>(defaultValue));
BitCount dataSize = dataBitsPerElement(elementSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(elementSize) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(origSegment, ptr, step, origRef->listRef.elementCount(),
dataSize, pointerCount);
copyMessage(origSegment, origRef, reinterpret_cast<const WirePointer*>(defaultValue));
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
} else {
// The pointer is already initialized. We must verify that it has the right size. Unlike
// in getWritableStructListReference(), we never need to "upgrade" the data, because this
// We must verify that the pointer has the right size. Unlike in
// getWritableStructListReference(), we never need to "upgrade" the data, because this
// method is called only for non-struct lists, and there is no allowed upgrade path *to*
// a non-struct list, only *from* them.
......@@ -957,7 +939,6 @@ struct WireHelpers {
dataSize, pointerCount);
}
}
}
static CAPNPROTO_ALWAYS_INLINE(ListBuilder getWritableStructListPointer(
WirePointer* origRef, SegmentBuilder* origSegment, StructSize elementSize,
......@@ -968,29 +949,11 @@ struct WireHelpers {
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return ListBuilder();
}
word* ptr = copyMessage(origSegment, origRef,
reinterpret_cast<const WirePointer*>(defaultValue));
// Assume the default value is valid.
if (elementSize.preferredListEncoding == FieldSize::INLINE_COMPOSITE) {
WirePointer* tag = reinterpret_cast<WirePointer*>(ptr);
return ListBuilder(origSegment, tag + 1, elementSize.total() * BITS_PER_WORD / ELEMENTS,
tag->inlineCompositeListElementCount(),
elementSize.data * BITS_PER_WORD,
elementSize.pointers);
} else {
BitCount dataSize = dataBitsPerElement(elementSize.preferredListEncoding) * ELEMENTS;
WirePointerCount pointerCount =
pointersPerElement(elementSize.preferredListEncoding) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(origSegment, ptr, step, origRef->listRef.elementCount(),
dataSize, pointerCount);
copyMessage(origSegment, origRef, reinterpret_cast<const WirePointer*>(defaultValue));
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
} else {
// The pointer is already initialized. We must verify that it has the right size and
// potentially upgrade it if not.
// We must verify that the pointer has the right size and potentially upgrade it if not.
WirePointer* oldRef = origRef;
SegmentBuilder* oldSegment = origSegment;
......@@ -1232,7 +1195,6 @@ struct WireHelpers {
}
}
}
}
static CAPNPROTO_ALWAYS_INLINE(Text::Builder initTextPointer(
WirePointer* ref, SegmentBuilder* segment, ByteCount size)) {
......@@ -1459,25 +1421,23 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(StructReader readStructPointer(
SegmentReader* segment, const WirePointer* ref, const word* defaultValue,
int nestingLimit)) {
const word* ptr;
if (ref == nullptr || ref->isNull()) {
useDefault:
if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
return StructReader(nullptr, nullptr, nullptr, 0 * BITS, 0 * POINTERS, 0 * BITS,
std::numeric_limits<int>::max());
return StructReader();
}
segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
ptr = followFars(ref, segment);
const word* ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported the error.
goto useDefault;
......@@ -1488,14 +1448,10 @@ struct WireHelpers {
goto useDefault;
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
goto useDefault;
}
} else {
// Trusted messages don't contain far pointers.
ptr = ref->target();
}
return StructReader(
segment, ptr, reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get()),
......@@ -1507,7 +1463,6 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(ListReader readListPointer(
SegmentReader* segment, const WirePointer* ref, const word* defaultValue,
FieldSize expectedElementSize, int nestingLimit)) {
const word* ptr;
if (ref == nullptr || ref->isNull()) {
useDefault:
if (defaultValue == nullptr ||
......@@ -1516,14 +1471,15 @@ struct WireHelpers {
}
segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
ptr = followFars(ref, segment);
const word* ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported error.
goto useDefault;
......@@ -1533,10 +1489,6 @@ struct WireHelpers {
"Message contains non-list pointer where list pointer was expected.") {
goto useDefault;
}
} else {
// Trusted messages don't contain far pointers.
ptr = ref->target();
}
if (ref->listRef.elementSize() == FieldSize::INLINE_COMPOSITE) {
decltype(WORDS/ELEMENTS) wordsPerElement;
......@@ -1548,8 +1500,7 @@ struct WireHelpers {
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS;
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
VALIDATE_INPUT(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
......@@ -1608,17 +1559,6 @@ struct WireHelpers {
break;
}
} else {
// Trusted message.
// This logic is equivalent to the other branch, above, but skipping all the checks.
size = tag->inlineCompositeListElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (expectedElementSize == FieldSize::POINTER) {
ptr += tag->structRef.dataSize.get();
}
}
return ListReader(
segment, ptr, size, wordsPerElement * BITS_PER_WORD,
tag->structRef.dataSize.get() * BITS_PER_WORD,
......@@ -1632,15 +1572,12 @@ struct WireHelpers {
pointersPerElement(ref->listRef.elementSize()) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
}
if (segment != nullptr) {
// Verify that the elements are at least as large as the expected type. Note that if we
// expected INLINE_COMPOSITE, the expected sizes here will be zero, because bounds checking
// will be performed at field access time. So this check here is for the case where we
......@@ -1659,7 +1596,6 @@ struct WireHelpers {
"Message contained list with incompatible element type.") {
goto useDefault;
}
}
return ListReader(segment, ptr, ref->listRef.elementCount(), step,
dataSize, pointerCount, nestingLimit - 1);
......@@ -1671,15 +1607,8 @@ struct WireHelpers {
const void* defaultValue, ByteCount defaultSize)) {
if (ref == nullptr || ref->isNull()) {
useDefault:
if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
defaultValue = "";
}
if (defaultValue == nullptr) defaultValue = "";
return Text::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES);
} else if (segment == nullptr) {
// Trusted message.
return Text::Reader(reinterpret_cast<const char*>(ref->target()),
ref->listRef.elementCount() / ELEMENTS - 1);
} else {
const word* ptr = followFars(ref, segment);
......@@ -1700,7 +1629,7 @@ struct WireHelpers {
goto useDefault;
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds text pointer.") {
goto useDefault;
......@@ -1727,10 +1656,6 @@ struct WireHelpers {
if (ref == nullptr || ref->isNull()) {
useDefault:
return Data::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES);
} else if (segment == nullptr) {
// Trusted message.
return Data::Reader(reinterpret_cast<const char*>(ref->target()),
ref->listRef.elementCount() / ELEMENTS);
} else {
const word* ptr = followFars(ref, segment);
......@@ -1751,7 +1676,7 @@ struct WireHelpers {
goto useDefault;
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds data pointer.") {
goto useDefault;
......@@ -1771,7 +1696,6 @@ struct WireHelpers {
// Not always-inline because it is called from several places in the copying code, and anyway
// is relatively rarely used.
const word* ptr;
if (ref == nullptr || ref->isNull()) {
useDefault:
if (defaultValue == nullptr ||
......@@ -1780,30 +1704,26 @@ struct WireHelpers {
}
segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
ptr = WireHelpers::followFars(ref, segment);
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
const word* ptr = WireHelpers::followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported the error.
goto useDefault;
}
} else {
ptr = ref->target();
}
switch (ref->kind()) {
case WirePointer::STRUCT:
if (segment != nullptr) {
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
goto useDefault;
}
}
return ObjectReader(
StructReader(segment, ptr,
reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get()),
......@@ -1813,21 +1733,17 @@ struct WireHelpers {
case WirePointer::LIST: {
FieldSize elementSize = ref->listRef.elementSize();
if (segment != nullptr) {
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
}
if (elementSize == FieldSize::INLINE_COMPOSITE) {
WordCount wordCount = ref->listRef.inlineCompositeWordCount();
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS;
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr - POINTER_SIZE_IN_WORDS,
ptr + wordCount),
VALIDATE_INPUT(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
......@@ -1836,15 +1752,12 @@ struct WireHelpers {
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault;
}
}
ElementCount elementCount = tag->inlineCompositeListElementCount();
auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (segment != nullptr) {
VALIDATE_INPUT(wordsPerElement * elementCount <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.");
}
return ObjectReader(
ListReader(segment, ptr, elementCount, wordsPerElement * BITS_PER_WORD,
......@@ -1857,12 +1770,10 @@ struct WireHelpers {
ElementCount elementCount = ref->listRef.elementCount();
WordCount wordCount = roundUpToWords(ElementCount64(elementCount) * step);
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + wordCount),
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
}
return ObjectReader(
ListReader(segment, ptr, elementCount, step, dataSize, pointerCount,
......@@ -1997,7 +1908,7 @@ StructReader StructReader::readRootTrusted(const word* location) {
StructReader StructReader::readRoot(
const word* location, SegmentReader* segment, int nestingLimit) {
VALIDATE_INPUT(segment->containsInterval(location, location + POINTER_SIZE_IN_WORDS),
VALIDATE_INPUT(WireHelpers::boundsCheck(segment, location, location + POINTER_SIZE_IN_WORDS),
"Root location out-of-bounds.") {
location = nullptr;
}
......@@ -2220,7 +2131,7 @@ Data::Reader ListReader::asData() {
}
StructReader ListReader::getStructElement(ElementCount index) const {
VALIDATE_INPUT((segment == nullptr) | (nestingLimit > 0),
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
return StructReader();
}
......
......@@ -416,7 +416,7 @@ class StructReader {
public:
inline StructReader()
: segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0),
pointerCount(0), bit0Offset(0), nestingLimit(0) {}
pointerCount(0), bit0Offset(0), nestingLimit(0x7fffffff) {}
static StructReader readRootTrusted(const word* location);
static StructReader readRoot(const word* location, SegmentReader* segment, int nestingLimit);
......@@ -609,7 +609,7 @@ class ListReader {
public:
inline ListReader()
: segment(nullptr), ptr(nullptr), elementCount(0), step(0 * BITS / ELEMENTS),
structDataSize(0), structPointerCount(0), nestingLimit(0) {}
structDataSize(0), structPointerCount(0), nestingLimit(0x7fffffff) {}
inline ElementCount size() const;
// The number of elements in the list.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment