Commit 36df9e4c authored by Kenton Varda's avatar Kenton Varda

Extend EventPort interface to support cross-thread wakeups.

This is not currently used, but eventually we'll use this to support efficient inter-thread messaging.

That said, I've decided that inter-thread messaging is complicated and I don't want to work on it right now.
parent d8958a54
......@@ -649,8 +649,8 @@ public:
bool runnable = false;
int callCount = 0;
void wait() override { KJ_FAIL_ASSERT("Nothing to wait for."); }
void poll() override {}
bool wait() override { KJ_FAIL_ASSERT("Nothing to wait for."); }
bool poll() override { return false; }
void setRunnable(bool runnable) override {
this->runnable = runnable;
++callCount;
......
......@@ -451,4 +451,32 @@ TEST_F(AsyncUnixTest, SteadyTimers) {
}
}
TEST_F(AsyncUnixTest, Wake) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
EXPECT_FALSE(port.poll());
port.wake();
EXPECT_TRUE(port.poll());
EXPECT_FALSE(port.poll());
port.wake();
EXPECT_TRUE(port.wait());
{
auto promise = port.atSteadyTime(port.steadyTime());
EXPECT_FALSE(port.wait());
}
bool woken = false;
Thread thread([&]() {
delay();
woken = true;
port.wake();
});
EXPECT_TRUE(port.wait());
}
} // namespace kj
......@@ -331,7 +331,7 @@ Promise<void> UnixEventPort::FdObserver::whenBecomesWritable() {
return kj::mv(paf.promise);
}
void UnixEventPort::wait() {
bool UnixEventPort::wait() {
// epoll_wait()'s timeout is an `int` count of milliseconds, so truncate to that.
// Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that
// will break the math below.
......@@ -352,11 +352,18 @@ void UnixEventPort::wait() {
}
}
doEpollWait(epollTimeout);
return doEpollWait(epollTimeout);
}
void UnixEventPort::poll() {
doEpollWait(0);
bool UnixEventPort::poll() {
return doEpollWait(0);
}
void UnixEventPort::wake() const {
uint64_t one = 1;
ssize_t n;
KJ_NONBLOCKING_SYSCALL(n = write(eventFd, &one, sizeof(one)));
KJ_ASSERT(n < 0 || n == sizeof(one));
}
static siginfo_t toRegularSiginfo(const struct signalfd_siginfo& siginfo) {
......@@ -460,7 +467,7 @@ static siginfo_t toRegularSiginfo(const struct signalfd_siginfo& siginfo) {
return result;
}
void UnixEventPort::doEpollWait(int timeout) {
bool UnixEventPort::doEpollWait(int timeout) {
sigset_t newMask;
sigemptyset(&newMask);
......@@ -483,6 +490,8 @@ void UnixEventPort::doEpollWait(int timeout) {
int n;
KJ_SYSCALL(n = epoll_wait(epollFd, events, kj::size(events), timeout));
bool woken = false;
for (int i = 0; i < n; i++) {
if (events[i].data.u64 == 0) {
for (;;) {
......@@ -496,11 +505,14 @@ void UnixEventPort::doEpollWait(int timeout) {
gotSignal(toRegularSiginfo(siginfo));
}
} else if (events[i].data.u64 == 1) {
// Someone wanted to wake up this thread. Read and discard the event.
// Someone called wake() from another thread. Consume the event.
uint64_t value;
ssize_t n;
KJ_NONBLOCKING_SYSCALL(n = read(eventFd, &value, sizeof(value)));
KJ_ASSERT(n < 0 || n == sizeof(value));
// We were woken. Need to return true.
woken = true;
} else {
FdObserver* observer = reinterpret_cast<FdObserver*>(events[i].data.ptr);
observer->fire(events[i].events);
......@@ -508,6 +520,8 @@ void UnixEventPort::doEpollWait(int timeout) {
}
processTimers();
return woken;
}
#else // KJ_USE_EPOLL
......@@ -521,6 +535,10 @@ void UnixEventPort::doEpollWait(int timeout) {
UnixEventPort::UnixEventPort()
: timers(kj::heap<TimerSet>()),
frozenSteadyTime(currentSteadyTime()) {
static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port.");
*reinterpret_cast<pthread_t*>(&threadId) = pthread_self();
pthread_once(&registerReservedSignalOnce, &registerReservedSignal);
}
......@@ -659,7 +677,7 @@ private:
int pollError = 0;
};
void UnixEventPort::wait() {
bool UnixEventPort::wait() {
sigset_t newMask;
sigemptyset(&newMask);
sigaddset(&newMask, reservedSignal);
......@@ -681,11 +699,12 @@ void UnixEventPort::wait() {
// We received a signal and longjmp'd back out of the signal handler.
threadCapture = nullptr;
if (capture.siginfo.si_signo != reservedSignal) {
if (capture.siginfo.si_signo == reservedSignal) {
return true;
} else {
gotSignal(capture.siginfo);
return false;
}
return;
}
// Enable signals, run the poll, then mask them again.
......@@ -720,9 +739,13 @@ void UnixEventPort::wait() {
// Queue events.
pollContext.processResults();
processTimers();
return false;
}
void UnixEventPort::poll() {
bool UnixEventPort::poll() {
bool woken = false;
sigset_t pending;
sigset_t waitMask;
sigemptyset(&pending);
......@@ -732,6 +755,12 @@ void UnixEventPort::poll() {
KJ_SYSCALL(sigpending(&pending));
uint signalCount = 0;
if (sigismember(&pending, reservedSignal)) {
++signalCount;
sigdelset(&pending, reservedSignal);
sigdelset(&waitMask, reservedSignal);
}
{
auto ptr = signalHead;
while (ptr != nullptr) {
......@@ -752,7 +781,11 @@ void UnixEventPort::poll() {
if (sigsetjmp(capture.jumpTo, true)) {
// We received a signal and longjmp'd back out of the signal handler.
sigdelset(&waitMask, capture.siginfo.si_signo);
if (capture.siginfo.si_signo == reservedSignal) {
woken = true;
} else {
gotSignal(capture.siginfo);
}
} else {
sigsuspend(&waitMask);
KJ_FAIL_ASSERT("sigsuspend() shouldn't return because the signal handler should "
......@@ -767,6 +800,15 @@ void UnixEventPort::poll() {
pollContext.processResults();
}
processTimers();
return woken;
}
void UnixEventPort::wake() const {
int error = pthread_kill(*reinterpret_cast<const pthread_t*>(&threadId), reservedSignal);
if (error != 0) {
KJ_FAIL_SYSCALL("pthread_kill", error);
}
}
#endif // KJ_USE_EPOLL, else
......
......@@ -90,8 +90,9 @@ public:
Promise<void> atSteadyTime(TimePoint time);
// implements EventPort ------------------------------------------------------
void wait() override;
void poll() override;
bool wait() override;
bool poll() override;
void wake() const override;
private:
struct TimerSet; // Defined in source file to avoid STL include.
......@@ -119,13 +120,15 @@ private:
// Signal mask as currently set on the signalFd. Tracked so we can detect whether or not it
// needs updating.
void doEpollWait(int timeout);
bool doEpollWait(int timeout);
#else
class PollContext;
FdObserver* observersHead = nullptr;
FdObserver** observersTail = &observersHead;
unsigned long long threadId; // actually pthread_t
#endif
};
......
......@@ -180,11 +180,17 @@ LoggingErrorHandler LoggingErrorHandler::instance = LoggingErrorHandler();
class NullEventPort: public EventPort {
public:
void wait() override {
bool wait() override {
KJ_FAIL_REQUIRE("Nothing to wait for; this thread would hang forever.");
}
void poll() override {}
bool poll() override { return false; }
void wake() const override {
// TODO(soon): Implement using condvar.
kj::throwRecoverableException(KJ_EXCEPTION(UNIMPLEMENTED,
"Cross-thread events are not yet implemented for EventLoops with no EventPort."));
}
static NullEventPort instance;
};
......@@ -197,6 +203,11 @@ NullEventPort NullEventPort::instance = NullEventPort();
void EventPort::setRunnable(bool runnable) {}
void EventPort::wake() const {
kj::throwRecoverableException(KJ_EXCEPTION(UNIMPLEMENTED,
"cross-thread wake() not implemented by this EventPort implementation"));
}
EventLoop::EventLoop()
: port(_::NullEventPort::instance),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {}
......
......@@ -513,7 +513,7 @@ class EventPort {
// framework, allowing the two to coexist in a single thread.
public:
virtual void wait() = 0;
virtual bool wait() = 0;
// Wait for an external event to arrive, sleeping if necessary. Once at least one event has
// arrived, queue it to the event loop (e.g. by fulfilling a promise) and return.
//
......@@ -523,21 +523,37 @@ public:
// It is safe to return even if nothing has actually been queued, so long as calling `wait()` in
// a loop will eventually sleep. (That is to say, false positives are fine.)
//
// If the implementation knows that no event will ever arrive, it should throw an exception
// rather than deadlock.
// Returns true if wake() has been called from another thread. (Precisely, returns true if
// no previous call to wait `wait()` nor `poll()` has returned true since `wake()` was last
// called.)
virtual void poll() = 0;
virtual bool poll() = 0;
// Check if any external events have arrived, but do not sleep. If any events have arrived,
// add them to the event queue (e.g. by fulfilling promises) before returning.
//
// This may be called during `Promise::wait()` when the EventLoop has been executing for a while
// without a break but is still non-empty.
//
// Returns true if wake() has been called from another thread. (Precisely, returns true if
// no previous call to wait `wait()` nor `poll()` has returned true since `wake()` was last
// called.)
virtual void setRunnable(bool runnable);
// Called to notify the `EventPort` when the `EventLoop` has work to do; specifically when it
// transitions from empty -> runnable or runnable -> empty. This is typically useful when
// integrating with an external event loop; if the loop is currently runnable then you should
// arrange to call run() on it soon. The default implementation does nothing.
virtual void wake() const;
// Wake up the EventPort's thread from another thread.
//
// Unlike all other methods on this interface, `wake()` may be called from another thread, hence
// it is `const`.
//
// Technically speaking, `wake()` causes the target thread to cease sleeping and not to sleep
// again until `wait()` or `poll()` has returned true at least once.
//
// The default implementation throws an UNIMPLEMENTED exception.
};
class EventLoop {
......
......@@ -254,8 +254,8 @@ public:
return *this;
}
inline operator int() { return fd; }
inline int get() { return fd; }
inline operator int() const { return fd; }
inline int get() const { return fd; }
inline bool operator==(decltype(nullptr)) { return fd < 0; }
inline bool operator!=(decltype(nullptr)) { return fd >= 0; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment