Commit a2c1269c authored by Kenton Varda's avatar Kenton Varda

Fixes #39 - race conditions in promise framework.

parent 2e3b3413
...@@ -138,9 +138,11 @@ Promise<void> EventLoop::yieldIfSameThread() const { ...@@ -138,9 +138,11 @@ Promise<void> EventLoop::yieldIfSameThread() const {
EventLoop::Event::~Event() noexcept(false) { EventLoop::Event::~Event() noexcept(false) {
if (this != &loop.queue) { if (this != &loop.queue) {
KJ_ASSERT(next == nullptr || std::uncaught_exception(), KJ_ASSERT(next == this,
"Event destroyed while armed. You must call disarm() in the subclass's destructor " "Event destroyed while armed. You must call disarm() in the subclass's destructor "
"in order to ensure that fire() is not running when the event is destroyed."); "in order to ensure that fire() is not running when the event is destroyed.") {
break;
}
} }
} }
...@@ -180,9 +182,9 @@ void EventLoop::Event::arm(bool preemptIfSameThread) { ...@@ -180,9 +182,9 @@ void EventLoop::Event::arm(bool preemptIfSameThread) {
} }
void EventLoop::Event::disarm() { void EventLoop::Event::disarm() {
if (next != nullptr) {
loop.queue.mutex.lock(_::Mutex::EXCLUSIVE); loop.queue.mutex.lock(_::Mutex::EXCLUSIVE);
if (next != nullptr && next != this) {
if (loop.insertPoint == this) { if (loop.insertPoint == this) {
loop.insertPoint = next; loop.insertPoint = next;
} }
...@@ -191,9 +193,11 @@ void EventLoop::Event::disarm() { ...@@ -191,9 +193,11 @@ void EventLoop::Event::disarm() {
prev->next = next; prev->next = next;
next = nullptr; next = nullptr;
prev = nullptr; prev = nullptr;
}
next = this;
loop.queue.mutex.unlock(_::Mutex::EXCLUSIVE); loop.queue.mutex.unlock(_::Mutex::EXCLUSIVE);
}
// Ensure that if fire() is currently running, it completes before disarm() returns. // Ensure that if fire() is currently running, it completes before disarm() returns.
mutex.lock(_::Mutex::EXCLUSIVE); mutex.lock(_::Mutex::EXCLUSIVE);
...@@ -482,11 +486,6 @@ ForkHubBase::ForkHubBase(const EventLoop& loop, Own<PromiseNode>&& inner, ...@@ -482,11 +486,6 @@ ForkHubBase::ForkHubBase(const EventLoop& loop, Own<PromiseNode>&& inner,
ExceptionOrValue& resultRef) ExceptionOrValue& resultRef)
: EventLoop::Event(loop), inner(kj::mv(inner)), resultRef(resultRef) { : EventLoop::Event(loop), inner(kj::mv(inner)), resultRef(resultRef) {
KJ_DREQUIRE(this->inner->isSafeEventLoop(loop)); KJ_DREQUIRE(this->inner->isSafeEventLoop(loop));
arm();
}
ForkHubBase::~ForkHubBase() noexcept(false) {
disarm();
} }
void ForkHubBase::fire() { void ForkHubBase::fire() {
...@@ -597,11 +596,6 @@ CrossThreadPromiseNodeBase::CrossThreadPromiseNodeBase( ...@@ -597,11 +596,6 @@ CrossThreadPromiseNodeBase::CrossThreadPromiseNodeBase(
const EventLoop& loop, Own<PromiseNode>&& dependency, ExceptionOrValue& resultRef) const EventLoop& loop, Own<PromiseNode>&& dependency, ExceptionOrValue& resultRef)
: Event(loop), dependency(kj::mv(dependency)), resultRef(resultRef) { : Event(loop), dependency(kj::mv(dependency)), resultRef(resultRef) {
KJ_DREQUIRE(this->dependency->isSafeEventLoop(loop)); KJ_DREQUIRE(this->dependency->isSafeEventLoop(loop));
arm();
}
CrossThreadPromiseNodeBase::~CrossThreadPromiseNodeBase() noexcept(false) {
disarm();
} }
bool CrossThreadPromiseNodeBase::onReady(EventLoop::Event& event) noexcept { bool CrossThreadPromiseNodeBase::onReady(EventLoop::Event& event) noexcept {
......
...@@ -296,6 +296,9 @@ public: ...@@ -296,6 +296,9 @@ public:
class Event { class Event {
// An event waiting to be executed. Not for direct use by applications -- promises use this // An event waiting to be executed. Not for direct use by applications -- promises use this
// internally. // internally.
//
// WARNING: This class is difficult to use correctly. It's easy to have subtle race
// conditions.
public: public:
Event(const EventLoop& loop): loop(loop), next(nullptr), prev(nullptr) {} Event(const EventLoop& loop): loop(loop), next(nullptr), prev(nullptr) {}
...@@ -318,10 +321,10 @@ public: ...@@ -318,10 +321,10 @@ public:
// order in which they were queued. // order in which they were queued.
void disarm(); void disarm();
// Cancel this event if it is armed. If it is already running, block until it finishes // Cancel this event if it is armed, and ignore any further arm()s. If it is already running,
// before returning. MUST be called in the subclass's destructor if it is possible that // block until it finishes before returning. MUST be called in the subclass's destructor if
// the event is still armed, because once Event's destructor is reached, fire() is a // it is possible that the event is still armed, because once Event's destructor is reached,
// pure-virtual function. // fire() is a pure-virtual function.
inline const EventLoop& getEventLoop() { return loop; } inline const EventLoop& getEventLoop() { return loop; }
// Get the event loop on which this event will run. // Get the event loop on which this event will run.
...@@ -333,7 +336,7 @@ public: ...@@ -333,7 +336,7 @@ public:
private: private:
friend class EventLoop; friend class EventLoop;
const EventLoop& loop; const EventLoop& loop;
Event* next; Event* next; // if == this, disarm() has been called.
Event* prev; Event* prev;
mutable kj::_::Mutex mutex; mutable kj::_::Mutex mutex;
...@@ -1066,10 +1069,9 @@ public: ...@@ -1066,10 +1069,9 @@ public:
// ------------------------------------------------------------------- // -------------------------------------------------------------------
class ForkHubBase: public Refcounted, private EventLoop::Event { class ForkHubBase: public Refcounted, protected EventLoop::Event {
public: public:
ForkHubBase(const EventLoop& loop, Own<PromiseNode>&& inner, ExceptionOrValue& resultRef); ForkHubBase(const EventLoop& loop, Own<PromiseNode>&& inner, ExceptionOrValue& resultRef);
~ForkHubBase() noexcept(false);
inline const ExceptionOrValue& getResultRef() const { return resultRef; } inline const ExceptionOrValue& getResultRef() const { return resultRef; }
...@@ -1100,7 +1102,16 @@ class ForkHub final: public ForkHubBase { ...@@ -1100,7 +1102,16 @@ class ForkHub final: public ForkHubBase {
public: public:
ForkHub(const EventLoop& loop, Own<PromiseNode>&& inner) ForkHub(const EventLoop& loop, Own<PromiseNode>&& inner)
: ForkHubBase(loop, kj::mv(inner), result) {} : ForkHubBase(loop, kj::mv(inner), result) {
// Note that it's unsafe to call this from the superclass's constructor because `result` won't
// be initialized yet and the event could fire in another thread immediately.
arm();
}
~ForkHub() noexcept(false) {
// Note that it's unsafe to call this from the superclass's destructor because we must disarm
// before `result` is destroyed.
disarm();
}
Promise<_::Forked<_::UnfixVoid<T>>> addBranch() const { Promise<_::Forked<_::UnfixVoid<T>>> addBranch() const {
return Promise<_::Forked<_::UnfixVoid<T>>>(false, kj::heap<ForkBranch<T>>(addRef(*this))); return Promise<_::Forked<_::UnfixVoid<T>>>(false, kj::heap<ForkBranch<T>>(addRef(*this)));
...@@ -1165,14 +1176,13 @@ Own<PromiseNode>&& maybeChain(Own<PromiseNode>&& node, T*) { ...@@ -1165,14 +1176,13 @@ Own<PromiseNode>&& maybeChain(Own<PromiseNode>&& node, T*) {
// ------------------------------------------------------------------- // -------------------------------------------------------------------
class CrossThreadPromiseNodeBase: public PromiseNode, private EventLoop::Event { class CrossThreadPromiseNodeBase: public PromiseNode, protected EventLoop::Event {
// A PromiseNode that safely imports a promised value from one EventLoop to another (which // A PromiseNode that safely imports a promised value from one EventLoop to another (which
// implies crossing threads). // implies crossing threads).
public: public:
CrossThreadPromiseNodeBase(const EventLoop& loop, Own<PromiseNode>&& dependency, CrossThreadPromiseNodeBase(const EventLoop& loop, Own<PromiseNode>&& dependency,
ExceptionOrValue& resultRef); ExceptionOrValue& resultRef);
~CrossThreadPromiseNodeBase() noexcept(false);
bool onReady(EventLoop::Event& event) noexcept override; bool onReady(EventLoop::Event& event) noexcept override;
Maybe<const EventLoop&> getSafeEventLoop() noexcept override; Maybe<const EventLoop&> getSafeEventLoop() noexcept override;
...@@ -1192,7 +1202,16 @@ template <typename T> ...@@ -1192,7 +1202,16 @@ template <typename T>
class CrossThreadPromiseNode final: public CrossThreadPromiseNodeBase { class CrossThreadPromiseNode final: public CrossThreadPromiseNodeBase {
public: public:
CrossThreadPromiseNode(const EventLoop& loop, Own<PromiseNode>&& dependency) CrossThreadPromiseNode(const EventLoop& loop, Own<PromiseNode>&& dependency)
: CrossThreadPromiseNodeBase(loop, kj::mv(dependency), result) {} : CrossThreadPromiseNodeBase(loop, kj::mv(dependency), result) {
// Note that it's unsafe to call this from the superclass's constructor because `result` won't
// be initialized yet and the event could fire in another thread immediately.
arm();
}
~CrossThreadPromiseNode() noexcept(false) {
// Note that it's unsafe to call this from the superclass's destructor because we must disarm
// before `result` is destroyed.
disarm();
}
void get(ExceptionOrValue& output) noexcept override { void get(ExceptionOrValue& output) noexcept override {
output.as<T>() = kj::mv(result); output.as<T>() = kj::mv(result);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment