Commit 80149744 authored by Kenton Varda's avatar Kenton Varda

SECURITY: Additional CPU amplification case.

Unfortunately, commit 10487060 missed a case of CPU amplification via struct lists with zero-sized elements.

See advisory: https://github.com/sandstorm-io/capnproto/blob/master/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md
parent 63187bd4
...@@ -1427,17 +1427,24 @@ TEST(Encoding, VoidListAmplification) { ...@@ -1427,17 +1427,24 @@ TEST(Encoding, VoidListAmplification) {
} }
TEST(Encoding, EmptyStructListAmplification) { TEST(Encoding, EmptyStructListAmplification) {
MallocMessageBuilder builder; MallocMessageBuilder builder(1024);
builder.initRoot<test::TestAnyPointer>().getAnyPointerField() auto listList = builder.initRoot<test::TestAnyPointer>().getAnyPointerField()
.initAs<List<test::TestEmptyStruct>>(1u << 28); .initAs<List<List<test::TestEmptyStruct>>>(500);
for (uint i = 0; i < listList.size(); i++) {
listList.init(i, 1u << 28);
}
auto segments = builder.getSegmentsForOutput(); auto segments = builder.getSegmentsForOutput();
EXPECT_EQ(1, segments.size()); ASSERT_EQ(1, segments.size());
EXPECT_LT(segments[0].size(), 16); // quite small for such a big list!
SegmentArrayMessageReader reader(builder.getSegmentsForOutput()); SegmentArrayMessageReader reader(builder.getSegmentsForOutput());
auto root = reader.getRoot<test::TestAnyPointer>().getAnyPointerField(); auto root = reader.getRoot<test::TestAnyPointer>();
EXPECT_NONFATAL_FAILURE(root.getAs<List<TestAllTypes>>()); auto listListReader = root.getAnyPointerField().getAs<List<List<TestAllTypes>>>();
EXPECT_NONFATAL_FAILURE(listListReader[0]);
EXPECT_NONFATAL_FAILURE(listListReader[10]);
EXPECT_EQ(segments[0].size() - 1, root.totalSize().wordCount);
} }
TEST(Encoding, Constants) { TEST(Encoding, Constants) {
......
...@@ -555,14 +555,16 @@ struct WireHelpers { ...@@ -555,14 +555,16 @@ struct WireHelpers {
WordCount dataSize = elementTag->structRef.dataSize.get(); WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get(); WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
word* pos = ptr + POINTER_SIZE_IN_WORDS;
uint count = elementTag->inlineCompositeListElementCount() / ELEMENTS; uint count = elementTag->inlineCompositeListElementCount() / ELEMENTS;
for (uint i = 0; i < count; i++) { if (pointerCount > 0 * POINTERS) {
pos += dataSize; word* pos = ptr + POINTER_SIZE_IN_WORDS;
for (uint i = 0; i < count; i++) {
for (uint j = 0; j < pointerCount / POINTERS; j++) { pos += dataSize;
zeroObject(segment, reinterpret_cast<WirePointer*>(pos));
pos += POINTER_SIZE_IN_WORDS; for (uint j = 0; j < pointerCount / POINTERS; j++) {
zeroObject(segment, reinterpret_cast<WirePointer*>(pos));
pos += POINTER_SIZE_IN_WORDS;
}
} }
} }
...@@ -680,8 +682,6 @@ struct WireHelpers { ...@@ -680,8 +682,6 @@ struct WireHelpers {
return result; return result;
} }
result.wordCount += wordCount + POINTER_SIZE_IN_WORDS;
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount(); ElementCount count = elementTag->inlineCompositeListElementCount();
...@@ -690,23 +690,29 @@ struct WireHelpers { ...@@ -690,23 +690,29 @@ struct WireHelpers {
return result; return result;
} }
KJ_REQUIRE(elementTag->structRef.wordSize() / ELEMENTS * auto actualSize = elementTag->structRef.wordSize() / ELEMENTS * ElementCount64(count);
ElementCount64(count) <= wordCount, KJ_REQUIRE(actualSize <= wordCount,
"Struct list pointer's elements overran size.") { "Struct list pointer's elements overran size.") {
return result; return result;
} }
// We count the actual size rather than the claimed word count because that's what
// we'll end up with if we make a copy.
result.wordCount += actualSize + POINTER_SIZE_IN_WORDS;
WordCount dataSize = elementTag->structRef.dataSize.get(); WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get(); WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
const word* pos = ptr + POINTER_SIZE_IN_WORDS; if (pointerCount > 0 * POINTERS) {
for (uint i = 0; i < count / ELEMENTS; i++) { const word* pos = ptr + POINTER_SIZE_IN_WORDS;
pos += dataSize; for (uint i = 0; i < count / ELEMENTS; i++) {
pos += dataSize;
for (uint j = 0; j < pointerCount / POINTERS; j++) { for (uint j = 0; j < pointerCount / POINTERS; j++) {
result += totalSize(segment, reinterpret_cast<const WirePointer*>(pos), result += totalSize(segment, reinterpret_cast<const WirePointer*>(pos),
nestingLimit); nestingLimit);
pos += POINTER_SIZE_IN_WORDS; pos += POINTER_SIZE_IN_WORDS;
}
} }
} }
break; break;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment