Commit 90eae25e authored by Kenton Varda's avatar Kenton Varda

Improve how struct lists work.

parent e9fa0685
......@@ -79,30 +79,58 @@ struct Descriptor {
enum class FieldSize: uint8_t {
// TODO: Rename to FieldLayout or maybe ValueLayout.
BIT = 0,
BYTE = 1,
TWO_BYTES = 2,
FOUR_BYTES = 3,
EIGHT_BYTES = 4,
REFERENCE = 5, // Indicates that the field lives in the reference segment, not the data segment.
KEY_REFERENCE = 6, // A 64-bit key, 64-bit reference pair. Valid only in lists.
STRUCT = 7 // An arbitrary-sized inlined struct. Used only for list elements, not struct
// fields, since a struct cannot embed another struct inline.
VOID = 0,
BIT = 1,
BYTE = 2,
TWO_BYTES = 3,
FOUR_BYTES = 4,
EIGHT_BYTES = 5,
REFERENCE = 6, // Indicates that the field lives in the reference segment, not the data segment.
INLINE_COMPOSITE = 7
// A composite type of fixed width. This serves two purposes:
// 1) For lists of composite types where all the elements would have the exact same width,
// allocating a list of references which in turn point at the elements would waste space. We
// can avoid a layer of indirection by placing all the elements in a flat sequence, and only
// indicating the element properties (e.g. field count for structs) once.
//
// Specifically, a list reference indicating INLINE_COMPOSITE element size actually points to
// a "tag" describing one element. This tag is formatted like a wire reference, but the
// "offset" instead stores the element count of the list. The flat list of elements appears
// immediately after the tag. In the list reference itself, the element count is replaced with
// a word count for the whole list (excluding tag). This allows the tag and elements to be
// precached in a single step rather than two sequential steps.
//
// It is NOT intended to be possible to substitute an INLINE_COMPOSITE list for a REFERENCE
// list or vice-versa without breaking recipients. Recipients expect one or the other
// depending on the message definition.
//
// However, it IS allowed to substitute an INLINE_COMPOSITE list -- specifically, of structs --
// when a list was expected, or vice versa, with the assumption that the first field of the
// struct (field number zero) correspond to the element type. This allows a list of
// primitives to be upgraded to a list of structs, avoiding the need to use parallel arrays
// when you realize that you need to attach some extra information to each element of some
// primitive list.
//
// 2) For struct fields of composite types where the field's total size is known at compile time,
// we can embed the field directly into the parent struct to avoid indirection through a
// reference. However, this means that the field size can never change -- e.g. if it is a
// struct, new fields cannot be added to it. It's unclear if this is really useful so at this
// time it is not supported.
};
typedef decltype(BITS / ELEMENTS) BitsPerElement;
namespace internal {
static constexpr BitsPerElement BITS_PER_ELEMENT_TABLE[] = {
static constexpr BitsPerElement BITS_PER_ELEMENT_TABLE[8] = {
0 * BITS / ELEMENTS,
1 * BITS / ELEMENTS,
8 * BITS / ELEMENTS,
16 * BITS / ELEMENTS,
32 * BITS / ELEMENTS,
64 * BITS / ELEMENTS,
64 * BITS / ELEMENTS,
128 * BITS / ELEMENTS,
0 * BITS / ELEMENTS
};
}
......@@ -190,7 +218,7 @@ struct FieldDescriptor {
// If the field is a reference field (size == REFERENCE), then this is the index within the
// reference array at which the field is located.
//
// A value of INVALID_FIELD_OFFSET means that this is a void field.
// For void fields, the offset is irrelevant and may be INVALID_FIELD_OFFSET.
ByteCount16 unionTagOffset;
// Offset within the data segment at which a union tag exists deciding whether this field is
......
......@@ -53,7 +53,7 @@ MallocMessage::MallocMessage(WordCount preferredSegmentSize)
MallocMessage::~MallocMessage() {}
SegmentReader* MallocMessage::tryGetSegment(SegmentId id) {
if (id.value > segments.size()) {
if (id.value >= segments.size()) {
return nullptr;
} else {
return segments[id.value].get();
......
......@@ -427,6 +427,8 @@ constexpr auto BITS_PER_REFERENCE = 64 * BITS / REFERENCES;
constexpr auto BYTES_PER_REFERENCE = 8 * BYTES / REFERENCES;
constexpr auto WORDS_PER_REFERENCE = 1 * WORDS / REFERENCES;
constexpr WordCount REFERENCE_SIZE_IN_WORDS = 1 * REFERENCES * WORDS_PER_REFERENCE;
template <typename T>
inline constexpr decltype(BYTES / ELEMENTS) bytesPerElement() {
return sizeof(T) * BYTES / ELEMENTS;
......
......@@ -150,6 +150,59 @@ static void setupStruct(StructBuilder builder) {
}
}
static void checkStruct(StructBuilder builder) {
EXPECT_EQ(0x1011121314151617ull, builder.getDataField<uint64_t>(0 * ELEMENTS));
EXPECT_EQ(0x20212223u, builder.getDataField<uint32_t>(2 * ELEMENTS));
EXPECT_EQ(0x3031u, builder.getDataField<uint16_t>(6 * ELEMENTS));
EXPECT_EQ(0x40u, builder.getDataField<uint8_t>(14 * ELEMENTS));
EXPECT_FALSE(builder.getDataField<bool>(120 * ELEMENTS));
EXPECT_FALSE(builder.getDataField<bool>(121 * ELEMENTS));
EXPECT_TRUE (builder.getDataField<bool>(122 * ELEMENTS));
EXPECT_FALSE(builder.getDataField<bool>(123 * ELEMENTS));
EXPECT_TRUE (builder.getDataField<bool>(124 * ELEMENTS));
EXPECT_TRUE (builder.getDataField<bool>(125 * ELEMENTS));
EXPECT_TRUE (builder.getDataField<bool>(126 * ELEMENTS));
EXPECT_FALSE(builder.getDataField<bool>(127 * ELEMENTS));
{
StructBuilder subStruct = builder.getStructField(
0 * REFERENCES, FieldNumber(1), 1 * WORDS, 0 * REFERENCES);
EXPECT_EQ(123u, subStruct.getDataField<uint32_t>(0 * ELEMENTS));
}
{
ListBuilder list = builder.getListField(1 * REFERENCES, FieldSize::FOUR_BYTES);
ASSERT_EQ(3 * ELEMENTS, list.size());
EXPECT_EQ(200, list.getDataElement<int32_t>(0 * ELEMENTS));
EXPECT_EQ(201, list.getDataElement<int32_t>(1 * ELEMENTS));
EXPECT_EQ(202, list.getDataElement<int32_t>(2 * ELEMENTS));
}
{
ListBuilder list = builder.getListField(2 * REFERENCES, FieldSize::INLINE_COMPOSITE);
ASSERT_EQ(4 * ELEMENTS, list.size());
for (int i = 0; i < 4; i++) {
StructBuilder element = list.getStructElement(i * ELEMENTS, 2 * WORDS / ELEMENTS, 1 * WORDS);
EXPECT_EQ(300 + i, element.getDataField<int32_t>(0 * ELEMENTS));
EXPECT_EQ(400 + i,
element.getStructField(0 * REFERENCES, FieldNumber(1), 1 * WORDS, 0 * REFERENCES)
.getDataField<int32_t>(0 * ELEMENTS));
}
}
{
ListBuilder list = builder.getListField(3 * REFERENCES, FieldSize::REFERENCE);
ASSERT_EQ(5 * ELEMENTS, list.size());
for (uint i = 0; i < 5; i++) {
ListBuilder element = list.getListElement(i * REFERENCES, FieldSize::TWO_BYTES);
ASSERT_EQ((i + 1) * ELEMENTS, element.size());
for (uint j = 0; j <= i; j++) {
EXPECT_EQ(500u + j, element.getDataElement<uint16_t>(j * ELEMENTS));
}
}
}
}
static void checkStruct(StructReader reader) {
EXPECT_EQ(0x1011121314151617ull, reader.getDataField<uint64_t>(0 * ELEMENTS, 1616));
EXPECT_EQ(0x20212223u, reader.getDataField<uint32_t>(2 * ELEMENTS, 1616));
......@@ -181,7 +234,7 @@ static void checkStruct(StructReader reader) {
{
// TODO: Use valid default value.
ListReader list = reader.getListField(2 * REFERENCES, FieldSize::STRUCT, nullptr);
ListReader list = reader.getListField(2 * REFERENCES, FieldSize::INLINE_COMPOSITE, nullptr);
ASSERT_EQ(4 * ELEMENTS, list.size());
for (int i = 0; i < 4; i++) {
StructReader element = list.getStructElement(i * ELEMENTS, nullptr);
......@@ -206,7 +259,7 @@ static void checkStruct(StructReader reader) {
}
}
TEST(WireFormat, StructRoundTrip) {
TEST(WireFormat, StructRoundTrip_OneSegment) {
std::unique_ptr<MessageBuilder> message = newMallocMessage(512 * WORDS);
SegmentBuilder* segment = message->getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS);
......@@ -233,12 +286,13 @@ TEST(WireFormat, StructRoundTrip) {
// 34
EXPECT_EQ(34 * WORDS, segment->getSize());
checkStruct(builder);
checkStruct(builder.asReader());
checkStruct(StructReader::readRootTrusted(segment->getStartPtr(), nullptr));
checkStruct(StructReader::readRoot(segment->getStartPtr(), nullptr, segment, 4));
}
TEST(WireFormat, StructRoundTrip_MultipleSegments) {
TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
std::unique_ptr<MessageBuilder> message = newMallocMessage(1 * WORDS);
SegmentBuilder* segment = message->getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS);
......@@ -269,7 +323,36 @@ TEST(WireFormat, StructRoundTrip_MultipleSegments) {
EXPECT_EQ( 2 * WORDS, message->getSegment(SegmentId(13))->getSize()); // list list sublist 4
EXPECT_EQ( 3 * WORDS, message->getSegment(SegmentId(14))->getSize()); // list list sublist 5
checkStruct(builder);
checkStruct(builder.asReader());
checkStruct(StructReader::readRoot(segment->getStartPtr(), nullptr, segment, 4));
}
TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) {
std::unique_ptr<MessageBuilder> message = newMallocMessage(8 * WORDS);
SegmentBuilder* segment = message->getSegmentWithAvailable(1 * WORDS);
word* rootLocation = segment->allocate(1 * WORDS);
StructBuilder builder =
StructBuilder::initRoot(segment, rootLocation, FieldNumber(16), 2 * WORDS, 4 * REFERENCES);
setupStruct(builder);
// Verify that we made 6 segments.
ASSERT_TRUE(message->tryGetSegment(SegmentId(5)) != nullptr);
EXPECT_EQ(nullptr, message->tryGetSegment(SegmentId(6)));
// Check that each segment has the expected size. Recall that each object will be prefixed by an
// extra word if its parent is in a different segment.
EXPECT_EQ( 8 * WORDS, message->getSegment(SegmentId(0))->getSize()); // root ref + struct + sub
EXPECT_EQ( 3 * WORDS, message->getSegment(SegmentId(1))->getSize()); // 3-element int32 list
EXPECT_EQ(10 * WORDS, message->getSegment(SegmentId(2))->getSize()); // struct list
EXPECT_EQ( 8 * WORDS, message->getSegment(SegmentId(3))->getSize()); // struct list substructs
EXPECT_EQ( 8 * WORDS, message->getSegment(SegmentId(4))->getSize()); // list list + sublist 1,2
EXPECT_EQ( 7 * WORDS, message->getSegment(SegmentId(5))->getSize()); // list list sublist 3,4,5
checkStruct(builder);
checkStruct(builder.asReader());
checkStruct(StructReader::readRoot(segment->getStartPtr(), nullptr, segment, 4));
}
} // namespace
......
......@@ -75,6 +75,9 @@ struct WireReference {
CAPNPROTO_ALWAYS_INLINE(ElementCount elementCount() const) {
return (elementSizeAndCount.get() & 0x1fffffffu) * ELEMENTS;
}
CAPNPROTO_ALWAYS_INLINE(WordCount inlineCompositeWordCount() const) {
return elementCount() * (1 * WORDS / ELEMENTS);
}
} listRef;
struct {
......@@ -86,6 +89,9 @@ struct WireReference {
CAPNPROTO_ALWAYS_INLINE(WordCount offset() const) {
return (offsetAndTag.get() >> 3) * WORDS;
}
CAPNPROTO_ALWAYS_INLINE(ElementCount tagElementCount() const) {
return (offsetAndTag.get() >> 3) * ELEMENTS;
}
CAPNPROTO_ALWAYS_INLINE(word* target()) {
return reinterpret_cast<word*>(this) + offset();
}
......@@ -100,6 +106,10 @@ struct WireReference {
offsetAndTag.set(((offset / WORDS) << 3) | tag);
}
CAPNPROTO_ALWAYS_INLINE(void setTagAndElementCount(Tag tag, ElementCount elementCount)) {
offsetAndTag.set(((elementCount / ELEMENTS) << 3) | tag);
}
CAPNPROTO_ALWAYS_INLINE(void setStruct(
FieldNumber fieldCount, WordCount dataSize, WireReferenceCount refCount, word* target)) {
setTagAndOffset(STRUCT, intervalLength(reinterpret_cast<word*>(this), target));
......@@ -109,6 +119,16 @@ struct WireReference {
structRef.reserved0.set(0);
}
CAPNPROTO_ALWAYS_INLINE(void setStructTag(
FieldNumber fieldCount, WordCount dataSize, WireReferenceCount refCount,
ElementCount elementCount)) {
setTagAndElementCount(STRUCT, elementCount);
structRef.fieldCount.set(fieldCount);
structRef.dataSize.set(WordCount8(dataSize));
structRef.refCount.set(refCount);
structRef.reserved0.set(0);
}
CAPNPROTO_ALWAYS_INLINE(void setList(
FieldSize elementSize, ElementCount elementCount, word* target)) {
setTagAndOffset(LIST, intervalLength(reinterpret_cast<word*>(this), target));
......@@ -118,6 +138,23 @@ struct WireReference {
(static_cast<int>(elementSize) << 29) | (elementCount / ELEMENTS));
}
CAPNPROTO_ALWAYS_INLINE(void setInlineCompositeList(WordCount wordCount, word* target)) {
setTagAndOffset(LIST, intervalLength(reinterpret_cast<word*>(this), target));
CAPNPROTO_DEBUG_ASSERT(wordCount < (1 << 29) * WORDS,
"Inline composite lists are limited to 2**29 words.");
listRef.elementSizeAndCount.set(
(static_cast<int>(FieldSize::INLINE_COMPOSITE) << 29) | (wordCount / WORDS));
}
CAPNPROTO_ALWAYS_INLINE(void setListTag(FieldSize elementSize, ElementCount listCount,
ElementCount elementsPerList)) {
setTagAndElementCount(LIST, listCount);
CAPNPROTO_DEBUG_ASSERT(elementsPerList < (1 << 29) * ELEMENTS,
"Lists are limited to 2**29 elements.");
listRef.elementSizeAndCount.set(
(static_cast<int>(elementSize) << 29) | (elementsPerList / ELEMENTS));
}
CAPNPROTO_ALWAYS_INLINE(void setFar(SegmentId segmentId, WordCount offset)) {
setTagAndOffset(FAR, offset);
farRef.segmentId.set(segmentId);
......@@ -152,7 +189,7 @@ struct WireHelpers {
// thread could have grabbed the space between when we asked the message for the segment and
// when we asked the segment to allocate space.
do {
WordCount amountPlusRef = amount + 1 * REFERENCES * WORDS_PER_REFERENCE;
WordCount amountPlusRef = amount + REFERENCE_SIZE_IN_WORDS;
segment = segment->getMessage()->getSegmentWithAvailable(amountPlusRef);
ptr = segment->allocate(amountPlusRef);
} while (CAPNPROTO_EXPECT_FALSE(ptr == nullptr));
......@@ -161,7 +198,7 @@ struct WireHelpers {
ref = reinterpret_cast<WireReference*>(ptr);
// Allocated space follows new reference.
return ptr + 1 * REFERENCES * WORDS_PER_REFERENCE;
return ptr + REFERENCE_SIZE_IN_WORDS;
} else {
return ptr;
}
......@@ -182,12 +219,12 @@ struct WireHelpers {
return false;
}
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(segment->getStartPtr(),
segment->getStartPtr() + 1 * REFERENCES * WORDS_PER_REFERENCE))) {
const word* ptr = segment->getStartPtr() + ref->offset();
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + REFERENCE_SIZE_IN_WORDS))) {
return false;
}
ref = reinterpret_cast<const WireReference*>(segment->getStartPtr());
ref = reinterpret_cast<const WireReference*>(ptr);
}
return true;
}
......@@ -216,7 +253,7 @@ struct WireHelpers {
CAPNPROTO_DEBUG_ASSERT(ref->structRef.refCount.get() == referenceCount,
"Trying to update struct with incorrect reference count.");
word* ptr = segment->getPtrUnchecked(ref->offset());
word* ptr = ref->target();
return StructBuilder(segment, ptr, reinterpret_cast<WireReference*>(ptr + dataSize));
}
}
......@@ -224,7 +261,7 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(ListBuilder initListReference(
WireReference* ref, SegmentBuilder* segment, ElementCount elementCount,
FieldSize elementSize)) {
CAPNPROTO_DEBUG_ASSERT(elementSize != FieldSize::STRUCT,
CAPNPROTO_DEBUG_ASSERT(elementSize != FieldSize::INLINE_COMPOSITE,
"Should have called initStructListReference() instead.");
// Calculate size of the list.
......@@ -247,19 +284,19 @@ struct WireHelpers {
auto wordsPerElement = (dataSize + referenceCount * WORDS_PER_REFERENCE) / ELEMENTS;
// Allocate the list, prefixed by a single WireReference.
word* ptr = allocate(ref, segment,
1 * REFERENCES * WORDS_PER_REFERENCE + elementCount * wordsPerElement);
WordCount wordCount = elementCount * wordsPerElement;
word* ptr = allocate(ref, segment, REFERENCE_SIZE_IN_WORDS + wordCount);
// Initialize the reference.
ref->setList(FieldSize::STRUCT, elementCount, ptr);
// INLINE_COMPOSITE lists replace the element count with the word count.
ref->setInlineCompositeList(wordCount, ptr);
// The list is prefixed by a struct reference.
WireReference* structRef = reinterpret_cast<WireReference*>(ptr);
word* structPtr = ptr + 1 * REFERENCES * WORDS_PER_REFERENCE;
structRef->setStruct(fieldCount, dataSize, referenceCount, structPtr);
// Initialize the list tag.
reinterpret_cast<WireReference*>(ptr)->setStructTag(
fieldCount, dataSize, referenceCount, elementCount);
// Build the ListBuilder.
return ListBuilder(segment, structPtr, elementCount);
return ListBuilder(segment, ptr + REFERENCE_SIZE_IN_WORDS, elementCount);
}
static CAPNPROTO_ALWAYS_INLINE(ListBuilder getWritableListReference(
......@@ -273,14 +310,17 @@ struct WireHelpers {
CAPNPROTO_ASSERT(ref->tag() == WireReference::LIST,
"Called getList{Field,Element}() but existing reference is not a list.");
if (elementSize == FieldSize::STRUCT) {
WireReference* structRef = reinterpret_cast<WireReference*>(
segment->getPtrUnchecked(ref->offset()));
return ListBuilder(segment,
segment->getPtrUnchecked(structRef->offset()), ref->listRef.elementCount());
if (elementSize == FieldSize::INLINE_COMPOSITE) {
// Read the tag to get the actual element count.
WireReference* tag = reinterpret_cast<WireReference*>(ref->target());
CAPNPROTO_ASSERT(tag->tag() == WireReference::STRUCT,
"INLINE_COMPOSITE list with non-STRUCT elements not supported.");
ElementCount elementCount = tag->tagElementCount();
// First list element is at ptr + 1 reference.
return ListBuilder(segment, reinterpret_cast<word*>(tag + 1), elementCount);
} else {
return ListBuilder(segment,
segment->getPtrUnchecked(ref->offset()), ref->listRef.elementCount());
return ListBuilder(segment, ref->target(), ref->listRef.elementCount());
}
}
......@@ -314,10 +354,8 @@ struct WireHelpers {
}
ptr = ref->target();
WordCount size = ref->structRef.dataSize.get() +
ref->structRef.refCount.get() * WORDS_PER_REFERENCE;
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + size))) {
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + ref->structRef.wordSize()))){
segment->getMessage()->reportInvalidData(
"Message contained out-of-bounds struct reference.");
goto useDefault;
......@@ -327,11 +365,12 @@ struct WireHelpers {
ptr = ref->target();
}
return StructReader(segment, ptr,
ref->structRef.fieldCount.get(),
ref->structRef.dataSize.get(),
ref->structRef.refCount.get(),
0 * BITS, recursionLimit - 1);
return StructReader(
segment, ptr, reinterpret_cast<const WireReference*>(ptr + ref->structRef.dataSize.get()),
ref->structRef.fieldCount.get(),
ref->structRef.dataSize.get(),
ref->structRef.refCount.get(),
0 * BITS, recursionLimit - 1);
}
static CAPNPROTO_ALWAYS_INLINE(ListReader readListReference(
......@@ -361,38 +400,37 @@ struct WireHelpers {
}
}
if (ref->listRef.elementSize() == FieldSize::STRUCT) {
ElementCount size = ref->listRef.elementCount();
if (ref->listRef.elementSize() == FieldSize::INLINE_COMPOSITE) {
decltype(WORDS/ELEMENTS) wordsPerElement;
ElementCount size;
// A struct list reference actually points to a struct reference which in turn points to the
// first struct in the list.
const word* ptrPtr = ref->target();
ref = reinterpret_cast<const WireReference*>(ptrPtr);
const word* ptr = ref->target();
WordCount wordCount = ref->listRef.inlineCompositeWordCount();
const word* ptr;
// An INLINE_COMPOSITE list points to a tag, which is formatted like a reference.
const WireReference* tag = reinterpret_cast<const WireReference*>(ptr);
ptr += REFERENCE_SIZE_IN_WORDS;
if (segment != nullptr) {
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(
ptrPtr, ptrPtr + 1 * REFERENCES * WORDS_PER_REFERENCE))) {
ptr - REFERENCE_SIZE_IN_WORDS, ptr + wordCount))) {
segment->getMessage()->reportInvalidData(
"Message contains out-of-bounds list reference.");
goto useDefault;
}
if (CAPNPROTO_EXPECT_FALSE(ref->tag() != WireReference::STRUCT)) {
if (CAPNPROTO_EXPECT_FALSE(tag->tag() != WireReference::STRUCT)) {
segment->getMessage()->reportInvalidData(
"Message contains struct list reference that does not point to a struct reference.");
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.");
goto useDefault;
}
wordsPerElement = (ref->structRef.dataSize.get() +
ref->structRef.refCount.get() * WORDS_PER_REFERENCE) / ELEMENTS;
ptr = ref->target();
size = tag->tagElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (CAPNPROTO_EXPECT_FALSE(!segment->containsInterval(ptr, ptr + wordsPerElement * size))) {
if (CAPNPROTO_EXPECT_FALSE(size * wordsPerElement > wordCount)) {
segment->getMessage()->reportInvalidData(
"Message contained out-of-bounds struct list tag.");
"INLINE_COMPOSITE list's elements overrun its word count.");
goto useDefault;
}
......@@ -404,24 +442,27 @@ struct WireHelpers {
// Check whether the size is compatible.
bool compatible = false;
switch (expectedElementSize) {
case FieldSize::VOID:
compatible = true;
break;
case FieldSize::BIT:
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES:
compatible = ref->structRef.dataSize.get() > 0 * WORDS;
compatible = tag->structRef.dataSize.get() > 0 * WORDS;
break;
case FieldSize::REFERENCE:
ptr += ref->structRef.dataSize.get();
compatible = ref->structRef.refCount.get() > 0 * REFERENCES;
break;
case FieldSize::KEY_REFERENCE:
compatible = false;
// We expected a list of references but got a list of structs. Assuming the first field
// in the struct is the reference we were looking for, we want to munge the pointer to
// point at the first element's reference segment.
ptr += tag->structRef.dataSize.get();
compatible = tag->structRef.refCount.get() > 0 * REFERENCES;
break;
case FieldSize::STRUCT:
case FieldSize::INLINE_COMPOSITE:
compatible = true;
break;
}
......@@ -432,20 +473,20 @@ struct WireHelpers {
}
} else {
// Trusted message.
// This logic is equivalent to the other branch, above, but skipping all the checks.
ptr = ref->target();
wordsPerElement = (ref->structRef.dataSize.get() +
ref->structRef.refCount.get() * WORDS_PER_REFERENCE) / ELEMENTS;
size = tag->tagElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (expectedElementSize == FieldSize::REFERENCE) {
ptr += ref->structRef.dataSize.get();
ptr += tag->structRef.dataSize.get();
}
}
return ListReader(segment, ptr, size, wordsPerElement * BITS_PER_WORD,
ref->structRef.fieldCount.get(),
ref->structRef.dataSize.get(),
ref->structRef.refCount.get(),
tag->structRef.fieldCount.get(),
tag->structRef.dataSize.get(),
tag->structRef.refCount.get(),
recursionLimit - 1);
} else {
......@@ -465,7 +506,7 @@ struct WireHelpers {
if (ref->listRef.elementSize() == expectedElementSize) {
return ListReader(segment, ptr, ref->listRef.elementCount(), step, recursionLimit - 1);
} else if (expectedElementSize == FieldSize::STRUCT) {
} else if (expectedElementSize == FieldSize::INLINE_COMPOSITE) {
// We were expecting a struct list, but we received a list of some other type. Perhaps a
// non-struct list was recently upgraded to a struct list, but the sender is using the
// old version of the protocol. We need to verify that the struct's first field matches
......@@ -475,6 +516,11 @@ struct WireHelpers {
WireReferenceCount referenceCount;
switch (ref->listRef.elementSize()) {
case FieldSize::VOID:
dataSize = 0 * WORDS;
referenceCount = 0 * REFERENCES;
break;
case FieldSize::BIT:
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
......@@ -489,12 +535,7 @@ struct WireHelpers {
referenceCount = 1 * REFERENCES;
break;
case FieldSize::KEY_REFERENCE:
dataSize = 1 * WORDS;
referenceCount = 1 * REFERENCES;
break;
case FieldSize::STRUCT:
case FieldSize::INLINE_COMPOSITE:
CAPNPROTO_ASSERT(false, "can't get here");
break;
}
......@@ -502,8 +543,7 @@ struct WireHelpers {
return ListReader(segment, ptr, ref->listRef.elementCount(), step, FieldNumber(1),
dataSize, referenceCount, recursionLimit - 1);
} else {
// If segment is null, then we're parsing a trusted message that was invalid. Crashing is
// within contract.
CAPNPROTO_ASSERT(segment != nullptr, "Trusted message had incompatible list element type.");
segment->getMessage()->reportInvalidData("A list had incompatible element type.");
goto useDefault;
}
......@@ -549,14 +589,16 @@ ListBuilder StructBuilder::getListField(WireReferenceCount refIndex, FieldSize e
}
StructReader StructBuilder::asReader() const {
// HACK: We just give maxed-out field and reference counts because they are only used for
// checking for field presence.
// HACK: We just give maxed-out field, data size, and reference counts because they are only
// used for checking for field presence.
static_assert(sizeof(WireReference::structRef.fieldCount) == 1,
"Has the maximum field count changed?");
static_assert(sizeof(WireReference::structRef.dataSize) == 1,
"Has the maximum data size changed?");
static_assert(sizeof(WireReference::structRef.refCount) == 1,
"Has the maximum reference count changed?");
return StructReader(segment, data, FieldNumber(0xff),
intervalLength(data, reinterpret_cast<word*>(references)), 0xff * REFERENCES,
return StructReader(segment, data, references,
FieldNumber(0xff), 0xff * WORDS, 0xff * REFERENCES,
0 * BITS, std::numeric_limits<int>::max());
}
......@@ -567,7 +609,7 @@ StructReader StructReader::readRootTrusted(const word* location, const word* def
StructReader StructReader::readRoot(const word* location, const word* defaultValue,
SegmentReader* segment, int recursionLimit) {
if (!segment->containsInterval(location, location + 1 * REFERENCES * WORDS_PER_REFERENCE)) {
if (!segment->containsInterval(location, location + REFERENCE_SIZE_IN_WORDS)) {
segment->getMessage()->reportInvalidData("Root location out-of-bounds.");
location = nullptr;
}
......@@ -578,17 +620,13 @@ StructReader StructReader::readRoot(const word* location, const word* defaultVal
StructReader StructReader::getStructField(
WireReferenceCount refIndex, const word* defaultValue) const {
const WireReference* ref = refIndex >= referenceCount ? nullptr :
reinterpret_cast<const WireReference*>(
reinterpret_cast<const word*>(ptr) + dataSize) + refIndex;
const WireReference* ref = refIndex >= referenceCount ? nullptr : references + refIndex;
return WireHelpers::readStructReference(segment, ref, defaultValue, recursionLimit);
}
ListReader StructReader::getListField(
WireReferenceCount refIndex, FieldSize expectedElementSize, const word* defaultValue) const {
const WireReference* ref = refIndex >= referenceCount ? nullptr :
reinterpret_cast<const WireReference*>(
reinterpret_cast<const word*>(ptr) + dataSize) + refIndex;
const WireReference* ref = refIndex >= referenceCount ? nullptr : references + refIndex;
return WireHelpers::readListReference(
segment, ref, defaultValue, expectedElementSize, recursionLimit);
}
......@@ -621,6 +659,9 @@ ListBuilder ListBuilder::getListElement(WireReferenceCount index, FieldSize elem
}
ListReader ListBuilder::asReader(FieldSize elementSize) const {
// TODO: For INLINE_COMPOSITE I suppose we could just check the tag?
CAPNPROTO_ASSERT(elementSize != FieldSize::INLINE_COMPOSITE,
"Need to call the other asReader() overload for INLINE_COMPOSITE lists.");
return ListReader(segment, ptr, elementCount, bitsPerElement(elementSize),
std::numeric_limits<int>::max());
}
......@@ -639,8 +680,10 @@ StructReader ListReader::getStructElement(ElementCount index, const word* defaul
return WireHelpers::readStructReference(nullptr, nullptr, defaultValue, recursionLimit);
} else {
BitCount64 indexBit = ElementCount64(index) * stepBits;
const byte* structPtr = reinterpret_cast<const byte*>(ptr) + indexBit / BITS_PER_BYTE;
return StructReader(
segment, reinterpret_cast<const byte*>(ptr) + indexBit / BITS_PER_BYTE,
segment, structPtr,
reinterpret_cast<const WireReference*>(structPtr + structDataSize * BYTES_PER_WORD),
structFieldCount, structDataSize, structReferenceCount, indexBit % BITS_PER_BYTE,
recursionLimit - 1);
}
......
......@@ -132,8 +132,8 @@ private:
class StructReader {
public:
inline StructReader()
: segment(nullptr), ptr(nullptr), fieldCount(0), dataSize(0), referenceCount(0),
bit0Offset(0 * BITS), recursionLimit(0) {}
: segment(nullptr), data(nullptr), references(nullptr), fieldCount(0), dataSize(0),
referenceCount(0), bit0Offset(0 * BITS), recursionLimit(0) {}
static StructReader readRootTrusted(const word* location, const word* defaultValue);
static StructReader readRoot(const word* location, const word* defaultValue,
......@@ -167,17 +167,8 @@ public:
private:
SegmentReader* segment; // Memory segment in which the struct resides.
const void* ptr;
// ptr[0] points to the location between the struct's data and reference segments.
// ptr[1] points to the end of the *default* data segment.
// We put these in an array so we can choose between them without a branch.
// These pointers are not necessarily word-aligned -- they are aligned as well as necessary for
// the data they might point at. So if the struct has only one field that we know of, and it is
// of type Int16, then the pointers only need to be 16-bit aligned. Or if the struct has fields
// of type Int16 and Int64 (in that order), but the struct reference on the wire self-reported
// as having only one field (therefore, only the Int16), then ptr[0] need only be 16-bit aligned
// while ptr[1] must be 64-bit aligned. This relaxation of alignment is needed to handle the
// case where a list of primitives is upgraded to a list of structs.
const void* data;
const WireReference* references;
FieldNumber fieldCount; // Number of fields the struct is reported to have.
WordCount8 dataSize; // Size of data segment.
......@@ -192,11 +183,12 @@ private:
// Limits the depth of message structures to guard against stack-overflow-based DoS attacks.
// Once this reaches zero, further pointers will be pruned.
inline StructReader(SegmentReader* segment, const void* ptr, FieldNumber fieldCount,
WordCount dataSize, WireReferenceCount referenceCount,
inline StructReader(SegmentReader* segment, const void* data, const WireReference* references,
FieldNumber fieldCount, WordCount dataSize, WireReferenceCount referenceCount,
BitCount bit0Offset, int recursionLimit)
: segment(segment), ptr(ptr), fieldCount(fieldCount), dataSize(dataSize),
referenceCount(referenceCount), bit0Offset(bit0Offset), recursionLimit(recursionLimit) {}
: segment(segment), data(data), references(references), fieldCount(fieldCount),
dataSize(dataSize), referenceCount(referenceCount), bit0Offset(bit0Offset),
recursionLimit(recursionLimit) {}
friend class ListReader;
friend class StructBuilder;
......@@ -354,7 +346,7 @@ inline void StructBuilder::setDataField<bool>(ElementCount offset, bool value) c
template <typename T>
T StructReader::getDataField(ElementCount offset, typename NoInfer<T>::Type defaultValue) const {
if (offset * bytesPerElement<T>() < dataSize * BYTES_PER_WORD) {
return reinterpret_cast<const WireValue<T>*>(ptr)[offset / ELEMENTS].get();
return reinterpret_cast<const WireValue<T>*>(data)[offset / ELEMENTS].get();
} else {
return defaultValue;
}
......@@ -368,7 +360,7 @@ inline bool StructReader::getDataField<bool>(ElementCount offset, bool defaultVa
if (boffset == 0 * BITS) boffset = bit0Offset;
if (boffset < dataSize * BITS_PER_WORD) {
const byte* b = reinterpret_cast<const byte*>(ptr) + boffset / BITS_PER_BYTE;
const byte* b = reinterpret_cast<const byte*>(data) + boffset / BITS_PER_BYTE;
return (*reinterpret_cast<const uint8_t*>(b) & (1 << (boffset % BITS_PER_BYTE / BITS))) != 0;
} else {
return defaultValue;
......@@ -381,7 +373,7 @@ T StructReader::getDataFieldCheckingNumber(
// Intentionally use & rather than && to reduce branches.
if ((fieldNumber < fieldCount) &
(offset * bytesPerElement<T>() < dataSize * BYTES_PER_WORD)) {
return reinterpret_cast<const WireValue<T>*>(ptr)[offset / ELEMENTS].get();
return reinterpret_cast<const WireValue<T>*>(data)[offset / ELEMENTS].get();
} else {
return defaultValue;
}
......@@ -397,7 +389,7 @@ inline bool StructReader::getDataFieldCheckingNumber<bool>(
// Intentionally use & rather than && to reduce branches.
if ((fieldNumber < fieldCount) & (boffset < dataSize * BITS_PER_WORD)) {
const byte* b = reinterpret_cast<const byte*>(ptr) + boffset / BITS_PER_BYTE;
const byte* b = reinterpret_cast<const byte*>(data) + boffset / BITS_PER_BYTE;
return (*reinterpret_cast<const uint8_t*>(b) & (1 << (boffset % BITS_PER_BYTE / BITS))) != 0;
} else {
return defaultValue;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment