Commit 80149744 authored by Kenton Varda's avatar Kenton Varda

SECURITY: Additional CPU amplification case.

Unfortunately, commit 10487060 missed a case of CPU amplification via struct lists with zero-sized elements.

See advisory: https://github.com/sandstorm-io/capnproto/blob/master/security-advisories/2015-03-05-0-c++-addl-cpu-amplification.md
parent 63187bd4
...@@ -1427,17 +1427,24 @@ TEST(Encoding, VoidListAmplification) { ...@@ -1427,17 +1427,24 @@ TEST(Encoding, VoidListAmplification) {
} }
TEST(Encoding, EmptyStructListAmplification) { TEST(Encoding, EmptyStructListAmplification) {
MallocMessageBuilder builder; MallocMessageBuilder builder(1024);
builder.initRoot<test::TestAnyPointer>().getAnyPointerField() auto listList = builder.initRoot<test::TestAnyPointer>().getAnyPointerField()
.initAs<List<test::TestEmptyStruct>>(1u << 28); .initAs<List<List<test::TestEmptyStruct>>>(500);
for (uint i = 0; i < listList.size(); i++) {
listList.init(i, 1u << 28);
}
auto segments = builder.getSegmentsForOutput(); auto segments = builder.getSegmentsForOutput();
EXPECT_EQ(1, segments.size()); ASSERT_EQ(1, segments.size());
EXPECT_LT(segments[0].size(), 16); // quite small for such a big list!
SegmentArrayMessageReader reader(builder.getSegmentsForOutput()); SegmentArrayMessageReader reader(builder.getSegmentsForOutput());
auto root = reader.getRoot<test::TestAnyPointer>().getAnyPointerField(); auto root = reader.getRoot<test::TestAnyPointer>();
EXPECT_NONFATAL_FAILURE(root.getAs<List<TestAllTypes>>()); auto listListReader = root.getAnyPointerField().getAs<List<List<TestAllTypes>>>();
EXPECT_NONFATAL_FAILURE(listListReader[0]);
EXPECT_NONFATAL_FAILURE(listListReader[10]);
EXPECT_EQ(segments[0].size() - 1, root.totalSize().wordCount);
} }
TEST(Encoding, Constants) { TEST(Encoding, Constants) {
......
...@@ -555,8 +555,9 @@ struct WireHelpers { ...@@ -555,8 +555,9 @@ struct WireHelpers {
WordCount dataSize = elementTag->structRef.dataSize.get(); WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get(); WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
word* pos = ptr + POINTER_SIZE_IN_WORDS;
uint count = elementTag->inlineCompositeListElementCount() / ELEMENTS; uint count = elementTag->inlineCompositeListElementCount() / ELEMENTS;
if (pointerCount > 0 * POINTERS) {
word* pos = ptr + POINTER_SIZE_IN_WORDS;
for (uint i = 0; i < count; i++) { for (uint i = 0; i < count; i++) {
pos += dataSize; pos += dataSize;
...@@ -565,6 +566,7 @@ struct WireHelpers { ...@@ -565,6 +566,7 @@ struct WireHelpers {
pos += POINTER_SIZE_IN_WORDS; pos += POINTER_SIZE_IN_WORDS;
} }
} }
}
memset(ptr, 0, (elementTag->structRef.wordSize() * count + POINTER_SIZE_IN_WORDS) memset(ptr, 0, (elementTag->structRef.wordSize() * count + POINTER_SIZE_IN_WORDS)
* BYTES_PER_WORD / BYTES); * BYTES_PER_WORD / BYTES);
...@@ -680,8 +682,6 @@ struct WireHelpers { ...@@ -680,8 +682,6 @@ struct WireHelpers {
return result; return result;
} }
result.wordCount += wordCount + POINTER_SIZE_IN_WORDS;
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount(); ElementCount count = elementTag->inlineCompositeListElementCount();
...@@ -690,15 +690,20 @@ struct WireHelpers { ...@@ -690,15 +690,20 @@ struct WireHelpers {
return result; return result;
} }
KJ_REQUIRE(elementTag->structRef.wordSize() / ELEMENTS * auto actualSize = elementTag->structRef.wordSize() / ELEMENTS * ElementCount64(count);
ElementCount64(count) <= wordCount, KJ_REQUIRE(actualSize <= wordCount,
"Struct list pointer's elements overran size.") { "Struct list pointer's elements overran size.") {
return result; return result;
} }
// We count the actual size rather than the claimed word count because that's what
// we'll end up with if we make a copy.
result.wordCount += actualSize + POINTER_SIZE_IN_WORDS;
WordCount dataSize = elementTag->structRef.dataSize.get(); WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get(); WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
if (pointerCount > 0 * POINTERS) {
const word* pos = ptr + POINTER_SIZE_IN_WORDS; const word* pos = ptr + POINTER_SIZE_IN_WORDS;
for (uint i = 0; i < count / ELEMENTS; i++) { for (uint i = 0; i < count / ELEMENTS; i++) {
pos += dataSize; pos += dataSize;
...@@ -709,6 +714,7 @@ struct WireHelpers { ...@@ -709,6 +714,7 @@ struct WireHelpers {
pos += POINTER_SIZE_IN_WORDS; pos += POINTER_SIZE_IN_WORDS;
} }
} }
}
break; break;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment