Commit 517e1e4d authored by Kenton Varda's avatar Kenton Varda

Factor out Timer implementation from async-unix.c++ into a reusable class.

parent 2ebe56c4
...@@ -893,28 +893,10 @@ public: ...@@ -893,28 +893,10 @@ public:
UnixEventPort::FdObserver observer; UnixEventPort::FdObserver observer;
}; };
class TimerImpl final: public Timer {
public:
TimerImpl(UnixEventPort& eventPort): eventPort(eventPort) {}
TimePoint now() override { return eventPort.steadyTime(); }
Promise<void> atTime(TimePoint time) override {
return eventPort.atSteadyTime(time);
}
Promise<void> afterDelay(Duration delay) override {
return eventPort.atSteadyTime(eventPort.steadyTime() + delay);
}
private:
UnixEventPort& eventPort;
};
class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider { class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
public: public:
LowLevelAsyncIoProviderImpl() LowLevelAsyncIoProviderImpl()
: eventLoop(eventPort), timer(eventPort), waitScope(eventLoop) {} : eventLoop(eventPort), waitScope(eventLoop) {}
inline WaitScope& getWaitScope() { return waitScope; } inline WaitScope& getWaitScope() { return waitScope; }
...@@ -948,14 +930,13 @@ public: ...@@ -948,14 +930,13 @@ public:
return heap<DatagramPortImpl>(*this, eventPort, fd, flags); return heap<DatagramPortImpl>(*this, eventPort, fd, flags);
} }
Timer& getTimer() override { return timer; } Timer& getTimer() override { return eventPort.getTimer(); }
UnixEventPort& getEventPort() { return eventPort; } UnixEventPort& getEventPort() { return eventPort; }
private: private:
UnixEventPort eventPort; UnixEventPort eventPort;
EventLoop eventLoop; EventLoop eventLoop;
TimerImpl timer;
WaitScope waitScope; WaitScope waitScope;
}; };
......
...@@ -532,14 +532,16 @@ TEST(AsyncUnixTest, SteadyTimers) { ...@@ -532,14 +532,16 @@ TEST(AsyncUnixTest, SteadyTimers) {
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop); WaitScope waitScope(loop);
auto start = port.steadyTime(); auto& timer = port.getTimer();
auto start = timer.now();
kj::Vector<TimePoint> expected; kj::Vector<TimePoint> expected;
kj::Vector<TimePoint> actual; kj::Vector<TimePoint> actual;
auto addTimer = [&](Duration delay) { auto addTimer = [&](Duration delay) {
expected.add(max(start + delay, start)); expected.add(max(start + delay, start));
port.atSteadyTime(start + delay).then([&]() { timer.atTime(start + delay).then([&]() {
actual.add(port.steadyTime()); actual.add(timer.now());
}).detach([](Exception&& e) { ADD_FAILURE() << str(e).cStr(); }); }).detach([](Exception&& e) { ADD_FAILURE() << str(e).cStr(); });
}; };
...@@ -550,7 +552,7 @@ TEST(AsyncUnixTest, SteadyTimers) { ...@@ -550,7 +552,7 @@ TEST(AsyncUnixTest, SteadyTimers) {
addTimer(-10 * MILLISECONDS); addTimer(-10 * MILLISECONDS);
std::sort(expected.begin(), expected.end()); std::sort(expected.begin(), expected.end());
port.atSteadyTime(expected.back() + MILLISECONDS).wait(waitScope); timer.atTime(expected.back() + MILLISECONDS).wait(waitScope);
ASSERT_EQ(expected.size(), actual.size()); ASSERT_EQ(expected.size(), actual.size());
for (int i = 0; i < expected.size(); ++i) { for (int i = 0; i < expected.size(); ++i) {
...@@ -574,7 +576,7 @@ TEST(AsyncUnixTest, Wake) { ...@@ -574,7 +576,7 @@ TEST(AsyncUnixTest, Wake) {
EXPECT_TRUE(port.wait()); EXPECT_TRUE(port.wait());
{ {
auto promise = port.atSteadyTime(port.steadyTime()); auto promise = port.getTimer().atTime(port.getTimer().now());
EXPECT_FALSE(port.wait()); EXPECT_FALSE(port.wait());
} }
......
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include <errno.h> #include <errno.h>
#include <inttypes.h> #include <inttypes.h>
#include <limits> #include <limits>
#include <set>
#include <chrono> #include <chrono>
#include <pthread.h> #include <pthread.h>
...@@ -46,64 +45,11 @@ namespace kj { ...@@ -46,64 +45,11 @@ namespace kj {
// ======================================================================================= // =======================================================================================
// Timer code common to multiple implementations // Timer code common to multiple implementations
struct UnixEventPort::TimerSet { TimePoint UnixEventPort::readClock() {
struct TimerBefore {
bool operator()(TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs);
};
using Timers = std::multiset<TimerPromiseAdapter*, TimerBefore>;
Timers timers;
};
class UnixEventPort::TimerPromiseAdapter {
public:
TimerPromiseAdapter(PromiseFulfiller<void>& fulfiller, UnixEventPort& port, TimePoint time)
: time(time), fulfiller(fulfiller), port(port) {
pos = port.timers->timers.insert(this);
}
~TimerPromiseAdapter() {
if (pos != port.timers->timers.end()) {
port.timers->timers.erase(pos);
}
}
void fulfill() {
fulfiller.fulfill();
port.timers->timers.erase(pos);
pos = port.timers->timers.end();
}
const TimePoint time;
PromiseFulfiller<void>& fulfiller;
UnixEventPort& port;
TimerSet::Timers::const_iterator pos;
};
bool UnixEventPort::TimerSet::TimerBefore::operator()(
TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) {
return lhs->time < rhs->time;
}
Promise<void> UnixEventPort::atSteadyTime(TimePoint time) {
return newAdaptedPromise<void, TimerPromiseAdapter>(*this, time);
}
TimePoint UnixEventPort::currentSteadyTime() {
return origin<TimePoint>() + std::chrono::duration_cast<std::chrono::nanoseconds>( return origin<TimePoint>() + std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch()).count() * NANOSECONDS; std::chrono::steady_clock::now().time_since_epoch()).count() * NANOSECONDS;
} }
void UnixEventPort::processTimers() {
frozenSteadyTime = currentSteadyTime();
for (;;) {
auto front = timers->timers.begin();
if (front == timers->timers.end() || (*front)->time > frozenSteadyTime) {
break;
}
(*front)->fulfill();
}
}
// ======================================================================================= // =======================================================================================
// Signal code common to multiple implementations // Signal code common to multiple implementations
...@@ -249,8 +195,7 @@ void UnixEventPort::gotSignal(const siginfo_t& siginfo) { ...@@ -249,8 +195,7 @@ void UnixEventPort::gotSignal(const siginfo_t& siginfo) {
// epoll FdObserver implementation // epoll FdObserver implementation
UnixEventPort::UnixEventPort() UnixEventPort::UnixEventPort()
: timers(kj::heap<TimerSet>()), : timerImpl(readClock()),
frozenSteadyTime(currentSteadyTime()),
epollFd(-1), epollFd(-1),
signalFd(-1), signalFd(-1),
eventFd(-1) { eventFd(-1) {
...@@ -360,27 +305,10 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() { ...@@ -360,27 +305,10 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() {
} }
bool UnixEventPort::wait() { bool UnixEventPort::wait() {
// epoll_wait()'s timeout is an `int` count of milliseconds, so truncate to that. return doEpollWait(
// Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue))
// will break the math below. .map([](uint64_t t) -> int { return t; })
constexpr Duration MAX_TIMEOUT = .orDefault(-1));
min(int(maxValue) * MILLISECONDS, Duration(maxValue) - MILLISECONDS);
int epollTimeout = -1;
auto timer = timers->timers.begin();
if (timer != timers->timers.end()) {
Duration timeout = (*timer)->time - currentSteadyTime();
if (timeout < 0 * SECONDS) {
epollTimeout = 0;
} else if (timeout < MAX_TIMEOUT) {
// Round up to the next millisecond
epollTimeout = (timeout + 1 * MILLISECONDS - unit<Duration>()) / MILLISECONDS;
} else {
epollTimeout = MAX_TIMEOUT / MILLISECONDS;
}
}
return doEpollWait(epollTimeout);
} }
bool UnixEventPort::poll() { bool UnixEventPort::poll() {
...@@ -554,7 +482,7 @@ bool UnixEventPort::doEpollWait(int timeout) { ...@@ -554,7 +482,7 @@ bool UnixEventPort::doEpollWait(int timeout) {
} }
} }
processTimers(); timerImpl.advanceTo(readClock());
return woken; return woken;
} }
...@@ -568,8 +496,7 @@ bool UnixEventPort::doEpollWait(int timeout) { ...@@ -568,8 +496,7 @@ bool UnixEventPort::doEpollWait(int timeout) {
#endif #endif
UnixEventPort::UnixEventPort() UnixEventPort::UnixEventPort()
: timers(kj::heap<TimerSet>()), : timerImpl(readClock()) {
frozenSteadyTime(currentSteadyTime()) {
static_assert(sizeof(threadId) >= sizeof(pthread_t), static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port."); "pthread_t is larger than a long long on your platform. Please port.");
*reinterpret_cast<pthread_t*>(&threadId) = pthread_self(); *reinterpret_cast<pthread_t*>(&threadId) = pthread_self();
...@@ -771,33 +698,17 @@ bool UnixEventPort::wait() { ...@@ -771,33 +698,17 @@ bool UnixEventPort::wait() {
threadCapture = &capture; threadCapture = &capture;
sigprocmask(SIG_UNBLOCK, &newMask, &origMask); sigprocmask(SIG_UNBLOCK, &newMask, &origMask);
// poll()'s timeout is an `int` count of milliseconds, so truncate to that. pollContext.run(
// Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue))
// will break the math below. .map([](uint64_t t) -> int { return t; })
constexpr Duration MAX_TIMEOUT = .orDefault(-1));
min(int(maxValue) * MILLISECONDS, Duration(maxValue) - MILLISECONDS);
int pollTimeout = -1;
auto timer = timers->timers.begin();
if (timer != timers->timers.end()) {
Duration timeout = (*timer)->time - currentSteadyTime();
if (timeout < 0 * SECONDS) {
pollTimeout = 0;
} else if (timeout < MAX_TIMEOUT) {
// Round up to the next millisecond
pollTimeout = (timeout + 1 * MILLISECONDS - unit<Duration>()) / MILLISECONDS;
} else {
pollTimeout = MAX_TIMEOUT / MILLISECONDS;
}
}
pollContext.run(pollTimeout);
sigprocmask(SIG_SETMASK, &origMask, nullptr); sigprocmask(SIG_SETMASK, &origMask, nullptr);
threadCapture = nullptr; threadCapture = nullptr;
// Queue events. // Queue events.
pollContext.processResults(); pollContext.processResults();
processTimers(); timerImpl.advanceTo(readClock());
return false; return false;
} }
...@@ -859,7 +770,7 @@ bool UnixEventPort::poll() { ...@@ -859,7 +770,7 @@ bool UnixEventPort::poll() {
pollContext.run(0); pollContext.run(0);
pollContext.processResults(); pollContext.processResults();
} }
processTimers(); timerImpl.advanceTo(readClock());
return woken; return woken;
} }
......
...@@ -97,8 +97,7 @@ public: ...@@ -97,8 +97,7 @@ public:
// needs to use SIGUSR1, call this at startup (before any calls to `captureSignal()` and before // needs to use SIGUSR1, call this at startup (before any calls to `captureSignal()` and before
// constructing an `UnixEventPort`) to offer a different signal. // constructing an `UnixEventPort`) to offer a different signal.
TimePoint steadyTime() { return frozenSteadyTime; } Timer& getTimer() { return timerImpl; }
Promise<void> atSteadyTime(TimePoint time);
// implements EventPort ------------------------------------------------------ // implements EventPort ------------------------------------------------------
bool wait() override; bool wait() override;
...@@ -110,14 +109,12 @@ private: ...@@ -110,14 +109,12 @@ private:
class TimerPromiseAdapter; class TimerPromiseAdapter;
class SignalPromiseAdapter; class SignalPromiseAdapter;
Own<TimerSet> timers; TimerImpl timerImpl;
TimePoint frozenSteadyTime;
SignalPromiseAdapter* signalHead = nullptr; SignalPromiseAdapter* signalHead = nullptr;
SignalPromiseAdapter** signalTail = &signalHead; SignalPromiseAdapter** signalTail = &signalHead;
TimePoint currentSteadyTime(); TimePoint readClock();
void processTimers();
void gotSignal(const siginfo_t& siginfo); void gotSignal(const siginfo_t& siginfo);
friend class TimerPromiseAdapter; friend class TimerPromiseAdapter;
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "time.h" #include "time.h"
#include "debug.h" #include "debug.h"
#include <set>
namespace kj { namespace kj {
...@@ -29,4 +30,96 @@ kj::Exception Timer::makeTimeoutException() { ...@@ -29,4 +30,96 @@ kj::Exception Timer::makeTimeoutException() {
return KJ_EXCEPTION(OVERLOADED, "operation timed out"); return KJ_EXCEPTION(OVERLOADED, "operation timed out");
} }
struct TimerImpl::Impl {
struct TimerBefore {
bool operator()(TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs);
};
using Timers = std::multiset<TimerPromiseAdapter*, TimerBefore>;
Timers timers;
};
class TimerImpl::TimerPromiseAdapter {
public:
TimerPromiseAdapter(PromiseFulfiller<void>& fulfiller, TimerImpl::Impl& impl, TimePoint time)
: time(time), fulfiller(fulfiller), impl(impl) {
pos = impl.timers.insert(this);
}
~TimerPromiseAdapter() {
if (pos != impl.timers.end()) {
impl.timers.erase(pos);
}
}
void fulfill() {
fulfiller.fulfill();
impl.timers.erase(pos);
pos = impl.timers.end();
}
const TimePoint time;
private:
PromiseFulfiller<void>& fulfiller;
TimerImpl::Impl& impl;
Impl::Timers::const_iterator pos;
};
inline bool TimerImpl::Impl::TimerBefore::operator()(
TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) {
return lhs->time < rhs->time;
}
Promise<void> TimerImpl::atTime(TimePoint time) {
return newAdaptedPromise<void, TimerPromiseAdapter>(*impl, time);
}
Promise<void> TimerImpl::afterDelay(Duration delay) {
return newAdaptedPromise<void, TimerPromiseAdapter>(*impl, time + delay);
}
TimerImpl::TimerImpl(TimePoint startTime)
: time(startTime), impl(heap<Impl>()) {}
TimerImpl::~TimerImpl() noexcept(false) {}
Maybe<TimePoint> TimerImpl::nextEvent() {
auto iter = impl->timers.begin();
if (iter == impl->timers.end()) {
return nullptr;
} else {
return (*iter)->time;
}
}
Maybe<uint64_t> TimerImpl::timeoutToNextEvent(TimePoint start, Duration unit, uint64_t max) {
return nextEvent().map([&](TimePoint nextTime) -> uint64_t {
if (nextTime <= start) return 0;
Duration timeout = nextTime - start;
uint64_t result = timeout / unit;
bool roundUp = timeout % unit > 0 * SECONDS;
if (result >= max) {
return max;
} else {
return result + roundUp;
}
});
}
void TimerImpl::advanceTo(TimePoint newTime) {
KJ_REQUIRE(newTime >= time, "can't advance backwards in time") { return; }
time = newTime;
for (;;) {
auto front = impl->timers.begin();
if (front == impl->timers.end() || (*front)->time > time) {
break;
}
(*front)->fulfill();
}
}
} // namespace kj } // namespace kj
...@@ -97,6 +97,49 @@ private: ...@@ -97,6 +97,49 @@ private:
static kj::Exception makeTimeoutException(); static kj::Exception makeTimeoutException();
}; };
class TimerImpl final: public Timer {
// Implementation of Timer that expects an external caller -- usually, the EventPort
// implementation -- to tell it when time has advanced.
public:
TimerImpl(TimePoint startTime);
~TimerImpl() noexcept(false);
Maybe<TimePoint> nextEvent();
// Returns the time at which the next scheduled timer event will occur, or null if no timer
// events are scheduled.
Maybe<uint64_t> timeoutToNextEvent(TimePoint start, Duration unit, uint64_t max);
// Convenience method which computes a timeout value to pass to an event-waiting system call to
// cause it to time out when the next timer event occurs.
//
// `start` is the time at which the timeout starts counting. This is typically not the same as
// now() since some time may have passed since the last time advanceTo() was called.
//
// `unit` is the time unit in which the timeout is measured. This is often MILLISECONDS. Note
// that this method will fractional values *up*, to guarantee that the returned timeout waits
// until just *after* the time the event is scheduled.
//
// The timeout will be clamped to `max`. Use this to avoid an overflow if e.g. the OS wants a
// 32-bit value or a signed value.
//
// Returns nullptr if there are no future events.
void advanceTo(TimePoint newTime);
// Set the time to `time` and fire any at() events that have been passed.
// implements Timer ----------------------------------------------------------
TimePoint now() override;
Promise<void> atTime(TimePoint time) override;
Promise<void> afterDelay(Duration delay) override;
private:
struct Impl;
class TimerPromiseAdapter;
TimePoint time;
Own<Impl> impl;
};
// ======================================================================================= // =======================================================================================
// inline implementation details // inline implementation details
...@@ -114,6 +157,8 @@ Promise<T> Timer::timeoutAfter(Duration delay, Promise<T>&& promise) { ...@@ -114,6 +157,8 @@ Promise<T> Timer::timeoutAfter(Duration delay, Promise<T>&& promise) {
})); }));
} }
inline TimePoint TimerImpl::now() { return time; }
} // namespace kj } // namespace kj
#endif // KJ_TIME_H_ #endif // KJ_TIME_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment