Commit 3070f627 authored by Kenton Varda's avatar Kenton Varda

Implement handling of Resolve messages. Though currently they are never sent.

parent e4a5344b
...@@ -278,10 +278,9 @@ TEST_F(RpcTest, Basic) { ...@@ -278,10 +278,9 @@ TEST_F(RpcTest, Basic) {
request1.setJ(true); request1.setJ(true);
auto promise1 = request1.send(); auto promise1 = request1.send();
auto request2 = client.bazRequest(); // We used to call bar() after baz(), hence the numbering, but this masked the case where the
initTestMessage(request2.initS()); // RPC system actually disconnected on bar() (thus returning an exception, which we decided
auto promise2 = request2.send(); // was expected).
bool barFailed = false; bool barFailed = false;
auto request3 = client.barRequest(); auto request3 = client.barRequest();
auto promise3 = loop.there(request3.send(), auto promise3 = loop.there(request3.send(),
...@@ -291,6 +290,10 @@ TEST_F(RpcTest, Basic) { ...@@ -291,6 +290,10 @@ TEST_F(RpcTest, Basic) {
barFailed = true; barFailed = true;
}); });
auto request2 = client.bazRequest();
initTestMessage(request2.initS());
auto promise2 = request2.send();
EXPECT_EQ(0, restorer.callCount); EXPECT_EQ(0, restorer.callCount);
auto response1 = loop.wait(kj::mv(promise1)); auto response1 = loop.wait(kj::mv(promise1));
......
...@@ -241,6 +241,7 @@ public: ...@@ -241,6 +241,7 @@ public:
disconnectFulfiller(kj::mv(disconnectFulfiller)), disconnectFulfiller(kj::mv(disconnectFulfiller)),
tasks(eventLoop, *this) { tasks(eventLoop, *this) {
tasks.add(messageLoop()); tasks.add(messageLoop());
tables.getWithoutLock().resolutionChainTail = kj::refcounted<ResolutionChain>();
} }
kj::Own<const ClientHook> restore(ObjectPointer::Reader objectId) { kj::Own<const ClientHook> restore(ObjectPointer::Reader objectId) {
...@@ -255,7 +256,8 @@ public: ...@@ -255,7 +256,8 @@ public:
// We need a dummy paramCaps since null normally indicates that the question has completed. // We need a dummy paramCaps since null normally indicates that the question has completed.
question.paramCaps = kj::heap<CapInjectorImpl>(*this); question.paramCaps = kj::heap<CapInjectorImpl>(*this);
questionRef = kj::refcounted<QuestionRef>(*this, questionId, kj::mv(paf.fulfiller)); questionRef = kj::refcounted<QuestionRef>(*this, questionId, kj::mv(paf.fulfiller),
kj::addRef(*lock->resolutionChainTail));
question.selfRef = *questionRef; question.selfRef = *questionRef;
paf.promise.attach(kj::addRef(*questionRef)); paf.promise.attach(kj::addRef(*questionRef));
...@@ -347,6 +349,8 @@ public: ...@@ -347,6 +349,8 @@ public:
} }
private: private:
class ResolutionChain;
class RpcClient;
class ImportClient; class ImportClient;
class PromiseClient; class PromiseClient;
class CapInjectorImpl; class CapInjectorImpl;
...@@ -413,9 +417,14 @@ private: ...@@ -413,9 +417,14 @@ private:
Import& operator=(Import&&) = default; Import& operator=(Import&&) = default;
// If we don't explicitly write all this, we get some stupid error deep in STL. // If we don't explicitly write all this, we get some stupid error deep in STL.
kj::Maybe<ImportClient&> client; kj::Maybe<ImportClient&> importClient;
// Becomes null when the import is destroyed. // Becomes null when the import is destroyed.
kj::Maybe<RpcClient&> appClient;
// Either a copy of importClient, or, in the case of promises, the wrapping PromiseClient.
// Becomes null when it is discarded *or* when the import is destroyed (e.g. the promise is
// resolved and the import is no longer needed).
kj::Maybe<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> promiseFulfiller; kj::Maybe<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> promiseFulfiller;
// If non-null, the import is a promise. // If non-null, the import is a promise.
}; };
...@@ -438,6 +447,12 @@ private: ...@@ -438,6 +447,12 @@ private:
std::unordered_map<const ClientHook*, ExportId> exportsByCap; std::unordered_map<const ClientHook*, ExportId> exportsByCap;
// Maps already-exported ClientHook objects to their ID in the export table. // Maps already-exported ClientHook objects to their ID in the export table.
kj::Own<ResolutionChain> resolutionChainTail;
// The end of the resolution chain. This node actually isn't filled in yet, but it will be
// filled in and the chain will be extended with a new node any time a `Resolve` is received.
// CapExtractors need to hold a ref to resolutionChainTail to prevent resolved promises from
// becoming invalid while the app is still processing the message. See `ResolutionChain`.
kj::Maybe<kj::Exception> networkException; kj::Maybe<kj::Exception> networkException;
// If the connection has failed, this is the exception describing the failure. All future // If the connection has failed, this is the exception describing the failure. All future
// calls should throw this exception. // calls should throw this exception.
...@@ -446,6 +461,68 @@ private: ...@@ -446,6 +461,68 @@ private:
kj::TaskSet tasks; kj::TaskSet tasks;
// =====================================================================================
class ResolutionChain: public kj::Refcounted {
// A chain of pending promise resolutions which may affect messages that are still being
// processed.
//
// When a `Resolve` message comes in, we can't just handle it and then release the original
// promise import all at once, because it's possible that the application is still processing
// the `params` or `results` from a previous call, and that it will encounter an instance of
// the promise as it does. We need to hold off on the release until we've actually gotten
// through all outstanding messages.
//
// To that end, we have the resolution chain. Each time a `CapExtractorImpl` is created --
// representing a message to be consumed by the application -- it takes a reference to the
// current end of the chain. When a `Resolve` message arrives, it is added to the end of the
// chain, and thus all `CapExtractorImpl`s that exist at that point now hold a reference to it,
// but new `CapExtractorImpl`s will not. Once all references are dropped, the original promise
// can be released.
//
// The connection state actually holds one instance of ResolutionChain which doesn't yet have
// a promise attached to it, representing the end of the chain. This is what allows a new
// resolution to be "added to the end" and have existing `CapExtractorImpl`s suddenly be
// holding a reference to it.
public:
kj::Own<ResolutionChain> add(ExportId importId,
kj::Own<const ClientHook>&& replacement) {
// Add the a new resolution to the chain. Returns the new end-of-chain.
this->importId = importId;
this->replacement = kj::mv(replacement);
auto result = kj::refcounted<ResolutionChain>();
next = kj::addRef(*result);
filled = true;
return kj::mv(result);
}
kj::Maybe<kj::Own<const ClientHook>> find(ExportId importId) const {
// Look for the given import ID in the resolution chain.
const ResolutionChain* ptr = this;
while (ptr->filled) {
if (ptr->importId == importId) {
return ptr->replacement->addRef();
}
ptr = ptr->next;
}
return nullptr;
}
private:
kj::Own<const ResolutionChain> next;
bool filled = false;
ExportId importId;
kj::Own<const ClientHook> replacement;
};
// ===================================================================================== // =====================================================================================
// ClientHook implementations // ClientHook implementations
...@@ -547,10 +624,11 @@ private: ...@@ -547,10 +624,11 @@ private:
// Remove self from the import table, if the table is still pointing at us. (It's possible // Remove self from the import table, if the table is still pointing at us. (It's possible
// that another thread attempted to obtain this import just as the destructor started, in // that another thread attempted to obtain this import just as the destructor started, in
// which case that other thread will have constructed a new ImportClient and placed it in // which case that other thread will have constructed a new ImportClient and placed it in
// the import table.) // the import table. Therefore, we must actually verify that the import table points at
// this object.)
auto lock = connectionState->tables.lockExclusive(); auto lock = connectionState->tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports.find(importId)) { KJ_IF_MAYBE(import, lock->imports.find(importId)) {
KJ_IF_MAYBE(i, import->client) { KJ_IF_MAYBE(i, import->importClient) {
if (i == this) { if (i == this) {
lock->imports.erase(importId); lock->imports.erase(importId);
} }
...@@ -564,17 +642,9 @@ private: ...@@ -564,17 +642,9 @@ private:
} }
} }
kj::Maybe<kj::Own<ImportClient>> tryAddRemoteRef() { void addRemoteRef() {
// Add a new RemoteRef and return a new ref to this client representing it. Returns null // Add a new RemoteRef and return a new ref to this client representing it.
// if this client is being deleted in another thread, in which case the caller should ++remoteRefcount;
// construct a new one.
KJ_IF_MAYBE(ref, kj::tryAddRef(*this)) {
++remoteRefcount;
return kj::mv(*ref);
} else {
return nullptr;
}
} }
kj::Maybe<ExportId> writeDescriptor( kj::Maybe<ExportId> writeDescriptor(
...@@ -673,9 +743,11 @@ private: ...@@ -673,9 +743,11 @@ private:
public: public:
PromiseClient(const RpcConnectionState& connectionState, PromiseClient(const RpcConnectionState& connectionState,
kj::Own<const ClientHook> initial, kj::Own<const ClientHook> initial,
kj::Promise<kj::Own<const ClientHook>> eventual) kj::Promise<kj::Own<const ClientHook>> eventual,
kj::Maybe<ExportId> importId)
: RpcClient(connectionState), : RpcClient(connectionState),
inner(kj::mv(initial)), inner(kj::mv(initial)),
importId(importId),
fork(connectionState.eventLoop.fork(kj::mv(eventual))), fork(connectionState.eventLoop.fork(kj::mv(eventual))),
resolveSelfPromise(connectionState.eventLoop.there(fork.addBranch(), resolveSelfPromise(connectionState.eventLoop.there(fork.addBranch(),
[this](kj::Own<const ClientHook>&& resolution) { [this](kj::Own<const ClientHook>&& resolution) {
...@@ -692,6 +764,23 @@ private: ...@@ -692,6 +764,23 @@ private:
resolveSelfPromise.eagerlyEvaluate(connectionState.eventLoop); resolveSelfPromise.eagerlyEvaluate(connectionState.eventLoop);
} }
~PromiseClient() noexcept(false) {
KJ_IF_MAYBE(id, importId) {
// This object is representing an import promise. That means the import table may still
// contain a pointer back to it. Remove that pointer. Note that we have to verify that
// the import still exists and the pointer still points back to this object because this
// object may actually outlive the import.
auto lock = connectionState->tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports.find(*id)) {
KJ_IF_MAYBE(c, import->appClient) {
if (c == this) {
import->appClient = nullptr;
}
}
}
}
}
kj::Maybe<ExportId> writeDescriptor( kj::Maybe<ExportId> writeDescriptor(
rpc::CapDescriptor::Builder descriptor, Tables& tables) const override { rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
auto cap = inner.lockExclusive()->get()->addRef(); auto cap = inner.lockExclusive()->get()->addRef();
...@@ -716,6 +805,7 @@ private: ...@@ -716,6 +805,7 @@ private:
private: private:
kj::MutexGuarded<kj::Own<const ClientHook>> inner; kj::MutexGuarded<kj::Own<const ClientHook>> inner;
kj::Maybe<ExportId> importId;
kj::ForkedPromise<kj::Own<const ClientHook>> fork; kj::ForkedPromise<kj::Own<const ClientHook>> fork;
// Keep this last, because the continuation uses *this, so it should be destroyed first to // Keep this last, because the continuation uses *this, so it should be destroyed first to
...@@ -784,8 +874,10 @@ private: ...@@ -784,8 +874,10 @@ private:
// Reads CapDescriptors from a received message. // Reads CapDescriptors from a received message.
public: public:
CapExtractorImpl(const RpcConnectionState& connectionState) CapExtractorImpl(const RpcConnectionState& connectionState,
: connectionState(connectionState) {} kj::Own<const ResolutionChain> resolutionChain)
: connectionState(connectionState),
resolutionChain(kj::mv(resolutionChain)) {}
~CapExtractorImpl() noexcept(false) { ~CapExtractorImpl() noexcept(false) {
KJ_ASSERT(retainedCaps.getWithoutLock().size() == 0, KJ_ASSERT(retainedCaps.getWithoutLock().size() == 0,
...@@ -806,9 +898,27 @@ private: ...@@ -806,9 +898,27 @@ private:
return (count * sizeof(ExportId) + (sizeof(ExportId) - 1)) / sizeof(word); return (count * sizeof(ExportId) + (sizeof(ExportId) - 1)) / sizeof(word);
} }
Orphan<List<ExportId>> finalizeRetainedCaps(Orphanage orphanage) { struct FinalizedRetainedCaps {
// Called on finalization, when the lock is no longer needed. // List of capabilities extracted from this message which are to be retained past the
// message's release.
Orphan<List<ExportId>> exportList;
// List of export IDs, to be placed in the Return/Finish message.
kj::Vector<kj::Own<const ClientHook>> refs;
// List of ClientHooks which need to be kept live until the message is sent, to prevent
// their premature release.
};
FinalizedRetainedCaps finalizeRetainedCaps(Orphanage orphanage) {
// Build the list of export IDs found in this message which are to be retained past the
// message's release.
//
// `capsToKeepUntilSend` will be filled in with
// Called on finalization, when all extractions have ceased, so we can skip the lock.
kj::Vector<ExportId> retainedCaps = kj::mv(this->retainedCaps.getWithoutLock()); kj::Vector<ExportId> retainedCaps = kj::mv(this->retainedCaps.getWithoutLock());
kj::Vector<kj::Own<const ClientHook>> refs(retainedCaps.size());
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState.tables.lockExclusive();
...@@ -816,11 +926,13 @@ private: ...@@ -816,11 +926,13 @@ private:
for (ExportId importId: retainedCaps) { for (ExportId importId: retainedCaps) {
// Check if the import still exists under this ID. // Check if the import still exists under this ID.
KJ_IF_MAYBE(import, lock->imports.find(importId)) { KJ_IF_MAYBE(import, lock->imports.find(importId)) {
KJ_IF_MAYBE(i, import->client) { KJ_IF_MAYBE(ic, import->importClient) {
if (i->tryAddRemoteRef() != nullptr) { KJ_IF_MAYBE(ref, kj::tryAddRef(*ic)) {
// Import indeed still exists! We are responsible for retaining it. // Import indeed still exists! We'll return it in the retained caps, which means it
// TODO(now): Do we need to hold on to the ref that tryAddRemoteRef() returned? // now has a new remote ref.
ic->addRemoteRef();
*actualRetained++ = importId; *actualRetained++ = importId;
refs.add(kj::mv(*ref));
} }
} }
} }
...@@ -836,24 +948,65 @@ private: ...@@ -836,24 +948,65 @@ private:
resultBuilder.set(count++, *iter); resultBuilder.set(count++, *iter);
} }
return kj::mv(result); return FinalizedRetainedCaps { kj::mv(result), kj::mv(refs) };
}
static kj::Own<const ClientHook> extractCapAndAddRef(
const RpcConnectionState& connectionState, Tables& lockedTables,
rpc::CapDescriptor::Reader descriptor) {
// Interpret the given capability descriptor and, if it is an import, immediately give it
// a remote ref. This is called when interpreting messages that have a CapabilityDescriptor
// but do not have a corresponding response message where a list of retained caps is given.
// In these cases, the cap is always assumed retained, and must be explicitly released.
// For example, the 'Resolve' message contains a capability which is presumed to be retained.
return extractCapImpl(connectionState, lockedTables, descriptor,
*lockedTables.resolutionChainTail, nullptr);
} }
// implements CapDescriptor ------------------------------------------------ // implements CapDescriptor ------------------------------------------------
kj::Own<const ClientHook> extractCap(rpc::CapDescriptor::Reader descriptor) const override { kj::Own<const ClientHook> extractCap(rpc::CapDescriptor::Reader descriptor) const override {
return extractCapImpl(connectionState, *connectionState.tables.lockExclusive(), descriptor,
*resolutionChain, retainedCaps);
}
private:
const RpcConnectionState& connectionState;
kj::Own<const ResolutionChain> resolutionChain;
// Reference to the resolution chain, which prevents any promises that might be extracted from
// this message from being invalidated by `Resolve` messages before extraction is finished.
// Simply holding on to the chain keeps the import table entries valid.
kj::MutexGuarded<kj::Vector<ExportId>> retainedCaps;
// Imports which we are responsible for retaining, should they still exist at the time that
// this message is released.
static kj::Own<const ClientHook> extractCapImpl(
const RpcConnectionState& connectionState, Tables& tables,
rpc::CapDescriptor::Reader descriptor,
const ResolutionChain& resolutionChain,
kj::Maybe<const kj::MutexGuarded<kj::Vector<ExportId>>&> retainedCaps) {
switch (descriptor.which()) { switch (descriptor.which()) {
case rpc::CapDescriptor::SENDER_HOSTED: case rpc::CapDescriptor::SENDER_HOSTED:
case rpc::CapDescriptor::SENDER_PROMISE: { case rpc::CapDescriptor::SENDER_PROMISE: {
ExportId importId = descriptor.getSenderHosted(); ExportId importId = descriptor.getSenderHosted();
auto lock = connectionState.tables.lockExclusive(); // First check to see if this import ID is a promise that has resolved since when this
// message was received. In this case, the original import ID will already have been
// dropped and could even have been reused for another capability. Luckily, the
// resolution chain holds the capability we actually want.
KJ_IF_MAYBE(resolution, resolutionChain.find(importId)) {
return kj::mv(*resolution);
}
auto& import = lock->imports[importId]; // No recent resolutions. Check the import table then.
KJ_IF_MAYBE(i, import.client) { auto& import = tables.imports[importId];
KJ_IF_MAYBE(c, import.appClient) {
// The import is already on the table, but it could be being deleted in another // The import is already on the table, but it could be being deleted in another
// thread. // thread.
KJ_IF_MAYBE(ref, kj::tryAddRef(*i)) { KJ_IF_MAYBE(ref, kj::tryAddRef(*c)) {
// We successfully grabbed a reference to the import without it being deleted in // We successfully grabbed a reference to the import without it being deleted in
// another thread. Since this import already exists, we don't have to take // another thread. Since this import already exists, we don't have to take
// responsibility for retaining it. We can just return the existing object and // responsibility for retaining it. We can just return the existing object and
...@@ -865,9 +1018,17 @@ private: ...@@ -865,9 +1018,17 @@ private:
// No import for this ID exists currently, so create one. // No import for this ID exists currently, so create one.
kj::Own<ImportClient> importClient = kj::Own<ImportClient> importClient =
kj::refcounted<ImportClient>(connectionState, importId); kj::refcounted<ImportClient>(connectionState, importId);
import.client = *importClient; import.importClient = *importClient;
KJ_IF_MAYBE(rc, retainedCaps) {
// We need to retain this import later if it still exists.
rc->lockExclusive()->add(importId);
} else {
// Automatically increment the refcount.
importClient->addRemoteRef();
}
kj::Own<ClientHook> result; kj::Own<RpcClient> result;
if (descriptor.which() == rpc::CapDescriptor::SENDER_PROMISE) { if (descriptor.which() == rpc::CapDescriptor::SENDER_PROMISE) {
// TODO(now): Check for pending `Resolve` messages replacing this import ID, and if // TODO(now): Check for pending `Resolve` messages replacing this import ID, and if
// one exists, use that client instead. // one exists, use that client instead.
...@@ -876,29 +1037,26 @@ private: ...@@ -876,29 +1037,26 @@ private:
import.promiseFulfiller = kj::mv(paf.fulfiller); import.promiseFulfiller = kj::mv(paf.fulfiller);
paf.promise.attach(kj::addRef(*importClient)); paf.promise.attach(kj::addRef(*importClient));
result = kj::refcounted<PromiseClient>( result = kj::refcounted<PromiseClient>(
connectionState, kj::mv(importClient), kj::mv(paf.promise)); connectionState, kj::mv(importClient), kj::mv(paf.promise), importId);
} else { } else {
result = kj::mv(importClient); result = kj::mv(importClient);
} }
// Note that we need to retain this import later if it still exists. import.appClient = *result;
retainedCaps.lockExclusive()->add(importId);
return kj::mv(result); return kj::mv(result);
} }
case rpc::CapDescriptor::RECEIVER_HOSTED: { case rpc::CapDescriptor::RECEIVER_HOSTED: {
auto lock = connectionState.tables.lockExclusive(); // TODO(perf): shared? KJ_IF_MAYBE(exp, tables.exports.find(descriptor.getReceiverHosted())) {
KJ_IF_MAYBE(exp, lock->exports.find(descriptor.getReceiverHosted())) {
return exp->clientHook->addRef(); return exp->clientHook->addRef();
} }
return newBrokenCap("invalid 'receiverHosted' export ID"); return newBrokenCap("invalid 'receiverHosted' export ID");
} }
case rpc::CapDescriptor::RECEIVER_ANSWER: { case rpc::CapDescriptor::RECEIVER_ANSWER: {
auto lock = connectionState.tables.lockExclusive();
auto promisedAnswer = descriptor.getReceiverAnswer(); auto promisedAnswer = descriptor.getReceiverAnswer();
KJ_IF_MAYBE(answer, lock->answers.find(promisedAnswer.getQuestionId())) { KJ_IF_MAYBE(answer, tables.answers.find(promisedAnswer.getQuestionId())) {
if (answer->active) { if (answer->active) {
KJ_IF_MAYBE(pipeline, answer->pipeline) { KJ_IF_MAYBE(pipeline, answer->pipeline) {
KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) { KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) {
...@@ -920,13 +1078,6 @@ private: ...@@ -920,13 +1078,6 @@ private:
return newBrokenCap("unknown CapDescriptor type"); return newBrokenCap("unknown CapDescriptor type");
} }
} }
private:
const RpcConnectionState& connectionState;
kj::MutexGuarded<kj::Vector<ExportId>> retainedCaps;
// Imports which we are responsible for retaining, should they still exist at the time that
// this message is released.
}; };
// ----------------------------------------------------------------- // -----------------------------------------------------------------
...@@ -944,7 +1095,6 @@ private: ...@@ -944,7 +1095,6 @@ private:
if (lock->networkException == nullptr) { if (lock->networkException == nullptr) {
for (auto exportId: exports) { for (auto exportId: exports) {
KJ_DBG(&connectionState, exportId);
auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId)); auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId));
if (--exp.refcount == 0) { if (--exp.refcount == 0) {
clientsToRelease.add(kj::mv(exp.clientHook)); clientsToRelease.add(kj::mv(exp.clientHook));
...@@ -1042,21 +1192,24 @@ private: ...@@ -1042,21 +1192,24 @@ private:
public: public:
inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id, inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id,
kj::Own<kj::PromiseFulfiller<kj::Own<const RpcResponse>>> fulfiller) kj::Own<kj::PromiseFulfiller<kj::Own<const RpcResponse>>> fulfiller,
kj::Own<const ResolutionChain> resolutionChain)
: connectionState(kj::addRef(connectionState)), id(id), fulfiller(kj::mv(fulfiller)), : connectionState(kj::addRef(connectionState)), id(id), fulfiller(kj::mv(fulfiller)),
resultCaps(connectionState) {} resultCaps(connectionState, kj::mv(resolutionChain)) {}
~QuestionRef() { ~QuestionRef() {
// Send the "Finish" message. // Send the "Finish" message.
auto message = connectionState->connection->newOutgoingMessage( {
messageSizeHint<rpc::Finish>() + resultCaps.retainedListSizeHint(true)); auto message = connectionState->connection->newOutgoingMessage(
auto builder = message->getBody().getAs<rpc::Message>().initFinish(); messageSizeHint<rpc::Finish>() + resultCaps.retainedListSizeHint(true));
builder.setQuestionId(id); auto builder = message->getBody().getAs<rpc::Message>().initFinish();
builder.setQuestionId(id);
builder.adoptRetainedCaps(resultCaps.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder)));
message->send(); auto retainedCaps = resultCaps.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder));
builder.adoptRetainedCaps(kj::mv(retainedCaps.exportList));
message->send();
}
// Check if the question has returned and, if so, remove it from the table. // Check if the question has returned and, if so, remove it from the table.
// Remove question ID from the table. Must do this *after* sending `Finish` to ensure that // Remove question ID from the table. Must do this *after* sending `Finish` to ensure that
...@@ -1157,7 +1310,8 @@ private: ...@@ -1157,7 +1310,8 @@ private:
question.paramCaps = kj::mv(injector); question.paramCaps = kj::mv(injector);
questionRef = kj::refcounted<QuestionRef>( questionRef = kj::refcounted<QuestionRef>(
*connectionState, questionId, kj::mv(paf.fulfiller)); *connectionState, questionId, kj::mv(paf.fulfiller),
kj::addRef(*lock->resolutionChainTail));
question.selfRef = *questionRef; question.selfRef = *questionRef;
message->send(); message->send();
...@@ -1170,7 +1324,7 @@ private: ...@@ -1170,7 +1324,7 @@ private:
auto forkedPromise = connectionState->eventLoop.fork(kj::mv(promise)); auto forkedPromise = connectionState->eventLoop.fork(kj::mv(promise));
auto appPromise = forkedPromise.addBranch().thenInAnyThread( auto appPromise = forkedPromise.addBranch().thenInAnyThread(
[](kj::Own<const RpcResponse>&& response) { [=](kj::Own<const RpcResponse>&& response) {
auto reader = response->getResults(); auto reader = response->getResults();
return Response<ObjectPointer>(reader, kj::mv(response)); return Response<ObjectPointer>(reader, kj::mv(response));
}); });
...@@ -1244,7 +1398,7 @@ private: ...@@ -1244,7 +1398,7 @@ private:
})); }));
return kj::refcounted<PromiseClient>( return kj::refcounted<PromiseClient>(
*connectionState, kj::mv(pipelineClient), kj::mv(resolutionPromise)); *connectionState, kj::mv(pipelineClient), kj::mv(resolutionPromise), nullptr);
} else if (lock->is<Resolved>()) { } else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops); return lock->get<Resolved>()->getResults().getPipelinedCap(ops);
} else { } else {
...@@ -1346,11 +1500,12 @@ private: ...@@ -1346,11 +1500,12 @@ private:
class RpcCallContext final: public CallContextHook, public kj::Refcounted { class RpcCallContext final: public CallContextHook, public kj::Refcounted {
public: public:
RpcCallContext(const RpcConnectionState& connectionState, QuestionId questionId, RpcCallContext(const RpcConnectionState& connectionState, QuestionId questionId,
kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params) kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params,
kj::Own<const ResolutionChain> resolutionChain)
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
questionId(questionId), questionId(questionId),
request(kj::mv(request)), request(kj::mv(request)),
requestCapExtractor(connectionState), requestCapExtractor(connectionState, kj::mv(resolutionChain)),
requestCapContext(requestCapExtractor), requestCapContext(requestCapExtractor),
params(requestCapContext.imbue(params)), params(requestCapContext.imbue(params)),
returnMessage(nullptr) {} returnMessage(nullptr) {}
...@@ -1360,8 +1515,9 @@ private: ...@@ -1360,8 +1515,9 @@ private:
if (response == nullptr) getResults(1); // force initialization of response if (response == nullptr) getResults(1); // force initialization of response
returnMessage.setQuestionId(questionId); returnMessage.setQuestionId(questionId);
returnMessage.adoptRetainedCaps(requestCapExtractor.finalizeRetainedCaps( auto retainedCaps = requestCapExtractor.finalizeRetainedCaps(
Orphanage::getForMessageContaining(returnMessage))); Orphanage::getForMessageContaining(returnMessage));
returnMessage.adoptRetainedCaps(kj::mv(retainedCaps.exportList));
KJ_ASSERT_NONNULL(response)->send(); KJ_ASSERT_NONNULL(response)->send();
} }
...@@ -1374,8 +1530,9 @@ private: ...@@ -1374,8 +1530,9 @@ private:
auto builder = message->getBody().initAs<rpc::Message>().initReturn(); auto builder = message->getBody().initAs<rpc::Message>().initReturn();
builder.setQuestionId(questionId); builder.setQuestionId(questionId);
builder.adoptRetainedCaps(requestCapExtractor.finalizeRetainedCaps( auto retainedCaps = requestCapExtractor.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder))); Orphanage::getForMessageContaining(builder));
builder.adoptRetainedCaps(kj::mv(retainedCaps.exportList));
fromException(exception, builder.initException()); fromException(exception, builder.initException());
message->send(); message->send();
...@@ -1388,8 +1545,9 @@ private: ...@@ -1388,8 +1545,9 @@ private:
auto builder = message->getBody().initAs<rpc::Message>().initReturn(); auto builder = message->getBody().initAs<rpc::Message>().initReturn();
builder.setQuestionId(questionId); builder.setQuestionId(questionId);
builder.adoptRetainedCaps(requestCapExtractor.finalizeRetainedCaps( auto retainedCaps = requestCapExtractor.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder))); Orphanage::getForMessageContaining(builder));
builder.adoptRetainedCaps(kj::mv(retainedCaps.exportList));
builder.setCanceled(); builder.setCanceled();
message->send(); message->send();
...@@ -1626,7 +1784,7 @@ private: ...@@ -1626,7 +1784,7 @@ private:
break; break;
case rpc::Message::RESOLVE: case rpc::Message::RESOLVE:
// TODO(now) handleResolve(reader.getResolve());
break; break;
case rpc::Message::RELEASE: case rpc::Message::RELEASE:
...@@ -1720,8 +1878,11 @@ private: ...@@ -1720,8 +1878,11 @@ private:
} }
QuestionId questionId = call.getQuestionId(); QuestionId questionId = call.getQuestionId();
// Note: resolutionChainTail couldn't possibly be changing here because we only handle one
// message at a time, so we can hold off locking the tables for a bit longer.
auto context = kj::refcounted<RpcCallContext>( auto context = kj::refcounted<RpcCallContext>(
*this, questionId, kj::mv(message), call.getParams()); *this, questionId, kj::mv(message), call.getParams(),
kj::addRef(*tables.getWithoutLock().resolutionChainTail));
auto promiseAndPipeline = capability->call( auto promiseAndPipeline = capability->call(
call.getInterfaceId(), call.getMethodId(), context->addRef()); call.getInterfaceId(), call.getMethodId(), context->addRef());
...@@ -1753,6 +1914,12 @@ private: ...@@ -1753,6 +1914,12 @@ private:
contextPtr->sendReturn(); contextPtr->sendReturn();
}, [contextPtr](kj::Exception&& exception) { }, [contextPtr](kj::Exception&& exception) {
contextPtr->sendErrorReturn(kj::mv(exception)); contextPtr->sendErrorReturn(kj::mv(exception));
}).then([]() {
// Success.
}, [&](kj::Exception&& exception) {
// We never actually wait on `asyncOp` so we need to manually report exceptions.
// TODO(cleanup): Perhaps there should be a better, more-automated approach to this?
taskFailed(kj::mv(exception));
}); });
answer.asyncOp.attach(kj::mv(context)); answer.asyncOp.attach(kj::mv(context));
answer.asyncOp.eagerlyEvaluate(eventLoop); answer.asyncOp.eagerlyEvaluate(eventLoop);
...@@ -1851,6 +2018,51 @@ private: ...@@ -1851,6 +2018,51 @@ private:
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Level 1 // Level 1
void handleResolve(const rpc::Resolve::Reader& resolve) {
kj::Own<ResolutionChain> oldResolutionChainTail; // must be freed outside of lock
auto lock = tables.lockExclusive();
kj::Own<const ClientHook> replacement;
// Extract the replacement capability.
switch (resolve.which()) {
case rpc::Resolve::CAP:
replacement = CapExtractorImpl::extractCapAndAddRef(*this, *lock, resolve.getCap());
break;
case rpc::Resolve::EXCEPTION:
replacement = newBrokenCap(toException(resolve.getException()));
break;
case rpc::Resolve::CANCELED:
// Right, this can't possibly affect anything, then.
//
// TODO(now): Am I doing something wrong or is this not needed?
return;
default:
KJ_FAIL_REQUIRE("Unknown 'Resolve' type.") { return; }
}
// Extend the resolution chain.
auto oldTail = kj::mv(lock->resolutionChainTail);
lock->resolutionChainTail = oldTail->add(resolve.getPromiseId(), kj::mv(replacement));
lock.release(); // in case oldTail is destroyed
// If the import is on the table, fulfill it.
KJ_IF_MAYBE(import, lock->imports.find(resolve.getPromiseId())) {
KJ_IF_MAYBE(fulfiller, import->promiseFulfiller) {
// OK, this is in fact an unfulfilled promise!
fulfiller->get()->fulfill(kj::mv(replacement));
} else if (import->importClient != nullptr) {
// It appears this is a valid entry on the import table, but was not expected to be a
// promise.
KJ_FAIL_REQUIRE("Got 'Resolve' for a non-promise import.") { break; }
}
}
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Level 2 // Level 2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment