Commit 5db2c8f8 authored by Matthew Maurer's avatar Matthew Maurer

Add Canonicalization

The user facing API is in MessageReader and MessageBuilder

{MessageBuilder,MessageReader}::isCanonical verifies the canonicity of a
message. This is both useful for debugging and for knowing if a received
message can be used for hashes, bytewise equality, etc.

MessageBuilder::canonicalRoot(Reader) can be used to write a canonical
message on a best effort basis, and checks itself using isCanonical.
It should succeed as long as the MessageBuilder in question:
* Has a first segment which is long enough to contain the message
* Has not been used before

Tests have been added in canonicalize-test.c++ which verify that for
crafted examples of canonicalization errors, isCanonical will reject,
and for a canonicalized version of the standard test message, it will
accept.
parent 2c59e520
...@@ -200,6 +200,7 @@ if(BUILD_TESTING) ...@@ -200,6 +200,7 @@ if(BUILD_TESTING)
orphan-test.c++ orphan-test.c++
serialize-test.c++ serialize-test.c++
serialize-packed-test.c++ serialize-packed-test.c++
canonicalize-test.c++
fuzz-test.c++ fuzz-test.c++
test-util.c++ test-util.c++
${test_capnp_cpp_files} ${test_capnp_cpp_files}
......
...@@ -221,9 +221,14 @@ struct AnyPointer { ...@@ -221,9 +221,14 @@ struct AnyPointer {
inline void setAs(std::initializer_list<ReaderFor<ListElementType<T>>> list); inline void setAs(std::initializer_list<ReaderFor<ListElementType<T>>> list);
// Valid for T = List<?>. // Valid for T = List<?>.
template <typename T>
inline void setCanonicalAs(ReaderFor<T> value);
inline void set(Reader value) { builder.copyFrom(value.reader); } inline void set(Reader value) { builder.copyFrom(value.reader); }
// Set to a copy of another AnyPointer. // Set to a copy of another AnyPointer.
inline void setCanonical(Reader value) { builder.copyFrom(value.reader, true); }
template <typename T> template <typename T>
inline void adopt(Orphan<T>&& orphan); inline void adopt(Orphan<T>&& orphan);
// Valid for T = any generated struct type, List<U>, Text, Data, DynamicList, DynamicStruct, // Valid for T = any generated struct type, List<U>, Text, Data, DynamicList, DynamicStruct,
...@@ -793,6 +798,11 @@ inline void AnyPointer::Builder::setAs(ReaderFor<T> value) { ...@@ -793,6 +798,11 @@ inline void AnyPointer::Builder::setAs(ReaderFor<T> value) {
return _::PointerHelpers<T>::set(builder, value); return _::PointerHelpers<T>::set(builder, value);
} }
template <typename T>
inline void AnyPointer::Builder::setCanonicalAs(ReaderFor<T> value) {
return _::PointerHelpers<T>::setCanonical(builder, value);
}
template <typename T> template <typename T>
inline void AnyPointer::Builder::setAs( inline void AnyPointer::Builder::setAs(
std::initializer_list<ReaderFor<ListElementType<T>>> list) { std::initializer_list<ReaderFor<ListElementType<T>>> list) {
......
// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#include "message.h"
#include "any.h"
#include <kj/debug.h>
#include <kj/test.h>
#include "test-util.h"
namespace capnp {
namespace _ { // private
namespace {
KJ_TEST("canonicalize yields cannonical message") {
MallocMessageBuilder builder;
auto root = builder.initRoot<TestAllTypes>();
initTestMessage(root);
MallocMessageBuilder canonicalMessage;
canonicalMessage.canonicalRoot(builder.getRoot<AnyPointer>().asReader());
KJ_ASSERT(canonicalMessage.isCanonical());
}
KJ_TEST("isCanonical requires pointer preorder") {
AlignedData<5> misorderedSegment = {{
//Struct pointer, data immediately follows, two pointer fields, no data
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00,
//Pointer field 1, pointing to the last entry, data size 1, no pointer
0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
//Pointer field 2, pointing to the next entry, data size 2, no pointer
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
//Data for field 2
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
//Data for field 1
0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00
}};
kj::ArrayPtr<const word> segments[1] = {kj::arrayPtr(misorderedSegment.words,
3)};
SegmentArrayMessageReader outOfOrder(kj::arrayPtr(segments, 1));
KJ_ASSERT(!outOfOrder.isCanonical());
}
KJ_TEST("isCanonical requires dense packing") {
AlignedData<3> gapSegment = {{
//Struct pointer, data after a gap
0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
//The gap
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
//Data for field 1
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
}};
kj::ArrayPtr<const word> segments[1] = {kj::arrayPtr(gapSegment.words,
3)};
SegmentArrayMessageReader gap(kj::arrayPtr(segments, 1));
KJ_ASSERT(!gap.isCanonical());
}
KJ_TEST("isCanonical rejects multi-segment messages") {
AlignedData<1> farPtr = {{
//Far pointer to next segment
0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
}};
AlignedData<2> farTarget = {{
//Struct pointer (needed to make the far pointer legal)
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
//Dummy data
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
}};
kj::ArrayPtr<const word> segments[2] = {
kj::arrayPtr(farPtr.words, 1),
kj::arrayPtr(farTarget.words, 2)
};
SegmentArrayMessageReader multiSegmentMessage(kj::arrayPtr(segments, 2));
KJ_ASSERT(!multiSegmentMessage.isCanonical());
}
KJ_TEST("isCanonical rejects zero segment messages") {
SegmentArrayMessageReader zero(kj::arrayPtr((kj::ArrayPtr<const word>*)NULL,
0));
KJ_ASSERT(!zero.isCanonical());
}
KJ_TEST("isCanonical requires truncation of 0-valued struct fields") {
AlignedData<2> nonTruncatedSegment = {{
//Struct pointer, data immediately follows
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
//Default data value, should have been truncated
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}};
kj::ArrayPtr<const word> segments[1] = {
kj::arrayPtr(nonTruncatedSegment.words, 3)
};
SegmentArrayMessageReader nonTruncated(kj::arrayPtr(segments, 1));
KJ_ASSERT(!nonTruncated.isCanonical());
}
KJ_TEST("isCanonical requires truncation of 0-valued struct fields\
in all list members") {
AlignedData<6> nonTruncatedList = {{
//List pointer, composite,
0x01, 0x00, 0x00, 0x00, 0x27, 0x00, 0x00, 0x00,
//Struct tag word, 2 structs, 2 data words per struct
0x08, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
//Data word non-null
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
//Null trailing word
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
//Data word non-null
0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00,
//Null trailing word
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
}};
kj::ArrayPtr<const word> segments[1] = {
kj::arrayPtr(nonTruncatedList.words, 6)
};
SegmentArrayMessageReader nonTruncated(kj::arrayPtr(segments, 1));
KJ_ASSERT(!nonTruncated.isCanonical());
}
} // namespace
} // namespace _ (private)
} // namespace capnp
...@@ -323,7 +323,8 @@ struct WireHelpers { ...@@ -323,7 +323,8 @@ struct WireHelpers {
static KJ_ALWAYS_INLINE(word* allocate( static KJ_ALWAYS_INLINE(word* allocate(
WirePointer*& ref, SegmentBuilder*& segment, CapTableBuilder* capTable, WordCount amount, WirePointer*& ref, SegmentBuilder*& segment, CapTableBuilder* capTable, WordCount amount,
WirePointer::Kind kind, BuilderArena* orphanArena)) { WirePointer::Kind kind, BuilderArena* orphanArena,
bool canonical = false)) {
// Allocate space in the message for a new object, creating far pointers if necessary. // Allocate space in the message for a new object, creating far pointers if necessary.
// //
// * `ref` starts out being a reference to the pointer which shall be assigned to point at the // * `ref` starts out being a reference to the pointer which shall be assigned to point at the
...@@ -344,7 +345,8 @@ struct WireHelpers { ...@@ -344,7 +345,8 @@ struct WireHelpers {
// segment belonging to the arena. `ref` will be initialized as a non-far pointer, but its // segment belonging to the arena. `ref` will be initialized as a non-far pointer, but its
// target offset will be set to zero. // target offset will be set to zero.
if (orphanArena == nullptr) { // Canonical messages can't be crafted from orphans, don't try
if ((orphanArena == nullptr) || canonical) {
if (!ref->isNull()) zeroObject(segment, capTable, ref); if (!ref->isNull()) zeroObject(segment, capTable, ref);
if (amount == 0 * WORDS && kind == WirePointer::STRUCT) { if (amount == 0 * WORDS && kind == WirePointer::STRUCT) {
...@@ -354,9 +356,14 @@ struct WireHelpers { ...@@ -354,9 +356,14 @@ struct WireHelpers {
return reinterpret_cast<word*>(ref); return reinterpret_cast<word*>(ref);
} }
// Hope this generates sequential memory, validate with isCanonical
word* ptr = segment->allocate(amount); word* ptr = segment->allocate(amount);
if (ptr == nullptr) { if (ptr == nullptr) {
if (canonical) {
KJ_FAIL_REQUIRE("segment0 must hold entire canonical message");
}
// Need to allocate in a new segment. We'll need to allocate an extra pointer worth of // Need to allocate in a new segment. We'll need to allocate an extra pointer worth of
// space to act as the landing pad for a far pointer. // space to act as the landing pad for a far pointer.
...@@ -1545,23 +1552,43 @@ struct WireHelpers { ...@@ -1545,23 +1552,43 @@ struct WireHelpers {
static SegmentAnd<word*> setStructPointer( static SegmentAnd<word*> setStructPointer(
SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, StructReader value, SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, StructReader value,
BuilderArena* orphanArena = nullptr) { BuilderArena* orphanArena = nullptr, bool canonical = false) {
WordCount dataSize = roundBitsUpToWords(value.dataSize); WordCount dataSize = roundBitsUpToWords(value.dataSize);
WordCount totalSize = dataSize + value.pointerCount * WORDS_PER_POINTER; WirePointerCount ptrCount = value.pointerCount;
word* ptr = allocate(ref, segment, capTable, totalSize, WirePointer::STRUCT, orphanArena);
ref->structRef.set(dataSize, value.pointerCount); if (canonical) {
// Truncate the data section
while ((dataSize != 0) &&
(value.getDataField<uint64_t>(dataSize - 1) == 0)) {
dataSize--;
}
// Truncate pointer section
while ((ptrCount != 0) &&
value.getPointerField(ptrCount - 1).isNull()) {
ptrCount--;
}
}
WordCount totalSize = dataSize + ptrCount * WORDS_PER_POINTER;
word* ptr = allocate(ref, segment, capTable, totalSize, WirePointer::STRUCT, orphanArena, canonical);
ref->structRef.set(dataSize, ptrCount);
if (value.dataSize == 1 * BITS) { if (value.dataSize == 1 * BITS) {
*reinterpret_cast<char*>(ptr) = value.getDataField<bool>(0 * ELEMENTS); // Data size could be made 0 by truncation
if (dataSize != 0) {
*reinterpret_cast<char*>(ptr) = value.getDataField<bool>(0 * ELEMENTS);
}
} else { } else {
memcpy(ptr, value.data, value.dataSize / BITS_PER_BYTE / BYTES); memcpy(ptr, value.data, dataSize * BYTES_PER_WORD);
} }
WirePointer* pointerSection = reinterpret_cast<WirePointer*>(ptr + dataSize); WirePointer* pointerSection = reinterpret_cast<WirePointer*>(ptr + dataSize);
for (uint i = 0; i < value.pointerCount / POINTERS; i++) { for (uint i = 0; i < ptrCount; i++) {
copyPointer(segment, capTable, pointerSection + i, copyPointer(segment, capTable, pointerSection + i,
value.segment, value.capTable, value.pointers + i, value.nestingLimit); value.segment, value.capTable, value.pointers + i,
value.nestingLimit, nullptr, canonical);
} }
return { segment, ptr }; return { segment, ptr };
...@@ -1584,12 +1611,12 @@ struct WireHelpers { ...@@ -1584,12 +1611,12 @@ struct WireHelpers {
static SegmentAnd<word*> setListPointer( static SegmentAnd<word*> setListPointer(
SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, ListReader value, SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, ListReader value,
BuilderArena* orphanArena = nullptr) { BuilderArena* orphanArena = nullptr, bool canonical = false) {
WordCount totalSize = roundBitsUpToWords(value.elementCount * value.step); WordCount totalSize = roundBitsUpToWords(value.elementCount * value.step);
if (value.elementSize != ElementSize::INLINE_COMPOSITE) { if (value.elementSize != ElementSize::INLINE_COMPOSITE) {
// List of non-structs. // List of non-structs.
word* ptr = allocate(ref, segment, capTable, totalSize, WirePointer::LIST, orphanArena); word* ptr = allocate(ref, segment, capTable, totalSize, WirePointer::LIST, orphanArena, canonical);
if (value.elementSize == ElementSize::POINTER) { if (value.elementSize == ElementSize::POINTER) {
// List of pointers. // List of pointers.
...@@ -1598,7 +1625,7 @@ struct WireHelpers { ...@@ -1598,7 +1625,7 @@ struct WireHelpers {
copyPointer(segment, capTable, reinterpret_cast<WirePointer*>(ptr) + i, copyPointer(segment, capTable, reinterpret_cast<WirePointer*>(ptr) + i,
value.segment, value.capTable, value.segment, value.capTable,
reinterpret_cast<const WirePointer*>(value.ptr) + i, reinterpret_cast<const WirePointer*>(value.ptr) + i,
value.nestingLimit); value.nestingLimit, nullptr, canonical);
} }
} else { } else {
// List of data. // List of data.
...@@ -1609,31 +1636,61 @@ struct WireHelpers { ...@@ -1609,31 +1636,61 @@ struct WireHelpers {
return { segment, ptr }; return { segment, ptr };
} else { } else {
// List of structs. // List of structs.
WordCount declDataSize = roundBitsUpToWords(value.structDataSize);
WirePointerCount declPointerCount = value.structPointerCount;
WordCount dataSize = 0 * WORDS;
WirePointerCount ptrCount = 0 * POINTERS;
if (canonical) {
for (auto ec = ElementCount(0); ec < value.elementCount; ec++) {
auto se = value.getStructElement(ec);
WordCount localDataSize = declDataSize;
while ((localDataSize != 0 * WORDS) &&
(se.getDataField<uint64_t>(localDataSize - 1) == 0)) {
localDataSize--;
}
if (localDataSize > dataSize) {
dataSize = localDataSize;
}
WirePointerCount localPtrCount = declPointerCount;
while ((localPtrCount != 0 * POINTERS) &&
se.getPointerField(localPtrCount - 1).isNull()) {
localPtrCount--;
}
if (localPtrCount > ptrCount) {
ptrCount = localPtrCount;
}
}
totalSize = (dataSize + ptrCount * WORDS_PER_POINTER) * value.elementCount;
} else {
dataSize = declDataSize;
ptrCount = declPointerCount;
}
word* ptr = allocate(ref, segment, capTable, totalSize + POINTER_SIZE_IN_WORDS, word* ptr = allocate(ref, segment, capTable, totalSize + POINTER_SIZE_IN_WORDS,
WirePointer::LIST, orphanArena); WirePointer::LIST, orphanArena, canonical);
ref->listRef.setInlineComposite(totalSize); ref->listRef.setInlineComposite(totalSize);
WordCount dataSize = roundBitsUpToWords(value.structDataSize);
WirePointerCount pointerCount = value.structPointerCount;
WirePointer* tag = reinterpret_cast<WirePointer*>(ptr); WirePointer* tag = reinterpret_cast<WirePointer*>(ptr);
tag->setKindAndInlineCompositeListElementCount(WirePointer::STRUCT, value.elementCount); tag->setKindAndInlineCompositeListElementCount(WirePointer::STRUCT, value.elementCount);
tag->structRef.set(dataSize, pointerCount); tag->structRef.set(dataSize, ptrCount);
word* dst = ptr + POINTER_SIZE_IN_WORDS; word* dst = ptr + POINTER_SIZE_IN_WORDS;
const word* src = reinterpret_cast<const word*>(value.ptr); const word* src = reinterpret_cast<const word*>(value.ptr);
for (uint i = 0; i < value.elementCount / ELEMENTS; i++) { for (uint i = 0; i < value.elementCount / ELEMENTS; i++) {
memcpy(dst, src, value.structDataSize / BITS_PER_BYTE / BYTES); memcpy(dst, src, dataSize * BYTES_PER_WORD);
dst += dataSize; dst += dataSize;
src += dataSize; src += declDataSize;
for (uint j = 0; j < pointerCount / POINTERS; j++) { for (uint j = 0; j < ptrCount / POINTERS; j++) {
copyPointer(segment, capTable, reinterpret_cast<WirePointer*>(dst), copyPointer(segment, capTable, reinterpret_cast<WirePointer*>(dst),
value.segment, value.capTable, reinterpret_cast<const WirePointer*>(src), value.segment, value.capTable, reinterpret_cast<const WirePointer*>(src),
value.nestingLimit); value.nestingLimit, nullptr, canonical);
dst += POINTER_SIZE_IN_WORDS; dst += POINTER_SIZE_IN_WORDS;
src += POINTER_SIZE_IN_WORDS; src += POINTER_SIZE_IN_WORDS;
} }
src += (declPointerCount - ptrCount) * POINTER_SIZE_IN_WORDS;
} }
return { segment, ptr }; return { segment, ptr };
...@@ -1643,16 +1700,18 @@ struct WireHelpers { ...@@ -1643,16 +1700,18 @@ struct WireHelpers {
static KJ_ALWAYS_INLINE(SegmentAnd<word*> copyPointer( static KJ_ALWAYS_INLINE(SegmentAnd<word*> copyPointer(
SegmentBuilder* dstSegment, CapTableBuilder* dstCapTable, WirePointer* dst, SegmentBuilder* dstSegment, CapTableBuilder* dstCapTable, WirePointer* dst,
SegmentReader* srcSegment, CapTableReader* srcCapTable, const WirePointer* src, SegmentReader* srcSegment, CapTableReader* srcCapTable, const WirePointer* src,
int nestingLimit, BuilderArena* orphanArena = nullptr)) { int nestingLimit, BuilderArena* orphanArena = nullptr,
bool canonical = false)) {
return copyPointer(dstSegment, dstCapTable, dst, return copyPointer(dstSegment, dstCapTable, dst,
srcSegment, srcCapTable, src, src->target(), srcSegment, srcCapTable, src, src->target(),
nestingLimit, orphanArena); nestingLimit, orphanArena, canonical);
} }
static SegmentAnd<word*> copyPointer( static SegmentAnd<word*> copyPointer(
SegmentBuilder* dstSegment, CapTableBuilder* dstCapTable, WirePointer* dst, SegmentBuilder* dstSegment, CapTableBuilder* dstCapTable, WirePointer* dst,
SegmentReader* srcSegment, CapTableReader* srcCapTable, const WirePointer* src, SegmentReader* srcSegment, CapTableReader* srcCapTable, const WirePointer* src,
const word* srcTarget, int nestingLimit, BuilderArena* orphanArena = nullptr) { const word* srcTarget, int nestingLimit,
BuilderArena* orphanArena = nullptr, bool canonical = false) {
// Deep-copy the object pointed to by src into dst. It turns out we can't reuse // Deep-copy the object pointed to by src into dst. It turns out we can't reuse
// readStructPointer(), etc. because they do type checking whereas here we want to accept any // readStructPointer(), etc. because they do type checking whereas here we want to accept any
// valid pointer. // valid pointer.
...@@ -1689,7 +1748,7 @@ struct WireHelpers { ...@@ -1689,7 +1748,7 @@ struct WireHelpers {
src->structRef.dataSize.get() * BITS_PER_WORD, src->structRef.dataSize.get() * BITS_PER_WORD,
src->structRef.ptrCount.get(), src->structRef.ptrCount.get(),
nestingLimit - 1), nestingLimit - 1),
orphanArena); orphanArena, canonical);
case WirePointer::LIST: { case WirePointer::LIST: {
ElementSize elementSize = src->listRef.elementSize(); ElementSize elementSize = src->listRef.elementSize();
...@@ -1737,7 +1796,7 @@ struct WireHelpers { ...@@ -1737,7 +1796,7 @@ struct WireHelpers {
tag->structRef.dataSize.get() * BITS_PER_WORD, tag->structRef.dataSize.get() * BITS_PER_WORD,
tag->structRef.ptrCount.get(), ElementSize::INLINE_COMPOSITE, tag->structRef.ptrCount.get(), ElementSize::INLINE_COMPOSITE,
nestingLimit - 1), nestingLimit - 1),
orphanArena); orphanArena, canonical);
} else { } else {
BitCount dataSize = dataBitsPerElement(elementSize) * ELEMENTS; BitCount dataSize = dataBitsPerElement(elementSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(elementSize) * ELEMENTS; WirePointerCount pointerCount = pointersPerElement(elementSize) * ELEMENTS;
...@@ -1762,7 +1821,7 @@ struct WireHelpers { ...@@ -1762,7 +1821,7 @@ struct WireHelpers {
return setListPointer(dstSegment, dstCapTable, dst, return setListPointer(dstSegment, dstCapTable, dst,
ListReader(srcSegment, srcCapTable, ptr, elementCount, step, dataSize, pointerCount, ListReader(srcSegment, srcCapTable, ptr, elementCount, step, dataSize, pointerCount,
elementSize, nestingLimit - 1), elementSize, nestingLimit - 1),
orphanArena); orphanArena, canonical);
} }
} }
...@@ -1777,6 +1836,11 @@ struct WireHelpers { ...@@ -1777,6 +1836,11 @@ struct WireHelpers {
} }
#if !CAPNP_LITE #if !CAPNP_LITE
if (canonical) {
KJ_FAIL_REQUIRE("Cannot create a canonical message with a capability") {
goto useDefault;
}
}
KJ_IF_MAYBE(cap, srcCapTable->extractCap(src->capRef.index.get())) { KJ_IF_MAYBE(cap, srcCapTable->extractCap(src->capRef.index.get())) {
setCapabilityPointer(dstSegment, dstCapTable, dst, kj::mv(*cap)); setCapabilityPointer(dstSegment, dstCapTable, dst, kj::mv(*cap));
// Return dummy non-null pointer so OrphanBuilder doesn't end up null. // Return dummy non-null pointer so OrphanBuilder doesn't end up null.
...@@ -2258,7 +2322,7 @@ Text::Builder PointerBuilder::initBlob<Text>(ByteCount size) { ...@@ -2258,7 +2322,7 @@ Text::Builder PointerBuilder::initBlob<Text>(ByteCount size) {
return WireHelpers::initTextPointer(pointer, segment, capTable, size).value; return WireHelpers::initTextPointer(pointer, segment, capTable, size).value;
} }
template <> template <>
void PointerBuilder::setBlob<Text>(Text::Reader value) { void PointerBuilder::setBlob<Text>(Text::Reader value, bool canonical) {
WireHelpers::setTextPointer(pointer, segment, capTable, value); WireHelpers::setTextPointer(pointer, segment, capTable, value);
} }
template <> template <>
...@@ -2271,7 +2335,7 @@ Data::Builder PointerBuilder::initBlob<Data>(ByteCount size) { ...@@ -2271,7 +2335,7 @@ Data::Builder PointerBuilder::initBlob<Data>(ByteCount size) {
return WireHelpers::initDataPointer(pointer, segment, capTable, size).value; return WireHelpers::initDataPointer(pointer, segment, capTable, size).value;
} }
template <> template <>
void PointerBuilder::setBlob<Data>(Data::Reader value) { void PointerBuilder::setBlob<Data>(Data::Reader value, bool canonical) {
WireHelpers::setDataPointer(pointer, segment, capTable, value); WireHelpers::setDataPointer(pointer, segment, capTable, value);
} }
template <> template <>
...@@ -2279,12 +2343,12 @@ Data::Builder PointerBuilder::getBlob<Data>(const void* defaultValue, ByteCount ...@@ -2279,12 +2343,12 @@ Data::Builder PointerBuilder::getBlob<Data>(const void* defaultValue, ByteCount
return WireHelpers::getWritableDataPointer(pointer, segment, capTable, defaultValue, defaultSize); return WireHelpers::getWritableDataPointer(pointer, segment, capTable, defaultValue, defaultSize);
} }
void PointerBuilder::setStruct(const StructReader& value) { void PointerBuilder::setStruct(const StructReader& value, bool canonical) {
WireHelpers::setStructPointer(segment, capTable, pointer, value); WireHelpers::setStructPointer(segment, capTable, pointer, value, nullptr, canonical);
} }
void PointerBuilder::setList(const ListReader& value) { void PointerBuilder::setList(const ListReader& value, bool canonical) {
WireHelpers::setListPointer(segment, capTable, pointer, value); WireHelpers::setListPointer(segment, capTable, pointer, value, nullptr, canonical);
} }
#if !CAPNP_LITE #if !CAPNP_LITE
...@@ -2341,7 +2405,7 @@ void PointerBuilder::transferFrom(PointerBuilder other) { ...@@ -2341,7 +2405,7 @@ void PointerBuilder::transferFrom(PointerBuilder other) {
memset(other.pointer, 0, sizeof(*other.pointer)); memset(other.pointer, 0, sizeof(*other.pointer));
} }
void PointerBuilder::copyFrom(PointerReader other) { void PointerBuilder::copyFrom(PointerReader other, bool canonical) {
if (other.pointer == nullptr) { if (other.pointer == nullptr) {
if (!pointer->isNull()) { if (!pointer->isNull()) {
WireHelpers::zeroObject(segment, capTable, pointer); WireHelpers::zeroObject(segment, capTable, pointer);
...@@ -2349,7 +2413,9 @@ void PointerBuilder::copyFrom(PointerReader other) { ...@@ -2349,7 +2413,9 @@ void PointerBuilder::copyFrom(PointerReader other) {
} }
} else { } else {
WireHelpers::copyPointer(segment, capTable, pointer, WireHelpers::copyPointer(segment, capTable, pointer,
other.segment, other.capTable, other.pointer, other.nestingLimit); other.segment, other.capTable, other.pointer, other.nestingLimit,
nullptr,
canonical);
} }
} }
...@@ -2468,6 +2534,36 @@ PointerReader PointerReader::imbue(CapTableReader* capTable) const { ...@@ -2468,6 +2534,36 @@ PointerReader PointerReader::imbue(CapTableReader* capTable) const {
return result; return result;
} }
bool PointerReader::isCanonical(const word **readHead) {
if (!this->pointer) {
// The pointer is null, so we are canonical and do not read
return true;
}
if (!this->pointer->isPositional()) {
// The pointer is a FAR or OTHER pointer, and is non-canonical
return false;
}
switch (this->getPointerType()) {
case PointerType::NULL_:
// The pointer is null, we are canonical and do not read
return true;
case PointerType::STRUCT:
bool dataTrunc, ptrTrunc;
return (this->getStruct(nullptr).isCanonical(readHead,
readHead,
&dataTrunc,
&ptrTrunc)
&& dataTrunc && ptrTrunc);
case PointerType::LIST:
return this->getListAnySize(nullptr).isCanonical(readHead);
case PointerType::CAPABILITY:
KJ_FAIL_ASSERT("Capabilities are not positional");
}
KJ_UNREACHABLE;
}
// ======================================================================================= // =======================================================================================
// StructBuilder // StructBuilder
...@@ -2609,6 +2705,49 @@ StructReader StructReader::imbue(CapTableReader* capTable) const { ...@@ -2609,6 +2705,49 @@ StructReader StructReader::imbue(CapTableReader* capTable) const {
return result; return result;
} }
bool StructReader::isCanonical(const word **readHead,
const word **ptrHead,
bool *dataTrunc,
bool *ptrTrunc) {
if (this->getLocation() != *readHead) {
// Our target area is not at the readHead, preorder fails
return false;
}
if (this->getDataSectionSize() % BITS_PER_WORD != 0) {
// Using legacy non-word-size structs, reject
return false;
}
WordCount32 dataSize = this->getDataSectionSize() / BITS_PER_WORD;
// Mark whether the struct is properly truncated
if (dataSize != 0) {
*dataTrunc = this->getDataField<uint64_t>(dataSize - 1) != 0;
} else {
*dataTrunc = true;
}
if (this->pointerCount != 0) {
*ptrTrunc = !this->getPointerField(this->pointerCount - 1).isNull();
} else {
*ptrTrunc = true;
}
// Advance the read head
*readHead += dataSize + this->pointerCount;
// Check each pointer field for canonicity
for (WirePointerCount16 ptrIndex = 0;
ptrIndex < this->pointerCount;
ptrIndex++) {
if (!this->getPointerField(ptrIndex).isCanonical(ptrHead)) {
return false;
}
}
return true;
}
// ======================================================================================= // =======================================================================================
// ListBuilder // ListBuilder
...@@ -2748,6 +2887,74 @@ ListReader ListReader::imbue(CapTableReader* capTable) const { ...@@ -2748,6 +2887,74 @@ ListReader ListReader::imbue(CapTableReader* capTable) const {
return result; return result;
} }
bool ListReader::isCanonical(const word **readHead) {
switch (this->getElementSize()) {
case ElementSize::INLINE_COMPOSITE: {
*readHead += 1;
if (reinterpret_cast<const word*>(this->ptr) != *readHead) {
// The next word to read is the tag word, but the pointer is in
// front of it, so our check is slightly different
return false;
}
if (this->structDataSize % BITS_PER_WORD != 0) {
return false;
}
auto structSize = (this->structDataSize / BITS_PER_WORD) +
(this->structPointerCount * WORDS_PER_POINTER);
auto listEnd = *readHead + this->elementCount * structSize;
auto pointerHead = listEnd;
bool listDataTrunc = false;
bool listPtrTrunc = false;
for (ElementCount ec = ElementCount(0);
ec < this->elementCount;
ec++) {
bool dataTrunc, ptrTrunc;
if (!this->getStructElement(ec).isCanonical(readHead,
&pointerHead,
&dataTrunc,
&ptrTrunc)) {
return false;
}
listDataTrunc |= dataTrunc;
listPtrTrunc |= ptrTrunc;
}
KJ_REQUIRE(*readHead == listEnd, *readHead, listEnd);
*readHead = pointerHead;
return listDataTrunc && listPtrTrunc;
}
case ElementSize::POINTER: {
if (reinterpret_cast<const word*>(this->ptr) != *readHead) {
return false;
}
*readHead += this->elementCount;
for (ElementCount ec = ElementCount(0);
ec < this->elementCount;
ec++) {
if (!this->getPointerElement(ec).isCanonical(readHead)) {
return false;
}
}
return true;
}
default: {
if (reinterpret_cast<const word*>(this->ptr) != *readHead) {
return false;
}
auto bitSize = this->elementCount *
dataBitsPerElement(this->elementSize);
auto wordSize = bitSize / BITS_PER_WORD;
if (bitSize % BITS_PER_WORD != 0) {
wordSize++;
}
*readHead += wordSize;
return true;
}
}
KJ_UNREACHABLE;
}
// ======================================================================================= // =======================================================================================
// OrphanBuilder // OrphanBuilder
......
...@@ -339,9 +339,9 @@ public: ...@@ -339,9 +339,9 @@ public:
// Init methods: Initialize the pointer to a newly-allocated object, discarding the existing // Init methods: Initialize the pointer to a newly-allocated object, discarding the existing
// object. // object.
void setStruct(const StructReader& value); void setStruct(const StructReader& value, bool canonical = false);
void setList(const ListReader& value); void setList(const ListReader& value, bool canonical = false);
template <typename T> void setBlob(typename T::Reader value); template <typename T> void setBlob(typename T::Reader value, bool canonical = false);
#if !CAPNP_LITE #if !CAPNP_LITE
void setCapability(kj::Own<ClientHook>&& cap); void setCapability(kj::Own<ClientHook>&& cap);
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
...@@ -360,8 +360,10 @@ public: ...@@ -360,8 +360,10 @@ public:
void transferFrom(PointerBuilder other); void transferFrom(PointerBuilder other);
// Equivalent to `adopt(other.disown())`. // Equivalent to `adopt(other.disown())`.
void copyFrom(PointerReader other); void copyFrom(PointerReader other, bool canonical = false);
// Equivalent to `set(other.get())`. // Equivalent to `set(other.get())`.
// If you set the canonical flag, it will attempt to lay the target out
// canonically, provided enough space is available.
PointerReader asReader() const; PointerReader asReader() const;
...@@ -436,6 +438,13 @@ public: ...@@ -436,6 +438,13 @@ public:
PointerReader imbue(CapTableReader* capTable) const; PointerReader imbue(CapTableReader* capTable) const;
// Return a copy of this reader except using the given capability context. // Return a copy of this reader except using the given capability context.
bool isCanonical(const word **readHead);
// Validate this pointer's canonicity, subject to the conditions:
// * All data to the left of readHead has been read thus far (for pointer
// ordering)
// * All pointers in preorder have already been checked
// * This pointer is in the first and only segment of the message
private: private:
SegmentReader* segment; // Memory segment in which the pointer resides. SegmentReader* segment; // Memory segment in which the pointer resides.
CapTableReader* capTable; // Table of capability indexes. CapTableReader* capTable; // Table of capability indexes.
...@@ -595,6 +604,21 @@ public: ...@@ -595,6 +604,21 @@ public:
StructReader imbue(CapTableReader* capTable) const; StructReader imbue(CapTableReader* capTable) const;
// Return a copy of this reader except using the given capability context. // Return a copy of this reader except using the given capability context.
bool isCanonical(const word **readHead, const word **ptrHead,
bool *dataTrunc, bool *ptrTrunc);
// Validate this pointer's canonicity, subject to the conditions:
// * All data to the left of readHead has been read thus far (for pointer
// ordering)
// * All pointers in preorder have already been checked
// * This pointer is in the first and only segment of the message
//
// If this function returns false, the struct is non-canonical. If it
// returns true, then:
// * If it is a composite in a list, it is canonical if at least one struct
// in the list outputs dataTrunc = 1, and at least one outputs ptrTrunc = 1
// * If it is derived from a struct pointer, it is canonical if
// dataTrunc = 1 AND ptrTrunc = 1
private: private:
SegmentReader* segment; // Memory segment in which the struct resides. SegmentReader* segment; // Memory segment in which the struct resides.
CapTableReader* capTable; // Table of capability indexes. CapTableReader* capTable; // Table of capability indexes.
...@@ -758,6 +782,13 @@ public: ...@@ -758,6 +782,13 @@ public:
ListReader imbue(CapTableReader* capTable) const; ListReader imbue(CapTableReader* capTable) const;
// Return a copy of this reader except using the given capability context. // Return a copy of this reader except using the given capability context.
bool isCanonical(const word **readHead);
// Validate this pointer's canonicity, subject to the conditions:
// * All data to the left of readHead has been read thus far (for pointer
// ordering)
// * All pointers in preorder have already been checked
// * This pointer is in the first and only segment of the message
private: private:
SegmentReader* segment; // Memory segment in which the list resides. SegmentReader* segment; // Memory segment in which the list resides.
CapTableReader* capTable; // Table of capability indexes. CapTableReader* capTable; // Table of capability indexes.
...@@ -915,12 +946,12 @@ private: ...@@ -915,12 +946,12 @@ private:
// These are defined in the source file. // These are defined in the source file.
template <> typename Text::Builder PointerBuilder::initBlob<Text>(ByteCount size); template <> typename Text::Builder PointerBuilder::initBlob<Text>(ByteCount size);
template <> void PointerBuilder::setBlob<Text>(typename Text::Reader value); template <> void PointerBuilder::setBlob<Text>(typename Text::Reader value, bool canonical);
template <> typename Text::Builder PointerBuilder::getBlob<Text>(const void* defaultValue, ByteCount defaultSize); template <> typename Text::Builder PointerBuilder::getBlob<Text>(const void* defaultValue, ByteCount defaultSize);
template <> typename Text::Reader PointerReader::getBlob<Text>(const void* defaultValue, ByteCount defaultSize) const; template <> typename Text::Reader PointerReader::getBlob<Text>(const void* defaultValue, ByteCount defaultSize) const;
template <> typename Data::Builder PointerBuilder::initBlob<Data>(ByteCount size); template <> typename Data::Builder PointerBuilder::initBlob<Data>(ByteCount size);
template <> void PointerBuilder::setBlob<Data>(typename Data::Reader value); template <> void PointerBuilder::setBlob<Data>(typename Data::Reader value, bool canonical);
template <> typename Data::Builder PointerBuilder::getBlob<Data>(const void* defaultValue, ByteCount defaultSize); template <> typename Data::Builder PointerBuilder::getBlob<Data>(const void* defaultValue, ByteCount defaultSize);
template <> typename Data::Reader PointerReader::getBlob<Data>(const void* defaultValue, ByteCount defaultSize) const; template <> typename Data::Reader PointerReader::getBlob<Data>(const void* defaultValue, ByteCount defaultSize) const;
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <errno.h> #include <errno.h>
#include <climits>
namespace capnp { namespace capnp {
...@@ -51,6 +52,34 @@ MessageReader::~MessageReader() noexcept(false) { ...@@ -51,6 +52,34 @@ MessageReader::~MessageReader() noexcept(false) {
} }
} }
bool MessageReader::isCanonical() {
if (!allocatedArena) {
static_assert(sizeof(_::ReaderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a ReaderArena. Please increase it. This will break "
"ABI compatibility.");
new(arena()) _::ReaderArena(this);
allocatedArena = true;
}
_::SegmentReader *segment = arena()->tryGetSegment(_::SegmentId(0));
if (segment == NULL) {
// The message has no segments
return false;
}
if (arena()->tryGetSegment(_::SegmentId(1))) {
// The message has more than one segment
return false;
}
const word* readHead = segment->getStartPtr() + 1;
return _::PointerReader::getRoot(segment, nullptr, segment->getStartPtr(),
this->getOptions().nestingLimit)
.isCanonical(&readHead);
}
AnyPointer::Reader MessageReader::getRootInternal() { AnyPointer::Reader MessageReader::getRootInternal() {
if (!allocatedArena) { if (!allocatedArena) {
static_assert(sizeof(_::ReaderArena) <= sizeof(arenaSpace), static_assert(sizeof(_::ReaderArena) <= sizeof(arenaSpace),
...@@ -130,6 +159,25 @@ Orphanage MessageBuilder::getOrphanage() { ...@@ -130,6 +159,25 @@ Orphanage MessageBuilder::getOrphanage() {
return Orphanage(arena(), arena()->getLocalCapTable()); return Orphanage(arena(), arena()->getLocalCapTable());
} }
bool MessageBuilder::isCanonical() {
_::SegmentReader *segment = getRootSegment();
if (segment == NULL) {
// The message has no segments
return false;
}
if (arena()->tryGetSegment(_::SegmentId(1))) {
// The message has more than one segment
return false;
}
const word* readHead = segment->getStartPtr() + 1;
return _::PointerReader::getRoot(segment, nullptr, segment->getStartPtr(),
INT_MAX)
.isCanonical(&readHead);
}
// ======================================================================================= // =======================================================================================
SegmentArrayMessageReader::SegmentArrayMessageReader( SegmentArrayMessageReader::SegmentArrayMessageReader(
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <kj/common.h> #include <kj/common.h>
#include <kj/memory.h> #include <kj/memory.h>
#include <kj/mutex.h> #include <kj/mutex.h>
#include <kj/debug.h>
#include "common.h" #include "common.h"
#include "layout.h" #include "layout.h"
#include "any.h" #include "any.h"
...@@ -118,6 +119,9 @@ public: ...@@ -118,6 +119,9 @@ public:
// RootType in this case must be DynamicStruct, and you must #include <capnp/dynamic.h> to // RootType in this case must be DynamicStruct, and you must #include <capnp/dynamic.h> to
// use this. // use this.
bool isCanonical();
// Returns whether the message encoded in the reader is in canonical form.
private: private:
ReaderOptions options; ReaderOptions options;
...@@ -195,6 +199,12 @@ public: ...@@ -195,6 +199,12 @@ public:
void setRoot(Reader&& value); void setRoot(Reader&& value);
// Set the root struct to a deep copy of the given struct. // Set the root struct to a deep copy of the given struct.
template <typename Reader>
void canonicalRoot(Reader&& value);
// Set the root to a canonical deep copy of the struct.
// will likely only work if the builder has not yet been used, and has
// been set up with an arena with a segment as big as value.targetSize();
template <typename RootType> template <typename RootType>
typename RootType::Builder getRoot(); typename RootType::Builder getRoot();
// Get the root struct of the message, interpreting it as the given struct type. // Get the root struct of the message, interpreting it as the given struct type.
...@@ -220,6 +230,9 @@ public: ...@@ -220,6 +230,9 @@ public:
Orphanage getOrphanage(); Orphanage getOrphanage();
bool isCanonical();
// Check whether the message builder is in canonical form
private: private:
void* arenaSpace[22]; void* arenaSpace[22];
// Space in which we can construct a BuilderArena. We don't use BuilderArena directly here // Space in which we can construct a BuilderArena. We don't use BuilderArena directly here
...@@ -461,6 +474,13 @@ typename RootType::Builder MessageBuilder::initRoot(SchemaType schema) { ...@@ -461,6 +474,13 @@ typename RootType::Builder MessageBuilder::initRoot(SchemaType schema) {
return getRootInternal().initAs<RootType>(schema); return getRootInternal().initAs<RootType>(schema);
} }
template <typename Reader>
void MessageBuilder::canonicalRoot(Reader&& value) {
auto target = initRoot<AnyPointer>();
target.setCanonical(value);
KJ_ASSERT(isCanonical());
}
template <typename RootType> template <typename RootType>
typename RootType::Reader readMessageUnchecked(const word* data) { typename RootType::Reader readMessageUnchecked(const word* data) {
return AnyPointer::Reader(_::PointerReader::getRootUnchecked(data)).getAs<RootType>(); return AnyPointer::Reader(_::PointerReader::getRootUnchecked(data)).getAs<RootType>();
......
...@@ -48,6 +48,9 @@ struct PointerHelpers<T, Kind::STRUCT> { ...@@ -48,6 +48,9 @@ struct PointerHelpers<T, Kind::STRUCT> {
static inline void set(PointerBuilder builder, typename T::Reader value) { static inline void set(PointerBuilder builder, typename T::Reader value) {
builder.setStruct(value._reader); builder.setStruct(value._reader);
} }
static inline void setCanonical(PointerBuilder builder, typename T::Reader value) {
builder.setStruct(value._reader, true);
}
static inline typename T::Builder init(PointerBuilder builder) { static inline typename T::Builder init(PointerBuilder builder) {
return typename T::Builder(builder.initStruct(structSize<T>())); return typename T::Builder(builder.initStruct(structSize<T>()));
} }
...@@ -78,6 +81,9 @@ struct PointerHelpers<List<T>, Kind::LIST> { ...@@ -78,6 +81,9 @@ struct PointerHelpers<List<T>, Kind::LIST> {
static inline void set(PointerBuilder builder, typename List<T>::Reader value) { static inline void set(PointerBuilder builder, typename List<T>::Reader value) {
builder.setList(value.reader); builder.setList(value.reader);
} }
static inline void setCanonical(PointerBuilder builder, typename List<T>::Reader value) {
builder.setList(value.reader, true);
}
static void set(PointerBuilder builder, kj::ArrayPtr<const ReaderFor<T>> value) { static void set(PointerBuilder builder, kj::ArrayPtr<const ReaderFor<T>> value) {
auto l = init(builder, value.size()); auto l = init(builder, value.size());
uint i = 0; uint i = 0;
...@@ -117,6 +123,9 @@ struct PointerHelpers<T, Kind::BLOB> { ...@@ -117,6 +123,9 @@ struct PointerHelpers<T, Kind::BLOB> {
static inline void set(PointerBuilder builder, typename T::Reader value) { static inline void set(PointerBuilder builder, typename T::Reader value) {
builder.setBlob<T>(value); builder.setBlob<T>(value);
} }
static inline void setCanonical(PointerBuilder builder, typename T::Reader value) {
builder.setBlob<T>(value);
}
static inline typename T::Builder init(PointerBuilder builder, uint size) { static inline typename T::Builder init(PointerBuilder builder, uint size) {
return builder.initBlob<T>(size * BYTES); return builder.initBlob<T>(size * BYTES);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment