Commit 9c3cb05f authored by Kenton Varda's avatar Kenton Varda

Another iteration of refactoring the Message classes.

parent fe29fa08
...@@ -34,10 +34,11 @@ Arena::~Arena() {} ...@@ -34,10 +34,11 @@ Arena::~Arena() {}
// ======================================================================================= // =======================================================================================
ReaderArena::ReaderArena(std::unique_ptr<ReaderContext> context) ReaderArena::ReaderArena(MessageReader* message)
: context(std::move(context)), : message(message),
readLimiter(this->context->getReadLimit() * WORDS), readLimiter(this->message->getOptions().traversalLimitInWords * WORDS),
segment0(this, SegmentId(0), this->context->getSegment(0), &readLimiter) {} ignoreErrors(false),
segment0(this, SegmentId(0), this->message->getSegment(0), &readLimiter) {}
ReaderArena::~ReaderArena() {} ReaderArena::~ReaderArena() {}
...@@ -61,7 +62,7 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { ...@@ -61,7 +62,7 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
} }
} }
ArrayPtr<const word> newSegment = context->getSegment(id.value); ArrayPtr<const word> newSegment = message->getSegment(id.value);
if (newSegment == nullptr) { if (newSegment == nullptr) {
return nullptr; return nullptr;
} }
...@@ -77,17 +78,26 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { ...@@ -77,17 +78,26 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
} }
void ReaderArena::reportInvalidData(const char* description) { void ReaderArena::reportInvalidData(const char* description) {
context->reportError(description); if (!ignoreErrors) {
message->getOptions().errorReporter->reportError(description);
}
} }
void ReaderArena::reportReadLimitReached() { void ReaderArena::reportReadLimitReached() {
context->reportError("Exceeded read limit."); if (!ignoreErrors) {
message->getOptions().errorReporter->reportError(
"Exceeded message traversal limit. See capnproto::ReaderOptions.");
// Ignore further errors since they are likely repeats or caused by the read limit being
// reached.
ignoreErrors = true;
}
} }
// ======================================================================================= // =======================================================================================
BuilderArena::BuilderArena(std::unique_ptr<BuilderContext> context) BuilderArena::BuilderArena(MessageBuilder* message)
: context(std::move(context)), segment0(nullptr, SegmentId(0), nullptr, nullptr) {} : message(message), segment0(nullptr, SegmentId(0), nullptr, nullptr) {}
BuilderArena::~BuilderArena() {} BuilderArena::~BuilderArena() {}
SegmentBuilder* BuilderArena::getSegment(SegmentId id) { SegmentBuilder* BuilderArena::getSegment(SegmentId id) {
...@@ -105,7 +115,7 @@ SegmentBuilder* BuilderArena::getSegmentWithAvailable(WordCount minimumAvailable ...@@ -105,7 +115,7 @@ SegmentBuilder* BuilderArena::getSegmentWithAvailable(WordCount minimumAvailable
if (segment0.getArena() == nullptr) { if (segment0.getArena() == nullptr) {
// We're allocating the first segment. // We're allocating the first segment.
ArrayPtr<word> ptr = context->allocateSegment(minimumAvailable / WORDS); ArrayPtr<word> ptr = message->allocateSegment(minimumAvailable / WORDS);
// Re-allocate segment0 in-place. This is a bit of a hack, but we have not returned any // Re-allocate segment0 in-place. This is a bit of a hack, but we have not returned any
// pointers to this segment yet, so it should be fine. // pointers to this segment yet, so it should be fine.
...@@ -132,7 +142,7 @@ SegmentBuilder* BuilderArena::getSegmentWithAvailable(WordCount minimumAvailable ...@@ -132,7 +142,7 @@ SegmentBuilder* BuilderArena::getSegmentWithAvailable(WordCount minimumAvailable
std::unique_ptr<SegmentBuilder> newBuilder = std::unique_ptr<SegmentBuilder>( std::unique_ptr<SegmentBuilder> newBuilder = std::unique_ptr<SegmentBuilder>(
new SegmentBuilder(this, SegmentId(moreSegments->builders.size() + 1), new SegmentBuilder(this, SegmentId(moreSegments->builders.size() + 1),
context->allocateSegment(minimumAvailable / WORDS), &this->dummyLimiter)); message->allocateSegment(minimumAvailable / WORDS), &this->dummyLimiter));
SegmentBuilder* result = newBuilder.get(); SegmentBuilder* result = newBuilder.get();
moreSegments->builders.push_back(std::move(newBuilder)); moreSegments->builders.push_back(std::move(newBuilder));
......
...@@ -116,8 +116,6 @@ private: ...@@ -116,8 +116,6 @@ private:
word* pos; word* pos;
CAPNPROTO_DISALLOW_COPY(SegmentBuilder); CAPNPROTO_DISALLOW_COPY(SegmentBuilder);
// TODO: Do we need mutex locking?
}; };
class Arena { class Arena {
...@@ -158,7 +156,7 @@ public: ...@@ -158,7 +156,7 @@ public:
class ReaderArena final: public Arena { class ReaderArena final: public Arena {
public: public:
ReaderArena(std::unique_ptr<ReaderContext> context); ReaderArena(MessageReader* message);
~ReaderArena(); ~ReaderArena();
CAPNPROTO_DISALLOW_COPY(ReaderArena); CAPNPROTO_DISALLOW_COPY(ReaderArena);
...@@ -168,8 +166,9 @@ public: ...@@ -168,8 +166,9 @@ public:
void reportReadLimitReached() override; void reportReadLimitReached() override;
private: private:
std::unique_ptr<ReaderContext> context; MessageReader* message;
ReadLimiter readLimiter; ReadLimiter readLimiter;
bool ignoreErrors;
// Optimize for single-segment messages so that small messages are handled quickly. // Optimize for single-segment messages so that small messages are handled quickly.
SegmentReader segment0; SegmentReader segment0;
...@@ -180,7 +179,7 @@ private: ...@@ -180,7 +179,7 @@ private:
class BuilderArena final: public Arena { class BuilderArena final: public Arena {
public: public:
BuilderArena(std::unique_ptr<BuilderContext> context); BuilderArena(MessageBuilder* message);
~BuilderArena(); ~BuilderArena();
CAPNPROTO_DISALLOW_COPY(BuilderArena); CAPNPROTO_DISALLOW_COPY(BuilderArena);
...@@ -204,7 +203,7 @@ public: ...@@ -204,7 +203,7 @@ public:
void reportReadLimitReached() override; void reportReadLimitReached() override;
private: private:
std::unique_ptr<BuilderContext> context; MessageBuilder* message;
ReadLimiter dummyLimiter; ReadLimiter dummyLimiter;
SegmentBuilder segment0; SegmentBuilder segment0;
......
...@@ -227,78 +227,80 @@ void checkMessage(Reader reader) { ...@@ -227,78 +227,80 @@ void checkMessage(Reader reader) {
} }
TEST(Encoding, AllTypes) { TEST(Encoding, AllTypes) {
Message<TestAllTypes>::Builder builder; MallocMessageBuilder builder;
initMessage(builder.initRoot()); initMessage(builder.initRoot<TestAllTypes>());
checkMessage(builder.getRoot()); checkMessage(builder.getRoot<TestAllTypes>());
checkMessage(builder.getRoot().asReader()); checkMessage(builder.getRoot<TestAllTypes>().asReader());
Message<TestAllTypes>::Reader reader(builder.getSegmentsForOutput()); SegmentArrayMessageReader reader(builder.getSegmentsForOutput());
checkMessage(reader.getRoot()); checkMessage(reader.getRoot<TestAllTypes>());
ASSERT_EQ(1u, builder.getSegmentsForOutput().size()); ASSERT_EQ(1u, builder.getSegmentsForOutput().size());
checkMessage(Message<TestAllTypes>::readTrusted(builder.getSegmentsForOutput()[0].begin())); checkMessage(readMessageTrusted<TestAllTypes>(builder.getSegmentsForOutput()[0].begin()));
} }
TEST(Encoding, AllTypesMultiSegment) { TEST(Encoding, AllTypesMultiSegment) {
Message<TestAllTypes>::Builder builder(newFixedWidthBuilderContext(0)); MallocMessageBuilder builder(0, AllocationStrategy::FIXED_SIZE);
initMessage(builder.initRoot()); initMessage(builder.initRoot<TestAllTypes>());
checkMessage(builder.getRoot()); checkMessage(builder.getRoot<TestAllTypes>());
checkMessage(builder.getRoot().asReader()); checkMessage(builder.getRoot<TestAllTypes>().asReader());
Message<TestAllTypes>::Reader reader(builder.getSegmentsForOutput()); SegmentArrayMessageReader reader(builder.getSegmentsForOutput());
checkMessage(reader.getRoot()); checkMessage(reader.getRoot<TestAllTypes>());
} }
TEST(Encoding, Defaults) { TEST(Encoding, Defaults) {
AlignedData<1> nullRoot = {{0, 0, 0, 0, 0, 0, 0, 0}}; AlignedData<1> nullRoot = {{0, 0, 0, 0, 0, 0, 0, 0}};
ArrayPtr<const word> segments[1] = {arrayPtr(nullRoot.words, 1)}; ArrayPtr<const word> segments[1] = {arrayPtr(nullRoot.words, 1)};
Message<TestDefaults>::Reader reader(arrayPtr(segments, 1)); SegmentArrayMessageReader reader(arrayPtr(segments, 1));
checkMessage(reader.getRoot()); checkMessage(reader.getRoot<TestDefaults>());
checkMessage(Message<TestDefaults>::readTrusted(nullRoot.words)); checkMessage(readMessageTrusted<TestDefaults>(nullRoot.words));
} }
TEST(Encoding, DefaultInitialization) { TEST(Encoding, DefaultInitialization) {
Message<TestDefaults>::Builder builder; MallocMessageBuilder builder;
checkMessage(builder.getRoot()); // first pass initializes to defaults checkMessage(builder.getRoot<TestDefaults>()); // first pass initializes to defaults
checkMessage(builder.getRoot().asReader()); checkMessage(builder.getRoot<TestDefaults>().asReader());
checkMessage(builder.getRoot()); // second pass just reads the initialized structure checkMessage(builder.getRoot<TestDefaults>()); // second pass just reads the initialized structure
checkMessage(builder.getRoot().asReader()); checkMessage(builder.getRoot<TestDefaults>().asReader());
Message<TestDefaults>::Reader reader(builder.getSegmentsForOutput()); SegmentArrayMessageReader reader(builder.getSegmentsForOutput());
checkMessage(reader.getRoot()); checkMessage(reader.getRoot<TestDefaults>());
} }
TEST(Encoding, DefaultInitializationMultiSegment) { TEST(Encoding, DefaultInitializationMultiSegment) {
Message<TestDefaults>::Builder builder(newFixedWidthBuilderContext(0)); MallocMessageBuilder builder(0, AllocationStrategy::FIXED_SIZE);
checkMessage(builder.getRoot()); // first pass initializes to defaults // first pass initializes to defaults
checkMessage(builder.getRoot().asReader()); checkMessage(builder.getRoot<TestDefaults>());
checkMessage(builder.getRoot<TestDefaults>().asReader());
checkMessage(builder.getRoot()); // second pass just reads the initialized structure // second pass just reads the initialized structure
checkMessage(builder.getRoot().asReader()); checkMessage(builder.getRoot<TestDefaults>());
checkMessage(builder.getRoot<TestDefaults>().asReader());
Message<TestDefaults>::Reader reader(builder.getSegmentsForOutput()); SegmentArrayMessageReader reader(builder.getSegmentsForOutput());
checkMessage(reader.getRoot()); checkMessage(reader.getRoot<TestDefaults>());
} }
TEST(Encoding, DefaultsFromEmptyMessage) { TEST(Encoding, DefaultsFromEmptyMessage) {
AlignedData<1> emptyMessage = {{4, 0, 0, 0, 0, 0, 0, 0}}; AlignedData<1> emptyMessage = {{4, 0, 0, 0, 0, 0, 0, 0}};
ArrayPtr<const word> segments[1] = {arrayPtr(emptyMessage.words, 1)}; ArrayPtr<const word> segments[1] = {arrayPtr(emptyMessage.words, 1)};
Message<TestDefaults>::Reader reader(arrayPtr(segments, 1)); SegmentArrayMessageReader reader(arrayPtr(segments, 1));
checkMessage(reader.getRoot()); checkMessage(reader.getRoot<TestDefaults>());
checkMessage(Message<TestDefaults>::readTrusted(emptyMessage.words)); checkMessage(readMessageTrusted<TestDefaults>(emptyMessage.words));
} }
} // namespace } // namespace
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// This header contains internal interfaces relied upon by message.h and implemented in message.c++.
// These declarations should be thought of as being part of message.h, but I moved them here to make
// message.h more readable. The problem is that the interface people really care about in
// message.h -- namely, Message -- has to be declared after internal::Message. Having an internal
// interface appear in the middle of the header ahead of the public interface was distracting and
// confusing.
#ifndef CAPNPROTO_MESSAGE_INTERNAL_H_
#define CAPNPROTO_MESSAGE_INTERNAL_H_
#include <cstddef>
#include <memory>
#include "type-safety.h"
#include "wire-format.h"
namespace capnproto {
class ReaderContext;
class BuilderContext;
}
namespace capnproto {
namespace internal {
// TODO: Move to message-internal.h so that this header looks nicer?
class ReaderArena;
class BuilderArena;
struct MessageImpl {
// Underlying implementation of capnproto::Message. All the parts that don't need to be templated
// are implemented by this class, so that they can be shared and non-inline.
MessageImpl() = delete;
class Reader {
public:
Reader(ArrayPtr<const ArrayPtr<const word>> segments);
Reader(std::unique_ptr<ReaderContext> context);
Reader(Reader&& other) = default;
CAPNPROTO_DISALLOW_COPY(Reader);
~Reader();
StructReader getRoot(const word* defaultValue);
private:
uint recursionLimit;
// Space in which we can construct a ReaderArena. We don't use ReaderArena directly here
// because we don't want clients to have to #include arena.h, which itself includes a bunch of
// big STL headers. We don't use a pointer to a ReaderArena because that would require an
// extra malloc on every message which could be expensive when processing small messages,
// particularly when the context itself is freelisted and so no other allocation is necessary.
void* arenaSpace[15];
ReaderArena* arena() { return reinterpret_cast<ReaderArena*>(arenaSpace); }
};
class Builder {
public:
Builder();
Builder(std::unique_ptr<BuilderContext> context);
Builder(Builder&& other) = default;
CAPNPROTO_DISALLOW_COPY(Builder);
~Builder();
StructBuilder initRoot(const word* defaultValue);
StructBuilder getRoot(const word* defaultValue);
ArrayPtr<const ArrayPtr<const word>> getSegmentsForOutput();
private:
SegmentBuilder* rootSegment;
// Space in which we can construct a BuilderArena. We don't use BuilderArena directly here
// because we don't want clients to have to #include arena.h, which itself includes a bunch of
// big STL headers. We don't use a pointer to a BuilderArena because that would require an
// extra malloc on every message which could be expensive when processing small messages,
// particularly when the context itself is freelisted and so no other allocation is necessary.
void* arenaSpace[15];
BuilderArena* arena() { return reinterpret_cast<BuilderArena*>(arenaSpace); }
static SegmentBuilder* allocateRoot(BuilderArena* arena);
};
};
} // namespace internal
} // namespace capnproto
#endif // CAPNPROTO_MESSAGE_INTERNAL_H_
This diff is collapsed.
This diff is collapsed.
...@@ -106,6 +106,77 @@ inline ArrayPtr<T> arrayPtr(T* begin, T* end) { ...@@ -106,6 +106,77 @@ inline ArrayPtr<T> arrayPtr(T* begin, T* end) {
return ArrayPtr<T>(begin, end); return ArrayPtr<T>(begin, end);
} }
template <typename T>
class Array {
// An owned array which will automatically be deleted in the destructor. Can be moved, but not
// copied.
public:
inline Array(): ptr(nullptr), size_(0) {}
inline Array(std::nullptr_t): ptr(nullptr), size_(0) {}
inline Array(Array&& other): ptr(other.ptr), size_(other.size_) {
other.ptr = nullptr;
other.size_ = 0;
}
CAPNPROTO_DISALLOW_COPY(Array);
inline ~Array() noexcept { delete[] ptr; }
inline operator ArrayPtr<T>() {
return ArrayPtr<T>(ptr, size_);
}
inline ArrayPtr<T> asPtr() {
return ArrayPtr<T>(ptr, size_);
}
inline std::size_t size() const { return size_; }
inline T& operator[](std::size_t index) const {
CAPNPROTO_DEBUG_ASSERT(index < size_, "Out-of-bounds Array access.");
return ptr[index];
}
inline T* begin() const { return ptr; }
inline T* end() const { return ptr + size_; }
inline ArrayPtr<T> slice(size_t start, size_t end) {
CAPNPROTO_DEBUG_ASSERT(start <= end && end <= size_, "Out-of-bounds Array::slice().");
return ArrayPtr<T>(ptr + start, end - start);
}
inline bool operator==(std::nullptr_t) { return ptr == nullptr; }
inline bool operator!=(std::nullptr_t) { return ptr != nullptr; }
inline Array& operator=(std::nullptr_t) {
delete[] ptr;
ptr = nullptr;
size_ = 0;
return *this;
}
inline Array& operator=(Array&& other) {
delete[] ptr;
ptr = other.ptr;
size_ = other.size_;
other.ptr = nullptr;
other.size_ = 0;
return *this;
}
private:
T* ptr;
std::size_t size_;
inline explicit Array(std::size_t size): ptr(new T[size]), size_(size) {}
template <typename U>
friend Array<U> newArray(size_t size);
};
template <typename T>
inline Array<T> newArray(size_t size) {
return Array<T>(size);
}
// ======================================================================================= // =======================================================================================
// IDs // IDs
...@@ -381,8 +452,8 @@ inline constexpr auto operator*(UnitRatio<Number1, Unit2, Unit> ratio, ...@@ -381,8 +452,8 @@ inline constexpr auto operator*(UnitRatio<Number1, Unit2, Unit> ratio,
// ======================================================================================= // =======================================================================================
// Raw memory types and measures // Raw memory types and measures
class byte { uint8_t content; CAPNPROTO_DISALLOW_COPY(byte); }; class byte { uint8_t content; CAPNPROTO_DISALLOW_COPY(byte); public: byte() = default; };
class word { uint64_t content; CAPNPROTO_DISALLOW_COPY(word); }; class word { uint64_t content; CAPNPROTO_DISALLOW_COPY(word); public: word() = default; };
// byte and word are opaque types with sizes of 8 and 64 bits, respectively. These types are useful // byte and word are opaque types with sizes of 8 and 64 bits, respectively. These types are useful
// only to make pointer arithmetic clearer. Since the contents are private, the only way to access // only to make pointer arithmetic clearer. Since the contents are private, the only way to access
// them is to first reinterpret_cast to some other pointer type. // them is to first reinterpret_cast to some other pointer type.
......
...@@ -264,7 +264,8 @@ static void checkStruct(StructReader reader) { ...@@ -264,7 +264,8 @@ static void checkStruct(StructReader reader) {
} }
TEST(WireFormat, StructRoundTrip_OneSegment) { TEST(WireFormat, StructRoundTrip_OneSegment) {
BuilderArena arena(newBuilderContext()); MallocMessageBuilder message;
BuilderArena arena(&message);
SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS); SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS); word* rootLocation = segment->allocate(1 * WORDS);
...@@ -298,7 +299,8 @@ TEST(WireFormat, StructRoundTrip_OneSegment) { ...@@ -298,7 +299,8 @@ TEST(WireFormat, StructRoundTrip_OneSegment) {
} }
TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) { TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
BuilderArena arena(newFixedWidthBuilderContext(0)); MallocMessageBuilder message(0, AllocationStrategy::FIXED_SIZE);
BuilderArena arena(&message);
SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS); SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS); word* rootLocation = segment->allocate(1 * WORDS);
...@@ -333,7 +335,8 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) { ...@@ -333,7 +335,8 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
} }
TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) { TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) {
BuilderArena arena(newFixedWidthBuilderContext(8)); MallocMessageBuilder message(8, AllocationStrategy::FIXED_SIZE);
BuilderArena arena(&message);
SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS); SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS); word* rootLocation = segment->allocate(1 * WORDS);
......
...@@ -612,7 +612,7 @@ struct WireHelpers { ...@@ -612,7 +612,7 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(StructReader readStructReference( static CAPNPROTO_ALWAYS_INLINE(StructReader readStructReference(
SegmentReader* segment, const WireReference* ref, const word* defaultValue, SegmentReader* segment, const WireReference* ref, const word* defaultValue,
int recursionLimit)) { int nestingLimit)) {
const word* ptr; const word* ptr;
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
...@@ -621,9 +621,9 @@ struct WireHelpers { ...@@ -621,9 +621,9 @@ struct WireHelpers {
ref = reinterpret_cast<const WireReference*>(defaultValue); ref = reinterpret_cast<const WireReference*>(defaultValue);
ptr = ref->target(); ptr = ref->target();
} else if (segment != nullptr) { } else if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(recursionLimit == 0)) { if (CAPNPROTO_EXPECT_FALSE(nestingLimit == 0)) {
segment->getArena()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles."); "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.");
goto useDefault; goto useDefault;
} }
...@@ -655,25 +655,25 @@ struct WireHelpers { ...@@ -655,25 +655,25 @@ struct WireHelpers {
ref->structRef.fieldCount.get(), ref->structRef.fieldCount.get(),
ref->structRef.dataSize.get(), ref->structRef.dataSize.get(),
ref->structRef.refCount.get(), ref->structRef.refCount.get(),
0 * BITS, recursionLimit - 1); 0 * BITS, nestingLimit - 1);
} }
static CAPNPROTO_ALWAYS_INLINE(ListReader readListReference( static CAPNPROTO_ALWAYS_INLINE(ListReader readListReference(
SegmentReader* segment, const WireReference* ref, const word* defaultValue, SegmentReader* segment, const WireReference* ref, const word* defaultValue,
FieldSize expectedElementSize, int recursionLimit)) { FieldSize expectedElementSize, int nestingLimit)) {
const word* ptr; const word* ptr;
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr) { if (defaultValue == nullptr) {
return ListReader(nullptr, nullptr, 0 * ELEMENTS, 0 * BITS / ELEMENTS, recursionLimit - 1); return ListReader(nullptr, nullptr, 0 * ELEMENTS, 0 * BITS / ELEMENTS, nestingLimit - 1);
} }
segment = nullptr; segment = nullptr;
ref = reinterpret_cast<const WireReference*>(defaultValue); ref = reinterpret_cast<const WireReference*>(defaultValue);
ptr = ref->target(); ptr = ref->target();
} else if (segment != nullptr) { } else if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(recursionLimit == 0)) { if (CAPNPROTO_EXPECT_FALSE(nestingLimit == 0)) {
segment->getArena()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles."); "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.");
goto useDefault; goto useDefault;
} }
...@@ -780,7 +780,7 @@ struct WireHelpers { ...@@ -780,7 +780,7 @@ struct WireHelpers {
tag->structRef.fieldCount.get(), tag->structRef.fieldCount.get(),
tag->structRef.dataSize.get(), tag->structRef.dataSize.get(),
tag->structRef.refCount.get(), tag->structRef.refCount.get(),
recursionLimit - 1); nestingLimit - 1);
} else { } else {
// The elements of the list are NOT structs. // The elements of the list are NOT structs.
...@@ -796,7 +796,7 @@ struct WireHelpers { ...@@ -796,7 +796,7 @@ struct WireHelpers {
} }
if (ref->listRef.elementSize() == expectedElementSize) { if (ref->listRef.elementSize() == expectedElementSize) {
return ListReader(segment, ptr, ref->listRef.elementCount(), step, recursionLimit - 1); return ListReader(segment, ptr, ref->listRef.elementCount(), step, nestingLimit - 1);
} else if (expectedElementSize == FieldSize::INLINE_COMPOSITE) { } else if (expectedElementSize == FieldSize::INLINE_COMPOSITE) {
// We were expecting a struct list, but we received a list of some other type. Perhaps a // We were expecting a struct list, but we received a list of some other type. Perhaps a
// non-struct list was recently upgraded to a struct list, but the sender is using the // non-struct list was recently upgraded to a struct list, but the sender is using the
...@@ -832,7 +832,7 @@ struct WireHelpers { ...@@ -832,7 +832,7 @@ struct WireHelpers {
} }
return ListReader(segment, ptr, ref->listRef.elementCount(), step, FieldNumber(1), return ListReader(segment, ptr, ref->listRef.elementCount(), step, FieldNumber(1),
dataSize, referenceCount, recursionLimit - 1); dataSize, referenceCount, nestingLimit - 1);
} else { } else {
CAPNPROTO_ASSERT(segment != nullptr, "Trusted message had incompatible list element type."); CAPNPROTO_ASSERT(segment != nullptr, "Trusted message had incompatible list element type.");
segment->getArena()->reportInvalidData("A list had incompatible element type."); segment->getArena()->reportInvalidData("A list had incompatible element type.");
...@@ -1028,27 +1028,27 @@ StructReader StructReader::readRootTrusted(const word* location, const word* def ...@@ -1028,27 +1028,27 @@ StructReader StructReader::readRootTrusted(const word* location, const word* def
} }
StructReader StructReader::readRoot(const word* location, const word* defaultValue, StructReader StructReader::readRoot(const word* location, const word* defaultValue,
SegmentReader* segment, int recursionLimit) { SegmentReader* segment, int nestingLimit) {
if (!segment->containsInterval(location, location + REFERENCE_SIZE_IN_WORDS)) { if (!segment->containsInterval(location, location + REFERENCE_SIZE_IN_WORDS)) {
segment->getArena()->reportInvalidData("Root location out-of-bounds."); segment->getArena()->reportInvalidData("Root location out-of-bounds.");
location = nullptr; location = nullptr;
} }
return WireHelpers::readStructReference(segment, reinterpret_cast<const WireReference*>(location), return WireHelpers::readStructReference(segment, reinterpret_cast<const WireReference*>(location),
defaultValue, recursionLimit); defaultValue, nestingLimit);
} }
StructReader StructReader::getStructField( StructReader StructReader::getStructField(
WireReferenceCount refIndex, const word* defaultValue) const { WireReferenceCount refIndex, const word* defaultValue) const {
const WireReference* ref = refIndex >= referenceCount ? nullptr : references + refIndex; const WireReference* ref = refIndex >= referenceCount ? nullptr : references + refIndex;
return WireHelpers::readStructReference(segment, ref, defaultValue, recursionLimit); return WireHelpers::readStructReference(segment, ref, defaultValue, nestingLimit);
} }
ListReader StructReader::getListField( ListReader StructReader::getListField(
WireReferenceCount refIndex, FieldSize expectedElementSize, const word* defaultValue) const { WireReferenceCount refIndex, FieldSize expectedElementSize, const word* defaultValue) const {
const WireReference* ref = refIndex >= referenceCount ? nullptr : references + refIndex; const WireReference* ref = refIndex >= referenceCount ? nullptr : references + refIndex;
return WireHelpers::readListReference( return WireHelpers::readListReference(
segment, ref, defaultValue, expectedElementSize, recursionLimit); segment, ref, defaultValue, expectedElementSize, nestingLimit);
} }
Text::Reader StructReader::getTextField( Text::Reader StructReader::getTextField(
...@@ -1131,10 +1131,10 @@ ListReader ListBuilder::asReader(FieldNumber fieldCount, WordCount dataSize, ...@@ -1131,10 +1131,10 @@ ListReader ListBuilder::asReader(FieldNumber fieldCount, WordCount dataSize,
} }
StructReader ListReader::getStructElement(ElementCount index, const word* defaultValue) const { StructReader ListReader::getStructElement(ElementCount index, const word* defaultValue) const {
if (CAPNPROTO_EXPECT_FALSE((segment != nullptr) & (recursionLimit == 0))) { if (CAPNPROTO_EXPECT_FALSE((segment != nullptr) & (nestingLimit == 0))) {
segment->getArena()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles."); "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.");
return WireHelpers::readStructReference(nullptr, nullptr, defaultValue, recursionLimit); return WireHelpers::readStructReference(nullptr, nullptr, defaultValue, nestingLimit);
} else { } else {
BitCount64 indexBit = ElementCount64(index) * stepBits; BitCount64 indexBit = ElementCount64(index) * stepBits;
const byte* structPtr = reinterpret_cast<const byte*>(ptr) + indexBit / BITS_PER_BYTE; const byte* structPtr = reinterpret_cast<const byte*>(ptr) + indexBit / BITS_PER_BYTE;
...@@ -1142,7 +1142,7 @@ StructReader ListReader::getStructElement(ElementCount index, const word* defaul ...@@ -1142,7 +1142,7 @@ StructReader ListReader::getStructElement(ElementCount index, const word* defaul
segment, structPtr, segment, structPtr,
reinterpret_cast<const WireReference*>(structPtr + structDataSize * BYTES_PER_WORD), reinterpret_cast<const WireReference*>(structPtr + structDataSize * BYTES_PER_WORD),
structFieldCount, structDataSize, structReferenceCount, indexBit % BITS_PER_BYTE, structFieldCount, structDataSize, structReferenceCount, indexBit % BITS_PER_BYTE,
recursionLimit - 1); nestingLimit - 1);
} }
} }
...@@ -1150,7 +1150,7 @@ ListReader ListReader::getListElement( ...@@ -1150,7 +1150,7 @@ ListReader ListReader::getListElement(
WireReferenceCount index, FieldSize expectedElementSize) const { WireReferenceCount index, FieldSize expectedElementSize) const {
return WireHelpers::readListReference( return WireHelpers::readListReference(
segment, reinterpret_cast<const WireReference*>(ptr) + index, segment, reinterpret_cast<const WireReference*>(ptr) + index,
nullptr, expectedElementSize, recursionLimit); nullptr, expectedElementSize, nestingLimit);
} }
Text::Reader ListReader::getTextElement(WireReferenceCount index) const { Text::Reader ListReader::getTextElement(WireReferenceCount index) const {
......
...@@ -159,11 +159,11 @@ class StructReader { ...@@ -159,11 +159,11 @@ class StructReader {
public: public:
inline StructReader() inline StructReader()
: segment(nullptr), data(nullptr), references(nullptr), fieldCount(0), dataSize(0), : segment(nullptr), data(nullptr), references(nullptr), fieldCount(0), dataSize(0),
referenceCount(0), bit0Offset(0 * BITS), recursionLimit(0) {} referenceCount(0), bit0Offset(0 * BITS), nestingLimit(0) {}
static StructReader readRootTrusted(const word* location, const word* defaultValue); static StructReader readRootTrusted(const word* location, const word* defaultValue);
static StructReader readRoot(const word* location, const word* defaultValue, static StructReader readRoot(const word* location, const word* defaultValue,
SegmentReader* segment, int recursionLimit); SegmentReader* segment, int nestingLimit);
template <typename T> template <typename T>
CAPNPROTO_ALWAYS_INLINE( CAPNPROTO_ALWAYS_INLINE(
...@@ -213,16 +213,16 @@ private: ...@@ -213,16 +213,16 @@ private:
// instead of the usual zero. This is needed to allow a boolean list to be upgraded to a list // instead of the usual zero. This is needed to allow a boolean list to be upgraded to a list
// of structs. // of structs.
int recursionLimit; int nestingLimit;
// Limits the depth of message structures to guard against stack-overflow-based DoS attacks. // Limits the depth of message structures to guard against stack-overflow-based DoS attacks.
// Once this reaches zero, further pointers will be pruned. // Once this reaches zero, further pointers will be pruned.
inline StructReader(SegmentReader* segment, const void* data, const WireReference* references, inline StructReader(SegmentReader* segment, const void* data, const WireReference* references,
FieldNumber fieldCount, WordCount dataSize, WireReferenceCount referenceCount, FieldNumber fieldCount, WordCount dataSize, WireReferenceCount referenceCount,
BitCount bit0Offset, int recursionLimit) BitCount bit0Offset, int nestingLimit)
: segment(segment), data(data), references(references), fieldCount(fieldCount), : segment(segment), data(data), references(references), fieldCount(fieldCount),
dataSize(dataSize), referenceCount(referenceCount), bit0Offset(bit0Offset), dataSize(dataSize), referenceCount(referenceCount), bit0Offset(bit0Offset),
recursionLimit(recursionLimit) {} nestingLimit(nestingLimit) {}
friend class ListReader; friend class ListReader;
friend class StructBuilder; friend class StructBuilder;
...@@ -306,7 +306,7 @@ public: ...@@ -306,7 +306,7 @@ public:
inline ListReader() inline ListReader()
: segment(nullptr), ptr(nullptr), elementCount(0), : segment(nullptr), ptr(nullptr), elementCount(0),
stepBits(0 * BITS / ELEMENTS), structFieldCount(0), structDataSize(0), stepBits(0 * BITS / ELEMENTS), structFieldCount(0), structDataSize(0),
structReferenceCount(0), recursionLimit(0) {} structReferenceCount(0), nestingLimit(0) {}
inline ElementCount size(); inline ElementCount size();
// The number of elements in the list. // The number of elements in the list.
...@@ -348,22 +348,22 @@ private: ...@@ -348,22 +348,22 @@ private:
// only used to check for field presence; the data size is also used to compute the reference // only used to check for field presence; the data size is also used to compute the reference
// pointer. // pointer.
int recursionLimit; int nestingLimit;
// Limits the depth of message structures to guard against stack-overflow-based DoS attacks. // Limits the depth of message structures to guard against stack-overflow-based DoS attacks.
// Once this reaches zero, further pointers will be pruned. // Once this reaches zero, further pointers will be pruned.
inline ListReader(SegmentReader* segment, const void* ptr, ElementCount elementCount, inline ListReader(SegmentReader* segment, const void* ptr, ElementCount elementCount,
decltype(BITS / ELEMENTS) stepBits, int recursionLimit) decltype(BITS / ELEMENTS) stepBits, int nestingLimit)
: segment(segment), ptr(ptr), elementCount(elementCount), stepBits(stepBits), : segment(segment), ptr(ptr), elementCount(elementCount), stepBits(stepBits),
structFieldCount(0), structDataSize(0), structReferenceCount(0), structFieldCount(0), structDataSize(0), structReferenceCount(0),
recursionLimit(recursionLimit) {} nestingLimit(nestingLimit) {}
inline ListReader(SegmentReader* segment, const void* ptr, ElementCount elementCount, inline ListReader(SegmentReader* segment, const void* ptr, ElementCount elementCount,
decltype(BITS / ELEMENTS) stepBits, decltype(BITS / ELEMENTS) stepBits,
FieldNumber structFieldCount, WordCount structDataSize, FieldNumber structFieldCount, WordCount structDataSize,
WireReferenceCount structReferenceCount, int recursionLimit) WireReferenceCount structReferenceCount, int nestingLimit)
: segment(segment), ptr(ptr), elementCount(elementCount), stepBits(stepBits), : segment(segment), ptr(ptr), elementCount(elementCount), stepBits(stepBits),
structFieldCount(structFieldCount), structDataSize(structDataSize), structFieldCount(structFieldCount), structDataSize(structDataSize),
structReferenceCount(structReferenceCount), recursionLimit(recursionLimit) {} structReferenceCount(structReferenceCount), nestingLimit(nestingLimit) {}
friend class StructReader; friend class StructReader;
friend class ListBuilder; friend class ListBuilder;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment