async-unix.c++ 12 KB
Newer Older
Kenton Varda's avatar
Kenton Varda committed
1 2
// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
3
//
Kenton Varda's avatar
Kenton Varda committed
4 5 6 7 8 9
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
10
//
Kenton Varda's avatar
Kenton Varda committed
11 12
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
13
//
Kenton Varda's avatar
Kenton Varda committed
14 15 16 17 18 19 20
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
21 22 23

#include "async-unix.h"
#include "debug.h"
24
#include "threadlocal.h"
25 26
#include <setjmp.h>
#include <errno.h>
27 28
#include <inttypes.h>
#include <limits>
29
#include <set>
30
#include <chrono>
31 32 33 34 35 36 37

namespace kj {

// =======================================================================================

namespace {

38 39 40
int reservedSignal = SIGUSR1;
bool tooLateToSetReserved = false;

41 42 43 44 45
struct SignalCapture {
  sigjmp_buf jumpTo;
  siginfo_t siginfo;
};

46
KJ_THREADLOCAL_PTR(SignalCapture) threadCapture = nullptr;
47 48 49 50 51 52 53 54 55 56

void signalHandler(int, siginfo_t* siginfo, void*) {
  SignalCapture* capture = threadCapture;
  if (capture != nullptr) {
    capture->siginfo = *siginfo;
    siglongjmp(capture->jumpTo, 1);
  }
}

void registerSignalHandler(int signum) {
57 58
  tooLateToSetReserved = true;

59 60 61 62 63 64 65 66 67 68 69 70 71
  sigset_t mask;
  sigemptyset(&mask);
  sigaddset(&mask, signum);
  sigprocmask(SIG_BLOCK, &mask, nullptr);

  struct sigaction action;
  memset(&action, 0, sizeof(action));
  action.sa_sigaction = &signalHandler;
  sigfillset(&action.sa_mask);
  action.sa_flags = SA_SIGINFO;
  sigaction(signum, &action, nullptr);
}

72 73
void registerReservedSignal() {
  registerSignalHandler(reservedSignal);
74 75 76

  // We also disable SIGPIPE because users of UnixEventLoop almost certainly don't want it.
  signal(SIGPIPE, SIG_IGN);
77 78
}

79
pthread_once_t registerReservedSignalOnce = PTHREAD_ONCE_INIT;
80 81 82 83 84

}  // namespace

// =======================================================================================

85 86 87 88 89 90 91 92 93 94
struct UnixEventPort::TimerSet {
  struct TimerBefore {
    bool operator()(TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs);
  };
  using Timers = std::multiset<TimerPromiseAdapter*, TimerBefore>;
  Timers timers;
};

// =======================================================================================

95
class UnixEventPort::SignalPromiseAdapter {
96 97
public:
  inline SignalPromiseAdapter(PromiseFulfiller<siginfo_t>& fulfiller,
98
                              UnixEventPort& loop, int signum)
99 100 101 102
      : loop(loop), signum(signum), fulfiller(fulfiller) {
    prev = loop.signalTail;
    *loop.signalTail = this;
    loop.signalTail = &next;
103 104 105
  }

  ~SignalPromiseAdapter() noexcept(false) {
106 107 108 109 110 111 112 113
    if (prev != nullptr) {
      if (next == nullptr) {
        loop.signalTail = prev;
      } else {
        next->prev = prev;
      }
      *prev = next;
    }
114 115
  }

116 117 118 119 120 121 122 123 124 125 126
  SignalPromiseAdapter* removeFromList() {
    auto result = next;
    if (next == nullptr) {
      loop.signalTail = prev;
    } else {
      next->prev = prev;
    }
    *prev = next;
    next = nullptr;
    prev = nullptr;
    return result;
127 128
  }

129
  UnixEventPort& loop;
130 131 132 133
  int signum;
  PromiseFulfiller<siginfo_t>& fulfiller;
  SignalPromiseAdapter* next = nullptr;
  SignalPromiseAdapter** prev = nullptr;
134 135
};

136
class UnixEventPort::PollPromiseAdapter {
137 138
public:
  inline PollPromiseAdapter(PromiseFulfiller<short>& fulfiller,
139
                            UnixEventPort& loop, int fd, short eventMask)
140 141 142 143
      : loop(loop), fd(fd), eventMask(eventMask), fulfiller(fulfiller) {
    prev = loop.pollTail;
    *loop.pollTail = this;
    loop.pollTail = &next;
144 145 146
  }

  ~PollPromiseAdapter() noexcept(false) {
147 148 149 150 151 152 153 154
    if (prev != nullptr) {
      if (next == nullptr) {
        loop.pollTail = prev;
      } else {
        next->prev = prev;
      }
      *prev = next;
    }
155
  }
156

157 158 159 160 161 162 163 164 165 166 167
  void removeFromList() {
    if (next == nullptr) {
      loop.pollTail = prev;
    } else {
      next->prev = prev;
    }
    *prev = next;
    next = nullptr;
    prev = nullptr;
  }

168
  UnixEventPort& loop;
169 170 171 172 173
  int fd;
  short eventMask;
  PromiseFulfiller<short>& fulfiller;
  PollPromiseAdapter* next = nullptr;
  PollPromiseAdapter** prev = nullptr;
174 175
};

176 177
class UnixEventPort::TimerPromiseAdapter {
public:
178
  TimerPromiseAdapter(PromiseFulfiller<void>& fulfiller, UnixEventPort& port, TimePoint time)
179
      : time(time), fulfiller(fulfiller), port(port) {
180
    pos = port.timers->timers.insert(this);
181 182 183
  }

  ~TimerPromiseAdapter() {
184 185
    if (pos != port.timers->timers.end()) {
      port.timers->timers.erase(pos);
186 187 188 189 190
    }
  }

  void fulfill() {
    fulfiller.fulfill();
191 192
    port.timers->timers.erase(pos);
    pos = port.timers->timers.end();
193 194
  }

195
  const TimePoint time;
196 197
  PromiseFulfiller<void>& fulfiller;
  UnixEventPort& port;
198
  TimerSet::Timers::const_iterator pos;
199 200
};

201 202
bool UnixEventPort::TimerSet::TimerBefore::operator()(
    TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) {
203 204 205
  return lhs->time < rhs->time;
}

206 207 208
UnixEventPort::UnixEventPort()
    : timers(kj::heap<TimerSet>()),
      frozenSteadyTime(currentSteadyTime()) {
209
  pthread_once(&registerReservedSignalOnce, &registerReservedSignal);
210 211
}

212
UnixEventPort::~UnixEventPort() noexcept(false) {}
213

214
Promise<short> UnixEventPort::onFdEvent(int fd, short eventMask) {
215
  return newAdaptedPromise<short, PollPromiseAdapter>(*this, fd, eventMask);
216 217
}

218
Promise<siginfo_t> UnixEventPort::onSignal(int signum) {
219
  return newAdaptedPromise<siginfo_t, SignalPromiseAdapter>(*this, signum);
220 221
}

222
void UnixEventPort::captureSignal(int signum) {
223 224 225 226 227 228 229 230
  if (reservedSignal == SIGUSR1) {
    KJ_REQUIRE(signum != SIGUSR1,
               "Sorry, SIGUSR1 is reserved by the UnixEventPort implementation.  You may call "
               "UnixEventPort::setReservedSignal() to reserve a different signal.");
  } else {
    KJ_REQUIRE(signum != reservedSignal,
               "Can't capture signal reserved using setReservedSignal().", signum);
  }
231 232 233
  registerSignalHandler(signum);
}

234 235 236 237 238 239 240 241 242 243 244
void UnixEventPort::setReservedSignal(int signum) {
  KJ_REQUIRE(!tooLateToSetReserved,
             "setReservedSignal() must be called before any calls to `captureSignal()` and "
             "before any `UnixEventPort` is constructed.");
  if (reservedSignal != SIGUSR1 && reservedSignal != signum) {
    KJ_FAIL_REQUIRE("Detected multiple conflicting calls to setReservedSignal().  Please only "
                    "call this once, or always call it with the same signal number.");
  }
  reservedSignal = signum;
}

245 246 247 248 249 250 251 252 253 254 255 256 257
class UnixEventPort::PollContext {
public:
  PollContext(PollPromiseAdapter* ptr) {
    while (ptr != nullptr) {
      struct pollfd pollfd;
      memset(&pollfd, 0, sizeof(pollfd));
      pollfd.fd = ptr->fd;
      pollfd.events = ptr->eventMask;
      pollfds.add(pollfd);
      pollEvents.add(ptr);
      ptr = ptr->next;
    }
  }
258

259 260 261 262
  void run(int timeout) {
    do {
      pollResult = ::poll(pollfds.begin(), pollfds.size(), timeout);
      pollError = pollResult < 0 ? errno : 0;
263

264 265 266 267
      // EINTR should only happen if we received a signal *other than* the ones registered via
      // the UnixEventPort, so we don't care about that case.
    } while (pollError == EINTR);
  }
268

269 270 271 272 273 274 275 276 277 278 279
  void processResults() {
    if (pollResult < 0) {
      KJ_FAIL_SYSCALL("poll()", pollError);
    }

    for (auto i: indices(pollfds)) {
      if (pollfds[i].revents != 0) {
        pollEvents[i]->fulfiller.fulfill(kj::mv(pollfds[i].revents));
        pollEvents[i]->removeFromList();
        if (--pollResult <= 0) {
          break;
280 281 282 283 284
        }
      }
    }
  }

285 286 287 288 289 290 291
private:
  kj::Vector<struct pollfd> pollfds;
  kj::Vector<PollPromiseAdapter*> pollEvents;
  int pollResult = 0;
  int pollError = 0;
};

292
Promise<void> UnixEventPort::atSteadyTime(TimePoint time) {
293 294 295
  return newAdaptedPromise<void, TimerPromiseAdapter>(*this, time);
}

296
void UnixEventPort::wait() {
297 298
  sigset_t newMask;
  sigemptyset(&newMask);
299
  sigaddset(&newMask, reservedSignal);
300 301

  {
302 303 304 305
    auto ptr = signalHead;
    while (ptr != nullptr) {
      sigaddset(&newMask, ptr->signum);
      ptr = ptr->next;
306 307 308
    }
  }

309
  PollContext pollContext(pollHead);
310

311 312 313 314 315 316 317
  // Capture signals.
  SignalCapture capture;

  if (sigsetjmp(capture.jumpTo, true)) {
    // We received a signal and longjmp'd back out of the signal handler.
    threadCapture = nullptr;

318
    if (capture.siginfo.si_signo != reservedSignal) {
319
      gotSignal(capture.siginfo);
320
    }
321 322

    return;
323 324
  }

325
  // Enable signals, run the poll, then mask them again.
326
  sigset_t origMask;
327
  threadCapture = &capture;
328 329
  sigprocmask(SIG_UNBLOCK, &newMask, &origMask);

330 331 332 333 334
  // poll()'s timeout is an `int` count of milliseconds, so truncate to that.
  // Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that
  // will break the math below.
  constexpr Duration MAX_TIMEOUT =
      min(int(maxValue) * MILLISECONDS, Duration(maxValue) - MILLISECONDS);
335 336

  int pollTimeout = -1;
337 338 339 340
  auto timer = timers->timers.begin();
  if (timer != timers->timers.end()) {
    Duration timeout = (*timer)->time - currentSteadyTime();
    if (timeout < 0 * SECONDS) {
341
      pollTimeout = 0;
342
    } else if (timeout < MAX_TIMEOUT) {
343
      // Round up to the next millisecond
344
      pollTimeout = (timeout + 1 * MILLISECONDS - unit<Duration>()) / MILLISECONDS;
345
    } else {
346
      pollTimeout = MAX_TIMEOUT / MILLISECONDS;
347 348 349
    }
  }
  pollContext.run(pollTimeout);
350 351 352 353

  sigprocmask(SIG_SETMASK, &origMask, nullptr);
  threadCapture = nullptr;

354 355
  // Queue events.
  pollContext.processResults();
356
  processTimers();
357
}
358

359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
void UnixEventPort::poll() {
  sigset_t pending;
  sigset_t waitMask;
  sigemptyset(&pending);
  sigfillset(&waitMask);

  // Count how many signals that we care about are pending.
  KJ_SYSCALL(sigpending(&pending));
  uint signalCount = 0;

  {
    auto ptr = signalHead;
    while (ptr != nullptr) {
      if (sigismember(&pending, ptr->signum)) {
        ++signalCount;
        sigdelset(&pending, ptr->signum);
        sigdelset(&waitMask, ptr->signum);
376
      }
377
      ptr = ptr->next;
378 379
    }
  }
380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402

  // Wait for each pending signal.  It would be nice to use sigtimedwait() here but it is not
  // available on OSX.  :(  Instead, we call sigsuspend() once per expected signal.
  while (signalCount-- > 0) {
    SignalCapture capture;
    threadCapture = &capture;
    if (sigsetjmp(capture.jumpTo, true)) {
      // We received a signal and longjmp'd back out of the signal handler.
      sigdelset(&waitMask, capture.siginfo.si_signo);
      gotSignal(capture.siginfo);
    } else {
      sigsuspend(&waitMask);
      KJ_FAIL_ASSERT("sigsuspend() shouldn't return because the signal handler should "
                     "have siglongjmp()ed.");
    }
    threadCapture = nullptr;
  }

  {
    PollContext pollContext(pollHead);
    pollContext.run(0);
    pollContext.processResults();
  }
403
  processTimers();
404 405
}

406 407 408 409 410 411 412 413 414 415
void UnixEventPort::gotSignal(const siginfo_t& siginfo) {
  // Fire any events waiting on this signal.
  auto ptr = signalHead;
  while (ptr != nullptr) {
    if (ptr->signum == siginfo.si_signo) {
      ptr->fulfiller.fulfill(kj::cp(siginfo));
      ptr = ptr->removeFromList();
    } else {
      ptr = ptr->next;
    }
416 417 418
  }
}

419
TimePoint UnixEventPort::currentSteadyTime() {
420 421
  return origin<TimePoint>() + std::chrono::duration_cast<std::chrono::nanoseconds>(
      std::chrono::steady_clock::now().time_since_epoch()).count() * NANOSECONDS;
422 423 424 425 426
}

void UnixEventPort::processTimers() {
  frozenSteadyTime = currentSteadyTime();
  for (;;) {
427 428
    auto front = timers->timers.begin();
    if (front == timers->timers.end() || (*front)->time > frozenSteadyTime) {
429 430 431 432 433 434
      break;
    }
    (*front)->fulfill();
  }
}

435
}  // namespace kj