Commit 281b9faf authored by Kenton Varda's avatar Kenton Varda

Optimize promise tail calls by making ChainPromiseNode automatically detect and…

Optimize promise tail calls by making ChainPromiseNode automatically detect and remove redundant nodes.
parent fe3f7212
......@@ -136,6 +136,12 @@ public:
virtual void onReady(Event& event) noexcept = 0;
// Arms the given event when ready.
virtual void setSelfPointer(Own<PromiseNode>* selfPtr) noexcept;
// Tells the node that `selfPtr` is the pointer that owns this node, and will continue to own
// this node until it is destroyed or setSelfPointer() is called again. ChainPromiseNode uses
// this to shorten redundant chains. The default implementation does nothing; only
// ChainPromiseNode should implement this.
virtual void get(ExceptionOrValue& output) noexcept = 0;
// Get the result. `output` points to an ExceptionOr<T> into which the result will be written.
// Can only be called once, and only after the node is ready. Must be called directly from the
......@@ -396,12 +402,18 @@ inline ExceptionOrValue& ForkBranchBase::getHubResultRef() {
// -------------------------------------------------------------------
class ChainPromiseNode final: public PromiseNode, private Event {
class ChainPromiseNode final: public PromiseNode, public Event {
// Promise node which reduces Promise<Promise<T>> to Promise<T>.
//
// `Event` is only a public base class because otherwise we can't cast Own<ChainPromiseNode> to
// Own<Event>. Ugh, templates and private...
public:
explicit ChainPromiseNode(Own<PromiseNode> inner);
~ChainPromiseNode() noexcept(false);
void onReady(Event& event) noexcept override;
void setSelfPointer(Own<PromiseNode>* selfPtr) noexcept override;
void get(ExceptionOrValue& output) noexcept override;
PromiseNode* getInnerForTrace() override;
......@@ -418,6 +430,7 @@ private:
// In STEP2, a PromiseNode for a T.
Event* onReadyEvent = nullptr;
Own<PromiseNode>* selfPtr = nullptr;
Maybe<Own<Event>> fire() override;
};
......
......@@ -173,6 +173,120 @@ TEST(Async, Chain) {
EXPECT_EQ(444, promise3.wait(waitScope));
}
TEST(Async, DeepChain) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<void> promise = NEVER_DONE;
// Create a ridiculous chain of promises.
for (uint i = 0; i < 1000; i++) {
promise = evalLater(mvCapture(promise, [&,i](Promise<void> promise) {
return kj::mv(promise);
}));
}
loop.run();
auto trace = promise.trace();
uint lines = 0;
for (char c: trace) {
lines += c == '\n';
}
// Chain nodes should have been collapsed such that instead of a chain of 1000 nodes, we have
// 2-ish nodes. We'll give a little room for implementation freedom.
EXPECT_LT(lines, 5);
}
TEST(Async, DeepChain2) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<void> promise = nullptr;
promise = evalLater([&]() {
auto trace = promise.trace();
uint lines = 0;
for (char c: trace) {
lines += c == '\n';
}
// Chain nodes should have been collapsed such that instead of a chain of 1000 nodes, we have
// 2-ish nodes. We'll give a little room for implementation freedom.
EXPECT_LT(lines, 5);
});
// Create a ridiculous chain of promises.
for (uint i = 0; i < 1000; i++) {
promise = evalLater(mvCapture(promise, [&](Promise<void> promise) {
return kj::mv(promise);
}));
}
promise.wait(waitScope);
}
Promise<void> makeChain(uint i) {
if (i > 0) {
return evalLater([i]() -> Promise<void> {
return makeChain(i - 1);
});
} else {
return NEVER_DONE;
}
}
TEST(Async, DeepChain3) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<void> promise = makeChain(1000);
loop.run();
auto trace = promise.trace();
uint lines = 0;
for (char c: trace) {
lines += c == '\n';
}
// Chain nodes should have been collapsed such that instead of a chain of 1000 nodes, we have
// 2-ish nodes. We'll give a little room for implementation freedom.
EXPECT_LT(lines, 5);
}
Promise<void> makeChain2(uint i, Promise<void> promise) {
if (i > 0) {
return evalLater(mvCapture(promise, [i](Promise<void>&& promise) -> Promise<void> {
return makeChain2(i - 1, kj::mv(promise));
}));
} else {
return kj::mv(promise);
}
}
TEST(Async, DeepChain4) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<void> promise = nullptr;
promise = evalLater([&]() {
auto trace = promise.trace();
uint lines = 0;
for (char c: trace) {
lines += c == '\n';
}
// Chain nodes should have been collapsed such that instead of a chain of 1000 nodes, we have
// 2-ish nodes. We'll give a little room for implementation freedom.
EXPECT_LT(lines, 5);
});
promise = makeChain2(1000, kj::mv(promise));
promise.wait(waitScope);
}
TEST(Async, SeparateFulfiller) {
EventLoop loop;
WaitScope waitScope(loop);
......
......@@ -108,6 +108,7 @@ public:
public:
Task(TaskSetImpl& taskSet, Own<_::PromiseNode>&& nodeParam)
: taskSet(taskSet), node(kj::mv(nodeParam)) {
node->setSelfPointer(&node);
node->onReady(*this);
}
......@@ -305,6 +306,7 @@ void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope
KJ_REQUIRE(!loop.running, "wait() is not allowed from within event callbacks.");
BoolEvent doneEvent;
node->setSelfPointer(&node);
node->onReady(doneEvent);
loop.running = true;
......@@ -421,7 +423,7 @@ _::PromiseNode* Event::getInnerForTrace() {
static kj::String demangleTypeName(const char* name) {
int status;
char* buf = abi::__cxa_demangle(name, nullptr, nullptr, &status);
kj::String result = kj::heapString(buf);
kj::String result = kj::heapString(buf == nullptr ? name : buf);
free(buf);
return kj::mv(result);
}
......@@ -478,6 +480,8 @@ kj::String PromiseBase::trace() {
return traceImpl(nullptr, node);
}
void PromiseNode::setSelfPointer(Own<PromiseNode>* selfPtr) noexcept {}
PromiseNode* PromiseNode::getInnerForTrace() { return nullptr; }
void PromiseNode::OnReadyEvent::init(Event& newEvent) {
......@@ -517,8 +521,10 @@ void ImmediateBrokenPromiseNode::get(ExceptionOrValue& output) noexcept {
// -------------------------------------------------------------------
AttachmentPromiseNodeBase::AttachmentPromiseNodeBase(Own<PromiseNode>&& dependency)
: dependency(kj::mv(dependency)) {}
AttachmentPromiseNodeBase::AttachmentPromiseNodeBase(Own<PromiseNode>&& dependencyParam)
: dependency(kj::mv(dependencyParam)) {
dependency->setSelfPointer(&dependency);
}
void AttachmentPromiseNodeBase::onReady(Event& event) noexcept {
dependency->onReady(event);
......@@ -538,8 +544,10 @@ void AttachmentPromiseNodeBase::dropDependency() {
// -------------------------------------------------------------------
TransformPromiseNodeBase::TransformPromiseNodeBase(Own<PromiseNode>&& dependency)
: dependency(kj::mv(dependency)) {}
TransformPromiseNodeBase::TransformPromiseNodeBase(Own<PromiseNode>&& dependencyParam)
: dependency(kj::mv(dependencyParam)) {
dependency->setSelfPointer(&dependency);
}
void TransformPromiseNodeBase::onReady(Event& event) noexcept {
dependency->onReady(event);
......@@ -617,6 +625,7 @@ PromiseNode* ForkBranchBase::getInnerForTrace() {
ForkHubBase::ForkHubBase(Own<PromiseNode>&& innerParam, ExceptionOrValue& resultRef)
: inner(kj::mv(innerParam)), resultRef(resultRef) {
inner->setSelfPointer(&inner);
inner->onReady(*this);
}
......@@ -650,6 +659,7 @@ _::PromiseNode* ForkHubBase::getInnerForTrace() {
ChainPromiseNode::ChainPromiseNode(Own<PromiseNode> innerParam)
: state(STEP1), inner(kj::mv(innerParam)) {
inner->setSelfPointer(&inner);
inner->onReady(*this);
}
......@@ -668,6 +678,15 @@ void ChainPromiseNode::onReady(Event& event) noexcept {
KJ_UNREACHABLE;
}
void ChainPromiseNode::setSelfPointer(Own<PromiseNode>* selfPtr) noexcept {
if (state == STEP2) {
*selfPtr = kj::mv(inner); // deletes this!
selfPtr->get()->setSelfPointer(selfPtr);
} else {
this->selfPtr = selfPtr;
}
}
void ChainPromiseNode::get(ExceptionOrValue& output) noexcept {
KJ_REQUIRE(state == STEP2);
return inner->get(output);
......@@ -708,11 +727,25 @@ Maybe<Own<Event>> ChainPromiseNode::fire() {
}
state = STEP2;
if (onReadyEvent != nullptr) {
inner->onReady(*onReadyEvent);
}
if (selfPtr != nullptr) {
// Hey, we can shorten the chain here.
auto chain = selfPtr->downcast<ChainPromiseNode>();
*selfPtr = kj::mv(inner);
selfPtr->get()->setSelfPointer(selfPtr);
if (onReadyEvent != nullptr) {
selfPtr->get()->onReady(*onReadyEvent);
}
return nullptr;
// Return our self-pointer so that the caller takes care of deleting it.
return Own<Event>(kj::mv(chain));
} else {
inner->setSelfPointer(&inner);
if (onReadyEvent != nullptr) {
inner->onReady(*onReadyEvent);
}
return nullptr;
}
}
// -------------------------------------------------------------------
......@@ -741,6 +774,7 @@ PromiseNode* ExclusiveJoinPromiseNode::getInnerForTrace() {
ExclusiveJoinPromiseNode::Branch::Branch(
ExclusiveJoinPromiseNode& joinNode, Own<PromiseNode> dependencyParam)
: joinNode(joinNode), dependency(kj::mv(dependencyParam)) {
dependency->setSelfPointer(&dependency);
dependency->onReady(*this);
}
......@@ -776,6 +810,7 @@ PromiseNode* ExclusiveJoinPromiseNode::Branch::getInnerForTrace() {
EagerPromiseNodeBase::EagerPromiseNodeBase(
Own<PromiseNode>&& dependencyParam, ExceptionOrValue& resultRef)
: dependency(kj::mv(dependencyParam)), resultRef(resultRef) {
dependency->setSelfPointer(&dependency);
dependency->onReady(*this);
}
......
......@@ -47,6 +47,26 @@ TEST(Memory, CanConvert) {
static_assert(!canConvert<Own<Super>, Own<Sub>>(), "failure");
}
struct Nested {
Nested(bool& destroyed): destroyed(destroyed) {}
~Nested() { destroyed = true; }
bool& destroyed;
Own<Nested> nested;
};
TEST(Memory, AssignNested) {
bool destroyed1 = false, destroyed2 = false;
auto nested = heap<Nested>(destroyed1);
nested->nested = heap<Nested>(destroyed2);
EXPECT_FALSE(destroyed1 || destroyed2);
nested = kj::mv(nested->nested);
EXPECT_TRUE(destroyed1);
EXPECT_FALSE(destroyed2);
nested = nullptr;
EXPECT_TRUE(destroyed1 && destroyed2);
}
// TODO(test): More tests.
} // namespace
......
......@@ -128,10 +128,18 @@ public:
~Own() noexcept(false) { dispose(); }
inline Own& operator=(Own&& other) {
dispose();
// Move-assingnment operator.
// Careful, this might own `other`. Therefore we have to transfer the pointers first, then
// dispose.
const Disposer* disposerCopy = disposer;
T* ptrCopy = ptr;
disposer = other.disposer;
ptr = other.ptr;
other.ptr = nullptr;
if (ptrCopy != nullptr) {
disposerCopy->dispose(const_cast<RemoveConst<T>*>(ptrCopy));
}
return *this;
}
......@@ -140,6 +148,21 @@ public:
return *this;
}
template <typename U>
Own<U> downcast() {
// Downcast the pointer to Own<U>, destroying the original pointer. If this pointer does not
// actually point at an instance of U, the results are undefined (throws an exception in debug
// mode if RTTI is enabled, otherwise you're on your own).
Own<U> result;
if (ptr != nullptr) {
result.ptr = &kj::downcast<U>(*ptr);
result.disposer = disposer;
ptr = nullptr;
}
return result;
}
inline T* operator->() { return ptr; }
inline const T* operator->() const { return ptr; }
inline T& operator*() { return *ptr; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment