Commit 851f51e5 authored by Kenton Varda's avatar Kenton Varda

Overhaul the way EventLoop is specialized so that it's possible to hook up to an…

Overhaul the way EventLoop is specialized so that it's possible to hook up to an existing event loop infrastructure that is not KJ-aware.  This also makes the async IO API more dependency-injection-friendly.
parent 296f9af1
...@@ -51,7 +51,7 @@ namespace { ...@@ -51,7 +51,7 @@ namespace {
#endif #endif
TEST(Capability, Basic) { TEST(Capability, Basic) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount = 0; int callCount = 0;
test::TestInterface::Client client(kj::heap<TestInterfaceImpl>(callCount)); test::TestInterface::Client client(kj::heap<TestInterfaceImpl>(callCount));
...@@ -89,7 +89,7 @@ TEST(Capability, Basic) { ...@@ -89,7 +89,7 @@ TEST(Capability, Basic) {
} }
TEST(Capability, Inheritance) { TEST(Capability, Inheritance) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount = 0; int callCount = 0;
test::TestExtends::Client client1(kj::heap<TestExtendsImpl>(callCount)); test::TestExtends::Client client1(kj::heap<TestExtendsImpl>(callCount));
...@@ -117,7 +117,7 @@ TEST(Capability, Inheritance) { ...@@ -117,7 +117,7 @@ TEST(Capability, Inheritance) {
} }
TEST(Capability, Pipelining) { TEST(Capability, Pipelining) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount = 0; int callCount = 0;
int chainedCallCount = 0; int chainedCallCount = 0;
...@@ -152,7 +152,7 @@ TEST(Capability, Pipelining) { ...@@ -152,7 +152,7 @@ TEST(Capability, Pipelining) {
} }
TEST(Capability, TailCall) { TEST(Capability, TailCall) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int calleeCallCount = 0; int calleeCallCount = 0;
int callerCallCount = 0; int callerCallCount = 0;
...@@ -187,7 +187,7 @@ TEST(Capability, TailCall) { ...@@ -187,7 +187,7 @@ TEST(Capability, TailCall) {
TEST(Capability, AsyncCancelation) { TEST(Capability, AsyncCancelation) {
// Tests allowAsyncCancellation(). // Tests allowAsyncCancellation().
kj::SimpleEventLoop loop; kj::EventLoop loop;
auto paf = kj::newPromiseAndFulfiller<void>(); auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false; bool destroyed = false;
...@@ -227,7 +227,7 @@ TEST(Capability, AsyncCancelation) { ...@@ -227,7 +227,7 @@ TEST(Capability, AsyncCancelation) {
TEST(Capability, SyncCancelation) { TEST(Capability, SyncCancelation) {
// Tests isCanceled() without allowAsyncCancellation(). // Tests isCanceled() without allowAsyncCancellation().
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount = 0; int callCount = 0;
int innerCallCount = 0; int innerCallCount = 0;
...@@ -269,7 +269,7 @@ TEST(Capability, SyncCancelation) { ...@@ -269,7 +269,7 @@ TEST(Capability, SyncCancelation) {
// ======================================================================================= // =======================================================================================
TEST(Capability, DynamicClient) { TEST(Capability, DynamicClient) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount = 0; int callCount = 0;
DynamicCapability::Client client = DynamicCapability::Client client =
...@@ -308,7 +308,7 @@ TEST(Capability, DynamicClient) { ...@@ -308,7 +308,7 @@ TEST(Capability, DynamicClient) {
} }
TEST(Capability, DynamicClientInheritance) { TEST(Capability, DynamicClientInheritance) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount = 0; int callCount = 0;
...@@ -344,7 +344,7 @@ TEST(Capability, DynamicClientInheritance) { ...@@ -344,7 +344,7 @@ TEST(Capability, DynamicClientInheritance) {
} }
TEST(Capability, DynamicClientPipelining) { TEST(Capability, DynamicClientPipelining) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount = 0; int callCount = 0;
int chainedCallCount = 0; int chainedCallCount = 0;
...@@ -415,7 +415,7 @@ public: ...@@ -415,7 +415,7 @@ public:
}; };
TEST(Capability, DynamicServer) { TEST(Capability, DynamicServer) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount = 0; int callCount = 0;
test::TestInterface::Client client = test::TestInterface::Client client =
...@@ -484,7 +484,7 @@ public: ...@@ -484,7 +484,7 @@ public:
}; };
TEST(Capability, DynamicServerInheritance) { TEST(Capability, DynamicServerInheritance) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount = 0; int callCount = 0;
test::TestExtends::Client client1 = test::TestExtends::Client client1 =
...@@ -558,7 +558,7 @@ public: ...@@ -558,7 +558,7 @@ public:
}; };
TEST(Capability, DynamicServerPipelining) { TEST(Capability, DynamicServerPipelining) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount = 0; int callCount = 0;
int chainedCallCount = 0; int chainedCallCount = 0;
...@@ -615,7 +615,7 @@ void verifyClient(DynamicCapability::Client client, const int& callCount) { ...@@ -615,7 +615,7 @@ void verifyClient(DynamicCapability::Client client, const int& callCount) {
} }
TEST(Capability, ObjectsAndOrphans) { TEST(Capability, ObjectsAndOrphans) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount1 = 0; int callCount1 = 0;
int callCount2 = 0; int callCount2 = 0;
...@@ -678,7 +678,7 @@ TEST(Capability, ObjectsAndOrphans) { ...@@ -678,7 +678,7 @@ TEST(Capability, ObjectsAndOrphans) {
} }
TEST(Capability, Lists) { TEST(Capability, Lists) {
kj::SimpleEventLoop loop; kj::EventLoop loop;
int callCount1 = 0; int callCount1 = 0;
int callCount2 = 0; int callCount2 = 0;
......
...@@ -430,7 +430,7 @@ public: ...@@ -430,7 +430,7 @@ public:
}; };
struct TestContext { struct TestContext {
kj::SimpleEventLoop loop; kj::EventLoop loop;
TestNetwork network; TestNetwork network;
TestRestorer restorer; TestRestorer restorer;
TestNetworkAdapter& clientNetwork; TestNetworkAdapter& clientNetwork;
......
...@@ -59,14 +59,14 @@ private: ...@@ -59,14 +59,14 @@ private:
int& callCount; int& callCount;
}; };
void runServer(kj::Own<kj::AsyncIoStream> stream, int& callCount) { kj::Own<kj::AsyncIoStream> runServer(kj::AsyncIoProvider& ioProvider, int& callCount) {
// Set up the server. return ioProvider.newPipeThread(
kj::UnixEventLoop eventLoop; [&callCount](kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream) {
TwoPartyVatNetwork network(*stream, rpc::twoparty::Side::SERVER); TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
TestRestorer restorer(callCount); TestRestorer restorer(callCount);
auto server = makeRpcServer(network, restorer); auto server = makeRpcServer(network, restorer);
network.onDisconnect().wait();
eventLoop.onSignal(SIGUSR2).wait(); });
} }
Capability::Client getPersistentCap(RpcSystem<rpc::twoparty::SturdyRefHostId>& client, Capability::Client getPersistentCap(RpcSystem<rpc::twoparty::SturdyRefHostId>& client,
...@@ -85,29 +85,12 @@ Capability::Client getPersistentCap(RpcSystem<rpc::twoparty::SturdyRefHostId>& c ...@@ -85,29 +85,12 @@ Capability::Client getPersistentCap(RpcSystem<rpc::twoparty::SturdyRefHostId>& c
return client.restore(hostId, objectIdMessage.getRoot<ObjectPointer>()); return client.restore(hostId, objectIdMessage.getRoot<ObjectPointer>());
} }
class CaptureSignalsOnInit {
public:
CaptureSignalsOnInit() {
kj::UnixEventLoop::captureSignal(SIGUSR2);
}
};
static CaptureSignalsOnInit captureSignalsOnInit;
TEST(TwoPartyNetwork, Basic) { TEST(TwoPartyNetwork, Basic) {
auto ioProvider = kj::setupIoEventLoop();
int callCount = 0; int callCount = 0;
// We'll communicate over this two-way pipe (actually, a socketpair). auto stream = runServer(*ioProvider, callCount);
auto pipe = kj::newTwoWayPipe(); TwoPartyVatNetwork network(*stream, rpc::twoparty::Side::CLIENT);
// Start up server in another thread.
kj::Thread thread([&]() {
runServer(kj::mv(pipe.ends[1]), callCount);
});
KJ_DEFER(thread.sendSignal(SIGUSR2));
// Set up the client-side objects.
kj::UnixEventLoop loop;
TwoPartyVatNetwork network(*pipe.ends[0], rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network); auto rpcClient = makeRpcClient(network);
// Request the particular capability from the server. // Request the particular capability from the server.
...@@ -148,21 +131,12 @@ TEST(TwoPartyNetwork, Basic) { ...@@ -148,21 +131,12 @@ TEST(TwoPartyNetwork, Basic) {
} }
TEST(TwoPartyNetwork, Pipelining) { TEST(TwoPartyNetwork, Pipelining) {
auto ioProvider = kj::setupIoEventLoop();
int callCount = 0; int callCount = 0;
int reverseCallCount = 0; // Calls back from server to client. int reverseCallCount = 0; // Calls back from server to client.
// We'll communicate over this two-way pipe (actually, a socketpair). auto stream = runServer(*ioProvider, callCount);
auto pipe = kj::newTwoWayPipe(); TwoPartyVatNetwork network(*stream, rpc::twoparty::Side::CLIENT);
// Start up server in another thread.
auto thread = kj::heap<kj::Thread>([&]() {
runServer(kj::mv(pipe.ends[1]), callCount);
});
KJ_ON_SCOPE_FAILURE(thread->sendSignal(SIGUSR2));
// Set up the client-side objects.
kj::UnixEventLoop loop;
TwoPartyVatNetwork network(*pipe.ends[0], rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network); auto rpcClient = makeRpcClient(network);
bool disconnected = false; bool disconnected = false;
...@@ -209,10 +183,10 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -209,10 +183,10 @@ TEST(TwoPartyNetwork, Pipelining) {
EXPECT_FALSE(disconnected); EXPECT_FALSE(disconnected);
EXPECT_FALSE(drained); EXPECT_FALSE(drained);
// What if the other side disconnects? // What if we disconnect?
thread->sendSignal(SIGUSR2); stream->shutdownWrite();
thread = nullptr;
// The other side should also disconnect.
disconnectPromise.wait(); disconnectPromise.wait();
EXPECT_FALSE(drained); EXPECT_FALSE(drained);
......
...@@ -269,7 +269,7 @@ RpcSystem<SturdyRefHostId> makeRpcServer( ...@@ -269,7 +269,7 @@ RpcSystem<SturdyRefHostId> makeRpcServer(
// MyNetwork network; // MyNetwork network;
// MyRestorer restorer; // MyRestorer restorer;
// auto server = makeRpcServer(network, restorer); // auto server = makeRpcServer(network, restorer);
// eventLoop.wait(...); // (e.g. wait on a promise that never returns) // kj::NEVER_DONE.wait(); // run forever
template <typename SturdyRefHostId, typename ProvisionId, template <typename SturdyRefHostId, typename ProvisionId,
typename RecipientId, typename ThirdPartyCapId, typename JoinResult> typename RecipientId, typename ThirdPartyCapId, typename JoinResult>
......
...@@ -120,7 +120,8 @@ protected: ...@@ -120,7 +120,8 @@ protected:
}; };
TEST_F(SerializeAsyncTest, ParseAsync) { TEST_F(SerializeAsyncTest, ParseAsync) {
auto input = kj::AsyncInputStream::wrapFd(fds[0]); auto ioProvider = kj::setupIoEventLoop();
auto input = ioProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]); kj::FdOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput); FragmentingOutputStream output(rawOutput);
...@@ -131,15 +132,14 @@ TEST_F(SerializeAsyncTest, ParseAsync) { ...@@ -131,15 +132,14 @@ TEST_F(SerializeAsyncTest, ParseAsync) {
writeMessage(output, message); writeMessage(output, message);
}); });
auto received = kj::runIoEventLoop([&]() { auto received = readMessage(*input).wait();
return readMessage(*input);
});
checkTestMessage(received->getRoot<TestAllTypes>()); checkTestMessage(received->getRoot<TestAllTypes>());
} }
TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) { TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) {
auto input = kj::AsyncInputStream::wrapFd(fds[0]); auto ioProvider = kj::setupIoEventLoop();
auto input = ioProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]); kj::FdOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput); FragmentingOutputStream output(rawOutput);
...@@ -150,15 +150,14 @@ TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) { ...@@ -150,15 +150,14 @@ TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) {
writeMessage(output, message); writeMessage(output, message);
}); });
auto received = kj::runIoEventLoop([&]() { auto received = readMessage(*input).wait();
return readMessage(*input);
});
checkTestMessage(received->getRoot<TestAllTypes>()); checkTestMessage(received->getRoot<TestAllTypes>());
} }
TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) { TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) {
auto input = kj::AsyncInputStream::wrapFd(fds[0]); auto ioProvider = kj::setupIoEventLoop();
auto input = ioProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]); kj::FdOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput); FragmentingOutputStream output(rawOutput);
...@@ -169,15 +168,14 @@ TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) { ...@@ -169,15 +168,14 @@ TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) {
writeMessage(output, message); writeMessage(output, message);
}); });
auto received = kj::runIoEventLoop([&]() { auto received = readMessage(*input).wait();
return readMessage(*input);
});
checkTestMessage(received->getRoot<TestAllTypes>()); checkTestMessage(received->getRoot<TestAllTypes>());
} }
TEST_F(SerializeAsyncTest, WriteAsync) { TEST_F(SerializeAsyncTest, WriteAsync) {
auto output = kj::AsyncOutputStream::wrapFd(fds[1]); auto ioProvider = kj::setupIoEventLoop();
auto output = ioProvider->wrapOutputFd(fds[1]);
TestMessageBuilder message(1); TestMessageBuilder message(1);
auto root = message.getRoot<TestAllTypes>(); auto root = message.getRoot<TestAllTypes>();
...@@ -195,13 +193,12 @@ TEST_F(SerializeAsyncTest, WriteAsync) { ...@@ -195,13 +193,12 @@ TEST_F(SerializeAsyncTest, WriteAsync) {
} }
}); });
kj::runIoEventLoop([&]() { writeMessage(*output, message).wait();
return writeMessage(*output, message);
});
} }
TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) { TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
auto output = kj::AsyncOutputStream::wrapFd(fds[1]); auto ioProvider = kj::setupIoEventLoop();
auto output = ioProvider->wrapOutputFd(fds[1]);
TestMessageBuilder message(7); TestMessageBuilder message(7);
auto root = message.getRoot<TestAllTypes>(); auto root = message.getRoot<TestAllTypes>();
...@@ -219,13 +216,12 @@ TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) { ...@@ -219,13 +216,12 @@ TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
} }
}); });
kj::runIoEventLoop([&]() { writeMessage(*output, message).wait();
return writeMessage(*output, message);
});
} }
TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) { TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) {
auto output = kj::AsyncOutputStream::wrapFd(fds[1]); auto ioProvider = kj::setupIoEventLoop();
auto output = ioProvider->wrapOutputFd(fds[1]);
TestMessageBuilder message(10); TestMessageBuilder message(10);
auto root = message.getRoot<TestAllTypes>(); auto root = message.getRoot<TestAllTypes>();
...@@ -243,9 +239,7 @@ TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) { ...@@ -243,9 +239,7 @@ TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) {
} }
}); });
kj::runIoEventLoop([&]() { writeMessage(*output, message).wait();
return writeMessage(*output, message);
});
} }
} // namespace } // namespace
......
...@@ -567,43 +567,6 @@ private: ...@@ -567,43 +567,6 @@ private:
// ======================================================================================= // =======================================================================================
template <typename T>
T EventLoop::wait(Promise<T>&& promise) {
_::ExceptionOr<_::FixVoid<T>> result;
waitImpl(kj::mv(promise.node), result);
KJ_IF_MAYBE(value, result.value) {
KJ_IF_MAYBE(exception, result.exception) {
throwRecoverableException(kj::mv(*exception));
}
return _::returnMaybeVoid(kj::mv(*value));
} else KJ_IF_MAYBE(exception, result.exception) {
throwFatalException(kj::mv(*exception));
} else {
// Result contained neither a value nor an exception?
KJ_UNREACHABLE;
}
}
template <typename Func>
PromiseForResult<Func, void> EventLoop::evalLater(Func&& func) {
// Invoke thenImpl() on yield().
return PromiseForResult<Func, void>(false,
thenImpl(yield(), kj::fwd<Func>(func), _::PropagateException()));
}
template <typename T, typename Func, typename ErrorFunc>
Own<_::PromiseNode> EventLoop::thenImpl(Promise<T>&& promise, Func&& func,
ErrorFunc&& errorHandler) {
typedef _::FixVoid<_::ReturnType<Func, T>> ResultT;
Own<_::PromiseNode> intermediate =
heap<_::TransformPromiseNode<ResultT, _::FixVoid<T>, Func, ErrorFunc>>(
kj::mv(promise.node), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler));
return _::maybeChain(kj::mv(intermediate), implicitCast<ResultT*>(nullptr));
}
template <typename T> template <typename T>
Promise<T>::Promise(_::FixVoid<T> value) Promise<T>::Promise(_::FixVoid<T> value)
: PromiseBase(heap<_::ImmediatePromiseNode<_::FixVoid<T>>>(kj::mv(value))) {} : PromiseBase(heap<_::ImmediatePromiseNode<_::FixVoid<T>>>(kj::mv(value))) {}
...@@ -615,13 +578,32 @@ Promise<T>::Promise(kj::Exception&& exception) ...@@ -615,13 +578,32 @@ Promise<T>::Promise(kj::Exception&& exception)
template <typename T> template <typename T>
template <typename Func, typename ErrorFunc> template <typename Func, typename ErrorFunc>
PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler) { PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler) {
return PromiseForResult<Func, T>(false, EventLoop::current().thenImpl( typedef _::FixVoid<_::ReturnType<Func, T>> ResultT;
kj::mv(*this), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler)));
Own<_::PromiseNode> intermediate =
heap<_::TransformPromiseNode<ResultT, _::FixVoid<T>, Func, ErrorFunc>>(
kj::mv(node), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler));
return PromiseForResult<Func, T>(false,
_::maybeChain(kj::mv(intermediate), implicitCast<ResultT*>(nullptr)));
} }
template <typename T> template <typename T>
T Promise<T>::wait() { T Promise<T>::wait() {
return EventLoop::current().wait(kj::mv(*this)); _::ExceptionOr<_::FixVoid<T>> result;
waitImpl(kj::mv(node), result);
KJ_IF_MAYBE(value, result.value) {
KJ_IF_MAYBE(exception, result.exception) {
throwRecoverableException(kj::mv(*exception));
}
return _::returnMaybeVoid(kj::mv(*value));
} else KJ_IF_MAYBE(exception, result.exception) {
throwFatalException(kj::mv(*exception));
} else {
// Result contained neither a value nor an exception?
KJ_UNREACHABLE;
}
} }
template <typename T> template <typename T>
...@@ -658,19 +640,19 @@ kj::String Promise<T>::trace() { ...@@ -658,19 +640,19 @@ kj::String Promise<T>::trace() {
template <typename Func> template <typename Func>
inline PromiseForResult<Func, void> evalLater(Func&& func) { inline PromiseForResult<Func, void> evalLater(Func&& func) {
return EventLoop::current().evalLater(kj::fwd<Func>(func)); return _::yield().then(kj::fwd<Func>(func), _::PropagateException());
} }
template <typename T> template <typename T>
template <typename ErrorFunc> template <typename ErrorFunc>
void Promise<T>::daemonize(ErrorFunc&& errorHandler) { void Promise<T>::daemonize(ErrorFunc&& errorHandler) {
return EventLoop::current().daemonize(then([](T&&) {}, kj::fwd<ErrorFunc>(errorHandler))); return _::daemonize(then([](T&&) {}, kj::fwd<ErrorFunc>(errorHandler)));
} }
template <> template <>
template <typename ErrorFunc> template <typename ErrorFunc>
void Promise<void>::daemonize(ErrorFunc&& errorHandler) { void Promise<void>::daemonize(ErrorFunc&& errorHandler) {
return EventLoop::current().daemonize(then([]() {}, kj::fwd<ErrorFunc>(errorHandler))); return _::daemonize(then([]() {}, kj::fwd<ErrorFunc>(errorHandler)));
} }
// ======================================================================================= // =======================================================================================
......
...@@ -30,8 +30,8 @@ namespace kj { ...@@ -30,8 +30,8 @@ namespace kj {
namespace { namespace {
TEST(AsyncIo, SimpleNetwork) { TEST(AsyncIo, SimpleNetwork) {
UnixEventLoop loop; auto ioProvider = setupIoEventLoop();
auto network = Network::newSystemNetwork(); auto& network = ioProvider->getNetwork();
Own<ConnectionReceiver> listener; Own<ConnectionReceiver> listener;
Own<AsyncIoStream> server; Own<AsyncIoStream> server;
...@@ -42,7 +42,7 @@ TEST(AsyncIo, SimpleNetwork) { ...@@ -42,7 +42,7 @@ TEST(AsyncIo, SimpleNetwork) {
auto port = newPromiseAndFulfiller<uint>(); auto port = newPromiseAndFulfiller<uint>();
port.promise.then([&](uint portnum) { port.promise.then([&](uint portnum) {
return network->parseRemoteAddress("127.0.0.1", portnum); return network.parseRemoteAddress("127.0.0.1", portnum);
}).then([&](Own<RemoteAddress>&& result) { }).then([&](Own<RemoteAddress>&& result) {
return result->connect(); return result->connect();
}).then([&](Own<AsyncIoStream>&& result) { }).then([&](Own<AsyncIoStream>&& result) {
...@@ -52,7 +52,7 @@ TEST(AsyncIo, SimpleNetwork) { ...@@ -52,7 +52,7 @@ TEST(AsyncIo, SimpleNetwork) {
ADD_FAILURE() << kj::str(exception).cStr(); ADD_FAILURE() << kj::str(exception).cStr();
}); });
kj::String result = network->parseLocalAddress("*").then([&](Own<LocalAddress>&& result) { kj::String result = network.parseLocalAddress("*").then([&](Own<LocalAddress>&& result) {
listener = result->listen(); listener = result->listen();
port.fulfiller->fulfill(listener->getPort()); port.fulfiller->fulfill(listener->getPort());
return listener->accept(); return listener->accept();
...@@ -67,34 +67,34 @@ TEST(AsyncIo, SimpleNetwork) { ...@@ -67,34 +67,34 @@ TEST(AsyncIo, SimpleNetwork) {
EXPECT_EQ("foo", result); EXPECT_EQ("foo", result);
} }
String tryParseLocal(EventLoop& loop, Network& network, StringPtr text, uint portHint = 0) { String tryParseLocal(Network& network, StringPtr text, uint portHint = 0) {
return network.parseLocalAddress(text, portHint).wait()->toString(); return network.parseLocalAddress(text, portHint).wait()->toString();
} }
String tryParseRemote(EventLoop& loop, Network& network, StringPtr text, uint portHint = 0) { String tryParseRemote(Network& network, StringPtr text, uint portHint = 0) {
return network.parseRemoteAddress(text, portHint).wait()->toString(); return network.parseRemoteAddress(text, portHint).wait()->toString();
} }
TEST(AsyncIo, AddressParsing) { TEST(AsyncIo, AddressParsing) {
UnixEventLoop loop; auto ioProvider = setupIoEventLoop();
auto network = Network::newSystemNetwork(); auto& network = ioProvider->getNetwork();
EXPECT_EQ("*:0", tryParseLocal(loop, *network, "*")); EXPECT_EQ("*:0", tryParseLocal(network, "*"));
EXPECT_EQ("*:123", tryParseLocal(loop, *network, "123")); EXPECT_EQ("*:123", tryParseLocal(network, "123"));
EXPECT_EQ("*:123", tryParseLocal(loop, *network, ":123")); EXPECT_EQ("*:123", tryParseLocal(network, ":123"));
EXPECT_EQ("[::]:123", tryParseLocal(loop, *network, "0::0", 123)); EXPECT_EQ("[::]:123", tryParseLocal(network, "0::0", 123));
EXPECT_EQ("0.0.0.0:0", tryParseLocal(loop, *network, "0.0.0.0")); EXPECT_EQ("0.0.0.0:0", tryParseLocal(network, "0.0.0.0"));
EXPECT_EQ("1.2.3.4:5678", tryParseRemote(loop, *network, "1.2.3.4", 5678)); EXPECT_EQ("1.2.3.4:5678", tryParseRemote(network, "1.2.3.4", 5678));
EXPECT_EQ("[12ab:cd::34]:321", tryParseRemote(loop, *network, "[12ab:cd:0::0:34]:321", 432)); EXPECT_EQ("[12ab:cd::34]:321", tryParseRemote(network, "[12ab:cd:0::0:34]:321", 432));
EXPECT_EQ("unix:foo/bar/baz", tryParseLocal(loop, *network, "unix:foo/bar/baz")); EXPECT_EQ("unix:foo/bar/baz", tryParseLocal(network, "unix:foo/bar/baz"));
EXPECT_EQ("unix:foo/bar/baz", tryParseRemote(loop, *network, "unix:foo/bar/baz")); EXPECT_EQ("unix:foo/bar/baz", tryParseRemote(network, "unix:foo/bar/baz"));
} }
TEST(AsyncIo, OneWayPipe) { TEST(AsyncIo, OneWayPipe) {
UnixEventLoop loop; auto ioProvider = setupIoEventLoop();
auto pipe = newOneWayPipe(); auto pipe = ioProvider->newOneWayPipe();
char receiveBuffer[4]; char receiveBuffer[4];
pipe.out->write("foo", 3).daemonize([](kj::Exception&& exception) { pipe.out->write("foo", 3).daemonize([](kj::Exception&& exception) {
...@@ -110,9 +110,9 @@ TEST(AsyncIo, OneWayPipe) { ...@@ -110,9 +110,9 @@ TEST(AsyncIo, OneWayPipe) {
} }
TEST(AsyncIo, TwoWayPipe) { TEST(AsyncIo, TwoWayPipe) {
UnixEventLoop loop; auto ioProvider = setupIoEventLoop();
auto pipe = newTwoWayPipe(); auto pipe = ioProvider->newTwoWayPipe();
char receiveBuffer1[4]; char receiveBuffer1[4];
char receiveBuffer2[4]; char receiveBuffer2[4];
...@@ -136,23 +136,45 @@ TEST(AsyncIo, TwoWayPipe) { ...@@ -136,23 +136,45 @@ TEST(AsyncIo, TwoWayPipe) {
EXPECT_EQ("bar", result2); EXPECT_EQ("bar", result2);
} }
TEST(AsyncIo, RunIoEventLoop) { TEST(AsyncIo, PipeThread) {
auto pipe = newOneWayPipe(); auto ioProvider = setupIoEventLoop();
char receiveBuffer[4];
String result = runIoEventLoop([&]() { auto stream = ioProvider->newPipeThread([](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
auto promise1 = pipe.out->write("foo", 3); char buf[4];
stream.write("foo", 3).wait();
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait());
EXPECT_EQ("bar", heapString(buf, 3));
auto promise2 = pipe.in->tryRead(receiveBuffer, 3, 4) // Expect disconnect.
.then([&](size_t n) { EXPECT_EQ(0, stream.tryRead(buf, 1, 1).wait());
EXPECT_EQ(3u, n);
return heapString(receiveBuffer, n);
}); });
return promise1.then(mvCapture(promise2, [](Promise<String> promise2) { return promise2; })); char buf[4];
stream->write("bar", 3).wait();
EXPECT_EQ(3u, stream->tryRead(buf, 3, 4).wait());
EXPECT_EQ("foo", heapString(buf, 3));
}
TEST(AsyncIo, PipeThreadDisconnects) {
// Like above, but in this case we expect the main thread to detect the pipe thread disconnecting.
auto ioProvider = setupIoEventLoop();
auto stream = ioProvider->newPipeThread([](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
char buf[4];
stream.write("foo", 3).wait();
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait());
EXPECT_EQ("bar", heapString(buf, 3));
}); });
EXPECT_EQ("foo", result); char buf[4];
EXPECT_EQ(3u, stream->tryRead(buf, 3, 4).wait());
EXPECT_EQ("foo", heapString(buf, 3));
stream->write("bar", 3).wait();
// Expect disconnect.
EXPECT_EQ(0, stream->tryRead(buf, 1, 1).wait());
} }
} // namespace } // namespace
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "async-io.h" #include "async-io.h"
#include "async-unix.h" #include "async-unix.h"
#include "debug.h" #include "debug.h"
#include "thread.h"
#include <unistd.h> #include <unistd.h>
#include <sys/uio.h> #include <sys/uio.h>
#include <errno.h> #include <errno.h>
...@@ -45,10 +46,6 @@ namespace kj { ...@@ -45,10 +46,6 @@ namespace kj {
namespace { namespace {
UnixEventLoop& eventLoop() {
return downcast<UnixEventLoop>(EventLoop::current());
}
void setNonblocking(int fd) { void setNonblocking(int fd) {
int flags; int flags;
KJ_SYSCALL(flags = fcntl(fd, F_GETFL)); KJ_SYSCALL(flags = fcntl(fd, F_GETFL));
...@@ -90,7 +87,8 @@ protected: ...@@ -90,7 +87,8 @@ protected:
class AsyncStreamFd: public AsyncIoStream { class AsyncStreamFd: public AsyncIoStream {
public: public:
AsyncStreamFd(int readFd, int writeFd): readFd(readFd), writeFd(writeFd) {} AsyncStreamFd(UnixEventPort& eventPort, int readFd, int writeFd)
: eventPort(eventPort), readFd(readFd), writeFd(writeFd) {}
virtual ~AsyncStreamFd() noexcept(false) {} virtual ~AsyncStreamFd() noexcept(false) {}
Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override { Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
...@@ -124,7 +122,7 @@ public: ...@@ -124,7 +122,7 @@ public:
size -= n; size -= n;
} }
return eventLoop().onFdEvent(writeFd, POLLOUT).then([=](short) { return eventPort.onFdEvent(writeFd, POLLOUT).then([=](short) {
return write(buffer, size); return write(buffer, size);
}); });
} }
...@@ -137,7 +135,15 @@ public: ...@@ -137,7 +135,15 @@ public:
} }
} }
void shutdownWrite() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// UnixAsyncIoProvider interface.
KJ_REQUIRE(readFd == writeFd, "shutdownWrite() is only implemented on sockets.");
KJ_SYSCALL(shutdown(writeFd, SHUT_WR));
}
private: private:
UnixEventPort& eventPort;
int readFd; int readFd;
int writeFd; int writeFd;
bool gotHup = false; bool gotHup = false;
...@@ -155,7 +161,7 @@ private: ...@@ -155,7 +161,7 @@ private:
if (n < 0) { if (n < 0) {
// Read would block. // Read would block.
return eventLoop().onFdEvent(readFd, POLLIN | POLLRDHUP).then([=](short events) { return eventPort.onFdEvent(readFd, POLLIN | POLLRDHUP).then([=](short events) {
gotHup = events & (POLLHUP | POLLRDHUP); gotHup = events & (POLLHUP | POLLRDHUP);
return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead); return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead);
}); });
...@@ -180,7 +186,7 @@ private: ...@@ -180,7 +186,7 @@ private:
minBytes -= n; minBytes -= n;
maxBytes -= n; maxBytes -= n;
alreadyRead += n; alreadyRead += n;
return eventLoop().onFdEvent(readFd, POLLIN | POLLRDHUP).then([=](short events) { return eventPort.onFdEvent(readFd, POLLIN | POLLRDHUP).then([=](short events) {
gotHup = events & (POLLHUP | POLLRDHUP); gotHup = events & (POLLHUP | POLLRDHUP);
return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead); return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead);
}); });
...@@ -217,7 +223,7 @@ private: ...@@ -217,7 +223,7 @@ private:
if (n < firstPiece.size()) { if (n < firstPiece.size()) {
// Only part of the first piece was consumed. Wait for POLLOUT and then write again. // Only part of the first piece was consumed. Wait for POLLOUT and then write again.
firstPiece = firstPiece.slice(n, firstPiece.size()); firstPiece = firstPiece.slice(n, firstPiece.size());
return eventLoop().onFdEvent(writeFd, POLLOUT).then([=](short) { return eventPort.onFdEvent(writeFd, POLLOUT).then([=](short) {
return writeInternal(firstPiece, morePieces); return writeInternal(firstPiece, morePieces);
}); });
} else if (morePieces.size() == 0) { } else if (morePieces.size() == 0) {
...@@ -236,7 +242,19 @@ private: ...@@ -236,7 +242,19 @@ private:
class Socket final: public OwnedFileDescriptor, public AsyncStreamFd { class Socket final: public OwnedFileDescriptor, public AsyncStreamFd {
public: public:
Socket(int fd): OwnedFileDescriptor(fd), AsyncStreamFd(fd, fd) {} Socket(UnixEventPort& eventPort, int fd)
: OwnedFileDescriptor(fd), AsyncStreamFd(eventPort, fd, fd) {}
};
class ThreadSocket final: public Thread, public OwnedFileDescriptor, public AsyncStreamFd {
// Combination thread and socket. The thread must be joined strictly after the socket is closed.
public:
template <typename StartFunc>
ThreadSocket(UnixEventPort& eventPort, int fd, StartFunc&& startFunc)
: Thread(kj::fwd<StartFunc>(startFunc)),
OwnedFileDescriptor(fd),
AsyncStreamFd(eventPort, fd, fd) {}
}; };
// ======================================================================================= // =======================================================================================
...@@ -488,7 +506,8 @@ private: ...@@ -488,7 +506,8 @@ private:
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor { class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor {
public: public:
FdConnectionReceiver(int fd): OwnedFileDescriptor(fd) {} FdConnectionReceiver(UnixEventPort& eventPort, int fd)
: OwnedFileDescriptor(fd), eventPort(eventPort) {}
Promise<Own<AsyncIoStream>> accept() override { Promise<Own<AsyncIoStream>> accept() override {
int newFd; int newFd;
...@@ -507,28 +526,32 @@ public: ...@@ -507,28 +526,32 @@ public:
if (newFd < 0) { if (newFd < 0) {
// Gotta wait. // Gotta wait.
return eventLoop().onFdEvent(fd, POLLIN).then([this](short) { return eventPort.onFdEvent(fd, POLLIN).then([this](short) {
return accept(); return accept();
}); });
} else { } else {
return Own<AsyncIoStream>(heap<Socket>(newFd)); return Own<AsyncIoStream>(heap<Socket>(eventPort, newFd));
} }
} }
uint getPort() override { uint getPort() override {
return SocketAddress::getLocalAddress(fd).getPort(); return SocketAddress::getLocalAddress(fd).getPort();
} }
public:
UnixEventPort& eventPort;
}; };
// ======================================================================================= // =======================================================================================
class LocalSocketAddress final: public LocalAddress { class LocalSocketAddress final: public LocalAddress {
public: public:
LocalSocketAddress(SocketAddress addr): addr(addr) {} LocalSocketAddress(UnixEventPort& eventPort, SocketAddress addr)
: eventPort(eventPort), addr(addr) {}
Own<ConnectionReceiver> listen() override { Own<ConnectionReceiver> listen() override {
int fd = addr.socket(SOCK_STREAM); int fd = addr.socket(SOCK_STREAM);
auto result = heap<FdConnectionReceiver>(fd); auto result = heap<FdConnectionReceiver>(eventPort, fd);
// We always enable SO_REUSEADDR because having to take your server down for five minutes // We always enable SO_REUSEADDR because having to take your server down for five minutes
// before it can restart really sucks. // before it can restart really sucks.
...@@ -548,19 +571,21 @@ public: ...@@ -548,19 +571,21 @@ public:
} }
private: private:
UnixEventPort& eventPort;
SocketAddress addr; SocketAddress addr;
}; };
class RemoteSocketAddress final: public RemoteAddress { class RemoteSocketAddress final: public RemoteAddress {
public: public:
RemoteSocketAddress(SocketAddress addr): addr(addr) {} RemoteSocketAddress(UnixEventPort& eventPort, SocketAddress addr)
: eventPort(eventPort), addr(addr) {}
Promise<Own<AsyncIoStream>> connect() override { Promise<Own<AsyncIoStream>> connect() override {
int fd = addr.socket(SOCK_STREAM); int fd = addr.socket(SOCK_STREAM);
auto result = heap<Socket>(fd); auto result = heap<Socket>(eventPort, fd);
addr.connect(fd); addr.connect(fd);
return eventLoop().onFdEvent(fd, POLLOUT).then(kj::mvCapture(result, return eventPort.onFdEvent(fd, POLLOUT).then(kj::mvCapture(result,
[fd](Own<AsyncIoStream>&& stream, short events) { [fd](Own<AsyncIoStream>&& stream, short events) {
int err; int err;
socklen_t errlen = sizeof(err); socklen_t errlen = sizeof(err);
...@@ -577,83 +602,118 @@ public: ...@@ -577,83 +602,118 @@ public:
} }
private: private:
UnixEventPort& eventPort;
SocketAddress addr; SocketAddress addr;
}; };
class SocketNetwork final: public Network { class SocketNetwork final: public Network {
public: public:
Promise<Own<LocalAddress>> parseLocalAddress(StringPtr addr, uint portHint = 0) const override { explicit SocketNetwork(UnixEventPort& eventPort): eventPort(eventPort) {}
Promise<Own<LocalAddress>> parseLocalAddress(StringPtr addr, uint portHint = 0) override {
auto& eventPortCopy = eventPort;
return evalLater(mvCapture(heapString(addr), return evalLater(mvCapture(heapString(addr),
[portHint](String&& addr) -> Own<LocalAddress> { [&eventPortCopy,portHint](String&& addr) -> Own<LocalAddress> {
return heap<LocalSocketAddress>(SocketAddress::parseLocal(addr, portHint)); return heap<LocalSocketAddress>(eventPortCopy, SocketAddress::parseLocal(addr, portHint));
})); }));
} }
Promise<Own<RemoteAddress>> parseRemoteAddress(StringPtr addr, uint portHint = 0) const override { Promise<Own<RemoteAddress>> parseRemoteAddress(StringPtr addr, uint portHint = 0) override {
auto& eventPortCopy = eventPort;
return evalLater(mvCapture(heapString(addr), return evalLater(mvCapture(heapString(addr),
[portHint](String&& addr) -> Own<RemoteAddress> { [&eventPortCopy,portHint](String&& addr) -> Own<RemoteAddress> {
return heap<RemoteSocketAddress>(SocketAddress::parse(addr, portHint)); return heap<RemoteSocketAddress>(eventPortCopy, SocketAddress::parse(addr, portHint));
})); }));
} }
Own<LocalAddress> getLocalSockaddr(const void* sockaddr, uint len) const override { Own<LocalAddress> getLocalSockaddr(const void* sockaddr, uint len) override {
return Own<LocalAddress>(heap<LocalSocketAddress>(SocketAddress(sockaddr, len))); return Own<LocalAddress>(heap<LocalSocketAddress>(eventPort, SocketAddress(sockaddr, len)));
} }
Own<RemoteAddress> getRemoteSockaddr(const void* sockaddr, uint len) const override { Own<RemoteAddress> getRemoteSockaddr(const void* sockaddr, uint len) override {
return Own<RemoteAddress>(heap<RemoteSocketAddress>(SocketAddress(sockaddr, len))); return Own<RemoteAddress>(heap<RemoteSocketAddress>(eventPort, SocketAddress(sockaddr, len)));
} }
};
} // namespace private:
UnixEventPort& eventPort;
Promise<void> AsyncInputStream::read(void* buffer, size_t bytes) { };
return read(buffer, bytes, bytes).then([](size_t) {});
}
Own<AsyncInputStream> AsyncInputStream::wrapFd(int fd) {
setNonblocking(fd);
return heap<AsyncStreamFd>(fd, -1);
}
Own<AsyncOutputStream> AsyncOutputStream::wrapFd(int fd) {
setNonblocking(fd);
return heap<AsyncStreamFd>(-1, fd);
}
Own<AsyncIoStream> AsyncIoStream::wrapFd(int fd) { // =======================================================================================
setNonblocking(fd);
return heap<AsyncStreamFd>(fd, fd);
}
Own<Network> Network::newSystemNetwork() { class UnixAsyncIoProvider final: public AsyncIoProvider {
return heap<SocketNetwork>(); public:
} UnixAsyncIoProvider()
: eventLoop(eventPort), network(eventPort) {}
OneWayPipe newOneWayPipe() { OneWayPipe newOneWayPipe() override {
int fds[2]; int fds[2];
#if __linux__ #if __linux__
KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC)); KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC));
#else #else
KJ_SYSCALL(pipe(fds)); KJ_SYSCALL(pipe(fds));
#endif #endif
return OneWayPipe { heap<Socket>(fds[0]), heap<Socket>(fds[1]) }; return OneWayPipe { heap<Socket>(eventPort, fds[0]), heap<Socket>(eventPort, fds[1]) };
} }
TwoWayPipe newTwoWayPipe() { TwoWayPipe newTwoWayPipe() override {
int fds[2]; int fds[2];
int type = SOCK_STREAM; int type = SOCK_STREAM;
#if __linux__ #if __linux__
type |= SOCK_NONBLOCK | SOCK_CLOEXEC; type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
#endif #endif
KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds)); KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
return TwoWayPipe { { heap<Socket>(fds[0]), heap<Socket>(fds[1]) } }; return TwoWayPipe { { heap<Socket>(eventPort, fds[0]), heap<Socket>(eventPort, fds[1]) } };
} }
Network& getNetwork() override {
return network;
}
Own<AsyncIoStream> newPipeThread(
Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)> startFunc) override {
int fds[2];
int type = SOCK_STREAM;
#if __linux__
type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
#endif
KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
int threadFd = fds[1];
return heap<ThreadSocket>(eventPort, fds[0], kj::mvCapture(startFunc,
[threadFd](Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)>&& startFunc) {
KJ_DEFER(KJ_SYSCALL(close(threadFd)));
UnixAsyncIoProvider ioProvider;
auto stream = ioProvider.wrapSocketFd(threadFd);
startFunc(ioProvider, *stream);
}));
}
namespace _ { // private Own<AsyncInputStream> wrapInputFd(int fd) override {
setNonblocking(fd);
return heap<AsyncStreamFd>(eventPort, fd, -1);
}
Own<AsyncOutputStream> wrapOutputFd(int fd) override {
setNonblocking(fd);
return heap<AsyncStreamFd>(eventPort, -1, fd);
}
Own<AsyncIoStream> wrapSocketFd(int fd) override {
setNonblocking(fd);
return heap<AsyncStreamFd>(eventPort, fd, fd);
}
private:
UnixEventPort eventPort;
EventLoop eventLoop;
SocketNetwork network;
};
} // namespace
Promise<void> AsyncInputStream::read(void* buffer, size_t bytes) {
return read(buffer, bytes, bytes).then([](size_t) {});
}
void runIoEventLoopInternal(IoLoopMain& func) { Own<AsyncIoProvider> setupIoEventLoop() {
UnixEventLoop loop; return heap<UnixAsyncIoProvider>();
func.run(loop);
} }
} // namespace _ (private)
} // namespace kj } // namespace kj
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define KJ_ASYNC_IO_H_ #define KJ_ASYNC_IO_H_
#include "async.h" #include "async.h"
#include "function.h"
namespace kj { namespace kj {
...@@ -36,13 +37,6 @@ public: ...@@ -36,13 +37,6 @@ public:
virtual Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) = 0; virtual Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) = 0;
Promise<void> read(void* buffer, size_t bytes); Promise<void> read(void* buffer, size_t bytes);
static Own<AsyncInputStream> wrapFd(int fd);
// Create an AsyncInputStream wrapping a file descriptor.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
//
// The returned object can only be called from within a system event loop (e.g. `UnixEventLoop`).
}; };
class AsyncOutputStream { class AsyncOutputStream {
...@@ -51,25 +45,14 @@ class AsyncOutputStream { ...@@ -51,25 +45,14 @@ class AsyncOutputStream {
public: public:
virtual Promise<void> write(const void* buffer, size_t size) = 0; virtual Promise<void> write(const void* buffer, size_t size) = 0;
virtual Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) = 0; virtual Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) = 0;
static Own<AsyncOutputStream> wrapFd(int fd);
// Create an AsyncOutputStream wrapping a file descriptor.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
//
// The returned object can only be called from within a system event loop (e.g. `UnixEventLoop`).
}; };
class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream { class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream {
// A combination input and output stream. // A combination input and output stream.
public: public:
static Own<AsyncIoStream> wrapFd(int fd); virtual void shutdownWrite() = 0;
// Create an AsyncIoStream wrapping a file descriptor. // Cleanly shut down just the write end of the stream, while keeping the read end open.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
//
// The returned object can only be called from within a system event loop (e.g. `UnixEventLoop`).
}; };
class ConnectionReceiver { class ConnectionReceiver {
...@@ -121,33 +104,10 @@ class Network { ...@@ -121,33 +104,10 @@ class Network {
// LocalAddress and/or RemoteAddress instances directly and work from there, if at all possible. // LocalAddress and/or RemoteAddress instances directly and work from there, if at all possible.
public: public:
static Own<Network> newSystemNetwork();
// Creates a new `Network` instance representing the networks exposed by the operating system.
//
// DO NOT CALL THIS except at the highest levels of your code, ideally in the main() function. If
// you call this from low-level code, then you are preventing higher-level code from injecting an
// alternative implementation. Instead, if your code needs to use network functionality, it
// should ask for a `Network` as a constructor or method parameter, so that higher-level code can
// chose what implementation to use. The system network is essentially a singleton. See:
// http://www.object-oriented-security.org/lets-argue/singletons
//
// Code that uses the system network should not make any assumptions about what kinds of
// addresses it will parse, as this could differ across platforms. String addresses should come
// strictly from the user, who will know how to write them correctly for their system.
//
// With that said, KJ currently supports the following string address formats:
// - IPv4: "1.2.3.4", "1.2.3.4:80"
// - IPv6: "1234:5678::abcd", "[1234:5678::abcd]:80"
// - Local IP wildcard (local addresses only; covers both v4 and v6): "*", "*:80", ":80", "80"
// - Unix domain: "unix:/path/to/socket"
//
// The system network -- and all objects it creates -- can only be used from threads running
// a system event loop (e.g. `UnixEventLoop`).
virtual Promise<Own<LocalAddress>> parseLocalAddress( virtual Promise<Own<LocalAddress>> parseLocalAddress(
StringPtr addr, uint portHint = 0) const = 0; StringPtr addr, uint portHint = 0) = 0;
virtual Promise<Own<RemoteAddress>> parseRemoteAddress( virtual Promise<Own<RemoteAddress>> parseRemoteAddress(
StringPtr addr, uint portHint = 0) const = 0; StringPtr addr, uint portHint = 0) = 0;
// Construct a local or remote address from a user-provided string. The format of the address // Construct a local or remote address from a user-provided string. The format of the address
// strings is not specified at the API level, and application code should make no assumptions // strings is not specified at the API level, and application code should make no assumptions
// about them. These strings should always be provided by humans, and said humans will know // about them. These strings should always be provided by humans, and said humans will know
...@@ -160,86 +120,130 @@ public: ...@@ -160,86 +120,130 @@ public:
// In practice, a local address is usually just a port number (or even an empty string, if a // In practice, a local address is usually just a port number (or even an empty string, if a
// reasonable `portHint` is provided), whereas a remote address usually requires a hostname. // reasonable `portHint` is provided), whereas a remote address usually requires a hostname.
virtual Own<LocalAddress> getLocalSockaddr(const void* sockaddr, uint len) const = 0; virtual Own<LocalAddress> getLocalSockaddr(const void* sockaddr, uint len) = 0;
virtual Own<RemoteAddress> getRemoteSockaddr(const void* sockaddr, uint len) const = 0; virtual Own<RemoteAddress> getRemoteSockaddr(const void* sockaddr, uint len) = 0;
// Construct a local or remote address from a legacy struct sockaddr. // Construct a local or remote address from a legacy struct sockaddr.
}; };
struct OneWayPipe { struct OneWayPipe {
// A data pipe with an input end and an output end. The two ends are safe to use in different
// threads. (Typically backed by pipe() system call.)
Own<AsyncInputStream> in; Own<AsyncInputStream> in;
Own<AsyncOutputStream> out; Own<AsyncOutputStream> out;
}; };
OneWayPipe newOneWayPipe();
// Creates an input/output stream pair representing the ends of a one-way OS pipe (created with
// pipe(2)).
struct TwoWayPipe { struct TwoWayPipe {
// A data pipe that supports sending in both directions. Each end's output sends data to the
// other end's input. The ends can be used in separate threads. (Typically backed by
// socketpair() system call.)
Own<AsyncIoStream> ends[2]; Own<AsyncIoStream> ends[2];
}; };
TwoWayPipe newTwoWayPipe();
// Creates two AsyncIoStreams representing the two ends of a two-way OS pipe (created with
// socketpair(2)). Data written to one end can be read from the other.
// =======================================================================================
namespace _ { // private class AsyncIoProvider {
// Class which constructs asynchronous wrappers around the operating system's I/O facilities.
//
// Generally, the implementation of this interface must integrate closely with a particular
// `EventLoop` implementation. Typically, the EventLoop implementation itself will provide
// an AsyncIoProvider.
class IoLoopMain {
public: public:
virtual void run(EventLoop& loop) = 0; virtual OneWayPipe newOneWayPipe() = 0;
}; // Creates an input/output stream pair representing the ends of a one-way pipe (e.g. created with
// the pipe(2) system call).
template <typename Func, typename Result> virtual TwoWayPipe newTwoWayPipe() = 0;
class IoLoopMainImpl: public IoLoopMain { // Creates two AsyncIoStreams representing the two ends of a two-way pipe (e.g. created with
public: // socketpair(2) system call). Data written to one end can be read from the other.
IoLoopMainImpl(Func&& func): func(kj::mv(func)) {}
void run(EventLoop& loop) override {
result = space.construct(kj::evalLater(func).wait());
}
Result getResult() { return kj::mv(*result); }
private:
Func func;
SpaceFor<Result> space;
Own<Result> result;
};
template <typename Func> virtual Network& getNetwork() = 0;
class IoLoopMainImpl<Func, void>: public IoLoopMain { // Creates a new `Network` instance representing the networks exposed by the operating system.
public: //
IoLoopMainImpl(Func&& func): func(kj::mv(func)) {} // DO NOT CALL THIS except at the highest levels of your code, ideally in the main() function. If
void run(EventLoop& loop) override { // you call this from low-level code, then you are preventing higher-level code from injecting an
kj::evalLater(func).wait(); // alternative implementation. Instead, if your code needs to use network functionality, it
} // should ask for a `Network` as a constructor or method parameter, so that higher-level code can
void getResult() {} // chose what implementation to use. The system network is essentially a singleton. See:
// http://www.object-oriented-security.org/lets-argue/singletons
private: //
Func func; // Code that uses the system network should not make any assumptions about what kinds of
}; // addresses it will parse, as this could differ across platforms. String addresses should come
// strictly from the user, who will know how to write them correctly for their system.
//
// With that said, KJ currently supports the following string address formats:
// - IPv4: "1.2.3.4", "1.2.3.4:80"
// - IPv6: "1234:5678::abcd", "[1234:5678::abcd]:80"
// - Local IP wildcard (local addresses only; covers both v4 and v6): "*", "*:80", ":80", "80"
// - Unix domain: "unix:/path/to/socket"
virtual Own<AsyncIoStream> newPipeThread(
Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)> startFunc) = 0;
// Create a new thread and set up a two-way pipe (socketpair) which can be used to communicate
// with it. One end of the pipe is passed to the thread's starct function and the other end of
// the pipe is returned. The new thread also gets its own `AsyncIoProvider` instance and will
// already have an active `EventLoop` when `startFunc` is called.
//
// The returned stream's destructor first closes its end of the pipe then waits for the thread to
// finish (joins it). The thread should therefore be designed to exit soon after receiving EOF
// on the input stream.
//
// TODO(someday): I'm not entirely comfortable with this interface. It seems to be doing too
// much at once but I'm not sure how to cleanly break it down.
void runIoEventLoopInternal(IoLoopMain& func); // ---------------------------------------------------------------------------
// Unix-only methods
//
// TODO(cleanup): Should these be in a subclass?
virtual Own<AsyncInputStream> wrapInputFd(int fd) = 0;
// Create an AsyncInputStream wrapping a file descriptor.
//
// Does not take ownership of the descriptor.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
} // namespace _ (private) virtual Own<AsyncOutputStream> wrapOutputFd(int fd) = 0;
// Create an AsyncOutputStream wrapping a file descriptor.
//
// Does not take ownership of the descriptor.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
template <typename Func> virtual Own<AsyncIoStream> wrapSocketFd(int fd) = 0;
auto runIoEventLoop(Func&& func) -> decltype(func().wait()) { // Create an AsyncIoStream wrapping a socket file descriptor.
// Sets up an appropriate EventLoop for doing I/O, then executes the given function. The function //
// returns a promise. The EventLoop will continue running until that promise resolves, then the // Does not take ownership of the descriptor.
// whole function will return its resolution. On return, the EventLoop is destroyed, cancelling
// all outstanding I/O.
// //
// This function is great for running inside main() to set up an Async I/O environment without // This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
// specifying a platform-specific EventLoop or other such things.
// ---------------------------------------------------------------------------
// Windows-only methods
// TODO(cleanup): I wanted to forward-declare this function in order to document it separate // TODO(port): IOCP
// from the implementation details but GCC claimed the two declarations were overloads rather };
// than the same function, even though the signature was identical. FFFFFFFFFFUUUUUUUUUUUUUUU-
typedef decltype(instance<Func>()().wait()) Result; Own<AsyncIoProvider> setupIoEventLoop();
_::IoLoopMainImpl<Func, Result> func2(kj::fwd<Func>(func)); // Convenience method which sets up the current thread with everything it needs to do async I/O.
_::runIoEventLoopInternal(func2); // The returned object contains an `EventLoop` which is wrapping an appropriate `EventPort` for
return func2.getResult(); // doing I/O on the host system, so everything is ready for the thread to start making async calls
} // and waiting on promises.
//
// You would typically call this in your main() loop or in the start function of a thread.
// Example:
//
// int main() {
// auto ioSystem = kj::setupIoEventLoop();
//
// // Now we can call an async function.
// Promise<String> textPromise = getHttp(ioSystem->getNetwork(), "http://example.com");
//
// // And we can wait for the promise to complete. Note that you can only use `wait()`
// // from the top level, not from inside a promise callback.
// String text = textPromise.wait();
// print(text);
// return 0;
// }
} // namespace kj } // namespace kj
......
...@@ -172,6 +172,18 @@ private: ...@@ -172,6 +172,18 @@ private:
friend class TaskSetImpl; friend class TaskSetImpl;
}; };
void daemonize(kj::Promise<void>&& promise);
void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result);
Promise<void> yield();
class NeverDone {
public:
template <typename T>
operator Promise<T>() const;
void wait() KJ_NORETURN;
};
} // namespace _ (private) } // namespace _ (private)
} // namespace kj } // namespace kj
......
...@@ -29,7 +29,7 @@ namespace kj { ...@@ -29,7 +29,7 @@ namespace kj {
namespace { namespace {
TEST(Async, EvalVoid) { TEST(Async, EvalVoid) {
SimpleEventLoop loop; EventLoop loop;
bool done = false; bool done = false;
...@@ -40,7 +40,7 @@ TEST(Async, EvalVoid) { ...@@ -40,7 +40,7 @@ TEST(Async, EvalVoid) {
} }
TEST(Async, EvalInt) { TEST(Async, EvalInt) {
SimpleEventLoop loop; EventLoop loop;
bool done = false; bool done = false;
...@@ -51,7 +51,7 @@ TEST(Async, EvalInt) { ...@@ -51,7 +51,7 @@ TEST(Async, EvalInt) {
} }
TEST(Async, There) { TEST(Async, There) {
SimpleEventLoop loop; EventLoop loop;
Promise<int> a = 123; Promise<int> a = 123;
bool done = false; bool done = false;
...@@ -63,7 +63,7 @@ TEST(Async, There) { ...@@ -63,7 +63,7 @@ TEST(Async, There) {
} }
TEST(Async, ThereVoid) { TEST(Async, ThereVoid) {
SimpleEventLoop loop; EventLoop loop;
Promise<int> a = 123; Promise<int> a = 123;
int value = 0; int value = 0;
...@@ -75,7 +75,7 @@ TEST(Async, ThereVoid) { ...@@ -75,7 +75,7 @@ TEST(Async, ThereVoid) {
} }
TEST(Async, Exception) { TEST(Async, Exception) {
SimpleEventLoop loop; EventLoop loop;
Promise<int> promise = evalLater( Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } }); [&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
...@@ -86,7 +86,7 @@ TEST(Async, Exception) { ...@@ -86,7 +86,7 @@ TEST(Async, Exception) {
} }
TEST(Async, HandleException) { TEST(Async, HandleException) {
SimpleEventLoop loop; EventLoop loop;
Promise<int> promise = evalLater( Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } }); [&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
...@@ -100,7 +100,7 @@ TEST(Async, HandleException) { ...@@ -100,7 +100,7 @@ TEST(Async, HandleException) {
} }
TEST(Async, PropagateException) { TEST(Async, PropagateException) {
SimpleEventLoop loop; EventLoop loop;
Promise<int> promise = evalLater( Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } }); [&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
...@@ -116,7 +116,7 @@ TEST(Async, PropagateException) { ...@@ -116,7 +116,7 @@ TEST(Async, PropagateException) {
} }
TEST(Async, PropagateExceptionTypeChange) { TEST(Async, PropagateExceptionTypeChange) {
SimpleEventLoop loop; EventLoop loop;
Promise<int> promise = evalLater( Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } }); [&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
...@@ -132,12 +132,11 @@ TEST(Async, PropagateExceptionTypeChange) { ...@@ -132,12 +132,11 @@ TEST(Async, PropagateExceptionTypeChange) {
} }
TEST(Async, Then) { TEST(Async, Then) {
SimpleEventLoop loop; EventLoop loop;
bool done = false; bool done = false;
Promise<int> promise = Promise<int>(123).then([&](int i) { Promise<int> promise = Promise<int>(123).then([&](int i) {
EXPECT_EQ(&loop, &EventLoop::current());
done = true; done = true;
return i + 321; return i + 321;
}); });
...@@ -150,15 +149,13 @@ TEST(Async, Then) { ...@@ -150,15 +149,13 @@ TEST(Async, Then) {
} }
TEST(Async, Chain) { TEST(Async, Chain) {
SimpleEventLoop loop; EventLoop loop;
Promise<int> promise = evalLater([&]() -> int { return 123; }); Promise<int> promise = evalLater([&]() -> int { return 123; });
Promise<int> promise2 = evalLater([&]() -> int { return 321; }); Promise<int> promise2 = evalLater([&]() -> int { return 321; });
auto promise3 = promise.then([&](int i) { auto promise3 = promise.then([&](int i) {
EXPECT_EQ(&loop, &EventLoop::current());
return promise2.then([&loop,i](int j) { return promise2.then([&loop,i](int j) {
EXPECT_EQ(&loop, &EventLoop::current());
return i + j; return i + j;
}); });
}); });
...@@ -167,7 +164,7 @@ TEST(Async, Chain) { ...@@ -167,7 +164,7 @@ TEST(Async, Chain) {
} }
TEST(Async, SeparateFulfiller) { TEST(Async, SeparateFulfiller) {
SimpleEventLoop loop; EventLoop loop;
auto pair = newPromiseAndFulfiller<int>(); auto pair = newPromiseAndFulfiller<int>();
...@@ -179,7 +176,7 @@ TEST(Async, SeparateFulfiller) { ...@@ -179,7 +176,7 @@ TEST(Async, SeparateFulfiller) {
} }
TEST(Async, SeparateFulfillerVoid) { TEST(Async, SeparateFulfillerVoid) {
SimpleEventLoop loop; EventLoop loop;
auto pair = newPromiseAndFulfiller<void>(); auto pair = newPromiseAndFulfiller<void>();
...@@ -199,7 +196,7 @@ TEST(Async, SeparateFulfillerCanceled) { ...@@ -199,7 +196,7 @@ TEST(Async, SeparateFulfillerCanceled) {
} }
TEST(Async, SeparateFulfillerChained) { TEST(Async, SeparateFulfillerChained) {
SimpleEventLoop loop; EventLoop loop;
auto pair = newPromiseAndFulfiller<Promise<int>>(); auto pair = newPromiseAndFulfiller<Promise<int>>();
auto inner = newPromiseAndFulfiller<int>(); auto inner = newPromiseAndFulfiller<int>();
...@@ -219,7 +216,7 @@ TEST(Async, SeparateFulfillerChained) { ...@@ -219,7 +216,7 @@ TEST(Async, SeparateFulfillerChained) {
#endif #endif
TEST(Async, SeparateFulfillerDiscarded) { TEST(Async, SeparateFulfillerDiscarded) {
SimpleEventLoop loop; EventLoop loop;
auto pair = newPromiseAndFulfiller<int>(); auto pair = newPromiseAndFulfiller<int>();
pair.fulfiller = nullptr; pair.fulfiller = nullptr;
...@@ -233,7 +230,7 @@ TEST(Async, SeparateFulfillerMemoryLeak) { ...@@ -233,7 +230,7 @@ TEST(Async, SeparateFulfillerMemoryLeak) {
} }
TEST(Async, Ordering) { TEST(Async, Ordering) {
SimpleEventLoop loop; EventLoop loop;
int counter = 0; int counter = 0;
Promise<void> promises[6] = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; Promise<void> promises[6] = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
...@@ -295,7 +292,7 @@ TEST(Async, Ordering) { ...@@ -295,7 +292,7 @@ TEST(Async, Ordering) {
} }
TEST(Async, Fork) { TEST(Async, Fork) {
SimpleEventLoop loop; EventLoop loop;
Promise<int> promise = evalLater([&]() { return 123; }); Promise<int> promise = evalLater([&]() { return 123; });
...@@ -325,7 +322,7 @@ struct RefcountedInt: public Refcounted { ...@@ -325,7 +322,7 @@ struct RefcountedInt: public Refcounted {
}; };
TEST(Async, ForkRef) { TEST(Async, ForkRef) {
SimpleEventLoop loop; EventLoop loop;
Promise<Own<RefcountedInt>> promise = evalLater([&]() { Promise<Own<RefcountedInt>> promise = evalLater([&]() {
return refcounted<RefcountedInt>(123); return refcounted<RefcountedInt>(123);
...@@ -352,7 +349,7 @@ TEST(Async, ForkRef) { ...@@ -352,7 +349,7 @@ TEST(Async, ForkRef) {
TEST(Async, ExclusiveJoin) { TEST(Async, ExclusiveJoin) {
{ {
SimpleEventLoop loop; EventLoop loop;
auto left = evalLater([&]() { return 123; }); auto left = evalLater([&]() { return 123; });
auto right = newPromiseAndFulfiller<int>(); // never fulfilled auto right = newPromiseAndFulfiller<int>(); // never fulfilled
...@@ -363,7 +360,7 @@ TEST(Async, ExclusiveJoin) { ...@@ -363,7 +360,7 @@ TEST(Async, ExclusiveJoin) {
} }
{ {
SimpleEventLoop loop; EventLoop loop;
auto left = newPromiseAndFulfiller<int>(); // never fulfilled auto left = newPromiseAndFulfiller<int>(); // never fulfilled
auto right = evalLater([&]() { return 123; }); auto right = evalLater([&]() { return 123; });
...@@ -374,7 +371,7 @@ TEST(Async, ExclusiveJoin) { ...@@ -374,7 +371,7 @@ TEST(Async, ExclusiveJoin) {
} }
{ {
SimpleEventLoop loop; EventLoop loop;
auto left = evalLater([&]() { return 123; }); auto left = evalLater([&]() { return 123; });
auto right = evalLater([&]() { return 456; }); auto right = evalLater([&]() { return 456; });
...@@ -385,7 +382,7 @@ TEST(Async, ExclusiveJoin) { ...@@ -385,7 +382,7 @@ TEST(Async, ExclusiveJoin) {
} }
{ {
SimpleEventLoop loop; EventLoop loop;
auto left = evalLater([&]() { return 123; }); auto left = evalLater([&]() { return 123; });
auto right = evalLater([&]() { return 456; }); auto right = evalLater([&]() { return 456; });
...@@ -408,7 +405,7 @@ public: ...@@ -408,7 +405,7 @@ public:
}; };
TEST(Async, TaskSet) { TEST(Async, TaskSet) {
SimpleEventLoop loop; EventLoop loop;
ErrorHandlerImpl errorHandler; ErrorHandlerImpl errorHandler;
TaskSet tasks(errorHandler); TaskSet tasks(errorHandler);
...@@ -449,7 +446,7 @@ private: ...@@ -449,7 +446,7 @@ private:
TEST(Async, Attach) { TEST(Async, Attach) {
bool destroyed = false; bool destroyed = false;
SimpleEventLoop loop; EventLoop loop;
Promise<int> promise = evalLater([&]() { Promise<int> promise = evalLater([&]() {
EXPECT_FALSE(destroyed); EXPECT_FALSE(destroyed);
...@@ -471,7 +468,7 @@ TEST(Async, Attach) { ...@@ -471,7 +468,7 @@ TEST(Async, Attach) {
TEST(Async, EagerlyEvaluate) { TEST(Async, EagerlyEvaluate) {
bool called = false; bool called = false;
SimpleEventLoop loop; EventLoop loop;
Promise<void> promise = Promise<void>(READY_NOW).then([&]() { Promise<void> promise = Promise<void>(READY_NOW).then([&]() {
called = true; called = true;
...@@ -488,7 +485,7 @@ TEST(Async, EagerlyEvaluate) { ...@@ -488,7 +485,7 @@ TEST(Async, EagerlyEvaluate) {
} }
TEST(Async, Daemonize) { TEST(Async, Daemonize) {
SimpleEventLoop loop; EventLoop loop;
bool ran1 = false; bool ran1 = false;
bool ran2 = false; bool ran2 = false;
......
...@@ -45,17 +45,18 @@ inline void delay() { usleep(10000); } ...@@ -45,17 +45,18 @@ inline void delay() { usleep(10000); }
class AsyncUnixTest: public testing::Test { class AsyncUnixTest: public testing::Test {
public: public:
static void SetUpTestCase() { static void SetUpTestCase() {
UnixEventLoop::captureSignal(SIGUSR2); UnixEventPort::captureSignal(SIGUSR2);
UnixEventLoop::captureSignal(SIGIO); UnixEventPort::captureSignal(SIGIO);
} }
}; };
TEST_F(AsyncUnixTest, Signals) { TEST_F(AsyncUnixTest, Signals) {
UnixEventLoop loop; UnixEventPort port;
EventLoop loop(port);
kill(getpid(), SIGUSR2); kill(getpid(), SIGUSR2);
siginfo_t info = loop.onSignal(SIGUSR2).wait(); siginfo_t info = port.onSignal(SIGUSR2).wait();
EXPECT_EQ(SIGUSR2, info.si_signo); EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code); EXPECT_SI_CODE(SI_USER, info.si_code);
} }
...@@ -67,14 +68,15 @@ TEST_F(AsyncUnixTest, SignalWithValue) { ...@@ -67,14 +68,15 @@ TEST_F(AsyncUnixTest, SignalWithValue) {
// though the signal we're sending is SIGUSR2, the sigqueue() system call is introduced by RT // though the signal we're sending is SIGUSR2, the sigqueue() system call is introduced by RT
// signals. Hence this test won't run on e.g. Mac OSX. // signals. Hence this test won't run on e.g. Mac OSX.
UnixEventLoop loop; UnixEventPort port;
EventLoop loop(port);
union sigval value; union sigval value;
memset(&value, 0, sizeof(value)); memset(&value, 0, sizeof(value));
value.sival_int = 123; value.sival_int = 123;
sigqueue(getpid(), SIGUSR2, value); sigqueue(getpid(), SIGUSR2, value);
siginfo_t info = loop.onSignal(SIGUSR2).wait(); siginfo_t info = port.onSignal(SIGUSR2).wait();
EXPECT_EQ(SIGUSR2, info.si_signo); EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_QUEUE, info.si_code); EXPECT_SI_CODE(SI_QUEUE, info.si_code);
EXPECT_EQ(123, info.si_value.sival_int); EXPECT_EQ(123, info.si_value.sival_int);
...@@ -82,9 +84,10 @@ TEST_F(AsyncUnixTest, SignalWithValue) { ...@@ -82,9 +84,10 @@ TEST_F(AsyncUnixTest, SignalWithValue) {
#endif #endif
TEST_F(AsyncUnixTest, SignalsMultiListen) { TEST_F(AsyncUnixTest, SignalsMultiListen) {
UnixEventLoop loop; UnixEventPort port;
EventLoop loop(port);
loop.onSignal(SIGIO).then([](siginfo_t&&) { port.onSignal(SIGIO).then([](siginfo_t&&) {
ADD_FAILURE() << "Received wrong signal."; ADD_FAILURE() << "Received wrong signal.";
}).daemonize([](kj::Exception&& exception) { }).daemonize([](kj::Exception&& exception) {
ADD_FAILURE() << kj::str(exception).cStr(); ADD_FAILURE() << kj::str(exception).cStr();
...@@ -92,28 +95,30 @@ TEST_F(AsyncUnixTest, SignalsMultiListen) { ...@@ -92,28 +95,30 @@ TEST_F(AsyncUnixTest, SignalsMultiListen) {
kill(getpid(), SIGUSR2); kill(getpid(), SIGUSR2);
siginfo_t info = loop.onSignal(SIGUSR2).wait(); siginfo_t info = port.onSignal(SIGUSR2).wait();
EXPECT_EQ(SIGUSR2, info.si_signo); EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code); EXPECT_SI_CODE(SI_USER, info.si_code);
} }
TEST_F(AsyncUnixTest, SignalsMultiReceive) { TEST_F(AsyncUnixTest, SignalsMultiReceive) {
UnixEventLoop loop; UnixEventPort port;
EventLoop loop(port);
kill(getpid(), SIGUSR2); kill(getpid(), SIGUSR2);
kill(getpid(), SIGIO); kill(getpid(), SIGIO);
siginfo_t info = loop.onSignal(SIGUSR2).wait(); siginfo_t info = port.onSignal(SIGUSR2).wait();
EXPECT_EQ(SIGUSR2, info.si_signo); EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code); EXPECT_SI_CODE(SI_USER, info.si_code);
info = loop.onSignal(SIGIO).wait(); info = port.onSignal(SIGIO).wait();
EXPECT_EQ(SIGIO, info.si_signo); EXPECT_EQ(SIGIO, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code); EXPECT_SI_CODE(SI_USER, info.si_code);
} }
TEST_F(AsyncUnixTest, SignalsAsync) { TEST_F(AsyncUnixTest, SignalsAsync) {
UnixEventLoop loop; UnixEventPort port;
EventLoop loop(port);
// Arrange for a signal to be sent from another thread. // Arrange for a signal to be sent from another thread.
pthread_t mainThread = pthread_self(); pthread_t mainThread = pthread_self();
...@@ -122,33 +127,75 @@ TEST_F(AsyncUnixTest, SignalsAsync) { ...@@ -122,33 +127,75 @@ TEST_F(AsyncUnixTest, SignalsAsync) {
pthread_kill(mainThread, SIGUSR2); pthread_kill(mainThread, SIGUSR2);
}); });
siginfo_t info = loop.onSignal(SIGUSR2).wait(); siginfo_t info = port.onSignal(SIGUSR2).wait();
EXPECT_EQ(SIGUSR2, info.si_signo); EXPECT_EQ(SIGUSR2, info.si_signo);
#if __linux__ #if __linux__
EXPECT_SI_CODE(SI_TKILL, info.si_code); EXPECT_SI_CODE(SI_TKILL, info.si_code);
#endif #endif
} }
TEST_F(AsyncUnixTest, SignalsNoWait) {
// Verify that UnixEventPort::poll() correctly receives pending signals.
UnixEventPort port;
EventLoop loop(port);
bool receivedSigusr2 = false;
bool receivedSigio = false;
port.onSignal(SIGUSR2).then([&](siginfo_t&& info) {
receivedSigusr2 = true;
EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code);
}).daemonize([](Exception&& e) { ADD_FAILURE() << str(e).cStr(); });
port.onSignal(SIGIO).then([&](siginfo_t&& info) {
receivedSigio = true;
EXPECT_EQ(SIGIO, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code);
}).daemonize([](Exception&& e) { ADD_FAILURE() << str(e).cStr(); });
kill(getpid(), SIGUSR2);
kill(getpid(), SIGIO);
EXPECT_FALSE(receivedSigusr2);
EXPECT_FALSE(receivedSigio);
loop.run();
EXPECT_FALSE(receivedSigusr2);
EXPECT_FALSE(receivedSigio);
port.poll();
EXPECT_FALSE(receivedSigusr2);
EXPECT_FALSE(receivedSigio);
loop.run();
EXPECT_TRUE(receivedSigusr2);
EXPECT_TRUE(receivedSigio);
}
TEST_F(AsyncUnixTest, Poll) { TEST_F(AsyncUnixTest, Poll) {
UnixEventLoop loop; UnixEventPort port;
EventLoop loop(port);
int pipefds[2]; int pipefds[2];
KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); }); KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); });
KJ_SYSCALL(pipe(pipefds)); KJ_SYSCALL(pipe(pipefds));
KJ_SYSCALL(write(pipefds[1], "foo", 3)); KJ_SYSCALL(write(pipefds[1], "foo", 3));
EXPECT_EQ(POLLIN, loop.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait()); EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait());
} }
TEST_F(AsyncUnixTest, PollMultiListen) { TEST_F(AsyncUnixTest, PollMultiListen) {
UnixEventLoop loop; UnixEventPort port;
EventLoop loop(port);
int bogusPipefds[2]; int bogusPipefds[2];
KJ_SYSCALL(pipe(bogusPipefds)); KJ_SYSCALL(pipe(bogusPipefds));
KJ_DEFER({ close(bogusPipefds[1]); close(bogusPipefds[0]); }); KJ_DEFER({ close(bogusPipefds[1]); close(bogusPipefds[0]); });
loop.onFdEvent(bogusPipefds[0], POLLIN | POLLPRI).then([](short s) { port.onFdEvent(bogusPipefds[0], POLLIN | POLLPRI).then([](short s) {
KJ_DBG(s);
ADD_FAILURE() << "Received wrong poll."; ADD_FAILURE() << "Received wrong poll.";
}).daemonize([](kj::Exception&& exception) { }).daemonize([](kj::Exception&& exception) {
ADD_FAILURE() << kj::str(exception).cStr(); ADD_FAILURE() << kj::str(exception).cStr();
...@@ -159,11 +206,12 @@ TEST_F(AsyncUnixTest, PollMultiListen) { ...@@ -159,11 +206,12 @@ TEST_F(AsyncUnixTest, PollMultiListen) {
KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); }); KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); });
KJ_SYSCALL(write(pipefds[1], "foo", 3)); KJ_SYSCALL(write(pipefds[1], "foo", 3));
EXPECT_EQ(POLLIN, loop.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait()); EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait());
} }
TEST_F(AsyncUnixTest, PollMultiReceive) { TEST_F(AsyncUnixTest, PollMultiReceive) {
UnixEventLoop loop; UnixEventPort port;
EventLoop loop(port);
int pipefds[2]; int pipefds[2];
KJ_SYSCALL(pipe(pipefds)); KJ_SYSCALL(pipe(pipefds));
...@@ -175,12 +223,13 @@ TEST_F(AsyncUnixTest, PollMultiReceive) { ...@@ -175,12 +223,13 @@ TEST_F(AsyncUnixTest, PollMultiReceive) {
KJ_DEFER({ close(pipefds2[1]); close(pipefds2[0]); }); KJ_DEFER({ close(pipefds2[1]); close(pipefds2[0]); });
KJ_SYSCALL(write(pipefds2[1], "bar", 3)); KJ_SYSCALL(write(pipefds2[1], "bar", 3));
EXPECT_EQ(POLLIN, loop.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait()); EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait());
EXPECT_EQ(POLLIN, loop.onFdEvent(pipefds2[0], POLLIN | POLLPRI).wait()); EXPECT_EQ(POLLIN, port.onFdEvent(pipefds2[0], POLLIN | POLLPRI).wait());
} }
TEST_F(AsyncUnixTest, PollAsync) { TEST_F(AsyncUnixTest, PollAsync) {
UnixEventLoop loop; UnixEventPort port;
EventLoop loop(port);
// Make a pipe and wait on its read end while another thread writes to it. // Make a pipe and wait on its read end while another thread writes to it.
int pipefds[2]; int pipefds[2];
...@@ -192,7 +241,49 @@ TEST_F(AsyncUnixTest, PollAsync) { ...@@ -192,7 +241,49 @@ TEST_F(AsyncUnixTest, PollAsync) {
}); });
// Wait for the event in this thread. // Wait for the event in this thread.
EXPECT_EQ(POLLIN, loop.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait()); EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait());
}
TEST_F(AsyncUnixTest, PollNoWait) {
// Verify that UnixEventPort::poll() correctly receives pending FD events.
UnixEventPort port;
EventLoop loop(port);
int pipefds[2];
KJ_SYSCALL(pipe(pipefds));
KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); });
int pipefds2[2];
KJ_SYSCALL(pipe(pipefds2));
KJ_DEFER({ close(pipefds2[1]); close(pipefds2[0]); });
int receivedCount = 0;
port.onFdEvent(pipefds[0], POLLIN | POLLPRI).then([&](short&& events) {
receivedCount++;
EXPECT_EQ(POLLIN, events);
}).daemonize([](Exception&& e) { ADD_FAILURE() << str(e).cStr(); });
port.onFdEvent(pipefds2[0], POLLIN | POLLPRI).then([&](short&& events) {
receivedCount++;
EXPECT_EQ(POLLIN, events);
}).daemonize([](Exception&& e) { ADD_FAILURE() << str(e).cStr(); });
KJ_SYSCALL(write(pipefds[1], "foo", 3));
KJ_SYSCALL(write(pipefds2[1], "bar", 3));
EXPECT_EQ(0, receivedCount);
loop.run();
EXPECT_EQ(0, receivedCount);
port.poll();
EXPECT_EQ(0, receivedCount);
loop.run();
EXPECT_EQ(2, receivedCount);
} }
} // namespace kj } // namespace kj
...@@ -74,10 +74,10 @@ pthread_once_t registerSigusr1Once = PTHREAD_ONCE_INIT; ...@@ -74,10 +74,10 @@ pthread_once_t registerSigusr1Once = PTHREAD_ONCE_INIT;
// ======================================================================================= // =======================================================================================
class UnixEventLoop::SignalPromiseAdapter { class UnixEventPort::SignalPromiseAdapter {
public: public:
inline SignalPromiseAdapter(PromiseFulfiller<siginfo_t>& fulfiller, inline SignalPromiseAdapter(PromiseFulfiller<siginfo_t>& fulfiller,
UnixEventLoop& loop, int signum) UnixEventPort& loop, int signum)
: loop(loop), signum(signum), fulfiller(fulfiller) { : loop(loop), signum(signum), fulfiller(fulfiller) {
prev = loop.signalTail; prev = loop.signalTail;
*loop.signalTail = this; *loop.signalTail = this;
...@@ -108,17 +108,17 @@ public: ...@@ -108,17 +108,17 @@ public:
return result; return result;
} }
UnixEventLoop& loop; UnixEventPort& loop;
int signum; int signum;
PromiseFulfiller<siginfo_t>& fulfiller; PromiseFulfiller<siginfo_t>& fulfiller;
SignalPromiseAdapter* next = nullptr; SignalPromiseAdapter* next = nullptr;
SignalPromiseAdapter** prev = nullptr; SignalPromiseAdapter** prev = nullptr;
}; };
class UnixEventLoop::PollPromiseAdapter { class UnixEventPort::PollPromiseAdapter {
public: public:
inline PollPromiseAdapter(PromiseFulfiller<short>& fulfiller, inline PollPromiseAdapter(PromiseFulfiller<short>& fulfiller,
UnixEventLoop& loop, int fd, short eventMask) UnixEventPort& loop, int fd, short eventMask)
: loop(loop), fd(fd), eventMask(eventMask), fulfiller(fulfiller) { : loop(loop), fd(fd), eventMask(eventMask), fulfiller(fulfiller) {
prev = loop.pollTail; prev = loop.pollTail;
*loop.pollTail = this; *loop.pollTail = this;
...@@ -147,7 +147,7 @@ public: ...@@ -147,7 +147,7 @@ public:
prev = nullptr; prev = nullptr;
} }
UnixEventLoop& loop; UnixEventPort& loop;
int fd; int fd;
short eventMask; short eventMask;
PromiseFulfiller<short>& fulfiller; PromiseFulfiller<short>& fulfiller;
...@@ -155,55 +155,73 @@ public: ...@@ -155,55 +155,73 @@ public:
PollPromiseAdapter** prev = nullptr; PollPromiseAdapter** prev = nullptr;
}; };
UnixEventLoop::UnixEventLoop() { UnixEventPort::UnixEventPort() {
pthread_once(&registerSigusr1Once, &registerSigusr1); pthread_once(&registerSigusr1Once, &registerSigusr1);
} }
UnixEventLoop::~UnixEventLoop() {} UnixEventPort::~UnixEventPort() {}
Promise<short> UnixEventLoop::onFdEvent(int fd, short eventMask) { Promise<short> UnixEventPort::onFdEvent(int fd, short eventMask) {
return newAdaptedPromise<short, PollPromiseAdapter>(*this, fd, eventMask); return newAdaptedPromise<short, PollPromiseAdapter>(*this, fd, eventMask);
} }
Promise<siginfo_t> UnixEventLoop::onSignal(int signum) { Promise<siginfo_t> UnixEventPort::onSignal(int signum) {
return newAdaptedPromise<siginfo_t, SignalPromiseAdapter>(*this, signum); return newAdaptedPromise<siginfo_t, SignalPromiseAdapter>(*this, signum);
} }
void UnixEventLoop::captureSignal(int signum) { void UnixEventPort::captureSignal(int signum) {
KJ_REQUIRE(signum != SIGUSR1, "Sorry, SIGUSR1 is reserved by the UnixEventLoop implementation."); KJ_REQUIRE(signum != SIGUSR1, "Sorry, SIGUSR1 is reserved by the UnixEventPort implementation.");
registerSignalHandler(signum); registerSignalHandler(signum);
} }
void UnixEventLoop::prepareToSleep() noexcept { class UnixEventPort::PollContext {
waitThread = pthread_self(); public:
__atomic_store_n(&isSleeping, true, __ATOMIC_RELEASE); PollContext(PollPromiseAdapter* ptr) {
}
void UnixEventLoop::sleep() {
SignalCapture capture;
threadCapture = &capture;
if (sigsetjmp(capture.jumpTo, true)) {
// We received a signal and longjmp'd back out of the signal handler.
threadCapture = nullptr;
__atomic_store_n(&isSleeping, false, __ATOMIC_RELAXED);
if (capture.siginfo.si_signo != SIGUSR1) {
// Fire any events waiting on this signal.
auto ptr = signalHead;
while (ptr != nullptr) { while (ptr != nullptr) {
if (ptr->signum == capture.siginfo.si_signo) { struct pollfd pollfd;
ptr->fulfiller.fulfill(kj::cp(capture.siginfo)); memset(&pollfd, 0, sizeof(pollfd));
ptr = ptr->removeFromList(); pollfd.fd = ptr->fd;
} else { pollfd.events = ptr->eventMask;
pollfds.add(pollfd);
pollEvents.add(ptr);
ptr = ptr->next; ptr = ptr->next;
} }
} }
void run(int timeout) {
do {
pollResult = ::poll(pollfds.begin(), pollfds.size(), timeout);
pollError = pollResult < 0 ? errno : 0;
// EINTR should only happen if we received a signal *other than* the ones registered via
// the UnixEventPort, so we don't care about that case.
} while (pollError == EINTR);
}
void processResults() {
if (pollResult < 0) {
KJ_FAIL_SYSCALL("poll()", pollError);
} }
return; for (auto i: indices(pollfds)) {
if (pollfds[i].revents != 0) {
pollEvents[i]->fulfiller.fulfill(kj::mv(pollfds[i].revents));
pollEvents[i]->removeFromList();
if (--pollResult <= 0) {
break;
}
}
}
} }
private:
kj::Vector<struct pollfd> pollfds;
kj::Vector<PollPromiseAdapter*> pollEvents;
int pollResult = 0;
int pollError = 0;
};
void UnixEventPort::wait() {
sigset_t newMask; sigset_t newMask;
sigemptyset(&newMask); sigemptyset(&newMask);
sigaddset(&newMask, SIGUSR1); sigaddset(&newMask, SIGUSR1);
...@@ -216,62 +234,94 @@ void UnixEventLoop::sleep() { ...@@ -216,62 +234,94 @@ void UnixEventLoop::sleep() {
} }
} }
kj::Vector<struct pollfd> pollfds; PollContext pollContext(pollHead);
kj::Vector<PollPromiseAdapter*> pollEvents;
{ // Capture signals.
auto ptr = pollHead; SignalCapture capture;
while (ptr != nullptr) {
struct pollfd pollfd; if (sigsetjmp(capture.jumpTo, true)) {
memset(&pollfd, 0, sizeof(pollfd)); // We received a signal and longjmp'd back out of the signal handler.
pollfd.fd = ptr->fd; threadCapture = nullptr;
pollfd.events = ptr->eventMask;
pollfds.add(pollfd); if (capture.siginfo.si_signo != SIGUSR1) {
pollEvents.add(ptr); gotSignal(capture.siginfo);
ptr = ptr->next;
} }
return;
} }
// Enable signals, run the poll, then mask them again.
sigset_t origMask; sigset_t origMask;
threadCapture = &capture;
sigprocmask(SIG_UNBLOCK, &newMask, &origMask); sigprocmask(SIG_UNBLOCK, &newMask, &origMask);
int pollResult; pollContext.run(-1);
int pollError;
do {
pollResult = poll(pollfds.begin(), pollfds.size(), -1);
pollError = pollResult < 0 ? errno : 0;
// EINTR should only happen if we received a signal *other than* the ones registered via
// the UnixEventLoop, so we don't care about that case.
} while (pollError == EINTR);
sigprocmask(SIG_SETMASK, &origMask, nullptr); sigprocmask(SIG_SETMASK, &origMask, nullptr);
threadCapture = nullptr; threadCapture = nullptr;
__atomic_store_n(&isSleeping, false, __ATOMIC_RELAXED);
if (pollResult < 0) { // Queue events.
KJ_FAIL_SYSCALL("poll()", pollError); pollContext.processResults();
}
void UnixEventPort::poll() {
sigset_t pending;
sigset_t waitMask;
sigemptyset(&pending);
sigfillset(&waitMask);
// Count how many signals that we care about are pending.
KJ_SYSCALL(sigpending(&pending));
uint signalCount = 0;
{
auto ptr = signalHead;
while (ptr != nullptr) {
if (sigismember(&pending, ptr->signum)) {
++signalCount;
sigdelset(&pending, ptr->signum);
sigdelset(&waitMask, ptr->signum);
}
ptr = ptr->next;
}
} }
for (auto i: indices(pollfds)) { // Wait for each pending signal. It would be nice to use sigtimedwait() here but it is not
if (pollfds[i].revents != 0) { // available on OSX. :( Instead, we call sigsuspend() once per expected signal.
pollEvents[i]->fulfiller.fulfill(kj::mv(pollfds[i].revents)); while (signalCount-- > 0) {
pollEvents[i]->removeFromList(); SignalCapture capture;
if (--pollResult <= 0) { threadCapture = &capture;
break; if (sigsetjmp(capture.jumpTo, true)) {
// We received a signal and longjmp'd back out of the signal handler.
KJ_DBG("unsuspend", signalCount);
sigdelset(&waitMask, capture.siginfo.si_signo);
gotSignal(capture.siginfo);
} else {
KJ_DBG("suspend", signalCount);
sigsuspend(&waitMask);
KJ_FAIL_ASSERT("sigsuspend() shouldn't return because the signal handler should "
"have siglongjmp()ed.");
} }
threadCapture = nullptr;
} }
{
PollContext pollContext(pollHead);
pollContext.run(0);
pollContext.processResults();
} }
} }
void UnixEventLoop::wake() const { void UnixEventPort::gotSignal(const siginfo_t& siginfo) {
// The first load is a fast-path check -- if false, we can avoid a barrier. If true, then we // Fire any events waiting on this signal.
// follow up with an exchange to set it false. If it turns out we were in fact the one thread auto ptr = signalHead;
// to transition the value from true to false, then we go ahead and raise SIGUSR1 on the target while (ptr != nullptr) {
// thread to wake it up. if (ptr->signum == siginfo.si_signo) {
if (__atomic_load_n(&isSleeping, __ATOMIC_RELAXED) && ptr->fulfiller.fulfill(kj::cp(siginfo));
__atomic_exchange_n(&isSleeping, false, __ATOMIC_ACQUIRE)) { ptr = ptr->removeFromList();
pthread_kill(waitThread, SIGUSR1); } else {
ptr = ptr->next;
}
} }
} }
......
...@@ -32,8 +32,10 @@ ...@@ -32,8 +32,10 @@
namespace kj { namespace kj {
class UnixEventLoop: public EventLoop { class UnixEventPort: public EventPort {
// An EventLoop implementation which can wait for events on file descriptors as well as signals. // THIS INTERFACE IS LIKELY TO CHANGE; consider using only what is defined in async-io.h instead.
//
// An EventPort implementation which can wait for events on file descriptors as well as signals.
// This API only makes sense on Unix. // This API only makes sense on Unix.
// //
// The implementation uses `poll()` or possibly a platform-specific API (e.g. epoll, kqueue). // The implementation uses `poll()` or possibly a platform-specific API (e.g. epoll, kqueue).
...@@ -45,8 +47,8 @@ class UnixEventLoop: public EventLoop { ...@@ -45,8 +47,8 @@ class UnixEventLoop: public EventLoop {
// purposes. // purposes.
public: public:
UnixEventLoop(); UnixEventPort();
~UnixEventLoop(); ~UnixEventPort();
Promise<short> onFdEvent(int fd, short eventMask); Promise<short> onFdEvent(int fd, short eventMask);
// `eventMask` is a bitwise-OR of poll events (e.g. `POLLIN`, `POLLOUT`, etc.). The next time // `eventMask` is a bitwise-OR of poll events (e.g. `POLLIN`, `POLLOUT`, etc.). The next time
...@@ -66,7 +68,7 @@ public: ...@@ -66,7 +68,7 @@ public:
// The result of waiting on the same signal twice at once is undefined. // The result of waiting on the same signal twice at once is undefined.
static void captureSignal(int signum); static void captureSignal(int signum);
// Arranges for the given signal to be captured and handled via UnixEventLoop, so that you may // Arranges for the given signal to be captured and handled via UnixEventPort, so that you may
// then pass it to `onSignal()`. This method is static because it registers a signal handler // then pass it to `onSignal()`. This method is static because it registers a signal handler
// which applies process-wide. If any other threads exist in the process when `captureSignal()` // which applies process-wide. If any other threads exist in the process when `captureSignal()`
// is called, you *must* set the signal mask in those threads to block this signal, otherwise // is called, you *must* set the signal mask in those threads to block this signal, otherwise
...@@ -77,22 +79,21 @@ public: ...@@ -77,22 +79,21 @@ public:
// To un-capture a signal, simply install a different signal handler and then un-block it from // To un-capture a signal, simply install a different signal handler and then un-block it from
// the signal mask. // the signal mask.
protected: // implements EventPort ------------------------------------------------------
void prepareToSleep() noexcept override; void wait() override;
void sleep() override; void poll() override;
void wake() const override;
private: private:
class PollPromiseAdapter; class PollPromiseAdapter;
class SignalPromiseAdapter; class SignalPromiseAdapter;
class PollContext;
PollPromiseAdapter* pollHead = nullptr; PollPromiseAdapter* pollHead = nullptr;
PollPromiseAdapter** pollTail = &pollHead; PollPromiseAdapter** pollTail = &pollHead;
SignalPromiseAdapter* signalHead = nullptr; SignalPromiseAdapter* signalHead = nullptr;
SignalPromiseAdapter** signalTail = &signalHead; SignalPromiseAdapter** signalTail = &signalHead;
pthread_t waitThread; void gotSignal(const siginfo_t& siginfo);
mutable bool isSleeping = false;
}; };
} // namespace kj } // namespace kj
......
...@@ -49,6 +49,12 @@ static __thread EventLoop* threadLocalEventLoop = nullptr; ...@@ -49,6 +49,12 @@ static __thread EventLoop* threadLocalEventLoop = nullptr;
#define _kJ_ALREADY_READY reinterpret_cast< ::kj::_::Event*>(1) #define _kJ_ALREADY_READY reinterpret_cast< ::kj::_::Event*>(1)
EventLoop& currentEventLoop() {
EventLoop* loop = threadLocalEventLoop;
KJ_REQUIRE(loop != nullptr, "No event loop is running on this thread.");
return *loop;
}
class BoolEvent: public _::Event { class BoolEvent: public _::Event {
public: public:
bool fired = false; bool fired = false;
...@@ -172,22 +178,35 @@ public: ...@@ -172,22 +178,35 @@ public:
LoggingErrorHandler LoggingErrorHandler::instance = LoggingErrorHandler(); LoggingErrorHandler LoggingErrorHandler::instance = LoggingErrorHandler();
class NullEventPort: public EventPort {
public:
void wait() override {
KJ_FAIL_REQUIRE("Nothing to wait for; this thread would hang forever.");
}
void poll() override {}
static NullEventPort instance;
};
NullEventPort NullEventPort::instance = NullEventPort();
} // namespace _ (private) } // namespace _ (private)
// ======================================================================================= // =======================================================================================
EventLoop& EventLoop::current() { void EventPort::setRunnable(bool runnable) {}
EventLoop* result = threadLocalEventLoop;
KJ_REQUIRE(result != nullptr, "No event loop is running on this thread.");
return *result;
}
bool EventLoop::isCurrent() const { EventLoop::EventLoop()
return threadLocalEventLoop == this; : port(_::NullEventPort::instance),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {
KJ_REQUIRE(threadLocalEventLoop == nullptr, "This thread already has an EventLoop.");
threadLocalEventLoop = this;
} }
EventLoop::EventLoop() EventLoop::EventLoop(EventPort& port)
: daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) { : port(port),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {
KJ_REQUIRE(threadLocalEventLoop == nullptr, "This thread already has an EventLoop."); KJ_REQUIRE(threadLocalEventLoop == nullptr, "This thread already has an EventLoop.");
threadLocalEventLoop = this; threadLocalEventLoop = this;
} }
...@@ -220,36 +239,21 @@ EventLoop::~EventLoop() noexcept(false) { ...@@ -220,36 +239,21 @@ EventLoop::~EventLoop() noexcept(false) {
} }
} }
void EventLoop::runForever() { void EventLoop::run(uint maxTurnCount) {
_::ExceptionOr<_::Void> result; for (uint i = 0; i < maxTurnCount; i++) {
waitImpl(kj::heap<NeverReadyPromiseNode>(), result); if (!turn()) {
KJ_UNREACHABLE; break;
}
}
} }
void EventLoop::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result) { bool EventLoop::turn() {
KJ_REQUIRE(threadLocalEventLoop == this,
"Can only call wait() in the thread that created this EventLoop.");
KJ_REQUIRE(!running, "wait() is not allowed from within event callbacks.");
BoolEvent doneEvent;
node->onReady(doneEvent);
running = true;
KJ_DEFER(running = false);
while (!doneEvent.fired) {
if (head == nullptr) {
// No events in the queue. Wait for callback.
prepareToSleep();
if (head != nullptr) {
// Whoa, new job was just added.
// TODO(now): Can't happen anymore?
wake();
}
sleep();
} else {
_::Event* event = head; _::Event* event = head;
if (event == nullptr) {
// No events in the queue.
return false;
} else {
head = event->next; head = event->next;
if (head != nullptr) { if (head != nullptr) {
head->prev = &head; head->prev = &head;
...@@ -269,32 +273,51 @@ void EventLoop::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result ...@@ -269,32 +273,51 @@ void EventLoop::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result
KJ_DEFER(event->firing = false); KJ_DEFER(event->firing = false);
eventToDestroy = event->fire(); eventToDestroy = event->fire();
} }
}
depthFirstInsertPoint = &head; depthFirstInsertPoint = &head;
return true;
}
}
namespace _ { // private
void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result) {
EventLoop& loop = currentEventLoop();
KJ_REQUIRE(!loop.running, "wait() is not allowed from within event callbacks.");
BoolEvent doneEvent;
node->onReady(doneEvent);
loop.running = true;
KJ_DEFER(loop.running = false);
while (!doneEvent.fired) {
if (!loop.turn()) {
// No events in the queue. Wait for callback.
loop.port.wait();
}
} }
node->get(result); node->get(result);
KJ_IF_MAYBE(exception, runCatchingExceptions([&]() { KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
node = nullptr; node = nullptr;
})) { })) {
result.addException(kj::mv(*exception)); result.addException(kj::mv(*exception));
} }
} }
Promise<void> EventLoop::yield() { Promise<void> yield() {
return Promise<void>(false, kj::heap<YieldPromiseNode>()); return Promise<void>(false, kj::heap<YieldPromiseNode>());
} }
void EventLoop::daemonize(kj::Promise<void>&& promise) { void daemonize(kj::Promise<void>&& promise) {
KJ_REQUIRE(daemons.get() != nullptr, "EventLoop is shutting down.") { return; } EventLoop& loop = currentEventLoop();
daemons->add(kj::mv(promise)); KJ_REQUIRE(loop.daemons.get() != nullptr, "EventLoop is shutting down.") { return; }
loop.daemons->add(kj::mv(promise));
} }
namespace _ { // private
Event::Event() Event::Event()
: loop(EventLoop::current()), next(nullptr), prev(nullptr) {} : loop(currentEventLoop()), next(nullptr), prev(nullptr) {}
Event::~Event() noexcept(false) { Event::~Event() noexcept(false) {
if (prev != nullptr) { if (prev != nullptr) {
...@@ -401,80 +424,6 @@ kj::String Event::trace() { ...@@ -401,80 +424,6 @@ kj::String Event::trace() {
// ======================================================================================= // =======================================================================================
#if KJ_USE_FUTEX
SimpleEventLoop::SimpleEventLoop() {}
SimpleEventLoop::~SimpleEventLoop() noexcept(false) {}
void SimpleEventLoop::prepareToSleep() noexcept {
__atomic_store_n(&preparedToSleep, 1, __ATOMIC_RELAXED);
}
void SimpleEventLoop::sleep() {
while (__atomic_load_n(&preparedToSleep, __ATOMIC_RELAXED) == 1) {
syscall(SYS_futex, &preparedToSleep, FUTEX_WAIT_PRIVATE, 1, NULL, NULL, 0);
}
}
void SimpleEventLoop::wake() const {
if (__atomic_exchange_n(&preparedToSleep, 0, __ATOMIC_RELAXED) != 0) {
// preparedToSleep was 1 before the exchange, so a sleep must be in progress in another thread.
syscall(SYS_futex, &preparedToSleep, FUTEX_WAKE_PRIVATE, 1, NULL, NULL, 0);
}
}
#else
#define KJ_PTHREAD_CALL(code) \
{ \
int pthreadError = code; \
if (pthreadError != 0) { \
KJ_FAIL_SYSCALL(#code, pthreadError); \
} \
}
#define KJ_PTHREAD_CLEANUP(code) \
{ \
int pthreadError = code; \
if (pthreadError != 0) { \
KJ_LOG(ERROR, #code, strerror(pthreadError)); \
} \
}
SimpleEventLoop::SimpleEventLoop() {
KJ_PTHREAD_CALL(pthread_mutex_init(&mutex, nullptr));
KJ_PTHREAD_CALL(pthread_cond_init(&condvar, nullptr));
}
SimpleEventLoop::~SimpleEventLoop() noexcept(false) {
KJ_PTHREAD_CLEANUP(pthread_cond_destroy(&condvar));
KJ_PTHREAD_CLEANUP(pthread_mutex_destroy(&mutex));
}
void SimpleEventLoop::prepareToSleep() noexcept {
__atomic_store_n(&preparedToSleep, 1, __ATOMIC_RELAXED);
}
void SimpleEventLoop::sleep() {
pthread_mutex_lock(&mutex);
while (__atomic_load_n(&preparedToSleep, __ATOMIC_RELAXED) == 1) {
pthread_cond_wait(&condvar, &mutex);
}
pthread_mutex_unlock(&mutex);
}
void SimpleEventLoop::wake() const {
pthread_mutex_lock(&mutex);
if (__atomic_exchange_n(&preparedToSleep, 0, __ATOMIC_RELAXED) != 0) {
// preparedToSleep was 1 before the exchange, so a sleep must be in progress in another thread.
pthread_cond_signal(&condvar);
}
pthread_mutex_unlock(&mutex);
}
#endif
// =======================================================================================
TaskSet::TaskSet(ErrorHandler& errorHandler) TaskSet::TaskSet(ErrorHandler& errorHandler)
: impl(heap<_::TaskSetImpl>(errorHandler)) {} : impl(heap<_::TaskSetImpl>(errorHandler)) {}
......
...@@ -26,14 +26,12 @@ ...@@ -26,14 +26,12 @@
#include "async-prelude.h" #include "async-prelude.h"
#include "exception.h" #include "exception.h"
#include "mutex.h"
#include "refcount.h" #include "refcount.h"
#include "tuple.h" #include "tuple.h"
namespace kj { namespace kj {
class EventLoop; class EventLoop;
class SimpleEventLoop;
template <typename T> template <typename T>
class Promise; class Promise;
...@@ -274,6 +272,7 @@ private: ...@@ -274,6 +272,7 @@ private:
template <typename> template <typename>
friend class _::ForkHub; friend class _::ForkHub;
friend class _::TaskSetImpl; friend class _::TaskSetImpl;
friend Promise<void> _::yield();
}; };
template <typename T> template <typename T>
...@@ -300,6 +299,11 @@ constexpr _::Void READY_NOW = _::Void(); ...@@ -300,6 +299,11 @@ constexpr _::Void READY_NOW = _::Void();
// Use this when you need a Promise<void> that is already fulfilled -- this value can be implicitly // Use this when you need a Promise<void> that is already fulfilled -- this value can be implicitly
// cast to `Promise<void>`. // cast to `Promise<void>`.
constexpr _::NeverDone NEVER_DONE = _::NeverDone();
// The opposite of `READY_NOW`, return this when the promise should never resolve. This can be
// implicitly converted to any promise type. You may also call `NEVER_DONE.wait()` to wait
// forever (useful for servers).
template <typename Func> template <typename Func>
PromiseForResult<Func, void> evalLater(Func&& func); PromiseForResult<Func, void> evalLater(Func&& func);
// Schedule for the given zero-parameter function to be executed in the event loop at some // Schedule for the given zero-parameter function to be executed in the event loop at some
...@@ -475,6 +479,40 @@ private: ...@@ -475,6 +479,40 @@ private:
// ======================================================================================= // =======================================================================================
// The EventLoop class // The EventLoop class
class EventPort {
// Interfaces between an `EventLoop` and events originating from outside of the loop's thread.
// All such events come in through the `EventPort` implementation.
//
// An `EventPort` implementation may interface with low-level operating system APIs and/or other
// threads. You can also write an `EventPort` which wraps some other (non-KJ) event loop
// framework, allowing the two to coexist in a single thread.
public:
virtual void wait() = 0;
// Wait for an external event to arrive, sleeping if necessary. Once at least one event has
// arrived, queue it to the event loop (e.g. by fulfilling a promise) and return.
//
// It is safe to return even if nothing has actually been queued, so long as calling `wait()` in
// a loop will eventually sleep. (That is to say, false positives are fine.)
//
// If the implementation knows that no event will ever arrive, it should throw an exception
// rather than deadlock.
//
// This is called only during `Promise::wait()`.
virtual void poll() = 0;
// Check if any external events have arrived, but do not sleep. If any events have arrived,
// add them to the event queue (e.g. by fulfilling promises) before returning.
//
// This is called only during `Promise::wait()`.
virtual void setRunnable(bool runnable);
// Called to notify the `EventPort` when the `EventLoop` has work to do; specifically when it
// transitions from empty -> runnable or runnable -> empty. This is typically useful when
// integrating with an external event loop; if the loop is currently runnable then you should
// arrange to call run() on it soon. The default implementation does nothing.
};
class EventLoop { class EventLoop {
// Represents a queue of events being executed in a loop. Most code won't interact with // Represents a queue of events being executed in a loop. Most code won't interact with
// EventLoop directly, but instead use `Promise`s to interact with it indirectly. See the // EventLoop directly, but instead use `Promise`s to interact with it indirectly. See the
...@@ -490,7 +528,8 @@ class EventLoop { ...@@ -490,7 +528,8 @@ class EventLoop {
// //
// int main() { // int main() {
// // `loop` becomes the official EventLoop for the thread. // // `loop` becomes the official EventLoop for the thread.
// SimpleEventLoop loop; // MyEventPort eventPort;
// EventLoop loop(eventPort);
// //
// // Now we can call an async function. // // Now we can call an async function.
// Promise<String> textPromise = getHttp("http://example.com"); // Promise<String> textPromise = getHttp("http://example.com");
...@@ -501,42 +540,30 @@ class EventLoop { ...@@ -501,42 +540,30 @@ class EventLoop {
// print(text); // print(text);
// return 0; // return 0;
// } // }
//
class EventJob; // Most applications that do I/O will prefer to use `setupIoEventLoop()` from `async-io.h` rather
// than allocate an `EventLoop` directly.
public: public:
EventLoop(); EventLoop();
~EventLoop() noexcept(false); // Construct an `EventLoop` which does not receive external events at all.
static EventLoop& current(); explicit EventLoop(EventPort& port);
// Get the event loop for the current thread. Throws an exception if no event loop is active. // Construct an `EventLoop` which receives external events through the given `EventPort`.
bool isCurrent() const; ~EventLoop() noexcept(false);
// Is this EventLoop the current one for this thread? This can safely be called from any thread.
void runForever() KJ_NORETURN;
// Runs the loop forever. Useful for servers.
protected:
// -----------------------------------------------------------------
// Subclasses should implement these.
virtual void prepareToSleep() noexcept = 0;
// Called just before `sleep()`. After calling this, the caller checks if any events are
// scheduled. If so, it calls `wake()`. Then, whether or not events were scheduled, it calls
// `sleep()`. Thus, `prepareToSleep()` is always followed by exactly one call to `sleep()`.
virtual void sleep() = 0; void run(uint maxTurnCount = maxValue);
// Do not return until `wake()` is called. Always preceded by a call to `prepareToSleep()`. // Run the event loop for `maxTurnCount` turns or until there is nothing left to be done,
// whichever comes first. This never calls the `EventPort`'s `sleep()` or `poll()`. It will
// call the `EventPort`'s `setRunnable(false)` if the queue becomes empty.
virtual void wake() const = 0; bool isRunnable();
// Cancel any calls to sleep() that occurred *after* the last call to `prepareToSleep()`. // Returns true if run() would currently do anything, or false if the queue is empty.
// May be called from a different thread. The interaction with `prepareToSleep()` is important:
// a `wake()` may occur between a call to `prepareToSleep()` and `sleep()`, in which case
// the subsequent `sleep()` must return immediately. `wake()` may be called any time an event
// is armed; it should return quickly if the loop isn't prepared to sleep.
private: private:
EventPort& port;
bool running = false; bool running = false;
// True while looping -- wait() is then not allowed. // True while looping -- wait() is then not allowed.
...@@ -546,56 +573,13 @@ private: ...@@ -546,56 +573,13 @@ private:
Own<_::TaskSetImpl> daemons; Own<_::TaskSetImpl> daemons;
template <typename T, typename Func, typename ErrorFunc> bool turn();
Own<_::PromiseNode> thenImpl(Promise<T>&& promise, Func&& func, ErrorFunc&& errorHandler);
void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result);
// Run the event loop until `node` is fulfilled, and then `get()` its result into `result`.
Promise<void> yield();
// Returns a promise that won't resolve until all events currently on the queue are fired.
// Otherwise, returns an already-resolved promise. Used to implement evalLater().
template <typename T>
T wait(Promise<T>&& promise);
template <typename Func>
PromiseForResult<Func, void> evalLater(Func&& func) KJ_WARN_UNUSED_RESULT;
void daemonize(kj::Promise<void>&& promise);
template <typename> friend void _::daemonize(kj::Promise<void>&& promise);
friend class Promise; friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result);
friend Promise<void> yield();
template <typename ErrorFunc>
friend void daemonize(kj::Promise<void>&& promise, ErrorFunc&& errorHandler);
template <typename Func>
friend PromiseForResult<Func, void> evalLater(Func&& func);
friend class _::Event; friend class _::Event;
}; };
// -------------------------------------------------------------------
class SimpleEventLoop final: public EventLoop {
// A simple EventLoop implementation that does not know how to wait for any external I/O.
public:
SimpleEventLoop();
~SimpleEventLoop() noexcept(false);
protected:
void prepareToSleep() noexcept override;
void sleep() override;
void wake() const override;
private:
mutable int preparedToSleep = 0;
#if !KJ_USE_FUTEX
mutable pthread_mutex_t mutex;
mutable pthread_cond_t condvar;
#endif
};
} // namespace kj } // namespace kj
#include "async-inl.h" #include "async-inl.h"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment