Commit 517e1e4d authored by Kenton Varda's avatar Kenton Varda

Factor out Timer implementation from async-unix.c++ into a reusable class.

parent 2ebe56c4
......@@ -893,28 +893,10 @@ public:
UnixEventPort::FdObserver observer;
};
class TimerImpl final: public Timer {
public:
TimerImpl(UnixEventPort& eventPort): eventPort(eventPort) {}
TimePoint now() override { return eventPort.steadyTime(); }
Promise<void> atTime(TimePoint time) override {
return eventPort.atSteadyTime(time);
}
Promise<void> afterDelay(Duration delay) override {
return eventPort.atSteadyTime(eventPort.steadyTime() + delay);
}
private:
UnixEventPort& eventPort;
};
class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
public:
LowLevelAsyncIoProviderImpl()
: eventLoop(eventPort), timer(eventPort), waitScope(eventLoop) {}
: eventLoop(eventPort), waitScope(eventLoop) {}
inline WaitScope& getWaitScope() { return waitScope; }
......@@ -948,14 +930,13 @@ public:
return heap<DatagramPortImpl>(*this, eventPort, fd, flags);
}
Timer& getTimer() override { return timer; }
Timer& getTimer() override { return eventPort.getTimer(); }
UnixEventPort& getEventPort() { return eventPort; }
private:
UnixEventPort eventPort;
EventLoop eventLoop;
TimerImpl timer;
WaitScope waitScope;
};
......
......@@ -532,14 +532,16 @@ TEST(AsyncUnixTest, SteadyTimers) {
EventLoop loop(port);
WaitScope waitScope(loop);
auto start = port.steadyTime();
auto& timer = port.getTimer();
auto start = timer.now();
kj::Vector<TimePoint> expected;
kj::Vector<TimePoint> actual;
auto addTimer = [&](Duration delay) {
expected.add(max(start + delay, start));
port.atSteadyTime(start + delay).then([&]() {
actual.add(port.steadyTime());
timer.atTime(start + delay).then([&]() {
actual.add(timer.now());
}).detach([](Exception&& e) { ADD_FAILURE() << str(e).cStr(); });
};
......@@ -550,7 +552,7 @@ TEST(AsyncUnixTest, SteadyTimers) {
addTimer(-10 * MILLISECONDS);
std::sort(expected.begin(), expected.end());
port.atSteadyTime(expected.back() + MILLISECONDS).wait(waitScope);
timer.atTime(expected.back() + MILLISECONDS).wait(waitScope);
ASSERT_EQ(expected.size(), actual.size());
for (int i = 0; i < expected.size(); ++i) {
......@@ -574,7 +576,7 @@ TEST(AsyncUnixTest, Wake) {
EXPECT_TRUE(port.wait());
{
auto promise = port.atSteadyTime(port.steadyTime());
auto promise = port.getTimer().atTime(port.getTimer().now());
EXPECT_FALSE(port.wait());
}
......
......@@ -28,7 +28,6 @@
#include <errno.h>
#include <inttypes.h>
#include <limits>
#include <set>
#include <chrono>
#include <pthread.h>
......@@ -46,64 +45,11 @@ namespace kj {
// =======================================================================================
// Timer code common to multiple implementations
struct UnixEventPort::TimerSet {
struct TimerBefore {
bool operator()(TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs);
};
using Timers = std::multiset<TimerPromiseAdapter*, TimerBefore>;
Timers timers;
};
class UnixEventPort::TimerPromiseAdapter {
public:
TimerPromiseAdapter(PromiseFulfiller<void>& fulfiller, UnixEventPort& port, TimePoint time)
: time(time), fulfiller(fulfiller), port(port) {
pos = port.timers->timers.insert(this);
}
~TimerPromiseAdapter() {
if (pos != port.timers->timers.end()) {
port.timers->timers.erase(pos);
}
}
void fulfill() {
fulfiller.fulfill();
port.timers->timers.erase(pos);
pos = port.timers->timers.end();
}
const TimePoint time;
PromiseFulfiller<void>& fulfiller;
UnixEventPort& port;
TimerSet::Timers::const_iterator pos;
};
bool UnixEventPort::TimerSet::TimerBefore::operator()(
TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) {
return lhs->time < rhs->time;
}
Promise<void> UnixEventPort::atSteadyTime(TimePoint time) {
return newAdaptedPromise<void, TimerPromiseAdapter>(*this, time);
}
TimePoint UnixEventPort::currentSteadyTime() {
TimePoint UnixEventPort::readClock() {
return origin<TimePoint>() + std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch()).count() * NANOSECONDS;
}
void UnixEventPort::processTimers() {
frozenSteadyTime = currentSteadyTime();
for (;;) {
auto front = timers->timers.begin();
if (front == timers->timers.end() || (*front)->time > frozenSteadyTime) {
break;
}
(*front)->fulfill();
}
}
// =======================================================================================
// Signal code common to multiple implementations
......@@ -249,8 +195,7 @@ void UnixEventPort::gotSignal(const siginfo_t& siginfo) {
// epoll FdObserver implementation
UnixEventPort::UnixEventPort()
: timers(kj::heap<TimerSet>()),
frozenSteadyTime(currentSteadyTime()),
: timerImpl(readClock()),
epollFd(-1),
signalFd(-1),
eventFd(-1) {
......@@ -360,27 +305,10 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() {
}
bool UnixEventPort::wait() {
// epoll_wait()'s timeout is an `int` count of milliseconds, so truncate to that.
// Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that
// will break the math below.
constexpr Duration MAX_TIMEOUT =
min(int(maxValue) * MILLISECONDS, Duration(maxValue) - MILLISECONDS);
int epollTimeout = -1;
auto timer = timers->timers.begin();
if (timer != timers->timers.end()) {
Duration timeout = (*timer)->time - currentSteadyTime();
if (timeout < 0 * SECONDS) {
epollTimeout = 0;
} else if (timeout < MAX_TIMEOUT) {
// Round up to the next millisecond
epollTimeout = (timeout + 1 * MILLISECONDS - unit<Duration>()) / MILLISECONDS;
} else {
epollTimeout = MAX_TIMEOUT / MILLISECONDS;
}
}
return doEpollWait(epollTimeout);
return doEpollWait(
timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue))
.map([](uint64_t t) -> int { return t; })
.orDefault(-1));
}
bool UnixEventPort::poll() {
......@@ -554,7 +482,7 @@ bool UnixEventPort::doEpollWait(int timeout) {
}
}
processTimers();
timerImpl.advanceTo(readClock());
return woken;
}
......@@ -568,8 +496,7 @@ bool UnixEventPort::doEpollWait(int timeout) {
#endif
UnixEventPort::UnixEventPort()
: timers(kj::heap<TimerSet>()),
frozenSteadyTime(currentSteadyTime()) {
: timerImpl(readClock()) {
static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port.");
*reinterpret_cast<pthread_t*>(&threadId) = pthread_self();
......@@ -771,33 +698,17 @@ bool UnixEventPort::wait() {
threadCapture = &capture;
sigprocmask(SIG_UNBLOCK, &newMask, &origMask);
// poll()'s timeout is an `int` count of milliseconds, so truncate to that.
// Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that
// will break the math below.
constexpr Duration MAX_TIMEOUT =
min(int(maxValue) * MILLISECONDS, Duration(maxValue) - MILLISECONDS);
int pollTimeout = -1;
auto timer = timers->timers.begin();
if (timer != timers->timers.end()) {
Duration timeout = (*timer)->time - currentSteadyTime();
if (timeout < 0 * SECONDS) {
pollTimeout = 0;
} else if (timeout < MAX_TIMEOUT) {
// Round up to the next millisecond
pollTimeout = (timeout + 1 * MILLISECONDS - unit<Duration>()) / MILLISECONDS;
} else {
pollTimeout = MAX_TIMEOUT / MILLISECONDS;
}
}
pollContext.run(pollTimeout);
pollContext.run(
timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue))
.map([](uint64_t t) -> int { return t; })
.orDefault(-1));
sigprocmask(SIG_SETMASK, &origMask, nullptr);
threadCapture = nullptr;
// Queue events.
pollContext.processResults();
processTimers();
timerImpl.advanceTo(readClock());
return false;
}
......@@ -859,7 +770,7 @@ bool UnixEventPort::poll() {
pollContext.run(0);
pollContext.processResults();
}
processTimers();
timerImpl.advanceTo(readClock());
return woken;
}
......
......@@ -97,8 +97,7 @@ public:
// needs to use SIGUSR1, call this at startup (before any calls to `captureSignal()` and before
// constructing an `UnixEventPort`) to offer a different signal.
TimePoint steadyTime() { return frozenSteadyTime; }
Promise<void> atSteadyTime(TimePoint time);
Timer& getTimer() { return timerImpl; }
// implements EventPort ------------------------------------------------------
bool wait() override;
......@@ -110,14 +109,12 @@ private:
class TimerPromiseAdapter;
class SignalPromiseAdapter;
Own<TimerSet> timers;
TimePoint frozenSteadyTime;
TimerImpl timerImpl;
SignalPromiseAdapter* signalHead = nullptr;
SignalPromiseAdapter** signalTail = &signalHead;
TimePoint currentSteadyTime();
void processTimers();
TimePoint readClock();
void gotSignal(const siginfo_t& siginfo);
friend class TimerPromiseAdapter;
......
......@@ -22,6 +22,7 @@
#include "time.h"
#include "debug.h"
#include <set>
namespace kj {
......@@ -29,4 +30,96 @@ kj::Exception Timer::makeTimeoutException() {
return KJ_EXCEPTION(OVERLOADED, "operation timed out");
}
struct TimerImpl::Impl {
struct TimerBefore {
bool operator()(TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs);
};
using Timers = std::multiset<TimerPromiseAdapter*, TimerBefore>;
Timers timers;
};
class TimerImpl::TimerPromiseAdapter {
public:
TimerPromiseAdapter(PromiseFulfiller<void>& fulfiller, TimerImpl::Impl& impl, TimePoint time)
: time(time), fulfiller(fulfiller), impl(impl) {
pos = impl.timers.insert(this);
}
~TimerPromiseAdapter() {
if (pos != impl.timers.end()) {
impl.timers.erase(pos);
}
}
void fulfill() {
fulfiller.fulfill();
impl.timers.erase(pos);
pos = impl.timers.end();
}
const TimePoint time;
private:
PromiseFulfiller<void>& fulfiller;
TimerImpl::Impl& impl;
Impl::Timers::const_iterator pos;
};
inline bool TimerImpl::Impl::TimerBefore::operator()(
TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) {
return lhs->time < rhs->time;
}
Promise<void> TimerImpl::atTime(TimePoint time) {
return newAdaptedPromise<void, TimerPromiseAdapter>(*impl, time);
}
Promise<void> TimerImpl::afterDelay(Duration delay) {
return newAdaptedPromise<void, TimerPromiseAdapter>(*impl, time + delay);
}
TimerImpl::TimerImpl(TimePoint startTime)
: time(startTime), impl(heap<Impl>()) {}
TimerImpl::~TimerImpl() noexcept(false) {}
Maybe<TimePoint> TimerImpl::nextEvent() {
auto iter = impl->timers.begin();
if (iter == impl->timers.end()) {
return nullptr;
} else {
return (*iter)->time;
}
}
Maybe<uint64_t> TimerImpl::timeoutToNextEvent(TimePoint start, Duration unit, uint64_t max) {
return nextEvent().map([&](TimePoint nextTime) -> uint64_t {
if (nextTime <= start) return 0;
Duration timeout = nextTime - start;
uint64_t result = timeout / unit;
bool roundUp = timeout % unit > 0 * SECONDS;
if (result >= max) {
return max;
} else {
return result + roundUp;
}
});
}
void TimerImpl::advanceTo(TimePoint newTime) {
KJ_REQUIRE(newTime >= time, "can't advance backwards in time") { return; }
time = newTime;
for (;;) {
auto front = impl->timers.begin();
if (front == impl->timers.end() || (*front)->time > time) {
break;
}
(*front)->fulfill();
}
}
} // namespace kj
......@@ -97,6 +97,49 @@ private:
static kj::Exception makeTimeoutException();
};
class TimerImpl final: public Timer {
// Implementation of Timer that expects an external caller -- usually, the EventPort
// implementation -- to tell it when time has advanced.
public:
TimerImpl(TimePoint startTime);
~TimerImpl() noexcept(false);
Maybe<TimePoint> nextEvent();
// Returns the time at which the next scheduled timer event will occur, or null if no timer
// events are scheduled.
Maybe<uint64_t> timeoutToNextEvent(TimePoint start, Duration unit, uint64_t max);
// Convenience method which computes a timeout value to pass to an event-waiting system call to
// cause it to time out when the next timer event occurs.
//
// `start` is the time at which the timeout starts counting. This is typically not the same as
// now() since some time may have passed since the last time advanceTo() was called.
//
// `unit` is the time unit in which the timeout is measured. This is often MILLISECONDS. Note
// that this method will fractional values *up*, to guarantee that the returned timeout waits
// until just *after* the time the event is scheduled.
//
// The timeout will be clamped to `max`. Use this to avoid an overflow if e.g. the OS wants a
// 32-bit value or a signed value.
//
// Returns nullptr if there are no future events.
void advanceTo(TimePoint newTime);
// Set the time to `time` and fire any at() events that have been passed.
// implements Timer ----------------------------------------------------------
TimePoint now() override;
Promise<void> atTime(TimePoint time) override;
Promise<void> afterDelay(Duration delay) override;
private:
struct Impl;
class TimerPromiseAdapter;
TimePoint time;
Own<Impl> impl;
};
// =======================================================================================
// inline implementation details
......@@ -114,6 +157,8 @@ Promise<T> Timer::timeoutAfter(Duration delay, Promise<T>&& promise) {
}));
}
inline TimePoint TimerImpl::now() { return time; }
} // namespace kj
#endif // KJ_TIME_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment