Commit 06945999 authored by Kenton Varda's avatar Kenton Varda

Implement ability to fork promises.

parent e8d256ab
......@@ -316,5 +316,29 @@ TEST(Async, Ordering) {
EXPECT_EQ(7, counter);
}
TEST(Async, Fork) {
SimpleEventLoop loop;
auto outer = loop.evalLater([&]() {
Promise<String> promise = loop.evalLater([&]() { return str("foo"); });
auto fork = promise.fork();
auto branch1 = fork->addBranch().then([](const String& s) {
EXPECT_EQ("foo", s);
return 456;
});
auto branch2 = fork->addBranch().then([](const String& s) {
EXPECT_EQ("foo", s);
return 789;
});
EXPECT_EQ(456, loop.wait(kj::mv(branch1)));
EXPECT_EQ(789, loop.wait(kj::mv(branch2)));
});
loop.wait(kj::mv(outer));
}
} // namespace
} // namespace kj
......@@ -296,6 +296,8 @@ void PromiseNode::atomicReady(EventLoop::Event*& onReadyEvent,
}
}
// -------------------------------------------------------------------
bool ImmediatePromiseNodeBase::onReady(EventLoop::Event& event) noexcept { return true; }
Maybe<const EventLoop&> ImmediatePromiseNodeBase::getSafeEventLoop() noexcept { return nullptr; }
......@@ -306,6 +308,8 @@ void ImmediateBrokenPromiseNode::get(ExceptionOrValue& output) noexcept {
output.exception = kj::mv(exception);
}
// -------------------------------------------------------------------
TransformPromiseNodeBase::TransformPromiseNodeBase(
const EventLoop& loop, Own<PromiseNode>&& dependency)
: loop(loop), dependency(kj::mv(dependency)) {}
......@@ -326,6 +330,94 @@ Maybe<const EventLoop&> TransformPromiseNodeBase::getSafeEventLoop() noexcept {
return loop;
}
// -------------------------------------------------------------------
ForkBranchBase::ForkBranchBase(Own<ForkHubBase>&& hubParam): hub(kj::mv(hubParam)) {
auto lock = hub->branchList.lockExclusive();
if (lock->lastPtr == nullptr) {
onReadyEvent = _kJ_ALREADY_READY;
} else {
// Insert into hub's linked list of branches.
prevPtr = lock->lastPtr;
*prevPtr = this;
next = nullptr;
lock->lastPtr = &next;
}
}
ForkBranchBase::~ForkBranchBase() {
if (prevPtr != nullptr) {
// Remove from hub's linked list of branches.
auto lock = hub->branchList.lockExclusive();
*prevPtr = next;
(next == nullptr ? lock->lastPtr : next->prevPtr) = prevPtr;
}
}
void ForkBranchBase::hubReady() noexcept {
// TODO(soon): This should only yield if queuing cross-thread.
atomicReady(onReadyEvent, EventLoop::Event::YIELD);
}
void ForkBranchBase::releaseHub(ExceptionOrValue& output) {
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
auto deleteMe = kj::mv(hub);
})) {
output.addException(kj::mv(*exception));
}
}
bool ForkBranchBase::onReady(EventLoop::Event& event) noexcept {
return atomicOnReady(onReadyEvent, event);
}
Maybe<const EventLoop&> ForkBranchBase::getSafeEventLoop() noexcept {
// It's safe to read the hub's value from multiple threads, once it is ready, since we'll only
// be reading a const reference.
return nullptr;
}
// -------------------------------------------------------------------
ForkHubBase::ForkHubBase(const EventLoop& loop, Own<PromiseNode>&& inner,
ExceptionOrValue& resultRef)
: EventLoop::Event(loop), inner(kj::mv(inner)), resultRef(resultRef) {
KJ_DREQUIRE(this->inner->isSafeEventLoop(loop));
// TODO(soon): This should only yield if queuing cross-thread.
arm(YIELD);
}
ForkHubBase::~ForkHubBase() noexcept(false) {}
void ForkHubBase::fire() {
if (!isWaiting && !inner->onReady(*this)) {
isWaiting = true;
} else {
// Dependency is ready. Fetch its result and then delete the node.
inner->get(resultRef);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
auto deleteMe = kj::mv(inner);
})) {
resultRef.addException(kj::mv(*exception));
}
auto lock = branchList.lockExclusive();
for (auto branch = lock->first; branch != nullptr; branch = branch->next) {
branch->hubReady();
*branch->prevPtr = nullptr;
branch->prevPtr = nullptr;
}
*lock->lastPtr = nullptr;
// Indicate that the list is no longer active.
lock->lastPtr = nullptr;
}
}
// -------------------------------------------------------------------
ChainPromiseNode::ChainPromiseNode(const EventLoop& loop, Own<PromiseNode> inner, Schedule schedule)
: Event(loop), state(PRE_STEP1), inner(kj::mv(inner)) {
KJ_DREQUIRE(this->inner->isSafeEventLoop(loop));
......@@ -401,10 +493,12 @@ void ChainPromiseNode::fire() {
}
}
// -------------------------------------------------------------------
CrossThreadPromiseNodeBase::CrossThreadPromiseNodeBase(
const EventLoop& loop, Own<PromiseNode>&& dependent, ExceptionOrValue& resultRef)
: Event(loop), dependent(kj::mv(dependent)), resultRef(resultRef) {
KJ_DREQUIRE(this->dependent->isSafeEventLoop(loop));
const EventLoop& loop, Own<PromiseNode>&& dependency, ExceptionOrValue& resultRef)
: Event(loop), dependency(kj::mv(dependency)), resultRef(resultRef) {
KJ_DREQUIRE(this->dependency->isSafeEventLoop(loop));
// The constructor may be called from any thread, so before we can even call onReady() we need
// to switch threads. We yield here so that the event is added to the end of the queue, which
......@@ -426,12 +520,12 @@ Maybe<const EventLoop&> CrossThreadPromiseNodeBase::getSafeEventLoop() noexcept
}
void CrossThreadPromiseNodeBase::fire() {
if (!isWaiting && !this->dependent->onReady(*this)) {
if (!isWaiting && !dependency->onReady(*this)) {
isWaiting = true;
} else {
dependent->get(resultRef);
dependency->get(resultRef);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
auto deleteMe = kj::mv(dependent);
auto deleteMe = kj::mv(dependency);
})) {
resultRef.addException(kj::mv(*exception));
}
......@@ -441,6 +535,8 @@ void CrossThreadPromiseNodeBase::fire() {
}
}
// -------------------------------------------------------------------
bool AdapterPromiseNodeBase::onReady(EventLoop::Event& event) noexcept {
return PromiseNode::atomicOnReady(onReadyEvent, event);
}
......
......@@ -26,6 +26,7 @@
#include "exception.h"
#include "mutex.h"
#include "refcount.h"
namespace kj {
......@@ -77,6 +78,9 @@ public:
Bottom operator()(Exception&& e) {
return Bottom(kj::mv(e));
}
Bottom operator()(const Exception& e) {
return Bottom(kj::cp(e));
}
};
template <typename Func, typename T>
......@@ -89,6 +93,19 @@ using ReturnType = typename ReturnType_<Func, T>::Type;
// The return type of functor Func given a parameter of type T, with the special exception that if
// T is void, this is the return type of Func called with no arguments.
template <typename T>
struct ConstReferenceTo_ { typedef const T& Type; };
template <typename T>
struct ConstReferenceTo_<T&> { typedef const T& Type; };
template <typename T>
struct ConstReferenceTo_<const T&> { typedef const T& Type; };
template <>
struct ConstReferenceTo_<void> { typedef void Type; };
template <typename T>
using ConstReferenceTo = typename ConstReferenceTo_<T>::Type;
// Resolves to `const T&`, or to `void` if `T` is `void`.
struct Void {};
// Application code should NOT refer to this! See `kj::READY_NOW` instead.
......@@ -112,6 +129,13 @@ struct MaybeVoidCaller {
return func(kj::mv(in));
}
};
template <typename In, typename Out>
struct MaybeVoidCaller<In&, Out> {
template <typename Func>
static inline Out apply(Func& func, In& in) {
return func(in);
}
};
template <typename Out>
struct MaybeVoidCaller<Void, Out> {
template <typename Func>
......@@ -127,6 +151,14 @@ struct MaybeVoidCaller<In, Void> {
return Void();
}
};
template <typename In>
struct MaybeVoidCaller<In&, Void> {
template <typename Func>
static inline Void apply(Func& func, In& in) {
func(in);
return Void();
}
};
template <>
struct MaybeVoidCaller<Void, Void> {
template <typename Func>
......@@ -145,6 +177,8 @@ inline void returnMaybeVoid(Void&& v) {}
class ExceptionOrValue;
class PromiseNode;
class ChainPromiseNode;
template <typename T>
class ForkHub;
} // namespace _ (private)
......@@ -238,9 +272,9 @@ public:
// `evalLater()` is equivalent to `there()` chained on `Promise<void>(READY_NOW)`.
template <typename T, typename Func, typename ErrorFunc = _::PropagateException>
auto there(Promise<T>&& promise, Func&& func,
PromiseForResult<Func, T> there(Promise<T>&& promise, Func&& func,
ErrorFunc&& errorHandler = _::PropagateException()) const
-> PromiseForResult<Func, T>;
KJ_WARN_UNUSED_RESULT;
// Like `Promise::then()`, but schedules the continuation to be executed on *this* EventLoop
// rather than the thread's current loop. See Promise::then().
......@@ -558,6 +592,23 @@ public:
// After returning, the promise is no longer valid, and cannot be `wait()`ed on or `then()`ed
// again.
class Fork {
public:
virtual Promise<_::ConstReferenceTo<T>> addBranch() = 0;
// Add a new branch to the fork. The branch is equivalent to the original promise except that
// its type is `Promise<const T&>` rather than `Promise<T>` (except when `T` was already a
// reference, or was `void`).
};
Own<Fork> fork();
// Forks the promise, so that multiple different clients can independently wait on the result.
// Returns an object that can be used to construct branches, all of which are equivalent to the
// original promise except that they produce references to its result rather than passing the
// result by move.
//
// As with `then()` and `wait()`, `fork()` consumes the original promise, in the sense of move
// semantics.
private:
Promise(Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {}
......@@ -570,6 +621,8 @@ private:
friend PromiseFulfillerPair<U> newPromiseAndFulfiller();
template <typename U>
friend PromiseFulfillerPair<U> newPromiseAndFulfiller(const EventLoop& loop);
template <typename>
friend class _::ForkHub;
};
constexpr _::Void READY_NOW = _::Void();
......@@ -725,6 +778,8 @@ public:
template <typename T>
ExceptionOr<T>& as() { return *static_cast<ExceptionOr<T>*>(this); }
template <typename T>
const ExceptionOr<T>& as() const { return *static_cast<const ExceptionOr<T>*>(this); }
Maybe<Exception> exception;
......@@ -788,6 +843,8 @@ protected:
// Useful for firing events in conjuction with atomicOnReady().
};
// -------------------------------------------------------------------
class ImmediatePromiseNodeBase: public PromiseNode {
public:
bool onReady(EventLoop::Event& event) noexcept override;
......@@ -819,6 +876,8 @@ private:
Exception exception;
};
// -------------------------------------------------------------------
class TransformPromiseNodeBase: public PromiseNode {
public:
TransformPromiseNodeBase(const EventLoop& loop, Own<PromiseNode>&& dependency);
......@@ -871,6 +930,110 @@ private:
}
};
// -------------------------------------------------------------------
class ForkHubBase;
class ForkBranchBase: public PromiseNode {
public:
ForkBranchBase(Own<ForkHubBase>&& hub);
~ForkBranchBase();
void hubReady() noexcept;
// Called by the hub to indicate that it is ready.
// implements PromiseNode ------------------------------------------
bool onReady(EventLoop::Event& event) noexcept override;
Maybe<const EventLoop&> getSafeEventLoop() noexcept override;
protected:
inline const ExceptionOrValue& getHubResultRef() const;
void releaseHub(ExceptionOrValue& output);
// Release the hub. If an exception is thrown, add it to `output`.
private:
EventLoop::Event* onReadyEvent = nullptr;
Own<ForkHubBase> hub;
ForkBranchBase* next = nullptr;
ForkBranchBase** prevPtr = nullptr;
friend class ForkHubBase;
};
template <typename T>
class ForkBranch final: public ForkBranchBase {
// A PromiseNode that implements one branch of a fork -- i.e. one of the branches that receives
// a const reference.
public:
ForkBranch(Own<ForkHubBase>&& hub): ForkBranchBase(kj::mv(hub)) {}
void get(ExceptionOrValue& output) noexcept override {
const ExceptionOr<T>& hubResult = getHubResultRef().template as<T>();
KJ_IF_MAYBE(value, hubResult.value) {
output.as<ConstReferenceTo<T>>().value = *value;
} else {
output.as<ConstReferenceTo<T>>().value = nullptr;
}
output.exception = hubResult.exception;
releaseHub(output);
}
};
// -------------------------------------------------------------------
class ForkHubBase: public Refcounted, private EventLoop::Event {
public:
ForkHubBase(const EventLoop& loop, Own<PromiseNode>&& inner, ExceptionOrValue& resultRef);
~ForkHubBase() noexcept(false);
inline const ExceptionOrValue& getResultRef() const { return resultRef; }
private:
struct BranchList {
ForkBranchBase* first = nullptr;
ForkBranchBase** lastPtr = &first;
};
Own<PromiseNode> inner;
ExceptionOrValue& resultRef;
bool isWaiting = false;
MutexGuarded<BranchList> branchList;
// Becomes null once the inner promise is ready and all branches have been notified.
void fire() override;
friend class ForkBranchBase;
};
template <typename T>
class ForkHub final: public ForkHubBase, public Promise<T>::Fork {
// A PromiseNode that implements the hub of a fork. The first call to Promise::fork() replaces
// the promise's outer node with a ForkHub, and subsequent calls add branches to that hub (if
// possible).
public:
ForkHub(const EventLoop& loop, Own<PromiseNode>&& inner)
: ForkHubBase(loop, kj::mv(inner), result) {}
Promise<_::ConstReferenceTo<T>> addBranch() override {
return Promise<_::ConstReferenceTo<T>>(kj::heap<ForkBranch<T>>(addRef(*this)));
}
private:
ExceptionOr<T> result;
};
inline const ExceptionOrValue& ForkBranchBase::getHubResultRef() const {
return hub->getResultRef();
}
// -------------------------------------------------------------------
class ChainPromiseNode final: public PromiseNode, private EventLoop::Event {
public:
ChainPromiseNode(const EventLoop& loop, Own<PromiseNode> inner, Schedule schedule);
......@@ -920,12 +1083,14 @@ Own<PromiseNode>&& maybeChain(Own<PromiseNode>&& node, T*) {
return kj::mv(node);
}
// -------------------------------------------------------------------
class CrossThreadPromiseNodeBase: public PromiseNode, private EventLoop::Event {
// A PromiseNode that safely imports a promised value from one EventLoop to another (which
// implies crossing threads).
public:
CrossThreadPromiseNodeBase(const EventLoop& loop, Own<PromiseNode>&& dependent,
CrossThreadPromiseNodeBase(const EventLoop& loop, Own<PromiseNode>&& dependency,
ExceptionOrValue& resultRef);
~CrossThreadPromiseNodeBase() noexcept(false);
......@@ -933,7 +1098,7 @@ public:
Maybe<const EventLoop&> getSafeEventLoop() noexcept override;
private:
Own<PromiseNode> dependent;
Own<PromiseNode> dependency;
EventLoop::Event* onReadyEvent = nullptr;
ExceptionOrValue& resultRef;
......@@ -946,8 +1111,8 @@ private:
template <typename T>
class CrossThreadPromiseNode final: public CrossThreadPromiseNodeBase {
public:
CrossThreadPromiseNode(const EventLoop& loop, Own<PromiseNode>&& dependent)
: CrossThreadPromiseNodeBase(loop, kj::mv(dependent), result) {}
CrossThreadPromiseNode(const EventLoop& loop, Own<PromiseNode>&& dependency)
: CrossThreadPromiseNodeBase(loop, kj::mv(dependency), result) {}
void get(ExceptionOrValue& output) noexcept override {
output.as<T>() = kj::mv(result);
......@@ -976,6 +1141,8 @@ Own<PromiseNode> spark(Own<PromiseNode>&& node, const EventLoop& loop) {
return heap<CrossThreadPromiseNode<T>>(loop, kj::mv(node));
}
// -------------------------------------------------------------------
class AdapterPromiseNodeBase: public PromiseNode {
public:
bool onReady(EventLoop::Event& event) noexcept override;
......@@ -1029,6 +1196,8 @@ private:
} // namespace _ (private)
// =======================================================================================
template <typename T>
T EventLoop::wait(Promise<T>&& promise) {
_::ExceptionOr<_::FixVoid<T>> result;
......@@ -1054,8 +1223,8 @@ auto EventLoop::evalLater(Func&& func) const -> PromiseForResult<Func, void> {
}
template <typename T, typename Func, typename ErrorFunc>
auto EventLoop::there(Promise<T>&& promise, Func&& func, ErrorFunc&& errorHandler) const
-> PromiseForResult<Func, T> {
PromiseForResult<Func, T> EventLoop::there(
Promise<T>&& promise, Func&& func, ErrorFunc&& errorHandler) const {
return _::spark<_::FixVoid<_::JoinPromises<_::ReturnType<Func, T>>>>(thereImpl(
kj::mv(promise), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler), Event::YIELD), *this);
}
......@@ -1094,6 +1263,12 @@ T Promise<T>::wait() {
return EventLoop::current().wait(kj::mv(*this));
}
template <typename T>
Own<typename Promise<T>::Fork> Promise<T>::fork() {
auto& loop = EventLoop::current();
return refcounted<_::ForkHub<T>>(loop, _::makeSafeForLoop<_::FixVoid<T>>(kj::mv(node), loop));
}
// =======================================================================================
namespace _ { // private
......
......@@ -766,8 +766,8 @@ public:
inline Maybe(const Maybe<const U&>& other) noexcept: ptr(other.ptr) {}
inline Maybe(decltype(nullptr)) noexcept: ptr(nullptr) {}
inline Maybe& operator=(T& other) noexcept { ptr = &other; }
inline Maybe& operator=(T* other) noexcept { ptr = other; }
inline Maybe& operator=(T& other) noexcept { ptr = &other; return *this; }
inline Maybe& operator=(T* other) noexcept { ptr = other; return *this; }
template <typename U>
inline Maybe& operator=(Maybe<U&>& other) noexcept { ptr = other.ptr; return *this; }
template <typename U>
......
......@@ -52,9 +52,6 @@ class Refcounted: private Disposer {
public:
virtual ~Refcounted() noexcept(false);
template <typename T>
static Own<T> addRef(T& object);
private:
mutable volatile uint refcount = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment