Commit 08143efb authored by Kenton Varda's avatar Kenton Varda

Implement exclusive promise joining: Join two Promise<T>s to create a Promise…

Implement exclusive promise joining:  Join two Promise<T>s to create a Promise that resolves when either input resolves.
parent 28ad8ae3
...@@ -433,6 +433,52 @@ TEST(Async, ForkRef) { ...@@ -433,6 +433,52 @@ TEST(Async, ForkRef) {
loop.wait(kj::mv(outer)); loop.wait(kj::mv(outer));
} }
TEST(Async, ExclusiveJoin) {
{
SimpleEventLoop loop;
auto left = loop.evalLater([&]() { return 123; });
auto right = newPromiseAndFulfiller<int>(); // never fulfilled
auto promise = loop.exclusiveJoin(kj::mv(left), kj::mv(right.promise));
EXPECT_EQ(123, loop.wait(kj::mv(promise)));
}
{
SimpleEventLoop loop;
auto left = newPromiseAndFulfiller<int>(); // never fulfilled
auto right = loop.evalLater([&]() { return 123; });
auto promise = loop.exclusiveJoin(kj::mv(left.promise), kj::mv(right));
EXPECT_EQ(123, loop.wait(kj::mv(promise)));
}
{
SimpleEventLoop loop;
auto left = loop.evalLater([&]() { return 123; });
auto right = loop.evalLater([&]() { return 456; });
auto promise = loop.exclusiveJoin(kj::mv(left), kj::mv(right));
EXPECT_EQ(123, loop.wait(kj::mv(promise)));
}
{
SimpleEventLoop loop;
auto right = loop.evalLater([&]() { return 456; });
auto left = loop.evalLater([&]() { return 123; });
auto promise = loop.exclusiveJoin(kj::mv(left), kj::mv(right));
EXPECT_EQ(456, loop.wait(kj::mv(promise)));
}
}
class ErrorHandlerImpl: public TaskSet::ErrorHandler { class ErrorHandlerImpl: public TaskSet::ErrorHandler {
public: public:
uint exceptionCount = 0; uint exceptionCount = 0;
......
...@@ -564,6 +564,76 @@ void ChainPromiseNode::fire() { ...@@ -564,6 +564,76 @@ void ChainPromiseNode::fire() {
// ------------------------------------------------------------------- // -------------------------------------------------------------------
ExclusiveJoinPromiseNode::ExclusiveJoinPromiseNode(
const EventLoop& loop, Own<PromiseNode> left, Own<PromiseNode> right)
: left(loop, *this, kj::mv(left)),
right(loop, *this, kj::mv(right)) {}
ExclusiveJoinPromiseNode::~ExclusiveJoinPromiseNode() noexcept(false) {}
bool ExclusiveJoinPromiseNode::onReady(EventLoop::Event& event) noexcept {
if (onReadyEvent == _kJ_ALREADY_READY) {
return true;
} else {
onReadyEvent = &event;
return false;
}
}
void ExclusiveJoinPromiseNode::get(ExceptionOrValue& output) noexcept {
KJ_REQUIRE(left.get(output) || right.get(output),
"get() called before ready.");
}
Maybe<const EventLoop&> ExclusiveJoinPromiseNode::getSafeEventLoop() noexcept {
return left.getEventLoop();
}
ExclusiveJoinPromiseNode::Branch::Branch(
const EventLoop& loop, ExclusiveJoinPromiseNode& joinNode, Own<PromiseNode> dependency)
: Event(loop), joinNode(joinNode), dependency(kj::mv(dependency)) {
KJ_DREQUIRE(this->dependency->isSafeEventLoop(loop));
arm();
}
ExclusiveJoinPromiseNode::Branch::~Branch() noexcept(false) {
disarm();
}
bool ExclusiveJoinPromiseNode::Branch::get(ExceptionOrValue& output) {
if (finished) {
dependency->get(output);
return true;
} else {
return false;
}
}
void ExclusiveJoinPromiseNode::Branch::fire() {
if (!isWaiting && !dependency->onReady(*this)) {
isWaiting = true;
} else {
finished = true;
// Cancel the branch that didn't return first. Ignore exceptions caused by cancellation.
if (this == &joinNode.left) {
joinNode.right.disarm();
kj::runCatchingExceptions([&]() { joinNode.right.dependency = nullptr; });
} else {
joinNode.left.disarm();
kj::runCatchingExceptions([&]() { joinNode.left.dependency = nullptr; });
}
if (joinNode.onReadyEvent == nullptr) {
joinNode.onReadyEvent = _kJ_ALREADY_READY;
} else {
joinNode.onReadyEvent->arm();
}
}
}
// -------------------------------------------------------------------
CrossThreadPromiseNodeBase::CrossThreadPromiseNodeBase( CrossThreadPromiseNodeBase::CrossThreadPromiseNodeBase(
const EventLoop& loop, Own<PromiseNode>&& dependency, ExceptionOrValue& resultRef) const EventLoop& loop, Own<PromiseNode>&& dependency, ExceptionOrValue& resultRef)
: Event(loop), dependency(kj::mv(dependency)), resultRef(resultRef) { : Event(loop), dependency(kj::mv(dependency)), resultRef(resultRef) {
......
...@@ -294,6 +294,10 @@ public: ...@@ -294,6 +294,10 @@ public:
// Like `Promise::fork()`, but manages the fork on *this* EventLoop rather than the thread's // Like `Promise::fork()`, but manages the fork on *this* EventLoop rather than the thread's
// current loop. See Promise::fork(). // current loop. See Promise::fork().
template <typename T>
Promise<T> exclusiveJoin(Promise<T>&& promise1, Promise<T>&& promise2) const;
// Like `promise1.exclusiveJoin(promise2)`, returning the joined promise.
// ----------------------------------------------------------------- // -----------------------------------------------------------------
// Low-level interface. // Low-level interface.
...@@ -632,6 +636,13 @@ public: ...@@ -632,6 +636,13 @@ public:
// `Own<U>`, `U` must have a method `Own<const U> addRef() const` which returns a new reference // `Own<U>`, `U` must have a method `Own<const U> addRef() const` which returns a new reference
// to the same (or an equivalent) object (probably implemented via reference counting). // to the same (or an equivalent) object (probably implemented via reference counting).
void exclusiveJoin(Promise<T>&& other);
// Replace this promise with one that resolves when either the original promise resolves or
// `other` resolves (whichever comes first). The promise that didn't resolve first is canceled.
// TODO(someday): inclusiveJoin(), or perhaps just join(), which waits for both completions
// and produces a tuple?
template <typename... Attachments> template <typename... Attachments>
void attach(Attachments&&... attachments); void attach(Attachments&&... attachments);
// "Attaches" one or more movable objects (often, Own<T>s) to the promise, such that they will // "Attaches" one or more movable objects (often, Own<T>s) to the promise, such that they will
...@@ -1276,6 +1287,40 @@ Own<PromiseNode>&& maybeChain(Own<PromiseNode>&& node, T*) { ...@@ -1276,6 +1287,40 @@ Own<PromiseNode>&& maybeChain(Own<PromiseNode>&& node, T*) {
// ------------------------------------------------------------------- // -------------------------------------------------------------------
class ExclusiveJoinPromiseNode final: public PromiseNode {
public:
ExclusiveJoinPromiseNode(const EventLoop& loop, Own<PromiseNode> left, Own<PromiseNode> right);
~ExclusiveJoinPromiseNode() noexcept(false);
bool onReady(EventLoop::Event& event) noexcept override;
void get(ExceptionOrValue& output) noexcept override;
Maybe<const EventLoop&> getSafeEventLoop() noexcept override;
private:
class Branch: public EventLoop::Event {
public:
Branch(const EventLoop& loop, ExclusiveJoinPromiseNode& joinNode, Own<PromiseNode> dependency);
~Branch() noexcept(false);
bool get(ExceptionOrValue& output);
// Returns true if this is the side that finished.
void fire() override;
private:
bool isWaiting = false;
bool finished = false;
ExclusiveJoinPromiseNode& joinNode;
Own<PromiseNode> dependency;
};
Branch left;
Branch right;
EventLoop::Event* onReadyEvent = nullptr;
};
// -------------------------------------------------------------------
class CrossThreadPromiseNodeBase: public PromiseNode, protected EventLoop::Event { class CrossThreadPromiseNodeBase: public PromiseNode, protected EventLoop::Event {
// A PromiseNode that safely imports a promised value from one EventLoop to another (which // A PromiseNode that safely imports a promised value from one EventLoop to another (which
// implies crossing threads). // implies crossing threads).
...@@ -1505,6 +1550,21 @@ Promise<_::Forked<T>> ForkedPromise<T>::addBranch() const { ...@@ -1505,6 +1550,21 @@ Promise<_::Forked<T>> ForkedPromise<T>::addBranch() const {
return hub->addBranch(); return hub->addBranch();
} }
template <typename T>
void Promise<T>::exclusiveJoin(Promise<T>&& other) {
auto& loop = EventLoop::current();
node = heap<_::ExclusiveJoinPromiseNode>(loop,
_::makeSafeForLoop<_::FixVoid<T>>(kj::mv(node), loop),
_::makeSafeForLoop<_::FixVoid<T>>(kj::mv(other.node), loop));
}
template <typename T>
Promise<T> EventLoop::exclusiveJoin(Promise<T>&& promise1, Promise<T>&& promise2) const {
return Promise<T>(false, heap<_::ExclusiveJoinPromiseNode>(*this,
_::makeSafeForLoop<_::FixVoid<T>>(kj::mv(promise1.node), *this),
_::makeSafeForLoop<_::FixVoid<T>>(kj::mv(promise2.node), *this)));
}
template <typename T> template <typename T>
template <typename... Attachments> template <typename... Attachments>
void Promise<T>::attach(Attachments&&... attachments) { void Promise<T>::attach(Attachments&&... attachments) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment