Commit 3c7efbb4 authored by Kenton Varda's avatar Kenton Varda

Eliminate the concept of imbuing messages in favor of the simpler concept of…

Eliminate the concept of imbuing messages in favor of the simpler concept of setting a cap table directly on MessageReader / getting one from MessageBuilder.  This eliminates capability-context entirely.  This was made possible by the earlier change which moved capabilities to a separate table rather than storing CapDescriptors inline, but I didn't realize it at the time.
parent bf5dbebf
...@@ -160,7 +160,6 @@ includecapnp_HEADERS = \ ...@@ -160,7 +160,6 @@ includecapnp_HEADERS = \
src/capnp/any.h \ src/capnp/any.h \
src/capnp/message.h \ src/capnp/message.h \
src/capnp/capability.h \ src/capnp/capability.h \
src/capnp/capability-context.h \
src/capnp/schema.capnp.h \ src/capnp/schema.capnp.h \
src/capnp/schema.h \ src/capnp/schema.h \
src/capnp/schema-loader.h \ src/capnp/schema-loader.h \
...@@ -235,7 +234,6 @@ libcapnp_rpc_la_LDFLAGS = -release $(VERSION) -no-undefined ...@@ -235,7 +234,6 @@ libcapnp_rpc_la_LDFLAGS = -release $(VERSION) -no-undefined
libcapnp_rpc_la_SOURCES= \ libcapnp_rpc_la_SOURCES= \
src/capnp/serialize-async.c++ \ src/capnp/serialize-async.c++ \
src/capnp/capability.c++ \ src/capnp/capability.c++ \
src/capnp/capability-context.c++ \
src/capnp/dynamic-capability.c++ \ src/capnp/dynamic-capability.c++ \
src/capnp/rpc.c++ \ src/capnp/rpc.c++ \
src/capnp/rpc.capnp.c++ \ src/capnp/rpc.capnp.c++ \
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// This sample code appears in the documentation for the C++ implementation. // This sample code appears in the documentation for the C++ implementation.
// //
// Compile with: // Compile with:
......
# Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@0x9eb32e19f86ee174; @0x9eb32e19f86ee174;
using Cxx = import "/capnp/c++.capnp"; using Cxx = import "/capnp/c++.capnp";
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "calculator.capnp.h" #include "calculator.capnp.h"
#include <capnp/ez-rpc.h> #include <capnp/ez-rpc.h>
#include <kj/debug.h> #include <kj/debug.h>
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "calculator.capnp.h" #include "calculator.capnp.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <capnp/ez-rpc.h> #include <capnp/ez-rpc.h>
#include <capnp/capability-context.h> // for LocalMessage #include <capnp/message.h>
#include <iostream> #include <iostream>
typedef unsigned int uint; typedef unsigned int uint;
...@@ -90,13 +113,15 @@ class FunctionImpl final: public Calculator::Function::Server { ...@@ -90,13 +113,15 @@ class FunctionImpl final: public Calculator::Function::Server {
public: public:
FunctionImpl(uint paramCount, Calculator::Expression::Reader body) FunctionImpl(uint paramCount, Calculator::Expression::Reader body)
: paramCount(paramCount), body(body) {} : paramCount(paramCount) {
this->body.setRoot(body);
}
kj::Promise<void> call(CallContext context) { kj::Promise<void> call(CallContext context) {
auto params = context.getParams().getParams(); auto params = context.getParams().getParams();
KJ_REQUIRE(params.size() == paramCount, "Wrong number of parameters."); KJ_REQUIRE(params.size() == paramCount, "Wrong number of parameters.");
return evaluateImpl(body.getRoot().getAs<Calculator::Expression>(), params) return evaluateImpl(body.getRoot<Calculator::Expression>(), params)
.then([context](double value) mutable { .then([context](double value) mutable {
context.getResults().setValue(value); context.getResults().setValue(value);
}); });
...@@ -106,10 +131,8 @@ private: ...@@ -106,10 +131,8 @@ private:
uint paramCount; uint paramCount;
// The function's arity. // The function's arity.
capnp::LocalMessage body; capnp::MallocMessageBuilder body;
// LocalMessage holds a message that might contain capabilities (interface // Stores a permanent copy of the function body.
// references). Here we're using it to hold a Calculator.Expression, which
// might contain Calculator.Function and/or Calculator.Value capabilities.
}; };
class OperatorImpl final: public Calculator::Function::Server { class OperatorImpl final: public Calculator::Function::Server {
......
# Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@0x85150b117366d14b; @0x85150b117366d14b;
interface Calculator { interface Calculator {
......
This diff is collapsed.
This diff is collapsed.
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define CAPNP_PRIVATE
#include "capability-context.h"
#include "capability.h"
#include "arena.h"
#include <kj/debug.h>
namespace capnp {
namespace _ {
void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory);
// Defined in layout.c++.
} // namespace _
namespace {
class BrokenCapFactoryImpl: public _::BrokenCapFactory {
public:
kj::Own<ClientHook> newBrokenCap(kj::StringPtr description) override {
return capnp::newBrokenCap(description);
}
};
static BrokenCapFactoryImpl brokenCapFactory;
} // namespace
CapReaderContext::CapReaderContext(kj::Array<kj::Own<ClientHook>>&& capTable)
: capTable(kj::mv(capTable)) {
setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
}
CapReaderContext::~CapReaderContext() noexcept(false) {
if (capTable == nullptr) {
kj::dtor(arena());
}
}
AnyPointer::Reader CapReaderContext::imbue(AnyPointer::Reader base) {
KJ_IF_MAYBE(oldArena, base.reader.getArena()) {
static_assert(sizeof(arena()) <= sizeof(arenaSpace),
"arenaSpace is too small. Please increase it.");
kj::ctor(arena(), oldArena, brokenCapFactory,
kj::mv(KJ_REQUIRE_NONNULL(capTable, "imbue() can only be called once.")));
} else {
KJ_FAIL_REQUIRE("Cannot imbue unchecked message.");
}
capTable = nullptr;
return AnyPointer::Reader(base.reader.imbue(arena()));
}
CapBuilderContext::CapBuilderContext() {
setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
}
CapBuilderContext::~CapBuilderContext() noexcept(false) {
if (arenaAllocated) {
kj::dtor(arena());
}
}
AnyPointer::Builder CapBuilderContext::imbue(AnyPointer::Builder base) {
KJ_REQUIRE(!arenaAllocated, "imbue() can only be called once.");
static_assert(sizeof(arena()) <= sizeof(arenaSpace),
"arenaSpace is too small. Please increase it.");
kj::ctor(arena(), base.builder.getArena(), brokenCapFactory);
arenaAllocated = true;
return AnyPointer::Builder(base.builder.imbue(arena()));
}
kj::ArrayPtr<kj::Own<ClientHook>> CapBuilderContext::getCapTable() {
if (arenaAllocated) {
return arena().getCapTable();
} else {
return nullptr;
}
}
// =======================================================================================
namespace {
uint firstSegmentSize(kj::Maybe<MessageSize> sizeHint) {
KJ_IF_MAYBE(s, sizeHint) {
// 1 for the root pointer. We don't store caps in the message so we don't count those here.
return s->wordCount + 1;
} else {
return SUGGESTED_FIRST_SEGMENT_WORDS;
}
}
} // namespace
LocalMessage::LocalMessage(kj::Maybe<MessageSize> sizeHint)
: message(firstSegmentSize(sizeHint)),
root(capContext.imbue(message.getRoot<AnyPointer>())) {}
// =======================================================================================
namespace {
class BrokenPipeline final: public PipelineHook, public kj::Refcounted {
public:
BrokenPipeline(const kj::Exception& exception): exception(exception) {}
kj::Own<PipelineHook> addRef() override {
return kj::addRef(*this);
}
kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) override;
private:
kj::Exception exception;
};
class BrokenRequest final: public RequestHook {
public:
BrokenRequest(const kj::Exception& exception, kj::Maybe<MessageSize> sizeHint)
: exception(exception), message(sizeHint) {}
RemotePromise<AnyPointer> send() override {
return RemotePromise<AnyPointer>(kj::cp(exception),
AnyPointer::Pipeline(kj::refcounted<BrokenPipeline>(exception)));
}
const void* getBrand() {
return nullptr;
}
kj::Exception exception;
LocalMessage message;
};
class BrokenClient final: public ClientHook, public kj::Refcounted {
public:
BrokenClient(const kj::Exception& exception): exception(exception) {}
BrokenClient(const kj::StringPtr description)
: exception(kj::Exception::Nature::PRECONDITION, kj::Exception::Durability::PERMANENT,
"", 0, kj::str(description)) {}
Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<BrokenRequest>(exception, sizeHint);
auto root = hook->message.getRoot();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) override {
return VoidPromiseAndPipeline { kj::cp(exception), kj::heap<BrokenPipeline>(exception) };
}
kj::Maybe<ClientHook&> getResolved() {
return nullptr;
}
kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
return kj::Promise<kj::Own<ClientHook>>(kj::cp(exception));
}
kj::Own<ClientHook> addRef() override {
return kj::addRef(*this);
}
const void* getBrand() override {
return nullptr;
}
private:
kj::Exception exception;
};
kj::Own<ClientHook> BrokenPipeline::getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) {
return kj::refcounted<BrokenClient>(exception);
}
} // namespace
kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason) {
return kj::refcounted<BrokenClient>(reason);
}
kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason) {
return kj::refcounted<BrokenClient>(kj::mv(reason));
}
kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
return kj::refcounted<BrokenPipeline>(kj::mv(reason));
}
} // namespace capnp
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Classes for imbuing message readers/builders with a capability context.
//
// These classes are for use by RPC implementations. Application code need not know about them.
//
// Normally, MessageReader and MessageBuilder do not support interface pointers because they
// are not RPC-aware and so have no idea how to convert between a serialized CapabilityDescriptor
// and a live capability. To fix this, a reader/builder object needs to be "imbued" with a
// capability context. This creates a new reader/builder which points at the same object but has
// the ability to deal with interface fields. Use `CapReaderContext` and `CapBuilderContext` to
// accomplish this.
#ifndef CAPNP_CAPABILITY_CONTEXT_H_
#define CAPNP_CAPABILITY_CONTEXT_H_
#include "layout.h"
#include "any.h"
#include "message.h"
#include <kj/mutex.h>
#include <kj/vector.h>
namespace kj { class Exception; }
namespace capnp {
class ClientHook;
// -------------------------------------------------------------------
class CapReaderContext {
// Class which can "imbue" reader objects from some other message with a capability context,
// so that interface pointers found in the message can be extracted and called.
//
// `imbue()` can only be called once per context.
public:
CapReaderContext(kj::Array<kj::Own<ClientHook>>&& capTable);
// `capTable` is the list of capabilities for this message.
~CapReaderContext() noexcept(false);
AnyPointer::Reader imbue(AnyPointer::Reader base);
private:
kj::Maybe<kj::Array<kj::Own<ClientHook>>> capTable; // becomes null once arena is allocated
void* arenaSpace[12 + sizeof(kj::MutexGuarded<void*>) / sizeof(void*)];
_::ImbuedReaderArena& arena() { return *reinterpret_cast<_::ImbuedReaderArena*>(arenaSpace); }
friend class _::ImbuedReaderArena;
};
class CapBuilderContext {
// Class which can "imbue" reader objects from some other message with a capability context,
// so that interface pointers found in the message can be set to point at live capabilities.
//
// `imbue()` can only be called once per context.
public:
CapBuilderContext();
~CapBuilderContext() noexcept(false);
AnyPointer::Builder imbue(AnyPointer::Builder base);
kj::ArrayPtr<kj::Own<ClientHook>> getCapTable();
// Return the table of capabilities injected into the message.
private:
bool arenaAllocated = false;
void* arenaSpace[15];
_::ImbuedBuilderArena& arena() { return *reinterpret_cast<_::ImbuedBuilderArena*>(arenaSpace); }
friend class _::ImbuedBuilderArena;
};
// -------------------------------------------------------------------
class LocalMessage final {
// An in-process message which can contain capabilities. Use in place of MallocMessageBuilder
// when you need to be able to construct a message in-memory that contains capabilities, and this
// message will never leave the process. You cannot serialize this message, since it doesn't
// know how to properly serialize its capabilities.
public:
explicit LocalMessage(kj::Maybe<MessageSize> sizeHint = nullptr);
template <typename T, typename = FromReader<T>>
inline LocalMessage(T&& reader): LocalMessage(reader.totalSize()) {
// Create a LocalMessage that is a copy of a given reader.
getRoot().setAs<FromReader<T>>(kj::fwd<T>(reader));
}
inline AnyPointer::Builder getRoot() { return root; }
inline AnyPointer::Reader getRootReader() const { return root.asReader(); }
private:
MallocMessageBuilder message;
CapBuilderContext capContext;
AnyPointer::Builder root;
};
kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason);
kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called.
kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason);
// Helper function that creates a pipeline which simply throws exceptions when called.
} // namespace capnp
#endif // CAPNP_CAPABILITY_CONTEXT_H_
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#define CAPNP_PRIVATE #define CAPNP_PRIVATE
#include "capability.h" #include "capability.h"
#include "capability-context.h"
#include "message.h" #include "message.h"
#include "arena.h" #include "arena.h"
#include <kj/refcount.h> #include <kj/refcount.h>
...@@ -34,6 +33,37 @@ ...@@ -34,6 +33,37 @@
namespace capnp { namespace capnp {
namespace _ {
void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory);
// Defined in layout.c++.
} // namespace _
namespace {
class BrokenCapFactoryImpl: public _::BrokenCapFactory {
public:
kj::Own<ClientHook> newBrokenCap(kj::StringPtr description) override {
return capnp::newBrokenCap(description);
}
};
static BrokenCapFactoryImpl brokenCapFactory;
} // namespace
ClientHook::ClientHook() {
setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
}
void MessageReader::initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTable) {
setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
arena()->initCapTable(kj::mv(capTable));
}
// =======================================================================================
Capability::Client::Client(decltype(nullptr)) Capability::Client::Client(decltype(nullptr))
: hook(newBrokenCap("Called null capability.")) {} : hook(newBrokenCap("Called null capability.")) {}
...@@ -86,24 +116,32 @@ kj::Promise<void> ClientHook::whenResolved() { ...@@ -86,24 +116,32 @@ kj::Promise<void> ClientHook::whenResolved() {
// ======================================================================================= // =======================================================================================
static inline uint firstSegmentSize(kj::Maybe<MessageSize> sizeHint) {
KJ_IF_MAYBE(s, sizeHint) {
return s->wordCount;
} else {
return SUGGESTED_FIRST_SEGMENT_WORDS;
}
}
class LocalResponse final: public ResponseHook, public kj::Refcounted { class LocalResponse final: public ResponseHook, public kj::Refcounted {
public: public:
LocalResponse(kj::Maybe<MessageSize> sizeHint) LocalResponse(kj::Maybe<MessageSize> sizeHint)
: message(sizeHint) {} : message(firstSegmentSize(sizeHint)) {}
LocalMessage message; MallocMessageBuilder message;
}; };
class LocalCallContext final: public CallContextHook, public kj::Refcounted { class LocalCallContext final: public CallContextHook, public kj::Refcounted {
public: public:
LocalCallContext(kj::Own<LocalMessage>&& request, kj::Own<ClientHook> clientRef, LocalCallContext(kj::Own<MallocMessageBuilder>&& request, kj::Own<ClientHook> clientRef,
kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller) kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller)
: request(kj::mv(request)), clientRef(kj::mv(clientRef)), : request(kj::mv(request)), clientRef(kj::mv(clientRef)),
cancelAllowedFulfiller(kj::mv(cancelAllowedFulfiller)) {} cancelAllowedFulfiller(kj::mv(cancelAllowedFulfiller)) {}
AnyPointer::Reader getParams() override { AnyPointer::Reader getParams() override {
KJ_IF_MAYBE(r, request) { KJ_IF_MAYBE(r, request) {
return r->get()->getRoot(); return r->get()->getRoot<AnyPointer>();
} else { } else {
KJ_FAIL_REQUIRE("Can't call getParams() after releaseParams()."); KJ_FAIL_REQUIRE("Can't call getParams() after releaseParams().");
} }
...@@ -114,7 +152,7 @@ public: ...@@ -114,7 +152,7 @@ public:
AnyPointer::Builder getResults(kj::Maybe<MessageSize> sizeHint) override { AnyPointer::Builder getResults(kj::Maybe<MessageSize> sizeHint) override {
if (response == nullptr) { if (response == nullptr) {
auto localResponse = kj::refcounted<LocalResponse>(sizeHint); auto localResponse = kj::refcounted<LocalResponse>(sizeHint);
responseBuilder = localResponse->message.getRoot(); responseBuilder = localResponse->message.getRoot<AnyPointer>();
response = Response<AnyPointer>(responseBuilder.asReader(), kj::mv(localResponse)); response = Response<AnyPointer>(responseBuilder.asReader(), kj::mv(localResponse));
} }
return responseBuilder; return responseBuilder;
...@@ -149,7 +187,7 @@ public: ...@@ -149,7 +187,7 @@ public:
return kj::addRef(*this); return kj::addRef(*this);
} }
kj::Maybe<kj::Own<LocalMessage>> request; kj::Maybe<kj::Own<MallocMessageBuilder>> request;
kj::Maybe<Response<AnyPointer>> response; kj::Maybe<Response<AnyPointer>> response;
AnyPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null AnyPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null
kj::Own<ClientHook> clientRef; kj::Own<ClientHook> clientRef;
...@@ -161,7 +199,7 @@ class LocalRequest final: public RequestHook { ...@@ -161,7 +199,7 @@ class LocalRequest final: public RequestHook {
public: public:
inline LocalRequest(uint64_t interfaceId, uint16_t methodId, inline LocalRequest(uint64_t interfaceId, uint16_t methodId,
kj::Maybe<MessageSize> sizeHint, kj::Own<ClientHook> client) kj::Maybe<MessageSize> sizeHint, kj::Own<ClientHook> client)
: message(kj::heap<LocalMessage>(sizeHint)), : message(kj::heap<MallocMessageBuilder>(firstSegmentSize(sizeHint))),
interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {} interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {}
RemotePromise<AnyPointer> send() override { RemotePromise<AnyPointer> send() override {
...@@ -204,7 +242,7 @@ public: ...@@ -204,7 +242,7 @@ public:
return nullptr; return nullptr;
} }
kj::Own<LocalMessage> message; kj::Own<MallocMessageBuilder> message;
private: private:
uint64_t interfaceId; uint64_t interfaceId;
...@@ -274,7 +312,7 @@ public: ...@@ -274,7 +312,7 @@ public:
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<LocalRequest>( auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, sizeHint, kj::addRef(*this)); interfaceId, methodId, sizeHint, kj::addRef(*this));
auto root = hook->message->getRoot(); // Do not inline `root` -- kj::mv may happen first. auto root = hook->message->getRoot<AnyPointer>();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
} }
...@@ -424,7 +462,7 @@ public: ...@@ -424,7 +462,7 @@ public:
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<LocalRequest>( auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, sizeHint, kj::addRef(*this)); interfaceId, methodId, sizeHint, kj::addRef(*this));
auto root = hook->message->getRoot(); // Do not inline `root` -- kj::mv may happen first. auto root = hook->message->getRoot<AnyPointer>();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
} }
...@@ -495,4 +533,97 @@ kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& pro ...@@ -495,4 +533,97 @@ kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& pro
return kj::refcounted<QueuedClient>(kj::mv(promise)); return kj::refcounted<QueuedClient>(kj::mv(promise));
} }
// =======================================================================================
namespace {
class BrokenPipeline final: public PipelineHook, public kj::Refcounted {
public:
BrokenPipeline(const kj::Exception& exception): exception(exception) {}
kj::Own<PipelineHook> addRef() override {
return kj::addRef(*this);
}
kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) override;
private:
kj::Exception exception;
};
class BrokenRequest final: public RequestHook {
public:
BrokenRequest(const kj::Exception& exception, kj::Maybe<MessageSize> sizeHint)
: exception(exception), message(firstSegmentSize(sizeHint)) {}
RemotePromise<AnyPointer> send() override {
return RemotePromise<AnyPointer>(kj::cp(exception),
AnyPointer::Pipeline(kj::refcounted<BrokenPipeline>(exception)));
}
const void* getBrand() {
return nullptr;
}
kj::Exception exception;
MallocMessageBuilder message;
};
class BrokenClient final: public ClientHook, public kj::Refcounted {
public:
BrokenClient(const kj::Exception& exception): exception(exception) {}
BrokenClient(const kj::StringPtr description)
: exception(kj::Exception::Nature::PRECONDITION, kj::Exception::Durability::PERMANENT,
"", 0, kj::str(description)) {}
Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<BrokenRequest>(exception, sizeHint);
auto root = hook->message.getRoot<AnyPointer>();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) override {
return VoidPromiseAndPipeline { kj::cp(exception), kj::heap<BrokenPipeline>(exception) };
}
kj::Maybe<ClientHook&> getResolved() {
return nullptr;
}
kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
return kj::Promise<kj::Own<ClientHook>>(kj::cp(exception));
}
kj::Own<ClientHook> addRef() override {
return kj::addRef(*this);
}
const void* getBrand() override {
return nullptr;
}
private:
kj::Exception exception;
};
kj::Own<ClientHook> BrokenPipeline::getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) {
return kj::refcounted<BrokenClient>(exception);
}
} // namespace
kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason) {
return kj::refcounted<BrokenClient>(reason);
}
kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason) {
return kj::refcounted<BrokenClient>(kj::mv(reason));
}
kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
return kj::refcounted<BrokenPipeline>(kj::mv(reason));
}
} // namespace capnp } // namespace capnp
...@@ -332,6 +332,8 @@ public: ...@@ -332,6 +332,8 @@ public:
class ClientHook { class ClientHook {
public: public:
ClientHook();
virtual Request<AnyPointer, AnyPointer> newCall( virtual Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) = 0; uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) = 0;
// Start a new call, allowing the client to allocate request/response objects as it sees fit. // Start a new call, allowing the client to allocate request/response objects as it sees fit.
...@@ -410,6 +412,13 @@ kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& pro ...@@ -410,6 +412,13 @@ kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& pro
// the new client. This hook's `getResolved()` and `whenMoreResolved()` methods will reflect the // the new client. This hook's `getResolved()` and `whenMoreResolved()` methods will reflect the
// redirection to the eventual replacement client. // redirection to the eventual replacement client.
kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason);
kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called.
kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason);
// Helper function that creates a pipeline which simply throws exceptions when called.
// ======================================================================================= // =======================================================================================
// Extend PointerHelpers for interfaces // Extend PointerHelpers for interfaces
......
...@@ -79,6 +79,16 @@ class EzRpcClient { ...@@ -79,6 +79,16 @@ class EzRpcClient {
// EzRpcClient / EzRpcServer objects in a single thread; they will make sure to make no more // EzRpcClient / EzRpcServer objects in a single thread; they will make sure to make no more
// than one EventLoop.) // than one EventLoop.)
// - These classes only support simple two-party connections, not multilateral VatNetworks. // - These classes only support simple two-party connections, not multilateral VatNetworks.
// - These classes only support communication over a raw, unencrypted socket. If you want to
// build on an abstract stream (perhaps one which supports encryption), you must use the
// lower-level interfaces.
//
// Some of these restrictions will probably be lifted in future versions, but some things will
// always require using the low-level interfaces directly. If you are interested in working
// at a lower level, start by looking at these interfaces:
// - `kj::startAsyncIo()` in `kj/async-io.h`.
// - `RpcSystem` in `capnp/rpc.h`.
// - `TwoPartyVatNetwork` in `capnp/rpc-twoparty.h`.
public: public:
explicit EzRpcClient(kj::StringPtr serverAddress, uint defaultPort = 0); explicit EzRpcClient(kj::StringPtr serverAddress, uint defaultPort = 0);
......
...@@ -279,7 +279,7 @@ static void checkStruct(StructReader reader) { ...@@ -279,7 +279,7 @@ static void checkStruct(StructReader reader) {
TEST(WireFormat, StructRoundTrip_OneSegment) { TEST(WireFormat, StructRoundTrip_OneSegment) {
MallocMessageBuilder message; MallocMessageBuilder message;
BasicBuilderArena arena(&message); BuilderArena arena(&message);
auto allocation = arena.allocate(1 * WORDS); auto allocation = arena.allocate(1 * WORDS);
SegmentBuilder* segment = allocation.segment; SegmentBuilder* segment = allocation.segment;
word* rootLocation = allocation.words; word* rootLocation = allocation.words;
...@@ -316,7 +316,7 @@ TEST(WireFormat, StructRoundTrip_OneSegment) { ...@@ -316,7 +316,7 @@ TEST(WireFormat, StructRoundTrip_OneSegment) {
TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) { TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
MallocMessageBuilder message(0, AllocationStrategy::FIXED_SIZE); MallocMessageBuilder message(0, AllocationStrategy::FIXED_SIZE);
BasicBuilderArena arena(&message); BuilderArena arena(&message);
auto allocation = arena.allocate(1 * WORDS); auto allocation = arena.allocate(1 * WORDS);
SegmentBuilder* segment = allocation.segment; SegmentBuilder* segment = allocation.segment;
word* rootLocation = allocation.words; word* rootLocation = allocation.words;
...@@ -354,7 +354,7 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) { ...@@ -354,7 +354,7 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) { TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) {
MallocMessageBuilder message(8, AllocationStrategy::FIXED_SIZE); MallocMessageBuilder message(8, AllocationStrategy::FIXED_SIZE);
BasicBuilderArena arena(&message); BuilderArena arena(&message);
auto allocation = arena.allocate(1 * WORDS); auto allocation = arena.allocate(1 * WORDS);
SegmentBuilder* segment = allocation.segment; SegmentBuilder* segment = allocation.segment;
word* rootLocation = allocation.words; word* rootLocation = allocation.words;
......
...@@ -37,8 +37,8 @@ static BrokenCapFactory* brokenCapFactory = nullptr; ...@@ -37,8 +37,8 @@ static BrokenCapFactory* brokenCapFactory = nullptr;
// but we can't have a link-time dependency on libcapnp-rpc. // but we can't have a link-time dependency on libcapnp-rpc.
void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory) { void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory) {
// Called from capability-context.c++ when a capability context is created. May be called // Called from capability.c++ when the capability API is used, to make sure that layout.c++
// multiple times but always with the same value. // is ready for it. May be called multiple times but always with the same value.
__atomic_store_n(&brokenCapFactory, &factory, __ATOMIC_RELAXED); __atomic_store_n(&brokenCapFactory, &factory, __ATOMIC_RELAXED);
} }
...@@ -1725,8 +1725,7 @@ struct WireHelpers { ...@@ -1725,8 +1725,7 @@ struct WireHelpers {
setCapabilityPointer(dstSegment, dst, kj::mv(*cap), orphanArena); setCapabilityPointer(dstSegment, dst, kj::mv(*cap), orphanArena);
return { dstSegment, nullptr }; return { dstSegment, nullptr };
} else { } else {
KJ_FAIL_REQUIRE("Message contained capability pointer but is not imbued with a " KJ_FAIL_REQUIRE("Message contained invalid capability pointer.") {
"capability context.") {
goto useDefault; goto useDefault;
} }
} }
...@@ -1846,28 +1845,21 @@ struct WireHelpers { ...@@ -1846,28 +1845,21 @@ struct WireHelpers {
"use the Cap'n Proto RPC system."); "use the Cap'n Proto RPC system.");
if (ref->isNull()) { if (ref->isNull()) {
maybeCap = brokenCapFactory->newBrokenCap("Calling null capability pointer."); return brokenCapFactory->newBrokenCap("Calling null capability pointer.");
} else if (!ref->isCapability()) { } else if (!ref->isCapability()) {
KJ_FAIL_REQUIRE( KJ_FAIL_REQUIRE(
"Message contains non-capability pointer where capability pointer was expected.") { "Message contains non-capability pointer where capability pointer was expected.") {
break; break;
} }
maybeCap = brokenCapFactory->newBrokenCap( return brokenCapFactory->newBrokenCap(
"Calling capability extracted from a non-capability pointer."); "Calling capability extracted from a non-capability pointer.");
} else { } else KJ_IF_MAYBE(cap, segment->getArena()->extractCap(ref->capRef.index.get())) {
maybeCap = segment->getArena()->extractCap(ref->capRef.index.get());
}
KJ_IF_MAYBE(cap, maybeCap) {
return kj::mv(*cap); return kj::mv(*cap);
} else { } else {
// The message is not imbued with a capability context. We can't really recover from this, KJ_FAIL_REQUIRE("Message contains invalid capability pointer.") {
// because we have no way to construct a ClientHook in this case -- capability.c++ may not break;
// even be linked in. Luckily, this is the caller's error, not the message sender's -- }
// it's the message reader who is calling a capability getter on a message they should know return brokenCapFactory->newBrokenCap("Calling invalid capability pointer.");
// they have not imbued properly.
KJ_FAIL_REQUIRE("Tried to read a capability out of a message that doesn't have a "
"capability context.");
} }
} }
...@@ -2216,10 +2208,6 @@ BuilderArena* PointerBuilder::getArena() const { ...@@ -2216,10 +2208,6 @@ BuilderArena* PointerBuilder::getArena() const {
return segment->getArena(); return segment->getArena();
} }
PointerBuilder PointerBuilder::imbue(ImbuedBuilderArena& newArena) const {
return PointerBuilder(newArena.imbue(segment), pointer);
}
// ======================================================================================= // =======================================================================================
// PointerReader // PointerReader
...@@ -2278,10 +2266,6 @@ kj::Maybe<Arena&> PointerReader::getArena() const { ...@@ -2278,10 +2266,6 @@ kj::Maybe<Arena&> PointerReader::getArena() const {
return segment == nullptr ? nullptr : segment->getArena(); return segment == nullptr ? nullptr : segment->getArena();
} }
PointerReader PointerReader::imbue(ImbuedReaderArena& newArena) const {
return PointerReader(newArena.imbue(segment), pointer, nestingLimit);
}
// ======================================================================================= // =======================================================================================
// StructBuilder // StructBuilder
......
...@@ -56,8 +56,6 @@ class SegmentReader; ...@@ -56,8 +56,6 @@ class SegmentReader;
class SegmentBuilder; class SegmentBuilder;
class Arena; class Arena;
class BuilderArena; class BuilderArena;
class ImbuedReaderArena;
class ImbuedBuilderArena;
// ============================================================================= // =============================================================================
...@@ -341,9 +339,6 @@ public: ...@@ -341,9 +339,6 @@ public:
BuilderArena* getArena() const; BuilderArena* getArena() const;
// Get the arena containing this pointer. // Get the arena containing this pointer.
PointerBuilder imbue(ImbuedBuilderArena& newArena) const;
// Imbue the pointer with a capability context, returning the imbued pointer.
private: private:
SegmentBuilder* segment; // Memory segment in which the pointer resides. SegmentBuilder* segment; // Memory segment in which the pointer resides.
WirePointer* pointer; // Pointer to the pointer. WirePointer* pointer; // Pointer to the pointer.
...@@ -392,9 +387,6 @@ public: ...@@ -392,9 +387,6 @@ public:
kj::Maybe<Arena&> getArena() const; kj::Maybe<Arena&> getArena() const;
// Get the arena containing this pointer. // Get the arena containing this pointer.
PointerReader imbue(ImbuedReaderArena& newArena) const;
// Imbue the pointer with a capability context, returning the imbued pointer.
private: private:
SegmentReader* segment; // Memory segment in which the pointer resides. SegmentReader* segment; // Memory segment in which the pointer resides.
const WirePointer* pointer; // Pointer to the pointer. null = treat as null pointer. const WirePointer* pointer; // Pointer to the pointer. null = treat as null pointer.
...@@ -473,10 +465,6 @@ public: ...@@ -473,10 +465,6 @@ public:
BuilderArena* getArena(); BuilderArena* getArena();
// Gets the arena in which this object is allocated. // Gets the arena in which this object is allocated.
void unimbue();
// Removes the capability context from the builder. This means replacing the segment pointer --
// which is assumed to point to an ImbuedSegmentBuilder -- with the non-imbued base segment.
private: private:
SegmentBuilder* segment; // Memory segment in which the struct resides. SegmentBuilder* segment; // Memory segment in which the struct resides.
void* data; // Pointer to the encoded data. void* data; // Pointer to the encoded data.
...@@ -542,10 +530,6 @@ public: ...@@ -542,10 +530,6 @@ public:
// use the result as a hint for allocating the first segment, do the copy, and then throw an // use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns. // exception if it overruns.
void unimbue();
// Removes the capability context from the reader. This means replacing the segment pointer --
// which is assumed to point to an ImbuedSegmentReader -- with the non-imbued base segment.
private: private:
SegmentReader* segment; // Memory segment in which the struct resides. SegmentReader* segment; // Memory segment in which the struct resides.
......
...@@ -38,16 +38,16 @@ namespace capnp { ...@@ -38,16 +38,16 @@ namespace capnp {
MessageReader::MessageReader(ReaderOptions options): options(options), allocatedArena(false) {} MessageReader::MessageReader(ReaderOptions options): options(options), allocatedArena(false) {}
MessageReader::~MessageReader() noexcept(false) { MessageReader::~MessageReader() noexcept(false) {
if (allocatedArena) { if (allocatedArena) {
arena()->~BasicReaderArena(); arena()->~ReaderArena();
} }
} }
AnyPointer::Reader MessageReader::getRootInternal() { AnyPointer::Reader MessageReader::getRootInternal() {
if (!allocatedArena) { if (!allocatedArena) {
static_assert(sizeof(_::BasicReaderArena) <= sizeof(arenaSpace), static_assert(sizeof(_::ReaderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a BasicReaderArena. Please increase it. This will break " "arenaSpace is too small to hold a ReaderArena. Please increase it. This will break "
"ABI compatibility."); "ABI compatibility.");
new(arena()) _::BasicReaderArena(this); new(arena()) _::ReaderArena(this);
allocatedArena = true; allocatedArena = true;
} }
...@@ -75,7 +75,7 @@ _::SegmentBuilder* MessageBuilder::getRootSegment() { ...@@ -75,7 +75,7 @@ _::SegmentBuilder* MessageBuilder::getRootSegment() {
if (allocatedArena) { if (allocatedArena) {
return arena()->getSegment(_::SegmentId(0)); return arena()->getSegment(_::SegmentId(0));
} else { } else {
static_assert(sizeof(_::BasicBuilderArena) <= sizeof(arenaSpace), static_assert(sizeof(_::BuilderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a BuilderArena. Please increase it."); "arenaSpace is too small to hold a BuilderArena. Please increase it.");
kj::ctor(*arena(), this); kj::ctor(*arena(), this);
allocatedArena = true; allocatedArena = true;
...@@ -104,6 +104,14 @@ kj::ArrayPtr<const kj::ArrayPtr<const word>> MessageBuilder::getSegmentsForOutpu ...@@ -104,6 +104,14 @@ kj::ArrayPtr<const kj::ArrayPtr<const word>> MessageBuilder::getSegmentsForOutpu
} }
} }
kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> MessageBuilder::getCapTable() {
if (allocatedArena) {
return arena()->getCapTable();
} else {
return nullptr;
}
}
Orphanage MessageBuilder::getOrphanage() { Orphanage MessageBuilder::getOrphanage() {
// We must ensure that the arena and root pointer have been allocated before the Orphanage // We must ensure that the arena and root pointer have been allocated before the Orphanage
// can be used. // can be used.
......
...@@ -34,8 +34,8 @@ ...@@ -34,8 +34,8 @@
namespace capnp { namespace capnp {
namespace _ { // private namespace _ { // private
class BasicReaderArena; class ReaderArena;
class BasicBuilderArena; class BuilderArena;
} }
class StructSchema; class StructSchema;
...@@ -118,6 +118,16 @@ public: ...@@ -118,6 +118,16 @@ public:
// RootType in this case must be DynamicStruct, and you must #include <capnp/dynamic.h> to // RootType in this case must be DynamicStruct, and you must #include <capnp/dynamic.h> to
// use this. // use this.
void initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTable);
// Sets the table of capabilities embedded in this message. Capability pointers found in the
// message content contain indexes into this table. You must call this before attempting to
// read any capability pointers (interface pointers) from the message. The table is not passed
// to the constructor because often (as in the RPC system) the cap table is actually constructed
// based on a list read from the message itself.
//
// You must link against libcapnp-rpc to call this method (the rest of MessageBuilder is in
// regular libcapnp).
private: private:
ReaderOptions options; ReaderOptions options;
...@@ -128,7 +138,7 @@ private: ...@@ -128,7 +138,7 @@ private:
void* arenaSpace[15 + sizeof(kj::MutexGuarded<void*>) / sizeof(void*)]; void* arenaSpace[15 + sizeof(kj::MutexGuarded<void*>) / sizeof(void*)];
bool allocatedArena; bool allocatedArena;
_::BasicReaderArena* arena() { return reinterpret_cast<_::BasicReaderArena*>(arenaSpace); } _::ReaderArena* arena() { return reinterpret_cast<_::ReaderArena*>(arenaSpace); }
AnyPointer::Reader getRootInternal(); AnyPointer::Reader getRootInternal();
}; };
...@@ -186,11 +196,18 @@ public: ...@@ -186,11 +196,18 @@ public:
// Like setRoot() but adopts the orphan without copying. // Like setRoot() but adopts the orphan without copying.
kj::ArrayPtr<const kj::ArrayPtr<const word>> getSegmentsForOutput(); kj::ArrayPtr<const kj::ArrayPtr<const word>> getSegmentsForOutput();
// Get the raw data that makes up the message.
kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> getCapTable();
// Get the table of capabilities (interface pointers) that have been added to this message.
// When you later parse this message, you must call `initCapTable()` on the `MessageReader` and
// give it an equivalent set of capabilities, otherwise cap pointers in the message will be
// unusable.
Orphanage getOrphanage(); Orphanage getOrphanage();
private: private:
void* arenaSpace[16]; void* arenaSpace[17];
// Space in which we can construct a BuilderArena. We don't use BuilderArena directly here // Space in which we can construct a BuilderArena. We don't use BuilderArena directly here
// because we don't want clients to have to #include arena.h, which itself includes a bunch of // because we don't want clients to have to #include arena.h, which itself includes a bunch of
// big STL headers. We don't use a pointer to a BuilderArena because that would require an // big STL headers. We don't use a pointer to a BuilderArena because that would require an
...@@ -203,7 +220,7 @@ private: ...@@ -203,7 +220,7 @@ private:
// isn't constructed yet. This is kind of annoying because it means that getOrphanage() is // isn't constructed yet. This is kind of annoying because it means that getOrphanage() is
// not thread-safe, but that shouldn't be a huge deal... // not thread-safe, but that shouldn't be a huge deal...
_::BasicBuilderArena* arena() { return reinterpret_cast<_::BasicBuilderArena*>(arenaSpace); } _::BuilderArena* arena() { return reinterpret_cast<_::BuilderArena*>(arenaSpace); }
_::SegmentBuilder* getRootSegment(); _::SegmentBuilder* getRootSegment();
AnyPointer::Builder getRootInternal(); AnyPointer::Builder getRootInternal();
}; };
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "rpc.h" #include "rpc.h"
#include "capability-context.h"
#include "test-util.h" #include "test-util.h"
#include "schema.h" #include "schema.h"
#include "serialize.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/string-tree.h> #include <kj/string-tree.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -83,14 +83,7 @@ public: ...@@ -83,14 +83,7 @@ public:
} }
auto payload = call.getParams(); auto payload = call.getParams();
auto capTableReader = payload.getCapTable(); auto params = kj::str(payload.getContent().getAs<DynamicStruct>(paramType));
auto capTable = kj::heapArrayBuilder<kj::Own<ClientHook>>(capTableReader.size());
for (uint i = 0; i < capTableReader.size(); i++) {
capTable.add(newBrokenCap("fake cap"));
}
CapReaderContext context(capTable.finish());
auto params = kj::str(context.imbue(payload.getContent()).getAs<DynamicStruct>(paramType));
auto sendResultsTo = call.getSendResultsTo(); auto sendResultsTo = call.getSendResultsTo();
...@@ -120,22 +113,14 @@ public: ...@@ -120,22 +113,14 @@ public:
} }
auto payload = ret.getResults(); auto payload = ret.getResults();
auto capTableReader = payload.getCapTable();
auto capTable = kj::heapArrayBuilder<kj::Own<ClientHook>>(capTableReader.size());
for (uint i = 0; i < capTableReader.size(); i++) {
capTable.add(newBrokenCap("fake cap"));
}
CapReaderContext context(capTable.finish());
auto imbued = context.imbue(payload.getContent());
if (schema.getProto().isStruct()) { if (schema.getProto().isStruct()) {
auto results = kj::str(imbued.getAs<DynamicStruct>(schema.asStruct())); auto results = kj::str(payload.getContent().getAs<DynamicStruct>(schema.asStruct()));
return kj::str(senderName, "(", ret.getAnswerId(), "): return ", results, return kj::str(senderName, "(", ret.getAnswerId(), "): return ", results,
" caps:[", kj::strArray(payload.getCapTable(), ", "), "]"); " caps:[", kj::strArray(payload.getCapTable(), ", "), "]");
} else if (schema.getProto().isInterface()) { } else if (schema.getProto().isInterface()) {
imbued.getAs<DynamicCapability>(schema.asInterface()); payload.getContent().getAs<DynamicCapability>(schema.asInterface());
return kj::str(senderName, "(", ret.getAnswerId(), "): return cap ", return kj::str(senderName, "(", ret.getAnswerId(), "): return cap ",
kj::strArray(payload.getCapTable(), ", ")); kj::strArray(payload.getCapTable(), ", "));
} else { } else {
...@@ -246,26 +231,37 @@ public: ...@@ -246,26 +231,37 @@ public:
class IncomingRpcMessageImpl final: public IncomingRpcMessage, public kj::Refcounted { class IncomingRpcMessageImpl final: public IncomingRpcMessage, public kj::Refcounted {
public: public:
IncomingRpcMessageImpl(uint firstSegmentWordSize) IncomingRpcMessageImpl(kj::Array<word> data)
: message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : data(kj::mv(data)),
: firstSegmentWordSize) {} message(this->data) {}
AnyPointer::Reader getBody() override { AnyPointer::Reader getBody() override {
return message.getRoot<AnyPointer>().asReader(); return message.getRoot<AnyPointer>();
} }
MallocMessageBuilder message; void initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>>&& capTable) override {
message.initCapTable(kj::mv(capTable));
}
kj::Array<word> data;
FlatArrayMessageReader message;
}; };
class OutgoingRpcMessageImpl final: public OutgoingRpcMessage { class OutgoingRpcMessageImpl final: public OutgoingRpcMessage {
public: public:
OutgoingRpcMessageImpl(ConnectionImpl& connection, uint firstSegmentWordSize) OutgoingRpcMessageImpl(ConnectionImpl& connection, uint firstSegmentWordSize)
: connection(connection), : connection(connection),
message(kj::refcounted<IncomingRpcMessageImpl>(firstSegmentWordSize)) {} message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS
: firstSegmentWordSize) {}
AnyPointer::Builder getBody() override { AnyPointer::Builder getBody() override {
return message->message.getRoot<AnyPointer>(); return message.getRoot<AnyPointer>();
}
kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> getCapTable() override {
return message.getCapTable();
} }
void send() override { void send() override {
if (connection.networkException != nullptr) { if (connection.networkException != nullptr) {
return; return;
...@@ -275,11 +271,13 @@ public: ...@@ -275,11 +271,13 @@ public:
// Uncomment to get a debug dump. // Uncomment to get a debug dump.
// kj::String msg = connection.network.network.dumper.dump( // kj::String msg = connection.network.network.dumper.dump(
// message->message.getRoot<rpc::Message>(), connection.sender); // message.getRoot<rpc::Message>(), connection.sender);
// KJ_ DBG(msg); // KJ_ DBG(msg);
auto incomingMessage = kj::heap<IncomingRpcMessageImpl>(messageToFlatArray(message));
auto connectionPtr = &connection; auto connectionPtr = &connection;
connection.tasks->add(kj::evalLater(kj::mvCapture(kj::addRef(*message), connection.tasks->add(kj::evalLater(kj::mvCapture(incomingMessage,
[connectionPtr](kj::Own<IncomingRpcMessageImpl>&& message) { [connectionPtr](kj::Own<IncomingRpcMessageImpl>&& message) {
KJ_IF_MAYBE(p, connectionPtr->partner) { KJ_IF_MAYBE(p, connectionPtr->partner) {
if (p->fulfillers.empty()) { if (p->fulfillers.empty()) {
...@@ -296,7 +294,7 @@ public: ...@@ -296,7 +294,7 @@ public:
private: private:
ConnectionImpl& connection; ConnectionImpl& connection;
kj::Own<IncomingRpcMessageImpl> message; MallocMessageBuilder message;
}; };
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) override { kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) override {
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "rpc-twoparty.h" #include "rpc-twoparty.h"
#include "capability-context.h"
#include "test-util.h" #include "test-util.h"
#include <kj/async-unix.h> #include <kj/async-unix.h>
#include <kj/debug.h> #include <kj/debug.h>
......
...@@ -76,6 +76,10 @@ public: ...@@ -76,6 +76,10 @@ public:
return message.getRoot<AnyPointer>(); return message.getRoot<AnyPointer>();
} }
kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> getCapTable() override {
return message.getCapTable();
}
void send() override { void send() override {
network.previousWrite = network.previousWrite.then([&]() { network.previousWrite = network.previousWrite.then([&]() {
auto promise = writeMessage(network.stream, message).then([]() { auto promise = writeMessage(network.stream, message).then([]() {
...@@ -101,6 +105,10 @@ public: ...@@ -101,6 +105,10 @@ public:
return message->getRoot<AnyPointer>(); return message->getRoot<AnyPointer>();
} }
void initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>>&& capTable) override {
message->initCapTable(kj::mv(capTable));
}
private: private:
kj::Own<MessageReader> message; kj::Own<MessageReader> message;
}; };
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "rpc.h" #include "rpc.h"
#include "capability-context.h" #include "message.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/vector.h> #include <kj/vector.h>
#include <kj/async.h> #include <kj/async.h>
...@@ -910,13 +910,17 @@ private: ...@@ -910,13 +910,17 @@ private:
} }
} }
kj::Array<ExportId> writeDescriptors(kj::ArrayPtr<kj::Own<ClientHook>> capTable, kj::Array<ExportId> writeDescriptors(kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> capTable,
rpc::Payload::Builder payload) { rpc::Payload::Builder payload) {
auto capTableBuilder = payload.initCapTable(capTable.size()); auto capTableBuilder = payload.initCapTable(capTable.size());
kj::Vector<ExportId> exports(capTable.size()); kj::Vector<ExportId> exports(capTable.size());
for (uint i: kj::indices(capTable)) { for (uint i: kj::indices(capTable)) {
KJ_IF_MAYBE(exportId, writeDescriptor(*capTable[i], capTableBuilder[i])) { KJ_IF_MAYBE(cap, capTable[i]) {
exports.add(*exportId); KJ_IF_MAYBE(exportId, writeDescriptor(**cap, capTableBuilder[i])) {
exports.add(*exportId);
}
} else {
capTableBuilder[i].setNone();
} }
} }
return exports.releaseAsArray(); return exports.releaseAsArray();
...@@ -1070,10 +1074,10 @@ private: ...@@ -1070,10 +1074,10 @@ private:
} }
} }
kj::Own<ClientHook> receiveCap(rpc::CapDescriptor::Reader descriptor) { kj::Maybe<kj::Own<ClientHook>> receiveCap(rpc::CapDescriptor::Reader descriptor) {
switch (descriptor.which()) { switch (descriptor.which()) {
case rpc::CapDescriptor::NONE: case rpc::CapDescriptor::NONE:
return newBrokenCap("Called a `CapDescriptor.none`."); return nullptr;
case rpc::CapDescriptor::SENDER_HOSTED: case rpc::CapDescriptor::SENDER_HOSTED:
return import(descriptor.getSenderHosted(), false); return import(descriptor.getSenderHosted(), false);
...@@ -1115,8 +1119,8 @@ private: ...@@ -1115,8 +1119,8 @@ private:
} }
} }
kj::Array<kj::Own<ClientHook>> receiveCaps(List<rpc::CapDescriptor>::Reader capTable) { kj::Array<kj::Maybe<kj::Own<ClientHook>>> receiveCaps(List<rpc::CapDescriptor>::Reader capTable) {
auto result = kj::heapArrayBuilder<kj::Own<ClientHook>>(capTable.size()); auto result = kj::heapArrayBuilder<kj::Maybe<kj::Own<ClientHook>>>(capTable.size());
for (auto cap: capTable) { for (auto cap: capTable) {
result.add(receiveCap(cap)); result.add(receiveCap(cap));
} }
...@@ -1194,7 +1198,7 @@ private: ...@@ -1194,7 +1198,7 @@ private:
firstSegmentSize(sizeHint, messageSizeHint<rpc::Call>() + firstSegmentSize(sizeHint, messageSizeHint<rpc::Call>() +
sizeInWords<rpc::Payload>() + MESSAGE_TARGET_SIZE_HINT))), sizeInWords<rpc::Payload>() + MESSAGE_TARGET_SIZE_HINT))),
callBuilder(message->getBody().getAs<rpc::Message>().initCall()), callBuilder(message->getBody().getAs<rpc::Message>().initCall()),
paramsBuilder(context.imbue(callBuilder.getParams().getContent())) {} paramsBuilder(callBuilder.getParams().getContent()) {}
inline AnyPointer::Builder getRoot() { inline AnyPointer::Builder getRoot() {
return paramsBuilder; return paramsBuilder;
...@@ -1288,7 +1292,6 @@ private: ...@@ -1288,7 +1292,6 @@ private:
kj::Own<RpcClient> target; kj::Own<RpcClient> target;
kj::Own<OutgoingRpcMessage> message; kj::Own<OutgoingRpcMessage> message;
CapBuilderContext context;
rpc::Call::Builder callBuilder; rpc::Call::Builder callBuilder;
AnyPointer::Builder paramsBuilder; AnyPointer::Builder paramsBuilder;
...@@ -1300,7 +1303,7 @@ private: ...@@ -1300,7 +1303,7 @@ private:
SendInternalResult sendInternal(bool isTailCall) { SendInternalResult sendInternal(bool isTailCall) {
// Build the cap table. // Build the cap table.
auto exports = connectionState->writeDescriptors( auto exports = connectionState->writeDescriptors(
context.getCapTable(), callBuilder.getParams()); message->getCapTable(), callBuilder.getParams());
// Init the question table. Do this after writing descriptors to avoid interference. // Init the question table. Do this after writing descriptors to avoid interference.
QuestionId questionId; QuestionId questionId;
...@@ -1432,12 +1435,10 @@ private: ...@@ -1432,12 +1435,10 @@ private:
RpcResponseImpl(RpcConnectionState& connectionState, RpcResponseImpl(RpcConnectionState& connectionState,
kj::Own<QuestionRef>&& questionRef, kj::Own<QuestionRef>&& questionRef,
kj::Own<IncomingRpcMessage>&& message, kj::Own<IncomingRpcMessage>&& message,
AnyPointer::Reader results, AnyPointer::Reader results)
kj::Array<kj::Own<ClientHook>>&& capTable)
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
message(kj::mv(message)), message(kj::mv(message)),
context(kj::mv(capTable)), reader(results),
reader(context.imbue(results)),
questionRef(kj::mv(questionRef)) {} questionRef(kj::mv(questionRef)) {}
AnyPointer::Reader getResults() override { AnyPointer::Reader getResults() override {
...@@ -1451,7 +1452,6 @@ private: ...@@ -1451,7 +1452,6 @@ private:
private: private:
kj::Own<RpcConnectionState> connectionState; kj::Own<RpcConnectionState> connectionState;
kj::Own<IncomingRpcMessage> message; kj::Own<IncomingRpcMessage> message;
CapReaderContext context;
AnyPointer::Reader reader; AnyPointer::Reader reader;
kj::Own<QuestionRef> questionRef; kj::Own<QuestionRef> questionRef;
}; };
...@@ -1471,11 +1471,10 @@ private: ...@@ -1471,11 +1471,10 @@ private:
rpc::Payload::Builder payload) rpc::Payload::Builder payload)
: connectionState(connectionState), : connectionState(connectionState),
message(kj::mv(message)), message(kj::mv(message)),
payload(payload), payload(payload) {}
builder(context.imbue(payload.getContent())) {}
AnyPointer::Builder getResultsBuilder() override { AnyPointer::Builder getResultsBuilder() override {
return builder; return payload.getContent();
} }
kj::Maybe<kj::Array<ExportId>> send() { kj::Maybe<kj::Array<ExportId>> send() {
...@@ -1483,7 +1482,7 @@ private: ...@@ -1483,7 +1482,7 @@ private:
// (Could return a non-null empty array if there were caps but none of them were exports.) // (Could return a non-null empty array if there were caps but none of them were exports.)
// Build the cap table. // Build the cap table.
auto capTable = context.getCapTable(); auto capTable = message->getCapTable();
auto exports = connectionState.writeDescriptors(capTable, payload); auto exports = connectionState.writeDescriptors(capTable, payload);
message->send(); message->send();
...@@ -1497,23 +1496,22 @@ private: ...@@ -1497,23 +1496,22 @@ private:
private: private:
RpcConnectionState& connectionState; RpcConnectionState& connectionState;
kj::Own<OutgoingRpcMessage> message; kj::Own<OutgoingRpcMessage> message;
CapBuilderContext context;
rpc::Payload::Builder payload; rpc::Payload::Builder payload;
AnyPointer::Builder builder;
}; };
class LocallyRedirectedRpcResponse final class LocallyRedirectedRpcResponse final
: public RpcResponse, public RpcServerResponse, public kj::Refcounted{ : public RpcResponse, public RpcServerResponse, public kj::Refcounted{
public: public:
LocallyRedirectedRpcResponse(kj::Maybe<MessageSize> sizeHint) LocallyRedirectedRpcResponse(kj::Maybe<MessageSize> sizeHint)
: message(sizeHint) {} : message(sizeHint.map([](MessageSize size) { return size.wordCount; })
.orDefault(SUGGESTED_FIRST_SEGMENT_WORDS)) {}
AnyPointer::Builder getResultsBuilder() override { AnyPointer::Builder getResultsBuilder() override {
return message.getRoot(); return message.getRoot<AnyPointer>();
} }
AnyPointer::Reader getResults() override { AnyPointer::Reader getResults() override {
return message.getRootReader(); return message.getRoot<AnyPointer>();
} }
kj::Own<RpcResponse> addRef() override { kj::Own<RpcResponse> addRef() override {
...@@ -1521,20 +1519,18 @@ private: ...@@ -1521,20 +1519,18 @@ private:
} }
private: private:
LocalMessage message; MallocMessageBuilder message;
}; };
class RpcCallContext final: public CallContextHook, public kj::Refcounted { class RpcCallContext final: public CallContextHook, public kj::Refcounted {
public: public:
RpcCallContext(RpcConnectionState& connectionState, AnswerId answerId, RpcCallContext(RpcConnectionState& connectionState, AnswerId answerId,
kj::Own<IncomingRpcMessage>&& request, const AnyPointer::Reader& params, kj::Own<IncomingRpcMessage>&& request, const AnyPointer::Reader& params,
kj::Array<kj::Own<ClientHook>>&& requestCapTable, bool redirectResults, bool redirectResults, kj::Own<kj::PromiseFulfiller<void>>&& cancelFulfiller)
kj::Own<kj::PromiseFulfiller<void>>&& cancelFulfiller)
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
answerId(answerId), answerId(answerId),
request(kj::mv(request)), request(kj::mv(request)),
requestCapContext(kj::mv(requestCapTable)), params(params),
params(requestCapContext.imbue(params)),
returnMessage(nullptr), returnMessage(nullptr),
redirectResults(redirectResults), redirectResults(redirectResults),
cancelFulfiller(kj::mv(cancelFulfiller)) {} cancelFulfiller(kj::mv(cancelFulfiller)) {}
...@@ -1740,7 +1736,6 @@ private: ...@@ -1740,7 +1736,6 @@ private:
// Request --------------------------------------------- // Request ---------------------------------------------
kj::Maybe<kj::Own<IncomingRpcMessage>> request; kj::Maybe<kj::Own<IncomingRpcMessage>> request;
CapReaderContext requestCapContext;
AnyPointer::Reader params; AnyPointer::Reader params;
// Response -------------------------------------------- // Response --------------------------------------------
...@@ -1938,14 +1933,14 @@ private: ...@@ -1938,14 +1933,14 @@ private:
} }
auto payload = call.getParams(); auto payload = call.getParams();
auto capTable = receiveCaps(payload.getCapTable()); message->initCapTable(receiveCaps(payload.getCapTable()));
auto cancelPaf = kj::newPromiseAndFulfiller<void>(); auto cancelPaf = kj::newPromiseAndFulfiller<void>();
AnswerId answerId = call.getQuestionId(); AnswerId answerId = call.getQuestionId();
auto context = kj::refcounted<RpcCallContext>( auto context = kj::refcounted<RpcCallContext>(
*this, answerId, kj::mv(message), payload.getContent(), *this, answerId, kj::mv(message), payload.getContent(),
kj::mv(capTable), redirectResults, kj::mv(cancelPaf.fulfiller)); redirectResults, kj::mv(cancelPaf.fulfiller));
// No more using `call` after this point, as it now belongs to the context. // No more using `call` after this point, as it now belongs to the context.
...@@ -2080,9 +2075,9 @@ private: ...@@ -2080,9 +2075,9 @@ private:
} }
auto payload = ret.getResults(); auto payload = ret.getResults();
message->initCapTable(receiveCaps(payload.getCapTable()));
questionRef->fulfill(kj::refcounted<RpcResponseImpl>( questionRef->fulfill(kj::refcounted<RpcResponseImpl>(
*this, kj::addRef(*questionRef), kj::mv(message), payload.getContent(), *this, kj::addRef(*questionRef), kj::mv(message), payload.getContent()));
receiveCaps(payload.getCapTable())));
break; break;
} }
...@@ -2184,7 +2179,11 @@ private: ...@@ -2184,7 +2179,11 @@ private:
// Extract the replacement capability. // Extract the replacement capability.
switch (resolve.which()) { switch (resolve.which()) {
case rpc::Resolve::CAP: case rpc::Resolve::CAP:
replacement = receiveCap(resolve.getCap()); KJ_IF_MAYBE(cap, receiveCap(resolve.getCap())) {
replacement = kj::mv(*cap);
} else {
KJ_FAIL_REQUIRE("'Resolve' contained 'CapDescriptor.none'.") { return; }
}
break; break;
case rpc::Resolve::EXCEPTION: case rpc::Resolve::EXCEPTION:
...@@ -2347,8 +2346,6 @@ private: ...@@ -2347,8 +2346,6 @@ private:
rpc::Return::Builder ret = response->getBody().getAs<rpc::Message>().initReturn(); rpc::Return::Builder ret = response->getBody().getAs<rpc::Message>().initReturn();
ret.setAnswerId(answerId); ret.setAnswerId(answerId);
CapBuilderContext context;
kj::Own<ClientHook> capHook; kj::Own<ClientHook> capHook;
kj::Array<ExportId> resultExports; kj::Array<ExportId> resultExports;
KJ_DEFER(releaseExports(resultExports)); // in case something goes wrong KJ_DEFER(releaseExports(resultExports)); // in case something goes wrong
...@@ -2358,13 +2355,12 @@ private: ...@@ -2358,13 +2355,12 @@ private:
KJ_IF_MAYBE(r, restorer) { KJ_IF_MAYBE(r, restorer) {
Capability::Client cap = r->baseRestore(restore.getObjectId()); Capability::Client cap = r->baseRestore(restore.getObjectId());
auto payload = ret.initResults(); auto payload = ret.initResults();
auto results = context.imbue(payload.getContent()); payload.getContent().setAs<Capability>(kj::mv(cap));
results.setAs<Capability>(cap);
auto capTable = context.getCapTable(); auto capTable = response->getCapTable();
KJ_DASSERT(capTable.size() == 1); KJ_DASSERT(capTable.size() == 1);
resultExports = writeDescriptors(capTable, payload); resultExports = writeDescriptors(capTable, payload);
capHook = capTable[0]->addRef(); capHook = KJ_ASSERT_NONNULL(capTable[0])->addRef();
} else { } else {
KJ_FAIL_REQUIRE("This vat cannot restore this SturdyRef.") { break; } KJ_FAIL_REQUIRE("This vat cannot restore this SturdyRef.") { break; }
} }
......
...@@ -126,6 +126,9 @@ public: ...@@ -126,6 +126,9 @@ public:
// Get the message body, which the caller may fill in any way it wants. (The standard RPC // Get the message body, which the caller may fill in any way it wants. (The standard RPC
// implementation initializes it as a Message as defined in rpc.capnp.) // implementation initializes it as a Message as defined in rpc.capnp.)
virtual kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> getCapTable() = 0;
// Calls getCapTable() on the underlying MessageBuilder.
virtual void send() = 0; virtual void send() = 0;
// Send the message, or at least put it in a queue to be sent later. Note that the builder // Send the message, or at least put it in a queue to be sent later. Note that the builder
// returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed. // returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed.
...@@ -138,6 +141,9 @@ public: ...@@ -138,6 +141,9 @@ public:
virtual AnyPointer::Reader getBody() = 0; virtual AnyPointer::Reader getBody() = 0;
// Get the message body, to be interpreted by the caller. (The standard RPC implementation // Get the message body, to be interpreted by the caller. (The standard RPC implementation
// interprets it as a Message as defined in rpc.capnp.) // interprets it as a Message as defined in rpc.capnp.)
virtual void initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>>&& capTable) = 0;
// Calls initCapTable() on the underlying MessageReader.
}; };
template <typename SturdyRefHostId, typename ProvisionId, typename RecipientId, template <typename SturdyRefHostId, typename ProvisionId, typename RecipientId,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment