Commit 7a958a11 authored by Kenton Varda's avatar Kenton Varda

Extend totalSizeInWords() to also return a count of capabilities, which helps…

Extend totalSizeInWords() to also return a count of capabilities, which helps when a separate capability table needs to be allocated as well.  Use this in the RPC system.
parent 94f43211
...@@ -52,8 +52,8 @@ struct AnyPointer { ...@@ -52,8 +52,8 @@ struct AnyPointer {
Reader() = default; Reader() = default;
inline Reader(_::PointerReader reader): reader(reader) {} inline Reader(_::PointerReader reader): reader(reader) {}
inline size_t targetSizeInWords() const; inline MessageSize targetSize() const;
// Get the total size, in words, of the target object and all its children. // Get the total size of the target object and all its children.
inline bool isNull() const; inline bool isNull() const;
...@@ -92,8 +92,8 @@ struct AnyPointer { ...@@ -92,8 +92,8 @@ struct AnyPointer {
inline Builder(decltype(nullptr)) {} inline Builder(decltype(nullptr)) {}
inline Builder(_::PointerBuilder builder): builder(builder) {} inline Builder(_::PointerBuilder builder): builder(builder) {}
inline size_t targetSizeInWords() const; inline MessageSize targetSize() const;
// Get the total size, in words, of the target object and all its children. // Get the total size of the target object and all its children.
inline bool isNull(); inline bool isNull();
...@@ -325,8 +325,8 @@ public: ...@@ -325,8 +325,8 @@ public:
// ======================================================================================= // =======================================================================================
// Inline implementation details // Inline implementation details
inline size_t AnyPointer::Reader::targetSizeInWords() const { inline MessageSize AnyPointer::Reader::targetSize() const {
return reader.targetSize() / WORDS; return reader.targetSize().asPublic();
} }
inline bool AnyPointer::Reader::isNull() const { inline bool AnyPointer::Reader::isNull() const {
...@@ -338,8 +338,8 @@ inline ReaderFor<T> AnyPointer::Reader::getAs() const { ...@@ -338,8 +338,8 @@ inline ReaderFor<T> AnyPointer::Reader::getAs() const {
return _::PointerHelpers<T>::get(reader); return _::PointerHelpers<T>::get(reader);
} }
inline size_t AnyPointer::Builder::targetSizeInWords() const { inline MessageSize AnyPointer::Builder::targetSize() const {
return asReader().targetSizeInWords(); return asReader().targetSize();
} }
inline bool AnyPointer::Builder::isNull() { inline bool AnyPointer::Builder::isNull() {
......
...@@ -90,8 +90,21 @@ kj::ArrayPtr<kj::Own<ClientHook>> CapBuilderContext::getCapTable() { ...@@ -90,8 +90,21 @@ kj::ArrayPtr<kj::Own<ClientHook>> CapBuilderContext::getCapTable() {
// ======================================================================================= // =======================================================================================
LocalMessage::LocalMessage(uint firstSegmentWords, AllocationStrategy allocationStrategy) namespace {
: message(firstSegmentWords, allocationStrategy),
uint firstSegmentSize(kj::Maybe<MessageSize> sizeHint) {
KJ_IF_MAYBE(s, sizeHint) {
// 1 for the root pointer. We don't store caps in the message so we don't count those here.
return s->wordCount + 1;
} else {
return SUGGESTED_FIRST_SEGMENT_WORDS;
}
}
} // namespace
LocalMessage::LocalMessage(kj::Maybe<MessageSize> sizeHint)
: message(firstSegmentSize(sizeHint)),
root(capContext.imbue(message.getRoot<AnyPointer>())) {} root(capContext.imbue(message.getRoot<AnyPointer>())) {}
// ======================================================================================= // =======================================================================================
...@@ -114,9 +127,8 @@ private: ...@@ -114,9 +127,8 @@ private:
class BrokenRequest final: public RequestHook { class BrokenRequest final: public RequestHook {
public: public:
BrokenRequest(const kj::Exception& exception, uint firstSegmentWordSize) BrokenRequest(const kj::Exception& exception, kj::Maybe<MessageSize> sizeHint)
: exception(exception), : exception(exception), message(sizeHint) {}
message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize) {}
RemotePromise<AnyPointer> send() override { RemotePromise<AnyPointer> send() override {
return RemotePromise<AnyPointer>(kj::cp(exception), return RemotePromise<AnyPointer>(kj::cp(exception),
...@@ -139,8 +151,8 @@ public: ...@@ -139,8 +151,8 @@ public:
"", 0, kj::str(description)) {} "", 0, kj::str(description)) {}
Request<AnyPointer, AnyPointer> newCall( Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<BrokenRequest>(exception, firstSegmentWordSize); auto hook = kj::heap<BrokenRequest>(exception, sizeHint);
auto root = hook->message.getRoot(); auto root = hook->message.getRoot();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
} }
......
...@@ -105,8 +105,7 @@ class LocalMessage final { ...@@ -105,8 +105,7 @@ class LocalMessage final {
// know how to properly serialize its capabilities. // know how to properly serialize its capabilities.
public: public:
LocalMessage(uint firstSegmentWords = SUGGESTED_FIRST_SEGMENT_WORDS, LocalMessage(kj::Maybe<MessageSize> sizeHint = nullptr);
AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY);
inline AnyPointer::Builder getRoot() { return root; } inline AnyPointer::Builder getRoot() { return root; }
inline AnyPointer::Reader getRootReader() const { return root.asReader(); } inline AnyPointer::Reader getRootReader() const { return root.asReader(); }
......
...@@ -88,8 +88,8 @@ kj::Promise<void> ClientHook::whenResolved() { ...@@ -88,8 +88,8 @@ kj::Promise<void> ClientHook::whenResolved() {
class LocalResponse final: public ResponseHook, public kj::Refcounted { class LocalResponse final: public ResponseHook, public kj::Refcounted {
public: public:
LocalResponse(uint sizeHint) LocalResponse(kj::Maybe<MessageSize> sizeHint)
: message(sizeHint == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : sizeHint) {} : message(sizeHint) {}
LocalMessage message; LocalMessage message;
}; };
...@@ -111,9 +111,9 @@ public: ...@@ -111,9 +111,9 @@ public:
void releaseParams() override { void releaseParams() override {
request = nullptr; request = nullptr;
} }
AnyPointer::Builder getResults(uint firstSegmentWordSize) override { AnyPointer::Builder getResults(kj::Maybe<MessageSize> sizeHint) override {
if (response == nullptr) { if (response == nullptr) {
auto localResponse = kj::refcounted<LocalResponse>(firstSegmentWordSize); auto localResponse = kj::refcounted<LocalResponse>(sizeHint);
responseBuilder = localResponse->message.getRoot(); responseBuilder = localResponse->message.getRoot();
response = Response<AnyPointer>(responseBuilder.asReader(), kj::mv(localResponse)); response = Response<AnyPointer>(responseBuilder.asReader(), kj::mv(localResponse));
} }
...@@ -162,9 +162,8 @@ public: ...@@ -162,9 +162,8 @@ public:
class LocalRequest final: public RequestHook { class LocalRequest final: public RequestHook {
public: public:
inline LocalRequest(uint64_t interfaceId, uint16_t methodId, inline LocalRequest(uint64_t interfaceId, uint16_t methodId,
uint firstSegmentWordSize, kj::Own<ClientHook> client) kj::Maybe<MessageSize> sizeHint, kj::Own<ClientHook> client)
: message(kj::heap<LocalMessage>( : message(kj::heap<LocalMessage>(sizeHint)),
firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize)),
interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {} interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {}
RemotePromise<AnyPointer> send() override { RemotePromise<AnyPointer> send() override {
...@@ -194,7 +193,7 @@ public: ...@@ -194,7 +193,7 @@ public:
// Now the other branch returns the response from the context. // Now the other branch returns the response from the context.
auto promise = forked.addBranch().then(kj::mvCapture(context, auto promise = forked.addBranch().then(kj::mvCapture(context,
[](kj::Own<LocalCallContext>&& context) { [](kj::Own<LocalCallContext>&& context) {
context->getResults(1); // force response allocation context->getResults(MessageSize { 0, 0 }); // force response allocation
return kj::mv(KJ_ASSERT_NONNULL(context->response)); return kj::mv(KJ_ASSERT_NONNULL(context->response));
})); }));
...@@ -274,9 +273,9 @@ public: ...@@ -274,9 +273,9 @@ public:
promiseForClientResolution(promise.addBranch().fork()) {} promiseForClientResolution(promise.addBranch().fork()) {}
Request<AnyPointer, AnyPointer> newCall( Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<LocalRequest>( auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this)); interfaceId, methodId, sizeHint, kj::addRef(*this));
auto root = hook->message->getRoot(); // Do not inline `root` -- kj::mv may happen first. auto root = hook->message->getRoot(); // Do not inline `root` -- kj::mv may happen first.
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
} }
...@@ -403,7 +402,7 @@ class LocalPipeline final: public PipelineHook, public kj::Refcounted { ...@@ -403,7 +402,7 @@ class LocalPipeline final: public PipelineHook, public kj::Refcounted {
public: public:
inline LocalPipeline(kj::Own<CallContextHook>&& contextParam) inline LocalPipeline(kj::Own<CallContextHook>&& contextParam)
: context(kj::mv(contextParam)), : context(kj::mv(contextParam)),
results(context->getResults(1)) {} results(context->getResults(MessageSize { 0, 0 })) {}
kj::Own<PipelineHook> addRef() { kj::Own<PipelineHook> addRef() {
return kj::addRef(*this); return kj::addRef(*this);
...@@ -424,9 +423,9 @@ public: ...@@ -424,9 +423,9 @@ public:
: server(kj::mv(server)) {} : server(kj::mv(server)) {}
Request<AnyPointer, AnyPointer> newCall( Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<LocalRequest>( auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this)); interfaceId, methodId, sizeHint, kj::addRef(*this));
auto root = hook->message->getRoot(); // Do not inline `root` -- kj::mv may happen first. auto root = hook->message->getRoot(); // Do not inline `root` -- kj::mv may happen first.
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
} }
......
...@@ -166,7 +166,7 @@ protected: ...@@ -166,7 +166,7 @@ protected:
template <typename Params, typename Results> template <typename Params, typename Results>
Request<Params, Results> newCall(uint64_t interfaceId, uint16_t methodId, Request<Params, Results> newCall(uint64_t interfaceId, uint16_t methodId,
uint firstSegmentWordSize); kj::Maybe<MessageSize> sizeHint);
private: private:
kj::Own<ClientHook> hook; kj::Own<ClientHook> hook;
...@@ -209,19 +209,22 @@ public: ...@@ -209,19 +209,22 @@ public:
// requests. Long-running asynchronous methods should try to call this as early as is // requests. Long-running asynchronous methods should try to call this as early as is
// convenient. // convenient.
typename Results::Builder getResults(uint firstSegmentWordSize = 0); typename Results::Builder getResults(kj::Maybe<MessageSize> sizeHint = nullptr);
typename Results::Builder initResults(uint firstSegmentWordSize = 0); typename Results::Builder initResults(kj::Maybe<MessageSize> sizeHint = nullptr);
void setResults(typename Results::Reader value); void setResults(typename Results::Reader value);
void adoptResults(Orphan<Results>&& value); void adoptResults(Orphan<Results>&& value);
Orphanage getResultsOrphanage(uint firstSegmentWordSize = 0); Orphanage getResultsOrphanage(kj::Maybe<MessageSize> sizeHint = nullptr);
// Manipulate the results payload. The "Return" message (part of the RPC protocol) will // Manipulate the results payload. The "Return" message (part of the RPC protocol) will
// typically be allocated the first time one of these is called. Some RPC systems may // typically be allocated the first time one of these is called. Some RPC systems may
// allocate these messages in a limited space (such as a shared memory segment), therefore the // allocate these messages in a limited space (such as a shared memory segment), therefore the
// application should delay calling these as long as is convenient to do so (but don't delay // application should delay calling these as long as is convenient to do so (but don't delay
// if doing so would require extra copies later). // if doing so would require extra copies later).
// //
// `firstSegmentWordSize` indicates the suggested size of the message's first segment. This // `sizeHint` indicates a guess at the message size. This will usually be used to decide how
// is a hint only. If not specified, the system will decide on its own. // much space to allocate for the first message segment (don't worry: only space that is actually
// used will be sent on the wire). If omitted, the system decides. The message root pointer
// should not be included in the size. So, if you are simply going to copy some existing message
// directly into the results, just call `.totalSize()` and pass that in.
template <typename SubParams> template <typename SubParams>
kj::Promise<void> tailCall(Request<SubParams, Results>&& tailRequest); kj::Promise<void> tailCall(Request<SubParams, Results>&& tailRequest);
...@@ -335,7 +338,7 @@ public: ...@@ -335,7 +338,7 @@ public:
class ClientHook { class ClientHook {
public: public:
virtual Request<AnyPointer, AnyPointer> newCall( virtual Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) = 0; uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) = 0;
// Start a new call, allowing the client to allocate request/response objects as it sees fit. // Start a new call, allowing the client to allocate request/response objects as it sees fit.
// This version is used when calls are made from application code in the local process. // This version is used when calls are made from application code in the local process.
...@@ -391,7 +394,7 @@ class CallContextHook { ...@@ -391,7 +394,7 @@ class CallContextHook {
public: public:
virtual AnyPointer::Reader getParams() = 0; virtual AnyPointer::Reader getParams() = 0;
virtual void releaseParams() = 0; virtual void releaseParams() = 0;
virtual AnyPointer::Builder getResults(uint firstSegmentWordSize) = 0; virtual AnyPointer::Builder getResults(kj::Maybe<MessageSize> sizeHint) = 0;
virtual kj::Promise<void> tailCall(kj::Own<RequestHook>&& request) = 0; virtual kj::Promise<void> tailCall(kj::Own<RequestHook>&& request) = 0;
virtual void allowCancellation() = 0; virtual void allowCancellation() = 0;
...@@ -577,8 +580,8 @@ inline kj::Promise<void> Capability::Client::whenResolved() { ...@@ -577,8 +580,8 @@ inline kj::Promise<void> Capability::Client::whenResolved() {
} }
template <typename Params, typename Results> template <typename Params, typename Results>
inline Request<Params, Results> Capability::Client::newCall( inline Request<Params, Results> Capability::Client::newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) {
auto typeless = hook->newCall(interfaceId, methodId, firstSegmentWordSize); auto typeless = hook->newCall(interfaceId, methodId, sizeHint);
return Request<Params, Results>(typeless.template getAs<Params>(), kj::mv(typeless.hook)); return Request<Params, Results>(typeless.template getAs<Params>(), kj::mv(typeless.hook));
} }
...@@ -594,27 +597,28 @@ inline void CallContext<Params, Results>::releaseParams() { ...@@ -594,27 +597,28 @@ inline void CallContext<Params, Results>::releaseParams() {
} }
template <typename Params, typename Results> template <typename Params, typename Results>
inline typename Results::Builder CallContext<Params, Results>::getResults( inline typename Results::Builder CallContext<Params, Results>::getResults(
uint firstSegmentWordSize) { kj::Maybe<MessageSize> sizeHint) {
// `template` keyword needed due to: http://llvm.org/bugs/show_bug.cgi?id=17401 // `template` keyword needed due to: http://llvm.org/bugs/show_bug.cgi?id=17401
return hook->getResults(firstSegmentWordSize).template getAs<Results>(); return hook->getResults(sizeHint).template getAs<Results>();
} }
template <typename Params, typename Results> template <typename Params, typename Results>
inline typename Results::Builder CallContext<Params, Results>::initResults( inline typename Results::Builder CallContext<Params, Results>::initResults(
uint firstSegmentWordSize) { kj::Maybe<MessageSize> sizeHint) {
// `template` keyword needed due to: http://llvm.org/bugs/show_bug.cgi?id=17401 // `template` keyword needed due to: http://llvm.org/bugs/show_bug.cgi?id=17401
return hook->getResults(firstSegmentWordSize).template initAs<Results>(); return hook->getResults(sizeHint).template initAs<Results>();
} }
template <typename Params, typename Results> template <typename Params, typename Results>
inline void CallContext<Params, Results>::setResults(typename Results::Reader value) { inline void CallContext<Params, Results>::setResults(typename Results::Reader value) {
hook->getResults(value.totalSizeInWords() + 1).set(value); hook->getResults(value.totalSize()).set(value);
} }
template <typename Params, typename Results> template <typename Params, typename Results>
inline void CallContext<Params, Results>::adoptResults(Orphan<Results>&& value) { inline void CallContext<Params, Results>::adoptResults(Orphan<Results>&& value) {
hook->getResults(0).adopt(kj::mv(value)); hook->getResults(nullptr).adopt(kj::mv(value));
} }
template <typename Params, typename Results> template <typename Params, typename Results>
inline Orphanage CallContext<Params, Results>::getResultsOrphanage(uint firstSegmentWordSize) { inline Orphanage CallContext<Params, Results>::getResultsOrphanage(
return Orphanage::getForMessageContaining(hook->getResults(firstSegmentWordSize)); kj::Maybe<MessageSize> sizeHint) {
return Orphanage::getForMessageContaining(hook->getResults(sizeHint));
} }
template <typename Params, typename Results> template <typename Params, typename Results>
template <typename SubParams> template <typename SubParams>
......
...@@ -166,6 +166,12 @@ template <typename T, Kind k = kind<T>()> ...@@ -166,6 +166,12 @@ template <typename T, Kind k = kind<T>()>
struct PointerHelpers; struct PointerHelpers;
} // namespace _ (private) } // namespace _ (private)
struct MessageSize {
// Size of a message. Every struct type has a method `.totalSize()` that returns this.
uint64_t wordCount;
uint capCount;
};
// ======================================================================================= // =======================================================================================
// Raw memory types and measures // Raw memory types and measures
......
...@@ -1196,7 +1196,7 @@ private: ...@@ -1196,7 +1196,7 @@ private:
void writeFlat(DynamicStruct::Reader value, kj::BufferedOutputStream& output) { void writeFlat(DynamicStruct::Reader value, kj::BufferedOutputStream& output) {
// Always copy the message to a flat array so that the output is predictable (one segment, // Always copy the message to a flat array so that the output is predictable (one segment,
// in canonical order). // in canonical order).
size_t size = value.totalSizeInWords() + 1; size_t size = value.totalSize().wordCount + 1;
kj::Array<word> space = kj::heapArray<word>(size); kj::Array<word> space = kj::heapArray<word>(size);
memset(space.begin(), 0, size * sizeof(word)); memset(space.begin(), 0, size * sizeof(word));
FlatMessageBuilder flatMessage(space); FlatMessageBuilder flatMessage(space);
......
...@@ -1048,8 +1048,8 @@ private: ...@@ -1048,8 +1048,8 @@ private:
" Reader() = default;\n" " Reader() = default;\n"
" inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}\n" " inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}\n"
"\n" "\n"
" inline size_t totalSizeInWords() const {\n" " inline ::capnp::MessageSize totalSize() const {\n"
" return _reader.totalSize() / ::capnp::WORDS;\n" " return _reader.totalSize().asPublic();\n"
" }\n" " }\n"
"\n", "\n",
isUnion ? kj::strTree(" inline Which which() const;\n") : kj::strTree(), isUnion ? kj::strTree(" inline Which which() const;\n") : kj::strTree(),
...@@ -1087,7 +1087,7 @@ private: ...@@ -1087,7 +1087,7 @@ private:
" inline operator Reader() const { return Reader(_builder.asReader()); }\n" " inline operator Reader() const { return Reader(_builder.asReader()); }\n"
" inline Reader asReader() const { return *this; }\n" " inline Reader asReader() const { return *this; }\n"
"\n" "\n"
" inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); }\n" " inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }\n"
"\n", "\n",
isUnion ? kj::strTree(" inline Which which();\n") : kj::strTree(), isUnion ? kj::strTree(" inline Which which();\n") : kj::strTree(),
kj::mv(methodDecls), kj::mv(methodDecls),
...@@ -1219,7 +1219,7 @@ private: ...@@ -1219,7 +1219,7 @@ private:
return MethodText { return MethodText {
kj::strTree( kj::strTree(
" ::capnp::Request<", paramType, ", ", resultType, "> ", name, "Request(\n" " ::capnp::Request<", paramType, ", ", resultType, "> ", name, "Request(\n"
" unsigned int firstSegmentWordSize = 0);\n"), " ::kj::Maybe< ::capnp::MessageSize> sizeHint = nullptr);\n"),
kj::strTree( kj::strTree(
paramProto.getScopeId() != 0 ? kj::strTree() : kj::strTree( paramProto.getScopeId() != 0 ? kj::strTree() : kj::strTree(
...@@ -1234,9 +1234,9 @@ private: ...@@ -1234,9 +1234,9 @@ private:
kj::strTree( kj::strTree(
"::capnp::Request<", paramType, ", ", resultType, ">\n", "::capnp::Request<", paramType, ", ", resultType, ">\n",
interfaceName, "::Client::", name, "Request(unsigned int firstSegmentWordSize) {\n" interfaceName, "::Client::", name, "Request(::kj::Maybe< ::capnp::MessageSize> sizeHint) {\n"
" return newCall<", paramType, ", ", resultType, ">(\n" " return newCall<", paramType, ", ", resultType, ">(\n"
" 0x", interfaceIdHex, "ull, ", methodId, ", firstSegmentWordSize);\n" " 0x", interfaceIdHex, "ull, ", methodId, ", sizeHint);\n"
"}\n" "}\n"
"::kj::Promise<void> ", interfaceName, "::Server::", name, "(", titleCase, "Context) {\n" "::kj::Promise<void> ", interfaceName, "::Server::", name, "(", titleCase, "Context) {\n"
" return ::capnp::Capability::Server::internalUnimplemented(\n" " return ::capnp::Capability::Server::internalUnimplemented(\n"
......
This diff is collapsed.
...@@ -103,8 +103,8 @@ public: ...@@ -103,8 +103,8 @@ public:
Reader() = default; Reader() = default;
inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}
inline size_t totalSizeInWords() const { inline ::capnp::MessageSize totalSize() const {
return _reader.totalSize() / ::capnp::WORDS; return _reader.totalSize().asPublic();
} }
inline Which which() const; inline Which which() const;
...@@ -166,7 +166,7 @@ public: ...@@ -166,7 +166,7 @@ public:
inline operator Reader() const { return Reader(_builder.asReader()); } inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return *this; } inline Reader asReader() const { return *this; }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); } inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }
inline Which which(); inline Which which();
inline bool isIdentifier(); inline bool isIdentifier();
...@@ -258,8 +258,8 @@ public: ...@@ -258,8 +258,8 @@ public:
Reader() = default; Reader() = default;
inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}
inline size_t totalSizeInWords() const { inline ::capnp::MessageSize totalSize() const {
return _reader.totalSize() / ::capnp::WORDS; return _reader.totalSize().asPublic();
} }
inline Which which() const; inline Which which() const;
...@@ -308,7 +308,7 @@ public: ...@@ -308,7 +308,7 @@ public:
inline operator Reader() const { return Reader(_builder.asReader()); } inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return *this; } inline Reader asReader() const { return *this; }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); } inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }
inline Which which(); inline Which which();
inline bool hasTokens(); inline bool hasTokens();
...@@ -376,8 +376,8 @@ public: ...@@ -376,8 +376,8 @@ public:
Reader() = default; Reader() = default;
inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}
inline size_t totalSizeInWords() const { inline ::capnp::MessageSize totalSize() const {
return _reader.totalSize() / ::capnp::WORDS; return _reader.totalSize().asPublic();
} }
inline bool hasTokens() const; inline bool hasTokens() const;
...@@ -411,7 +411,7 @@ public: ...@@ -411,7 +411,7 @@ public:
inline operator Reader() const { return Reader(_builder.asReader()); } inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return *this; } inline Reader asReader() const { return *this; }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); } inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }
inline bool hasTokens(); inline bool hasTokens();
inline ::capnp::List< ::capnp::compiler::Token>::Builder getTokens(); inline ::capnp::List< ::capnp::compiler::Token>::Builder getTokens();
...@@ -453,8 +453,8 @@ public: ...@@ -453,8 +453,8 @@ public:
Reader() = default; Reader() = default;
inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}
inline size_t totalSizeInWords() const { inline ::capnp::MessageSize totalSize() const {
return _reader.totalSize() / ::capnp::WORDS; return _reader.totalSize().asPublic();
} }
inline bool hasStatements() const; inline bool hasStatements() const;
...@@ -488,7 +488,7 @@ public: ...@@ -488,7 +488,7 @@ public:
inline operator Reader() const { return Reader(_builder.asReader()); } inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return *this; } inline Reader asReader() const { return *this; }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); } inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }
inline bool hasStatements(); inline bool hasStatements();
inline ::capnp::List< ::capnp::compiler::Statement>::Builder getStatements(); inline ::capnp::List< ::capnp::compiler::Statement>::Builder getStatements();
......
...@@ -34,7 +34,7 @@ DynamicCapability::Client DynamicCapability::Client::upcast(InterfaceSchema requ ...@@ -34,7 +34,7 @@ DynamicCapability::Client DynamicCapability::Client::upcast(InterfaceSchema requ
} }
Request<DynamicStruct, DynamicStruct> DynamicCapability::Client::newRequest( Request<DynamicStruct, DynamicStruct> DynamicCapability::Client::newRequest(
InterfaceSchema::Method method, uint firstSegmentWordSize) { InterfaceSchema::Method method, kj::Maybe<MessageSize> sizeHint) {
auto methodInterface = method.getContainingInterface(); auto methodInterface = method.getContainingInterface();
KJ_REQUIRE(schema.extends(methodInterface), "Interface does not implement this method."); KJ_REQUIRE(schema.extends(methodInterface), "Interface does not implement this method.");
...@@ -44,15 +44,15 @@ Request<DynamicStruct, DynamicStruct> DynamicCapability::Client::newRequest( ...@@ -44,15 +44,15 @@ Request<DynamicStruct, DynamicStruct> DynamicCapability::Client::newRequest(
auto resultType = methodInterface.getDependency(proto.getResultStructType()).asStruct(); auto resultType = methodInterface.getDependency(proto.getResultStructType()).asStruct();
auto typeless = hook->newCall( auto typeless = hook->newCall(
methodInterface.getProto().getId(), method.getIndex(), firstSegmentWordSize); methodInterface.getProto().getId(), method.getIndex(), sizeHint);
return Request<DynamicStruct, DynamicStruct>( return Request<DynamicStruct, DynamicStruct>(
typeless.getAs<DynamicStruct>(paramType), kj::mv(typeless.hook), resultType); typeless.getAs<DynamicStruct>(paramType), kj::mv(typeless.hook), resultType);
} }
Request<DynamicStruct, DynamicStruct> DynamicCapability::Client::newRequest( Request<DynamicStruct, DynamicStruct> DynamicCapability::Client::newRequest(
kj::StringPtr methodName, uint firstSegmentWordSize) { kj::StringPtr methodName, kj::Maybe<MessageSize> sizeHint) {
return newRequest(schema.getMethodByName(methodName), firstSegmentWordSize); return newRequest(schema.getMethodByName(methodName), sizeHint);
} }
kj::Promise<void> DynamicCapability::Server::dispatchCall( kj::Promise<void> DynamicCapability::Server::dispatchCall(
......
...@@ -162,7 +162,7 @@ public: ...@@ -162,7 +162,7 @@ public:
template <typename T, typename = kj::EnableIf<kind<FromReader<T>>() == Kind::STRUCT>> template <typename T, typename = kj::EnableIf<kind<FromReader<T>>() == Kind::STRUCT>>
inline Reader(T&& value): Reader(toDynamic(value)) {} inline Reader(T&& value): Reader(toDynamic(value)) {}
inline size_t totalSizeInWords() const { return reader.totalSize() / ::capnp::WORDS; } inline MessageSize totalSize() const { return reader.totalSize().asPublic(); }
template <typename T> template <typename T>
typename T::Reader as() const; typename T::Reader as() const;
...@@ -228,7 +228,7 @@ public: ...@@ -228,7 +228,7 @@ public:
template <typename T, typename = kj::EnableIf<kind<FromBuilder<T>>() == Kind::STRUCT>> template <typename T, typename = kj::EnableIf<kind<FromBuilder<T>>() == Kind::STRUCT>>
inline Builder(T&& value): Builder(toDynamic(value)) {} inline Builder(T&& value): Builder(toDynamic(value)) {}
inline size_t totalSizeInWords() const { return asReader().totalSizeInWords(); } inline MessageSize totalSize() const { return asReader().totalSize(); }
template <typename T> template <typename T>
typename T::Builder as(); typename T::Builder as();
...@@ -455,9 +455,9 @@ public: ...@@ -455,9 +455,9 @@ public:
inline InterfaceSchema getSchema() { return schema; } inline InterfaceSchema getSchema() { return schema; }
Request<DynamicStruct, DynamicStruct> newRequest( Request<DynamicStruct, DynamicStruct> newRequest(
InterfaceSchema::Method method, uint firstSegmentWordSize = 0); InterfaceSchema::Method method, kj::Maybe<MessageSize> sizeHint = nullptr);
Request<DynamicStruct, DynamicStruct> newRequest( Request<DynamicStruct, DynamicStruct> newRequest(
kj::StringPtr methodName, uint firstSegmentWordSize = 0); kj::StringPtr methodName, kj::Maybe<MessageSize> sizeHint = nullptr);
private: private:
InterfaceSchema schema; InterfaceSchema schema;
...@@ -532,11 +532,11 @@ public: ...@@ -532,11 +532,11 @@ public:
DynamicStruct::Reader getParams(); DynamicStruct::Reader getParams();
void releaseParams(); void releaseParams();
DynamicStruct::Builder getResults(uint firstSegmentWordSize = 0); DynamicStruct::Builder getResults(kj::Maybe<MessageSize> sizeHint = nullptr);
DynamicStruct::Builder initResults(uint firstSegmentWordSize = 0); DynamicStruct::Builder initResults(kj::Maybe<MessageSize> sizeHint = nullptr);
void setResults(DynamicStruct::Reader value); void setResults(DynamicStruct::Reader value);
void adoptResults(Orphan<DynamicStruct>&& value); void adoptResults(Orphan<DynamicStruct>&& value);
Orphanage getResultsOrphanage(uint firstSegmentWordSize = 0); Orphanage getResultsOrphanage(kj::Maybe<MessageSize> sizeHint = nullptr);
template <typename SubParams> template <typename SubParams>
kj::Promise<void> tailCall(Request<SubParams, DynamicStruct>&& tailRequest); kj::Promise<void> tailCall(Request<SubParams, DynamicStruct>&& tailRequest);
void allowCancellation(); void allowCancellation();
...@@ -1517,22 +1517,22 @@ inline void CallContext<DynamicStruct, DynamicStruct>::releaseParams() { ...@@ -1517,22 +1517,22 @@ inline void CallContext<DynamicStruct, DynamicStruct>::releaseParams() {
hook->releaseParams(); hook->releaseParams();
} }
inline DynamicStruct::Builder CallContext<DynamicStruct, DynamicStruct>::getResults( inline DynamicStruct::Builder CallContext<DynamicStruct, DynamicStruct>::getResults(
uint firstSegmentWordSize) { kj::Maybe<MessageSize> sizeHint) {
return hook->getResults(firstSegmentWordSize).getAs<DynamicStruct>(resultType); return hook->getResults(sizeHint).getAs<DynamicStruct>(resultType);
} }
inline DynamicStruct::Builder CallContext<DynamicStruct, DynamicStruct>::initResults( inline DynamicStruct::Builder CallContext<DynamicStruct, DynamicStruct>::initResults(
uint firstSegmentWordSize) { kj::Maybe<MessageSize> sizeHint) {
return hook->getResults(firstSegmentWordSize).initAs<DynamicStruct>(resultType); return hook->getResults(sizeHint).initAs<DynamicStruct>(resultType);
} }
inline void CallContext<DynamicStruct, DynamicStruct>::setResults(DynamicStruct::Reader value) { inline void CallContext<DynamicStruct, DynamicStruct>::setResults(DynamicStruct::Reader value) {
hook->getResults(value.totalSizeInWords() + 1).setAs<DynamicStruct>(value); hook->getResults(value.totalSize()).setAs<DynamicStruct>(value);
} }
inline void CallContext<DynamicStruct, DynamicStruct>::adoptResults(Orphan<DynamicStruct>&& value) { inline void CallContext<DynamicStruct, DynamicStruct>::adoptResults(Orphan<DynamicStruct>&& value) {
hook->getResults(0).adopt(kj::mv(value)); hook->getResults(MessageSize { 0, 0 }).adopt(kj::mv(value));
} }
inline Orphanage CallContext<DynamicStruct, DynamicStruct>::getResultsOrphanage( inline Orphanage CallContext<DynamicStruct, DynamicStruct>::getResultsOrphanage(
uint firstSegmentWordSize) { kj::Maybe<MessageSize> sizeHint) {
return Orphanage::getForMessageContaining(hook->getResults(firstSegmentWordSize)); return Orphanage::getForMessageContaining(hook->getResults(sizeHint));
} }
template <typename SubParams> template <typename SubParams>
inline kj::Promise<void> CallContext<DynamicStruct, DynamicStruct>::tailCall( inline kj::Promise<void> CallContext<DynamicStruct, DynamicStruct>::tailCall(
......
...@@ -72,7 +72,7 @@ TEST(Encoding, AllTypes) { ...@@ -72,7 +72,7 @@ TEST(Encoding, AllTypes) {
checkTestMessage(readMessageUnchecked<TestAllTypes>(builder.getSegmentsForOutput()[0].begin())); checkTestMessage(readMessageUnchecked<TestAllTypes>(builder.getSegmentsForOutput()[0].begin()));
EXPECT_EQ(builder.getSegmentsForOutput()[0].size() - 1, // -1 for root pointer EXPECT_EQ(builder.getSegmentsForOutput()[0].size() - 1, // -1 for root pointer
reader.getRoot<TestAllTypes>().totalSizeInWords()); reader.getRoot<TestAllTypes>().totalSize().wordCount);
} }
TEST(Encoding, AllTypesMultiSegment) { TEST(Encoding, AllTypesMultiSegment) {
...@@ -1624,7 +1624,7 @@ TEST(Encoding, HasEmptyStruct) { ...@@ -1624,7 +1624,7 @@ TEST(Encoding, HasEmptyStruct) {
MallocMessageBuilder message; MallocMessageBuilder message;
auto root = message.initRoot<test::TestAnyPointer>(); auto root = message.initRoot<test::TestAnyPointer>();
EXPECT_EQ(1, root.totalSizeInWords()); EXPECT_EQ(1, root.totalSize().wordCount);
EXPECT_FALSE(root.asReader().hasAnyPointerField()); EXPECT_FALSE(root.asReader().hasAnyPointerField());
EXPECT_FALSE(root.hasAnyPointerField()); EXPECT_FALSE(root.hasAnyPointerField());
...@@ -1632,14 +1632,14 @@ TEST(Encoding, HasEmptyStruct) { ...@@ -1632,14 +1632,14 @@ TEST(Encoding, HasEmptyStruct) {
EXPECT_TRUE(root.asReader().hasAnyPointerField()); EXPECT_TRUE(root.asReader().hasAnyPointerField());
EXPECT_TRUE(root.hasAnyPointerField()); EXPECT_TRUE(root.hasAnyPointerField());
EXPECT_EQ(1, root.totalSizeInWords()); EXPECT_EQ(1, root.totalSize().wordCount);
} }
TEST(Encoding, HasEmptyList) { TEST(Encoding, HasEmptyList) {
MallocMessageBuilder message; MallocMessageBuilder message;
auto root = message.initRoot<test::TestAnyPointer>(); auto root = message.initRoot<test::TestAnyPointer>();
EXPECT_EQ(1, root.totalSizeInWords()); EXPECT_EQ(1, root.totalSize().wordCount);
EXPECT_FALSE(root.asReader().hasAnyPointerField()); EXPECT_FALSE(root.asReader().hasAnyPointerField());
EXPECT_FALSE(root.hasAnyPointerField()); EXPECT_FALSE(root.hasAnyPointerField());
...@@ -1647,14 +1647,14 @@ TEST(Encoding, HasEmptyList) { ...@@ -1647,14 +1647,14 @@ TEST(Encoding, HasEmptyList) {
EXPECT_TRUE(root.asReader().hasAnyPointerField()); EXPECT_TRUE(root.asReader().hasAnyPointerField());
EXPECT_TRUE(root.hasAnyPointerField()); EXPECT_TRUE(root.hasAnyPointerField());
EXPECT_EQ(1, root.totalSizeInWords()); EXPECT_EQ(1, root.totalSize().wordCount);
} }
TEST(Encoding, HasEmptyStructList) { TEST(Encoding, HasEmptyStructList) {
MallocMessageBuilder message; MallocMessageBuilder message;
auto root = message.initRoot<test::TestAnyPointer>(); auto root = message.initRoot<test::TestAnyPointer>();
EXPECT_EQ(1, root.totalSizeInWords()); EXPECT_EQ(1, root.totalSize().wordCount);
EXPECT_FALSE(root.asReader().hasAnyPointerField()); EXPECT_FALSE(root.asReader().hasAnyPointerField());
EXPECT_FALSE(root.hasAnyPointerField()); EXPECT_FALSE(root.hasAnyPointerField());
...@@ -1662,7 +1662,7 @@ TEST(Encoding, HasEmptyStructList) { ...@@ -1662,7 +1662,7 @@ TEST(Encoding, HasEmptyStructList) {
EXPECT_TRUE(root.asReader().hasAnyPointerField()); EXPECT_TRUE(root.asReader().hasAnyPointerField());
EXPECT_TRUE(root.hasAnyPointerField()); EXPECT_TRUE(root.hasAnyPointerField());
EXPECT_EQ(2, root.totalSizeInWords()); EXPECT_EQ(2, root.totalSize().wordCount);
} }
} // namespace } // namespace
......
...@@ -567,29 +567,30 @@ struct WireHelpers { ...@@ -567,29 +567,30 @@ struct WireHelpers {
// ----------------------------------------------------------------- // -----------------------------------------------------------------
static WordCount64 totalSize(SegmentReader* segment, const WirePointer* ref, int nestingLimit) { static MessageSizeCounts totalSize(
SegmentReader* segment, const WirePointer* ref, int nestingLimit) {
// Compute the total size of the object pointed to, not counting far pointer overhead. // Compute the total size of the object pointed to, not counting far pointer overhead.
MessageSizeCounts result = { 0 * WORDS, 0 };
if (ref->isNull()) { if (ref->isNull()) {
return 0 * WORDS; return result;
} }
KJ_REQUIRE(nestingLimit > 0, "Message is too deeply-nested.") { KJ_REQUIRE(nestingLimit > 0, "Message is too deeply-nested.") {
return 0 * WORDS; return result;
} }
--nestingLimit; --nestingLimit;
const word* ptr = followFars(ref, ref->target(), segment); const word* ptr = followFars(ref, ref->target(), segment);
WordCount64 result = 0 * WORDS;
switch (ref->kind()) { switch (ref->kind()) {
case WirePointer::STRUCT: { case WirePointer::STRUCT: {
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
return result; return result;
} }
result += ref->structRef.wordSize(); result.wordCount += ref->structRef.wordSize();
const WirePointer* pointerSection = const WirePointer* pointerSection =
reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get()); reinterpret_cast<const WirePointer*>(ptr + ref->structRef.dataSize.get());
...@@ -616,7 +617,7 @@ struct WireHelpers { ...@@ -616,7 +617,7 @@ struct WireHelpers {
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
return result; return result;
} }
result += totalWords; result.wordCount += totalWords;
break; break;
} }
case FieldSize::POINTER: { case FieldSize::POINTER: {
...@@ -627,7 +628,7 @@ struct WireHelpers { ...@@ -627,7 +628,7 @@ struct WireHelpers {
return result; return result;
} }
result += count * WORDS_PER_POINTER; result.wordCount += count * WORDS_PER_POINTER;
for (uint i = 0; i < count / POINTERS; i++) { for (uint i = 0; i < count / POINTERS; i++) {
result += totalSize(segment, reinterpret_cast<const WirePointer*>(ptr) + i, result += totalSize(segment, reinterpret_cast<const WirePointer*>(ptr) + i,
...@@ -642,7 +643,7 @@ struct WireHelpers { ...@@ -642,7 +643,7 @@ struct WireHelpers {
return result; return result;
} }
result += wordCount + POINTER_SIZE_IN_WORDS; result.wordCount += wordCount + POINTER_SIZE_IN_WORDS;
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount(); ElementCount count = elementTag->inlineCompositeListElementCount();
...@@ -681,7 +682,11 @@ struct WireHelpers { ...@@ -681,7 +682,11 @@ struct WireHelpers {
} }
break; break;
case WirePointer::OTHER: case WirePointer::OTHER:
KJ_REQUIRE(ref->isCapability(), "Unknown pointer type.") { break; } if (ref->isCapability()) {
result.capCount++;
} else {
KJ_FAIL_REQUIRE("Unknown pointer type.") { break; }
}
break; break;
} }
...@@ -2246,7 +2251,7 @@ const word* PointerReader::getUnchecked() const { ...@@ -2246,7 +2251,7 @@ const word* PointerReader::getUnchecked() const {
return reinterpret_cast<const word*>(pointer); return reinterpret_cast<const word*>(pointer);
} }
WordCount64 PointerReader::targetSize() const { MessageSizeCounts PointerReader::targetSize() const {
return WireHelpers::totalSize(segment, pointer, nestingLimit); return WireHelpers::totalSize(segment, pointer, nestingLimit);
} }
...@@ -2366,8 +2371,9 @@ BuilderArena* StructBuilder::getArena() { ...@@ -2366,8 +2371,9 @@ BuilderArena* StructBuilder::getArena() {
// ======================================================================================= // =======================================================================================
// StructReader // StructReader
WordCount64 StructReader::totalSize() const { MessageSizeCounts StructReader::totalSize() const {
WordCount64 result = WireHelpers::roundBitsUpToWords(dataSize) + pointerCount * WORDS_PER_POINTER; MessageSizeCounts result = {
WireHelpers::roundBitsUpToWords(dataSize) + pointerCount * WORDS_PER_POINTER, 0 };
for (uint i = 0; i < pointerCount / POINTERS; i++) { for (uint i = 0; i < pointerCount / POINTERS; i++) {
result += WireHelpers::totalSize(segment, pointers + i, nestingLimit); result += WireHelpers::totalSize(segment, pointers + i, nestingLimit);
...@@ -2376,7 +2382,7 @@ WordCount64 StructReader::totalSize() const { ...@@ -2376,7 +2382,7 @@ WordCount64 StructReader::totalSize() const {
if (segment != nullptr) { if (segment != nullptr) {
// This traversal should not count against the read limit, because it's highly likely that // This traversal should not count against the read limit, because it's highly likely that
// the caller is going to traverse the object again, e.g. to copy it. // the caller is going to traverse the object again, e.g. to copy it.
segment->unread(result); segment->unread(result.wordCount);
} }
return result; return result;
......
...@@ -166,11 +166,22 @@ inline constexpr FieldSize elementSizeForType() { ...@@ -166,11 +166,22 @@ inline constexpr FieldSize elementSizeForType() {
return ElementSizeForType<T>::value; return ElementSizeForType<T>::value;
} }
} // namespace _ (private) struct MessageSizeCounts {
WordCount64 wordCount;
uint capCount;
MessageSizeCounts& operator+=(const MessageSizeCounts& other) {
wordCount += other.wordCount;
capCount += other.capCount;
return *this;
}
// ============================================================================= MessageSize asPublic() {
return MessageSize { wordCount / WORDS, capCount };
}
};
namespace _ { // private // =============================================================================
template <int wordCount> template <int wordCount>
union AlignedData { union AlignedData {
...@@ -355,7 +366,7 @@ public: ...@@ -355,7 +366,7 @@ public:
static inline PointerReader getRootUnchecked(const word* location); static inline PointerReader getRootUnchecked(const word* location);
// Get a PointerReader for an unchecked message. // Get a PointerReader for an unchecked message.
WordCount64 targetSize() const; MessageSizeCounts targetSize() const;
// Return the total size of the target object and everything to which it points. Does not count // Return the total size of the target object and everything to which it points. Does not count
// far pointer overhead. This is useful for deciding how much space is needed to copy the object // far pointer overhead. This is useful for deciding how much space is needed to copy the object
// into a flat array. However, the caller is advised NOT to treat this value as secure. Instead, // into a flat array. However, the caller is advised NOT to treat this value as secure. Instead,
...@@ -524,7 +535,7 @@ public: ...@@ -524,7 +535,7 @@ public:
// Get a reader for a pointer field given the index within the pointer section. If the index // Get a reader for a pointer field given the index within the pointer section. If the index
// is out-of-bounds, returns a null pointer. // is out-of-bounds, returns a null pointer.
WordCount64 totalSize() const; MessageSizeCounts totalSize() const;
// Return the total size of the struct and everything to which it points. Does not count far // Return the total size of the struct and everything to which it points. Does not count far
// pointer overhead. This is useful for deciding how much space is needed to copy the struct // pointer overhead. This is useful for deciding how much space is needed to copy the struct
// into a flat array. However, the caller is advised NOT to treat this value as secure. Instead, // into a flat array. However, the caller is advised NOT to treat this value as secure. Instead,
......
...@@ -335,8 +335,17 @@ private: ...@@ -335,8 +335,17 @@ private:
}; };
class FlatMessageBuilder: public MessageBuilder { class FlatMessageBuilder: public MessageBuilder {
// A message builder implementation which allocates from a single flat array, throwing an // THIS IS NOT THE CLASS YOU'RE LOOKING FOR.
// exception if it runs out of space. The array must be zero'd before use. //
// If you want to write a message into already-existing scratch space, use `MallocMessageBuilder`
// and pass the scratch space to its constructor. It will then only fall back to malloc() if
// the scratch space is not large enough.
//
// Do NOT use this class unless you really know what you're doing. This class is problematic
// because it requires advance knowledge of the size of your message, which is usually impossible
// to determine without actually building the message. The class was created primarily to
// implement `copyToUnchecked()`, which itself exists only to support other internal parts of
// the Cap'n Proto implementation.
public: public:
explicit FlatMessageBuilder(kj::ArrayPtr<word> array); explicit FlatMessageBuilder(kj::ArrayPtr<word> array);
......
...@@ -125,8 +125,8 @@ public: ...@@ -125,8 +125,8 @@ public:
Reader() = default; Reader() = default;
inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}
inline size_t totalSizeInWords() const { inline ::capnp::MessageSize totalSize() const {
return _reader.totalSize() / ::capnp::WORDS; return _reader.totalSize().asPublic();
} }
inline ::capnp::rpc::twoparty::Side getSide() const; inline ::capnp::rpc::twoparty::Side getSide() const;
...@@ -159,7 +159,7 @@ public: ...@@ -159,7 +159,7 @@ public:
inline operator Reader() const { return Reader(_builder.asReader()); } inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return *this; } inline Reader asReader() const { return *this; }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); } inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }
inline ::capnp::rpc::twoparty::Side getSide(); inline ::capnp::rpc::twoparty::Side getSide();
inline void setSide( ::capnp::rpc::twoparty::Side value); inline void setSide( ::capnp::rpc::twoparty::Side value);
...@@ -197,8 +197,8 @@ public: ...@@ -197,8 +197,8 @@ public:
Reader() = default; Reader() = default;
inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}
inline size_t totalSizeInWords() const { inline ::capnp::MessageSize totalSize() const {
return _reader.totalSize() / ::capnp::WORDS; return _reader.totalSize().asPublic();
} }
inline ::uint32_t getJoinId() const; inline ::uint32_t getJoinId() const;
...@@ -231,7 +231,7 @@ public: ...@@ -231,7 +231,7 @@ public:
inline operator Reader() const { return Reader(_builder.asReader()); } inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return *this; } inline Reader asReader() const { return *this; }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); } inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }
inline ::uint32_t getJoinId(); inline ::uint32_t getJoinId();
inline void setJoinId( ::uint32_t value); inline void setJoinId( ::uint32_t value);
...@@ -269,8 +269,8 @@ public: ...@@ -269,8 +269,8 @@ public:
Reader() = default; Reader() = default;
inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}
inline size_t totalSizeInWords() const { inline ::capnp::MessageSize totalSize() const {
return _reader.totalSize() / ::capnp::WORDS; return _reader.totalSize().asPublic();
} }
private: private:
...@@ -301,7 +301,7 @@ public: ...@@ -301,7 +301,7 @@ public:
inline operator Reader() const { return Reader(_builder.asReader()); } inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return *this; } inline Reader asReader() const { return *this; }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); } inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }
private: private:
::capnp::_::StructBuilder _builder; ::capnp::_::StructBuilder _builder;
...@@ -336,8 +336,8 @@ public: ...@@ -336,8 +336,8 @@ public:
Reader() = default; Reader() = default;
inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}
inline size_t totalSizeInWords() const { inline ::capnp::MessageSize totalSize() const {
return _reader.totalSize() / ::capnp::WORDS; return _reader.totalSize().asPublic();
} }
private: private:
...@@ -368,7 +368,7 @@ public: ...@@ -368,7 +368,7 @@ public:
inline operator Reader() const { return Reader(_builder.asReader()); } inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return *this; } inline Reader asReader() const { return *this; }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); } inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }
private: private:
::capnp::_::StructBuilder _builder; ::capnp::_::StructBuilder _builder;
...@@ -403,8 +403,8 @@ public: ...@@ -403,8 +403,8 @@ public:
Reader() = default; Reader() = default;
inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}
inline size_t totalSizeInWords() const { inline ::capnp::MessageSize totalSize() const {
return _reader.totalSize() / ::capnp::WORDS; return _reader.totalSize().asPublic();
} }
inline ::uint32_t getJoinId() const; inline ::uint32_t getJoinId() const;
...@@ -441,7 +441,7 @@ public: ...@@ -441,7 +441,7 @@ public:
inline operator Reader() const { return Reader(_builder.asReader()); } inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return *this; } inline Reader asReader() const { return *this; }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); } inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }
inline ::uint32_t getJoinId(); inline ::uint32_t getJoinId();
inline void setJoinId( ::uint32_t value); inline void setJoinId( ::uint32_t value);
...@@ -485,8 +485,8 @@ public: ...@@ -485,8 +485,8 @@ public:
Reader() = default; Reader() = default;
inline explicit Reader(::capnp::_::StructReader base): _reader(base) {} inline explicit Reader(::capnp::_::StructReader base): _reader(base) {}
inline size_t totalSizeInWords() const { inline ::capnp::MessageSize totalSize() const {
return _reader.totalSize() / ::capnp::WORDS; return _reader.totalSize().asPublic();
} }
inline ::uint32_t getJoinId() const; inline ::uint32_t getJoinId() const;
...@@ -524,7 +524,7 @@ public: ...@@ -524,7 +524,7 @@ public:
inline operator Reader() const { return Reader(_builder.asReader()); } inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return *this; } inline Reader asReader() const { return *this; }
inline size_t totalSizeInWords() { return asReader().totalSizeInWords(); } inline ::capnp::MessageSize totalSize() const { return asReader().totalSize(); }
inline ::uint32_t getJoinId(); inline ::uint32_t getJoinId();
inline void setJoinId( ::uint32_t value); inline void setJoinId( ::uint32_t value);
......
...@@ -50,6 +50,24 @@ inline constexpr uint messageSizeHint<void>() { ...@@ -50,6 +50,24 @@ inline constexpr uint messageSizeHint<void>() {
constexpr const uint MESSAGE_TARGET_SIZE_HINT = sizeInWords<rpc::MessageTarget>() + constexpr const uint MESSAGE_TARGET_SIZE_HINT = sizeInWords<rpc::MessageTarget>() +
sizeInWords<rpc::PromisedAnswer>() + 16; // +16 for ops; hope that's enough sizeInWords<rpc::PromisedAnswer>() + 16; // +16 for ops; hope that's enough
constexpr const uint CAP_DESCRIPTOR_SIZE_HINT = sizeInWords<rpc::CapDescriptor>() +
sizeInWords<rpc::PromisedAnswer>();
constexpr const uint64_t MAX_SIZE_HINT = 1 << 20;
uint copySizeHint(MessageSize size) {
uint64_t sizeHint = size.wordCount + size.capCount * CAP_DESCRIPTOR_SIZE_HINT;
return kj::min(MAX_SIZE_HINT, sizeHint);
}
uint firstSegmentSize(kj::Maybe<MessageSize> sizeHint, uint additional) {
KJ_IF_MAYBE(s, sizeHint) {
return copySizeHint(*s) + additional;
} else {
return 0;
}
}
kj::Maybe<kj::Array<PipelineOp>> toPipelineOps(List<rpc::PromisedAnswer::Op>::Reader ops) { kj::Maybe<kj::Array<PipelineOp>> toPipelineOps(List<rpc::PromisedAnswer::Op>::Reader ops) {
auto result = kj::heapArrayBuilder<PipelineOp>(ops.size()); auto result = kj::heapArrayBuilder<PipelineOp>(ops.size());
for (auto opReader: ops) { for (auto opReader: ops) {
...@@ -274,7 +292,7 @@ public: ...@@ -274,7 +292,7 @@ public:
{ {
auto message = connection->newOutgoingMessage( auto message = connection->newOutgoingMessage(
objectId.targetSizeInWords() + messageSizeHint<rpc::Restore>()); objectId.targetSize().wordCount + messageSizeHint<rpc::Restore>());
auto builder = message->getBody().initAs<rpc::Message>().initRestore(); auto builder = message->getBody().initAs<rpc::Message>().initRestore();
builder.setQuestionId(questionId); builder.setQuestionId(questionId);
...@@ -537,9 +555,9 @@ private: ...@@ -537,9 +555,9 @@ private:
// implements ClientHook ----------------------------------------- // implements ClientHook -----------------------------------------
Request<AnyPointer, AnyPointer> newCall( Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto request = kj::heap<RpcRequest>( auto request = kj::heap<RpcRequest>(
*connectionState, firstSegmentWordSize, kj::addRef(*this)); *connectionState, sizeHint, kj::addRef(*this));
auto callBuilder = request->getCall(); auto callBuilder = request->getCall();
callBuilder.setInterfaceId(interfaceId); callBuilder.setInterfaceId(interfaceId);
...@@ -554,21 +572,7 @@ private: ...@@ -554,21 +572,7 @@ private:
// Implement call() by copying params and results messages. // Implement call() by copying params and results messages.
auto params = context->getParams(); auto params = context->getParams();
auto request = newCall(interfaceId, methodId, params.targetSize());
size_t sizeHint = params.targetSizeInWords();
// TODO(perf): Extend targetSizeInWords() to include a capability count? Here we increase
// the size by 1/16 to deal with cap descriptors possibly expanding. See also in
// RpcRequest::send() and RpcCallContext::directTailCall().
// TODO(now): This is a problem, deal with it.
sizeHint += sizeHint / 16;
// Don't overflow.
if (uint(sizeHint) != sizeHint) {
sizeHint = ~uint(0);
}
auto request = newCall(interfaceId, methodId, sizeHint);
request.set(params); request.set(params);
context->releaseParams(); context->releaseParams();
...@@ -770,9 +774,9 @@ private: ...@@ -770,9 +774,9 @@ private:
// implements ClientHook ----------------------------------------- // implements ClientHook -----------------------------------------
Request<AnyPointer, AnyPointer> newCall( Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
receivedCall = true; receivedCall = true;
return cap->newCall(interfaceId, methodId, firstSegmentWordSize); return cap->newCall(interfaceId, methodId, sizeHint);
} }
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
...@@ -1171,14 +1175,13 @@ private: ...@@ -1171,14 +1175,13 @@ private:
class RpcRequest final: public RequestHook { class RpcRequest final: public RequestHook {
public: public:
RpcRequest(RpcConnectionState& connectionState, uint firstSegmentWordSize, RpcRequest(RpcConnectionState& connectionState, kj::Maybe<MessageSize> sizeHint,
kj::Own<RpcClient>&& target) kj::Own<RpcClient>&& target)
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
target(kj::mv(target)), target(kj::mv(target)),
message(connectionState.connection->newOutgoingMessage( message(connectionState.connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 : firstSegmentSize(sizeHint, messageSizeHint<rpc::Call>() +
firstSegmentWordSize + messageSizeHint<rpc::Call>() + sizeInWords<rpc::Payload>() + MESSAGE_TARGET_SIZE_HINT))),
sizeInWords<rpc::Payload>() + MESSAGE_TARGET_SIZE_HINT)),
callBuilder(message->getBody().getAs<rpc::Message>().initCall()), callBuilder(message->getBody().getAs<rpc::Message>().initCall()),
paramsBuilder(context.imbue(callBuilder.getParams().getContent())) {} paramsBuilder(context.imbue(callBuilder.getParams().getContent())) {}
...@@ -1201,18 +1204,8 @@ private: ...@@ -1201,18 +1204,8 @@ private:
// Whoops, this capability has been redirected while we were building the request! // Whoops, this capability has been redirected while we were building the request!
// We'll have to make a new request and do a copy. Ick. // We'll have to make a new request and do a copy. Ick.
size_t sizeHint = paramsBuilder.targetSizeInWords();
// TODO(perf): See TODO in RpcClient::call() about why we need to inflate the size a bit.
sizeHint += sizeHint / 16;
// Don't overflow.
if (uint(sizeHint) != sizeHint) {
sizeHint = ~uint(0);
}
auto replacement = redirect->get()->newCall( auto replacement = redirect->get()->newCall(
callBuilder.getInterfaceId(), callBuilder.getMethodId(), sizeHint); callBuilder.getInterfaceId(), callBuilder.getMethodId(), paramsBuilder.targetSize());
replacement.set(paramsBuilder); replacement.set(paramsBuilder);
return replacement.send(); return replacement.send();
} else { } else {
...@@ -1501,9 +1494,8 @@ private: ...@@ -1501,9 +1494,8 @@ private:
class LocallyRedirectedRpcResponse final class LocallyRedirectedRpcResponse final
: public RpcResponse, public RpcServerResponse, public kj::Refcounted{ : public RpcResponse, public RpcServerResponse, public kj::Refcounted{
public: public:
LocallyRedirectedRpcResponse(uint firstSegmentWordSize) LocallyRedirectedRpcResponse(kj::Maybe<MessageSize> sizeHint)
: message(firstSegmentWordSize == 0 ? : message(sizeHint) {}
SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize + 1) {}
AnyPointer::Builder getResultsBuilder() override { AnyPointer::Builder getResultsBuilder() override {
return message.getRoot(); return message.getRoot();
...@@ -1568,7 +1560,7 @@ private: ...@@ -1568,7 +1560,7 @@ private:
kj::Own<RpcResponse> consumeRedirectedResponse() { kj::Own<RpcResponse> consumeRedirectedResponse() {
KJ_ASSERT(redirectResults); KJ_ASSERT(redirectResults);
if (response == nullptr) getResults(1); // force initialization of response if (response == nullptr) getResults(MessageSize{0, 0}); // force initialization of response
// Note that the context needs to keep its own reference to the response so that it doesn't // Note that the context needs to keep its own reference to the response so that it doesn't
// get GC'd until the PipelineHook drops its reference to the context. // get GC'd until the PipelineHook drops its reference to the context.
...@@ -1581,7 +1573,7 @@ private: ...@@ -1581,7 +1573,7 @@ private:
// Avoid sending results if canceled so that we don't have to figure out whether or not // Avoid sending results if canceled so that we don't have to figure out whether or not
// `releaseResultCaps` was set in the already-received `Finish`. // `releaseResultCaps` was set in the already-received `Finish`.
if (!(cancellationFlags & CANCEL_REQUESTED) && isFirstResponder()) { if (!(cancellationFlags & CANCEL_REQUESTED) && isFirstResponder()) {
if (response == nullptr) getResults(1); // force initialization of response if (response == nullptr) getResults(MessageSize{0, 0}); // force initialization of response
returnMessage.setQuestionId(questionId); returnMessage.setQuestionId(questionId);
returnMessage.setReleaseParamCaps(false); returnMessage.setReleaseParamCaps(false);
...@@ -1641,19 +1633,18 @@ private: ...@@ -1641,19 +1633,18 @@ private:
void releaseParams() override { void releaseParams() override {
request = nullptr; request = nullptr;
} }
AnyPointer::Builder getResults(uint firstSegmentWordSize) override { AnyPointer::Builder getResults(kj::Maybe<MessageSize> sizeHint) override {
KJ_IF_MAYBE(r, response) { KJ_IF_MAYBE(r, response) {
return r->get()->getResultsBuilder(); return r->get()->getResultsBuilder();
} else { } else {
kj::Own<RpcServerResponse> response; kj::Own<RpcServerResponse> response;
if (redirectResults) { if (redirectResults) {
response = kj::refcounted<LocallyRedirectedRpcResponse>(firstSegmentWordSize); response = kj::refcounted<LocallyRedirectedRpcResponse>(sizeHint);
} else { } else {
// TODO(now): Expand size hint for cap count...
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 : firstSegmentSize(sizeHint, messageSizeHint<rpc::Return>() +
firstSegmentWordSize + messageSizeHint<rpc::Return>() + sizeInWords<rpc::Payload>()); sizeInWords<rpc::Payload>()));
returnMessage = message->getBody().initAs<rpc::Message>().initReturn(); returnMessage = message->getBody().initAs<rpc::Message>().initReturn();
response = kj::heap<RpcServerResponseImpl>( response = kj::heap<RpcServerResponseImpl>(
*connectionState, kj::mv(message), returnMessage.getResults()); *connectionState, kj::mv(message), returnMessage.getResults());
...@@ -1709,9 +1700,7 @@ private: ...@@ -1709,9 +1700,7 @@ private:
// Copy the response. // Copy the response.
// TODO(perf): It would be nice if we could somehow make the response get built in-place // TODO(perf): It would be nice if we could somehow make the response get built in-place
// but requires some refactoring. // but requires some refactoring.
size_t sizeHint = tailResponse.targetSizeInWords(); getResults(tailResponse.targetSize()).set(tailResponse);
sizeHint += sizeHint / 16; // see TODO in RpcClient::call().
getResults(sizeHint).set(tailResponse);
}); });
return { kj::mv(voidPromise), PipelineHook::from(kj::mv(promise)) }; return { kj::mv(voidPromise), PipelineHook::from(kj::mv(promise)) };
...@@ -1881,7 +1870,7 @@ private: ...@@ -1881,7 +1870,7 @@ private:
default: { default: {
auto message = connection->newOutgoingMessage( auto message = connection->newOutgoingMessage(
reader.totalSizeInWords() + messageSizeHint<void>()); firstSegmentSize(reader.totalSize(), messageSizeHint<void>()));
message->getBody().initAs<rpc::Message>().setUnimplemented(reader); message->getBody().initAs<rpc::Message>().setUnimplemented(reader);
message->send(); message->send();
break; break;
......
This diff is collapsed.
...@@ -1328,7 +1328,7 @@ void SchemaLoader::Impl::requireStructSize(uint64_t id, uint dataWordCount, uint ...@@ -1328,7 +1328,7 @@ void SchemaLoader::Impl::requireStructSize(uint64_t id, uint dataWordCount, uint
} }
kj::ArrayPtr<word> SchemaLoader::Impl::makeUncheckedNode(schema::Node::Reader node) { kj::ArrayPtr<word> SchemaLoader::Impl::makeUncheckedNode(schema::Node::Reader node) {
size_t size = node.totalSizeInWords() + 1; size_t size = node.totalSize().wordCount + 1;
kj::ArrayPtr<word> result = arena.allocateArray<word>(size); kj::ArrayPtr<word> result = arena.allocateArray<word>(size);
memset(result.begin(), 0, size * sizeof(word)); memset(result.begin(), 0, size * sizeof(word));
copyToUnchecked(node, result); copyToUnchecked(node, result);
......
This diff is collapsed.
...@@ -181,7 +181,7 @@ TEST(Async, DeepChain) { ...@@ -181,7 +181,7 @@ TEST(Async, DeepChain) {
// Create a ridiculous chain of promises. // Create a ridiculous chain of promises.
for (uint i = 0; i < 1000; i++) { for (uint i = 0; i < 1000; i++) {
promise = evalLater(mvCapture(promise, [&,i](Promise<void> promise) { promise = evalLater(mvCapture(promise, [](Promise<void> promise) {
return kj::mv(promise); return kj::mv(promise);
})); }));
} }
...@@ -218,7 +218,7 @@ TEST(Async, DeepChain2) { ...@@ -218,7 +218,7 @@ TEST(Async, DeepChain2) {
// Create a ridiculous chain of promises. // Create a ridiculous chain of promises.
for (uint i = 0; i < 1000; i++) { for (uint i = 0; i < 1000; i++) {
promise = evalLater(mvCapture(promise, [&](Promise<void> promise) { promise = evalLater(mvCapture(promise, [](Promise<void> promise) {
return kj::mv(promise); return kj::mv(promise);
})); }));
} }
......
linux-gcc-4.7 1773 ./super-test.sh tmpdir capnp-gcc-4.7 quick linux-gcc-4.7 1779 ./super-test.sh tmpdir capnp-gcc-4.7 quick
linux-gcc-4.8 1776 ./super-test.sh tmpdir capnp-gcc-4.8 quick gcc-4.8 linux-gcc-4.8 1782 ./super-test.sh tmpdir capnp-gcc-4.8 quick gcc-4.8
linux-clang 1796 ./super-test.sh tmpdir capnp-clang quick clang linux-clang 1802 ./super-test.sh tmpdir capnp-clang quick clang
mac 807 ./super-test.sh remote beat caffeinate quick mac 807 ./super-test.sh remote beat caffeinate quick
cygwin 812 ./super-test.sh remote Kenton@flashman quick cygwin 812 ./super-test.sh remote Kenton@flashman quick
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment