Commit 44a56220 authored by Kenton Varda's avatar Kenton Varda

Remove hacks that support List(Bool) -> List(struct) upgrades (no longer allowed).

Extended discussion:
https://groups.google.com/d/msg/capnproto/lRlWBOglQv4/8-Qo96AcZQIJ
parent b469079d
......@@ -606,12 +606,14 @@ TEST(Encoding, ListUpgrade) {
}
TEST(Encoding, BitListDowngrade) {
// NO LONGER SUPPORTED -- We check for exceptions thrown.
MallocMessageBuilder builder;
auto root = builder.initRoot<test::TestAnyPointer>();
root.getAnyPointerField().setAs<List<uint16_t>>({0x1201u, 0x3400u, 0x5601u, 0x7801u});
checkList(root.getAnyPointerField().getAs<List<bool>>(), {true, false, true, true});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
{
auto l = root.getAnyPointerField().getAs<List<test::TestLists::Struct1>>();
......@@ -627,7 +629,7 @@ TEST(Encoding, BitListDowngrade) {
auto reader = root.asReader();
checkList(reader.getAnyPointerField().getAs<List<bool>>(), {true, false, true, true});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
{
auto l = reader.getAnyPointerField().getAs<List<test::TestLists::Struct1>>();
......@@ -654,7 +656,7 @@ TEST(Encoding, BitListDowngradeFromStruct) {
list[3].setF(true);
}
checkList(root.getAnyPointerField().getAs<List<bool>>(), {true, false, true, true});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
{
auto l = root.getAnyPointerField().getAs<List<test::TestLists::Struct1>>();
......@@ -667,7 +669,7 @@ TEST(Encoding, BitListDowngradeFromStruct) {
auto reader = root.asReader();
checkList(reader.getAnyPointerField().getAs<List<bool>>(), {true, false, true, true});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
{
auto l = reader.getAnyPointerField().getAs<List<test::TestLists::Struct1>>();
......@@ -1054,7 +1056,7 @@ TEST(Encoding, UpgradeListInBuilder) {
{
root.getAnyPointerField().setAs<List<bool>>({true, false, true, true});
auto orig = root.asReader().getAnyPointerField().getAs<List<bool>>();
checkList(root.getAnyPointerField().getAs<List<Void>>(), {VOID, VOID, VOID, VOID});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<Void>>());
checkList(root.getAnyPointerField().getAs<List<bool>>(), {true, false, true, true});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<uint8_t>>());
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<uint16_t>>());
......@@ -1074,7 +1076,7 @@ TEST(Encoding, UpgradeListInBuilder) {
root.getAnyPointerField().setAs<List<uint8_t>>({0x12, 0x23, 0x33, 0x44});
auto orig = root.asReader().getAnyPointerField().getAs<List<uint8_t>>();
checkList(root.getAnyPointerField().getAs<List<Void>>(), {VOID, VOID, VOID, VOID});
checkList(root.getAnyPointerField().getAs<List<bool>>(), {false, true, true, false});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
checkList(root.getAnyPointerField().getAs<List<uint8_t>>(), {0x12, 0x23, 0x33, 0x44});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<uint16_t>>());
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<uint32_t>>());
......@@ -1092,7 +1094,7 @@ TEST(Encoding, UpgradeListInBuilder) {
root.getAnyPointerField().setAs<List<uint16_t>>({0x5612, 0x7823, 0xab33, 0xcd44});
auto orig = root.asReader().getAnyPointerField().getAs<List<uint16_t>>();
checkList(root.getAnyPointerField().getAs<List<Void>>(), {VOID, VOID, VOID, VOID});
checkList(root.getAnyPointerField().getAs<List<bool>>(), {false, true, true, false});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
checkList(root.getAnyPointerField().getAs<List<uint8_t>>(), {0x12, 0x23, 0x33, 0x44});
checkList(root.getAnyPointerField().getAs<List<uint16_t>>(), {0x5612, 0x7823, 0xab33, 0xcd44});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<uint32_t>>());
......@@ -1110,7 +1112,7 @@ TEST(Encoding, UpgradeListInBuilder) {
root.getAnyPointerField().setAs<List<uint32_t>>({0x17595612, 0x29347823, 0x5923ab32, 0x1a39cd45});
auto orig = root.asReader().getAnyPointerField().getAs<List<uint32_t>>();
checkList(root.getAnyPointerField().getAs<List<Void>>(), {VOID, VOID, VOID, VOID});
checkList(root.getAnyPointerField().getAs<List<bool>>(), {false, true, false, true});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
checkList(root.getAnyPointerField().getAs<List<uint8_t>>(), {0x12, 0x23, 0x32, 0x45});
checkList(root.getAnyPointerField().getAs<List<uint16_t>>(), {0x5612, 0x7823, 0xab32, 0xcd45});
checkList(root.getAnyPointerField().getAs<List<uint32_t>>(), {0x17595612u, 0x29347823u, 0x5923ab32u, 0x1a39cd45u});
......@@ -1128,7 +1130,7 @@ TEST(Encoding, UpgradeListInBuilder) {
root.getAnyPointerField().setAs<List<uint64_t>>({0x1234abcd8735fe21, 0x7173bc0e1923af36});
auto orig = root.asReader().getAnyPointerField().getAs<List<uint64_t>>();
checkList(root.getAnyPointerField().getAs<List<Void>>(), {VOID, VOID});
checkList(root.getAnyPointerField().getAs<List<bool>>(), {true, false});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
checkList(root.getAnyPointerField().getAs<List<uint8_t>>(), {0x21, 0x36});
checkList(root.getAnyPointerField().getAs<List<uint16_t>>(), {0xfe21, 0xaf36});
checkList(root.getAnyPointerField().getAs<List<uint32_t>>(), {0x8735fe21u, 0x1923af36u});
......@@ -1173,7 +1175,7 @@ TEST(Encoding, UpgradeListInBuilder) {
auto orig = root.asReader().getAnyPointerField().getAs<List<test::TestOldVersion>>();
checkList(root.getAnyPointerField().getAs<List<Void>>(), {VOID, VOID, VOID});
checkList(root.getAnyPointerField().getAs<List<bool>>(), {true, true, false});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
checkList(root.getAnyPointerField().getAs<List<uint8_t>>(), {0xefu, 0xf1u, 0x12u});
checkList(root.getAnyPointerField().getAs<List<uint16_t>>(), {0xcdefu, 0xdef1u, 0xef12u});
checkList(root.getAnyPointerField().getAs<List<uint32_t>>(), {0x90abcdefu, 0x0abcdef1u, 0xabcdef12u});
......@@ -1210,7 +1212,7 @@ TEST(Encoding, UpgradeListInBuilder) {
l[2].setF(0x33423082u);
l[3].setF(0x12988948u);
}
checkList(root.getAnyPointerField().getAs<List<bool>>(), {true, true, false, false});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
checkList(root.getAnyPointerField().getAs<List<uint8_t>>(), {0x35u, 0x79u, 0x82u, 0x48u});
checkList(root.getAnyPointerField().getAs<List<uint16_t>>(), {0x1235u, 0x2879u, 0x3082u, 0x8948u});
checkList(root.getAnyPointerField().getAs<List<uint32_t>>(),
......@@ -1233,7 +1235,7 @@ TEST(Encoding, UpgradeListInBuilder) {
l[2].setF(9238);
l[3].setF(5832);
}
checkList(root.getAnyPointerField().getAs<List<bool>>(), {true, true, false, false});
EXPECT_NONFATAL_FAILURE(root.getAnyPointerField().getAs<List<bool>>());
checkList(root.getAnyPointerField().getAs<List<uint16_t>>(), {12573u, 3251u, 9238u, 5832u});
checkList(root.getAnyPointerField().getAs<List<uint32_t>>(), {12573u, 3251u, 9238u, 5832u});
checkList(root.getAnyPointerField().getAs<List<uint64_t>>(), {12573u, 3251u, 9238u, 5832u});
......
......@@ -292,16 +292,12 @@ struct WireHelpers {
return (bits + 7 * BITS) / BITS_PER_BYTE;
}
// The maximum object size is 4GB - 1 byte. If measured in bits, this would overflow a 32-bit
// counter, so we need to accept BitCount64. However, 32 bits is enough for the returned
// ByteCounts and WordCounts.
static KJ_ALWAYS_INLINE(WordCount roundBitsUpToWords(BitCount64 bits)) {
static KJ_ALWAYS_INLINE(WordCount64 roundBitsUpToWords(BitCount64 bits)) {
static_assert(sizeof(word) == 8, "This code assumes 64-bit words.");
return (bits + 63 * BITS) / BITS_PER_WORD;
}
static KJ_ALWAYS_INLINE(ByteCount roundBitsUpToBytes(BitCount64 bits)) {
static KJ_ALWAYS_INLINE(ByteCount64 roundBitsUpToBytes(BitCount64 bits)) {
return (bits + 7 * BITS) / BITS_PER_BYTE;
}
......@@ -645,7 +641,7 @@ struct WireHelpers {
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES: {
WordCount totalWords = roundBitsUpToWords(
WordCount64 totalWords = roundBitsUpToWords(
ElementCount64(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize()));
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + totalWords),
......@@ -923,7 +919,7 @@ struct WireHelpers {
// Build the StructBuilder.
return StructBuilder(segment, ptr, reinterpret_cast<WirePointer*>(ptr + size.data),
size.data * BITS_PER_WORD, size.pointers, 0 * BITS);
size.data * BITS_PER_WORD, size.pointers);
}
static KJ_ALWAYS_INLINE(StructBuilder getWritableStructPointer(
......@@ -992,10 +988,10 @@ struct WireHelpers {
(oldDataSize + oldPointerCount * WORDS_PER_POINTER) * BYTES_PER_WORD / BYTES);
return StructBuilder(segment, ptr, newPointerSection, newDataSize * BITS_PER_WORD,
newPointerCount, 0 * BITS);
newPointerCount);
} else {
return StructBuilder(oldSegment, oldPtr, oldPointerSection, oldDataSize * BITS_PER_WORD,
oldPointerCount, 0 * BITS);
oldPointerCount);
}
}
......@@ -1109,6 +1105,13 @@ struct WireHelpers {
break;
case FieldSize::BIT:
KJ_FAIL_REQUIRE(
"Found struct list where bit list was expected; upgrading boolean lists to structs "
"is no longer supported.") {
goto useDefault;
}
break;
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
......@@ -1142,13 +1145,24 @@ struct WireHelpers {
BitCount dataSize = dataBitsPerElement(oldSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(oldSize) * ELEMENTS;
KJ_REQUIRE(dataSize >= dataBitsPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type.") {
goto useDefault;
}
KJ_REQUIRE(pointerCount >= pointersPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type.") {
goto useDefault;
if (elementSize == FieldSize::BIT) {
KJ_REQUIRE(oldSize == FieldSize::BIT,
"Found non-bit list where bit list was expected.") {
goto useDefault;
}
} else {
KJ_REQUIRE(oldSize != FieldSize::BIT,
"Found bit list where non-bit list was expected.") {
goto useDefault;
}
KJ_REQUIRE(dataSize >= dataBitsPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type.") {
goto useDefault;
}
KJ_REQUIRE(pointerCount >= pointersPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type.") {
goto useDefault;
}
}
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
......@@ -1576,7 +1590,7 @@ struct WireHelpers {
reinterpret_cast<const WirePointer*>(ptr + src->structRef.dataSize.get()),
src->structRef.dataSize.get() * BITS_PER_WORD,
src->structRef.ptrCount.get(),
0 * BITS, nestingLimit - 1),
nestingLimit - 1),
orphanArena);
case WirePointer::LIST: {
......@@ -1621,7 +1635,7 @@ struct WireHelpers {
WirePointerCount pointerCount = pointersPerElement(elementSize) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
ElementCount elementCount = src->listRef.elementCount();
WordCount wordCount = roundBitsUpToWords(ElementCount64(elementCount) * step);
WordCount64 wordCount = roundBitsUpToWords(ElementCount64(elementCount) * step);
KJ_REQUIRE(boundsCheck(srcSegment, ptr, ptr + wordCount),
"Message contains out-of-bounds list pointer.") {
......@@ -1760,7 +1774,7 @@ struct WireHelpers {
segment, ptr, reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get()),
ref->structRef.dataSize.get() * BITS_PER_WORD,
ref->structRef.ptrCount.get(),
0 * BITS, nestingLimit - 1);
nestingLimit - 1);
}
#if !CAPNP_LITE
......@@ -1871,6 +1885,13 @@ struct WireHelpers {
break;
case FieldSize::BIT:
KJ_FAIL_REQUIRE(
"Found struct list where bit list was expected; upgrading boolean lists to structs "
"is no longer supported.") {
goto useDefault;
}
break;
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
......@@ -2304,7 +2325,7 @@ void StructBuilder::copyContentFrom(StructReader other) {
StructReader StructBuilder::asReader() const {
return StructReader(segment, data, pointers,
dataSize, pointerCount, bit0Offset, kj::maxValue);
dataSize, pointerCount, kj::maxValue);
}
BuilderArena* StructBuilder::getArena() {
......@@ -2368,9 +2389,10 @@ Data::Builder ListBuilder::asData() {
StructBuilder ListBuilder::getStructElement(ElementCount index) {
BitCount64 indexBit = ElementCount64(index) * step;
byte* structData = ptr + indexBit / BITS_PER_BYTE;
KJ_DASSERT(indexBit % BITS_PER_BYTE == 0 * BITS);
return StructBuilder(segment, structData,
reinterpret_cast<WirePointer*>(structData + structDataSize / BITS_PER_BYTE),
structDataSize, structPointerCount, indexBit % BITS_PER_BYTE);
structDataSize, structPointerCount);
}
ListReader ListBuilder::asReader() const {
......@@ -2432,10 +2454,11 @@ StructReader ListReader::getStructElement(ElementCount index) const {
(uintptr_t)structPointers % sizeof(void*) == 0,
"Pointer section of struct list element not aligned.");
KJ_DASSERT(indexBit % BITS_PER_BYTE == 0 * BITS);
return StructReader(
segment, structData, structPointers,
structDataSize, structPointerCount,
indexBit % BITS_PER_BYTE, nestingLimit - 1);
nestingLimit - 1);
}
// =======================================================================================
......
......@@ -443,7 +443,7 @@ private:
class StructBuilder: public kj::DisallowConstCopy {
public:
inline StructBuilder(): segment(nullptr), data(nullptr), pointers(nullptr), bit0Offset(0) {}
inline StructBuilder(): segment(nullptr), data(nullptr), pointers(nullptr) {}
inline word* getLocation() { return reinterpret_cast<word*>(data); }
// Get the object's location. Only valid for independently-allocated objects (i.e. not list
......@@ -511,15 +511,10 @@ private:
WirePointerCount16 pointerCount; // Size of the pointer section.
BitCount8 bit0Offset;
// A special hack: If dataSize == 1 bit, then bit0Offset is the offset of that bit within the
// byte pointed to by `data`. In all other cases, this is zero. This is needed to implement
// struct lists where each struct is one bit.
inline StructBuilder(SegmentBuilder* segment, void* data, WirePointer* pointers,
BitCount dataSize, WirePointerCount pointerCount, BitCount8 bit0Offset)
BitCount dataSize, WirePointerCount pointerCount)
: segment(segment), data(data), pointers(pointers),
dataSize(dataSize), pointerCount(pointerCount), bit0Offset(bit0Offset) {}
dataSize(dataSize), pointerCount(pointerCount) {}
friend class ListBuilder;
friend struct WireHelpers;
......@@ -530,7 +525,7 @@ class StructReader {
public:
inline StructReader()
: segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0),
pointerCount(0), bit0Offset(0), nestingLimit(0x7fffffff) {}
pointerCount(0), nestingLimit(0x7fffffff) {}
const void* getLocation() const { return data; }
......@@ -577,26 +572,15 @@ private:
WirePointerCount16 pointerCount; // Size of the pointer section.
BitCount8 bit0Offset;
// A special hack: If dataSize == 1 bit, then bit0Offset is the offset of that bit within the
// byte pointed to by `data`. In all other cases, this is zero. This is needed to implement
// struct lists where each struct is one bit.
//
// TODO(someday): Consider packing this together with dataSize, since we have 10 extra bits
// there doing nothing -- or arguably 12 bits, if you consider that 2-bit and 4-bit sizes
// aren't allowed. Consider that we could have a method like getDataSizeIn<T>() which is
// specialized to perform the correct shifts for each size.
int nestingLimit;
// Limits the depth of message structures to guard against stack-overflow-based DoS attacks.
// Once this reaches zero, further pointers will be pruned.
// TODO(perf): Limit to 8 bits for better alignment?
// TODO(perf): Limit to 16 bits for better packing?
inline StructReader(SegmentReader* segment, const void* data, const WirePointer* pointers,
BitCount dataSize, WirePointerCount pointerCount, BitCount8 bit0Offset,
int nestingLimit)
BitCount dataSize, WirePointerCount pointerCount, int nestingLimit)
: segment(segment), data(data), pointers(pointers),
dataSize(dataSize), pointerCount(pointerCount), bit0Offset(bit0Offset),
dataSize(dataSize), pointerCount(pointerCount),
nestingLimit(nestingLimit) {}
friend class ListReader;
......@@ -879,9 +863,7 @@ inline T StructBuilder::getDataField(ElementCount offset) {
template <>
inline bool StructBuilder::getDataField<bool>(ElementCount offset) {
// This branch should be compiled out whenever this is inlined with a constant offset.
BitCount boffset = (offset == 0 * ELEMENTS) ?
BitCount(bit0Offset) : offset * (1 * BITS / ELEMENTS);
BitCount boffset = offset * (1 * BITS / ELEMENTS);
byte* b = reinterpret_cast<byte*>(data) + boffset / BITS_PER_BYTE;
return (*reinterpret_cast<uint8_t*>(b) & (1 << (boffset % BITS_PER_BYTE / BITS))) != 0;
}
......@@ -915,9 +897,7 @@ inline void StructBuilder::setDataField<double>(ElementCount offset, double valu
template <>
inline void StructBuilder::setDataField<bool>(ElementCount offset, bool value) {
// This branch should be compiled out whenever this is inlined with a constant offset.
BitCount boffset = (offset == 0 * ELEMENTS) ?
BitCount(bit0Offset) : offset * (1 * BITS / ELEMENTS);
BitCount boffset = offset * (1 * BITS / ELEMENTS);
byte* b = reinterpret_cast<byte*>(data) + boffset / BITS_PER_BYTE;
uint bitnum = boffset % BITS_PER_BYTE / BITS;
*reinterpret_cast<uint8_t*>(b) = (*reinterpret_cast<uint8_t*>(b) & ~(1 << bitnum))
......@@ -967,10 +947,6 @@ template <>
inline bool StructReader::getDataField<bool>(ElementCount offset) const {
BitCount boffset = offset * (1 * BITS / ELEMENTS);
if (boffset < dataSize) {
// This branch should be compiled out whenever this is inlined with a constant offset.
if (offset == 0 * ELEMENTS) {
boffset = bit0Offset;
}
const byte* b = reinterpret_cast<const byte*>(data) + boffset / BITS_PER_BYTE;
return (*reinterpret_cast<const uint8_t*>(b) & (1 << (boffset % BITS_PER_BYTE / BITS))) != 0;
} else {
......@@ -1016,8 +992,8 @@ inline T ListBuilder::getDataElement(ElementCount index) {
template <>
inline bool ListBuilder::getDataElement<bool>(ElementCount index) {
// Ignore stepBytes for bit lists because bit lists cannot be upgraded to struct lists.
BitCount bindex = index * step;
// Ignore step for bit lists because bit lists cannot be upgraded to struct lists.
BitCount bindex = index * (1 * BITS / ELEMENTS);
byte* b = ptr + bindex / BITS_PER_BYTE;
return (*reinterpret_cast<uint8_t*>(b) & (1 << (bindex % BITS_PER_BYTE / BITS))) != 0;
}
......@@ -1073,8 +1049,8 @@ inline T ListReader::getDataElement(ElementCount index) const {
template <>
inline bool ListReader::getDataElement<bool>(ElementCount index) const {
// Ignore stepBytes for bit lists because bit lists cannot be upgraded to struct lists.
BitCount bindex = index * step;
// Ignore step for bit lists because bit lists cannot be upgraded to struct lists.
BitCount bindex = index * (1 * BITS / ELEMENTS);
const byte* b = ptr + bindex / BITS_PER_BYTE;
return (*reinterpret_cast<const uint8_t*>(b) & (1 << (bindex % BITS_PER_BYTE / BITS))) != 0;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment