Commit bdd06585 authored by Kenton Varda's avatar Kenton Varda

Refactor capability code using fork. Still too much refcounting, though. Maybe…

Refactor capability code using fork.  Still too much refcounting, though.  Maybe this calls for a different design for pipelining...
parent fe5b21e8
......@@ -113,7 +113,7 @@ public:
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContextHook& context) const override {
kj::Own<CallContextHook>&& context) const override {
KJ_FAIL_REQUIRE("Calling capability that was extracted from a message that had no "
"capability context.");
}
......
......@@ -81,6 +81,8 @@ kj::Promise<void> ClientHook::whenResolved() const {
// =======================================================================================
namespace {
class LocalResponse final: public ResponseHook, public kj::Refcounted {
public:
LocalResponse(uint sizeHint)
......@@ -89,7 +91,7 @@ public:
MallocMessageBuilder message;
};
class LocalCallContext final: public CallContextHook {
class LocalCallContext final: public CallContextHook, public kj::Refcounted {
public:
LocalCallContext(kj::Own<MallocMessageBuilder>&& request, kj::Own<const ClientHook> clientRef)
: request(kj::mv(request)), clientRef(kj::mv(clientRef)) {}
......@@ -112,9 +114,8 @@ public:
bool isCanceled() override {
return false;
}
Response<ObjectPointer> getResponseForPipeline() override {
auto reader = getResults(1); // Needs to be a separate line since it may allocate the response.
return Response<ObjectPointer>(reader, kj::addRef(*response));
kj::Own<CallContextHook> addRef() override {
return kj::addRef(*this);
}
kj::Own<MallocMessageBuilder> request;
......@@ -124,10 +125,11 @@ public:
class LocalRequest final: public RequestHook {
public:
inline LocalRequest(uint64_t interfaceId, uint16_t methodId,
inline LocalRequest(kj::EventLoop& loop, uint64_t interfaceId, uint16_t methodId,
uint firstSegmentWordSize, kj::Own<const ClientHook> client)
: message(kj::heap<MallocMessageBuilder>(
firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize)),
loop(loop),
interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {}
RemotePromise<TypelessResults> send() override {
......@@ -135,10 +137,10 @@ public:
uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId;
auto context = kj::heap<LocalCallContext>(kj::mv(message), kj::mv(client));
auto promiseAndPipeline = client->call(interfaceId, methodId, *context);
auto context = kj::refcounted<LocalCallContext>(kj::mv(message), kj::mv(client));
auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context));
auto promise = promiseAndPipeline.promise.then(
auto promise = loop.there(kj::mv(promiseAndPipeline.promise),
kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
return Response<TypelessResults>(context->getResults(1).asReader(),
kj::mv(context->response));
......@@ -151,67 +153,12 @@ public:
kj::Own<MallocMessageBuilder> message;
private:
kj::EventLoop& loop;
uint64_t interfaceId;
uint16_t methodId;
kj::Own<const ClientHook> client;
};
// =======================================================================================
namespace {
class BrokenPipeline final: public PipelineHook, public kj::Refcounted {
public:
BrokenPipeline(const kj::Exception& exception): exception(exception) {}
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
}
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const override;
private:
kj::Exception exception;
};
class BrokenClient final: public ClientHook, public kj::Refcounted {
public:
BrokenClient(const kj::Exception& exception): exception(exception) {}
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
return Request<ObjectPointer, TypelessResults>(
hook->message->getRoot<ObjectPointer>(), kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContextHook& context) const override {
return VoidPromiseAndPipeline { kj::cp(exception), kj::heap<BrokenPipeline>(exception) };
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return kj::Promise<kj::Own<const ClientHook>>(kj::cp(exception));
}
kj::Own<const ClientHook> addRef() const override {
return kj::addRef(*this);
}
void* getBrand() const override {
return nullptr;
}
private:
kj::Exception exception;
};
kj::Own<const ClientHook> BrokenPipeline::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) const {
return kj::heap<BrokenClient>(exception);
}
// =======================================================================================
// Call queues
//
......@@ -225,27 +172,7 @@ class QueuedPipeline final: public PipelineHook, public kj::Refcounted {
public:
QueuedPipeline(kj::EventLoop& loop, kj::Promise<kj::Own<const PipelineHook>>&& promise)
: loop(loop),
innerPromise(loop.there(kj::mv(promise), [this](kj::Own<const PipelineHook>&& resolution) {
auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
for (auto& waiter: oldState) {
waiter.fulfiller->fulfill(resolution->getPipelinedCap(kj::mv(waiter.ops)));
}
lock->init<Resolved>(kj::mv(resolution));
}, [this](kj::Exception&& exception) {
auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
for (auto& waiter: oldState) {
waiter.fulfiller->reject(kj::cp(exception));
}
lock->init<kj::Exception>(kj::mv(exception));
})) {
state.getWithoutLock().init<Waiting>();
}
promise(loop.fork(kj::mv(promise))) {}
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
......@@ -262,17 +189,8 @@ public:
kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const override;
private:
struct Waiter {
kj::Array<PipelineOp> ops;
kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>> fulfiller;
};
typedef kj::Vector<Waiter> Waiting;
typedef kj::Own<const PipelineHook> Resolved;
kj::EventLoop& loop;
kj::Promise<void> innerPromise;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, kj::Exception>> state;
kj::ForkedPromise<kj::Own<const PipelineHook>> promise;
};
class QueuedClient final: public ClientHook, public kj::Refcounted {
......@@ -282,107 +200,73 @@ class QueuedClient final: public ClientHook, public kj::Refcounted {
public:
QueuedClient(kj::EventLoop& loop, kj::Promise<kj::Own<const ClientHook>>&& promise)
: loop(loop),
innerPromise(loop.there(kj::mv(promise), [this](kj::Own<const ClientHook>&& resolution) {
// The promised capability has resolved. Forward all queued calls to it.
auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
// First we want to initiate all the queued calls, and notify the QueuedPipelines to
// transfer their queues to the new call's own pipeline. It's important that this all
// happen before the application receives any notification that the promise resolved,
// so that any new calls it makes in response to the resolution don't end up being
// delivered before the previously-queued calls.
auto realCallPromises = kj::heapArrayBuilder<kj::Promise<void>>(oldState.pending.size());
for (auto& pendingCall: oldState.pending) {
auto realCall = resolution->call(
pendingCall.interfaceId, pendingCall.methodId, *pendingCall.context);
pendingCall.pipelineFulfiller->fulfill(kj::mv(realCall.pipeline));
realCallPromises.add(kj::mv(realCall.promise));
}
// Fire the "whenMoreResolved" callbacks.
for (auto& notify: oldState.notifyOnResolution) {
notify->fulfill(resolution->addRef());
}
// For each queued call, chain the pipelined promise to the real promise. It's important
// that this happens after the "whenMoreResolved" callbacks because applications may get
// confused if a pipelined call completes before the promise on which it was made
// resolves.
for (uint i: kj::indices(realCallPromises)) {
oldState.pending[i].fulfiller->fulfill(kj::mv(realCallPromises[i]));
}
lock->init<Resolved>(kj::mv(resolution));
}, [this](kj::Exception&& exception) {
auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
// Reject outer promises before dependent promises.
for (auto& notify: oldState.notifyOnResolution) {
notify->reject(kj::cp(exception));
}
for (auto& call: oldState.pending) {
call.fulfiller->reject(kj::cp(exception));
call.pipelineFulfiller->reject(kj::cp(exception));
}
lock->init<kj::Exception>(kj::mv(exception));
})) {
state.getWithoutLock().init<Waiting>();
}
promise(loop.fork(kj::mv(promise))) {}
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto lock = state.lockExclusive();
if (lock->is<Resolved>()) {
return lock->get<Resolved>()->newCall(interfaceId, methodId, firstSegmentWordSize);
} else {
auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
loop, interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
return Request<ObjectPointer, TypelessResults>(
hook->message->getRoot<ObjectPointer>(), kj::mv(hook));
}
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContextHook& context) const override {
auto lock = state.lockExclusive();
if (lock->is<Resolved>()) {
return lock->get<Resolved>()->call(interfaceId, methodId, context);
} else if (lock->is<kj::Exception>()) {
return VoidPromiseAndPipeline { kj::cp(lock->get<kj::Exception>()),
kj::heap<BrokenPipeline>(lock->get<kj::Exception>()) };
} else {
auto pair = kj::newPromiseAndFulfiller<kj::Promise<void>>(loop);
auto pipelinePromise = kj::newPromiseAndFulfiller<kj::Own<const PipelineHook>>(loop);
auto pipeline = kj::heap<QueuedPipeline>(loop, kj::mv(pipelinePromise.promise));
kj::Own<CallContextHook>&& context) const override {
// This is a bit complicated. We need to initiate this call later on. When we initiate the
// call, we'll get a void promise for its completion and a pipeline object. Right now, we have
// to produce a similar void promise and pipeline that will eventually be chained to those.
// The problem is, these are two independent objects, but they both depend on the result of
// one future call.
//
// So, we need to set up a continuation that will initiate the call later, then we need to
// fork the promise for that continuation in order to send the completion promise and the
// pipeline to their respective places.
//
// TODO(perf): Too much reference counting? Can we do better? Maybe a way to fork
// Promise<Tuple<T, U>> into Tuple<Promise<T>, Promise<U>>?
struct CallResultHolder: public kj::Refcounted {
// Essentially acts as a refcounted \VoidPromiseAndPipeline, so that we can create a promise
// for it and fork that promise.
mutable VoidPromiseAndPipeline content;
// One branch of the fork will use content.promise, the other branch will use
// content.pipeline. Neither branch will touch the other's piece, but each needs to clobber
// its own piece, so we declare this mutable.
inline CallResultHolder(VoidPromiseAndPipeline&& content): content(kj::mv(content)) {}
kj::Own<const CallResultHolder> addRef() const { return kj::addRef(*this); }
};
lock->get<Waiting>().pending.add(PendingCall {
interfaceId, methodId, &context, kj::mv(pair.fulfiller),
kj::mv(pipelinePromise.fulfiller) });
// Create a promise for the call initiation.
kj::ForkedPromise<kj::Own<CallResultHolder>> callResultPromise = loop.there(
getPromiseForCallForwarding().addBranch(), kj::mvCapture(context,
[=](kj::Own<CallContextHook>&& context, kj::Own<const ClientHook>&& client){
return kj::refcounted<CallResultHolder>(
client->call(interfaceId, methodId, kj::mv(context)));
})).fork();
// Create a promise that extracts the pipeline from the call initiation, and construct our
// QueuedPipeline to chain to it.
auto pipelinePromise = loop.there(callResultPromise.addBranch(),
[](kj::Own<const CallResultHolder>&& callResult){
return kj::mv(callResult->content.pipeline);
});
auto pipeline = kj::refcounted<QueuedPipeline>(loop, kj::mv(pipelinePromise));
// TODO(now): returned promise must hold a reference to this.
return VoidPromiseAndPipeline { kj::mv(pair.promise), kj::mv(pipeline) };
}
// Create a promise that simply chains to the void promise produced by the call initiation.
auto completionPromise = loop.there(callResultPromise.addBranch(),
[](kj::Own<const CallResultHolder>&& callResult){
return kj::mv(callResult->content.promise);
});
// OK, now we can actually return our thing.
return VoidPromiseAndPipeline { kj::mv(completionPromise), kj::mv(pipeline) };
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
auto lock = state.lockExclusive();
if (lock->is<Resolved>()) {
// Already resolved.
return kj::Promise<kj::Own<const ClientHook>>(lock->get<Resolved>()->addRef());
} else if (lock->is<kj::Exception>()) {
// Already broken.
return kj::Promise<kj::Own<const ClientHook>>(kj::Own<const ClientHook>(
kj::heap<BrokenClient>(lock->get<kj::Exception>())));
} else {
// Waiting.
auto pair = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>();
lock->get<Waiting>().notifyOnResolution.add(kj::mv(pair.fulfiller));
// TODO(now): returned promise must hold a reference to this.
return kj::mv(pair.promise);
}
return getPromiseForClientResolution().addBranch();
}
kj::Own<const ClientHook> addRef() const override {
......@@ -394,55 +278,71 @@ public:
}
private:
struct PendingCall {
uint64_t interfaceId;
uint16_t methodId;
CallContextHook* context;
kj::Own<kj::PromiseFulfiller<kj::Promise<void>>> fulfiller;
kj::Own<kj::PromiseFulfiller<kj::Own<const PipelineHook>>> pipelineFulfiller;
};
struct Waiting {
kj::Vector<PendingCall> pending;
kj::Vector<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> notifyOnResolution;
};
kj::EventLoop& loop;
typedef kj::Own<const ClientHook> Resolved;
typedef kj::ForkedPromise<kj::Own<const ClientHook>> ClientHookPromiseFork;
ClientHookPromiseFork promise;
// Promise that resolves when we have a new ClientHook to forward to.
//
// This fork shall only have two branches: `promiseForCallForwarding` and
// `promiseForClientResolution`, in that order.
kj::Lazy<ClientHookPromiseFork> promiseForCallForwarding;
// When this promise resolves, each queued call will be forwarded to the real client. This needs
// to occur *before* any 'whenMoreResolved()' promises resolve, because we want to make sure
// previously-queued calls are delivered before any new calls made in response to the resolution.
kj::Lazy<ClientHookPromiseFork> promiseForClientResolution;
// whenMoreResolved() returns forks of this promise. These must resolve *after* queued calls
// have been initiated (so that any calls made in the whenMoreResolved() handler are correctly
// delivered after calls made earlier), but *before* any queued calls return (because it might
// confuse the application if a queued call returns before the capability on which it was made
// resolves). Luckily, we know that queued calls will involve, at the very least, an
// eventLoop.evalLater.
const ClientHookPromiseFork& getPromiseForCallForwarding() const {
return promiseForCallForwarding.get([this](kj::SpaceFor<ClientHookPromiseFork>& space) {
return space.construct(promise.addBranch().fork());
});
}
kj::EventLoop& loop;
kj::Promise<void> innerPromise;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, kj::Exception>> state;
const kj::ForkedPromise<kj::Own<const ClientHook>>& getPromiseForClientResolution() const {
return promiseForClientResolution.get([this](kj::SpaceFor<ClientHookPromiseFork>& space) {
getPromiseForCallForwarding(); // must be initialized first.
return space.construct(promise.addBranch().fork());
});
}
};
kj::Own<const ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) const {
auto lock = state.lockExclusive();
if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getPipelinedCap(ops);
} else if (lock->is<kj::Exception>()) {
return kj::heap<BrokenClient>(lock->get<kj::Exception>());
} else {
auto pair = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>();
lock->get<Waiting>().add(Waiter { kj::mv(ops), kj::mv(pair.fulfiller) });
return kj::heap<QueuedClient>(loop, kj::mv(pair.promise));
}
auto clientPromise = loop.there(promise.addBranch(), kj::mvCapture(ops,
[](kj::Array<PipelineOp>&& ops, kj::Own<const PipelineHook> pipeline) {
return pipeline->getPipelinedCap(kj::mv(ops));
}));
return kj::refcounted<QueuedClient>(loop, kj::mv(clientPromise));
}
// =======================================================================================
class LocalPipeline final: public PipelineHook, public kj::Refcounted {
public:
inline LocalPipeline(Response<ObjectPointer> response): response(kj::mv(response)) {}
inline LocalPipeline(kj::Own<CallContextHook>&& context)
: context(kj::mv(context)),
results(context->getResults(1)) {}
kj::Own<const PipelineHook> addRef() const {
return kj::addRef(*this);
}
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const {
return response.getPipelinedCap(ops);
return results.getPipelinedCap(ops);
}
private:
Response<ObjectPointer> response;
kj::Own<CallContextHook> context;
ObjectPointer::Reader results;
};
class LocalClient final: public ClientHook, public kj::Refcounted {
......@@ -453,30 +353,51 @@ public:
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
eventLoop, interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
return Request<ObjectPointer, TypelessResults>(
hook->message->getRoot<ObjectPointer>(), kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContextHook& context) const override {
kj::Own<CallContextHook>&& context) const override {
// We can const-cast the server because we're synchronizing on the event loop.
auto server = const_cast<Capability::Server*>(this->server.get());
auto pipelineFulfiller = kj::newPromiseAndFulfiller<kj::Own<const PipelineHook>>();
auto promise = eventLoop.evalLater(kj::mvCapture(pipelineFulfiller.fulfiller,
[=,&context](kj::Own<kj::PromiseFulfiller<kj::Own<const PipelineHook>>>&& fulfiller) mutable {
auto contextPtr = context.get();
// We don't want to actually dispatch the call synchronously, because:
// 1) The server may prefer a different EventLoop.
// 2) If the server is in the same EventLoop, calling it synchronously could be dangerous due
// to risk of deadlocks if it happens to take a mutex that the client already holds. One
// of the main goals of message-passing architectures is to avoid this!
//
// So, we do an evalLater() here.
//
// Note also that QueuedClient depends on this evalLater() to ensure that pipelined calls don't
// complete before 'whenMoreResolved()' promises resolve.
auto promise = eventLoop.evalLater(
[=]() mutable {
return server->dispatchCall(interfaceId, methodId,
CallContext<ObjectPointer, ObjectPointer>(context))
.then(kj::mvCapture(fulfiller,
[=,&context](kj::Own<kj::PromiseFulfiller<kj::Own<const PipelineHook>>>&& fulfiller) {
fulfiller->fulfill(kj::heap<LocalPipeline>(context.getResponseForPipeline()));
CallContext<ObjectPointer, ObjectPointer>(*contextPtr));
});
// We have to fork this promise for the pipeline to receive a copy of the answer.
auto forked = eventLoop.fork(kj::mv(promise));
auto pipelinePromise = eventLoop.there(forked.addBranch(), kj::mvCapture(context->addRef(),
[=](kj::Own<CallContextHook>&& context) -> kj::Own<const PipelineHook> {
context->releaseParams();
return kj::refcounted<LocalPipeline>(kj::mv(context));
}));
auto completionPromise = eventLoop.there(forked.addBranch(), kj::mvCapture(context->addRef(),
[=](kj::Own<CallContextHook>&& context) {
// Nothing to do here. We just wanted to make sure to hold on to a reference to the
// context even if the pipeline was discarded.
}));
return VoidPromiseAndPipeline { kj::mv(promise),
kj::heap<QueuedPipeline>(eventLoop, kj::mv(pipelineFulfiller.promise)) };
return VoidPromiseAndPipeline { kj::mv(completionPromise),
kj::refcounted<QueuedPipeline>(eventLoop, kj::mv(pipelinePromise)) };
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
......
......@@ -342,13 +342,14 @@ public:
};
virtual VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContextHook& context) const = 0;
kj::Own<CallContextHook>&& context) const = 0;
// Call the object, but the caller controls allocation of the request/response objects. If the
// callee insists on allocating this objects itself, it must make a copy. This version is used
// when calls come in over the network via an RPC system. During the call, the context object
// may be used from any thread so long as it is only used from one thread at a time. Once the
// returned promise resolves or has been canceled, the context can no longer be used. The caller
// must not allow the ClientHook to be destroyed until the call completes or is canceled.
// may be used from any thread so long as it is only used from one thread at a time. Note that
// even if the returned `Promise<void>` is discarded, the call may continue executing if any
// pipelined calls are waiting for it; the call is only truly done when the CallContextHook is
// destroyed.
//
// The call must not begin synchronously, as the caller may hold arbitrary mutexes.
......@@ -380,10 +381,7 @@ public:
virtual void allowAsyncCancellation(bool allow) = 0;
virtual bool isCanceled() = 0;
virtual Response<ObjectPointer> getResponseForPipeline() = 0;
// Get a copy or reference to the response which will be used to execute pipelined calls. This
// will be called no more than once, just after the server implementation successfully returns
// from the call.
virtual kj::Own<CallContextHook> addRef() = 0;
};
// =======================================================================================
......
......@@ -324,11 +324,11 @@ TEST(Async, Fork) {
auto fork = promise.fork();
auto branch1 = fork->addBranch().then([](int i) {
auto branch1 = fork.addBranch().then([](int i) {
EXPECT_EQ(123, i);
return 456;
});
auto branch2 = fork->addBranch().then([](int i) {
auto branch2 = fork.addBranch().then([](int i) {
EXPECT_EQ(123, i);
return 789;
});
......@@ -360,11 +360,11 @@ TEST(Async, ForkRef) {
auto fork = promise.fork();
auto branch1 = fork->addBranch().then([](Own<const RefcountedInt>&& i) {
auto branch1 = fork.addBranch().then([](Own<const RefcountedInt>&& i) {
EXPECT_EQ(123, i->i);
return 456;
});
auto branch2 = fork->addBranch().then([](Own<const RefcountedInt>&& i) {
auto branch2 = fork.addBranch().then([](Own<const RefcountedInt>&& i) {
EXPECT_EQ(123, i->i);
return 789;
});
......
......@@ -260,7 +260,7 @@ void SimpleEventLoop::wake() const {
// =======================================================================================
void PromiseBase::absolve() {
runCatchingExceptions([this]() { auto deleteMe = kj::mv(node); });
runCatchingExceptions([this]() { node = nullptr; });
}
namespace _ { // private
......@@ -330,9 +330,13 @@ Maybe<const EventLoop&> TransformPromiseNodeBase::getSafeEventLoop() noexcept {
return loop;
}
void TransformPromiseNodeBase::dropDependency() {
dependency = nullptr;
}
// -------------------------------------------------------------------
ForkBranchBase::ForkBranchBase(Own<ForkHubBase>&& hubParam): hub(kj::mv(hubParam)) {
ForkBranchBase::ForkBranchBase(Own<const ForkHubBase>&& hubParam): hub(kj::mv(hubParam)) {
auto lock = hub->branchList.lockExclusive();
if (lock->lastPtr == nullptr) {
......@@ -362,7 +366,7 @@ void ForkBranchBase::hubReady() noexcept {
void ForkBranchBase::releaseHub(ExceptionOrValue& output) {
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
auto deleteMe = kj::mv(hub);
hub = nullptr;
})) {
output.addException(kj::mv(*exception));
}
......@@ -398,7 +402,7 @@ void ForkHubBase::fire() {
// Dependency is ready. Fetch its result and then delete the node.
inner->get(resultRef);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
auto deleteMe = kj::mv(inner);
inner = nullptr;
})) {
resultRef.addException(kj::mv(*exception));
}
......@@ -525,7 +529,7 @@ void CrossThreadPromiseNodeBase::fire() {
} else {
dependency->get(resultRef);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
auto deleteMe = kj::mv(dependency);
dependency = nullptr;
})) {
resultRef.addException(kj::mv(*exception));
}
......
......@@ -36,6 +36,8 @@ class SimpleEventLoop;
template <typename T>
class Promise;
template <typename T>
class ForkedPromise;
template <typename T>
class PromiseFulfiller;
template <typename T>
struct PromiseFulfillerPair;
......@@ -272,6 +274,11 @@ public:
// Like `Promise::then()`, but schedules the continuation to be executed on *this* EventLoop
// rather than the thread's current loop. See Promise::then().
template <typename T>
ForkedPromise<T> fork(Promise<T>&& promise);
// Like `Promise::fork()`, but manages the fork on *this* EventLoop rather than the thread's
// current loop. See Promise::fork().
// -----------------------------------------------------------------
// Low-level interface.
......@@ -586,14 +593,7 @@ public:
// After returning, the promise is no longer valid, and cannot be `wait()`ed on or `then()`ed
// again.
class Fork {
public:
virtual Promise<_::Forked<T>> addBranch() = 0;
// Add a new branch to the fork. The branch is equivalent to the original promise, except
// that if T is a reference or owned pointer, the target becomes const.
};
Own<Fork> fork();
ForkedPromise<T> fork();
// Forks the promise, so that multiple different clients can independently wait on the result.
// `T` must be copy-constructable for this to work. Or, in the special case where `T` is
// `Own<U>`, `U` must have a method `Own<const U> addRef() const` which returns a new reference
......@@ -616,6 +616,27 @@ private:
friend class _::ForkHub;
};
template <typename T>
class ForkedPromise {
// The result of `Promise::fork()` and `EventLoop::fork()`. Allows branches to be created.
// Like `Promise<T>`, this is a pass-by-move type.
public:
inline ForkedPromise(decltype(nullptr)): hub(nullptr) {}
Promise<_::Forked<T>> addBranch() const;
// Add a new branch to the fork. The branch is equivalent to the original promise, except
// that if T is a reference or owned pointer, the target becomes const.
private:
Own<const _::ForkHub<_::FixVoid<T>>> hub;
inline ForkedPromise(bool, Own<const _::ForkHub<_::FixVoid<T>>>&& hub): hub(kj::mv(hub)) {}
friend class Promise<T>;
friend class EventLoop;
};
constexpr _::Void READY_NOW = _::Void();
// Use this when you need a Promise<void> that is already fulfilled -- this value can be implicitly
// cast to `Promise<void>`.
......@@ -881,6 +902,8 @@ private:
const EventLoop& loop;
Own<PromiseNode> dependency;
void dropDependency();
virtual void getImpl(ExceptionOrValue& output) = 0;
template <typename, typename, typename, typename>
......@@ -898,6 +921,13 @@ public:
: TransformPromiseNodeBase(loop, kj::mv(dependency)),
func(kj::fwd<Func>(func)), errorHandler(kj::fwd<ErrorFunc>(errorHandler)) {}
~TransformPromiseNode() noexcept(false) {
// We need to make sure the dependency is deleted before we delete the continuations because it
// is a common pattern for the continuations to hold ownership of objects that might be in-use
// by the dependency.
dropDependency();
}
private:
Func func;
ErrorFunc errorHandler;
......@@ -927,7 +957,7 @@ class ForkHubBase;
class ForkBranchBase: public PromiseNode {
public:
ForkBranchBase(Own<ForkHubBase>&& hub);
ForkBranchBase(Own<const ForkHubBase>&& hub);
~ForkBranchBase();
void hubReady() noexcept;
......@@ -946,7 +976,7 @@ protected:
private:
EventLoop::Event* onReadyEvent = nullptr;
Own<ForkHubBase> hub;
Own<const ForkHubBase> hub;
ForkBranchBase* next = nullptr;
ForkBranchBase** prevPtr = nullptr;
......@@ -963,7 +993,7 @@ class ForkBranch final: public ForkBranchBase {
// a const reference.
public:
ForkBranch(Own<ForkHubBase>&& hub): ForkBranchBase(kj::mv(hub)) {}
ForkBranch(Own<const ForkHubBase>&& hub): ForkBranchBase(kj::mv(hub)) {}
void get(ExceptionOrValue& output) noexcept override {
const ExceptionOr<T>& hubResult = getHubResultRef().template as<T>();
......@@ -1006,7 +1036,7 @@ private:
};
template <typename T>
class ForkHub final: public ForkHubBase, public Promise<T>::Fork {
class ForkHub final: public ForkHubBase {
// A PromiseNode that implements the hub of a fork. The first call to Promise::fork() replaces
// the promise's outer node with a ForkHub, and subsequent calls add branches to that hub (if
// possible).
......@@ -1015,8 +1045,8 @@ public:
ForkHub(const EventLoop& loop, Own<PromiseNode>&& inner)
: ForkHubBase(loop, kj::mv(inner), result) {}
Promise<_::Forked<T>> addBranch() override {
return Promise<_::Forked<T>>(false, kj::heap<ForkBranch<T>>(addRef(*this)));
Promise<_::Forked<_::UnfixVoid<T>>> addBranch() const {
return Promise<_::Forked<_::UnfixVoid<T>>>(false, kj::heap<ForkBranch<T>>(addRef(*this)));
}
private:
......@@ -1261,9 +1291,23 @@ T Promise<T>::wait() {
}
template <typename T>
Own<typename Promise<T>::Fork> Promise<T>::fork() {
ForkedPromise<T> Promise<T>::fork() {
auto& loop = EventLoop::current();
return refcounted<_::ForkHub<T>>(loop, _::makeSafeForLoop<_::FixVoid<T>>(kj::mv(node), loop));
return ForkedPromise<T>(false,
refcounted<_::ForkHub<_::FixVoid<T>>>(
loop, _::makeSafeForLoop<_::FixVoid<T>>(kj::mv(node), loop)));
}
template <typename T>
ForkedPromise<T> EventLoop::fork(Promise<T>&& promise) {
return ForkedPromise<T>(false,
refcounted<_::ForkHub<_::FixVoid<T>>>(*this,
_::makeSafeForLoop<_::FixVoid<T>>(kj::mv(promise.node), *this)));
}
template <typename T>
Promise<_::Forked<T>> ForkedPromise<T>::addBranch() const {
return hub->addBranch();
}
// =======================================================================================
......
......@@ -22,13 +22,22 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "refcount.h"
#include <memory>
namespace kj {
Refcounted::~Refcounted() noexcept(false) {}
void Refcounted::disposeImpl(void* pointer) const {
if (__atomic_sub_fetch(&refcount, 1, __ATOMIC_RELAXED) == 0) {
// The load is a fast-path for the common case where this is the last reference. An acquire-load
// is just a regular load on x86. If there is more than one reference, then we need to do a full
// atomic decrement with full memory barrier, because:
// - If this is the final decrement then we need to acquire the object state in order to destroy
// it.
// - If this is not the final decrement then we need to release the object state so that another
// thread may destroy it.
if (__atomic_load_n(&refcount, __ATOMIC_ACQUIRE) == 1 ||
__atomic_sub_fetch(&refcount, 1, __ATOMIC_ACQ_REL) == 0) {
delete this;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment