Commit d7081cca authored by Kenton Varda's avatar Kenton Varda

totalSizeInWords() method on generated struct readers/builders returns total…

totalSizeInWords() method on generated struct readers/builders returns total size of the struct, not counting far pointer overhead.  Useful for allocating space for a flat copy.  Also a couple misc tweaks.
parent ecdd2ae4
......@@ -34,6 +34,16 @@ namespace internal {
Arena::~Arena() {}
void ReadLimiter::unread(WordCount64 amount) {
// Be careful not to overflow here. Since ReadLimiter has no thread-safety, it's possible that
// the limit value was not updated correctly for one or more reads, and therefore unread() could
// overflow it even if it is only unreading bytes that were acutally read.
WordCount64 newValue = limit + amount;
if (newValue > limit) {
limit = newValue;
}
}
// =======================================================================================
ReaderArena::ReaderArena(MessageReader* message)
......
......@@ -71,6 +71,10 @@ public:
CAPNPROTO_ALWAYS_INLINE(bool canRead(WordCount amount, Arena* arena));
void unread(WordCount64 amount);
// Adds back some words to the limit. Useful when the caller knows they are double-reading
// some data.
private:
WordCount64 limit;
......@@ -93,6 +97,9 @@ public:
inline ArrayPtr<const word> getArray();
inline void unread(WordCount64 amount);
// Add back some words to the ReadLimiter.
private:
Arena* arena;
SegmentId id;
......@@ -242,6 +249,7 @@ inline WordCount SegmentReader::getOffsetTo(const word* ptr) {
}
inline WordCount SegmentReader::getSize() { return ptr.size() * WORDS; }
inline ArrayPtr<const word> SegmentReader::getArray() { return ptr; }
inline void SegmentReader::unread(WordCount64 amount) { readLimiter->unread(amount); }
// -------------------------------------------------------------------
......
......@@ -70,6 +70,9 @@ TEST(Encoding, AllTypes) {
ASSERT_EQ(1u, builder.getSegmentsForOutput().size());
checkTestMessage(readMessageTrusted<TestAllTypes>(builder.getSegmentsForOutput()[0].begin()));
EXPECT_EQ(builder.getSegmentsForOutput()[0].size() - 1, // -1 for root pointer
reader.getRoot<TestAllTypes>().totalSizeInWords());
}
TEST(Encoding, AllTypesMultiSegment) {
......
......@@ -141,7 +141,7 @@ void ExceptionCallback::useProcessWide() {
globalCallback = this;
}
ExceptionCallback::ScopedRegistration::ScopedRegistration(ExceptionCallback* callback)
ExceptionCallback::ScopedRegistration::ScopedRegistration(ExceptionCallback& callback)
: callback(callback) {
old = threadLocalCallback;
threadLocalCallback = this;
......@@ -151,11 +151,11 @@ ExceptionCallback::ScopedRegistration::~ScopedRegistration() {
threadLocalCallback = old;
}
ExceptionCallback* getExceptionCallback() {
ExceptionCallback& getExceptionCallback() {
static ExceptionCallback defaultCallback;
ExceptionCallback::ScopedRegistration* scoped = threadLocalCallback;
return scoped != nullptr ? scoped->getCallback() :
globalCallback != nullptr ? globalCallback : &defaultCallback;
globalCallback != nullptr ? *globalCallback : defaultCallback;
}
} // namespace capnproto
......@@ -131,18 +131,18 @@ public:
// callback will be restored.
public:
ScopedRegistration(ExceptionCallback* callback);
ScopedRegistration(ExceptionCallback& callback);
~ScopedRegistration();
inline ExceptionCallback* getCallback() { return callback; }
inline ExceptionCallback& getCallback() { return callback; }
private:
ExceptionCallback* callback;
ExceptionCallback& callback;
ScopedRegistration* old;
};
};
ExceptionCallback* getExceptionCallback();
ExceptionCallback& getExceptionCallback();
// Returns the current exception callback.
} // namespace capnproto
......
......@@ -418,6 +418,148 @@ struct WireHelpers {
memset(ref, 0, sizeof(*ref));
}
// -----------------------------------------------------------------
static WordCount64 totalSize(SegmentReader* segment, const WirePointer* ref, uint nestingLimit) {
// Compute the total size of the object pointed to, not counting far pointer overhead.
if (ref->isNull()) {
return 0 * WORDS;
}
VALIDATE_INPUT(nestingLimit > 0, "Message is too deeply-nested.") {
return 0 * WORDS;
}
--nestingLimit;
const word* ptr;
if (segment == nullptr) {
ptr = ref->target();
} else {
ptr = followFars(ref, segment);
}
WordCount64 result = 0 * WORDS;
switch (ref->kind()) {
case WirePointer::STRUCT: {
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
break;
}
}
result += ref->structRef.wordSize();
const WirePointer* pointerSection =
reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get());
uint count = ref->structRef.ptrCount.get() / POINTERS;
for (uint i = 0; i < count; i++) {
result += totalSize(segment, pointerSection + i, nestingLimit);
}
break;
}
case WirePointer::LIST: {
switch (ref->listRef.elementSize()) {
case FieldSize::VOID:
// Nothing.
break;
case FieldSize::BIT:
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES: {
WordCount totalWords = roundUpToWords(
ElementCount64(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize()));
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + totalWords),
"Message contained out-of-bounds list pointer.") {
break;
}
}
result += totalWords;
break;
}
case FieldSize::POINTER: {
WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS);
if (segment != nullptr) {
VALIDATE_INPUT(segment->containsInterval(ptr, ptr + count * WORDS_PER_POINTER),
"Message contained out-of-bounds list pointer.") {
break;
}
}
result += count * WORDS_PER_POINTER;
for (uint i = 0; i < count / POINTERS; i++) {
result += totalSize(segment, reinterpret_cast<const WirePointer*>(ptr) + i,
nestingLimit);
}
break;
}
case FieldSize::INLINE_COMPOSITE: {
WordCount wordCount = ref->listRef.inlineCompositeWordCount();
if (segment != nullptr) {
VALIDATE_INPUT(
segment->containsInterval(ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
"Message contained out-of-bounds list pointer.") {
break;
}
}
result += wordCount + POINTER_SIZE_IN_WORDS;
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount();
if (segment != nullptr) {
VALIDATE_INPUT(elementTag->kind() == WirePointer::STRUCT,
"Don't know how to handle non-STRUCT inline composite.") {
break;
}
VALIDATE_INPUT(elementTag->structRef.wordSize() / ELEMENTS * count <= wordCount,
"Struct list pointer's elements overran size.") {
break;
}
}
WordCount dataSize = elementTag->structRef.dataSize.get();
WirePointerCount pointerCount = elementTag->structRef.ptrCount.get();
const word* pos = ptr + POINTER_SIZE_IN_WORDS;
for (uint i = 0; i < count / ELEMENTS; i++) {
pos += dataSize;
for (uint j = 0; j < pointerCount / POINTERS; j++) {
result += totalSize(segment, reinterpret_cast<const WirePointer*>(pos),
nestingLimit);
pos += POINTER_SIZE_IN_WORDS;
}
}
break;
}
}
break;
}
case WirePointer::FAR:
FAIL_RECOVERABLE_CHECK("Unexpected FAR pointer.") {
break;
}
break;
case WirePointer::RESERVED_3:
FAIL_VALIDATE_INPUT("Don't know how to handle RESERVED_3.") {
break;
}
break;
}
return result;
}
// -----------------------------------------------------------------
static CAPNPROTO_ALWAYS_INLINE(
......@@ -1906,6 +2048,22 @@ bool StructReader::isPointerFieldNull(WirePointerCount ptrIndex) const {
return (pointers + ptrIndex)->isNull();
}
WordCount64 StructReader::totalSize() const {
WordCount64 result = WireHelpers::roundUpToWords(dataSize) + pointerCount * WORDS_PER_POINTER;
for (uint i = 0; i < pointerCount / POINTERS; i++) {
result += WireHelpers::totalSize(segment, pointers + i, nestingLimit);
}
if (segment != nullptr) {
// This traversal should not count against the read limit, because it's highly likely that
// the caller is going to traverse the object again, e.g. to copy it.
segment->unread(result);
}
return result;
}
// =======================================================================================
// ListBuilder
......
......@@ -463,6 +463,13 @@ public:
bool isPointerFieldNull(WirePointerCount ptrIndex) const;
WordCount64 totalSize() const;
// Return the total size of the struct and everything to which it points. Does not count far
// pointer overhead. This is useful for deciding how much space is needed to copy the struct
// into a flat array. However, the caller is advised NOT to treat this value as secure. Instead,
// use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns.
private:
SegmentReader* segment; // Memory segment in which the struct resides.
......
......@@ -86,7 +86,7 @@ std::string fileLine(std::string file, int line) {
TEST(Logging, Log) {
MockExceptionCallback mockCallback;
MockExceptionCallback::ScopedRegistration reg(&mockCallback);
MockExceptionCallback::ScopedRegistration reg(mockCallback);
int line;
LOG(WARNING, "Hello world!"); line = __LINE__;
......@@ -137,7 +137,7 @@ TEST(Logging, Log) {
TEST(Logging, Syscall) {
MockExceptionCallback mockCallback;
MockExceptionCallback::ScopedRegistration reg(&mockCallback);
MockExceptionCallback::ScopedRegistration reg(mockCallback);
int line;
int i = 123;
......
......@@ -93,7 +93,7 @@ static Array<char> makeDescription(DescriptionStyle style, const char* code, int
++index;
if (index != argValues.size()) {
getExceptionCallback()->logMessage(
getExceptionCallback().logMessage(
str(__FILE__, ":", __LINE__, ": Failed to parse logging macro args into ",
argValues.size(), " names: ", macroArgs, '\n'));
}
......@@ -191,7 +191,7 @@ static Array<char> makeDescription(DescriptionStyle style, const char* code, int
void Log::logInternal(const char* file, int line, Severity severity, const char* macroArgs,
ArrayPtr<Array<char>> argValues) {
getExceptionCallback()->logMessage(
getExceptionCallback().logMessage(
str(severity, ": ", file, ":", line, ": ",
makeDescription(LOG, nullptr, 0, macroArgs, argValues), '\n'));
}
......@@ -199,7 +199,7 @@ void Log::logInternal(const char* file, int line, Severity severity, const char*
void Log::recoverableFaultInternal(
const char* file, int line, Exception::Nature nature,
const char* condition, const char* macroArgs, ArrayPtr<Array<char>> argValues) {
getExceptionCallback()->onRecoverableException(
getExceptionCallback().onRecoverableException(
Exception(nature, Exception::Durability::PERMANENT, file, line,
makeDescription(ASSERTION, condition, 0, macroArgs, argValues)));
}
......@@ -207,7 +207,7 @@ void Log::recoverableFaultInternal(
void Log::fatalFaultInternal(
const char* file, int line, Exception::Nature nature,
const char* condition, const char* macroArgs, ArrayPtr<Array<char>> argValues) {
getExceptionCallback()->onFatalException(
getExceptionCallback().onFatalException(
Exception(nature, Exception::Durability::PERMANENT, file, line,
makeDescription(ASSERTION, condition, 0, macroArgs, argValues)));
abort();
......@@ -216,7 +216,7 @@ void Log::fatalFaultInternal(
void Log::recoverableFailedSyscallInternal(
const char* file, int line, const char* call,
int errorNumber, const char* macroArgs, ArrayPtr<Array<char>> argValues) {
getExceptionCallback()->onRecoverableException(
getExceptionCallback().onRecoverableException(
Exception(Exception::Nature::OS_ERROR, Exception::Durability::PERMANENT, file, line,
makeDescription(SYSCALL, call, errorNumber, macroArgs, argValues)));
}
......@@ -224,7 +224,7 @@ void Log::recoverableFailedSyscallInternal(
void Log::fatalFailedSyscallInternal(
const char* file, int line, const char* call,
int errorNumber, const char* macroArgs, ArrayPtr<Array<char>> argValues) {
getExceptionCallback()->onFatalException(
getExceptionCallback().onFatalException(
Exception(Exception::Nature::OS_ERROR, Exception::Durability::PERMANENT, file, line,
makeDescription(SYSCALL, call, errorNumber, macroArgs, argValues)));
abort();
......
......@@ -170,6 +170,8 @@ private:
internal::StructBuilder initRoot(internal::StructSize size);
void setRootInternal(internal::StructReader reader);
internal::StructBuilder getRoot(internal::StructSize size);
friend struct SchemaLoader; // for a dirty hack, see schema-loader.c++.
};
template <typename RootType>
......
......@@ -76,6 +76,7 @@ struct {{typeFullName}} {
};
{{/typeUnion}}
private:
{{#typeFields}}
{{#fieldDefaultBytes}}
static const ::capnproto::internal::AlignedData<{{defaultWordCount}}> DEFAULT_{{fieldUpperCase}};
......@@ -144,9 +145,12 @@ public:
inline explicit Reader(::capnproto::internal::StructReader base): _reader(base) {}
{{#typeStruct}}
::capnproto::String debugString() {
inline ::capnproto::String debugString() {
return ::capnproto::internal::debugString<{{typeName}}>(_reader);
}
inline size_t totalSizeInWords() {
return _reader.totalSize() / ::capnproto::WORDS;
}
{{#structUnions}}
// {{unionDecl}}
......@@ -194,7 +198,8 @@ public:
inline Reader asReader() { return *this; }
{{#typeStruct}}
::capnproto::String debugString() { return asReader().debugString(); }
inline ::capnproto::String debugString() { return asReader().debugString(); }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); }
{{#structUnions}}
// {{unionDecl}}
......
......@@ -62,7 +62,6 @@ static const ::capnproto::internal::RawSchema* const d_{{schemaId}}[] = {
{{#schemaDependencies}}
&s_{{dependencyId}},
{{/schemaDependencies}}
nullptr
};
static const ::capnproto::internal::RawSchema::MemberInfo m_{{schemaId}}[] = {
{{#schemaMembersByName}}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment