Commit 9e3eb675 authored by Kenton Varda's avatar Kenton Varda

Starting on dynamic message manipulation (manipluating messages without knowing…

Starting on dynamic message manipulation (manipluating messages without knowing the schema, or without compiling in the schema).
parent de33baaf
......@@ -79,7 +79,7 @@ public:
inline SegmentReader(Arena* arena, SegmentId id, ArrayPtr<const word> ptr,
ReadLimiter* readLimiter);
CAPNPROTO_ALWAYS_INLINE(bool containsInterval(const word* from, const word* to));
CAPNPROTO_ALWAYS_INLINE(bool containsInterval(const void* from, const void* to));
inline Arena* getArena();
inline SegmentId getSegmentId();
......@@ -223,9 +223,12 @@ inline SegmentReader::SegmentReader(Arena* arena, SegmentId id, ArrayPtr<const w
ReadLimiter* readLimiter)
: arena(arena), id(id), ptr(ptr), readLimiter(readLimiter) {}
inline bool SegmentReader::containsInterval(const word* from, const word* to) {
inline bool SegmentReader::containsInterval(const void* from, const void* to) {
return from >= this->ptr.begin() && to <= this->ptr.end() &&
readLimiter->canRead(intervalLength(from, to), arena);
readLimiter->canRead(
intervalLength(reinterpret_cast<const byte*>(from),
reinterpret_cast<const byte*>(to)) / BYTES_PER_WORD,
arena);
}
inline Arena* SegmentReader::getArena() { return arena; }
......
......@@ -610,6 +610,49 @@ struct WireHelpers {
}
}
static CAPNPROTO_ALWAYS_INLINE(ObjectBuilder getWritableObjectReference(
SegmentBuilder* segment, WireReference* ref, const word* defaultValue)) {
word* ptr;
if (ref->isNull()) {
if (defaultValue == nullptr) {
return ObjectBuilder();
} else {
ptr = copyMessage(segment, ref, reinterpret_cast<const WireReference*>(defaultValue));
}
} else {
ptr = followFars(ref, segment);
}
if (ref->kind() == WireReference::LIST) {
if (ref->listRef.elementSize() == FieldSize::INLINE_COMPOSITE) {
// Read the tag to get the actual element count.
WireReference* tag = reinterpret_cast<WireReference*>(ptr);
PRECOND(tag->kind() == WireReference::STRUCT,
"INLINE_COMPOSITE list with non-STRUCT elements not supported.");
// First list element is at tag + 1 reference.
return ObjectBuilder(
ListBuilder(segment, tag + 1, tag->structRef.wordSize() * BYTES_PER_WORD / ELEMENTS,
tag->inlineCompositeListElementCount()),
FieldSize::INLINE_COMPOSITE);
} else {
auto step = bytesPerElement(ref->listRef.elementSize());
return ObjectBuilder(
ListBuilder(segment, ptr, step, ref->listRef.elementCount()),
ref->listRef.elementSize());
}
} else {
return ObjectBuilder(
StructBuilder(segment, ptr,
reinterpret_cast<WireReference*>(ptr + ref->structRef.dataSize.get())),
ref->structRef.dataSize.get() * BYTES_PER_WORD,
ref->structRef.refCount.get());
}
}
// -----------------------------------------------------------------
static CAPNPROTO_ALWAYS_INLINE(StructReader readStructReference(
SegmentReader* segment, const WireReference* ref, const word* defaultValue,
int nestingLimit)) {
......@@ -783,7 +826,7 @@ struct WireHelpers {
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr +
roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)),
"Message contained out-of-bounds list reference.") {
"Message contains out-of-bounds list reference.") {
goto useDefault;
}
}
......@@ -917,6 +960,118 @@ struct WireHelpers {
return Data::Reader(reinterpret_cast<const char*>(ptr), size);
}
}
static CAPNPROTO_ALWAYS_INLINE(ObjectReader readObjectReference(
SegmentReader* segment, const WireReference* ref,
const word* defaultValue, int nestingLimit)) {
// We can't really reuse readStructReference() and readListReference() because they are designed
// for the case where we are expecting a specific type, and they do validation around that,
// whereas this method is for the case where we accept any pointer.
const word* ptr;
if (ref == nullptr || ref->isNull()) {
useDefault:
if (defaultValue == nullptr) {
return ObjectReader();
}
segment = nullptr;
ref = reinterpret_cast<const WireReference*>(defaultValue);
ptr = ref->target();
} else if (segment != nullptr) {
ptr = WireHelpers::followFars(ref, segment);
if (CAPNPROTO_EXPECT_FALSE(ptr == nullptr)) {
// Already reported the error.
goto useDefault;
}
} else {
ptr = ref->target();
}
switch (ref->kind()) {
case WireReference::STRUCT:
if (segment != nullptr) {
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct reference.") {
goto useDefault;
}
}
return ObjectReader(
StructReader(segment, ptr,
reinterpret_cast<const WireReference*>(ptr + ref->structRef.dataSize.get()),
ref->structRef.dataSize.get() * BYTES_PER_WORD,
ref->structRef.refCount.get(),
nestingLimit - 1));
case WireReference::LIST: {
FieldSize elementSize = ref->listRef.elementSize();
if (segment != nullptr) {
VALIDATE_INPUT(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault;
}
}
if (elementSize == FieldSize::INLINE_COMPOSITE) {
WordCount wordCount = ref->listRef.inlineCompositeWordCount();
const WireReference* tag = reinterpret_cast<const WireReference*>(ptr);
ptr += REFERENCE_SIZE_IN_WORDS;
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr - REFERENCE_SIZE_IN_WORDS,
ptr + wordCount),
"Message contains out-of-bounds list reference.") {
goto useDefault;
}
VALIDATE_INPUT(tag->kind() == WireReference::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault;
}
}
ElementCount elementCount = tag->inlineCompositeListElementCount();
auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
if (segment != nullptr) {
VALIDATE_INPUT(wordsPerElement * elementCount <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.");
}
return ObjectReader(
ListReader(segment, ptr, elementCount, wordsPerElement * BYTES_PER_WORD,
tag->structRef.dataSize.get() * BYTES_PER_WORD,
tag->structRef.refCount.get(), nestingLimit - 1),
elementSize);
} else {
decltype(BITS / ELEMENTS) step = bitsPerElement(elementSize);
ElementCount elementCount = ref->listRef.elementCount();
WordCount wordCount = roundUpToWords(ElementCount64(elementCount) * step);
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + wordCount),
"Message contains out-of-bounds list reference.") {
goto useDefault;
}
}
return ObjectReader(
ListReader(segment, ptr, elementCount, step / BITS_PER_BYTE,
elementSize == FieldSize::REFERENCE ? 0 * BYTES : step * ELEMENTS / BITS_PER_BYTE,
elementSize == FieldSize::REFERENCE ? 1 * REFERENCES : 0 * REFERENCES,
nestingLimit - 1),
elementSize);
}
}
default:
FAIL_VALIDATE_INPUT("Message contained invalid pointer.") {}
goto useDefault;
}
}
};
// =======================================================================================
......@@ -988,6 +1143,11 @@ Data::Builder StructBuilder::getDataField(
references + refIndex, segment, defaultValue, defaultSize);
}
ObjectBuilder StructBuilder::getObjectField(
WireReferenceCount refIndex, const word* defaultValue) const {
return WireHelpers::getWritableObjectReference(segment, references + refIndex, defaultValue);
}
StructReader StructBuilder::asReader() const {
// HACK: We just give maxed-out data size and reference counts because they are only
// used for checking for field presence.
......@@ -1048,6 +1208,12 @@ Data::Reader StructReader::getDataField(
return WireHelpers::readDataReference(segment, ref, defaultValue, defaultSize);
}
ObjectReader StructReader::getObjectField(
WireReferenceCount refIndex, const word* defaultValue) const {
return WireHelpers::readObjectReference(
segment, references + refIndex, defaultValue, nestingLimit);
}
// =======================================================================================
// ListBuilder
......@@ -1103,6 +1269,11 @@ Data::Builder ListBuilder::getDataElement(ElementCount index) const {
reinterpret_cast<WireReference*>(ptr + index * stepBytes), segment, nullptr, 0 * BYTES);
}
ObjectBuilder ListBuilder::getObjectElement(ElementCount index, const word* defaultValue) const {
return WireHelpers::getWritableObjectReference(
segment, reinterpret_cast<WireReference*>(ptr + index * stepBytes), defaultValue);
}
ListReader ListBuilder::asReader(FieldSize elementSize) const {
// TODO: For INLINE_COMPOSITE I suppose we could just check the tag?
PRECOND(elementSize != FieldSize::INLINE_COMPOSITE,
......@@ -1173,5 +1344,10 @@ Data::Reader ListReader::getDataElement(ElementCount index) const {
nullptr, 0 * BYTES);
}
ObjectReader ListReader::getObjectElement(ElementCount index, const word* defaultValue) const {
return WireHelpers::readObjectReference(
segment, checkAlignment(ptr + index * stepBytes), defaultValue, nestingLimit);
}
} // namespace internal
} // namespace capnproto
......@@ -42,6 +42,8 @@ class StructBuilder;
class StructReader;
class ListBuilder;
class ListReader;
class ObjectBuilder;
class ObjectReader;
struct WireReference;
struct WireHelpers;
class SegmentReader;
......@@ -49,7 +51,6 @@ class SegmentBuilder;
class FieldDescriptor;
typedef Id<uint8_t, FieldDescriptor> FieldNumber;
enum class FieldSize: uint8_t;
enum class FieldSize: uint8_t {
// TODO: Rename to FieldLayout or maybe ValueLayout.
......@@ -319,6 +320,9 @@ public:
const void* defaultValue, ByteCount defaultSize) const;
// Same as *Text*, but for data blobs.
ObjectBuilder getObjectField(WireReferenceCount refIndex, const word* defaultValue) const;
// Read a pointer of arbitrary type.
StructReader asReader() const;
// Gets a StructReader pointing at the same memory.
......@@ -375,6 +379,9 @@ public:
const void* defaultValue, ByteCount defaultSize) const;
// Gets the data field, or the given default value if not initialized.
ObjectReader getObjectField(WireReferenceCount refIndex, const word* defaultValue) const;
// Read a pointer of arbitrary type.
WireReferenceCount getReferenceCount() { return referenceCount; }
private:
......@@ -457,6 +464,9 @@ public:
Data::Builder getDataElement(ElementCount index) const;
// Like *Text*() but for Data.
ObjectBuilder getObjectElement(ElementCount index, const word* defaultValue) const;
// Gets a pointer element of arbitrary type.
ListReader asReader(FieldSize elementSize) const;
// Get a ListReader pointing at the same memory. Use this version only for non-struct lists.
......@@ -508,6 +518,9 @@ public:
Data::Reader getDataElement(ElementCount index) const;
// Get the data element. If it is not initialized, returns an empty Data::Reader.
ObjectReader getObjectElement(ElementCount index, const word* defaultValue) const;
// Gets a pointer element of arbitrary type.
private:
SegmentReader* segment; // Memory segment in which the list resides.
......@@ -546,6 +559,69 @@ private:
friend struct WireHelpers;
};
// -------------------------------------------------------------------
struct ObjectBuilder {
// A reader for any kind of object.
enum Kind {
NULL_POINTER, // Object was read from a null pointer.
STRUCT,
LIST
};
Kind kind;
FieldSize listElementSize;
// Only set if kind == LIST. This would be part of the union, except that then ObjectReader would
// end up larger overall.
WireReferenceCount16 structPointerSectionSize;
ByteCount32 structDataSectionSize;
// For kind == STRUCT, the size of the struct. For kind == LIST, the size of each element of the
// list, unless listElementSize == BIT.
union {
StructBuilder structBuilder;
ListBuilder listBuilder;
};
ObjectBuilder(): kind(NULL_POINTER), structBuilder() {}
ObjectBuilder(StructBuilder structBuilder, ByteCount32 dataSectionSize,
WireReferenceCount16 pointerSectionSize)
: kind(STRUCT), structPointerSectionSize(pointerSectionSize),
structDataSectionSize(dataSectionSize), structBuilder(structBuilder) {}
ObjectBuilder(ListBuilder listBuilderBuilder, FieldSize elementSize)
: kind(LIST), listElementSize(elementSize), listBuilder(listBuilder) {}
};
struct ObjectReader {
// A reader for any kind of object.
enum Kind {
NULL_POINTER, // Object was read from a null pointer.
STRUCT,
LIST
};
Kind kind;
FieldSize listElementSize;
// Only set if kind == LIST. This would be part of the union, except that then ObjectReader would
// end up larger overall.
union {
StructReader structReader;
ListReader listReader;
};
ObjectReader(): kind(NULL_POINTER), structReader() {}
ObjectReader(StructReader structReader)
: kind(STRUCT), structReader(structReader) {}
ObjectReader(ListReader listReader, FieldSize elementSize)
: kind(LIST), listElementSize(elementSize), listReader(listReader) {}
};
// =======================================================================================
// Internal implementation details...
......
......@@ -43,14 +43,38 @@ public:
static constexpr bool value = sizeof(test<T>(nullptr)) == sizeof(yes);
};
template <typename T>
constexpr bool isPrimitive() { return IsPrimitive<T>::value; }
template <typename T, bool isPrimitive = isPrimitive<T>()>
struct MaybeReaderBuilder {};
template <typename T>
struct MaybeReaderBuilder<T, true> {
typedef T Reader;
typedef T Builder;
};
template <typename T>
struct MaybeReaderBuilder<T, false> {
typedef typename T::Reader Reader;
typedef typename T::Builder Builder;
};
template <typename t>
struct PointerHelpers;
} // namespace internal
template <typename T, bool isPrimitive = internal::IsPrimitive<T>::value>
template <typename T, bool isPrimitive = internal::isPrimitive<T>()>
struct List;
template <typename T>
using ReaderFor = typename internal::MaybeReaderBuilder<T>::Reader;
// The type returned by List<T>::Reader::operator[].
template <typename T>
using BuilderFor = typename internal::MaybeReaderBuilder<T>::Reader;
// The type returned by List<T>::Builder::operator[].
namespace internal {
template <size_t size> struct FieldSizeForByteSize;
......@@ -60,7 +84,7 @@ template <> struct FieldSizeForByteSize<4> { static constexpr FieldSize value =
template <> struct FieldSizeForByteSize<8> { static constexpr FieldSize value = FieldSize::EIGHT_BYTES; };
template <typename T> struct FieldSizeForType {
static constexpr FieldSize value = IsPrimitive<T>::value ?
static constexpr FieldSize value = isPrimitive<T>() ?
// Primitive types that aren't special-cased below can be determined from sizeof().
FieldSizeForByteSize<sizeof(T)>::value :
......
......@@ -54,8 +54,8 @@ struct NoInfer {
template <typename T> struct RemoveReference { typedef T Type; };
template <typename T> struct RemoveReference<T&> { typedef T Type; };
template<typename> struct IsLvalueReference { static constexpr bool value = false; };
template<typename T> struct IsLvalueReference<T&> { static constexpr bool value = true; };
template <typename> struct IsLvalueReference { static constexpr bool value = false; };
template <typename T> struct IsLvalueReference<T&> { static constexpr bool value = true; };
// #including <utility> just for std::move() and std::forward() is excessive. Instead, we
// re-define them here.
......@@ -71,6 +71,164 @@ template<typename T> constexpr T&& forward(typename RemoveReference<T>::Type&& t
return static_cast<T&&>(t);
}
template <typename T>
T instance() noexcept;
// Like std::declval, but doesn't transform T into an rvalue reference. If you want that, specify
// instance<T&&>().
// =======================================================================================
// Maybe
template <typename T>
class Maybe {
public:
Maybe(): isSet(false) {}
Maybe(T&& t)
: isSet(true) {
new (&value) T(move(t));
}
Maybe(const T& t)
: isSet(true) {
new (&value) T(t);
}
Maybe(Maybe&& other) noexcept(noexcept(T(capnproto::move(other.value))))
: isSet(other.isSet) {
if (isSet) {
new (&value) T(move(other.value));
}
}
Maybe(const Maybe& other)
: isSet(other.isSet) {
if (isSet) {
new (&value) T(other.value);
}
}
Maybe(std::nullptr_t): isSet(false) {}
~Maybe() {
if (isSet) {
value.~T();
}
}
template <typename... Params>
inline void init(Params&&... params) {
if (isSet) {
value.~T();
}
isSet = true;
new (&value) T(capnproto::forward(params)...);
}
inline T& operator*() { return value; }
inline const T& operator*() const { return value; }
inline T* operator->() { return &value; }
inline const T* operator->() const { return &value; }
inline Maybe& operator=(Maybe&& other) {
if (&other != this) {
if (isSet) {
value.~T();
}
isSet = other.isSet;
if (isSet) {
new (&value) T(move(other.value));
}
}
return *this;
}
inline Maybe& operator=(const Maybe& other) {
if (&other != this) {
if (isSet) {
value.~T();
}
isSet = other.isSet;
if (isSet) {
new (&value) T(other.value);
}
}
return *this;
}
bool operator==(const Maybe& other) const {
if (isSet == other.isSet) {
if (isSet) {
return value == other.value;
} else {
return true;
}
}
return false;
}
inline bool operator!=(const Maybe& other) const { return !(*this == other); }
inline bool operator==(std::nullptr_t) const { return !isSet; }
inline bool operator!=(std::nullptr_t) const { return isSet; }
template <typename Func>
auto map(const Func& func) const -> Maybe<decltype(func(instance<const T&>()))> {
// Construct a new Maybe by applying the given function to the Maybe's value.
if (isSet) {
return func(value);
} else {
return nullptr;
}
}
template <typename Func>
auto map(const Func& func) -> Maybe<decltype(func(instance<T&>()))> {
// Construct a new Maybe by applying the given function to the Maybe's value.
if (isSet) {
return func(value);
} else {
return nullptr;
}
}
template <typename Func>
auto moveMap(const Func& func) -> Maybe<decltype(func(instance<T&&>()))> {
// Like map() but allows the function to take an rvalue reference to the value.
if (isSet) {
return func(capnproto::move(value));
} else {
return nullptr;
}
}
private:
bool isSet;
union {
T value;
};
};
template <typename T>
class Maybe<T&> {
public:
Maybe(): ptr(nullptr) {}
Maybe(T& t): ptr(&t) {}
Maybe(std::nullptr_t): ptr(nullptr) {}
~Maybe() noexcept {}
inline T& operator*() { return *ptr; }
inline const T& operator*() const { return *ptr; }
inline T* operator->() { return ptr; }
inline const T* operator->() const { return ptr; }
inline bool operator==(const Maybe& other) const { return ptr == other.ptr; }
inline bool operator!=(const Maybe& other) const { return ptr != other.ptr; }
inline bool operator==(std::nullptr_t) const { return ptr == nullptr; }
inline bool operator!=(std::nullptr_t) const { return ptr != nullptr; }
private:
T* ptr;
};
// =======================================================================================
// ArrayPtr
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment