Commit 36df9e4c authored by Kenton Varda's avatar Kenton Varda

Extend EventPort interface to support cross-thread wakeups.

This is not currently used, but eventually we'll use this to support efficient inter-thread messaging.

That said, I've decided that inter-thread messaging is complicated and I don't want to work on it right now.
parent d8958a54
...@@ -649,8 +649,8 @@ public: ...@@ -649,8 +649,8 @@ public:
bool runnable = false; bool runnable = false;
int callCount = 0; int callCount = 0;
void wait() override { KJ_FAIL_ASSERT("Nothing to wait for."); } bool wait() override { KJ_FAIL_ASSERT("Nothing to wait for."); }
void poll() override {} bool poll() override { return false; }
void setRunnable(bool runnable) override { void setRunnable(bool runnable) override {
this->runnable = runnable; this->runnable = runnable;
++callCount; ++callCount;
......
...@@ -451,4 +451,32 @@ TEST_F(AsyncUnixTest, SteadyTimers) { ...@@ -451,4 +451,32 @@ TEST_F(AsyncUnixTest, SteadyTimers) {
} }
} }
TEST_F(AsyncUnixTest, Wake) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
EXPECT_FALSE(port.poll());
port.wake();
EXPECT_TRUE(port.poll());
EXPECT_FALSE(port.poll());
port.wake();
EXPECT_TRUE(port.wait());
{
auto promise = port.atSteadyTime(port.steadyTime());
EXPECT_FALSE(port.wait());
}
bool woken = false;
Thread thread([&]() {
delay();
woken = true;
port.wake();
});
EXPECT_TRUE(port.wait());
}
} // namespace kj } // namespace kj
...@@ -331,7 +331,7 @@ Promise<void> UnixEventPort::FdObserver::whenBecomesWritable() { ...@@ -331,7 +331,7 @@ Promise<void> UnixEventPort::FdObserver::whenBecomesWritable() {
return kj::mv(paf.promise); return kj::mv(paf.promise);
} }
void UnixEventPort::wait() { bool UnixEventPort::wait() {
// epoll_wait()'s timeout is an `int` count of milliseconds, so truncate to that. // epoll_wait()'s timeout is an `int` count of milliseconds, so truncate to that.
// Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that // Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that
// will break the math below. // will break the math below.
...@@ -352,11 +352,18 @@ void UnixEventPort::wait() { ...@@ -352,11 +352,18 @@ void UnixEventPort::wait() {
} }
} }
doEpollWait(epollTimeout); return doEpollWait(epollTimeout);
} }
void UnixEventPort::poll() { bool UnixEventPort::poll() {
doEpollWait(0); return doEpollWait(0);
}
void UnixEventPort::wake() const {
uint64_t one = 1;
ssize_t n;
KJ_NONBLOCKING_SYSCALL(n = write(eventFd, &one, sizeof(one)));
KJ_ASSERT(n < 0 || n == sizeof(one));
} }
static siginfo_t toRegularSiginfo(const struct signalfd_siginfo& siginfo) { static siginfo_t toRegularSiginfo(const struct signalfd_siginfo& siginfo) {
...@@ -460,7 +467,7 @@ static siginfo_t toRegularSiginfo(const struct signalfd_siginfo& siginfo) { ...@@ -460,7 +467,7 @@ static siginfo_t toRegularSiginfo(const struct signalfd_siginfo& siginfo) {
return result; return result;
} }
void UnixEventPort::doEpollWait(int timeout) { bool UnixEventPort::doEpollWait(int timeout) {
sigset_t newMask; sigset_t newMask;
sigemptyset(&newMask); sigemptyset(&newMask);
...@@ -483,6 +490,8 @@ void UnixEventPort::doEpollWait(int timeout) { ...@@ -483,6 +490,8 @@ void UnixEventPort::doEpollWait(int timeout) {
int n; int n;
KJ_SYSCALL(n = epoll_wait(epollFd, events, kj::size(events), timeout)); KJ_SYSCALL(n = epoll_wait(epollFd, events, kj::size(events), timeout));
bool woken = false;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
if (events[i].data.u64 == 0) { if (events[i].data.u64 == 0) {
for (;;) { for (;;) {
...@@ -496,11 +505,14 @@ void UnixEventPort::doEpollWait(int timeout) { ...@@ -496,11 +505,14 @@ void UnixEventPort::doEpollWait(int timeout) {
gotSignal(toRegularSiginfo(siginfo)); gotSignal(toRegularSiginfo(siginfo));
} }
} else if (events[i].data.u64 == 1) { } else if (events[i].data.u64 == 1) {
// Someone wanted to wake up this thread. Read and discard the event. // Someone called wake() from another thread. Consume the event.
uint64_t value; uint64_t value;
ssize_t n; ssize_t n;
KJ_NONBLOCKING_SYSCALL(n = read(eventFd, &value, sizeof(value))); KJ_NONBLOCKING_SYSCALL(n = read(eventFd, &value, sizeof(value)));
KJ_ASSERT(n < 0 || n == sizeof(value)); KJ_ASSERT(n < 0 || n == sizeof(value));
// We were woken. Need to return true.
woken = true;
} else { } else {
FdObserver* observer = reinterpret_cast<FdObserver*>(events[i].data.ptr); FdObserver* observer = reinterpret_cast<FdObserver*>(events[i].data.ptr);
observer->fire(events[i].events); observer->fire(events[i].events);
...@@ -508,6 +520,8 @@ void UnixEventPort::doEpollWait(int timeout) { ...@@ -508,6 +520,8 @@ void UnixEventPort::doEpollWait(int timeout) {
} }
processTimers(); processTimers();
return woken;
} }
#else // KJ_USE_EPOLL #else // KJ_USE_EPOLL
...@@ -521,6 +535,10 @@ void UnixEventPort::doEpollWait(int timeout) { ...@@ -521,6 +535,10 @@ void UnixEventPort::doEpollWait(int timeout) {
UnixEventPort::UnixEventPort() UnixEventPort::UnixEventPort()
: timers(kj::heap<TimerSet>()), : timers(kj::heap<TimerSet>()),
frozenSteadyTime(currentSteadyTime()) { frozenSteadyTime(currentSteadyTime()) {
static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port.");
*reinterpret_cast<pthread_t*>(&threadId) = pthread_self();
pthread_once(&registerReservedSignalOnce, &registerReservedSignal); pthread_once(&registerReservedSignalOnce, &registerReservedSignal);
} }
...@@ -659,7 +677,7 @@ private: ...@@ -659,7 +677,7 @@ private:
int pollError = 0; int pollError = 0;
}; };
void UnixEventPort::wait() { bool UnixEventPort::wait() {
sigset_t newMask; sigset_t newMask;
sigemptyset(&newMask); sigemptyset(&newMask);
sigaddset(&newMask, reservedSignal); sigaddset(&newMask, reservedSignal);
...@@ -681,11 +699,12 @@ void UnixEventPort::wait() { ...@@ -681,11 +699,12 @@ void UnixEventPort::wait() {
// We received a signal and longjmp'd back out of the signal handler. // We received a signal and longjmp'd back out of the signal handler.
threadCapture = nullptr; threadCapture = nullptr;
if (capture.siginfo.si_signo != reservedSignal) { if (capture.siginfo.si_signo == reservedSignal) {
return true;
} else {
gotSignal(capture.siginfo); gotSignal(capture.siginfo);
return false;
} }
return;
} }
// Enable signals, run the poll, then mask them again. // Enable signals, run the poll, then mask them again.
...@@ -720,9 +739,13 @@ void UnixEventPort::wait() { ...@@ -720,9 +739,13 @@ void UnixEventPort::wait() {
// Queue events. // Queue events.
pollContext.processResults(); pollContext.processResults();
processTimers(); processTimers();
return false;
} }
void UnixEventPort::poll() { bool UnixEventPort::poll() {
bool woken = false;
sigset_t pending; sigset_t pending;
sigset_t waitMask; sigset_t waitMask;
sigemptyset(&pending); sigemptyset(&pending);
...@@ -732,6 +755,12 @@ void UnixEventPort::poll() { ...@@ -732,6 +755,12 @@ void UnixEventPort::poll() {
KJ_SYSCALL(sigpending(&pending)); KJ_SYSCALL(sigpending(&pending));
uint signalCount = 0; uint signalCount = 0;
if (sigismember(&pending, reservedSignal)) {
++signalCount;
sigdelset(&pending, reservedSignal);
sigdelset(&waitMask, reservedSignal);
}
{ {
auto ptr = signalHead; auto ptr = signalHead;
while (ptr != nullptr) { while (ptr != nullptr) {
...@@ -752,7 +781,11 @@ void UnixEventPort::poll() { ...@@ -752,7 +781,11 @@ void UnixEventPort::poll() {
if (sigsetjmp(capture.jumpTo, true)) { if (sigsetjmp(capture.jumpTo, true)) {
// We received a signal and longjmp'd back out of the signal handler. // We received a signal and longjmp'd back out of the signal handler.
sigdelset(&waitMask, capture.siginfo.si_signo); sigdelset(&waitMask, capture.siginfo.si_signo);
gotSignal(capture.siginfo); if (capture.siginfo.si_signo == reservedSignal) {
woken = true;
} else {
gotSignal(capture.siginfo);
}
} else { } else {
sigsuspend(&waitMask); sigsuspend(&waitMask);
KJ_FAIL_ASSERT("sigsuspend() shouldn't return because the signal handler should " KJ_FAIL_ASSERT("sigsuspend() shouldn't return because the signal handler should "
...@@ -767,6 +800,15 @@ void UnixEventPort::poll() { ...@@ -767,6 +800,15 @@ void UnixEventPort::poll() {
pollContext.processResults(); pollContext.processResults();
} }
processTimers(); processTimers();
return woken;
}
void UnixEventPort::wake() const {
int error = pthread_kill(*reinterpret_cast<const pthread_t*>(&threadId), reservedSignal);
if (error != 0) {
KJ_FAIL_SYSCALL("pthread_kill", error);
}
} }
#endif // KJ_USE_EPOLL, else #endif // KJ_USE_EPOLL, else
......
...@@ -90,8 +90,9 @@ public: ...@@ -90,8 +90,9 @@ public:
Promise<void> atSteadyTime(TimePoint time); Promise<void> atSteadyTime(TimePoint time);
// implements EventPort ------------------------------------------------------ // implements EventPort ------------------------------------------------------
void wait() override; bool wait() override;
void poll() override; bool poll() override;
void wake() const override;
private: private:
struct TimerSet; // Defined in source file to avoid STL include. struct TimerSet; // Defined in source file to avoid STL include.
...@@ -119,13 +120,15 @@ private: ...@@ -119,13 +120,15 @@ private:
// Signal mask as currently set on the signalFd. Tracked so we can detect whether or not it // Signal mask as currently set on the signalFd. Tracked so we can detect whether or not it
// needs updating. // needs updating.
void doEpollWait(int timeout); bool doEpollWait(int timeout);
#else #else
class PollContext; class PollContext;
FdObserver* observersHead = nullptr; FdObserver* observersHead = nullptr;
FdObserver** observersTail = &observersHead; FdObserver** observersTail = &observersHead;
unsigned long long threadId; // actually pthread_t
#endif #endif
}; };
......
...@@ -180,11 +180,17 @@ LoggingErrorHandler LoggingErrorHandler::instance = LoggingErrorHandler(); ...@@ -180,11 +180,17 @@ LoggingErrorHandler LoggingErrorHandler::instance = LoggingErrorHandler();
class NullEventPort: public EventPort { class NullEventPort: public EventPort {
public: public:
void wait() override { bool wait() override {
KJ_FAIL_REQUIRE("Nothing to wait for; this thread would hang forever."); KJ_FAIL_REQUIRE("Nothing to wait for; this thread would hang forever.");
} }
void poll() override {} bool poll() override { return false; }
void wake() const override {
// TODO(soon): Implement using condvar.
kj::throwRecoverableException(KJ_EXCEPTION(UNIMPLEMENTED,
"Cross-thread events are not yet implemented for EventLoops with no EventPort."));
}
static NullEventPort instance; static NullEventPort instance;
}; };
...@@ -197,6 +203,11 @@ NullEventPort NullEventPort::instance = NullEventPort(); ...@@ -197,6 +203,11 @@ NullEventPort NullEventPort::instance = NullEventPort();
void EventPort::setRunnable(bool runnable) {} void EventPort::setRunnable(bool runnable) {}
void EventPort::wake() const {
kj::throwRecoverableException(KJ_EXCEPTION(UNIMPLEMENTED,
"cross-thread wake() not implemented by this EventPort implementation"));
}
EventLoop::EventLoop() EventLoop::EventLoop()
: port(_::NullEventPort::instance), : port(_::NullEventPort::instance),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {} daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {}
......
...@@ -513,7 +513,7 @@ class EventPort { ...@@ -513,7 +513,7 @@ class EventPort {
// framework, allowing the two to coexist in a single thread. // framework, allowing the two to coexist in a single thread.
public: public:
virtual void wait() = 0; virtual bool wait() = 0;
// Wait for an external event to arrive, sleeping if necessary. Once at least one event has // Wait for an external event to arrive, sleeping if necessary. Once at least one event has
// arrived, queue it to the event loop (e.g. by fulfilling a promise) and return. // arrived, queue it to the event loop (e.g. by fulfilling a promise) and return.
// //
...@@ -523,21 +523,37 @@ public: ...@@ -523,21 +523,37 @@ public:
// It is safe to return even if nothing has actually been queued, so long as calling `wait()` in // It is safe to return even if nothing has actually been queued, so long as calling `wait()` in
// a loop will eventually sleep. (That is to say, false positives are fine.) // a loop will eventually sleep. (That is to say, false positives are fine.)
// //
// If the implementation knows that no event will ever arrive, it should throw an exception // Returns true if wake() has been called from another thread. (Precisely, returns true if
// rather than deadlock. // no previous call to wait `wait()` nor `poll()` has returned true since `wake()` was last
// called.)
virtual void poll() = 0; virtual bool poll() = 0;
// Check if any external events have arrived, but do not sleep. If any events have arrived, // Check if any external events have arrived, but do not sleep. If any events have arrived,
// add them to the event queue (e.g. by fulfilling promises) before returning. // add them to the event queue (e.g. by fulfilling promises) before returning.
// //
// This may be called during `Promise::wait()` when the EventLoop has been executing for a while // This may be called during `Promise::wait()` when the EventLoop has been executing for a while
// without a break but is still non-empty. // without a break but is still non-empty.
//
// Returns true if wake() has been called from another thread. (Precisely, returns true if
// no previous call to wait `wait()` nor `poll()` has returned true since `wake()` was last
// called.)
virtual void setRunnable(bool runnable); virtual void setRunnable(bool runnable);
// Called to notify the `EventPort` when the `EventLoop` has work to do; specifically when it // Called to notify the `EventPort` when the `EventLoop` has work to do; specifically when it
// transitions from empty -> runnable or runnable -> empty. This is typically useful when // transitions from empty -> runnable or runnable -> empty. This is typically useful when
// integrating with an external event loop; if the loop is currently runnable then you should // integrating with an external event loop; if the loop is currently runnable then you should
// arrange to call run() on it soon. The default implementation does nothing. // arrange to call run() on it soon. The default implementation does nothing.
virtual void wake() const;
// Wake up the EventPort's thread from another thread.
//
// Unlike all other methods on this interface, `wake()` may be called from another thread, hence
// it is `const`.
//
// Technically speaking, `wake()` causes the target thread to cease sleeping and not to sleep
// again until `wait()` or `poll()` has returned true at least once.
//
// The default implementation throws an UNIMPLEMENTED exception.
}; };
class EventLoop { class EventLoop {
......
...@@ -254,8 +254,8 @@ public: ...@@ -254,8 +254,8 @@ public:
return *this; return *this;
} }
inline operator int() { return fd; } inline operator int() const { return fd; }
inline int get() { return fd; } inline int get() const { return fd; }
inline bool operator==(decltype(nullptr)) { return fd < 0; } inline bool operator==(decltype(nullptr)) { return fd < 0; }
inline bool operator!=(decltype(nullptr)) { return fd >= 0; } inline bool operator!=(decltype(nullptr)) { return fd >= 0; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment