Commit 97aad141 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #579 from capnproto/taskset-improvements

Add TaskSet.onEmpty() to wait until all tasks have completed
parents 7c8870eb b62c247b
......@@ -38,6 +38,7 @@ class EventLoop;
template <typename T>
class Promise;
class WaitScope;
class TaskSet;
template <typename T>
Promise<Array<T>> joinPromises(Array<Promise<T>>&& promises);
......@@ -172,8 +173,6 @@ class ChainPromiseNode;
template <typename T>
class ForkHub;
class TaskSetImpl;
class Event;
class PromiseBase {
......@@ -191,7 +190,7 @@ private:
friend class ChainPromiseNode;
template <typename>
friend class kj::Promise;
friend class TaskSetImpl;
friend class kj::TaskSet;
template <typename U>
friend Promise<Array<U>> kj::joinPromises(Array<Promise<U>>&& promises);
friend Promise<void> kj::joinPromises(Array<Promise<void>>&& promises);
......
......@@ -620,6 +620,30 @@ TEST(Async, TaskSet) {
EXPECT_EQ(1u, errorHandler.exceptionCount);
}
TEST(Async, TaskSetOnEmpty) {
EventLoop loop;
WaitScope waitScope(loop);
ErrorHandlerImpl errorHandler;
TaskSet tasks(errorHandler);
KJ_EXPECT(tasks.isEmpty());
auto paf = newPromiseAndFulfiller<void>();
tasks.add(kj::mv(paf.promise));
tasks.add(evalLater([]() {}));
KJ_EXPECT(!tasks.isEmpty());
auto promise = tasks.onEmpty();
KJ_EXPECT(!promise.poll(waitScope));
KJ_EXPECT(!tasks.isEmpty());
paf.fulfiller->fulfill();
KJ_ASSERT(promise.poll(waitScope));
KJ_EXPECT(tasks.isEmpty());
promise.wait(waitScope);
}
class DestructorDetector {
public:
DestructorDetector(bool& setTrue): setTrue(setTrue) {}
......
......@@ -23,8 +23,6 @@
#include "debug.h"
#include "vector.h"
#include "threadlocal.h"
#include <exception>
#include <map>
#if KJ_USE_FUTEX
#include <unistd.h>
......@@ -86,86 +84,108 @@ public:
} // namespace
namespace _ { // private
TaskSet::TaskSet(TaskSet::ErrorHandler& errorHandler)
: errorHandler(errorHandler) {}
TaskSet::~TaskSet() noexcept(false) {}
class TaskSetImpl {
class TaskSet::Task final: public _::Event {
public:
inline TaskSetImpl(TaskSet::ErrorHandler& errorHandler)
: errorHandler(errorHandler) {}
~TaskSetImpl() noexcept(false) {
// std::map doesn't like it when elements' destructors throw, so carefully disassemble it.
if (!tasks.empty()) {
Vector<Own<Task>> deleteMe(tasks.size());
for (auto& entry: tasks) {
deleteMe.add(kj::mv(entry.second));
}
Task(TaskSet& taskSet, Own<_::PromiseNode>&& nodeParam)
: taskSet(taskSet), node(kj::mv(nodeParam)) {
node->setSelfPointer(&node);
node->onReady(this);
}
Maybe<Own<Task>> next;
Maybe<Own<Task>>* prev = nullptr;
protected:
Maybe<Own<Event>> fire() override {
// Get the result.
_::ExceptionOr<_::Void> result;
node->get(result);
// Delete the node, catching any exceptions.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
node = nullptr;
})) {
result.addException(kj::mv(*exception));
}
}
class Task final: public Event {
public:
Task(TaskSetImpl& taskSet, Own<_::PromiseNode>&& nodeParam)
: taskSet(taskSet), node(kj::mv(nodeParam)) {
node->setSelfPointer(&node);
node->onReady(this);
// Call the error handler if there was an exception.
KJ_IF_MAYBE(e, result.exception) {
taskSet.errorHandler.taskFailed(kj::mv(*e));
}
protected:
Maybe<Own<Event>> fire() override {
// Get the result.
_::ExceptionOr<_::Void> result;
node->get(result);
// Delete the node, catching any exceptions.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
node = nullptr;
})) {
result.addException(kj::mv(*exception));
}
// Remove from the task list.
KJ_IF_MAYBE(n, next) {
n->get()->prev = prev;
}
Own<Event> self = kj::mv(KJ_ASSERT_NONNULL(*prev));
KJ_ASSERT(self.get() == this);
*prev = kj::mv(next);
next = nullptr;
prev = nullptr;
// Call the error handler if there was an exception.
KJ_IF_MAYBE(e, result.exception) {
taskSet.errorHandler.taskFailed(kj::mv(*e));
KJ_IF_MAYBE(f, taskSet.emptyFulfiller) {
if (taskSet.tasks == nullptr) {
f->get()->fulfill();
taskSet.emptyFulfiller = nullptr;
}
// Remove from the task map.
auto iter = taskSet.tasks.find(this);
KJ_ASSERT(iter != taskSet.tasks.end());
Own<Event> self = kj::mv(iter->second);
taskSet.tasks.erase(iter);
return mv(self);
}
_::PromiseNode* getInnerForTrace() override {
return node;
}
return mv(self);
}
_::PromiseNode* getInnerForTrace() override {
return node;
}
private:
TaskSetImpl& taskSet;
kj::Own<_::PromiseNode> node;
};
private:
TaskSet& taskSet;
Own<_::PromiseNode> node;
};
void add(Promise<void>&& promise) {
auto task = heap<Task>(*this, kj::mv(promise.node));
Task* ptr = task;
tasks.insert(std::make_pair(ptr, kj::mv(task)));
void TaskSet::add(Promise<void>&& promise) {
auto task = heap<Task>(*this, kj::mv(promise.node));
KJ_IF_MAYBE(head, tasks) {
head->get()->prev = &task->next;
task->next = kj::mv(tasks);
}
task->prev = &tasks;
tasks = kj::mv(task);
}
kj::String trace() {
kj::Vector<kj::String> traces;
for (auto& entry: tasks) {
traces.add(entry.second->trace());
kj::String TaskSet::trace() {
kj::Vector<kj::String> traces;
Maybe<Own<Task>>* ptr = &tasks;
for (;;) {
KJ_IF_MAYBE(task, *ptr) {
traces.add(task->get()->trace());
ptr = &task->get()->next;
} else {
break;
}
return kj::strArray(traces, "\n============================================\n");
}
private:
TaskSet::ErrorHandler& errorHandler;
return kj::strArray(traces, "\n============================================\n");
}
// TODO(perf): Use a linked list instead.
std::map<Task*, Own<Task>> tasks;
};
Promise<void> TaskSet::onEmpty() {
KJ_REQUIRE(emptyFulfiller == nullptr, "onEmpty() can only be called once at a time");
if (tasks == nullptr) {
return READY_NOW;
} else {
auto paf = newPromiseAndFulfiller<void>();
emptyFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
}
namespace _ { // private
class LoggingErrorHandler: public TaskSet::ErrorHandler {
public:
......@@ -210,11 +230,11 @@ void EventPort::wake() const {
EventLoop::EventLoop()
: port(_::NullEventPort::instance),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {}
daemons(kj::heap<TaskSet>(_::LoggingErrorHandler::instance)) {}
EventLoop::EventLoop(EventPort& port)
: port(port),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {}
daemons(kj::heap<TaskSet>(_::LoggingErrorHandler::instance)) {}
EventLoop::~EventLoop() noexcept(false) {
// Destroy all "daemon" tasks, noting that their destructors might try to access the EventLoop
......@@ -524,19 +544,6 @@ kj::String Event::trace() {
// =======================================================================================
TaskSet::TaskSet(ErrorHandler& errorHandler)
: impl(heap<_::TaskSetImpl>(errorHandler)) {}
TaskSet::~TaskSet() noexcept(false) {}
void TaskSet::add(Promise<void>&& promise) {
impl->add(kj::mv(promise));
}
kj::String TaskSet::trace() {
return impl->trace();
}
namespace _ { // private
kj::String PromiseBase::trace() {
......
......@@ -317,7 +317,7 @@ private:
friend PromiseFulfillerPair<U> newPromiseAndFulfiller();
template <typename>
friend class _::ForkHub;
friend class _::TaskSetImpl;
friend class TaskSet;
friend Promise<void> _::yield();
friend class _::NeverDone;
template <typename U>
......@@ -522,8 +522,8 @@ public:
};
TaskSet(ErrorHandler& errorHandler);
// `loop` will be used to wait on promises. `errorHandler` will be executed any time a task
// throws an exception, and will execute within the given EventLoop.
// `errorHandler` will be executed any time a task throws an exception, and will execute within
// the given EventLoop.
~TaskSet() noexcept(false);
......@@ -532,8 +532,19 @@ public:
kj::String trace();
// Return debug info about all promises currently in the TaskSet.
bool isEmpty() { return tasks == nullptr; }
// Check if any tasks are running.
Promise<void> onEmpty();
// Returns a promise that fulfills the next time the TaskSet is empty. Only one such promise can
// exist at a time.
private:
Own<_::TaskSetImpl> impl;
class Task;
TaskSet::ErrorHandler& errorHandler;
Maybe<Own<Task>> tasks;
Maybe<Own<PromiseFulfiller<void>>> emptyFulfiller;
};
// =======================================================================================
......@@ -653,7 +664,7 @@ private:
_::Event** tail = &head;
_::Event** depthFirstInsertPoint = &head;
Own<_::TaskSetImpl> daemons;
Own<TaskSet> daemons;
bool turn();
void setRunnable(bool runnable);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment