Commit 7298c7fa authored by Kenton Varda's avatar Kenton Varda

More RPC protocol WIP.

parent 41feb767
...@@ -91,11 +91,11 @@ class Response: public Results::Reader { ...@@ -91,11 +91,11 @@ class Response: public Results::Reader {
// is move-only -- once it goes out-of-scope, the underlying message will be freed. // is move-only -- once it goes out-of-scope, the underlying message will be freed.
public: public:
inline Response(typename Results::Reader reader, kj::Own<ResponseHook>&& hook) inline Response(typename Results::Reader reader, kj::Own<const ResponseHook>&& hook)
: Results::Reader(reader), hook(kj::mv(hook)) {} : Results::Reader(reader), hook(kj::mv(hook)) {}
private: private:
kj::Own<ResponseHook> hook; kj::Own<const ResponseHook> hook;
template <typename, typename> template <typename, typename>
friend class Request; friend class Request;
...@@ -347,6 +347,7 @@ public: ...@@ -347,6 +347,7 @@ public:
}; };
kj::Own<const ClientHook> newBrokenCap(const char* reason); kj::Own<const ClientHook> newBrokenCap(const char* reason);
kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called. // Helper function that creates a capability which simply throws exceptions when called.
// ======================================================================================= // =======================================================================================
......
...@@ -494,6 +494,8 @@ public: ...@@ -494,6 +494,8 @@ public:
: segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0), : segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0),
pointerCount(0), bit0Offset(0), nestingLimit(0x7fffffff) {} pointerCount(0), bit0Offset(0), nestingLimit(0x7fffffff) {}
const void* getLocation() const { return data; }
inline BitCount getDataSectionSize() const { return dataSize; } inline BitCount getDataSectionSize() const { return dataSize; }
inline WirePointerCount getPointerSectionSize() const { return pointerCount; } inline WirePointerCount getPointerSectionSize() const { return pointerCount; }
inline Data::Reader getDataSectionAsBlob(); inline Data::Reader getDataSectionAsBlob();
......
...@@ -27,7 +27,9 @@ ...@@ -27,7 +27,9 @@
#include <kj/vector.h> #include <kj/vector.h>
#include <kj/async.h> #include <kj/async.h>
#include <kj/one-of.h> #include <kj/one-of.h>
#include <kj/function.h>
#include <unordered_map> #include <unordered_map>
#include <map>
#include <queue> #include <queue>
#include <capnp/rpc.capnp.h> #include <capnp/rpc.capnp.h>
...@@ -68,7 +70,7 @@ kj::Maybe<kj::Array<PipelineOp>> toPipelineOps(List<rpc::PromisedAnswer::Op>::Re ...@@ -68,7 +70,7 @@ kj::Maybe<kj::Array<PipelineOp>> toPipelineOps(List<rpc::PromisedAnswer::Op>::Re
} }
Orphan<List<rpc::PromisedAnswer::Op>> fromPipelineOps( Orphan<List<rpc::PromisedAnswer::Op>> fromPipelineOps(
Orphanage orphanage, const kj::Array<PipelineOp>& ops) { Orphanage orphanage, kj::ArrayPtr<const PipelineOp> ops) {
auto result = orphanage.newOrphan<List<rpc::PromisedAnswer::Op>>(ops.size()); auto result = orphanage.newOrphan<List<rpc::PromisedAnswer::Op>>(ops.size());
auto builder = result.get(); auto builder = result.get();
for (uint i: kj::indices(ops)) { for (uint i: kj::indices(ops)) {
...@@ -186,7 +188,7 @@ public: ...@@ -186,7 +188,7 @@ public:
return low[id]; return low[id];
} else { } else {
auto iter = high.find(id); auto iter = high.find(id);
if (iter == nullptr) { if (iter == high.end()) {
return nullptr; return nullptr;
} else { } else {
return iter->second; return iter->second;
...@@ -207,28 +209,29 @@ private: ...@@ -207,28 +209,29 @@ private:
std::unordered_map<Id, T> high; std::unordered_map<Id, T> high;
}; };
template <typename ParamCaps, typename RpcPipeline> template <typename ParamCaps, typename RpcPipeline, typename RpcResponse>
struct Question { struct Question {
kj::Own<ParamCaps> paramCaps; kj::Own<kj::PromiseFulfiller<kj::Own<RpcResponse>>> fulfiller;
// A handle representing the capabilities in the parameter struct. This will be dropped as soon
// as the call returns.
kj::Own<kj::PromiseFulfiller<Response<ObjectPointer>>> fulfiller;
// Fulfill with the response. // Fulfill with the response.
kj::Maybe<kj::Own<ParamCaps>> paramCaps;
// CapInjector from the parameter struct. This will be released once the `Return` message is
// received and `retainedCaps` processed. (If this is non-null, then the call has not returned
// yet.)
kj::Maybe<RpcPipeline&> pipeline; kj::Maybe<RpcPipeline&> pipeline;
// The local pipeline object. The RpcPipeline's own destructor sets this value to null and then // The local pipeline object. The RpcPipeline's own destructor sets this value to null.
// sends the Finish message.
// //
// TODO(cleanup): We only have this pointer here because CapInjectorImpl::getInjectedCap() needs // TODO(cleanup): We only have this pointer here because CapInjectorImpl::getInjectedCap() needs
// it, but perhaps CapInjectorImpl should instead hold on to the ClientHook it got in the first // it, but perhaps CapInjectorImpl should instead hold on to the ClientHook it got in the first
// place. // place.
bool isStarted = false; bool isStarted = false;
// Is this Question currently in-use? // Is this Question ID currently in-use? (This is true until both `Return` has been received and
// `Finish` has been sent.)
bool isReturned = false; bool isFinished = false;
// Has the call returned? // Has the `Finish` message been sent?
inline bool operator==(decltype(nullptr)) const { return !isStarted; } inline bool operator==(decltype(nullptr)) const { return !isStarted; }
inline bool operator!=(decltype(nullptr)) const { return isStarted; } inline bool operator!=(decltype(nullptr)) const { return isStarted; }
...@@ -261,6 +264,13 @@ struct Export { ...@@ -261,6 +264,13 @@ struct Export {
inline bool operator!=(decltype(nullptr)) const { return refcount != 0; } inline bool operator!=(decltype(nullptr)) const { return refcount != 0; }
}; };
template <typename ImportClient>
struct Import {
ImportClient* client = nullptr;
// Normally I'd want this to be Maybe<ImportClient&>, but GCC's unordered_map doesn't seem to
// like DisableConstCopy types.
};
// ======================================================================================= // =======================================================================================
class RpcConnectionState: public kj::TaskSet::ErrorHandler { class RpcConnectionState: public kj::TaskSet::ErrorHandler {
...@@ -285,12 +295,13 @@ private: ...@@ -285,12 +295,13 @@ private:
class CapExtractorImpl; class CapExtractorImpl;
class RpcPipeline; class RpcPipeline;
class RpcCallContext; class RpcCallContext;
class RpcResponse;
struct Tables { struct Tables {
ExportTable<QuestionId, Question<CapInjectorImpl, RpcPipeline>> questions; ExportTable<QuestionId, Question<CapInjectorImpl, RpcPipeline, RpcResponse>> questions;
ImportTable<QuestionId, Answer<RpcCallContext>> answers; ImportTable<QuestionId, Answer<RpcCallContext>> answers;
ExportTable<ExportId, Export> exports; ExportTable<ExportId, Export> exports;
ImportTable<ExportId, kj::Maybe<ImportClient&>> imports; ImportTable<ExportId, Import<ImportClient>> imports;
}; };
kj::MutexGuarded<Tables> tables; kj::MutexGuarded<Tables> tables;
...@@ -331,13 +342,20 @@ private: ...@@ -331,13 +342,20 @@ private:
RpcClient(const RpcConnectionState& connectionState) RpcClient(const RpcConnectionState& connectionState)
: connectionState(connectionState) {} : connectionState(connectionState) {}
virtual kj::Own<const kj::Refcounted> writeDescriptor( virtual void writeDescriptor(rpc::CapDescriptor::Builder descriptor, Tables& tables) const = 0;
rpc::CapDescriptor::Builder descriptor) const = 0; // Writes a CapDescriptor referencing this client. Must be called with the
// Writes a CapDescriptor referencing this client. Returns a reference to some object which // RpcConnectionState's table locked -- a reference to them is passed as the second argument.
// must be held at least until the message containing `descriptor` has been sent. // The CapDescriptor must be sent before unlocking the tables, as it may become invalid at
// any time once the tables are unlocked.
virtual kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target) const = 0;
// Writes the appropriate call target for calls to this capability and returns null.
// //
// TODO(cleanup): Specialize Own<void> so that we can return it here instead of // - OR -
// Own<Refcounted>. //
// If calls have been redirected to some other local ClientHook, returns that hook instead.
// This can happen if the capability represents a promise that has been resolved.
// implements ClientHook ----------------------------------------- // implements ClientHook -----------------------------------------
...@@ -349,7 +367,7 @@ private: ...@@ -349,7 +367,7 @@ private:
// TODO(perf): Extend targetSizeInWords() to include a capability count? Here we increase // TODO(perf): Extend targetSizeInWords() to include a capability count? Here we increase
// the size by 1/16 to deal with cap descriptors possibly expanding. See also below, when // the size by 1/16 to deal with cap descriptors possibly expanding. See also below, when
// handling the response. // handling the response, and in RpcRequest::send().
sizeHint += sizeHint / 16; sizeHint += sizeHint / 16;
// Don't overflow. // Don't overflow.
...@@ -408,9 +426,9 @@ private: ...@@ -408,9 +426,9 @@ private:
// which case that other thread will have constructed a new ImportClient and placed it in // which case that other thread will have constructed a new ImportClient and placed it in
// the import table.) // the import table.)
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState.tables.lockExclusive();
KJ_IF_MAYBE(ptr, lock->imports[importId]) { KJ_IF_MAYBE(import, lock->imports.find(importId)) {
if (ptr == this) { if (import->client == this) {
lock->imports[importId] = nullptr; lock->imports.erase(importId);
} }
} }
} }
...@@ -438,22 +456,25 @@ private: ...@@ -438,22 +456,25 @@ private:
} }
} }
kj::Own<const kj::Refcounted> writeDescriptor( void writeDescriptor(rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
rpc::CapDescriptor::Builder descriptor) const override {
descriptor.setReceiverHosted(importId); descriptor.setReceiverHosted(importId);
return kj::addRef(*this); }
kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target) const override {
target.setExportedCap(importId);
return nullptr;
} }
// implements ClientHook ----------------------------------------- // implements ClientHook -----------------------------------------
Request<ObjectPointer, ObjectPointer> newCall( Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override { uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto request = kj::heap<RpcRequest>(connectionState, firstSegmentWordSize); auto request = kj::heap<RpcRequest>(connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall(); auto callBuilder = request->getCall();
callBuilder.getTarget().setExportedCap(importId); callBuilder.getTarget().setExportedCap(importId);
callBuilder.setInterfaceId(interfaceId); callBuilder.setInterfaceId(interfaceId);
callBuilder.setMethodId(methodId); callBuilder.setMethodId(methodId);
request->holdRef(writeTarget(callBuilder.getTarget()));
auto root = request->getRoot(); auto root = request->getRoot();
return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request)); return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
...@@ -495,6 +516,9 @@ private: ...@@ -495,6 +516,9 @@ private:
return true; return true;
} }
// TODO(now): Override writeDescriptor() and writeTarget() to redirect once the promise
// resolves.
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override { kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
// We need the returned promise to hold a reference back to this object, so that it doesn't // We need the returned promise to hold a reference back to this object, so that it doesn't
// disappear while the promise is still outstanding. // disappear while the promise is still outstanding.
...@@ -511,50 +535,117 @@ private: ...@@ -511,50 +535,117 @@ private:
class PromisedAnswerClient final: public RpcClient { class PromisedAnswerClient final: public RpcClient {
public: public:
PromisedAnswerClient(const RpcConnectionState& connectionState, QuestionId questionId, PromisedAnswerClient(const RpcConnectionState& connectionState,
kj::Array<PipelineOp>&& ops, kj::Own<const RpcPipeline> pipeline); kj::Own<const RpcPipeline>&& pipeline,
kj::Array<PipelineOp>&& ops)
: RpcClient(connectionState), ops(kj::mv(ops)),
resolveSelfPromise(pipeline->onResponse().then(
[this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> {
resolve(kj::mv(response));
return kj::READY_NOW; // hack to force eager resolution.
}, [this](kj::Exception&& exception) -> kj::Promise<void> {
resolve(kj::mv(exception));
return kj::READY_NOW; // hack to force eager resolution.
})) {
state.getWithoutLock().init<Waiting>(kj::mv(pipeline));
}
kj::Own<const kj::Refcounted> writeDescriptor(rpc::CapDescriptor::Builder descriptor) const override { void writeDescriptor(rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
auto lock = state.lockShared(); auto lock = state.lockShared();
if (lock->is<Waiting>()) { if (lock->is<Waiting>()) {
auto promisedAnswer = descriptor.initReceiverAnswer(); return lock->get<Waiting>()->writeDescriptor(descriptor, tables, ops);
promisedAnswer.setQuestionId(questionId); } else if (lock->is<Resolved>()) {
promisedAnswer.adoptTransform(fromPipelineOps( return connectionState.writeDescriptor(lock->get<Resolved>()->addRef(), descriptor, tables);
Orphanage::getForMessageContaining(descriptor), ops)); } else {
// TODO(now)
}
}
// Return a ref to the RpcPipeline to ensure that we don't send a Finish message for this kj::Maybe<kj::Own<const ClientHook>> writeTarget(
// call before the message containing this CapDescriptor is sent. rpc::Call::Target::Builder target) const override {
return kj::addRef(*lock->get<Waiting>()); auto lock = state.lockShared();
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->writeTarget(target, ops);
} else if (lock->is<Resolved>()) {
return connectionState.writeTarget(*lock->get<Resolved>(), target);
} else { } else {
// TODO(now): Problem: This won't necessarily be a remote cap! return newBrokenCap(kj::cp(lock->get<Broken>()));
return connectionState.writeDescriptor(
lock->get<Resolved>().getPipelinedCap(ops), descriptor);
} }
} }
// implements ClientHook ----------------------------------------- // implements ClientHook -----------------------------------------
Request<ObjectPointer, ObjectPointer> newCall( Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override; uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override; auto lock = state.lockShared();
if (lock->is<Waiting>()) {
auto request = kj::heap<RpcRequest>(
connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall();
callBuilder.setInterfaceId(interfaceId);
callBuilder.setMethodId(methodId);
auto root = request->getRoot();
return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->newCall(interfaceId, methodId, firstSegmentWordSize);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()))->newCall(
interfaceId, methodId, firstSegmentWordSize);
}
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
auto lock = state.lockShared();
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->onResponse().then(kj::mvCapture(kj::heapArray(ops.asPtr()),
[](kj::Array<PipelineOp>&& ops, kj::Own<const RpcResponse>&& response) {
return response->getResults().getPipelinedCap(ops);
}));
} else if (lock->is<Resolved>()) {
return kj::Promise<kj::Own<const ClientHook>>(lock->get<Resolved>()->addRef());
} else {
return kj::Promise<kj::Own<const ClientHook>>(kj::cp(lock->get<Broken>()));
}
}
private: private:
const RpcConnectionState& connectionState;
QuestionId questionId;
kj::Array<PipelineOp> ops; kj::Array<PipelineOp> ops;
typedef kj::Own<const RpcPipeline> Waiting; typedef kj::Own<const RpcPipeline> Waiting;
typedef Response<ObjectPointer> Resolved; typedef kj::Own<const ClientHook> Resolved;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved>> state; typedef kj::Exception Broken;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, Broken>> state;
// Keep this last, because the continuation uses *this, so it should be destroyed first to
// ensure the continuation is not still running.
kj::Promise<void> resolveSelfPromise;
void resolve(kj::Own<const RpcResponse>&& response) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Resolved>(response->getResults().getPipelinedCap(ops));
}
void resolve(const kj::Exception&& exception) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Broken>(kj::mv(exception));
}
}; };
kj::Own<const kj::Refcounted> writeDescriptor( void writeDescriptor(kj::Own<const ClientHook> cap, rpc::CapDescriptor::Builder descriptor,
kj::Own<const ClientHook>&& cap, rpc::CapDescriptor::Builder descriptor) const { Tables& tables) const {
// Write a descriptor for the given capability. Returns a reference to something which must // Write a descriptor for the given capability. The tables must be locked by the caller and
// be held at least until the message containing the descriptor is sent. // passed in as a parameter.
if (cap->getBrand() == this) { if (cap->getBrand() == this) {
return kj::downcast<const RpcClient>(*cap).writeDescriptor(descriptor); kj::downcast<const RpcClient>(*cap).writeDescriptor(descriptor, tables);
} else { } else {
// TODO(now): We have to figure out if the client is already in our table. // TODO(now): We have to figure out if the client is already in our table.
// TODO(now): We have to add a refcount to the export, and return an object that decrements // TODO(now): We have to add a refcount to the export, and return an object that decrements
...@@ -562,6 +653,25 @@ private: ...@@ -562,6 +653,25 @@ private:
} }
} }
kj::Maybe<kj::Own<const ClientHook>> writeTarget(
const ClientHook& cap, rpc::Call::Target::Builder target) const {
// If calls to the given capability should pass over this connection, fill in `target`
// appropriately for such a call and return nullptr. Otherwise, return a `ClientHook` to which
// the call should be forwarded; the caller should then delegate the call to that `ClientHook`.
//
// The main case where this ends up returning non-null is if `cap` is a promise that has
// recently resolved. The application might have started building a request before the promise
// resolved, and so the request may have been built on the assumption that it would be sent over
// this network connection, but then the promise resolved to point somewhere else before the
// request was sent. Now the request has to be redirected to the new target instead.
if (cap.getBrand() == this) {
return kj::downcast<const RpcClient>(cap).writeTarget(target);
} else {
return cap.addRef();
}
}
// ===================================================================================== // =====================================================================================
// CapExtractor / CapInjector implementations // CapExtractor / CapInjector implementations
...@@ -600,8 +710,8 @@ private: ...@@ -600,8 +710,8 @@ private:
auto actualRetained = retainedCaps.begin(); auto actualRetained = retainedCaps.begin();
for (ExportId importId: retainedCaps) { for (ExportId importId: retainedCaps) {
// Check if the import still exists under this ID. // Check if the import still exists under this ID.
KJ_IF_MAYBE(import, lock->imports[importId]) { KJ_IF_MAYBE(import, lock->imports.find(importId)) {
if (import->tryAddRemoteRef() != nullptr) { if (import->client != nullptr && import->client->tryAddRemoteRef() != nullptr) {
// Import indeed still exists! We are responsible for retaining it. // Import indeed still exists! We are responsible for retaining it.
*actualRetained++ = importId; *actualRetained++ = importId;
} }
...@@ -631,10 +741,11 @@ private: ...@@ -631,10 +741,11 @@ private:
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState.tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports[importId]) { auto& import = lock->imports[importId];
if (import.client != nullptr) {
// The import is already on the table, but it could be being deleted in another // The import is already on the table, but it could be being deleted in another
// thread. // thread.
KJ_IF_MAYBE(ref, kj::tryAddRef(*import)) { KJ_IF_MAYBE(ref, kj::tryAddRef(*import.client)) {
// We successfully grabbed a reference to the import without it being deleted in // We successfully grabbed a reference to the import without it being deleted in
// another thread. Since this import already exists, we don't have to take // another thread. Since this import already exists, we don't have to take
// responsibility for retaining it. We can just return the existing object and // responsibility for retaining it. We can just return the existing object and
...@@ -652,7 +763,7 @@ private: ...@@ -652,7 +763,7 @@ private:
} else { } else {
kj::refcounted<SettledImportClient>(connectionState, importId); kj::refcounted<SettledImportClient>(connectionState, importId);
} }
lock->imports[importId] = *result; import.client = result;
// Note that we need to retain this import later if it still exists. // Note that we need to retain this import later if it still exists.
retainedCaps.lockExclusive()->add(importId); retainedCaps.lockExclusive()->add(importId);
...@@ -699,14 +810,33 @@ private: ...@@ -699,14 +810,33 @@ private:
: connectionState(connectionState) {} : connectionState(connectionState) {}
~CapInjectorImpl() {} ~CapInjectorImpl() {}
void finish(Tables& tables) {
// Finish writing all of the CapDescriptors. Must be called with the tables locked, and the
// message must be sent before the tables are unlocked.
for (auto& entry: caps.getWithoutLock()) {
connectionState.writeDescriptor(kj::mv(entry.second.cap), entry.second.builder, tables);
}
}
// implements CapInjector ---------------------------------------- // implements CapInjector ----------------------------------------
void injectCap(rpc::CapDescriptor::Builder descriptor, void injectCap(rpc::CapDescriptor::Builder descriptor,
kj::Own<const ClientHook>&& cap) const override { kj::Own<const ClientHook>&& cap) const override {
auto ref = connectionState.writeDescriptor(kj::mv(cap), descriptor); auto lock = caps.lockExclusive();
refs.lockExclusive()->add(kj::mv(ref)); auto result = lock->insert(std::make_pair(
identity(descriptor), CapInfo(descriptor, kj::mv(cap))));
KJ_REQUIRE(result.second, "A cap has already been injected at this location.") {
result.first->second.cap = kj::mv(cap);
break;
}
} }
kj::Own<const ClientHook> getInjectedCap(rpc::CapDescriptor::Reader descriptor) const override { kj::Own<const ClientHook> getInjectedCap(rpc::CapDescriptor::Reader descriptor) const override {
auto lock = caps.lockExclusive();
auto iter = lock->find(identity(descriptor));
KJ_REQUIRE(iter != lock->end(), "getInjectedCap() called on descriptor I didn't write.");
return iter->second.cap->addRef();
switch (descriptor.which()) { switch (descriptor.which()) {
case rpc::CapDescriptor::SENDER_HOSTED: { case rpc::CapDescriptor::SENDER_HOSTED: {
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState.tables.lockExclusive();
...@@ -722,9 +852,11 @@ private: ...@@ -722,9 +852,11 @@ private:
case rpc::CapDescriptor::RECEIVER_HOSTED: { case rpc::CapDescriptor::RECEIVER_HOSTED: {
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState.tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports[descriptor.getReceiverHosted()]) { KJ_IF_MAYBE(import, lock->imports.find(descriptor.getReceiverHosted())) {
KJ_IF_MAYBE(ref, kj::tryAddRef(*import)) { if (import->client != nullptr) {
return kj::mv(*ref); KJ_IF_MAYBE(ref, kj::tryAddRef(*import->client)) {
return kj::mv(*ref);
}
} }
} }
...@@ -741,9 +873,7 @@ private: ...@@ -741,9 +873,7 @@ private:
KJ_IF_MAYBE(question, lock->questions.find(promisedAnswer.getQuestionId())) { KJ_IF_MAYBE(question, lock->questions.find(promisedAnswer.getQuestionId())) {
KJ_IF_MAYBE(pipeline, question->pipeline) { KJ_IF_MAYBE(pipeline, question->pipeline) {
KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) { KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) {
return kj::refcounted<PromisedAnswerClient>( return pipeline->getPipelinedCap(kj::mv(*ops));
connectionState, promisedAnswer.getQuestionId(),
kj::mv(*ops), kj::addRef(*pipeline));
} }
} }
} }
...@@ -758,24 +888,86 @@ private: ...@@ -758,24 +888,86 @@ private:
} }
} }
void dropCap(rpc::CapDescriptor::Reader descriptor) const override { void dropCap(rpc::CapDescriptor::Reader descriptor) const override {
// TODO(someday): We could implement this by maintaining a map from CapDescriptors to caps.lockExclusive()->erase(identity(descriptor));
// the corresponding refs, but is it worth it?
} }
private: private:
const RpcConnectionState& connectionState; const RpcConnectionState& connectionState;
kj::MutexGuarded<kj::Vector<kj::Own<const kj::Refcounted>>> refs; struct CapInfo {
// List of references that need to be held until the message is destroyed. rpc::CapDescriptor::Builder builder;
kj::Own<const ClientHook> cap;
CapInfo(rpc::CapDescriptor::Builder& builder, kj::Own<const ClientHook>&& cap)
: builder(builder), cap(kj::mv(cap)) {}
CapInfo(const CapInfo& other);
// Work around problem where std::pair complains about the copy constructor requiring a
// non-const argument due to `builder` inheriting kj::DisableConstCopy. The copy constructor
// should never be called anyway.
};
kj::MutexGuarded<std::map<const void*, CapInfo>> caps;
// Maps CapDescriptor locations to embedded caps. The descriptors aren't actually filled in
// until just before the message is sent.
static const void* identity(const rpc::CapDescriptor::Reader& desc) {
// TODO(cleanup): Don't rely on internal APIs here.
return _::PointerHelpers<rpc::CapDescriptor>::getInternalReader(desc).getLocation();
}
}; };
// ===================================================================================== // =====================================================================================
// RequestHook/PipelineHook/ResponseHook implementations // RequestHook/PipelineHook/ResponseHook implementations
class QuestionRef: public kj::Refcounted {
public:
inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id)
: connectionState(connectionState), id(id), resultCaps(connectionState) {}
~QuestionRef() {
// Send the "Finish" message.
auto message = connectionState.connection->newOutgoingMessage(
messageSizeHint<rpc::Finish>() + resultCaps.retainedListSizeHint(true));
auto builder = message->getBody().getAs<rpc::Message>().initFinish();
builder.setQuestionId(id);
builder.adoptRetainedCaps(resultCaps.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder)));
message->send();
// Check if the question has returned and, if so, remove it from the table.
// Remove question ID from the table. Must do this *after* sending `Finish` to ensure that
// the ID is not re-allocated before the `Finish` message can be sent.
{
auto lock = connectionState.tables.lockExclusive();
auto& question = KJ_ASSERT_NONNULL(
lock->questions.find(id), "Question ID no longer on table?");
if (question.paramCaps == nullptr) {
// Call has already returned, so we can now remove it from the table.
KJ_ASSERT(lock->questions.erase(id));
} else {
question.isFinished = true;
}
}
}
inline QuestionId getId() const { return id; }
CapExtractorImpl& getResultCapExtractor() { return resultCaps; }
private:
const RpcConnectionState& connectionState;
QuestionId id;
CapExtractorImpl resultCaps;
};
class RpcRequest final: public RequestHook { class RpcRequest final: public RequestHook {
public: public:
RpcRequest(const RpcConnectionState& connectionState, uint firstSegmentWordSize) RpcRequest(const RpcConnectionState& connectionState, uint firstSegmentWordSize,
kj::Own<const RpcClient>&& target)
: connectionState(connectionState), : connectionState(connectionState),
target(kj::mv(target)),
message(connectionState.connection->newOutgoingMessage( message(connectionState.connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 : firstSegmentWordSize + messageSizeHint<rpc::Call>())), firstSegmentWordSize == 0 ? 0 : firstSegmentWordSize + messageSizeHint<rpc::Call>())),
injector(kj::heap<CapInjectorImpl>(connectionState)), injector(kj::heap<CapInjectorImpl>(connectionState)),
...@@ -791,39 +983,77 @@ private: ...@@ -791,39 +983,77 @@ private:
} }
RemotePromise<ObjectPointer> send() override { RemotePromise<ObjectPointer> send() override {
auto paf = kj::newPromiseAndFulfiller<Response<ObjectPointer>>(connectionState.eventLoop);
QuestionId questionId; QuestionId questionId;
kj::Promise<kj::Own<RpcResponse>> promise = nullptr;
{ {
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState.tables.lockExclusive();
auto& question = lock->questions.next(questionId); KJ_IF_MAYBE(redirect, target->writeTarget(callBuilder.getTarget())) {
// Whoops, this capability has been redirected while we were building the request!
// We'll have to make a new request and do a copy. Ick.
lock.release();
size_t sizeHint = paramsBuilder.targetSizeInWords();
// TODO(perf): See TODO in RpcClient::call() about why we need to inflate the size a bit.
sizeHint += sizeHint / 16;
// Don't overflow.
if (uint(sizeHint) != sizeHint) {
sizeHint = ~uint(0);
}
auto replacement = redirect->get()->newCall(
callBuilder.getInterfaceId(), callBuilder.getMethodId(), sizeHint);
replacement.set(paramsBuilder);
return replacement.send();
} else {
injector->finish(*lock);
auto paf = kj::newPromiseAndFulfiller<kj::Own<RpcResponse>>(connectionState.eventLoop);
auto& question = lock->questions.next(questionId);
callBuilder.setQuestionId(questionId);
question.isStarted = true;
question.paramCaps = kj::mv(injector);
question.fulfiller = kj::mv(paf.fulfiller);
callBuilder.setQuestionId(questionId); message->send();
question.isStarted = true;
question.paramCaps = kj::mv(injector);
question.fulfiller = kj::mv(paf.fulfiller); promise = kj::mv(paf.promise);
}
} }
auto pipeline = kj::refcounted<RpcPipeline>(connectionState, questionId); auto questionRef = kj::refcounted<QuestionRef>(connectionState, questionId);
// If the caller discards the pipeline without discarding the promise, we need the pipeline auto promiseWithQuestionRef = promise.then(kj::mvCapture(kj::addRef(*questionRef),
// to stay alive so that we don't cancel the call altogether. [](kj::Own<const QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response)
auto promiseWithPipelineRef = paf.promise.then(kj::mvCapture(pipeline->addRef(), -> kj::Own<const RpcResponse> {
[](kj::Own<const PipelineHook>&&, Response<ObjectPointer>&& response) response->setQuestionRef(kj::mv(questionRef));
-> Response<ObjectPointer> {
return kj::mv(response); return kj::mv(response);
})); }));
message->send(); auto forkedPromise = promiseWithQuestionRef.fork();
auto appPromise = forkedPromise.addBranch().then([](kj::Own<const RpcResponse>&& response) {
auto reader = response->getResults();
return Response<ObjectPointer>(reader, kj::mv(response));
});
auto pipeline = kj::refcounted<RpcPipeline>(
connectionState, questionId, kj::mv(forkedPromise));
return RemotePromise<ObjectPointer>( return RemotePromise<ObjectPointer>(
kj::mv(promiseWithPipelineRef), kj::mv(appPromise),
ObjectPointer::Pipeline(kj::mv(pipeline))); ObjectPointer::Pipeline(kj::mv(pipeline)));
} }
private: private:
const RpcConnectionState& connectionState; const RpcConnectionState& connectionState;
kj::Own<const RpcClient> target;
kj::Own<OutgoingRpcMessage> message; kj::Own<OutgoingRpcMessage> message;
kj::Own<CapInjectorImpl> injector; kj::Own<CapInjectorImpl> injector;
CapBuilderContext context; CapBuilderContext context;
...@@ -831,43 +1061,67 @@ private: ...@@ -831,43 +1061,67 @@ private:
ObjectPointer::Builder paramsBuilder; ObjectPointer::Builder paramsBuilder;
}; };
class RpcPipeline: public PipelineHook, public kj::Refcounted { class RpcPipeline final: public PipelineHook, public kj::Refcounted {
public: public:
RpcPipeline(const RpcConnectionState& connectionState, QuestionId questionId) RpcPipeline(const RpcConnectionState& connectionState, QuestionId questionId,
: connectionState(connectionState), questionId(questionId) {} kj::ForkedPromise<kj::Own<const RpcResponse>>&& redirectLaterParam)
: connectionState(connectionState),
~RpcPipeline() noexcept(false) { redirectLater(kj::mv(redirectLaterParam)),
uint sizeHint = messageSizeHint<rpc::Finish>(); resolveSelfPromise(redirectLater.addBranch().then(
KJ_IF_MAYBE(ce, capExtractor) { [this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> {
sizeHint += ce->retainedListSizeHint(true); resolve(kj::mv(response));
} return kj::READY_NOW; // hack to force eager resolution.
auto finishMessage = connectionState.connection->newOutgoingMessage(sizeHint); }, [this](kj::Exception&& exception) -> kj::Promise<void> {
resolve(kj::mv(exception));
return kj::READY_NOW;
})) {
// Construct a new RpcPipeline.
state.getWithoutLock().init<Waiting>(questionId);
}
rpc::Finish::Builder builder = finishMessage->getBody().initAs<rpc::Message>().initFinish(); kj::Promise<kj::Own<const RpcResponse>> onResponse() const {
return redirectLater.addBranch();
}
builder.setQuestionId(questionId); kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target, kj::ArrayPtr<const PipelineOp> ops) const {
// Initializes `target` to a PromisedAnswer for a pipelined capability, *or* returns the
// specified capability directly.
//
// The caller *should* have a lock on the connection state's tables while calling this, so
// that a Finish message cannot be sent before the caller manages to send its `Call`.
KJ_IF_MAYBE(ce, capExtractor) { auto lock = state.lockExclusive();
builder.adoptRetainedCaps(ce->finalizeRetainedCaps( if (lock->is<Waiting>()) {
Orphanage::getForMessageContaining(builder))); auto builder = target.initPromisedAnswer();
builder.setQuestionId(lock->get<Waiting>());
builder.adoptTransform(fromPipelineOps(Orphanage::getForMessageContaining(builder), ops));
return nullptr;
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()));
} }
}
finishMessage->send(); void writeDescriptor(rpc::CapDescriptor::Builder descriptor, Tables& tables,
kj::ArrayPtr<const PipelineOp> ops) const {
{ auto lock = state.lockExclusive();
auto lock = connectionState.tables.lockExclusive(); if (lock->is<Waiting>()) {
auto& question = KJ_ASSERT_NONNULL(lock->questions.find(questionId), auto promisedAnswer = descriptor.initReceiverAnswer();
"RpcPipeline had invalid questionId?"); promisedAnswer.setQuestionId(lock->get<Waiting>());
question.pipeline = nullptr; promisedAnswer.adoptTransform(fromPipelineOps(
Orphanage::getForMessageContaining(descriptor), ops));
if (question.isReturned) { } else if (lock->is<Resolved>()) {
KJ_ASSERT(lock->questions.erase(questionId)); connectionState.writeDescriptor(lock->get<Resolved>()->getResults().getPipelinedCap(ops),
} descriptor, tables);
} else {
connectionState.writeDescriptor(newBrokenCap(kj::cp(lock->get<Broken>())),
descriptor, tables);
} }
} }
kj::Promise<Response<ObjectPointer>> getResponse();
// implements PipelineHook --------------------------------------- // implements PipelineHook ---------------------------------------
kj::Own<const PipelineHook> addRef() const override { kj::Own<const PipelineHook> addRef() const override {
...@@ -883,17 +1137,45 @@ private: ...@@ -883,17 +1137,45 @@ private:
} }
kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const override { kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const override {
return kj::refcounted<PromisedAnswerClient>( auto lock = state.lockExclusive();
connectionState, questionId, kj::mv(ops), kj::addRef(*this)); if (lock->is<Waiting>()) {
return kj::refcounted<PromisedAnswerClient>(
connectionState, kj::addRef(*this), kj::mv(ops));
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()));
}
} }
private: private:
const RpcConnectionState& connectionState; const RpcConnectionState& connectionState;
QuestionId questionId;
kj::Maybe<CapExtractorImpl&> capExtractor; kj::Maybe<CapExtractorImpl&> capExtractor;
kj::ForkedPromise<kj::Own<const RpcResponse>> redirectLater;
typedef QuestionId Waiting;
typedef kj::Own<const RpcResponse> Resolved;
typedef kj::Exception Broken;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, Broken>> state;
// Keep this last, because the continuation uses *this, so it should be destroyed first to
// ensure the continuation is not still running.
kj::Promise<void> resolveSelfPromise;
void resolve(kj::Own<const RpcResponse>&& response) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Resolved>(kj::mv(response));
}
void resolve(const kj::Exception&& exception) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Broken>(kj::mv(exception));
}
}; };
class RpcResponse: public ResponseHook { class RpcResponse final: public ResponseHook, public kj::Refcounted {
public: public:
RpcResponse(const RpcConnectionState& connectionState, RpcResponse(const RpcConnectionState& connectionState,
kj::Own<IncomingRpcMessage>&& message, kj::Own<IncomingRpcMessage>&& message,
...@@ -903,15 +1185,24 @@ private: ...@@ -903,15 +1185,24 @@ private:
context(extractor), context(extractor),
reader(context.imbue(results)) {} reader(context.imbue(results)) {}
ObjectPointer::Reader getResults() { ObjectPointer::Reader getResults() const {
return reader; return reader;
} }
kj::Own<const RpcResponse> addRef() const {
return kj::addRef(*this);
}
void setQuestionRef(kj::Own<const QuestionRef>&& questionRef) {
this->questionRef = kj::mv(questionRef);
}
private: private:
kj::Own<IncomingRpcMessage> message; kj::Own<IncomingRpcMessage> message;
CapExtractorImpl extractor; CapExtractorImpl extractor;
CapReaderContext context; CapReaderContext context;
ObjectPointer::Reader reader; ObjectPointer::Reader reader;
kj::Own<const QuestionRef> questionRef;
}; };
// ===================================================================================== // =====================================================================================
...@@ -922,7 +1213,8 @@ private: ...@@ -922,7 +1213,8 @@ private:
RpcServerResponse(const RpcConnectionState& connectionState, RpcServerResponse(const RpcConnectionState& connectionState,
kj::Own<OutgoingRpcMessage>&& message, kj::Own<OutgoingRpcMessage>&& message,
ObjectPointer::Builder results) ObjectPointer::Builder results)
: message(kj::mv(message)), : connectionState(connectionState),
message(kj::mv(message)),
injector(connectionState), injector(connectionState),
context(injector), context(injector),
builder(context.imbue(results)) {} builder(context.imbue(results)) {}
...@@ -932,10 +1224,13 @@ private: ...@@ -932,10 +1224,13 @@ private:
} }
void send() { void send() {
auto lock = connectionState.tables.lockExclusive();
injector.finish(*lock);
message->send(); message->send();
} }
private: private:
const RpcConnectionState& connectionState;
kj::Own<OutgoingRpcMessage> message; kj::Own<OutgoingRpcMessage> message;
CapInjectorImpl injector; CapInjectorImpl injector;
CapBuilderContext context; CapBuilderContext context;
...@@ -1188,23 +1483,23 @@ private: ...@@ -1188,23 +1483,23 @@ private:
auto reader = message->getBody().getAs<rpc::Message>(); auto reader = message->getBody().getAs<rpc::Message>();
switch (reader.which()) { switch (reader.which()) {
case rpc::Message::UNIMPLEMENTED: case rpc::Message::UNIMPLEMENTED:
doUnimplemented(reader.getUnimplemented()); handleUnimplemented(reader.getUnimplemented());
break; break;
case rpc::Message::ABORT: case rpc::Message::ABORT:
doAbort(reader.getAbort()); handleAbort(reader.getAbort());
break; break;
case rpc::Message::CALL: case rpc::Message::CALL:
doCall(kj::mv(message), reader.getCall()); handleCall(kj::mv(message), reader.getCall());
break; break;
case rpc::Message::RETURN: case rpc::Message::RETURN:
doReturn(kj::mv(message), reader.getReturn()); handleReturn(kj::mv(message), reader.getReturn());
break; break;
case rpc::Message::FINISH: case rpc::Message::FINISH:
doFinish(reader.getFinish()); handleFinish(reader.getFinish());
break; break;
default: { default: {
...@@ -1217,15 +1512,15 @@ private: ...@@ -1217,15 +1512,15 @@ private:
} }
} }
void doUnimplemented(const rpc::Message::Reader& message) { void handleUnimplemented(const rpc::Message::Reader& message) {
// TODO(now) // TODO(now)
} }
void doAbort(const rpc::Exception::Reader& exception) { void handleAbort(const rpc::Exception::Reader& exception) {
kj::throwRecoverableException(toException(exception)); kj::throwRecoverableException(toException(exception));
} }
void doCall(kj::Own<IncomingRpcMessage>&& message, const rpc::Call::Reader& call) { void handleCall(kj::Own<IncomingRpcMessage>&& message, const rpc::Call::Reader& call) {
kj::Own<const ClientHook> capability; kj::Own<const ClientHook> capability;
auto target = call.getTarget(); auto target = call.getTarget();
...@@ -1320,13 +1615,19 @@ private: ...@@ -1320,13 +1615,19 @@ private:
} }
} }
void doReturn(kj::Own<IncomingRpcMessage>&& message, const rpc::Return::Reader& ret) { void handleReturn(kj::Own<IncomingRpcMessage>&& message, const rpc::Return::Reader& ret) {
kj::Own<CapInjectorImpl> paramCapsToRelease; kj::Own<CapInjectorImpl> paramCapsToRelease;
auto lock = tables.lockExclusive(); auto lock = tables.lockExclusive();
KJ_IF_MAYBE(question, lock->questions.find(ret.getQuestionId())) { KJ_IF_MAYBE(question, lock->questions.find(ret.getQuestionId())) {
KJ_REQUIRE(!question->isReturned, "Duplicate Return.") { return; } KJ_REQUIRE(question->paramCaps != nullptr, "Duplicate Return.") { return; }
question->isReturned = true;
KJ_IF_MAYBE(pc, question->paramCaps) {
// Release these later, after unlocking.
paramCapsToRelease = kj::mv(*pc);
} else {
KJ_FAIL_REQUIRE("Duplicate return.") { return; }
}
for (ExportId retained: ret.getRetainedCaps()) { for (ExportId retained: ret.getRetainedCaps()) {
KJ_IF_MAYBE(exp, lock->exports.find(retained)) { KJ_IF_MAYBE(exp, lock->exports.find(retained)) {
...@@ -1336,15 +1637,10 @@ private: ...@@ -1336,15 +1637,10 @@ private:
} }
} }
// We can now release the caps in the parameter message. Well... we can release them once
// we release the lock, at least.
paramCapsToRelease = kj::mv(question->paramCaps);
// TODO(now): Handle exception/cancel response // TODO(now): Handle exception/cancel response
auto response = kj::heap<RpcResponse>(*this, kj::mv(message), ret.getAnswer()); auto response = kj::refcounted<RpcResponse>(*this, kj::mv(message), ret.getAnswer());
auto imbuedResults = response->getResults(); question->fulfiller->fulfill(kj::mv(response));
question->fulfiller->fulfill(Response<ObjectPointer>(imbuedResults, kj::mv(response)));
if (question->pipeline == nullptr) { if (question->pipeline == nullptr) {
lock->questions.erase(ret.getQuestionId()); lock->questions.erase(ret.getQuestionId());
...@@ -1354,7 +1650,7 @@ private: ...@@ -1354,7 +1650,7 @@ private:
} }
} }
void doFinish(const rpc::Finish::Reader& finish) { void handleFinish(const rpc::Finish::Reader& finish) {
kj::Maybe<kj::Own<const PipelineHook>> pipelineToRelease; kj::Maybe<kj::Own<const PipelineHook>> pipelineToRelease;
auto lock = tables.lockExclusive(); auto lock = tables.lockExclusive();
...@@ -1399,6 +1695,10 @@ public: ...@@ -1399,6 +1695,10 @@ public:
Impl(VatNetworkBase& network, SturdyRefRestorerBase& restorer, const kj::EventLoop& eventLoop) Impl(VatNetworkBase& network, SturdyRefRestorerBase& restorer, const kj::EventLoop& eventLoop)
: network(network), restorer(restorer), eventLoop(eventLoop) {} : network(network), restorer(restorer), eventLoop(eventLoop) {}
Capability::Client connect(_::StructReader reader) {
// TODO(now)
}
private: private:
VatNetworkBase& network; VatNetworkBase& network;
SturdyRefRestorerBase& restorer; SturdyRefRestorerBase& restorer;
...@@ -1411,7 +1711,7 @@ RpcSystemBase::RpcSystemBase(VatNetworkBase& network, SturdyRefRestorerBase& res ...@@ -1411,7 +1711,7 @@ RpcSystemBase::RpcSystemBase(VatNetworkBase& network, SturdyRefRestorerBase& res
RpcSystemBase::~RpcSystemBase() noexcept(false) {} RpcSystemBase::~RpcSystemBase() noexcept(false) {}
Capability::Client RpcSystemBase::baseConnect(_::StructReader reader) { Capability::Client RpcSystemBase::baseConnect(_::StructReader reader) {
impl->connect(reader); return impl->connect(reader);
} }
} // namespace _ (private) } // namespace _ (private)
......
...@@ -156,6 +156,9 @@ public: ...@@ -156,6 +156,9 @@ public:
inline ArrayPtr<T> asPtr() { inline ArrayPtr<T> asPtr() {
return ArrayPtr<T>(ptr, size_); return ArrayPtr<T>(ptr, size_);
} }
inline ArrayPtr<const T> asPtr() const {
return ArrayPtr<T>(ptr, size_);
}
inline size_t size() const { return size_; } inline size_t size() const { return size_; }
inline T& operator[](size_t index) const { inline T& operator[](size_t index) const {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment