async-unix.c++ 23.6 KB
Newer Older
Kenton Varda's avatar
Kenton Varda committed
1 2
// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
3
//
Kenton Varda's avatar
Kenton Varda committed
4 5 6 7 8 9
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
10
//
Kenton Varda's avatar
Kenton Varda committed
11 12
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
13
//
Kenton Varda's avatar
Kenton Varda committed
14 15 16 17 18 19 20
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
21 22 23

#include "async-unix.h"
#include "debug.h"
24
#include "threadlocal.h"
25 26
#include <setjmp.h>
#include <errno.h>
27 28
#include <inttypes.h>
#include <limits>
29
#include <set>
30
#include <chrono>
31 32 33 34 35 36 37 38
#include <pthread.h>

#if KJ_USE_EPOLL
#include <unistd.h>
#include <sys/epoll.h>
#include <sys/signalfd.h>
#include <sys/eventfd.h>
#else
39
#include <poll.h>
40
#endif
41 42 43 44

namespace kj {

// =======================================================================================
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
// Timer code common to multiple implementations

struct UnixEventPort::TimerSet {
  struct TimerBefore {
    bool operator()(TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs);
  };
  using Timers = std::multiset<TimerPromiseAdapter*, TimerBefore>;
  Timers timers;
};

class UnixEventPort::TimerPromiseAdapter {
public:
  TimerPromiseAdapter(PromiseFulfiller<void>& fulfiller, UnixEventPort& port, TimePoint time)
      : time(time), fulfiller(fulfiller), port(port) {
    pos = port.timers->timers.insert(this);
  }

  ~TimerPromiseAdapter() {
    if (pos != port.timers->timers.end()) {
      port.timers->timers.erase(pos);
    }
  }

  void fulfill() {
    fulfiller.fulfill();
    port.timers->timers.erase(pos);
    pos = port.timers->timers.end();
  }

  const TimePoint time;
  PromiseFulfiller<void>& fulfiller;
  UnixEventPort& port;
  TimerSet::Timers::const_iterator pos;
};

bool UnixEventPort::TimerSet::TimerBefore::operator()(
    TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) {
  return lhs->time < rhs->time;
}

Promise<void> UnixEventPort::atSteadyTime(TimePoint time) {
  return newAdaptedPromise<void, TimerPromiseAdapter>(*this, time);
}

TimePoint UnixEventPort::currentSteadyTime() {
  return origin<TimePoint>() + std::chrono::duration_cast<std::chrono::nanoseconds>(
      std::chrono::steady_clock::now().time_since_epoch()).count() * NANOSECONDS;
}

void UnixEventPort::processTimers() {
  frozenSteadyTime = currentSteadyTime();
  for (;;) {
    auto front = timers->timers.begin();
    if (front == timers->timers.end() || (*front)->time > frozenSteadyTime) {
      break;
    }
    (*front)->fulfill();
  }
}

// =======================================================================================
// Signal code common to multiple implementations
107 108 109

namespace {

110 111 112
int reservedSignal = SIGUSR1;
bool tooLateToSetReserved = false;

113 114 115 116 117
struct SignalCapture {
  sigjmp_buf jumpTo;
  siginfo_t siginfo;
};

118
#if !KJ_USE_EPOLL  // on Linux we'll use signalfd
119
KJ_THREADLOCAL_PTR(SignalCapture) threadCapture = nullptr;
120 121 122 123 124 125 126 127

void signalHandler(int, siginfo_t* siginfo, void*) {
  SignalCapture* capture = threadCapture;
  if (capture != nullptr) {
    capture->siginfo = *siginfo;
    siglongjmp(capture->jumpTo, 1);
  }
}
128
#endif
129 130

void registerSignalHandler(int signum) {
131 132
  tooLateToSetReserved = true;

133
  sigset_t mask;
134 135 136
  KJ_SYSCALL(sigemptyset(&mask));
  KJ_SYSCALL(sigaddset(&mask, signum));
  KJ_SYSCALL(sigprocmask(SIG_BLOCK, &mask, nullptr));
137

138
#if !KJ_USE_EPOLL  // on Linux we'll use signalfd
139 140 141
  struct sigaction action;
  memset(&action, 0, sizeof(action));
  action.sa_sigaction = &signalHandler;
142
  KJ_SYSCALL(sigfillset(&action.sa_mask));
143
  action.sa_flags = SA_SIGINFO;
144 145
  KJ_SYSCALL(sigaction(signum, &action, nullptr));
#endif
146 147
}

148 149
void registerReservedSignal() {
  registerSignalHandler(reservedSignal);
150

151
  // We also disable SIGPIPE because users of UnixEventPort almost certainly don't want it.
152 153 154 155 156 157
  while (signal(SIGPIPE, SIG_IGN) == SIG_ERR) {
    int error = errno;
    if (error != EINTR) {
      KJ_FAIL_SYSCALL("signal(SIGPIPE, SIG_IGN)", error);
    }
  }
158 159
}

160
pthread_once_t registerReservedSignalOnce = PTHREAD_ONCE_INIT;
161 162 163

}  // namespace

164
class UnixEventPort::SignalPromiseAdapter {
165 166
public:
  inline SignalPromiseAdapter(PromiseFulfiller<siginfo_t>& fulfiller,
167
                              UnixEventPort& loop, int signum)
168 169 170 171
      : loop(loop), signum(signum), fulfiller(fulfiller) {
    prev = loop.signalTail;
    *loop.signalTail = this;
    loop.signalTail = &next;
172 173 174
  }

  ~SignalPromiseAdapter() noexcept(false) {
175 176 177 178 179 180 181 182
    if (prev != nullptr) {
      if (next == nullptr) {
        loop.signalTail = prev;
      } else {
        next->prev = prev;
      }
      *prev = next;
    }
183 184
  }

185 186 187 188 189 190 191 192 193 194 195
  SignalPromiseAdapter* removeFromList() {
    auto result = next;
    if (next == nullptr) {
      loop.signalTail = prev;
    } else {
      next->prev = prev;
    }
    *prev = next;
    next = nullptr;
    prev = nullptr;
    return result;
196 197
  }

198
  UnixEventPort& loop;
199 200 201 202
  int signum;
  PromiseFulfiller<siginfo_t>& fulfiller;
  SignalPromiseAdapter* next = nullptr;
  SignalPromiseAdapter** prev = nullptr;
203 204
};

205 206 207 208 209 210 211 212 213 214 215 216
Promise<siginfo_t> UnixEventPort::onSignal(int signum) {
  return newAdaptedPromise<siginfo_t, SignalPromiseAdapter>(*this, signum);
}

void UnixEventPort::captureSignal(int signum) {
  if (reservedSignal == SIGUSR1) {
    KJ_REQUIRE(signum != SIGUSR1,
               "Sorry, SIGUSR1 is reserved by the UnixEventPort implementation.  You may call "
               "UnixEventPort::setReservedSignal() to reserve a different signal.");
  } else {
    KJ_REQUIRE(signum != reservedSignal,
               "Can't capture signal reserved using setReservedSignal().", signum);
217
  }
218 219
  registerSignalHandler(signum);
}
220

221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
void UnixEventPort::setReservedSignal(int signum) {
  KJ_REQUIRE(!tooLateToSetReserved,
             "setReservedSignal() must be called before any calls to `captureSignal()` and "
             "before any `UnixEventPort` is constructed.");
  if (reservedSignal != SIGUSR1 && reservedSignal != signum) {
    KJ_FAIL_REQUIRE("Detected multiple conflicting calls to setReservedSignal().  Please only "
                    "call this once, or always call it with the same signal number.");
  }
  reservedSignal = signum;
}

void UnixEventPort::gotSignal(const siginfo_t& siginfo) {
  // Fire any events waiting on this signal.
  auto ptr = signalHead;
  while (ptr != nullptr) {
    if (ptr->signum == siginfo.si_signo) {
      ptr->fulfiller.fulfill(kj::cp(siginfo));
      ptr = ptr->removeFromList();
    } else {
      ptr = ptr->next;
241
    }
242
  }
243
}
244

245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
#if KJ_USE_EPOLL
// =======================================================================================
// epoll FdObserver implementation

UnixEventPort::UnixEventPort()
    : timers(kj::heap<TimerSet>()),
      frozenSteadyTime(currentSteadyTime()),
      epollFd(-1),
      signalFd(-1),
      eventFd(-1) {
  pthread_once(&registerReservedSignalOnce, &registerReservedSignal);

  int fd;
  KJ_SYSCALL(fd = epoll_create1(EPOLL_CLOEXEC));
  epollFd = AutoCloseFd(fd);

  KJ_SYSCALL(sigemptyset(&signalFdSigset));
  KJ_SYSCALL(fd = signalfd(-1, &signalFdSigset, SFD_NONBLOCK | SFD_CLOEXEC));
  signalFd = AutoCloseFd(fd);

  KJ_SYSCALL(fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
  eventFd = AutoCloseFd(fd);


  struct epoll_event event;
  memset(&event, 0, sizeof(event));
  event.events = EPOLLIN;
  event.data.u64 = 0;
  KJ_SYSCALL(epoll_ctl(epollFd, EPOLL_CTL_ADD, signalFd, &event));
  event.data.u64 = 1;
  KJ_SYSCALL(epoll_ctl(epollFd, EPOLL_CTL_ADD, eventFd, &event));
}

UnixEventPort::~UnixEventPort() noexcept(false) {}

UnixEventPort::FdObserver::FdObserver(UnixEventPort& eventPort, int fd, uint flags)
    : eventPort(eventPort), fd(fd), flags(flags) {
  struct epoll_event event;
  memset(&event, 0, sizeof(event));

  if (flags & OBSERVE_READ) {
    event.events |= EPOLLIN | EPOLLRDHUP;
  }
  if (flags & OBSERVE_WRITE) {
    event.events |= EPOLLOUT;
  }
  event.events |= EPOLLET;  // Set edge-triggered mode.

  event.data.ptr = this;

  KJ_SYSCALL(epoll_ctl(eventPort.epollFd, EPOLL_CTL_ADD, fd, &event));
}

UnixEventPort::FdObserver::~FdObserver() noexcept(false) {
299
  KJ_SYSCALL(epoll_ctl(eventPort.epollFd, EPOLL_CTL_DEL, fd, nullptr)) { break; }
300 301 302 303 304 305 306 307 308 309 310 311 312 313
}

void UnixEventPort::FdObserver::fire(short events) {
  if (events & (EPOLLIN | EPOLLHUP | EPOLLRDHUP | EPOLLERR)) {
    if (events & (EPOLLHUP | EPOLLRDHUP)) {
      atEnd = true;
    } else {
      // Since we didn't receive EPOLLRDHUP, we know that we're not at the end.
      atEnd = false;
    }

    KJ_IF_MAYBE(f, readFulfiller) {
      f->get()->fulfill();
      readFulfiller = nullptr;
314 315 316
    }
  }

317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
  if (events & (EPOLLOUT | EPOLLHUP | EPOLLERR)) {
    KJ_IF_MAYBE(f, writeFulfiller) {
      f->get()->fulfill();
      writeFulfiller = nullptr;
    }
  }
}

Promise<void> UnixEventPort::FdObserver::whenBecomesReadable() {
  KJ_REQUIRE(flags & OBSERVE_READ, "FdObserver was not set to observe reads.");

  auto paf = newPromiseAndFulfiller<void>();
  readFulfiller = kj::mv(paf.fulfiller);
  return kj::mv(paf.promise);
}

Promise<void> UnixEventPort::FdObserver::whenBecomesWritable() {
  KJ_REQUIRE(flags & OBSERVE_WRITE, "FdObserver was not set to observe writes.");

  auto paf = newPromiseAndFulfiller<void>();
  writeFulfiller = kj::mv(paf.fulfiller);
  return kj::mv(paf.promise);
}

341
bool UnixEventPort::wait() {
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
  // epoll_wait()'s timeout is an `int` count of milliseconds, so truncate to that.
  // Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that
  // will break the math below.
  constexpr Duration MAX_TIMEOUT =
      min(int(maxValue) * MILLISECONDS, Duration(maxValue) - MILLISECONDS);

  int epollTimeout = -1;
  auto timer = timers->timers.begin();
  if (timer != timers->timers.end()) {
    Duration timeout = (*timer)->time - currentSteadyTime();
    if (timeout < 0 * SECONDS) {
      epollTimeout = 0;
    } else if (timeout < MAX_TIMEOUT) {
      // Round up to the next millisecond
      epollTimeout = (timeout + 1 * MILLISECONDS - unit<Duration>()) / MILLISECONDS;
357
    } else {
358
      epollTimeout = MAX_TIMEOUT / MILLISECONDS;
359 360 361
    }
  }

362
  return doEpollWait(epollTimeout);
363
}
364

365 366 367 368 369 370 371 372 373
bool UnixEventPort::poll() {
  return doEpollWait(0);
}

void UnixEventPort::wake() const {
  uint64_t one = 1;
  ssize_t n;
  KJ_NONBLOCKING_SYSCALL(n = write(eventFd, &one, sizeof(one)));
  KJ_ASSERT(n < 0 || n == sizeof(one));
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
}

static siginfo_t toRegularSiginfo(const struct signalfd_siginfo& siginfo) {
  // Unfortunately, siginfo_t is mostly a big union and the correct set of fields to fill in
  // depends on the type of signal. OTOH, signalfd_siginfo is a flat struct that expands all
  // siginfo_t's union fields out to be non-overlapping. We can't just copy all the fields over
  // because of the unions; we have to carefully figure out which fields are appropriate to fill
  // in for this signal. Ick.

  siginfo_t result;
  memset(&result, 0, sizeof(result));

  result.si_signo = siginfo.ssi_signo;
  result.si_errno = siginfo.ssi_errno;
  result.si_code = siginfo.ssi_code;

  if (siginfo.ssi_code > 0) {
    // Signal originated from the kernel. The structure of the siginfo depends primarily on the
    // signal number.

    switch (siginfo.ssi_signo) {
    case SIGCHLD:
      result.si_pid = siginfo.ssi_pid;
      result.si_uid = siginfo.ssi_uid;
      result.si_status = siginfo.ssi_status;
      result.si_utime = siginfo.ssi_utime;
      result.si_stime = siginfo.ssi_stime;
      break;

    case SIGILL:
    case SIGFPE:
    case SIGSEGV:
    case SIGBUS:
    case SIGTRAP:
      result.si_addr = reinterpret_cast<void*>(static_cast<uintptr_t>(siginfo.ssi_addr));
#ifdef si_trapno
      result.si_trapno = siginfo.ssi_trapno;
#endif
412
#ifdef si_addr_lsb
413 414 415
      // ssi_addr_lsb is defined as coming immediately after ssi_addr in the kernel headers but
      // apparently the userspace headers were never updated. So we do a pointer hack. :(
      result.si_addr_lsb = *reinterpret_cast<const uint16_t*>(&siginfo.ssi_addr + 1);
416
#endif
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
      break;

    case SIGIO:
      static_assert(SIGIO == SIGPOLL, "SIGIO != SIGPOLL?");

      // Note: Technically, code can arrange for SIGIO signals to be delivered with a signal number
      //   other than SIGIO. AFAICT there is no way for us to detect this in the siginfo. Luckily
      //   SIGIO is totally obsoleted by epoll so it shouldn't come up.

      result.si_band = siginfo.ssi_band;
      result.si_fd = siginfo.ssi_fd;
      break;

    case SIGSYS:
      // Apparently SIGSYS's fields are not available in signalfd_siginfo?
      break;
    }

  } else {
    // Signal originated from userspace. The sender could specify whatever signal number they
    // wanted. The structure of the signal is determined by the API they used, which is identified
    // by SI_CODE.

    switch (siginfo.ssi_code) {
      case SI_USER:
      case SI_TKILL:
        // kill(), tkill(), or tgkill().
        result.si_pid = siginfo.ssi_pid;
        result.si_uid = siginfo.ssi_uid;
        break;

      case SI_QUEUE:
      case SI_MESGQ:
      case SI_ASYNCIO:
      default:
        result.si_pid = siginfo.ssi_pid;
        result.si_uid = siginfo.ssi_uid;

        // This is awkward. In siginfo_t, si_ptr and si_int are in a union together. In
        // signalfd_siginfo, they are not. We don't really know whether the app intended to send
        // an int or a pointer. Presumably since the pointer is always larger than the int, if
        // we write the pointer, we'll end up with the right value for the int? Presumably the
        // two fields of signalfd_siginfo are actually extracted from one of these unions
        // originally, so actually contain redundant data? Better write some tests...
        result.si_ptr = reinterpret_cast<void*>(static_cast<uintptr_t>(siginfo.ssi_ptr));
        break;

      case SI_TIMER:
        result.si_timerid = siginfo.ssi_tid;
        result.si_overrun = siginfo.ssi_overrun;

        // Again with this weirdness...
        result.si_ptr = reinterpret_cast<void*>(static_cast<uintptr_t>(siginfo.ssi_ptr));
        break;
    }
472 473
  }

474 475 476
  return result;
}

477
bool UnixEventPort::doEpollWait(int timeout) {
478 479 480 481 482 483 484 485
  sigset_t newMask;
  sigemptyset(&newMask);

  {
    auto ptr = signalHead;
    while (ptr != nullptr) {
      sigaddset(&newMask, ptr->signum);
      ptr = ptr->next;
486 487 488
    }
  }

489 490 491 492 493
  if (memcmp(&newMask, &signalFdSigset, sizeof(newMask)) != 0) {
    // Apparently we're not waiting on the same signals as last time. Need to update the signal
    // FD's mask.
    signalFdSigset = newMask;
    KJ_SYSCALL(signalfd(signalFd, &signalFdSigset, SFD_NONBLOCK | SFD_CLOEXEC));
494 495
  }

496 497 498
  struct epoll_event events[16];
  int n;
  KJ_SYSCALL(n = epoll_wait(epollFd, events, kj::size(events), timeout));
499

500 501
  bool woken = false;

502 503 504 505 506 507 508 509 510 511 512 513 514
  for (int i = 0; i < n; i++) {
    if (events[i].data.u64 == 0) {
      for (;;) {
        struct signalfd_siginfo siginfo;
        ssize_t n;
        KJ_NONBLOCKING_SYSCALL(n = read(signalFd, &siginfo, sizeof(siginfo)));
        if (n < 0) break;  // no more signals

        KJ_ASSERT(n == sizeof(siginfo));

        gotSignal(toRegularSiginfo(siginfo));
      }
    } else if (events[i].data.u64 == 1) {
515
      // Someone called wake() from another thread. Consume the event.
516 517 518 519
      uint64_t value;
      ssize_t n;
      KJ_NONBLOCKING_SYSCALL(n = read(eventFd, &value, sizeof(value)));
      KJ_ASSERT(n < 0 || n == sizeof(value));
520 521 522

      // We were woken. Need to return true.
      woken = true;
523 524 525 526 527 528 529
    } else {
      FdObserver* observer = reinterpret_cast<FdObserver*>(events[i].data.ptr);
      observer->fire(events[i].events);
    }
  }

  processTimers();
530 531

  return woken;
532 533
}

534 535 536 537 538 539 540 541
#else  // KJ_USE_EPOLL
// =======================================================================================
// Traditional poll() FdObserver implementation.

#ifndef POLLRDHUP
#define POLLRDHUP 0
#endif

542 543 544
UnixEventPort::UnixEventPort()
    : timers(kj::heap<TimerSet>()),
      frozenSteadyTime(currentSteadyTime()) {
545 546 547 548
  static_assert(sizeof(threadId) >= sizeof(pthread_t),
                "pthread_t is larger than a long long on your platform.  Please port.");
  *reinterpret_cast<pthread_t*>(&threadId) = pthread_self();

549
  pthread_once(&registerReservedSignalOnce, &registerReservedSignal);
550 551
}

552
UnixEventPort::~UnixEventPort() noexcept(false) {}
553

554 555 556 557 558 559 560 561 562 563 564 565
UnixEventPort::FdObserver::FdObserver(UnixEventPort& eventPort, int fd, uint flags)
    : eventPort(eventPort), fd(fd), flags(flags), next(nullptr), prev(nullptr) {}

UnixEventPort::FdObserver::~FdObserver() noexcept(false) {
  if (prev != nullptr) {
    if (next == nullptr) {
      eventPort.observersTail = prev;
    } else {
      next->prev = prev;
    }
    *prev = next;
  }
566 567
}

568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
void UnixEventPort::FdObserver::fire(short events) {
  if (events & (POLLIN | POLLHUP | POLLRDHUP | POLLERR | POLLNVAL)) {
    if (events & (POLLHUP | POLLRDHUP)) {
      atEnd = true;
#if POLLRDHUP
    } else {
      // Since POLLRDHUP exists on this platform, and we didn't receive it, we know that we're not
      // at the end.
      atEnd = false;
#endif
    }

    KJ_IF_MAYBE(f, readFulfiller) {
      f->get()->fulfill();
      readFulfiller = nullptr;
    }
  }

  if (events & (POLLOUT | POLLHUP | POLLERR | POLLNVAL)) {
    KJ_IF_MAYBE(f, writeFulfiller) {
      f->get()->fulfill();
      writeFulfiller = nullptr;
    }
  }

  if (readFulfiller == nullptr && writeFulfiller == nullptr) {
    // Remove from list.
    if (next == nullptr) {
      eventPort.observersTail = prev;
    } else {
      next->prev = prev;
    }
    *prev = next;
    next = nullptr;
    prev = nullptr;
  }
604 605
}

606 607 608
short UnixEventPort::FdObserver::getEventMask() {
  return (readFulfiller == nullptr ? 0 : (POLLIN | POLLRDHUP)) |
         (writeFulfiller == nullptr ? 0 : POLLOUT);
609 610
}

611 612 613 614 615 616 617 618
Promise<void> UnixEventPort::FdObserver::whenBecomesReadable() {
  KJ_REQUIRE(flags & OBSERVE_READ, "FdObserver was not set to observe reads.");

  if (prev == nullptr) {
    KJ_DASSERT(next == nullptr);
    prev = eventPort.observersTail;
    *prev = this;
    eventPort.observersTail = &next;
619
  }
620 621 622 623

  auto paf = newPromiseAndFulfiller<void>();
  readFulfiller = kj::mv(paf.fulfiller);
  return kj::mv(paf.promise);
624 625
}

626 627 628 629 630 631 632 633
Promise<void> UnixEventPort::FdObserver::whenBecomesWritable() {
  KJ_REQUIRE(flags & OBSERVE_WRITE, "FdObserver was not set to observe writes.");

  if (prev == nullptr) {
    KJ_DASSERT(next == nullptr);
    prev = eventPort.observersTail;
    *prev = this;
    eventPort.observersTail = &next;
634
  }
635 636 637 638

  auto paf = newPromiseAndFulfiller<void>();
  writeFulfiller = kj::mv(paf.fulfiller);
  return kj::mv(paf.promise);
639 640
}

641 642
class UnixEventPort::PollContext {
public:
643
  PollContext(FdObserver* ptr) {
644 645 646 647
    while (ptr != nullptr) {
      struct pollfd pollfd;
      memset(&pollfd, 0, sizeof(pollfd));
      pollfd.fd = ptr->fd;
648
      pollfd.events = ptr->getEventMask();
649 650 651 652 653
      pollfds.add(pollfd);
      pollEvents.add(ptr);
      ptr = ptr->next;
    }
  }
654

655 656 657 658
  void run(int timeout) {
    do {
      pollResult = ::poll(pollfds.begin(), pollfds.size(), timeout);
      pollError = pollResult < 0 ? errno : 0;
659

660 661 662 663
      // EINTR should only happen if we received a signal *other than* the ones registered via
      // the UnixEventPort, so we don't care about that case.
    } while (pollError == EINTR);
  }
664

665 666 667 668 669 670 671
  void processResults() {
    if (pollResult < 0) {
      KJ_FAIL_SYSCALL("poll()", pollError);
    }

    for (auto i: indices(pollfds)) {
      if (pollfds[i].revents != 0) {
672
        pollEvents[i]->fire(pollfds[i].revents);
673 674
        if (--pollResult <= 0) {
          break;
675 676 677 678 679
        }
      }
    }
  }

680 681
private:
  kj::Vector<struct pollfd> pollfds;
682
  kj::Vector<FdObserver*> pollEvents;
683 684 685 686
  int pollResult = 0;
  int pollError = 0;
};

687
bool UnixEventPort::wait() {
688 689
  sigset_t newMask;
  sigemptyset(&newMask);
690
  sigaddset(&newMask, reservedSignal);
691 692

  {
693 694 695 696
    auto ptr = signalHead;
    while (ptr != nullptr) {
      sigaddset(&newMask, ptr->signum);
      ptr = ptr->next;
697 698 699
    }
  }

700
  PollContext pollContext(observersHead);
701

702 703 704 705 706 707 708
  // Capture signals.
  SignalCapture capture;

  if (sigsetjmp(capture.jumpTo, true)) {
    // We received a signal and longjmp'd back out of the signal handler.
    threadCapture = nullptr;

709 710 711
    if (capture.siginfo.si_signo == reservedSignal) {
      return true;
    } else {
712
      gotSignal(capture.siginfo);
713
      return false;
714 715 716
    }
  }

717
  // Enable signals, run the poll, then mask them again.
718
  sigset_t origMask;
719
  threadCapture = &capture;
720 721
  sigprocmask(SIG_UNBLOCK, &newMask, &origMask);

722 723 724 725 726
  // poll()'s timeout is an `int` count of milliseconds, so truncate to that.
  // Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that
  // will break the math below.
  constexpr Duration MAX_TIMEOUT =
      min(int(maxValue) * MILLISECONDS, Duration(maxValue) - MILLISECONDS);
727 728

  int pollTimeout = -1;
729 730 731 732
  auto timer = timers->timers.begin();
  if (timer != timers->timers.end()) {
    Duration timeout = (*timer)->time - currentSteadyTime();
    if (timeout < 0 * SECONDS) {
733
      pollTimeout = 0;
734
    } else if (timeout < MAX_TIMEOUT) {
735
      // Round up to the next millisecond
736
      pollTimeout = (timeout + 1 * MILLISECONDS - unit<Duration>()) / MILLISECONDS;
737
    } else {
738
      pollTimeout = MAX_TIMEOUT / MILLISECONDS;
739 740 741
    }
  }
  pollContext.run(pollTimeout);
742 743 744 745

  sigprocmask(SIG_SETMASK, &origMask, nullptr);
  threadCapture = nullptr;

746 747
  // Queue events.
  pollContext.processResults();
748
  processTimers();
749 750

  return false;
751
}
752

753
bool UnixEventPort::poll() {
754 755
  // volatile so that longjmp() doesn't clobber it.
  volatile bool woken = false;
756

757 758 759 760 761 762 763 764 765
  sigset_t pending;
  sigset_t waitMask;
  sigemptyset(&pending);
  sigfillset(&waitMask);

  // Count how many signals that we care about are pending.
  KJ_SYSCALL(sigpending(&pending));
  uint signalCount = 0;

766 767 768 769 770 771
  if (sigismember(&pending, reservedSignal)) {
    ++signalCount;
    sigdelset(&pending, reservedSignal);
    sigdelset(&waitMask, reservedSignal);
  }

772 773 774 775 776 777 778
  {
    auto ptr = signalHead;
    while (ptr != nullptr) {
      if (sigismember(&pending, ptr->signum)) {
        ++signalCount;
        sigdelset(&pending, ptr->signum);
        sigdelset(&waitMask, ptr->signum);
779
      }
780
      ptr = ptr->next;
781 782
    }
  }
783 784 785 786 787 788 789 790 791

  // Wait for each pending signal.  It would be nice to use sigtimedwait() here but it is not
  // available on OSX.  :(  Instead, we call sigsuspend() once per expected signal.
  while (signalCount-- > 0) {
    SignalCapture capture;
    threadCapture = &capture;
    if (sigsetjmp(capture.jumpTo, true)) {
      // We received a signal and longjmp'd back out of the signal handler.
      sigdelset(&waitMask, capture.siginfo.si_signo);
792 793 794 795 796
      if (capture.siginfo.si_signo == reservedSignal) {
        woken = true;
      } else {
        gotSignal(capture.siginfo);
      }
797 798 799 800 801 802 803 804 805
    } else {
      sigsuspend(&waitMask);
      KJ_FAIL_ASSERT("sigsuspend() shouldn't return because the signal handler should "
                     "have siglongjmp()ed.");
    }
    threadCapture = nullptr;
  }

  {
806
    PollContext pollContext(observersHead);
807 808 809
    pollContext.run(0);
    pollContext.processResults();
  }
810
  processTimers();
811 812 813 814 815 816 817 818 819

  return woken;
}

void UnixEventPort::wake() const {
  int error = pthread_kill(*reinterpret_cast<const pthread_t*>(&threadId), reservedSignal);
  if (error != 0) {
    KJ_FAIL_SYSCALL("pthread_kill", error);
  }
820 821
}

822
#endif  // KJ_USE_EPOLL, else
823

824
}  // namespace kj