Commit f4ab3d2d authored by Kenton Varda's avatar Kenton Varda

Refactor how promise imports are represented in the RPC implementation.

parent 2aa37a8a
...@@ -280,9 +280,17 @@ struct Export { ...@@ -280,9 +280,17 @@ struct Export {
template <typename ImportClient> template <typename ImportClient>
struct Import { struct Import {
ImportClient* client = nullptr; Import() = default;
// Normally I'd want this to be Maybe<ImportClient&>, but GCC's unordered_map doesn't seem to Import(const Import&) = delete;
// like DisableConstCopy types. Import(Import&&) = default;
Import& operator=(Import&&) = default;
// If we don't explicitly write all this, we get some stupid error deep in STL.
kj::Maybe<ImportClient&> client;
// Becomes null when the import is destroyed.
kj::Maybe<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> promiseFulfiller;
// If non-null, the import is a promise.
}; };
// ======================================================================================= // =======================================================================================
...@@ -339,21 +347,26 @@ public: ...@@ -339,21 +347,26 @@ public:
auto pipeline = kj::refcounted<RpcPipeline>( auto pipeline = kj::refcounted<RpcPipeline>(
*this, questionId, eventLoop.fork(kj::mv(promiseWithQuestionRef))); *this, questionId, eventLoop.fork(kj::mv(promiseWithQuestionRef)));
return kj::refcounted<PromisedAnswerClient>(*this, kj::mv(pipeline), nullptr); return pipeline->getPipelinedCap(kj::Array<const PipelineOp>(nullptr));
} }
void taskFailed(kj::Exception&& exception) override { void taskFailed(kj::Exception&& exception) override {
{ {
kj::Exception networkException(
kj::Exception::Nature::NETWORK_FAILURE, kj::Exception::Durability::PERMANENT,
"", 0, kj::str("Disconnected: ", exception.getDescription()));
kj::Vector<kj::Own<const PipelineHook>> pipelinesToRelease; kj::Vector<kj::Own<const PipelineHook>> pipelinesToRelease;
kj::Vector<kj::Own<const ClientHook>> clientsToRelease; kj::Vector<kj::Own<const ClientHook>> clientsToRelease;
kj::Vector<kj::Own<CapInjectorImpl>> paramCapsToRelease; kj::Vector<kj::Own<CapInjectorImpl>> paramCapsToRelease;
auto lock = tables.lockExclusive(); auto lock = tables.lockExclusive();
if (lock->networkException != nullptr) {
// Oops, already disconnected.
return;
}
kj::Exception networkException(
kj::Exception::Nature::NETWORK_FAILURE, kj::Exception::Durability::PERMANENT,
__FILE__, __LINE__, kj::str("Disconnected: ", exception.getDescription()));
// All current questions complete with exceptions. // All current questions complete with exceptions.
lock->questions.forEach([&](QuestionId id, lock->questions.forEach([&](QuestionId id,
Question<CapInjectorImpl, RpcPipeline, RpcResponse>& question) { Question<CapInjectorImpl, RpcPipeline, RpcResponse>& question) {
...@@ -379,8 +392,8 @@ public: ...@@ -379,8 +392,8 @@ public:
}); });
lock->imports.forEach([&](ExportId id, Import<ImportClient>& import) { lock->imports.forEach([&](ExportId id, Import<ImportClient>& import) {
if (import.client != nullptr) { KJ_IF_MAYBE(f, import.promiseFulfiller) {
import.client->disconnect(kj::cp(networkException)); f->get()->reject(kj::cp(networkException));
} }
}); });
...@@ -407,6 +420,7 @@ private: ...@@ -407,6 +420,7 @@ private:
kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller; kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller;
class ImportClient; class ImportClient;
class PromiseClient;
class CapInjectorImpl; class CapInjectorImpl;
class CapExtractorImpl; class CapExtractorImpl;
class RpcPipeline; class RpcPipeline;
...@@ -414,10 +428,11 @@ private: ...@@ -414,10 +428,11 @@ private:
class RpcResponse; class RpcResponse;
struct Tables { struct Tables {
ExportTable<ExportId, Export> exports;
ExportTable<QuestionId, Question<CapInjectorImpl, RpcPipeline, RpcResponse>> questions; ExportTable<QuestionId, Question<CapInjectorImpl, RpcPipeline, RpcResponse>> questions;
ImportTable<QuestionId, Answer<RpcCallContext>> answers; ImportTable<QuestionId, Answer<RpcCallContext>> answers;
ExportTable<ExportId, Export> exports;
ImportTable<ExportId, Import<ImportClient>> imports; ImportTable<ExportId, Import<ImportClient>> imports;
// The order of the tables is important for correct destruction.
std::unordered_map<const ClientHook*, ExportId> exportsByCap; std::unordered_map<const ClientHook*, ExportId> exportsByCap;
// Maps already-exported ClientHook objects to their ID in the export table. // Maps already-exported ClientHook objects to their ID in the export table.
...@@ -455,6 +470,7 @@ private: ...@@ -455,6 +470,7 @@ private:
const RpcConnectionState& connectionState; const RpcConnectionState& connectionState;
}; };
// TODO(now): unused?
ExportDisposer exportDisposer; ExportDisposer exportDisposer;
// ===================================================================================== // =====================================================================================
...@@ -546,12 +562,13 @@ private: ...@@ -546,12 +562,13 @@ private:
kj::Own<const RpcConnectionState> connectionState; kj::Own<const RpcConnectionState> connectionState;
}; };
class ImportClient: public RpcClient { class ImportClient final: public RpcClient {
protected: // A ClientHook that wraps an entry in the import table.
public:
ImportClient(const RpcConnectionState& connectionState, ExportId importId) ImportClient(const RpcConnectionState& connectionState, ExportId importId)
: RpcClient(connectionState), importId(importId) {} : RpcClient(connectionState), importId(importId) {}
public:
~ImportClient() noexcept(false) { ~ImportClient() noexcept(false) {
{ {
// Remove self from the import table, if the table is still pointing at us. (It's possible // Remove self from the import table, if the table is still pointing at us. (It's possible
...@@ -560,11 +577,13 @@ private: ...@@ -560,11 +577,13 @@ private:
// the import table.) // the import table.)
auto lock = connectionState->tables.lockExclusive(); auto lock = connectionState->tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports.find(importId)) { KJ_IF_MAYBE(import, lock->imports.find(importId)) {
if (import->client == this) { KJ_IF_MAYBE(i, import->client) {
if (i == this) {
lock->imports.erase(importId); lock->imports.erase(importId);
} }
} }
} }
}
// Send a message releasing our remote references. // Send a message releasing our remote references.
if (remoteRefcount > 0) { if (remoteRefcount > 0) {
...@@ -572,13 +591,6 @@ private: ...@@ -572,13 +591,6 @@ private:
} }
} }
virtual bool settle(kj::Own<const ClientHook> replacement) = 0;
// Replace the PromiseImportClient with its resolution. Returns false if this is not a promise
// (i.e. it is a SettledImportClient).
virtual void disconnect(kj::Exception&& exception) = 0;
// Cause whenMoreResolved() to fail.
kj::Maybe<kj::Own<ImportClient>> tryAddRemoteRef() { kj::Maybe<kj::Own<ImportClient>> tryAddRemoteRef() {
// Add a new RemoteRef and return a new ref to this client representing it. Returns null // Add a new RemoteRef and return a new ref to this client representing it. Returns null
// if this client is being deleted in another thread, in which case the caller should // if this client is being deleted in another thread, in which case the caller should
...@@ -620,6 +632,10 @@ private: ...@@ -620,6 +632,10 @@ private:
return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request)); return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
} }
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return nullptr;
}
private: private:
ExportId importId; ExportId importId;
...@@ -627,167 +643,113 @@ private: ...@@ -627,167 +643,113 @@ private:
// Number of times we've received this import from the peer. // Number of times we've received this import from the peer.
}; };
class SettledImportClient final: public ImportClient { class PipelineClient final: public RpcClient {
// A ClientHook representing a pipelined promise. Always wrapped in PromiseClient.
public: public:
inline SettledImportClient(const RpcConnectionState& connectionState, ExportId importId) PipelineClient(const RpcConnectionState& connectionState,
: ImportClient(connectionState, importId) {} kj::Own<const RpcPipeline>&& pipeline,
kj::Array<PipelineOp>&& ops)
: RpcClient(connectionState), pipeline(kj::mv(pipeline)), ops(kj::mv(ops)) {}
bool settle(kj::Own<const ClientHook> replacement) override { kj::Maybe<ExportId> writeDescriptor(
return false; rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
return pipeline->writeDescriptor(descriptor, tables, ops);
} }
void disconnect(kj::Exception&& exception) override { kj::Maybe<kj::Own<const ClientHook>> writeTarget(
// nothing rpc::Call::Target::Builder target) const override {
// TODO(now): The pipeline may redirect to the resolution before PromiseClient has resolved.
// This could lead to a race condition if PromiseClient implements embargoes.
return pipeline->writeTarget(target, ops);
} }
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override { // implements ClientHook -----------------------------------------
return nullptr;
}
};
class PromiseImportClient final: public ImportClient { Request<ObjectPointer, ObjectPointer> newCall(
public: uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
PromiseImportClient(const RpcConnectionState& connectionState, ExportId importId) auto request = kj::heap<RpcRequest>(
: ImportClient(connectionState, importId), *connectionState, firstSegmentWordSize, kj::addRef(*this));
fork(nullptr) { auto callBuilder = request->getCall();
auto paf = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>(connectionState.eventLoop);
fulfiller = kj::mv(paf.fulfiller);
fork = connectionState.eventLoop.fork(kj::mv(paf.promise));
}
bool settle(kj::Own<const ClientHook> replacement) override { callBuilder.setInterfaceId(interfaceId);
fulfiller->fulfill(kj::mv(replacement)); callBuilder.setMethodId(methodId);
return true;
}
void disconnect(kj::Exception&& exception) override { auto root = request->getRoot();
fulfiller->reject(kj::mv(exception)); return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
} }
// TODO(now): Override writeDescriptor() and writeTarget() to redirect once the promise
// resolves.
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override { kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
// We need the returned promise to hold a reference back to this object, so that it doesn't return nullptr;
// disappear while the promise is still outstanding.
return fork.addBranch().thenInAnyThread(kj::mvCapture(kj::addRef(*this),
[](kj::Own<const PromiseImportClient>&&, kj::Own<const ClientHook>&& replacement) {
return kj::mv(replacement);
}));
} }
private: private:
kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>> fulfiller; kj::Own<const RpcPipeline> pipeline;
kj::ForkedPromise<kj::Own<const ClientHook>> fork; kj::Array<PipelineOp> ops;
}; };
class PromisedAnswerClient final: public RpcClient { class PromiseClient final: public RpcClient {
// A ClientHook that initially wraps one client (in practice, an ImportClient or a
// PipelineClient) and then, later on, redirects to some other client.
public: public:
PromisedAnswerClient(const RpcConnectionState& connectionState, PromiseClient(const RpcConnectionState& connectionState,
kj::Own<const RpcPipeline>&& pipeline, kj::Own<const ClientHook> initial,
kj::Array<PipelineOp>&& ops) kj::Promise<kj::Own<const ClientHook>> eventual)
: RpcClient(connectionState), ops(kj::mv(ops)), : RpcClient(connectionState),
resolveSelfPromise(connectionState.eventLoop.there(pipeline->onResponse(), inner(kj::mv(initial)),
[this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> { fork(connectionState.eventLoop.fork(kj::mv(eventual))),
resolve(kj::mv(response)); resolveSelfPromise(connectionState.eventLoop.there(fork.addBranch(),
return kj::READY_NOW; // hack to force eager resolution. [this](kj::Own<const ClientHook>&& resolution) {
}, [this](kj::Exception&& exception) -> kj::Promise<void> { resolve(kj::mv(resolution));
resolve(kj::mv(exception)); }, [this](kj::Exception&& exception) {
return kj::READY_NOW; // hack to force eager resolution. resolve(newBrokenCap(kj::mv(exception)));
})) { })) {
state.getWithoutLock().init<Waiting>(kj::mv(pipeline)); // Create a client that starts out forwarding all calls to `initial` but, once `eventual`
// resolves, will forward there instead. In addition, `whenMoreResolved()` will return a fork
// of `eventual`. Note that this means the application could hold on to `eventual` even after
// the `PromiseClient` is destroyed; `eventual` must therefore make sure to hold references to
// anything that needs to stay alive in order to resolve it correctly (such as making sure the
// import ID is not released).
resolveSelfPromise.eagerlyEvaluate(connectionState.eventLoop);
} }
kj::Maybe<ExportId> writeDescriptor( kj::Maybe<ExportId> writeDescriptor(
rpc::CapDescriptor::Builder descriptor, Tables& tables) const override { rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
auto lock = state.lockShared(); auto cap = inner.lockExclusive()->get()->addRef();
return connectionState->writeDescriptor(kj::mv(cap), descriptor, tables);
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->writeDescriptor(descriptor, tables, ops);
} else if (lock->is<Resolved>()) {
return connectionState->writeDescriptor(
lock->get<Resolved>()->addRef(), descriptor, tables);
} else {
return connectionState->writeDescriptor(
newBrokenCap(kj::cp(lock->get<Broken>())), descriptor, tables);
}
} }
kj::Maybe<kj::Own<const ClientHook>> writeTarget( kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target) const override { rpc::Call::Target::Builder target) const override {
auto lock = state.lockShared(); return connectionState->writeTarget(**inner.lockExclusive(), target);
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->writeTarget(target, ops);
} else if (lock->is<Resolved>()) {
return connectionState->writeTarget(*lock->get<Resolved>(), target);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()));
}
} }
// implements ClientHook ----------------------------------------- // implements ClientHook -----------------------------------------
Request<ObjectPointer, ObjectPointer> newCall( Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override { uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto lock = state.lockShared(); return inner.lockExclusive()->get()->newCall(interfaceId, methodId, firstSegmentWordSize);
if (lock->is<Waiting>()) {
auto request = kj::heap<RpcRequest>(
*connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall();
callBuilder.setInterfaceId(interfaceId);
callBuilder.setMethodId(methodId);
auto root = request->getRoot();
return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->newCall(interfaceId, methodId, firstSegmentWordSize);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()))->newCall(
interfaceId, methodId, firstSegmentWordSize);
}
} }
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override { kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
auto lock = state.lockShared(); return fork.addBranch();
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->onResponse().thenInAnyThread(
kj::mvCapture(kj::heapArray(ops.asPtr()),
[](kj::Array<PipelineOp>&& ops, kj::Own<const RpcResponse>&& response) {
return response->getResults().getPipelinedCap(ops);
}));
} else if (lock->is<Resolved>()) {
return kj::Promise<kj::Own<const ClientHook>>(lock->get<Resolved>()->addRef());
} else {
return kj::Promise<kj::Own<const ClientHook>>(kj::cp(lock->get<Broken>()));
}
} }
private: private:
kj::Array<PipelineOp> ops; kj::MutexGuarded<kj::Own<const ClientHook>> inner;
kj::ForkedPromise<kj::Own<const ClientHook>> fork;
typedef kj::Own<const RpcPipeline> Waiting;
typedef kj::Own<const ClientHook> Resolved;
typedef kj::Exception Broken;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, Broken>> state;
// Keep this last, because the continuation uses *this, so it should be destroyed first to // Keep this last, because the continuation uses *this, so it should be destroyed first to
// ensure the continuation is not still running. // ensure the continuation is not still running.
kj::Promise<void> resolveSelfPromise; kj::Promise<void> resolveSelfPromise;
void resolve(kj::Own<const RpcResponse>&& response) { void resolve(kj::Own<const ClientHook> replacement) {
auto lock = state.lockExclusive(); // Careful to make sure the old client is not destroyed until we release the lock.
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?"); kj::Own<const ClientHook> old;
lock->init<Resolved>(response->getResults().getPipelinedCap(ops)); auto lock = inner.lockExclusive();
} old = kj::mv(*lock);
*lock = replacement->addRef();
void resolve(const kj::Exception&& exception) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Broken>(kj::mv(exception));
} }
}; };
...@@ -813,6 +775,7 @@ private: ...@@ -813,6 +775,7 @@ private:
exp.refcount = 1; exp.refcount = 1;
exp.clientHook = kj::mv(cap); exp.clientHook = kj::mv(cap);
descriptor.setSenderHosted(exportId); descriptor.setSenderHosted(exportId);
KJ_DBG(this, exportId);
return exportId; return exportId;
} }
} }
...@@ -876,12 +839,15 @@ private: ...@@ -876,12 +839,15 @@ private:
for (ExportId importId: retainedCaps) { for (ExportId importId: retainedCaps) {
// Check if the import still exists under this ID. // Check if the import still exists under this ID.
KJ_IF_MAYBE(import, lock->imports.find(importId)) { KJ_IF_MAYBE(import, lock->imports.find(importId)) {
if (import->client != nullptr && import->client->tryAddRemoteRef() != nullptr) { KJ_IF_MAYBE(i, import->client) {
if (i->tryAddRemoteRef() != nullptr) {
// Import indeed still exists! We are responsible for retaining it. // Import indeed still exists! We are responsible for retaining it.
// TODO(now): Do we need to hold on to the ref that tryAddRemoteRef() returned?
*actualRetained++ = importId; *actualRetained++ = importId;
} }
} }
} }
}
uint count = actualRetained - retainedCaps.begin(); uint count = actualRetained - retainedCaps.begin();
...@@ -907,10 +873,10 @@ private: ...@@ -907,10 +873,10 @@ private:
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState.tables.lockExclusive();
auto& import = lock->imports[importId]; auto& import = lock->imports[importId];
if (import.client != nullptr) { KJ_IF_MAYBE(i, import.client) {
// The import is already on the table, but it could be being deleted in another // The import is already on the table, but it could be being deleted in another
// thread. // thread.
KJ_IF_MAYBE(ref, kj::tryAddRef(*import.client)) { KJ_IF_MAYBE(ref, kj::tryAddRef(*i)) {
// We successfully grabbed a reference to the import without it being deleted in // We successfully grabbed a reference to the import without it being deleted in
// another thread. Since this import already exists, we don't have to take // another thread. Since this import already exists, we don't have to take
// responsibility for retaining it. We can just return the existing object and // responsibility for retaining it. We can just return the existing object and
...@@ -920,15 +886,23 @@ private: ...@@ -920,15 +886,23 @@ private:
} }
// No import for this ID exists currently, so create one. // No import for this ID exists currently, so create one.
kj::Own<ImportClient> result; kj::Own<ImportClient> importClient =
kj::refcounted<ImportClient>(connectionState, importId);
import.client = *importClient;
kj::Own<ClientHook> result;
if (descriptor.which() == rpc::CapDescriptor::SENDER_PROMISE) { if (descriptor.which() == rpc::CapDescriptor::SENDER_PROMISE) {
// TODO(now): Check for pending `Resolve` messages replacing this import ID, and if // TODO(now): Check for pending `Resolve` messages replacing this import ID, and if
// one exists, use that client instead. // one exists, use that client instead.
result = kj::refcounted<PromiseImportClient>(connectionState, importId);
auto paf = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>();
import.promiseFulfiller = kj::mv(paf.fulfiller);
paf.promise.attach(kj::addRef(*importClient));
result = kj::refcounted<PromiseClient>(
connectionState, kj::mv(importClient), kj::mv(paf.promise));
} else { } else {
result = kj::refcounted<SettledImportClient>(connectionState, importId); result = kj::mv(importClient);
} }
import.client = result;
// Note that we need to retain this import later if it still exists. // Note that we need to retain this import later if it still exists.
retainedCaps.lockExclusive()->add(importId); retainedCaps.lockExclusive()->add(importId);
...@@ -993,6 +967,7 @@ private: ...@@ -993,6 +967,7 @@ private:
if (lock->networkException == nullptr) { if (lock->networkException == nullptr) {
for (auto exportId: exports) { for (auto exportId: exports) {
KJ_DBG(&connectionState, exportId);
auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId)); auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId));
if (--exp.refcount == 0) { if (--exp.refcount == 0) {
clientsToRelease.add(kj::mv(exp.clientHook)); clientsToRelease.add(kj::mv(exp.clientHook));
...@@ -1021,6 +996,8 @@ private: ...@@ -1021,6 +996,8 @@ private:
auto maybeExportId = connectionState.writeDescriptor( auto maybeExportId = connectionState.writeDescriptor(
entry.second.cap->addRef(), entry.second.builder, tables); entry.second.cap->addRef(), entry.second.builder, tables);
KJ_IF_MAYBE(exportId, maybeExportId) { KJ_IF_MAYBE(exportId, maybeExportId) {
KJ_ASSERT(tables.exports.find(*exportId) != nullptr);
KJ_DBG(&connectionState, *exportId);
exports.add(*exportId); exports.add(*exportId);
} }
} }
...@@ -1245,15 +1222,14 @@ private: ...@@ -1245,15 +1222,14 @@ private:
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
redirectLater(kj::mv(redirectLaterParam)), redirectLater(kj::mv(redirectLaterParam)),
resolveSelfPromise(connectionState.eventLoop.there(redirectLater.addBranch(), resolveSelfPromise(connectionState.eventLoop.there(redirectLater.addBranch(),
[this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> { [this](kj::Own<const RpcResponse>&& response) {
resolve(kj::mv(response)); resolve(kj::mv(response));
return kj::READY_NOW; // hack to force eager resolution. }, [this](kj::Exception&& exception) {
}, [this](kj::Exception&& exception) -> kj::Promise<void> {
resolve(kj::mv(exception)); resolve(kj::mv(exception));
return kj::READY_NOW;
})) { })) {
// Construct a new RpcPipeline. // Construct a new RpcPipeline.
resolveSelfPromise.eagerlyEvaluate(connectionState.eventLoop);
state.getWithoutLock().init<Waiting>(questionId); state.getWithoutLock().init<Waiting>(questionId);
} }
...@@ -1318,8 +1294,18 @@ private: ...@@ -1318,8 +1294,18 @@ private:
kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const override { kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const override {
auto lock = state.lockExclusive(); auto lock = state.lockExclusive();
if (lock->is<Waiting>()) { if (lock->is<Waiting>()) {
return kj::refcounted<PromisedAnswerClient>( // Wrap a PipelineClient in a PromiseClient.
*connectionState, kj::addRef(*this), kj::mv(ops)); auto pipelineClient = kj::refcounted<PipelineClient>(
*connectionState, kj::addRef(*this), kj::heapArray(ops.asPtr()));
auto resolutionPromise = connectionState->eventLoop.there(redirectLater.addBranch(),
kj::mvCapture(ops,
[](kj::Array<PipelineOp> ops, kj::Own<const RpcResponse>&& response) {
return response->getResults().getPipelinedCap(ops);
}));
return kj::refcounted<PromiseClient>(
*connectionState, kj::mv(pipelineClient), kj::mv(resolutionPromise));
} else if (lock->is<Resolved>()) { } else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops); return lock->get<Resolved>()->getResults().getPipelinedCap(ops);
} else { } else {
...@@ -1633,11 +1619,9 @@ private: ...@@ -1633,11 +1619,9 @@ private:
// actual deletion asynchronously. But we have to remove it from the table *now*, while // actual deletion asynchronously. But we have to remove it from the table *now*, while
// we still hold the lock, because once we send the return message the answer ID is free // we still hold the lock, because once we send the return message the answer ID is free
// for reuse. // for reuse.
connectionState->tasks.add(connectionState->eventLoop.evalLater( auto promise = connectionState->eventLoop.evalLater([]() {});
kj::mvCapture(lock->answers[questionId], promise.attach(kj::mv(lock->answers[questionId]));
[](Answer<RpcCallContext>&& answer) { connectionState->tasks.add(kj::mv(promise));
// Just let the answer be deleted.
})));
// Erase from the table. // Erase from the table.
lock->answers.erase(questionId); lock->answers.erase(questionId);
...@@ -1830,17 +1814,14 @@ private: ...@@ -1830,17 +1814,14 @@ private:
// refcount, but both will be destroyed at the same time anyway. // refcount, but both will be destroyed at the same time anyway.
RpcCallContext* contextPtr = context; RpcCallContext* contextPtr = context;
// TODO(cleanup): We have the continuations return Promise<void> rather than void because
// this tricks the promise framework into making sure the continuations actually run
// without anyone waiting on them. We should find a cleaner way to do this.
answer.asyncOp = promiseAndPipeline.promise.then( answer.asyncOp = promiseAndPipeline.promise.then(
kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) -> kj::Promise<void> { [contextPtr]() {
context->sendReturn(); contextPtr->sendReturn();
return kj::READY_NOW; }, [contextPtr](kj::Exception&& exception) {
}), [contextPtr](kj::Exception&& exception) -> kj::Promise<void> {
contextPtr->sendErrorReturn(kj::mv(exception)); contextPtr->sendErrorReturn(kj::mv(exception));
return kj::READY_NOW;
}); });
answer.asyncOp.attach(kj::mv(context));
answer.asyncOp.eagerlyEvaluate(eventLoop);
} }
} }
...@@ -2033,7 +2014,11 @@ public: ...@@ -2033,7 +2014,11 @@ public:
auto& connectionMap = connections.getWithoutLock(); auto& connectionMap = connections.getWithoutLock();
if (!connectionMap.empty()) { if (!connectionMap.empty()) {
kj::Vector<kj::Own<RpcConnectionState>> deleteMe(connectionMap.size()); kj::Vector<kj::Own<RpcConnectionState>> deleteMe(connectionMap.size());
kj::Exception shutdownException(
kj::Exception::Nature::LOCAL_BUG, kj::Exception::Durability::PERMANENT,
__FILE__, __LINE__, kj::str("RpcSystem was destroyed."));
for (auto& entry: connectionMap) { for (auto& entry: connectionMap) {
entry.second->taskFailed(kj::cp(shutdownException));
deleteMe.add(kj::mv(entry.second)); deleteMe.add(kj::mv(entry.second));
} }
} }
...@@ -2053,9 +2038,7 @@ public: ...@@ -2053,9 +2038,7 @@ public:
} }
void taskFailed(kj::Exception&& exception) override { void taskFailed(kj::Exception&& exception) override {
// TODO(now): What do we do? KJ_LOG(ERROR, exception);
kj::throwRecoverableException(kj::mv(exception));
} }
private: private:
......
...@@ -428,7 +428,7 @@ struct Resolve { ...@@ -428,7 +428,7 @@ struct Resolve {
# When a promise ID is first sent over the wire (e.g. in a `CapDescriptor`), the sender (exporter) # When a promise ID is first sent over the wire (e.g. in a `CapDescriptor`), the sender (exporter)
# guarantees that it will follow up at some point with exactly one `Resolve` message. If the # guarantees that it will follow up at some point with exactly one `Resolve` message. If the
# same `promiseId` is sent again before `Resolve`, still only one `Resolve` is sent. If the # same `promiseId` is sent again before `Resolve`, still only one `Resolve` is sent. If the
# same ID is reused again later _after_ a `Resolve`, it can only be because the export's # same ID is sent again later _after_ a `Resolve`, it can only be because the export's
# reference count hit zero in the meantime and the ID was re-assigned to a new export, therefore # reference count hit zero in the meantime and the ID was re-assigned to a new export, therefore
# this later promise does _not_ correspond to the earlier `Resolve`. # this later promise does _not_ correspond to the earlier `Resolve`.
# #
...@@ -816,10 +816,10 @@ struct CapDescriptor { ...@@ -816,10 +816,10 @@ struct CapDescriptor {
senderPromise @1 :ExportId; senderPromise @1 :ExportId;
# A promise which the sender will resolve later. The sender will send exactly one Resolve # A promise which the sender will resolve later. The sender will send exactly one Resolve
# message at a future point in time to replace this promise. # message at a future point in time to replace this promise. Note that even if the same
# # `senderPromise` is received multiple times, only one `Resolve` is sent to cover all of
# TODO(soon): Can we merge this with senderHosted? Change `Resolve` to be allowed on any # them. The `Resolve` is delivered even if `senderPromise` is not retained, or is retained
# export (i.e. it can be delivered zero or one times). Maybe rename it to `Replace`. # but then released before the `Resolve` is sent.
receiverHosted @2 :ExportId; receiverHosted @2 :ExportId;
# A capability (or promise) previously exported by the receiver. # A capability (or promise) previously exported by the receiver.
......
...@@ -259,6 +259,7 @@ inline Array<T> heapArray(size_t size) { ...@@ -259,6 +259,7 @@ inline Array<T> heapArray(size_t size) {
} }
template <typename T> Array<T> heapArray(const T* content, size_t size); template <typename T> Array<T> heapArray(const T* content, size_t size);
template <typename T> Array<T> heapArray(ArrayPtr<T> content);
template <typename T> Array<T> heapArray(ArrayPtr<const T> content); template <typename T> Array<T> heapArray(ArrayPtr<const T> content);
template <typename T, typename Iterator> Array<T> heapArray(Iterator begin, Iterator end); template <typename T, typename Iterator> Array<T> heapArray(Iterator begin, Iterator end);
template <typename T> Array<T> heapArray(std::initializer_list<T> init); template <typename T> Array<T> heapArray(std::initializer_list<T> init);
...@@ -654,6 +655,13 @@ Array<T> heapArray(const T* content, size_t size) { ...@@ -654,6 +655,13 @@ Array<T> heapArray(const T* content, size_t size) {
return builder.finish(); return builder.finish();
} }
template <typename T>
Array<T> heapArray(ArrayPtr<T> content) {
ArrayBuilder<T> builder = heapArrayBuilder<T>(content.size());
builder.addAll(content);
return builder.finish();
}
template <typename T> template <typename T>
Array<T> heapArray(ArrayPtr<const T> content) { Array<T> heapArray(ArrayPtr<const T> content) {
ArrayBuilder<T> builder = heapArrayBuilder<T>(content.size()); ArrayBuilder<T> builder = heapArrayBuilder<T>(content.size());
......
...@@ -536,5 +536,59 @@ TEST(Async, EventLoopGuarded) { ...@@ -536,5 +536,59 @@ TEST(Async, EventLoopGuarded) {
} }
} }
class DestructorDetector {
public:
DestructorDetector(bool& setTrue): setTrue(setTrue) {}
~DestructorDetector() { setTrue = true; }
private:
bool& setTrue;
};
TEST(Async, Attach) {
bool destroyed = false;
SimpleEventLoop loop;
Promise<int> promise = loop.evalLater([&]() {
EXPECT_FALSE(destroyed);
return 123;
});
promise.attach(kj::heap<DestructorDetector>(destroyed));
promise = loop.there(kj::mv(promise), [&](int i) {
EXPECT_TRUE(destroyed);
return i + 321;
});
EXPECT_FALSE(destroyed);
EXPECT_EQ(444, loop.wait(kj::mv(promise)));
EXPECT_TRUE(destroyed);
}
TEST(Async, EagerlyEvaluate) {
bool called = false;
SimpleEventLoop loop;
Promise<void> promise = nullptr;
loop.wait(loop.evalLater([&]() {
promise = Promise<void>(READY_NOW).then([&]() {
called = true;
});
}));
loop.wait(loop.evalLater([]() {}));
EXPECT_FALSE(called);
promise.eagerlyEvaluate(loop);
loop.wait(loop.evalLater([]() {}));
EXPECT_TRUE(called);
}
} // namespace } // namespace
} // namespace kj } // namespace kj
...@@ -350,6 +350,27 @@ void ImmediateBrokenPromiseNode::get(ExceptionOrValue& output) noexcept { ...@@ -350,6 +350,27 @@ void ImmediateBrokenPromiseNode::get(ExceptionOrValue& output) noexcept {
// ------------------------------------------------------------------- // -------------------------------------------------------------------
AttachmentPromiseNodeBase::AttachmentPromiseNodeBase(Own<PromiseNode>&& dependency)
: dependency(kj::mv(dependency)) {}
bool AttachmentPromiseNodeBase::onReady(EventLoop::Event& event) noexcept {
return dependency->onReady(event);
}
void AttachmentPromiseNodeBase::get(ExceptionOrValue& output) noexcept {
dependency->get(output);
}
Maybe<const EventLoop&> AttachmentPromiseNodeBase::getSafeEventLoop() noexcept {
return dependency->getSafeEventLoop();
}
void AttachmentPromiseNodeBase::dropDependency() {
dependency = nullptr;
}
// -------------------------------------------------------------------
TransformPromiseNodeBase::TransformPromiseNodeBase( TransformPromiseNodeBase::TransformPromiseNodeBase(
Maybe<const EventLoop&> loop, Own<PromiseNode>&& dependency) Maybe<const EventLoop&> loop, Own<PromiseNode>&& dependency)
: loop(loop), dependency(kj::mv(dependency)) {} : loop(loop), dependency(kj::mv(dependency)) {}
...@@ -375,6 +396,15 @@ void TransformPromiseNodeBase::dropDependency() { ...@@ -375,6 +396,15 @@ void TransformPromiseNodeBase::dropDependency() {
dependency = nullptr; dependency = nullptr;
} }
void TransformPromiseNodeBase::getDepResult(ExceptionOrValue& output) {
dependency->get(output);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
dependency = nullptr;
})) {
output.addException(kj::mv(*exception));
}
}
// ------------------------------------------------------------------- // -------------------------------------------------------------------
ForkBranchBase::ForkBranchBase(Own<const ForkHubBase>&& hubParam): hub(kj::mv(hubParam)) { ForkBranchBase::ForkBranchBase(Own<const ForkHubBase>&& hubParam): hub(kj::mv(hubParam)) {
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "mutex.h" #include "mutex.h"
#include "refcount.h" #include "refcount.h"
#include "work-queue.h" #include "work-queue.h"
#include "tuple.h"
namespace kj { namespace kj {
...@@ -631,6 +632,18 @@ public: ...@@ -631,6 +632,18 @@ public:
// `Own<U>`, `U` must have a method `Own<const U> addRef() const` which returns a new reference // `Own<U>`, `U` must have a method `Own<const U> addRef() const` which returns a new reference
// to the same (or an equivalent) object (probably implemented via reference counting). // to the same (or an equivalent) object (probably implemented via reference counting).
template <typename... Attachments>
void attach(Attachments&&... attachments);
// "Attaches" one or more movable objects (often, Own<T>s) to the promise, such that they will
// be destroyed when the promise resolves. This is useful when a promise's callback contains
// pointers into some object and you want to make sure the object still exists when the callback
// runs -- after calling then(), use attach() to add necessary objects to the result.
void eagerlyEvaluate(const EventLoop& eventLoop = EventLoop::current());
// Force eager evaluation of this promise. Use this if you are going to hold on to the promise
// for awhile without consuming the result, but you want to make sure that the system actually
// processes it.
private: private:
Promise(bool, Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {} Promise(bool, Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {}
// Second parameter prevent ambiguity with immediate-value constructor. // Second parameter prevent ambiguity with immediate-value constructor.
...@@ -996,6 +1009,45 @@ private: ...@@ -996,6 +1009,45 @@ private:
// ------------------------------------------------------------------- // -------------------------------------------------------------------
class AttachmentPromiseNodeBase: public PromiseNode {
public:
AttachmentPromiseNodeBase(Own<PromiseNode>&& dependency);
bool onReady(EventLoop::Event& event) noexcept override;
void get(ExceptionOrValue& output) noexcept override;
Maybe<const EventLoop&> getSafeEventLoop() noexcept override;
private:
Own<PromiseNode> dependency;
void dropDependency();
template <typename>
friend class AttachmentPromiseNode;
};
template <typename Attachment>
class AttachmentPromiseNode final: public AttachmentPromiseNodeBase {
// A PromiseNode that holds on to some object (usually, an Own<T>, but could be any movable
// object) until the promise resolves.
public:
AttachmentPromiseNode(Own<PromiseNode>&& dependency, Attachment&& attachment)
: AttachmentPromiseNodeBase(kj::mv(dependency)),
attachment(kj::mv<Attachment>(attachment)) {}
~AttachmentPromiseNode() noexcept(false) {
// We need to make sure the dependency is deleted before we delete the attachment because the
// dependency may be using the attachment.
dropDependency();
}
private:
Attachment attachment;
};
// -------------------------------------------------------------------
class TransformPromiseNodeBase: public PromiseNode { class TransformPromiseNodeBase: public PromiseNode {
public: public:
TransformPromiseNodeBase(Maybe<const EventLoop&> loop, Own<PromiseNode>&& dependency); TransformPromiseNodeBase(Maybe<const EventLoop&> loop, Own<PromiseNode>&& dependency);
...@@ -1009,6 +1061,7 @@ private: ...@@ -1009,6 +1061,7 @@ private:
Own<PromiseNode> dependency; Own<PromiseNode> dependency;
void dropDependency(); void dropDependency();
void getDepResult(ExceptionOrValue& output);
virtual void getImpl(ExceptionOrValue& output) = 0; virtual void getImpl(ExceptionOrValue& output) = 0;
...@@ -1040,7 +1093,7 @@ private: ...@@ -1040,7 +1093,7 @@ private:
void getImpl(ExceptionOrValue& output) override { void getImpl(ExceptionOrValue& output) override {
ExceptionOr<DepT> depResult; ExceptionOr<DepT> depResult;
dependency->get(depResult); getDepResult(depResult);
KJ_IF_MAYBE(depException, depResult.exception) { KJ_IF_MAYBE(depException, depResult.exception) {
output.as<T>() = handle( output.as<T>() = handle(
MaybeVoidCaller<Exception, FixVoid<ReturnType<ErrorFunc, Exception>>>::apply( MaybeVoidCaller<Exception, FixVoid<ReturnType<ErrorFunc, Exception>>>::apply(
...@@ -1452,6 +1505,18 @@ Promise<_::Forked<T>> ForkedPromise<T>::addBranch() const { ...@@ -1452,6 +1505,18 @@ Promise<_::Forked<T>> ForkedPromise<T>::addBranch() const {
return hub->addBranch(); return hub->addBranch();
} }
template <typename T>
template <typename... Attachments>
void Promise<T>::attach(Attachments&&... attachments) {
node = kj::heap<_::AttachmentPromiseNode<Tuple<Attachments...>>>(
kj::mv(node), kj::tuple(kj::fwd<Attachments>(attachments)...));
}
template <typename T>
void Promise<T>::eagerlyEvaluate(const EventLoop& eventLoop) {
node = _::spark<_::FixVoid<T>>(kj::mv(node), eventLoop);
}
// ======================================================================================= // =======================================================================================
namespace _ { // private namespace _ { // private
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment