Commit d7081cca authored by Kenton Varda's avatar Kenton Varda

totalSizeInWords() method on generated struct readers/builders returns total…

totalSizeInWords() method on generated struct readers/builders returns total size of the struct, not counting far pointer overhead.  Useful for allocating space for a flat copy.  Also a couple misc tweaks.
parent ecdd2ae4
...@@ -34,6 +34,16 @@ namespace internal { ...@@ -34,6 +34,16 @@ namespace internal {
Arena::~Arena() {} Arena::~Arena() {}
void ReadLimiter::unread(WordCount64 amount) {
// Be careful not to overflow here. Since ReadLimiter has no thread-safety, it's possible that
// the limit value was not updated correctly for one or more reads, and therefore unread() could
// overflow it even if it is only unreading bytes that were acutally read.
WordCount64 newValue = limit + amount;
if (newValue > limit) {
limit = newValue;
}
}
// ======================================================================================= // =======================================================================================
ReaderArena::ReaderArena(MessageReader* message) ReaderArena::ReaderArena(MessageReader* message)
......
...@@ -71,6 +71,10 @@ public: ...@@ -71,6 +71,10 @@ public:
CAPNPROTO_ALWAYS_INLINE(bool canRead(WordCount amount, Arena* arena)); CAPNPROTO_ALWAYS_INLINE(bool canRead(WordCount amount, Arena* arena));
void unread(WordCount64 amount);
// Adds back some words to the limit. Useful when the caller knows they are double-reading
// some data.
private: private:
WordCount64 limit; WordCount64 limit;
...@@ -93,6 +97,9 @@ public: ...@@ -93,6 +97,9 @@ public:
inline ArrayPtr<const word> getArray(); inline ArrayPtr<const word> getArray();
inline void unread(WordCount64 amount);
// Add back some words to the ReadLimiter.
private: private:
Arena* arena; Arena* arena;
SegmentId id; SegmentId id;
...@@ -242,6 +249,7 @@ inline WordCount SegmentReader::getOffsetTo(const word* ptr) { ...@@ -242,6 +249,7 @@ inline WordCount SegmentReader::getOffsetTo(const word* ptr) {
} }
inline WordCount SegmentReader::getSize() { return ptr.size() * WORDS; } inline WordCount SegmentReader::getSize() { return ptr.size() * WORDS; }
inline ArrayPtr<const word> SegmentReader::getArray() { return ptr; } inline ArrayPtr<const word> SegmentReader::getArray() { return ptr; }
inline void SegmentReader::unread(WordCount64 amount) { readLimiter->unread(amount); }
// ------------------------------------------------------------------- // -------------------------------------------------------------------
......
...@@ -70,6 +70,9 @@ TEST(Encoding, AllTypes) { ...@@ -70,6 +70,9 @@ TEST(Encoding, AllTypes) {
ASSERT_EQ(1u, builder.getSegmentsForOutput().size()); ASSERT_EQ(1u, builder.getSegmentsForOutput().size());
checkTestMessage(readMessageTrusted<TestAllTypes>(builder.getSegmentsForOutput()[0].begin())); checkTestMessage(readMessageTrusted<TestAllTypes>(builder.getSegmentsForOutput()[0].begin()));
EXPECT_EQ(builder.getSegmentsForOutput()[0].size() - 1, // -1 for root pointer
reader.getRoot<TestAllTypes>().totalSizeInWords());
} }
TEST(Encoding, AllTypesMultiSegment) { TEST(Encoding, AllTypesMultiSegment) {
......
...@@ -141,7 +141,7 @@ void ExceptionCallback::useProcessWide() { ...@@ -141,7 +141,7 @@ void ExceptionCallback::useProcessWide() {
globalCallback = this; globalCallback = this;
} }
ExceptionCallback::ScopedRegistration::ScopedRegistration(ExceptionCallback* callback) ExceptionCallback::ScopedRegistration::ScopedRegistration(ExceptionCallback& callback)
: callback(callback) { : callback(callback) {
old = threadLocalCallback; old = threadLocalCallback;
threadLocalCallback = this; threadLocalCallback = this;
...@@ -151,11 +151,11 @@ ExceptionCallback::ScopedRegistration::~ScopedRegistration() { ...@@ -151,11 +151,11 @@ ExceptionCallback::ScopedRegistration::~ScopedRegistration() {
threadLocalCallback = old; threadLocalCallback = old;
} }
ExceptionCallback* getExceptionCallback() { ExceptionCallback& getExceptionCallback() {
static ExceptionCallback defaultCallback; static ExceptionCallback defaultCallback;
ExceptionCallback::ScopedRegistration* scoped = threadLocalCallback; ExceptionCallback::ScopedRegistration* scoped = threadLocalCallback;
return scoped != nullptr ? scoped->getCallback() : return scoped != nullptr ? scoped->getCallback() :
globalCallback != nullptr ? globalCallback : &defaultCallback; globalCallback != nullptr ? *globalCallback : defaultCallback;
} }
} // namespace capnproto } // namespace capnproto
...@@ -131,18 +131,18 @@ public: ...@@ -131,18 +131,18 @@ public:
// callback will be restored. // callback will be restored.
public: public:
ScopedRegistration(ExceptionCallback* callback); ScopedRegistration(ExceptionCallback& callback);
~ScopedRegistration(); ~ScopedRegistration();
inline ExceptionCallback* getCallback() { return callback; } inline ExceptionCallback& getCallback() { return callback; }
private: private:
ExceptionCallback* callback; ExceptionCallback& callback;
ScopedRegistration* old; ScopedRegistration* old;
}; };
}; };
ExceptionCallback* getExceptionCallback(); ExceptionCallback& getExceptionCallback();
// Returns the current exception callback. // Returns the current exception callback.
} // namespace capnproto } // namespace capnproto
......
...@@ -418,6 +418,148 @@ struct WireHelpers { ...@@ -418,6 +418,148 @@ struct WireHelpers {
memset(ref, 0, sizeof(*ref)); memset(ref, 0, sizeof(*ref));
} }
// -----------------------------------------------------------------
static WordCount64 totalSize(SegmentReader* segment, const WirePointer* ref, uint nestingLimit) {
// Compute the total size of the object pointed to, not counting far pointer overhead.
if (ref->isNull()) {
return 0 * WORDS;
}
VALIDATE_INPUT(nestingLimit > 0, "Message is too deeply-nested.") {
return 0 * WORDS;
}
--nestingLimit;
const word* ptr;
if (segment == nullptr) {
ptr = ref->target();
} else {
ptr = followFars(ref, segment);
}
WordCount64 result = 0 * WORDS;
switch (ref->kind()) {
case WirePointer::STRUCT: {
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
break;
}
}
result += ref->structRef.wordSize();
const WirePointer* pointerSection =
reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get());
uint count = ref->structRef.ptrCount.get() / POINTERS;
for (uint i = 0; i < count; i++) {
result += totalSize(segment, pointerSection + i, nestingLimit);
}
break;
}
case WirePointer::LIST: {
switch (ref->listRef.elementSize()) {
case FieldSize::VOID:
// Nothing.
break;
case FieldSize::BIT:
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES: {
WordCount totalWords = roundUpToWords(
ElementCount64(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize()));
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + totalWords),
"Message contained out-of-bounds list pointer.") {
break;
}
}
result += totalWords;
break;
}
case FieldSize::POINTER: {
WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS);
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + count * WORDS_PER_POINTER),
"Message contained out-of-bounds list pointer.") {
break;
}
}
result += count * WORDS_PER_POINTER;
for (uint i = 0; i < count / POINTERS; i++) {
result += totalSize(segment, reinterpret_cast<const WirePointer*>(ptr) + i,
nestingLimit);
}
break;
}
case FieldSize::INLINE_COMPOSITE: {
WordCount wordCount = ref->listRef.inlineCompositeWordCount();
if (segment != nullptr) {
VALIDATE_INPUT(
segment->containsInterval(ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
"Message contained out-of-bounds list pointer.") {
break;
}
}
result += wordCount + POINTER_SIZE_IN_WORDS;
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount();
if (segment != nullptr) {
VALIDATE_INPUT(elementTag->kind() == WirePointer::STRUCT,
"Don't know how to handle non-STRUCT inline composite.") {
break;
}
VALIDATE_INPUT(elementTag->structRef.wordSize() / ELEMENTS * count <= wordCount,
"Struct list pointer's elements overran size.") {
break;
}
}
WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
const word* pos = ptr + POINTER_SIZE_IN_WORDS;
for (uint i = 0; i < count / ELEMENTS; i++) {
pos += dataSize;
for (uint j = 0; j < pointerCount / POINTERS; j++) {
result += totalSize(segment, reinterpret_cast<const WirePointer*>(pos),
nestingLimit);
pos += POINTER_SIZE_IN_WORDS;
}
}
break;
}
}
break;
}
case WirePointer::FAR:
FAIL_RECOVERABLE_CHECK("Unexpected FAR pointer.") {
break;
}
break;
case WirePointer::RESERVED_3:
FAIL_VALIDATE_INPUT("Don't know how to handle RESERVED_3.") {
break;
}
break;
}
return result;
}
// ----------------------------------------------------------------- // -----------------------------------------------------------------
static CAPNPROTO_ALWAYS_INLINE( static CAPNPROTO_ALWAYS_INLINE(
...@@ -1906,6 +2048,22 @@ bool StructReader::isPointerFieldNull(WirePointerCount ptrIndex) const { ...@@ -1906,6 +2048,22 @@ bool StructReader::isPointerFieldNull(WirePointerCount ptrIndex) const {
return (pointers + ptrIndex)->isNull(); return (pointers + ptrIndex)->isNull();
} }
WordCount64 StructReader::totalSize() const {
WordCount64 result = WireHelpers::roundUpToWords(dataSize) + pointerCount * WORDS_PER_POINTER;
for (uint i = 0; i < pointerCount / POINTERS; i++) {
result += WireHelpers::totalSize(segment, pointers + i, nestingLimit);
}
if (segment != nullptr) {
// This traversal should not count against the read limit, because it's highly likely that
// the caller is going to traverse the object again, e.g. to copy it.
segment->unread(result);
}
return result;
}
// ======================================================================================= // =======================================================================================
// ListBuilder // ListBuilder
......
...@@ -463,6 +463,13 @@ public: ...@@ -463,6 +463,13 @@ public:
bool isPointerFieldNull(WirePointerCount ptrIndex) const; bool isPointerFieldNull(WirePointerCount ptrIndex) const;
WordCount64 totalSize() const;
// Return the total size of the struct and everything to which it points. Does not count far
// pointer overhead. This is useful for deciding how much space is needed to copy the struct
// into a flat array. However, the caller is advised NOT to treat this value as secure. Instead,
// use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns.
private: private:
SegmentReader* segment; // Memory segment in which the struct resides. SegmentReader* segment; // Memory segment in which the struct resides.
......
...@@ -86,7 +86,7 @@ std::string fileLine(std::string file, int line) { ...@@ -86,7 +86,7 @@ std::string fileLine(std::string file, int line) {
TEST(Logging, Log) { TEST(Logging, Log) {
MockExceptionCallback mockCallback; MockExceptionCallback mockCallback;
MockExceptionCallback::ScopedRegistration reg(&mockCallback); MockExceptionCallback::ScopedRegistration reg(mockCallback);
int line; int line;
LOG(WARNING, "Hello world!"); line = __LINE__; LOG(WARNING, "Hello world!"); line = __LINE__;
...@@ -137,7 +137,7 @@ TEST(Logging, Log) { ...@@ -137,7 +137,7 @@ TEST(Logging, Log) {
TEST(Logging, Syscall) { TEST(Logging, Syscall) {
MockExceptionCallback mockCallback; MockExceptionCallback mockCallback;
MockExceptionCallback::ScopedRegistration reg(&mockCallback); MockExceptionCallback::ScopedRegistration reg(mockCallback);
int line; int line;
int i = 123; int i = 123;
......
...@@ -93,7 +93,7 @@ static Array<char> makeDescription(DescriptionStyle style, const char* code, int ...@@ -93,7 +93,7 @@ static Array<char> makeDescription(DescriptionStyle style, const char* code, int
++index; ++index;
if (index != argValues.size()) { if (index != argValues.size()) {
getExceptionCallback()->logMessage( getExceptionCallback().logMessage(
str(__FILE__, ":", __LINE__, ": Failed to parse logging macro args into ", str(__FILE__, ":", __LINE__, ": Failed to parse logging macro args into ",
argValues.size(), " names: ", macroArgs, '\n')); argValues.size(), " names: ", macroArgs, '\n'));
} }
...@@ -191,7 +191,7 @@ static Array<char> makeDescription(DescriptionStyle style, const char* code, int ...@@ -191,7 +191,7 @@ static Array<char> makeDescription(DescriptionStyle style, const char* code, int
void Log::logInternal(const char* file, int line, Severity severity, const char* macroArgs, void Log::logInternal(const char* file, int line, Severity severity, const char* macroArgs,
ArrayPtr<Array<char>> argValues) { ArrayPtr<Array<char>> argValues) {
getExceptionCallback()->logMessage( getExceptionCallback().logMessage(
str(severity, ": ", file, ":", line, ": ", str(severity, ": ", file, ":", line, ": ",
makeDescription(LOG, nullptr, 0, macroArgs, argValues), '\n')); makeDescription(LOG, nullptr, 0, macroArgs, argValues), '\n'));
} }
...@@ -199,7 +199,7 @@ void Log::logInternal(const char* file, int line, Severity severity, const char* ...@@ -199,7 +199,7 @@ void Log::logInternal(const char* file, int line, Severity severity, const char*
void Log::recoverableFaultInternal( void Log::recoverableFaultInternal(
const char* file, int line, Exception::Nature nature, const char* file, int line, Exception::Nature nature,
const char* condition, const char* macroArgs, ArrayPtr<Array<char>> argValues) { const char* condition, const char* macroArgs, ArrayPtr<Array<char>> argValues) {
getExceptionCallback()->onRecoverableException( getExceptionCallback().onRecoverableException(
Exception(nature, Exception::Durability::PERMANENT, file, line, Exception(nature, Exception::Durability::PERMANENT, file, line,
makeDescription(ASSERTION, condition, 0, macroArgs, argValues))); makeDescription(ASSERTION, condition, 0, macroArgs, argValues)));
} }
...@@ -207,7 +207,7 @@ void Log::recoverableFaultInternal( ...@@ -207,7 +207,7 @@ void Log::recoverableFaultInternal(
void Log::fatalFaultInternal( void Log::fatalFaultInternal(
const char* file, int line, Exception::Nature nature, const char* file, int line, Exception::Nature nature,
const char* condition, const char* macroArgs, ArrayPtr<Array<char>> argValues) { const char* condition, const char* macroArgs, ArrayPtr<Array<char>> argValues) {
getExceptionCallback()->onFatalException( getExceptionCallback().onFatalException(
Exception(nature, Exception::Durability::PERMANENT, file, line, Exception(nature, Exception::Durability::PERMANENT, file, line,
makeDescription(ASSERTION, condition, 0, macroArgs, argValues))); makeDescription(ASSERTION, condition, 0, macroArgs, argValues)));
abort(); abort();
...@@ -216,7 +216,7 @@ void Log::fatalFaultInternal( ...@@ -216,7 +216,7 @@ void Log::fatalFaultInternal(
void Log::recoverableFailedSyscallInternal( void Log::recoverableFailedSyscallInternal(
const char* file, int line, const char* call, const char* file, int line, const char* call,
int errorNumber, const char* macroArgs, ArrayPtr<Array<char>> argValues) { int errorNumber, const char* macroArgs, ArrayPtr<Array<char>> argValues) {
getExceptionCallback()->onRecoverableException( getExceptionCallback().onRecoverableException(
Exception(Exception::Nature::OS_ERROR, Exception::Durability::PERMANENT, file, line, Exception(Exception::Nature::OS_ERROR, Exception::Durability::PERMANENT, file, line,
makeDescription(SYSCALL, call, errorNumber, macroArgs, argValues))); makeDescription(SYSCALL, call, errorNumber, macroArgs, argValues)));
} }
...@@ -224,7 +224,7 @@ void Log::recoverableFailedSyscallInternal( ...@@ -224,7 +224,7 @@ void Log::recoverableFailedSyscallInternal(
void Log::fatalFailedSyscallInternal( void Log::fatalFailedSyscallInternal(
const char* file, int line, const char* call, const char* file, int line, const char* call,
int errorNumber, const char* macroArgs, ArrayPtr<Array<char>> argValues) { int errorNumber, const char* macroArgs, ArrayPtr<Array<char>> argValues) {
getExceptionCallback()->onFatalException( getExceptionCallback().onFatalException(
Exception(Exception::Nature::OS_ERROR, Exception::Durability::PERMANENT, file, line, Exception(Exception::Nature::OS_ERROR, Exception::Durability::PERMANENT, file, line,
makeDescription(SYSCALL, call, errorNumber, macroArgs, argValues))); makeDescription(SYSCALL, call, errorNumber, macroArgs, argValues)));
abort(); abort();
......
...@@ -170,6 +170,8 @@ private: ...@@ -170,6 +170,8 @@ private:
internal::StructBuilder initRoot(internal::StructSize size); internal::StructBuilder initRoot(internal::StructSize size);
void setRootInternal(internal::StructReader reader); void setRootInternal(internal::StructReader reader);
internal::StructBuilder getRoot(internal::StructSize size); internal::StructBuilder getRoot(internal::StructSize size);
friend struct SchemaLoader; // for a dirty hack, see schema-loader.c++.
}; };
template <typename RootType> template <typename RootType>
......
...@@ -76,6 +76,7 @@ struct {{typeFullName}} { ...@@ -76,6 +76,7 @@ struct {{typeFullName}} {
}; };
{{/typeUnion}} {{/typeUnion}}
private:
{{#typeFields}} {{#typeFields}}
{{#fieldDefaultBytes}} {{#fieldDefaultBytes}}
static const ::capnproto::internal::AlignedData<{{defaultWordCount}}> DEFAULT_{{fieldUpperCase}}; static const ::capnproto::internal::AlignedData<{{defaultWordCount}}> DEFAULT_{{fieldUpperCase}};
...@@ -144,9 +145,12 @@ public: ...@@ -144,9 +145,12 @@ public:
inline explicit Reader(::capnproto::internal::StructReader base): _reader(base) {} inline explicit Reader(::capnproto::internal::StructReader base): _reader(base) {}
{{#typeStruct}} {{#typeStruct}}
::capnproto::String debugString() { inline ::capnproto::String debugString() {
return ::capnproto::internal::debugString<{{typeName}}>(_reader); return ::capnproto::internal::debugString<{{typeName}}>(_reader);
} }
inline size_t totalSizeInWords() {
return _reader.totalSize() / ::capnproto::WORDS;
}
{{#structUnions}} {{#structUnions}}
// {{unionDecl}} // {{unionDecl}}
...@@ -194,7 +198,8 @@ public: ...@@ -194,7 +198,8 @@ public:
inline Reader asReader() { return *this; } inline Reader asReader() { return *this; }
{{#typeStruct}} {{#typeStruct}}
::capnproto::String debugString() { return asReader().debugString(); } inline ::capnproto::String debugString() { return asReader().debugString(); }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); }
{{#structUnions}} {{#structUnions}}
// {{unionDecl}} // {{unionDecl}}
......
...@@ -62,7 +62,6 @@ static const ::capnproto::internal::RawSchema* const d_{{schemaId}}[] = { ...@@ -62,7 +62,6 @@ static const ::capnproto::internal::RawSchema* const d_{{schemaId}}[] = {
{{#schemaDependencies}} {{#schemaDependencies}}
&s_{{dependencyId}}, &s_{{dependencyId}},
{{/schemaDependencies}} {{/schemaDependencies}}
nullptr
}; };
static const ::capnproto::internal::RawSchema::MemberInfo m_{{schemaId}}[] = { static const ::capnproto::internal::RawSchema::MemberInfo m_{{schemaId}}[] = {
{{#schemaMembersByName}} {{#schemaMembersByName}}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment