Commit 97aad141 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #579 from capnproto/taskset-improvements

Add TaskSet.onEmpty() to wait until all tasks have completed
parents 7c8870eb b62c247b
...@@ -38,6 +38,7 @@ class EventLoop; ...@@ -38,6 +38,7 @@ class EventLoop;
template <typename T> template <typename T>
class Promise; class Promise;
class WaitScope; class WaitScope;
class TaskSet;
template <typename T> template <typename T>
Promise<Array<T>> joinPromises(Array<Promise<T>>&& promises); Promise<Array<T>> joinPromises(Array<Promise<T>>&& promises);
...@@ -172,8 +173,6 @@ class ChainPromiseNode; ...@@ -172,8 +173,6 @@ class ChainPromiseNode;
template <typename T> template <typename T>
class ForkHub; class ForkHub;
class TaskSetImpl;
class Event; class Event;
class PromiseBase { class PromiseBase {
...@@ -191,7 +190,7 @@ private: ...@@ -191,7 +190,7 @@ private:
friend class ChainPromiseNode; friend class ChainPromiseNode;
template <typename> template <typename>
friend class kj::Promise; friend class kj::Promise;
friend class TaskSetImpl; friend class kj::TaskSet;
template <typename U> template <typename U>
friend Promise<Array<U>> kj::joinPromises(Array<Promise<U>>&& promises); friend Promise<Array<U>> kj::joinPromises(Array<Promise<U>>&& promises);
friend Promise<void> kj::joinPromises(Array<Promise<void>>&& promises); friend Promise<void> kj::joinPromises(Array<Promise<void>>&& promises);
......
...@@ -620,6 +620,30 @@ TEST(Async, TaskSet) { ...@@ -620,6 +620,30 @@ TEST(Async, TaskSet) {
EXPECT_EQ(1u, errorHandler.exceptionCount); EXPECT_EQ(1u, errorHandler.exceptionCount);
} }
TEST(Async, TaskSetOnEmpty) {
EventLoop loop;
WaitScope waitScope(loop);
ErrorHandlerImpl errorHandler;
TaskSet tasks(errorHandler);
KJ_EXPECT(tasks.isEmpty());
auto paf = newPromiseAndFulfiller<void>();
tasks.add(kj::mv(paf.promise));
tasks.add(evalLater([]() {}));
KJ_EXPECT(!tasks.isEmpty());
auto promise = tasks.onEmpty();
KJ_EXPECT(!promise.poll(waitScope));
KJ_EXPECT(!tasks.isEmpty());
paf.fulfiller->fulfill();
KJ_ASSERT(promise.poll(waitScope));
KJ_EXPECT(tasks.isEmpty());
promise.wait(waitScope);
}
class DestructorDetector { class DestructorDetector {
public: public:
DestructorDetector(bool& setTrue): setTrue(setTrue) {} DestructorDetector(bool& setTrue): setTrue(setTrue) {}
......
...@@ -23,8 +23,6 @@ ...@@ -23,8 +23,6 @@
#include "debug.h" #include "debug.h"
#include "vector.h" #include "vector.h"
#include "threadlocal.h" #include "threadlocal.h"
#include <exception>
#include <map>
#if KJ_USE_FUTEX #if KJ_USE_FUTEX
#include <unistd.h> #include <unistd.h>
...@@ -86,86 +84,108 @@ public: ...@@ -86,86 +84,108 @@ public:
} // namespace } // namespace
namespace _ { // private TaskSet::TaskSet(TaskSet::ErrorHandler& errorHandler)
: errorHandler(errorHandler) {}
TaskSet::~TaskSet() noexcept(false) {}
class TaskSetImpl { class TaskSet::Task final: public _::Event {
public: public:
inline TaskSetImpl(TaskSet::ErrorHandler& errorHandler) Task(TaskSet& taskSet, Own<_::PromiseNode>&& nodeParam)
: errorHandler(errorHandler) {} : taskSet(taskSet), node(kj::mv(nodeParam)) {
node->setSelfPointer(&node);
~TaskSetImpl() noexcept(false) { node->onReady(this);
// std::map doesn't like it when elements' destructors throw, so carefully disassemble it. }
if (!tasks.empty()) {
Vector<Own<Task>> deleteMe(tasks.size()); Maybe<Own<Task>> next;
for (auto& entry: tasks) { Maybe<Own<Task>>* prev = nullptr;
deleteMe.add(kj::mv(entry.second));
} protected:
Maybe<Own<Event>> fire() override {
// Get the result.
_::ExceptionOr<_::Void> result;
node->get(result);
// Delete the node, catching any exceptions.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
node = nullptr;
})) {
result.addException(kj::mv(*exception));
} }
}
class Task final: public Event { // Call the error handler if there was an exception.
public: KJ_IF_MAYBE(e, result.exception) {
Task(TaskSetImpl& taskSet, Own<_::PromiseNode>&& nodeParam) taskSet.errorHandler.taskFailed(kj::mv(*e));
: taskSet(taskSet), node(kj::mv(nodeParam)) {
node->setSelfPointer(&node);
node->onReady(this);
} }
protected: // Remove from the task list.
Maybe<Own<Event>> fire() override { KJ_IF_MAYBE(n, next) {
// Get the result. n->get()->prev = prev;
_::ExceptionOr<_::Void> result; }
node->get(result); Own<Event> self = kj::mv(KJ_ASSERT_NONNULL(*prev));
KJ_ASSERT(self.get() == this);
// Delete the node, catching any exceptions. *prev = kj::mv(next);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() { next = nullptr;
node = nullptr; prev = nullptr;
})) {
result.addException(kj::mv(*exception));
}
// Call the error handler if there was an exception. KJ_IF_MAYBE(f, taskSet.emptyFulfiller) {
KJ_IF_MAYBE(e, result.exception) { if (taskSet.tasks == nullptr) {
taskSet.errorHandler.taskFailed(kj::mv(*e)); f->get()->fulfill();
taskSet.emptyFulfiller = nullptr;
} }
// Remove from the task map.
auto iter = taskSet.tasks.find(this);
KJ_ASSERT(iter != taskSet.tasks.end());
Own<Event> self = kj::mv(iter->second);
taskSet.tasks.erase(iter);
return mv(self);
} }
_::PromiseNode* getInnerForTrace() override { return mv(self);
return node; }
}
_::PromiseNode* getInnerForTrace() override {
return node;
}
private: private:
TaskSetImpl& taskSet; TaskSet& taskSet;
kj::Own<_::PromiseNode> node; Own<_::PromiseNode> node;
}; };
void add(Promise<void>&& promise) { void TaskSet::add(Promise<void>&& promise) {
auto task = heap<Task>(*this, kj::mv(promise.node)); auto task = heap<Task>(*this, kj::mv(promise.node));
Task* ptr = task; KJ_IF_MAYBE(head, tasks) {
tasks.insert(std::make_pair(ptr, kj::mv(task))); head->get()->prev = &task->next;
task->next = kj::mv(tasks);
} }
task->prev = &tasks;
tasks = kj::mv(task);
}
kj::String trace() { kj::String TaskSet::trace() {
kj::Vector<kj::String> traces; kj::Vector<kj::String> traces;
for (auto& entry: tasks) {
traces.add(entry.second->trace()); Maybe<Own<Task>>* ptr = &tasks;
for (;;) {
KJ_IF_MAYBE(task, *ptr) {
traces.add(task->get()->trace());
ptr = &task->get()->next;
} else {
break;
} }
return kj::strArray(traces, "\n============================================\n");
} }
private: return kj::strArray(traces, "\n============================================\n");
TaskSet::ErrorHandler& errorHandler; }
// TODO(perf): Use a linked list instead. Promise<void> TaskSet::onEmpty() {
std::map<Task*, Own<Task>> tasks; KJ_REQUIRE(emptyFulfiller == nullptr, "onEmpty() can only be called once at a time");
};
if (tasks == nullptr) {
return READY_NOW;
} else {
auto paf = newPromiseAndFulfiller<void>();
emptyFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
}
namespace _ { // private
class LoggingErrorHandler: public TaskSet::ErrorHandler { class LoggingErrorHandler: public TaskSet::ErrorHandler {
public: public:
...@@ -210,11 +230,11 @@ void EventPort::wake() const { ...@@ -210,11 +230,11 @@ void EventPort::wake() const {
EventLoop::EventLoop() EventLoop::EventLoop()
: port(_::NullEventPort::instance), : port(_::NullEventPort::instance),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {} daemons(kj::heap<TaskSet>(_::LoggingErrorHandler::instance)) {}
EventLoop::EventLoop(EventPort& port) EventLoop::EventLoop(EventPort& port)
: port(port), : port(port),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {} daemons(kj::heap<TaskSet>(_::LoggingErrorHandler::instance)) {}
EventLoop::~EventLoop() noexcept(false) { EventLoop::~EventLoop() noexcept(false) {
// Destroy all "daemon" tasks, noting that their destructors might try to access the EventLoop // Destroy all "daemon" tasks, noting that their destructors might try to access the EventLoop
...@@ -524,19 +544,6 @@ kj::String Event::trace() { ...@@ -524,19 +544,6 @@ kj::String Event::trace() {
// ======================================================================================= // =======================================================================================
TaskSet::TaskSet(ErrorHandler& errorHandler)
: impl(heap<_::TaskSetImpl>(errorHandler)) {}
TaskSet::~TaskSet() noexcept(false) {}
void TaskSet::add(Promise<void>&& promise) {
impl->add(kj::mv(promise));
}
kj::String TaskSet::trace() {
return impl->trace();
}
namespace _ { // private namespace _ { // private
kj::String PromiseBase::trace() { kj::String PromiseBase::trace() {
......
...@@ -317,7 +317,7 @@ private: ...@@ -317,7 +317,7 @@ private:
friend PromiseFulfillerPair<U> newPromiseAndFulfiller(); friend PromiseFulfillerPair<U> newPromiseAndFulfiller();
template <typename> template <typename>
friend class _::ForkHub; friend class _::ForkHub;
friend class _::TaskSetImpl; friend class TaskSet;
friend Promise<void> _::yield(); friend Promise<void> _::yield();
friend class _::NeverDone; friend class _::NeverDone;
template <typename U> template <typename U>
...@@ -522,8 +522,8 @@ public: ...@@ -522,8 +522,8 @@ public:
}; };
TaskSet(ErrorHandler& errorHandler); TaskSet(ErrorHandler& errorHandler);
// `loop` will be used to wait on promises. `errorHandler` will be executed any time a task // `errorHandler` will be executed any time a task throws an exception, and will execute within
// throws an exception, and will execute within the given EventLoop. // the given EventLoop.
~TaskSet() noexcept(false); ~TaskSet() noexcept(false);
...@@ -532,8 +532,19 @@ public: ...@@ -532,8 +532,19 @@ public:
kj::String trace(); kj::String trace();
// Return debug info about all promises currently in the TaskSet. // Return debug info about all promises currently in the TaskSet.
bool isEmpty() { return tasks == nullptr; }
// Check if any tasks are running.
Promise<void> onEmpty();
// Returns a promise that fulfills the next time the TaskSet is empty. Only one such promise can
// exist at a time.
private: private:
Own<_::TaskSetImpl> impl; class Task;
TaskSet::ErrorHandler& errorHandler;
Maybe<Own<Task>> tasks;
Maybe<Own<PromiseFulfiller<void>>> emptyFulfiller;
}; };
// ======================================================================================= // =======================================================================================
...@@ -653,7 +664,7 @@ private: ...@@ -653,7 +664,7 @@ private:
_::Event** tail = &head; _::Event** tail = &head;
_::Event** depthFirstInsertPoint = &head; _::Event** depthFirstInsertPoint = &head;
Own<_::TaskSetImpl> daemons; Own<TaskSet> daemons;
bool turn(); bool turn();
void setRunnable(bool runnable); void setRunnable(bool runnable);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment