Commit fecb6086 authored by Kenton Varda's avatar Kenton Varda

Introduce WaitScope for making it clear what functions might wait().

parent 23b792a6
This diff is collapsed.
...@@ -35,7 +35,7 @@ TEST(EzRpc, Basic) { ...@@ -35,7 +35,7 @@ TEST(EzRpc, Basic) {
server.exportCap("cap1", kj::heap<TestInterfaceImpl>(callCount)); server.exportCap("cap1", kj::heap<TestInterfaceImpl>(callCount));
server.exportCap("cap2", kj::heap<TestCallOrderImpl>()); server.exportCap("cap2", kj::heap<TestCallOrderImpl>());
EzRpcClient client("localhost", server.getPort().wait()); EzRpcClient client("localhost", server.getPort().wait(server.getWaitScope()));
auto cap = client.importCap<test::TestInterface>("cap1"); auto cap = client.importCap<test::TestInterface>("cap1");
auto request = cap.fooRequest(); auto request = cap.fooRequest();
...@@ -43,14 +43,14 @@ TEST(EzRpc, Basic) { ...@@ -43,14 +43,14 @@ TEST(EzRpc, Basic) {
request.setJ(true); request.setJ(true);
EXPECT_EQ(0, callCount); EXPECT_EQ(0, callCount);
auto response = request.send().wait(); auto response = request.send().wait(server.getWaitScope());
EXPECT_EQ("foo", response.getX()); EXPECT_EQ("foo", response.getX());
EXPECT_EQ(1, callCount); EXPECT_EQ(1, callCount);
EXPECT_EQ(0, client.importCap("cap2").castAs<test::TestCallOrder>() EXPECT_EQ(0, client.importCap("cap2").castAs<test::TestCallOrder>()
.getCallSequenceRequest().send().wait().getN()); .getCallSequenceRequest().send().wait(server.getWaitScope()).getN());
EXPECT_EQ(1, client.importCap("cap2").castAs<test::TestCallOrder>() EXPECT_EQ(1, client.importCap("cap2").castAs<test::TestCallOrder>()
.getCallSequenceRequest().send().wait().getN()); .getCallSequenceRequest().send().wait(server.getWaitScope()).getN());
} }
} // namespace } // namespace
......
...@@ -46,6 +46,10 @@ public: ...@@ -46,6 +46,10 @@ public:
threadEzContext = nullptr; threadEzContext = nullptr;
} }
kj::WaitScope& getWaitScope() {
return ioContext.waitScope;
}
kj::AsyncIoProvider& getIoProvider() { kj::AsyncIoProvider& getIoProvider() {
return *ioContext.provider; return *ioContext.provider;
} }
...@@ -146,6 +150,10 @@ Capability::Client EzRpcClient::importCap(kj::StringPtr name) { ...@@ -146,6 +150,10 @@ Capability::Client EzRpcClient::importCap(kj::StringPtr name) {
} }
} }
kj::WaitScope& EzRpcClient::getWaitScope() {
return impl->context->getWaitScope();
}
kj::AsyncIoProvider& EzRpcClient::getIoProvider() { kj::AsyncIoProvider& EzRpcClient::getIoProvider() {
return impl->context->getIoProvider(); return impl->context->getIoProvider();
} }
...@@ -273,6 +281,10 @@ kj::Promise<uint> EzRpcServer::getPort() { ...@@ -273,6 +281,10 @@ kj::Promise<uint> EzRpcServer::getPort() {
return impl->portPromise.addBranch(); return impl->portPromise.addBranch();
} }
kj::WaitScope& EzRpcServer::getWaitScope() {
return impl->context->getWaitScope();
}
kj::AsyncIoProvider& EzRpcServer::getIoProvider() { kj::AsyncIoProvider& EzRpcServer::getIoProvider() {
return impl->context->getIoProvider(); return impl->context->getIoProvider();
} }
......
...@@ -47,7 +47,7 @@ class EzRpcClient { ...@@ -47,7 +47,7 @@ class EzRpcClient {
// auto request = adder.frobRequest(); // auto request = adder.frobRequest();
// request.setLeft(12); // request.setLeft(12);
// request.setRight(34); // request.setRight(34);
// auto response = request.wait(); // auto response = request.wait(client.getWaitScope());
// assert(response.getResult() == 46); // assert(response.getResult() == 46);
// return 0; // return 0;
// } // }
...@@ -64,7 +64,7 @@ class EzRpcClient { ...@@ -64,7 +64,7 @@ class EzRpcClient {
// int main() { // int main() {
// EzRpcServer server(":3456"); // EzRpcServer server(":3456");
// server.exportCap("adder", kj::heap<AdderImpl>()); // server.exportCap("adder", kj::heap<AdderImpl>());
// kj::NEVER_DONE.wait(); // kj::NEVER_DONE.wait(server.getWaitScope());
// } // }
// //
// This interface is easy, but it hides a lot of useful features available from the lower-level // This interface is easy, but it hides a lot of useful features available from the lower-level
...@@ -72,10 +72,11 @@ class EzRpcClient { ...@@ -72,10 +72,11 @@ class EzRpcClient {
// - The server can only export a small set of public, singleton capabilities under well-known // - The server can only export a small set of public, singleton capabilities under well-known
// string names. This is fine for transient services where no state needs to be kept between // string names. This is fine for transient services where no state needs to be kept between
// connections, but hides the power of Cap'n Proto when it comes to long-lived resources. // connections, but hides the power of Cap'n Proto when it comes to long-lived resources.
// - EzRpcClient/EzRpcServer automatically set up a `kj::EventLoop`. Only one `kj::EventLoop` // - EzRpcClient/EzRpcServer automatically set up a `kj::EventLoop` and make it current for the
// can exist per thread, so you cannot use these interfaces if you wish to set up your own // thread. Only one `kj::EventLoop` can exist per thread, so you cannot use these interfaces
// event loop. (However, you can safely create multiple EzRpcClient / EzRpcServer objects // if you wish to set up your own event loop. (However, you can safely create multiple
// in a single thread; they will make sure to make no more than one EventLoop.) // EzRpcClient / EzRpcServer objects in a single thread; they will make sure to make no more
// than one EventLoop.)
// - These classes only support simple two-party connections, not multilateral VatNetworks. // - These classes only support simple two-party connections, not multilateral VatNetworks.
public: public:
...@@ -105,6 +106,10 @@ public: ...@@ -105,6 +106,10 @@ public:
// Ask the sever for the capability with the given name. You may specify a type to automatically // Ask the sever for the capability with the given name. You may specify a type to automatically
// down-cast to that type. It is up to you to specify the correct expected type. // down-cast to that type. It is up to you to specify the correct expected type.
kj::WaitScope& getWaitScope();
// Get the `WaitScope` for the client's `EventLoop`, which allows you to synchronously wait on
// promises.
kj::AsyncIoProvider& getIoProvider(); kj::AsyncIoProvider& getIoProvider();
// Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want // Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want
// to do some non-RPC I/O in asynchronous fashion. // to do some non-RPC I/O in asynchronous fashion.
...@@ -159,6 +164,10 @@ public: ...@@ -159,6 +164,10 @@ public:
// the server is actually listening. If the address was not an IP address (e.g. it was a Unix // the server is actually listening. If the address was not an IP address (e.g. it was a Unix
// domain socket) then getPort() resolves to zero. // domain socket) then getPort() resolves to zero.
kj::WaitScope& getWaitScope();
// Get the `WaitScope` for the client's `EventLoop`, which allows you to synchronously wait on
// promises.
kj::AsyncIoProvider& getIoProvider(); kj::AsyncIoProvider& getIoProvider();
// Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want // Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want
// to do some non-RPC I/O in asynchronous fashion. // to do some non-RPC I/O in asynchronous fashion.
......
This diff is collapsed.
...@@ -61,11 +61,12 @@ private: ...@@ -61,11 +61,12 @@ private:
kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider, int& callCount) { kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider, int& callCount) {
return ioProvider.newPipeThread( return ioProvider.newPipeThread(
[&callCount](kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream) { [&callCount](kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream,
kj::WaitScope& waitScope) {
TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER); TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
TestRestorer restorer(callCount); TestRestorer restorer(callCount);
auto server = makeRpcServer(network, restorer); auto server = makeRpcServer(network, restorer);
network.onDisconnect().wait(); network.onDisconnect().wait(waitScope);
}); });
} }
...@@ -118,13 +119,13 @@ TEST(TwoPartyNetwork, Basic) { ...@@ -118,13 +119,13 @@ TEST(TwoPartyNetwork, Basic) {
EXPECT_EQ(0, callCount); EXPECT_EQ(0, callCount);
auto response1 = promise1.wait(); auto response1 = promise1.wait(ioContext.waitScope);
EXPECT_EQ("foo", response1.getX()); EXPECT_EQ("foo", response1.getX());
auto response2 = promise2.wait(); auto response2 = promise2.wait(ioContext.waitScope);
promise3.wait(); promise3.wait(ioContext.waitScope);
EXPECT_EQ(2, callCount); EXPECT_EQ(2, callCount);
EXPECT_TRUE(barFailed); EXPECT_TRUE(barFailed);
...@@ -170,10 +171,10 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -170,10 +171,10 @@ TEST(TwoPartyNetwork, Pipelining) {
EXPECT_EQ(0, callCount); EXPECT_EQ(0, callCount);
EXPECT_EQ(0, reverseCallCount); EXPECT_EQ(0, reverseCallCount);
auto response = pipelinePromise.wait(); auto response = pipelinePromise.wait(ioContext.waitScope);
EXPECT_EQ("bar", response.getX()); EXPECT_EQ("bar", response.getX());
auto response2 = pipelinePromise2.wait(); auto response2 = pipelinePromise2.wait(ioContext.waitScope);
checkTestMessage(response2); checkTestMessage(response2);
EXPECT_EQ(3, callCount); EXPECT_EQ(3, callCount);
...@@ -187,7 +188,7 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -187,7 +188,7 @@ TEST(TwoPartyNetwork, Pipelining) {
serverThread.pipe->shutdownWrite(); serverThread.pipe->shutdownWrite();
// The other side should also disconnect. // The other side should also disconnect.
disconnectPromise.wait(); disconnectPromise.wait(ioContext.waitScope);
EXPECT_FALSE(drained); EXPECT_FALSE(drained);
{ {
...@@ -206,8 +207,8 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -206,8 +207,8 @@ TEST(TwoPartyNetwork, Pipelining) {
.castAs<test::TestExtends>().graultRequest(); .castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send(); auto pipelinePromise2 = pipelineRequest2.send();
EXPECT_ANY_THROW(pipelinePromise.wait()); EXPECT_ANY_THROW(pipelinePromise.wait(ioContext.waitScope));
EXPECT_ANY_THROW(pipelinePromise2.wait()); EXPECT_ANY_THROW(pipelinePromise2.wait(ioContext.waitScope));
EXPECT_EQ(3, callCount); EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount); EXPECT_EQ(1, reverseCallCount);
...@@ -216,7 +217,7 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -216,7 +217,7 @@ TEST(TwoPartyNetwork, Pipelining) {
EXPECT_FALSE(drained); EXPECT_FALSE(drained);
} }
drainedPromise.wait(); drainedPromise.wait(ioContext.waitScope);
} }
} // namespace } // namespace
......
...@@ -132,7 +132,7 @@ TEST_F(SerializeAsyncTest, ParseAsync) { ...@@ -132,7 +132,7 @@ TEST_F(SerializeAsyncTest, ParseAsync) {
writeMessage(output, message); writeMessage(output, message);
}); });
auto received = readMessage(*input).wait(); auto received = readMessage(*input).wait(ioContext.waitScope);
checkTestMessage(received->getRoot<TestAllTypes>()); checkTestMessage(received->getRoot<TestAllTypes>());
} }
...@@ -150,7 +150,7 @@ TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) { ...@@ -150,7 +150,7 @@ TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) {
writeMessage(output, message); writeMessage(output, message);
}); });
auto received = readMessage(*input).wait(); auto received = readMessage(*input).wait(ioContext.waitScope);
checkTestMessage(received->getRoot<TestAllTypes>()); checkTestMessage(received->getRoot<TestAllTypes>());
} }
...@@ -168,7 +168,7 @@ TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) { ...@@ -168,7 +168,7 @@ TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) {
writeMessage(output, message); writeMessage(output, message);
}); });
auto received = readMessage(*input).wait(); auto received = readMessage(*input).wait(ioContext.waitScope);
checkTestMessage(received->getRoot<TestAllTypes>()); checkTestMessage(received->getRoot<TestAllTypes>());
} }
...@@ -193,7 +193,7 @@ TEST_F(SerializeAsyncTest, WriteAsync) { ...@@ -193,7 +193,7 @@ TEST_F(SerializeAsyncTest, WriteAsync) {
} }
}); });
writeMessage(*output, message).wait(); writeMessage(*output, message).wait(ioContext.waitScope);
} }
TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) { TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
...@@ -216,7 +216,7 @@ TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) { ...@@ -216,7 +216,7 @@ TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
} }
}); });
writeMessage(*output, message).wait(); writeMessage(*output, message).wait(ioContext.waitScope);
} }
TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) { TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) {
...@@ -239,7 +239,7 @@ TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) { ...@@ -239,7 +239,7 @@ TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) {
} }
}); });
writeMessage(*output, message).wait(); writeMessage(*output, message).wait(ioContext.waitScope);
} }
} // namespace } // namespace
......
...@@ -588,10 +588,10 @@ PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler ...@@ -588,10 +588,10 @@ PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler
} }
template <typename T> template <typename T>
T Promise<T>::wait() { T Promise<T>::wait(WaitScope& waitScope) {
_::ExceptionOr<_::FixVoid<T>> result; _::ExceptionOr<_::FixVoid<T>> result;
waitImpl(kj::mv(node), result); waitImpl(kj::mv(node), result, waitScope);
KJ_IF_MAYBE(value, result.value) { KJ_IF_MAYBE(value, result.value) {
KJ_IF_MAYBE(exception, result.exception) { KJ_IF_MAYBE(exception, result.exception) {
......
...@@ -62,33 +62,34 @@ TEST(AsyncIo, SimpleNetwork) { ...@@ -62,33 +62,34 @@ TEST(AsyncIo, SimpleNetwork) {
}).then([&](size_t n) { }).then([&](size_t n) {
EXPECT_EQ(3u, n); EXPECT_EQ(3u, n);
return heapString(receiveBuffer, n); return heapString(receiveBuffer, n);
}).wait(); }).wait(ioContext.waitScope);
EXPECT_EQ("foo", result); EXPECT_EQ("foo", result);
} }
String tryParse(Network& network, StringPtr text, uint portHint = 0) { String tryParse(WaitScope& waitScope, Network& network, StringPtr text, uint portHint = 0) {
return network.parseAddress(text, portHint).wait()->toString(); return network.parseAddress(text, portHint).wait(waitScope)->toString();
} }
TEST(AsyncIo, AddressParsing) { TEST(AsyncIo, AddressParsing) {
auto ioContext = setupAsyncIo(); auto ioContext = setupAsyncIo();
auto& w = ioContext.waitScope;
auto& network = ioContext.provider->getNetwork(); auto& network = ioContext.provider->getNetwork();
EXPECT_EQ("*:0", tryParse(network, "*")); EXPECT_EQ("*:0", tryParse(w, network, "*"));
EXPECT_EQ("*:123", tryParse(network, "*:123")); EXPECT_EQ("*:123", tryParse(w, network, "*:123"));
EXPECT_EQ("[::]:123", tryParse(network, "0::0", 123)); EXPECT_EQ("[::]:123", tryParse(w, network, "0::0", 123));
EXPECT_EQ("0.0.0.0:0", tryParse(network, "0.0.0.0")); EXPECT_EQ("0.0.0.0:0", tryParse(w, network, "0.0.0.0"));
EXPECT_EQ("1.2.3.4:5678", tryParse(network, "1.2.3.4", 5678)); EXPECT_EQ("1.2.3.4:5678", tryParse(w, network, "1.2.3.4", 5678));
EXPECT_EQ("[12ab:cd::34]:321", tryParse(network, "[12ab:cd:0::0:34]:321", 432)); EXPECT_EQ("[12ab:cd::34]:321", tryParse(w, network, "[12ab:cd:0::0:34]:321", 432));
EXPECT_EQ("unix:foo/bar/baz", tryParse(network, "unix:foo/bar/baz")); EXPECT_EQ("unix:foo/bar/baz", tryParse(w, network, "unix:foo/bar/baz"));
// We can parse services by name... // We can parse services by name...
EXPECT_EQ("1.2.3.4:80", tryParse(network, "1.2.3.4:http", 5678)); EXPECT_EQ("1.2.3.4:80", tryParse(w, network, "1.2.3.4:http", 5678));
EXPECT_EQ("[::]:80", tryParse(network, "[::]:http", 5678)); EXPECT_EQ("[::]:80", tryParse(w, network, "[::]:http", 5678));
EXPECT_EQ("[12ab:cd::34]:80", tryParse(network, "[12ab:cd::34]:http", 5678)); EXPECT_EQ("[12ab:cd::34]:80", tryParse(w, network, "[12ab:cd::34]:http", 5678));
EXPECT_EQ("*:80", tryParse(network, "*:http", 5678)); EXPECT_EQ("*:80", tryParse(w, network, "*:http", 5678));
// It would be nice to test DNS lookup here but the test would not be very hermetic. Even // It would be nice to test DNS lookup here but the test would not be very hermetic. Even
// localhost can map to different addresses depending on whether IPv6 is enabled. We do // localhost can map to different addresses depending on whether IPv6 is enabled. We do
...@@ -108,7 +109,7 @@ TEST(AsyncIo, OneWayPipe) { ...@@ -108,7 +109,7 @@ TEST(AsyncIo, OneWayPipe) {
kj::String result = pipe.in->tryRead(receiveBuffer, 3, 4).then([&](size_t n) { kj::String result = pipe.in->tryRead(receiveBuffer, 3, 4).then([&](size_t n) {
EXPECT_EQ(3u, n); EXPECT_EQ(3u, n);
return heapString(receiveBuffer, n); return heapString(receiveBuffer, n);
}).wait(); }).wait(ioContext.waitScope);
EXPECT_EQ("foo", result); EXPECT_EQ("foo", result);
} }
...@@ -132,9 +133,9 @@ TEST(AsyncIo, TwoWayPipe) { ...@@ -132,9 +133,9 @@ TEST(AsyncIo, TwoWayPipe) {
}).then([&](size_t n) { }).then([&](size_t n) {
EXPECT_EQ(3u, n); EXPECT_EQ(3u, n);
return heapString(receiveBuffer2, n); return heapString(receiveBuffer2, n);
}).wait(); }).wait(ioContext.waitScope);
kj::String result2 = promise.wait(); kj::String result2 = promise.wait(ioContext.waitScope);
EXPECT_EQ("foo", result); EXPECT_EQ("foo", result);
EXPECT_EQ("bar", result2); EXPECT_EQ("bar", result2);
...@@ -144,19 +145,19 @@ TEST(AsyncIo, PipeThread) { ...@@ -144,19 +145,19 @@ TEST(AsyncIo, PipeThread) {
auto ioContext = setupAsyncIo(); auto ioContext = setupAsyncIo();
auto pipeThread = ioContext.provider->newPipeThread( auto pipeThread = ioContext.provider->newPipeThread(
[](AsyncIoProvider& ioProvider, AsyncIoStream& stream) { [](AsyncIoProvider& ioProvider, AsyncIoStream& stream, WaitScope& waitScope) {
char buf[4]; char buf[4];
stream.write("foo", 3).wait(); stream.write("foo", 3).wait(waitScope);
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait()); EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait(waitScope));
EXPECT_EQ("bar", heapString(buf, 3)); EXPECT_EQ("bar", heapString(buf, 3));
// Expect disconnect. // Expect disconnect.
EXPECT_EQ(0, stream.tryRead(buf, 1, 1).wait()); EXPECT_EQ(0, stream.tryRead(buf, 1, 1).wait(waitScope));
}); });
char buf[4]; char buf[4];
pipeThread.pipe->write("bar", 3).wait(); pipeThread.pipe->write("bar", 3).wait(ioContext.waitScope);
EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait()); EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait(ioContext.waitScope));
EXPECT_EQ("foo", heapString(buf, 3)); EXPECT_EQ("foo", heapString(buf, 3));
} }
...@@ -166,21 +167,21 @@ TEST(AsyncIo, PipeThreadDisconnects) { ...@@ -166,21 +167,21 @@ TEST(AsyncIo, PipeThreadDisconnects) {
auto ioContext = setupAsyncIo(); auto ioContext = setupAsyncIo();
auto pipeThread = ioContext.provider->newPipeThread( auto pipeThread = ioContext.provider->newPipeThread(
[](AsyncIoProvider& ioProvider, AsyncIoStream& stream) { [](AsyncIoProvider& ioProvider, AsyncIoStream& stream, WaitScope& waitScope) {
char buf[4]; char buf[4];
stream.write("foo", 3).wait(); stream.write("foo", 3).wait(waitScope);
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait()); EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait(waitScope));
EXPECT_EQ("bar", heapString(buf, 3)); EXPECT_EQ("bar", heapString(buf, 3));
}); });
char buf[4]; char buf[4];
EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait()); EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait(ioContext.waitScope));
EXPECT_EQ("foo", heapString(buf, 3)); EXPECT_EQ("foo", heapString(buf, 3));
pipeThread.pipe->write("bar", 3).wait(); pipeThread.pipe->write("bar", 3).wait(ioContext.waitScope);
// Expect disconnect. // Expect disconnect.
EXPECT_EQ(0, pipeThread.pipe->tryRead(buf, 1, 1).wait()); EXPECT_EQ(0, pipeThread.pipe->tryRead(buf, 1, 1).wait(ioContext.waitScope));
} }
} // namespace } // namespace
......
...@@ -716,7 +716,9 @@ public: ...@@ -716,7 +716,9 @@ public:
class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider { class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
public: public:
LowLevelAsyncIoProviderImpl(): eventLoop(eventPort) {} LowLevelAsyncIoProviderImpl(): eventLoop(eventPort), waitScope(eventLoop) {}
inline WaitScope& getWaitScope() { return waitScope; }
Own<AsyncInputStream> wrapInputFd(int fd, uint flags = 0) override { Own<AsyncInputStream> wrapInputFd(int fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags); return heap<AsyncStreamFd>(eventPort, fd, flags);
...@@ -747,6 +749,7 @@ public: ...@@ -747,6 +749,7 @@ public:
private: private:
UnixEventPort eventPort; UnixEventPort eventPort;
EventLoop eventLoop; EventLoop eventLoop;
WaitScope waitScope;
}; };
// ======================================================================================= // =======================================================================================
...@@ -889,7 +892,7 @@ public: ...@@ -889,7 +892,7 @@ public:
} }
PipeThread newPipeThread( PipeThread newPipeThread(
Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)> startFunc) override { Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
int fds[2]; int fds[2];
int type = SOCK_STREAM; int type = SOCK_STREAM;
#if __linux__ #if __linux__
...@@ -903,11 +906,11 @@ public: ...@@ -903,11 +906,11 @@ public:
auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS); auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
auto thread = heap<Thread>(kj::mvCapture(startFunc, auto thread = heap<Thread>(kj::mvCapture(startFunc,
[threadFd](Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)>&& startFunc) { [threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
LowLevelAsyncIoProviderImpl lowLevel; LowLevelAsyncIoProviderImpl lowLevel;
auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS); auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
AsyncIoProviderImpl ioProvider(lowLevel); AsyncIoProviderImpl ioProvider(lowLevel);
startFunc(ioProvider, *stream); startFunc(ioProvider, *stream, lowLevel.getWaitScope());
})); }));
return { kj::mv(thread), kj::mv(pipe) }; return { kj::mv(thread), kj::mv(pipe) };
...@@ -931,7 +934,8 @@ Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) { ...@@ -931,7 +934,8 @@ Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
AsyncIoContext setupAsyncIo() { AsyncIoContext setupAsyncIo() {
auto lowLevel = heap<LowLevelAsyncIoProviderImpl>(); auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel); auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
return { kj::mv(lowLevel), kj::mv(ioProvider) }; auto& waitScope = lowLevel->getWaitScope();
return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope };
} }
} // namespace kj } // namespace kj
...@@ -188,7 +188,7 @@ public: ...@@ -188,7 +188,7 @@ public:
}; };
virtual PipeThread newPipeThread( virtual PipeThread newPipeThread(
Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)> startFunc) = 0; Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) = 0;
// Create a new thread and set up a two-way pipe (socketpair) which can be used to communicate // Create a new thread and set up a two-way pipe (socketpair) which can be used to communicate
// with it. One end of the pipe is passed to the thread's start function and the other end of // with it. One end of the pipe is passed to the thread's start function and the other end of
// the pipe is returned. The new thread also gets its own `AsyncIoProvider` instance and will // the pipe is returned. The new thread also gets its own `AsyncIoProvider` instance and will
...@@ -278,6 +278,7 @@ Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel); ...@@ -278,6 +278,7 @@ Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel);
struct AsyncIoContext { struct AsyncIoContext {
Own<LowLevelAsyncIoProvider> lowLevelProvider; Own<LowLevelAsyncIoProvider> lowLevelProvider;
Own<AsyncIoProvider> provider; Own<AsyncIoProvider> provider;
WaitScope& waitScope;
}; };
AsyncIoContext setupAsyncIo(); AsyncIoContext setupAsyncIo();
......
...@@ -34,6 +34,7 @@ namespace kj { ...@@ -34,6 +34,7 @@ namespace kj {
class EventLoop; class EventLoop;
template <typename T> template <typename T>
class Promise; class Promise;
class WaitScope;
namespace _ { // private namespace _ { // private
...@@ -173,7 +174,7 @@ private: ...@@ -173,7 +174,7 @@ private:
}; };
void daemonize(kj::Promise<void>&& promise); void daemonize(kj::Promise<void>&& promise);
void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result); void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope& waitScope);
Promise<void> yield(); Promise<void> yield();
class NeverDone { class NeverDone {
......
...@@ -30,63 +30,69 @@ namespace { ...@@ -30,63 +30,69 @@ namespace {
TEST(Async, EvalVoid) { TEST(Async, EvalVoid) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
bool done = false; bool done = false;
Promise<void> promise = evalLater([&]() { done = true; }); Promise<void> promise = evalLater([&]() { done = true; });
EXPECT_FALSE(done); EXPECT_FALSE(done);
promise.wait(); promise.wait(waitScope);
EXPECT_TRUE(done); EXPECT_TRUE(done);
} }
TEST(Async, EvalInt) { TEST(Async, EvalInt) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
bool done = false; bool done = false;
Promise<int> promise = evalLater([&]() { done = true; return 123; }); Promise<int> promise = evalLater([&]() { done = true; return 123; });
EXPECT_FALSE(done); EXPECT_FALSE(done);
EXPECT_EQ(123, promise.wait()); EXPECT_EQ(123, promise.wait(waitScope));
EXPECT_TRUE(done); EXPECT_TRUE(done);
} }
TEST(Async, There) { TEST(Async, There) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<int> a = 123; Promise<int> a = 123;
bool done = false; bool done = false;
Promise<int> promise = a.then([&](int ai) { done = true; return ai + 321; }); Promise<int> promise = a.then([&](int ai) { done = true; return ai + 321; });
EXPECT_FALSE(done); EXPECT_FALSE(done);
EXPECT_EQ(444, promise.wait()); EXPECT_EQ(444, promise.wait(waitScope));
EXPECT_TRUE(done); EXPECT_TRUE(done);
} }
TEST(Async, ThereVoid) { TEST(Async, ThereVoid) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<int> a = 123; Promise<int> a = 123;
int value = 0; int value = 0;
Promise<void> promise = a.then([&](int ai) { value = ai; }); Promise<void> promise = a.then([&](int ai) { value = ai; });
EXPECT_EQ(0, value); EXPECT_EQ(0, value);
promise.wait(); promise.wait(waitScope);
EXPECT_EQ(123, value); EXPECT_EQ(123, value);
} }
TEST(Async, Exception) { TEST(Async, Exception) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater( Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } }); [&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
EXPECT_TRUE(kj::runCatchingExceptions([&]() { EXPECT_TRUE(kj::runCatchingExceptions([&]() {
// wait() only returns when compiling with -fno-exceptions. // wait() only returns when compiling with -fno-exceptions.
EXPECT_EQ(123, promise.wait()); EXPECT_EQ(123, promise.wait(waitScope));
}) != nullptr); }) != nullptr);
} }
TEST(Async, HandleException) { TEST(Async, HandleException) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater( Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } }); [&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
...@@ -96,11 +102,12 @@ TEST(Async, HandleException) { ...@@ -96,11 +102,12 @@ TEST(Async, HandleException) {
[](int i) { return i + 1; }, [](int i) { return i + 1; },
[&](Exception&& e) { EXPECT_EQ(line, e.getLine()); return 345; }); [&](Exception&& e) { EXPECT_EQ(line, e.getLine()); return 345; });
EXPECT_EQ(345, promise.wait()); EXPECT_EQ(345, promise.wait(waitScope));
} }
TEST(Async, PropagateException) { TEST(Async, PropagateException) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater( Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } }); [&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
...@@ -112,11 +119,12 @@ TEST(Async, PropagateException) { ...@@ -112,11 +119,12 @@ TEST(Async, PropagateException) {
[](int i) { return i + 2; }, [](int i) { return i + 2; },
[&](Exception&& e) { EXPECT_EQ(line, e.getLine()); return 345; }); [&](Exception&& e) { EXPECT_EQ(line, e.getLine()); return 345; });
EXPECT_EQ(345, promise.wait()); EXPECT_EQ(345, promise.wait(waitScope));
} }
TEST(Async, PropagateExceptionTypeChange) { TEST(Async, PropagateExceptionTypeChange) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater( Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } }); [&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
...@@ -128,11 +136,12 @@ TEST(Async, PropagateExceptionTypeChange) { ...@@ -128,11 +136,12 @@ TEST(Async, PropagateExceptionTypeChange) {
[](StringPtr s) -> StringPtr { return "bar"; }, [](StringPtr s) -> StringPtr { return "bar"; },
[&](Exception&& e) -> StringPtr { EXPECT_EQ(line, e.getLine()); return "baz"; }); [&](Exception&& e) -> StringPtr { EXPECT_EQ(line, e.getLine()); return "baz"; });
EXPECT_EQ("baz", promise2.wait()); EXPECT_EQ("baz", promise2.wait(waitScope));
} }
TEST(Async, Then) { TEST(Async, Then) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
bool done = false; bool done = false;
...@@ -143,13 +152,14 @@ TEST(Async, Then) { ...@@ -143,13 +152,14 @@ TEST(Async, Then) {
EXPECT_FALSE(done); EXPECT_FALSE(done);
EXPECT_EQ(444, promise.wait()); EXPECT_EQ(444, promise.wait(waitScope));
EXPECT_TRUE(done); EXPECT_TRUE(done);
} }
TEST(Async, Chain) { TEST(Async, Chain) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater([&]() -> int { return 123; }); Promise<int> promise = evalLater([&]() -> int { return 123; });
Promise<int> promise2 = evalLater([&]() -> int { return 321; }); Promise<int> promise2 = evalLater([&]() -> int { return 321; });
...@@ -160,11 +170,12 @@ TEST(Async, Chain) { ...@@ -160,11 +170,12 @@ TEST(Async, Chain) {
}); });
}); });
EXPECT_EQ(444, promise3.wait()); EXPECT_EQ(444, promise3.wait(waitScope));
} }
TEST(Async, SeparateFulfiller) { TEST(Async, SeparateFulfiller) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
auto pair = newPromiseAndFulfiller<int>(); auto pair = newPromiseAndFulfiller<int>();
...@@ -172,11 +183,12 @@ TEST(Async, SeparateFulfiller) { ...@@ -172,11 +183,12 @@ TEST(Async, SeparateFulfiller) {
pair.fulfiller->fulfill(123); pair.fulfiller->fulfill(123);
EXPECT_FALSE(pair.fulfiller->isWaiting()); EXPECT_FALSE(pair.fulfiller->isWaiting());
EXPECT_EQ(123, pair.promise.wait()); EXPECT_EQ(123, pair.promise.wait(waitScope));
} }
TEST(Async, SeparateFulfillerVoid) { TEST(Async, SeparateFulfillerVoid) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
auto pair = newPromiseAndFulfiller<void>(); auto pair = newPromiseAndFulfiller<void>();
...@@ -184,7 +196,7 @@ TEST(Async, SeparateFulfillerVoid) { ...@@ -184,7 +196,7 @@ TEST(Async, SeparateFulfillerVoid) {
pair.fulfiller->fulfill(); pair.fulfiller->fulfill();
EXPECT_FALSE(pair.fulfiller->isWaiting()); EXPECT_FALSE(pair.fulfiller->isWaiting());
pair.promise.wait(); pair.promise.wait(waitScope);
} }
TEST(Async, SeparateFulfillerCanceled) { TEST(Async, SeparateFulfillerCanceled) {
...@@ -197,6 +209,7 @@ TEST(Async, SeparateFulfillerCanceled) { ...@@ -197,6 +209,7 @@ TEST(Async, SeparateFulfillerCanceled) {
TEST(Async, SeparateFulfillerChained) { TEST(Async, SeparateFulfillerChained) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
auto pair = newPromiseAndFulfiller<Promise<int>>(); auto pair = newPromiseAndFulfiller<Promise<int>>();
auto inner = newPromiseAndFulfiller<int>(); auto inner = newPromiseAndFulfiller<int>();
...@@ -207,7 +220,7 @@ TEST(Async, SeparateFulfillerChained) { ...@@ -207,7 +220,7 @@ TEST(Async, SeparateFulfillerChained) {
inner.fulfiller->fulfill(123); inner.fulfiller->fulfill(123);
EXPECT_EQ(123, pair.promise.wait()); EXPECT_EQ(123, pair.promise.wait(waitScope));
} }
#if KJ_NO_EXCEPTIONS #if KJ_NO_EXCEPTIONS
...@@ -217,11 +230,12 @@ TEST(Async, SeparateFulfillerChained) { ...@@ -217,11 +230,12 @@ TEST(Async, SeparateFulfillerChained) {
TEST(Async, SeparateFulfillerDiscarded) { TEST(Async, SeparateFulfillerDiscarded) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
auto pair = newPromiseAndFulfiller<int>(); auto pair = newPromiseAndFulfiller<int>();
pair.fulfiller = nullptr; pair.fulfiller = nullptr;
EXPECT_ANY_THROW(pair.promise.wait()); EXPECT_ANY_THROW(pair.promise.wait(waitScope));
} }
TEST(Async, SeparateFulfillerMemoryLeak) { TEST(Async, SeparateFulfillerMemoryLeak) {
...@@ -231,6 +245,7 @@ TEST(Async, SeparateFulfillerMemoryLeak) { ...@@ -231,6 +245,7 @@ TEST(Async, SeparateFulfillerMemoryLeak) {
TEST(Async, Ordering) { TEST(Async, Ordering) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
int counter = 0; int counter = 0;
Promise<void> promises[6] = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; Promise<void> promises[6] = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
...@@ -285,7 +300,7 @@ TEST(Async, Ordering) { ...@@ -285,7 +300,7 @@ TEST(Async, Ordering) {
promises[0].eagerlyEvaluate(); promises[0].eagerlyEvaluate();
for (auto i: indices(promises)) { for (auto i: indices(promises)) {
kj::mv(promises[i]).wait(); kj::mv(promises[i]).wait(waitScope);
} }
EXPECT_EQ(7, counter); EXPECT_EQ(7, counter);
...@@ -293,6 +308,7 @@ TEST(Async, Ordering) { ...@@ -293,6 +308,7 @@ TEST(Async, Ordering) {
TEST(Async, Fork) { TEST(Async, Fork) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater([&]() { return 123; }); Promise<int> promise = evalLater([&]() { return 123; });
...@@ -311,8 +327,8 @@ TEST(Async, Fork) { ...@@ -311,8 +327,8 @@ TEST(Async, Fork) {
auto releaseFork = kj::mv(fork); auto releaseFork = kj::mv(fork);
} }
EXPECT_EQ(456, branch1.wait()); EXPECT_EQ(456, branch1.wait(waitScope));
EXPECT_EQ(789, branch2.wait()); EXPECT_EQ(789, branch2.wait(waitScope));
} }
struct RefcountedInt: public Refcounted { struct RefcountedInt: public Refcounted {
...@@ -323,6 +339,7 @@ struct RefcountedInt: public Refcounted { ...@@ -323,6 +339,7 @@ struct RefcountedInt: public Refcounted {
TEST(Async, ForkRef) { TEST(Async, ForkRef) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<Own<RefcountedInt>> promise = evalLater([&]() { Promise<Own<RefcountedInt>> promise = evalLater([&]() {
return refcounted<RefcountedInt>(123); return refcounted<RefcountedInt>(123);
...@@ -343,46 +360,50 @@ TEST(Async, ForkRef) { ...@@ -343,46 +360,50 @@ TEST(Async, ForkRef) {
auto releaseFork = kj::mv(fork); auto releaseFork = kj::mv(fork);
} }
EXPECT_EQ(456, branch1.wait()); EXPECT_EQ(456, branch1.wait(waitScope));
EXPECT_EQ(789, branch2.wait()); EXPECT_EQ(789, branch2.wait(waitScope));
} }
TEST(Async, ExclusiveJoin) { TEST(Async, ExclusiveJoin) {
{ {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
auto left = evalLater([&]() { return 123; }); auto left = evalLater([&]() { return 123; });
auto right = newPromiseAndFulfiller<int>(); // never fulfilled auto right = newPromiseAndFulfiller<int>(); // never fulfilled
left.exclusiveJoin(kj::mv(right.promise)); left.exclusiveJoin(kj::mv(right.promise));
EXPECT_EQ(123, left.wait()); EXPECT_EQ(123, left.wait(waitScope));
} }
{ {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
auto left = newPromiseAndFulfiller<int>(); // never fulfilled auto left = newPromiseAndFulfiller<int>(); // never fulfilled
auto right = evalLater([&]() { return 123; }); auto right = evalLater([&]() { return 123; });
left.promise.exclusiveJoin(kj::mv(right)); left.promise.exclusiveJoin(kj::mv(right));
EXPECT_EQ(123, left.promise.wait()); EXPECT_EQ(123, left.promise.wait(waitScope));
} }
{ {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
auto left = evalLater([&]() { return 123; }); auto left = evalLater([&]() { return 123; });
auto right = evalLater([&]() { return 456; }); auto right = evalLater([&]() { return 456; });
left.exclusiveJoin(kj::mv(right)); left.exclusiveJoin(kj::mv(right));
EXPECT_EQ(123, left.wait()); EXPECT_EQ(123, left.wait(waitScope));
} }
{ {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
auto left = evalLater([&]() { return 123; }); auto left = evalLater([&]() { return 123; });
auto right = evalLater([&]() { return 456; }); auto right = evalLater([&]() { return 456; });
...@@ -391,7 +412,7 @@ TEST(Async, ExclusiveJoin) { ...@@ -391,7 +412,7 @@ TEST(Async, ExclusiveJoin) {
left.exclusiveJoin(kj::mv(right)); left.exclusiveJoin(kj::mv(right));
EXPECT_EQ(456, left.wait()); EXPECT_EQ(456, left.wait(waitScope));
} }
} }
...@@ -406,6 +427,7 @@ public: ...@@ -406,6 +427,7 @@ public:
TEST(Async, TaskSet) { TEST(Async, TaskSet) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
ErrorHandlerImpl errorHandler; ErrorHandlerImpl errorHandler;
TaskSet tasks(errorHandler); TaskSet tasks(errorHandler);
...@@ -428,7 +450,7 @@ TEST(Async, TaskSet) { ...@@ -428,7 +450,7 @@ TEST(Async, TaskSet) {
evalLater([&]() { evalLater([&]() {
EXPECT_EQ(3, counter++); EXPECT_EQ(3, counter++);
}).wait(); }).wait(waitScope);
EXPECT_EQ(4, counter); EXPECT_EQ(4, counter);
EXPECT_EQ(1u, errorHandler.exceptionCount); EXPECT_EQ(1u, errorHandler.exceptionCount);
...@@ -447,6 +469,7 @@ TEST(Async, Attach) { ...@@ -447,6 +469,7 @@ TEST(Async, Attach) {
bool destroyed = false; bool destroyed = false;
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater([&]() { Promise<int> promise = evalLater([&]() {
EXPECT_FALSE(destroyed); EXPECT_FALSE(destroyed);
...@@ -461,7 +484,7 @@ TEST(Async, Attach) { ...@@ -461,7 +484,7 @@ TEST(Async, Attach) {
}); });
EXPECT_FALSE(destroyed); EXPECT_FALSE(destroyed);
EXPECT_EQ(444, promise.wait()); EXPECT_EQ(444, promise.wait(waitScope));
EXPECT_TRUE(destroyed); EXPECT_TRUE(destroyed);
} }
...@@ -469,23 +492,25 @@ TEST(Async, EagerlyEvaluate) { ...@@ -469,23 +492,25 @@ TEST(Async, EagerlyEvaluate) {
bool called = false; bool called = false;
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
Promise<void> promise = Promise<void>(READY_NOW).then([&]() { Promise<void> promise = Promise<void>(READY_NOW).then([&]() {
called = true; called = true;
}); });
evalLater([]() {}).wait(); evalLater([]() {}).wait(waitScope);
EXPECT_FALSE(called); EXPECT_FALSE(called);
promise.eagerlyEvaluate(); promise.eagerlyEvaluate();
evalLater([]() {}).wait(); evalLater([]() {}).wait(waitScope);
EXPECT_TRUE(called); EXPECT_TRUE(called);
} }
TEST(Async, Daemonize) { TEST(Async, Daemonize) {
EventLoop loop; EventLoop loop;
WaitScope waitScope(loop);
bool ran1 = false; bool ran1 = false;
bool ran2 = false; bool ran2 = false;
...@@ -499,7 +524,7 @@ TEST(Async, Daemonize) { ...@@ -499,7 +524,7 @@ TEST(Async, Daemonize) {
EXPECT_FALSE(ran2); EXPECT_FALSE(ran2);
EXPECT_FALSE(ran3); EXPECT_FALSE(ran3);
evalLater([]() {}).wait(); evalLater([]() {}).wait(waitScope);
EXPECT_FALSE(ran1); EXPECT_FALSE(ran1);
EXPECT_TRUE(ran2); EXPECT_TRUE(ran2);
...@@ -522,6 +547,7 @@ public: ...@@ -522,6 +547,7 @@ public:
TEST(Async, SetRunnable) { TEST(Async, SetRunnable) {
DummyEventPort port; DummyEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
EXPECT_FALSE(port.runnable); EXPECT_FALSE(port.runnable);
EXPECT_EQ(0, port.callCount); EXPECT_EQ(0, port.callCount);
...@@ -535,7 +561,7 @@ TEST(Async, SetRunnable) { ...@@ -535,7 +561,7 @@ TEST(Async, SetRunnable) {
EXPECT_FALSE(port.runnable); EXPECT_FALSE(port.runnable);
EXPECT_EQ(2, port.callCount); EXPECT_EQ(2, port.callCount);
promise.wait(); promise.wait(waitScope);
EXPECT_FALSE(port.runnable); EXPECT_FALSE(port.runnable);
EXPECT_EQ(4, port.callCount); EXPECT_EQ(4, port.callCount);
} }
...@@ -556,7 +582,7 @@ TEST(Async, SetRunnable) { ...@@ -556,7 +582,7 @@ TEST(Async, SetRunnable) {
loop.run(10); loop.run(10);
EXPECT_FALSE(port.runnable); EXPECT_FALSE(port.runnable);
promise.wait(); promise.wait(waitScope);
EXPECT_FALSE(port.runnable); EXPECT_FALSE(port.runnable);
EXPECT_EQ(8, port.callCount); EXPECT_EQ(8, port.callCount);
......
...@@ -53,10 +53,11 @@ public: ...@@ -53,10 +53,11 @@ public:
TEST_F(AsyncUnixTest, Signals) { TEST_F(AsyncUnixTest, Signals) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
kill(getpid(), SIGUSR2); kill(getpid(), SIGUSR2);
siginfo_t info = port.onSignal(SIGUSR2).wait(); siginfo_t info = port.onSignal(SIGUSR2).wait(waitScope);
EXPECT_EQ(SIGUSR2, info.si_signo); EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code); EXPECT_SI_CODE(SI_USER, info.si_code);
} }
...@@ -70,13 +71,14 @@ TEST_F(AsyncUnixTest, SignalWithValue) { ...@@ -70,13 +71,14 @@ TEST_F(AsyncUnixTest, SignalWithValue) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
union sigval value; union sigval value;
memset(&value, 0, sizeof(value)); memset(&value, 0, sizeof(value));
value.sival_int = 123; value.sival_int = 123;
sigqueue(getpid(), SIGUSR2, value); sigqueue(getpid(), SIGUSR2, value);
siginfo_t info = port.onSignal(SIGUSR2).wait(); siginfo_t info = port.onSignal(SIGUSR2).wait(waitScope);
EXPECT_EQ(SIGUSR2, info.si_signo); EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_QUEUE, info.si_code); EXPECT_SI_CODE(SI_QUEUE, info.si_code);
EXPECT_EQ(123, info.si_value.sival_int); EXPECT_EQ(123, info.si_value.sival_int);
...@@ -86,6 +88,7 @@ TEST_F(AsyncUnixTest, SignalWithValue) { ...@@ -86,6 +88,7 @@ TEST_F(AsyncUnixTest, SignalWithValue) {
TEST_F(AsyncUnixTest, SignalsMultiListen) { TEST_F(AsyncUnixTest, SignalsMultiListen) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
port.onSignal(SIGIO).then([](siginfo_t&&) { port.onSignal(SIGIO).then([](siginfo_t&&) {
ADD_FAILURE() << "Received wrong signal."; ADD_FAILURE() << "Received wrong signal.";
...@@ -95,7 +98,7 @@ TEST_F(AsyncUnixTest, SignalsMultiListen) { ...@@ -95,7 +98,7 @@ TEST_F(AsyncUnixTest, SignalsMultiListen) {
kill(getpid(), SIGUSR2); kill(getpid(), SIGUSR2);
siginfo_t info = port.onSignal(SIGUSR2).wait(); siginfo_t info = port.onSignal(SIGUSR2).wait(waitScope);
EXPECT_EQ(SIGUSR2, info.si_signo); EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code); EXPECT_SI_CODE(SI_USER, info.si_code);
} }
...@@ -103,15 +106,16 @@ TEST_F(AsyncUnixTest, SignalsMultiListen) { ...@@ -103,15 +106,16 @@ TEST_F(AsyncUnixTest, SignalsMultiListen) {
TEST_F(AsyncUnixTest, SignalsMultiReceive) { TEST_F(AsyncUnixTest, SignalsMultiReceive) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
kill(getpid(), SIGUSR2); kill(getpid(), SIGUSR2);
kill(getpid(), SIGIO); kill(getpid(), SIGIO);
siginfo_t info = port.onSignal(SIGUSR2).wait(); siginfo_t info = port.onSignal(SIGUSR2).wait(waitScope);
EXPECT_EQ(SIGUSR2, info.si_signo); EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code); EXPECT_SI_CODE(SI_USER, info.si_code);
info = port.onSignal(SIGIO).wait(); info = port.onSignal(SIGIO).wait(waitScope);
EXPECT_EQ(SIGIO, info.si_signo); EXPECT_EQ(SIGIO, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code); EXPECT_SI_CODE(SI_USER, info.si_code);
} }
...@@ -119,6 +123,7 @@ TEST_F(AsyncUnixTest, SignalsMultiReceive) { ...@@ -119,6 +123,7 @@ TEST_F(AsyncUnixTest, SignalsMultiReceive) {
TEST_F(AsyncUnixTest, SignalsAsync) { TEST_F(AsyncUnixTest, SignalsAsync) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
// Arrange for a signal to be sent from another thread. // Arrange for a signal to be sent from another thread.
pthread_t mainThread = pthread_self(); pthread_t mainThread = pthread_self();
...@@ -127,7 +132,7 @@ TEST_F(AsyncUnixTest, SignalsAsync) { ...@@ -127,7 +132,7 @@ TEST_F(AsyncUnixTest, SignalsAsync) {
pthread_kill(mainThread, SIGUSR2); pthread_kill(mainThread, SIGUSR2);
}); });
siginfo_t info = port.onSignal(SIGUSR2).wait(); siginfo_t info = port.onSignal(SIGUSR2).wait(waitScope);
EXPECT_EQ(SIGUSR2, info.si_signo); EXPECT_EQ(SIGUSR2, info.si_signo);
#if __linux__ #if __linux__
EXPECT_SI_CODE(SI_TKILL, info.si_code); EXPECT_SI_CODE(SI_TKILL, info.si_code);
...@@ -139,6 +144,7 @@ TEST_F(AsyncUnixTest, SignalsNoWait) { ...@@ -139,6 +144,7 @@ TEST_F(AsyncUnixTest, SignalsNoWait) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
bool receivedSigusr2 = false; bool receivedSigusr2 = false;
bool receivedSigio = false; bool receivedSigio = false;
...@@ -178,18 +184,20 @@ TEST_F(AsyncUnixTest, SignalsNoWait) { ...@@ -178,18 +184,20 @@ TEST_F(AsyncUnixTest, SignalsNoWait) {
TEST_F(AsyncUnixTest, Poll) { TEST_F(AsyncUnixTest, Poll) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
int pipefds[2]; int pipefds[2];
KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); }); KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); });
KJ_SYSCALL(pipe(pipefds)); KJ_SYSCALL(pipe(pipefds));
KJ_SYSCALL(write(pipefds[1], "foo", 3)); KJ_SYSCALL(write(pipefds[1], "foo", 3));
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait()); EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait(waitScope));
} }
TEST_F(AsyncUnixTest, PollMultiListen) { TEST_F(AsyncUnixTest, PollMultiListen) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
int bogusPipefds[2]; int bogusPipefds[2];
KJ_SYSCALL(pipe(bogusPipefds)); KJ_SYSCALL(pipe(bogusPipefds));
...@@ -206,12 +214,13 @@ TEST_F(AsyncUnixTest, PollMultiListen) { ...@@ -206,12 +214,13 @@ TEST_F(AsyncUnixTest, PollMultiListen) {
KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); }); KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); });
KJ_SYSCALL(write(pipefds[1], "foo", 3)); KJ_SYSCALL(write(pipefds[1], "foo", 3));
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait()); EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait(waitScope));
} }
TEST_F(AsyncUnixTest, PollMultiReceive) { TEST_F(AsyncUnixTest, PollMultiReceive) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
int pipefds[2]; int pipefds[2];
KJ_SYSCALL(pipe(pipefds)); KJ_SYSCALL(pipe(pipefds));
...@@ -223,13 +232,14 @@ TEST_F(AsyncUnixTest, PollMultiReceive) { ...@@ -223,13 +232,14 @@ TEST_F(AsyncUnixTest, PollMultiReceive) {
KJ_DEFER({ close(pipefds2[1]); close(pipefds2[0]); }); KJ_DEFER({ close(pipefds2[1]); close(pipefds2[0]); });
KJ_SYSCALL(write(pipefds2[1], "bar", 3)); KJ_SYSCALL(write(pipefds2[1], "bar", 3));
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait()); EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait(waitScope));
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds2[0], POLLIN | POLLPRI).wait()); EXPECT_EQ(POLLIN, port.onFdEvent(pipefds2[0], POLLIN | POLLPRI).wait(waitScope));
} }
TEST_F(AsyncUnixTest, PollAsync) { TEST_F(AsyncUnixTest, PollAsync) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
// Make a pipe and wait on its read end while another thread writes to it. // Make a pipe and wait on its read end while another thread writes to it.
int pipefds[2]; int pipefds[2];
...@@ -241,7 +251,7 @@ TEST_F(AsyncUnixTest, PollAsync) { ...@@ -241,7 +251,7 @@ TEST_F(AsyncUnixTest, PollAsync) {
}); });
// Wait for the event in this thread. // Wait for the event in this thread.
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait()); EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait(waitScope));
} }
TEST_F(AsyncUnixTest, PollNoWait) { TEST_F(AsyncUnixTest, PollNoWait) {
...@@ -249,6 +259,7 @@ TEST_F(AsyncUnixTest, PollNoWait) { ...@@ -249,6 +259,7 @@ TEST_F(AsyncUnixTest, PollNoWait) {
UnixEventPort port; UnixEventPort port;
EventLoop loop(port); EventLoop loop(port);
WaitScope waitScope(loop);
int pipefds[2]; int pipefds[2];
KJ_SYSCALL(pipe(pipefds)); KJ_SYSCALL(pipe(pipefds));
......
...@@ -199,26 +199,13 @@ void EventPort::setRunnable(bool runnable) {} ...@@ -199,26 +199,13 @@ void EventPort::setRunnable(bool runnable) {}
EventLoop::EventLoop() EventLoop::EventLoop()
: port(_::NullEventPort::instance), : port(_::NullEventPort::instance),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) { daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {}
KJ_REQUIRE(threadLocalEventLoop == nullptr, "This thread already has an EventLoop.");
threadLocalEventLoop = this;
}
EventLoop::EventLoop(EventPort& port) EventLoop::EventLoop(EventPort& port)
: port(port), : port(port),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) { daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {}
KJ_REQUIRE(threadLocalEventLoop == nullptr, "This thread already has an EventLoop.");
threadLocalEventLoop = this;
}
EventLoop::~EventLoop() noexcept(false) { EventLoop::~EventLoop() noexcept(false) {
KJ_REQUIRE(threadLocalEventLoop == this,
"EventLoop being destroyed in a different thread than it was created.") {
break;
}
KJ_DEFER(threadLocalEventLoop = nullptr);
// Destroy all "daemon" tasks, noting that their destructors might try to access the EventLoop // Destroy all "daemon" tasks, noting that their destructors might try to access the EventLoop
// some more. // some more.
daemons = nullptr; daemons = nullptr;
...@@ -237,6 +224,12 @@ EventLoop::~EventLoop() noexcept(false) { ...@@ -237,6 +224,12 @@ EventLoop::~EventLoop() noexcept(false) {
} }
break; break;
} }
KJ_REQUIRE(threadLocalEventLoop != this,
"EventLoop destroyed while still current for the thread.") {
threadLocalEventLoop = nullptr;
break;
}
} }
void EventLoop::run(uint maxTurnCount) { void EventLoop::run(uint maxTurnCount) {
...@@ -291,10 +284,24 @@ void EventLoop::setRunnable(bool runnable) { ...@@ -291,10 +284,24 @@ void EventLoop::setRunnable(bool runnable) {
} }
} }
void EventLoop::enterScope() {
KJ_REQUIRE(threadLocalEventLoop == nullptr, "This thread already has an EventLoop.");
threadLocalEventLoop = this;
}
void EventLoop::leaveScope() {
KJ_REQUIRE(threadLocalEventLoop == this,
"WaitScope destroyed in a different thread than it was created in.") {
break;
}
threadLocalEventLoop = nullptr;
}
namespace _ { // private namespace _ { // private
void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result) { void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope& waitScope) {
EventLoop& loop = currentEventLoop(); EventLoop& loop = waitScope.loop;
KJ_REQUIRE(&loop == threadLocalEventLoop, "WaitScope not valid for this thread.");
KJ_REQUIRE(!loop.running, "wait() is not allowed from within event callbacks."); KJ_REQUIRE(!loop.running, "wait() is not allowed from within event callbacks.");
BoolEvent doneEvent; BoolEvent doneEvent;
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
namespace kj { namespace kj {
class EventLoop; class EventLoop;
class WaitScope;
template <typename T> template <typename T>
class Promise; class Promise;
...@@ -186,14 +187,10 @@ public: ...@@ -186,14 +187,10 @@ public:
// actual I/O. To solve this, use `kj::evalLater()` to yield control; this way, all other events // actual I/O. To solve this, use `kj::evalLater()` to yield control; this way, all other events
// in the queue will get a chance to run before your callback is executed. // in the queue will get a chance to run before your callback is executed.
T wait(); T wait(WaitScope& waitScope);
// Run the event loop until the promise is fulfilled, then return its result. If the promise // Run the event loop until the promise is fulfilled, then return its result. If the promise
// is rejected, throw an exception. // is rejected, throw an exception.
// //
// wait() cannot be called recursively -- that is, an event callback cannot call wait().
// Instead, callbacks that need to perform more async operations should return a promise and
// rely on promise chaining.
//
// wait() is primarily useful at the top level of a program -- typically, within the function // wait() is primarily useful at the top level of a program -- typically, within the function
// that allocated the EventLoop. For example, a program that performs one or two RPCs and then // that allocated the EventLoop. For example, a program that performs one or two RPCs and then
// exits would likely use wait() in its main() function to wait on each RPC. On the other hand, // exits would likely use wait() in its main() function to wait on each RPC. On the other hand,
...@@ -205,13 +202,27 @@ public: ...@@ -205,13 +202,27 @@ public:
// use `then()` to set an appropriate handler for the exception case, so that the promise you // use `then()` to set an appropriate handler for the exception case, so that the promise you
// actually wait on never throws. // actually wait on never throws.
// //
// `waitScope` is an object proving that the caller is in a scope where wait() is allowed. By
// convention, any function which might call wait(), or which might call another function which
// might call wait(), must take `WaitScope&` as one of its parameters. This is needed for two
// reasons:
// * `wait()` is not allowed during an event callback, because event callbacks are themselves
// called during some other `wait()`, and such recursive `wait()`s would only be able to
// complete in LIFO order, which might mean that the outer `wait()` ends up waiting longer
// than it is supposed to. To prevent this, a `WaitScope` cannot be constructed or used during
// an event callback.
// * Since `wait()` runs the event loop, unrelated event callbacks may execute before `wait()`
// returns. This means that anyone calling `wait()` must be reentrant -- state may change
// around them in arbitrary ways. Therefore, callers really need to know if a function they
// are calling might wait(), and the `WaitScope&` parameter makes this clear.
//
// Note that `wait()` consumes the promise on which it is called, in the sense of move semantics. // Note that `wait()` consumes the promise on which it is called, in the sense of move semantics.
// After returning, the original promise is no longer valid. // After returning, the original promise is no longer valid.
// //
// TODO(someday): Implement fibers, and let them call wait() even when they are handling an // TODO(someday): Implement fibers, and let them call wait() even when they are handling an
// event. // event.
ForkedPromise<T> fork(); ForkedPromise<T> fork() KJ_WARN_UNUSED_RESULT;
// Forks the promise, so that multiple different clients can independently wait on the result. // Forks the promise, so that multiple different clients can independently wait on the result.
// `T` must be copy-constructable for this to work. Or, in the special case where `T` is // `T` must be copy-constructable for this to work. Or, in the special case where `T` is
// `Own<U>`, `U` must have a method `Own<U> addRef()` which returns a new reference to the same // `Own<U>`, `U` must have a method `Own<U> addRef()` which returns a new reference to the same
...@@ -578,10 +589,35 @@ private: ...@@ -578,10 +589,35 @@ private:
bool turn(); bool turn();
void setRunnable(bool runnable); void setRunnable(bool runnable);
void enterScope();
void leaveScope();
friend void _::daemonize(kj::Promise<void>&& promise); friend void _::daemonize(kj::Promise<void>&& promise);
friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result); friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result,
WaitScope& waitScope);
friend class _::Event; friend class _::Event;
friend class WaitScope;
};
class WaitScope {
// Represents a scope in which asynchronous programming can occur. A `WaitScope` should usually
// be allocated on the stack and serves two purposes:
// * While the `WaitScope` exists, its `EventLoop` is registered as the current loop for the
// thread. Most operations dealing with `Promise` (including all of its methods) do not work
// unless the thread has a current `EventLoop`.
// * `WaitScope` may be passed to `Promise::wait()` to synchronously wait for a particular
// promise to complete. See `Promise::wait()` for an extended discussion.
public:
inline explicit WaitScope(EventLoop& loop): loop(loop) { loop.enterScope(); }
inline ~WaitScope() { loop.leaveScope(); }
KJ_DISALLOW_COPY(WaitScope);
private:
EventLoop& loop;
friend class EventLoop;
friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result,
WaitScope& waitScope);
}; };
} // namespace kj } // namespace kj
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment