Commit bb7d1dff authored by Kenton Varda's avatar Kenton Varda

Get rid of redundant ErrorReporter that goes in ReaderOptions. Registering an…

Get rid of redundant ErrorReporter that goes in ReaderOptions.  Registering an ExceptionCallback is a better approach.
parent 5dcbfe3a
...@@ -39,7 +39,6 @@ Arena::~Arena() {} ...@@ -39,7 +39,6 @@ Arena::~Arena() {}
ReaderArena::ReaderArena(MessageReader* message) ReaderArena::ReaderArena(MessageReader* message)
: message(message), : message(message),
readLimiter(message->getOptions().traversalLimitInWords * WORDS), readLimiter(message->getOptions().traversalLimitInWords * WORDS),
ignoreErrors(false),
segment0(this, SegmentId(0), message->getSegment(0), &readLimiter) {} segment0(this, SegmentId(0), message->getSegment(0), &readLimiter) {}
ReaderArena::~ReaderArena() {} ReaderArena::~ReaderArena() {}
...@@ -79,21 +78,8 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { ...@@ -79,21 +78,8 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
return slot->get(); return slot->get();
} }
void ReaderArena::reportInvalidData(const char* description) {
if (!ignoreErrors) {
message->getOptions().errorReporter->reportError(description);
}
}
void ReaderArena::reportReadLimitReached() { void ReaderArena::reportReadLimitReached() {
if (!ignoreErrors) { FAIL_VALIDATE_INPUT("Exceeded message traversal limit. See capnproto::ReaderOptions.");
message->getOptions().errorReporter->reportError(
"Exceeded message traversal limit. See capnproto::ReaderOptions.");
// Ignore further errors since they are likely repeats or caused by the read limit being
// reached.
ignoreErrors = true;
}
} }
// ======================================================================================= // =======================================================================================
...@@ -204,14 +190,9 @@ SegmentReader* BuilderArena::tryGetSegment(SegmentId id) { ...@@ -204,14 +190,9 @@ SegmentReader* BuilderArena::tryGetSegment(SegmentId id) {
} }
} }
void BuilderArena::reportInvalidData(const char* description) {
// TODO: Better error reporting.
fprintf(stderr, "BuilderArena: Parse error: %s\n", description);
}
void BuilderArena::reportReadLimitReached() { void BuilderArena::reportReadLimitReached() {
// TODO: Better error reporting. FAIL_RECOVERABLE_CHECK(
fputs("BuilderArena: Exceeded read limit.\n", stderr); "Read limit reached for BuilderArena, but it should have been unlimited.") {}
} }
} // namespace internal } // namespace internal
......
...@@ -129,31 +129,10 @@ public: ...@@ -129,31 +129,10 @@ public:
virtual SegmentReader* tryGetSegment(SegmentId id) = 0; virtual SegmentReader* tryGetSegment(SegmentId id) = 0;
// Gets the segment with the given ID, or return nullptr if no such segment exists. // Gets the segment with the given ID, or return nullptr if no such segment exists.
virtual void reportInvalidData(const char* description) = 0;
// Called to report that the message data is invalid.
//
// Implementations should, ideally, report the error to the sender, if possible. They may also
// want to write a debug message, etc.
//
// Implementations may choose to throw an exception in order to cut short further processing of
// the message. If no exception is thrown, then the caller will attempt to work around the
// invalid data by using a default value instead. This is good enough to guard against
// maliciously-crafted messages (the sender could just as easily have sent a perfectly-valid
// message containing the default value), but in the case of accidentally-corrupted messages this
// behavior may propagate the corruption.
//
// TODO: Give more information about the error, e.g. the segment and offset at which the invalid
// data was encountered, any relevant type/field names if known, etc.
virtual void reportReadLimitReached() = 0; virtual void reportReadLimitReached() = 0;
// Called to report that the read limit has been reached. See ReadLimiter, below. // Called to report that the read limit has been reached. See ReadLimiter, below. This invokes
// // the VALIDATE_INPUT() macro which may throw an exception; if it return normally, the caller
// As with reportInvalidData(), this may throw an exception, and if it doesn't, default values // will need to continue with default values.
// will be used in place of the actual message data.
//
// If this method returns rather that throwing, many other errors are likely to be reported as
// a side-effect of reading being blocked. The Arena should ignore all further errors
// after this call.
// TODO: Methods to deal with bundled capabilities. // TODO: Methods to deal with bundled capabilities.
}; };
...@@ -166,13 +145,11 @@ public: ...@@ -166,13 +145,11 @@ public:
// implements Arena ------------------------------------------------ // implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override; SegmentReader* tryGetSegment(SegmentId id) override;
void reportInvalidData(const char* description) override;
void reportReadLimitReached() override; void reportReadLimitReached() override;
private: private:
MessageReader* message; MessageReader* message;
ReadLimiter readLimiter; ReadLimiter readLimiter;
bool ignoreErrors;
// Optimize for single-segment messages so that small messages are handled quickly. // Optimize for single-segment messages so that small messages are handled quickly.
SegmentReader segment0; SegmentReader segment0;
...@@ -203,7 +180,6 @@ public: ...@@ -203,7 +180,6 @@ public:
// implements Arena ------------------------------------------------ // implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override; SegmentReader* tryGetSegment(SegmentId id) override;
void reportInvalidData(const char* description) override;
void reportReadLimitReached() override; void reportReadLimitReached() override;
private: private:
......
This diff is collapsed.
...@@ -50,13 +50,13 @@ internal::StructReader MessageReader::getRootInternal() { ...@@ -50,13 +50,13 @@ internal::StructReader MessageReader::getRootInternal() {
} }
internal::SegmentReader* segment = arena()->tryGetSegment(SegmentId(0)); internal::SegmentReader* segment = arena()->tryGetSegment(SegmentId(0));
if (segment == nullptr || VALIDATE_INPUT(segment != nullptr &&
!segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1)) { segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1),
arena()->reportInvalidData("Message did not contain a root pointer."); "Message did not contain a root pointer.") {
return internal::StructReader::readEmpty(); return internal::StructReader::readEmpty();
} else {
return internal::StructReader::readRoot(segment->getStartPtr(), segment, options.nestingLimit);
} }
return internal::StructReader::readRoot(segment->getStartPtr(), segment, options.nestingLimit);
} }
// ------------------------------------------------------------------- // -------------------------------------------------------------------
...@@ -111,53 +111,6 @@ ArrayPtr<const ArrayPtr<const word>> MessageBuilder::getSegmentsForOutput() { ...@@ -111,53 +111,6 @@ ArrayPtr<const ArrayPtr<const word>> MessageBuilder::getSegmentsForOutput() {
// ======================================================================================= // =======================================================================================
ErrorReporter::~ErrorReporter() {}
class ThrowingErrorReporter: public ErrorReporter {
public:
virtual ~ThrowingErrorReporter() {}
void reportError(const char* description) override {
FAIL_VALIDATE_INPUT("Invalid Cap'n Proto message", description);
}
};
ErrorReporter* getThrowingErrorReporter() {
static ThrowingErrorReporter instance;
return &instance;
}
class StderrErrorReporter: public ErrorReporter {
public:
virtual ~StderrErrorReporter() {}
void reportError(const char* description) override {
std::string message("ERROR: Cap'n Proto message was invalid: ");
message += description;
message += '\n';
write(STDERR_FILENO, message.data(), message.size());
}
};
ErrorReporter* getStderrErrorReporter() {
static StderrErrorReporter instance;
return &instance;
}
class IgnoringErrorReporter: public ErrorReporter {
public:
virtual ~IgnoringErrorReporter() {}
void reportError(const char* description) override {}
};
ErrorReporter* getIgnoringErrorReporter() {
static IgnoringErrorReporter instance;
return &instance;
}
// =======================================================================================
SegmentArrayMessageReader::SegmentArrayMessageReader( SegmentArrayMessageReader::SegmentArrayMessageReader(
ArrayPtr<const ArrayPtr<const word>> segments, ReaderOptions options) ArrayPtr<const ArrayPtr<const word>> segments, ReaderOptions options)
: MessageReader(options), segments(segments) {} : MessageReader(options), segments(segments) {}
......
...@@ -42,38 +42,6 @@ typedef Id<uint32_t, Segment> SegmentId; ...@@ -42,38 +42,6 @@ typedef Id<uint32_t, Segment> SegmentId;
// ======================================================================================= // =======================================================================================
class ErrorReporter {
// Abstract interface for a class which receives notification of errors found in an input message.
public:
virtual ~ErrorReporter();
virtual void reportError(const char* description) = 0;
// Reports an error discovered while validating a message. This happens lazily, as the message
// is traversed. E.g., it can happen when a get() accessor is called for a sub-struct or list,
// and that object is found to be out-of-bounds or has the wrong type.
//
// This method can throw an exception. If it does not, then the getter that was called will
// return the default value. Returning a default value is sufficient to prevent invalid messages
// from being a security threat, since an attacker could always construct a valid message
// containing the default value to get the same effect. However, returning a default value is
// not ideal when handling messages that were accidentally corrupted -- it may lead to the wrong
// behavior, e.g. storing the wrong data to disk, which could cause further problems down the
// road. Therefore, throwing an exception is preferred -- if your code is exception-safe, of
// course.
};
ErrorReporter* getThrowingErrorReporter();
// Returns a singleton ErrorReporter which throws an exception (deriving from std::exception) on
// error.
ErrorReporter* getStderrErrorReporter();
// Returns a singleton ErrorReporter which prints a message to stderr on error, then replaces the
// invalid data with the default value.
ErrorReporter* getIgnoringErrorReporter();
// Returns a singleton ErrorReporter which silently replaces invalid data with its default value.
struct ReaderOptions { struct ReaderOptions {
// Options controlling how data is read. // Options controlling how data is read.
...@@ -107,9 +75,6 @@ struct ReaderOptions { ...@@ -107,9 +75,6 @@ struct ReaderOptions {
// overflow by sending a very-deeply-nested (or even cyclic) message, without the message even // overflow by sending a very-deeply-nested (or even cyclic) message, without the message even
// being very large. The default limit of 64 is probably low enough to prevent any chance of // being very large. The default limit of 64 is probably low enough to prevent any chance of
// stack overflow, yet high enough that it is never a problem in practice. // stack overflow, yet high enough that it is never a problem in practice.
ErrorReporter* errorReporter = getThrowingErrorReporter();
// How to report errors.
}; };
class MessageReader { class MessageReader {
......
...@@ -41,8 +41,7 @@ FlatArrayMessageReader::FlatArrayMessageReader(ArrayPtr<const word> array, Reade ...@@ -41,8 +41,7 @@ FlatArrayMessageReader::FlatArrayMessageReader(ArrayPtr<const word> array, Reade
uint segmentCount = table[0].get() + 1; uint segmentCount = table[0].get() + 1;
size_t offset = segmentCount / 2u + 1u; size_t offset = segmentCount / 2u + 1u;
if (array.size() < offset) { VALIDATE_INPUT(array.size() >= offset, "Message ends prematurely in segment table.") {
options.errorReporter->reportError("Message ends prematurely in segment table.");
return; return;
} }
...@@ -52,8 +51,8 @@ FlatArrayMessageReader::FlatArrayMessageReader(ArrayPtr<const word> array, Reade ...@@ -52,8 +51,8 @@ FlatArrayMessageReader::FlatArrayMessageReader(ArrayPtr<const word> array, Reade
uint segmentSize = table[1].get(); uint segmentSize = table[1].get();
if (array.size() < offset + segmentSize) { VALIDATE_INPUT(array.size() >= offset + segmentSize,
options.errorReporter->reportError("Message ends prematurely in first segment."); "Message ends prematurely in first segment.") {
return; return;
} }
...@@ -66,9 +65,8 @@ FlatArrayMessageReader::FlatArrayMessageReader(ArrayPtr<const word> array, Reade ...@@ -66,9 +65,8 @@ FlatArrayMessageReader::FlatArrayMessageReader(ArrayPtr<const word> array, Reade
for (uint i = 1; i < segmentCount; i++) { for (uint i = 1; i < segmentCount; i++) {
uint segmentSize = table[i + 1].get(); uint segmentSize = table[i + 1].get();
if (array.size() < offset + segmentSize) { VALIDATE_INPUT(array.size() >= offset + segmentSize, "Message ends prematurely.") {
moreSegments = nullptr; moreSegments = nullptr;
options.errorReporter->reportError("Message ends prematurely.");
return; return;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment