Commit 2a37f18c authored by Kenton Varda's avatar Kenton Varda

When init()ing or set()ing a pointer that is already initialized, zero out the old value.

parent e5dc9924
......@@ -1286,6 +1286,24 @@ TEST(Encoding, ListSetters) {
}
}
TEST(Encoding, ZeroOldObject) {
MallocMessageBuilder builder;
auto root = builder.initRoot<TestAllTypes>();
initTestMessage(root);
auto oldRoot = root.asReader();
checkTestMessage(oldRoot);
auto oldSub = oldRoot.getStructField();
auto oldSub2 = oldRoot.getStructList()[0];
root = builder.initRoot<TestAllTypes>();
checkTestMessageAllZero(oldRoot);
checkTestMessageAllZero(oldSub);
checkTestMessageAllZero(oldSub2);
}
} // namespace
} // namespace internal
} // namespace capnproto
......@@ -212,6 +212,8 @@ struct WireHelpers {
static CAPNPROTO_ALWAYS_INLINE(word* allocate(
WirePointer*& ref, SegmentBuilder*& segment, WordCount amount,
WirePointer::Kind kind)) {
if (!ref->isNull()) zeroObject(segment, ref);
word* ptr = segment->allocate(amount);
if (ptr == nullptr) {
......@@ -302,6 +304,122 @@ struct WireHelpers {
// -----------------------------------------------------------------
static void zeroObject(SegmentBuilder* segment, WirePointer* ref) {
// Zero out the pointed-to object. Use when the pointer is about to be overwritten making the
// target object no longer reachable.
switch (ref->kind()) {
case WirePointer::STRUCT:
zeroObject(segment, ref, ref->target());
break;
case WirePointer::LIST:
zeroObject(segment, ref, ref->target());
break;
case WirePointer::FAR: {
segment = segment->getArena()->getSegment(ref->farRef.segmentId.get());
WirePointer* pad =
reinterpret_cast<WirePointer*>(segment->getPtrUnchecked(ref->farPositionInSegment()));
if (ref->isDoubleFar()) {
segment = segment->getArena()->getSegment(pad->farRef.segmentId.get());
zeroObject(segment, pad + 1, segment->getPtrUnchecked(pad->farPositionInSegment()));
memset(pad, 0, sizeof(WirePointer) * 2);
} else {
zeroObject(segment, pad);
memset(pad, 0, sizeof(WirePointer));
}
break;
}
case WirePointer::RESERVED_3:
FAIL_CHECK("Don't know how to handle RESERVED_3.");
break;
}
}
static void zeroObject(SegmentBuilder* segment, WirePointer* tag, word* ptr) {
switch (tag->kind()) {
case WirePointer::STRUCT: {
WirePointer* pointerSection =
reinterpret_cast<WirePointer*>(ptr + tag->structRef.dataSize.get());
uint count = tag->structRef.ptrCount.get() / POINTERS;
for (uint i = 0; i < count; i++) {
zeroObject(segment, pointerSection + i);
}
memset(ptr, 0, tag->structRef.wordSize() * BYTES_PER_WORD / BYTES);
break;
}
case WirePointer::LIST: {
switch (tag->listRef.elementSize()) {
case FieldSize::VOID:
// Nothing.
break;
case FieldSize::BIT:
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES:
memset(ptr, 0,
roundUpToWords(ElementCount64(tag->listRef.elementCount()) *
dataBitsPerElement(tag->listRef.elementSize()))
* BYTES_PER_WORD / BYTES);
break;
case FieldSize::POINTER: {
uint count = tag->listRef.elementCount() / ELEMENTS;
for (uint i = 0; i < count; i++) {
zeroObject(segment, reinterpret_cast<WirePointer*>(ptr) + i);
}
break;
}
case FieldSize::INLINE_COMPOSITE: {
WirePointer* elementTag = reinterpret_cast<WirePointer*>(ptr);
CHECK(elementTag->kind() == WirePointer::STRUCT,
"Don't know how to handle non-STRUCT inline composite.");
WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
word* pos = ptr + POINTER_SIZE_IN_WORDS;
uint count = elementTag->inlineCompositeListElementCount() / ELEMENTS;
for (uint i = 0; i < count; i++) {
pos += dataSize;
for (uint j = 0; j < pointerCount / POINTERS; j++) {
zeroObject(segment, reinterpret_cast<WirePointer*>(pos));
pos += POINTER_SIZE_IN_WORDS;
}
}
memset(ptr, 0, (elementTag->structRef.wordSize() + POINTER_SIZE_IN_WORDS)
* BYTES_PER_WORD / BYTES);
break;
}
}
break;
}
case WirePointer::FAR:
FAIL_CHECK("Unexpected FAR pointer.");
break;
case WirePointer::RESERVED_3:
FAIL_CHECK("Don't know how to handle RESERVED_3.");
break;
}
}
static CAPNPROTO_ALWAYS_INLINE(
void zeroPointerAndFars(SegmentBuilder* segment, WirePointer* ref)) {
// Zero out the pointer itself and, if it is a far pointer, zero the landing pad as well, but
// do not zero the object body. Used when upgrading.
if (ref->kind() == WirePointer::FAR) {
word* pad = segment->getArena()->getSegment(ref->farRef.segmentId.get())
->getPtrUnchecked(ref->farPositionInSegment());
memset(pad, 0, sizeof(WirePointer) * (1 + ref->isDoubleFar()));
}
memset(ref, 0, sizeof(*ref));
}
// -----------------------------------------------------------------
static CAPNPROTO_ALWAYS_INLINE(
void copyStruct(SegmentBuilder* segment, word* dst, const word* src,
WordCount dataSize, WirePointerCount pointerCount)) {
......@@ -513,6 +631,9 @@ struct WireHelpers {
std::max<WirePointerCount>(oldPointerCount, size.pointers);
WordCount totalSize = newDataSize + newPointerCount * WORDS_PER_POINTER;
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(segment, ref);
word* ptr = allocate(ref, segment, totalSize, WirePointer::STRUCT);
ref->structRef.set(newDataSize, newPointerCount);
......@@ -770,6 +891,9 @@ struct WireHelpers {
auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS;
WordCount totalSize = newStep * elementCount;
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef);
word* newPtr = allocate(origRef, origSegment, totalSize + POINTER_SIZE_IN_WORDS,
WirePointer::LIST);
origRef->listRef.setInlineComposite(totalSize);
......@@ -872,6 +996,9 @@ struct WireHelpers {
auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS;
WordCount totalWords = elementCount * newStep;
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef);
word* newPtr = allocate(origRef, origSegment, totalWords + POINTER_SIZE_IN_WORDS,
WirePointer::LIST);
origRef->listRef.setInlineComposite(totalWords);
......@@ -932,6 +1059,9 @@ struct WireHelpers {
WordCount totalWords =
roundUpToWords(BitCount64(newDataSize) * (elementCount / ELEMENTS));
// Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef);
word* newPtr = allocate(origRef, origSegment, totalWords, WirePointer::LIST);
origRef->listRef.set(elementSize.preferredListEncoding, elementCount);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment