Commit 7298c7fa authored by Kenton Varda's avatar Kenton Varda

More RPC protocol WIP.

parent 41feb767
......@@ -91,11 +91,11 @@ class Response: public Results::Reader {
// is move-only -- once it goes out-of-scope, the underlying message will be freed.
public:
inline Response(typename Results::Reader reader, kj::Own<ResponseHook>&& hook)
inline Response(typename Results::Reader reader, kj::Own<const ResponseHook>&& hook)
: Results::Reader(reader), hook(kj::mv(hook)) {}
private:
kj::Own<ResponseHook> hook;
kj::Own<const ResponseHook> hook;
template <typename, typename>
friend class Request;
......@@ -347,6 +347,7 @@ public:
};
kj::Own<const ClientHook> newBrokenCap(const char* reason);
kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called.
// =======================================================================================
......
......@@ -494,6 +494,8 @@ public:
: segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0),
pointerCount(0), bit0Offset(0), nestingLimit(0x7fffffff) {}
const void* getLocation() const { return data; }
inline BitCount getDataSectionSize() const { return dataSize; }
inline WirePointerCount getPointerSectionSize() const { return pointerCount; }
inline Data::Reader getDataSectionAsBlob();
......
......@@ -27,7 +27,9 @@
#include <kj/vector.h>
#include <kj/async.h>
#include <kj/one-of.h>
#include <kj/function.h>
#include <unordered_map>
#include <map>
#include <queue>
#include <capnp/rpc.capnp.h>
......@@ -68,7 +70,7 @@ kj::Maybe<kj::Array<PipelineOp>> toPipelineOps(List<rpc::PromisedAnswer::Op>::Re
}
Orphan<List<rpc::PromisedAnswer::Op>> fromPipelineOps(
Orphanage orphanage, const kj::Array<PipelineOp>& ops) {
Orphanage orphanage, kj::ArrayPtr<const PipelineOp> ops) {
auto result = orphanage.newOrphan<List<rpc::PromisedAnswer::Op>>(ops.size());
auto builder = result.get();
for (uint i: kj::indices(ops)) {
......@@ -186,7 +188,7 @@ public:
return low[id];
} else {
auto iter = high.find(id);
if (iter == nullptr) {
if (iter == high.end()) {
return nullptr;
} else {
return iter->second;
......@@ -207,28 +209,29 @@ private:
std::unordered_map<Id, T> high;
};
template <typename ParamCaps, typename RpcPipeline>
template <typename ParamCaps, typename RpcPipeline, typename RpcResponse>
struct Question {
kj::Own<ParamCaps> paramCaps;
// A handle representing the capabilities in the parameter struct. This will be dropped as soon
// as the call returns.
kj::Own<kj::PromiseFulfiller<Response<ObjectPointer>>> fulfiller;
kj::Own<kj::PromiseFulfiller<kj::Own<RpcResponse>>> fulfiller;
// Fulfill with the response.
kj::Maybe<kj::Own<ParamCaps>> paramCaps;
// CapInjector from the parameter struct. This will be released once the `Return` message is
// received and `retainedCaps` processed. (If this is non-null, then the call has not returned
// yet.)
kj::Maybe<RpcPipeline&> pipeline;
// The local pipeline object. The RpcPipeline's own destructor sets this value to null and then
// sends the Finish message.
// The local pipeline object. The RpcPipeline's own destructor sets this value to null.
//
// TODO(cleanup): We only have this pointer here because CapInjectorImpl::getInjectedCap() needs
// it, but perhaps CapInjectorImpl should instead hold on to the ClientHook it got in the first
// place.
bool isStarted = false;
// Is this Question currently in-use?
// Is this Question ID currently in-use? (This is true until both `Return` has been received and
// `Finish` has been sent.)
bool isReturned = false;
// Has the call returned?
bool isFinished = false;
// Has the `Finish` message been sent?
inline bool operator==(decltype(nullptr)) const { return !isStarted; }
inline bool operator!=(decltype(nullptr)) const { return isStarted; }
......@@ -261,6 +264,13 @@ struct Export {
inline bool operator!=(decltype(nullptr)) const { return refcount != 0; }
};
template <typename ImportClient>
struct Import {
ImportClient* client = nullptr;
// Normally I'd want this to be Maybe<ImportClient&>, but GCC's unordered_map doesn't seem to
// like DisableConstCopy types.
};
// =======================================================================================
class RpcConnectionState: public kj::TaskSet::ErrorHandler {
......@@ -285,12 +295,13 @@ private:
class CapExtractorImpl;
class RpcPipeline;
class RpcCallContext;
class RpcResponse;
struct Tables {
ExportTable<QuestionId, Question<CapInjectorImpl, RpcPipeline>> questions;
ExportTable<QuestionId, Question<CapInjectorImpl, RpcPipeline, RpcResponse>> questions;
ImportTable<QuestionId, Answer<RpcCallContext>> answers;
ExportTable<ExportId, Export> exports;
ImportTable<ExportId, kj::Maybe<ImportClient&>> imports;
ImportTable<ExportId, Import<ImportClient>> imports;
};
kj::MutexGuarded<Tables> tables;
......@@ -331,13 +342,20 @@ private:
RpcClient(const RpcConnectionState& connectionState)
: connectionState(connectionState) {}
virtual kj::Own<const kj::Refcounted> writeDescriptor(
rpc::CapDescriptor::Builder descriptor) const = 0;
// Writes a CapDescriptor referencing this client. Returns a reference to some object which
// must be held at least until the message containing `descriptor` has been sent.
virtual void writeDescriptor(rpc::CapDescriptor::Builder descriptor, Tables& tables) const = 0;
// Writes a CapDescriptor referencing this client. Must be called with the
// RpcConnectionState's table locked -- a reference to them is passed as the second argument.
// The CapDescriptor must be sent before unlocking the tables, as it may become invalid at
// any time once the tables are unlocked.
virtual kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target) const = 0;
// Writes the appropriate call target for calls to this capability and returns null.
//
// TODO(cleanup): Specialize Own<void> so that we can return it here instead of
// Own<Refcounted>.
// - OR -
//
// If calls have been redirected to some other local ClientHook, returns that hook instead.
// This can happen if the capability represents a promise that has been resolved.
// implements ClientHook -----------------------------------------
......@@ -349,7 +367,7 @@ private:
// TODO(perf): Extend targetSizeInWords() to include a capability count? Here we increase
// the size by 1/16 to deal with cap descriptors possibly expanding. See also below, when
// handling the response.
// handling the response, and in RpcRequest::send().
sizeHint += sizeHint / 16;
// Don't overflow.
......@@ -408,9 +426,9 @@ private:
// which case that other thread will have constructed a new ImportClient and placed it in
// the import table.)
auto lock = connectionState.tables.lockExclusive();
KJ_IF_MAYBE(ptr, lock->imports[importId]) {
if (ptr == this) {
lock->imports[importId] = nullptr;
KJ_IF_MAYBE(import, lock->imports.find(importId)) {
if (import->client == this) {
lock->imports.erase(importId);
}
}
}
......@@ -438,22 +456,25 @@ private:
}
}
kj::Own<const kj::Refcounted> writeDescriptor(
rpc::CapDescriptor::Builder descriptor) const override {
void writeDescriptor(rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
descriptor.setReceiverHosted(importId);
return kj::addRef(*this);
}
kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target) const override {
target.setExportedCap(importId);
return nullptr;
}
// implements ClientHook -----------------------------------------
Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto request = kj::heap<RpcRequest>(connectionState, firstSegmentWordSize);
auto request = kj::heap<RpcRequest>(connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall();
callBuilder.getTarget().setExportedCap(importId);
callBuilder.setInterfaceId(interfaceId);
callBuilder.setMethodId(methodId);
request->holdRef(writeTarget(callBuilder.getTarget()));
auto root = request->getRoot();
return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
......@@ -495,6 +516,9 @@ private:
return true;
}
// TODO(now): Override writeDescriptor() and writeTarget() to redirect once the promise
// resolves.
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
// We need the returned promise to hold a reference back to this object, so that it doesn't
// disappear while the promise is still outstanding.
......@@ -511,50 +535,117 @@ private:
class PromisedAnswerClient final: public RpcClient {
public:
PromisedAnswerClient(const RpcConnectionState& connectionState, QuestionId questionId,
kj::Array<PipelineOp>&& ops, kj::Own<const RpcPipeline> pipeline);
PromisedAnswerClient(const RpcConnectionState& connectionState,
kj::Own<const RpcPipeline>&& pipeline,
kj::Array<PipelineOp>&& ops)
: RpcClient(connectionState), ops(kj::mv(ops)),
resolveSelfPromise(pipeline->onResponse().then(
[this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> {
resolve(kj::mv(response));
return kj::READY_NOW; // hack to force eager resolution.
}, [this](kj::Exception&& exception) -> kj::Promise<void> {
resolve(kj::mv(exception));
return kj::READY_NOW; // hack to force eager resolution.
})) {
state.getWithoutLock().init<Waiting>(kj::mv(pipeline));
}
kj::Own<const kj::Refcounted> writeDescriptor(rpc::CapDescriptor::Builder descriptor) const override {
void writeDescriptor(rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
auto lock = state.lockShared();
if (lock->is<Waiting>()) {
auto promisedAnswer = descriptor.initReceiverAnswer();
promisedAnswer.setQuestionId(questionId);
promisedAnswer.adoptTransform(fromPipelineOps(
Orphanage::getForMessageContaining(descriptor), ops));
return lock->get<Waiting>()->writeDescriptor(descriptor, tables, ops);
} else if (lock->is<Resolved>()) {
return connectionState.writeDescriptor(lock->get<Resolved>()->addRef(), descriptor, tables);
} else {
// TODO(now)
}
}
// Return a ref to the RpcPipeline to ensure that we don't send a Finish message for this
// call before the message containing this CapDescriptor is sent.
return kj::addRef(*lock->get<Waiting>());
kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target) const override {
auto lock = state.lockShared();
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->writeTarget(target, ops);
} else if (lock->is<Resolved>()) {
return connectionState.writeTarget(*lock->get<Resolved>(), target);
} else {
// TODO(now): Problem: This won't necessarily be a remote cap!
return connectionState.writeDescriptor(
lock->get<Resolved>().getPipelinedCap(ops), descriptor);
return newBrokenCap(kj::cp(lock->get<Broken>()));
}
}
// implements ClientHook -----------------------------------------
Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override;
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override;
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto lock = state.lockShared();
if (lock->is<Waiting>()) {
auto request = kj::heap<RpcRequest>(
connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall();
callBuilder.setInterfaceId(interfaceId);
callBuilder.setMethodId(methodId);
auto root = request->getRoot();
return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->newCall(interfaceId, methodId, firstSegmentWordSize);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()))->newCall(
interfaceId, methodId, firstSegmentWordSize);
}
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
auto lock = state.lockShared();
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->onResponse().then(kj::mvCapture(kj::heapArray(ops.asPtr()),
[](kj::Array<PipelineOp>&& ops, kj::Own<const RpcResponse>&& response) {
return response->getResults().getPipelinedCap(ops);
}));
} else if (lock->is<Resolved>()) {
return kj::Promise<kj::Own<const ClientHook>>(lock->get<Resolved>()->addRef());
} else {
return kj::Promise<kj::Own<const ClientHook>>(kj::cp(lock->get<Broken>()));
}
}
private:
const RpcConnectionState& connectionState;
QuestionId questionId;
kj::Array<PipelineOp> ops;
typedef kj::Own<const RpcPipeline> Waiting;
typedef Response<ObjectPointer> Resolved;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved>> state;
typedef kj::Own<const ClientHook> Resolved;
typedef kj::Exception Broken;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, Broken>> state;
// Keep this last, because the continuation uses *this, so it should be destroyed first to
// ensure the continuation is not still running.
kj::Promise<void> resolveSelfPromise;
void resolve(kj::Own<const RpcResponse>&& response) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Resolved>(response->getResults().getPipelinedCap(ops));
}
void resolve(const kj::Exception&& exception) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Broken>(kj::mv(exception));
}
};
kj::Own<const kj::Refcounted> writeDescriptor(
kj::Own<const ClientHook>&& cap, rpc::CapDescriptor::Builder descriptor) const {
// Write a descriptor for the given capability. Returns a reference to something which must
// be held at least until the message containing the descriptor is sent.
void writeDescriptor(kj::Own<const ClientHook> cap, rpc::CapDescriptor::Builder descriptor,
Tables& tables) const {
// Write a descriptor for the given capability. The tables must be locked by the caller and
// passed in as a parameter.
if (cap->getBrand() == this) {
return kj::downcast<const RpcClient>(*cap).writeDescriptor(descriptor);
kj::downcast<const RpcClient>(*cap).writeDescriptor(descriptor, tables);
} else {
// TODO(now): We have to figure out if the client is already in our table.
// TODO(now): We have to add a refcount to the export, and return an object that decrements
......@@ -562,6 +653,25 @@ private:
}
}
kj::Maybe<kj::Own<const ClientHook>> writeTarget(
const ClientHook& cap, rpc::Call::Target::Builder target) const {
// If calls to the given capability should pass over this connection, fill in `target`
// appropriately for such a call and return nullptr. Otherwise, return a `ClientHook` to which
// the call should be forwarded; the caller should then delegate the call to that `ClientHook`.
//
// The main case where this ends up returning non-null is if `cap` is a promise that has
// recently resolved. The application might have started building a request before the promise
// resolved, and so the request may have been built on the assumption that it would be sent over
// this network connection, but then the promise resolved to point somewhere else before the
// request was sent. Now the request has to be redirected to the new target instead.
if (cap.getBrand() == this) {
return kj::downcast<const RpcClient>(cap).writeTarget(target);
} else {
return cap.addRef();
}
}
// =====================================================================================
// CapExtractor / CapInjector implementations
......@@ -600,8 +710,8 @@ private:
auto actualRetained = retainedCaps.begin();
for (ExportId importId: retainedCaps) {
// Check if the import still exists under this ID.
KJ_IF_MAYBE(import, lock->imports[importId]) {
if (import->tryAddRemoteRef() != nullptr) {
KJ_IF_MAYBE(import, lock->imports.find(importId)) {
if (import->client != nullptr && import->client->tryAddRemoteRef() != nullptr) {
// Import indeed still exists! We are responsible for retaining it.
*actualRetained++ = importId;
}
......@@ -631,10 +741,11 @@ private:
auto lock = connectionState.tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports[importId]) {
auto& import = lock->imports[importId];
if (import.client != nullptr) {
// The import is already on the table, but it could be being deleted in another
// thread.
KJ_IF_MAYBE(ref, kj::tryAddRef(*import)) {
KJ_IF_MAYBE(ref, kj::tryAddRef(*import.client)) {
// We successfully grabbed a reference to the import without it being deleted in
// another thread. Since this import already exists, we don't have to take
// responsibility for retaining it. We can just return the existing object and
......@@ -652,7 +763,7 @@ private:
} else {
kj::refcounted<SettledImportClient>(connectionState, importId);
}
lock->imports[importId] = *result;
import.client = result;
// Note that we need to retain this import later if it still exists.
retainedCaps.lockExclusive()->add(importId);
......@@ -699,14 +810,33 @@ private:
: connectionState(connectionState) {}
~CapInjectorImpl() {}
void finish(Tables& tables) {
// Finish writing all of the CapDescriptors. Must be called with the tables locked, and the
// message must be sent before the tables are unlocked.
for (auto& entry: caps.getWithoutLock()) {
connectionState.writeDescriptor(kj::mv(entry.second.cap), entry.second.builder, tables);
}
}
// implements CapInjector ----------------------------------------
void injectCap(rpc::CapDescriptor::Builder descriptor,
kj::Own<const ClientHook>&& cap) const override {
auto ref = connectionState.writeDescriptor(kj::mv(cap), descriptor);
refs.lockExclusive()->add(kj::mv(ref));
auto lock = caps.lockExclusive();
auto result = lock->insert(std::make_pair(
identity(descriptor), CapInfo(descriptor, kj::mv(cap))));
KJ_REQUIRE(result.second, "A cap has already been injected at this location.") {
result.first->second.cap = kj::mv(cap);
break;
}
}
kj::Own<const ClientHook> getInjectedCap(rpc::CapDescriptor::Reader descriptor) const override {
auto lock = caps.lockExclusive();
auto iter = lock->find(identity(descriptor));
KJ_REQUIRE(iter != lock->end(), "getInjectedCap() called on descriptor I didn't write.");
return iter->second.cap->addRef();
switch (descriptor.which()) {
case rpc::CapDescriptor::SENDER_HOSTED: {
auto lock = connectionState.tables.lockExclusive();
......@@ -722,9 +852,11 @@ private:
case rpc::CapDescriptor::RECEIVER_HOSTED: {
auto lock = connectionState.tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports[descriptor.getReceiverHosted()]) {
KJ_IF_MAYBE(ref, kj::tryAddRef(*import)) {
return kj::mv(*ref);
KJ_IF_MAYBE(import, lock->imports.find(descriptor.getReceiverHosted())) {
if (import->client != nullptr) {
KJ_IF_MAYBE(ref, kj::tryAddRef(*import->client)) {
return kj::mv(*ref);
}
}
}
......@@ -741,9 +873,7 @@ private:
KJ_IF_MAYBE(question, lock->questions.find(promisedAnswer.getQuestionId())) {
KJ_IF_MAYBE(pipeline, question->pipeline) {
KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) {
return kj::refcounted<PromisedAnswerClient>(
connectionState, promisedAnswer.getQuestionId(),
kj::mv(*ops), kj::addRef(*pipeline));
return pipeline->getPipelinedCap(kj::mv(*ops));
}
}
}
......@@ -758,24 +888,86 @@ private:
}
}
void dropCap(rpc::CapDescriptor::Reader descriptor) const override {
// TODO(someday): We could implement this by maintaining a map from CapDescriptors to
// the corresponding refs, but is it worth it?
caps.lockExclusive()->erase(identity(descriptor));
}
private:
const RpcConnectionState& connectionState;
kj::MutexGuarded<kj::Vector<kj::Own<const kj::Refcounted>>> refs;
// List of references that need to be held until the message is destroyed.
struct CapInfo {
rpc::CapDescriptor::Builder builder;
kj::Own<const ClientHook> cap;
CapInfo(rpc::CapDescriptor::Builder& builder, kj::Own<const ClientHook>&& cap)
: builder(builder), cap(kj::mv(cap)) {}
CapInfo(const CapInfo& other);
// Work around problem where std::pair complains about the copy constructor requiring a
// non-const argument due to `builder` inheriting kj::DisableConstCopy. The copy constructor
// should never be called anyway.
};
kj::MutexGuarded<std::map<const void*, CapInfo>> caps;
// Maps CapDescriptor locations to embedded caps. The descriptors aren't actually filled in
// until just before the message is sent.
static const void* identity(const rpc::CapDescriptor::Reader& desc) {
// TODO(cleanup): Don't rely on internal APIs here.
return _::PointerHelpers<rpc::CapDescriptor>::getInternalReader(desc).getLocation();
}
};
// =====================================================================================
// RequestHook/PipelineHook/ResponseHook implementations
class QuestionRef: public kj::Refcounted {
public:
inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id)
: connectionState(connectionState), id(id), resultCaps(connectionState) {}
~QuestionRef() {
// Send the "Finish" message.
auto message = connectionState.connection->newOutgoingMessage(
messageSizeHint<rpc::Finish>() + resultCaps.retainedListSizeHint(true));
auto builder = message->getBody().getAs<rpc::Message>().initFinish();
builder.setQuestionId(id);
builder.adoptRetainedCaps(resultCaps.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder)));
message->send();
// Check if the question has returned and, if so, remove it from the table.
// Remove question ID from the table. Must do this *after* sending `Finish` to ensure that
// the ID is not re-allocated before the `Finish` message can be sent.
{
auto lock = connectionState.tables.lockExclusive();
auto& question = KJ_ASSERT_NONNULL(
lock->questions.find(id), "Question ID no longer on table?");
if (question.paramCaps == nullptr) {
// Call has already returned, so we can now remove it from the table.
KJ_ASSERT(lock->questions.erase(id));
} else {
question.isFinished = true;
}
}
}
inline QuestionId getId() const { return id; }
CapExtractorImpl& getResultCapExtractor() { return resultCaps; }
private:
const RpcConnectionState& connectionState;
QuestionId id;
CapExtractorImpl resultCaps;
};
class RpcRequest final: public RequestHook {
public:
RpcRequest(const RpcConnectionState& connectionState, uint firstSegmentWordSize)
RpcRequest(const RpcConnectionState& connectionState, uint firstSegmentWordSize,
kj::Own<const RpcClient>&& target)
: connectionState(connectionState),
target(kj::mv(target)),
message(connectionState.connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 : firstSegmentWordSize + messageSizeHint<rpc::Call>())),
injector(kj::heap<CapInjectorImpl>(connectionState)),
......@@ -791,39 +983,77 @@ private:
}
RemotePromise<ObjectPointer> send() override {
auto paf = kj::newPromiseAndFulfiller<Response<ObjectPointer>>(connectionState.eventLoop);
QuestionId questionId;
kj::Promise<kj::Own<RpcResponse>> promise = nullptr;
{
auto lock = connectionState.tables.lockExclusive();
auto& question = lock->questions.next(questionId);
KJ_IF_MAYBE(redirect, target->writeTarget(callBuilder.getTarget())) {
// Whoops, this capability has been redirected while we were building the request!
// We'll have to make a new request and do a copy. Ick.
lock.release();
size_t sizeHint = paramsBuilder.targetSizeInWords();
// TODO(perf): See TODO in RpcClient::call() about why we need to inflate the size a bit.
sizeHint += sizeHint / 16;
// Don't overflow.
if (uint(sizeHint) != sizeHint) {
sizeHint = ~uint(0);
}
auto replacement = redirect->get()->newCall(
callBuilder.getInterfaceId(), callBuilder.getMethodId(), sizeHint);
replacement.set(paramsBuilder);
return replacement.send();
} else {
injector->finish(*lock);
auto paf = kj::newPromiseAndFulfiller<kj::Own<RpcResponse>>(connectionState.eventLoop);
auto& question = lock->questions.next(questionId);
callBuilder.setQuestionId(questionId);
question.isStarted = true;
question.paramCaps = kj::mv(injector);
question.fulfiller = kj::mv(paf.fulfiller);
callBuilder.setQuestionId(questionId);
question.isStarted = true;
question.paramCaps = kj::mv(injector);
message->send();
question.fulfiller = kj::mv(paf.fulfiller);
promise = kj::mv(paf.promise);
}
}
auto pipeline = kj::refcounted<RpcPipeline>(connectionState, questionId);
auto questionRef = kj::refcounted<QuestionRef>(connectionState, questionId);
// If the caller discards the pipeline without discarding the promise, we need the pipeline
// to stay alive so that we don't cancel the call altogether.
auto promiseWithPipelineRef = paf.promise.then(kj::mvCapture(pipeline->addRef(),
[](kj::Own<const PipelineHook>&&, Response<ObjectPointer>&& response)
-> Response<ObjectPointer> {
auto promiseWithQuestionRef = promise.then(kj::mvCapture(kj::addRef(*questionRef),
[](kj::Own<const QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response)
-> kj::Own<const RpcResponse> {
response->setQuestionRef(kj::mv(questionRef));
return kj::mv(response);
}));
message->send();
auto forkedPromise = promiseWithQuestionRef.fork();
auto appPromise = forkedPromise.addBranch().then([](kj::Own<const RpcResponse>&& response) {
auto reader = response->getResults();
return Response<ObjectPointer>(reader, kj::mv(response));
});
auto pipeline = kj::refcounted<RpcPipeline>(
connectionState, questionId, kj::mv(forkedPromise));
return RemotePromise<ObjectPointer>(
kj::mv(promiseWithPipelineRef),
kj::mv(appPromise),
ObjectPointer::Pipeline(kj::mv(pipeline)));
}
private:
const RpcConnectionState& connectionState;
kj::Own<const RpcClient> target;
kj::Own<OutgoingRpcMessage> message;
kj::Own<CapInjectorImpl> injector;
CapBuilderContext context;
......@@ -831,43 +1061,67 @@ private:
ObjectPointer::Builder paramsBuilder;
};
class RpcPipeline: public PipelineHook, public kj::Refcounted {
class RpcPipeline final: public PipelineHook, public kj::Refcounted {
public:
RpcPipeline(const RpcConnectionState& connectionState, QuestionId questionId)
: connectionState(connectionState), questionId(questionId) {}
~RpcPipeline() noexcept(false) {
uint sizeHint = messageSizeHint<rpc::Finish>();
KJ_IF_MAYBE(ce, capExtractor) {
sizeHint += ce->retainedListSizeHint(true);
}
auto finishMessage = connectionState.connection->newOutgoingMessage(sizeHint);
RpcPipeline(const RpcConnectionState& connectionState, QuestionId questionId,
kj::ForkedPromise<kj::Own<const RpcResponse>>&& redirectLaterParam)
: connectionState(connectionState),
redirectLater(kj::mv(redirectLaterParam)),
resolveSelfPromise(redirectLater.addBranch().then(
[this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> {
resolve(kj::mv(response));
return kj::READY_NOW; // hack to force eager resolution.
}, [this](kj::Exception&& exception) -> kj::Promise<void> {
resolve(kj::mv(exception));
return kj::READY_NOW;
})) {
// Construct a new RpcPipeline.
state.getWithoutLock().init<Waiting>(questionId);
}
rpc::Finish::Builder builder = finishMessage->getBody().initAs<rpc::Message>().initFinish();
kj::Promise<kj::Own<const RpcResponse>> onResponse() const {
return redirectLater.addBranch();
}
builder.setQuestionId(questionId);
kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target, kj::ArrayPtr<const PipelineOp> ops) const {
// Initializes `target` to a PromisedAnswer for a pipelined capability, *or* returns the
// specified capability directly.
//
// The caller *should* have a lock on the connection state's tables while calling this, so
// that a Finish message cannot be sent before the caller manages to send its `Call`.
KJ_IF_MAYBE(ce, capExtractor) {
builder.adoptRetainedCaps(ce->finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder)));
auto lock = state.lockExclusive();
if (lock->is<Waiting>()) {
auto builder = target.initPromisedAnswer();
builder.setQuestionId(lock->get<Waiting>());
builder.adoptTransform(fromPipelineOps(Orphanage::getForMessageContaining(builder), ops));
return nullptr;
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()));
}
}
finishMessage->send();
{
auto lock = connectionState.tables.lockExclusive();
auto& question = KJ_ASSERT_NONNULL(lock->questions.find(questionId),
"RpcPipeline had invalid questionId?");
question.pipeline = nullptr;
if (question.isReturned) {
KJ_ASSERT(lock->questions.erase(questionId));
}
void writeDescriptor(rpc::CapDescriptor::Builder descriptor, Tables& tables,
kj::ArrayPtr<const PipelineOp> ops) const {
auto lock = state.lockExclusive();
if (lock->is<Waiting>()) {
auto promisedAnswer = descriptor.initReceiverAnswer();
promisedAnswer.setQuestionId(lock->get<Waiting>());
promisedAnswer.adoptTransform(fromPipelineOps(
Orphanage::getForMessageContaining(descriptor), ops));
} else if (lock->is<Resolved>()) {
connectionState.writeDescriptor(lock->get<Resolved>()->getResults().getPipelinedCap(ops),
descriptor, tables);
} else {
connectionState.writeDescriptor(newBrokenCap(kj::cp(lock->get<Broken>())),
descriptor, tables);
}
}
kj::Promise<Response<ObjectPointer>> getResponse();
// implements PipelineHook ---------------------------------------
kj::Own<const PipelineHook> addRef() const override {
......@@ -883,17 +1137,45 @@ private:
}
kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const override {
return kj::refcounted<PromisedAnswerClient>(
connectionState, questionId, kj::mv(ops), kj::addRef(*this));
auto lock = state.lockExclusive();
if (lock->is<Waiting>()) {
return kj::refcounted<PromisedAnswerClient>(
connectionState, kj::addRef(*this), kj::mv(ops));
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()));
}
}
private:
const RpcConnectionState& connectionState;
QuestionId questionId;
kj::Maybe<CapExtractorImpl&> capExtractor;
kj::ForkedPromise<kj::Own<const RpcResponse>> redirectLater;
typedef QuestionId Waiting;
typedef kj::Own<const RpcResponse> Resolved;
typedef kj::Exception Broken;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, Broken>> state;
// Keep this last, because the continuation uses *this, so it should be destroyed first to
// ensure the continuation is not still running.
kj::Promise<void> resolveSelfPromise;
void resolve(kj::Own<const RpcResponse>&& response) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Resolved>(kj::mv(response));
}
void resolve(const kj::Exception&& exception) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Broken>(kj::mv(exception));
}
};
class RpcResponse: public ResponseHook {
class RpcResponse final: public ResponseHook, public kj::Refcounted {
public:
RpcResponse(const RpcConnectionState& connectionState,
kj::Own<IncomingRpcMessage>&& message,
......@@ -903,15 +1185,24 @@ private:
context(extractor),
reader(context.imbue(results)) {}
ObjectPointer::Reader getResults() {
ObjectPointer::Reader getResults() const {
return reader;
}
kj::Own<const RpcResponse> addRef() const {
return kj::addRef(*this);
}
void setQuestionRef(kj::Own<const QuestionRef>&& questionRef) {
this->questionRef = kj::mv(questionRef);
}
private:
kj::Own<IncomingRpcMessage> message;
CapExtractorImpl extractor;
CapReaderContext context;
ObjectPointer::Reader reader;
kj::Own<const QuestionRef> questionRef;
};
// =====================================================================================
......@@ -922,7 +1213,8 @@ private:
RpcServerResponse(const RpcConnectionState& connectionState,
kj::Own<OutgoingRpcMessage>&& message,
ObjectPointer::Builder results)
: message(kj::mv(message)),
: connectionState(connectionState),
message(kj::mv(message)),
injector(connectionState),
context(injector),
builder(context.imbue(results)) {}
......@@ -932,10 +1224,13 @@ private:
}
void send() {
auto lock = connectionState.tables.lockExclusive();
injector.finish(*lock);
message->send();
}
private:
const RpcConnectionState& connectionState;
kj::Own<OutgoingRpcMessage> message;
CapInjectorImpl injector;
CapBuilderContext context;
......@@ -1188,23 +1483,23 @@ private:
auto reader = message->getBody().getAs<rpc::Message>();
switch (reader.which()) {
case rpc::Message::UNIMPLEMENTED:
doUnimplemented(reader.getUnimplemented());
handleUnimplemented(reader.getUnimplemented());
break;
case rpc::Message::ABORT:
doAbort(reader.getAbort());
handleAbort(reader.getAbort());
break;
case rpc::Message::CALL:
doCall(kj::mv(message), reader.getCall());
handleCall(kj::mv(message), reader.getCall());
break;
case rpc::Message::RETURN:
doReturn(kj::mv(message), reader.getReturn());
handleReturn(kj::mv(message), reader.getReturn());
break;
case rpc::Message::FINISH:
doFinish(reader.getFinish());
handleFinish(reader.getFinish());
break;
default: {
......@@ -1217,15 +1512,15 @@ private:
}
}
void doUnimplemented(const rpc::Message::Reader& message) {
void handleUnimplemented(const rpc::Message::Reader& message) {
// TODO(now)
}
void doAbort(const rpc::Exception::Reader& exception) {
void handleAbort(const rpc::Exception::Reader& exception) {
kj::throwRecoverableException(toException(exception));
}
void doCall(kj::Own<IncomingRpcMessage>&& message, const rpc::Call::Reader& call) {
void handleCall(kj::Own<IncomingRpcMessage>&& message, const rpc::Call::Reader& call) {
kj::Own<const ClientHook> capability;
auto target = call.getTarget();
......@@ -1320,13 +1615,19 @@ private:
}
}
void doReturn(kj::Own<IncomingRpcMessage>&& message, const rpc::Return::Reader& ret) {
void handleReturn(kj::Own<IncomingRpcMessage>&& message, const rpc::Return::Reader& ret) {
kj::Own<CapInjectorImpl> paramCapsToRelease;
auto lock = tables.lockExclusive();
KJ_IF_MAYBE(question, lock->questions.find(ret.getQuestionId())) {
KJ_REQUIRE(!question->isReturned, "Duplicate Return.") { return; }
question->isReturned = true;
KJ_REQUIRE(question->paramCaps != nullptr, "Duplicate Return.") { return; }
KJ_IF_MAYBE(pc, question->paramCaps) {
// Release these later, after unlocking.
paramCapsToRelease = kj::mv(*pc);
} else {
KJ_FAIL_REQUIRE("Duplicate return.") { return; }
}
for (ExportId retained: ret.getRetainedCaps()) {
KJ_IF_MAYBE(exp, lock->exports.find(retained)) {
......@@ -1336,15 +1637,10 @@ private:
}
}
// We can now release the caps in the parameter message. Well... we can release them once
// we release the lock, at least.
paramCapsToRelease = kj::mv(question->paramCaps);
// TODO(now): Handle exception/cancel response
auto response = kj::heap<RpcResponse>(*this, kj::mv(message), ret.getAnswer());
auto imbuedResults = response->getResults();
question->fulfiller->fulfill(Response<ObjectPointer>(imbuedResults, kj::mv(response)));
auto response = kj::refcounted<RpcResponse>(*this, kj::mv(message), ret.getAnswer());
question->fulfiller->fulfill(kj::mv(response));
if (question->pipeline == nullptr) {
lock->questions.erase(ret.getQuestionId());
......@@ -1354,7 +1650,7 @@ private:
}
}
void doFinish(const rpc::Finish::Reader& finish) {
void handleFinish(const rpc::Finish::Reader& finish) {
kj::Maybe<kj::Own<const PipelineHook>> pipelineToRelease;
auto lock = tables.lockExclusive();
......@@ -1399,6 +1695,10 @@ public:
Impl(VatNetworkBase& network, SturdyRefRestorerBase& restorer, const kj::EventLoop& eventLoop)
: network(network), restorer(restorer), eventLoop(eventLoop) {}
Capability::Client connect(_::StructReader reader) {
// TODO(now)
}
private:
VatNetworkBase& network;
SturdyRefRestorerBase& restorer;
......@@ -1411,7 +1711,7 @@ RpcSystemBase::RpcSystemBase(VatNetworkBase& network, SturdyRefRestorerBase& res
RpcSystemBase::~RpcSystemBase() noexcept(false) {}
Capability::Client RpcSystemBase::baseConnect(_::StructReader reader) {
impl->connect(reader);
return impl->connect(reader);
}
} // namespace _ (private)
......
......@@ -156,6 +156,9 @@ public:
inline ArrayPtr<T> asPtr() {
return ArrayPtr<T>(ptr, size_);
}
inline ArrayPtr<const T> asPtr() const {
return ArrayPtr<T>(ptr, size_);
}
inline size_t size() const { return size_; }
inline T& operator[](size_t index) const {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment