Commit e2cbc1b5 authored by Kenton Varda's avatar Kenton Varda

WIP integer overflow detection via template metaprogramming.

See: https://capnproto.org/news/2015-03-02-security-advisory-and-integer-overflow-protection.html

This commit as-is is the result of wading through two years of merge conflicts. It does not build as-is because new code added in that time hasn't been converted over.
parent a37a0cc7
...@@ -45,7 +45,7 @@ kj::Own<ClientHook> AnyPointer::Reader::getPipelinedCap( ...@@ -45,7 +45,7 @@ kj::Own<ClientHook> AnyPointer::Reader::getPipelinedCap(
break; break;
case PipelineOp::Type::GET_POINTER_FIELD: case PipelineOp::Type::GET_POINTER_FIELD:
pointer = pointer.getStruct(nullptr).getPointerField(op.pointerIndex * POINTERS); pointer = pointer.getStruct(nullptr).getPointerField(guarded(op.pointerIndex) * POINTERS);
break; break;
} }
} }
......
...@@ -208,9 +208,9 @@ struct AnyPointer { ...@@ -208,9 +208,9 @@ struct AnyPointer {
// Note: Does not accept INLINE_COMPOSITE for elementSize. // Note: Does not accept INLINE_COMPOSITE for elementSize.
inline List<AnyStruct>::Builder initAsListOfAnyStruct( inline List<AnyStruct>::Builder initAsListOfAnyStruct(
uint dataWordCount, uint pointerCount, uint elementCount); uint16_t dataWordCount, uint16_t pointerCount, uint elementCount);
inline AnyStruct::Builder initAsAnyStruct(uint dataWordCount, uint pointerCount); inline AnyStruct::Builder initAsAnyStruct(uint16_t dataWordCount, uint16_t pointerCount);
template <typename T> template <typename T>
inline void setAs(ReaderFor<T> value); inline void setAs(ReaderFor<T> value);
...@@ -398,10 +398,10 @@ struct List<AnyPointer, Kind::OTHER> { ...@@ -398,10 +398,10 @@ struct List<AnyPointer, Kind::OTHER> {
inline Reader(): reader(ElementSize::POINTER) {} inline Reader(): reader(ElementSize::POINTER) {}
inline explicit Reader(_::ListReader reader): reader(reader) {} inline explicit Reader(_::ListReader reader): reader(reader) {}
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return unguard(reader.size() / ELEMENTS); }
inline AnyPointer::Reader operator[](uint index) const { inline AnyPointer::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return AnyPointer::Reader(reader.getPointerElement(index * ELEMENTS)); return AnyPointer::Reader(reader.getPointerElement(guarded(index) * ELEMENTS));
} }
typedef _::IndexingIterator<const Reader, typename AnyPointer::Reader> Iterator; typedef _::IndexingIterator<const Reader, typename AnyPointer::Reader> Iterator;
...@@ -430,10 +430,10 @@ struct List<AnyPointer, Kind::OTHER> { ...@@ -430,10 +430,10 @@ struct List<AnyPointer, Kind::OTHER> {
inline operator Reader() const { return Reader(builder.asReader()); } inline operator Reader() const { return Reader(builder.asReader()); }
inline Reader asReader() const { return Reader(builder.asReader()); } inline Reader asReader() const { return Reader(builder.asReader()); }
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return unguard(builder.size() / ELEMENTS); }
inline AnyPointer::Builder operator[](uint index) { inline AnyPointer::Builder operator[](uint index) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return AnyPointer::Builder(builder.getPointerElement(index * ELEMENTS)); return AnyPointer::Builder(builder.getPointerElement(guarded(index) * ELEMENTS));
} }
typedef _::IndexingIterator<Builder, typename AnyPointer::Builder> Iterator; typedef _::IndexingIterator<Builder, typename AnyPointer::Builder> Iterator;
...@@ -563,10 +563,10 @@ public: ...@@ -563,10 +563,10 @@ public:
inline Reader(): reader(ElementSize::INLINE_COMPOSITE) {} inline Reader(): reader(ElementSize::INLINE_COMPOSITE) {}
inline explicit Reader(_::ListReader reader): reader(reader) {} inline explicit Reader(_::ListReader reader): reader(reader) {}
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return unguard(reader.size() / ELEMENTS); }
inline AnyStruct::Reader operator[](uint index) const { inline AnyStruct::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return AnyStruct::Reader(reader.getStructElement(index * ELEMENTS)); return AnyStruct::Reader(reader.getStructElement(guarded(index) * ELEMENTS));
} }
typedef _::IndexingIterator<const Reader, typename AnyStruct::Reader> Iterator; typedef _::IndexingIterator<const Reader, typename AnyStruct::Reader> Iterator;
...@@ -595,10 +595,10 @@ public: ...@@ -595,10 +595,10 @@ public:
inline operator Reader() const { return Reader(builder.asReader()); } inline operator Reader() const { return Reader(builder.asReader()); }
inline Reader asReader() const { return Reader(builder.asReader()); } inline Reader asReader() const { return Reader(builder.asReader()); }
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return unguard(builder.size() / ELEMENTS); }
inline AnyStruct::Builder operator[](uint index) { inline AnyStruct::Builder operator[](uint index) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return AnyStruct::Builder(builder.getStructElement(index * ELEMENTS)); return AnyStruct::Builder(builder.getStructElement(guarded(index) * ELEMENTS));
} }
typedef _::IndexingIterator<Builder, typename AnyStruct::Builder> Iterator; typedef _::IndexingIterator<Builder, typename AnyStruct::Builder> Iterator;
...@@ -628,7 +628,7 @@ public: ...@@ -628,7 +628,7 @@ public:
#endif #endif
inline ElementSize getElementSize() { return _reader.getElementSize(); } inline ElementSize getElementSize() { return _reader.getElementSize(); }
inline uint size() { return _reader.size() / ELEMENTS; } inline uint size() { return unguard(_reader.size() / ELEMENTS); }
inline kj::ArrayPtr<const byte> getRawBytes() { return _reader.asRawBytes(); } inline kj::ArrayPtr<const byte> getRawBytes() { return _reader.asRawBytes(); }
...@@ -664,7 +664,7 @@ public: ...@@ -664,7 +664,7 @@ public:
#endif #endif
inline ElementSize getElementSize() { return _builder.getElementSize(); } inline ElementSize getElementSize() { return _builder.getElementSize(); }
inline uint size() { return _builder.size() / ELEMENTS; } inline uint size() { return unguard(_builder.size() / ELEMENTS); }
Equality equals(AnyList::Reader right); Equality equals(AnyList::Reader right);
inline bool operator==(AnyList::Reader right) { inline bool operator==(AnyList::Reader right) {
...@@ -781,18 +781,21 @@ inline BuilderFor<T> AnyPointer::Builder::initAs(uint elementCount) { ...@@ -781,18 +781,21 @@ inline BuilderFor<T> AnyPointer::Builder::initAs(uint elementCount) {
inline AnyList::Builder AnyPointer::Builder::initAsAnyList( inline AnyList::Builder AnyPointer::Builder::initAsAnyList(
ElementSize elementSize, uint elementCount) { ElementSize elementSize, uint elementCount) {
return AnyList::Builder(builder.initList(elementSize, elementCount * ELEMENTS)); return AnyList::Builder(builder.initList(elementSize, guarded(elementCount) * ELEMENTS));
} }
inline List<AnyStruct>::Builder AnyPointer::Builder::initAsListOfAnyStruct( inline List<AnyStruct>::Builder AnyPointer::Builder::initAsListOfAnyStruct(
uint dataWordCount, uint pointerCount, uint elementCount) { uint16_t dataWordCount, uint16_t pointerCount, uint elementCount) {
return List<AnyStruct>::Builder(builder.initStructList(elementCount * ELEMENTS, return List<AnyStruct>::Builder(builder.initStructList(guarded(elementCount) * ELEMENTS,
_::StructSize(dataWordCount * WORDS, pointerCount * POINTERS))); _::StructSize(guarded(dataWordCount) * WORDS,
guarded(pointerCount) * POINTERS)));
} }
inline AnyStruct::Builder AnyPointer::Builder::initAsAnyStruct(uint dataWordCount, uint pointerCount) { inline AnyStruct::Builder AnyPointer::Builder::initAsAnyStruct(
uint16_t dataWordCount, uint16_t pointerCount) {
return AnyStruct::Builder(builder.initStruct( return AnyStruct::Builder(builder.initStruct(
_::StructSize(dataWordCount * WORDS, pointerCount * POINTERS))); _::StructSize(guarded(dataWordCount) * WORDS,
guarded(pointerCount) * POINTERS)));
} }
template <typename T> template <typename T>
...@@ -960,15 +963,16 @@ struct PointerHelpers<AnyStruct, Kind::OTHER> { ...@@ -960,15 +963,16 @@ struct PointerHelpers<AnyStruct, Kind::OTHER> {
PointerBuilder builder, const word* defaultValue = nullptr) { PointerBuilder builder, const word* defaultValue = nullptr) {
// TODO(someday): Allow specifying the size somehow? // TODO(someday): Allow specifying the size somehow?
return AnyStruct::Builder(builder.getStruct( return AnyStruct::Builder(builder.getStruct(
_::StructSize(0 * WORDS, 0 * POINTERS), defaultValue)); _::StructSize(ZERO * WORDS, ZERO * POINTERS), defaultValue));
} }
static inline void set(PointerBuilder builder, AnyStruct::Reader value) { static inline void set(PointerBuilder builder, AnyStruct::Reader value) {
builder.setStruct(value._reader); builder.setStruct(value._reader);
} }
static inline AnyStruct::Builder init( static inline AnyStruct::Builder init(
PointerBuilder builder, uint dataWordCount, uint pointerCount) { PointerBuilder builder, uint16_t dataWordCount, uint16_t pointerCount) {
return AnyStruct::Builder(builder.initStruct( return AnyStruct::Builder(builder.initStruct(
StructSize(dataWordCount * WORDS, pointerCount * POINTERS))); StructSize(guarded(dataWordCount) * WORDS,
guarded(pointerCount) * POINTERS)));
} }
// TODO(soon): implement these // TODO(soon): implement these
...@@ -991,12 +995,15 @@ struct PointerHelpers<AnyList, Kind::OTHER> { ...@@ -991,12 +995,15 @@ struct PointerHelpers<AnyList, Kind::OTHER> {
} }
static inline AnyList::Builder init( static inline AnyList::Builder init(
PointerBuilder builder, ElementSize elementSize, uint elementCount) { PointerBuilder builder, ElementSize elementSize, uint elementCount) {
return AnyList::Builder(builder.initList(elementSize, elementCount * ELEMENTS)); return AnyList::Builder(builder.initList(
elementSize, guarded(elementCount) * ELEMENTS));
} }
static inline AnyList::Builder init( static inline AnyList::Builder init(
PointerBuilder builder, uint dataWordCount, uint pointerCount, uint elementCount) { PointerBuilder builder, uint16_t dataWordCount, uint16_t pointerCount, uint elementCount) {
return AnyList::Builder(builder.initStructList( return AnyList::Builder(builder.initStructList(
elementCount * ELEMENTS, StructSize(dataWordCount * WORDS, pointerCount * POINTERS))); guarded(elementCount) * ELEMENTS,
StructSize(guarded(dataWordCount) * WORDS,
guarded(pointerCount) * POINTERS)));
} }
// TODO(soon): implement these // TODO(soon): implement these
......
...@@ -42,7 +42,7 @@ void ReadLimiter::unread(WordCount64 amount) { ...@@ -42,7 +42,7 @@ void ReadLimiter::unread(WordCount64 amount) {
// the limit value was not updated correctly for one or more reads, and therefore unread() could // the limit value was not updated correctly for one or more reads, and therefore unread() could
// overflow it even if it is only unreading bytes that were actually read. // overflow it even if it is only unreading bytes that were actually read.
uint64_t oldValue = limit; uint64_t oldValue = limit;
uint64_t newValue = oldValue + amount / WORDS; uint64_t newValue = oldValue + unguard(amount / WORDS);
if (newValue > oldValue) { if (newValue > oldValue) {
limit = newValue; limit = newValue;
} }
...@@ -57,10 +57,24 @@ void SegmentBuilder::throwNotWritable() { ...@@ -57,10 +57,24 @@ void SegmentBuilder::throwNotWritable() {
// ======================================================================================= // =======================================================================================
ReaderArena::ReaderArena(MessageReader* message) static SegmentWordCount verifySegmentSize(size_t size) {
auto gsize = guarded(size) * WORDS;
return assertMaxBits<SEGMENT_WORD_COUNT_BITS>(gsize, [&]() {
KJ_FAIL_REQUIRE("segment is too large", size);
});
}
inline ReaderArena::ReaderArena(MessageReader* message, const word* firstSegment,
SegmentWordCount firstSegmentSize)
: message(message), : message(message),
readLimiter(message->getOptions().traversalLimitInWords * WORDS), readLimiter(guarded(message->getOptions().traversalLimitInWords) * WORDS),
segment0(this, SegmentId(0), message->getSegment(0), &readLimiter) {} segment0(this, SegmentId(0), firstSegment, firstSegmentSize, &readLimiter) {}
inline ReaderArena::ReaderArena(MessageReader* message, kj::ArrayPtr<const word> firstSegment)
: ReaderArena(message, firstSegment.begin(), verifySegmentSize(firstSegment.size())) {}
ReaderArena::ReaderArena(MessageReader* message)
: ReaderArena(message, message->getSegment(0)) {}
ReaderArena::~ReaderArena() noexcept(false) {} ReaderArena::~ReaderArena() noexcept(false) {}
...@@ -89,6 +103,8 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { ...@@ -89,6 +103,8 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
return nullptr; return nullptr;
} }
SegmentWordCount newSegmentSize = verifySegmentSize(newSegment.size());
if (*lock == nullptr) { if (*lock == nullptr) {
// OK, the segment exists, so allocate the map. // OK, the segment exists, so allocate the map.
auto s = kj::heap<SegmentMap>(); auto s = kj::heap<SegmentMap>();
...@@ -96,7 +112,8 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { ...@@ -96,7 +112,8 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
*lock = kj::mv(s); *lock = kj::mv(s);
} }
auto segment = kj::heap<SegmentReader>(this, id, newSegment, &readLimiter); auto segment = kj::heap<SegmentReader>(
this, id, newSegment.begin(), newSegmentSize, &readLimiter);
SegmentReader* result = segment; SegmentReader* result = segment;
segments->insert(std::make_pair(id.value, mv(segment))); segments->insert(std::make_pair(id.value, mv(segment)));
return result; return result;
...@@ -116,14 +133,17 @@ BuilderArena::BuilderArena(MessageBuilder* message) ...@@ -116,14 +133,17 @@ BuilderArena::BuilderArena(MessageBuilder* message)
BuilderArena::BuilderArena(MessageBuilder* message, BuilderArena::BuilderArena(MessageBuilder* message,
kj::ArrayPtr<MessageBuilder::SegmentInit> segments) kj::ArrayPtr<MessageBuilder::SegmentInit> segments)
: message(message), : message(message),
segment0(this, SegmentId(0), segments[0].space, &this->dummyLimiter, segments[0].wordsUsed) { segment0(this, SegmentId(0), segments[0].space.begin(),
verifySegmentSize(segments[0].space.size()),
&this->dummyLimiter, verifySegmentSize(segments[0].wordsUsed)) {
if (segments.size() > 1) { if (segments.size() > 1) {
kj::Vector<kj::Own<SegmentBuilder>> builders(segments.size() - 1); kj::Vector<kj::Own<SegmentBuilder>> builders(segments.size() - 1);
uint i = 1; uint i = 1;
for (auto& segment: segments.slice(1, segments.size())) { for (auto& segment: segments.slice(1, segments.size())) {
builders.add(kj::heap<SegmentBuilder>( builders.add(kj::heap<SegmentBuilder>(
this, SegmentId(i++), segment.space, &this->dummyLimiter, segment.wordsUsed)); this, SegmentId(i++), segment.space.begin(), verifySegmentSize(segment.space.size()),
&this->dummyLimiter, verifySegmentSize(segment.wordsUsed)));
} }
kj::Vector<kj::ArrayPtr<const word>> forOutput; kj::Vector<kj::ArrayPtr<const word>> forOutput;
...@@ -155,15 +175,16 @@ SegmentBuilder* BuilderArena::getSegment(SegmentId id) { ...@@ -155,15 +175,16 @@ SegmentBuilder* BuilderArena::getSegment(SegmentId id) {
} }
} }
BuilderArena::AllocateResult BuilderArena::allocate(WordCount amount) { BuilderArena::AllocateResult BuilderArena::allocate(SegmentWordCount amount) {
if (segment0.getArena() == nullptr) { if (segment0.getArena() == nullptr) {
// We're allocating the first segment. // We're allocating the first segment.
kj::ArrayPtr<word> ptr = message->allocateSegment(amount / WORDS); kj::ArrayPtr<word> ptr = message->allocateSegment(unguard(amount / WORDS));
auto actualSize = verifySegmentSize(ptr.size());
// Re-allocate segment0 in-place. This is a bit of a hack, but we have not returned any // Re-allocate segment0 in-place. This is a bit of a hack, but we have not returned any
// pointers to this segment yet, so it should be fine. // pointers to this segment yet, so it should be fine.
kj::dtor(segment0); kj::dtor(segment0);
kj::ctor(segment0, this, SegmentId(0), ptr, &this->dummyLimiter); kj::ctor(segment0, this, SegmentId(0), ptr.begin(), actualSize, &this->dummyLimiter);
segmentWithSpace = &segment0; segmentWithSpace = &segment0;
return AllocateResult { &segment0, segment0.allocate(amount) }; return AllocateResult { &segment0, segment0.allocate(amount) };
...@@ -183,7 +204,7 @@ BuilderArena::AllocateResult BuilderArena::allocate(WordCount amount) { ...@@ -183,7 +204,7 @@ BuilderArena::AllocateResult BuilderArena::allocate(WordCount amount) {
} }
// Need to allocate a new segment. // Need to allocate a new segment.
SegmentBuilder* result = addSegmentInternal(message->allocateSegment(amount / WORDS)); SegmentBuilder* result = addSegmentInternal(message->allocateSegment(unguard(amount / WORDS)));
// Check this new segment first the next time we need to allocate. // Check this new segment first the next time we need to allocate.
segmentWithSpace = result; segmentWithSpace = result;
...@@ -204,6 +225,8 @@ SegmentBuilder* BuilderArena::addSegmentInternal(kj::ArrayPtr<T> content) { ...@@ -204,6 +225,8 @@ SegmentBuilder* BuilderArena::addSegmentInternal(kj::ArrayPtr<T> content) {
KJ_REQUIRE(segment0.getArena() != nullptr, KJ_REQUIRE(segment0.getArena() != nullptr,
"Can't allocate external segments before allocating the root segment."); "Can't allocate external segments before allocating the root segment.");
auto contentSize = verifySegmentSize(content.size());
MultiSegmentState* segmentState; MultiSegmentState* segmentState;
KJ_IF_MAYBE(s, moreSegments) { KJ_IF_MAYBE(s, moreSegments) {
segmentState = *s; segmentState = *s;
...@@ -214,7 +237,8 @@ SegmentBuilder* BuilderArena::addSegmentInternal(kj::ArrayPtr<T> content) { ...@@ -214,7 +237,8 @@ SegmentBuilder* BuilderArena::addSegmentInternal(kj::ArrayPtr<T> content) {
} }
kj::Own<SegmentBuilder> newBuilder = kj::heap<SegmentBuilder>( kj::Own<SegmentBuilder> newBuilder = kj::heap<SegmentBuilder>(
this, SegmentId(segmentState->builders.size() + 1), content, &this->dummyLimiter); this, SegmentId(segmentState->builders.size() + 1),
content.begin(), contentSize, &this->dummyLimiter);
SegmentBuilder* result = newBuilder.get(); SegmentBuilder* result = newBuilder.get();
segmentState->builders.add(kj::mv(newBuilder)); segmentState->builders.add(kj::mv(newBuilder));
......
...@@ -85,7 +85,7 @@ public: ...@@ -85,7 +85,7 @@ public:
inline void reset(WordCount64 limit); inline void reset(WordCount64 limit);
KJ_ALWAYS_INLINE(bool canRead(WordCount amount, Arena* arena)); KJ_ALWAYS_INLINE(bool canRead(WordCount64 amount, Arena* arena));
void unread(WordCount64 amount); void unread(WordCount64 amount);
// Adds back some words to the limit. Useful when the caller knows they are double-reading // Adds back some words to the limit. Useful when the caller knows they are double-reading
...@@ -113,7 +113,7 @@ public: ...@@ -113,7 +113,7 @@ public:
class SegmentReader { class SegmentReader {
public: public:
inline SegmentReader(Arena* arena, SegmentId id, kj::ArrayPtr<const word> ptr, inline SegmentReader(Arena* arena, SegmentId id, const word* ptr, SegmentWordCount size,
ReadLimiter* readLimiter); ReadLimiter* readLimiter);
KJ_ALWAYS_INLINE(bool containsInterval(const void* from, const void* to)); KJ_ALWAYS_INLINE(bool containsInterval(const void* from, const void* to));
...@@ -129,8 +129,8 @@ public: ...@@ -129,8 +129,8 @@ public:
inline SegmentId getSegmentId(); inline SegmentId getSegmentId();
inline const word* getStartPtr(); inline const word* getStartPtr();
inline WordCount getOffsetTo(const word* ptr); inline SegmentWordCount getOffsetTo(const word* ptr);
inline WordCount getSize(); inline SegmentWordCount getSize();
inline kj::ArrayPtr<const word> getArray(); inline kj::ArrayPtr<const word> getArray();
...@@ -140,7 +140,7 @@ public: ...@@ -140,7 +140,7 @@ public:
private: private:
Arena* arena; Arena* arena;
SegmentId id; SegmentId id;
kj::ArrayPtr<const word> ptr; kj::ArrayPtr<const word> ptr; // size guaranteed to fit in SEGMENT_WORD_COUNT_BITS bits
ReadLimiter* readLimiter; ReadLimiter* readLimiter;
KJ_DISALLOW_COPY(SegmentReader); KJ_DISALLOW_COPY(SegmentReader);
...@@ -150,19 +150,19 @@ private: ...@@ -150,19 +150,19 @@ private:
class SegmentBuilder: public SegmentReader { class SegmentBuilder: public SegmentReader {
public: public:
inline SegmentBuilder(BuilderArena* arena, SegmentId id, kj::ArrayPtr<word> ptr, inline SegmentBuilder(BuilderArena* arena, SegmentId id, word* ptr, SegmentWordCount size,
ReadLimiter* readLimiter, size_t wordsUsed = 0); ReadLimiter* readLimiter, SegmentWordCount wordsUsed = ZERO * WORDS);
inline SegmentBuilder(BuilderArena* arena, SegmentId id, kj::ArrayPtr<const word> ptr, inline SegmentBuilder(BuilderArena* arena, SegmentId id, const word* ptr, SegmentWordCount size,
ReadLimiter* readLimiter); ReadLimiter* readLimiter);
inline SegmentBuilder(BuilderArena* arena, SegmentId id, decltype(nullptr), inline SegmentBuilder(BuilderArena* arena, SegmentId id, decltype(nullptr),
ReadLimiter* readLimiter); ReadLimiter* readLimiter);
KJ_ALWAYS_INLINE(word* allocate(WordCount amount)); KJ_ALWAYS_INLINE(word* allocate(SegmentWordCount amount));
KJ_ALWAYS_INLINE(void checkWritable()); KJ_ALWAYS_INLINE(void checkWritable());
// Throw an exception if the segment is read-only (meaning it is a reference to external data). // Throw an exception if the segment is read-only (meaning it is a reference to external data).
KJ_ALWAYS_INLINE(word* getPtrUnchecked(WordCount offset)); KJ_ALWAYS_INLINE(word* getPtrUnchecked(SegmentWordCount offset));
// Get a writable pointer into the segment. Throws an exception if the segment is read-only (i.e. // Get a writable pointer into the segment. Throws an exception if the segment is read-only (i.e.
// a reference to external immutable data). // a reference to external immutable data).
...@@ -210,7 +210,7 @@ public: ...@@ -210,7 +210,7 @@ public:
class ReaderArena final: public Arena { class ReaderArena final: public Arena {
public: public:
ReaderArena(MessageReader* message); explicit ReaderArena(MessageReader* message);
~ReaderArena() noexcept(false); ~ReaderArena() noexcept(false);
KJ_DISALLOW_COPY(ReaderArena); KJ_DISALLOW_COPY(ReaderArena);
...@@ -234,6 +234,9 @@ private: ...@@ -234,6 +234,9 @@ private:
// TODO(perf): Thread-local thing instead? Some kind of lockless map? Or do sharing of data // TODO(perf): Thread-local thing instead? Some kind of lockless map? Or do sharing of data
// in a different way, where you have to construct a new MessageReader in each thread (but // in a different way, where you have to construct a new MessageReader in each thread (but
// possibly backed by the same data)? // possibly backed by the same data)?
ReaderArena(MessageReader* message, kj::ArrayPtr<const word> firstSegment);
ReaderArena(MessageReader* message, const word* firstSegment, SegmentWordCount firstSegmentSize);
}; };
class BuilderArena final: public Arena { class BuilderArena final: public Arena {
...@@ -277,7 +280,7 @@ public: ...@@ -277,7 +280,7 @@ public:
word* words; word* words;
}; };
AllocateResult allocate(WordCount amount); AllocateResult allocate(SegmentWordCount amount);
// Find a segment with at least the given amount of space available and allocate the space. // Find a segment with at least the given amount of space available and allocate the space.
// Note that allocating directly from a particular segment is much faster, but allocating from // Note that allocating directly from a particular segment is much faster, but allocating from
// the arena is guaranteed to succeed. Therefore callers should try to allocate from a specific // the arena is guaranteed to succeed. Therefore callers should try to allocate from a specific
...@@ -339,34 +342,36 @@ private: ...@@ -339,34 +342,36 @@ private:
inline ReadLimiter::ReadLimiter() inline ReadLimiter::ReadLimiter()
: limit(kj::maxValue) {} : limit(kj::maxValue) {}
inline ReadLimiter::ReadLimiter(WordCount64 limit): limit(limit / WORDS) {} inline ReadLimiter::ReadLimiter(WordCount64 limit): limit(unguard(limit / WORDS)) {}
inline void ReadLimiter::reset(WordCount64 limit) { this->limit = limit / WORDS; } inline void ReadLimiter::reset(WordCount64 limit) { this->limit = unguard(limit / WORDS); }
inline bool ReadLimiter::canRead(WordCount amount, Arena* arena) { inline bool ReadLimiter::canRead(WordCount64 amount, Arena* arena) {
// Be careful not to store an underflowed value into `limit`, even if multiple threads are // Be careful not to store an underflowed value into `limit`, even if multiple threads are
// decrementing it. // decrementing it.
uint64_t current = limit; uint64_t current = limit;
if (KJ_UNLIKELY(amount / WORDS > current)) { if (KJ_UNLIKELY(unguard(amount / WORDS) > current)) {
arena->reportReadLimitReached(); arena->reportReadLimitReached();
return false; return false;
} else { } else {
limit = current - amount / WORDS; limit = current - unguard(amount / WORDS);
return true; return true;
} }
} }
// ------------------------------------------------------------------- // -------------------------------------------------------------------
inline SegmentReader::SegmentReader(Arena* arena, SegmentId id, kj::ArrayPtr<const word> ptr, inline SegmentReader::SegmentReader(Arena* arena, SegmentId id, const word* ptr,
ReadLimiter* readLimiter) SegmentWordCount size, ReadLimiter* readLimiter)
: arena(arena), id(id), ptr(ptr), readLimiter(readLimiter) {} : arena(arena), id(id), ptr(kj::arrayPtr(ptr, unguard(size / WORDS))),
readLimiter(readLimiter) {}
inline bool SegmentReader::containsInterval(const void* from, const void* to) { inline bool SegmentReader::containsInterval(const void* from, const void* to) {
return from >= this->ptr.begin() && to <= this->ptr.end() && from <= to && return from >= this->ptr.begin() && to <= this->ptr.end() && from <= to &&
readLimiter->canRead( readLimiter->canRead(
intervalLength(reinterpret_cast<const byte*>(from), intervalLength(reinterpret_cast<const byte*>(from),
reinterpret_cast<const byte*>(to)) / BYTES_PER_WORD, reinterpret_cast<const byte*>(to))
/ BYTES_PER_WORD,
arena); arena);
} }
...@@ -377,31 +382,36 @@ inline bool SegmentReader::amplifiedRead(WordCount virtualAmount) { ...@@ -377,31 +382,36 @@ inline bool SegmentReader::amplifiedRead(WordCount virtualAmount) {
inline Arena* SegmentReader::getArena() { return arena; } inline Arena* SegmentReader::getArena() { return arena; }
inline SegmentId SegmentReader::getSegmentId() { return id; } inline SegmentId SegmentReader::getSegmentId() { return id; }
inline const word* SegmentReader::getStartPtr() { return ptr.begin(); } inline const word* SegmentReader::getStartPtr() { return ptr.begin(); }
inline WordCount SegmentReader::getOffsetTo(const word* ptr) { inline SegmentWordCount SegmentReader::getOffsetTo(const word* ptr) {
return intervalLength(this->ptr.begin(), ptr); KJ_IREQUIRE(this->ptr.begin() <= ptr && ptr < this->ptr.end());
return assumeBits<SEGMENT_WORD_COUNT_BITS>(intervalLength(this->ptr.begin(), ptr));
}
inline SegmentWordCount SegmentReader::getSize() {
return assumeBits<SEGMENT_WORD_COUNT_BITS>(ptr.size()) * WORDS;
} }
inline WordCount SegmentReader::getSize() { return ptr.size() * WORDS; }
inline kj::ArrayPtr<const word> SegmentReader::getArray() { return ptr; } inline kj::ArrayPtr<const word> SegmentReader::getArray() { return ptr; }
inline void SegmentReader::unread(WordCount64 amount) { readLimiter->unread(amount); } inline void SegmentReader::unread(WordCount64 amount) { readLimiter->unread(amount); }
// ------------------------------------------------------------------- // -------------------------------------------------------------------
inline SegmentBuilder::SegmentBuilder( inline SegmentBuilder::SegmentBuilder(
BuilderArena* arena, SegmentId id, kj::ArrayPtr<word> ptr, ReadLimiter* readLimiter, BuilderArena* arena, SegmentId id, word* ptr, SegmentWordCount size,
size_t wordsUsed) ReadLimiter* readLimiter, SegmentWordCount wordsUsed)
: SegmentReader(arena, id, ptr, readLimiter), pos(ptr.begin() + wordsUsed), readOnly(false) {} : SegmentReader(arena, id, ptr, size, readLimiter),
pos(ptr + wordsUsed), readOnly(false) {}
inline SegmentBuilder::SegmentBuilder( inline SegmentBuilder::SegmentBuilder(
BuilderArena* arena, SegmentId id, kj::ArrayPtr<const word> ptr, ReadLimiter* readLimiter) BuilderArena* arena, SegmentId id, const word* ptr, SegmentWordCount size,
: SegmentReader(arena, id, ptr, readLimiter), ReadLimiter* readLimiter)
: SegmentReader(arena, id, ptr, size, readLimiter),
// const_cast is safe here because the member won't ever be dereferenced because it appears // const_cast is safe here because the member won't ever be dereferenced because it appears
// to point to the end of the segment anyway. // to point to the end of the segment anyway.
pos(const_cast<word*>(ptr.end())), pos(const_cast<word*>(ptr + size)), readOnly(true) {}
readOnly(true) {}
inline SegmentBuilder::SegmentBuilder(BuilderArena* arena, SegmentId id, decltype(nullptr), inline SegmentBuilder::SegmentBuilder(BuilderArena* arena, SegmentId id, decltype(nullptr),
ReadLimiter* readLimiter) ReadLimiter* readLimiter)
: SegmentReader(arena, id, nullptr, readLimiter), pos(nullptr), readOnly(false) {} : SegmentReader(arena, id, nullptr, ZERO * WORDS, readLimiter),
pos(nullptr), readOnly(false) {}
inline word* SegmentBuilder::allocate(WordCount amount) { inline word* SegmentBuilder::allocate(SegmentWordCount amount) {
if (intervalLength(pos, ptr.end()) < amount) { if (intervalLength(pos, ptr.end()) < amount) {
// Not enough space in the segment for this allocation. // Not enough space in the segment for this allocation.
return nullptr; return nullptr;
...@@ -417,7 +427,7 @@ inline void SegmentBuilder::checkWritable() { ...@@ -417,7 +427,7 @@ inline void SegmentBuilder::checkWritable() {
if (KJ_UNLIKELY(readOnly)) throwNotWritable(); if (KJ_UNLIKELY(readOnly)) throwNotWritable();
} }
inline word* SegmentBuilder::getPtrUnchecked(WordCount offset) { inline word* SegmentBuilder::getPtrUnchecked(SegmentWordCount offset) {
return const_cast<word*>(ptr.begin() + offset); return const_cast<word*>(ptr.begin() + offset);
} }
...@@ -432,7 +442,7 @@ inline kj::ArrayPtr<const word> SegmentBuilder::currentlyAllocated() { ...@@ -432,7 +442,7 @@ inline kj::ArrayPtr<const word> SegmentBuilder::currentlyAllocated() {
} }
inline void SegmentBuilder::reset() { inline void SegmentBuilder::reset() {
word* start = getPtrUnchecked(0 * WORDS); word* start = getPtrUnchecked(ZERO * WORDS);
memset(start, 0, (pos - start) * sizeof(word)); memset(start, 0, (pos - start) * sizeof(word));
pos = start; pos = start;
} }
......
...@@ -653,7 +653,8 @@ struct List<T, Kind::INTERFACE> { ...@@ -653,7 +653,8 @@ struct List<T, Kind::INTERFACE> {
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return reader.size() / ELEMENTS; }
inline typename T::Client operator[](uint index) const { inline typename T::Client operator[](uint index) const {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return typename T::Client(reader.getPointerElement(index * ELEMENTS).getCapability()); return typename T::Client(reader.getPointerElement(
guarded(index) * ELEMENTS).getCapability());
} }
typedef _::IndexingIterator<const Reader, typename T::Client> Iterator; typedef _::IndexingIterator<const Reader, typename T::Client> Iterator;
...@@ -685,19 +686,20 @@ struct List<T, Kind::INTERFACE> { ...@@ -685,19 +686,20 @@ struct List<T, Kind::INTERFACE> {
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return builder.size() / ELEMENTS; }
inline typename T::Client operator[](uint index) { inline typename T::Client operator[](uint index) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return typename T::Client(builder.getPointerElement(index * ELEMENTS).getCapability()); return typename T::Client(builder.getPointerElement(
guarded(index) * ELEMENTS).getCapability());
} }
inline void set(uint index, typename T::Client value) { inline void set(uint index, typename T::Client value) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
builder.getPointerElement(index * ELEMENTS).setCapability(kj::mv(value.hook)); builder.getPointerElement(guarded(index) * ELEMENTS).setCapability(kj::mv(value.hook));
} }
inline void adopt(uint index, Orphan<T>&& value) { inline void adopt(uint index, Orphan<T>&& value) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
builder.getPointerElement(index * ELEMENTS).adopt(kj::mv(value)); builder.getPointerElement(guarded(index) * ELEMENTS).adopt(kj::mv(value));
} }
inline Orphan<T> disown(uint index) { inline Orphan<T> disown(uint index) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return Orphan<T>(builder.getPointerElement(index * ELEMENTS).disown()); return Orphan<T>(builder.getPointerElement(guarded(index) * ELEMENTS).disown());
} }
typedef _::IndexingIterator<Builder, typename T::Client> Iterator; typedef _::IndexingIterator<Builder, typename T::Client> Iterator;
...@@ -713,7 +715,7 @@ struct List<T, Kind::INTERFACE> { ...@@ -713,7 +715,7 @@ struct List<T, Kind::INTERFACE> {
private: private:
inline static _::ListBuilder initPointer(_::PointerBuilder builder, uint size) { inline static _::ListBuilder initPointer(_::PointerBuilder builder, uint size) {
return builder.initList(ElementSize::POINTER, size * ELEMENTS); return builder.initList(ElementSize::POINTER, guarded(size) * ELEMENTS);
} }
inline static _::ListBuilder getFromPointer(_::PointerBuilder builder, const word* defaultValue) { inline static _::ListBuilder getFromPointer(_::PointerBuilder builder, const word* defaultValue) {
return builder.getList(ElementSize::POINTER, defaultValue); return builder.getList(ElementSize::POINTER, defaultValue);
......
...@@ -345,138 +345,176 @@ static_assert(sizeof(word) == 8, "uint64_t is not 8 bytes?"); ...@@ -345,138 +345,176 @@ static_assert(sizeof(word) == 8, "uint64_t is not 8 bytes?");
namespace _ { class BitLabel; class ElementLabel; struct WirePointer; } namespace _ { class BitLabel; class ElementLabel; struct WirePointer; }
typedef kj::Quantity<uint, _::BitLabel> BitCount; template <uint width, typename T = uint>
typedef kj::Quantity<uint8_t, _::BitLabel> BitCount8; using BitCountN = kj::Quantity<kj::Guarded<kj::maxValueForBits<width>(), T>, _::BitLabel>;
typedef kj::Quantity<uint16_t, _::BitLabel> BitCount16; template <uint width, typename T = uint>
typedef kj::Quantity<uint32_t, _::BitLabel> BitCount32; using ByteCountN = kj::Quantity<kj::Guarded<kj::maxValueForBits<width>(), T>, byte>;
typedef kj::Quantity<uint64_t, _::BitLabel> BitCount64; template <uint width, typename T = uint>
using WordCountN = kj::Quantity<kj::Guarded<kj::maxValueForBits<width>(), T>, word>;
typedef kj::Quantity<uint, byte> ByteCount; template <uint width, typename T = uint>
typedef kj::Quantity<uint8_t, byte> ByteCount8; using ElementCountN = kj::Quantity<kj::Guarded<kj::maxValueForBits<width>(), T>, _::ElementLabel>;
typedef kj::Quantity<uint16_t, byte> ByteCount16; template <uint width, typename T = uint>
typedef kj::Quantity<uint32_t, byte> ByteCount32; using WirePointerCountN = kj::Quantity<kj::Guarded<kj::maxValueForBits<width>(), T>, _::WirePointer>;
typedef kj::Quantity<uint64_t, byte> ByteCount64;
typedef BitCountN<8, uint8_t> BitCount8;
typedef kj::Quantity<uint, word> WordCount; typedef BitCountN<16, uint16_t> BitCount16;
typedef kj::Quantity<uint8_t, word> WordCount8; typedef BitCountN<32, uint32_t> BitCount32;
typedef kj::Quantity<uint16_t, word> WordCount16; typedef BitCountN<64, uint64_t> BitCount64;
typedef kj::Quantity<uint32_t, word> WordCount32; typedef BitCountN<sizeof(uint) * 8, uint> BitCount;
typedef kj::Quantity<uint64_t, word> WordCount64;
typedef ByteCountN<8, uint8_t> ByteCount8;
typedef kj::Quantity<uint, _::ElementLabel> ElementCount; typedef ByteCountN<16, uint16_t> ByteCount16;
typedef kj::Quantity<uint8_t, _::ElementLabel> ElementCount8; typedef ByteCountN<32, uint32_t> ByteCount32;
typedef kj::Quantity<uint16_t, _::ElementLabel> ElementCount16; typedef ByteCountN<64, uint64_t> ByteCount64;
typedef kj::Quantity<uint32_t, _::ElementLabel> ElementCount32; typedef ByteCountN<sizeof(uint) * 8, uint> ByteCount;
typedef kj::Quantity<uint64_t, _::ElementLabel> ElementCount64;
typedef WordCountN<8, uint8_t> WordCount8;
typedef kj::Quantity<uint, _::WirePointer> WirePointerCount; typedef WordCountN<16, uint16_t> WordCount16;
typedef kj::Quantity<uint8_t, _::WirePointer> WirePointerCount8; typedef WordCountN<32, uint32_t> WordCount32;
typedef kj::Quantity<uint16_t, _::WirePointer> WirePointerCount16; typedef WordCountN<64, uint64_t> WordCount64;
typedef kj::Quantity<uint32_t, _::WirePointer> WirePointerCount32; typedef WordCountN<sizeof(uint) * 8, uint> WordCount;
typedef kj::Quantity<uint64_t, _::WirePointer> WirePointerCount64;
typedef ElementCountN<8, uint8_t> ElementCount8;
typedef ElementCountN<16, uint16_t> ElementCount16;
typedef ElementCountN<32, uint32_t> ElementCount32;
typedef ElementCountN<64, uint64_t> ElementCount64;
typedef ElementCountN<sizeof(uint) * 8, uint> ElementCount;
typedef WirePointerCountN<8, uint8_t> WirePointerCount8;
typedef WirePointerCountN<16, uint16_t> WirePointerCount16;
typedef WirePointerCountN<32, uint32_t> WirePointerCount32;
typedef WirePointerCountN<64, uint64_t> WirePointerCount64;
typedef WirePointerCountN<sizeof(uint) * 8, uint> WirePointerCount;
template <uint width>
using BitsPerElementN = decltype(BitCountN<width>() / ElementCountN<width>());
template <uint width>
using BytesPerElementN = decltype(ByteCountN<width>() / ElementCountN<width>());
template <uint width>
using WordsPerElementN = decltype(WordCountN<width>() / ElementCountN<width>());
template <uint width>
using PointersPerElementN = decltype(WirePointerCountN<width>() / ElementCountN<width>());
using kj::guarded;
using kj::unguard;
using kj::unguardAs;
using kj::unguardMax;
using kj::unguardMaxBits;
using kj::assertMax;
using kj::assertMaxBits;
using kj::upgradeGuard;
using kj::ThrowOverflow;
using kj::assumeBits;
using kj::subtractChecked;
template <typename T, typename U> template <typename T, typename U>
inline constexpr U* operator+(U* ptr, kj::Quantity<T, U> offset) { inline constexpr U* operator+(U* ptr, kj::Quantity<T, U> offset) {
return ptr + offset / kj::unit<kj::Quantity<T, U>>(); return ptr + unguard(offset / kj::unit<kj::Quantity<T, U>>());
} }
template <typename T, typename U> template <typename T, typename U>
inline constexpr const U* operator+(const U* ptr, kj::Quantity<T, U> offset) { inline constexpr const U* operator+(const U* ptr, kj::Quantity<T, U> offset) {
return ptr + offset / kj::unit<kj::Quantity<T, U>>(); return ptr + unguard(offset / kj::unit<kj::Quantity<T, U>>());
} }
template <typename T, typename U> template <typename T, typename U>
inline constexpr U* operator+=(U*& ptr, kj::Quantity<T, U> offset) { inline constexpr U* operator+=(U*& ptr, kj::Quantity<T, U> offset) {
return ptr = ptr + offset / kj::unit<kj::Quantity<T, U>>(); return ptr = ptr + unguard(offset / kj::unit<kj::Quantity<T, U>>());
} }
template <typename T, typename U> template <typename T, typename U>
inline constexpr const U* operator+=(const U*& ptr, kj::Quantity<T, U> offset) { inline constexpr const U* operator+=(const U*& ptr, kj::Quantity<T, U> offset) {
return ptr = ptr + offset / kj::unit<kj::Quantity<T, U>>(); return ptr = ptr + unguard(offset / kj::unit<kj::Quantity<T, U>>());
} }
template <typename T, typename U> template <typename T, typename U>
inline constexpr U* operator-(U* ptr, kj::Quantity<T, U> offset) { inline constexpr U* operator-(U* ptr, kj::Quantity<T, U> offset) {
return ptr - offset / kj::unit<kj::Quantity<T, U>>(); return ptr - unguard(offset / kj::unit<kj::Quantity<T, U>>());
} }
template <typename T, typename U> template <typename T, typename U>
inline constexpr const U* operator-(const U* ptr, kj::Quantity<T, U> offset) { inline constexpr const U* operator-(const U* ptr, kj::Quantity<T, U> offset) {
return ptr - offset / kj::unit<kj::Quantity<T, U>>(); return ptr - unguard(offset / kj::unit<kj::Quantity<T, U>>());
} }
template <typename T, typename U> template <typename T, typename U>
inline constexpr U* operator-=(U*& ptr, kj::Quantity<T, U> offset) { inline constexpr U* operator-=(U*& ptr, kj::Quantity<T, U> offset) {
return ptr = ptr - offset / kj::unit<kj::Quantity<T, U>>(); return ptr = ptr - unguard(offset / kj::unit<kj::Quantity<T, U>>());
} }
template <typename T, typename U> template <typename T, typename U>
inline constexpr const U* operator-=(const U*& ptr, kj::Quantity<T, U> offset) { inline constexpr const U* operator-=(const U*& ptr, kj::Quantity<T, U> offset) {
return ptr = ptr - offset / kj::unit<kj::Quantity<T, U>>(); return ptr = ptr - unguard(offset / kj::unit<kj::Quantity<T, U>>());
} }
#else constexpr auto BITS = kj::unit<BitCountN<1>>();
constexpr auto BYTES = kj::unit<ByteCountN<1>>();
typedef uint BitCount; constexpr auto WORDS = kj::unit<WordCountN<1>>();
typedef uint8_t BitCount8; constexpr auto ELEMENTS = kj::unit<ElementCountN<1>>();
typedef uint16_t BitCount16; constexpr auto POINTERS = kj::unit<WirePointerCountN<1>>();
typedef uint32_t BitCount32;
typedef uint64_t BitCount64;
typedef uint ByteCount;
typedef uint8_t ByteCount8;
typedef uint16_t ByteCount16;
typedef uint32_t ByteCount32;
typedef uint64_t ByteCount64;
typedef uint WordCount;
typedef uint8_t WordCount8;
typedef uint16_t WordCount16;
typedef uint32_t WordCount32;
typedef uint64_t WordCount64;
typedef uint ElementCount;
typedef uint8_t ElementCount8;
typedef uint16_t ElementCount16;
typedef uint32_t ElementCount32;
typedef uint64_t ElementCount64;
typedef uint WirePointerCount;
typedef uint8_t WirePointerCount8;
typedef uint16_t WirePointerCount16;
typedef uint32_t WirePointerCount32;
typedef uint64_t WirePointerCount64;
#endif
constexpr BitCount BITS = kj::unit<BitCount>(); constexpr auto ZERO = kj::guarded<0>();
constexpr ByteCount BYTES = kj::unit<ByteCount>(); constexpr auto ONE = kj::guarded<1>();
constexpr WordCount WORDS = kj::unit<WordCount>();
constexpr ElementCount ELEMENTS = kj::unit<ElementCount>();
constexpr WirePointerCount POINTERS = kj::unit<WirePointerCount>();
// GCC 4.7 actually gives unused warnings on these constants in opt mode... // GCC 4.7 actually gives unused warnings on these constants in opt mode...
constexpr auto BITS_PER_BYTE KJ_UNUSED = 8 * BITS / BYTES; constexpr auto BITS_PER_BYTE KJ_UNUSED = guarded<8>() * BITS / BYTES;
constexpr auto BITS_PER_WORD KJ_UNUSED = 64 * BITS / WORDS; constexpr auto BITS_PER_WORD KJ_UNUSED = guarded<64>() * BITS / WORDS;
constexpr auto BYTES_PER_WORD KJ_UNUSED = 8 * BYTES / WORDS; constexpr auto BYTES_PER_WORD KJ_UNUSED = guarded<8>() * BYTES / WORDS;
constexpr auto BITS_PER_POINTER KJ_UNUSED = 64 * BITS / POINTERS; constexpr auto BITS_PER_POINTER KJ_UNUSED = guarded<64>() * BITS / POINTERS;
constexpr auto BYTES_PER_POINTER KJ_UNUSED = 8 * BYTES / POINTERS; constexpr auto BYTES_PER_POINTER KJ_UNUSED = guarded<8>() * BYTES / POINTERS;
constexpr auto WORDS_PER_POINTER KJ_UNUSED = 1 * WORDS / POINTERS; constexpr auto WORDS_PER_POINTER KJ_UNUSED = ONE * WORDS / POINTERS;
constexpr WordCount POINTER_SIZE_IN_WORDS = 1 * POINTERS * WORDS_PER_POINTER; constexpr auto POINTER_SIZE_IN_WORDS = ONE * POINTERS * WORDS_PER_POINTER;
constexpr uint SEGMENT_WORD_COUNT_BITS = 29; // Number of words in a segment.
constexpr uint LIST_ELEMENT_COUNT_BITS = 29; // Number of elements in a list.
constexpr uint STRUCT_DATA_WORD_COUNT_BITS = 16; // Number of words in a Struct data section.
constexpr uint STRUCT_POINTER_COUNT_BITS = 16; // Number of pointers in a Struct pointer section.
constexpr uint BLOB_SIZE_BITS = 29; // Number of bytes in a blob.
typedef WordCountN<SEGMENT_WORD_COUNT_BITS> SegmentWordCount;
typedef ElementCountN<LIST_ELEMENT_COUNT_BITS> ListElementCount;
typedef WordCountN<STRUCT_DATA_WORD_COUNT_BITS, uint16_t> StructDataWordCount;
typedef WirePointerCountN<STRUCT_POINTER_COUNT_BITS, uint16_t> StructPointerCount;
typedef ByteCountN<BLOB_SIZE_BITS> BlobSize;
constexpr auto MAX_SEGMENT_WORDS =
guarded<kj::maxValueForBits<SEGMENT_WORD_COUNT_BITS>()>() * WORDS;
constexpr auto MAX_LIST_ELEMENTS =
guarded<kj::maxValueForBits<LIST_ELEMENT_COUNT_BITS>()>() * ELEMENTS;
constexpr auto MAX_STUCT_DATA_WORDS =
guarded<kj::maxValueForBits<STRUCT_DATA_WORD_COUNT_BITS>()>() * WORDS;
constexpr auto MAX_STRUCT_POINTER_COUNT =
guarded<kj::maxValueForBits<STRUCT_POINTER_COUNT_BITS>()>() *POINTERS;
using StructDataBitCount = decltype(WordCountN<STRUCT_POINTER_COUNT_BITS>() * BITS_PER_WORD);
using StructDataElementOffset = decltype(StructDataBitCount() * (ONE * ELEMENTS / BITS));
// Number of bits in a Struct data segment (should come out to BitCountN<22>).
constexpr uint MAX_TEXT_SIZE = kj::maxValueForBits<BLOB_SIZE_BITS>() - 1;
typedef kj::Quantity<kj::Guarded<MAX_TEXT_SIZE, uint>, byte> TextSize;
// Not including NUL terminator.
template <typename T> template <typename T>
inline KJ_CONSTEXPR() decltype(BYTES / ELEMENTS) bytesPerElement() { inline KJ_CONSTEXPR() decltype(BYTES / ELEMENTS) bytesPerElement() {
return sizeof(T) * BYTES / ELEMENTS; return guarded<sizeof(T)>() * BYTES / ELEMENTS;
} }
template <typename T> template <typename T>
inline KJ_CONSTEXPR() decltype(BITS / ELEMENTS) bitsPerElement() { inline KJ_CONSTEXPR() decltype(BITS / ELEMENTS) bitsPerElement() {
return sizeof(T) * 8 * BITS / ELEMENTS; return guarded<sizeof(T)>() * 8 * BITS / ELEMENTS;
} }
inline constexpr ByteCount intervalLength(const byte* a, const byte* b) { inline constexpr ByteCountN<sizeof(size_t) * 8, size_t>
return uint(b - a) * BYTES; intervalLength(const byte* a, const byte* b) {
return kj::guarded(b - a) * BYTES;
} }
inline constexpr WordCount intervalLength(const word* a, const word* b) { inline constexpr WordCountN<sizeof(size_t) * 8, size_t>
return uint(b - a) * WORDS; intervalLength(const word* a, const word* b) {
return kj::guarded(b - a) * WORDS;
} }
#else
#error TODO
#endif
} // namespace capnp } // namespace capnp
#endif // CAPNP_COMMON_H_ #endif // CAPNP_COMMON_H_
...@@ -59,6 +59,8 @@ namespace _ { // private ...@@ -59,6 +59,8 @@ namespace _ { // private
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
#define G(n) guarded<n>()
// ======================================================================================= // =======================================================================================
struct WirePointer { struct WirePointer {
...@@ -159,26 +161,28 @@ struct WirePointer { ...@@ -159,26 +161,28 @@ struct WirePointer {
offsetAndKind.set(kind | 0xfffffffc); offsetAndKind.set(kind | 0xfffffffc);
} }
KJ_ALWAYS_INLINE(ElementCount inlineCompositeListElementCount() const) { KJ_ALWAYS_INLINE(ListElementCount inlineCompositeListElementCount() const) {
return (offsetAndKind.get() >> 2) * ELEMENTS; return ((guarded(offsetAndKind.get()) >> G(2))
& G(kj::maxValueForBits<LIST_ELEMENT_COUNT_BITS>())) * ELEMENTS;
} }
KJ_ALWAYS_INLINE(void setKindAndInlineCompositeListElementCount( KJ_ALWAYS_INLINE(void setKindAndInlineCompositeListElementCount(
Kind kind, ElementCount elementCount)) { Kind kind, ListElementCount elementCount)) {
offsetAndKind.set(((elementCount / ELEMENTS) << 2) | kind); offsetAndKind.set(unguardAs<uint32_t>((elementCount / ELEMENTS) << G(2)) | kind);
} }
KJ_ALWAYS_INLINE(WordCount farPositionInSegment() const) { KJ_ALWAYS_INLINE(SegmentWordCount farPositionInSegment() const) {
KJ_DREQUIRE(kind() == FAR, KJ_DREQUIRE(kind() == FAR,
"positionInSegment() should only be called on FAR pointers."); "positionInSegment() should only be called on FAR pointers.");
return (offsetAndKind.get() >> 3) * WORDS; return (guarded(offsetAndKind.get()) >> G(3)) * WORDS;
} }
KJ_ALWAYS_INLINE(bool isDoubleFar() const) { KJ_ALWAYS_INLINE(bool isDoubleFar() const) {
KJ_DREQUIRE(kind() == FAR, KJ_DREQUIRE(kind() == FAR,
"isDoubleFar() should only be called on FAR pointers."); "isDoubleFar() should only be called on FAR pointers.");
return (offsetAndKind.get() >> 2) & 1; return unguard((guarded(offsetAndKind.get()) >> G(2)) & G(1));
} }
KJ_ALWAYS_INLINE(void setFar(bool isDoubleFar, WordCount pos)) { KJ_ALWAYS_INLINE(void setFar(bool isDoubleFar, WordCountN<29> pos)) {
offsetAndKind.set(((pos / WORDS) << 3) | (static_cast<uint32_t>(isDoubleFar) << 2) | offsetAndKind.set(unguardAs<uint32_t>((pos / WORDS) << G(3)) |
(static_cast<uint32_t>(isDoubleFar) << 2) |
static_cast<uint32_t>(Kind::FAR)); static_cast<uint32_t>(Kind::FAR));
} }
KJ_ALWAYS_INLINE(void setCap(uint index)) { KJ_ALWAYS_INLINE(void setCap(uint index)) {
...@@ -196,11 +200,11 @@ struct WirePointer { ...@@ -196,11 +200,11 @@ struct WirePointer {
WireValue<WordCount16> dataSize; WireValue<WordCount16> dataSize;
WireValue<WirePointerCount16> ptrCount; WireValue<WirePointerCount16> ptrCount;
inline WordCount wordSize() const { inline WordCountN<17> wordSize() const {
return dataSize.get() + ptrCount.get() * WORDS_PER_POINTER; return upgradeGuard<uint32_t>(dataSize.get()) + ptrCount.get() * WORDS_PER_POINTER;
} }
KJ_ALWAYS_INLINE(void set(WordCount ds, WirePointerCount rc)) { KJ_ALWAYS_INLINE(void set(WordCount16 ds, WirePointerCount16 rc)) {
dataSize.set(ds); dataSize.set(ds);
ptrCount.set(rc); ptrCount.set(rc);
} }
...@@ -216,21 +220,20 @@ struct WirePointer { ...@@ -216,21 +220,20 @@ struct WirePointer {
KJ_ALWAYS_INLINE(ElementSize elementSize() const) { KJ_ALWAYS_INLINE(ElementSize elementSize() const) {
return static_cast<ElementSize>(elementSizeAndCount.get() & 7); return static_cast<ElementSize>(elementSizeAndCount.get() & 7);
} }
KJ_ALWAYS_INLINE(ElementCount elementCount() const) { KJ_ALWAYS_INLINE(ElementCountN<29> elementCount() const) {
return (elementSizeAndCount.get() >> 3) * ELEMENTS; return (guarded(elementSizeAndCount.get()) >> G(3)) * ELEMENTS;
} }
KJ_ALWAYS_INLINE(WordCount inlineCompositeWordCount() const) { KJ_ALWAYS_INLINE(WordCountN<29> inlineCompositeWordCount() const) {
return elementCount() * (1 * WORDS / ELEMENTS); return elementCount() * (ONE * WORDS / ELEMENTS);
} }
KJ_ALWAYS_INLINE(void set(ElementSize es, ElementCount ec)) { KJ_ALWAYS_INLINE(void set(ElementSize es, ElementCountN<29> ec)) {
KJ_DREQUIRE(ec < (1 << 29) * ELEMENTS, "Lists are limited to 2**29 elements."); elementSizeAndCount.set(unguardAs<uint32_t>((ec / ELEMENTS) << G(3)) |
elementSizeAndCount.set(((ec / ELEMENTS) << 3) | static_cast<int>(es)); static_cast<int>(es));
} }
KJ_ALWAYS_INLINE(void setInlineComposite(WordCount wc)) { KJ_ALWAYS_INLINE(void setInlineComposite(WordCountN<29> wc)) {
KJ_DREQUIRE(wc < (1 << 29) * WORDS, "Inline composite lists are limited to 2**29 words."); elementSizeAndCount.set(unguardAs<uint32_t>((wc / WORDS) << G(3)) |
elementSizeAndCount.set(((wc / WORDS) << 3) |
static_cast<int>(ElementSize::INLINE_COMPOSITE)); static_cast<int>(ElementSize::INLINE_COMPOSITE));
} }
}; };
...@@ -269,17 +272,19 @@ struct WirePointer { ...@@ -269,17 +272,19 @@ struct WirePointer {
}; };
static_assert(sizeof(WirePointer) == sizeof(word), static_assert(sizeof(WirePointer) == sizeof(word),
"capnp::WirePointer is not exactly one word. This will probably break everything."); "capnp::WirePointer is not exactly one word. This will probably break everything.");
static_assert(POINTERS * WORDS_PER_POINTER * BYTES_PER_WORD / BYTES == sizeof(WirePointer), static_assert(unguardAs<size_t>(POINTERS * WORDS_PER_POINTER * BYTES_PER_WORD / BYTES) ==
sizeof(WirePointer),
"WORDS_PER_POINTER is wrong."); "WORDS_PER_POINTER is wrong.");
static_assert(POINTERS * BYTES_PER_POINTER / BYTES == sizeof(WirePointer), static_assert(unguardAs<size_t>(POINTERS * BYTES_PER_POINTER / BYTES) == sizeof(WirePointer),
"BYTES_PER_POINTER is wrong."); "BYTES_PER_POINTER is wrong.");
static_assert(POINTERS * BITS_PER_POINTER / BITS_PER_BYTE / BYTES == sizeof(WirePointer), static_assert(unguardAs<size_t>(POINTERS * BITS_PER_POINTER / BITS_PER_BYTE / BYTES) ==
sizeof(WirePointer),
"BITS_PER_POINTER is wrong."); "BITS_PER_POINTER is wrong.");
namespace { namespace {
static const union { static const union {
AlignedData<POINTER_SIZE_IN_WORDS / WORDS> word; AlignedData<unguard(POINTER_SIZE_IN_WORDS / WORDS)> word;
WirePointer pointer; WirePointer pointer;
} zero = {{{0}}}; } zero = {{{0}}};
...@@ -298,22 +303,72 @@ struct SegmentAnd { ...@@ -298,22 +303,72 @@ struct SegmentAnd {
} // namespace } // namespace
struct WireHelpers { struct WireHelpers {
#if CAPNP_DEBUG_TYPES
template <uint64_t maxN, typename T>
static KJ_ALWAYS_INLINE(
kj::Quantity<kj::Guarded<(maxN + 7) / 8, T>, word> roundBytesUpToWords(
kj::Quantity<kj::Guarded<maxN, T>, byte> bytes)) {
static_assert(sizeof(word) == 8, "This code assumes 64-bit words.");
return (bytes + G(7) * BYTES) / BYTES_PER_WORD;
}
template <uint64_t maxN, typename T>
static KJ_ALWAYS_INLINE(
kj::Quantity<kj::Guarded<(maxN + 7) / 8, T>, byte> roundBitsUpToBytes(
kj::Quantity<kj::Guarded<maxN, T>, BitLabel> bits)) {
return (bits + G(7) * BITS) / BITS_PER_BYTE;
}
template <uint64_t maxN, typename T>
static KJ_ALWAYS_INLINE(
kj::Quantity<kj::Guarded<(maxN + 63) / 64, T>, word> roundBitsUpToWords(
kj::Quantity<kj::Guarded<maxN, T>, BitLabel> bits)) {
static_assert(sizeof(word) == 8, "This code assumes 64-bit words.");
return (bits + G(63) * BITS) / BITS_PER_WORD;
}
#else
static KJ_ALWAYS_INLINE(WordCount roundBytesUpToWords(ByteCount bytes)) { static KJ_ALWAYS_INLINE(WordCount roundBytesUpToWords(ByteCount bytes)) {
static_assert(sizeof(word) == 8, "This code assumes 64-bit words."); static_assert(sizeof(word) == 8, "This code assumes 64-bit words.");
return (bytes + 7 * BYTES) / BYTES_PER_WORD; return (bytes + G(7) * BYTES) / BYTES_PER_WORD;
} }
static KJ_ALWAYS_INLINE(ByteCount roundBitsUpToBytes(BitCount bits)) { static KJ_ALWAYS_INLINE(ByteCount roundBitsUpToBytes(BitCount bits)) {
return (bits + 7 * BITS) / BITS_PER_BYTE; return (bits + G(7) * BITS) / BITS_PER_BYTE;
} }
static KJ_ALWAYS_INLINE(WordCount64 roundBitsUpToWords(BitCount64 bits)) { static KJ_ALWAYS_INLINE(WordCount64 roundBitsUpToWords(BitCount64 bits)) {
static_assert(sizeof(word) == 8, "This code assumes 64-bit words."); static_assert(sizeof(word) == 8, "This code assumes 64-bit words.");
return (bits + 63 * BITS) / BITS_PER_WORD; return (bits + G(63) * BITS) / BITS_PER_WORD;
} }
static KJ_ALWAYS_INLINE(ByteCount64 roundBitsUpToBytes(BitCount64 bits)) { static KJ_ALWAYS_INLINE(ByteCount64 roundBitsUpToBytes(BitCount64 bits)) {
return (bits + 7 * BITS) / BITS_PER_BYTE; return (bits + G(7) * BITS) / BITS_PER_BYTE;
}
#endif
static KJ_ALWAYS_INLINE(void zeroMemory(byte* ptr, ByteCount32 count)) {
memset(ptr, 0, unguard(count / BYTES));
}
static KJ_ALWAYS_INLINE(void zeroMemory(word* ptr, WordCountN<29> count)) {
memset(ptr, 0, unguard(count * BYTES_PER_WORD / BYTES));
}
static KJ_ALWAYS_INLINE(void zeroMemory(WirePointer* ptr, WirePointerCountN<29> count)) {
memset(ptr, 0, unguard(count * BYTES_PER_POINTER / BYTES));
}
static KJ_ALWAYS_INLINE(void copyMemory(byte* to, const byte* from, ByteCount32 count)) {
memcpy(to, from, unguard(count / BYTES));
}
static KJ_ALWAYS_INLINE(void copyMemory(word* to, const word* from, WordCountN<29> count)) {
memcpy(to, from, unguard(count * BYTES_PER_WORD / BYTES));
}
static KJ_ALWAYS_INLINE(void copyMemory(WirePointer* to, const WirePointer* from,
WirePointerCountN<29> count)) {
memcpy(to, from, unguard(count * BYTES_PER_POINTER / BYTES));
} }
static KJ_ALWAYS_INLINE(bool boundsCheck( static KJ_ALWAYS_INLINE(bool boundsCheck(
...@@ -328,8 +383,8 @@ struct WireHelpers { ...@@ -328,8 +383,8 @@ struct WireHelpers {
} }
static KJ_ALWAYS_INLINE(word* allocate( static KJ_ALWAYS_INLINE(word* allocate(
WirePointer*& ref, SegmentBuilder*& segment, CapTableBuilder* capTable, WordCount amount, WirePointer*& ref, SegmentBuilder*& segment, CapTableBuilder* capTable,
WirePointer::Kind kind, BuilderArena* orphanArena)) { SegmentWordCount amount, WirePointer::Kind kind, BuilderArena* orphanArena)) {
// Allocate space in the message for a new object, creating far pointers if necessary. // Allocate space in the message for a new object, creating far pointers if necessary.
// //
// * `ref` starts out being a reference to the pointer which shall be assigned to point at the // * `ref` starts out being a reference to the pointer which shall be assigned to point at the
...@@ -353,7 +408,7 @@ struct WireHelpers { ...@@ -353,7 +408,7 @@ struct WireHelpers {
if (orphanArena == nullptr) { if (orphanArena == nullptr) {
if (!ref->isNull()) zeroObject(segment, capTable, ref); if (!ref->isNull()) zeroObject(segment, capTable, ref);
if (amount == 0 * WORDS && kind == WirePointer::STRUCT) { if (amount == ZERO * WORDS && kind == WirePointer::STRUCT) {
// Note that the check for kind == WirePointer::STRUCT will hopefully cause this whole // Note that the check for kind == WirePointer::STRUCT will hopefully cause this whole
// branch to be optimized away from all the call sites that are allocating non-structs. // branch to be optimized away from all the call sites that are allocating non-structs.
ref->setKindAndTargetForEmptyStruct(); ref->setKindAndTargetForEmptyStruct();
...@@ -368,7 +423,10 @@ struct WireHelpers { ...@@ -368,7 +423,10 @@ struct WireHelpers {
// space to act as the landing pad for a far pointer. // space to act as the landing pad for a far pointer.
WordCount amountPlusRef = amount + POINTER_SIZE_IN_WORDS; WordCount amountPlusRef = amount + POINTER_SIZE_IN_WORDS;
auto allocation = segment->getArena()->allocate(amountPlusRef); auto allocation = segment->getArena()->allocate(
assertMaxBits<SEGMENT_WORD_COUNT_BITS>(amountPlusRef, []() {
KJ_FAIL_REQUIRE("requested object size exceeds maximum segment size");
}));
segment = allocation.segment; segment = allocation.segment;
ptr = allocation.words; ptr = allocation.words;
...@@ -448,7 +506,7 @@ struct WireHelpers { ...@@ -448,7 +506,7 @@ struct WireHelpers {
// Find the landing pad and check that it is within bounds. // Find the landing pad and check that it is within bounds.
const word* ptr = segment->getStartPtr() + ref->farPositionInSegment(); const word* ptr = segment->getStartPtr() + ref->farPositionInSegment();
WordCount padWords = (1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS; WordCount padWords = guarded(1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS;
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + padWords), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + padWords),
"Message contains out-of-bounds far pointer.") { "Message contains out-of-bounds far pointer.") {
return nullptr; return nullptr;
...@@ -503,10 +561,10 @@ struct WireHelpers { ...@@ -503,10 +561,10 @@ struct WireHelpers {
zeroObject(segment, capTable, zeroObject(segment, capTable,
pad + 1, segment->getPtrUnchecked(pad->farPositionInSegment())); pad + 1, segment->getPtrUnchecked(pad->farPositionInSegment()));
} }
memset(pad, 0, sizeof(WirePointer) * 2); zeroMemory(pad, G(2) * POINTERS);
} else { } else {
zeroObject(segment, capTable, pad); zeroObject(segment, capTable, pad);
memset(pad, 0, sizeof(WirePointer)); zeroMemory(pad, ONE * POINTERS);
} }
} }
break; break;
...@@ -534,11 +592,10 @@ struct WireHelpers { ...@@ -534,11 +592,10 @@ struct WireHelpers {
case WirePointer::STRUCT: { case WirePointer::STRUCT: {
WirePointer* pointerSection = WirePointer* pointerSection =
reinterpret_cast<WirePointer*>(ptr + tag->structRef.dataSize.get()); reinterpret_cast<WirePointer*>(ptr + tag->structRef.dataSize.get());
uint count = tag->structRef.ptrCount.get() / POINTERS; for (auto i: kj::zeroTo(tag->structRef.ptrCount.get())) {
for (uint i = 0; i < count; i++) {
zeroObject(segment, capTable, pointerSection + i); zeroObject(segment, capTable, pointerSection + i);
} }
memset(ptr, 0, tag->structRef.wordSize() * BYTES_PER_WORD / BYTES); zeroMemory(ptr, tag->structRef.wordSize());
break; break;
} }
case WirePointer::LIST: { case WirePointer::LIST: {
...@@ -550,18 +607,19 @@ struct WireHelpers { ...@@ -550,18 +607,19 @@ struct WireHelpers {
case ElementSize::BYTE: case ElementSize::BYTE:
case ElementSize::TWO_BYTES: case ElementSize::TWO_BYTES:
case ElementSize::FOUR_BYTES: case ElementSize::FOUR_BYTES:
case ElementSize::EIGHT_BYTES: case ElementSize::EIGHT_BYTES: {
memset(ptr, 0, zeroMemory(ptr, roundBitsUpToWords(
roundBitsUpToWords(ElementCount64(tag->listRef.elementCount()) * upgradeGuard<uint64_t>(tag->listRef.elementCount()) *
dataBitsPerElement(tag->listRef.elementSize())) dataBitsPerElement(tag->listRef.elementSize())));
* BYTES_PER_WORD / BYTES);
break; break;
}
case ElementSize::POINTER: { case ElementSize::POINTER: {
uint count = tag->listRef.elementCount() / ELEMENTS; WirePointer* typedPtr = reinterpret_cast<WirePointer*>(ptr);
for (uint i = 0; i < count; i++) { auto count = tag->listRef.elementCount() * (ONE * POINTERS / ELEMENTS);
zeroObject(segment, capTable, reinterpret_cast<WirePointer*>(ptr) + i); for (auto i: kj::zeroTo(count)) {
zeroObject(segment, capTable, typedPtr + i);
} }
memset(ptr, 0, POINTER_SIZE_IN_WORDS * count * BYTES_PER_WORD / BYTES); zeroMemory(typedPtr, count);
break; break;
} }
case ElementSize::INLINE_COMPOSITE: { case ElementSize::INLINE_COMPOSITE: {
...@@ -572,21 +630,25 @@ struct WireHelpers { ...@@ -572,21 +630,25 @@ struct WireHelpers {
WordCount dataSize = elementTag->structRef.dataSize.get(); WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get(); WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
uint count = elementTag->inlineCompositeListElementCount() / ELEMENTS; auto count = elementTag->inlineCompositeListElementCount();
if (pointerCount > 0 * POINTERS) { if (pointerCount > 0 * POINTERS) {
word* pos = ptr + POINTER_SIZE_IN_WORDS; word* pos = ptr + POINTER_SIZE_IN_WORDS;
for (uint i = 0; i < count; i++) { for (auto i KJ_UNUSED: kj::zeroTo(count)) {
pos += dataSize; pos += dataSize;
for (uint j = 0; j < pointerCount / POINTERS; j++) { for (auto j KJ_UNUSED: kj::zeroTo(pointerCount)) {
zeroObject(segment, capTable, reinterpret_cast<WirePointer*>(pos)); zeroObject(segment, capTable, reinterpret_cast<WirePointer*>(pos));
pos += POINTER_SIZE_IN_WORDS; pos += POINTER_SIZE_IN_WORDS;
} }
} }
} }
memset(ptr, 0, (elementTag->structRef.wordSize() * count + POINTER_SIZE_IN_WORDS) auto wordsPerElement = elementTag->structRef.wordSize() / ELEMENTS;
* BYTES_PER_WORD / BYTES); zeroMemory(ptr, assertMaxBits<SEGMENT_WORD_COUNT_BITS>(POINTER_SIZE_IN_WORDS +
upgradeGuard<uint64_t>(count) * wordsPerElement, []() {
KJ_FAIL_ASSERT("encountered list pointer in builder which is too large to "
"possibly fit in a segment. Bug in builder code?");
}));
break; break;
} }
} }
...@@ -627,7 +689,7 @@ struct WireHelpers { ...@@ -627,7 +689,7 @@ struct WireHelpers {
SegmentReader* segment, const WirePointer* ref, int nestingLimit) { SegmentReader* segment, const WirePointer* ref, int nestingLimit) {
// Compute the total size of the object pointed to, not counting far pointer overhead. // Compute the total size of the object pointed to, not counting far pointer overhead.
MessageSizeCounts result = { 0 * WORDS, 0 }; MessageSizeCounts result = { ZERO * WORDS, 0 };
if (ref->isNull()) { if (ref->isNull()) {
return result; return result;
...@@ -646,12 +708,11 @@ struct WireHelpers { ...@@ -646,12 +708,11 @@ struct WireHelpers {
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
return result; return result;
} }
result.wordCount += ref->structRef.wordSize(); result.addWords(ref->structRef.wordSize());
const WirePointer* pointerSection = const WirePointer* pointerSection =
reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get()); reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get());
uint count = ref->structRef.ptrCount.get() / POINTERS; for (auto i: kj::zeroTo(ref->structRef.ptrCount.get())) {
for (uint i = 0; i < count; i++) {
result += totalSize(segment, pointerSection + i, nestingLimit); result += totalSize(segment, pointerSection + i, nestingLimit);
} }
break; break;
...@@ -666,14 +727,14 @@ struct WireHelpers { ...@@ -666,14 +727,14 @@ struct WireHelpers {
case ElementSize::TWO_BYTES: case ElementSize::TWO_BYTES:
case ElementSize::FOUR_BYTES: case ElementSize::FOUR_BYTES:
case ElementSize::EIGHT_BYTES: { case ElementSize::EIGHT_BYTES: {
WordCount64 totalWords = roundBitsUpToWords( auto totalWords = roundBitsUpToWords(
ElementCount64(ref->listRef.elementCount()) * upgradeGuard<uint64_t>(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize())); dataBitsPerElement(ref->listRef.elementSize()));
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + totalWords), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + totalWords),
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
return result; return result;
} }
result.wordCount += totalWords; result.addWords(totalWords);
break; break;
} }
case ElementSize::POINTER: { case ElementSize::POINTER: {
...@@ -684,30 +745,31 @@ struct WireHelpers { ...@@ -684,30 +745,31 @@ struct WireHelpers {
return result; return result;
} }
result.wordCount += count * WORDS_PER_POINTER; result.addWords(count * WORDS_PER_POINTER);
for (uint i = 0; i < count / POINTERS; i++) { for (auto i: kj::zeroTo(count)) {
result += totalSize(segment, reinterpret_cast<const WirePointer*>(ptr) + i, result += totalSize(segment, reinterpret_cast<const WirePointer*>(ptr) + i,
nestingLimit); nestingLimit);
} }
break; break;
} }
case ElementSize::INLINE_COMPOSITE: { case ElementSize::INLINE_COMPOSITE: {
WordCount wordCount = ref->listRef.inlineCompositeWordCount(); auto wordCount = ref->listRef.inlineCompositeWordCount();
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
return result; return result;
} }
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount(); auto count = elementTag->inlineCompositeListElementCount();
KJ_REQUIRE(elementTag->kind() == WirePointer::STRUCT, KJ_REQUIRE(elementTag->kind() == WirePointer::STRUCT,
"Don't know how to handle non-STRUCT inline composite.") { "Don't know how to handle non-STRUCT inline composite.") {
return result; return result;
} }
auto actualSize = elementTag->structRef.wordSize() / ELEMENTS * ElementCount64(count); auto actualSize = elementTag->structRef.wordSize() / ELEMENTS *
upgradeGuard<uint64_t>(count);
KJ_REQUIRE(actualSize <= wordCount, KJ_REQUIRE(actualSize <= wordCount,
"Struct list pointer's elements overran size.") { "Struct list pointer's elements overran size.") {
return result; return result;
...@@ -715,17 +777,17 @@ struct WireHelpers { ...@@ -715,17 +777,17 @@ struct WireHelpers {
// We count the actual size rather than the claimed word count because that's what // We count the actual size rather than the claimed word count because that's what
// we'll end up with if we make a copy. // we'll end up with if we make a copy.
result.wordCount += actualSize + POINTER_SIZE_IN_WORDS; result.addWords(wordCount + POINTER_SIZE_IN_WORDS);
WordCount dataSize = elementTag->structRef.dataSize.get(); WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get(); WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
if (pointerCount > 0 * POINTERS) { if (pointerCount > 0 * POINTERS) {
const word* pos = ptr + POINTER_SIZE_IN_WORDS; const word* pos = ptr + POINTER_SIZE_IN_WORDS;
for (uint i = 0; i < count / ELEMENTS; i++) { for (auto i KJ_UNUSED: kj::zeroTo(count)) {
pos += dataSize; pos += dataSize;
for (uint j = 0; j < pointerCount / POINTERS; j++) { for (auto j KJ_UNUSED: kj::zeroTo(pointerCount)) {
result += totalSize(segment, reinterpret_cast<const WirePointer*>(pos), result += totalSize(segment, reinterpret_cast<const WirePointer*>(pos),
nestingLimit); nestingLimit);
pos += POINTER_SIZE_IN_WORDS; pos += POINTER_SIZE_IN_WORDS;
...@@ -760,13 +822,13 @@ struct WireHelpers { ...@@ -760,13 +822,13 @@ struct WireHelpers {
static KJ_ALWAYS_INLINE( static KJ_ALWAYS_INLINE(
void copyStruct(SegmentBuilder* segment, CapTableBuilder* capTable, void copyStruct(SegmentBuilder* segment, CapTableBuilder* capTable,
word* dst, const word* src, word* dst, const word* src,
WordCount dataSize, WirePointerCount pointerCount)) { StructDataWordCount dataSize, StructPointerCount pointerCount)) {
memcpy(dst, src, dataSize * BYTES_PER_WORD / BYTES); copyMemory(dst, src, dataSize);
const WirePointer* srcRefs = reinterpret_cast<const WirePointer*>(src + dataSize); const WirePointer* srcRefs = reinterpret_cast<const WirePointer*>(src + dataSize);
WirePointer* dstRefs = reinterpret_cast<WirePointer*>(dst + dataSize); WirePointer* dstRefs = reinterpret_cast<WirePointer*>(dst + dataSize);
for (uint i = 0; i < pointerCount / POINTERS; i++) { for (auto i: kj::zeroTo(pointerCount)) {
SegmentBuilder* subSegment = segment; SegmentBuilder* subSegment = segment;
WirePointer* dstRef = dstRefs + i; WirePointer* dstRef = dstRefs + i;
copyMessage(subSegment, capTable, dstRef, srcRefs + i); copyMessage(subSegment, capTable, dstRef, srcRefs + i);
...@@ -803,12 +865,12 @@ struct WireHelpers { ...@@ -803,12 +865,12 @@ struct WireHelpers {
case ElementSize::TWO_BYTES: case ElementSize::TWO_BYTES:
case ElementSize::FOUR_BYTES: case ElementSize::FOUR_BYTES:
case ElementSize::EIGHT_BYTES: { case ElementSize::EIGHT_BYTES: {
WordCount wordCount = roundBitsUpToWords( auto wordCount = roundBitsUpToWords(
ElementCount64(src->listRef.elementCount()) * upgradeGuard<uint64_t>(src->listRef.elementCount()) *
dataBitsPerElement(src->listRef.elementSize())); dataBitsPerElement(src->listRef.elementSize()));
const word* srcPtr = src->target(); const word* srcPtr = src->target();
word* dstPtr = allocate(dst, segment, capTable, wordCount, WirePointer::LIST, nullptr); word* dstPtr = allocate(dst, segment, capTable, wordCount, WirePointer::LIST, nullptr);
memcpy(dstPtr, srcPtr, wordCount * BYTES_PER_WORD / BYTES); copyMemory(dstPtr, srcPtr, wordCount);
dst->listRef.set(src->listRef.elementSize(), src->listRef.elementCount()); dst->listRef.set(src->listRef.elementSize(), src->listRef.elementCount());
return dstPtr; return dstPtr;
...@@ -818,11 +880,10 @@ struct WireHelpers { ...@@ -818,11 +880,10 @@ struct WireHelpers {
const WirePointer* srcRefs = reinterpret_cast<const WirePointer*>(src->target()); const WirePointer* srcRefs = reinterpret_cast<const WirePointer*>(src->target());
WirePointer* dstRefs = reinterpret_cast<WirePointer*>( WirePointer* dstRefs = reinterpret_cast<WirePointer*>(
allocate(dst, segment, capTable, src->listRef.elementCount() * allocate(dst, segment, capTable, src->listRef.elementCount() *
(1 * POINTERS / ELEMENTS) * WORDS_PER_POINTER, (ONE * POINTERS / ELEMENTS) * WORDS_PER_POINTER,
WirePointer::LIST, nullptr)); WirePointer::LIST, nullptr));
uint n = src->listRef.elementCount() / ELEMENTS; for (auto i: kj::zeroTo(src->listRef.elementCount() * (ONE * POINTERS / ELEMENTS))) {
for (uint i = 0; i < n; i++) {
SegmentBuilder* subSegment = segment; SegmentBuilder* subSegment = segment;
WirePointer* dstRef = dstRefs + i; WirePointer* dstRef = dstRefs + i;
copyMessage(subSegment, capTable, dstRef, srcRefs + i); copyMessage(subSegment, capTable, dstRef, srcRefs + i);
...@@ -835,7 +896,9 @@ struct WireHelpers { ...@@ -835,7 +896,9 @@ struct WireHelpers {
case ElementSize::INLINE_COMPOSITE: { case ElementSize::INLINE_COMPOSITE: {
const word* srcPtr = src->target(); const word* srcPtr = src->target();
word* dstPtr = allocate(dst, segment, capTable, word* dstPtr = allocate(dst, segment, capTable,
src->listRef.inlineCompositeWordCount() + POINTER_SIZE_IN_WORDS, assertMaxBits<SEGMENT_WORD_COUNT_BITS>(
src->listRef.inlineCompositeWordCount() + POINTER_SIZE_IN_WORDS,
[]() { KJ_FAIL_ASSERT("list too big to fit in a segment"); }),
WirePointer::LIST, nullptr); WirePointer::LIST, nullptr);
dst->listRef.setInlineComposite(src->listRef.inlineCompositeWordCount()); dst->listRef.setInlineComposite(src->listRef.inlineCompositeWordCount());
...@@ -849,8 +912,7 @@ struct WireHelpers { ...@@ -849,8 +912,7 @@ struct WireHelpers {
KJ_ASSERT(srcTag->kind() == WirePointer::STRUCT, KJ_ASSERT(srcTag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE of lists is not yet supported."); "INLINE_COMPOSITE of lists is not yet supported.");
uint n = srcTag->inlineCompositeListElementCount() / ELEMENTS; for (auto i KJ_UNUSED: kj::zeroTo(srcTag->inlineCompositeListElementCount())) {
for (uint i = 0; i < n; i++) {
copyStruct(segment, capTable, dstElement, srcElement, copyStruct(segment, capTable, dstElement, srcElement,
srcTag->structRef.dataSize.get(), srcTag->structRef.ptrCount.get()); srcTag->structRef.dataSize.get(), srcTag->structRef.ptrCount.get());
srcElement += srcTag->structRef.wordSize(); srcElement += srcTag->structRef.wordSize();
...@@ -917,10 +979,10 @@ struct WireHelpers { ...@@ -917,10 +979,10 @@ struct WireHelpers {
// that it doesn't need to be a double-far. // that it doesn't need to be a double-far.
WirePointer* landingPad = WirePointer* landingPad =
reinterpret_cast<WirePointer*>(srcSegment->allocate(1 * WORDS)); reinterpret_cast<WirePointer*>(srcSegment->allocate(G(1) * WORDS));
if (landingPad == nullptr) { if (landingPad == nullptr) {
// Darn, need a double-far. // Darn, need a double-far.
auto allocation = srcSegment->getArena()->allocate(2 * WORDS); auto allocation = srcSegment->getArena()->allocate(G(2) * WORDS);
SegmentBuilder* farSegment = allocation.segment; SegmentBuilder* farSegment = allocation.segment;
landingPad = reinterpret_cast<WirePointer*>(allocation.words); landingPad = reinterpret_cast<WirePointer*>(allocation.words);
...@@ -988,8 +1050,8 @@ struct WireHelpers { ...@@ -988,8 +1050,8 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
WordCount oldDataSize = oldRef->structRef.dataSize.get(); auto oldDataSize = oldRef->structRef.dataSize.get();
WirePointerCount oldPointerCount = oldRef->structRef.ptrCount.get(); auto oldPointerCount = oldRef->structRef.ptrCount.get();
WirePointer* oldPointerSection = WirePointer* oldPointerSection =
reinterpret_cast<WirePointer*>(oldPtr + oldDataSize); reinterpret_cast<WirePointer*>(oldPtr + oldDataSize);
...@@ -998,9 +1060,9 @@ struct WireHelpers { ...@@ -998,9 +1060,9 @@ struct WireHelpers {
// run with it and do bounds checks at access time, because how would we handle writes? // run with it and do bounds checks at access time, because how would we handle writes?
// Instead, we have to copy the struct to a new space now. // Instead, we have to copy the struct to a new space now.
WordCount newDataSize = kj::max(oldDataSize, size.data); auto newDataSize = kj::max(oldDataSize, size.data);
WirePointerCount newPointerCount = kj::max(oldPointerCount, size.pointers); auto newPointerCount = kj::max(oldPointerCount, size.pointers);
WordCount totalSize = newDataSize + newPointerCount * WORDS_PER_POINTER; auto totalSize = newDataSize + newPointerCount * WORDS_PER_POINTER;
// Don't let allocate() zero out the object just yet. // Don't let allocate() zero out the object just yet.
zeroPointerAndFars(segment, ref); zeroPointerAndFars(segment, ref);
...@@ -1009,11 +1071,11 @@ struct WireHelpers { ...@@ -1009,11 +1071,11 @@ struct WireHelpers {
ref->structRef.set(newDataSize, newPointerCount); ref->structRef.set(newDataSize, newPointerCount);
// Copy data section. // Copy data section.
memcpy(ptr, oldPtr, oldDataSize * BYTES_PER_WORD / BYTES); copyMemory(ptr, oldPtr, oldDataSize);
// Copy pointer section. // Copy pointer section.
WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(ptr + newDataSize); WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(ptr + newDataSize);
for (uint i = 0; i < oldPointerCount / POINTERS; i++) { for (auto i: kj::zeroTo(oldPointerCount)) {
transferPointer(segment, newPointerSection + i, oldSegment, oldPointerSection + i); transferPointer(segment, newPointerSection + i, oldSegment, oldPointerSection + i);
} }
...@@ -1022,8 +1084,7 @@ struct WireHelpers { ...@@ -1022,8 +1084,7 @@ struct WireHelpers {
// out as it may contain secrets that the caller intends to remove from the new copy. // out as it may contain secrets that the caller intends to remove from the new copy.
// 2) Zeros will be deflated by packing, making this dead memory almost-free if it ever // 2) Zeros will be deflated by packing, making this dead memory almost-free if it ever
// hits the wire. // hits the wire.
memset(oldPtr, 0, zeroMemory(oldPtr, oldDataSize + oldPointerCount * WORDS_PER_POINTER);
(oldDataSize + oldPointerCount * WORDS_PER_POINTER) * BYTES_PER_WORD / BYTES);
return StructBuilder(segment, capTable, ptr, newPointerSection, newDataSize * BITS_PER_WORD, return StructBuilder(segment, capTable, ptr, newPointerSection, newDataSize * BITS_PER_WORD,
newPointerCount); newPointerCount);
...@@ -1039,31 +1100,40 @@ struct WireHelpers { ...@@ -1039,31 +1100,40 @@ struct WireHelpers {
KJ_DREQUIRE(elementSize != ElementSize::INLINE_COMPOSITE, KJ_DREQUIRE(elementSize != ElementSize::INLINE_COMPOSITE,
"Should have called initStructListPointer() instead."); "Should have called initStructListPointer() instead.");
BitCount dataSize = dataBitsPerElement(elementSize) * ELEMENTS; auto checkedElementCount = assertMaxBits<LIST_ELEMENT_COUNT_BITS>(elementCount,
WirePointerCount pointerCount = pointersPerElement(elementSize) * ELEMENTS; []() { KJ_FAIL_REQUIRE("tried to allocate list with too many elements"); });
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
auto dataSize = dataBitsPerElement(elementSize) * ELEMENTS;
auto pointerCount = pointersPerElement(elementSize) * ELEMENTS;
auto step = bitsPerElementIncludingPointers(elementSize);
KJ_DASSERT(step * ELEMENTS == (dataSize + pointerCount * BITS_PER_POINTER));
// Calculate size of the list. // Calculate size of the list.
WordCount wordCount = roundBitsUpToWords(ElementCount64(elementCount) * step); auto wordCount = roundBitsUpToWords(upgradeGuard<uint64_t>(checkedElementCount) * step);
// Allocate the list. // Allocate the list.
word* ptr = allocate(ref, segment, capTable, wordCount, WirePointer::LIST, orphanArena); word* ptr = allocate(ref, segment, capTable, wordCount, WirePointer::LIST, orphanArena);
// Initialize the pointer. // Initialize the pointer.
ref->listRef.set(elementSize, elementCount); ref->listRef.set(elementSize, checkedElementCount);
// Build the ListBuilder. // Build the ListBuilder.
return ListBuilder(segment, capTable, ptr, step, elementCount, dataSize, return ListBuilder(segment, capTable, ptr, step, checkedElementCount,
pointerCount, elementSize); dataSize, pointerCount, elementSize);
} }
static KJ_ALWAYS_INLINE(ListBuilder initStructListPointer( static KJ_ALWAYS_INLINE(ListBuilder initStructListPointer(
WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable,
ElementCount elementCount, StructSize elementSize, BuilderArena* orphanArena = nullptr)) { ElementCount elementCount, StructSize elementSize, BuilderArena* orphanArena = nullptr)) {
auto wordsPerElement = elementSize.total() / ELEMENTS; auto checkedElementCount = assertMaxBits<LIST_ELEMENT_COUNT_BITS>(elementCount,
[]() { KJ_FAIL_REQUIRE("tried to allocate list with too many elements"); });
WordsPerElementN<17> wordsPerElement = elementSize.total() / ELEMENTS;
// Allocate the list, prefixed by a single WirePointer. // Allocate the list, prefixed by a single WirePointer.
WordCount wordCount = elementCount * wordsPerElement; auto wordCount = assertMax<kj::maxValueForBits<SEGMENT_WORD_COUNT_BITS>() - 1>(
upgradeGuard<uint64_t>(checkedElementCount) * wordsPerElement,
[]() { KJ_FAIL_REQUIRE("total size of struct list is larger than max segment size"); });
word* ptr = allocate(ref, segment, capTable, POINTER_SIZE_IN_WORDS + wordCount, word* ptr = allocate(ref, segment, capTable, POINTER_SIZE_IN_WORDS + wordCount,
WirePointer::LIST, orphanArena); WirePointer::LIST, orphanArena);
...@@ -1073,12 +1143,12 @@ struct WireHelpers { ...@@ -1073,12 +1143,12 @@ struct WireHelpers {
// Initialize the list tag. // Initialize the list tag.
reinterpret_cast<WirePointer*>(ptr)->setKindAndInlineCompositeListElementCount( reinterpret_cast<WirePointer*>(ptr)->setKindAndInlineCompositeListElementCount(
WirePointer::STRUCT, elementCount); WirePointer::STRUCT, checkedElementCount);
reinterpret_cast<WirePointer*>(ptr)->structRef.set(elementSize); reinterpret_cast<WirePointer*>(ptr)->structRef.set(elementSize);
ptr += POINTER_SIZE_IN_WORDS; ptr += POINTER_SIZE_IN_WORDS;
// Build the ListBuilder. // Build the ListBuilder.
return ListBuilder(segment, capTable, ptr, wordsPerElement * BITS_PER_WORD, elementCount, return ListBuilder(segment, capTable, ptr, wordsPerElement * BITS_PER_WORD, checkedElementCount,
elementSize.data * BITS_PER_WORD, elementSize.pointers, elementSize.data * BITS_PER_WORD, elementSize.pointers,
ElementSize::INLINE_COMPOSITE); ElementSize::INLINE_COMPOSITE);
} }
...@@ -1136,8 +1206,8 @@ struct WireHelpers { ...@@ -1136,8 +1206,8 @@ struct WireHelpers {
"INLINE_COMPOSITE list with non-STRUCT elements not supported."); "INLINE_COMPOSITE list with non-STRUCT elements not supported.");
ptr += POINTER_SIZE_IN_WORDS; ptr += POINTER_SIZE_IN_WORDS;
WordCount dataSize = tag->structRef.dataSize.get(); auto dataSize = tag->structRef.dataSize.get();
WirePointerCount pointerCount = tag->structRef.ptrCount.get(); auto pointerCount = tag->structRef.ptrCount.get();
switch (elementSize) { switch (elementSize) {
case ElementSize::VOID: case ElementSize::VOID:
...@@ -1156,14 +1226,14 @@ struct WireHelpers { ...@@ -1156,14 +1226,14 @@ struct WireHelpers {
case ElementSize::TWO_BYTES: case ElementSize::TWO_BYTES:
case ElementSize::FOUR_BYTES: case ElementSize::FOUR_BYTES:
case ElementSize::EIGHT_BYTES: case ElementSize::EIGHT_BYTES:
KJ_REQUIRE(dataSize >= 1 * WORDS, KJ_REQUIRE(dataSize >= ONE * WORDS,
"Existing list value is incompatible with expected type.") { "Existing list value is incompatible with expected type.") {
goto useDefault; goto useDefault;
} }
break; break;
case ElementSize::POINTER: case ElementSize::POINTER:
KJ_REQUIRE(pointerCount >= 1 * POINTERS, KJ_REQUIRE(pointerCount >= ONE * POINTERS,
"Existing list value is incompatible with expected type.") { "Existing list value is incompatible with expected type.") {
goto useDefault; goto useDefault;
} }
...@@ -1182,8 +1252,8 @@ struct WireHelpers { ...@@ -1182,8 +1252,8 @@ struct WireHelpers {
tag->inlineCompositeListElementCount(), tag->inlineCompositeListElementCount(),
dataSize * BITS_PER_WORD, pointerCount, ElementSize::INLINE_COMPOSITE); dataSize * BITS_PER_WORD, pointerCount, ElementSize::INLINE_COMPOSITE);
} else { } else {
BitCount dataSize = dataBitsPerElement(oldSize) * ELEMENTS; auto dataSize = dataBitsPerElement(oldSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(oldSize) * ELEMENTS; auto pointerCount = pointersPerElement(oldSize) * ELEMENTS;
if (elementSize == ElementSize::BIT) { if (elementSize == ElementSize::BIT) {
KJ_REQUIRE(oldSize == ElementSize::BIT, KJ_REQUIRE(oldSize == ElementSize::BIT,
...@@ -1257,8 +1327,8 @@ struct WireHelpers { ...@@ -1257,8 +1327,8 @@ struct WireHelpers {
tag->structRef.dataSize.get() * BITS_PER_WORD, tag->structRef.dataSize.get() * BITS_PER_WORD,
tag->structRef.ptrCount.get(), ElementSize::INLINE_COMPOSITE); tag->structRef.ptrCount.get(), ElementSize::INLINE_COMPOSITE);
} else { } else {
BitCount dataSize = dataBitsPerElement(elementSize) * ELEMENTS; auto dataSize = dataBitsPerElement(elementSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(elementSize) * ELEMENTS; auto pointerCount = pointersPerElement(elementSize) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS; auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(segment, capTable, ptr, step, ref->listRef.elementCount(), return ListBuilder(segment, capTable, ptr, step, ref->listRef.elementCount(),
...@@ -1310,10 +1380,11 @@ struct WireHelpers { ...@@ -1310,10 +1380,11 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
WordCount oldDataSize = oldTag->structRef.dataSize.get(); auto oldDataSize = oldTag->structRef.dataSize.get();
WirePointerCount oldPointerCount = oldTag->structRef.ptrCount.get(); auto oldPointerCount = oldTag->structRef.ptrCount.get();
auto oldStep = (oldDataSize + oldPointerCount * WORDS_PER_POINTER) / ELEMENTS; auto oldStep = (oldDataSize + oldPointerCount * WORDS_PER_POINTER) / ELEMENTS;
ElementCount elementCount = oldTag->inlineCompositeListElementCount();
auto elementCount = oldTag->inlineCompositeListElementCount();
if (oldDataSize >= elementSize.data && oldPointerCount >= elementSize.pointers) { if (oldDataSize >= elementSize.data && oldPointerCount >= elementSize.pointers) {
// Old size is at least as large as we need. Ship it. // Old size is at least as large as we need. Ship it.
...@@ -1325,10 +1396,13 @@ struct WireHelpers { ...@@ -1325,10 +1396,13 @@ struct WireHelpers {
// The structs in this list are smaller than expected, probably written using an older // The structs in this list are smaller than expected, probably written using an older
// version of the protocol. We need to make a copy and expand them. // version of the protocol. We need to make a copy and expand them.
WordCount newDataSize = kj::max(oldDataSize, elementSize.data); auto newDataSize = kj::max(oldDataSize, elementSize.data);
WirePointerCount newPointerCount = kj::max(oldPointerCount, elementSize.pointers); auto newPointerCount = kj::max(oldPointerCount, elementSize.pointers);
auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS; auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS;
WordCount totalSize = newStep * elementCount;
auto totalSize = assertMax<kj::maxValueForBits<SEGMENT_WORD_COUNT_BITS>() - 1>(
newStep * upgradeGuard<uint64_t>(elementCount),
[]() { KJ_FAIL_REQUIRE("total size of struct list is larger than max segment size"); });
// Don't let allocate() zero out the object just yet. // Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef); zeroPointerAndFars(origSegment, origRef);
...@@ -1344,35 +1418,39 @@ struct WireHelpers { ...@@ -1344,35 +1418,39 @@ struct WireHelpers {
word* src = oldPtr; word* src = oldPtr;
word* dst = newPtr; word* dst = newPtr;
for (uint i = 0; i < elementCount / ELEMENTS; i++) { for (auto i KJ_UNUSED: kj::zeroTo(elementCount)) {
// Copy data section. // Copy data section.
memcpy(dst, src, oldDataSize * BYTES_PER_WORD / BYTES); copyMemory(dst, src, oldDataSize);
// Copy pointer section. // Copy pointer section.
WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(dst + newDataSize); WirePointer* newPointerSection = reinterpret_cast<WirePointer*>(dst + newDataSize);
WirePointer* oldPointerSection = reinterpret_cast<WirePointer*>(src + oldDataSize); WirePointer* oldPointerSection = reinterpret_cast<WirePointer*>(src + oldDataSize);
for (uint j = 0; j < oldPointerCount / POINTERS; j++) { for (auto j: kj::zeroTo(oldPointerCount)) {
transferPointer(origSegment, newPointerSection + j, oldSegment, oldPointerSection + j); transferPointer(origSegment, newPointerSection + j, oldSegment, oldPointerSection + j);
} }
dst += newStep * (1 * ELEMENTS); dst += newStep * (ONE * ELEMENTS);
src += oldStep * (1 * ELEMENTS); src += oldStep * (ONE * ELEMENTS);
} }
auto oldSize = assertMax<kj::maxValueForBits<SEGMENT_WORD_COUNT_BITS>() - 1>(
oldStep * upgradeGuard<uint64_t>(elementCount),
[]() { KJ_FAIL_ASSERT("old size overflows but new size doesn't?"); });
// Zero out old location. See explanation in getWritableStructPointer(). // Zero out old location. See explanation in getWritableStructPointer().
// Make sure to include the tag word. // Make sure to include the tag word.
memset(oldPtr - POINTER_SIZE_IN_WORDS, 0, zeroMemory(oldPtr - POINTER_SIZE_IN_WORDS, oldSize + POINTER_SIZE_IN_WORDS);
(POINTER_SIZE_IN_WORDS + oldStep * elementCount) * BYTES_PER_WORD / BYTES);
return ListBuilder(origSegment, capTable, newPtr, newStep * BITS_PER_WORD, elementCount, return ListBuilder(origSegment, capTable, newPtr, newStep * BITS_PER_WORD, elementCount,
newDataSize * BITS_PER_WORD, newPointerCount, ElementSize::INLINE_COMPOSITE); newDataSize * BITS_PER_WORD, newPointerCount,
ElementSize::INLINE_COMPOSITE);
} else { } else {
// We're upgrading from a non-struct list. // We're upgrading from a non-struct list.
BitCount oldDataSize = dataBitsPerElement(oldSize) * ELEMENTS; auto oldDataSize = dataBitsPerElement(oldSize) * ELEMENTS;
WirePointerCount oldPointerCount = pointersPerElement(oldSize) * ELEMENTS; auto oldPointerCount = pointersPerElement(oldSize) * ELEMENTS;
auto oldStep = (oldDataSize + oldPointerCount * BITS_PER_POINTER) / ELEMENTS; auto oldStep = (oldDataSize + oldPointerCount * BITS_PER_POINTER) / ELEMENTS;
ElementCount elementCount = oldRef->listRef.elementCount(); auto elementCount = oldRef->listRef.elementCount();
if (oldSize == ElementSize::VOID) { if (oldSize == ElementSize::VOID) {
// Nothing to copy, just allocate a new list. // Nothing to copy, just allocate a new list.
...@@ -1386,18 +1464,20 @@ struct WireHelpers { ...@@ -1386,18 +1464,20 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
WordCount newDataSize = elementSize.data; auto newDataSize = elementSize.data;
WirePointerCount newPointerCount = elementSize.pointers; auto newPointerCount = elementSize.pointers;
if (oldSize == ElementSize::POINTER) { if (oldSize == ElementSize::POINTER) {
newPointerCount = kj::max(newPointerCount, 1 * POINTERS); newPointerCount = kj::max(newPointerCount, ONE * POINTERS);
} else { } else {
// Old list contains data elements, so we need at least 1 word of data. // Old list contains data elements, so we need at least 1 word of data.
newDataSize = kj::max(newDataSize, 1 * WORDS); newDataSize = kj::max(newDataSize, ONE * WORDS);
} }
auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS; auto newStep = (newDataSize + newPointerCount * WORDS_PER_POINTER) / ELEMENTS;
WordCount totalWords = elementCount * newStep; auto totalWords = assertMax<kj::maxValueForBits<SEGMENT_WORD_COUNT_BITS>() - 1>(
newStep * upgradeGuard<uint64_t>(elementCount),
[]() {KJ_FAIL_REQUIRE("total size of struct list is larger than max segment size");});
// Don't let allocate() zero out the object just yet. // Don't let allocate() zero out the object just yet.
zeroPointerAndFars(origSegment, origRef); zeroPointerAndFars(origSegment, origRef);
...@@ -1414,24 +1494,29 @@ struct WireHelpers { ...@@ -1414,24 +1494,29 @@ struct WireHelpers {
if (oldSize == ElementSize::POINTER) { if (oldSize == ElementSize::POINTER) {
WirePointer* dst = reinterpret_cast<WirePointer*>(newPtr + newDataSize); WirePointer* dst = reinterpret_cast<WirePointer*>(newPtr + newDataSize);
WirePointer* src = reinterpret_cast<WirePointer*>(oldPtr); WirePointer* src = reinterpret_cast<WirePointer*>(oldPtr);
for (uint i = 0; i < elementCount / ELEMENTS; i++) { for (auto i KJ_UNUSED: kj::zeroTo(elementCount)) {
transferPointer(origSegment, dst, oldSegment, src); transferPointer(origSegment, dst, oldSegment, src);
dst += newStep / WORDS_PER_POINTER * (1 * ELEMENTS); dst += newStep / WORDS_PER_POINTER * (ONE * ELEMENTS);
++src; ++src;
} }
} else { } else {
word* dst = newPtr; byte* dst = reinterpret_cast<byte*>(newPtr);
char* src = reinterpret_cast<char*>(oldPtr); byte* src = reinterpret_cast<byte*>(oldPtr);
ByteCount oldByteStep = oldDataSize / BITS_PER_BYTE; auto newByteStep = newStep * (ONE * ELEMENTS) * BYTES_PER_WORD;
for (uint i = 0; i < elementCount / ELEMENTS; i++) { auto oldByteStep = oldDataSize / BITS_PER_BYTE;
memcpy(dst, src, oldByteStep / BYTES); for (auto i KJ_UNUSED: kj::zeroTo(elementCount)) {
src += oldByteStep / BYTES; copyMemory(dst, src, oldByteStep);
dst += newStep * (1 * ELEMENTS); src += oldByteStep;
dst += newByteStep;
} }
} }
auto oldSize = assertMax<kj::maxValueForBits<SEGMENT_WORD_COUNT_BITS>() - 1>(
roundBitsUpToWords(oldStep * upgradeGuard<uint64_t>(elementCount)),
[]() { KJ_FAIL_ASSERT("old size overflows but new size doesn't?"); });
// Zero out old location. See explanation in getWritableStructPointer(). // Zero out old location. See explanation in getWritableStructPointer().
memset(oldPtr, 0, roundBitsUpToBytes(oldStep * elementCount) / BYTES); zeroMemory(oldPtr, oldSize);
return ListBuilder(origSegment, capTable, newPtr, newStep * BITS_PER_WORD, elementCount, return ListBuilder(origSegment, capTable, newPtr, newStep * BITS_PER_WORD, elementCount,
newDataSize * BITS_PER_WORD, newPointerCount, newDataSize * BITS_PER_WORD, newPointerCount,
...@@ -1441,46 +1526,50 @@ struct WireHelpers { ...@@ -1441,46 +1526,50 @@ struct WireHelpers {
} }
static KJ_ALWAYS_INLINE(SegmentAnd<Text::Builder> initTextPointer( static KJ_ALWAYS_INLINE(SegmentAnd<Text::Builder> initTextPointer(
WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, ByteCount size, WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, TextSize size,
BuilderArena* orphanArena = nullptr)) { BuilderArena* orphanArena = nullptr)) {
// The byte list must include a NUL terminator. // The byte list must include a NUL terminator.
ByteCount byteSize = size + 1 * BYTES; auto byteSize = size + ONE * BYTES;
// Allocate the space. // Allocate the space.
word* ptr = allocate( word* ptr = allocate(
ref, segment, capTable, roundBytesUpToWords(byteSize), WirePointer::LIST, orphanArena); ref, segment, capTable, roundBytesUpToWords(byteSize), WirePointer::LIST, orphanArena);
// Initialize the pointer. // Initialize the pointer.
ref->listRef.set(ElementSize::BYTE, byteSize * (1 * ELEMENTS / BYTES)); ref->listRef.set(ElementSize::BYTE, byteSize * (ONE * ELEMENTS / BYTES));
// Build the Text::Builder. This will initialize the NUL terminator. // Build the Text::Builder. This will initialize the NUL terminator.
return { segment, Text::Builder(reinterpret_cast<char*>(ptr), size / BYTES) }; return { segment, Text::Builder(reinterpret_cast<char*>(ptr), unguard(size / BYTES)) };
} }
static KJ_ALWAYS_INLINE(SegmentAnd<Text::Builder> setTextPointer( static KJ_ALWAYS_INLINE(SegmentAnd<Text::Builder> setTextPointer(
WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, Text::Reader value, WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, Text::Reader value,
BuilderArena* orphanArena = nullptr)) { BuilderArena* orphanArena = nullptr)) {
auto allocation = initTextPointer(ref, segment, capTable, value.size() * BYTES, orphanArena); TextSize size = assertMax<MAX_TEXT_SIZE>(guarded(value.size()),
[]() { KJ_FAIL_REQUIRE("text blob too big"); }) * BYTES;
auto allocation = initTextPointer(ref, segment, capTable, size, orphanArena);
memcpy(allocation.value.begin(), value.begin(), value.size()); memcpy(allocation.value.begin(), value.begin(), value.size());
return allocation; return allocation;
} }
static KJ_ALWAYS_INLINE(Text::Builder getWritableTextPointer( static KJ_ALWAYS_INLINE(Text::Builder getWritableTextPointer(
WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable,
const void* defaultValue, ByteCount defaultSize)) { const void* defaultValue, TextSize defaultSize)) {
return getWritableTextPointer(ref, ref->target(), segment, capTable, defaultValue, defaultSize); return getWritableTextPointer(ref, ref->target(), segment,capTable, defaultValue, defaultSize);
} }
static KJ_ALWAYS_INLINE(Text::Builder getWritableTextPointer( static KJ_ALWAYS_INLINE(Text::Builder getWritableTextPointer(
WirePointer* ref, word* refTarget, SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, word* refTarget, SegmentBuilder* segment, CapTableBuilder* capTable,
const void* defaultValue, ByteCount defaultSize)) { const void* defaultValue, TextSize defaultSize)) {
if (ref->isNull()) { if (ref->isNull()) {
useDefault: useDefault:
if (defaultSize == 0 * BYTES) { if (defaultSize == ZERO * BYTES) {
return nullptr; return nullptr;
} else { } else {
Text::Builder builder = initTextPointer(ref, segment, capTable, defaultSize).value; Text::Builder builder = initTextPointer(ref, segment, capTable, defaultSize).value;
memcpy(builder.begin(), defaultValue, defaultSize / BYTES); copyMemory(builder.asBytes().begin(), reinterpret_cast<const byte*>(defaultValue),
defaultSize);
return builder; return builder;
} }
} else { } else {
...@@ -1492,52 +1581,56 @@ struct WireHelpers { ...@@ -1492,52 +1581,56 @@ struct WireHelpers {
KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE, KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE,
"Called getText{Field,Element}() but existing list pointer is not byte-sized."); "Called getText{Field,Element}() but existing list pointer is not byte-sized.");
size_t size = ref->listRef.elementCount() / ELEMENTS; size_t size = unguard(subtractChecked(ref->listRef.elementCount() / ELEMENTS, ONE,
KJ_REQUIRE(size > 0 && cptr[size-1] == '\0', "Text blob missing NUL terminator.") { []() { KJ_FAIL_REQUIRE("zero-size blob can't be text (need NUL terminator)"); }));
KJ_REQUIRE(cptr[size] == '\0', "Text blob missing NUL terminator.") {
goto useDefault; goto useDefault;
} }
return Text::Builder(cptr, size - 1); return Text::Builder(cptr, size);
} }
} }
static KJ_ALWAYS_INLINE(SegmentAnd<Data::Builder> initDataPointer( static KJ_ALWAYS_INLINE(SegmentAnd<Data::Builder> initDataPointer(
WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, ByteCount size, WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, BlobSize size,
BuilderArena* orphanArena = nullptr)) { BuilderArena* orphanArena = nullptr)) {
// Allocate the space. // Allocate the space.
word* ptr = allocate(ref, segment, capTable, roundBytesUpToWords(size), word* ptr = allocate(ref, segment, capTable, roundBytesUpToWords(size),
WirePointer::LIST, orphanArena); WirePointer::LIST, orphanArena);
// Initialize the pointer. // Initialize the pointer.
ref->listRef.set(ElementSize::BYTE, size * (1 * ELEMENTS / BYTES)); ref->listRef.set(ElementSize::BYTE, size * (ONE * ELEMENTS / BYTES));
// Build the Data::Builder. // Build the Data::Builder.
return { segment, Data::Builder(reinterpret_cast<byte*>(ptr), size / BYTES) }; return { segment, Data::Builder(reinterpret_cast<byte*>(ptr), unguard(size / BYTES)) };
} }
static KJ_ALWAYS_INLINE(SegmentAnd<Data::Builder> setDataPointer( static KJ_ALWAYS_INLINE(SegmentAnd<Data::Builder> setDataPointer(
WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, Data::Reader value, WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, Data::Reader value,
BuilderArena* orphanArena = nullptr)) { BuilderArena* orphanArena = nullptr)) {
auto allocation = initDataPointer(ref, segment, capTable, value.size() * BYTES, orphanArena); BlobSize size = assertMaxBits<BLOB_SIZE_BITS>(guarded(value.size()),
[]() { KJ_FAIL_REQUIRE("text blob too big"); }) * BYTES;
auto allocation = initDataPointer(ref, segment, capTable, size, orphanArena);
memcpy(allocation.value.begin(), value.begin(), value.size()); memcpy(allocation.value.begin(), value.begin(), value.size());
return allocation; return allocation;
} }
static KJ_ALWAYS_INLINE(Data::Builder getWritableDataPointer( static KJ_ALWAYS_INLINE(Data::Builder getWritableDataPointer(
WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, SegmentBuilder* segment, CapTableBuilder* capTable,
const void* defaultValue, ByteCount defaultSize)) { const void* defaultValue, BlobSize defaultSize)) {
return getWritableDataPointer(ref, ref->target(), segment, capTable, defaultValue, defaultSize); return getWritableDataPointer(ref, ref->target(), segment, capTable, defaultValue, defaultSize);
} }
static KJ_ALWAYS_INLINE(Data::Builder getWritableDataPointer( static KJ_ALWAYS_INLINE(Data::Builder getWritableDataPointer(
WirePointer* ref, word* refTarget, SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, word* refTarget, SegmentBuilder* segment, CapTableBuilder* capTable,
const void* defaultValue, ByteCount defaultSize)) { const void* defaultValue, BlobSize defaultSize)) {
if (ref->isNull()) { if (ref->isNull()) {
if (defaultSize == 0 * BYTES) { if (defaultSize == ZERO * BYTES) {
return nullptr; return nullptr;
} else { } else {
Data::Builder builder = initDataPointer(ref, segment, capTable, defaultSize).value; Data::Builder builder = initDataPointer(ref, segment, capTable, defaultSize).value;
memcpy(builder.begin(), defaultValue, defaultSize / BYTES); copyMemory(builder.begin(), reinterpret_cast<const byte*>(defaultValue), defaultSize);
return builder; return builder;
} }
} else { } else {
...@@ -1548,15 +1641,17 @@ struct WireHelpers { ...@@ -1548,15 +1641,17 @@ struct WireHelpers {
KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE, KJ_REQUIRE(ref->listRef.elementSize() == ElementSize::BYTE,
"Called getData{Field,Element}() but existing list pointer is not byte-sized."); "Called getData{Field,Element}() but existing list pointer is not byte-sized.");
return Data::Builder(reinterpret_cast<byte*>(ptr), ref->listRef.elementCount() / ELEMENTS); return Data::Builder(reinterpret_cast<byte*>(ptr),
unguard(ref->listRef.elementCount() / ELEMENTS));
} }
} }
static SegmentAnd<word*> setStructPointer( static SegmentAnd<word*> setStructPointer(
SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, StructReader value, SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, StructReader value,
BuilderArena* orphanArena = nullptr, bool canonical = false) { BuilderArena* orphanArena = nullptr, bool canonical = false) {
ByteCount dataSize = roundBitsUpToBytes(value.dataSize); // TODO(now): Function may have been damaged in merge conflict resolution.
WirePointerCount ptrCount = value.pointerCount; auto dataSize = roundBitsUpToBytes(value.dataSize);
auto ptrCount = value.pointerCount;
if (canonical) { if (canonical) {
// StructReaders should not have bitwidths other than 1, but let's be safe // StructReaders should not have bitwidths other than 1, but let's be safe
...@@ -1604,17 +1699,19 @@ struct WireHelpers { ...@@ -1604,17 +1699,19 @@ struct WireHelpers {
word* ptr = allocate(ref, segment, capTable, totalSize, WirePointer::STRUCT, orphanArena); word* ptr = allocate(ref, segment, capTable, totalSize, WirePointer::STRUCT, orphanArena);
ref->structRef.set(dataWords, ptrCount); ref->structRef.set(dataWords, ptrCount);
if (value.dataSize == 1 * BITS) { if (value.dataSize == ONE * BITS) {
// Data size could be made 0 by truncation // Data size could be made 0 by truncation
if (dataSize != 0 * BYTES) { if (dataSize != ZERO * BYTES) {
*reinterpret_cast<char*>(ptr) = value.getDataField<bool>(0 * ELEMENTS); *reinterpret_cast<char*>(ptr) = value.getDataField<bool>(ZERO * ELEMENTS);
} }
} else { } else {
memcpy(ptr, value.data, dataSize / BYTES); copyMemory(reinterpret_cast<byte*>(ptr),
reinterpret_cast<const byte*>(value.data),
dataSize);
} }
WirePointer* pointerSection = reinterpret_cast<WirePointer*>(ptr + dataWords); WirePointer* pointerSection = reinterpret_cast<WirePointer*>(ptr + dataWords);
for (uint i = 0; i < ptrCount / POINTERS; i++) { for (auto i: kj::zeroTo(ptrCount)) {
copyPointer(segment, capTable, pointerSection + i, copyPointer(segment, capTable, pointerSection + i,
value.segment, value.capTable, value.pointers + i, value.segment, value.capTable, value.pointers + i,
value.nestingLimit, nullptr, canonical); value.nestingLimit, nullptr, canonical);
...@@ -1641,7 +1738,10 @@ struct WireHelpers { ...@@ -1641,7 +1738,10 @@ struct WireHelpers {
static SegmentAnd<word*> setListPointer( static SegmentAnd<word*> setListPointer(
SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, ListReader value, SegmentBuilder* segment, CapTableBuilder* capTable, WirePointer* ref, ListReader value,
BuilderArena* orphanArena = nullptr, bool canonical = false) { BuilderArena* orphanArena = nullptr, bool canonical = false) {
WordCount totalSize = roundBitsUpToWords(value.elementCount * value.step); // TODO(now): Function may have been damaged in merge conflict resolution.
auto totalSize = assertMax<kj::maxValueForBits<SEGMENT_WORD_COUNT_BITS>() - 1>(
roundBitsUpToWords(upgradeGuard<uint64_t>(value.elementCount) * value.step),
[]() { KJ_FAIL_ASSERT("encountered impossibly long struct list ListReader"); });
if (value.elementSize != ElementSize::INLINE_COMPOSITE) { if (value.elementSize != ElementSize::INLINE_COMPOSITE) {
// List of non-structs. // List of non-structs.
...@@ -1650,7 +1750,7 @@ struct WireHelpers { ...@@ -1650,7 +1750,7 @@ struct WireHelpers {
if (value.elementSize == ElementSize::POINTER) { if (value.elementSize == ElementSize::POINTER) {
// List of pointers. // List of pointers.
ref->listRef.set(ElementSize::POINTER, value.elementCount); ref->listRef.set(ElementSize::POINTER, value.elementCount);
for (uint i = 0; i < value.elementCount / ELEMENTS; i++) { for (auto i: zeroTo(value.elementCount * (ONE * POINTERS / ELEMENTS))) {
copyPointer(segment, capTable, reinterpret_cast<WirePointer*>(ptr) + i, copyPointer(segment, capTable, reinterpret_cast<WirePointer*>(ptr) + i,
value.segment, value.capTable, value.segment, value.capTable,
reinterpret_cast<const WirePointer*>(value.ptr) + i, reinterpret_cast<const WirePointer*>(value.ptr) + i,
...@@ -1659,14 +1759,12 @@ struct WireHelpers { ...@@ -1659,14 +1759,12 @@ struct WireHelpers {
} else { } else {
// List of data. // List of data.
ref->listRef.set(value.elementSize, value.elementCount); ref->listRef.set(value.elementSize, value.elementCount);
memcpy(ptr, value.ptr, totalSize * BYTES_PER_WORD / BYTES); copyMemory(ptr, reinterpret_cast<const word*>(value.ptr), totalSize);
} }
return { segment, ptr }; return { segment, ptr };
} else { } else {
// List of structs. // List of structs.
KJ_DASSERT(value.structDataSize % BITS_PER_WORD == 0 * BITS);
WordCount declDataSize = value.structDataSize / BITS_PER_WORD; WordCount declDataSize = value.structDataSize / BITS_PER_WORD;
WirePointerCount declPointerCount = value.structPointerCount; WirePointerCount declPointerCount = value.structPointerCount;
...@@ -1711,6 +1809,7 @@ struct WireHelpers { ...@@ -1711,6 +1809,7 @@ struct WireHelpers {
ptrCount = declPointerCount; ptrCount = declPointerCount;
} }
KJ_DASSERT(value.structDataSize % BITS_PER_WORD == 0 * BITS);
word* ptr = allocate(ref, segment, capTable, totalSize + POINTER_SIZE_IN_WORDS, word* ptr = allocate(ref, segment, capTable, totalSize + POINTER_SIZE_IN_WORDS,
WirePointer::LIST, orphanArena); WirePointer::LIST, orphanArena);
ref->listRef.setInlineComposite(totalSize); ref->listRef.setInlineComposite(totalSize);
...@@ -1721,12 +1820,12 @@ struct WireHelpers { ...@@ -1721,12 +1820,12 @@ struct WireHelpers {
word* dst = ptr + POINTER_SIZE_IN_WORDS; word* dst = ptr + POINTER_SIZE_IN_WORDS;
const word* src = reinterpret_cast<const word*>(value.ptr); const word* src = reinterpret_cast<const word*>(value.ptr);
for (uint i = 0; i < value.elementCount / ELEMENTS; i++) { for (auto i KJ_UNUSED: kj::zeroTo(value.elementCount)) {
memcpy(dst, src, dataSize * BYTES_PER_WORD / BYTES); copyMemory(dst, src, dataSize);
dst += dataSize; dst += dataSize;
src += declDataSize; src += declDataSize;
for (uint j = 0; j < ptrCount / POINTERS; j++) { for (auto j KJ_UNUSED: kj::zeroTo(ptrCount)) {
copyPointer(segment, capTable, reinterpret_cast<WirePointer*>(dst), copyPointer(segment, capTable, reinterpret_cast<WirePointer*>(dst),
value.segment, value.capTable, reinterpret_cast<const WirePointer*>(src), value.segment, value.capTable, reinterpret_cast<const WirePointer*>(src),
value.nestingLimit, nullptr, canonical); value.nestingLimit, nullptr, canonical);
...@@ -1802,7 +1901,7 @@ struct WireHelpers { ...@@ -1802,7 +1901,7 @@ struct WireHelpers {
} }
if (elementSize == ElementSize::INLINE_COMPOSITE) { if (elementSize == ElementSize::INLINE_COMPOSITE) {
WordCount wordCount = src->listRef.inlineCompositeWordCount(); auto wordCount = src->listRef.inlineCompositeWordCount();
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS; ptr += POINTER_SIZE_IN_WORDS;
...@@ -1816,10 +1915,10 @@ struct WireHelpers { ...@@ -1816,10 +1915,10 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
ElementCount elementCount = tag->inlineCompositeListElementCount(); auto elementCount = tag->inlineCompositeListElementCount();
auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS; auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
KJ_REQUIRE(wordsPerElement * ElementCount64(elementCount) <= wordCount, KJ_REQUIRE(wordsPerElement * upgradeGuard<uint64_t>(elementCount) <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.") { "INLINE_COMPOSITE list's elements overrun its word count.") {
goto useDefault; goto useDefault;
} }
...@@ -1841,11 +1940,11 @@ struct WireHelpers { ...@@ -1841,11 +1940,11 @@ struct WireHelpers {
nestingLimit - 1), nestingLimit - 1),
orphanArena, canonical); orphanArena, canonical);
} else { } else {
BitCount dataSize = dataBitsPerElement(elementSize) * ELEMENTS; auto dataSize = dataBitsPerElement(elementSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(elementSize) * ELEMENTS; auto pointerCount = pointersPerElement(elementSize) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS; auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
ElementCount elementCount = src->listRef.elementCount(); auto elementCount = src->listRef.elementCount();
WordCount64 wordCount = roundBitsUpToWords(ElementCount64(elementCount) * step); auto wordCount = roundBitsUpToWords(upgradeGuard<uint64_t>(elementCount) * step);
KJ_REQUIRE(boundsCheck(srcSegment, ptr, ptr + wordCount), KJ_REQUIRE(boundsCheck(srcSegment, ptr, ptr + wordCount),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
...@@ -2081,15 +2180,7 @@ struct WireHelpers { ...@@ -2081,15 +2180,7 @@ struct WireHelpers {
ElementSize elementSize = ref->listRef.elementSize(); ElementSize elementSize = ref->listRef.elementSize();
if (elementSize == ElementSize::INLINE_COMPOSITE) { if (elementSize == ElementSize::INLINE_COMPOSITE) {
#if _MSC_VER auto wordCount = ref->listRef.inlineCompositeWordCount();
// TODO(msvc): MSVC thinks decltype(WORDS/ELEMENTS) is a const type. /eyeroll
uint wordsPerElement;
#else
decltype(WORDS/ELEMENTS) wordsPerElement;
#endif
ElementCount size;
WordCount wordCount = ref->listRef.inlineCompositeWordCount();
// An INLINE_COMPOSITE list points to a tag, which is formatted like a pointer. // An INLINE_COMPOSITE list points to a tag, which is formatted like a pointer.
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
...@@ -2105,18 +2196,18 @@ struct WireHelpers { ...@@ -2105,18 +2196,18 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
size = tag->inlineCompositeListElementCount(); auto size = tag->inlineCompositeListElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS; auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
KJ_REQUIRE(ElementCount64(size) * wordsPerElement <= wordCount, KJ_REQUIRE(upgradeGuard<uint64_t>(size) * wordsPerElement <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.") { "INLINE_COMPOSITE list's elements overrun its word count.") {
goto useDefault; goto useDefault;
} }
if (wordsPerElement * (1 * ELEMENTS) == 0 * WORDS) { if (wordsPerElement * (ONE * ELEMENTS) == ZERO * WORDS) {
// Watch out for lists of zero-sized structs, which can claim to be arbitrarily large // Watch out for lists of zero-sized structs, which can claim to be arbitrarily large
// without having sent actual data. // without having sent actual data.
KJ_REQUIRE(amplifiedRead(segment, size * (1 * WORDS / ELEMENTS)), KJ_REQUIRE(amplifiedRead(segment, size * (ONE * WORDS / ELEMENTS)),
"Message contains amplified list pointer.") { "Message contains amplified list pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -2145,7 +2236,7 @@ struct WireHelpers { ...@@ -2145,7 +2236,7 @@ struct WireHelpers {
case ElementSize::TWO_BYTES: case ElementSize::TWO_BYTES:
case ElementSize::FOUR_BYTES: case ElementSize::FOUR_BYTES:
case ElementSize::EIGHT_BYTES: case ElementSize::EIGHT_BYTES:
KJ_REQUIRE(tag->structRef.dataSize.get() > 0 * WORDS, KJ_REQUIRE(tag->structRef.dataSize.get() > ZERO * WORDS,
"Expected a primitive list, but got a list of pointer-only structs.") { "Expected a primitive list, but got a list of pointer-only structs.") {
goto useDefault; goto useDefault;
} }
...@@ -2156,7 +2247,7 @@ struct WireHelpers { ...@@ -2156,7 +2247,7 @@ struct WireHelpers {
// in the struct is the pointer we were looking for, we want to munge the pointer to // in the struct is the pointer we were looking for, we want to munge the pointer to
// point at the first element's pointer section. // point at the first element's pointer section.
ptr += tag->structRef.dataSize.get(); ptr += tag->structRef.dataSize.get();
KJ_REQUIRE(tag->structRef.ptrCount.get() > 0 * POINTERS, KJ_REQUIRE(tag->structRef.ptrCount.get() > ZERO * POINTERS,
"Expected a pointer list, but got a list of data-only structs.") { "Expected a pointer list, but got a list of data-only structs.") {
goto useDefault; goto useDefault;
} }
...@@ -2176,22 +2267,21 @@ struct WireHelpers { ...@@ -2176,22 +2267,21 @@ struct WireHelpers {
} else { } else {
// This is a primitive or pointer list, but all such lists can also be interpreted as struct // This is a primitive or pointer list, but all such lists can also be interpreted as struct
// lists. We need to compute the data size and pointer count for such structs. // lists. We need to compute the data size and pointer count for such structs.
BitCount dataSize = dataBitsPerElement(ref->listRef.elementSize()) * ELEMENTS; auto dataSize = dataBitsPerElement(ref->listRef.elementSize()) * ELEMENTS;
WirePointerCount pointerCount = auto pointerCount = pointersPerElement(ref->listRef.elementSize()) * ELEMENTS;
pointersPerElement(ref->listRef.elementSize()) * ELEMENTS; auto elementCount = ref->listRef.elementCount();
ElementCount elementCount = ref->listRef.elementCount();
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS; auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
WordCount wordCount = roundBitsUpToWords(ElementCount64(elementCount) * step); auto wordCount = roundBitsUpToWords(upgradeGuard<uint64_t>(elementCount) * step);
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + wordCount), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + wordCount),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
if (elementSize == ElementSize::VOID) { if (elementSize == ElementSize::VOID) {
// Watch out for lists of void, which can claim to be arbitrarily large without having sent // Watch out for lists of void, which can claim to be arbitrarily large without having sent
// actual data. // actual data.
KJ_REQUIRE(amplifiedRead(segment, elementCount * (1 * WORDS / ELEMENTS)), KJ_REQUIRE(amplifiedRead(segment, elementCount * (ONE * WORDS / ELEMENTS)),
"Message contains amplified list pointer.") { "Message contains amplified list pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -2243,7 +2333,8 @@ struct WireHelpers { ...@@ -2243,7 +2333,8 @@ struct WireHelpers {
if (ref->isNull()) { if (ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr) defaultValue = ""; if (defaultValue == nullptr) defaultValue = "";
return Text::Reader(reinterpret_cast<const char*>(defaultValue), defaultSize / BYTES); return Text::Reader(reinterpret_cast<const char*>(defaultValue),
unguard(defaultSize / BYTES));
} else { } else {
const word* ptr = followFars(ref, refTarget, segment); const word* ptr = followFars(ref, refTarget, segment);
...@@ -2252,7 +2343,7 @@ struct WireHelpers { ...@@ -2252,7 +2343,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
uint size = ref->listRef.elementCount() / ELEMENTS; auto size = ref->listRef.elementCount() * (ONE * BYTES / ELEMENTS);
KJ_REQUIRE(ref->kind() == WirePointer::LIST, KJ_REQUIRE(ref->kind() == WirePointer::LIST,
"Message contains non-list pointer where text was expected.") { "Message contains non-list pointer where text was expected.") {
...@@ -2264,39 +2355,39 @@ struct WireHelpers { ...@@ -2264,39 +2355,39 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + KJ_REQUIRE(boundsCheck(segment, ptr, ptr + roundBytesUpToWords(size)),
roundBytesUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds text pointer.") { "Message contained out-of-bounds text pointer.") {
goto useDefault; goto useDefault;
} }
KJ_REQUIRE(size > 0, "Message contains text that is not NUL-terminated.") { KJ_REQUIRE(size > ZERO * BYTES, "Message contains text that is not NUL-terminated.") {
goto useDefault; goto useDefault;
} }
const char* cptr = reinterpret_cast<const char*>(ptr); const char* cptr = reinterpret_cast<const char*>(ptr);
--size; // NUL terminator uint unguardedSize = unguard(size / BYTES) - 1;
KJ_REQUIRE(cptr[size] == '\0', "Message contains text that is not NUL-terminated.") { KJ_REQUIRE(cptr[unguardedSize] == '\0', "Message contains text that is not NUL-terminated.") {
goto useDefault; goto useDefault;
} }
return Text::Reader(cptr, size); return Text::Reader(cptr, unguardedSize);
} }
} }
static KJ_ALWAYS_INLINE(Data::Reader readDataPointer( static KJ_ALWAYS_INLINE(Data::Reader readDataPointer(
SegmentReader* segment, const WirePointer* ref, SegmentReader* segment, const WirePointer* ref,
const void* defaultValue, ByteCount defaultSize)) { const void* defaultValue, BlobSize defaultSize)) {
return readDataPointer(segment, ref, ref->target(), defaultValue, defaultSize); return readDataPointer(segment, ref, ref->target(), defaultValue, defaultSize);
} }
static KJ_ALWAYS_INLINE(Data::Reader readDataPointer( static KJ_ALWAYS_INLINE(Data::Reader readDataPointer(
SegmentReader* segment, const WirePointer* ref, const word* refTarget, SegmentReader* segment, const WirePointer* ref, const word* refTarget,
const void* defaultValue, ByteCount defaultSize)) { const void* defaultValue, BlobSize defaultSize)) {
if (ref->isNull()) { if (ref->isNull()) {
useDefault: useDefault:
return Data::Reader(reinterpret_cast<const byte*>(defaultValue), defaultSize / BYTES); return Data::Reader(reinterpret_cast<const byte*>(defaultValue),
unguard(defaultSize / BYTES));
} else { } else {
const word* ptr = followFars(ref, refTarget, segment); const word* ptr = followFars(ref, refTarget, segment);
...@@ -2305,7 +2396,7 @@ struct WireHelpers { ...@@ -2305,7 +2396,7 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
uint size = ref->listRef.elementCount() / ELEMENTS; auto size = ref->listRef.elementCount() * (ONE * BYTES / ELEMENTS);
KJ_REQUIRE(ref->kind() == WirePointer::LIST, KJ_REQUIRE(ref->kind() == WirePointer::LIST,
"Message contains non-list pointer where data was expected.") { "Message contains non-list pointer where data was expected.") {
...@@ -2317,13 +2408,12 @@ struct WireHelpers { ...@@ -2317,13 +2408,12 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + KJ_REQUIRE(boundsCheck(segment, ptr, ptr + roundBytesUpToWords(size)),
roundBytesUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds data pointer.") { "Message contained out-of-bounds data pointer.") {
goto useDefault; goto useDefault;
} }
return Data::Reader(reinterpret_cast<const byte*>(ptr), size); return Data::Reader(reinterpret_cast<const byte*>(ptr), unguard(size / BYTES));
} }
} }
}; };
...@@ -2362,7 +2452,8 @@ ListBuilder PointerBuilder::getListAnySize(const word* defaultValue) { ...@@ -2362,7 +2452,8 @@ ListBuilder PointerBuilder::getListAnySize(const word* defaultValue) {
template <> template <>
Text::Builder PointerBuilder::initBlob<Text>(ByteCount size) { Text::Builder PointerBuilder::initBlob<Text>(ByteCount size) {
return WireHelpers::initTextPointer(pointer, segment, capTable, size).value; return WireHelpers::initTextPointer(pointer, segment, capTable,
assertMax<MAX_TEXT_SIZE>(size, ThrowOverflow())).value;
} }
template <> template <>
void PointerBuilder::setBlob<Text>(Text::Reader value) { void PointerBuilder::setBlob<Text>(Text::Reader value) {
...@@ -2370,12 +2461,14 @@ void PointerBuilder::setBlob<Text>(Text::Reader value) { ...@@ -2370,12 +2461,14 @@ void PointerBuilder::setBlob<Text>(Text::Reader value) {
} }
template <> template <>
Text::Builder PointerBuilder::getBlob<Text>(const void* defaultValue, ByteCount defaultSize) { Text::Builder PointerBuilder::getBlob<Text>(const void* defaultValue, ByteCount defaultSize) {
return WireHelpers::getWritableTextPointer(pointer, segment, capTable, defaultValue, defaultSize); return WireHelpers::getWritableTextPointer(pointer, segment, capTable, defaultValue,
assertMax<MAX_TEXT_SIZE>(defaultSize, ThrowOverflow()));
} }
template <> template <>
Data::Builder PointerBuilder::initBlob<Data>(ByteCount size) { Data::Builder PointerBuilder::initBlob<Data>(ByteCount size) {
return WireHelpers::initDataPointer(pointer, segment, capTable, size).value; return WireHelpers::initDataPointer(pointer, segment, capTable,
assertMaxBits<BLOB_SIZE_BITS>(size, ThrowOverflow())).value;
} }
template <> template <>
void PointerBuilder::setBlob<Data>(Data::Reader value) { void PointerBuilder::setBlob<Data>(Data::Reader value) {
...@@ -2383,7 +2476,8 @@ void PointerBuilder::setBlob<Data>(Data::Reader value) { ...@@ -2383,7 +2476,8 @@ void PointerBuilder::setBlob<Data>(Data::Reader value) {
} }
template <> template <>
Data::Builder PointerBuilder::getBlob<Data>(const void* defaultValue, ByteCount defaultSize) { Data::Builder PointerBuilder::getBlob<Data>(const void* defaultValue, ByteCount defaultSize) {
return WireHelpers::getWritableDataPointer(pointer, segment, capTable, defaultValue, defaultSize); return WireHelpers::getWritableDataPointer(pointer, segment, capTable, defaultValue,
assertMaxBits<BLOB_SIZE_BITS>(defaultSize, ThrowOverflow()));
} }
void PointerBuilder::setStruct(const StructReader& value, bool canonical) { void PointerBuilder::setStruct(const StructReader& value, bool canonical) {
...@@ -2521,7 +2615,8 @@ Text::Reader PointerReader::getBlob<Text>(const void* defaultValue, ByteCount de ...@@ -2521,7 +2615,8 @@ Text::Reader PointerReader::getBlob<Text>(const void* defaultValue, ByteCount de
template <> template <>
Data::Reader PointerReader::getBlob<Data>(const void* defaultValue, ByteCount defaultSize) const { Data::Reader PointerReader::getBlob<Data>(const void* defaultValue, ByteCount defaultSize) const {
const WirePointer* ref = pointer == nullptr ? &zero.pointer : pointer; const WirePointer* ref = pointer == nullptr ? &zero.pointer : pointer;
return WireHelpers::readDataPointer(segment, ref, defaultValue, defaultSize); return WireHelpers::readDataPointer(segment, ref, defaultValue,
assertMaxBits<BLOB_SIZE_BITS>(defaultSize, ThrowOverflow()));
} }
#if !CAPNP_LITE #if !CAPNP_LITE
...@@ -2537,7 +2632,7 @@ const word* PointerReader::getUnchecked() const { ...@@ -2537,7 +2632,7 @@ const word* PointerReader::getUnchecked() const {
} }
MessageSizeCounts PointerReader::targetSize() const { MessageSizeCounts PointerReader::targetSize() const {
return pointer == nullptr ? MessageSizeCounts { 0 * WORDS, 0 } return pointer == nullptr ? MessageSizeCounts { ZERO * WORDS, 0 }
: WireHelpers::totalSize(segment, pointer, nestingLimit); : WireHelpers::totalSize(segment, pointer, nestingLimit);
} }
...@@ -2615,89 +2710,96 @@ bool PointerReader::isCanonical(const word **readHead) { ...@@ -2615,89 +2710,96 @@ bool PointerReader::isCanonical(const word **readHead) {
// StructBuilder // StructBuilder
void StructBuilder::clearAll() { void StructBuilder::clearAll() {
if (dataSize == 1 * BITS) { if (dataSize == ONE * BITS) {
setDataField<bool>(1 * ELEMENTS, false); setDataField<bool>(ONE * ELEMENTS, false);
} else { } else {
memset(data, 0, dataSize / BITS_PER_BYTE / BYTES); WireHelpers::zeroMemory(reinterpret_cast<byte*>(data), dataSize / BITS_PER_BYTE);
} }
for (uint i = 0; i < pointerCount / POINTERS; i++) { for (auto i: kj::zeroTo(pointerCount)) {
WireHelpers::zeroObject(segment, capTable, pointers + i); WireHelpers::zeroObject(segment, capTable, pointers + i);
} }
memset(pointers, 0, pointerCount * BYTES_PER_POINTER / BYTES); WireHelpers::zeroMemory(pointers, pointerCount);
} }
void StructBuilder::transferContentFrom(StructBuilder other) { void StructBuilder::transferContentFrom(StructBuilder other) {
// Determine the amount of data the builders have in common. // Determine the amount of data the builders have in common.
BitCount sharedDataSize = kj::min(dataSize, other.dataSize); auto sharedDataSize = kj::min(dataSize, other.dataSize);
if (dataSize > sharedDataSize) { if (dataSize > sharedDataSize) {
// Since the target is larger than the source, make sure to zero out the extra bits that the // Since the target is larger than the source, make sure to zero out the extra bits that the
// source doesn't have. // source doesn't have.
if (dataSize == 1 * BITS) { if (dataSize == ONE * BITS) {
setDataField<bool>(0 * ELEMENTS, false); setDataField<bool>(ZERO * ELEMENTS, false);
} else { } else {
byte* unshared = reinterpret_cast<byte*>(data) + sharedDataSize / BITS_PER_BYTE / BYTES; byte* unshared = reinterpret_cast<byte*>(data) + sharedDataSize / BITS_PER_BYTE;
memset(unshared, 0, (dataSize - sharedDataSize) / BITS_PER_BYTE / BYTES); // Note: this subtraction can't fail due to the if() above
WireHelpers::zeroMemory(unshared,
subtractChecked(dataSize, sharedDataSize, []() {}) / BITS_PER_BYTE);
} }
} }
// Copy over the shared part. // Copy over the shared part.
if (sharedDataSize == 1 * BITS) { if (sharedDataSize == ONE * BITS) {
setDataField<bool>(0 * ELEMENTS, other.getDataField<bool>(0 * ELEMENTS)); setDataField<bool>(ZERO * ELEMENTS, other.getDataField<bool>(ZERO * ELEMENTS));
} else { } else {
memcpy(data, other.data, sharedDataSize / BITS_PER_BYTE / BYTES); WireHelpers::copyMemory(reinterpret_cast<byte*>(data),
reinterpret_cast<byte*>(other.data),
sharedDataSize / BITS_PER_BYTE);
} }
// Zero out all pointers in the target. // Zero out all pointers in the target.
for (uint i = 0; i < pointerCount / POINTERS; i++) { for (auto i: kj::zeroTo(pointerCount)) {
WireHelpers::zeroObject(segment, capTable, pointers + i); WireHelpers::zeroObject(segment, capTable, pointers + i);
} }
memset(pointers, 0, pointerCount * BYTES_PER_POINTER / BYTES); WireHelpers::zeroMemory(pointers, pointerCount);
// Transfer the pointers. // Transfer the pointers.
WirePointerCount sharedPointerCount = kj::min(pointerCount, other.pointerCount); auto sharedPointerCount = kj::min(pointerCount, other.pointerCount);
for (uint i = 0; i < sharedPointerCount / POINTERS; i++) { for (auto i: kj::zeroTo(sharedPointerCount)) {
WireHelpers::transferPointer(segment, pointers + i, other.segment, other.pointers + i); WireHelpers::transferPointer(segment, pointers + i, other.segment, other.pointers + i);
} }
// Zero out the pointers that were transferred in the source because it no longer has ownership. // Zero out the pointers that were transferred in the source because it no longer has ownership.
// If the source had any extra pointers that the destination didn't have space for, we // If the source had any extra pointers that the destination didn't have space for, we
// intentionally leave them be, so that they'll be cleaned up later. // intentionally leave them be, so that they'll be cleaned up later.
memset(other.pointers, 0, sharedPointerCount * BYTES_PER_POINTER / BYTES); WireHelpers::zeroMemory(other.pointers, sharedPointerCount);
} }
void StructBuilder::copyContentFrom(StructReader other) { void StructBuilder::copyContentFrom(StructReader other) {
// Determine the amount of data the builders have in common. // Determine the amount of data the builders have in common.
BitCount sharedDataSize = kj::min(dataSize, other.dataSize); auto sharedDataSize = kj::min(dataSize, other.dataSize);
if (dataSize > sharedDataSize) { if (dataSize > sharedDataSize) {
// Since the target is larger than the source, make sure to zero out the extra bits that the // Since the target is larger than the source, make sure to zero out the extra bits that the
// source doesn't have. // source doesn't have.
if (dataSize == 1 * BITS) { if (dataSize == ONE * BITS) {
setDataField<bool>(0 * ELEMENTS, false); setDataField<bool>(ZERO * ELEMENTS, false);
} else { } else {
byte* unshared = reinterpret_cast<byte*>(data) + sharedDataSize / BITS_PER_BYTE / BYTES; byte* unshared = reinterpret_cast<byte*>(data) + sharedDataSize / BITS_PER_BYTE;
memset(unshared, 0, (dataSize - sharedDataSize) / BITS_PER_BYTE / BYTES); WireHelpers::zeroMemory(unshared,
subtractChecked(dataSize, sharedDataSize, []() {}) / BITS_PER_BYTE);
} }
} }
// Copy over the shared part. // Copy over the shared part.
if (sharedDataSize == 1 * BITS) { if (sharedDataSize == ONE * BITS) {
setDataField<bool>(0 * ELEMENTS, other.getDataField<bool>(0 * ELEMENTS)); setDataField<bool>(ZERO * ELEMENTS, other.getDataField<bool>(ZERO * ELEMENTS));
} else { } else {
memcpy(data, other.data, sharedDataSize / BITS_PER_BYTE / BYTES); WireHelpers::copyMemory(reinterpret_cast<byte*>(data),
reinterpret_cast<const byte*>(other.data),
sharedDataSize / BITS_PER_BYTE);
} }
// Zero out all pointers in the target. // Zero out all pointers in the target.
for (uint i = 0; i < pointerCount / POINTERS; i++) { for (auto i: kj::zeroTo(pointerCount)) {
WireHelpers::zeroObject(segment, capTable, pointers + i); WireHelpers::zeroObject(segment, capTable, pointers + i);
} }
memset(pointers, 0, pointerCount * BYTES_PER_POINTER / BYTES); WireHelpers::zeroMemory(pointers, pointerCount);
// Copy the pointers. // Copy the pointers.
WirePointerCount sharedPointerCount = kj::min(pointerCount, other.pointerCount); auto sharedPointerCount = kj::min(pointerCount, other.pointerCount);
for (uint i = 0; i < sharedPointerCount / POINTERS; i++) { for (auto i: kj::zeroTo(sharedPointerCount)) {
WireHelpers::copyPointer(segment, capTable, pointers + i, WireHelpers::copyPointer(segment, capTable, pointers + i,
other.segment, other.capTable, other.pointers + i, other.nestingLimit); other.segment, other.capTable, other.pointers + i, other.nestingLimit);
} }
...@@ -2729,7 +2831,7 @@ MessageSizeCounts StructReader::totalSize() const { ...@@ -2729,7 +2831,7 @@ MessageSizeCounts StructReader::totalSize() const {
MessageSizeCounts result = { MessageSizeCounts result = {
WireHelpers::roundBitsUpToWords(dataSize) + pointerCount * WORDS_PER_POINTER, 0 }; WireHelpers::roundBitsUpToWords(dataSize) + pointerCount * WORDS_PER_POINTER, 0 };
for (uint i = 0; i < pointerCount / POINTERS; i++) { for (auto i: kj::zeroTo(pointerCount)) {
result += WireHelpers::totalSize(segment, pointers + i, nestingLimit); result += WireHelpers::totalSize(segment, pointers + i, nestingLimit);
} }
...@@ -2812,12 +2914,12 @@ bool StructReader::isCanonical(const word **readHead, ...@@ -2812,12 +2914,12 @@ bool StructReader::isCanonical(const word **readHead,
// ListBuilder // ListBuilder
Text::Builder ListBuilder::asText() { Text::Builder ListBuilder::asText() {
KJ_REQUIRE(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS, KJ_REQUIRE(structDataSize == G(8) * BITS && structPointerCount == ZERO * POINTERS,
"Expected Text, got list of non-bytes.") { "Expected Text, got list of non-bytes.") {
return Text::Builder(); return Text::Builder();
} }
size_t size = elementCount / ELEMENTS; size_t size = unguard(elementCount / ELEMENTS);
KJ_REQUIRE(size > 0, "Message contains text that is not NUL-terminated.") { KJ_REQUIRE(size > 0, "Message contains text that is not NUL-terminated.") {
return Text::Builder(); return Text::Builder();
...@@ -2834,18 +2936,18 @@ Text::Builder ListBuilder::asText() { ...@@ -2834,18 +2936,18 @@ Text::Builder ListBuilder::asText() {
} }
Data::Builder ListBuilder::asData() { Data::Builder ListBuilder::asData() {
KJ_REQUIRE(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS, KJ_REQUIRE(structDataSize == G(8) * BITS && structPointerCount == ZERO * POINTERS,
"Expected Text, got list of non-bytes.") { "Expected Text, got list of non-bytes.") {
return Data::Builder(); return Data::Builder();
} }
return Data::Builder(reinterpret_cast<byte*>(ptr), elementCount / ELEMENTS); return Data::Builder(reinterpret_cast<byte*>(ptr), unguard(elementCount / ELEMENTS));
} }
StructBuilder ListBuilder::getStructElement(ElementCount index) { StructBuilder ListBuilder::getStructElement(ElementCount index) {
BitCount64 indexBit = ElementCount64(index) * step; auto indexBit = upgradeGuard<uint64_t>(index) * step;
byte* structData = ptr + indexBit / BITS_PER_BYTE; byte* structData = ptr + indexBit / BITS_PER_BYTE;
KJ_DASSERT(indexBit % BITS_PER_BYTE == 0 * BITS); KJ_DASSERT(indexBit % BITS_PER_BYTE == ZERO * BITS);
return StructBuilder(segment, capTable, structData, return StructBuilder(segment, capTable, structData,
reinterpret_cast<WirePointer*>(structData + structDataSize / BITS_PER_BYTE), reinterpret_cast<WirePointer*>(structData + structDataSize / BITS_PER_BYTE),
structDataSize, structPointerCount); structDataSize, structPointerCount);
...@@ -2874,12 +2976,12 @@ ListBuilder ListBuilder::imbue(CapTableBuilder* capTable) { ...@@ -2874,12 +2976,12 @@ ListBuilder ListBuilder::imbue(CapTableBuilder* capTable) {
// ListReader // ListReader
Text::Reader ListReader::asText() { Text::Reader ListReader::asText() {
KJ_REQUIRE(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS, KJ_REQUIRE(structDataSize == G(8) * BITS && structPointerCount == ZERO * POINTERS,
"Expected Text, got list of non-bytes.") { "Expected Text, got list of non-bytes.") {
return Text::Reader(); return Text::Reader();
} }
size_t size = elementCount / ELEMENTS; size_t size = unguard(elementCount / ELEMENTS);
KJ_REQUIRE(size > 0, "Message contains text that is not NUL-terminated.") { KJ_REQUIRE(size > 0, "Message contains text that is not NUL-terminated.") {
return Text::Reader(); return Text::Reader();
...@@ -2896,12 +2998,12 @@ Text::Reader ListReader::asText() { ...@@ -2896,12 +2998,12 @@ Text::Reader ListReader::asText() {
} }
Data::Reader ListReader::asData() { Data::Reader ListReader::asData() {
KJ_REQUIRE(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS, KJ_REQUIRE(structDataSize == G(8) * BITS && structPointerCount == ZERO * POINTERS,
"Expected Text, got list of non-bytes.") { "Expected Text, got list of non-bytes.") {
return Data::Reader(); return Data::Reader();
} }
return Data::Reader(reinterpret_cast<const byte*>(ptr), elementCount / ELEMENTS); return Data::Reader(reinterpret_cast<const byte*>(ptr), unguard(elementCount / ELEMENTS));
} }
kj::ArrayPtr<const byte> ListReader::asRawBytes() { kj::ArrayPtr<const byte> ListReader::asRawBytes() {
...@@ -2920,17 +3022,17 @@ StructReader ListReader::getStructElement(ElementCount index) const { ...@@ -2920,17 +3022,17 @@ StructReader ListReader::getStructElement(ElementCount index) const {
return StructReader(); return StructReader();
} }
BitCount64 indexBit = ElementCount64(index) * step; auto indexBit = upgradeGuard<uint64_t>(index) * step;
const byte* structData = ptr + indexBit / BITS_PER_BYTE; const byte* structData = ptr + indexBit / BITS_PER_BYTE;
const WirePointer* structPointers = const WirePointer* structPointers =
reinterpret_cast<const WirePointer*>(structData + structDataSize / BITS_PER_BYTE); reinterpret_cast<const WirePointer*>(structData + structDataSize / BITS_PER_BYTE);
// This check should pass if there are no bugs in the list pointer validation code. // This check should pass if there are no bugs in the list pointer validation code.
KJ_DASSERT(structPointerCount == 0 * POINTERS || KJ_DASSERT(structPointerCount == ZERO * POINTERS ||
(uintptr_t)structPointers % sizeof(void*) == 0, (uintptr_t)structPointers % sizeof(void*) == 0,
"Pointer section of struct list element not aligned."); "Pointer section of struct list element not aligned.");
KJ_DASSERT(indexBit % BITS_PER_BYTE == 0 * BITS); KJ_DASSERT(indexBit % BITS_PER_BYTE == ZERO * BITS);
return StructReader( return StructReader(
segment, capTable, structData, structPointers, segment, capTable, structData, structPointers,
structDataSize, structPointerCount, structDataSize, structPointerCount,
...@@ -3062,7 +3164,8 @@ OrphanBuilder OrphanBuilder::initStructList( ...@@ -3062,7 +3164,8 @@ OrphanBuilder OrphanBuilder::initStructList(
OrphanBuilder OrphanBuilder::initText( OrphanBuilder OrphanBuilder::initText(
BuilderArena* arena, CapTableBuilder* capTable, ByteCount size) { BuilderArena* arena, CapTableBuilder* capTable, ByteCount size) {
OrphanBuilder result; OrphanBuilder result;
auto allocation = WireHelpers::initTextPointer(result.tagAsPtr(), nullptr, capTable, size, arena); auto allocation = WireHelpers::initTextPointer(result.tagAsPtr(), nullptr, capTable,
assertMax<MAX_TEXT_SIZE>(size, ThrowOverflow()), arena);
result.segment = allocation.segment; result.segment = allocation.segment;
result.capTable = capTable; result.capTable = capTable;
result.location = reinterpret_cast<word*>(allocation.value.begin()); result.location = reinterpret_cast<word*>(allocation.value.begin());
...@@ -3072,7 +3175,8 @@ OrphanBuilder OrphanBuilder::initText( ...@@ -3072,7 +3175,8 @@ OrphanBuilder OrphanBuilder::initText(
OrphanBuilder OrphanBuilder::initData( OrphanBuilder OrphanBuilder::initData(
BuilderArena* arena, CapTableBuilder* capTable, ByteCount size) { BuilderArena* arena, CapTableBuilder* capTable, ByteCount size) {
OrphanBuilder result; OrphanBuilder result;
auto allocation = WireHelpers::initDataPointer(result.tagAsPtr(), nullptr, capTable, size, arena); auto allocation = WireHelpers::initDataPointer(result.tagAsPtr(), nullptr, capTable,
assertMaxBits<BLOB_SIZE_BITS>(size), arena);
result.segment = allocation.segment; result.segment = allocation.segment;
result.capTable = capTable; result.capTable = capTable;
result.location = reinterpret_cast<word*>(allocation.value.begin()); result.location = reinterpret_cast<word*>(allocation.value.begin());
...@@ -3236,12 +3340,14 @@ OrphanBuilder OrphanBuilder::referenceExternalData(BuilderArena* arena, Data::Re ...@@ -3236,12 +3340,14 @@ OrphanBuilder OrphanBuilder::referenceExternalData(BuilderArena* arena, Data::Re
KJ_REQUIRE(reinterpret_cast<uintptr_t>(data.begin()) % sizeof(void*) == 0, KJ_REQUIRE(reinterpret_cast<uintptr_t>(data.begin()) % sizeof(void*) == 0,
"Cannot referenceExternalData() that is not aligned."); "Cannot referenceExternalData() that is not aligned.");
auto wordCount = WireHelpers::roundBytesUpToWords(data.size() * BYTES); auto checkedSize = assertMaxBits<BLOB_SIZE_BITS>(guarded(data.size()));
kj::ArrayPtr<const word> words(reinterpret_cast<const word*>(data.begin()), wordCount / WORDS); auto wordCount = WireHelpers::roundBytesUpToWords(checkedSize * BYTES);
kj::ArrayPtr<const word> words(reinterpret_cast<const word*>(data.begin()),
unguard(wordCount / WORDS));
OrphanBuilder result; OrphanBuilder result;
result.tagAsPtr()->setKindForOrphan(WirePointer::LIST); result.tagAsPtr()->setKindForOrphan(WirePointer::LIST);
result.tagAsPtr()->listRef.set(ElementSize::BYTE, data.size() * ELEMENTS); result.tagAsPtr()->listRef.set(ElementSize::BYTE, checkedSize * ELEMENTS);
result.segment = arena->addExternalSegment(words); result.segment = arena->addExternalSegment(words);
// External data cannot possibly contain capabilities. // External data cannot possibly contain capabilities.
...@@ -3297,7 +3403,7 @@ Text::Builder OrphanBuilder::asText() { ...@@ -3297,7 +3403,7 @@ Text::Builder OrphanBuilder::asText() {
// Never relocates. // Never relocates.
return WireHelpers::getWritableTextPointer( return WireHelpers::getWritableTextPointer(
tagAsPtr(), location, segment, capTable, nullptr, 0 * BYTES); tagAsPtr(), location, segment, capTable, nullptr, ZERO * BYTES);
} }
Data::Builder OrphanBuilder::asData() { Data::Builder OrphanBuilder::asData() {
...@@ -3305,7 +3411,7 @@ Data::Builder OrphanBuilder::asData() { ...@@ -3305,7 +3411,7 @@ Data::Builder OrphanBuilder::asData() {
// Never relocates. // Never relocates.
return WireHelpers::getWritableDataPointer( return WireHelpers::getWritableDataPointer(
tagAsPtr(), location, segment, capTable, nullptr, 0 * BYTES); tagAsPtr(), location, segment, capTable, nullptr, ZERO * BYTES);
} }
StructReader OrphanBuilder::asStructReader(StructSize size) const { StructReader OrphanBuilder::asStructReader(StructSize size) const {
...@@ -3328,12 +3434,12 @@ kj::Own<ClientHook> OrphanBuilder::asCapability() const { ...@@ -3328,12 +3434,12 @@ kj::Own<ClientHook> OrphanBuilder::asCapability() const {
Text::Reader OrphanBuilder::asTextReader() const { Text::Reader OrphanBuilder::asTextReader() const {
KJ_DASSERT(tagAsPtr()->isNull() == (location == nullptr)); KJ_DASSERT(tagAsPtr()->isNull() == (location == nullptr));
return WireHelpers::readTextPointer(segment, tagAsPtr(), location, nullptr, 0 * BYTES); return WireHelpers::readTextPointer(segment, tagAsPtr(), location, nullptr, ZERO * BYTES);
} }
Data::Reader OrphanBuilder::asDataReader() const { Data::Reader OrphanBuilder::asDataReader() const {
KJ_DASSERT(tagAsPtr()->isNull() == (location == nullptr)); KJ_DASSERT(tagAsPtr()->isNull() == (location == nullptr));
return WireHelpers::readDataPointer(segment, tagAsPtr(), location, nullptr, 0 * BYTES); return WireHelpers::readDataPointer(segment, tagAsPtr(), location, nullptr, ZERO * BYTES);
} }
bool OrphanBuilder::truncate(ElementCount size, bool isText) { bool OrphanBuilder::truncate(ElementCount size, bool isText) {
......
...@@ -82,26 +82,48 @@ class BuilderArena; ...@@ -82,26 +82,48 @@ class BuilderArena;
// ============================================================================= // =============================================================================
typedef decltype(BITS / ELEMENTS) BitsPerElement; #if CAPNP_DEBUG_TYPES
typedef decltype(POINTERS / ELEMENTS) PointersPerElement; typedef kj::UnitRatio<kj::Guarded<64, uint>, BitLabel, ElementLabel> BitsPerElementTableType;
#else
static constexpr BitsPerElement BITS_PER_ELEMENT_TABLE[8] = { typedef uint BitsPerElementTableType;
0 * BITS / ELEMENTS, #endif
1 * BITS / ELEMENTS,
8 * BITS / ELEMENTS, static constexpr BitsPerElementTableType BITS_PER_ELEMENT_TABLE[8] = {
16 * BITS / ELEMENTS, guarded< 0>() * BITS / ELEMENTS,
32 * BITS / ELEMENTS, guarded< 1>() * BITS / ELEMENTS,
64 * BITS / ELEMENTS, guarded< 8>() * BITS / ELEMENTS,
0 * BITS / ELEMENTS, guarded<16>() * BITS / ELEMENTS,
0 * BITS / ELEMENTS guarded<32>() * BITS / ELEMENTS,
guarded<64>() * BITS / ELEMENTS,
guarded< 0>() * BITS / ELEMENTS,
guarded< 0>() * BITS / ELEMENTS
}; };
inline KJ_CONSTEXPR() BitsPerElement dataBitsPerElement(ElementSize size) { inline KJ_CONSTEXPR() BitsPerElementTableType dataBitsPerElement(ElementSize size) {
return _::BITS_PER_ELEMENT_TABLE[static_cast<int>(size)]; return _::BITS_PER_ELEMENT_TABLE[static_cast<int>(size)];
} }
inline constexpr PointersPerElement pointersPerElement(ElementSize size) { inline constexpr PointersPerElementN<1> pointersPerElement(ElementSize size) {
return size == ElementSize::POINTER ? 1 * POINTERS / ELEMENTS : 0 * POINTERS / ELEMENTS; if (size == ElementSize::POINTER) {
return ONE * POINTERS / ELEMENTS;
} else {
return ZERO * POINTERS / ELEMENTS;
}
}
static constexpr BitsPerElementTableType BITS_PER_ELEMENT_INCLUDING_PONITERS_TABLE[8] = {
guarded< 0>() * BITS / ELEMENTS,
guarded< 1>() * BITS / ELEMENTS,
guarded< 8>() * BITS / ELEMENTS,
guarded<16>() * BITS / ELEMENTS,
guarded<32>() * BITS / ELEMENTS,
guarded<64>() * BITS / ELEMENTS,
guarded<64>() * BITS / ELEMENTS,
guarded< 0>() * BITS / ELEMENTS
};
inline KJ_CONSTEXPR() BitsPerElementTableType bitsPerElementIncludingPointers(ElementSize size) {
return _::BITS_PER_ELEMENT_INCLUDING_PONITERS_TABLE[static_cast<int>(size)];
} }
template <size_t size> struct ElementSizeForByteSize; template <size_t size> struct ElementSizeForByteSize;
...@@ -142,17 +164,23 @@ inline constexpr ElementSize elementSizeForType() { ...@@ -142,17 +164,23 @@ inline constexpr ElementSize elementSizeForType() {
} }
struct MessageSizeCounts { struct MessageSizeCounts {
WordCount64 wordCount; WordCountN<61, uint64_t> wordCount; // 2^64 bytes
uint capCount; uint capCount;
MessageSizeCounts& operator+=(const MessageSizeCounts& other) { MessageSizeCounts& operator+=(const MessageSizeCounts& other) {
wordCount += other.wordCount; // OK to truncate unchecked because this class is used to count actual stuff in memory, and
// we couldn't possibly have anywhere near 2^61 words.
wordCount = assumeBits<61>(wordCount + other.wordCount);
capCount += other.capCount; capCount += other.capCount;
return *this; return *this;
} }
void addWords(WordCountN<61, uint64_t> other) {
wordCount = assumeBits<61>(wordCount + other);
}
MessageSize asPublic() { MessageSize asPublic() {
return MessageSize { wordCount / WORDS, capCount }; return MessageSize { unguard(wordCount / WORDS), capCount };
} }
}; };
...@@ -168,13 +196,13 @@ union AlignedData { ...@@ -168,13 +196,13 @@ union AlignedData {
}; };
struct StructSize { struct StructSize {
WordCount16 data; StructDataWordCount data;
WirePointerCount16 pointers; StructPointerCount pointers;
inline constexpr WordCount total() const { return data + pointers * WORDS_PER_POINTER; } inline constexpr WordCountN<17> total() const { return data + pointers * WORDS_PER_POINTER; }
StructSize() = default; StructSize() = default;
inline constexpr StructSize(WordCount data, WirePointerCount pointers) inline constexpr StructSize(StructDataWordCount data, StructPointerCount pointers)
: data(data), pointers(pointers) {} : data(data), pointers(pointers) {}
}; };
...@@ -324,7 +352,8 @@ public: ...@@ -324,7 +352,8 @@ public:
ListBuilder getList(ElementSize elementSize, const word* defaultValue); ListBuilder getList(ElementSize elementSize, const word* defaultValue);
ListBuilder getStructList(StructSize elementSize, const word* defaultValue); ListBuilder getStructList(StructSize elementSize, const word* defaultValue);
ListBuilder getListAnySize(const word* defaultValue); ListBuilder getListAnySize(const word* defaultValue);
template <typename T> typename T::Builder getBlob(const void* defaultValue,ByteCount defaultSize); template <typename T> typename T::Builder getBlob(
const void* defaultValue, ByteCount defaultSize);
#if !CAPNP_LITE #if !CAPNP_LITE
kj::Own<ClientHook> getCapability(); kj::Own<ClientHook> getCapability();
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
...@@ -474,37 +503,36 @@ public: ...@@ -474,37 +503,36 @@ public:
// Get the object's location. Only valid for independently-allocated objects (i.e. not list // Get the object's location. Only valid for independently-allocated objects (i.e. not list
// elements). // elements).
inline BitCount getDataSectionSize() const { return dataSize; } inline StructDataBitCount getDataSectionSize() const { return dataSize; }
inline WirePointerCount getPointerSectionSize() const { return pointerCount; } inline StructPointerCount getPointerSectionSize() const { return pointerCount; }
inline kj::ArrayPtr<byte> getDataSectionAsBlob(); inline kj::ArrayPtr<byte> getDataSectionAsBlob();
inline _::ListBuilder getPointerSectionAsList(); inline _::ListBuilder getPointerSectionAsList();
template <typename T> template <typename T>
KJ_ALWAYS_INLINE(bool hasDataField(ElementCount offset)); KJ_ALWAYS_INLINE(bool hasDataField(StructDataElementOffset offset));
// Return true if the field is set to something other than its default value. // Return true if the field is set to something other than its default value.
template <typename T> template <typename T>
KJ_ALWAYS_INLINE(T getDataField(ElementCount offset)); KJ_ALWAYS_INLINE(T getDataField(StructDataElementOffset offset));
// Gets the data field value of the given type at the given offset. The offset is measured in // Gets the data field value of the given type at the given offset. The offset is measured in
// multiples of the field size, determined by the type. // multiples of the field size, determined by the type.
template <typename T> template <typename T>
KJ_ALWAYS_INLINE(T getDataField(ElementCount offset, Mask<T> mask)); KJ_ALWAYS_INLINE(T getDataField(StructDataElementOffset offset, Mask<T> mask));
// Like getDataField() but applies the given XOR mask to the data on load. Used for reading // Like getDataField() but applies the given XOR mask to the data on load. Used for reading
// fields with non-zero default values. // fields with non-zero default values.
template <typename T> template <typename T>
KJ_ALWAYS_INLINE(void setDataField( KJ_ALWAYS_INLINE(void setDataField(StructDataElementOffset offset, kj::NoInfer<T> value));
ElementCount offset, kj::NoInfer<T> value));
// Sets the data field value at the given offset. // Sets the data field value at the given offset.
template <typename T> template <typename T>
KJ_ALWAYS_INLINE(void setDataField( KJ_ALWAYS_INLINE(void setDataField(StructDataElementOffset offset,
ElementCount offset, kj::NoInfer<T> value, Mask<T> mask)); kj::NoInfer<T> value, Mask<T> mask));
// Like setDataField() but applies the given XOR mask before storing. Used for writing fields // Like setDataField() but applies the given XOR mask before storing. Used for writing fields
// with non-zero default values. // with non-zero default values.
KJ_ALWAYS_INLINE(PointerBuilder getPointerField(WirePointerCount ptrIndex)); KJ_ALWAYS_INLINE(PointerBuilder getPointerField(StructPointerCount ptrIndex));
// Get a builder for a pointer field given the index within the pointer section. // Get a builder for a pointer field given the index within the pointer section.
void clearAll(); void clearAll();
...@@ -538,15 +566,15 @@ private: ...@@ -538,15 +566,15 @@ private:
void* data; // Pointer to the encoded data. void* data; // Pointer to the encoded data.
WirePointer* pointers; // Pointer to the encoded pointers. WirePointer* pointers; // Pointer to the encoded pointers.
BitCount32 dataSize; StructDataBitCount dataSize;
// Size of data section. We use a bit count rather than a word count to more easily handle the // Size of data section. We use a bit count rather than a word count to more easily handle the
// case of struct lists encoded with less than a word per element. // case of struct lists encoded with less than a word per element.
WirePointerCount16 pointerCount; // Size of the pointer section. StructPointerCount pointerCount; // Size of the pointer section.
inline StructBuilder(SegmentBuilder* segment, CapTableBuilder* capTable, inline StructBuilder(SegmentBuilder* segment, CapTableBuilder* capTable,
void* data, WirePointer* pointers, void* data, WirePointer* pointers,
BitCount dataSize, WirePointerCount pointerCount) StructDataBitCount dataSize, StructPointerCount pointerCount)
: segment(segment), capTable(capTable), data(data), pointers(pointers), : segment(segment), capTable(capTable), data(data), pointers(pointers),
dataSize(dataSize), pointerCount(pointerCount) {} dataSize(dataSize), pointerCount(pointerCount) {}
...@@ -558,38 +586,38 @@ private: ...@@ -558,38 +586,38 @@ private:
class StructReader { class StructReader {
public: public:
inline StructReader() inline StructReader()
: segment(nullptr), capTable(nullptr), data(nullptr), pointers(nullptr), dataSize(0), : segment(nullptr), capTable(nullptr), data(nullptr), pointers(nullptr),
pointerCount(0), nestingLimit(0x7fffffff) {} dataSize(ZERO * BITS), pointerCount(ZERO * POINTERS), nestingLimit(0x7fffffff) {}
inline StructReader(kj::ArrayPtr<const word> data) inline StructReader(kj::ArrayPtr<const word> data)
: segment(nullptr), capTable(nullptr), data(data.begin()), pointers(nullptr), : segment(nullptr), capTable(nullptr), data(data.begin()), pointers(nullptr),
dataSize(data.size() * WORDS * BITS_PER_WORD), pointerCount(0), nestingLimit(0x7fffffff) {} dataSize(data.size() * WORDS * BITS_PER_WORD), pointerCount(ZERO * POINTERS),
nestingLimit(0x7fffffff) {}
const void* getLocation() const { return data; } const void* getLocation() const { return data; }
inline BitCount getDataSectionSize() const { return dataSize; } inline StructDataBitCount getDataSectionSize() const { return dataSize; }
inline WirePointerCount getPointerSectionSize() const { return pointerCount; } inline StructPointerCount getPointerSectionSize() const { return pointerCount; }
inline kj::ArrayPtr<const byte> getDataSectionAsBlob(); inline kj::ArrayPtr<const byte> getDataSectionAsBlob();
inline _::ListReader getPointerSectionAsList(); inline _::ListReader getPointerSectionAsList();
kj::Array<word> canonicalize(); kj::Array<word> canonicalize();
template <typename T> template <typename T>
KJ_ALWAYS_INLINE(bool hasDataField(ElementCount offset) const); KJ_ALWAYS_INLINE(bool hasDataField(StructDataElementOffset offset) const);
// Return true if the field is set to something other than its default value. // Return true if the field is set to something other than its default value.
template <typename T> template <typename T>
KJ_ALWAYS_INLINE(T getDataField(ElementCount offset) const); KJ_ALWAYS_INLINE(T getDataField(StructDataElementOffset offset) const);
// Get the data field value of the given type at the given offset. The offset is measured in // Get the data field value of the given type at the given offset. The offset is measured in
// multiples of the field size, determined by the type. Returns zero if the offset is past the // multiples of the field size, determined by the type. Returns zero if the offset is past the
// end of the struct's data section. // end of the struct's data section.
template <typename T> template <typename T>
KJ_ALWAYS_INLINE( KJ_ALWAYS_INLINE(T getDataField(StructDataElementOffset offset, Mask<T> mask) const);
T getDataField(ElementCount offset, Mask<T> mask) const);
// Like getDataField(offset), but applies the given XOR mask to the result. Used for reading // Like getDataField(offset), but applies the given XOR mask to the result. Used for reading
// fields with non-zero default values. // fields with non-zero default values.
KJ_ALWAYS_INLINE(PointerReader getPointerField(WirePointerCount ptrIndex) const); KJ_ALWAYS_INLINE(PointerReader getPointerField(StructPointerCount ptrIndex) const);
// Get a reader for a pointer field given the index within the pointer section. If the index // Get a reader for a pointer field given the index within the pointer section. If the index
// is out-of-bounds, returns a null pointer. // is out-of-bounds, returns a null pointer.
...@@ -628,11 +656,11 @@ private: ...@@ -628,11 +656,11 @@ private:
const void* data; const void* data;
const WirePointer* pointers; const WirePointer* pointers;
BitCount32 dataSize; StructDataBitCount dataSize;
// Size of data section. We use a bit count rather than a word count to more easily handle the // Size of data section. We use a bit count rather than a word count to more easily handle the
// case of struct lists encoded with less than a word per element. // case of struct lists encoded with less than a word per element.
WirePointerCount16 pointerCount; // Size of the pointer section. StructPointerCount pointerCount; // Size of the pointer section.
int nestingLimit; int nestingLimit;
// Limits the depth of message structures to guard against stack-overflow-based DoS attacks. // Limits the depth of message structures to guard against stack-overflow-based DoS attacks.
...@@ -641,7 +669,8 @@ private: ...@@ -641,7 +669,8 @@ private:
inline StructReader(SegmentReader* segment, CapTableReader* capTable, inline StructReader(SegmentReader* segment, CapTableReader* capTable,
const void* data, const WirePointer* pointers, const void* data, const WirePointer* pointers,
BitCount dataSize, WirePointerCount pointerCount, int nestingLimit) StructDataBitCount dataSize, StructPointerCount pointerCount,
int nestingLimit)
: segment(segment), capTable(capTable), data(data), pointers(pointers), : segment(segment), capTable(capTable), data(data), pointers(pointers),
dataSize(dataSize), pointerCount(pointerCount), dataSize(dataSize), pointerCount(pointerCount),
nestingLimit(nestingLimit) {} nestingLimit(nestingLimit) {}
...@@ -657,8 +686,8 @@ class ListBuilder: public kj::DisallowConstCopy { ...@@ -657,8 +686,8 @@ class ListBuilder: public kj::DisallowConstCopy {
public: public:
inline explicit ListBuilder(ElementSize elementSize) inline explicit ListBuilder(ElementSize elementSize)
: segment(nullptr), capTable(nullptr), ptr(nullptr), elementCount(0 * ELEMENTS), : segment(nullptr), capTable(nullptr), ptr(nullptr), elementCount(0 * ELEMENTS),
step(0 * BITS / ELEMENTS), structDataSize(0 * BITS), structPointerCount(0 * POINTERS), step(ZERO * BITS / ELEMENTS), structDataSize(0 * BITS),
elementSize(elementSize) {} structPointerCount(ZERO * POINTERS), elementSize(elementSize) {}
inline word* getLocation() { inline word* getLocation() {
// Get the object's location. // Get the object's location.
...@@ -672,7 +701,7 @@ public: ...@@ -672,7 +701,7 @@ public:
inline ElementSize getElementSize() const { return elementSize; } inline ElementSize getElementSize() const { return elementSize; }
inline ElementCount size() const; inline ListElementCount size() const;
// The number of elements in the list. // The number of elements in the list.
Text::Builder asText(); Text::Builder asText();
...@@ -684,8 +713,7 @@ public: ...@@ -684,8 +713,7 @@ public:
// Get the element of the given type at the given index. // Get the element of the given type at the given index.
template <typename T> template <typename T>
KJ_ALWAYS_INLINE(void setDataElement( KJ_ALWAYS_INLINE(void setDataElement(ElementCount index, kj::NoInfer<T> value));
ElementCount index, kj::NoInfer<T> value));
// Set the element at the given index. // Set the element at the given index.
KJ_ALWAYS_INLINE(PointerBuilder getPointerElement(ElementCount index)); KJ_ALWAYS_INLINE(PointerBuilder getPointerElement(ElementCount index));
...@@ -710,13 +738,14 @@ private: ...@@ -710,13 +738,14 @@ private:
byte* ptr; // Pointer to list content. byte* ptr; // Pointer to list content.
ElementCount elementCount; // Number of elements in the list. ListElementCount elementCount; // Number of elements in the list.
decltype(BITS / ELEMENTS) step; BitsPerElementN<23> step;
// The distance between elements. // The distance between elements. The maximum value occurs when a struct contains 2^16-1 data
// words and 2^16-1 pointers, i.e. 2^17 - 2 words, or 2^23 - 128 bits.
BitCount32 structDataSize; StructDataBitCount structDataSize;
WirePointerCount16 structPointerCount; StructPointerCount structPointerCount;
// The struct properties to use when interpreting the elements as structs. All lists can be // The struct properties to use when interpreting the elements as structs. All lists can be
// interpreted as struct lists, so these are always filled in. // interpreted as struct lists, so these are always filled in.
...@@ -726,7 +755,7 @@ private: ...@@ -726,7 +755,7 @@ private:
inline ListBuilder(SegmentBuilder* segment, CapTableBuilder* capTable, void* ptr, inline ListBuilder(SegmentBuilder* segment, CapTableBuilder* capTable, void* ptr,
decltype(BITS / ELEMENTS) step, ElementCount size, decltype(BITS / ELEMENTS) step, ElementCount size,
BitCount structDataSize, WirePointerCount structPointerCount, StructDataBitCount structDataSize, StructPointerCount structPointerCount,
ElementSize elementSize) ElementSize elementSize)
: segment(segment), capTable(capTable), ptr(reinterpret_cast<byte*>(ptr)), : segment(segment), capTable(capTable), ptr(reinterpret_cast<byte*>(ptr)),
elementCount(size), step(step), structDataSize(structDataSize), elementCount(size), step(step), structDataSize(structDataSize),
...@@ -740,11 +769,11 @@ private: ...@@ -740,11 +769,11 @@ private:
class ListReader { class ListReader {
public: public:
inline explicit ListReader(ElementSize elementSize) inline explicit ListReader(ElementSize elementSize)
: segment(nullptr), capTable(nullptr), ptr(nullptr), elementCount(0), : segment(nullptr), capTable(nullptr), ptr(nullptr), elementCount(ZERO * ELEMENTS),
step(0 * BITS / ELEMENTS), structDataSize(0), structPointerCount(0), step(ZERO * BITS / ELEMENTS), structDataSize(ZERO * BITS),
elementSize(elementSize), nestingLimit(0x7fffffff) {} structPointerCount(ZERO * POINTERS), elementSize(elementSize), nestingLimit(0x7fffffff) {}
inline ElementCount size() const; inline ListElementCount size() const;
// The number of elements in the list. // The number of elements in the list.
inline ElementSize getElementSize() const { return elementSize; } inline ElementSize getElementSize() const { return elementSize; }
...@@ -782,13 +811,14 @@ private: ...@@ -782,13 +811,14 @@ private:
const byte* ptr; // Pointer to list content. const byte* ptr; // Pointer to list content.
ElementCount elementCount; // Number of elements in the list. ListElementCount elementCount; // Number of elements in the list.
decltype(BITS / ELEMENTS) step; BitsPerElementN<23> step;
// The distance between elements. // The distance between elements. The maximum value occurs when a struct contains 2^16-1 data
// words and 2^16-1 pointers, i.e. 2^17 - 2 words, or 2^23 - 2 bits.
BitCount32 structDataSize; StructDataBitCount structDataSize;
WirePointerCount16 structPointerCount; StructPointerCount structPointerCount;
// The struct properties to use when interpreting the elements as structs. All lists can be // The struct properties to use when interpreting the elements as structs. All lists can be
// interpreted as struct lists, so these are always filled in. // interpreted as struct lists, so these are always filled in.
...@@ -801,8 +831,8 @@ private: ...@@ -801,8 +831,8 @@ private:
// Once this reaches zero, further pointers will be pruned. // Once this reaches zero, further pointers will be pruned.
inline ListReader(SegmentReader* segment, CapTableReader* capTable, const void* ptr, inline ListReader(SegmentReader* segment, CapTableReader* capTable, const void* ptr,
ElementCount elementCount, decltype(BITS / ELEMENTS) step, ElementCount elementCount, BitsPerElementN<23> step,
BitCount structDataSize, WirePointerCount structPointerCount, StructDataBitCount structDataSize, StructPointerCount structPointerCount,
ElementSize elementSize, int nestingLimit) ElementSize elementSize, int nestingLimit)
: segment(segment), capTable(capTable), ptr(reinterpret_cast<const byte*>(ptr)), : segment(segment), capTable(capTable), ptr(reinterpret_cast<const byte*>(ptr)),
elementCount(elementCount), step(step), structDataSize(structDataSize), elementCount(elementCount), step(step), structDataSize(structDataSize),
...@@ -888,7 +918,7 @@ public: ...@@ -888,7 +918,7 @@ public:
// Versions of truncate() that know how to allocate a new list if needed. // Versions of truncate() that know how to allocate a new list if needed.
private: private:
static_assert(1 * POINTERS * WORDS_PER_POINTER == 1 * WORDS, static_assert(ONE * POINTERS * WORDS_PER_POINTER == ONE * WORDS,
"This struct assumes a pointer is one word."); "This struct assumes a pointer is one word.");
word tag; word tag;
// Contains an encoded WirePointer representing this object. WirePointer is defined in // Contains an encoded WirePointer representing this object. WirePointer is defined in
...@@ -934,13 +964,17 @@ private: ...@@ -934,13 +964,17 @@ private:
// These are defined in the source file. // These are defined in the source file.
template <> typename Text::Builder PointerBuilder::initBlob<Text>(ByteCount size); template <> typename Text::Builder PointerBuilder::initBlob<Text>(ByteCount size);
template <> void PointerBuilder::setBlob<Text>(typename Text::Reader value); template <> void PointerBuilder::setBlob<Text>(typename Text::Reader value);
template <> typename Text::Builder PointerBuilder::getBlob<Text>(const void* defaultValue, ByteCount defaultSize); template <> typename Text::Builder PointerBuilder::getBlob<Text>(
template <> typename Text::Reader PointerReader::getBlob<Text>(const void* defaultValue, ByteCount defaultSize) const; const void* defaultValue, ByteCount defaultSize);
template <> typename Text::Reader PointerReader::getBlob<Text>(
const void* defaultValue, ByteCount defaultSize) const;
template <> typename Data::Builder PointerBuilder::initBlob<Data>(ByteCount size); template <> typename Data::Builder PointerBuilder::initBlob<Data>(ByteCount size);
template <> void PointerBuilder::setBlob<Data>(typename Data::Reader value); template <> void PointerBuilder::setBlob<Data>(typename Data::Reader value);
template <> typename Data::Builder PointerBuilder::getBlob<Data>(const void* defaultValue, ByteCount defaultSize); template <> typename Data::Builder PointerBuilder::getBlob<Data>(
template <> typename Data::Reader PointerReader::getBlob<Data>(const void* defaultValue, ByteCount defaultSize) const; const void* defaultValue, ByteCount defaultSize);
template <> typename Data::Reader PointerReader::getBlob<Data>(
const void* defaultValue, ByteCount defaultSize) const;
inline PointerBuilder PointerBuilder::getRoot( inline PointerBuilder PointerBuilder::getRoot(
SegmentBuilder* segment, CapTableBuilder* capTable, word* location) { SegmentBuilder* segment, CapTableBuilder* capTable, word* location) {
...@@ -955,82 +989,85 @@ inline PointerReader PointerReader::getRootUnchecked(const word* location) { ...@@ -955,82 +989,85 @@ inline PointerReader PointerReader::getRootUnchecked(const word* location) {
// ------------------------------------------------------------------- // -------------------------------------------------------------------
inline kj::ArrayPtr<byte> StructBuilder::getDataSectionAsBlob() { inline kj::ArrayPtr<byte> StructBuilder::getDataSectionAsBlob() {
return kj::ArrayPtr<byte>(reinterpret_cast<byte*>(data), dataSize / BITS_PER_BYTE / BYTES); return kj::ArrayPtr<byte>(reinterpret_cast<byte*>(data),
unguard(dataSize / BITS_PER_BYTE / BYTES));
} }
inline _::ListBuilder StructBuilder::getPointerSectionAsList() { inline _::ListBuilder StructBuilder::getPointerSectionAsList() {
return _::ListBuilder(segment, capTable, pointers, 1 * POINTERS * BITS_PER_POINTER / ELEMENTS, return _::ListBuilder(segment, capTable, pointers, ONE * POINTERS * BITS_PER_POINTER / ELEMENTS,
pointerCount * (1 * ELEMENTS / POINTERS), pointerCount * (ONE * ELEMENTS / POINTERS),
0 * BITS, 1 * POINTERS, ElementSize::POINTER); ZERO * BITS, ONE * POINTERS, ElementSize::POINTER);
} }
template <typename T> template <typename T>
inline bool StructBuilder::hasDataField(ElementCount offset) { inline bool StructBuilder::hasDataField(StructDataElementOffset offset) {
return getDataField<Mask<T>>(offset) != 0; return getDataField<Mask<T>>(offset) != 0;
} }
template <> template <>
inline bool StructBuilder::hasDataField<Void>(ElementCount offset) { inline bool StructBuilder::hasDataField<Void>(StructDataElementOffset offset) {
return false; return false;
} }
template <typename T> template <typename T>
inline T StructBuilder::getDataField(ElementCount offset) { inline T StructBuilder::getDataField(StructDataElementOffset offset) {
return reinterpret_cast<WireValue<T>*>(data)[offset / ELEMENTS].get(); return reinterpret_cast<WireValue<T>*>(data)[offset / ELEMENTS].get();
} }
template <> template <>
inline bool StructBuilder::getDataField<bool>(ElementCount offset) { inline bool StructBuilder::getDataField<bool>(StructDataElementOffset offset) {
BitCount boffset = offset * (1 * BITS / ELEMENTS); BitCountN<22> boffset = offset * (ONE * BITS / ELEMENTS);
byte* b = reinterpret_cast<byte*>(data) + boffset / BITS_PER_BYTE; byte* b = reinterpret_cast<byte*>(data) + boffset / BITS_PER_BYTE;
return (*reinterpret_cast<uint8_t*>(b) & (1 << (boffset % BITS_PER_BYTE / BITS))) != 0; return (*reinterpret_cast<uint8_t*>(b) &
unguard(ONE << (boffset % BITS_PER_BYTE / BITS))) != 0;
} }
template <> template <>
inline Void StructBuilder::getDataField<Void>(ElementCount offset) { inline Void StructBuilder::getDataField<Void>(StructDataElementOffset offset) {
return VOID; return VOID;
} }
template <typename T> template <typename T>
inline T StructBuilder::getDataField(ElementCount offset, Mask<T> mask) { inline T StructBuilder::getDataField(StructDataElementOffset offset, Mask<T> mask) {
return unmask<T>(getDataField<Mask<T> >(offset), mask); return unmask<T>(getDataField<Mask<T> >(offset), mask);
} }
template <typename T> template <typename T>
inline void StructBuilder::setDataField(ElementCount offset, kj::NoInfer<T> value) { inline void StructBuilder::setDataField(StructDataElementOffset offset, kj::NoInfer<T> value) {
reinterpret_cast<WireValue<T>*>(data)[offset / ELEMENTS].set(value); reinterpret_cast<WireValue<T>*>(data)[offset / ELEMENTS].set(value);
} }
#if CAPNP_CANONICALIZE_NAN #if CAPNP_CANONICALIZE_NAN
// Use mask() on floats and doubles to make sure we canonicalize NaNs. // Use mask() on floats and doubles to make sure we canonicalize NaNs.
template <> template <>
inline void StructBuilder::setDataField<float>(ElementCount offset, float value) { inline void StructBuilder::setDataField<float>(StructDataElementOffset offset, float value) {
setDataField<uint32_t>(offset, mask<float>(value, 0)); setDataField<uint32_t>(offset, mask<float>(value, 0));
} }
template <> template <>
inline void StructBuilder::setDataField<double>(ElementCount offset, double value) { inline void StructBuilder::setDataField<double>(StructDataElementOffset offset, double value) {
setDataField<uint64_t>(offset, mask<double>(value, 0)); setDataField<uint64_t>(offset, mask<double>(value, 0));
} }
#endif #endif
template <> template <>
inline void StructBuilder::setDataField<bool>(ElementCount offset, bool value) { inline void StructBuilder::setDataField<bool>(StructDataElementOffset offset, bool value) {
BitCount boffset = offset * (1 * BITS / ELEMENTS); auto boffset = offset * (ONE * BITS / ELEMENTS);
byte* b = reinterpret_cast<byte*>(data) + boffset / BITS_PER_BYTE; byte* b = reinterpret_cast<byte*>(data) + boffset / BITS_PER_BYTE;
uint bitnum = boffset % BITS_PER_BYTE / BITS; uint bitnum = unguardMaxBits<3>(boffset % BITS_PER_BYTE / BITS);
*reinterpret_cast<uint8_t*>(b) = (*reinterpret_cast<uint8_t*>(b) & ~(1 << bitnum)) *reinterpret_cast<uint8_t*>(b) = (*reinterpret_cast<uint8_t*>(b) & ~(1 << bitnum))
| (static_cast<uint8_t>(value) << bitnum); | (static_cast<uint8_t>(value) << bitnum);
} }
template <> template <>
inline void StructBuilder::setDataField<Void>(ElementCount offset, Void value) {} inline void StructBuilder::setDataField<Void>(StructDataElementOffset offset, Void value) {}
template <typename T> template <typename T>
inline void StructBuilder::setDataField(ElementCount offset, kj::NoInfer<T> value, Mask<T> m) { inline void StructBuilder::setDataField(StructDataElementOffset offset,
kj::NoInfer<T> value, Mask<T> m) {
setDataField<Mask<T> >(offset, mask<T>(value, m)); setDataField<Mask<T> >(offset, mask<T>(value, m));
} }
inline PointerBuilder StructBuilder::getPointerField(WirePointerCount ptrIndex) { inline PointerBuilder StructBuilder::getPointerField(StructPointerCount ptrIndex) {
// Hacky because WirePointer is defined in the .c++ file (so is incomplete here). // Hacky because WirePointer is defined in the .c++ file (so is incomplete here).
return PointerBuilder(segment, capTable, reinterpret_cast<WirePointer*>( return PointerBuilder(segment, capTable, reinterpret_cast<WirePointer*>(
reinterpret_cast<word*>(pointers) + ptrIndex * WORDS_PER_POINTER)); reinterpret_cast<word*>(pointers) + ptrIndex * WORDS_PER_POINTER));
...@@ -1039,28 +1076,29 @@ inline PointerBuilder StructBuilder::getPointerField(WirePointerCount ptrIndex) ...@@ -1039,28 +1076,29 @@ inline PointerBuilder StructBuilder::getPointerField(WirePointerCount ptrIndex)
// ------------------------------------------------------------------- // -------------------------------------------------------------------
inline kj::ArrayPtr<const byte> StructReader::getDataSectionAsBlob() { inline kj::ArrayPtr<const byte> StructReader::getDataSectionAsBlob() {
return kj::ArrayPtr<const byte>(reinterpret_cast<const byte*>(data), dataSize / BITS_PER_BYTE / BYTES); return kj::ArrayPtr<const byte>(reinterpret_cast<const byte*>(data),
unguard(dataSize / BITS_PER_BYTE / BYTES));
} }
inline _::ListReader StructReader::getPointerSectionAsList() { inline _::ListReader StructReader::getPointerSectionAsList() {
return _::ListReader(segment, capTable, pointers, pointerCount * (1 * ELEMENTS / POINTERS), return _::ListReader(segment, capTable, pointers, pointerCount * (ONE * ELEMENTS / POINTERS),
1 * POINTERS * BITS_PER_POINTER / ELEMENTS, 0 * BITS, 1 * POINTERS, ONE * POINTERS * BITS_PER_POINTER / ELEMENTS, ZERO * BITS, ONE * POINTERS,
ElementSize::POINTER, nestingLimit); ElementSize::POINTER, nestingLimit);
} }
template <typename T> template <typename T>
inline bool StructReader::hasDataField(ElementCount offset) const { inline bool StructReader::hasDataField(StructDataElementOffset offset) const {
return getDataField<Mask<T>>(offset) != 0; return getDataField<Mask<T>>(offset) != 0;
} }
template <> template <>
inline bool StructReader::hasDataField<Void>(ElementCount offset) const { inline bool StructReader::hasDataField<Void>(StructDataElementOffset offset) const {
return false; return false;
} }
template <typename T> template <typename T>
inline T StructReader::getDataField(ElementCount offset) const { inline T StructReader::getDataField(StructDataElementOffset offset) const {
if ((offset + 1 * ELEMENTS) * capnp::bitsPerElement<T>() <= dataSize) { if ((offset + ONE * ELEMENTS) * capnp::bitsPerElement<T>() <= dataSize) {
return reinterpret_cast<const WireValue<T>*>(data)[offset / ELEMENTS].get(); return reinterpret_cast<const WireValue<T>*>(data)[offset / ELEMENTS].get();
} else { } else {
return static_cast<T>(0); return static_cast<T>(0);
...@@ -1068,27 +1106,28 @@ inline T StructReader::getDataField(ElementCount offset) const { ...@@ -1068,27 +1106,28 @@ inline T StructReader::getDataField(ElementCount offset) const {
} }
template <> template <>
inline bool StructReader::getDataField<bool>(ElementCount offset) const { inline bool StructReader::getDataField<bool>(StructDataElementOffset offset) const {
BitCount boffset = offset * (1 * BITS / ELEMENTS); auto boffset = offset * (ONE * BITS / ELEMENTS);
if (boffset < dataSize) { if (boffset < dataSize) {
const byte* b = reinterpret_cast<const byte*>(data) + boffset / BITS_PER_BYTE; const byte* b = reinterpret_cast<const byte*>(data) + boffset / BITS_PER_BYTE;
return (*reinterpret_cast<const uint8_t*>(b) & (1 << (boffset % BITS_PER_BYTE / BITS))) != 0; return (*reinterpret_cast<const uint8_t*>(b) &
unguard(ONE << (boffset % BITS_PER_BYTE / BITS))) != 0;
} else { } else {
return false; return false;
} }
} }
template <> template <>
inline Void StructReader::getDataField<Void>(ElementCount offset) const { inline Void StructReader::getDataField<Void>(StructDataElementOffset offset) const {
return VOID; return VOID;
} }
template <typename T> template <typename T>
T StructReader::getDataField(ElementCount offset, Mask<T> mask) const { T StructReader::getDataField(StructDataElementOffset offset, Mask<T> mask) const {
return unmask<T>(getDataField<Mask<T> >(offset), mask); return unmask<T>(getDataField<Mask<T> >(offset), mask);
} }
inline PointerReader StructReader::getPointerField(WirePointerCount ptrIndex) const { inline PointerReader StructReader::getPointerField(StructPointerCount ptrIndex) const {
if (ptrIndex < pointerCount) { if (ptrIndex < pointerCount) {
// Hacky because WirePointer is defined in the .c++ file (so is incomplete here). // Hacky because WirePointer is defined in the .c++ file (so is incomplete here).
return PointerReader(segment, capTable, reinterpret_cast<const WirePointer*>( return PointerReader(segment, capTable, reinterpret_cast<const WirePointer*>(
...@@ -1100,11 +1139,12 @@ inline PointerReader StructReader::getPointerField(WirePointerCount ptrIndex) co ...@@ -1100,11 +1139,12 @@ inline PointerReader StructReader::getPointerField(WirePointerCount ptrIndex) co
// ------------------------------------------------------------------- // -------------------------------------------------------------------
inline ElementCount ListBuilder::size() const { return elementCount; } inline ListElementCount ListBuilder::size() const { return elementCount; }
template <typename T> template <typename T>
inline T ListBuilder::getDataElement(ElementCount index) { inline T ListBuilder::getDataElement(ElementCount index) {
return reinterpret_cast<WireValue<T>*>(ptr + index * step / BITS_PER_BYTE)->get(); return reinterpret_cast<WireValue<T>*>(
ptr + upgradeGuard<uint64_t>(index) * step / BITS_PER_BYTE)->get();
// TODO(perf): Benchmark this alternate implementation, which I suspect may make better use of // TODO(perf): Benchmark this alternate implementation, which I suspect may make better use of
// the x86 SIB byte. Also use it for all the other getData/setData implementations below, and // the x86 SIB byte. Also use it for all the other getData/setData implementations below, and
...@@ -1117,9 +1157,10 @@ inline T ListBuilder::getDataElement(ElementCount index) { ...@@ -1117,9 +1157,10 @@ inline T ListBuilder::getDataElement(ElementCount index) {
template <> template <>
inline bool ListBuilder::getDataElement<bool>(ElementCount index) { inline bool ListBuilder::getDataElement<bool>(ElementCount index) {
// Ignore step for bit lists because bit lists cannot be upgraded to struct lists. // Ignore step for bit lists because bit lists cannot be upgraded to struct lists.
BitCount bindex = index * (1 * BITS / ELEMENTS); auto bindex = index * (ONE * BITS / ELEMENTS);
byte* b = ptr + bindex / BITS_PER_BYTE; byte* b = ptr + bindex / BITS_PER_BYTE;
return (*reinterpret_cast<uint8_t*>(b) & (1 << (bindex % BITS_PER_BYTE / BITS))) != 0; return (*reinterpret_cast<uint8_t*>(b) &
unguard(ONE << (bindex % BITS_PER_BYTE / BITS))) != 0;
} }
template <> template <>
...@@ -1129,7 +1170,8 @@ inline Void ListBuilder::getDataElement<Void>(ElementCount index) { ...@@ -1129,7 +1170,8 @@ inline Void ListBuilder::getDataElement<Void>(ElementCount index) {
template <typename T> template <typename T>
inline void ListBuilder::setDataElement(ElementCount index, kj::NoInfer<T> value) { inline void ListBuilder::setDataElement(ElementCount index, kj::NoInfer<T> value) {
reinterpret_cast<WireValue<T>*>(ptr + index * step / BITS_PER_BYTE)->set(value); reinterpret_cast<WireValue<T>*>(
ptr + upgradeGuard<uint64_t>(index) * step / BITS_PER_BYTE)->set(value);
} }
#if CAPNP_CANONICALIZE_NAN #if CAPNP_CANONICALIZE_NAN
...@@ -1147,36 +1189,38 @@ inline void ListBuilder::setDataElement<double>(ElementCount index, double value ...@@ -1147,36 +1189,38 @@ inline void ListBuilder::setDataElement<double>(ElementCount index, double value
template <> template <>
inline void ListBuilder::setDataElement<bool>(ElementCount index, bool value) { inline void ListBuilder::setDataElement<bool>(ElementCount index, bool value) {
// Ignore stepBytes for bit lists because bit lists cannot be upgraded to struct lists. // Ignore stepBytes for bit lists because bit lists cannot be upgraded to struct lists.
BitCount bindex = index * (1 * BITS / ELEMENTS); auto bindex = index * (ONE * BITS / ELEMENTS);
byte* b = ptr + bindex / BITS_PER_BYTE; byte* b = ptr + bindex / BITS_PER_BYTE;
uint bitnum = bindex % BITS_PER_BYTE / BITS; auto bitnum = bindex % BITS_PER_BYTE / BITS;
*reinterpret_cast<uint8_t*>(b) = (*reinterpret_cast<uint8_t*>(b) & ~(1 << bitnum)) *reinterpret_cast<uint8_t*>(b) = (*reinterpret_cast<uint8_t*>(b) & ~(1 << unguard(bitnum)))
| (static_cast<uint8_t>(value) << bitnum); | (static_cast<uint8_t>(value) << unguard(bitnum));
} }
template <> template <>
inline void ListBuilder::setDataElement<Void>(ElementCount index, Void value) {} inline void ListBuilder::setDataElement<Void>(ElementCount index, Void value) {}
inline PointerBuilder ListBuilder::getPointerElement(ElementCount index) { inline PointerBuilder ListBuilder::getPointerElement(ElementCount index) {
return PointerBuilder(segment, capTable, return PointerBuilder(segment, capTable, reinterpret_cast<WirePointer*>(ptr +
reinterpret_cast<WirePointer*>(ptr + index * step / BITS_PER_BYTE)); upgradeGuard<uint64_t>(index) * step / BITS_PER_BYTE));
} }
// ------------------------------------------------------------------- // -------------------------------------------------------------------
inline ElementCount ListReader::size() const { return elementCount; } inline ListElementCount ListReader::size() const { return elementCount; }
template <typename T> template <typename T>
inline T ListReader::getDataElement(ElementCount index) const { inline T ListReader::getDataElement(ElementCount index) const {
return reinterpret_cast<const WireValue<T>*>(ptr + index * step / BITS_PER_BYTE)->get(); return reinterpret_cast<const WireValue<T>*>(
ptr + upgradeGuard<uint64_t>(index) * step / BITS_PER_BYTE)->get();
} }
template <> template <>
inline bool ListReader::getDataElement<bool>(ElementCount index) const { inline bool ListReader::getDataElement<bool>(ElementCount index) const {
// Ignore step for bit lists because bit lists cannot be upgraded to struct lists. // Ignore step for bit lists because bit lists cannot be upgraded to struct lists.
BitCount bindex = index * (1 * BITS / ELEMENTS); auto bindex = index * (ONE * BITS / ELEMENTS);
const byte* b = ptr + bindex / BITS_PER_BYTE; const byte* b = ptr + bindex / BITS_PER_BYTE;
return (*reinterpret_cast<const uint8_t*>(b) & (1 << (bindex % BITS_PER_BYTE / BITS))) != 0; return (*reinterpret_cast<const uint8_t*>(b) &
unguard(ONE << (bindex % BITS_PER_BYTE / BITS))) != 0;
} }
template <> template <>
...@@ -1185,8 +1229,8 @@ inline Void ListReader::getDataElement<Void>(ElementCount index) const { ...@@ -1185,8 +1229,8 @@ inline Void ListReader::getDataElement<Void>(ElementCount index) const {
} }
inline PointerReader ListReader::getPointerElement(ElementCount index) const { inline PointerReader ListReader::getPointerElement(ElementCount index) const {
return PointerReader(segment, capTable, return PointerReader(segment, capTable, reinterpret_cast<const WirePointer*>(
reinterpret_cast<const WirePointer*>(ptr + index * step / BITS_PER_BYTE), nestingLimit); ptr + upgradeGuard<uint64_t>(index) * step / BITS_PER_BYTE), nestingLimit);
} }
// ------------------------------------------------------------------- // -------------------------------------------------------------------
......
...@@ -117,7 +117,7 @@ struct List<T, Kind::PRIMITIVE> { ...@@ -117,7 +117,7 @@ struct List<T, Kind::PRIMITIVE> {
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return reader.size() / ELEMENTS; }
inline T operator[](uint index) const { inline T operator[](uint index) const {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return reader.template getDataElement<T>(index * ELEMENTS); return reader.template getDataElement<T>(guarded(index) * ELEMENTS);
} }
typedef _::IndexingIterator<const Reader, T> Iterator; typedef _::IndexingIterator<const Reader, T> Iterator;
...@@ -149,7 +149,7 @@ struct List<T, Kind::PRIMITIVE> { ...@@ -149,7 +149,7 @@ struct List<T, Kind::PRIMITIVE> {
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return builder.size() / ELEMENTS; }
inline T operator[](uint index) { inline T operator[](uint index) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return builder.template getDataElement<T>(index * ELEMENTS); return builder.template getDataElement<T>(guarded(index) * ELEMENTS);
} }
inline void set(uint index, T value) { inline void set(uint index, T value) {
// Alas, it is not possible to make operator[] return a reference to which you can assign, // Alas, it is not possible to make operator[] return a reference to which you can assign,
...@@ -158,7 +158,7 @@ struct List<T, Kind::PRIMITIVE> { ...@@ -158,7 +158,7 @@ struct List<T, Kind::PRIMITIVE> {
// operator=() because it will lead to surprising behavior when using type inference (e.g. // operator=() because it will lead to surprising behavior when using type inference (e.g.
// calling a template function with inferred argument types, or using "auto" or "decltype"). // calling a template function with inferred argument types, or using "auto" or "decltype").
builder.template setDataElement<T>(index * ELEMENTS, value); builder.template setDataElement<T>(guarded(index) * ELEMENTS, value);
} }
typedef _::IndexingIterator<Builder, T> Iterator; typedef _::IndexingIterator<Builder, T> Iterator;
...@@ -178,7 +178,7 @@ struct List<T, Kind::PRIMITIVE> { ...@@ -178,7 +178,7 @@ struct List<T, Kind::PRIMITIVE> {
private: private:
inline static _::ListBuilder initPointer(_::PointerBuilder builder, uint size) { inline static _::ListBuilder initPointer(_::PointerBuilder builder, uint size) {
return builder.initList(_::elementSizeForType<T>(), size * ELEMENTS); return builder.initList(_::elementSizeForType<T>(), guarded(size) * ELEMENTS);
} }
inline static _::ListBuilder getFromPointer(_::PointerBuilder builder, const word* defaultValue) { inline static _::ListBuilder getFromPointer(_::PointerBuilder builder, const word* defaultValue) {
return builder.getList(_::elementSizeForType<T>(), defaultValue); return builder.getList(_::elementSizeForType<T>(), defaultValue);
...@@ -213,7 +213,7 @@ struct List<T, Kind::STRUCT> { ...@@ -213,7 +213,7 @@ struct List<T, Kind::STRUCT> {
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return reader.size() / ELEMENTS; }
inline typename T::Reader operator[](uint index) const { inline typename T::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return typename T::Reader(reader.getStructElement(index * ELEMENTS)); return typename T::Reader(reader.getStructElement(guarded(index) * ELEMENTS));
} }
typedef _::IndexingIterator<const Reader, typename T::Reader> Iterator; typedef _::IndexingIterator<const Reader, typename T::Reader> Iterator;
...@@ -245,7 +245,7 @@ struct List<T, Kind::STRUCT> { ...@@ -245,7 +245,7 @@ struct List<T, Kind::STRUCT> {
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return builder.size() / ELEMENTS; }
inline typename T::Builder operator[](uint index) { inline typename T::Builder operator[](uint index) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return typename T::Builder(builder.getStructElement(index * ELEMENTS)); return typename T::Builder(builder.getStructElement(guarded(index) * ELEMENTS));
} }
inline void adoptWithCaveats(uint index, Orphan<T>&& orphan) { inline void adoptWithCaveats(uint index, Orphan<T>&& orphan) {
...@@ -263,8 +263,8 @@ struct List<T, Kind::STRUCT> { ...@@ -263,8 +263,8 @@ struct List<T, Kind::STRUCT> {
// We pass a zero-valued StructSize to asStruct() because we do not want the struct to be // We pass a zero-valued StructSize to asStruct() because we do not want the struct to be
// expanded under any circumstances. We're just going to throw it away anyway, and // expanded under any circumstances. We're just going to throw it away anyway, and
// transferContentFrom() already carefully compares the struct sizes before transferring. // transferContentFrom() already carefully compares the struct sizes before transferring.
builder.getStructElement(index * ELEMENTS).transferContentFrom( builder.getStructElement(guarded(index) * ELEMENTS).transferContentFrom(
orphan.builder.asStruct(_::StructSize(0 * WORDS, 0 * POINTERS))); orphan.builder.asStruct(_::StructSize(ZERO * WORDS, ZERO * POINTERS)));
} }
inline void setWithCaveats(uint index, const typename T::Reader& reader) { inline void setWithCaveats(uint index, const typename T::Reader& reader) {
// Mostly behaves like you'd expect `set` to behave, but with a caveat originating from // Mostly behaves like you'd expect `set` to behave, but with a caveat originating from
...@@ -278,7 +278,7 @@ struct List<T, Kind::STRUCT> { ...@@ -278,7 +278,7 @@ struct List<T, Kind::STRUCT> {
// protocol. (Plus, it's easier to use anyhow.) // protocol. (Plus, it's easier to use anyhow.)
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
builder.getStructElement(index * ELEMENTS).copyContentFrom(reader._reader); builder.getStructElement(guarded(index) * ELEMENTS).copyContentFrom(reader._reader);
} }
// There are no init(), set(), adopt(), or disown() methods for lists of structs because the // There are no init(), set(), adopt(), or disown() methods for lists of structs because the
...@@ -303,7 +303,7 @@ struct List<T, Kind::STRUCT> { ...@@ -303,7 +303,7 @@ struct List<T, Kind::STRUCT> {
private: private:
inline static _::ListBuilder initPointer(_::PointerBuilder builder, uint size) { inline static _::ListBuilder initPointer(_::PointerBuilder builder, uint size) {
return builder.initStructList(size * ELEMENTS, _::structSize<T>()); return builder.initStructList(guarded(size) * ELEMENTS, _::structSize<T>());
} }
inline static _::ListBuilder getFromPointer(_::PointerBuilder builder, const word* defaultValue) { inline static _::ListBuilder getFromPointer(_::PointerBuilder builder, const word* defaultValue) {
return builder.getStructList(_::structSize<T>(), defaultValue); return builder.getStructList(_::structSize<T>(), defaultValue);
...@@ -335,8 +335,8 @@ struct List<List<T>, Kind::LIST> { ...@@ -335,8 +335,8 @@ struct List<List<T>, Kind::LIST> {
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return reader.size() / ELEMENTS; }
inline typename List<T>::Reader operator[](uint index) const { inline typename List<T>::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return typename List<T>::Reader( return typename List<T>::Reader(_::PointerHelpers<List<T>>::get(
_::PointerHelpers<List<T>>::get(reader.getPointerElement(index * ELEMENTS))); reader.getPointerElement(guarded(index) * ELEMENTS)));
} }
typedef _::IndexingIterator<const Reader, typename List<T>::Reader> Iterator; typedef _::IndexingIterator<const Reader, typename List<T>::Reader> Iterator;
...@@ -368,17 +368,17 @@ struct List<List<T>, Kind::LIST> { ...@@ -368,17 +368,17 @@ struct List<List<T>, Kind::LIST> {
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return builder.size() / ELEMENTS; }
inline typename List<T>::Builder operator[](uint index) { inline typename List<T>::Builder operator[](uint index) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return typename List<T>::Builder( return typename List<T>::Builder(_::PointerHelpers<List<T>>::get(
_::PointerHelpers<List<T>>::get(builder.getPointerElement(index * ELEMENTS))); builder.getPointerElement(guarded(index) * ELEMENTS)));
} }
inline typename List<T>::Builder init(uint index, uint size) { inline typename List<T>::Builder init(uint index, uint size) {
KJ_IREQUIRE(index < this->size()); KJ_IREQUIRE(index < this->size());
return typename List<T>::Builder( return typename List<T>::Builder(_::PointerHelpers<List<T>>::init(
_::PointerHelpers<List<T>>::init(builder.getPointerElement(index * ELEMENTS), size)); builder.getPointerElement(guarded(index) * ELEMENTS), size));
} }
inline void set(uint index, typename List<T>::Reader value) { inline void set(uint index, typename List<T>::Reader value) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
builder.getPointerElement(index * ELEMENTS).setList(value.reader); builder.getPointerElement(guarded(index) * ELEMENTS).setList(value.reader);
} }
void set(uint index, std::initializer_list<ReaderFor<T>> value) { void set(uint index, std::initializer_list<ReaderFor<T>> value) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
...@@ -390,11 +390,11 @@ struct List<List<T>, Kind::LIST> { ...@@ -390,11 +390,11 @@ struct List<List<T>, Kind::LIST> {
} }
inline void adopt(uint index, Orphan<T>&& value) { inline void adopt(uint index, Orphan<T>&& value) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
builder.getPointerElement(index * ELEMENTS).adopt(kj::mv(value.builder)); builder.getPointerElement(guarded(index) * ELEMENTS).adopt(kj::mv(value.builder));
} }
inline Orphan<T> disown(uint index) { inline Orphan<T> disown(uint index) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return Orphan<T>(builder.getPointerElement(index * ELEMENTS).disown()); return Orphan<T>(builder.getPointerElement(guarded(index) * ELEMENTS).disown());
} }
typedef _::IndexingIterator<Builder, typename List<T>::Builder> Iterator; typedef _::IndexingIterator<Builder, typename List<T>::Builder> Iterator;
...@@ -414,7 +414,7 @@ struct List<List<T>, Kind::LIST> { ...@@ -414,7 +414,7 @@ struct List<List<T>, Kind::LIST> {
private: private:
inline static _::ListBuilder initPointer(_::PointerBuilder builder, uint size) { inline static _::ListBuilder initPointer(_::PointerBuilder builder, uint size) {
return builder.initList(ElementSize::POINTER, size * ELEMENTS); return builder.initList(ElementSize::POINTER, guarded(size) * ELEMENTS);
} }
inline static _::ListBuilder getFromPointer(_::PointerBuilder builder, const word* defaultValue) { inline static _::ListBuilder getFromPointer(_::PointerBuilder builder, const word* defaultValue) {
return builder.getList(ElementSize::POINTER, defaultValue); return builder.getList(ElementSize::POINTER, defaultValue);
...@@ -444,7 +444,8 @@ struct List<T, Kind::BLOB> { ...@@ -444,7 +444,8 @@ struct List<T, Kind::BLOB> {
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return reader.size() / ELEMENTS; }
inline typename T::Reader operator[](uint index) const { inline typename T::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return reader.getPointerElement(index * ELEMENTS).template getBlob<T>(nullptr, 0 * BYTES); return reader.getPointerElement(guarded(index) * ELEMENTS)
.template getBlob<T>(nullptr, ZERO * BYTES);
} }
typedef _::IndexingIterator<const Reader, typename T::Reader> Iterator; typedef _::IndexingIterator<const Reader, typename T::Reader> Iterator;
...@@ -476,23 +477,25 @@ struct List<T, Kind::BLOB> { ...@@ -476,23 +477,25 @@ struct List<T, Kind::BLOB> {
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return builder.size() / ELEMENTS; }
inline typename T::Builder operator[](uint index) { inline typename T::Builder operator[](uint index) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return builder.getPointerElement(index * ELEMENTS).template getBlob<T>(nullptr, 0 * BYTES); return builder.getPointerElement(guarded(index) * ELEMENTS)
.template getBlob<T>(nullptr, ZERO * BYTES);
} }
inline void set(uint index, typename T::Reader value) { inline void set(uint index, typename T::Reader value) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
builder.getPointerElement(index * ELEMENTS).template setBlob<T>(value); builder.getPointerElement(guarded(index) * ELEMENTS).template setBlob<T>(value);
} }
inline typename T::Builder init(uint index, uint size) { inline typename T::Builder init(uint index, uint size) {
KJ_IREQUIRE(index < this->size()); KJ_IREQUIRE(index < this->size());
return builder.getPointerElement(index * ELEMENTS).template initBlob<T>(size * BYTES); return builder.getPointerElement(guarded(index) * ELEMENTS)
.template initBlob<T>(guarded(size) * BYTES);
} }
inline void adopt(uint index, Orphan<T>&& value) { inline void adopt(uint index, Orphan<T>&& value) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
builder.getPointerElement(index * ELEMENTS).adopt(kj::mv(value.builder)); builder.getPointerElement(guarded(index) * ELEMENTS).adopt(kj::mv(value.builder));
} }
inline Orphan<T> disown(uint index) { inline Orphan<T> disown(uint index) {
KJ_IREQUIRE(index < size()); KJ_IREQUIRE(index < size());
return Orphan<T>(builder.getPointerElement(index * ELEMENTS).disown()); return Orphan<T>(builder.getPointerElement(guarded(index) * ELEMENTS).disown());
} }
typedef _::IndexingIterator<Builder, typename T::Builder> Iterator; typedef _::IndexingIterator<Builder, typename T::Builder> Iterator;
...@@ -512,7 +515,7 @@ struct List<T, Kind::BLOB> { ...@@ -512,7 +515,7 @@ struct List<T, Kind::BLOB> {
private: private:
inline static _::ListBuilder initPointer(_::PointerBuilder builder, uint size) { inline static _::ListBuilder initPointer(_::PointerBuilder builder, uint size) {
return builder.initList(ElementSize::POINTER, size * ELEMENTS); return builder.initList(ElementSize::POINTER, guarded(size) * ELEMENTS);
} }
inline static _::ListBuilder getFromPointer(_::PointerBuilder builder, const word* defaultValue) { inline static _::ListBuilder getFromPointer(_::PointerBuilder builder, const word* defaultValue) {
return builder.getList(ElementSize::POINTER, defaultValue); return builder.getList(ElementSize::POINTER, defaultValue);
......
...@@ -135,7 +135,7 @@ _::SegmentBuilder* MessageBuilder::getRootSegment() { ...@@ -135,7 +135,7 @@ _::SegmentBuilder* MessageBuilder::getRootSegment() {
KJ_ASSERT(allocation.segment->getSegmentId() == _::SegmentId(0), KJ_ASSERT(allocation.segment->getSegmentId() == _::SegmentId(0),
"First allocated word of new arena was not in segment ID 0."); "First allocated word of new arena was not in segment ID 0.");
KJ_ASSERT(allocation.words == allocation.segment->getPtrUnchecked(0 * WORDS), KJ_ASSERT(allocation.words == allocation.segment->getPtrUnchecked(ZERO * WORDS),
"First allocated word of new arena was not the first word in its segment."); "First allocated word of new arena was not the first word in its segment.");
return allocation.segment; return allocation.segment;
} }
...@@ -144,7 +144,7 @@ _::SegmentBuilder* MessageBuilder::getRootSegment() { ...@@ -144,7 +144,7 @@ _::SegmentBuilder* MessageBuilder::getRootSegment() {
AnyPointer::Builder MessageBuilder::getRootInternal() { AnyPointer::Builder MessageBuilder::getRootInternal() {
_::SegmentBuilder* rootSegment = getRootSegment(); _::SegmentBuilder* rootSegment = getRootSegment();
return AnyPointer::Builder(_::PointerBuilder::getRoot( return AnyPointer::Builder(_::PointerBuilder::getRoot(
rootSegment, arena()->getLocalCapTable(), rootSegment->getPtrUnchecked(0 * WORDS))); rootSegment, arena()->getLocalCapTable(), rootSegment->getPtrUnchecked(ZERO * WORDS)));
} }
kj::ArrayPtr<const kj::ArrayPtr<const word>> MessageBuilder::getSegmentsForOutput() { kj::ArrayPtr<const kj::ArrayPtr<const word>> MessageBuilder::getSegmentsForOutput() {
......
...@@ -307,17 +307,17 @@ inline ReaderFor<T> Orphan<T>::getReader() const { ...@@ -307,17 +307,17 @@ inline ReaderFor<T> Orphan<T>::getReader() const {
template <typename T> template <typename T>
inline void Orphan<T>::truncate(uint size) { inline void Orphan<T>::truncate(uint size) {
_::OrphanGetImpl<ListElementType<T>>::truncateListOf(builder, size * ELEMENTS); _::OrphanGetImpl<ListElementType<T>>::truncateListOf(builder, guarded(size) * ELEMENTS);
} }
template <> template <>
inline void Orphan<Text>::truncate(uint size) { inline void Orphan<Text>::truncate(uint size) {
builder.truncateText(size * ELEMENTS); builder.truncateText(guarded(size) * ELEMENTS);
} }
template <> template <>
inline void Orphan<Data>::truncate(uint size) { inline void Orphan<Data>::truncate(uint size) {
builder.truncate(size * ELEMENTS, ElementSize::BYTE); builder.truncate(guarded(size) * ELEMENTS, ElementSize::BYTE);
} }
template <typename T> template <typename T>
...@@ -350,7 +350,7 @@ struct Orphanage::NewOrphanListImpl<List<T, k>> { ...@@ -350,7 +350,7 @@ struct Orphanage::NewOrphanListImpl<List<T, k>> {
static inline _::OrphanBuilder apply( static inline _::OrphanBuilder apply(
_::BuilderArena* arena, _::CapTableBuilder* capTable, uint size) { _::BuilderArena* arena, _::CapTableBuilder* capTable, uint size) {
return _::OrphanBuilder::initList( return _::OrphanBuilder::initList(
arena, capTable, size * ELEMENTS, _::ElementSizeForType<T>::value); arena, capTable, guarded(size) * ELEMENTS, _::ElementSizeForType<T>::value);
} }
}; };
...@@ -359,7 +359,7 @@ struct Orphanage::NewOrphanListImpl<List<T, Kind::STRUCT>> { ...@@ -359,7 +359,7 @@ struct Orphanage::NewOrphanListImpl<List<T, Kind::STRUCT>> {
static inline _::OrphanBuilder apply( static inline _::OrphanBuilder apply(
_::BuilderArena* arena, _::CapTableBuilder* capTable, uint size) { _::BuilderArena* arena, _::CapTableBuilder* capTable, uint size) {
return _::OrphanBuilder::initStructList( return _::OrphanBuilder::initStructList(
arena, capTable, size * ELEMENTS, _::structSize<T>()); arena, capTable, guarded(size) * ELEMENTS, _::structSize<T>());
} }
}; };
...@@ -367,7 +367,7 @@ template <> ...@@ -367,7 +367,7 @@ template <>
struct Orphanage::NewOrphanListImpl<Text> { struct Orphanage::NewOrphanListImpl<Text> {
static inline _::OrphanBuilder apply( static inline _::OrphanBuilder apply(
_::BuilderArena* arena, _::CapTableBuilder* capTable, uint size) { _::BuilderArena* arena, _::CapTableBuilder* capTable, uint size) {
return _::OrphanBuilder::initText(arena, capTable, size * BYTES); return _::OrphanBuilder::initText(arena, capTable, guarded(size) * BYTES);
} }
}; };
...@@ -375,7 +375,7 @@ template <> ...@@ -375,7 +375,7 @@ template <>
struct Orphanage::NewOrphanListImpl<Data> { struct Orphanage::NewOrphanListImpl<Data> {
static inline _::OrphanBuilder apply( static inline _::OrphanBuilder apply(
_::BuilderArena* arena, _::CapTableBuilder* capTable, uint size) { _::BuilderArena* arena, _::CapTableBuilder* capTable, uint size) {
return _::OrphanBuilder::initData(arena, capTable, size * BYTES); return _::OrphanBuilder::initData(arena, capTable, guarded(size) * BYTES);
} }
}; };
......
...@@ -113,12 +113,12 @@ struct PointerHelpers<T, Kind::BLOB> { ...@@ -113,12 +113,12 @@ struct PointerHelpers<T, Kind::BLOB> {
static inline typename T::Reader get(PointerReader reader, static inline typename T::Reader get(PointerReader reader,
const void* defaultValue = nullptr, const void* defaultValue = nullptr,
uint defaultBytes = 0) { uint defaultBytes = 0) {
return reader.getBlob<T>(defaultValue, defaultBytes * BYTES); return reader.getBlob<T>(defaultValue, guarded(defaultBytes) * BYTES);
} }
static inline typename T::Builder get(PointerBuilder builder, static inline typename T::Builder get(PointerBuilder builder,
const void* defaultValue = nullptr, const void* defaultValue = nullptr,
uint defaultBytes = 0) { uint defaultBytes = 0) {
return builder.getBlob<T>(defaultValue, defaultBytes * BYTES); return builder.getBlob<T>(defaultValue, guarded(defaultBytes) * BYTES);
} }
static inline void set(PointerBuilder builder, typename T::Reader value) { static inline void set(PointerBuilder builder, typename T::Reader value) {
builder.setBlob<T>(value); builder.setBlob<T>(value);
...@@ -127,7 +127,7 @@ struct PointerHelpers<T, Kind::BLOB> { ...@@ -127,7 +127,7 @@ struct PointerHelpers<T, Kind::BLOB> {
builder.setBlob<T>(value); builder.setBlob<T>(value);
} }
static inline typename T::Builder init(PointerBuilder builder, uint size) { static inline typename T::Builder init(PointerBuilder builder, uint size) {
return builder.initBlob<T>(size * BYTES); return builder.initBlob<T>(guarded(size) * BYTES);
} }
static inline void adopt(PointerBuilder builder, Orphan<T>&& value) { static inline void adopt(PointerBuilder builder, Orphan<T>&& value) {
builder.adopt(kj::mv(value.builder)); builder.adopt(kj::mv(value.builder));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment