Commit 97aad141 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #579 from capnproto/taskset-improvements

Add TaskSet.onEmpty() to wait until all tasks have completed
parents 7c8870eb b62c247b
...@@ -38,6 +38,7 @@ class EventLoop; ...@@ -38,6 +38,7 @@ class EventLoop;
template <typename T> template <typename T>
class Promise; class Promise;
class WaitScope; class WaitScope;
class TaskSet;
template <typename T> template <typename T>
Promise<Array<T>> joinPromises(Array<Promise<T>>&& promises); Promise<Array<T>> joinPromises(Array<Promise<T>>&& promises);
...@@ -172,8 +173,6 @@ class ChainPromiseNode; ...@@ -172,8 +173,6 @@ class ChainPromiseNode;
template <typename T> template <typename T>
class ForkHub; class ForkHub;
class TaskSetImpl;
class Event; class Event;
class PromiseBase { class PromiseBase {
...@@ -191,7 +190,7 @@ private: ...@@ -191,7 +190,7 @@ private:
friend class ChainPromiseNode; friend class ChainPromiseNode;
template <typename> template <typename>
friend class kj::Promise; friend class kj::Promise;
friend class TaskSetImpl; friend class kj::TaskSet;
template <typename U> template <typename U>
friend Promise<Array<U>> kj::joinPromises(Array<Promise<U>>&& promises); friend Promise<Array<U>> kj::joinPromises(Array<Promise<U>>&& promises);
friend Promise<void> kj::joinPromises(Array<Promise<void>>&& promises); friend Promise<void> kj::joinPromises(Array<Promise<void>>&& promises);
......
...@@ -620,6 +620,30 @@ TEST(Async, TaskSet) { ...@@ -620,6 +620,30 @@ TEST(Async, TaskSet) {
EXPECT_EQ(1u, errorHandler.exceptionCount); EXPECT_EQ(1u, errorHandler.exceptionCount);
} }
TEST(Async, TaskSetOnEmpty) {
EventLoop loop;
WaitScope waitScope(loop);
ErrorHandlerImpl errorHandler;
TaskSet tasks(errorHandler);
KJ_EXPECT(tasks.isEmpty());
auto paf = newPromiseAndFulfiller<void>();
tasks.add(kj::mv(paf.promise));
tasks.add(evalLater([]() {}));
KJ_EXPECT(!tasks.isEmpty());
auto promise = tasks.onEmpty();
KJ_EXPECT(!promise.poll(waitScope));
KJ_EXPECT(!tasks.isEmpty());
paf.fulfiller->fulfill();
KJ_ASSERT(promise.poll(waitScope));
KJ_EXPECT(tasks.isEmpty());
promise.wait(waitScope);
}
class DestructorDetector { class DestructorDetector {
public: public:
DestructorDetector(bool& setTrue): setTrue(setTrue) {} DestructorDetector(bool& setTrue): setTrue(setTrue) {}
......
...@@ -23,8 +23,6 @@ ...@@ -23,8 +23,6 @@
#include "debug.h" #include "debug.h"
#include "vector.h" #include "vector.h"
#include "threadlocal.h" #include "threadlocal.h"
#include <exception>
#include <map>
#if KJ_USE_FUTEX #if KJ_USE_FUTEX
#include <unistd.h> #include <unistd.h>
...@@ -86,32 +84,23 @@ public: ...@@ -86,32 +84,23 @@ public:
} // namespace } // namespace
namespace _ { // private TaskSet::TaskSet(TaskSet::ErrorHandler& errorHandler)
class TaskSetImpl {
public:
inline TaskSetImpl(TaskSet::ErrorHandler& errorHandler)
: errorHandler(errorHandler) {} : errorHandler(errorHandler) {}
~TaskSetImpl() noexcept(false) { TaskSet::~TaskSet() noexcept(false) {}
// std::map doesn't like it when elements' destructors throw, so carefully disassemble it.
if (!tasks.empty()) {
Vector<Own<Task>> deleteMe(tasks.size());
for (auto& entry: tasks) {
deleteMe.add(kj::mv(entry.second));
}
}
}
class Task final: public Event { class TaskSet::Task final: public _::Event {
public: public:
Task(TaskSetImpl& taskSet, Own<_::PromiseNode>&& nodeParam) Task(TaskSet& taskSet, Own<_::PromiseNode>&& nodeParam)
: taskSet(taskSet), node(kj::mv(nodeParam)) { : taskSet(taskSet), node(kj::mv(nodeParam)) {
node->setSelfPointer(&node); node->setSelfPointer(&node);
node->onReady(this); node->onReady(this);
} }
protected: Maybe<Own<Task>> next;
Maybe<Own<Task>>* prev = nullptr;
protected:
Maybe<Own<Event>> fire() override { Maybe<Own<Event>> fire() override {
// Get the result. // Get the result.
_::ExceptionOr<_::Void> result; _::ExceptionOr<_::Void> result;
...@@ -129,11 +118,23 @@ public: ...@@ -129,11 +118,23 @@ public:
taskSet.errorHandler.taskFailed(kj::mv(*e)); taskSet.errorHandler.taskFailed(kj::mv(*e));
} }
// Remove from the task map. // Remove from the task list.
auto iter = taskSet.tasks.find(this); KJ_IF_MAYBE(n, next) {
KJ_ASSERT(iter != taskSet.tasks.end()); n->get()->prev = prev;
Own<Event> self = kj::mv(iter->second); }
taskSet.tasks.erase(iter); Own<Event> self = kj::mv(KJ_ASSERT_NONNULL(*prev));
KJ_ASSERT(self.get() == this);
*prev = kj::mv(next);
next = nullptr;
prev = nullptr;
KJ_IF_MAYBE(f, taskSet.emptyFulfiller) {
if (taskSet.tasks == nullptr) {
f->get()->fulfill();
taskSet.emptyFulfiller = nullptr;
}
}
return mv(self); return mv(self);
} }
...@@ -141,31 +142,50 @@ public: ...@@ -141,31 +142,50 @@ public:
return node; return node;
} }
private: private:
TaskSetImpl& taskSet; TaskSet& taskSet;
kj::Own<_::PromiseNode> node; Own<_::PromiseNode> node;
}; };
void add(Promise<void>&& promise) { void TaskSet::add(Promise<void>&& promise) {
auto task = heap<Task>(*this, kj::mv(promise.node)); auto task = heap<Task>(*this, kj::mv(promise.node));
Task* ptr = task; KJ_IF_MAYBE(head, tasks) {
tasks.insert(std::make_pair(ptr, kj::mv(task))); head->get()->prev = &task->next;
task->next = kj::mv(tasks);
} }
task->prev = &tasks;
tasks = kj::mv(task);
}
kj::String trace() { kj::String TaskSet::trace() {
kj::Vector<kj::String> traces; kj::Vector<kj::String> traces;
for (auto& entry: tasks) {
traces.add(entry.second->trace()); Maybe<Own<Task>>* ptr = &tasks;
for (;;) {
KJ_IF_MAYBE(task, *ptr) {
traces.add(task->get()->trace());
ptr = &task->get()->next;
} else {
break;
} }
return kj::strArray(traces, "\n============================================\n");
} }
private: return kj::strArray(traces, "\n============================================\n");
TaskSet::ErrorHandler& errorHandler; }
// TODO(perf): Use a linked list instead. Promise<void> TaskSet::onEmpty() {
std::map<Task*, Own<Task>> tasks; KJ_REQUIRE(emptyFulfiller == nullptr, "onEmpty() can only be called once at a time");
};
if (tasks == nullptr) {
return READY_NOW;
} else {
auto paf = newPromiseAndFulfiller<void>();
emptyFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
}
namespace _ { // private
class LoggingErrorHandler: public TaskSet::ErrorHandler { class LoggingErrorHandler: public TaskSet::ErrorHandler {
public: public:
...@@ -210,11 +230,11 @@ void EventPort::wake() const { ...@@ -210,11 +230,11 @@ void EventPort::wake() const {
EventLoop::EventLoop() EventLoop::EventLoop()
: port(_::NullEventPort::instance), : port(_::NullEventPort::instance),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {} daemons(kj::heap<TaskSet>(_::LoggingErrorHandler::instance)) {}
EventLoop::EventLoop(EventPort& port) EventLoop::EventLoop(EventPort& port)
: port(port), : port(port),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {} daemons(kj::heap<TaskSet>(_::LoggingErrorHandler::instance)) {}
EventLoop::~EventLoop() noexcept(false) { EventLoop::~EventLoop() noexcept(false) {
// Destroy all "daemon" tasks, noting that their destructors might try to access the EventLoop // Destroy all "daemon" tasks, noting that their destructors might try to access the EventLoop
...@@ -524,19 +544,6 @@ kj::String Event::trace() { ...@@ -524,19 +544,6 @@ kj::String Event::trace() {
// ======================================================================================= // =======================================================================================
TaskSet::TaskSet(ErrorHandler& errorHandler)
: impl(heap<_::TaskSetImpl>(errorHandler)) {}
TaskSet::~TaskSet() noexcept(false) {}
void TaskSet::add(Promise<void>&& promise) {
impl->add(kj::mv(promise));
}
kj::String TaskSet::trace() {
return impl->trace();
}
namespace _ { // private namespace _ { // private
kj::String PromiseBase::trace() { kj::String PromiseBase::trace() {
......
...@@ -317,7 +317,7 @@ private: ...@@ -317,7 +317,7 @@ private:
friend PromiseFulfillerPair<U> newPromiseAndFulfiller(); friend PromiseFulfillerPair<U> newPromiseAndFulfiller();
template <typename> template <typename>
friend class _::ForkHub; friend class _::ForkHub;
friend class _::TaskSetImpl; friend class TaskSet;
friend Promise<void> _::yield(); friend Promise<void> _::yield();
friend class _::NeverDone; friend class _::NeverDone;
template <typename U> template <typename U>
...@@ -522,8 +522,8 @@ public: ...@@ -522,8 +522,8 @@ public:
}; };
TaskSet(ErrorHandler& errorHandler); TaskSet(ErrorHandler& errorHandler);
// `loop` will be used to wait on promises. `errorHandler` will be executed any time a task // `errorHandler` will be executed any time a task throws an exception, and will execute within
// throws an exception, and will execute within the given EventLoop. // the given EventLoop.
~TaskSet() noexcept(false); ~TaskSet() noexcept(false);
...@@ -532,8 +532,19 @@ public: ...@@ -532,8 +532,19 @@ public:
kj::String trace(); kj::String trace();
// Return debug info about all promises currently in the TaskSet. // Return debug info about all promises currently in the TaskSet.
bool isEmpty() { return tasks == nullptr; }
// Check if any tasks are running.
Promise<void> onEmpty();
// Returns a promise that fulfills the next time the TaskSet is empty. Only one such promise can
// exist at a time.
private: private:
Own<_::TaskSetImpl> impl; class Task;
TaskSet::ErrorHandler& errorHandler;
Maybe<Own<Task>> tasks;
Maybe<Own<PromiseFulfiller<void>>> emptyFulfiller;
}; };
// ======================================================================================= // =======================================================================================
...@@ -653,7 +664,7 @@ private: ...@@ -653,7 +664,7 @@ private:
_::Event** tail = &head; _::Event** tail = &head;
_::Event** depthFirstInsertPoint = &head; _::Event** depthFirstInsertPoint = &head;
Own<_::TaskSetImpl> daemons; Own<TaskSet> daemons;
bool turn(); bool turn();
void setRunnable(bool runnable); void setRunnable(bool runnable);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment