Commit e5dc9924 authored by Kenton Varda's avatar Kenton Varda

When upgrading objects in builder, make sure to zero out the old location.

parent 35adf54b
......@@ -612,6 +612,8 @@ TEST(Encoding, UpgradeStructInBuilder) {
MallocMessageBuilder builder;
auto root = builder.initRoot<test::TestObject>();
test::TestOldVersion::Reader oldReader;
{
auto oldVersion = root.initObjectField<test::TestOldVersion>();
oldVersion.setOld1(123);
......@@ -619,6 +621,8 @@ TEST(Encoding, UpgradeStructInBuilder) {
auto sub = oldVersion.initOld3();
sub.setOld1(456);
sub.setOld2("bar");
oldReader = oldVersion;
}
size_t size = builder.getSegmentsForOutput()[0].size();
......@@ -627,6 +631,12 @@ TEST(Encoding, UpgradeStructInBuilder) {
{
auto newVersion = root.getObjectField<test::TestNewVersion>();
// The old instance should have been zero'd.
EXPECT_EQ(0, oldReader.getOld1());
EXPECT_EQ("", oldReader.getOld2());
EXPECT_EQ(0, oldReader.getOld3().getOld1());
EXPECT_EQ("", oldReader.getOld3().getOld2());
// Size should have increased due to re-allocating the struct.
size_t size1 = builder.getSegmentsForOutput()[0].size();
EXPECT_GT(size1, size);
......@@ -857,19 +867,13 @@ TEST(Encoding, UpgradeStructInBuilderDoubleFarPointers) {
EXPECT_EQ(2u, builder.getSegmentsForOutput()[2].size());
}
void checkList(List<test::TestNewVersion>::Builder builder,
void checkList(List<test::TestOldVersion>::Reader reader,
std::initializer_list<int64_t> expectedData,
std::initializer_list<Text::Reader> expectedPointers) {
ASSERT_EQ(expectedData.size(), builder.size());
ASSERT_EQ(expectedData.size(), reader.size());
for (uint i = 0; i < expectedData.size(); i++) {
EXPECT_EQ(expectedData.begin()[i], builder[i].getOld1());
EXPECT_EQ(expectedPointers.begin()[i], builder[i].getOld2());
// Other fields shouldn't be set.
EXPECT_EQ(0, builder[i].asReader().getOld3().getOld1());
EXPECT_EQ("", builder[i].asReader().getOld3().getOld2());
EXPECT_EQ(987, builder[i].getNew1());
EXPECT_EQ("baz", builder[i].getNew2());
EXPECT_EQ(expectedData.begin()[i], reader[i].getOld1());
EXPECT_EQ(expectedPointers.begin()[i], reader[i].getOld2());
}
}
......@@ -939,11 +943,13 @@ TEST(Encoding, UpgradeListInBuilder) {
EXPECT_ANY_THROW(root.getObjectField<List<uint32_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<uint64_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<Text>>());
checkList(root.getObjectField<List<test::TestNewVersion>>(), {0, 0, 0, 0}, {"", "", "", ""});
checkUpgradedList(root, {0, 0, 0, 0}, {"", "", "", ""});
// -----------------------------------------------------------------
{
root.setObjectField<List<bool>>({true, false, true, true});
auto orig = root.asReader().getObjectField<List<bool>>();
checkList(root.getObjectField<List<Void>>(), {Void::VOID, Void::VOID, Void::VOID, Void::VOID});
checkList(root.getObjectField<List<bool>>(), {true, false, true, true});
EXPECT_ANY_THROW(root.getObjectField<List<uint8_t>>());
......@@ -951,11 +957,17 @@ TEST(Encoding, UpgradeListInBuilder) {
EXPECT_ANY_THROW(root.getObjectField<List<uint32_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<uint64_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<Text>>());
checkList(orig, {true, false, true, true});
checkUpgradedList(root, {1, 0, 1, 1}, {"", "", "", ""});
checkList(orig, {false, false, false, false}); // old location zero'd during upgrade
}
// -----------------------------------------------------------------
{
root.setObjectField<List<uint8_t>>({0x12, 0x23, 0x33, 0x44});
auto orig = root.asReader().getObjectField<List<uint8_t>>();
checkList(root.getObjectField<List<Void>>(), {Void::VOID, Void::VOID, Void::VOID, Void::VOID});
checkList(root.getObjectField<List<bool>>(), {false, true, true, false});
checkList(root.getObjectField<List<uint8_t>>(), {0x12, 0x23, 0x33, 0x44});
......@@ -963,11 +975,17 @@ TEST(Encoding, UpgradeListInBuilder) {
EXPECT_ANY_THROW(root.getObjectField<List<uint32_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<uint64_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<Text>>());
checkList(orig, {0x12, 0x23, 0x33, 0x44});
checkUpgradedList(root, {0x12, 0x23, 0x33, 0x44}, {"", "", "", ""});
checkList(orig, {0, 0, 0, 0}); // old location zero'd during upgrade
}
// -----------------------------------------------------------------
{
root.setObjectField<List<uint16_t>>({0x5612, 0x7823, 0xab33, 0xcd44});
auto orig = root.asReader().getObjectField<List<uint16_t>>();
checkList(root.getObjectField<List<Void>>(), {Void::VOID, Void::VOID, Void::VOID, Void::VOID});
checkList(root.getObjectField<List<bool>>(), {false, true, true, false});
checkList(root.getObjectField<List<uint8_t>>(), {0x12, 0x23, 0x33, 0x44});
......@@ -975,11 +993,17 @@ TEST(Encoding, UpgradeListInBuilder) {
EXPECT_ANY_THROW(root.getObjectField<List<uint32_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<uint64_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<Text>>());
checkList(orig, {0x5612, 0x7823, 0xab33, 0xcd44});
checkUpgradedList(root, {0x5612, 0x7823, 0xab33, 0xcd44}, {"", "", "", ""});
checkList(orig, {0, 0, 0, 0}); // old location zero'd during upgrade
}
// -----------------------------------------------------------------
{
root.setObjectField<List<uint32_t>>({0x17595612, 0x29347823, 0x5923ab32, 0x1a39cd45});
auto orig = root.asReader().getObjectField<List<uint32_t>>();
checkList(root.getObjectField<List<Void>>(), {Void::VOID, Void::VOID, Void::VOID, Void::VOID});
checkList(root.getObjectField<List<bool>>(), {false, true, false, true});
checkList(root.getObjectField<List<uint8_t>>(), {0x12, 0x23, 0x32, 0x45});
......@@ -987,11 +1011,17 @@ TEST(Encoding, UpgradeListInBuilder) {
checkList(root.getObjectField<List<uint32_t>>(), {0x17595612u, 0x29347823u, 0x5923ab32u, 0x1a39cd45u});
EXPECT_ANY_THROW(root.getObjectField<List<uint64_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<Text>>());
checkList(orig, {0x17595612u, 0x29347823u, 0x5923ab32u, 0x1a39cd45u});
checkUpgradedList(root, {0x17595612, 0x29347823, 0x5923ab32, 0x1a39cd45}, {"", "", "", ""});
checkList(orig, {0u, 0u, 0u, 0u}); // old location zero'd during upgrade
}
// -----------------------------------------------------------------
{
root.setObjectField<List<uint64_t>>({0x1234abcd8735fe21, 0x7173bc0e1923af36});
auto orig = root.asReader().getObjectField<List<uint64_t>>();
checkList(root.getObjectField<List<Void>>(), {Void::VOID, Void::VOID});
checkList(root.getObjectField<List<bool>>(), {true, false});
checkList(root.getObjectField<List<uint8_t>>(), {0x21, 0x36});
......@@ -999,11 +1029,17 @@ TEST(Encoding, UpgradeListInBuilder) {
checkList(root.getObjectField<List<uint32_t>>(), {0x8735fe21u, 0x1923af36u});
checkList(root.getObjectField<List<uint64_t>>(), {0x1234abcd8735fe21ull, 0x7173bc0e1923af36ull});
EXPECT_ANY_THROW(root.getObjectField<List<Text>>());
checkList(orig, {0x1234abcd8735fe21ull, 0x7173bc0e1923af36ull});
checkUpgradedList(root, {0x1234abcd8735fe21ull, 0x7173bc0e1923af36ull}, {"", ""});
checkList(orig, {0u, 0u}); // old location zero'd during upgrade
}
// -----------------------------------------------------------------
{
root.setObjectField<List<Text>>({"foo", "bar", "baz"});
auto orig = root.asReader().getObjectField<List<Text>>();
checkList(root.getObjectField<List<Void>>(), {Void::VOID, Void::VOID, Void::VOID});
EXPECT_ANY_THROW(root.getObjectField<List<bool>>());
EXPECT_ANY_THROW(root.getObjectField<List<uint8_t>>());
......@@ -1011,22 +1047,15 @@ TEST(Encoding, UpgradeListInBuilder) {
EXPECT_ANY_THROW(root.getObjectField<List<uint32_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<uint64_t>>());
checkList(root.getObjectField<List<Text>>(), {"foo", "bar", "baz"});
checkUpgradedList(root, {0, 0, 0}, {"foo", "bar", "baz"});
// -----------------------------------------------------------------
root.setObjectField<List<Text>>({"foo", "bar", "baz"});
checkList(root.getObjectField<List<Void>>(), {Void::VOID, Void::VOID, Void::VOID});
EXPECT_ANY_THROW(root.getObjectField<List<bool>>());
EXPECT_ANY_THROW(root.getObjectField<List<uint8_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<uint16_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<uint32_t>>());
EXPECT_ANY_THROW(root.getObjectField<List<uint64_t>>());
checkList(root.getObjectField<List<Text>>(), {"foo", "bar", "baz"});
checkList(orig, {"foo", "bar", "baz"});
checkUpgradedList(root, {0, 0, 0}, {"foo", "bar", "baz"});
checkList(orig, {"", "", ""}); // old location zero'd during upgrade
}
// -----------------------------------------------------------------
{
{
auto l = root.initObjectField<List<test::TestOldVersion>>(3);
l[0].setOld1(0x1234567890abcdef);
......@@ -1036,6 +1065,7 @@ TEST(Encoding, UpgradeListInBuilder) {
l[1].setOld2("bar");
l[2].setOld2("baz");
}
auto orig = root.asReader().getObjectField<List<test::TestOldVersion>>();
checkList(root.getObjectField<List<Void>>(), {Void::VOID, Void::VOID, Void::VOID});
checkList(root.getObjectField<List<bool>>(), {true, true, false});
......@@ -1045,8 +1075,13 @@ TEST(Encoding, UpgradeListInBuilder) {
checkList(root.getObjectField<List<uint64_t>>(),
{0x1234567890abcdefull, 0x234567890abcdef1ull, 0x34567890abcdef12ull});
checkList(root.getObjectField<List<Text>>(), {"foo", "bar", "baz"});
checkList(orig, {0x1234567890abcdefull, 0x234567890abcdef1ull, 0x34567890abcdef12ull},
{"foo", "bar", "baz"});
checkUpgradedList(root, {0x1234567890abcdefull, 0x234567890abcdef1ull, 0x34567890abcdef12ull},
{"foo", "bar", "baz"});
checkList(orig, {0u, 0u, 0u}, {"", "", ""}); // old location zero'd during upgrade
}
// -----------------------------------------------------------------
// OK, now we've tested upgrading every primitive list to every primitive list, every primitive
......@@ -1056,7 +1091,10 @@ TEST(Encoding, UpgradeListInBuilder) {
// Upgrade from bool.
root.setObjectField<List<bool>>({true, false, true, true});
{
auto orig = root.asReader().getObjectField<List<bool>>();
checkList(orig, {true, false, true, true});
auto l = root.getObjectField<List<test::TestLists::Struct16>>();
checkList(orig, {false, false, false, false}); // old location zero'd during upgrade
ASSERT_EQ(4u, l.size());
EXPECT_EQ(1u, l[0].getF());
EXPECT_EQ(0u, l[1].getF());
......@@ -1076,7 +1114,10 @@ TEST(Encoding, UpgradeListInBuilder) {
// Upgrade from multi-byte, sub-word data.
root.setObjectField<List<uint16_t>>({12u, 34u, 56u, 78u});
{
auto orig = root.asReader().getObjectField<List<uint16_t>>();
checkList(orig, {12u, 34u, 56u, 78u});
auto l = root.getObjectField<List<test::TestLists::Struct32>>();
checkList(orig, {0u, 0u, 0u, 0u}); // old location zero'd during upgrade
ASSERT_EQ(4u, l.size());
EXPECT_EQ(12u, l[0].getF());
EXPECT_EQ(34u, l[1].getF());
......
......@@ -476,10 +476,9 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(StructBuilder getWritableStructPointer(
WirePointer* ref, SegmentBuilder* segment, StructSize size, const word* defaultValue)) {
word* ptr;
if (ref->isNull()) {
useDefault:
word* ptr;
if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
ptr = allocate(ref, segment, size.total(), WirePointer::STRUCT);
......@@ -514,7 +513,7 @@ struct WireHelpers {
std::max<WirePointerCount>(oldPointerCount, size.pointers);
WordCount totalSize = newDataSize + newPointerCount * WORDS_PER_POINTER;
ptr = allocate(ref, segment, totalSize, WirePointer::STRUCT);
word* ptr = allocate(ref, segment, totalSize, WirePointer::STRUCT);
ref->structRef.set(newDataSize, newPointerCount);
// Copy data section.
......@@ -526,6 +525,14 @@ struct WireHelpers {
transferPointer(segment, newPointerSection + i, oldSegment, oldPointerSection + i);
}
// Zero out old location. This has two purposes:
// 1) We don't want to leak the original contents of the struct when the message is written
// out as it may contain secrets that the caller intends to remove from the new copy.
// 2) Zeros will be deflated by packing, making this dead memory almost-free if it ever
// hits the wire.
memset(oldPtr, 0,
(oldDataSize + oldPointerCount * WORDS_PER_POINTER) * BYTES_PER_WORD / BYTES);
return StructBuilder(segment, ptr, newPointerSection, newDataSize * BITS_PER_WORD,
newPointerCount, 0 * BITS);
} else {
......@@ -597,7 +604,6 @@ struct WireHelpers {
useDefault:
if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
memset(origRef, 0, sizeof(*origRef));
return ListBuilder();
}
word* ptr = copyMessage(origSegment, origRef,
......@@ -697,7 +703,6 @@ struct WireHelpers {
useDefault:
if (defaultValue == nullptr ||
reinterpret_cast<const WirePointer*>(defaultValue)->isNull()) {
memset(origRef, 0, sizeof(*origRef));
return ListBuilder();
}
word* ptr = copyMessage(origSegment, origRef,
......@@ -791,6 +796,9 @@ struct WireHelpers {
src += oldStep * (1 * ELEMENTS);
}
// Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, oldStep * elementCount * BYTES_PER_WORD / BYTES);
return ListBuilder(origSegment, newPtr, newStep * BITS_PER_WORD, elementCount,
newDataSize * BITS_PER_WORD, newPointerCount);
} else if (oldSize == elementSize.preferredListEncoding) {
......@@ -899,6 +907,9 @@ struct WireHelpers {
}
}
// Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, roundUpToBytes(oldStep * elementCount) / BYTES);
return ListBuilder(origSegment, newPtr, newStep * BITS_PER_WORD, elementCount,
newDataSize * BITS_PER_WORD, newPointerCount);
......@@ -941,6 +952,9 @@ struct WireHelpers {
}
}
// Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, roundUpToBytes(oldStep * elementCount) / BYTES);
return ListBuilder(origSegment, newPtr, newDataSize / ELEMENTS, elementCount,
newDataSize, 0 * POINTERS);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment