Commit 6a72d324 authored by Kenton Varda's avatar Kenton Varda

Rename the 'Message' interfaces to Arena and make them internal. Make a new,…

Rename the 'Message' interfaces to Arena and make them internal.  Make a new, intuitive 'Message' interface for creating / reading messages.
parent 20b1ea55
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "arena.h"
#include "message.h"
#include <vector>
#include <string.h>
#include <iostream>
namespace capnproto {
namespace internal {
Arena::~Arena() {}
// =======================================================================================
ReaderArena::ReaderArena(
ArrayPtr<const ArrayPtr<const word>> segments,
ErrorReporter* errorReporter,
WordCount64 readLimit)
: segments(segments),
errorReporter(errorReporter),
readLimiter(readLimit) {
segmentReaders.reserve(segments.size());
uint i = 0;
for (auto segment: segments) {
segmentReaders.emplace_back(new SegmentReader(this, SegmentId(i++), segment, &readLimiter));
}
}
ReaderArena::~ReaderArena() {}
ArrayPtr<const ArrayPtr<const word>> ReaderArena::getSegmentsForOutput() {
return segments;
}
SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
if (id.value >= segments.size()) {
return nullptr;
} else {
return segmentReaders[id.value].get();
}
}
void ReaderArena::reportInvalidData(const char* description) {
errorReporter->reportError(description);
}
void ReaderArena::reportReadLimitReached() {
errorReporter->reportError("Exceeded read limit.");
}
// =======================================================================================
BuilderArena::BuilderArena(Allocator* allocator): allocator(allocator) {}
BuilderArena::~BuilderArena() {
// TODO: This is wrong because we aren't taking into account how much of each segment is actually
// allocated.
uint i = 0;
for (ArrayPtr<word> ptr: memory) {
// The memory array contains Array<const word> only to ease implementation of getSegmentsForOutput().
// We actually own this space and can de-constify it.
allocator->free(SegmentId(i++), ptr);
}
}
SegmentBuilder* BuilderArena::getSegment(SegmentId id) {
return segments[id.value].get();
}
SegmentBuilder* BuilderArena::getSegmentWithAvailable(WordCount minimumAvailable) {
if (segments.empty() || segments.back()->available() < minimumAvailable) {
ArrayPtr<word> array = allocator->allocate(
SegmentId(segments.size()), minimumAvailable / WORDS);
memory.push_back(array);
segments.push_back(std::unique_ptr<SegmentBuilder>(new SegmentBuilder(
this, SegmentId(segments.size()), array, &dummyLimiter)));
}
return segments.back().get();
}
ArrayPtr<const ArrayPtr<const word>> BuilderArena::getSegmentsForOutput() {
segmentsForOutput.resize(segments.size());
for (uint i = 0; i < segments.size(); i++) {
segmentsForOutput[i] = segments[i]->currentlyAllocated();
}
return arrayPtr(&*segmentsForOutput.begin(), segmentsForOutput.size());
}
SegmentReader* BuilderArena::tryGetSegment(SegmentId id) {
if (id.value >= segments.size()) {
return nullptr;
} else {
return segments[id.value].get();
}
}
void BuilderArena::reportInvalidData(const char* description) {
// TODO: Better error reporting.
std::cerr << "BuilderArena: Parse error: " << description << std::endl;
}
void BuilderArena::reportReadLimitReached() {
// TODO: Better error reporting.
std::cerr << "BuilderArena: Exceeded read limit." << std::endl;
}
} // namespace internal
} // namespace capnproto
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// THIS HEADER IS NOT INCLUDABLE BY CLIENTS, even generated code. It is entirely internal to the
// library, which means we can safely #include STL stuff.
#include <vector>
#include <memory>
#include "macros.h"
#include "type-safety.h"
#include "message.h"
#ifndef CAPNPROTO_ARENA_H_
#define CAPNPROTO_ARENA_H_
namespace capnproto {
namespace internal {
class SegmentReader;
class SegmentBuilder;
class Arena;
class ReaderArena;
class BuilderArena;
class ReadLimiter;
class ReadLimiter {
// Used to keep track of how much data has been processed from a message, and cut off further
// processing if and when a particular limit is reached. This is primarily intended to guard
// against maliciously-crafted messages which contain cycles or overlapping structures. Cycles
// and overlapping are not permitted by the Cap'n Proto format because in many cases they could
// be used to craft a deceptively small message which could consume excessive server resources to
// process, perhaps even sending it into an infinite loop. Actually detecting overlaps would be
// time-consuming, so instead we just keep track of how many words worth of data structures the
// receiver has actually dereferenced and error out if this gets too high.
//
// This counting takes place as you call getters (for non-primitive values) on the message
// readers. If you call the same getter twice, the data it returns may be double-counted. This
// should not be a big deal in most cases -- just set the read limit high enough that it will
// only trigger in unreasonable cases.
public:
inline explicit ReadLimiter(); // No limit.
inline explicit ReadLimiter(WordCount64 limit); // Limit to the given number of words.
CAPNPROTO_ALWAYS_INLINE(bool canRead(WordCount amount, Arena* arena));
private:
WordCount64 limit;
CAPNPROTO_DISALLOW_COPY(ReadLimiter);
};
class Arena {
public:
virtual ~Arena();
virtual ArrayPtr<const ArrayPtr<const word>> getSegmentsForOutput() = 0;
// Get an array of all the segments, suitable for writing out. For BuilderArena, this only
// returns the allocated portion of each segment, whereas tryGetSegment() returns something that
// includes not-yet-allocated space.
virtual SegmentReader* tryGetSegment(SegmentId id) = 0;
// Gets the segment with the given ID, or return nullptr if no such segment exists.
virtual void reportInvalidData(const char* description) = 0;
// Called to report that the message data is invalid.
//
// Implementations should, ideally, report the error to the sender, if possible. They may also
// want to write a debug message, etc.
//
// Implementations may choose to throw an exception in order to cut short further processing of
// the message. If no exception is thrown, then the caller will attempt to work around the
// invalid data by using a default value instead. This is good enough to guard against
// maliciously-crafted messages (the sender could just as easily have sent a perfectly-valid
// message containing the default value), but in the case of accidentally-corrupted messages this
// behavior may propagate the corruption.
//
// TODO: Give more information about the error, e.g. the segment and offset at which the invalid
// data was encountered, any relevant type/field names if known, etc.
virtual void reportReadLimitReached() = 0;
// Called to report that the read limit has been reached. See ReadLimiter, below.
//
// As with reportInvalidData(), this may throw an exception, and if it doesn't, default values
// will be used in place of the actual message data.
//
// If this method returns rather that throwing, many other errors are likely to be reported as
// a side-effect of reading being blocked. The Arena should ignore all further errors
// after this call.
// TODO: Methods to deal with bundled capabilities.
};
class ReaderArena final: public Arena {
public:
ReaderArena(ArrayPtr<const ArrayPtr<const word>> segments, ErrorReporter* errorReporter,
WordCount64 readLimit);
~ReaderArena();
// implements Arena ------------------------------------------------
ArrayPtr<const ArrayPtr<const word>> getSegmentsForOutput() override;
SegmentReader* tryGetSegment(SegmentId id) override;
void reportInvalidData(const char* description) override;
void reportReadLimitReached() override;
private:
ArrayPtr<const ArrayPtr<const word>> segments;
ErrorReporter* errorReporter;
ReadLimiter readLimiter;
std::vector<std::unique_ptr<SegmentReader>> segmentReaders;
};
class BuilderArena final: public Arena {
public:
BuilderArena(Allocator* allocator);
~BuilderArena();
SegmentBuilder* getSegment(SegmentId id);
// Get the segment with the given id. Crashes or throws an exception if no such segment exists.
SegmentBuilder* getSegmentWithAvailable(WordCount minimumAvailable);
// Get a segment which has at least the given amount of space available, allocating it if
// necessary. Crashes or throws an exception if there is not enough memory.
// TODO: Methods to deal with bundled capabilities.
// implements Arena ------------------------------------------------
ArrayPtr<const ArrayPtr<const word>> getSegmentsForOutput() override;
SegmentReader* tryGetSegment(SegmentId id) override;
void reportInvalidData(const char* description) override;
void reportReadLimitReached() override;
private:
Allocator* allocator;
std::vector<std::unique_ptr<SegmentBuilder>> segments;
std::vector<ArrayPtr<word>> memory;
std::vector<ArrayPtr<const word>> segmentsForOutput;
ReadLimiter dummyLimiter;
};
class SegmentReader {
public:
inline SegmentReader(Arena* arena, SegmentId id, ArrayPtr<const word> ptr,
ReadLimiter* readLimiter);
CAPNPROTO_ALWAYS_INLINE(bool containsInterval(const word* from, const word* to));
inline Arena* getArena();
inline SegmentId getSegmentId();
inline const word* getStartPtr();
inline WordCount getOffsetTo(const word* ptr);
inline WordCount getSize();
private:
Arena* arena;
SegmentId id;
ArrayPtr<const word> ptr;
ReadLimiter* readLimiter;
CAPNPROTO_DISALLOW_COPY(SegmentReader);
friend class SegmentBuilder;
};
class SegmentBuilder: public SegmentReader {
public:
inline SegmentBuilder(BuilderArena* arena, SegmentId id, ArrayPtr<word> ptr,
ReadLimiter* readLimiter);
CAPNPROTO_ALWAYS_INLINE(word* allocate(WordCount amount));
inline word* getPtrUnchecked(WordCount offset);
inline BuilderArena* getArena();
inline WordCount available();
inline ArrayPtr<const word> currentlyAllocated();
private:
word* pos;
CAPNPROTO_DISALLOW_COPY(SegmentBuilder);
// TODO: Do we need mutex locking?
};
// =======================================================================================
inline ReadLimiter::ReadLimiter()
// I didn't want to #include <limits> just for this one lousy constant.
: limit(uint64_t(0x7fffffffffffffffll) * WORDS) {}
inline ReadLimiter::ReadLimiter(WordCount64 limit): limit(limit) {}
inline bool ReadLimiter::canRead(WordCount amount, Arena* arena) {
if (CAPNPROTO_EXPECT_FALSE(amount > limit)) {
arena->reportReadLimitReached();
return false;
} else {
limit -= amount;
return true;
}
}
// -------------------------------------------------------------------
inline SegmentReader::SegmentReader(Arena* arena, SegmentId id, ArrayPtr<const word> ptr,
ReadLimiter* readLimiter)
: arena(arena), id(id), ptr(ptr), readLimiter(readLimiter) {}
inline bool SegmentReader::containsInterval(const word* from, const word* to) {
return from >= this->ptr.begin() && to <= this->ptr.end() &&
readLimiter->canRead(intervalLength(from, to), arena);
}
inline Arena* SegmentReader::getArena() { return arena; }
inline SegmentId SegmentReader::getSegmentId() { return id; }
inline const word* SegmentReader::getStartPtr() { return ptr.begin(); }
inline WordCount SegmentReader::getOffsetTo(const word* ptr) {
return intervalLength(this->ptr.begin(), ptr);
}
inline WordCount SegmentReader::getSize() { return ptr.size() * WORDS; }
// -------------------------------------------------------------------
inline SegmentBuilder::SegmentBuilder(
BuilderArena* arena, SegmentId id, ArrayPtr<word> ptr, ReadLimiter* readLimiter)
: SegmentReader(arena, id, ptr, readLimiter),
pos(ptr.begin()) {}
inline word* SegmentBuilder::allocate(WordCount amount) {
if (amount > intervalLength(pos, ptr.end())) {
return nullptr;
} else {
// TODO: Atomic increment, backtracking if we go over, would make this thread-safe. How much
// would it cost in the single-threaded case? Is it free? Benchmark it.
word* result = pos;
pos += amount;
return result;
}
}
inline word* SegmentBuilder::getPtrUnchecked(WordCount offset) {
// const_cast OK because SegmentBuilder's constructor always initializes its SegmentReader base
// class with a pointer that was originally non-const.
return const_cast<word*>(ptr.begin() + offset);
}
inline BuilderArena* SegmentBuilder::getArena() {
// Down-cast safe because SegmentBuilder's constructor always initializes its SegmentReader base
// class with an Arena pointer that actually points to a BuilderArena.
return static_cast<BuilderArena*>(arena);
}
inline WordCount SegmentBuilder::available() {
return intervalLength(pos, ptr.end());
}
inline ArrayPtr<const word> SegmentBuilder::currentlyAllocated() {
return arrayPtr(ptr.begin(), pos - ptr.begin());
}
} // namespace internal
} // namespace capnproto
#endif // CAPNPROTO_ARENA_H_
...@@ -227,56 +227,86 @@ void checkMessage(Reader reader) { ...@@ -227,56 +227,86 @@ void checkMessage(Reader reader) {
} }
TEST(Encoding, AllTypes) { TEST(Encoding, AllTypes) {
auto root = newMallocMessageRoot<TestAllTypes>(); Message<TestAllTypes>::Builder builder;
initMessage(root.builder); initMessage(builder.initRoot());
checkMessage(root.builder); checkMessage(builder.getRoot());
checkMessage(root.builder.asReader()); checkMessage(builder.getRoot().asReader());
Message<TestAllTypes>::Reader reader(
builder.getSegmentsForOutput(), 64, 1 << 30, ThrowingErrorReporter::getDefaultInstance());
checkMessage(reader.getRoot());
ASSERT_EQ(1u, builder.getSegmentsForOutput().size());
checkMessage(Message<TestAllTypes>::readTrusted(builder.getSegmentsForOutput()[0].begin()));
} }
TEST(Encoding, AllTypesMultiSegment) { TEST(Encoding, AllTypesMultiSegment) {
auto root = newMallocMessageRoot<TestAllTypes>(0 * WORDS); MallocAllocator allocator(0);
Message<TestAllTypes>::Builder builder(&allocator);
initMessage(root.builder); initMessage(builder.initRoot());
checkMessage(root.builder); checkMessage(builder.getRoot());
checkMessage(root.builder.asReader()); checkMessage(builder.getRoot().asReader());
Message<TestAllTypes>::Reader reader(
builder.getSegmentsForOutput(), 64, 1 << 30, ThrowingErrorReporter::getDefaultInstance());
checkMessage(reader.getRoot());
} }
TEST(Encoding, Defaults) { TEST(Encoding, Defaults) {
auto root = newMallocMessageRoot<TestDefaults>(); AlignedData<1> nullRoot = {{0, 0, 0, 0, 0, 0, 0, 0}};
ArrayPtr<const word> segments[1] = {arrayPtr(nullRoot.words, 1)};
Message<TestDefaults>::Reader reader(arrayPtr(segments, 1), 64, 1 << 30,
ThrowingErrorReporter::getDefaultInstance());
checkMessage(root.builder.asReader()); checkMessage(reader.getRoot());
checkMessage(Message<TestDefaults>::readTrusted(nullRoot.words));
} }
TEST(Encoding, DefaultInitialization) { TEST(Encoding, DefaultInitialization) {
auto root = newMallocMessageRoot<TestDefaults>(); Message<TestDefaults>::Builder builder;
checkMessage(builder.getRoot()); // first pass initializes to defaults
checkMessage(builder.getRoot().asReader());
checkMessage(root.builder); checkMessage(builder.getRoot()); // second pass just reads the initialized structure
checkMessage(root.builder.asReader()); checkMessage(builder.getRoot().asReader());
Message<TestDefaults>::Reader reader(
builder.getSegmentsForOutput(), 64, 1 << 30, ThrowingErrorReporter::getDefaultInstance());
checkMessage(reader.getRoot());
} }
TEST(Encoding, DefaultInitializationMultiSegment) { TEST(Encoding, DefaultInitializationMultiSegment) {
auto root = newMallocMessageRoot<TestDefaults>(0 * WORDS); MallocAllocator allocator(0);
Message<TestDefaults>::Builder builder(&allocator);
checkMessage(builder.getRoot()); // first pass initializes to defaults
checkMessage(builder.getRoot().asReader());
checkMessage(root.builder); checkMessage(builder.getRoot()); // second pass just reads the initialized structure
checkMessage(root.builder.asReader()); checkMessage(builder.getRoot().asReader());
Message<TestDefaults>::Reader reader(
builder.getSegmentsForOutput(), 64, 1 << 30, ThrowingErrorReporter::getDefaultInstance());
checkMessage(reader.getRoot());
} }
TEST(Encoding, DefaultsNotOnWire) { TEST(Encoding, DefaultsFromEmptyMessage) {
AlignedData<1> emptyMessage = {{4, 0, 0, 0, 0, 0, 0, 0}}; AlignedData<1> emptyMessage = {{4, 0, 0, 0, 0, 0, 0, 0}};
std::unique_ptr<MessageBuilder> message = newMallocMessage(512 * WORDS); ArrayPtr<const word> segments[1] = {arrayPtr(emptyMessage.words, 1)};
SegmentBuilder* segment = message->getSegmentWithAvailable(1 * WORDS); Message<TestDefaults>::Reader reader(arrayPtr(segments, 1), 64, 1 << 30,
word* rootLocation = segment->allocate(1 * WORDS); ThrowingErrorReporter::getDefaultInstance());
memcpy(rootLocation, emptyMessage.words, sizeof(word));
TestDefaults::Reader reader(StructReader::readRoot(
emptyMessage.words, TestDefaults::DEFAULT.words, segment, 64));
checkMessage(reader);
TestDefaults::Reader reader2(StructReader::readRootTrusted( checkMessage(reader.getRoot());
emptyMessage.words, TestDefaults::DEFAULT.words)); checkMessage(Message<TestDefaults>::readTrusted(emptyMessage.words));
checkMessage(reader2);
} }
} // namespace } // namespace
......
...@@ -22,10 +22,9 @@ ...@@ -22,10 +22,9 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "macros.h" #include "macros.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <exception> #include <exception>
#include <string>
#include <unistd.h>
namespace capnproto { namespace capnproto {
namespace internal { namespace internal {
...@@ -35,39 +34,33 @@ class Exception: public std::exception { ...@@ -35,39 +34,33 @@ class Exception: public std::exception {
public: public:
Exception(const char* file, int line, const char* expectation, const char* message); Exception(const char* file, int line, const char* expectation, const char* message);
virtual ~Exception() noexcept; ~Exception() noexcept;
const char* getFile() { return file; } const char* what() const noexcept override;
int getLine() { return line; }
const char* getExpectation() { return expectation; }
const char* getMessage() { return message; }
virtual const char* what();
private: private:
const char* file; std::string description;
int line;
const char* expectation;
const char* message;
char* whatBuffer;
}; };
Exception::Exception( Exception::Exception(
const char* file, int line, const char* expectation, const char* message) const char* file, int line, const char* expectation, const char* message) {
: file(file), line(line), expectation(expectation), message(message), whatBuffer(nullptr) { description = "Captain Proto debug assertion failed:\n ";
fprintf(stderr, "Captain Proto debug assertion failed:\n %s:%d: %s\n %s", description += file;
file, line, expectation, message); description += ':';
} description += line;
description += ": ";
description += expectation;
description += "\n ";
description += message;
description += "\n";
Exception::~Exception() noexcept { write(STDERR_FILENO, description.data(), description.size());
delete [] whatBuffer;
} }
const char* Exception::what() { Exception::~Exception() noexcept {}
whatBuffer = new char[strlen(file) + strlen(expectation) + strlen(message) + 256];
sprintf(whatBuffer, "Captain Proto debug assertion failed:\n %s:%d: %s\n %s", const char* Exception::what() const noexcept {
file, line, expectation, message); return description.c_str();
return whatBuffer;
} }
void assertionFailure(const char* file, int line, const char* expectation, const char* message) { void assertionFailure(const char* file, int line, const char* expectation, const char* message) {
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// This header contains internal interfaces relied upon by message.h and implemented in message.c++.
// These declarations should be thought of as being part of message.h, but I moved them here to make
// message.h more readable. The problem is that the interface people really care about in
// message.h -- namely, Message -- has to be declared after internal::Message. Having an internal
// interface appear in the middle of the header ahead of the public interface was distracting and
// confusing.
#ifndef CAPNPROTO_MESSAGE_INTERNAL_H_
#define CAPNPROTO_MESSAGE_INTERNAL_H_
#include <cstddef>
#include <memory>
#include "type-safety.h"
#include "wire-format.h"
namespace capnproto {
class Allocator;
class ErrorReporter;
}
namespace capnproto {
namespace internal {
// TODO: Move to message-internal.h so that this header looks nicer?
class Arena;
class BuilderArena;
struct MessageImpl {
// Underlying implementation of capnproto::Message. All the parts that don't need to be templated
// are implemented by this class, so that they can be shared and non-inline.
MessageImpl() = delete;
class Reader {
public:
Reader(ArrayPtr<const ArrayPtr<const word>> segments,
uint recursionLimit, uint64_t readLimit, ErrorReporter* errorReporter);
Reader(Reader&& other) = default;
CAPNPROTO_DISALLOW_COPY(Reader);
~Reader();
StructReader getRoot(const word* defaultValue);
private:
std::unique_ptr<Arena> arena;
uint recursionLimit;
};
class Builder {
public:
Builder();
Builder(Allocator* allocator);
Builder(Builder&& other) = default;
CAPNPROTO_DISALLOW_COPY(Builder);
~Builder();
StructBuilder initRoot(const word* defaultValue);
StructBuilder getRoot(const word* defaultValue);
ArrayPtr<const ArrayPtr<const word>> getSegmentsForOutput();
private:
std::unique_ptr<BuilderArena> arena;
SegmentBuilder* rootSegment;
static SegmentBuilder* allocateRoot(BuilderArena* arena);
};
};
} // namespace internal
} // namespace capnproto
#endif // CAPNPROTO_MESSAGE_INTERNAL_H_
...@@ -22,75 +22,131 @@ ...@@ -22,75 +22,131 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "message.h" #include "message.h"
#include <vector> #include "arena.h"
#include <string.h> #include "stdlib.h"
#include <iostream> #include <exception>
#include <stdlib.h> #include <string>
#include <unistd.h>
namespace capnproto { namespace capnproto {
MessageReader::~MessageReader() {} Allocator::~Allocator() {}
MessageBuilder::~MessageBuilder() {} ErrorReporter::~ErrorReporter() {}
class MallocMessage: public MessageBuilder { MallocAllocator::MallocAllocator(uint preferredSegmentSizeWords)
: preferredSegmentSizeWords(preferredSegmentSizeWords) {}
MallocAllocator::~MallocAllocator() {}
MallocAllocator* MallocAllocator::getDefaultInstance() {
static MallocAllocator defaultInstance(1024);
return &defaultInstance;
}
ArrayPtr<word> MallocAllocator::allocate(SegmentId id, uint minimumSize) {
uint size = std::max(minimumSize, preferredSegmentSizeWords);
return arrayPtr(reinterpret_cast<word*>(calloc(size, sizeof(word))), size);
}
void MallocAllocator::free(SegmentId id, ArrayPtr<word> ptr) {
::free(ptr.begin());
}
StderrErrorReporter::~StderrErrorReporter() {}
StderrErrorReporter* StderrErrorReporter::getDefaultInstance() {
static StderrErrorReporter defaultInstance;
return &defaultInstance;
}
void StderrErrorReporter::reportError(const char* description) {
std::string message("ERROR: Cap'n Proto parse error: ");
message += description;
message += '\n';
write(STDERR_FILENO, message.data(), message.size());
}
class ParseException: public std::exception {
public: public:
MallocMessage(WordCount preferredSegmentSize); ParseException(const char* description);
~MallocMessage(); ~ParseException() noexcept;
SegmentReader* tryGetSegment(SegmentId id); const char* what() const noexcept override;
void reportInvalidData(const char* description);
void reportReadLimitReached();
SegmentBuilder* getSegment(SegmentId id);
SegmentBuilder* getSegmentWithAvailable(WordCount minimumAvailable);
private: private:
WordCount preferredSegmentSize; std::string description;
std::vector<std::unique_ptr<SegmentBuilder>> segments;
std::vector<word*> memory;
}; };
MallocMessage::MallocMessage(WordCount preferredSegmentSize) ParseException::ParseException(const char* description)
: preferredSegmentSize(preferredSegmentSize) {} : description(description) {}
MallocMessage::~MallocMessage() {
for (word* ptr: memory) { ParseException::~ParseException() noexcept {}
free(ptr);
} const char* ParseException::what() const noexcept {
return description.c_str();
}
ThrowingErrorReporter::~ThrowingErrorReporter() {}
ThrowingErrorReporter* ThrowingErrorReporter::getDefaultInstance() {
static ThrowingErrorReporter defaultInstance;
return &defaultInstance;
}
void ThrowingErrorReporter::reportError(const char* description) {
throw ParseException(description);
} }
SegmentReader* MallocMessage::tryGetSegment(SegmentId id) { // =======================================================================================
if (id.value >= segments.size()) {
return nullptr; namespace internal {
MessageImpl::Reader::Reader(ArrayPtr<const ArrayPtr<const word>> segments,
uint recursionLimit, uint64_t readLimit, ErrorReporter* errorReporter)
: arena(new ReaderArena(segments, errorReporter, readLimit * WORDS)),
recursionLimit(recursionLimit) {}
MessageImpl::Reader::~Reader() {}
StructReader MessageImpl::Reader::getRoot(const word* defaultValue) {
SegmentReader* segment = arena->tryGetSegment(SegmentId(0));
if (segment == nullptr ||
!segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1)) {
segment->getArena()->reportInvalidData("Message did not contain a root pointer.");
return StructReader::readRootTrusted(defaultValue, defaultValue);
} else { } else {
return segments[id.value].get(); return StructReader::readRoot(segment->getStartPtr(), defaultValue, segment, recursionLimit);
} }
} }
void MallocMessage::reportInvalidData(const char* description) { MessageImpl::Builder::Builder()
// TODO: Better error reporting. : arena(new BuilderArena(MallocAllocator::getDefaultInstance())),
std::cerr << "MallocMessage: Parse error: " << description << std::endl; rootSegment(allocateRoot(arena.get())) {}
} MessageImpl::Builder::Builder(Allocator* allocator)
: arena(new BuilderArena(allocator)),
rootSegment(allocateRoot(arena.get())) {}
MessageImpl::Builder::~Builder() {}
void MallocMessage::reportReadLimitReached() { StructBuilder MessageImpl::Builder::initRoot(const word* defaultValue) {
// TODO: Better error reporting. return StructBuilder::initRoot(
std::cerr << "MallocMessage: Exceeded read limit." << std::endl; rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue);
} }
SegmentBuilder* MallocMessage::getSegment(SegmentId id) { StructBuilder MessageImpl::Builder::getRoot(const word* defaultValue) {
return segments[id.value].get(); return StructBuilder::getRoot(rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue);
} }
SegmentBuilder* MallocMessage::getSegmentWithAvailable(WordCount minimumAvailable) { ArrayPtr<const ArrayPtr<const word>> MessageImpl::Builder::getSegmentsForOutput() {
if (segments.empty() || segments.back()->available() < minimumAvailable) { return arena->getSegmentsForOutput();
WordCount newSize = std::max(minimumAvailable, preferredSegmentSize);
memory.push_back(reinterpret_cast<word*>(calloc(newSize / WORDS, sizeof(word))));
segments.push_back(std::unique_ptr<SegmentBuilder>(new SegmentBuilder(
this, SegmentId(segments.size()), memory.back(), newSize)));
}
return segments.back().get();
} }
std::unique_ptr<MessageBuilder> newMallocMessage(WordCount preferredSegmentSize) { SegmentBuilder* MessageImpl::Builder::allocateRoot(BuilderArena* arena) {
return std::unique_ptr<MessageBuilder>(new MallocMessage(preferredSegmentSize)); WordCount refSize = 1 * REFERENCES * WORDS_PER_REFERENCE;
SegmentBuilder* segment = arena->getSegmentWithAvailable(refSize);
CAPNPROTO_ASSERT(segment->getSegmentId() == SegmentId(0),
"First allocated word of new arena was not in segment ID 0.");
word* location = segment->allocate(refSize);
CAPNPROTO_ASSERT(location == segment->getPtrUnchecked(0 * WORDS),
"First allocated word of new arena was not the first word in its segment.");
return segment;
} }
} // namespace internal
} // namespace capnproto } // namespace capnproto
...@@ -26,251 +26,173 @@ ...@@ -26,251 +26,173 @@
#include "macros.h" #include "macros.h"
#include "type-safety.h" #include "type-safety.h"
#include "wire-format.h" #include "wire-format.h"
#include "list.h" #include "message-internal.h"
#ifndef CAPNPROTO_MESSAGE_H_ #ifndef CAPNPROTO_MESSAGE_H_
#define CAPNPROTO_MESSAGE_H_ #define CAPNPROTO_MESSAGE_H_
namespace capnproto { namespace capnproto {
class SegmentReader; class Segment;
class SegmentBuilder; typedef Id<uint32_t, Segment> SegmentId;
class MessageReader;
class MessageBuilder;
class ReadLimiter;
typedef Id<uint32_t, SegmentReader> SegmentId; // =======================================================================================
class MessageReader {
// Abstract interface encapsulating a readable message. By implementing this interface, you can
// control how memory is allocated for the message. Or use MallocMessage to make things easy.
class Allocator {
public: public:
virtual ~MessageReader(); virtual ~Allocator();
virtual SegmentReader* tryGetSegment(SegmentId id) = 0;
// Gets the segment with the given ID, or return nullptr if no such segment exists.
virtual void reportInvalidData(const char* description) = 0; virtual ArrayPtr<word> allocate(SegmentId id, uint minimumSize) = 0;
// Called to report that the message data is invalid. virtual void free(SegmentId id, ArrayPtr<word> ptr) = 0;
// };
// Implementations should, ideally, report the error to the sender, if possible. They may also
// want to write a debug message, etc.
//
// Implementations may choose to throw an exception in order to cut short further processing of
// the message. If no exception is thrown, then the caller will attempt to work around the
// invalid data by using a default value instead. This is good enough to guard against
// maliciously-crafted messages (the sender could just as easily have sent a perfectly-valid
// message containing the default value), but in the case of accidentally-corrupted messages this
// behavior may propagate the corruption.
//
// TODO: Give more information about the error, e.g. the segment and offset at which the invalid
// data was encountered, any relevant type/field names if known, etc.
virtual void reportReadLimitReached() = 0; class ErrorReporter {
// Called to report that the read limit has been reached. See ReadLimiter, below. public:
// virtual ~ErrorReporter();
// As with reportInvalidData(), this may throw an exception, and if it doesn't, default values
// will be used in place of the actual message data.
//
// If this method returns rather that throwing, many other errors are likely to be reported as
// a side-effect of reading being blocked. The MessageReader should ignore all further errors
// after this call.
// TODO: Methods to deal with bundled capabilities. virtual void reportError(const char* description) = 0;
}; };
class MessageBuilder: public MessageReader { // =======================================================================================
// Abstract interface encapsulating a writable message. By implementing this interface, you can
// control how memory is allocated for the message. Or use MallocMessage to make things easy.
public: template <typename RootType>
virtual ~MessageBuilder(); struct Message {
Message() = delete;
virtual SegmentBuilder* getSegment(SegmentId id) = 0; class Reader {
// Get the segment with the given id. Crashes or throws an exception if no such segment exists. public:
Reader(ArrayPtr<const ArrayPtr<const word>> segments,
uint recursionLimit, uint64_t readLimit, ErrorReporter* errorReporter);
Reader(Reader&& other) = default;
CAPNPROTO_DISALLOW_COPY(Reader);
virtual SegmentBuilder* getSegmentWithAvailable(WordCount minimumAvailable) = 0; typename RootType::Reader getRoot();
// Get a segment which has at least the given amount of space available, allocating it if
// necessary. Crashes or throws an exception if there is not enough memory.
// TODO: Methods to deal with bundled capabilities. private:
}; internal::MessageImpl::Reader internal;
};
std::unique_ptr<MessageBuilder> newMallocMessage(WordCount preferredSegmentSize = 512 * WORDS); class Builder {
// Returns a simple MessageBuilder implementation that uses standard allocation. public:
Builder();
// Make a Builder that allocates using malloc, using the default segment size.
template <typename T> Builder(Allocator* allocator);
struct MessageRoot { // Make a Builder that allocates memory using the given allocator.
std::unique_ptr<MessageBuilder> message;
typename T::Builder builder;
MessageRoot() = default; Builder(Builder&& other) = default;
MessageRoot(std::unique_ptr<MessageBuilder> message, typename T::Builder builder) CAPNPROTO_DISALLOW_COPY(Builder);
: message(move(message)), builder(builder) {}
};
template <typename T> typename RootType::Builder initRoot();
MessageRoot<T> newMallocMessageRoot(WordCount preferredSegmentSize = 512 * WORDS); typename RootType::Builder getRoot();
// Starts a new message with the given root type backed by a MallocMessage.
// T must be a Cap'n Proto struct type.
class ReadLimiter {
// Used to keep track of how much data has been processed from a message, and cut off further
// processing if and when a particular limit is reached. This is primarily intended to guard
// against maliciously-crafted messages which contain cycles or overlapping structures. Cycles
// and overlapping are not permitted by the Cap'n Proto format because in many cases they could
// be used to craft a deceptively small message which could consume excessive server resources to
// process, perhaps even sending it into an infinite loop. Actually detecting overlaps would be
// time-consuming, so instead we just keep track of how many words worth of data structures the
// receiver has actually dereferenced and error out if this gets too high.
//
// This counting takes place as you call getters (for non-primitive values) on the message
// readers. If you call the same getter twice, the data it returns may be double-counted. This
// should not be a big deal in most cases -- just set the read limit high enough that it will
// only trigger in unreasonable cases.
public: ArrayPtr<const ArrayPtr<const word>> getSegmentsForOutput();
inline explicit ReadLimiter(); // No limit.
inline explicit ReadLimiter(WordCount64 limit); // Limit to the given number of words.
CAPNPROTO_ALWAYS_INLINE(bool canRead(WordCount amount, MessageReader* message));
private: private:
WordCount64 limit; internal::MessageImpl::Builder internal;
};
CAPNPROTO_DISALLOW_COPY(ReadLimiter); static typename RootType::Reader readTrusted(const word* data);
// IF THE INPUT IS INVALID, THIS MAY CRASH, CORRUPT MEMORY, CREATE A SECURITY HOLE IN YOUR APP,
// MURDER YOUR FIRST-BORN CHILD, AND/OR BRING ABOUT ETERNAL DAMNATION ON ALL OF HUMANITY. DO NOT
// USE UNLESS YOU UNDERSTAND THE CONSEQUENCES.
//
// Given a pointer to a known-valid message located in a single contiguous memory segment,
// returns a reader for that message. No bounds-checking will be done while tranversing this
// message. Use this only if you are absolutely sure that the input data is a valid message
// created by your own system. Never use this to read messages received from others.
//
// To create a trusted message, build a message using a MallocAllocator whose preferred segment
// size is larger than the message size. This guarantees that the message will be allocated as a
// single segment, meaning getSegmentsForOutput() returns a single word array. That word array
// is your message; you may pass a pointer to its first word into readTrusted() to read the
// message.
//
// This can be particularly handy for embedding messages in generated code: you can
// embed the raw bytes (using AlignedData) then make a Reader for it using this. This is the way
// default values are embedded in code generated by the Cap'n Proto compiler. E.g., if you have
// a message MyMessage, you can read its default value like so:
// MyMessage::Reader reader = Message<MyMessage>::ReadTrusted(MyMessage::DEFAULT.words);
}; };
class SegmentReader { // =======================================================================================
public: // Standard implementations of allocators and error reporters.
inline SegmentReader(MessageReader* message, SegmentId id, const word ptr[], WordCount size,
ReadLimiter* readLimiter);
CAPNPROTO_ALWAYS_INLINE(bool containsInterval(const word* from, const word* to)); class MallocAllocator: public Allocator {
public:
explicit MallocAllocator(uint preferredSegmentSizeWords);
~MallocAllocator();
inline MessageReader* getMessage(); static MallocAllocator* getDefaultInstance();
inline SegmentId getSegmentId();
inline const word* getStartPtr(); // implements Allocator --------------------------------------------
inline WordCount getOffsetTo(const word* ptr); ArrayPtr<word> allocate(SegmentId id, uint minimumSize) override;
inline WordCount getSize(); void free(SegmentId id, ArrayPtr<word> ptr) override;
private: private:
MessageReader* message; uint preferredSegmentSizeWords;
SegmentId id;
WordCount size;
const word* start;
ReadLimiter* readLimiter;
CAPNPROTO_DISALLOW_COPY(SegmentReader);
friend class SegmentBuilder;
}; };
class SegmentBuilder: public SegmentReader { class StderrErrorReporter: public ErrorReporter {
public: public:
inline SegmentBuilder(MessageBuilder* message, SegmentId id, word ptr[], WordCount available); ~StderrErrorReporter();
CAPNPROTO_ALWAYS_INLINE(word* allocate(WordCount amount)); static StderrErrorReporter* getDefaultInstance();
inline word* getPtrUnchecked(WordCount offset);
inline MessageBuilder* getMessage(); // implements ErrorReporter ----------------------------------------
void reportError(const char* description) override;
inline WordCount available(); };
private: class ThrowingErrorReporter: public ErrorReporter {
word* pos; public:
word* end; ~ThrowingErrorReporter();
ReadLimiter dummyLimiter;
CAPNPROTO_DISALLOW_COPY(SegmentBuilder); static ThrowingErrorReporter* getDefaultInstance();
// TODO: Do we need mutex locking? // implements ErrorReporter ----------------------------------------
void reportError(const char* description) override;
}; };
// ======================================================================================= // =======================================================================================
// implementation details
inline ReadLimiter::ReadLimiter() template <typename RootType>
// I didn't want to #include <limits> just for this one lousy constant. inline Message<RootType>::Reader::Reader(ArrayPtr<const ArrayPtr<const word>> segments,
: limit(uint64_t(0x7fffffffffffffffll) * WORDS) {} uint recursionLimit, uint64_t readLimit, ErrorReporter* errorReporter)
: internal(segments, recursionLimit, readLimit, errorReporter) {}
inline ReadLimiter::ReadLimiter(WordCount64 limit): limit(limit) {}
inline bool ReadLimiter::canRead(WordCount amount, MessageReader* message) { template <typename RootType>
if (CAPNPROTO_EXPECT_FALSE(amount > limit)) { inline typename RootType::Reader Message<RootType>::Reader::getRoot() {
message->reportReadLimitReached(); return typename RootType::Reader(internal.getRoot(RootType::DEFAULT.words));
return false;
} else {
limit -= amount;
return true;
}
} }
// ------------------------------------------------------------------- template <typename RootType>
inline Message<RootType>::Builder::Builder()
: internal() {}
inline SegmentReader::SegmentReader(MessageReader* message, SegmentId id, const word ptr[], template <typename RootType>
WordCount size, ReadLimiter* readLimiter) inline Message<RootType>::Builder::Builder(Allocator* allocator)
: message(message), id(id), size(size), start(ptr), readLimiter(readLimiter) {} : internal(allocator) {}
inline bool SegmentReader::containsInterval(const word* from, const word* to) { template <typename RootType>
return from >= this->start && to <= this->start + size && inline typename RootType::Builder Message<RootType>::Builder::initRoot() {
readLimiter->canRead(intervalLength(from, to), message); return typename RootType::Builder(internal.initRoot(RootType::DEFAULT.words));
} }
inline MessageReader* SegmentReader::getMessage() { return message; } template <typename RootType>
inline SegmentId SegmentReader::getSegmentId() { return id; } inline typename RootType::Builder Message<RootType>::Builder::getRoot() {
inline const word* SegmentReader::getStartPtr() { return start; } return typename RootType::Builder(internal.getRoot(RootType::DEFAULT.words));
inline WordCount SegmentReader::getOffsetTo(const word* ptr) {
return intervalLength(start, ptr);
}
inline WordCount SegmentReader::getSize() { return size; }
// -------------------------------------------------------------------
inline SegmentBuilder::SegmentBuilder(
MessageBuilder* message, SegmentId id, word ptr[], WordCount available)
: SegmentReader(message, id, ptr, 0 * WORDS, &dummyLimiter),
pos(ptr),
end(pos + available) {}
inline word* SegmentBuilder::allocate(WordCount amount) {
if (amount > intervalLength(pos, end)) {
return nullptr;
} else {
word* result = pos;
pos += amount;
size += amount;
return result;
}
} }
inline word* SegmentBuilder::getPtrUnchecked(WordCount offset) { template <typename RootType>
// const_cast OK because SegmentBuilder's constructor always initializes its SegmentReader base inline ArrayPtr<const ArrayPtr<const word>> Message<RootType>::Builder::getSegmentsForOutput() {
// class with a pointer that was originally non-const. return internal.getSegmentsForOutput();
return const_cast<word*>(start + offset);
} }
inline MessageBuilder* SegmentBuilder::getMessage() { template <typename RootType>
// Down-cast safe because SegmentBuilder's constructor always initializes its SegmentReader base typename RootType::Reader Message<RootType>::readTrusted(const word* data) {
// class with a MessageReader pointer that actually points to a MessageBuilder. return typename RootType::Reader(internal::StructReader::readRootTrusted(
return static_cast<MessageBuilder*>(message); data, RootType::DEFAULT.words));
}
inline WordCount SegmentBuilder::available() {
return intervalLength(pos, end);
}
// -------------------------------------------------------------------
template <typename T>
MessageRoot<T> newMallocMessageRoot(WordCount preferredSegmentSize) {
std::unique_ptr<MessageBuilder> message = newMallocMessage(preferredSegmentSize);
SegmentBuilder* segment = message->getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS);
return MessageRoot<T>(move(message),
typename T::Builder(internal::StructBuilder::initRoot(
segment, rootLocation, T::DEFAULT.words)));
} }
} // namespace capnproto } // namespace capnproto
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#define CAPNPROTO_TYPE_SAFETY_H_ #define CAPNPROTO_TYPE_SAFETY_H_
#include "macros.h" #include "macros.h"
#include <cstddef>
namespace capnproto { namespace capnproto {
...@@ -50,6 +51,55 @@ struct NoInfer { ...@@ -50,6 +51,55 @@ struct NoInfer {
typedef T Type; typedef T Type;
}; };
// =======================================================================================
// ArrayPtr
template <typename T>
class ArrayPtr {
// A pointer to an array. Includes a size. Like any pointer, it doesn't own the target data,
// and passing by value only copies the pointer, not the target.
public:
inline ArrayPtr(): ptr(nullptr), size_(0) {}
inline ArrayPtr(std::nullptr_t): ptr(nullptr), size_(0) {}
inline ArrayPtr(T* ptr, std::size_t size): ptr(ptr), size_(size) {}
inline ArrayPtr(T* begin, T* end): ptr(begin), size_(end - begin) {}
inline operator ArrayPtr<const T>() {
return ArrayPtr<const T>(ptr, size_);
}
inline std::size_t size() const { return size_; }
inline T& operator[](std::size_t index) const {
CAPNPROTO_DEBUG_ASSERT(index < size_, "Out-of-bounds ArrayPtr access.");
return ptr[index];
}
inline T* begin() const { return ptr; }
inline T* end() const { return ptr + size_; }
inline ArrayPtr slice(size_t start, size_t end) {
CAPNPROTO_DEBUG_ASSERT(start <= end && end <= size_, "Out-of-bounds ArrayPtr::slice().");
return ArrayPtr(ptr + start, end - start);
}
private:
T* ptr;
std::size_t size_;
};
template <typename T>
inline ArrayPtr<T> arrayPtr(T* ptr, size_t size) {
// Use this function to construct ArrayPtrs without writing out the type name.
return ArrayPtr<T>(ptr, size);
}
template <typename T>
inline ArrayPtr<T> arrayPtr(T* begin, T* end) {
// Use this function to construct ArrayPtrs without writing out the type name.
return ArrayPtr<T>(begin, end);
}
// ======================================================================================= // =======================================================================================
// IDs // IDs
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "wire-format.h" #include "wire-format.h"
#include "descriptor.h" #include "descriptor.h"
#include "message.h" #include "message.h"
#include "arena.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace capnproto { namespace capnproto {
...@@ -263,8 +264,8 @@ static void checkStruct(StructReader reader) { ...@@ -263,8 +264,8 @@ static void checkStruct(StructReader reader) {
} }
TEST(WireFormat, StructRoundTrip_OneSegment) { TEST(WireFormat, StructRoundTrip_OneSegment) {
std::unique_ptr<MessageBuilder> message = newMallocMessage(512 * WORDS); BuilderArena arena(MallocAllocator::getDefaultInstance());
SegmentBuilder* segment = message->getSegmentWithAvailable(1 * WORDS); SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS); word* rootLocation = segment->allocate(1 * WORDS);
StructBuilder builder = StructBuilder::initRoot(segment, rootLocation, STRUCT_DEFAULT.words); StructBuilder builder = StructBuilder::initRoot(segment, rootLocation, STRUCT_DEFAULT.words);
...@@ -286,7 +287,9 @@ TEST(WireFormat, StructRoundTrip_OneSegment) { ...@@ -286,7 +287,9 @@ TEST(WireFormat, StructRoundTrip_OneSegment) {
// 6 sub-lists (4x 1 word, 1x 2 words) // 6 sub-lists (4x 1 word, 1x 2 words)
// ----- // -----
// 34 // 34
EXPECT_EQ(34 * WORDS, segment->getSize()); ArrayPtr<const ArrayPtr<const word>> segments = arena.getSegmentsForOutput();
ASSERT_EQ(1u, segments.size());
EXPECT_EQ(34u, segments[0].size());
checkStruct(builder); checkStruct(builder);
checkStruct(builder.asReader()); checkStruct(builder.asReader());
...@@ -295,34 +298,35 @@ TEST(WireFormat, StructRoundTrip_OneSegment) { ...@@ -295,34 +298,35 @@ TEST(WireFormat, StructRoundTrip_OneSegment) {
} }
TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) { TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
std::unique_ptr<MessageBuilder> message = newMallocMessage(1 * WORDS); MallocAllocator allocator(1);
SegmentBuilder* segment = message->getSegmentWithAvailable(1 * WORDS); BuilderArena arena(&allocator);
SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS); word* rootLocation = segment->allocate(1 * WORDS);
StructBuilder builder = StructBuilder::initRoot(segment, rootLocation, STRUCT_DEFAULT.words); StructBuilder builder = StructBuilder::initRoot(segment, rootLocation, STRUCT_DEFAULT.words);
setupStruct(builder); setupStruct(builder);
// Verify that we made 15 segments. // Verify that we made 15 segments.
ASSERT_TRUE(message->tryGetSegment(SegmentId(14)) != nullptr); ArrayPtr<const ArrayPtr<const word>> segments = arena.getSegmentsForOutput();
EXPECT_EQ(nullptr, message->tryGetSegment(SegmentId(15))); ASSERT_EQ(15u, segments.size());
// Check that each segment has the expected size. Recall that the first word of each segment will // Check that each segment has the expected size. Recall that the first word of each segment will
// actually be a reference to the first thing allocated within that segment. // actually be a reference to the first thing allocated within that segment.
EXPECT_EQ( 1 * WORDS, message->getSegment(SegmentId( 0))->getSize()); // root ref EXPECT_EQ( 1u, segments[ 0].size()); // root ref
EXPECT_EQ( 7 * WORDS, message->getSegment(SegmentId( 1))->getSize()); // root struct EXPECT_EQ( 7u, segments[ 1].size()); // root struct
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId( 2))->getSize()); // sub-struct EXPECT_EQ( 2u, segments[ 2].size()); // sub-struct
EXPECT_EQ( 3 * WORDS, message->getSegment(SegmentId( 3))->getSize()); // 3-element int32 list EXPECT_EQ( 3u, segments[ 3].size()); // 3-element int32 list
EXPECT_EQ(10 * WORDS, message->getSegment(SegmentId( 4))->getSize()); // struct list EXPECT_EQ(10u, segments[ 4].size()); // struct list
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId( 5))->getSize()); // struct list substruct 1 EXPECT_EQ( 2u, segments[ 5].size()); // struct list substruct 1
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId( 6))->getSize()); // struct list substruct 2 EXPECT_EQ( 2u, segments[ 6].size()); // struct list substruct 2
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId( 7))->getSize()); // struct list substruct 3 EXPECT_EQ( 2u, segments[ 7].size()); // struct list substruct 3
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId( 8))->getSize()); // struct list substruct 4 EXPECT_EQ( 2u, segments[ 8].size()); // struct list substruct 4
EXPECT_EQ( 6 * WORDS, message->getSegment(SegmentId( 9))->getSize()); // list list EXPECT_EQ( 6u, segments[ 9].size()); // list list
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId(10))->getSize()); // list list sublist 1 EXPECT_EQ( 2u, segments[10].size()); // list list sublist 1
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId(11))->getSize()); // list list sublist 2 EXPECT_EQ( 2u, segments[11].size()); // list list sublist 2
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId(12))->getSize()); // list list sublist 3 EXPECT_EQ( 2u, segments[12].size()); // list list sublist 3
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId(13))->getSize()); // list list sublist 4 EXPECT_EQ( 2u, segments[13].size()); // list list sublist 4
EXPECT_EQ( 3 * WORDS, message->getSegment(SegmentId(14))->getSize()); // list list sublist 5 EXPECT_EQ( 3u, segments[14].size()); // list list sublist 5
checkStruct(builder); checkStruct(builder);
checkStruct(builder.asReader()); checkStruct(builder.asReader());
...@@ -330,25 +334,26 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) { ...@@ -330,25 +334,26 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
} }
TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) { TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) {
std::unique_ptr<MessageBuilder> message = newMallocMessage(8 * WORDS); MallocAllocator allocator(8);
SegmentBuilder* segment = message->getSegmentWithAvailable(1 * WORDS); BuilderArena arena(&allocator);
SegmentBuilder* segment = arena.getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS); word* rootLocation = segment->allocate(1 * WORDS);
StructBuilder builder = StructBuilder::initRoot(segment, rootLocation, STRUCT_DEFAULT.words); StructBuilder builder = StructBuilder::initRoot(segment, rootLocation, STRUCT_DEFAULT.words);
setupStruct(builder); setupStruct(builder);
// Verify that we made 6 segments. // Verify that we made 6 segments.
ASSERT_TRUE(message->tryGetSegment(SegmentId(5)) != nullptr); ArrayPtr<const ArrayPtr<const word>> segments = arena.getSegmentsForOutput();
EXPECT_EQ(nullptr, message->tryGetSegment(SegmentId(6))); ASSERT_EQ(6u, segments.size());
// Check that each segment has the expected size. Recall that each object will be prefixed by an // Check that each segment has the expected size. Recall that each object will be prefixed by an
// extra word if its parent is in a different segment. // extra word if its parent is in a different segment.
EXPECT_EQ( 8 * WORDS, message->getSegment(SegmentId(0))->getSize()); // root ref + struct + sub EXPECT_EQ( 8u, segments[0].size()); // root ref + struct + sub
EXPECT_EQ( 3 * WORDS, message->getSegment(SegmentId(1))->getSize()); // 3-element int32 list EXPECT_EQ( 3u, segments[1].size()); // 3-element int32 list
EXPECT_EQ(10 * WORDS, message->getSegment(SegmentId(2))->getSize()); // struct list EXPECT_EQ(10u, segments[2].size()); // struct list
EXPECT_EQ( 8 * WORDS, message->getSegment(SegmentId(3))->getSize()); // struct list substructs EXPECT_EQ( 8u, segments[3].size()); // struct list substructs
EXPECT_EQ( 8 * WORDS, message->getSegment(SegmentId(4))->getSize()); // list list + sublist 1,2 EXPECT_EQ( 8u, segments[4].size()); // list list + sublist 1,2
EXPECT_EQ( 7 * WORDS, message->getSegment(SegmentId(5))->getSize()); // list list sublist 3,4,5 EXPECT_EQ( 7u, segments[5].size()); // list list sublist 3,4,5
checkStruct(builder); checkStruct(builder);
checkStruct(builder.asReader()); checkStruct(builder.asReader());
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "wire-format.h" #include "wire-format.h"
#include "message.h" #include "arena.h"
#include "descriptor.h" #include "descriptor.h"
#include <string.h> #include <string.h>
#include <limits> #include <limits>
...@@ -203,7 +203,7 @@ struct WireHelpers { ...@@ -203,7 +203,7 @@ struct WireHelpers {
// space to act as the landing pad for a far reference. // space to act as the landing pad for a far reference.
WordCount amountPlusRef = amount + REFERENCE_SIZE_IN_WORDS; WordCount amountPlusRef = amount + REFERENCE_SIZE_IN_WORDS;
segment = segment->getMessage()->getSegmentWithAvailable(amountPlusRef); segment = segment->getArena()->getSegmentWithAvailable(amountPlusRef);
ptr = segment->allocate(amountPlusRef); ptr = segment->allocate(amountPlusRef);
// Set up the original reference to be a far reference to the new segment. // Set up the original reference to be a far reference to the new segment.
...@@ -224,12 +224,12 @@ struct WireHelpers { ...@@ -224,12 +224,12 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(word* followFars(WireReference*& ref, SegmentBuilder*& segment)) { static CAPNPROTO_ALWAYS_INLINE(word* followFars(WireReference*& ref, SegmentBuilder*& segment)) {
if (ref->kind() == WireReference::FAR) { if (ref->kind() == WireReference::FAR) {
segment = segment->getMessage()->getSegment(ref->farRef.segmentId.get()); segment = segment->getArena()->getSegment(ref->farRef.segmentId.get());
ref = reinterpret_cast<WireReference*>(segment->getPtrUnchecked(ref->positionInSegment())); ref = reinterpret_cast<WireReference*>(segment->getPtrUnchecked(ref->positionInSegment()));
if (ref->landingPadIsFollowedByAnotherReference()) { if (ref->landingPadIsFollowedByAnotherReference()) {
// Target lives elsewhere. Another far reference follows. // Target lives elsewhere. Another far reference follows.
WireReference* far2 = ref + 1; WireReference* far2 = ref + 1;
segment = segment->getMessage()->getSegment(far2->farRef.segmentId.get()); segment = segment->getArena()->getSegment(far2->farRef.segmentId.get());
return segment->getPtrUnchecked(far2->positionInSegment()); return segment->getPtrUnchecked(far2->positionInSegment());
} else { } else {
// Target immediately follows landing pad. // Target immediately follows landing pad.
...@@ -243,7 +243,7 @@ struct WireHelpers { ...@@ -243,7 +243,7 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE( static CAPNPROTO_ALWAYS_INLINE(
const word* followFars(const WireReference*& ref, SegmentReader*& segment)) { const word* followFars(const WireReference*& ref, SegmentReader*& segment)) {
if (ref->kind() == WireReference::FAR) { if (ref->kind() == WireReference::FAR) {
segment = segment->getMessage()->tryGetSegment(ref->farRef.segmentId.get()); segment = segment->getArena()->tryGetSegment(ref->farRef.segmentId.get());
if (CAPNPROTO_EXPECT_FALSE(segment == nullptr)) { if (CAPNPROTO_EXPECT_FALSE(segment == nullptr)) {
return nullptr; return nullptr;
} }
...@@ -259,7 +259,7 @@ struct WireHelpers { ...@@ -259,7 +259,7 @@ struct WireHelpers {
if (ref->landingPadIsFollowedByAnotherReference()) { if (ref->landingPadIsFollowedByAnotherReference()) {
// Target is in another castle. Another far reference follows. // Target is in another castle. Another far reference follows.
const WireReference* far2 = ref + 1; const WireReference* far2 = ref + 1;
segment = segment->getMessage()->tryGetSegment(far2->farRef.segmentId.get()); segment = segment->getArena()->tryGetSegment(far2->farRef.segmentId.get());
if (CAPNPROTO_EXPECT_FALSE(segment == nullptr)) { if (CAPNPROTO_EXPECT_FALSE(segment == nullptr)) {
return nullptr; return nullptr;
} }
...@@ -622,26 +622,26 @@ struct WireHelpers { ...@@ -622,26 +622,26 @@ struct WireHelpers {
ptr = ref->target(); ptr = ref->target();
} else if (segment != nullptr) { } else if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(recursionLimit == 0)) { if (CAPNPROTO_EXPECT_FALSE(recursionLimit == 0)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles."); "Message is too deeply-nested or contains cycles.");
goto useDefault; goto useDefault;
} }
ptr = followFars(ref, segment); ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) { if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains invalid far reference."); "Message contains invalid far reference.");
goto useDefault; goto useDefault;
} }
if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::STRUCT)) { if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::STRUCT)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains non-struct reference where struct reference was expected."); "Message contains non-struct reference where struct reference was expected.");
goto useDefault; goto useDefault;
} }
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + ref->structRef.wordSize()))){ if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + ref->structRef.wordSize()))){
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contained out-of-bounds struct reference."); "Message contained out-of-bounds struct reference.");
goto useDefault; goto useDefault;
} }
...@@ -672,20 +672,20 @@ struct WireHelpers { ...@@ -672,20 +672,20 @@ struct WireHelpers {
ptr = ref->target(); ptr = ref->target();
} else if (segment != nullptr) { } else if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(recursionLimit == 0)) { if (CAPNPROTO_EXPECT_FALSE(recursionLimit == 0)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles."); "Message is too deeply-nested or contains cycles.");
goto useDefault; goto useDefault;
} }
ptr = followFars(ref, segment); ptr = followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) { if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains invalid far reference."); "Message contains invalid far reference.");
goto useDefault; goto useDefault;
} }
if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::LIST)) { if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::LIST)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains non-list reference where list reference was expected."); "Message contains non-list reference where list reference was expected.");
goto useDefault; goto useDefault;
} }
...@@ -707,13 +707,13 @@ struct WireHelpers { ...@@ -707,13 +707,13 @@ struct WireHelpers {
if (segment != nullptr) { if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval( if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(
ptr - REFERENCE_SIZE_IN_WORDS, ptr + wordCount))) { ptr - REFERENCE_SIZE_IN_WORDS, ptr + wordCount))) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains out-of-bounds list reference."); "Message contains out-of-bounds list reference.");
goto useDefault; goto useDefault;
} }
if (CAPNPROTO_EXPECT_FALSE(tag->kind() != WireReference::STRUCT)) { if (CAPNPROTO_EXPECT_FALSE(tag->kind() != WireReference::STRUCT)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"INLINE_COMPOSITE lists of non-STRUCT type are not supported."); "INLINE_COMPOSITE lists of non-STRUCT type are not supported.");
goto useDefault; goto useDefault;
} }
...@@ -722,7 +722,7 @@ struct WireHelpers { ...@@ -722,7 +722,7 @@ struct WireHelpers {
wordsPerElement = tag->structRef.wordSize() / ELEMENTS; wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (CAPNPROTO_EXPECT_FALSE(size * wordsPerElement > wordCount)) { if (CAPNPROTO_EXPECT_FALSE(size * wordsPerElement > wordCount)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"INLINE_COMPOSITE list's elements overrun its word count."); "INLINE_COMPOSITE list's elements overrun its word count.");
goto useDefault; goto useDefault;
} }
...@@ -761,7 +761,7 @@ struct WireHelpers { ...@@ -761,7 +761,7 @@ struct WireHelpers {
} }
if (CAPNPROTO_EXPECT_FALSE(!compatible)) { if (CAPNPROTO_EXPECT_FALSE(!compatible)) {
segment->getMessage()->reportInvalidData("A list had incompatible element type."); segment->getArena()->reportInvalidData("A list had incompatible element type.");
goto useDefault; goto useDefault;
} }
...@@ -789,7 +789,7 @@ struct WireHelpers { ...@@ -789,7 +789,7 @@ struct WireHelpers {
if (segment != nullptr) { if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr +
roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)))) { roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)))) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contained out-of-bounds list reference."); "Message contained out-of-bounds list reference.");
goto useDefault; goto useDefault;
} }
...@@ -835,7 +835,7 @@ struct WireHelpers { ...@@ -835,7 +835,7 @@ struct WireHelpers {
dataSize, referenceCount, recursionLimit - 1); dataSize, referenceCount, recursionLimit - 1);
} else { } else {
CAPNPROTO_ASSERT(segment != nullptr, "Trusted message had incompatible list element type."); CAPNPROTO_ASSERT(segment != nullptr, "Trusted message had incompatible list element type.");
segment->getMessage()->reportInvalidData("A list had incompatible element type."); segment->getArena()->reportInvalidData("A list had incompatible element type.");
goto useDefault; goto useDefault;
} }
} }
...@@ -859,26 +859,26 @@ struct WireHelpers { ...@@ -859,26 +859,26 @@ struct WireHelpers {
uint size = ref->listRef.elementCount() / ELEMENTS; uint size = ref->listRef.elementCount() / ELEMENTS;
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) { if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains invalid far reference."); "Message contains invalid far reference.");
goto useDefault; goto useDefault;
} }
if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::LIST)) { if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::LIST)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains non-list reference where text was expected."); "Message contains non-list reference where text was expected.");
goto useDefault; goto useDefault;
} }
if (CAPNPROTO_EXPECT_FALSE(ref->listRef.elementSize() != FieldSize::BYTE)) { if (CAPNPROTO_EXPECT_FALSE(ref->listRef.elementSize() != FieldSize::BYTE)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains list reference of non-bytes where text was expected."); "Message contains list reference of non-bytes where text was expected.");
goto useDefault; goto useDefault;
} }
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))))) { roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))))) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contained out-of-bounds text reference."); "Message contained out-of-bounds text reference.");
goto useDefault; goto useDefault;
} }
...@@ -887,7 +887,7 @@ struct WireHelpers { ...@@ -887,7 +887,7 @@ struct WireHelpers {
--size; // NUL terminator --size; // NUL terminator
if (CAPNPROTO_EXPECT_FALSE(cptr[size] != '\0')) { if (CAPNPROTO_EXPECT_FALSE(cptr[size] != '\0')) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains text that is not NUL-terminated."); "Message contains text that is not NUL-terminated.");
goto useDefault; goto useDefault;
} }
...@@ -911,26 +911,26 @@ struct WireHelpers { ...@@ -911,26 +911,26 @@ struct WireHelpers {
uint size = ref->listRef.elementCount() / ELEMENTS; uint size = ref->listRef.elementCount() / ELEMENTS;
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) { if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains invalid far reference."); "Message contains invalid far reference.");
goto useDefault; goto useDefault;
} }
if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::LIST)) { if (CAPNPROTO_EXPECT_FALSE(ref->kind() != WireReference::LIST)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains non-list reference where data was expected."); "Message contains non-list reference where data was expected.");
goto useDefault; goto useDefault;
} }
if (CAPNPROTO_EXPECT_FALSE(ref->listRef.elementSize() != FieldSize::BYTE)) { if (CAPNPROTO_EXPECT_FALSE(ref->listRef.elementSize() != FieldSize::BYTE)) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contains list reference of non-bytes where data was expected."); "Message contains list reference of non-bytes where data was expected.");
goto useDefault; goto useDefault;
} }
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))))) { roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))))) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message contained out-of-bounds data reference."); "Message contained out-of-bounds data reference.");
goto useDefault; goto useDefault;
} }
...@@ -948,6 +948,12 @@ StructBuilder StructBuilder::initRoot( ...@@ -948,6 +948,12 @@ StructBuilder StructBuilder::initRoot(
reinterpret_cast<WireReference*>(location), segment, defaultValue); reinterpret_cast<WireReference*>(location), segment, defaultValue);
} }
StructBuilder StructBuilder::getRoot(
SegmentBuilder* segment, word* location, const word* defaultValue) {
return WireHelpers::getWritableStructReference(
reinterpret_cast<WireReference*>(location), segment, defaultValue);
}
StructBuilder StructBuilder::initStructField( StructBuilder StructBuilder::initStructField(
WireReferenceCount refIndex, const word* typeDefaultValue) const { WireReferenceCount refIndex, const word* typeDefaultValue) const {
return WireHelpers::initStructReference(references + refIndex, segment, typeDefaultValue); return WireHelpers::initStructReference(references + refIndex, segment, typeDefaultValue);
...@@ -1024,7 +1030,7 @@ StructReader StructReader::readRootTrusted(const word* location, const word* def ...@@ -1024,7 +1030,7 @@ StructReader StructReader::readRootTrusted(const word* location, const word* def
StructReader StructReader::readRoot(const word* location, const word* defaultValue, StructReader StructReader::readRoot(const word* location, const word* defaultValue,
SegmentReader* segment, int recursionLimit) { SegmentReader* segment, int recursionLimit) {
if (!segment->containsInterval(location, location + REFERENCE_SIZE_IN_WORDS)) { if (!segment->containsInterval(location, location + REFERENCE_SIZE_IN_WORDS)) {
segment->getMessage()->reportInvalidData("Root location out-of-bounds."); segment->getArena()->reportInvalidData("Root location out-of-bounds.");
location = nullptr; location = nullptr;
} }
...@@ -1126,7 +1132,7 @@ ListReader ListBuilder::asReader(FieldNumber fieldCount, WordCount dataSize, ...@@ -1126,7 +1132,7 @@ ListReader ListBuilder::asReader(FieldNumber fieldCount, WordCount dataSize,
StructReader ListReader::getStructElement(ElementCount index, const word* defaultValue) const { StructReader ListReader::getStructElement(ElementCount index, const word* defaultValue) const {
if (CAPNPROTO_EXPECT_FALSE((segment != nullptr) & (recursionLimit == 0))) { if (CAPNPROTO_EXPECT_FALSE((segment != nullptr) & (recursionLimit == 0))) {
segment->getMessage()->reportInvalidData( segment->getArena()->reportInvalidData(
"Message is too deeply-nested or contains cycles."); "Message is too deeply-nested or contains cycles.");
return WireHelpers::readStructReference(nullptr, nullptr, defaultValue, recursionLimit); return WireHelpers::readStructReference(nullptr, nullptr, defaultValue, recursionLimit);
} else { } else {
......
...@@ -34,11 +34,6 @@ ...@@ -34,11 +34,6 @@
#include "type-safety.h" #include "type-safety.h"
#include "blob.h" #include "blob.h"
namespace capnproto {
class SegmentReader;
class SegmentBuilder;
}
namespace capnproto { namespace capnproto {
namespace internal { namespace internal {
...@@ -52,6 +47,8 @@ class ListBuilder; ...@@ -52,6 +47,8 @@ class ListBuilder;
class ListReader; class ListReader;
struct WireReference; struct WireReference;
struct WireHelpers; struct WireHelpers;
class SegmentReader;
class SegmentBuilder;
// ------------------------------------------------------------------- // -------------------------------------------------------------------
...@@ -85,6 +82,7 @@ public: ...@@ -85,6 +82,7 @@ public:
inline StructBuilder(): segment(nullptr), data(nullptr), references(nullptr) {} inline StructBuilder(): segment(nullptr), data(nullptr), references(nullptr) {}
static StructBuilder initRoot(SegmentBuilder* segment, word* location, const word* defaultValue); static StructBuilder initRoot(SegmentBuilder* segment, word* location, const word* defaultValue);
static StructBuilder getRoot(SegmentBuilder* segment, word* location, const word* defaultValue);
template <typename T> template <typename T>
CAPNPROTO_ALWAYS_INLINE(T getDataField(ElementCount offset) const); CAPNPROTO_ALWAYS_INLINE(T getDataField(ElementCount offset) const);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment