Commit 2ca8e411 authored by Kenton Varda's avatar Kenton Varda

Refactor bounds checks to avoid ever creating out-of-bounds pointer values,…

Refactor bounds checks to avoid ever creating out-of-bounds pointer values, which is technically UB even if not dereferenced.
parent 52bc9564
......@@ -27,6 +27,7 @@
#include <vector>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#if !CAPNP_LITE
#include "capability.h"
......@@ -48,6 +49,12 @@ void ReadLimiter::unread(WordCount64 amount) {
}
}
void SegmentReader::abortCheckObjectFault() {
KJ_LOG(FATAL, "checkObject()'s parameter is not in-range; this would segfault in opt mode",
"this is a serious bug in Cap'n Proto; please notify security@sandstorm.io");
abort();
}
void SegmentBuilder::throwNotWritable() {
KJ_FAIL_REQUIRE(
"Tried to form a Builder to an external data segment referenced by the MessageBuilder. "
......
......@@ -117,7 +117,20 @@ public:
inline SegmentReader(Arena* arena, SegmentId id, const word* ptr, SegmentWordCount size,
ReadLimiter* readLimiter);
KJ_ALWAYS_INLINE(bool containsInterval(const void* from, const void* to));
KJ_ALWAYS_INLINE(const word* checkOffset(const word* from, ptrdiff_t offset));
// Adds the given offset to the given pointer, checks that it is still within the bounds of the
// segment, then returns it. Note that the "end" pointer of the segment (which technically points
// to the word after the last in the segment) is considered in-bounds for this purpose, so you
// can't necessarily dereference it. You must call checkObject() next to check that the object
// you want to read is entirely in-bounds.
//
// If `from + offset` is out-of-range, this returns a pointer to the end of the segment. Thus,
// any non-zero-sized object will fail `checkObject()`. We do this instead of throwing to save
// some code footprint.
KJ_ALWAYS_INLINE(bool checkObject(const word* start, WordCountN<31> size));
// Assuming that `start` is in-bounds for this segment (probably checked using `checkOffset()`),
// check that `start + size` is also in-bounds, and hence the whole area in-between is valid.
KJ_ALWAYS_INLINE(bool amplifiedRead(WordCount virtualAmount));
// Indicates that the reader should pretend that `virtualAmount` additional data was read even
......@@ -147,6 +160,9 @@ private:
KJ_DISALLOW_COPY(SegmentReader);
friend class SegmentBuilder;
static void abortCheckObjectFault();
// Called in debug mode in cases that would segfault in opt mode. (Should be impossible!)
};
class SegmentBuilder: public SegmentReader {
......@@ -367,18 +383,25 @@ inline SegmentReader::SegmentReader(Arena* arena, SegmentId id, const word* ptr,
: arena(arena), id(id), ptr(kj::arrayPtr(ptr, unbound(size / WORDS))),
readLimiter(readLimiter) {}
inline bool SegmentReader::containsInterval(const void* from, const void* to) {
uintptr_t start = reinterpret_cast<uintptr_t>(from) - reinterpret_cast<uintptr_t>(ptr.begin());
uintptr_t end = reinterpret_cast<uintptr_t>(to) - reinterpret_cast<uintptr_t>(ptr.begin());
uintptr_t bound = ptr.size() * sizeof(capnp::word);
return start <= bound && end <= bound && start <= end &&
readLimiter->canRead(
intervalLength(reinterpret_cast<const byte*>(from),
reinterpret_cast<const byte*>(to),
MAX_SEGMENT_WORDS * BYTES_PER_WORD)
/ BYTES_PER_WORD,
arena);
inline const word* SegmentReader::checkOffset(const word* from, ptrdiff_t offset) {
ptrdiff_t min = ptr.begin() - from;
ptrdiff_t max = ptr.end() - from;
if (offset >= min && offset <= max) {
return from + offset;
} else {
return ptr.end();
}
}
inline bool SegmentReader::checkObject(const word* start, WordCountN<31> size) {
auto startOffset = intervalLength(ptr.begin(), start, MAX_SEGMENT_WORDS);
#ifdef KJ_DEBUG
if (startOffset > bounded(ptr.size()) * WORDS) {
abortCheckObjectFault();
}
#endif
return startOffset + size <= bounded(ptr.size()) * WORDS &&
readLimiter->canRead(size, arena);
}
inline bool SegmentReader::amplifiedRead(WordCount virtualAmount) {
......
......@@ -122,9 +122,14 @@ struct WirePointer {
KJ_ALWAYS_INLINE(word* target()) {
return reinterpret_cast<word*>(this) + 1 + (static_cast<int32_t>(offsetAndKind.get()) >> 2);
}
KJ_ALWAYS_INLINE(const word* target() const) {
return reinterpret_cast<const word*>(this) + 1 +
KJ_ALWAYS_INLINE(const word* target(SegmentReader* segment) const) {
if (segment == nullptr) {
return reinterpret_cast<const word*>(this + 1) +
(static_cast<int32_t>(offsetAndKind.get()) >> 2);
} else {
return segment->checkOffset(reinterpret_cast<const word*>(this + 1),
static_cast<int32_t>(offsetAndKind.get()) >> 2);
}
}
KJ_ALWAYS_INLINE(void setKindAndTarget(Kind kind, word* target, SegmentBuilder* segment)) {
// Check that the target is really in the same segment, otherwise subtracting pointers is
......@@ -140,9 +145,14 @@ struct WirePointer {
// So now when the pointers are not aligned the same, we can end up corrupting the bottom
// two bits, where `kind` is stored. For example, this turns a struct into a far pointer.
// Ouch!
KJ_DREQUIRE(segment->containsInterval(
reinterpret_cast<word*>(this), reinterpret_cast<word*>(this + 1)));
KJ_DREQUIRE(segment->containsInterval(target, target));
KJ_DREQUIRE(reinterpret_cast<uintptr_t>(this) >=
reinterpret_cast<uintptr_t>(segment->getStartPtr()));
KJ_DREQUIRE(reinterpret_cast<uintptr_t>(this) <
reinterpret_cast<uintptr_t>(segment->getStartPtr() + segment->getSize()));
KJ_DREQUIRE(reinterpret_cast<uintptr_t>(target) >=
reinterpret_cast<uintptr_t>(segment->getStartPtr()));
KJ_DREQUIRE(reinterpret_cast<uintptr_t>(target) <=
reinterpret_cast<uintptr_t>(segment->getStartPtr() + segment->getSize()));
offsetAndKind.set(((target - reinterpret_cast<word*>(this) - 1) << 2) | kind);
}
KJ_ALWAYS_INLINE(void setKindWithZeroOffset(Kind kind)) {
......@@ -174,15 +184,20 @@ struct WirePointer {
offsetAndKind.set(unboundAs<uint32_t>((elementCount / ELEMENTS) << G(2)) | kind);
}
KJ_ALWAYS_INLINE(SegmentWordCount farPositionInSegment() const) {
KJ_ALWAYS_INLINE(const word* farTarget(SegmentReader* segment) const) {
KJ_DREQUIRE(kind() == FAR,
"farTarget() should only be called on FAR pointers.");
return segment->checkOffset(segment->getStartPtr(), offsetAndKind.get() >> 3);
}
KJ_ALWAYS_INLINE(word* farTarget(SegmentBuilder* segment) const) {
KJ_DREQUIRE(kind() == FAR,
"positionInSegment() should only be called on FAR pointers.");
return (bounded(offsetAndKind.get()) >> G(3)) * WORDS;
"farTarget() should only be called on FAR pointers.");
return segment->getPtrUnchecked((bounded(offsetAndKind.get()) >> G(3)) * WORDS);
}
KJ_ALWAYS_INLINE(bool isDoubleFar() const) {
KJ_DREQUIRE(kind() == FAR,
"isDoubleFar() should only be called on FAR pointers.");
return unbound((bounded(offsetAndKind.get()) >> G(2)) & G(1));
return (offsetAndKind.get() >> 2) & 1;
}
KJ_ALWAYS_INLINE(void setFar(bool isDoubleFar, WordCountN<29> pos)) {
offsetAndKind.set(unboundAs<uint32_t>((pos / WORDS) << G(3)) |
......@@ -403,9 +418,9 @@ struct WireHelpers {
}
static KJ_ALWAYS_INLINE(bool boundsCheck(
SegmentReader* segment, const word* start, const word* end)) {
SegmentReader* segment, const word* start, WordCountN<31> size)) {
// If segment is null, this is an unchecked message, so we don't do bounds checks.
return segment == nullptr || segment->containsInterval(start, end);
return segment == nullptr || segment->checkObject(start, size);
}
static KJ_ALWAYS_INLINE(bool amplifiedRead(SegmentReader* segment, WordCount virtualAmount)) {
......@@ -498,8 +513,7 @@ struct WireHelpers {
if (ref->kind() == WirePointer::FAR) {
segment = segment->getArena()->getSegment(ref->farRef.segmentId.get());
WirePointer* pad =
reinterpret_cast<WirePointer*>(segment->getPtrUnchecked(ref->farPositionInSegment()));
WirePointer* pad = reinterpret_cast<WirePointer*>(ref->farTarget(segment));
if (!ref->isDoubleFar()) {
ref = pad;
return pad->target();
......@@ -510,7 +524,7 @@ struct WireHelpers {
ref = pad + 1;
segment = segment->getArena()->getSegment(pad->farRef.segmentId.get());
return segment->getPtrUnchecked(pad->farPositionInSegment());
return pad->farTarget(segment);
} else {
return refTarget;
}
......@@ -536,9 +550,9 @@ struct WireHelpers {
}
// Find the landing pad and check that it is within bounds.
const word* ptr = segment->getStartPtr() + ref->farPositionInSegment();
WordCount padWords = bounded(1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + padWords),
const word* ptr = ref->farTarget(segment);
auto padWords = (ONE + bounded(ref->isDoubleFar())) * POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(boundsCheck(segment, ptr, padWords),
"Message contains out-of-bounds far pointer.") {
return nullptr;
}
......@@ -548,7 +562,7 @@ struct WireHelpers {
// If this is not a double-far then the landing pad is our final pointer.
if (!ref->isDoubleFar()) {
ref = pad;
return pad->target();
return pad->target(segment);
}
// Landing pad is another far pointer. It is followed by a tag describing the pointed-to
......@@ -566,7 +580,7 @@ struct WireHelpers {
}
segment = newSegment;
return segment->getStartPtr() + pad->farPositionInSegment();
return pad->farTarget(segment);
} else {
return refTarget;
}
......@@ -589,14 +603,12 @@ struct WireHelpers {
case WirePointer::FAR: {
segment = segment->getArena()->getSegment(ref->farRef.segmentId.get());
if (segment->isWritable()) { // Don't zero external data.
WirePointer* pad =
reinterpret_cast<WirePointer*>(segment->getPtrUnchecked(ref->farPositionInSegment()));
WirePointer* pad = reinterpret_cast<WirePointer*>(ref->farTarget(segment));
if (ref->isDoubleFar()) {
segment = segment->getArena()->getSegment(pad->farRef.segmentId.get());
if (segment->isWritable()) {
zeroObject(segment, capTable,
pad + 1, segment->getPtrUnchecked(pad->farPositionInSegment()));
zeroObject(segment, capTable, pad + 1, pad->farTarget(segment));
}
zeroMemory(pad, G(2) * POINTERS);
} else {
......@@ -712,8 +724,7 @@ struct WireHelpers {
if (ref->kind() == WirePointer::FAR) {
SegmentBuilder* padSegment = segment->getArena()->getSegment(ref->farRef.segmentId.get());
if (padSegment->isWritable()) { // Don't zero external data.
WirePointer* pad = reinterpret_cast<WirePointer*>(
padSegment->getPtrUnchecked(ref->farPositionInSegment()));
WirePointer* pad = reinterpret_cast<WirePointer*>(ref->farTarget(padSegment));
if (ref->isDoubleFar()) {
zeroMemory(pad, G(2) * POINTERS);
} else {
......@@ -743,11 +754,11 @@ struct WireHelpers {
}
--nestingLimit;
const word* ptr = followFars(ref, ref->target(), segment);
const word* ptr = followFars(ref, ref->target(segment), segment);
switch (ref->kind()) {
case WirePointer::STRUCT: {
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
KJ_REQUIRE(boundsCheck(segment, ptr, ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
return result;
}
......@@ -773,7 +784,7 @@ struct WireHelpers {
auto totalWords = roundBitsUpToWords(
upgradeBound<uint64_t>(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize()));
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + totalWords),
KJ_REQUIRE(boundsCheck(segment, ptr, totalWords),
"Message contained out-of-bounds list pointer.") {
return result;
}
......@@ -781,9 +792,9 @@ struct WireHelpers {
break;
}
case ElementSize::POINTER: {
WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS);
auto count = ref->listRef.elementCount() * (POINTERS / ELEMENTS);
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + count * WORDS_PER_POINTER),
KJ_REQUIRE(boundsCheck(segment, ptr, count * WORDS_PER_POINTER),
"Message contained out-of-bounds list pointer.") {
return result;
}
......@@ -798,7 +809,7 @@ struct WireHelpers {
}
case ElementSize::INLINE_COMPOSITE: {
auto wordCount = ref->listRef.inlineCompositeWordCount();
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
KJ_REQUIRE(boundsCheck(segment, ptr, wordCount + POINTER_SIZE_IN_WORDS),
"Message contained out-of-bounds list pointer.") {
return result;
}
......@@ -889,7 +900,7 @@ struct WireHelpers {
zeroMemory(dst);
return nullptr;
} else {
const word* srcPtr = src->target();
const word* srcPtr = src->target(nullptr);
word* dstPtr = allocate(
dst, segment, capTable, src->structRef.wordSize(), WirePointer::STRUCT, nullptr);
......@@ -911,7 +922,7 @@ struct WireHelpers {
auto wordCount = roundBitsUpToWords(
upgradeBound<uint64_t>(src->listRef.elementCount()) *
dataBitsPerElement(src->listRef.elementSize()));
const word* srcPtr = src->target();
const word* srcPtr = src->target(nullptr);
word* dstPtr = allocate(dst, segment, capTable, wordCount, WirePointer::LIST, nullptr);
copyMemory(dstPtr, srcPtr, wordCount);
......@@ -920,7 +931,7 @@ struct WireHelpers {
}
case ElementSize::POINTER: {
const WirePointer* srcRefs = reinterpret_cast<const WirePointer*>(src->target());
const WirePointer* srcRefs = reinterpret_cast<const WirePointer*>(src->target(nullptr));
WirePointer* dstRefs = reinterpret_cast<WirePointer*>(
allocate(dst, segment, capTable, src->listRef.elementCount() *
(ONE * POINTERS / ELEMENTS) * WORDS_PER_POINTER,
......@@ -937,7 +948,7 @@ struct WireHelpers {
}
case ElementSize::INLINE_COMPOSITE: {
const word* srcPtr = src->target();
const word* srcPtr = src->target(nullptr);
word* dstPtr = allocate(dst, segment, capTable,
assertMaxBits<SEGMENT_WORD_COUNT_BITS>(
src->listRef.inlineCompositeWordCount() + POINTER_SIZE_IN_WORDS,
......@@ -1873,7 +1884,7 @@ struct WireHelpers {
int nestingLimit, BuilderArena* orphanArena = nullptr,
bool canonical = false)) {
return copyPointer(dstSegment, dstCapTable, dst,
srcSegment, srcCapTable, src, src->target(),
srcSegment, srcCapTable, src, src->target(srcSegment),
nestingLimit, orphanArena, canonical);
}
......@@ -1908,7 +1919,7 @@ struct WireHelpers {
goto useDefault;
}
KJ_REQUIRE(boundsCheck(srcSegment, ptr, ptr + src->structRef.wordSize()),
KJ_REQUIRE(boundsCheck(srcSegment, ptr, src->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
goto useDefault;
}
......@@ -1931,13 +1942,14 @@ struct WireHelpers {
if (elementSize == ElementSize::INLINE_COMPOSITE) {
auto wordCount = src->listRef.inlineCompositeWordCount();
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(boundsCheck(srcSegment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
KJ_REQUIRE(boundsCheck(srcSegment, ptr, wordCount + POINTER_SIZE_IN_WORDS),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
ptr += POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault;
......@@ -1974,7 +1986,7 @@ struct WireHelpers {
auto elementCount = src->listRef.elementCount();
auto wordCount = roundBitsUpToWords(upgradeBound<uint64_t>(elementCount) * step);
KJ_REQUIRE(boundsCheck(srcSegment, ptr, ptr + wordCount),
KJ_REQUIRE(boundsCheck(srcSegment, ptr, wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
......@@ -2086,7 +2098,8 @@ struct WireHelpers {
SegmentReader* segment, CapTableReader* capTable,
const WirePointer* ref, const word* defaultValue,
int nestingLimit)) {
return readStructPointer(segment, capTable, ref, ref->target(), defaultValue, nestingLimit);
return readStructPointer(segment, capTable, ref, ref->target(segment),
defaultValue, nestingLimit);
}
static KJ_ALWAYS_INLINE(StructReader readStructPointer(
......@@ -2101,7 +2114,7 @@ struct WireHelpers {
}
segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue);
refTarget = ref->target();
refTarget = ref->target(segment);
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
......@@ -2121,7 +2134,7 @@ struct WireHelpers {
goto useDefault;
}
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
KJ_REQUIRE(boundsCheck(segment, ptr, ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
goto useDefault;
}
......@@ -2169,7 +2182,7 @@ struct WireHelpers {
SegmentReader* segment, CapTableReader* capTable,
const WirePointer* ref, const word* defaultValue,
ElementSize expectedElementSize, int nestingLimit, bool checkElementSize = true)) {
return readListPointer(segment, capTable, ref, ref->target(), defaultValue,
return readListPointer(segment, capTable, ref, ref->target(segment), defaultValue,
expectedElementSize, nestingLimit, checkElementSize);
}
......@@ -2186,7 +2199,7 @@ struct WireHelpers {
}
segment = nullptr;
ref = reinterpret_cast<const WirePointer*>(defaultValue);
refTarget = ref->target();
refTarget = ref->target(segment);
defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
}
......@@ -2212,13 +2225,14 @@ struct WireHelpers {
// An INLINE_COMPOSITE list points to a tag, which is formatted like a pointer.
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
KJ_REQUIRE(boundsCheck(segment, ptr, wordCount + POINTER_SIZE_IN_WORDS),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
ptr += POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault;
......@@ -2301,7 +2315,7 @@ struct WireHelpers {
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
auto wordCount = roundBitsUpToWords(upgradeBound<uint64_t>(elementCount) * step);
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + wordCount),
KJ_REQUIRE(boundsCheck(segment, ptr, wordCount),
"Message contains out-of-bounds list pointer.") {
goto useDefault;
}
......@@ -2352,7 +2366,7 @@ struct WireHelpers {
static KJ_ALWAYS_INLINE(Text::Reader readTextPointer(
SegmentReader* segment, const WirePointer* ref,
const void* defaultValue, ByteCount defaultSize)) {
return readTextPointer(segment, ref, ref->target(), defaultValue, defaultSize);
return readTextPointer(segment, ref, ref->target(segment), defaultValue, defaultSize);
}
static KJ_ALWAYS_INLINE(Text::Reader readTextPointer(
......@@ -2383,7 +2397,7 @@ struct WireHelpers {
goto useDefault;
}
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + roundBytesUpToWords(size)),
KJ_REQUIRE(boundsCheck(segment, ptr, roundBytesUpToWords(size)),
"Message contained out-of-bounds text pointer.") {
goto useDefault;
}
......@@ -2406,7 +2420,7 @@ struct WireHelpers {
static KJ_ALWAYS_INLINE(Data::Reader readDataPointer(
SegmentReader* segment, const WirePointer* ref,
const void* defaultValue, BlobSize defaultSize)) {
return readDataPointer(segment, ref, ref->target(), defaultValue, defaultSize);
return readDataPointer(segment, ref, ref->target(segment), defaultValue, defaultSize);
}
static KJ_ALWAYS_INLINE(Data::Reader readDataPointer(
......@@ -2436,7 +2450,7 @@ struct WireHelpers {
goto useDefault;
}
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + roundBytesUpToWords(size)),
KJ_REQUIRE(boundsCheck(segment, ptr, roundBytesUpToWords(size)),
"Message contained out-of-bounds data pointer.") {
goto useDefault;
}
......@@ -2608,7 +2622,7 @@ PointerBuilder PointerBuilder::imbue(CapTableBuilder* capTable) {
PointerReader PointerReader::getRoot(SegmentReader* segment, CapTableReader* capTable,
const word* location, int nestingLimit) {
KJ_REQUIRE(WireHelpers::boundsCheck(segment, location, location + POINTER_SIZE_IN_WORDS),
KJ_REQUIRE(WireHelpers::boundsCheck(segment, location, POINTER_SIZE_IN_WORDS),
"Root location out-of-bounds.") {
location = nullptr;
}
......
......@@ -95,7 +95,7 @@ AnyPointer::Reader MessageReader::getRootInternal() {
_::SegmentReader* segment = arena()->tryGetSegment(_::SegmentId(0));
KJ_REQUIRE(segment != nullptr &&
segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1),
segment->checkObject(segment->getStartPtr(), ONE * WORDS),
"Message did not contain a root pointer.") {
return AnyPointer::Reader();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment