Commit 2ca8e411 authored by Kenton Varda's avatar Kenton Varda

Refactor bounds checks to avoid ever creating out-of-bounds pointer values,…

Refactor bounds checks to avoid ever creating out-of-bounds pointer values, which is technically UB even if not dereferenced.
parent 52bc9564
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <vector> #include <vector>
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h>
#if !CAPNP_LITE #if !CAPNP_LITE
#include "capability.h" #include "capability.h"
...@@ -48,6 +49,12 @@ void ReadLimiter::unread(WordCount64 amount) { ...@@ -48,6 +49,12 @@ void ReadLimiter::unread(WordCount64 amount) {
} }
} }
void SegmentReader::abortCheckObjectFault() {
KJ_LOG(FATAL, "checkObject()'s parameter is not in-range; this would segfault in opt mode",
"this is a serious bug in Cap'n Proto; please notify security@sandstorm.io");
abort();
}
void SegmentBuilder::throwNotWritable() { void SegmentBuilder::throwNotWritable() {
KJ_FAIL_REQUIRE( KJ_FAIL_REQUIRE(
"Tried to form a Builder to an external data segment referenced by the MessageBuilder. " "Tried to form a Builder to an external data segment referenced by the MessageBuilder. "
......
...@@ -117,7 +117,20 @@ public: ...@@ -117,7 +117,20 @@ public:
inline SegmentReader(Arena* arena, SegmentId id, const word* ptr, SegmentWordCount size, inline SegmentReader(Arena* arena, SegmentId id, const word* ptr, SegmentWordCount size,
ReadLimiter* readLimiter); ReadLimiter* readLimiter);
KJ_ALWAYS_INLINE(bool containsInterval(const void* from, const void* to)); KJ_ALWAYS_INLINE(const word* checkOffset(const word* from, ptrdiff_t offset));
// Adds the given offset to the given pointer, checks that it is still within the bounds of the
// segment, then returns it. Note that the "end" pointer of the segment (which technically points
// to the word after the last in the segment) is considered in-bounds for this purpose, so you
// can't necessarily dereference it. You must call checkObject() next to check that the object
// you want to read is entirely in-bounds.
//
// If `from + offset` is out-of-range, this returns a pointer to the end of the segment. Thus,
// any non-zero-sized object will fail `checkObject()`. We do this instead of throwing to save
// some code footprint.
KJ_ALWAYS_INLINE(bool checkObject(const word* start, WordCountN<31> size));
// Assuming that `start` is in-bounds for this segment (probably checked using `checkOffset()`),
// check that `start + size` is also in-bounds, and hence the whole area in-between is valid.
KJ_ALWAYS_INLINE(bool amplifiedRead(WordCount virtualAmount)); KJ_ALWAYS_INLINE(bool amplifiedRead(WordCount virtualAmount));
// Indicates that the reader should pretend that `virtualAmount` additional data was read even // Indicates that the reader should pretend that `virtualAmount` additional data was read even
...@@ -147,6 +160,9 @@ private: ...@@ -147,6 +160,9 @@ private:
KJ_DISALLOW_COPY(SegmentReader); KJ_DISALLOW_COPY(SegmentReader);
friend class SegmentBuilder; friend class SegmentBuilder;
static void abortCheckObjectFault();
// Called in debug mode in cases that would segfault in opt mode. (Should be impossible!)
}; };
class SegmentBuilder: public SegmentReader { class SegmentBuilder: public SegmentReader {
...@@ -367,18 +383,25 @@ inline SegmentReader::SegmentReader(Arena* arena, SegmentId id, const word* ptr, ...@@ -367,18 +383,25 @@ inline SegmentReader::SegmentReader(Arena* arena, SegmentId id, const word* ptr,
: arena(arena), id(id), ptr(kj::arrayPtr(ptr, unbound(size / WORDS))), : arena(arena), id(id), ptr(kj::arrayPtr(ptr, unbound(size / WORDS))),
readLimiter(readLimiter) {} readLimiter(readLimiter) {}
inline bool SegmentReader::containsInterval(const void* from, const void* to) { inline const word* SegmentReader::checkOffset(const word* from, ptrdiff_t offset) {
uintptr_t start = reinterpret_cast<uintptr_t>(from) - reinterpret_cast<uintptr_t>(ptr.begin()); ptrdiff_t min = ptr.begin() - from;
uintptr_t end = reinterpret_cast<uintptr_t>(to) - reinterpret_cast<uintptr_t>(ptr.begin()); ptrdiff_t max = ptr.end() - from;
uintptr_t bound = ptr.size() * sizeof(capnp::word); if (offset >= min && offset <= max) {
return from + offset;
return start <= bound && end <= bound && start <= end && } else {
readLimiter->canRead( return ptr.end();
intervalLength(reinterpret_cast<const byte*>(from), }
reinterpret_cast<const byte*>(to), }
MAX_SEGMENT_WORDS * BYTES_PER_WORD)
/ BYTES_PER_WORD, inline bool SegmentReader::checkObject(const word* start, WordCountN<31> size) {
arena); auto startOffset = intervalLength(ptr.begin(), start, MAX_SEGMENT_WORDS);
#ifdef KJ_DEBUG
if (startOffset > bounded(ptr.size()) * WORDS) {
abortCheckObjectFault();
}
#endif
return startOffset + size <= bounded(ptr.size()) * WORDS &&
readLimiter->canRead(size, arena);
} }
inline bool SegmentReader::amplifiedRead(WordCount virtualAmount) { inline bool SegmentReader::amplifiedRead(WordCount virtualAmount) {
......
...@@ -122,9 +122,14 @@ struct WirePointer { ...@@ -122,9 +122,14 @@ struct WirePointer {
KJ_ALWAYS_INLINE(word* target()) { KJ_ALWAYS_INLINE(word* target()) {
return reinterpret_cast<word*>(this) + 1 + (static_cast<int32_t>(offsetAndKind.get()) >> 2); return reinterpret_cast<word*>(this) + 1 + (static_cast<int32_t>(offsetAndKind.get()) >> 2);
} }
KJ_ALWAYS_INLINE(const word* target() const) { KJ_ALWAYS_INLINE(const word* target(SegmentReader* segment) const) {
return reinterpret_cast<const word*>(this) + 1 + if (segment == nullptr) {
return reinterpret_cast<const word*>(this + 1) +
(static_cast<int32_t>(offsetAndKind.get()) >> 2); (static_cast<int32_t>(offsetAndKind.get()) >> 2);
} else {
return segment->checkOffset(reinterpret_cast<const word*>(this + 1),
static_cast<int32_t>(offsetAndKind.get()) >> 2);
}
} }
KJ_ALWAYS_INLINE(void setKindAndTarget(Kind kind, word* target, SegmentBuilder* segment)) { KJ_ALWAYS_INLINE(void setKindAndTarget(Kind kind, word* target, SegmentBuilder* segment)) {
// Check that the target is really in the same segment, otherwise subtracting pointers is // Check that the target is really in the same segment, otherwise subtracting pointers is
...@@ -140,9 +145,14 @@ struct WirePointer { ...@@ -140,9 +145,14 @@ struct WirePointer {
// So now when the pointers are not aligned the same, we can end up corrupting the bottom // So now when the pointers are not aligned the same, we can end up corrupting the bottom
// two bits, where `kind` is stored. For example, this turns a struct into a far pointer. // two bits, where `kind` is stored. For example, this turns a struct into a far pointer.
// Ouch! // Ouch!
KJ_DREQUIRE(segment->containsInterval( KJ_DREQUIRE(reinterpret_cast<uintptr_t>(this) >=
reinterpret_cast<word*>(this), reinterpret_cast<word*>(this + 1))); reinterpret_cast<uintptr_t>(segment->getStartPtr()));
KJ_DREQUIRE(segment->containsInterval(target, target)); KJ_DREQUIRE(reinterpret_cast<uintptr_t>(this) <
reinterpret_cast<uintptr_t>(segment->getStartPtr() + segment->getSize()));
KJ_DREQUIRE(reinterpret_cast<uintptr_t>(target) >=
reinterpret_cast<uintptr_t>(segment->getStartPtr()));
KJ_DREQUIRE(reinterpret_cast<uintptr_t>(target) <=
reinterpret_cast<uintptr_t>(segment->getStartPtr() + segment->getSize()));
offsetAndKind.set(((target - reinterpret_cast<word*>(this) - 1) << 2) | kind); offsetAndKind.set(((target - reinterpret_cast<word*>(this) - 1) << 2) | kind);
} }
KJ_ALWAYS_INLINE(void setKindWithZeroOffset(Kind kind)) { KJ_ALWAYS_INLINE(void setKindWithZeroOffset(Kind kind)) {
...@@ -174,15 +184,20 @@ struct WirePointer { ...@@ -174,15 +184,20 @@ struct WirePointer {
offsetAndKind.set(unboundAs<uint32_t>((elementCount / ELEMENTS) << G(2)) | kind); offsetAndKind.set(unboundAs<uint32_t>((elementCount / ELEMENTS) << G(2)) | kind);
} }
KJ_ALWAYS_INLINE(SegmentWordCount farPositionInSegment() const) { KJ_ALWAYS_INLINE(const word* farTarget(SegmentReader* segment) const) {
KJ_DREQUIRE(kind() == FAR,
"farTarget() should only be called on FAR pointers.");
return segment->checkOffset(segment->getStartPtr(), offsetAndKind.get() >> 3);
}
KJ_ALWAYS_INLINE(word* farTarget(SegmentBuilder* segment) const) {
KJ_DREQUIRE(kind() == FAR, KJ_DREQUIRE(kind() == FAR,
"positionInSegment() should only be called on FAR pointers."); "farTarget() should only be called on FAR pointers.");
return (bounded(offsetAndKind.get()) >> G(3)) * WORDS; return segment->getPtrUnchecked((bounded(offsetAndKind.get()) >> G(3)) * WORDS);
} }
KJ_ALWAYS_INLINE(bool isDoubleFar() const) { KJ_ALWAYS_INLINE(bool isDoubleFar() const) {
KJ_DREQUIRE(kind() == FAR, KJ_DREQUIRE(kind() == FAR,
"isDoubleFar() should only be called on FAR pointers."); "isDoubleFar() should only be called on FAR pointers.");
return unbound((bounded(offsetAndKind.get()) >> G(2)) & G(1)); return (offsetAndKind.get() >> 2) & 1;
} }
KJ_ALWAYS_INLINE(void setFar(bool isDoubleFar, WordCountN<29> pos)) { KJ_ALWAYS_INLINE(void setFar(bool isDoubleFar, WordCountN<29> pos)) {
offsetAndKind.set(unboundAs<uint32_t>((pos / WORDS) << G(3)) | offsetAndKind.set(unboundAs<uint32_t>((pos / WORDS) << G(3)) |
...@@ -403,9 +418,9 @@ struct WireHelpers { ...@@ -403,9 +418,9 @@ struct WireHelpers {
} }
static KJ_ALWAYS_INLINE(bool boundsCheck( static KJ_ALWAYS_INLINE(bool boundsCheck(
SegmentReader* segment, const word* start, const word* end)) { SegmentReader* segment, const word* start, WordCountN<31> size)) {
// If segment is null, this is an unchecked message, so we don't do bounds checks. // If segment is null, this is an unchecked message, so we don't do bounds checks.
return segment == nullptr || segment->containsInterval(start, end); return segment == nullptr || segment->checkObject(start, size);
} }
static KJ_ALWAYS_INLINE(bool amplifiedRead(SegmentReader* segment, WordCount virtualAmount)) { static KJ_ALWAYS_INLINE(bool amplifiedRead(SegmentReader* segment, WordCount virtualAmount)) {
...@@ -498,8 +513,7 @@ struct WireHelpers { ...@@ -498,8 +513,7 @@ struct WireHelpers {
if (ref->kind() == WirePointer::FAR) { if (ref->kind() == WirePointer::FAR) {
segment = segment->getArena()->getSegment(ref->farRef.segmentId.get()); segment = segment->getArena()->getSegment(ref->farRef.segmentId.get());
WirePointer* pad = WirePointer* pad = reinterpret_cast<WirePointer*>(ref->farTarget(segment));
reinterpret_cast<WirePointer*>(segment->getPtrUnchecked(ref->farPositionInSegment()));
if (!ref->isDoubleFar()) { if (!ref->isDoubleFar()) {
ref = pad; ref = pad;
return pad->target(); return pad->target();
...@@ -510,7 +524,7 @@ struct WireHelpers { ...@@ -510,7 +524,7 @@ struct WireHelpers {
ref = pad + 1; ref = pad + 1;
segment = segment->getArena()->getSegment(pad->farRef.segmentId.get()); segment = segment->getArena()->getSegment(pad->farRef.segmentId.get());
return segment->getPtrUnchecked(pad->farPositionInSegment()); return pad->farTarget(segment);
} else { } else {
return refTarget; return refTarget;
} }
...@@ -536,9 +550,9 @@ struct WireHelpers { ...@@ -536,9 +550,9 @@ struct WireHelpers {
} }
// Find the landing pad and check that it is within bounds. // Find the landing pad and check that it is within bounds.
const word* ptr = segment->getStartPtr() + ref->farPositionInSegment(); const word* ptr = ref->farTarget(segment);
WordCount padWords = bounded(1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS; auto padWords = (ONE + bounded(ref->isDoubleFar())) * POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + padWords), KJ_REQUIRE(boundsCheck(segment, ptr, padWords),
"Message contains out-of-bounds far pointer.") { "Message contains out-of-bounds far pointer.") {
return nullptr; return nullptr;
} }
...@@ -548,7 +562,7 @@ struct WireHelpers { ...@@ -548,7 +562,7 @@ struct WireHelpers {
// If this is not a double-far then the landing pad is our final pointer. // If this is not a double-far then the landing pad is our final pointer.
if (!ref->isDoubleFar()) { if (!ref->isDoubleFar()) {
ref = pad; ref = pad;
return pad->target(); return pad->target(segment);
} }
// Landing pad is another far pointer. It is followed by a tag describing the pointed-to // Landing pad is another far pointer. It is followed by a tag describing the pointed-to
...@@ -566,7 +580,7 @@ struct WireHelpers { ...@@ -566,7 +580,7 @@ struct WireHelpers {
} }
segment = newSegment; segment = newSegment;
return segment->getStartPtr() + pad->farPositionInSegment(); return pad->farTarget(segment);
} else { } else {
return refTarget; return refTarget;
} }
...@@ -589,14 +603,12 @@ struct WireHelpers { ...@@ -589,14 +603,12 @@ struct WireHelpers {
case WirePointer::FAR: { case WirePointer::FAR: {
segment = segment->getArena()->getSegment(ref->farRef.segmentId.get()); segment = segment->getArena()->getSegment(ref->farRef.segmentId.get());
if (segment->isWritable()) { // Don't zero external data. if (segment->isWritable()) { // Don't zero external data.
WirePointer* pad = WirePointer* pad = reinterpret_cast<WirePointer*>(ref->farTarget(segment));
reinterpret_cast<WirePointer*>(segment->getPtrUnchecked(ref->farPositionInSegment()));
if (ref->isDoubleFar()) { if (ref->isDoubleFar()) {
segment = segment->getArena()->getSegment(pad->farRef.segmentId.get()); segment = segment->getArena()->getSegment(pad->farRef.segmentId.get());
if (segment->isWritable()) { if (segment->isWritable()) {
zeroObject(segment, capTable, zeroObject(segment, capTable, pad + 1, pad->farTarget(segment));
pad + 1, segment->getPtrUnchecked(pad->farPositionInSegment()));
} }
zeroMemory(pad, G(2) * POINTERS); zeroMemory(pad, G(2) * POINTERS);
} else { } else {
...@@ -712,8 +724,7 @@ struct WireHelpers { ...@@ -712,8 +724,7 @@ struct WireHelpers {
if (ref->kind() == WirePointer::FAR) { if (ref->kind() == WirePointer::FAR) {
SegmentBuilder* padSegment = segment->getArena()->getSegment(ref->farRef.segmentId.get()); SegmentBuilder* padSegment = segment->getArena()->getSegment(ref->farRef.segmentId.get());
if (padSegment->isWritable()) { // Don't zero external data. if (padSegment->isWritable()) { // Don't zero external data.
WirePointer* pad = reinterpret_cast<WirePointer*>( WirePointer* pad = reinterpret_cast<WirePointer*>(ref->farTarget(padSegment));
padSegment->getPtrUnchecked(ref->farPositionInSegment()));
if (ref->isDoubleFar()) { if (ref->isDoubleFar()) {
zeroMemory(pad, G(2) * POINTERS); zeroMemory(pad, G(2) * POINTERS);
} else { } else {
...@@ -743,11 +754,11 @@ struct WireHelpers { ...@@ -743,11 +754,11 @@ struct WireHelpers {
} }
--nestingLimit; --nestingLimit;
const word* ptr = followFars(ref, ref->target(), segment); const word* ptr = followFars(ref, ref->target(segment), segment);
switch (ref->kind()) { switch (ref->kind()) {
case WirePointer::STRUCT: { case WirePointer::STRUCT: {
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()), KJ_REQUIRE(boundsCheck(segment, ptr, ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
return result; return result;
} }
...@@ -773,7 +784,7 @@ struct WireHelpers { ...@@ -773,7 +784,7 @@ struct WireHelpers {
auto totalWords = roundBitsUpToWords( auto totalWords = roundBitsUpToWords(
upgradeBound<uint64_t>(ref->listRef.elementCount()) * upgradeBound<uint64_t>(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize())); dataBitsPerElement(ref->listRef.elementSize()));
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + totalWords), KJ_REQUIRE(boundsCheck(segment, ptr, totalWords),
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
return result; return result;
} }
...@@ -781,9 +792,9 @@ struct WireHelpers { ...@@ -781,9 +792,9 @@ struct WireHelpers {
break; break;
} }
case ElementSize::POINTER: { case ElementSize::POINTER: {
WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS); auto count = ref->listRef.elementCount() * (POINTERS / ELEMENTS);
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + count * WORDS_PER_POINTER), KJ_REQUIRE(boundsCheck(segment, ptr, count * WORDS_PER_POINTER),
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
return result; return result;
} }
...@@ -798,7 +809,7 @@ struct WireHelpers { ...@@ -798,7 +809,7 @@ struct WireHelpers {
} }
case ElementSize::INLINE_COMPOSITE: { case ElementSize::INLINE_COMPOSITE: {
auto wordCount = ref->listRef.inlineCompositeWordCount(); auto wordCount = ref->listRef.inlineCompositeWordCount();
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS), KJ_REQUIRE(boundsCheck(segment, ptr, wordCount + POINTER_SIZE_IN_WORDS),
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
return result; return result;
} }
...@@ -889,7 +900,7 @@ struct WireHelpers { ...@@ -889,7 +900,7 @@ struct WireHelpers {
zeroMemory(dst); zeroMemory(dst);
return nullptr; return nullptr;
} else { } else {
const word* srcPtr = src->target(); const word* srcPtr = src->target(nullptr);
word* dstPtr = allocate( word* dstPtr = allocate(
dst, segment, capTable, src->structRef.wordSize(), WirePointer::STRUCT, nullptr); dst, segment, capTable, src->structRef.wordSize(), WirePointer::STRUCT, nullptr);
...@@ -911,7 +922,7 @@ struct WireHelpers { ...@@ -911,7 +922,7 @@ struct WireHelpers {
auto wordCount = roundBitsUpToWords( auto wordCount = roundBitsUpToWords(
upgradeBound<uint64_t>(src->listRef.elementCount()) * upgradeBound<uint64_t>(src->listRef.elementCount()) *
dataBitsPerElement(src->listRef.elementSize())); dataBitsPerElement(src->listRef.elementSize()));
const word* srcPtr = src->target(); const word* srcPtr = src->target(nullptr);
word* dstPtr = allocate(dst, segment, capTable, wordCount, WirePointer::LIST, nullptr); word* dstPtr = allocate(dst, segment, capTable, wordCount, WirePointer::LIST, nullptr);
copyMemory(dstPtr, srcPtr, wordCount); copyMemory(dstPtr, srcPtr, wordCount);
...@@ -920,7 +931,7 @@ struct WireHelpers { ...@@ -920,7 +931,7 @@ struct WireHelpers {
} }
case ElementSize::POINTER: { case ElementSize::POINTER: {
const WirePointer* srcRefs = reinterpret_cast<const WirePointer*>(src->target()); const WirePointer* srcRefs = reinterpret_cast<const WirePointer*>(src->target(nullptr));
WirePointer* dstRefs = reinterpret_cast<WirePointer*>( WirePointer* dstRefs = reinterpret_cast<WirePointer*>(
allocate(dst, segment, capTable, src->listRef.elementCount() * allocate(dst, segment, capTable, src->listRef.elementCount() *
(ONE * POINTERS / ELEMENTS) * WORDS_PER_POINTER, (ONE * POINTERS / ELEMENTS) * WORDS_PER_POINTER,
...@@ -937,7 +948,7 @@ struct WireHelpers { ...@@ -937,7 +948,7 @@ struct WireHelpers {
} }
case ElementSize::INLINE_COMPOSITE: { case ElementSize::INLINE_COMPOSITE: {
const word* srcPtr = src->target(); const word* srcPtr = src->target(nullptr);
word* dstPtr = allocate(dst, segment, capTable, word* dstPtr = allocate(dst, segment, capTable,
assertMaxBits<SEGMENT_WORD_COUNT_BITS>( assertMaxBits<SEGMENT_WORD_COUNT_BITS>(
src->listRef.inlineCompositeWordCount() + POINTER_SIZE_IN_WORDS, src->listRef.inlineCompositeWordCount() + POINTER_SIZE_IN_WORDS,
...@@ -1873,7 +1884,7 @@ struct WireHelpers { ...@@ -1873,7 +1884,7 @@ struct WireHelpers {
int nestingLimit, BuilderArena* orphanArena = nullptr, int nestingLimit, BuilderArena* orphanArena = nullptr,
bool canonical = false)) { bool canonical = false)) {
return copyPointer(dstSegment, dstCapTable, dst, return copyPointer(dstSegment, dstCapTable, dst,
srcSegment, srcCapTable, src, src->target(), srcSegment, srcCapTable, src, src->target(srcSegment),
nestingLimit, orphanArena, canonical); nestingLimit, orphanArena, canonical);
} }
...@@ -1908,7 +1919,7 @@ struct WireHelpers { ...@@ -1908,7 +1919,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
KJ_REQUIRE(boundsCheck(srcSegment, ptr, ptr + src->structRef.wordSize()), KJ_REQUIRE(boundsCheck(srcSegment, ptr, src->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -1931,13 +1942,14 @@ struct WireHelpers { ...@@ -1931,13 +1942,14 @@ struct WireHelpers {
if (elementSize == ElementSize::INLINE_COMPOSITE) { if (elementSize == ElementSize::INLINE_COMPOSITE) {
auto wordCount = src->listRef.inlineCompositeWordCount(); auto wordCount = src->listRef.inlineCompositeWordCount();
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(boundsCheck(srcSegment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount), KJ_REQUIRE(boundsCheck(srcSegment, ptr, wordCount + POINTER_SIZE_IN_WORDS),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
ptr += POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(tag->kind() == WirePointer::STRUCT, KJ_REQUIRE(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") { "INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault; goto useDefault;
...@@ -1974,7 +1986,7 @@ struct WireHelpers { ...@@ -1974,7 +1986,7 @@ struct WireHelpers {
auto elementCount = src->listRef.elementCount(); auto elementCount = src->listRef.elementCount();
auto wordCount = roundBitsUpToWords(upgradeBound<uint64_t>(elementCount) * step); auto wordCount = roundBitsUpToWords(upgradeBound<uint64_t>(elementCount) * step);
KJ_REQUIRE(boundsCheck(srcSegment, ptr, ptr + wordCount), KJ_REQUIRE(boundsCheck(srcSegment, ptr, wordCount),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -2086,7 +2098,8 @@ struct WireHelpers { ...@@ -2086,7 +2098,8 @@ struct WireHelpers {
SegmentReader* segment, CapTableReader* capTable, SegmentReader* segment, CapTableReader* capTable,
const WirePointer* ref, const word* defaultValue, const WirePointer* ref, const word* defaultValue,
int nestingLimit)) { int nestingLimit)) {
return readStructPointer(segment, capTable, ref, ref->target(), defaultValue, nestingLimit); return readStructPointer(segment, capTable, ref, ref->target(segment),
defaultValue, nestingLimit);
} }
static KJ_ALWAYS_INLINE(StructReader readStructPointer( static KJ_ALWAYS_INLINE(StructReader readStructPointer(
...@@ -2101,7 +2114,7 @@ struct WireHelpers { ...@@ -2101,7 +2114,7 @@ struct WireHelpers {
} }
segment = nullptr; segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue); ref = reinterpret_cast<const WirePointer*>(defaultValue);
refTarget = ref->target(); refTarget = ref->target(segment);
defaultValue = nullptr; // If the default value is itself invalid, don't use it again. defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} }
...@@ -2121,7 +2134,7 @@ struct WireHelpers { ...@@ -2121,7 +2134,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()), KJ_REQUIRE(boundsCheck(segment, ptr, ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -2169,7 +2182,7 @@ struct WireHelpers { ...@@ -2169,7 +2182,7 @@ struct WireHelpers {
SegmentReader* segment, CapTableReader* capTable, SegmentReader* segment, CapTableReader* capTable,
const WirePointer* ref, const word* defaultValue, const WirePointer* ref, const word* defaultValue,
ElementSize expectedElementSize, int nestingLimit, bool checkElementSize = true)) { ElementSize expectedElementSize, int nestingLimit, bool checkElementSize = true)) {
return readListPointer(segment, capTable, ref, ref->target(), defaultValue, return readListPointer(segment, capTable, ref, ref->target(segment), defaultValue,
expectedElementSize, nestingLimit, checkElementSize); expectedElementSize, nestingLimit, checkElementSize);
} }
...@@ -2186,7 +2199,7 @@ struct WireHelpers { ...@@ -2186,7 +2199,7 @@ struct WireHelpers {
} }
segment = nullptr; segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue); ref = reinterpret_cast<const WirePointer*>(defaultValue);
refTarget = ref->target(); refTarget = ref->target(segment);
defaultValue = nullptr; // If the default value is itself invalid, don't use it again. defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} }
...@@ -2212,13 +2225,14 @@ struct WireHelpers { ...@@ -2212,13 +2225,14 @@ struct WireHelpers {
// An INLINE_COMPOSITE list points to a tag, which is formatted like a pointer. // An INLINE_COMPOSITE list points to a tag, which is formatted like a pointer.
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount), KJ_REQUIRE(boundsCheck(segment, ptr, wordCount + POINTER_SIZE_IN_WORDS),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
ptr += POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(tag->kind() == WirePointer::STRUCT, KJ_REQUIRE(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") { "INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault; goto useDefault;
...@@ -2301,7 +2315,7 @@ struct WireHelpers { ...@@ -2301,7 +2315,7 @@ struct WireHelpers {
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS; auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
auto wordCount = roundBitsUpToWords(upgradeBound<uint64_t>(elementCount) * step); auto wordCount = roundBitsUpToWords(upgradeBound<uint64_t>(elementCount) * step);
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + wordCount), KJ_REQUIRE(boundsCheck(segment, ptr, wordCount),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -2352,7 +2366,7 @@ struct WireHelpers { ...@@ -2352,7 +2366,7 @@ struct WireHelpers {
static KJ_ALWAYS_INLINE(Text::Reader readTextPointer( static KJ_ALWAYS_INLINE(Text::Reader readTextPointer(
SegmentReader* segment, const WirePointer* ref, SegmentReader* segment, const WirePointer* ref,
const void* defaultValue, ByteCount defaultSize)) { const void* defaultValue, ByteCount defaultSize)) {
return readTextPointer(segment, ref, ref->target(), defaultValue, defaultSize); return readTextPointer(segment, ref, ref->target(segment), defaultValue, defaultSize);
} }
static KJ_ALWAYS_INLINE(Text::Reader readTextPointer( static KJ_ALWAYS_INLINE(Text::Reader readTextPointer(
...@@ -2383,7 +2397,7 @@ struct WireHelpers { ...@@ -2383,7 +2397,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + roundBytesUpToWords(size)), KJ_REQUIRE(boundsCheck(segment, ptr, roundBytesUpToWords(size)),
"Message contained out-of-bounds text pointer.") { "Message contained out-of-bounds text pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -2406,7 +2420,7 @@ struct WireHelpers { ...@@ -2406,7 +2420,7 @@ struct WireHelpers {
static KJ_ALWAYS_INLINE(Data::Reader readDataPointer( static KJ_ALWAYS_INLINE(Data::Reader readDataPointer(
SegmentReader* segment, const WirePointer* ref, SegmentReader* segment, const WirePointer* ref,
const void* defaultValue, BlobSize defaultSize)) { const void* defaultValue, BlobSize defaultSize)) {
return readDataPointer(segment, ref, ref->target(), defaultValue, defaultSize); return readDataPointer(segment, ref, ref->target(segment), defaultValue, defaultSize);
} }
static KJ_ALWAYS_INLINE(Data::Reader readDataPointer( static KJ_ALWAYS_INLINE(Data::Reader readDataPointer(
...@@ -2436,7 +2450,7 @@ struct WireHelpers { ...@@ -2436,7 +2450,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + roundBytesUpToWords(size)), KJ_REQUIRE(boundsCheck(segment, ptr, roundBytesUpToWords(size)),
"Message contained out-of-bounds data pointer.") { "Message contained out-of-bounds data pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -2608,7 +2622,7 @@ PointerBuilder PointerBuilder::imbue(CapTableBuilder* capTable) { ...@@ -2608,7 +2622,7 @@ PointerBuilder PointerBuilder::imbue(CapTableBuilder* capTable) {
PointerReader PointerReader::getRoot(SegmentReader* segment, CapTableReader* capTable, PointerReader PointerReader::getRoot(SegmentReader* segment, CapTableReader* capTable,
const word* location, int nestingLimit) { const word* location, int nestingLimit) {
KJ_REQUIRE(WireHelpers::boundsCheck(segment, location, location + POINTER_SIZE_IN_WORDS), KJ_REQUIRE(WireHelpers::boundsCheck(segment, location, POINTER_SIZE_IN_WORDS),
"Root location out-of-bounds.") { "Root location out-of-bounds.") {
location = nullptr; location = nullptr;
} }
......
...@@ -95,7 +95,7 @@ AnyPointer::Reader MessageReader::getRootInternal() { ...@@ -95,7 +95,7 @@ AnyPointer::Reader MessageReader::getRootInternal() {
_::SegmentReader* segment = arena()->tryGetSegment(_::SegmentId(0)); _::SegmentReader* segment = arena()->tryGetSegment(_::SegmentId(0));
KJ_REQUIRE(segment != nullptr && KJ_REQUIRE(segment != nullptr &&
segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1), segment->checkObject(segment->getStartPtr(), ONE * WORDS),
"Message did not contain a root pointer.") { "Message did not contain a root pointer.") {
return AnyPointer::Reader(); return AnyPointer::Reader();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment