Commit c28fb99c authored by Kenton Varda's avatar Kenton Varda

Refactor: Promise should have fewer friends.

Instead of making `Promise` friend everything that needs to construct a `Promise<T>` from an `Own<PromiseNode>` or vice versa, let's just friend `PromiseNode` itself -- which is already a "private" class by virtue of being in the `_` namespace -- and let it provide some static methods to do the conversions.
parent 0497297c
...@@ -194,6 +194,22 @@ public: ...@@ -194,6 +194,22 @@ public:
// If this node wraps some other PromiseNode, get the wrapped node. Used for debug tracing. // If this node wraps some other PromiseNode, get the wrapped node. Used for debug tracing.
// Default implementation returns nullptr. // Default implementation returns nullptr.
template <typename T>
static Own<PromiseNode> from(T&& promise) {
// Given a Promise, extract the PromiseNode.
return kj::mv(promise.node);
}
template <typename T>
static PromiseNode& from(T& promise) {
// Given a Promise, extract the PromiseNode.
return *promise.node;
}
template <typename T>
static T to(Own<PromiseNode>&& node) {
// Construct a Promise from a PromiseNode. (T should be a Promise type.)
return T(false, kj::mv(node));
}
protected: protected:
class OnReadyEvent { class OnReadyEvent {
// Helper class for implementing onReady(). // Helper class for implementing onReady().
...@@ -213,6 +229,13 @@ protected: ...@@ -213,6 +229,13 @@ protected:
// ------------------------------------------------------------------- // -------------------------------------------------------------------
template <typename T>
inline NeverDone::operator Promise<T>() const {
return PromiseNode::to<Promise<T>>(neverDone());
}
// -------------------------------------------------------------------
class ImmediatePromiseNodeBase: public PromiseNode { class ImmediatePromiseNodeBase: public PromiseNode {
public: public:
ImmediatePromiseNodeBase(); ImmediatePromiseNodeBase();
...@@ -557,7 +580,7 @@ public: ...@@ -557,7 +580,7 @@ public:
ForkHub(Own<PromiseNode>&& inner): ForkHubBase(kj::mv(inner), result) {} ForkHub(Own<PromiseNode>&& inner): ForkHubBase(kj::mv(inner), result) {}
Promise<_::UnfixVoid<T>> addBranch() { Promise<_::UnfixVoid<T>> addBranch() {
return Promise<_::UnfixVoid<T>>(false, kj::heap<ForkBranch<T>>(addRef(*this))); return _::PromiseNode::to<Promise<_::UnfixVoid<T>>>(kj::heap<ForkBranch<T>>(addRef(*this)));
} }
_::SplitTuplePromise<T> split() { _::SplitTuplePromise<T> split() {
...@@ -574,9 +597,9 @@ private: ...@@ -574,9 +597,9 @@ private:
template <size_t index> template <size_t index>
ReducePromises<typename SplitBranch<T, index>::Element> addSplit() { ReducePromises<typename SplitBranch<T, index>::Element> addSplit() {
return ReducePromises<typename SplitBranch<T, index>::Element>( return _::PromiseNode::to<ReducePromises<typename SplitBranch<T, index>::Element>>(
false, maybeChain(kj::heap<SplitBranch<T, index>>(addRef(*this)), maybeChain(kj::heap<SplitBranch<T, index>>(addRef(*this)),
implicitCast<typename SplitBranch<T, index>::Element*>(nullptr))); implicitCast<typename SplitBranch<T, index>::Element*>(nullptr)));
} }
}; };
...@@ -868,7 +891,7 @@ PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler ...@@ -868,7 +891,7 @@ PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler
Own<_::PromiseNode> intermediate = Own<_::PromiseNode> intermediate =
heap<_::TransformPromiseNode<ResultT, _::FixVoid<T>, Func, ErrorFunc>>( heap<_::TransformPromiseNode<ResultT, _::FixVoid<T>, Func, ErrorFunc>>(
kj::mv(node), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler)); kj::mv(node), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler));
auto result = _::ChainPromises<_::ReturnType<Func, T>>(false, auto result = _::PromiseNode::to<_::ChainPromises<_::ReturnType<Func, T>>>(
_::maybeChain(kj::mv(intermediate), implicitCast<ResultT*>(nullptr))); _::maybeChain(kj::mv(intermediate), implicitCast<ResultT*>(nullptr)));
return _::maybeReduce(kj::mv(result), false); return _::maybeReduce(kj::mv(result), false);
} }
...@@ -1005,8 +1028,8 @@ void Promise<void>::detach(ErrorFunc&& errorHandler) { ...@@ -1005,8 +1028,8 @@ void Promise<void>::detach(ErrorFunc&& errorHandler) {
template <typename T> template <typename T>
Promise<Array<T>> joinPromises(Array<Promise<T>>&& promises) { Promise<Array<T>> joinPromises(Array<Promise<T>>&& promises) {
return Promise<Array<T>>(false, kj::heap<_::ArrayJoinPromiseNode<T>>( return _::PromiseNode::to<Promise<Array<T>>>(kj::heap<_::ArrayJoinPromiseNode<T>>(
KJ_MAP(p, promises) { return kj::mv(p.node); }, KJ_MAP(p, promises) { return _::PromiseNode::from(kj::mv(p)); },
heapArray<_::ExceptionOr<T>>(promises.size()))); heapArray<_::ExceptionOr<T>>(promises.size())));
} }
...@@ -1134,7 +1157,7 @@ _::ReducePromises<T> newAdaptedPromise(Params&&... adapterConstructorParams) { ...@@ -1134,7 +1157,7 @@ _::ReducePromises<T> newAdaptedPromise(Params&&... adapterConstructorParams) {
Own<_::PromiseNode> intermediate( Own<_::PromiseNode> intermediate(
heap<_::AdapterPromiseNode<_::FixVoid<T>, Adapter>>( heap<_::AdapterPromiseNode<_::FixVoid<T>, Adapter>>(
kj::fwd<Params>(adapterConstructorParams)...)); kj::fwd<Params>(adapterConstructorParams)...));
return _::ReducePromises<T>(false, return _::PromiseNode::to<_::ReducePromises<T>>(
_::maybeChain(kj::mv(intermediate), implicitCast<T*>(nullptr))); _::maybeChain(kj::mv(intermediate), implicitCast<T*>(nullptr)));
} }
...@@ -1144,7 +1167,7 @@ PromiseFulfillerPair<T> newPromiseAndFulfiller() { ...@@ -1144,7 +1167,7 @@ PromiseFulfillerPair<T> newPromiseAndFulfiller() {
Own<_::PromiseNode> intermediate( Own<_::PromiseNode> intermediate(
heap<_::AdapterPromiseNode<_::FixVoid<T>, _::PromiseAndFulfillerAdapter<T>>>(*wrapper)); heap<_::AdapterPromiseNode<_::FixVoid<T>, _::PromiseAndFulfillerAdapter<T>>>(*wrapper));
_::ReducePromises<T> promise(false, auto promise = _::PromiseNode::to<_::ReducePromises<T>>(
_::maybeChain(kj::mv(intermediate), implicitCast<T*>(nullptr))); _::maybeChain(kj::mv(intermediate), implicitCast<T*>(nullptr)));
return PromiseFulfillerPair<T> { kj::mv(promise), kj::mv(wrapper) }; return PromiseFulfillerPair<T> { kj::mv(promise), kj::mv(wrapper) };
...@@ -1171,9 +1194,6 @@ protected: ...@@ -1171,9 +1194,6 @@ protected:
// Run the function. If the function returns a promise, returns the inner PromiseNode, otherwise // Run the function. If the function returns a promise, returns the inner PromiseNode, otherwise
// returns null. // returns null.
template <typename T>
Own<PromiseNode> extractNode(Promise<T> promise) { return kj::mv(promise.node); }
// implements PromiseNode ---------------------------------------------------- // implements PromiseNode ----------------------------------------------------
void onReady(Event* event) noexcept override; void onReady(Event* event) noexcept override;
...@@ -1271,7 +1291,7 @@ public: ...@@ -1271,7 +1291,7 @@ public:
typedef _::FixVoid<_::UnwrapPromise<PromiseForResult<Func, void>>> ResultT; typedef _::FixVoid<_::UnwrapPromise<PromiseForResult<Func, void>>> ResultT;
kj::Maybe<Own<_::PromiseNode>> execute() override { kj::Maybe<Own<_::PromiseNode>> execute() override {
auto result = extractNode(func()); auto result = _::PromiseNode::from(func());
KJ_IREQUIRE(result.get() != nullptr); KJ_IREQUIRE(result.get() != nullptr);
return kj::mv(result); return kj::mv(result);
} }
...@@ -1300,7 +1320,7 @@ template <typename Func> ...@@ -1300,7 +1320,7 @@ template <typename Func>
PromiseForResult<Func, void> Executor::executeAsync(Func&& func) const { PromiseForResult<Func, void> Executor::executeAsync(Func&& func) const {
auto event = kj::heap<_::XThreadEventImpl<Func>>(kj::fwd<Func>(func), *this); auto event = kj::heap<_::XThreadEventImpl<Func>>(kj::fwd<Func>(func), *this);
send(*event, false); send(*event, false);
return PromiseForResult<Func, void>(false, kj::mv(event)); return _::PromiseNode::to<PromiseForResult<Func, void>>(kj::mv(event));
} }
} // namespace kj } // namespace kj
......
...@@ -203,15 +203,9 @@ private: ...@@ -203,15 +203,9 @@ private:
PromiseBase() = default; PromiseBase() = default;
PromiseBase(Own<PromiseNode>&& node): node(kj::mv(node)) {} PromiseBase(Own<PromiseNode>&& node): node(kj::mv(node)) {}
friend class kj::EventLoop;
friend class ChainPromiseNode;
template <typename> template <typename>
friend class kj::Promise; friend class kj::Promise;
friend class kj::TaskSet; friend class PromiseNode;
template <typename U>
friend Promise<Array<U>> kj::joinPromises(Array<Promise<U>>&& promises);
friend Promise<void> kj::joinPromises(Array<Promise<void>>&& promises);
friend class XThreadEvent;
}; };
void detach(kj::Promise<void>&& promise); void detach(kj::Promise<void>&& promise);
...@@ -224,9 +218,7 @@ Own<PromiseNode> neverDone(); ...@@ -224,9 +218,7 @@ Own<PromiseNode> neverDone();
class NeverDone { class NeverDone {
public: public:
template <typename T> template <typename T>
operator Promise<T>() const { operator Promise<T>() const;
return Promise<T>(false, neverDone());
}
KJ_NORETURN(void wait(WaitScope& waitScope) const); KJ_NORETURN(void wait(WaitScope& waitScope) const);
}; };
......
...@@ -234,7 +234,7 @@ private: ...@@ -234,7 +234,7 @@ private:
}; };
void TaskSet::add(Promise<void>&& promise) { void TaskSet::add(Promise<void>&& promise) {
auto task = heap<Task>(*this, kj::mv(promise.node)); auto task = heap<Task>(*this, _::PromiseNode::from(kj::mv(promise)));
KJ_IF_MAYBE(head, tasks) { KJ_IF_MAYBE(head, tasks) {
head->get()->prev = &task->next; head->get()->prev = &task->next;
task->next = kj::mv(tasks); task->next = kj::mv(tasks);
...@@ -931,11 +931,11 @@ bool pollImpl(_::PromiseNode& node, WaitScope& waitScope) { ...@@ -931,11 +931,11 @@ bool pollImpl(_::PromiseNode& node, WaitScope& waitScope) {
} }
Promise<void> yield() { Promise<void> yield() {
return Promise<void>(false, kj::heap<YieldPromiseNode>()); return _::PromiseNode::to<Promise<void>>(kj::heap<YieldPromiseNode>());
} }
Promise<void> yieldHarder() { Promise<void> yieldHarder() {
return Promise<void>(false, kj::heap<YieldHarderPromiseNode>()); return _::PromiseNode::to<Promise<void>>(kj::heap<YieldHarderPromiseNode>());
} }
Own<PromiseNode> neverDone() { Own<PromiseNode> neverDone() {
...@@ -1379,7 +1379,7 @@ Maybe<Own<Event>> ChainPromiseNode::fire() { ...@@ -1379,7 +1379,7 @@ Maybe<Own<Event>> ChainPromiseNode::fire() {
} else KJ_IF_MAYBE(value, intermediate.value) { } else KJ_IF_MAYBE(value, intermediate.value) {
// There is a value and no exception. The value is itself a promise. Adopt it as our // There is a value and no exception. The value is itself a promise. Adopt it as our
// step2. // step2.
inner = kj::mv(value->node); inner = _::PromiseNode::from(kj::mv(*value));
} else { } else {
// We can only get here if inner->get() returned neither an exception nor a // We can only get here if inner->get() returned neither an exception nor a
// value, which never actually happens. // value, which never actually happens.
...@@ -1551,8 +1551,8 @@ void ArrayJoinPromiseNode<void>::getNoError(ExceptionOrValue& output) noexcept { ...@@ -1551,8 +1551,8 @@ void ArrayJoinPromiseNode<void>::getNoError(ExceptionOrValue& output) noexcept {
} // namespace _ (private) } // namespace _ (private)
Promise<void> joinPromises(Array<Promise<void>>&& promises) { Promise<void> joinPromises(Array<Promise<void>>&& promises) {
return Promise<void>(false, kj::heap<_::ArrayJoinPromiseNode<void>>( return _::PromiseNode::to<Promise<void>>(kj::heap<_::ArrayJoinPromiseNode<void>>(
KJ_MAP(p, promises) { return kj::mv(p.node); }, KJ_MAP(p, promises) { return _::PromiseNode::from(kj::mv(p)); },
heapArray<_::ExceptionOr<_::Void>>(promises.size()))); heapArray<_::ExceptionOr<_::Void>>(promises.size())));
} }
......
...@@ -305,24 +305,7 @@ private: ...@@ -305,24 +305,7 @@ private:
Promise(bool, Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {} Promise(bool, Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {}
// Second parameter prevent ambiguity with immediate-value constructor. // Second parameter prevent ambiguity with immediate-value constructor.
template <typename> friend class _::PromiseNode;
friend class Promise;
friend class EventLoop;
template <typename U, typename Adapter, typename... Params>
friend _::ReducePromises<U> newAdaptedPromise(Params&&... adapterConstructorParams);
template <typename U>
friend PromiseFulfillerPair<U> newPromiseAndFulfiller();
template <typename>
friend class _::ForkHub;
friend class TaskSet;
friend Promise<void> _::yield();
friend Promise<void> _::yieldHarder();
friend class _::NeverDone;
template <typename U>
friend Promise<Array<U>> joinPromises(Array<Promise<U>>&& promises);
friend Promise<void> joinPromises(Array<Promise<void>>&& promises);
friend class _::XThreadEvent;
friend class Executor;
}; };
template <typename T> template <typename T>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment