Commit 3c7efbb4 authored by Kenton Varda's avatar Kenton Varda

Eliminate the concept of imbuing messages in favor of the simpler concept of…

Eliminate the concept of imbuing messages in favor of the simpler concept of setting a cap table directly on MessageReader / getting one from MessageBuilder.  This eliminates capability-context entirely.  This was made possible by the earlier change which moved capabilities to a separate table rather than storing CapDescriptors inline, but I didn't realize it at the time.
parent bf5dbebf
...@@ -160,7 +160,6 @@ includecapnp_HEADERS = \ ...@@ -160,7 +160,6 @@ includecapnp_HEADERS = \
src/capnp/any.h \ src/capnp/any.h \
src/capnp/message.h \ src/capnp/message.h \
src/capnp/capability.h \ src/capnp/capability.h \
src/capnp/capability-context.h \
src/capnp/schema.capnp.h \ src/capnp/schema.capnp.h \
src/capnp/schema.h \ src/capnp/schema.h \
src/capnp/schema-loader.h \ src/capnp/schema-loader.h \
...@@ -235,7 +234,6 @@ libcapnp_rpc_la_LDFLAGS = -release $(VERSION) -no-undefined ...@@ -235,7 +234,6 @@ libcapnp_rpc_la_LDFLAGS = -release $(VERSION) -no-undefined
libcapnp_rpc_la_SOURCES= \ libcapnp_rpc_la_SOURCES= \
src/capnp/serialize-async.c++ \ src/capnp/serialize-async.c++ \
src/capnp/capability.c++ \ src/capnp/capability.c++ \
src/capnp/capability-context.c++ \
src/capnp/dynamic-capability.c++ \ src/capnp/dynamic-capability.c++ \
src/capnp/rpc.c++ \ src/capnp/rpc.c++ \
src/capnp/rpc.capnp.c++ \ src/capnp/rpc.capnp.c++ \
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// This sample code appears in the documentation for the C++ implementation. // This sample code appears in the documentation for the C++ implementation.
// //
// Compile with: // Compile with:
......
# Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@0x9eb32e19f86ee174; @0x9eb32e19f86ee174;
using Cxx = import "/capnp/c++.capnp"; using Cxx = import "/capnp/c++.capnp";
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "calculator.capnp.h" #include "calculator.capnp.h"
#include <capnp/ez-rpc.h> #include <capnp/ez-rpc.h>
#include <kj/debug.h> #include <kj/debug.h>
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "calculator.capnp.h" #include "calculator.capnp.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <capnp/ez-rpc.h> #include <capnp/ez-rpc.h>
#include <capnp/capability-context.h> // for LocalMessage #include <capnp/message.h>
#include <iostream> #include <iostream>
typedef unsigned int uint; typedef unsigned int uint;
...@@ -90,13 +113,15 @@ class FunctionImpl final: public Calculator::Function::Server { ...@@ -90,13 +113,15 @@ class FunctionImpl final: public Calculator::Function::Server {
public: public:
FunctionImpl(uint paramCount, Calculator::Expression::Reader body) FunctionImpl(uint paramCount, Calculator::Expression::Reader body)
: paramCount(paramCount), body(body) {} : paramCount(paramCount) {
this->body.setRoot(body);
}
kj::Promise<void> call(CallContext context) { kj::Promise<void> call(CallContext context) {
auto params = context.getParams().getParams(); auto params = context.getParams().getParams();
KJ_REQUIRE(params.size() == paramCount, "Wrong number of parameters."); KJ_REQUIRE(params.size() == paramCount, "Wrong number of parameters.");
return evaluateImpl(body.getRoot().getAs<Calculator::Expression>(), params) return evaluateImpl(body.getRoot<Calculator::Expression>(), params)
.then([context](double value) mutable { .then([context](double value) mutable {
context.getResults().setValue(value); context.getResults().setValue(value);
}); });
...@@ -106,10 +131,8 @@ private: ...@@ -106,10 +131,8 @@ private:
uint paramCount; uint paramCount;
// The function's arity. // The function's arity.
capnp::LocalMessage body; capnp::MallocMessageBuilder body;
// LocalMessage holds a message that might contain capabilities (interface // Stores a permanent copy of the function body.
// references). Here we're using it to hold a Calculator.Expression, which
// might contain Calculator.Function and/or Calculator.Value capabilities.
}; };
class OperatorImpl final: public Calculator::Function::Server { class OperatorImpl final: public Calculator::Function::Server {
......
# Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@0x85150b117366d14b; @0x85150b117366d14b;
interface Calculator { interface Calculator {
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include "arena.h" #include "arena.h"
#include "message.h" #include "message.h"
#include "capability.h" #include "capability.h"
#include "capability-context.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/refcount.h> #include <kj/refcount.h>
#include <vector> #include <vector>
...@@ -36,7 +35,6 @@ namespace capnp { ...@@ -36,7 +35,6 @@ namespace capnp {
namespace _ { // private namespace _ { // private
Arena::~Arena() noexcept(false) {} Arena::~Arena() noexcept(false) {}
BuilderArena::~BuilderArena() noexcept(false) {}
void ReadLimiter::unread(WordCount64 amount) { void ReadLimiter::unread(WordCount64 amount) {
// Be careful not to overflow here. Since ReadLimiter has no thread-safety, it's possible that // Be careful not to overflow here. Since ReadLimiter has no thread-safety, it's possible that
...@@ -51,14 +49,14 @@ void ReadLimiter::unread(WordCount64 amount) { ...@@ -51,14 +49,14 @@ void ReadLimiter::unread(WordCount64 amount) {
// ======================================================================================= // =======================================================================================
BasicReaderArena::BasicReaderArena(MessageReader* message) ReaderArena::ReaderArena(MessageReader* message)
: message(message), : message(message),
readLimiter(message->getOptions().traversalLimitInWords * WORDS), readLimiter(message->getOptions().traversalLimitInWords * WORDS),
segment0(this, SegmentId(0), message->getSegment(0), &readLimiter) {} segment0(this, SegmentId(0), message->getSegment(0), &readLimiter) {}
BasicReaderArena::~BasicReaderArena() noexcept(false) {} ReaderArena::~ReaderArena() noexcept(false) {}
SegmentReader* BasicReaderArena::tryGetSegment(SegmentId id) { SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
if (id == SegmentId(0)) { if (id == SegmentId(0)) {
if (segment0.getArray() == nullptr) { if (segment0.getArray() == nullptr) {
return nullptr; return nullptr;
...@@ -96,102 +94,41 @@ SegmentReader* BasicReaderArena::tryGetSegment(SegmentId id) { ...@@ -96,102 +94,41 @@ SegmentReader* BasicReaderArena::tryGetSegment(SegmentId id) {
return result; return result;
} }
void BasicReaderArena::reportReadLimitReached() { void ReaderArena::reportReadLimitReached() {
KJ_FAIL_REQUIRE("Exceeded message traversal limit. See capnp::ReaderOptions.") { KJ_FAIL_REQUIRE("Exceeded message traversal limit. See capnp::ReaderOptions.") {
return; return;
} }
} }
kj::Maybe<kj::Own<ClientHook>> BasicReaderArena::extractCap(uint index) { kj::Maybe<kj::Own<ClientHook>> ReaderArena::extractCap(uint index) {
return nullptr;
}
// =======================================================================================
ImbuedReaderArena::ImbuedReaderArena(Arena* base, BrokenCapFactory& brokenCapFactory,
kj::Array<kj::Own<ClientHook>>&& capTable)
: base(base), brokenCapFactory(brokenCapFactory), capTable(kj::mv(capTable)),
segment0(nullptr) {}
ImbuedReaderArena::~ImbuedReaderArena() noexcept(false) {}
SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) {
if (baseSegment == nullptr) return nullptr;
if (baseSegment->getSegmentId() == SegmentId(0)) {
if (segment0.getArena() == nullptr) {
kj::dtor(segment0);
kj::ctor(segment0, this, baseSegment);
}
KJ_DASSERT(segment0.getArray().begin() == baseSegment->getArray().begin());
return &segment0;
}
auto lock = moreSegments.lockExclusive();
SegmentMap* segments = nullptr;
KJ_IF_MAYBE(s, *lock) {
auto iter = s->get()->find(baseSegment);
if (iter != s->get()->end()) {
KJ_DASSERT(iter->second->getArray().begin() == baseSegment->getArray().begin());
return iter->second;
}
segments = *s;
} else {
auto newMap = kj::heap<SegmentMap>();
segments = newMap;
*lock = kj::mv(newMap);
}
auto newSegment = kj::heap<ImbuedSegmentReader>(this, baseSegment);
SegmentReader* result = newSegment;
segments->insert(std::make_pair(baseSegment, mv(newSegment)));
return result;
}
SegmentReader* ImbuedReaderArena::tryGetSegment(SegmentId id) {
return imbue(base->tryGetSegment(id));
}
void ImbuedReaderArena::reportReadLimitReached() {
return base->reportReadLimitReached();
}
kj::Maybe<kj::Own<ClientHook>> ImbuedReaderArena::extractCap(uint index) {
if (index < capTable.size()) { if (index < capTable.size()) {
return capTable[index]->addRef(); return capTable[index].map([](kj::Own<ClientHook>& cap) { return cap->addRef(); });
} else { } else {
KJ_FAIL_ASSERT("Invalid capability descriptor in message.") { return nullptr;
// Work around http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799 and
// http://llvm.org/bugs/show_bug.cgi?id=12286.
break;
}
return brokenCapFactory.newBrokenCap("Calling capability from invalid descriptor.");
} }
} }
// ======================================================================================= // =======================================================================================
BasicBuilderArena::BasicBuilderArena(MessageBuilder* message) BuilderArena::BuilderArena(MessageBuilder* message)
: message(message), segment0(nullptr, SegmentId(0), nullptr, nullptr) {} : message(message), segment0(nullptr, SegmentId(0), nullptr, nullptr) {}
BasicBuilderArena::~BasicBuilderArena() noexcept(false) {} BuilderArena::~BuilderArena() noexcept(false) {}
SegmentBuilder* BasicBuilderArena::getSegment(SegmentId id) { SegmentBuilder* BuilderArena::getSegment(SegmentId id) {
// This method is allowed to fail if the segment ID is not valid. // This method is allowed to fail if the segment ID is not valid.
if (id == SegmentId(0)) { if (id == SegmentId(0)) {
return &segment0; return &segment0;
} else { } else {
KJ_IF_MAYBE(s, moreSegments) { KJ_IF_MAYBE(s, moreSegments) {
KJ_REQUIRE(id.value - 1 < s->get()->builders.size(), "invalid segment id", id.value); KJ_REQUIRE(id.value - 1 < s->get()->builders.size(), "invalid segment id", id.value);
// TODO(cleanup): Return a const SegmentBuilder and tediously constify all SegmentBuilder return const_cast<SegmentBuilder*>(s->get()->builders[id.value - 1].get());
// pointers throughout the codebase.
return const_cast<BasicSegmentBuilder*>(s->get()->builders[id.value - 1].get());
} else { } else {
KJ_FAIL_REQUIRE("invalid segment id", id.value); KJ_FAIL_REQUIRE("invalid segment id", id.value);
} }
} }
} }
BasicBuilderArena::AllocateResult BasicBuilderArena::allocate(WordCount amount) { BuilderArena::AllocateResult BuilderArena::allocate(WordCount amount) {
if (segment0.getArena() == nullptr) { if (segment0.getArena() == nullptr) {
// We're allocating the first segment. // We're allocating the first segment.
kj::ArrayPtr<word> ptr = message->allocateSegment(amount / WORDS); kj::ArrayPtr<word> ptr = message->allocateSegment(amount / WORDS);
...@@ -230,7 +167,7 @@ BasicBuilderArena::AllocateResult BasicBuilderArena::allocate(WordCount amount) ...@@ -230,7 +167,7 @@ BasicBuilderArena::AllocateResult BasicBuilderArena::allocate(WordCount amount)
moreSegments = kj::mv(newSegmentState); moreSegments = kj::mv(newSegmentState);
} }
kj::Own<BasicSegmentBuilder> newBuilder = kj::heap<BasicSegmentBuilder>( kj::Own<SegmentBuilder> newBuilder = kj::heap<SegmentBuilder>(
this, SegmentId(segmentState->builders.size() + 1), this, SegmentId(segmentState->builders.size() + 1),
message->allocateSegment(amount / WORDS), &this->dummyLimiter); message->allocateSegment(amount / WORDS), &this->dummyLimiter);
SegmentBuilder* result = newBuilder.get(); SegmentBuilder* result = newBuilder.get();
...@@ -245,7 +182,7 @@ BasicBuilderArena::AllocateResult BasicBuilderArena::allocate(WordCount amount) ...@@ -245,7 +182,7 @@ BasicBuilderArena::AllocateResult BasicBuilderArena::allocate(WordCount amount)
} }
} }
kj::ArrayPtr<const kj::ArrayPtr<const word>> BasicBuilderArena::getSegmentsForOutput() { kj::ArrayPtr<const kj::ArrayPtr<const word>> BuilderArena::getSegmentsForOutput() {
// Although this is a read-only method, we shouldn't need to lock a mutex here because if this // Although this is a read-only method, we shouldn't need to lock a mutex here because if this
// is called multiple times simultaneously, we should only be overwriting the array with the // is called multiple times simultaneously, we should only be overwriting the array with the
// exact same data. If the number or size of segments is actually changing due to an activity // exact same data. If the number or size of segments is actually changing due to an activity
...@@ -276,7 +213,7 @@ kj::ArrayPtr<const kj::ArrayPtr<const word>> BasicBuilderArena::getSegmentsForOu ...@@ -276,7 +213,7 @@ kj::ArrayPtr<const kj::ArrayPtr<const word>> BasicBuilderArena::getSegmentsForOu
} }
} }
SegmentReader* BasicBuilderArena::tryGetSegment(SegmentId id) { SegmentReader* BuilderArena::tryGetSegment(SegmentId id) {
if (id == SegmentId(0)) { if (id == SegmentId(0)) {
if (segment0.getArena() == nullptr) { if (segment0.getArena() == nullptr) {
// We haven't allocated any segments yet. // We haven't allocated any segments yet.
...@@ -297,102 +234,21 @@ SegmentReader* BasicBuilderArena::tryGetSegment(SegmentId id) { ...@@ -297,102 +234,21 @@ SegmentReader* BasicBuilderArena::tryGetSegment(SegmentId id) {
} }
} }
void BasicBuilderArena::reportReadLimitReached() { void BuilderArena::reportReadLimitReached() {
KJ_FAIL_ASSERT("Read limit reached for BuilderArena, but it should have been unlimited.") { KJ_FAIL_ASSERT("Read limit reached for BuilderArena, but it should have been unlimited.") {
return; return;
} }
} }
kj::Maybe<kj::Own<ClientHook>> BasicBuilderArena::extractCap(uint index) { kj::Maybe<kj::Own<ClientHook>> BuilderArena::extractCap(uint index) {
return nullptr;
}
uint BasicBuilderArena::injectCap(kj::Own<ClientHook>&& cap) {
KJ_FAIL_REQUIRE("Cannot inject capability into a builder that has not been imbued with a "
"capability context.") {
return 0;
}
}
void BasicBuilderArena::dropCap(uint index) {
// They only way we could have a cap in the first place is if the error was already reported...
}
// =======================================================================================
ImbuedBuilderArena::ImbuedBuilderArena(BuilderArena* base, BrokenCapFactory& brokenCapFactory)
: base(base), brokenCapFactory(brokenCapFactory), segment0(nullptr) {}
ImbuedBuilderArena::~ImbuedBuilderArena() noexcept(false) {}
SegmentBuilder* ImbuedBuilderArena::imbue(SegmentBuilder* baseSegment) {
if (baseSegment == nullptr) return nullptr;
SegmentBuilder* result;
if (baseSegment->getSegmentId() == SegmentId(0)) {
if (segment0.getArena() == nullptr) {
kj::dtor(segment0);
kj::ctor(segment0, this, baseSegment);
}
result = &segment0;
} else {
MultiSegmentState* segmentState;
KJ_IF_MAYBE(s, moreSegments) {
segmentState = *s;
} else {
auto newState = kj::heap<MultiSegmentState>();
segmentState = newState;
moreSegments = kj::mv(newState);
}
auto id = baseSegment->getSegmentId().value;
if (id >= segmentState->builders.size()) {
segmentState->builders.resize(id + 1);
}
KJ_IF_MAYBE(segment, segmentState->builders[id]) {
result = *segment;
} else {
auto newBuilder = kj::heap<ImbuedSegmentBuilder>(this, baseSegment);
result = newBuilder;
segmentState->builders[id] = kj::mv(newBuilder);
}
}
KJ_DASSERT(result->getArray().begin() == baseSegment->getArray().begin());
return result;
}
SegmentReader* ImbuedBuilderArena::tryGetSegment(SegmentId id) {
return imbue(static_cast<SegmentBuilder*>(base->tryGetSegment(id)));
}
void ImbuedBuilderArena::reportReadLimitReached() {
base->reportReadLimitReached();
}
kj::Maybe<kj::Own<ClientHook>> ImbuedBuilderArena::extractCap(uint index) {
if (index < capTable.size()) { if (index < capTable.size()) {
return capTable[index]->addRef(); return capTable[index].map([](kj::Own<ClientHook>& cap) { return cap->addRef(); });
} else { } else {
KJ_FAIL_ASSERT("Invalid capability descriptor in message.") { return nullptr;
// Work around http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799 and
// http://llvm.org/bugs/show_bug.cgi?id=12286.
break;
}
return brokenCapFactory.newBrokenCap("Calling capability from invalid descriptor.");
} }
} }
SegmentBuilder* ImbuedBuilderArena::getSegment(SegmentId id) { uint BuilderArena::injectCap(kj::Own<ClientHook>&& cap) {
return imbue(base->getSegment(id));
}
BuilderArena::AllocateResult ImbuedBuilderArena::allocate(WordCount amount) {
auto result = base->allocate(amount);
result.segment = imbue(result.segment);
return result;
}
uint ImbuedBuilderArena::injectCap(kj::Own<ClientHook>&& cap) {
// TODO(perf): Detect if the cap is already on the table and reuse the index? Perhaps this // TODO(perf): Detect if the cap is already on the table and reuse the index? Perhaps this
// doesn't happen enough to be worth the effort. // doesn't happen enough to be worth the effort.
uint result = capTable.size(); uint result = capTable.size();
...@@ -400,7 +256,7 @@ uint ImbuedBuilderArena::injectCap(kj::Own<ClientHook>&& cap) { ...@@ -400,7 +256,7 @@ uint ImbuedBuilderArena::injectCap(kj::Own<ClientHook>&& cap) {
return result; return result;
} }
void ImbuedBuilderArena::dropCap(uint index) { void BuilderArena::dropCap(uint index) {
KJ_ASSERT(index < capTable.size(), "Invalid capability descriptor in message.") { KJ_ASSERT(index < capTable.size(), "Invalid capability descriptor in message.") {
return; return;
} }
......
...@@ -47,11 +47,7 @@ namespace _ { // private ...@@ -47,11 +47,7 @@ namespace _ { // private
class SegmentReader; class SegmentReader;
class SegmentBuilder; class SegmentBuilder;
class Arena; class Arena;
class BasicReaderArena;
class ImbuedReaderArena;
class BuilderArena; class BuilderArena;
class BasicBuilderArena;
class ImbuedBuilderArena;
class ReadLimiter; class ReadLimiter;
class Segment; class Segment;
...@@ -133,20 +129,12 @@ private: ...@@ -133,20 +129,12 @@ private:
KJ_DISALLOW_COPY(SegmentReader); KJ_DISALLOW_COPY(SegmentReader);
friend class SegmentBuilder; friend class SegmentBuilder;
friend class ImbuedSegmentBuilder;
friend class ImbuedSegmentReader;
};
class ImbuedSegmentReader: public SegmentReader {
public:
inline ImbuedSegmentReader(Arena* arena, SegmentReader* base);
inline ImbuedSegmentReader(decltype(nullptr));
}; };
class SegmentBuilder: public SegmentReader { class SegmentBuilder: public SegmentReader {
public: public:
inline SegmentBuilder(BuilderArena* arena, SegmentId id, kj::ArrayPtr<word> ptr, inline SegmentBuilder(BuilderArena* arena, SegmentId id, kj::ArrayPtr<word> ptr,
ReadLimiter* readLimiter, word** pos); ReadLimiter* readLimiter);
KJ_ALWAYS_INLINE(word* allocate(WordCount amount)); KJ_ALWAYS_INLINE(word* allocate(WordCount amount));
inline word* getPtrUnchecked(WordCount offset); inline word* getPtrUnchecked(WordCount offset);
...@@ -158,35 +146,13 @@ public: ...@@ -158,35 +146,13 @@ public:
inline void reset(); inline void reset();
private: private:
word** pos; word* pos;
// Pointer to a pointer to the current end point of the segment, i.e. the location where the // Pointer to a pointer to the current end point of the segment, i.e. the location where the
// next object should be allocated. The extra level of indirection allows an // next object should be allocated.
// ImbuedSegmentBuilder to share this pointer with the underlying BasicSegmentBuilder.
friend class ImbuedSegmentBuilder;
KJ_DISALLOW_COPY(SegmentBuilder); KJ_DISALLOW_COPY(SegmentBuilder);
}; };
class BasicSegmentBuilder: public SegmentBuilder {
public:
inline BasicSegmentBuilder(BuilderArena* arena, SegmentId id, kj::ArrayPtr<word> ptr,
ReadLimiter* readLimiter);
private:
word* actualPos;
KJ_DISALLOW_COPY(BasicSegmentBuilder);
};
class ImbuedSegmentBuilder: public SegmentBuilder {
public:
inline ImbuedSegmentBuilder(ImbuedBuilderArena* arena, SegmentBuilder* base);
inline ImbuedSegmentBuilder(decltype(nullptr));
KJ_DISALLOW_COPY(ImbuedSegmentBuilder);
};
class Arena { class Arena {
public: public:
virtual ~Arena() noexcept(false); virtual ~Arena() noexcept(false);
...@@ -200,16 +166,21 @@ public: ...@@ -200,16 +166,21 @@ public:
// will need to continue with default values. // will need to continue with default values.
virtual kj::Maybe<kj::Own<ClientHook>> extractCap(uint index) = 0; virtual kj::Maybe<kj::Own<ClientHook>> extractCap(uint index) = 0;
// Extract the capability at the given index. If the index is invalid, returns a dummy // Extract the capability at the given index. If the index is invalid, returns null.
// capability whose methods all throw. Returns null only if the message is not imbued with a
// capability context.
}; };
class BasicReaderArena final: public Arena { class ReaderArena final: public Arena {
public: public:
BasicReaderArena(MessageReader* message); ReaderArena(MessageReader* message);
~BasicReaderArena() noexcept(false); ~ReaderArena() noexcept(false);
KJ_DISALLOW_COPY(BasicReaderArena); KJ_DISALLOW_COPY(ReaderArena);
inline void initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTable) {
// Imbues the arena with a capability table. This is not passed to the constructor because the
// table itself may be built based on some other part of the message (as is the case with the
// RPC protocol).
this->capTable = kj::mv(capTable);
}
// implements Arena ------------------------------------------------ // implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override; SegmentReader* tryGetSegment(SegmentId id) override;
...@@ -219,6 +190,7 @@ public: ...@@ -219,6 +190,7 @@ public:
private: private:
MessageReader* message; MessageReader* message;
ReadLimiter readLimiter; ReadLimiter readLimiter;
kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTable;
// Optimize for single-segment messages so that small messages are handled quickly. // Optimize for single-segment messages so that small messages are handled quickly.
SegmentReader segment0; SegmentReader segment0;
...@@ -234,37 +206,25 @@ private: ...@@ -234,37 +206,25 @@ private:
// possibly backed by the same data)? // possibly backed by the same data)?
}; };
class ImbuedReaderArena final: public Arena { class BuilderArena final: public Arena {
public: // A BuilderArena that does not allow the injection of capabilities.
ImbuedReaderArena(Arena* base, BrokenCapFactory& brokenCapFactory,
kj::Array<kj::Own<ClientHook>>&& capTable);
~ImbuedReaderArena() noexcept(false);
KJ_DISALLOW_COPY(ImbuedReaderArena);
SegmentReader* imbue(SegmentReader* base);
// implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override;
void reportReadLimitReached() override;
kj::Maybe<kj::Own<ClientHook>> extractCap(uint index);
private: public:
Arena* base; BuilderArena(MessageBuilder* message);
BrokenCapFactory& brokenCapFactory; ~BuilderArena() noexcept(false);
kj::Array<kj::Own<ClientHook>> capTable; KJ_DISALLOW_COPY(BuilderArena);
// Optimize for single-segment messages so that small messages are handled quickly. inline SegmentBuilder* getRootSegment() { return &segment0; }
ImbuedSegmentReader segment0;
typedef std::unordered_map<SegmentReader*, kj::Own<ImbuedSegmentReader>> SegmentMap; kj::ArrayPtr<const kj::ArrayPtr<const word>> getSegmentsForOutput();
kj::MutexGuarded<kj::Maybe<kj::Own<SegmentMap>>> moreSegments; // Get an array of all the segments, suitable for writing out. This only returns the allocated
}; // portion of each segment, whereas tryGetSegment() returns something that includes
// not-yet-allocated space.
class BuilderArena: public Arena { inline kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> getCapTable() { return capTable; }
public: // Return the capability table.
virtual ~BuilderArena() noexcept(false);
virtual SegmentBuilder* getSegment(SegmentId id) = 0; SegmentBuilder* getSegment(SegmentId id);
// Get the segment with the given id. Crashes or throws an exception if no such segment exists. // Get the segment with the given id. Crashes or throws an exception if no such segment exists.
struct AllocateResult { struct AllocateResult {
...@@ -272,99 +232,40 @@ public: ...@@ -272,99 +232,40 @@ public:
word* words; word* words;
}; };
virtual AllocateResult allocate(WordCount amount) = 0; AllocateResult allocate(WordCount amount);
// Find a segment with at least the given amount of space available and allocate the space. // Find a segment with at least the given amount of space available and allocate the space.
// Note that allocating directly from a particular segment is much faster, but allocating from // Note that allocating directly from a particular segment is much faster, but allocating from
// the arena is guaranteed to succeed. Therefore callers should try to allocate from a specific // the arena is guaranteed to succeed. Therefore callers should try to allocate from a specific
// segment first if there is one, then fall back to the arena. // segment first if there is one, then fall back to the arena.
virtual uint injectCap(kj::Own<ClientHook>&& cap) = 0; uint injectCap(kj::Own<ClientHook>&& cap);
// Add the capability to the message and return its index. If the same ClientHook is injected // Add the capability to the message and return its index. If the same ClientHook is injected
// twice, this may return the same index both times, but in this case dropCap() needs to be // twice, this may return the same index both times, but in this case dropCap() needs to be
// called an equal number of times to actually remove the cap. // called an equal number of times to actually remove the cap.
virtual void dropCap(uint index) = 0; void dropCap(uint index);
// Remove a capability injected earlier. Called when the pointer is overwritten or zero'd out. // Remove a capability injected earlier. Called when the pointer is overwritten or zero'd out.
};
class BasicBuilderArena final: public BuilderArena {
// A BuilderArena that does not allow the injection of capabilities.
public:
BasicBuilderArena(MessageBuilder* message);
~BasicBuilderArena() noexcept(false);
KJ_DISALLOW_COPY(BasicBuilderArena);
inline SegmentBuilder* getRootSegment() { return &segment0; }
kj::ArrayPtr<const kj::ArrayPtr<const word>> getSegmentsForOutput();
// Get an array of all the segments, suitable for writing out. This only returns the allocated
// portion of each segment, whereas tryGetSegment() returns something that includes
// not-yet-allocated space.
// implements Arena ------------------------------------------------ // implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override; SegmentReader* tryGetSegment(SegmentId id) override;
void reportReadLimitReached() override; void reportReadLimitReached() override;
kj::Maybe<kj::Own<ClientHook>> extractCap(uint index); kj::Maybe<kj::Own<ClientHook>> extractCap(uint index);
// implements BuilderArena -----------------------------------------
SegmentBuilder* getSegment(SegmentId id) override;
AllocateResult allocate(WordCount amount) override;
uint injectCap(kj::Own<ClientHook>&& cap);
void dropCap(uint index);
private: private:
MessageBuilder* message; MessageBuilder* message;
ReadLimiter dummyLimiter; ReadLimiter dummyLimiter;
kj::Vector<kj::Maybe<kj::Own<ClientHook>>> capTable;
BasicSegmentBuilder segment0; SegmentBuilder segment0;
kj::ArrayPtr<const word> segment0ForOutput; kj::ArrayPtr<const word> segment0ForOutput;
struct MultiSegmentState { struct MultiSegmentState {
kj::Vector<kj::Own<BasicSegmentBuilder>> builders; kj::Vector<kj::Own<SegmentBuilder>> builders;
kj::Vector<kj::ArrayPtr<const word>> forOutput; kj::Vector<kj::ArrayPtr<const word>> forOutput;
}; };
kj::Maybe<kj::Own<MultiSegmentState>> moreSegments; kj::Maybe<kj::Own<MultiSegmentState>> moreSegments;
}; };
class ImbuedBuilderArena final: public BuilderArena {
// A BuilderArena imbued with the ability to inject capabilities.
public:
ImbuedBuilderArena(BuilderArena* base, BrokenCapFactory& brokenCapFactory);
~ImbuedBuilderArena() noexcept(false);
KJ_DISALLOW_COPY(ImbuedBuilderArena);
SegmentBuilder* imbue(SegmentBuilder* baseSegment);
// Return an imbued SegmentBuilder corresponding to the given segment from the base arena.
inline kj::ArrayPtr<kj::Own<ClientHook>> getCapTable() { return capTable; }
// Release and return the capability table.
// implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override;
void reportReadLimitReached() override;
kj::Maybe<kj::Own<ClientHook>> extractCap(uint index);
// implements BuilderArena -----------------------------------------
SegmentBuilder* getSegment(SegmentId id) override;
AllocateResult allocate(WordCount amount) override;
uint injectCap(kj::Own<ClientHook>&& cap);
void dropCap(uint index);
private:
BuilderArena* base;
BrokenCapFactory& brokenCapFactory;
kj::Vector<kj::Own<ClientHook>> capTable;
ImbuedSegmentBuilder segment0;
struct MultiSegmentState {
kj::Vector<kj::Maybe<kj::Own<ImbuedSegmentBuilder>>> builders;
};
kj::Maybe<kj::Own<MultiSegmentState>> moreSegments;
};
// ======================================================================================= // =======================================================================================
inline ReadLimiter::ReadLimiter() inline ReadLimiter::ReadLimiter()
...@@ -411,25 +312,20 @@ inline WordCount SegmentReader::getSize() { return ptr.size() * WORDS; } ...@@ -411,25 +312,20 @@ inline WordCount SegmentReader::getSize() { return ptr.size() * WORDS; }
inline kj::ArrayPtr<const word> SegmentReader::getArray() { return ptr; } inline kj::ArrayPtr<const word> SegmentReader::getArray() { return ptr; }
inline void SegmentReader::unread(WordCount64 amount) { readLimiter->unread(amount); } inline void SegmentReader::unread(WordCount64 amount) { readLimiter->unread(amount); }
inline ImbuedSegmentReader::ImbuedSegmentReader(Arena* arena, SegmentReader* base)
: SegmentReader(arena, base->id, base->ptr, base->readLimiter) {}
inline ImbuedSegmentReader::ImbuedSegmentReader(decltype(nullptr))
: SegmentReader(nullptr, SegmentId(0), nullptr, nullptr) {}
// ------------------------------------------------------------------- // -------------------------------------------------------------------
inline SegmentBuilder::SegmentBuilder( inline SegmentBuilder::SegmentBuilder(
BuilderArena* arena, SegmentId id, kj::ArrayPtr<word> ptr, ReadLimiter* readLimiter, word** pos) BuilderArena* arena, SegmentId id, kj::ArrayPtr<word> ptr, ReadLimiter* readLimiter)
: SegmentReader(arena, id, ptr, readLimiter), pos(pos) {} : SegmentReader(arena, id, ptr, readLimiter), pos(ptr.begin()) {}
inline word* SegmentBuilder::allocate(WordCount amount) { inline word* SegmentBuilder::allocate(WordCount amount) {
if (intervalLength(*pos, ptr.end()) < amount) { if (intervalLength(pos, ptr.end()) < amount) {
// Not enough space in the segment for this allocation. // Not enough space in the segment for this allocation.
return nullptr; return nullptr;
} else { } else {
// Success. // Success.
word* result = *pos; word* result = pos;
*pos = *pos + amount; pos = pos + amount;
return result; return result;
} }
} }
...@@ -447,27 +343,15 @@ inline BuilderArena* SegmentBuilder::getArena() { ...@@ -447,27 +343,15 @@ inline BuilderArena* SegmentBuilder::getArena() {
} }
inline kj::ArrayPtr<const word> SegmentBuilder::currentlyAllocated() { inline kj::ArrayPtr<const word> SegmentBuilder::currentlyAllocated() {
return kj::arrayPtr(ptr.begin(), *pos - ptr.begin()); return kj::arrayPtr(ptr.begin(), pos - ptr.begin());
} }
inline void SegmentBuilder::reset() { inline void SegmentBuilder::reset() {
word* start = getPtrUnchecked(0 * WORDS); word* start = getPtrUnchecked(0 * WORDS);
memset(start, 0, (*pos - start) * sizeof(word)); memset(start, 0, (pos - start) * sizeof(word));
*pos = start; pos = start;
} }
inline BasicSegmentBuilder::BasicSegmentBuilder(
BuilderArena* arena, SegmentId id, kj::ArrayPtr<word> ptr, ReadLimiter* readLimiter)
: SegmentBuilder(arena, id, ptr, readLimiter, &actualPos),
actualPos(ptr.begin()) {}
inline ImbuedSegmentBuilder::ImbuedSegmentBuilder(ImbuedBuilderArena* arena, SegmentBuilder* base)
: SegmentBuilder(arena, base->id,
kj::arrayPtr(const_cast<word*>(base->ptr.begin()), base->ptr.size()),
base->readLimiter, base->pos) {}
inline ImbuedSegmentBuilder::ImbuedSegmentBuilder(decltype(nullptr))
: SegmentBuilder(nullptr, SegmentId(0), nullptr, nullptr, nullptr) {}
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define CAPNP_PRIVATE
#include "capability-context.h"
#include "capability.h"
#include "arena.h"
#include <kj/debug.h>
namespace capnp {
namespace _ {
void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory);
// Defined in layout.c++.
} // namespace _
namespace {
class BrokenCapFactoryImpl: public _::BrokenCapFactory {
public:
kj::Own<ClientHook> newBrokenCap(kj::StringPtr description) override {
return capnp::newBrokenCap(description);
}
};
static BrokenCapFactoryImpl brokenCapFactory;
} // namespace
CapReaderContext::CapReaderContext(kj::Array<kj::Own<ClientHook>>&& capTable)
: capTable(kj::mv(capTable)) {
setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
}
CapReaderContext::~CapReaderContext() noexcept(false) {
if (capTable == nullptr) {
kj::dtor(arena());
}
}
AnyPointer::Reader CapReaderContext::imbue(AnyPointer::Reader base) {
KJ_IF_MAYBE(oldArena, base.reader.getArena()) {
static_assert(sizeof(arena()) <= sizeof(arenaSpace),
"arenaSpace is too small. Please increase it.");
kj::ctor(arena(), oldArena, brokenCapFactory,
kj::mv(KJ_REQUIRE_NONNULL(capTable, "imbue() can only be called once.")));
} else {
KJ_FAIL_REQUIRE("Cannot imbue unchecked message.");
}
capTable = nullptr;
return AnyPointer::Reader(base.reader.imbue(arena()));
}
CapBuilderContext::CapBuilderContext() {
setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
}
CapBuilderContext::~CapBuilderContext() noexcept(false) {
if (arenaAllocated) {
kj::dtor(arena());
}
}
AnyPointer::Builder CapBuilderContext::imbue(AnyPointer::Builder base) {
KJ_REQUIRE(!arenaAllocated, "imbue() can only be called once.");
static_assert(sizeof(arena()) <= sizeof(arenaSpace),
"arenaSpace is too small. Please increase it.");
kj::ctor(arena(), base.builder.getArena(), brokenCapFactory);
arenaAllocated = true;
return AnyPointer::Builder(base.builder.imbue(arena()));
}
kj::ArrayPtr<kj::Own<ClientHook>> CapBuilderContext::getCapTable() {
if (arenaAllocated) {
return arena().getCapTable();
} else {
return nullptr;
}
}
// =======================================================================================
namespace {
uint firstSegmentSize(kj::Maybe<MessageSize> sizeHint) {
KJ_IF_MAYBE(s, sizeHint) {
// 1 for the root pointer. We don't store caps in the message so we don't count those here.
return s->wordCount + 1;
} else {
return SUGGESTED_FIRST_SEGMENT_WORDS;
}
}
} // namespace
LocalMessage::LocalMessage(kj::Maybe<MessageSize> sizeHint)
: message(firstSegmentSize(sizeHint)),
root(capContext.imbue(message.getRoot<AnyPointer>())) {}
// =======================================================================================
namespace {
class BrokenPipeline final: public PipelineHook, public kj::Refcounted {
public:
BrokenPipeline(const kj::Exception& exception): exception(exception) {}
kj::Own<PipelineHook> addRef() override {
return kj::addRef(*this);
}
kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) override;
private:
kj::Exception exception;
};
class BrokenRequest final: public RequestHook {
public:
BrokenRequest(const kj::Exception& exception, kj::Maybe<MessageSize> sizeHint)
: exception(exception), message(sizeHint) {}
RemotePromise<AnyPointer> send() override {
return RemotePromise<AnyPointer>(kj::cp(exception),
AnyPointer::Pipeline(kj::refcounted<BrokenPipeline>(exception)));
}
const void* getBrand() {
return nullptr;
}
kj::Exception exception;
LocalMessage message;
};
class BrokenClient final: public ClientHook, public kj::Refcounted {
public:
BrokenClient(const kj::Exception& exception): exception(exception) {}
BrokenClient(const kj::StringPtr description)
: exception(kj::Exception::Nature::PRECONDITION, kj::Exception::Durability::PERMANENT,
"", 0, kj::str(description)) {}
Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<BrokenRequest>(exception, sizeHint);
auto root = hook->message.getRoot();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) override {
return VoidPromiseAndPipeline { kj::cp(exception), kj::heap<BrokenPipeline>(exception) };
}
kj::Maybe<ClientHook&> getResolved() {
return nullptr;
}
kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
return kj::Promise<kj::Own<ClientHook>>(kj::cp(exception));
}
kj::Own<ClientHook> addRef() override {
return kj::addRef(*this);
}
const void* getBrand() override {
return nullptr;
}
private:
kj::Exception exception;
};
kj::Own<ClientHook> BrokenPipeline::getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) {
return kj::refcounted<BrokenClient>(exception);
}
} // namespace
kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason) {
return kj::refcounted<BrokenClient>(reason);
}
kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason) {
return kj::refcounted<BrokenClient>(kj::mv(reason));
}
kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
return kj::refcounted<BrokenPipeline>(kj::mv(reason));
}
} // namespace capnp
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Classes for imbuing message readers/builders with a capability context.
//
// These classes are for use by RPC implementations. Application code need not know about them.
//
// Normally, MessageReader and MessageBuilder do not support interface pointers because they
// are not RPC-aware and so have no idea how to convert between a serialized CapabilityDescriptor
// and a live capability. To fix this, a reader/builder object needs to be "imbued" with a
// capability context. This creates a new reader/builder which points at the same object but has
// the ability to deal with interface fields. Use `CapReaderContext` and `CapBuilderContext` to
// accomplish this.
#ifndef CAPNP_CAPABILITY_CONTEXT_H_
#define CAPNP_CAPABILITY_CONTEXT_H_
#include "layout.h"
#include "any.h"
#include "message.h"
#include <kj/mutex.h>
#include <kj/vector.h>
namespace kj { class Exception; }
namespace capnp {
class ClientHook;
// -------------------------------------------------------------------
class CapReaderContext {
// Class which can "imbue" reader objects from some other message with a capability context,
// so that interface pointers found in the message can be extracted and called.
//
// `imbue()` can only be called once per context.
public:
CapReaderContext(kj::Array<kj::Own<ClientHook>>&& capTable);
// `capTable` is the list of capabilities for this message.
~CapReaderContext() noexcept(false);
AnyPointer::Reader imbue(AnyPointer::Reader base);
private:
kj::Maybe<kj::Array<kj::Own<ClientHook>>> capTable; // becomes null once arena is allocated
void* arenaSpace[12 + sizeof(kj::MutexGuarded<void*>) / sizeof(void*)];
_::ImbuedReaderArena& arena() { return *reinterpret_cast<_::ImbuedReaderArena*>(arenaSpace); }
friend class _::ImbuedReaderArena;
};
class CapBuilderContext {
// Class which can "imbue" reader objects from some other message with a capability context,
// so that interface pointers found in the message can be set to point at live capabilities.
//
// `imbue()` can only be called once per context.
public:
CapBuilderContext();
~CapBuilderContext() noexcept(false);
AnyPointer::Builder imbue(AnyPointer::Builder base);
kj::ArrayPtr<kj::Own<ClientHook>> getCapTable();
// Return the table of capabilities injected into the message.
private:
bool arenaAllocated = false;
void* arenaSpace[15];
_::ImbuedBuilderArena& arena() { return *reinterpret_cast<_::ImbuedBuilderArena*>(arenaSpace); }
friend class _::ImbuedBuilderArena;
};
// -------------------------------------------------------------------
class LocalMessage final {
// An in-process message which can contain capabilities. Use in place of MallocMessageBuilder
// when you need to be able to construct a message in-memory that contains capabilities, and this
// message will never leave the process. You cannot serialize this message, since it doesn't
// know how to properly serialize its capabilities.
public:
explicit LocalMessage(kj::Maybe<MessageSize> sizeHint = nullptr);
template <typename T, typename = FromReader<T>>
inline LocalMessage(T&& reader): LocalMessage(reader.totalSize()) {
// Create a LocalMessage that is a copy of a given reader.
getRoot().setAs<FromReader<T>>(kj::fwd<T>(reader));
}
inline AnyPointer::Builder getRoot() { return root; }
inline AnyPointer::Reader getRootReader() const { return root.asReader(); }
private:
MallocMessageBuilder message;
CapBuilderContext capContext;
AnyPointer::Builder root;
};
kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason);
kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called.
kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason);
// Helper function that creates a pipeline which simply throws exceptions when called.
} // namespace capnp
#endif // CAPNP_CAPABILITY_CONTEXT_H_
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#define CAPNP_PRIVATE #define CAPNP_PRIVATE
#include "capability.h" #include "capability.h"
#include "capability-context.h"
#include "message.h" #include "message.h"
#include "arena.h" #include "arena.h"
#include <kj/refcount.h> #include <kj/refcount.h>
...@@ -34,6 +33,37 @@ ...@@ -34,6 +33,37 @@
namespace capnp { namespace capnp {
namespace _ {
void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory);
// Defined in layout.c++.
} // namespace _
namespace {
class BrokenCapFactoryImpl: public _::BrokenCapFactory {
public:
kj::Own<ClientHook> newBrokenCap(kj::StringPtr description) override {
return capnp::newBrokenCap(description);
}
};
static BrokenCapFactoryImpl brokenCapFactory;
} // namespace
ClientHook::ClientHook() {
setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
}
void MessageReader::initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTable) {
setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
arena()->initCapTable(kj::mv(capTable));
}
// =======================================================================================
Capability::Client::Client(decltype(nullptr)) Capability::Client::Client(decltype(nullptr))
: hook(newBrokenCap("Called null capability.")) {} : hook(newBrokenCap("Called null capability.")) {}
...@@ -86,24 +116,32 @@ kj::Promise<void> ClientHook::whenResolved() { ...@@ -86,24 +116,32 @@ kj::Promise<void> ClientHook::whenResolved() {
// ======================================================================================= // =======================================================================================
static inline uint firstSegmentSize(kj::Maybe<MessageSize> sizeHint) {
KJ_IF_MAYBE(s, sizeHint) {
return s->wordCount;
} else {
return SUGGESTED_FIRST_SEGMENT_WORDS;
}
}
class LocalResponse final: public ResponseHook, public kj::Refcounted { class LocalResponse final: public ResponseHook, public kj::Refcounted {
public: public:
LocalResponse(kj::Maybe<MessageSize> sizeHint) LocalResponse(kj::Maybe<MessageSize> sizeHint)
: message(sizeHint) {} : message(firstSegmentSize(sizeHint)) {}
LocalMessage message; MallocMessageBuilder message;
}; };
class LocalCallContext final: public CallContextHook, public kj::Refcounted { class LocalCallContext final: public CallContextHook, public kj::Refcounted {
public: public:
LocalCallContext(kj::Own<LocalMessage>&& request, kj::Own<ClientHook> clientRef, LocalCallContext(kj::Own<MallocMessageBuilder>&& request, kj::Own<ClientHook> clientRef,
kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller) kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller)
: request(kj::mv(request)), clientRef(kj::mv(clientRef)), : request(kj::mv(request)), clientRef(kj::mv(clientRef)),
cancelAllowedFulfiller(kj::mv(cancelAllowedFulfiller)) {} cancelAllowedFulfiller(kj::mv(cancelAllowedFulfiller)) {}
AnyPointer::Reader getParams() override { AnyPointer::Reader getParams() override {
KJ_IF_MAYBE(r, request) { KJ_IF_MAYBE(r, request) {
return r->get()->getRoot(); return r->get()->getRoot<AnyPointer>();
} else { } else {
KJ_FAIL_REQUIRE("Can't call getParams() after releaseParams()."); KJ_FAIL_REQUIRE("Can't call getParams() after releaseParams().");
} }
...@@ -114,7 +152,7 @@ public: ...@@ -114,7 +152,7 @@ public:
AnyPointer::Builder getResults(kj::Maybe<MessageSize> sizeHint) override { AnyPointer::Builder getResults(kj::Maybe<MessageSize> sizeHint) override {
if (response == nullptr) { if (response == nullptr) {
auto localResponse = kj::refcounted<LocalResponse>(sizeHint); auto localResponse = kj::refcounted<LocalResponse>(sizeHint);
responseBuilder = localResponse->message.getRoot(); responseBuilder = localResponse->message.getRoot<AnyPointer>();
response = Response<AnyPointer>(responseBuilder.asReader(), kj::mv(localResponse)); response = Response<AnyPointer>(responseBuilder.asReader(), kj::mv(localResponse));
} }
return responseBuilder; return responseBuilder;
...@@ -149,7 +187,7 @@ public: ...@@ -149,7 +187,7 @@ public:
return kj::addRef(*this); return kj::addRef(*this);
} }
kj::Maybe<kj::Own<LocalMessage>> request; kj::Maybe<kj::Own<MallocMessageBuilder>> request;
kj::Maybe<Response<AnyPointer>> response; kj::Maybe<Response<AnyPointer>> response;
AnyPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null AnyPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null
kj::Own<ClientHook> clientRef; kj::Own<ClientHook> clientRef;
...@@ -161,7 +199,7 @@ class LocalRequest final: public RequestHook { ...@@ -161,7 +199,7 @@ class LocalRequest final: public RequestHook {
public: public:
inline LocalRequest(uint64_t interfaceId, uint16_t methodId, inline LocalRequest(uint64_t interfaceId, uint16_t methodId,
kj::Maybe<MessageSize> sizeHint, kj::Own<ClientHook> client) kj::Maybe<MessageSize> sizeHint, kj::Own<ClientHook> client)
: message(kj::heap<LocalMessage>(sizeHint)), : message(kj::heap<MallocMessageBuilder>(firstSegmentSize(sizeHint))),
interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {} interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {}
RemotePromise<AnyPointer> send() override { RemotePromise<AnyPointer> send() override {
...@@ -204,7 +242,7 @@ public: ...@@ -204,7 +242,7 @@ public:
return nullptr; return nullptr;
} }
kj::Own<LocalMessage> message; kj::Own<MallocMessageBuilder> message;
private: private:
uint64_t interfaceId; uint64_t interfaceId;
...@@ -274,7 +312,7 @@ public: ...@@ -274,7 +312,7 @@ public:
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<LocalRequest>( auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, sizeHint, kj::addRef(*this)); interfaceId, methodId, sizeHint, kj::addRef(*this));
auto root = hook->message->getRoot(); // Do not inline `root` -- kj::mv may happen first. auto root = hook->message->getRoot<AnyPointer>();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
} }
...@@ -424,7 +462,7 @@ public: ...@@ -424,7 +462,7 @@ public:
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<LocalRequest>( auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, sizeHint, kj::addRef(*this)); interfaceId, methodId, sizeHint, kj::addRef(*this));
auto root = hook->message->getRoot(); // Do not inline `root` -- kj::mv may happen first. auto root = hook->message->getRoot<AnyPointer>();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
} }
...@@ -495,4 +533,97 @@ kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& pro ...@@ -495,4 +533,97 @@ kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& pro
return kj::refcounted<QueuedClient>(kj::mv(promise)); return kj::refcounted<QueuedClient>(kj::mv(promise));
} }
// =======================================================================================
namespace {
class BrokenPipeline final: public PipelineHook, public kj::Refcounted {
public:
BrokenPipeline(const kj::Exception& exception): exception(exception) {}
kj::Own<PipelineHook> addRef() override {
return kj::addRef(*this);
}
kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) override;
private:
kj::Exception exception;
};
class BrokenRequest final: public RequestHook {
public:
BrokenRequest(const kj::Exception& exception, kj::Maybe<MessageSize> sizeHint)
: exception(exception), message(firstSegmentSize(sizeHint)) {}
RemotePromise<AnyPointer> send() override {
return RemotePromise<AnyPointer>(kj::cp(exception),
AnyPointer::Pipeline(kj::refcounted<BrokenPipeline>(exception)));
}
const void* getBrand() {
return nullptr;
}
kj::Exception exception;
MallocMessageBuilder message;
};
class BrokenClient final: public ClientHook, public kj::Refcounted {
public:
BrokenClient(const kj::Exception& exception): exception(exception) {}
BrokenClient(const kj::StringPtr description)
: exception(kj::Exception::Nature::PRECONDITION, kj::Exception::Durability::PERMANENT,
"", 0, kj::str(description)) {}
Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<BrokenRequest>(exception, sizeHint);
auto root = hook->message.getRoot<AnyPointer>();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) override {
return VoidPromiseAndPipeline { kj::cp(exception), kj::heap<BrokenPipeline>(exception) };
}
kj::Maybe<ClientHook&> getResolved() {
return nullptr;
}
kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
return kj::Promise<kj::Own<ClientHook>>(kj::cp(exception));
}
kj::Own<ClientHook> addRef() override {
return kj::addRef(*this);
}
const void* getBrand() override {
return nullptr;
}
private:
kj::Exception exception;
};
kj::Own<ClientHook> BrokenPipeline::getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) {
return kj::refcounted<BrokenClient>(exception);
}
} // namespace
kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason) {
return kj::refcounted<BrokenClient>(reason);
}
kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason) {
return kj::refcounted<BrokenClient>(kj::mv(reason));
}
kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
return kj::refcounted<BrokenPipeline>(kj::mv(reason));
}
} // namespace capnp } // namespace capnp
...@@ -332,6 +332,8 @@ public: ...@@ -332,6 +332,8 @@ public:
class ClientHook { class ClientHook {
public: public:
ClientHook();
virtual Request<AnyPointer, AnyPointer> newCall( virtual Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) = 0; uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) = 0;
// Start a new call, allowing the client to allocate request/response objects as it sees fit. // Start a new call, allowing the client to allocate request/response objects as it sees fit.
...@@ -410,6 +412,13 @@ kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& pro ...@@ -410,6 +412,13 @@ kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& pro
// the new client. This hook's `getResolved()` and `whenMoreResolved()` methods will reflect the // the new client. This hook's `getResolved()` and `whenMoreResolved()` methods will reflect the
// redirection to the eventual replacement client. // redirection to the eventual replacement client.
kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason);
kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called.
kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason);
// Helper function that creates a pipeline which simply throws exceptions when called.
// ======================================================================================= // =======================================================================================
// Extend PointerHelpers for interfaces // Extend PointerHelpers for interfaces
......
...@@ -79,6 +79,16 @@ class EzRpcClient { ...@@ -79,6 +79,16 @@ class EzRpcClient {
// EzRpcClient / EzRpcServer objects in a single thread; they will make sure to make no more // EzRpcClient / EzRpcServer objects in a single thread; they will make sure to make no more
// than one EventLoop.) // than one EventLoop.)
// - These classes only support simple two-party connections, not multilateral VatNetworks. // - These classes only support simple two-party connections, not multilateral VatNetworks.
// - These classes only support communication over a raw, unencrypted socket. If you want to
// build on an abstract stream (perhaps one which supports encryption), you must use the
// lower-level interfaces.
//
// Some of these restrictions will probably be lifted in future versions, but some things will
// always require using the low-level interfaces directly. If you are interested in working
// at a lower level, start by looking at these interfaces:
// - `kj::startAsyncIo()` in `kj/async-io.h`.
// - `RpcSystem` in `capnp/rpc.h`.
// - `TwoPartyVatNetwork` in `capnp/rpc-twoparty.h`.
public: public:
explicit EzRpcClient(kj::StringPtr serverAddress, uint defaultPort = 0); explicit EzRpcClient(kj::StringPtr serverAddress, uint defaultPort = 0);
......
...@@ -279,7 +279,7 @@ static void checkStruct(StructReader reader) { ...@@ -279,7 +279,7 @@ static void checkStruct(StructReader reader) {
TEST(WireFormat, StructRoundTrip_OneSegment) { TEST(WireFormat, StructRoundTrip_OneSegment) {
MallocMessageBuilder message; MallocMessageBuilder message;
BasicBuilderArena arena(&message); BuilderArena arena(&message);
auto allocation = arena.allocate(1 * WORDS); auto allocation = arena.allocate(1 * WORDS);
SegmentBuilder* segment = allocation.segment; SegmentBuilder* segment = allocation.segment;
word* rootLocation = allocation.words; word* rootLocation = allocation.words;
...@@ -316,7 +316,7 @@ TEST(WireFormat, StructRoundTrip_OneSegment) { ...@@ -316,7 +316,7 @@ TEST(WireFormat, StructRoundTrip_OneSegment) {
TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) { TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
MallocMessageBuilder message(0, AllocationStrategy::FIXED_SIZE); MallocMessageBuilder message(0, AllocationStrategy::FIXED_SIZE);
BasicBuilderArena arena(&message); BuilderArena arena(&message);
auto allocation = arena.allocate(1 * WORDS); auto allocation = arena.allocate(1 * WORDS);
SegmentBuilder* segment = allocation.segment; SegmentBuilder* segment = allocation.segment;
word* rootLocation = allocation.words; word* rootLocation = allocation.words;
...@@ -354,7 +354,7 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) { ...@@ -354,7 +354,7 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) { TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) {
MallocMessageBuilder message(8, AllocationStrategy::FIXED_SIZE); MallocMessageBuilder message(8, AllocationStrategy::FIXED_SIZE);
BasicBuilderArena arena(&message); BuilderArena arena(&message);
auto allocation = arena.allocate(1 * WORDS); auto allocation = arena.allocate(1 * WORDS);
SegmentBuilder* segment = allocation.segment; SegmentBuilder* segment = allocation.segment;
word* rootLocation = allocation.words; word* rootLocation = allocation.words;
......
...@@ -37,8 +37,8 @@ static BrokenCapFactory* brokenCapFactory = nullptr; ...@@ -37,8 +37,8 @@ static BrokenCapFactory* brokenCapFactory = nullptr;
// but we can't have a link-time dependency on libcapnp-rpc. // but we can't have a link-time dependency on libcapnp-rpc.
void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory) { void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory) {
// Called from capability-context.c++ when a capability context is created. May be called // Called from capability.c++ when the capability API is used, to make sure that layout.c++
// multiple times but always with the same value. // is ready for it. May be called multiple times but always with the same value.
__atomic_store_n(&brokenCapFactory, &factory, __ATOMIC_RELAXED); __atomic_store_n(&brokenCapFactory, &factory, __ATOMIC_RELAXED);
} }
...@@ -1725,8 +1725,7 @@ struct WireHelpers { ...@@ -1725,8 +1725,7 @@ struct WireHelpers {
setCapabilityPointer(dstSegment, dst, kj::mv(*cap), orphanArena); setCapabilityPointer(dstSegment, dst, kj::mv(*cap), orphanArena);
return { dstSegment, nullptr }; return { dstSegment, nullptr };
} else { } else {
KJ_FAIL_REQUIRE("Message contained capability pointer but is not imbued with a " KJ_FAIL_REQUIRE("Message contained invalid capability pointer.") {
"capability context.") {
goto useDefault; goto useDefault;
} }
} }
...@@ -1846,28 +1845,21 @@ struct WireHelpers { ...@@ -1846,28 +1845,21 @@ struct WireHelpers {
"use the Cap'n Proto RPC system."); "use the Cap'n Proto RPC system.");
if (ref->isNull()) { if (ref->isNull()) {
maybeCap = brokenCapFactory->newBrokenCap("Calling null capability pointer."); return brokenCapFactory->newBrokenCap("Calling null capability pointer.");
} else if (!ref->isCapability()) { } else if (!ref->isCapability()) {
KJ_FAIL_REQUIRE( KJ_FAIL_REQUIRE(
"Message contains non-capability pointer where capability pointer was expected.") { "Message contains non-capability pointer where capability pointer was expected.") {
break; break;
} }
maybeCap = brokenCapFactory->newBrokenCap( return brokenCapFactory->newBrokenCap(
"Calling capability extracted from a non-capability pointer."); "Calling capability extracted from a non-capability pointer.");
} else { } else KJ_IF_MAYBE(cap, segment->getArena()->extractCap(ref->capRef.index.get())) {
maybeCap = segment->getArena()->extractCap(ref->capRef.index.get());
}
KJ_IF_MAYBE(cap, maybeCap) {
return kj::mv(*cap); return kj::mv(*cap);
} else { } else {
// The message is not imbued with a capability context. We can't really recover from this, KJ_FAIL_REQUIRE("Message contains invalid capability pointer.") {
// because we have no way to construct a ClientHook in this case -- capability.c++ may not break;
// even be linked in. Luckily, this is the caller's error, not the message sender's -- }
// it's the message reader who is calling a capability getter on a message they should know return brokenCapFactory->newBrokenCap("Calling invalid capability pointer.");
// they have not imbued properly.
KJ_FAIL_REQUIRE("Tried to read a capability out of a message that doesn't have a "
"capability context.");
} }
} }
...@@ -2216,10 +2208,6 @@ BuilderArena* PointerBuilder::getArena() const { ...@@ -2216,10 +2208,6 @@ BuilderArena* PointerBuilder::getArena() const {
return segment->getArena(); return segment->getArena();
} }
PointerBuilder PointerBuilder::imbue(ImbuedBuilderArena& newArena) const {
return PointerBuilder(newArena.imbue(segment), pointer);
}
// ======================================================================================= // =======================================================================================
// PointerReader // PointerReader
...@@ -2278,10 +2266,6 @@ kj::Maybe<Arena&> PointerReader::getArena() const { ...@@ -2278,10 +2266,6 @@ kj::Maybe<Arena&> PointerReader::getArena() const {
return segment == nullptr ? nullptr : segment->getArena(); return segment == nullptr ? nullptr : segment->getArena();
} }
PointerReader PointerReader::imbue(ImbuedReaderArena& newArena) const {
return PointerReader(newArena.imbue(segment), pointer, nestingLimit);
}
// ======================================================================================= // =======================================================================================
// StructBuilder // StructBuilder
......
...@@ -56,8 +56,6 @@ class SegmentReader; ...@@ -56,8 +56,6 @@ class SegmentReader;
class SegmentBuilder; class SegmentBuilder;
class Arena; class Arena;
class BuilderArena; class BuilderArena;
class ImbuedReaderArena;
class ImbuedBuilderArena;
// ============================================================================= // =============================================================================
...@@ -341,9 +339,6 @@ public: ...@@ -341,9 +339,6 @@ public:
BuilderArena* getArena() const; BuilderArena* getArena() const;
// Get the arena containing this pointer. // Get the arena containing this pointer.
PointerBuilder imbue(ImbuedBuilderArena& newArena) const;
// Imbue the pointer with a capability context, returning the imbued pointer.
private: private:
SegmentBuilder* segment; // Memory segment in which the pointer resides. SegmentBuilder* segment; // Memory segment in which the pointer resides.
WirePointer* pointer; // Pointer to the pointer. WirePointer* pointer; // Pointer to the pointer.
...@@ -392,9 +387,6 @@ public: ...@@ -392,9 +387,6 @@ public:
kj::Maybe<Arena&> getArena() const; kj::Maybe<Arena&> getArena() const;
// Get the arena containing this pointer. // Get the arena containing this pointer.
PointerReader imbue(ImbuedReaderArena& newArena) const;
// Imbue the pointer with a capability context, returning the imbued pointer.
private: private:
SegmentReader* segment; // Memory segment in which the pointer resides. SegmentReader* segment; // Memory segment in which the pointer resides.
const WirePointer* pointer; // Pointer to the pointer. null = treat as null pointer. const WirePointer* pointer; // Pointer to the pointer. null = treat as null pointer.
...@@ -473,10 +465,6 @@ public: ...@@ -473,10 +465,6 @@ public:
BuilderArena* getArena(); BuilderArena* getArena();
// Gets the arena in which this object is allocated. // Gets the arena in which this object is allocated.
void unimbue();
// Removes the capability context from the builder. This means replacing the segment pointer --
// which is assumed to point to an ImbuedSegmentBuilder -- with the non-imbued base segment.
private: private:
SegmentBuilder* segment; // Memory segment in which the struct resides. SegmentBuilder* segment; // Memory segment in which the struct resides.
void* data; // Pointer to the encoded data. void* data; // Pointer to the encoded data.
...@@ -542,10 +530,6 @@ public: ...@@ -542,10 +530,6 @@ public:
// use the result as a hint for allocating the first segment, do the copy, and then throw an // use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns. // exception if it overruns.
void unimbue();
// Removes the capability context from the reader. This means replacing the segment pointer --
// which is assumed to point to an ImbuedSegmentReader -- with the non-imbued base segment.
private: private:
SegmentReader* segment; // Memory segment in which the struct resides. SegmentReader* segment; // Memory segment in which the struct resides.
......
...@@ -38,16 +38,16 @@ namespace capnp { ...@@ -38,16 +38,16 @@ namespace capnp {
MessageReader::MessageReader(ReaderOptions options): options(options), allocatedArena(false) {} MessageReader::MessageReader(ReaderOptions options): options(options), allocatedArena(false) {}
MessageReader::~MessageReader() noexcept(false) { MessageReader::~MessageReader() noexcept(false) {
if (allocatedArena) { if (allocatedArena) {
arena()->~BasicReaderArena(); arena()->~ReaderArena();
} }
} }
AnyPointer::Reader MessageReader::getRootInternal() { AnyPointer::Reader MessageReader::getRootInternal() {
if (!allocatedArena) { if (!allocatedArena) {
static_assert(sizeof(_::BasicReaderArena) <= sizeof(arenaSpace), static_assert(sizeof(_::ReaderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a BasicReaderArena. Please increase it. This will break " "arenaSpace is too small to hold a ReaderArena. Please increase it. This will break "
"ABI compatibility."); "ABI compatibility.");
new(arena()) _::BasicReaderArena(this); new(arena()) _::ReaderArena(this);
allocatedArena = true; allocatedArena = true;
} }
...@@ -75,7 +75,7 @@ _::SegmentBuilder* MessageBuilder::getRootSegment() { ...@@ -75,7 +75,7 @@ _::SegmentBuilder* MessageBuilder::getRootSegment() {
if (allocatedArena) { if (allocatedArena) {
return arena()->getSegment(_::SegmentId(0)); return arena()->getSegment(_::SegmentId(0));
} else { } else {
static_assert(sizeof(_::BasicBuilderArena) <= sizeof(arenaSpace), static_assert(sizeof(_::BuilderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a BuilderArena. Please increase it."); "arenaSpace is too small to hold a BuilderArena. Please increase it.");
kj::ctor(*arena(), this); kj::ctor(*arena(), this);
allocatedArena = true; allocatedArena = true;
...@@ -104,6 +104,14 @@ kj::ArrayPtr<const kj::ArrayPtr<const word>> MessageBuilder::getSegmentsForOutpu ...@@ -104,6 +104,14 @@ kj::ArrayPtr<const kj::ArrayPtr<const word>> MessageBuilder::getSegmentsForOutpu
} }
} }
kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> MessageBuilder::getCapTable() {
if (allocatedArena) {
return arena()->getCapTable();
} else {
return nullptr;
}
}
Orphanage MessageBuilder::getOrphanage() { Orphanage MessageBuilder::getOrphanage() {
// We must ensure that the arena and root pointer have been allocated before the Orphanage // We must ensure that the arena and root pointer have been allocated before the Orphanage
// can be used. // can be used.
......
...@@ -34,8 +34,8 @@ ...@@ -34,8 +34,8 @@
namespace capnp { namespace capnp {
namespace _ { // private namespace _ { // private
class BasicReaderArena; class ReaderArena;
class BasicBuilderArena; class BuilderArena;
} }
class StructSchema; class StructSchema;
...@@ -118,6 +118,16 @@ public: ...@@ -118,6 +118,16 @@ public:
// RootType in this case must be DynamicStruct, and you must #include <capnp/dynamic.h> to // RootType in this case must be DynamicStruct, and you must #include <capnp/dynamic.h> to
// use this. // use this.
void initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTable);
// Sets the table of capabilities embedded in this message. Capability pointers found in the
// message content contain indexes into this table. You must call this before attempting to
// read any capability pointers (interface pointers) from the message. The table is not passed
// to the constructor because often (as in the RPC system) the cap table is actually constructed
// based on a list read from the message itself.
//
// You must link against libcapnp-rpc to call this method (the rest of MessageBuilder is in
// regular libcapnp).
private: private:
ReaderOptions options; ReaderOptions options;
...@@ -128,7 +138,7 @@ private: ...@@ -128,7 +138,7 @@ private:
void* arenaSpace[15 + sizeof(kj::MutexGuarded<void*>) / sizeof(void*)]; void* arenaSpace[15 + sizeof(kj::MutexGuarded<void*>) / sizeof(void*)];
bool allocatedArena; bool allocatedArena;
_::BasicReaderArena* arena() { return reinterpret_cast<_::BasicReaderArena*>(arenaSpace); } _::ReaderArena* arena() { return reinterpret_cast<_::ReaderArena*>(arenaSpace); }
AnyPointer::Reader getRootInternal(); AnyPointer::Reader getRootInternal();
}; };
...@@ -186,11 +196,18 @@ public: ...@@ -186,11 +196,18 @@ public:
// Like setRoot() but adopts the orphan without copying. // Like setRoot() but adopts the orphan without copying.
kj::ArrayPtr<const kj::ArrayPtr<const word>> getSegmentsForOutput(); kj::ArrayPtr<const kj::ArrayPtr<const word>> getSegmentsForOutput();
// Get the raw data that makes up the message.
kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> getCapTable();
// Get the table of capabilities (interface pointers) that have been added to this message.
// When you later parse this message, you must call `initCapTable()` on the `MessageReader` and
// give it an equivalent set of capabilities, otherwise cap pointers in the message will be
// unusable.
Orphanage getOrphanage(); Orphanage getOrphanage();
private: private:
void* arenaSpace[16]; void* arenaSpace[17];
// Space in which we can construct a BuilderArena. We don't use BuilderArena directly here // Space in which we can construct a BuilderArena. We don't use BuilderArena directly here
// because we don't want clients to have to #include arena.h, which itself includes a bunch of // because we don't want clients to have to #include arena.h, which itself includes a bunch of
// big STL headers. We don't use a pointer to a BuilderArena because that would require an // big STL headers. We don't use a pointer to a BuilderArena because that would require an
...@@ -203,7 +220,7 @@ private: ...@@ -203,7 +220,7 @@ private:
// isn't constructed yet. This is kind of annoying because it means that getOrphanage() is // isn't constructed yet. This is kind of annoying because it means that getOrphanage() is
// not thread-safe, but that shouldn't be a huge deal... // not thread-safe, but that shouldn't be a huge deal...
_::BasicBuilderArena* arena() { return reinterpret_cast<_::BasicBuilderArena*>(arenaSpace); } _::BuilderArena* arena() { return reinterpret_cast<_::BuilderArena*>(arenaSpace); }
_::SegmentBuilder* getRootSegment(); _::SegmentBuilder* getRootSegment();
AnyPointer::Builder getRootInternal(); AnyPointer::Builder getRootInternal();
}; };
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "rpc.h" #include "rpc.h"
#include "capability-context.h"
#include "test-util.h" #include "test-util.h"
#include "schema.h" #include "schema.h"
#include "serialize.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/string-tree.h> #include <kj/string-tree.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -83,14 +83,7 @@ public: ...@@ -83,14 +83,7 @@ public:
} }
auto payload = call.getParams(); auto payload = call.getParams();
auto capTableReader = payload.getCapTable(); auto params = kj::str(payload.getContent().getAs<DynamicStruct>(paramType));
auto capTable = kj::heapArrayBuilder<kj::Own<ClientHook>>(capTableReader.size());
for (uint i = 0; i < capTableReader.size(); i++) {
capTable.add(newBrokenCap("fake cap"));
}
CapReaderContext context(capTable.finish());
auto params = kj::str(context.imbue(payload.getContent()).getAs<DynamicStruct>(paramType));
auto sendResultsTo = call.getSendResultsTo(); auto sendResultsTo = call.getSendResultsTo();
...@@ -120,22 +113,14 @@ public: ...@@ -120,22 +113,14 @@ public:
} }
auto payload = ret.getResults(); auto payload = ret.getResults();
auto capTableReader = payload.getCapTable();
auto capTable = kj::heapArrayBuilder<kj::Own<ClientHook>>(capTableReader.size());
for (uint i = 0; i < capTableReader.size(); i++) {
capTable.add(newBrokenCap("fake cap"));
}
CapReaderContext context(capTable.finish());
auto imbued = context.imbue(payload.getContent());
if (schema.getProto().isStruct()) { if (schema.getProto().isStruct()) {
auto results = kj::str(imbued.getAs<DynamicStruct>(schema.asStruct())); auto results = kj::str(payload.getContent().getAs<DynamicStruct>(schema.asStruct()));
return kj::str(senderName, "(", ret.getAnswerId(), "): return ", results, return kj::str(senderName, "(", ret.getAnswerId(), "): return ", results,
" caps:[", kj::strArray(payload.getCapTable(), ", "), "]"); " caps:[", kj::strArray(payload.getCapTable(), ", "), "]");
} else if (schema.getProto().isInterface()) { } else if (schema.getProto().isInterface()) {
imbued.getAs<DynamicCapability>(schema.asInterface()); payload.getContent().getAs<DynamicCapability>(schema.asInterface());
return kj::str(senderName, "(", ret.getAnswerId(), "): return cap ", return kj::str(senderName, "(", ret.getAnswerId(), "): return cap ",
kj::strArray(payload.getCapTable(), ", ")); kj::strArray(payload.getCapTable(), ", "));
} else { } else {
...@@ -246,26 +231,37 @@ public: ...@@ -246,26 +231,37 @@ public:
class IncomingRpcMessageImpl final: public IncomingRpcMessage, public kj::Refcounted { class IncomingRpcMessageImpl final: public IncomingRpcMessage, public kj::Refcounted {
public: public:
IncomingRpcMessageImpl(uint firstSegmentWordSize) IncomingRpcMessageImpl(kj::Array<word> data)
: message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : data(kj::mv(data)),
: firstSegmentWordSize) {} message(this->data) {}
AnyPointer::Reader getBody() override { AnyPointer::Reader getBody() override {
return message.getRoot<AnyPointer>().asReader(); return message.getRoot<AnyPointer>();
} }
MallocMessageBuilder message; void initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>>&& capTable) override {
message.initCapTable(kj::mv(capTable));
}
kj::Array<word> data;
FlatArrayMessageReader message;
}; };
class OutgoingRpcMessageImpl final: public OutgoingRpcMessage { class OutgoingRpcMessageImpl final: public OutgoingRpcMessage {
public: public:
OutgoingRpcMessageImpl(ConnectionImpl& connection, uint firstSegmentWordSize) OutgoingRpcMessageImpl(ConnectionImpl& connection, uint firstSegmentWordSize)
: connection(connection), : connection(connection),
message(kj::refcounted<IncomingRpcMessageImpl>(firstSegmentWordSize)) {} message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS
: firstSegmentWordSize) {}
AnyPointer::Builder getBody() override { AnyPointer::Builder getBody() override {
return message->message.getRoot<AnyPointer>(); return message.getRoot<AnyPointer>();
}
kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> getCapTable() override {
return message.getCapTable();
} }
void send() override { void send() override {
if (connection.networkException != nullptr) { if (connection.networkException != nullptr) {
return; return;
...@@ -275,11 +271,13 @@ public: ...@@ -275,11 +271,13 @@ public:
// Uncomment to get a debug dump. // Uncomment to get a debug dump.
// kj::String msg = connection.network.network.dumper.dump( // kj::String msg = connection.network.network.dumper.dump(
// message->message.getRoot<rpc::Message>(), connection.sender); // message.getRoot<rpc::Message>(), connection.sender);
// KJ_ DBG(msg); // KJ_ DBG(msg);
auto incomingMessage = kj::heap<IncomingRpcMessageImpl>(messageToFlatArray(message));
auto connectionPtr = &connection; auto connectionPtr = &connection;
connection.tasks->add(kj::evalLater(kj::mvCapture(kj::addRef(*message), connection.tasks->add(kj::evalLater(kj::mvCapture(incomingMessage,
[connectionPtr](kj::Own<IncomingRpcMessageImpl>&& message) { [connectionPtr](kj::Own<IncomingRpcMessageImpl>&& message) {
KJ_IF_MAYBE(p, connectionPtr->partner) { KJ_IF_MAYBE(p, connectionPtr->partner) {
if (p->fulfillers.empty()) { if (p->fulfillers.empty()) {
...@@ -296,7 +294,7 @@ public: ...@@ -296,7 +294,7 @@ public:
private: private:
ConnectionImpl& connection; ConnectionImpl& connection;
kj::Own<IncomingRpcMessageImpl> message; MallocMessageBuilder message;
}; };
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) override { kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) override {
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "rpc-twoparty.h" #include "rpc-twoparty.h"
#include "capability-context.h"
#include "test-util.h" #include "test-util.h"
#include <kj/async-unix.h> #include <kj/async-unix.h>
#include <kj/debug.h> #include <kj/debug.h>
......
...@@ -76,6 +76,10 @@ public: ...@@ -76,6 +76,10 @@ public:
return message.getRoot<AnyPointer>(); return message.getRoot<AnyPointer>();
} }
kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> getCapTable() override {
return message.getCapTable();
}
void send() override { void send() override {
network.previousWrite = network.previousWrite.then([&]() { network.previousWrite = network.previousWrite.then([&]() {
auto promise = writeMessage(network.stream, message).then([]() { auto promise = writeMessage(network.stream, message).then([]() {
...@@ -101,6 +105,10 @@ public: ...@@ -101,6 +105,10 @@ public:
return message->getRoot<AnyPointer>(); return message->getRoot<AnyPointer>();
} }
void initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>>&& capTable) override {
message->initCapTable(kj::mv(capTable));
}
private: private:
kj::Own<MessageReader> message; kj::Own<MessageReader> message;
}; };
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "rpc.h" #include "rpc.h"
#include "capability-context.h" #include "message.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/vector.h> #include <kj/vector.h>
#include <kj/async.h> #include <kj/async.h>
...@@ -910,13 +910,17 @@ private: ...@@ -910,13 +910,17 @@ private:
} }
} }
kj::Array<ExportId> writeDescriptors(kj::ArrayPtr<kj::Own<ClientHook>> capTable, kj::Array<ExportId> writeDescriptors(kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> capTable,
rpc::Payload::Builder payload) { rpc::Payload::Builder payload) {
auto capTableBuilder = payload.initCapTable(capTable.size()); auto capTableBuilder = payload.initCapTable(capTable.size());
kj::Vector<ExportId> exports(capTable.size()); kj::Vector<ExportId> exports(capTable.size());
for (uint i: kj::indices(capTable)) { for (uint i: kj::indices(capTable)) {
KJ_IF_MAYBE(exportId, writeDescriptor(*capTable[i], capTableBuilder[i])) { KJ_IF_MAYBE(cap, capTable[i]) {
exports.add(*exportId); KJ_IF_MAYBE(exportId, writeDescriptor(**cap, capTableBuilder[i])) {
exports.add(*exportId);
}
} else {
capTableBuilder[i].setNone();
} }
} }
return exports.releaseAsArray(); return exports.releaseAsArray();
...@@ -1070,10 +1074,10 @@ private: ...@@ -1070,10 +1074,10 @@ private:
} }
} }
kj::Own<ClientHook> receiveCap(rpc::CapDescriptor::Reader descriptor) { kj::Maybe<kj::Own<ClientHook>> receiveCap(rpc::CapDescriptor::Reader descriptor) {
switch (descriptor.which()) { switch (descriptor.which()) {
case rpc::CapDescriptor::NONE: case rpc::CapDescriptor::NONE:
return newBrokenCap("Called a `CapDescriptor.none`."); return nullptr;
case rpc::CapDescriptor::SENDER_HOSTED: case rpc::CapDescriptor::SENDER_HOSTED:
return import(descriptor.getSenderHosted(), false); return import(descriptor.getSenderHosted(), false);
...@@ -1115,8 +1119,8 @@ private: ...@@ -1115,8 +1119,8 @@ private:
} }
} }
kj::Array<kj::Own<ClientHook>> receiveCaps(List<rpc::CapDescriptor>::Reader capTable) { kj::Array<kj::Maybe<kj::Own<ClientHook>>> receiveCaps(List<rpc::CapDescriptor>::Reader capTable) {
auto result = kj::heapArrayBuilder<kj::Own<ClientHook>>(capTable.size()); auto result = kj::heapArrayBuilder<kj::Maybe<kj::Own<ClientHook>>>(capTable.size());
for (auto cap: capTable) { for (auto cap: capTable) {
result.add(receiveCap(cap)); result.add(receiveCap(cap));
} }
...@@ -1194,7 +1198,7 @@ private: ...@@ -1194,7 +1198,7 @@ private:
firstSegmentSize(sizeHint, messageSizeHint<rpc::Call>() + firstSegmentSize(sizeHint, messageSizeHint<rpc::Call>() +
sizeInWords<rpc::Payload>() + MESSAGE_TARGET_SIZE_HINT))), sizeInWords<rpc::Payload>() + MESSAGE_TARGET_SIZE_HINT))),
callBuilder(message->getBody().getAs<rpc::Message>().initCall()), callBuilder(message->getBody().getAs<rpc::Message>().initCall()),
paramsBuilder(context.imbue(callBuilder.getParams().getContent())) {} paramsBuilder(callBuilder.getParams().getContent()) {}
inline AnyPointer::Builder getRoot() { inline AnyPointer::Builder getRoot() {
return paramsBuilder; return paramsBuilder;
...@@ -1288,7 +1292,6 @@ private: ...@@ -1288,7 +1292,6 @@ private:
kj::Own<RpcClient> target; kj::Own<RpcClient> target;
kj::Own<OutgoingRpcMessage> message; kj::Own<OutgoingRpcMessage> message;
CapBuilderContext context;
rpc::Call::Builder callBuilder; rpc::Call::Builder callBuilder;
AnyPointer::Builder paramsBuilder; AnyPointer::Builder paramsBuilder;
...@@ -1300,7 +1303,7 @@ private: ...@@ -1300,7 +1303,7 @@ private:
SendInternalResult sendInternal(bool isTailCall) { SendInternalResult sendInternal(bool isTailCall) {
// Build the cap table. // Build the cap table.
auto exports = connectionState->writeDescriptors( auto exports = connectionState->writeDescriptors(
context.getCapTable(), callBuilder.getParams()); message->getCapTable(), callBuilder.getParams());
// Init the question table. Do this after writing descriptors to avoid interference. // Init the question table. Do this after writing descriptors to avoid interference.
QuestionId questionId; QuestionId questionId;
...@@ -1432,12 +1435,10 @@ private: ...@@ -1432,12 +1435,10 @@ private:
RpcResponseImpl(RpcConnectionState& connectionState, RpcResponseImpl(RpcConnectionState& connectionState,
kj::Own<QuestionRef>&& questionRef, kj::Own<QuestionRef>&& questionRef,
kj::Own<IncomingRpcMessage>&& message, kj::Own<IncomingRpcMessage>&& message,
AnyPointer::Reader results, AnyPointer::Reader results)
kj::Array<kj::Own<ClientHook>>&& capTable)
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
message(kj::mv(message)), message(kj::mv(message)),
context(kj::mv(capTable)), reader(results),
reader(context.imbue(results)),
questionRef(kj::mv(questionRef)) {} questionRef(kj::mv(questionRef)) {}
AnyPointer::Reader getResults() override { AnyPointer::Reader getResults() override {
...@@ -1451,7 +1452,6 @@ private: ...@@ -1451,7 +1452,6 @@ private:
private: private:
kj::Own<RpcConnectionState> connectionState; kj::Own<RpcConnectionState> connectionState;
kj::Own<IncomingRpcMessage> message; kj::Own<IncomingRpcMessage> message;
CapReaderContext context;
AnyPointer::Reader reader; AnyPointer::Reader reader;
kj::Own<QuestionRef> questionRef; kj::Own<QuestionRef> questionRef;
}; };
...@@ -1471,11 +1471,10 @@ private: ...@@ -1471,11 +1471,10 @@ private:
rpc::Payload::Builder payload) rpc::Payload::Builder payload)
: connectionState(connectionState), : connectionState(connectionState),
message(kj::mv(message)), message(kj::mv(message)),
payload(payload), payload(payload) {}
builder(context.imbue(payload.getContent())) {}
AnyPointer::Builder getResultsBuilder() override { AnyPointer::Builder getResultsBuilder() override {
return builder; return payload.getContent();
} }
kj::Maybe<kj::Array<ExportId>> send() { kj::Maybe<kj::Array<ExportId>> send() {
...@@ -1483,7 +1482,7 @@ private: ...@@ -1483,7 +1482,7 @@ private:
// (Could return a non-null empty array if there were caps but none of them were exports.) // (Could return a non-null empty array if there were caps but none of them were exports.)
// Build the cap table. // Build the cap table.
auto capTable = context.getCapTable(); auto capTable = message->getCapTable();
auto exports = connectionState.writeDescriptors(capTable, payload); auto exports = connectionState.writeDescriptors(capTable, payload);
message->send(); message->send();
...@@ -1497,23 +1496,22 @@ private: ...@@ -1497,23 +1496,22 @@ private:
private: private:
RpcConnectionState& connectionState; RpcConnectionState& connectionState;
kj::Own<OutgoingRpcMessage> message; kj::Own<OutgoingRpcMessage> message;
CapBuilderContext context;
rpc::Payload::Builder payload; rpc::Payload::Builder payload;
AnyPointer::Builder builder;
}; };
class LocallyRedirectedRpcResponse final class LocallyRedirectedRpcResponse final
: public RpcResponse, public RpcServerResponse, public kj::Refcounted{ : public RpcResponse, public RpcServerResponse, public kj::Refcounted{
public: public:
LocallyRedirectedRpcResponse(kj::Maybe<MessageSize> sizeHint) LocallyRedirectedRpcResponse(kj::Maybe<MessageSize> sizeHint)
: message(sizeHint) {} : message(sizeHint.map([](MessageSize size) { return size.wordCount; })
.orDefault(SUGGESTED_FIRST_SEGMENT_WORDS)) {}
AnyPointer::Builder getResultsBuilder() override { AnyPointer::Builder getResultsBuilder() override {
return message.getRoot(); return message.getRoot<AnyPointer>();
} }
AnyPointer::Reader getResults() override { AnyPointer::Reader getResults() override {
return message.getRootReader(); return message.getRoot<AnyPointer>();
} }
kj::Own<RpcResponse> addRef() override { kj::Own<RpcResponse> addRef() override {
...@@ -1521,20 +1519,18 @@ private: ...@@ -1521,20 +1519,18 @@ private:
} }
private: private:
LocalMessage message; MallocMessageBuilder message;
}; };
class RpcCallContext final: public CallContextHook, public kj::Refcounted { class RpcCallContext final: public CallContextHook, public kj::Refcounted {
public: public:
RpcCallContext(RpcConnectionState& connectionState, AnswerId answerId, RpcCallContext(RpcConnectionState& connectionState, AnswerId answerId,
kj::Own<IncomingRpcMessage>&& request, const AnyPointer::Reader& params, kj::Own<IncomingRpcMessage>&& request, const AnyPointer::Reader& params,
kj::Array<kj::Own<ClientHook>>&& requestCapTable, bool redirectResults, bool redirectResults, kj::Own<kj::PromiseFulfiller<void>>&& cancelFulfiller)
kj::Own<kj::PromiseFulfiller<void>>&& cancelFulfiller)
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
answerId(answerId), answerId(answerId),
request(kj::mv(request)), request(kj::mv(request)),
requestCapContext(kj::mv(requestCapTable)), params(params),
params(requestCapContext.imbue(params)),
returnMessage(nullptr), returnMessage(nullptr),
redirectResults(redirectResults), redirectResults(redirectResults),
cancelFulfiller(kj::mv(cancelFulfiller)) {} cancelFulfiller(kj::mv(cancelFulfiller)) {}
...@@ -1740,7 +1736,6 @@ private: ...@@ -1740,7 +1736,6 @@ private:
// Request --------------------------------------------- // Request ---------------------------------------------
kj::Maybe<kj::Own<IncomingRpcMessage>> request; kj::Maybe<kj::Own<IncomingRpcMessage>> request;
CapReaderContext requestCapContext;
AnyPointer::Reader params; AnyPointer::Reader params;
// Response -------------------------------------------- // Response --------------------------------------------
...@@ -1938,14 +1933,14 @@ private: ...@@ -1938,14 +1933,14 @@ private:
} }
auto payload = call.getParams(); auto payload = call.getParams();
auto capTable = receiveCaps(payload.getCapTable()); message->initCapTable(receiveCaps(payload.getCapTable()));
auto cancelPaf = kj::newPromiseAndFulfiller<void>(); auto cancelPaf = kj::newPromiseAndFulfiller<void>();
AnswerId answerId = call.getQuestionId(); AnswerId answerId = call.getQuestionId();
auto context = kj::refcounted<RpcCallContext>( auto context = kj::refcounted<RpcCallContext>(
*this, answerId, kj::mv(message), payload.getContent(), *this, answerId, kj::mv(message), payload.getContent(),
kj::mv(capTable), redirectResults, kj::mv(cancelPaf.fulfiller)); redirectResults, kj::mv(cancelPaf.fulfiller));
// No more using `call` after this point, as it now belongs to the context. // No more using `call` after this point, as it now belongs to the context.
...@@ -2080,9 +2075,9 @@ private: ...@@ -2080,9 +2075,9 @@ private:
} }
auto payload = ret.getResults(); auto payload = ret.getResults();
message->initCapTable(receiveCaps(payload.getCapTable()));
questionRef->fulfill(kj::refcounted<RpcResponseImpl>( questionRef->fulfill(kj::refcounted<RpcResponseImpl>(
*this, kj::addRef(*questionRef), kj::mv(message), payload.getContent(), *this, kj::addRef(*questionRef), kj::mv(message), payload.getContent()));
receiveCaps(payload.getCapTable())));
break; break;
} }
...@@ -2184,7 +2179,11 @@ private: ...@@ -2184,7 +2179,11 @@ private:
// Extract the replacement capability. // Extract the replacement capability.
switch (resolve.which()) { switch (resolve.which()) {
case rpc::Resolve::CAP: case rpc::Resolve::CAP:
replacement = receiveCap(resolve.getCap()); KJ_IF_MAYBE(cap, receiveCap(resolve.getCap())) {
replacement = kj::mv(*cap);
} else {
KJ_FAIL_REQUIRE("'Resolve' contained 'CapDescriptor.none'.") { return; }
}
break; break;
case rpc::Resolve::EXCEPTION: case rpc::Resolve::EXCEPTION:
...@@ -2347,8 +2346,6 @@ private: ...@@ -2347,8 +2346,6 @@ private:
rpc::Return::Builder ret = response->getBody().getAs<rpc::Message>().initReturn(); rpc::Return::Builder ret = response->getBody().getAs<rpc::Message>().initReturn();
ret.setAnswerId(answerId); ret.setAnswerId(answerId);
CapBuilderContext context;
kj::Own<ClientHook> capHook; kj::Own<ClientHook> capHook;
kj::Array<ExportId> resultExports; kj::Array<ExportId> resultExports;
KJ_DEFER(releaseExports(resultExports)); // in case something goes wrong KJ_DEFER(releaseExports(resultExports)); // in case something goes wrong
...@@ -2358,13 +2355,12 @@ private: ...@@ -2358,13 +2355,12 @@ private:
KJ_IF_MAYBE(r, restorer) { KJ_IF_MAYBE(r, restorer) {
Capability::Client cap = r->baseRestore(restore.getObjectId()); Capability::Client cap = r->baseRestore(restore.getObjectId());
auto payload = ret.initResults(); auto payload = ret.initResults();
auto results = context.imbue(payload.getContent()); payload.getContent().setAs<Capability>(kj::mv(cap));
results.setAs<Capability>(cap);
auto capTable = context.getCapTable(); auto capTable = response->getCapTable();
KJ_DASSERT(capTable.size() == 1); KJ_DASSERT(capTable.size() == 1);
resultExports = writeDescriptors(capTable, payload); resultExports = writeDescriptors(capTable, payload);
capHook = capTable[0]->addRef(); capHook = KJ_ASSERT_NONNULL(capTable[0])->addRef();
} else { } else {
KJ_FAIL_REQUIRE("This vat cannot restore this SturdyRef.") { break; } KJ_FAIL_REQUIRE("This vat cannot restore this SturdyRef.") { break; }
} }
......
...@@ -126,6 +126,9 @@ public: ...@@ -126,6 +126,9 @@ public:
// Get the message body, which the caller may fill in any way it wants. (The standard RPC // Get the message body, which the caller may fill in any way it wants. (The standard RPC
// implementation initializes it as a Message as defined in rpc.capnp.) // implementation initializes it as a Message as defined in rpc.capnp.)
virtual kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> getCapTable() = 0;
// Calls getCapTable() on the underlying MessageBuilder.
virtual void send() = 0; virtual void send() = 0;
// Send the message, or at least put it in a queue to be sent later. Note that the builder // Send the message, or at least put it in a queue to be sent later. Note that the builder
// returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed. // returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed.
...@@ -138,6 +141,9 @@ public: ...@@ -138,6 +141,9 @@ public:
virtual AnyPointer::Reader getBody() = 0; virtual AnyPointer::Reader getBody() = 0;
// Get the message body, to be interpreted by the caller. (The standard RPC implementation // Get the message body, to be interpreted by the caller. (The standard RPC implementation
// interprets it as a Message as defined in rpc.capnp.) // interprets it as a Message as defined in rpc.capnp.)
virtual void initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>>&& capTable) = 0;
// Calls initCapTable() on the underlying MessageReader.
}; };
template <typename SturdyRefHostId, typename ProvisionId, typename RecipientId, template <typename SturdyRefHostId, typename ProvisionId, typename RecipientId,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment