Commit 64a0ca66 authored by Kenton Varda's avatar Kenton Varda

Testing and bug-fixing.

parent 648d267d
...@@ -22,10 +22,71 @@ ...@@ -22,10 +22,71 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "message.h" #include "message.h"
#include <vector>
#include <string.h>
#include <iostream>
namespace capnproto { namespace capnproto {
MessageReader::~MessageReader() {} MessageReader::~MessageReader() {}
MessageBuilder::~MessageBuilder() {} MessageBuilder::~MessageBuilder() {}
class MallocMessage: public MessageBuilder {
public:
MallocMessage(WordCount preferredSegmentSize);
~MallocMessage();
SegmentReader* tryGetSegment(SegmentId id);
void reportInvalidData(const char* description);
void reportReadLimitReached();
SegmentBuilder* getSegment(SegmentId id);
SegmentBuilder* getSegmentWithAvailable(WordCount minimumAvailable);
private:
WordCount preferredSegmentSize;
std::vector<std::unique_ptr<SegmentBuilder>> segments;
std::vector<std::unique_ptr<word[]>> memory;
};
MallocMessage::MallocMessage(WordCount preferredSegmentSize)
: preferredSegmentSize(preferredSegmentSize) {}
MallocMessage::~MallocMessage() {}
SegmentReader* MallocMessage::tryGetSegment(SegmentId id) {
if (id.value > segments.size()) {
return nullptr;
} else {
return segments[id.value].get();
}
}
void MallocMessage::reportInvalidData(const char* description) {
// TODO: Better error reporting.
std::cerr << "MallocMessage: Parse error: " << description << std::endl;
}
void MallocMessage::reportReadLimitReached() {
// TODO: Better error reporting.
std::cerr << "MallocMessage: Exceeded read limit." << std::endl;
}
SegmentBuilder* MallocMessage::getSegment(SegmentId id) {
return segments[id.value].get();
}
SegmentBuilder* MallocMessage::getSegmentWithAvailable(WordCount minimumAvailable) {
if (segments.empty() || segments.back()->available() < minimumAvailable) {
WordCount newSize = std::max(minimumAvailable, preferredSegmentSize);
memory.push_back(std::unique_ptr<word[]>(new word[newSize / WORDS]));
memset(memory.back().get(), 0, newSize / WORDS * sizeof(word));
segments.push_back(std::unique_ptr<SegmentBuilder>(new SegmentBuilder(
this, SegmentId(segments.size()), memory.back().get(), newSize)));
}
return segments.back().get();
}
std::unique_ptr<MessageBuilder> newMallocMessage(WordCount preferredSegmentSize) {
return std::unique_ptr<MessageBuilder>(new MallocMessage(preferredSegmentSize));
}
} // namespace capnproto } // namespace capnproto
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cstddef> #include <cstddef>
#include <memory>
#include "macros.h" #include "macros.h"
#include "type-safety.h" #include "type-safety.h"
...@@ -94,6 +95,9 @@ public: ...@@ -94,6 +95,9 @@ public:
// TODO: Methods to deal with bundled capabilities. // TODO: Methods to deal with bundled capabilities.
}; };
std::unique_ptr<MessageBuilder> newMallocMessage(WordCount preferredSegmentSize);
// Returns a simple MessageBuilder implementation that uses standard allocation.
class ReadLimiter { class ReadLimiter {
// Used to keep track of how much data has been processed from a message, and cut off further // Used to keep track of how much data has been processed from a message, and cut off further
// processing if and when a particular limit is reached. This is primarily intended to guard // processing if and when a particular limit is reached. This is primarily intended to guard
...@@ -156,6 +160,8 @@ public: ...@@ -156,6 +160,8 @@ public:
inline MessageBuilder* getMessage(); inline MessageBuilder* getMessage();
inline WordCount available();
private: private:
word* pos; word* pos;
word* end; word* end;
...@@ -234,6 +240,10 @@ inline MessageBuilder* SegmentBuilder::getMessage() { ...@@ -234,6 +240,10 @@ inline MessageBuilder* SegmentBuilder::getMessage() {
return static_cast<MessageBuilder*>(message); return static_cast<MessageBuilder*>(message);
} }
inline WordCount SegmentBuilder::available() {
return intervalLength(pos, end);
}
} // namespace capnproto } // namespace capnproto
#endif // CAPNPROTO_MESSAGE_H_ #endif // CAPNPROTO_MESSAGE_H_
...@@ -317,8 +317,8 @@ inline constexpr auto operator*(UnitRatio<Number1, Unit2, Unit> ratio, ...@@ -317,8 +317,8 @@ inline constexpr auto operator*(UnitRatio<Number1, Unit2, Unit> ratio,
// ======================================================================================= // =======================================================================================
// Raw memory types and measures // Raw memory types and measures
class byte { uint8_t content; CAPNPROTO_DISALLOW_COPY(byte); }; class byte { uint8_t content; CAPNPROTO_DISALLOW_COPY(byte); public: byte() = default; };
class word { uint64_t content; CAPNPROTO_DISALLOW_COPY(word); }; class word { uint64_t content; CAPNPROTO_DISALLOW_COPY(word); public: word() = default; };
// byte and word are opaque types with sizes of 8 and 64 bits, respectively. These types are useful // byte and word are opaque types with sizes of 8 and 64 bits, respectively. These types are useful
// only to make pointer arithmetic clearer. Since the contents are private, the only way to access // only to make pointer arithmetic clearer. Since the contents are private, the only way to access
// them is to first reinterpret_cast to some other pointer type. // them is to first reinterpret_cast to some other pointer type.
...@@ -334,44 +334,20 @@ static_assert(sizeof(word) == 8, "uint64_t is not 8 bytes?"); ...@@ -334,44 +334,20 @@ static_assert(sizeof(word) == 8, "uint64_t is not 8 bytes?");
namespace internal { class BitLabel; class ElementLabel; class WireReference; } namespace internal { class BitLabel; class ElementLabel; class WireReference; }
#ifdef __CDT_PARSER__ #ifndef CAPNPROTO_DEBUG_TYPES
// Eclipse gets confused by decltypes, so we'll feed it these simplified-yet-compatible definitions. #define CAPNPROTO_DEBUG_TYPES 1
// Set this to zero to degrade all the "count" types below to being plain integers. All the code
// should still operate exactly the same, we just lose compile-time checking. Note that this will
// also change symbol names, so it's important that the Cap'n proto library and any clients be
// compiled with the same setting here.
// //
// We could also consider using these definitions in opt builds. Trouble is, the mangled symbol // TODO: Decide policy on this. It may make sense to only use CAPNPROTO_DEBUG_TYPES when compiling
// names of any functions that take these types as inputs would be affected, so it would be // Cap'n Proto's own tests, but disable it for all real builds, as clients may find this safety
// important to compile the Cap'n Proto library and the client app with the same flags. // tiring.
typedef uint BitCount; #endif
typedef uint8_t BitCount8;
typedef uint16_t BitCount16;
typedef uint32_t BitCount32;
typedef uint64_t BitCount64;
typedef uint ByteCount;
typedef uint8_t ByteCount8;
typedef uint16_t ByteCount16;
typedef uint32_t ByteCount32;
typedef uint64_t ByteCount64;
typedef uint WordCount;
typedef uint8_t WordCount8;
typedef uint16_t WordCount16;
typedef uint32_t WordCount32;
typedef uint64_t WordCount64;
typedef uint ElementCount;
typedef uint8_t ElementCount8;
typedef uint16_t ElementCount16;
typedef uint32_t ElementCount32;
typedef uint64_t ElementCount64;
typedef uint WireReferenceCount;
typedef uint8_t WireReferenceCount8;
typedef uint16_t WireReferenceCount16;
typedef uint32_t WireReferenceCount32;
typedef uint64_t WireReferenceCount64;
#else #if CAPNPROTO_DEBUG_TYPES
typedef Quantity<uint, internal::BitLabel> BitCount; typedef Quantity<uint, internal::BitLabel> BitCount;
typedef Quantity<uint8_t, internal::BitLabel> BitCount8; typedef Quantity<uint8_t, internal::BitLabel> BitCount8;
...@@ -403,6 +379,38 @@ typedef Quantity<uint16_t, internal::WireReference> WireReferenceCount16; ...@@ -403,6 +379,38 @@ typedef Quantity<uint16_t, internal::WireReference> WireReferenceCount16;
typedef Quantity<uint32_t, internal::WireReference> WireReferenceCount32; typedef Quantity<uint32_t, internal::WireReference> WireReferenceCount32;
typedef Quantity<uint64_t, internal::WireReference> WireReferenceCount64; typedef Quantity<uint64_t, internal::WireReference> WireReferenceCount64;
#else
typedef uint BitCount;
typedef uint8_t BitCount8;
typedef uint16_t BitCount16;
typedef uint32_t BitCount32;
typedef uint64_t BitCount64;
typedef uint ByteCount;
typedef uint8_t ByteCount8;
typedef uint16_t ByteCount16;
typedef uint32_t ByteCount32;
typedef uint64_t ByteCount64;
typedef uint WordCount;
typedef uint8_t WordCount8;
typedef uint16_t WordCount16;
typedef uint32_t WordCount32;
typedef uint64_t WordCount64;
typedef uint ElementCount;
typedef uint8_t ElementCount8;
typedef uint16_t ElementCount16;
typedef uint32_t ElementCount32;
typedef uint64_t ElementCount64;
typedef uint WireReferenceCount;
typedef uint8_t WireReferenceCount8;
typedef uint16_t WireReferenceCount16;
typedef uint32_t WireReferenceCount32;
typedef uint64_t WireReferenceCount64;
#endif #endif
constexpr BitCount BITS = unit<BitCount>(); constexpr BitCount BITS = unit<BitCount>();
......
...@@ -23,21 +23,27 @@ ...@@ -23,21 +23,27 @@
#include "wire-format.h" #include "wire-format.h"
#include "descriptor.h" #include "descriptor.h"
#include "message.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace capnproto {
template <typename T, typename U>
std::ostream& operator<<(std::ostream& os, Quantity<T, U> value) {
return os << (value / unit<Quantity<T, U>>());
}
}
namespace capnproto { namespace capnproto {
namespace internal { namespace internal {
namespace { namespace {
TEST(StructReader, RawData) { TEST(WireFormat, SimpleRawDataStruct) {
AlignedData<2> data = { AlignedData<2> data = {{
{ // Struct ref, offset = 1, fieldCount = 1, dataSize = 1, referenceCount = 0
// Struct ref, offset = 1, fieldCount = 1, dataSize = 1, referenceCount = 0 0x08, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, // Content for the data segment.
// Content for the data segment. 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef }};
}
};
StructReader reader = StructReader::readRootTrusted(data.words, data.words); StructReader reader = StructReader::readRootTrusted(data.words, data.words);
...@@ -71,6 +77,8 @@ TEST(StructReader, RawData) { ...@@ -71,6 +77,8 @@ TEST(StructReader, RawData) {
EXPECT_FALSE(reader.getDataField<bool>(14 * ELEMENTS, false)); EXPECT_FALSE(reader.getDataField<bool>(14 * ELEMENTS, false));
EXPECT_FALSE(reader.getDataField<bool>(15 * ELEMENTS, false)); EXPECT_FALSE(reader.getDataField<bool>(15 * ELEMENTS, false));
EXPECT_TRUE (reader.getDataField<bool>(63 * ELEMENTS, false));
EXPECT_TRUE (reader.getDataField<bool>(63 * ELEMENTS, true ));
EXPECT_FALSE(reader.getDataField<bool>(64 * ELEMENTS, false)); EXPECT_FALSE(reader.getDataField<bool>(64 * ELEMENTS, false));
EXPECT_TRUE (reader.getDataField<bool>(64 * ELEMENTS, true )); EXPECT_TRUE (reader.getDataField<bool>(64 * ELEMENTS, true ));
...@@ -88,6 +96,191 @@ TEST(StructReader, RawData) { ...@@ -88,6 +96,191 @@ TEST(StructReader, RawData) {
EXPECT_TRUE (reader.getDataFieldCheckingNumber<bool>(FieldNumber(1), 0 * ELEMENTS, true )); EXPECT_TRUE (reader.getDataFieldCheckingNumber<bool>(FieldNumber(1), 0 * ELEMENTS, true ));
} }
static void setupStruct(StructBuilder builder) {
builder.setDataField<uint64_t>(0 * ELEMENTS, 0x1011121314151617ull);
builder.setDataField<uint32_t>(2 * ELEMENTS, 0x20212223u);
builder.setDataField<uint16_t>(6 * ELEMENTS, 0x3031u);
builder.setDataField<uint8_t>(14 * ELEMENTS, 0x40u);
builder.setDataField<bool>(120 * ELEMENTS, false);
builder.setDataField<bool>(121 * ELEMENTS, false);
builder.setDataField<bool>(122 * ELEMENTS, true);
builder.setDataField<bool>(123 * ELEMENTS, false);
builder.setDataField<bool>(124 * ELEMENTS, true);
builder.setDataField<bool>(125 * ELEMENTS, true);
builder.setDataField<bool>(126 * ELEMENTS, true);
builder.setDataField<bool>(127 * ELEMENTS, false);
{
StructBuilder subStruct = builder.getStructField(
0 * REFERENCES, FieldNumber(1), 1 * WORDS, 0 * REFERENCES);
subStruct.setDataField<uint32_t>(0 * ELEMENTS, 123);
}
{
ListBuilder list = builder.initListField(1 * REFERENCES, FieldSize::FOUR_BYTES, 3 * ELEMENTS);
EXPECT_EQ(3 * ELEMENTS, list.size());
list.setDataElement<int32_t>(0 * ELEMENTS, 200);
list.setDataElement<int32_t>(1 * ELEMENTS, 201);
list.setDataElement<int32_t>(2 * ELEMENTS, 202);
}
{
ListBuilder list = builder.initStructListField(
2 * REFERENCES, 4 * ELEMENTS, FieldNumber(2), 1 * WORDS, 1 * REFERENCES);
EXPECT_EQ(4 * ELEMENTS, list.size());
for (int i = 0; i < 4; i++) {
StructBuilder element = list.getStructElement(i * ELEMENTS, 2 * WORDS / ELEMENTS, 1 * WORDS);
element.setDataField<int32_t>(0 * ELEMENTS, 300 + i);
element.getStructField(0 * REFERENCES, FieldNumber(1), 1 * WORDS, 0 * REFERENCES)
.setDataField<int32_t>(0 * ELEMENTS, 400 + i);
}
}
{
ListBuilder list = builder.initListField(3 * REFERENCES, FieldSize::REFERENCE, 5 * ELEMENTS);
EXPECT_EQ(5 * ELEMENTS, list.size());
for (uint i = 0; i < 5; i++) {
ListBuilder element = list.initListElement(
i * REFERENCES, FieldSize::TWO_BYTES, (i + 1) * ELEMENTS);
EXPECT_EQ((i + 1) * ELEMENTS, element.size());
for (uint j = 0; j <= i; j++) {
element.setDataElement<uint16_t>(j * ELEMENTS, 500 + j);
}
}
}
}
static void checkStruct(StructReader reader) {
EXPECT_EQ(0x1011121314151617ull, reader.getDataField<uint64_t>(0 * ELEMENTS, 1616));
EXPECT_EQ(0x20212223u, reader.getDataField<uint32_t>(2 * ELEMENTS, 1616));
EXPECT_EQ(0x3031u, reader.getDataField<uint16_t>(6 * ELEMENTS, 1616));
EXPECT_EQ(0x40u, reader.getDataField<uint8_t>(14 * ELEMENTS, 16));
EXPECT_FALSE(reader.getDataField<bool>(120 * ELEMENTS, false));
EXPECT_FALSE(reader.getDataField<bool>(121 * ELEMENTS, false));
EXPECT_TRUE (reader.getDataField<bool>(122 * ELEMENTS, false));
EXPECT_FALSE(reader.getDataField<bool>(123 * ELEMENTS, false));
EXPECT_TRUE (reader.getDataField<bool>(124 * ELEMENTS, false));
EXPECT_TRUE (reader.getDataField<bool>(125 * ELEMENTS, false));
EXPECT_TRUE (reader.getDataField<bool>(126 * ELEMENTS, false));
EXPECT_FALSE(reader.getDataField<bool>(127 * ELEMENTS, false));
{
// TODO: Use valid default value.
StructReader subStruct = reader.getStructField(0 * REFERENCES, nullptr);
EXPECT_EQ(123u, subStruct.getDataField<uint32_t>(0 * ELEMENTS, 456));
}
{
// TODO: Use valid default value.
ListReader list = reader.getListField(1 * REFERENCES, FieldSize::FOUR_BYTES, nullptr);
ASSERT_EQ(3 * ELEMENTS, list.size());
EXPECT_EQ(200, list.getDataElement<int32_t>(0 * ELEMENTS));
EXPECT_EQ(201, list.getDataElement<int32_t>(1 * ELEMENTS));
EXPECT_EQ(202, list.getDataElement<int32_t>(2 * ELEMENTS));
}
{
// TODO: Use valid default value.
ListReader list = reader.getListField(2 * REFERENCES, FieldSize::STRUCT, nullptr);
ASSERT_EQ(4 * ELEMENTS, list.size());
for (int i = 0; i < 4; i++) {
StructReader element = list.getStructElement(i * ELEMENTS, nullptr);
EXPECT_EQ(300 + i, element.getDataField<int32_t>(0 * ELEMENTS, 1616));
EXPECT_EQ(400 + i,
element.getStructField(0 * REFERENCES, nullptr)
.getDataField<int32_t>(0 * ELEMENTS, 1616));
}
}
{
// TODO: Use valid default value.
ListReader list = reader.getListField(3 * REFERENCES, FieldSize::REFERENCE, nullptr);
ASSERT_EQ(5 * ELEMENTS, list.size());
for (uint i = 0; i < 5; i++) {
ListReader element = list.getListElement(i * REFERENCES, FieldSize::TWO_BYTES, nullptr);
ASSERT_EQ((i + 1) * ELEMENTS, element.size());
for (uint j = 0; j <= i; j++) {
EXPECT_EQ(500u + j, element.getDataElement<uint16_t>(j * ELEMENTS));
}
}
}
}
TEST(WireFormat, StructRoundTrip) {
std::unique_ptr<MessageBuilder> message = newMallocMessage(512 * WORDS);
SegmentBuilder* segment = message->getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS);
StructBuilder builder =
StructBuilder::initRoot(segment, rootLocation, FieldNumber(16), 2 * WORDS, 4 * REFERENCES);
setupStruct(builder);
// word count:
// 1 root reference
// 6 root struct
// 1 sub message
// 2 3-element int32 list
// 13 struct list
// 1 tag
// 12 4x struct
// 1 data segment
// 1 reference segment
// 1 sub-struct
// 11 list list
// 5 references to sub-lists
// 6 sub-lists (4x 1 word, 1x 2 words)
// -----
// 34
EXPECT_EQ(34 * WORDS, segment->getSize());
checkStruct(builder.asReader());
}
TEST(WireFormat, StructRoundTrip_FarPointers) {
std::unique_ptr<MessageBuilder> message = newMallocMessage(1 * WORDS);
SegmentBuilder* segment = message->getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS);
StructBuilder builder =
StructBuilder::initRoot(segment, rootLocation, FieldNumber(16), 2 * WORDS, 4 * REFERENCES);
setupStruct(builder);
// word count:
// 1 root reference
// 6 root struct
// 1 sub message
// 2 3-element int32 list
// 13 struct list
// 1 tag
// 12 4x struct
// 1 data segment
// 1 reference segment
// 1 sub-struct
// 11 list list
// 5 references to sub-lists
// 6 sub-lists (4x 1 word, 1x 2 words)
// -----
// 34
EXPECT_EQ( 1 * WORDS, message->getSegment(SegmentId( 0))->getSize()); // root ref
EXPECT_EQ( 7 * WORDS, message->getSegment(SegmentId( 1))->getSize()); // root struct
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId( 2))->getSize()); // sub-struct
EXPECT_EQ( 3 * WORDS, message->getSegment(SegmentId( 3))->getSize()); // 3-element int32 list
EXPECT_EQ(10 * WORDS, message->getSegment(SegmentId( 4))->getSize()); // struct list
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId( 5))->getSize()); // struct list substruct 1
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId( 6))->getSize()); // struct list substruct 2
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId( 7))->getSize()); // struct list substruct 3
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId( 8))->getSize()); // struct list substruct 4
EXPECT_EQ( 6 * WORDS, message->getSegment(SegmentId( 9))->getSize()); // list list
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId(10))->getSize()); // list list sublist 1
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId(11))->getSize()); // list list sublist 2
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId(12))->getSize()); // list list sublist 3
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId(13))->getSize()); // list list sublist 4
EXPECT_EQ( 3 * WORDS, message->getSegment(SegmentId(14))->getSize()); // list list sublist 5
checkStruct(builder.asReader());
}
} // namespace } // namespace
} // namespace internal } // namespace internal
} // namespace capnproto } // namespace capnproto
...@@ -101,8 +101,8 @@ struct WireReference { ...@@ -101,8 +101,8 @@ struct WireReference {
} }
CAPNPROTO_ALWAYS_INLINE(void setStruct( CAPNPROTO_ALWAYS_INLINE(void setStruct(
FieldNumber fieldCount, WordCount dataSize, WireReferenceCount refCount, WordCount offset)) { FieldNumber fieldCount, WordCount dataSize, WireReferenceCount refCount, word* target)) {
setTagAndOffset(STRUCT, offset); setTagAndOffset(STRUCT, intervalLength(reinterpret_cast<word*>(this), target));
structRef.fieldCount.set(fieldCount); structRef.fieldCount.set(fieldCount);
structRef.dataSize.set(WordCount8(dataSize)); structRef.dataSize.set(WordCount8(dataSize));
structRef.refCount.set(refCount); structRef.refCount.set(refCount);
...@@ -110,8 +110,8 @@ struct WireReference { ...@@ -110,8 +110,8 @@ struct WireReference {
} }
CAPNPROTO_ALWAYS_INLINE(void setList( CAPNPROTO_ALWAYS_INLINE(void setList(
FieldSize elementSize, ElementCount elementCount, WordCount offset)) { FieldSize elementSize, ElementCount elementCount, word* target)) {
setTagAndOffset(LIST, offset); setTagAndOffset(LIST, intervalLength(reinterpret_cast<word*>(this), target));
CAPNPROTO_DEBUG_ASSERT(elementCount < (1 << 29) * ELEMENTS, CAPNPROTO_DEBUG_ASSERT(elementCount < (1 << 29) * ELEMENTS,
"Lists are limited to 2**29 elements."); "Lists are limited to 2**29 elements.");
listRef.elementSizeAndCount.set( listRef.elementSizeAndCount.set(
...@@ -200,7 +200,7 @@ struct WireHelpers { ...@@ -200,7 +200,7 @@ struct WireHelpers {
word* ptr = allocate(ref, segment, dataSize + referenceCount * WORDS_PER_REFERENCE); word* ptr = allocate(ref, segment, dataSize + referenceCount * WORDS_PER_REFERENCE);
// Initialize the reference. // Initialize the reference.
ref->setStruct(fieldCount, dataSize, referenceCount, segment->getOffsetTo(ptr)); ref->setStruct(fieldCount, dataSize, referenceCount, ptr);
// Build the StructBuilder. // Build the StructBuilder.
return StructBuilder(segment, ptr, reinterpret_cast<WireReference*>(ptr + dataSize)); return StructBuilder(segment, ptr, reinterpret_cast<WireReference*>(ptr + dataSize));
...@@ -235,7 +235,7 @@ struct WireHelpers { ...@@ -235,7 +235,7 @@ struct WireHelpers {
word* ptr = allocate(ref, segment, wordCount); word* ptr = allocate(ref, segment, wordCount);
// Initialize the reference. // Initialize the reference.
ref->setList(elementSize, elementCount, segment->getOffsetTo(ptr)); ref->setList(elementSize, elementCount, ptr);
// Build the ListBuilder. // Build the ListBuilder.
return ListBuilder(segment, ptr, elementCount); return ListBuilder(segment, ptr, elementCount);
...@@ -251,15 +251,15 @@ struct WireHelpers { ...@@ -251,15 +251,15 @@ struct WireHelpers {
1 * REFERENCES * WORDS_PER_REFERENCE + elementCount * wordsPerElement); 1 * REFERENCES * WORDS_PER_REFERENCE + elementCount * wordsPerElement);
// Initialize the reference. // Initialize the reference.
ref->setList(FieldSize::STRUCT, elementCount, segment->getOffsetTo(ptr)); ref->setList(FieldSize::STRUCT, elementCount, ptr);
// The list is prefixed by a struct reference. // The list is prefixed by a struct reference.
WireReference* structRef = reinterpret_cast<WireReference*>(ptr); WireReference* structRef = reinterpret_cast<WireReference*>(ptr);
word* structPtr = ptr + 1 * REFERENCES * WORDS_PER_REFERENCE; word* structPtr = ptr + 1 * REFERENCES * WORDS_PER_REFERENCE;
structRef->setStruct(fieldCount, dataSize, referenceCount, segment->getOffsetTo(structPtr)); structRef->setStruct(fieldCount, dataSize, referenceCount, structPtr);
// Build the ListBuilder. // Build the ListBuilder.
return ListBuilder(segment, ptr, elementCount); return ListBuilder(segment, structPtr, elementCount);
} }
static CAPNPROTO_ALWAYS_INLINE(ListBuilder getWritableListReference( static CAPNPROTO_ALWAYS_INLINE(ListBuilder getWritableListReference(
...@@ -287,12 +287,13 @@ struct WireHelpers { ...@@ -287,12 +287,13 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(StructReader readStructReference( static CAPNPROTO_ALWAYS_INLINE(StructReader readStructReference(
SegmentReader* segment, const WireReference* ref, const word* defaultValue, SegmentReader* segment, const WireReference* ref, const word* defaultValue,
int recursionLimit)) { int recursionLimit)) {
const word* ptr = ref->target(); const word* ptr;
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
segment = nullptr; segment = nullptr;
ref = reinterpret_cast<const WireReference*>(defaultValue); ref = reinterpret_cast<const WireReference*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) { } else if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(recursionLimit == 0)) { if (CAPNPROTO_EXPECT_FALSE(recursionLimit == 0)) {
segment->getMessage()->reportInvalidData( segment->getMessage()->reportInvalidData(
...@@ -312,6 +313,7 @@ struct WireHelpers { ...@@ -312,6 +313,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
ptr = ref->target();
WordCount size = ref->structRef.dataSize.get() + WordCount size = ref->structRef.dataSize.get() +
ref->structRef.refCount.get() * WORDS_PER_REFERENCE; ref->structRef.refCount.get() * WORDS_PER_REFERENCE;
...@@ -320,6 +322,9 @@ struct WireHelpers { ...@@ -320,6 +322,9 @@ struct WireHelpers {
"Message contained out-of-bounds struct reference."); "Message contained out-of-bounds struct reference.");
goto useDefault; goto useDefault;
} }
} else {
// Trusted messages don't contain far pointers.
ptr = ref->target();
} }
return StructReader(segment, ptr, return StructReader(segment, ptr,
...@@ -387,7 +392,7 @@ struct WireHelpers { ...@@ -387,7 +392,7 @@ struct WireHelpers {
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + wordsPerElement * size))) { if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + wordsPerElement * size))) {
segment->getMessage()->reportInvalidData( segment->getMessage()->reportInvalidData(
"Message contained out-of-bounds struct reference."); "Message contained out-of-bounds struct list tag.");
goto useDefault; goto useDefault;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment