Commit fe5b21e8 authored by Kenton Varda's avatar Kenton Varda

Using promises for references to represent each branch of a fork doesn't work,…

Using promises for references to represent each branch of a fork doesn't work, because the fork hub will be deleted when the last branch resolves, before any continuation can run.  Instead, let's just institutionalize the concept of addRef() for owned pointers.
parent 06945999
...@@ -320,19 +320,59 @@ TEST(Async, Fork) { ...@@ -320,19 +320,59 @@ TEST(Async, Fork) {
SimpleEventLoop loop; SimpleEventLoop loop;
auto outer = loop.evalLater([&]() { auto outer = loop.evalLater([&]() {
Promise<String> promise = loop.evalLater([&]() { return str("foo"); }); Promise<int> promise = loop.evalLater([&]() { return 123; });
auto fork = promise.fork(); auto fork = promise.fork();
auto branch1 = fork->addBranch().then([](const String& s) { auto branch1 = fork->addBranch().then([](int i) {
EXPECT_EQ("foo", s); EXPECT_EQ(123, i);
return 456; return 456;
}); });
auto branch2 = fork->addBranch().then([](const String& s) { auto branch2 = fork->addBranch().then([](int i) {
EXPECT_EQ("foo", s); EXPECT_EQ(123, i);
return 789; return 789;
}); });
{
auto releaseFork = kj::mv(fork);
}
EXPECT_EQ(456, loop.wait(kj::mv(branch1)));
EXPECT_EQ(789, loop.wait(kj::mv(branch2)));
});
loop.wait(kj::mv(outer));
}
struct RefcountedInt: public Refcounted {
RefcountedInt(int i): i(i) {}
int i;
Own<const RefcountedInt> addRef() const { return kj::addRef(*this); }
};
TEST(Async, ForkRef) {
SimpleEventLoop loop;
auto outer = loop.evalLater([&]() {
Promise<Own<RefcountedInt>> promise = loop.evalLater([&]() {
return refcounted<RefcountedInt>(123);
});
auto fork = promise.fork();
auto branch1 = fork->addBranch().then([](Own<const RefcountedInt>&& i) {
EXPECT_EQ(123, i->i);
return 456;
});
auto branch2 = fork->addBranch().then([](Own<const RefcountedInt>&& i) {
EXPECT_EQ(123, i->i);
return 789;
});
{
auto releaseFork = kj::mv(fork);
}
EXPECT_EQ(456, loop.wait(kj::mv(branch1))); EXPECT_EQ(456, loop.wait(kj::mv(branch1)));
EXPECT_EQ(789, loop.wait(kj::mv(branch2))); EXPECT_EQ(789, loop.wait(kj::mv(branch2)));
}); });
......
...@@ -93,19 +93,6 @@ using ReturnType = typename ReturnType_<Func, T>::Type; ...@@ -93,19 +93,6 @@ using ReturnType = typename ReturnType_<Func, T>::Type;
// The return type of functor Func given a parameter of type T, with the special exception that if // The return type of functor Func given a parameter of type T, with the special exception that if
// T is void, this is the return type of Func called with no arguments. // T is void, this is the return type of Func called with no arguments.
template <typename T>
struct ConstReferenceTo_ { typedef const T& Type; };
template <typename T>
struct ConstReferenceTo_<T&> { typedef const T& Type; };
template <typename T>
struct ConstReferenceTo_<const T&> { typedef const T& Type; };
template <>
struct ConstReferenceTo_<void> { typedef void Type; };
template <typename T>
using ConstReferenceTo = typename ConstReferenceTo_<T>::Type;
// Resolves to `const T&`, or to `void` if `T` is `void`.
struct Void {}; struct Void {};
// Application code should NOT refer to this! See `kj::READY_NOW` instead. // Application code should NOT refer to this! See `kj::READY_NOW` instead.
...@@ -119,6 +106,13 @@ template <> struct UnfixVoid_<Void> { typedef void Type; }; ...@@ -119,6 +106,13 @@ template <> struct UnfixVoid_<Void> { typedef void Type; };
template <typename T> using UnfixVoid = typename UnfixVoid_<T>::Type; template <typename T> using UnfixVoid = typename UnfixVoid_<T>::Type;
// UnfixVoid is the opposite of FixVoid. // UnfixVoid is the opposite of FixVoid.
template <typename T> struct Forked_ { typedef T Type; };
template <typename T> struct Forked_<T&> { typedef const T& Type; };
template <typename T> struct Forked_<Own<T>> { typedef Own<const T> Type; };
template <typename T> using Forked = typename Forked_<T>::Type;
// Forked<T> transforms T as a result of being forked. If T is an owned pointer or reference,
// it becomes const.
template <typename In, typename Out> template <typename In, typename Out>
struct MaybeVoidCaller { struct MaybeVoidCaller {
// Calls the function converting a Void input to an empty parameter list and a void return // Calls the function converting a Void input to an empty parameter list and a void return
...@@ -594,23 +588,20 @@ public: ...@@ -594,23 +588,20 @@ public:
class Fork { class Fork {
public: public:
virtual Promise<_::ConstReferenceTo<T>> addBranch() = 0; virtual Promise<_::Forked<T>> addBranch() = 0;
// Add a new branch to the fork. The branch is equivalent to the original promise except that // Add a new branch to the fork. The branch is equivalent to the original promise, except
// its type is `Promise<const T&>` rather than `Promise<T>` (except when `T` was already a // that if T is a reference or owned pointer, the target becomes const.
// reference, or was `void`).
}; };
Own<Fork> fork(); Own<Fork> fork();
// Forks the promise, so that multiple different clients can independently wait on the result. // Forks the promise, so that multiple different clients can independently wait on the result.
// Returns an object that can be used to construct branches, all of which are equivalent to the // `T` must be copy-constructable for this to work. Or, in the special case where `T` is
// original promise except that they produce references to its result rather than passing the // `Own<U>`, `U` must have a method `Own<const U> addRef() const` which returns a new reference
// result by move. // to the same (or an equivalent) object (probably implemented via reference counting).
//
// As with `then()` and `wait()`, `fork()` consumes the original promise, in the sense of move
// semantics.
private: private:
Promise(Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {} Promise(bool, Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {}
// Second parameter prevent ambiguity with immediate-value constructor.
template <typename> template <typename>
friend class Promise; friend class Promise;
...@@ -962,6 +953,10 @@ private: ...@@ -962,6 +953,10 @@ private:
friend class ForkHubBase; friend class ForkHubBase;
}; };
template <typename T> T copyOrAddRef(const T& t) { return t; }
template <typename T> Own<const T> copyOrAddRef(const Own<T>& t) { return t->addRef(); }
template <typename T> Own<const T> copyOrAddRef(const Own<const T>& t) { return t->addRef(); }
template <typename T> template <typename T>
class ForkBranch final: public ForkBranchBase { class ForkBranch final: public ForkBranchBase {
// A PromiseNode that implements one branch of a fork -- i.e. one of the branches that receives // A PromiseNode that implements one branch of a fork -- i.e. one of the branches that receives
...@@ -973,9 +968,9 @@ public: ...@@ -973,9 +968,9 @@ public:
void get(ExceptionOrValue& output) noexcept override { void get(ExceptionOrValue& output) noexcept override {
const ExceptionOr<T>& hubResult = getHubResultRef().template as<T>(); const ExceptionOr<T>& hubResult = getHubResultRef().template as<T>();
KJ_IF_MAYBE(value, hubResult.value) { KJ_IF_MAYBE(value, hubResult.value) {
output.as<ConstReferenceTo<T>>().value = *value; output.as<Forked<T>>().value = copyOrAddRef(*value);
} else { } else {
output.as<ConstReferenceTo<T>>().value = nullptr; output.as<Forked<T>>().value = nullptr;
} }
output.exception = hubResult.exception; output.exception = hubResult.exception;
releaseHub(output); releaseHub(output);
...@@ -1020,8 +1015,8 @@ public: ...@@ -1020,8 +1015,8 @@ public:
ForkHub(const EventLoop& loop, Own<PromiseNode>&& inner) ForkHub(const EventLoop& loop, Own<PromiseNode>&& inner)
: ForkHubBase(loop, kj::mv(inner), result) {} : ForkHubBase(loop, kj::mv(inner), result) {}
Promise<_::ConstReferenceTo<T>> addBranch() override { Promise<_::Forked<T>> addBranch() override {
return Promise<_::ConstReferenceTo<T>>(kj::heap<ForkBranch<T>>(addRef(*this))); return Promise<_::Forked<T>>(false, kj::heap<ForkBranch<T>>(addRef(*this)));
} }
private: private:
...@@ -1225,8 +1220,10 @@ auto EventLoop::evalLater(Func&& func) const -> PromiseForResult<Func, void> { ...@@ -1225,8 +1220,10 @@ auto EventLoop::evalLater(Func&& func) const -> PromiseForResult<Func, void> {
template <typename T, typename Func, typename ErrorFunc> template <typename T, typename Func, typename ErrorFunc>
PromiseForResult<Func, T> EventLoop::there( PromiseForResult<Func, T> EventLoop::there(
Promise<T>&& promise, Func&& func, ErrorFunc&& errorHandler) const { Promise<T>&& promise, Func&& func, ErrorFunc&& errorHandler) const {
return _::spark<_::FixVoid<_::JoinPromises<_::ReturnType<Func, T>>>>(thereImpl( return PromiseForResult<Func, T>(false,
kj::mv(promise), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler), Event::YIELD), *this); _::spark<_::FixVoid<_::JoinPromises<_::ReturnType<Func, T>>>>(thereImpl(
kj::mv(promise), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler), Event::YIELD),
*this));
} }
template <typename T, typename Func, typename ErrorFunc> template <typename T, typename Func, typename ErrorFunc>
...@@ -1253,9 +1250,9 @@ Promise<T>::Promise(kj::Exception&& exception) ...@@ -1253,9 +1250,9 @@ Promise<T>::Promise(kj::Exception&& exception)
template <typename T> template <typename T>
template <typename Func, typename ErrorFunc> template <typename Func, typename ErrorFunc>
PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler) { PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler) {
return EventLoop::current().thereImpl( return PromiseForResult<Func, T>(false, EventLoop::current().thereImpl(
kj::mv(*this), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler), kj::mv(*this), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler),
EventLoop::Event::PREEMPT); EventLoop::Event::PREEMPT));
} }
template <typename T> template <typename T>
...@@ -1341,7 +1338,7 @@ PromiseFulfillerPair<T> newPromiseAndFulfiller() { ...@@ -1341,7 +1338,7 @@ PromiseFulfillerPair<T> newPromiseAndFulfiller() {
Own<_::PromiseNode> intermediate( Own<_::PromiseNode> intermediate(
heap<_::AdapterPromiseNode<_::FixVoid<T>, _::PromiseAndFulfillerAdapter<T>>>(*wrapper)); heap<_::AdapterPromiseNode<_::FixVoid<T>, _::PromiseAndFulfillerAdapter<T>>>(*wrapper));
Promise<_::JoinPromises<T>> promise( Promise<_::JoinPromises<T>> promise(false,
_::maybeChain(kj::mv(intermediate), implicitCast<T*>(nullptr))); _::maybeChain(kj::mv(intermediate), implicitCast<T*>(nullptr)));
return PromiseFulfillerPair<T> { kj::mv(promise), kj::mv(wrapper) }; return PromiseFulfillerPair<T> { kj::mv(promise), kj::mv(wrapper) };
...@@ -1353,7 +1350,7 @@ PromiseFulfillerPair<T> newPromiseAndFulfiller(const EventLoop& loop) { ...@@ -1353,7 +1350,7 @@ PromiseFulfillerPair<T> newPromiseAndFulfiller(const EventLoop& loop) {
Own<_::PromiseNode> intermediate( Own<_::PromiseNode> intermediate(
heap<_::AdapterPromiseNode<_::FixVoid<T>, _::PromiseAndFulfillerAdapter<T>>>(*wrapper)); heap<_::AdapterPromiseNode<_::FixVoid<T>, _::PromiseAndFulfillerAdapter<T>>>(*wrapper));
Promise<_::JoinPromises<T>> promise( Promise<_::JoinPromises<T>> promise(false,
_::maybeChain(kj::mv(intermediate), loop, EventLoop::Event::YIELD, _::maybeChain(kj::mv(intermediate), loop, EventLoop::Event::YIELD,
implicitCast<T*>(nullptr))); implicitCast<T*>(nullptr)));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment