Commit 9c3cb05f authored by Kenton Varda's avatar Kenton Varda

Another iteration of refactoring the Message classes.

parent fe29fa08
......@@ -34,10 +34,11 @@ Arena::~Arena() {}
// =======================================================================================
ReaderArena::ReaderArena(std::unique_ptr<ReaderContext> context)
: context(std::move(context)),
readLimiter(this->context->getReadLimit() * WORDS),
segment0(this, SegmentId(0), this->context->getSegment(0), &readLimiter) {}
ReaderArena::ReaderArena(MessageReader* message)
: message(message),
readLimiter(this->message->getOptions().traversalLimitInWords * WORDS),
ignoreErrors(false),
segment0(this, SegmentId(0), this->message->getSegment(0), &readLimiter) {}
ReaderArena::~ReaderArena() {}
......@@ -61,7 +62,7 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
}
}
ArrayPtr<const word> newSegment = context->getSegment(id.value);
ArrayPtr<const word> newSegment = message->getSegment(id.value);
if (newSegment == nullptr) {
return nullptr;
}
......@@ -77,17 +78,26 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
}
void ReaderArena::reportInvalidData(const char* description) {
context->reportError(description);
if (!ignoreErrors) {
message->getOptions().errorReporter->reportError(description);
}
}
void ReaderArena::reportReadLimitReached() {
context->reportError("Exceeded read limit.");
if (!ignoreErrors) {
message->getOptions().errorReporter->reportError(
"Exceeded message traversal limit. See capnproto::ReaderOptions.");
// Ignore further errors since they are likely repeats or caused by the read limit being
// reached.
ignoreErrors = true;
}
}
// =======================================================================================
BuilderArena::BuilderArena(std::unique_ptr<BuilderContext> context)
: context(std::move(context)), segment0(nullptr, SegmentId(0), nullptr, nullptr) {}
BuilderArena::BuilderArena(MessageBuilder* message)
: message(message), segment0(nullptr, SegmentId(0), nullptr, nullptr) {}
BuilderArena::~BuilderArena() {}
SegmentBuilder* BuilderArena::getSegment(SegmentId id) {
......@@ -105,7 +115,7 @@ SegmentBuilder* BuilderArena::getSegmentWithAvailable(WordCount minimumAvailable
if (segment0.getArena() == nullptr) {
// We're allocating the first segment.
ArrayPtr<word> ptr = context->allocateSegment(minimumAvailable / WORDS);
ArrayPtr<word> ptr = message->allocateSegment(minimumAvailable / WORDS);
// Re-allocate segment0 in-place. This is a bit of a hack, but we have not returned any
// pointers to this segment yet, so it should be fine.
......@@ -132,7 +142,7 @@ SegmentBuilder* BuilderArena::getSegmentWithAvailable(WordCount minimumAvailable
std::unique_ptr<SegmentBuilder> newBuilder = std::unique_ptr<SegmentBuilder>(
new SegmentBuilder(this, SegmentId(moreSegments->builders.size() + 1),
context->allocateSegment(minimumAvailable / WORDS), &this->dummyLimiter));
message->allocateSegment(minimumAvailable / WORDS), &this->dummyLimiter));
SegmentBuilder* result = newBuilder.get();
moreSegments->builders.push_back(std::move(newBuilder));
......
......@@ -116,8 +116,6 @@ private:
word* pos;
CAPNPROTO_DISALLOW_COPY(SegmentBuilder);
// TODO: Do we need mutex locking?
};
class Arena {
......@@ -158,7 +156,7 @@ public:
class ReaderArena final: public Arena {
public:
ReaderArena(std::unique_ptr<ReaderContext> context);
ReaderArena(MessageReader* message);
~ReaderArena();
CAPNPROTO_DISALLOW_COPY(ReaderArena);
......@@ -168,8 +166,9 @@ public:
void reportReadLimitReached() override;
private:
std::unique_ptr<ReaderContext> context;
MessageReader* message;
ReadLimiter readLimiter;
bool ignoreErrors;
// Optimize for single-segment messages so that small messages are handled quickly.
SegmentReader segment0;
......@@ -180,7 +179,7 @@ private:
class BuilderArena final: public Arena {
public:
BuilderArena(std::unique_ptr<BuilderContext> context);
BuilderArena(MessageBuilder* message);
~BuilderArena();
CAPNPROTO_DISALLOW_COPY(BuilderArena);
......@@ -204,7 +203,7 @@ public:
void reportReadLimitReached() override;
private:
std::unique_ptr<BuilderContext> context;
MessageBuilder* message;
ReadLimiter dummyLimiter;
SegmentBuilder segment0;
......
......@@ -227,78 +227,80 @@ void checkMessage(Reader reader) {
}
TEST(Encoding, AllTypes) {
Message<TestAllTypes>::Builder builder;
MallocMessageBuilder builder;
initMessage(builder.initRoot());
checkMessage(builder.getRoot());
checkMessage(builder.getRoot().asReader());
initMessage(builder.initRoot<TestAllTypes>());
checkMessage(builder.getRoot<TestAllTypes>());
checkMessage(builder.getRoot<TestAllTypes>().asReader());
Message<TestAllTypes>::Reader reader(builder.getSegmentsForOutput());
SegmentArrayMessageReader reader(builder.getSegmentsForOutput());
checkMessage(reader.getRoot());
checkMessage(reader.getRoot<TestAllTypes>());
ASSERT_EQ(1u, builder.getSegmentsForOutput().size());
checkMessage(Message<TestAllTypes>::readTrusted(builder.getSegmentsForOutput()[0].begin()));
checkMessage(readMessageTrusted<TestAllTypes>(builder.getSegmentsForOutput()[0].begin()));
}
TEST(Encoding, AllTypesMultiSegment) {
Message<TestAllTypes>::Builder builder(newFixedWidthBuilderContext(0));
MallocMessageBuilder builder(0, AllocationStrategy::FIXED_SIZE);
initMessage(builder.initRoot());
checkMessage(builder.getRoot());
checkMessage(builder.getRoot().asReader());
initMessage(builder.initRoot<TestAllTypes>());
checkMessage(builder.getRoot<TestAllTypes>());
checkMessage(builder.getRoot<TestAllTypes>().asReader());
Message<TestAllTypes>::Reader reader(builder.getSegmentsForOutput());
SegmentArrayMessageReader reader(builder.getSegmentsForOutput());
checkMessage(reader.getRoot());
checkMessage(reader.getRoot<TestAllTypes>());
}
TEST(Encoding, Defaults) {
AlignedData<1> nullRoot = {{0, 0, 0, 0, 0, 0, 0, 0}};
ArrayPtr<const word> segments[1] = {arrayPtr(nullRoot.words, 1)};
Message<TestDefaults>::Reader reader(arrayPtr(segments, 1));
SegmentArrayMessageReader reader(arrayPtr(segments, 1));
checkMessage(reader.getRoot());
checkMessage(Message<TestDefaults>::readTrusted(nullRoot.words));
checkMessage(reader.getRoot<TestDefaults>());
checkMessage(readMessageTrusted<TestDefaults>(nullRoot.words));
}
TEST(Encoding, DefaultInitialization) {
Message<TestDefaults>::Builder builder;
MallocMessageBuilder builder;
checkMessage(builder.getRoot()); // first pass initializes to defaults
checkMessage(builder.getRoot().asReader());
checkMessage(builder.getRoot<TestDefaults>()); // first pass initializes to defaults
checkMessage(builder.getRoot<TestDefaults>().asReader());
checkMessage(builder.getRoot()); // second pass just reads the initialized structure
checkMessage(builder.getRoot().asReader());
checkMessage(builder.getRoot<TestDefaults>()); // second pass just reads the initialized structure
checkMessage(builder.getRoot<TestDefaults>().asReader());
Message<TestDefaults>::Reader reader(builder.getSegmentsForOutput());
SegmentArrayMessageReader reader(builder.getSegmentsForOutput());
checkMessage(reader.getRoot());
checkMessage(reader.getRoot<TestDefaults>());
}
TEST(Encoding, DefaultInitializationMultiSegment) {
Message<TestDefaults>::Builder builder(newFixedWidthBuilderContext(0));
MallocMessageBuilder builder(0, AllocationStrategy::FIXED_SIZE);
checkMessage(builder.getRoot()); // first pass initializes to defaults
checkMessage(builder.getRoot().asReader());
// first pass initializes to defaults
checkMessage(builder.getRoot<TestDefaults>());
checkMessage(builder.getRoot<TestDefaults>().asReader());
checkMessage(builder.getRoot()); // second pass just reads the initialized structure
checkMessage(builder.getRoot().asReader());
// second pass just reads the initialized structure
checkMessage(builder.getRoot<TestDefaults>());
checkMessage(builder.getRoot<TestDefaults>().asReader());
Message<TestDefaults>::Reader reader(builder.getSegmentsForOutput());
SegmentArrayMessageReader reader(builder.getSegmentsForOutput());
checkMessage(reader.getRoot());
checkMessage(reader.getRoot<TestDefaults>());
}
TEST(Encoding, DefaultsFromEmptyMessage) {
AlignedData<1> emptyMessage = {{4, 0, 0, 0, 0, 0, 0, 0}};
ArrayPtr<const word> segments[1] = {arrayPtr(emptyMessage.words, 1)};
Message<TestDefaults>::Reader reader(arrayPtr(segments, 1));
SegmentArrayMessageReader reader(arrayPtr(segments, 1));
checkMessage(reader.getRoot());
checkMessage(Message<TestDefaults>::readTrusted(emptyMessage.words));
checkMessage(reader.getRoot<TestDefaults>());
checkMessage(readMessageTrusted<TestDefaults>(emptyMessage.words));
}
} // namespace
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// This header contains internal interfaces relied upon by message.h and implemented in message.c++.
// These declarations should be thought of as being part of message.h, but I moved them here to make
// message.h more readable. The problem is that the interface people really care about in
// message.h -- namely, Message -- has to be declared after internal::Message. Having an internal
// interface appear in the middle of the header ahead of the public interface was distracting and
// confusing.
#ifndef CAPNPROTO_MESSAGE_INTERNAL_H_
#define CAPNPROTO_MESSAGE_INTERNAL_H_
#include <cstddef>
#include <memory>
#include "type-safety.h"
#include "wire-format.h"
namespace capnproto {
class ReaderContext;
class BuilderContext;
}
namespace capnproto {
namespace internal {
// TODO: Move to message-internal.h so that this header looks nicer?
class ReaderArena;
class BuilderArena;
struct MessageImpl {
// Underlying implementation of capnproto::Message. All the parts that don't need to be templated
// are implemented by this class, so that they can be shared and non-inline.
MessageImpl() = delete;
class Reader {
public:
Reader(ArrayPtr<const ArrayPtr<const word>> segments);
Reader(std::unique_ptr<ReaderContext> context);
Reader(Reader&& other) = default;
CAPNPROTO_DISALLOW_COPY(Reader);
~Reader();
StructReader getRoot(const word* defaultValue);
private:
uint recursionLimit;
// Space in which we can construct a ReaderArena. We don't use ReaderArena directly here
// because we don't want clients to have to #include arena.h, which itself includes a bunch of
// big STL headers. We don't use a pointer to a ReaderArena because that would require an
// extra malloc on every message which could be expensive when processing small messages,
// particularly when the context itself is freelisted and so no other allocation is necessary.
void* arenaSpace[15];
ReaderArena* arena() { return reinterpret_cast<ReaderArena*>(arenaSpace); }
};
class Builder {
public:
Builder();
Builder(std::unique_ptr<BuilderContext> context);
Builder(Builder&& other) = default;
CAPNPROTO_DISALLOW_COPY(Builder);
~Builder();
StructBuilder initRoot(const word* defaultValue);
StructBuilder getRoot(const word* defaultValue);
ArrayPtr<const ArrayPtr<const word>> getSegmentsForOutput();
private:
SegmentBuilder* rootSegment;
// Space in which we can construct a BuilderArena. We don't use BuilderArena directly here
// because we don't want clients to have to #include arena.h, which itself includes a bunch of
// big STL headers. We don't use a pointer to a BuilderArena because that would require an
// extra malloc on every message which could be expensive when processing small messages,
// particularly when the context itself is freelisted and so no other allocation is necessary.
void* arenaSpace[15];
BuilderArena* arena() { return reinterpret_cast<BuilderArena*>(arenaSpace); }
static SegmentBuilder* allocateRoot(BuilderArena* arena);
};
};
} // namespace internal
} // namespace capnproto
#endif // CAPNPROTO_MESSAGE_INTERNAL_H_
......@@ -31,8 +31,86 @@
namespace capnproto {
ReaderContext::~ReaderContext() {}
BuilderContext::~BuilderContext() {}
MessageReader::MessageReader(ReaderOptions options): options(options), allocatedArena(false) {}
MessageReader::~MessageReader() {
if (allocatedArena) {
arena()->~ReaderArena();
}
}
internal::StructReader MessageReader::getRoot(const word* defaultValue) {
if (!allocatedArena) {
static_assert(sizeof(internal::ReaderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a ReaderArena. Please increase it. This will break "
"ABI compatibility.");
new(arena()) internal::ReaderArena(this);
allocatedArena = true;
}
internal::SegmentReader* segment = arena()->tryGetSegment(SegmentId(0));
if (segment == nullptr ||
!segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1)) {
segment->getArena()->reportInvalidData("Message did not contain a root pointer.");
return internal::StructReader::readRootTrusted(defaultValue, defaultValue);
} else {
return internal::StructReader::readRoot(
segment->getStartPtr(), defaultValue, segment, options.nestingLimit);
}
}
// -------------------------------------------------------------------
MessageBuilder::MessageBuilder(): allocatedArena(false) {}
MessageBuilder::~MessageBuilder() {
if (allocatedArena) {
arena()->~BuilderArena();
}
}
internal::SegmentBuilder* MessageBuilder::getRootSegment() {
if (allocatedArena) {
return arena()->getSegment(SegmentId(0));
} else {
static_assert(sizeof(internal::BuilderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a BuilderArena. Please increase it. This will break "
"ABI compatibility.");
new(arena()) internal::BuilderArena(this);
allocatedArena = true;
WordCount refSize = 1 * REFERENCES * WORDS_PER_REFERENCE;
internal::SegmentBuilder* segment = arena()->getSegmentWithAvailable(refSize);
CAPNPROTO_ASSERT(segment->getSegmentId() == SegmentId(0),
"First allocated word of new arena was not in segment ID 0.");
word* location = segment->allocate(refSize);
CAPNPROTO_ASSERT(location == segment->getPtrUnchecked(0 * WORDS),
"First allocated word of new arena was not the first word in its segment.");
return segment;
}
}
internal::StructBuilder MessageBuilder::initRoot(const word* defaultValue) {
internal::SegmentBuilder* rootSegment = getRootSegment();
return internal::StructBuilder::initRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue);
}
internal::StructBuilder MessageBuilder::getRoot(const word* defaultValue) {
internal::SegmentBuilder* rootSegment = getRootSegment();
return internal::StructBuilder::getRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue);
}
ArrayPtr<const ArrayPtr<const word>> MessageBuilder::getSegmentsForOutput() {
if (allocatedArena) {
return arena()->getSegmentsForOutput();
} else {
return nullptr;
}
}
// =======================================================================================
ErrorReporter::~ErrorReporter() {}
class ParseException: public std::exception {
public:
......@@ -48,192 +126,107 @@ private:
std::string message;
};
class DefaultReaderContext: public ReaderContext {
class ThrowingErrorReporter: public ErrorReporter {
public:
DefaultReaderContext(ArrayPtr<const ArrayPtr<const word>> segments,
ErrorBehavior errorBehavior, uint64_t readLimit, uint nestingLimit)
: segments(segments), errorBehavior(errorBehavior), readLimit(readLimit),
nestingLimit(nestingLimit) {}
~DefaultReaderContext() {}
virtual ~ThrowingErrorReporter() {}
ArrayPtr<const word> getSegment(uint id) override {
if (id < segments.size()) {
return segments[id];
} else {
return nullptr;
}
void reportError(const char* description) override {
std::string message("Cap'n Proto message was invalid: ");
message += description;
throw ParseException(std::move(message));
}
};
uint64_t getReadLimit() override {
return readLimit;
}
ErrorReporter* getThrowingErrorReporter() {
static ThrowingErrorReporter instance;
return &instance;
}
uint getNestingLimit() override {
return nestingLimit;
}
class StderrErrorReporter: public ErrorReporter {
public:
virtual ~StderrErrorReporter() {}
void reportError(const char* description) override {
std::string message("ERROR: Cap'n Proto parse error: ");
std::string message("ERROR: Cap'n Proto message was invalid: ");
message += description;
message += '\n';
switch (errorBehavior) {
case ErrorBehavior::THROW_EXCEPTION:
throw ParseException(std::move(message));
break;
case ErrorBehavior::REPORT_TO_STDERR_AND_RETURN_DEFAULT:
write(STDERR_FILENO, message.data(), message.size());
break;
case ErrorBehavior::IGNORE_AND_RETURN_DEFAULT:
break;
}
}
private:
ArrayPtr<const ArrayPtr<const word>> segments;
ErrorBehavior errorBehavior;
uint64_t readLimit;
uint nestingLimit;
};
std::unique_ptr<ReaderContext> newReaderContext(
ArrayPtr<const ArrayPtr<const word>> segments,
ErrorBehavior errorBehavior, uint64_t readLimit, uint nestingLimit) {
return std::unique_ptr<ReaderContext>(new DefaultReaderContext(
segments, errorBehavior, readLimit, nestingLimit));
ErrorReporter* getStderrErrorReporter() {
static StderrErrorReporter instance;
return &instance;
}
class DefaultBuilderContext: public BuilderContext {
class IgnoringErrorReporter: public ErrorReporter {
public:
DefaultBuilderContext(uint firstSegmentWords, bool enableGrowthHeursitic)
: nextSize(firstSegmentWords), enableGrowthHeursitic(enableGrowthHeursitic),
firstSegment(nullptr) {}
~DefaultBuilderContext() {
free(firstSegment);
for (void* ptr: moreSegments) {
free(ptr);
}
}
virtual ~IgnoringErrorReporter() {}
ArrayPtr<word> allocateSegment(uint minimumSize) override {
uint size = std::max(minimumSize, nextSize);
void* result = calloc(size, sizeof(word));
if (result == nullptr) {
throw std::bad_alloc();
}
if (firstSegment == nullptr) {
firstSegment = result;
if (enableGrowthHeursitic) nextSize = size;
} else {
moreSegments.push_back(result);
if (enableGrowthHeursitic) nextSize += size;
}
return arrayPtr(reinterpret_cast<word*>(result), size);
}
private:
uint nextSize;
bool enableGrowthHeursitic;
// Avoid allocating the vector if there is only one segment.
void* firstSegment;
std::vector<void*> moreSegments;
void reportError(const char* description) override {}
};
std::unique_ptr<BuilderContext> newBuilderContext(uint firstSegmentWords) {
return std::unique_ptr<BuilderContext>(new DefaultBuilderContext(firstSegmentWords, true));
}
std::unique_ptr<BuilderContext> newFixedWidthBuilderContext(uint firstSegmentWords) {
return std::unique_ptr<BuilderContext>(new DefaultBuilderContext(firstSegmentWords, false));
ErrorReporter* getIgnoringErrorReporter() {
static IgnoringErrorReporter instance;
return &instance;
}
// =======================================================================================
namespace internal {
MessageImpl::Reader::Reader(ArrayPtr<const ArrayPtr<const word>> segments) {
std::unique_ptr<ReaderContext> context = newReaderContext(segments);
recursionLimit = context->getNestingLimit();
static_assert(sizeof(ReaderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a ReaderArena. Please increase it. This will break "
"ABI compatibility.");
new(arena()) ReaderArena(std::move(context));
}
MessageImpl::Reader::Reader(std::unique_ptr<ReaderContext> context)
: recursionLimit(context->getNestingLimit()) {
static_assert(sizeof(ReaderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a ReaderArena. Please increase it. This will break "
"ABI compatibility.");
new(arena()) ReaderArena(std::move(context));
}
SegmentArrayMessageReader::SegmentArrayMessageReader(
ArrayPtr<const ArrayPtr<const word>> segments, ReaderOptions options)
: MessageReader(options), segments(segments) {}
MessageImpl::Reader::~Reader() {
arena()->~ReaderArena();
}
SegmentArrayMessageReader::~SegmentArrayMessageReader() {}
StructReader MessageImpl::Reader::getRoot(const word* defaultValue) {
SegmentReader* segment = arena()->tryGetSegment(SegmentId(0));
if (segment == nullptr ||
!segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1)) {
segment->getArena()->reportInvalidData("Message did not contain a root pointer.");
return StructReader::readRootTrusted(defaultValue, defaultValue);
ArrayPtr<const word> SegmentArrayMessageReader::getSegment(uint id) {
if (id < segments.size()) {
return segments[id];
} else {
return StructReader::readRoot(segment->getStartPtr(), defaultValue, segment, recursionLimit);
return nullptr;
}
}
MessageImpl::Builder::Builder(): rootSegment(nullptr) {
std::unique_ptr<BuilderContext> context = newBuilderContext();
// -------------------------------------------------------------------
static_assert(sizeof(BuilderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a BuilderArena. Please increase it. This will break "
"ABI compatibility.");
new(arena()) BuilderArena(std::move(context));
}
struct MallocMessageBuilder::MoreSegments {
std::vector<void*> segments;
};
MessageImpl::Builder::Builder(std::unique_ptr<BuilderContext> context): rootSegment(nullptr) {
static_assert(sizeof(BuilderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a BuilderArena. Please increase it. This will break "
"ABI compatibility.");
new(arena()) BuilderArena(std::move(context));
}
MallocMessageBuilder::MallocMessageBuilder(
uint firstSegmentWords, AllocationStrategy allocationStrategy)
: nextSize(firstSegmentWords), allocationStrategy(allocationStrategy),
firstSegment(nullptr) {}
MessageImpl::Builder::~Builder() {
arena()->~BuilderArena();
MallocMessageBuilder::~MallocMessageBuilder() {
free(firstSegment);
if (moreSegments != nullptr) {
for (void* ptr: moreSegments->segments) {
free(ptr);
}
}
}
StructBuilder MessageImpl::Builder::initRoot(const word* defaultValue) {
if (rootSegment == nullptr) rootSegment = allocateRoot(arena());
return StructBuilder::initRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue);
}
ArrayPtr<word> MallocMessageBuilder::allocateSegment(uint minimumSize) {
uint size = std::max(minimumSize, nextSize);
StructBuilder MessageImpl::Builder::getRoot(const word* defaultValue) {
if (rootSegment == nullptr) rootSegment = allocateRoot(arena());
return StructBuilder::getRoot(rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue);
}
void* result = calloc(size, sizeof(word));
if (result == nullptr) {
throw std::bad_alloc();
}
ArrayPtr<const ArrayPtr<const word>> MessageImpl::Builder::getSegmentsForOutput() {
return arena()->getSegmentsForOutput();
}
if (firstSegment == nullptr) {
firstSegment = result;
if (allocationStrategy == AllocationStrategy::GROW_HEURISTICALLY) nextSize = size;
} else {
if (moreSegments == nullptr) {
moreSegments = std::unique_ptr<MoreSegments>(new MoreSegments);
}
moreSegments->segments.push_back(result);
if (allocationStrategy == AllocationStrategy::GROW_HEURISTICALLY) nextSize += size;
}
SegmentBuilder* MessageImpl::Builder::allocateRoot(BuilderArena* arena) {
WordCount refSize = 1 * REFERENCES * WORDS_PER_REFERENCE;
SegmentBuilder* segment = arena->getSegmentWithAvailable(refSize);
CAPNPROTO_ASSERT(segment->getSegmentId() == SegmentId(0),
"First allocated word of new arena was not in segment ID 0.");
word* location = segment->allocate(refSize);
CAPNPROTO_ASSERT(location == segment->getPtrUnchecked(0 * WORDS),
"First allocated word of new arena was not the first word in its segment.");
return segment;
return arrayPtr(reinterpret_cast<word*>(result), size);
}
} // namespace internal
} // namespace capnproto
......@@ -26,189 +26,279 @@
#include "macros.h"
#include "type-safety.h"
#include "wire-format.h"
#include "message-internal.h"
#ifndef CAPNPROTO_MESSAGE_H_
#define CAPNPROTO_MESSAGE_H_
namespace capnproto {
namespace internal {
class ReaderArena;
class BuilderArena;
}
class Segment;
typedef Id<uint32_t, Segment> SegmentId;
// =======================================================================================
class ReaderContext {
class ErrorReporter {
// Abstract interface for a class which receives notification of errors found in an input message.
public:
virtual ~ReaderContext();
virtual ~ErrorReporter();
virtual void reportError(const char* description) = 0;
// Reports an error discovered while validating a message. This happens lazily, as the message
// is traversed. E.g., it can happen when a get() accessor is called for a sub-struct or list,
// and that object is found to be out-of-bounds or has the wrong type.
//
// This method can throw an exception. If it does not, then the getter that was called will
// return the default value. Returning a default value is sufficient to prevent invalid messages
// from being a security threat, since an attacker could always construct a valid message
// containing the default value to get the same effect. However, returning a default value is
// not ideal when handling messages that were accidentally corrupted -- it may lead to the wrong
// behavior, e.g. storing the wrong data to disk, which could cause further problems down the
// road. Therefore, throwing an exception is preferred -- if your code is exception-safe, of
// course.
};
ErrorReporter* getThrowingErrorReporter();
// Returns a singleton ErrorReporter which throws an exception (deriving from std::exception) on
// error.
ErrorReporter* getStderrErrorReporter();
// Returns a singleton ErrorReporter which prints a message to stderr on error, then replaces the
// invalid data with the default value.
ErrorReporter* getIgnoringErrorReporter();
// Returns a singleton ErrorReporter which silently replaces invalid data with its default value.
struct ReaderOptions {
// Options controlling how data is read.
uint64_t traversalLimitInWords = 8 * 1024 * 1024;
// Limits how many total words of data are allowed to be traversed. Traversal is counted when
// a new struct or list builder is obtained, e.g. from a get() accessor. This means that calling
// the getter for the same sub-struct multiple times will cause it to be double-counted. Once
// the traversal limit is reached, an error will be reported.
//
// This limit exists for security reasons. It is possible for an attacker to construct a message
// in which multiple pointers point at the same location. This is technically invalid, but hard
// to detect. Using such a message, an attacker could cause a message which is small on the wire
// to appear much larger when actually traversed, possibly exhausting server resources leading to
// denial-of-service.
//
// It makes sense to set a traversal limit that is much larger than the underlying message.
// Together with sensible coding practices (e.g. trying to avoid calling sub-object getters
// multiple times, which is expensive anyway), this should provide adequate protection without
// inconvenience.
//
// The default limit is 64 MiB. This may or may not be a sensible number for any given use case,
// but probably at least prevents easy exploitation while also avoiding causing problems in most
// typical cases.
uint nestingLimit = 64;
// Limits how deeply-nested a message structure can be, e.g. structs containing other structs or
// lists of structs.
//
// Like the traversal limit, this limit exists for security reasons. Since it is common to use
// recursive code to traverse recursive data structures, an attacker could easily cause a stack
// overflow by sending a very-deeply-nested (or even cyclic) message, without the message even
// being very large. The default limit of 64 is probably low enough to prevent any chance of
// stack overflow, yet high enough that it is never a problem in practice.
ErrorReporter* errorReporter = getThrowingErrorReporter();
// How to report errors.
};
class MessageReader {
public:
MessageReader(ReaderOptions options);
// It is suggested that subclasses take ReaderOptions as a constructor parameter, but give it a
// default value of "ReaderOptions()". The base class constructor doesn't have a default value
// in order to remind subclasses that they really need to give the user a way to provide this.
virtual ~MessageReader();
virtual ArrayPtr<const word> getSegment(uint id) = 0;
// Gets the segment with the given ID, or returns null if no such segment exists.
virtual uint64_t getReadLimit() = 0;
virtual uint getNestingLimit() = 0;
inline const ReaderOptions& getOptions();
// Get the options passed to the constructor.
virtual void reportError(const char* description) = 0;
};
template <typename RootType>
typename RootType::Reader getRoot();
enum class ErrorBehavior {
THROW_EXCEPTION,
REPORT_TO_STDERR_AND_RETURN_DEFAULT,
IGNORE_AND_RETURN_DEFAULT
};
private:
ReaderOptions options;
// Space in which we can construct a ReaderArena. We don't use ReaderArena directly here
// because we don't want clients to have to #include arena.h, which itself includes a bunch of
// big STL headers. We don't use a pointer to a ReaderArena because that would require an
// extra malloc on every message which could be expensive when processing small messages.
void* arenaSpace[15];
bool allocatedArena;
std::unique_ptr<ReaderContext> newReaderContext(
ArrayPtr<const ArrayPtr<const word>> segments,
ErrorBehavior errorBehavior = ErrorBehavior::THROW_EXCEPTION,
uint64_t readLimit = 64 * 1024 * 1024, uint nestingLimit = 64);
// Creates a ReaderContext pointing at the given segment list, without taking ownership of the
// segments. All arrays passed in must remain valid until the context is destroyed.
internal::ReaderArena* arena() { return reinterpret_cast<internal::ReaderArena*>(arenaSpace); }
internal::StructReader getRoot(const word* defaultValue);
};
class BuilderContext {
class MessageBuilder {
public:
virtual ~BuilderContext();
MessageBuilder();
virtual ~MessageBuilder();
virtual ArrayPtr<word> allocateSegment(uint minimumSize) = 0;
// Allocates an array of at least the given number of words, throwing an exception or crashing if
// this is not possible. It is expected that this method will usually return more space than
// requested, and the caller should use that extra space as much as possible before allocating
// more. All returned space is deleted when the context is destroyed.
};
std::unique_ptr<BuilderContext> newBuilderContext(uint firstSegmentWords = 1024);
// Creates a BuilderContext which allocates at least the given number of words for the first
// segment, and then heuristically decides how much to allocate for subsequent segments. This
// should work well for most use cases that do not require writing messages to specific locations
// in memory. When choosing a value for firstSegmentWords, consider that:
// 1) Reading and writing messages gets slower when multiple segments are involved, so it's good
// if most messages fit in a single segment.
// 2) Unused bytes will not be written to the wire, so generally it is not a big deal to allocate
// more space than you need. It only becomes problematic if you are allocating many messages
// in parallel and thus use lots of memory, or if you allocate so much extra space that just
// zeroing it out becomes a bottleneck.
// The default has been chosen to be reasonable for most people, so don't change it unless you have
// reason to believe you need to.
std::unique_ptr<BuilderContext> newFixedWidthBuilderContext(uint preferredSegmentWords = 1024);
// Creates a BuilderContext which will always prefer to allocate segments with the given size with
// no heuristic growth. It will still allocate larger segments when the preferred size is too small
// for some single object. You can force every single object to be located in a separate segment by
// passing zero for the parameter to this function, but this isn't a good idea. This context
// implementation is probably most useful for testing purposes, where you want to verify that your
// serializer works when a message is split across segments and you want those segments to be
// somewhat predictable.
// =======================================================================================
// more. The returned space remains valid at least until the MessageBuilder is destroyed.
template <typename RootType>
struct Message {
Message() = delete;
template <typename RootType>
typename RootType::Builder initRoot();
template <typename RootType>
typename RootType::Builder getRoot();
class Reader {
public:
Reader(ArrayPtr<const ArrayPtr<const word>> segments);
// Make a Reader that reads from the given segments, as if the context were created using
// newReaderContext(segments).
ArrayPtr<const ArrayPtr<const word>> getSegmentsForOutput();
Reader(std::unique_ptr<ReaderContext> context);
private:
// Space in which we can construct a BuilderArena. We don't use BuilderArena directly here
// because we don't want clients to have to #include arena.h, which itself includes a bunch of
// big STL headers. We don't use a pointer to a BuilderArena because that would require an
// extra malloc on every message which could be expensive when processing small messages.
void* arenaSpace[15];
bool allocatedArena = false;
internal::BuilderArena* arena() { return reinterpret_cast<internal::BuilderArena*>(arenaSpace); }
internal::SegmentBuilder* getRootSegment();
internal::StructBuilder initRoot(const word* defaultValue);
internal::StructBuilder getRoot(const word* defaultValue);
};
CAPNPROTO_DISALLOW_COPY(Reader);
Reader(Reader&& other) = default;
template <typename RootType>
static typename RootType::Reader readMessageTrusted(const word* data);
// IF THE INPUT IS INVALID, THIS MAY CRASH, CORRUPT MEMORY, CREATE A SECURITY HOLE IN YOUR APP,
// MURDER YOUR FIRST-BORN CHILD, AND/OR BRING ABOUT ETERNAL DAMNATION ON ALL OF HUMANITY. DO NOT
// USE UNLESS YOU UNDERSTAND THE CONSEQUENCES.
//
// Given a pointer to a known-valid message located in a single contiguous memory segment,
// returns a reader for that message. No bounds-checking will be done while tranversing this
// message. Use this only if you are absolutely sure that the input data is a valid message
// created by your own system. Never use this to read messages received from others.
//
// To create a trusted message, build a message using a MallocAllocator whose preferred segment
// size is larger than the message size. This guarantees that the message will be allocated as a
// single segment, meaning getSegmentsForOutput() returns a single word array. That word array
// is your message; you may pass a pointer to its first word into readTrusted() to read the
// message.
//
// This can be particularly handy for embedding messages in generated code: you can
// embed the raw bytes (using AlignedData) then make a Reader for it using this. This is the way
// default values are embedded in code generated by the Cap'n Proto compiler. E.g., if you have
// a message MyMessage, you can read its default value like so:
// MyMessage::Reader reader = Message<MyMessage>::ReadTrusted(MyMessage::DEFAULT.words);
typename RootType::Reader getRoot();
// Get a reader pointing to the message root.
// =======================================================================================
private:
internal::MessageImpl::Reader internal;
};
class SegmentArrayMessageReader: public MessageReader {
// A simple MessageReader that reads from an array of word arrays representing all segments.
// In particular you can read directly from the output of MessageBuilder::getSegmentsForOutput()
// (although it would probably make more sense to call builder.getRoot().asReader() in that case).
class Builder {
public:
Builder();
// Make a Builder as if with a context created by newBuilderContext().
public:
SegmentArrayMessageReader(ArrayPtr<const ArrayPtr<const word>> segments,
ReaderOptions options = ReaderOptions());
// Creates a message pointing at the given segment array, without taking ownership of the
// segments. All arrays passed in must remain valid until the MessageReader is destroyed.
Builder(std::unique_ptr<BuilderContext> context);
CAPNPROTO_DISALLOW_COPY(SegmentArrayMessageReader);
~SegmentArrayMessageReader();
CAPNPROTO_DISALLOW_COPY(Builder);
Builder(Builder&& other) = default;
virtual ArrayPtr<const word> getSegment(uint id) override;
typename RootType::Builder initRoot();
// Allocate and initialize the message root. If already initialized, the old data is discarded.
private:
ArrayPtr<const ArrayPtr<const word>> segments;
};
typename RootType::Builder getRoot();
// Get the message root, initializing it to the type's default value if it isn't initialized
// already.
enum class AllocationStrategy {
FIXED_SIZE,
// The builder will prefer to allocate the same amount of space for each segment with no
// heuristic growth. It will still allocate larger segments when the preferred size is too small
// for some single object. This mode is generally not recommended, but can be particularly useful
// for testing in order to force a message to allocate a predictable number of segments. Note
// that you can force every single object in the message to be located in a separate segment by
// using this mode with firstSegmentWords = 0.
GROW_HEURISTICALLY
// The builder will heuristically decide how much space to allocate for each segment. Each
// allocated segment will be progressively larger than the previous segments on the assumption
// that message sizes are exponentially distributed. The total number of segments that will be
// allocated for a message of size n is O(log n).
};
ArrayPtr<const ArrayPtr<const word>> getSegmentsForOutput();
constexpr uint SUGGESTED_FIRST_SEGMENT_WORDS = 1024;
constexpr AllocationStrategy SUGGESTED_ALLOCATION_STRATEGY = AllocationStrategy::GROW_HEURISTICALLY;
private:
internal::MessageImpl::Builder internal;
};
class MallocMessageBuilder: public MessageBuilder {
// A simple MessageBuilder that uses malloc() (actually, calloc()) to allocate segments. This
// implementation should be reasonable for any case that doesn't require writing the message to
// a specific location in memory.
static typename RootType::Reader readTrusted(const word* data);
// IF THE INPUT IS INVALID, THIS MAY CRASH, CORRUPT MEMORY, CREATE A SECURITY HOLE IN YOUR APP,
// MURDER YOUR FIRST-BORN CHILD, AND/OR BRING ABOUT ETERNAL DAMNATION ON ALL OF HUMANITY. DO NOT
// USE UNLESS YOU UNDERSTAND THE CONSEQUENCES.
//
// Given a pointer to a known-valid message located in a single contiguous memory segment,
// returns a reader for that message. No bounds-checking will be done while tranversing this
// message. Use this only if you are absolutely sure that the input data is a valid message
// created by your own system. Never use this to read messages received from others.
//
// To create a trusted message, build a message using a MallocAllocator whose preferred segment
// size is larger than the message size. This guarantees that the message will be allocated as a
// single segment, meaning getSegmentsForOutput() returns a single word array. That word array
// is your message; you may pass a pointer to its first word into readTrusted() to read the
// message.
//
// This can be particularly handy for embedding messages in generated code: you can
// embed the raw bytes (using AlignedData) then make a Reader for it using this. This is the way
// default values are embedded in code generated by the Cap'n Proto compiler. E.g., if you have
// a message MyMessage, you can read its default value like so:
// MyMessage::Reader reader = Message<MyMessage>::ReadTrusted(MyMessage::DEFAULT.words);
public:
MallocMessageBuilder(uint firstSegmentWords = 1024,
AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY);
// Creates a BuilderContext which allocates at least the given number of words for the first
// segment, and then uses the given strategy to decide how much to allocate for subsequent
// segments. When choosing a value for firstSegmentWords, consider that:
// 1) Reading and writing messages gets slower when multiple segments are involved, so it's good
// if most messages fit in a single segment.
// 2) Unused bytes will not be written to the wire, so generally it is not a big deal to allocate
// more space than you need. It only becomes problematic if you are allocating many messages
// in parallel and thus use lots of memory, or if you allocate so much extra space that just
// zeroing it out becomes a bottleneck.
// The defaults have been chosen to be reasonable for most people, so don't change them unless you
// have reason to believe you need to.
CAPNPROTO_DISALLOW_COPY(MallocMessageBuilder);
virtual ~MallocMessageBuilder();
virtual ArrayPtr<word> allocateSegment(uint minimumSize) override;
private:
uint nextSize;
AllocationStrategy allocationStrategy;
void* firstSegment;
struct MoreSegments;
std::unique_ptr<MoreSegments> moreSegments;
};
// =======================================================================================
// implementation details
template <typename RootType>
inline Message<RootType>::Reader::Reader(ArrayPtr<const ArrayPtr<const word>> segments)
: internal(segments) {}
template <typename RootType>
inline Message<RootType>::Reader::Reader(std::unique_ptr<ReaderContext> context)
: internal(std::move(context)) {}
template <typename RootType>
inline typename RootType::Reader Message<RootType>::Reader::getRoot() {
return typename RootType::Reader(internal.getRoot(RootType::DEFAULT.words));
inline const ReaderOptions& MessageReader::getOptions() {
return options;
}
template <typename RootType>
inline Message<RootType>::Builder::Builder()
: internal() {}
template <typename RootType>
inline Message<RootType>::Builder::Builder(std::unique_ptr<BuilderContext> context)
: internal(std::move(context)) {}
template <typename RootType>
inline typename RootType::Builder Message<RootType>::Builder::initRoot() {
return typename RootType::Builder(internal.initRoot(RootType::DEFAULT.words));
inline typename RootType::Reader MessageReader::getRoot() {
return typename RootType::Reader(getRoot(RootType::DEFAULT.words));
}
template <typename RootType>
inline typename RootType::Builder Message<RootType>::Builder::getRoot() {
return typename RootType::Builder(internal.getRoot(RootType::DEFAULT.words));
inline typename RootType::Builder MessageBuilder::initRoot() {
return typename RootType::Builder(initRoot(RootType::DEFAULT.words));
}
template <typename RootType>
inline ArrayPtr<const ArrayPtr<const word>> Message<RootType>::Builder::getSegmentsForOutput() {
return internal.getSegmentsForOutput();
inline typename RootType::Builder MessageBuilder::getRoot() {
return typename RootType::Builder(getRoot(RootType::DEFAULT.words));
}
template <typename RootType>
typename RootType::Reader Message<RootType>::readTrusted(const word* data) {
typename RootType::Reader readMessageTrusted(const word* data) {
return typename RootType::Reader(internal::StructReader::readRootTrusted(
data, RootType::DEFAULT.words));
}
......
......@@ -106,6 +106,77 @@ inline ArrayPtr<T> arrayPtr(T* begin, T* end) {
return ArrayPtr<T>(begin, end);
}
template <typename T>
class Array {
// An owned array which will automatically be deleted in the destructor. Can be moved, but not
// copied.
public:
inline Array(): ptr(nullptr), size_(0) {}
inline Array(std::nullptr_t): ptr(nullptr), size_(0) {}
inline Array(Array&& other): ptr(other.ptr), size_(other.size_) {
other.ptr = nullptr;
other.size_ = 0;
}
CAPNPROTO_DISALLOW_COPY(Array);
inline ~Array() noexcept { delete[] ptr; }
inline operator ArrayPtr<T>() {
return ArrayPtr<T>(ptr, size_);
}
inline ArrayPtr<T> asPtr() {
return ArrayPtr<T>(ptr, size_);
}
inline std::size_t size() const { return size_; }
inline T& operator[](std::size_t index) const {
CAPNPROTO_DEBUG_ASSERT(index < size_, "Out-of-bounds Array access.");
return ptr[index];
}
inline T* begin() const { return ptr; }
inline T* end() const { return ptr + size_; }
inline ArrayPtr<T> slice(size_t start, size_t end) {
CAPNPROTO_DEBUG_ASSERT(start <= end && end <= size_, "Out-of-bounds Array::slice().");
return ArrayPtr<T>(ptr + start, end - start);
}
inline bool operator==(std::nullptr_t) { return ptr == nullptr; }
inline bool operator!=(std::nullptr_t) { return ptr != nullptr; }
inline Array& operator=(std::nullptr_t) {
delete[] ptr;
ptr = nullptr;
size_ = 0;
return *this;
}
inline Array& operator=(Array&& other) {
delete[] ptr;
ptr = other.ptr;
size_ = other.size_;
other.ptr = nullptr;
other.size_ = 0;
return *this;
}
private:
T* ptr;
std::size_t size_;
inline explicit Array(std::size_t size): ptr(new T[size]), size_(size) {}
template <typename U>
friend Array<U> newArray(size_t size);
};
template <typename T>
inline Array<T> newArray(size_t size) {
return Array<T>(size);
}
// =======================================================================================
// IDs
......@@ -381,8 +452,8 @@ inline constexpr auto operator*(UnitRatio<Number1, Unit2, Unit> ratio,
// =======================================================================================
// Raw memory types and measures
class byte { uint8_t content; CAPNPROTO_DISALLOW_COPY(byte); };
class word { uint64_t content; CAPNPROTO_DISALLOW_COPY(word); };
class byte { uint8_t content; CAPNPROTO_DISALLOW_COPY(byte); public: byte() = default; };
class word { uint64_t content; CAPNPROTO_DISALLOW_COPY(word); public: word() = default; };
// byte and word are opaque types with sizes of 8 and 64 bits, respectively. These types are useful
// only to make pointer arithmetic clearer. Since the contents are private, the only way to access
// them is to first reinterpret_cast to some other pointer type.
......
......@@ -264,7 +264,8 @@ static void checkStruct(StructReader reader) {
}
TEST(WireFormat, StructRoundTrip_OneSegment) {
BuilderArena arena(newBuilderContext());
MallocMessageBuilder message;
BuilderArena arena(&message);
SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS);
......@@ -298,7 +299,8 @@ TEST(WireFormat, StructRoundTrip_OneSegment) {
}
TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
BuilderArena arena(newFixedWidthBuilderContext(0));
MallocMessageBuilder message(0, AllocationStrategy::FIXED_SIZE);
BuilderArena arena(&message);
SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS);
......@@ -333,7 +335,8 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
}
TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) {
BuilderArena arena(newFixedWidthBuilderContext(8));
MallocMessageBuilder message(8, AllocationStrategy::FIXED_SIZE);
BuilderArena arena(&message);
SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS);
......
......@@ -612,7 +612,7 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(StructReader readStructReference(
SegmentReader* segment, const WireReference* ref, const word* defaultValue,
int recursionLimit)) {
int nestingLimit)) {
const word* ptr;
if (ref == nullptr || ref->isNull()) {
......@@ -621,9 +621,9 @@ struct WireHelpers {
ref = reinterpret_cast<const WireReference*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(recursionLimit == 0)) {
if (CAPNPROTO_EXPECT_FALSE(nestingLimit == 0)) {
segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles.");
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.");
goto useDefault;
}
......@@ -655,25 +655,25 @@ struct WireHelpers {
ref->structRef.fieldCount.get(),
ref->structRef.dataSize.get(),
ref->structRef.refCount.get(),
0 * BITS, recursionLimit - 1);
0 * BITS, nestingLimit - 1);
}
static CAPNPROTO_ALWAYS_INLINE(ListReader readListReference(
SegmentReader* segment, const WireReference* ref, const word* defaultValue,
FieldSize expectedElementSize, int recursionLimit)) {
FieldSize expectedElementSize, int nestingLimit)) {
const word* ptr;
if (ref == nullptr || ref->isNull()) {
useDefault:
if (defaultValue == nullptr) {
return ListReader(nullptr, nullptr, 0 * ELEMENTS, 0 * BITS / ELEMENTS, recursionLimit - 1);
return ListReader(nullptr, nullptr, 0 * ELEMENTS, 0 * BITS / ELEMENTS, nestingLimit - 1);
}
segment = nullptr;
ref = reinterpret_cast<const WireReference*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(recursionLimit == 0)) {
if (CAPNPROTO_EXPECT_FALSE(nestingLimit == 0)) {
segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles.");
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.");
goto useDefault;
}
......@@ -780,7 +780,7 @@ struct WireHelpers {
tag->structRef.fieldCount.get(),
tag->structRef.dataSize.get(),
tag->structRef.refCount.get(),
recursionLimit - 1);
nestingLimit - 1);
} else {
// The elements of the list are NOT structs.
......@@ -796,7 +796,7 @@ struct WireHelpers {
}
if (ref->listRef.elementSize() == expectedElementSize) {
return ListReader(segment, ptr, ref->listRef.elementCount(), step, recursionLimit - 1);
return ListReader(segment, ptr, ref->listRef.elementCount(), step, nestingLimit - 1);
} else if (expectedElementSize == FieldSize::INLINE_COMPOSITE) {
// We were expecting a struct list, but we received a list of some other type. Perhaps a
// non-struct list was recently upgraded to a struct list, but the sender is using the
......@@ -832,7 +832,7 @@ struct WireHelpers {
}
return ListReader(segment, ptr, ref->listRef.elementCount(), step, FieldNumber(1),
dataSize, referenceCount, recursionLimit - 1);
dataSize, referenceCount, nestingLimit - 1);
} else {
CAPNPROTO_ASSERT(segment != nullptr, "Trusted message had incompatible list element type.");
segment->getArena()->reportInvalidData("A list had incompatible element type.");
......@@ -1028,27 +1028,27 @@ StructReader StructReader::readRootTrusted(const word* location, const word* def
}
StructReader StructReader::readRoot(const word* location, const word* defaultValue,
SegmentReader* segment, int recursionLimit) {
SegmentReader* segment, int nestingLimit) {
if (!segment->containsInterval(location, location + REFERENCE_SIZE_IN_WORDS)) {
segment->getArena()->reportInvalidData("Root location out-of-bounds.");
location = nullptr;
}
return WireHelpers::readStructReference(segment, reinterpret_cast<const WireReference*>(location),
defaultValue, recursionLimit);
defaultValue, nestingLimit);
}
StructReader StructReader::getStructField(
WireReferenceCount refIndex, const word* defaultValue) const {
const WireReference* ref = refIndex >= referenceCount ? nullptr : references + refIndex;
return WireHelpers::readStructReference(segment, ref, defaultValue, recursionLimit);
return WireHelpers::readStructReference(segment, ref, defaultValue, nestingLimit);
}
ListReader StructReader::getListField(
WireReferenceCount refIndex, FieldSize expectedElementSize, const word* defaultValue) const {
const WireReference* ref = refIndex >= referenceCount ? nullptr : references + refIndex;
return WireHelpers::readListReference(
segment, ref, defaultValue, expectedElementSize, recursionLimit);
segment, ref, defaultValue, expectedElementSize, nestingLimit);
}
Text::Reader StructReader::getTextField(
......@@ -1131,10 +1131,10 @@ ListReader ListBuilder::asReader(FieldNumber fieldCount, WordCount dataSize,
}
StructReader ListReader::getStructElement(ElementCount index, const word* defaultValue) const {
if (CAPNPROTO_EXPECT_FALSE((segment != nullptr) & (recursionLimit == 0))) {
if (CAPNPROTO_EXPECT_FALSE((segment != nullptr) & (nestingLimit == 0))) {
segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles.");
return WireHelpers::readStructReference(nullptr, nullptr, defaultValue, recursionLimit);
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.");
return WireHelpers::readStructReference(nullptr, nullptr, defaultValue, nestingLimit);
} else {
BitCount64 indexBit = ElementCount64(index) * stepBits;
const byte* structPtr = reinterpret_cast<const byte*>(ptr) + indexBit / BITS_PER_BYTE;
......@@ -1142,7 +1142,7 @@ StructReader ListReader::getStructElement(ElementCount index, const word* defaul
segment, structPtr,
reinterpret_cast<const WireReference*>(structPtr + structDataSize * BYTES_PER_WORD),
structFieldCount, structDataSize, structReferenceCount, indexBit % BITS_PER_BYTE,
recursionLimit - 1);
nestingLimit - 1);
}
}
......@@ -1150,7 +1150,7 @@ ListReader ListReader::getListElement(
WireReferenceCount index, FieldSize expectedElementSize) const {
return WireHelpers::readListReference(
segment, reinterpret_cast<const WireReference*>(ptr) + index,
nullptr, expectedElementSize, recursionLimit);
nullptr, expectedElementSize, nestingLimit);
}
Text::Reader ListReader::getTextElement(WireReferenceCount index) const {
......
......@@ -159,11 +159,11 @@ class StructReader {
public:
inline StructReader()
: segment(nullptr), data(nullptr), references(nullptr), fieldCount(0), dataSize(0),
referenceCount(0), bit0Offset(0 * BITS), recursionLimit(0) {}
referenceCount(0), bit0Offset(0 * BITS), nestingLimit(0) {}
static StructReader readRootTrusted(const word* location, const word* defaultValue);
static StructReader readRoot(const word* location, const word* defaultValue,
SegmentReader* segment, int recursionLimit);
SegmentReader* segment, int nestingLimit);
template <typename T>
CAPNPROTO_ALWAYS_INLINE(
......@@ -213,16 +213,16 @@ private:
// instead of the usual zero. This is needed to allow a boolean list to be upgraded to a list
// of structs.
int recursionLimit;
int nestingLimit;
// Limits the depth of message structures to guard against stack-overflow-based DoS attacks.
// Once this reaches zero, further pointers will be pruned.
inline StructReader(SegmentReader* segment, const void* data, const WireReference* references,
FieldNumber fieldCount, WordCount dataSize, WireReferenceCount referenceCount,
BitCount bit0Offset, int recursionLimit)
BitCount bit0Offset, int nestingLimit)
: segment(segment), data(data), references(references), fieldCount(fieldCount),
dataSize(dataSize), referenceCount(referenceCount), bit0Offset(bit0Offset),
recursionLimit(recursionLimit) {}
nestingLimit(nestingLimit) {}
friend class ListReader;
friend class StructBuilder;
......@@ -306,7 +306,7 @@ public:
inline ListReader()
: segment(nullptr), ptr(nullptr), elementCount(0),
stepBits(0 * BITS / ELEMENTS), structFieldCount(0), structDataSize(0),
structReferenceCount(0), recursionLimit(0) {}
structReferenceCount(0), nestingLimit(0) {}
inline ElementCount size();
// The number of elements in the list.
......@@ -348,22 +348,22 @@ private:
// only used to check for field presence; the data size is also used to compute the reference
// pointer.
int recursionLimit;
int nestingLimit;
// Limits the depth of message structures to guard against stack-overflow-based DoS attacks.
// Once this reaches zero, further pointers will be pruned.
inline ListReader(SegmentReader* segment, const void* ptr, ElementCount elementCount,
decltype(BITS / ELEMENTS) stepBits, int recursionLimit)
decltype(BITS / ELEMENTS) stepBits, int nestingLimit)
: segment(segment), ptr(ptr), elementCount(elementCount), stepBits(stepBits),
structFieldCount(0), structDataSize(0), structReferenceCount(0),
recursionLimit(recursionLimit) {}
nestingLimit(nestingLimit) {}
inline ListReader(SegmentReader* segment, const void* ptr, ElementCount elementCount,
decltype(BITS / ELEMENTS) stepBits,
FieldNumber structFieldCount, WordCount structDataSize,
WireReferenceCount structReferenceCount, int recursionLimit)
WireReferenceCount structReferenceCount, int nestingLimit)
: segment(segment), ptr(ptr), elementCount(elementCount), stepBits(stepBits),
structFieldCount(structFieldCount), structDataSize(structDataSize),
structReferenceCount(structReferenceCount), recursionLimit(recursionLimit) {}
structReferenceCount(structReferenceCount), nestingLimit(nestingLimit) {}
friend class StructReader;
friend class ListBuilder;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment