Commit f4ab3d2d authored by Kenton Varda's avatar Kenton Varda

Refactor how promise imports are represented in the RPC implementation.

parent 2aa37a8a
......@@ -280,9 +280,17 @@ struct Export {
template <typename ImportClient>
struct Import {
ImportClient* client = nullptr;
// Normally I'd want this to be Maybe<ImportClient&>, but GCC's unordered_map doesn't seem to
// like DisableConstCopy types.
Import() = default;
Import(const Import&) = delete;
Import(Import&&) = default;
Import& operator=(Import&&) = default;
// If we don't explicitly write all this, we get some stupid error deep in STL.
kj::Maybe<ImportClient&> client;
// Becomes null when the import is destroyed.
kj::Maybe<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> promiseFulfiller;
// If non-null, the import is a promise.
};
// =======================================================================================
......@@ -339,21 +347,26 @@ public:
auto pipeline = kj::refcounted<RpcPipeline>(
*this, questionId, eventLoop.fork(kj::mv(promiseWithQuestionRef)));
return kj::refcounted<PromisedAnswerClient>(*this, kj::mv(pipeline), nullptr);
return pipeline->getPipelinedCap(kj::Array<const PipelineOp>(nullptr));
}
void taskFailed(kj::Exception&& exception) override {
{
kj::Exception networkException(
kj::Exception::Nature::NETWORK_FAILURE, kj::Exception::Durability::PERMANENT,
"", 0, kj::str("Disconnected: ", exception.getDescription()));
kj::Vector<kj::Own<const PipelineHook>> pipelinesToRelease;
kj::Vector<kj::Own<const ClientHook>> clientsToRelease;
kj::Vector<kj::Own<CapInjectorImpl>> paramCapsToRelease;
auto lock = tables.lockExclusive();
if (lock->networkException != nullptr) {
// Oops, already disconnected.
return;
}
kj::Exception networkException(
kj::Exception::Nature::NETWORK_FAILURE, kj::Exception::Durability::PERMANENT,
__FILE__, __LINE__, kj::str("Disconnected: ", exception.getDescription()));
// All current questions complete with exceptions.
lock->questions.forEach([&](QuestionId id,
Question<CapInjectorImpl, RpcPipeline, RpcResponse>& question) {
......@@ -379,8 +392,8 @@ public:
});
lock->imports.forEach([&](ExportId id, Import<ImportClient>& import) {
if (import.client != nullptr) {
import.client->disconnect(kj::cp(networkException));
KJ_IF_MAYBE(f, import.promiseFulfiller) {
f->get()->reject(kj::cp(networkException));
}
});
......@@ -407,6 +420,7 @@ private:
kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller;
class ImportClient;
class PromiseClient;
class CapInjectorImpl;
class CapExtractorImpl;
class RpcPipeline;
......@@ -414,10 +428,11 @@ private:
class RpcResponse;
struct Tables {
ExportTable<ExportId, Export> exports;
ExportTable<QuestionId, Question<CapInjectorImpl, RpcPipeline, RpcResponse>> questions;
ImportTable<QuestionId, Answer<RpcCallContext>> answers;
ExportTable<ExportId, Export> exports;
ImportTable<ExportId, Import<ImportClient>> imports;
// The order of the tables is important for correct destruction.
std::unordered_map<const ClientHook*, ExportId> exportsByCap;
// Maps already-exported ClientHook objects to their ID in the export table.
......@@ -455,6 +470,7 @@ private:
const RpcConnectionState& connectionState;
};
// TODO(now): unused?
ExportDisposer exportDisposer;
// =====================================================================================
......@@ -546,12 +562,13 @@ private:
kj::Own<const RpcConnectionState> connectionState;
};
class ImportClient: public RpcClient {
protected:
class ImportClient final: public RpcClient {
// A ClientHook that wraps an entry in the import table.
public:
ImportClient(const RpcConnectionState& connectionState, ExportId importId)
: RpcClient(connectionState), importId(importId) {}
public:
~ImportClient() noexcept(false) {
{
// Remove self from the import table, if the table is still pointing at us. (It's possible
......@@ -560,8 +577,10 @@ private:
// the import table.)
auto lock = connectionState->tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports.find(importId)) {
if (import->client == this) {
lock->imports.erase(importId);
KJ_IF_MAYBE(i, import->client) {
if (i == this) {
lock->imports.erase(importId);
}
}
}
}
......@@ -572,13 +591,6 @@ private:
}
}
virtual bool settle(kj::Own<const ClientHook> replacement) = 0;
// Replace the PromiseImportClient with its resolution. Returns false if this is not a promise
// (i.e. it is a SettledImportClient).
virtual void disconnect(kj::Exception&& exception) = 0;
// Cause whenMoreResolved() to fail.
kj::Maybe<kj::Own<ImportClient>> tryAddRemoteRef() {
// Add a new RemoteRef and return a new ref to this client representing it. Returns null
// if this client is being deleted in another thread, in which case the caller should
......@@ -620,6 +632,10 @@ private:
return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return nullptr;
}
private:
ExportId importId;
......@@ -627,167 +643,113 @@ private:
// Number of times we've received this import from the peer.
};
class SettledImportClient final: public ImportClient {
class PipelineClient final: public RpcClient {
// A ClientHook representing a pipelined promise. Always wrapped in PromiseClient.
public:
inline SettledImportClient(const RpcConnectionState& connectionState, ExportId importId)
: ImportClient(connectionState, importId) {}
PipelineClient(const RpcConnectionState& connectionState,
kj::Own<const RpcPipeline>&& pipeline,
kj::Array<PipelineOp>&& ops)
: RpcClient(connectionState), pipeline(kj::mv(pipeline)), ops(kj::mv(ops)) {}
bool settle(kj::Own<const ClientHook> replacement) override {
return false;
kj::Maybe<ExportId> writeDescriptor(
rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
return pipeline->writeDescriptor(descriptor, tables, ops);
}
void disconnect(kj::Exception&& exception) override {
// nothing
kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target) const override {
// TODO(now): The pipeline may redirect to the resolution before PromiseClient has resolved.
// This could lead to a race condition if PromiseClient implements embargoes.
return pipeline->writeTarget(target, ops);
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return nullptr;
}
};
// implements ClientHook -----------------------------------------
class PromiseImportClient final: public ImportClient {
public:
PromiseImportClient(const RpcConnectionState& connectionState, ExportId importId)
: ImportClient(connectionState, importId),
fork(nullptr) {
auto paf = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>(connectionState.eventLoop);
fulfiller = kj::mv(paf.fulfiller);
fork = connectionState.eventLoop.fork(kj::mv(paf.promise));
}
Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto request = kj::heap<RpcRequest>(
*connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall();
bool settle(kj::Own<const ClientHook> replacement) override {
fulfiller->fulfill(kj::mv(replacement));
return true;
}
callBuilder.setInterfaceId(interfaceId);
callBuilder.setMethodId(methodId);
void disconnect(kj::Exception&& exception) override {
fulfiller->reject(kj::mv(exception));
auto root = request->getRoot();
return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
}
// TODO(now): Override writeDescriptor() and writeTarget() to redirect once the promise
// resolves.
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
// We need the returned promise to hold a reference back to this object, so that it doesn't
// disappear while the promise is still outstanding.
return fork.addBranch().thenInAnyThread(kj::mvCapture(kj::addRef(*this),
[](kj::Own<const PromiseImportClient>&&, kj::Own<const ClientHook>&& replacement) {
return kj::mv(replacement);
}));
return nullptr;
}
private:
kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>> fulfiller;
kj::ForkedPromise<kj::Own<const ClientHook>> fork;
kj::Own<const RpcPipeline> pipeline;
kj::Array<PipelineOp> ops;
};
class PromisedAnswerClient final: public RpcClient {
class PromiseClient final: public RpcClient {
// A ClientHook that initially wraps one client (in practice, an ImportClient or a
// PipelineClient) and then, later on, redirects to some other client.
public:
PromisedAnswerClient(const RpcConnectionState& connectionState,
kj::Own<const RpcPipeline>&& pipeline,
kj::Array<PipelineOp>&& ops)
: RpcClient(connectionState), ops(kj::mv(ops)),
resolveSelfPromise(connectionState.eventLoop.there(pipeline->onResponse(),
[this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> {
resolve(kj::mv(response));
return kj::READY_NOW; // hack to force eager resolution.
}, [this](kj::Exception&& exception) -> kj::Promise<void> {
resolve(kj::mv(exception));
return kj::READY_NOW; // hack to force eager resolution.
PromiseClient(const RpcConnectionState& connectionState,
kj::Own<const ClientHook> initial,
kj::Promise<kj::Own<const ClientHook>> eventual)
: RpcClient(connectionState),
inner(kj::mv(initial)),
fork(connectionState.eventLoop.fork(kj::mv(eventual))),
resolveSelfPromise(connectionState.eventLoop.there(fork.addBranch(),
[this](kj::Own<const ClientHook>&& resolution) {
resolve(kj::mv(resolution));
}, [this](kj::Exception&& exception) {
resolve(newBrokenCap(kj::mv(exception)));
})) {
state.getWithoutLock().init<Waiting>(kj::mv(pipeline));
// Create a client that starts out forwarding all calls to `initial` but, once `eventual`
// resolves, will forward there instead. In addition, `whenMoreResolved()` will return a fork
// of `eventual`. Note that this means the application could hold on to `eventual` even after
// the `PromiseClient` is destroyed; `eventual` must therefore make sure to hold references to
// anything that needs to stay alive in order to resolve it correctly (such as making sure the
// import ID is not released).
resolveSelfPromise.eagerlyEvaluate(connectionState.eventLoop);
}
kj::Maybe<ExportId> writeDescriptor(
rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
auto lock = state.lockShared();
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->writeDescriptor(descriptor, tables, ops);
} else if (lock->is<Resolved>()) {
return connectionState->writeDescriptor(
lock->get<Resolved>()->addRef(), descriptor, tables);
} else {
return connectionState->writeDescriptor(
newBrokenCap(kj::cp(lock->get<Broken>())), descriptor, tables);
}
kj::Maybe<ExportId> writeDescriptor(
rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
auto cap = inner.lockExclusive()->get()->addRef();
return connectionState->writeDescriptor(kj::mv(cap), descriptor, tables);
}
kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target) const override {
auto lock = state.lockShared();
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->writeTarget(target, ops);
} else if (lock->is<Resolved>()) {
return connectionState->writeTarget(*lock->get<Resolved>(), target);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()));
}
return connectionState->writeTarget(**inner.lockExclusive(), target);
}
// implements ClientHook -----------------------------------------
Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto lock = state.lockShared();
if (lock->is<Waiting>()) {
auto request = kj::heap<RpcRequest>(
*connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall();
callBuilder.setInterfaceId(interfaceId);
callBuilder.setMethodId(methodId);
auto root = request->getRoot();
return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->newCall(interfaceId, methodId, firstSegmentWordSize);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()))->newCall(
interfaceId, methodId, firstSegmentWordSize);
}
return inner.lockExclusive()->get()->newCall(interfaceId, methodId, firstSegmentWordSize);
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
auto lock = state.lockShared();
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->onResponse().thenInAnyThread(
kj::mvCapture(kj::heapArray(ops.asPtr()),
[](kj::Array<PipelineOp>&& ops, kj::Own<const RpcResponse>&& response) {
return response->getResults().getPipelinedCap(ops);
}));
} else if (lock->is<Resolved>()) {
return kj::Promise<kj::Own<const ClientHook>>(lock->get<Resolved>()->addRef());
} else {
return kj::Promise<kj::Own<const ClientHook>>(kj::cp(lock->get<Broken>()));
}
return fork.addBranch();
}
private:
kj::Array<PipelineOp> ops;
typedef kj::Own<const RpcPipeline> Waiting;
typedef kj::Own<const ClientHook> Resolved;
typedef kj::Exception Broken;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, Broken>> state;
kj::MutexGuarded<kj::Own<const ClientHook>> inner;
kj::ForkedPromise<kj::Own<const ClientHook>> fork;
// Keep this last, because the continuation uses *this, so it should be destroyed first to
// ensure the continuation is not still running.
kj::Promise<void> resolveSelfPromise;
void resolve(kj::Own<const RpcResponse>&& response) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Resolved>(response->getResults().getPipelinedCap(ops));
}
void resolve(const kj::Exception&& exception) {
auto lock = state.lockExclusive();
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?");
lock->init<Broken>(kj::mv(exception));
void resolve(kj::Own<const ClientHook> replacement) {
// Careful to make sure the old client is not destroyed until we release the lock.
kj::Own<const ClientHook> old;
auto lock = inner.lockExclusive();
old = kj::mv(*lock);
*lock = replacement->addRef();
}
};
......@@ -813,6 +775,7 @@ private:
exp.refcount = 1;
exp.clientHook = kj::mv(cap);
descriptor.setSenderHosted(exportId);
KJ_DBG(this, exportId);
return exportId;
}
}
......@@ -876,9 +839,12 @@ private:
for (ExportId importId: retainedCaps) {
// Check if the import still exists under this ID.
KJ_IF_MAYBE(import, lock->imports.find(importId)) {
if (import->client != nullptr && import->client->tryAddRemoteRef() != nullptr) {
// Import indeed still exists! We are responsible for retaining it.
*actualRetained++ = importId;
KJ_IF_MAYBE(i, import->client) {
if (i->tryAddRemoteRef() != nullptr) {
// Import indeed still exists! We are responsible for retaining it.
// TODO(now): Do we need to hold on to the ref that tryAddRemoteRef() returned?
*actualRetained++ = importId;
}
}
}
}
......@@ -907,10 +873,10 @@ private:
auto lock = connectionState.tables.lockExclusive();
auto& import = lock->imports[importId];
if (import.client != nullptr) {
KJ_IF_MAYBE(i, import.client) {
// The import is already on the table, but it could be being deleted in another
// thread.
KJ_IF_MAYBE(ref, kj::tryAddRef(*import.client)) {
KJ_IF_MAYBE(ref, kj::tryAddRef(*i)) {
// We successfully grabbed a reference to the import without it being deleted in
// another thread. Since this import already exists, we don't have to take
// responsibility for retaining it. We can just return the existing object and
......@@ -920,15 +886,23 @@ private:
}
// No import for this ID exists currently, so create one.
kj::Own<ImportClient> result;
kj::Own<ImportClient> importClient =
kj::refcounted<ImportClient>(connectionState, importId);
import.client = *importClient;
kj::Own<ClientHook> result;
if (descriptor.which() == rpc::CapDescriptor::SENDER_PROMISE) {
// TODO(now): Check for pending `Resolve` messages replacing this import ID, and if
// one exists, use that client instead.
result = kj::refcounted<PromiseImportClient>(connectionState, importId);
auto paf = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>();
import.promiseFulfiller = kj::mv(paf.fulfiller);
paf.promise.attach(kj::addRef(*importClient));
result = kj::refcounted<PromiseClient>(
connectionState, kj::mv(importClient), kj::mv(paf.promise));
} else {
result = kj::refcounted<SettledImportClient>(connectionState, importId);
result = kj::mv(importClient);
}
import.client = result;
// Note that we need to retain this import later if it still exists.
retainedCaps.lockExclusive()->add(importId);
......@@ -993,6 +967,7 @@ private:
if (lock->networkException == nullptr) {
for (auto exportId: exports) {
KJ_DBG(&connectionState, exportId);
auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId));
if (--exp.refcount == 0) {
clientsToRelease.add(kj::mv(exp.clientHook));
......@@ -1021,6 +996,8 @@ private:
auto maybeExportId = connectionState.writeDescriptor(
entry.second.cap->addRef(), entry.second.builder, tables);
KJ_IF_MAYBE(exportId, maybeExportId) {
KJ_ASSERT(tables.exports.find(*exportId) != nullptr);
KJ_DBG(&connectionState, *exportId);
exports.add(*exportId);
}
}
......@@ -1245,15 +1222,14 @@ private:
: connectionState(kj::addRef(connectionState)),
redirectLater(kj::mv(redirectLaterParam)),
resolveSelfPromise(connectionState.eventLoop.there(redirectLater.addBranch(),
[this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> {
[this](kj::Own<const RpcResponse>&& response) {
resolve(kj::mv(response));
return kj::READY_NOW; // hack to force eager resolution.
}, [this](kj::Exception&& exception) -> kj::Promise<void> {
}, [this](kj::Exception&& exception) {
resolve(kj::mv(exception));
return kj::READY_NOW;
})) {
// Construct a new RpcPipeline.
resolveSelfPromise.eagerlyEvaluate(connectionState.eventLoop);
state.getWithoutLock().init<Waiting>(questionId);
}
......@@ -1318,8 +1294,18 @@ private:
kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const override {
auto lock = state.lockExclusive();
if (lock->is<Waiting>()) {
return kj::refcounted<PromisedAnswerClient>(
*connectionState, kj::addRef(*this), kj::mv(ops));
// Wrap a PipelineClient in a PromiseClient.
auto pipelineClient = kj::refcounted<PipelineClient>(
*connectionState, kj::addRef(*this), kj::heapArray(ops.asPtr()));
auto resolutionPromise = connectionState->eventLoop.there(redirectLater.addBranch(),
kj::mvCapture(ops,
[](kj::Array<PipelineOp> ops, kj::Own<const RpcResponse>&& response) {
return response->getResults().getPipelinedCap(ops);
}));
return kj::refcounted<PromiseClient>(
*connectionState, kj::mv(pipelineClient), kj::mv(resolutionPromise));
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops);
} else {
......@@ -1633,11 +1619,9 @@ private:
// actual deletion asynchronously. But we have to remove it from the table *now*, while
// we still hold the lock, because once we send the return message the answer ID is free
// for reuse.
connectionState->tasks.add(connectionState->eventLoop.evalLater(
kj::mvCapture(lock->answers[questionId],
[](Answer<RpcCallContext>&& answer) {
// Just let the answer be deleted.
})));
auto promise = connectionState->eventLoop.evalLater([]() {});
promise.attach(kj::mv(lock->answers[questionId]));
connectionState->tasks.add(kj::mv(promise));
// Erase from the table.
lock->answers.erase(questionId);
......@@ -1830,17 +1814,14 @@ private:
// refcount, but both will be destroyed at the same time anyway.
RpcCallContext* contextPtr = context;
// TODO(cleanup): We have the continuations return Promise<void> rather than void because
// this tricks the promise framework into making sure the continuations actually run
// without anyone waiting on them. We should find a cleaner way to do this.
answer.asyncOp = promiseAndPipeline.promise.then(
kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) -> kj::Promise<void> {
context->sendReturn();
return kj::READY_NOW;
}), [contextPtr](kj::Exception&& exception) -> kj::Promise<void> {
[contextPtr]() {
contextPtr->sendReturn();
}, [contextPtr](kj::Exception&& exception) {
contextPtr->sendErrorReturn(kj::mv(exception));
return kj::READY_NOW;
});
answer.asyncOp.attach(kj::mv(context));
answer.asyncOp.eagerlyEvaluate(eventLoop);
}
}
......@@ -2033,7 +2014,11 @@ public:
auto& connectionMap = connections.getWithoutLock();
if (!connectionMap.empty()) {
kj::Vector<kj::Own<RpcConnectionState>> deleteMe(connectionMap.size());
kj::Exception shutdownException(
kj::Exception::Nature::LOCAL_BUG, kj::Exception::Durability::PERMANENT,
__FILE__, __LINE__, kj::str("RpcSystem was destroyed."));
for (auto& entry: connectionMap) {
entry.second->taskFailed(kj::cp(shutdownException));
deleteMe.add(kj::mv(entry.second));
}
}
......@@ -2053,9 +2038,7 @@ public:
}
void taskFailed(kj::Exception&& exception) override {
// TODO(now): What do we do?
kj::throwRecoverableException(kj::mv(exception));
KJ_LOG(ERROR, exception);
}
private:
......
......@@ -428,7 +428,7 @@ struct Resolve {
# When a promise ID is first sent over the wire (e.g. in a `CapDescriptor`), the sender (exporter)
# guarantees that it will follow up at some point with exactly one `Resolve` message. If the
# same `promiseId` is sent again before `Resolve`, still only one `Resolve` is sent. If the
# same ID is reused again later _after_ a `Resolve`, it can only be because the export's
# same ID is sent again later _after_ a `Resolve`, it can only be because the export's
# reference count hit zero in the meantime and the ID was re-assigned to a new export, therefore
# this later promise does _not_ correspond to the earlier `Resolve`.
#
......@@ -816,10 +816,10 @@ struct CapDescriptor {
senderPromise @1 :ExportId;
# A promise which the sender will resolve later. The sender will send exactly one Resolve
# message at a future point in time to replace this promise.
#
# TODO(soon): Can we merge this with senderHosted? Change `Resolve` to be allowed on any
# export (i.e. it can be delivered zero or one times). Maybe rename it to `Replace`.
# message at a future point in time to replace this promise. Note that even if the same
# `senderPromise` is received multiple times, only one `Resolve` is sent to cover all of
# them. The `Resolve` is delivered even if `senderPromise` is not retained, or is retained
# but then released before the `Resolve` is sent.
receiverHosted @2 :ExportId;
# A capability (or promise) previously exported by the receiver.
......
......@@ -259,6 +259,7 @@ inline Array<T> heapArray(size_t size) {
}
template <typename T> Array<T> heapArray(const T* content, size_t size);
template <typename T> Array<T> heapArray(ArrayPtr<T> content);
template <typename T> Array<T> heapArray(ArrayPtr<const T> content);
template <typename T, typename Iterator> Array<T> heapArray(Iterator begin, Iterator end);
template <typename T> Array<T> heapArray(std::initializer_list<T> init);
......@@ -654,6 +655,13 @@ Array<T> heapArray(const T* content, size_t size) {
return builder.finish();
}
template <typename T>
Array<T> heapArray(ArrayPtr<T> content) {
ArrayBuilder<T> builder = heapArrayBuilder<T>(content.size());
builder.addAll(content);
return builder.finish();
}
template <typename T>
Array<T> heapArray(ArrayPtr<const T> content) {
ArrayBuilder<T> builder = heapArrayBuilder<T>(content.size());
......
......@@ -536,5 +536,59 @@ TEST(Async, EventLoopGuarded) {
}
}
class DestructorDetector {
public:
DestructorDetector(bool& setTrue): setTrue(setTrue) {}
~DestructorDetector() { setTrue = true; }
private:
bool& setTrue;
};
TEST(Async, Attach) {
bool destroyed = false;
SimpleEventLoop loop;
Promise<int> promise = loop.evalLater([&]() {
EXPECT_FALSE(destroyed);
return 123;
});
promise.attach(kj::heap<DestructorDetector>(destroyed));
promise = loop.there(kj::mv(promise), [&](int i) {
EXPECT_TRUE(destroyed);
return i + 321;
});
EXPECT_FALSE(destroyed);
EXPECT_EQ(444, loop.wait(kj::mv(promise)));
EXPECT_TRUE(destroyed);
}
TEST(Async, EagerlyEvaluate) {
bool called = false;
SimpleEventLoop loop;
Promise<void> promise = nullptr;
loop.wait(loop.evalLater([&]() {
promise = Promise<void>(READY_NOW).then([&]() {
called = true;
});
}));
loop.wait(loop.evalLater([]() {}));
EXPECT_FALSE(called);
promise.eagerlyEvaluate(loop);
loop.wait(loop.evalLater([]() {}));
EXPECT_TRUE(called);
}
} // namespace
} // namespace kj
......@@ -350,6 +350,27 @@ void ImmediateBrokenPromiseNode::get(ExceptionOrValue& output) noexcept {
// -------------------------------------------------------------------
AttachmentPromiseNodeBase::AttachmentPromiseNodeBase(Own<PromiseNode>&& dependency)
: dependency(kj::mv(dependency)) {}
bool AttachmentPromiseNodeBase::onReady(EventLoop::Event& event) noexcept {
return dependency->onReady(event);
}
void AttachmentPromiseNodeBase::get(ExceptionOrValue& output) noexcept {
dependency->get(output);
}
Maybe<const EventLoop&> AttachmentPromiseNodeBase::getSafeEventLoop() noexcept {
return dependency->getSafeEventLoop();
}
void AttachmentPromiseNodeBase::dropDependency() {
dependency = nullptr;
}
// -------------------------------------------------------------------
TransformPromiseNodeBase::TransformPromiseNodeBase(
Maybe<const EventLoop&> loop, Own<PromiseNode>&& dependency)
: loop(loop), dependency(kj::mv(dependency)) {}
......@@ -375,6 +396,15 @@ void TransformPromiseNodeBase::dropDependency() {
dependency = nullptr;
}
void TransformPromiseNodeBase::getDepResult(ExceptionOrValue& output) {
dependency->get(output);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
dependency = nullptr;
})) {
output.addException(kj::mv(*exception));
}
}
// -------------------------------------------------------------------
ForkBranchBase::ForkBranchBase(Own<const ForkHubBase>&& hubParam): hub(kj::mv(hubParam)) {
......
......@@ -28,6 +28,7 @@
#include "mutex.h"
#include "refcount.h"
#include "work-queue.h"
#include "tuple.h"
namespace kj {
......@@ -631,6 +632,18 @@ public:
// `Own<U>`, `U` must have a method `Own<const U> addRef() const` which returns a new reference
// to the same (or an equivalent) object (probably implemented via reference counting).
template <typename... Attachments>
void attach(Attachments&&... attachments);
// "Attaches" one or more movable objects (often, Own<T>s) to the promise, such that they will
// be destroyed when the promise resolves. This is useful when a promise's callback contains
// pointers into some object and you want to make sure the object still exists when the callback
// runs -- after calling then(), use attach() to add necessary objects to the result.
void eagerlyEvaluate(const EventLoop& eventLoop = EventLoop::current());
// Force eager evaluation of this promise. Use this if you are going to hold on to the promise
// for awhile without consuming the result, but you want to make sure that the system actually
// processes it.
private:
Promise(bool, Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {}
// Second parameter prevent ambiguity with immediate-value constructor.
......@@ -996,6 +1009,45 @@ private:
// -------------------------------------------------------------------
class AttachmentPromiseNodeBase: public PromiseNode {
public:
AttachmentPromiseNodeBase(Own<PromiseNode>&& dependency);
bool onReady(EventLoop::Event& event) noexcept override;
void get(ExceptionOrValue& output) noexcept override;
Maybe<const EventLoop&> getSafeEventLoop() noexcept override;
private:
Own<PromiseNode> dependency;
void dropDependency();
template <typename>
friend class AttachmentPromiseNode;
};
template <typename Attachment>
class AttachmentPromiseNode final: public AttachmentPromiseNodeBase {
// A PromiseNode that holds on to some object (usually, an Own<T>, but could be any movable
// object) until the promise resolves.
public:
AttachmentPromiseNode(Own<PromiseNode>&& dependency, Attachment&& attachment)
: AttachmentPromiseNodeBase(kj::mv(dependency)),
attachment(kj::mv<Attachment>(attachment)) {}
~AttachmentPromiseNode() noexcept(false) {
// We need to make sure the dependency is deleted before we delete the attachment because the
// dependency may be using the attachment.
dropDependency();
}
private:
Attachment attachment;
};
// -------------------------------------------------------------------
class TransformPromiseNodeBase: public PromiseNode {
public:
TransformPromiseNodeBase(Maybe<const EventLoop&> loop, Own<PromiseNode>&& dependency);
......@@ -1009,6 +1061,7 @@ private:
Own<PromiseNode> dependency;
void dropDependency();
void getDepResult(ExceptionOrValue& output);
virtual void getImpl(ExceptionOrValue& output) = 0;
......@@ -1040,7 +1093,7 @@ private:
void getImpl(ExceptionOrValue& output) override {
ExceptionOr<DepT> depResult;
dependency->get(depResult);
getDepResult(depResult);
KJ_IF_MAYBE(depException, depResult.exception) {
output.as<T>() = handle(
MaybeVoidCaller<Exception, FixVoid<ReturnType<ErrorFunc, Exception>>>::apply(
......@@ -1452,6 +1505,18 @@ Promise<_::Forked<T>> ForkedPromise<T>::addBranch() const {
return hub->addBranch();
}
template <typename T>
template <typename... Attachments>
void Promise<T>::attach(Attachments&&... attachments) {
node = kj::heap<_::AttachmentPromiseNode<Tuple<Attachments...>>>(
kj::mv(node), kj::tuple(kj::fwd<Attachments>(attachments)...));
}
template <typename T>
void Promise<T>::eagerlyEvaluate(const EventLoop& eventLoop) {
node = _::spark<_::FixVoid<T>>(kj::mv(node), eventLoop);
}
// =======================================================================================
namespace _ { // private
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment