Commit bb7d1dff authored by Kenton Varda's avatar Kenton Varda

Get rid of redundant ErrorReporter that goes in ReaderOptions. Registering an…

Get rid of redundant ErrorReporter that goes in ReaderOptions.  Registering an ExceptionCallback is a better approach.
parent 5dcbfe3a
......@@ -39,7 +39,6 @@ Arena::~Arena() {}
ReaderArena::ReaderArena(MessageReader* message)
: message(message),
readLimiter(message->getOptions().traversalLimitInWords * WORDS),
ignoreErrors(false),
segment0(this, SegmentId(0), message->getSegment(0), &readLimiter) {}
ReaderArena::~ReaderArena() {}
......@@ -79,21 +78,8 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
return slot->get();
}
void ReaderArena::reportInvalidData(const char* description) {
if (!ignoreErrors) {
message->getOptions().errorReporter->reportError(description);
}
}
void ReaderArena::reportReadLimitReached() {
if (!ignoreErrors) {
message->getOptions().errorReporter->reportError(
"Exceeded message traversal limit. See capnproto::ReaderOptions.");
// Ignore further errors since they are likely repeats or caused by the read limit being
// reached.
ignoreErrors = true;
}
FAIL_VALIDATE_INPUT("Exceeded message traversal limit. See capnproto::ReaderOptions.");
}
// =======================================================================================
......@@ -204,14 +190,9 @@ SegmentReader* BuilderArena::tryGetSegment(SegmentId id) {
}
}
void BuilderArena::reportInvalidData(const char* description) {
// TODO: Better error reporting.
fprintf(stderr, "BuilderArena: Parse error: %s\n", description);
}
void BuilderArena::reportReadLimitReached() {
// TODO: Better error reporting.
fputs("BuilderArena: Exceeded read limit.\n", stderr);
FAIL_RECOVERABLE_CHECK(
"Read limit reached for BuilderArena, but it should have been unlimited.") {}
}
} // namespace internal
......
......@@ -129,31 +129,10 @@ public:
virtual SegmentReader* tryGetSegment(SegmentId id) = 0;
// Gets the segment with the given ID, or return nullptr if no such segment exists.
virtual void reportInvalidData(const char* description) = 0;
// Called to report that the message data is invalid.
//
// Implementations should, ideally, report the error to the sender, if possible. They may also
// want to write a debug message, etc.
//
// Implementations may choose to throw an exception in order to cut short further processing of
// the message. If no exception is thrown, then the caller will attempt to work around the
// invalid data by using a default value instead. This is good enough to guard against
// maliciously-crafted messages (the sender could just as easily have sent a perfectly-valid
// message containing the default value), but in the case of accidentally-corrupted messages this
// behavior may propagate the corruption.
//
// TODO: Give more information about the error, e.g. the segment and offset at which the invalid
// data was encountered, any relevant type/field names if known, etc.
virtual void reportReadLimitReached() = 0;
// Called to report that the read limit has been reached. See ReadLimiter, below.
//
// As with reportInvalidData(), this may throw an exception, and if it doesn't, default values
// will be used in place of the actual message data.
//
// If this method returns rather that throwing, many other errors are likely to be reported as
// a side-effect of reading being blocked. The Arena should ignore all further errors
// after this call.
// Called to report that the read limit has been reached. See ReadLimiter, below. This invokes
// the VALIDATE_INPUT() macro which may throw an exception; if it return normally, the caller
// will need to continue with default values.
// TODO: Methods to deal with bundled capabilities.
};
......@@ -166,13 +145,11 @@ public:
// implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override;
void reportInvalidData(const char* description) override;
void reportReadLimitReached() override;
private:
MessageReader* message;
ReadLimiter readLimiter;
bool ignoreErrors;
// Optimize for single-segment messages so that small messages are handled quickly.
SegmentReader segment0;
......@@ -203,7 +180,6 @@ public:
// implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override;
void reportInvalidData(const char* description) override;
void reportReadLimitReached() override;
private:
......
......@@ -258,14 +258,15 @@ struct WireHelpers {
if (ref->kind() == WireReference::FAR) {
// Look up the segment containing the landing pad.
segment = segment->getArena()->tryGetSegment(ref->farRef.segmentId.get());
if (CAPNPROTO_EXPECT_FALSE(segment == nullptr)) {
VALIDATE_INPUT(segment != nullptr, "Message contains far pointer to unknown segment.") {
return nullptr;
}
// Find the landing pad and check that it is within bounds.
const word* ptr = segment->getStartPtr() + ref->farPositionInSegment();
WordCount padWords = (1 + ref->isDoubleFar()) * REFERENCE_SIZE_IN_WORDS;
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + padWords))) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + padWords),
"Message contains out-of-bounds far pointer.") {
return nullptr;
}
......@@ -282,7 +283,8 @@ struct WireHelpers {
ref = pad + 1;
segment = segment->getArena()->tryGetSegment(pad->farRef.segmentId.get());
if (CAPNPROTO_EXPECT_FALSE(segment == nullptr)) {
VALIDATE_INPUT(segment != nullptr,
"Message contains double-far pointer to unknown segment.") {
return nullptr;
}
......@@ -608,28 +610,24 @@ struct WireHelpers {
ref = reinterpret_cast<const WireReference*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(nestingLimit == 0)) {
segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.");
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
segment->getArena()->reportInvalidData(
"Message contains invalid far reference.");
// Already reported the error.
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::STRUCT)) {
segment->getArena()->reportInvalidData(
"Message contains non-struct reference where struct reference was expected.");
VALIDATE_INPUT(ref->kind() == WireReference::STRUCT,
"Message contains non-struct reference where struct reference was expected.") {
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + ref->structRef.wordSize()))){
segment->getArena()->reportInvalidData(
"Message contained out-of-bounds struct reference.");
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct reference.") {
goto useDefault;
}
} else {
......@@ -657,22 +655,19 @@ struct WireHelpers {
ref = reinterpret_cast<const WireReference*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(nestingLimit == 0)) {
segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.");
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
segment->getArena()->reportInvalidData(
"Message contains invalid far reference.");
// Already reported error.
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::LIST)) {
segment->getArena()->reportInvalidData(
"Message contains non-list reference where list reference was expected.");
VALIDATE_INPUT(ref->kind() == WireReference::LIST,
"Message contains non-list reference where list reference was expected.") {
goto useDefault;
}
} else {
......@@ -691,25 +686,21 @@ struct WireHelpers {
ptr += REFERENCE_SIZE_IN_WORDS;
if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(
ptr - REFERENCE_SIZE_IN_WORDS, ptr + wordCount))) {
segment->getArena()->reportInvalidData(
"Message contains out-of-bounds list reference.");
VALIDATE_INPUT(segment->containsInterval(ptr - REFERENCE_SIZE_IN_WORDS, ptr + wordCount),
"Message contains out-of-bounds list reference.") {
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(tag->kind() != WireReference::STRUCT)) {
segment->getArena()->reportInvalidData(
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.");
VALIDATE_INPUT(tag->kind() == WireReference::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault;
}
size = tag->inlineCompositeListElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (CAPNPROTO_EXPECT_FALSE(size * wordsPerElement > wordCount)) {
segment->getArena()->reportInvalidData(
"INLINE_COMPOSITE list's elements overrun its word count.");
VALIDATE_INPUT(size * wordsPerElement <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.") {
goto useDefault;
}
......@@ -719,10 +710,8 @@ struct WireHelpers {
// if it were a primitive list without branching.
// Check whether the size is compatible.
bool compatible = false;
switch (expectedElementSize) {
case FieldSize::VOID:
compatible = true;
break;
case FieldSize::BIT:
......@@ -730,7 +719,10 @@ struct WireHelpers {
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES:
compatible = tag->structRef.dataSize.get() > 0 * WORDS;
VALIDATE_INPUT(tag->structRef.dataSize.get() > 0 * WORDS,
"Expected a primitive list, but got a list of pointer-only structs.") {
goto useDefault;
}
break;
case FieldSize::REFERENCE:
......@@ -738,19 +730,16 @@ struct WireHelpers {
// in the struct is the reference we were looking for, we want to munge the pointer to
// point at the first element's reference segment.
ptr += tag->structRef.dataSize.get();
compatible = tag->structRef.refCount.get() > 0 * REFERENCES;
VALIDATE_INPUT(tag->structRef.refCount.get() > 0 * REFERENCES,
"Expected a pointer list, but got a list of data-only structs.") {
goto useDefault;
}
break;
case FieldSize::INLINE_COMPOSITE:
compatible = true;
break;
}
if (CAPNPROTO_EXPECT_FALSE(!compatible)) {
segment->getArena()->reportInvalidData("A list had incompatible element type.");
goto useDefault;
}
} else {
// Trusted message.
// This logic is equivalent to the other branch, above, but skipping all the checks.
......@@ -770,10 +759,9 @@ struct WireHelpers {
decltype(BITS/ELEMENTS) step = bitsPerElement(ref->listRef.elementSize());
if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr +
roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)))) {
segment->getArena()->reportInvalidData(
"Message contained out-of-bounds list reference.");
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)),
"Message contained out-of-bounds list reference.") {
goto useDefault;
}
}
......@@ -818,7 +806,6 @@ struct WireHelpers {
dataSize, referenceCount, nestingLimit - 1);
} else {
PRECOND(segment != nullptr, "Trusted message had incompatible list element type.");
segment->getArena()->reportInvalidData("A list had incompatible element type.");
goto useDefault;
}
}
......@@ -839,39 +826,34 @@ struct WireHelpers {
ref->listRef.elementCount() / ELEMENTS - 1);
} else {
const word* ptr = followFars(ref, segment);
uint size = ref->listRef.elementCount() / ELEMENTS;
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
segment->getArena()->reportInvalidData(
"Message contains invalid far reference.");
// Already reported error.
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::LIST)) {
segment->getArena()->reportInvalidData(
"Message contains non-list reference where text was expected.");
uint size = ref->listRef.elementCount() / ELEMENTS;
VALIDATE_INPUT(ref->kind() == WireReference::LIST,
"Message contains non-list reference where text was expected.") {
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(ref->listRef.elementSize() != FieldSize::BYTE)) {
segment->getArena()->reportInvalidData(
"Message contains list reference of non-bytes where text was expected.");
VALIDATE_INPUT(ref->listRef.elementSize() == FieldSize::BYTE,
"Message contains list reference of non-bytes where text was expected.") {
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))))) {
segment->getArena()->reportInvalidData(
"Message contained out-of-bounds text reference.");
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds text reference.") {
goto useDefault;
}
const char* cptr = reinterpret_cast<const char*>(ptr);
--size; // NUL terminator
if (CAPNPROTO_EXPECT_FALSE(cptr[size] != '\0')) {
segment->getArena()->reportInvalidData(
"Message contains text that is not NUL-terminated.");
VALIDATE_INPUT(cptr[size] == '\0', "Message contains text that is not NUL-terminated.") {
goto useDefault;
}
......@@ -891,30 +873,27 @@ struct WireHelpers {
ref->listRef.elementCount() / ELEMENTS);
} else {
const word* ptr = followFars(ref, segment);
uint size = ref->listRef.elementCount() / ELEMENTS;
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
segment->getArena()->reportInvalidData(
"Message contains invalid far reference.");
// Already reported error.
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::LIST)) {
segment->getArena()->reportInvalidData(
"Message contains non-list reference where data was expected.");
uint size = ref->listRef.elementCount() / ELEMENTS;
VALIDATE_INPUT(ref->kind() == WireReference::LIST,
"Message contains non-list reference where data was expected.") {
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(ref->listRef.elementSize() != FieldSize::BYTE)) {
segment->getArena()->reportInvalidData(
"Message contains list reference of non-bytes where data was expected.");
VALIDATE_INPUT(ref->listRef.elementSize() == FieldSize::BYTE,
"Message contains list reference of non-bytes where data was expected.") {
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))))) {
segment->getArena()->reportInvalidData(
"Message contained out-of-bounds data reference.");
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds data reference.") {
goto useDefault;
}
......@@ -1009,8 +988,8 @@ StructReader StructReader::readRootTrusted(const word* location) {
StructReader StructReader::readRoot(
const word* location, SegmentReader* segment, int nestingLimit) {
if (!segment->containsInterval(location, location + REFERENCE_SIZE_IN_WORDS)) {
segment->getArena()->reportInvalidData("Root location out-of-bounds.");
VALIDATE_INPUT(segment->containsInterval(location, location + REFERENCE_SIZE_IN_WORDS),
"Root location out-of-bounds.") {
location = nullptr;
}
......@@ -1115,18 +1094,17 @@ ListReader ListBuilder::asReader(WordCount dataSize, WireReferenceCount referenc
}
StructReader ListReader::getStructElement(ElementCount index) const {
if (CAPNPROTO_EXPECT_FALSE((segment != nullptr) & (nestingLimit == 0))) {
segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.");
VALIDATE_INPUT((segment == nullptr) | (nestingLimit > 0),
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
return StructReader::readEmpty();
} else {
BitCount64 indexBit = ElementCount64(index) * stepBits;
const byte* structPtr = reinterpret_cast<const byte*>(ptr) + indexBit / BITS_PER_BYTE;
return StructReader(
segment, structPtr,
reinterpret_cast<const WireReference*>(structPtr + structDataSize * BYTES_PER_WORD),
structDataSize, structReferenceCount, indexBit % BITS_PER_BYTE, nestingLimit - 1);
}
BitCount64 indexBit = ElementCount64(index) * stepBits;
const byte* structPtr = reinterpret_cast<const byte*>(ptr) + indexBit / BITS_PER_BYTE;
return StructReader(
segment, structPtr,
reinterpret_cast<const WireReference*>(structPtr + structDataSize * BYTES_PER_WORD),
structDataSize, structReferenceCount, indexBit % BITS_PER_BYTE, nestingLimit - 1);
}
ListReader ListReader::getListElement(
......
......@@ -50,13 +50,13 @@ internal::StructReader MessageReader::getRootInternal() {
}
internal::SegmentReader* segment = arena()->tryGetSegment(SegmentId(0));
if (segment == nullptr ||
!segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1)) {
arena()->reportInvalidData("Message did not contain a root pointer.");
VALIDATE_INPUT(segment != nullptr &&
segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1),
"Message did not contain a root pointer.") {
return internal::StructReader::readEmpty();
} else {
return internal::StructReader::readRoot(segment->getStartPtr(), segment, options.nestingLimit);
}
return internal::StructReader::readRoot(segment->getStartPtr(), segment, options.nestingLimit);
}
// -------------------------------------------------------------------
......@@ -111,53 +111,6 @@ ArrayPtr<const ArrayPtr<const word>> MessageBuilder::getSegmentsForOutput() {
// =======================================================================================
ErrorReporter::~ErrorReporter() {}
class ThrowingErrorReporter: public ErrorReporter {
public:
virtual ~ThrowingErrorReporter() {}
void reportError(const char* description) override {
FAIL_VALIDATE_INPUT("Invalid Cap'n Proto message", description);
}
};
ErrorReporter* getThrowingErrorReporter() {
static ThrowingErrorReporter instance;
return &instance;
}
class StderrErrorReporter: public ErrorReporter {
public:
virtual ~StderrErrorReporter() {}
void reportError(const char* description) override {
std::string message("ERROR: Cap'n Proto message was invalid: ");
message += description;
message += '\n';
write(STDERR_FILENO, message.data(), message.size());
}
};
ErrorReporter* getStderrErrorReporter() {
static StderrErrorReporter instance;
return &instance;
}
class IgnoringErrorReporter: public ErrorReporter {
public:
virtual ~IgnoringErrorReporter() {}
void reportError(const char* description) override {}
};
ErrorReporter* getIgnoringErrorReporter() {
static IgnoringErrorReporter instance;
return &instance;
}
// =======================================================================================
SegmentArrayMessageReader::SegmentArrayMessageReader(
ArrayPtr<const ArrayPtr<const word>> segments, ReaderOptions options)
: MessageReader(options), segments(segments) {}
......
......@@ -42,38 +42,6 @@ typedef Id<uint32_t, Segment> SegmentId;
// =======================================================================================
class ErrorReporter {
// Abstract interface for a class which receives notification of errors found in an input message.
public:
virtual ~ErrorReporter();
virtual void reportError(const char* description) = 0;
// Reports an error discovered while validating a message. This happens lazily, as the message
// is traversed. E.g., it can happen when a get() accessor is called for a sub-struct or list,
// and that object is found to be out-of-bounds or has the wrong type.
//
// This method can throw an exception. If it does not, then the getter that was called will
// return the default value. Returning a default value is sufficient to prevent invalid messages
// from being a security threat, since an attacker could always construct a valid message
// containing the default value to get the same effect. However, returning a default value is
// not ideal when handling messages that were accidentally corrupted -- it may lead to the wrong
// behavior, e.g. storing the wrong data to disk, which could cause further problems down the
// road. Therefore, throwing an exception is preferred -- if your code is exception-safe, of
// course.
};
ErrorReporter* getThrowingErrorReporter();
// Returns a singleton ErrorReporter which throws an exception (deriving from std::exception) on
// error.
ErrorReporter* getStderrErrorReporter();
// Returns a singleton ErrorReporter which prints a message to stderr on error, then replaces the
// invalid data with the default value.
ErrorReporter* getIgnoringErrorReporter();
// Returns a singleton ErrorReporter which silently replaces invalid data with its default value.
struct ReaderOptions {
// Options controlling how data is read.
......@@ -107,9 +75,6 @@ struct ReaderOptions {
// overflow by sending a very-deeply-nested (or even cyclic) message, without the message even
// being very large. The default limit of 64 is probably low enough to prevent any chance of
// stack overflow, yet high enough that it is never a problem in practice.
ErrorReporter* errorReporter = getThrowingErrorReporter();
// How to report errors.
};
class MessageReader {
......
......@@ -41,8 +41,7 @@ FlatArrayMessageReader::FlatArrayMessageReader(ArrayPtr<const word> array, Reade
uint segmentCount = table[0].get() + 1;
size_t offset = segmentCount / 2u + 1u;
if (array.size() < offset) {
options.errorReporter->reportError("Message ends prematurely in segment table.");
VALIDATE_INPUT(array.size() >= offset, "Message ends prematurely in segment table.") {
return;
}
......@@ -52,8 +51,8 @@ FlatArrayMessageReader::FlatArrayMessageReader(ArrayPtr<const word> array, Reade
uint segmentSize = table[1].get();
if (array.size() < offset + segmentSize) {
options.errorReporter->reportError("Message ends prematurely in first segment.");
VALIDATE_INPUT(array.size() >= offset + segmentSize,
"Message ends prematurely in first segment.") {
return;
}
......@@ -66,9 +65,8 @@ FlatArrayMessageReader::FlatArrayMessageReader(ArrayPtr<const word> array, Reade
for (uint i = 1; i < segmentCount; i++) {
uint segmentSize = table[i + 1].get();
if (array.size() < offset + segmentSize) {
VALIDATE_INPUT(array.size() >= offset + segmentSize, "Message ends prematurely.") {
moreSegments = nullptr;
options.errorReporter->reportError("Message ends prematurely.");
return;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment