Commit fecb6086 authored by Kenton Varda's avatar Kenton Varda

Introduce WaitScope for making it clear what functions might wait().

parent 23b792a6
This diff is collapsed.
......@@ -35,7 +35,7 @@ TEST(EzRpc, Basic) {
server.exportCap("cap1", kj::heap<TestInterfaceImpl>(callCount));
server.exportCap("cap2", kj::heap<TestCallOrderImpl>());
EzRpcClient client("localhost", server.getPort().wait());
EzRpcClient client("localhost", server.getPort().wait(server.getWaitScope()));
auto cap = client.importCap<test::TestInterface>("cap1");
auto request = cap.fooRequest();
......@@ -43,14 +43,14 @@ TEST(EzRpc, Basic) {
request.setJ(true);
EXPECT_EQ(0, callCount);
auto response = request.send().wait();
auto response = request.send().wait(server.getWaitScope());
EXPECT_EQ("foo", response.getX());
EXPECT_EQ(1, callCount);
EXPECT_EQ(0, client.importCap("cap2").castAs<test::TestCallOrder>()
.getCallSequenceRequest().send().wait().getN());
.getCallSequenceRequest().send().wait(server.getWaitScope()).getN());
EXPECT_EQ(1, client.importCap("cap2").castAs<test::TestCallOrder>()
.getCallSequenceRequest().send().wait().getN());
.getCallSequenceRequest().send().wait(server.getWaitScope()).getN());
}
} // namespace
......
......@@ -46,6 +46,10 @@ public:
threadEzContext = nullptr;
}
kj::WaitScope& getWaitScope() {
return ioContext.waitScope;
}
kj::AsyncIoProvider& getIoProvider() {
return *ioContext.provider;
}
......@@ -146,6 +150,10 @@ Capability::Client EzRpcClient::importCap(kj::StringPtr name) {
}
}
kj::WaitScope& EzRpcClient::getWaitScope() {
return impl->context->getWaitScope();
}
kj::AsyncIoProvider& EzRpcClient::getIoProvider() {
return impl->context->getIoProvider();
}
......@@ -273,6 +281,10 @@ kj::Promise<uint> EzRpcServer::getPort() {
return impl->portPromise.addBranch();
}
kj::WaitScope& EzRpcServer::getWaitScope() {
return impl->context->getWaitScope();
}
kj::AsyncIoProvider& EzRpcServer::getIoProvider() {
return impl->context->getIoProvider();
}
......
......@@ -47,7 +47,7 @@ class EzRpcClient {
// auto request = adder.frobRequest();
// request.setLeft(12);
// request.setRight(34);
// auto response = request.wait();
// auto response = request.wait(client.getWaitScope());
// assert(response.getResult() == 46);
// return 0;
// }
......@@ -64,7 +64,7 @@ class EzRpcClient {
// int main() {
// EzRpcServer server(":3456");
// server.exportCap("adder", kj::heap<AdderImpl>());
// kj::NEVER_DONE.wait();
// kj::NEVER_DONE.wait(server.getWaitScope());
// }
//
// This interface is easy, but it hides a lot of useful features available from the lower-level
......@@ -72,10 +72,11 @@ class EzRpcClient {
// - The server can only export a small set of public, singleton capabilities under well-known
// string names. This is fine for transient services where no state needs to be kept between
// connections, but hides the power of Cap'n Proto when it comes to long-lived resources.
// - EzRpcClient/EzRpcServer automatically set up a `kj::EventLoop`. Only one `kj::EventLoop`
// can exist per thread, so you cannot use these interfaces if you wish to set up your own
// event loop. (However, you can safely create multiple EzRpcClient / EzRpcServer objects
// in a single thread; they will make sure to make no more than one EventLoop.)
// - EzRpcClient/EzRpcServer automatically set up a `kj::EventLoop` and make it current for the
// thread. Only one `kj::EventLoop` can exist per thread, so you cannot use these interfaces
// if you wish to set up your own event loop. (However, you can safely create multiple
// EzRpcClient / EzRpcServer objects in a single thread; they will make sure to make no more
// than one EventLoop.)
// - These classes only support simple two-party connections, not multilateral VatNetworks.
public:
......@@ -105,6 +106,10 @@ public:
// Ask the sever for the capability with the given name. You may specify a type to automatically
// down-cast to that type. It is up to you to specify the correct expected type.
kj::WaitScope& getWaitScope();
// Get the `WaitScope` for the client's `EventLoop`, which allows you to synchronously wait on
// promises.
kj::AsyncIoProvider& getIoProvider();
// Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want
// to do some non-RPC I/O in asynchronous fashion.
......@@ -159,6 +164,10 @@ public:
// the server is actually listening. If the address was not an IP address (e.g. it was a Unix
// domain socket) then getPort() resolves to zero.
kj::WaitScope& getWaitScope();
// Get the `WaitScope` for the client's `EventLoop`, which allows you to synchronously wait on
// promises.
kj::AsyncIoProvider& getIoProvider();
// Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want
// to do some non-RPC I/O in asynchronous fashion.
......
This diff is collapsed.
......@@ -61,11 +61,12 @@ private:
kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider, int& callCount) {
return ioProvider.newPipeThread(
[&callCount](kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream) {
[&callCount](kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream,
kj::WaitScope& waitScope) {
TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
TestRestorer restorer(callCount);
auto server = makeRpcServer(network, restorer);
network.onDisconnect().wait();
network.onDisconnect().wait(waitScope);
});
}
......@@ -118,13 +119,13 @@ TEST(TwoPartyNetwork, Basic) {
EXPECT_EQ(0, callCount);
auto response1 = promise1.wait();
auto response1 = promise1.wait(ioContext.waitScope);
EXPECT_EQ("foo", response1.getX());
auto response2 = promise2.wait();
auto response2 = promise2.wait(ioContext.waitScope);
promise3.wait();
promise3.wait(ioContext.waitScope);
EXPECT_EQ(2, callCount);
EXPECT_TRUE(barFailed);
......@@ -170,10 +171,10 @@ TEST(TwoPartyNetwork, Pipelining) {
EXPECT_EQ(0, callCount);
EXPECT_EQ(0, reverseCallCount);
auto response = pipelinePromise.wait();
auto response = pipelinePromise.wait(ioContext.waitScope);
EXPECT_EQ("bar", response.getX());
auto response2 = pipelinePromise2.wait();
auto response2 = pipelinePromise2.wait(ioContext.waitScope);
checkTestMessage(response2);
EXPECT_EQ(3, callCount);
......@@ -187,7 +188,7 @@ TEST(TwoPartyNetwork, Pipelining) {
serverThread.pipe->shutdownWrite();
// The other side should also disconnect.
disconnectPromise.wait();
disconnectPromise.wait(ioContext.waitScope);
EXPECT_FALSE(drained);
{
......@@ -206,8 +207,8 @@ TEST(TwoPartyNetwork, Pipelining) {
.castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
EXPECT_ANY_THROW(pipelinePromise.wait());
EXPECT_ANY_THROW(pipelinePromise2.wait());
EXPECT_ANY_THROW(pipelinePromise.wait(ioContext.waitScope));
EXPECT_ANY_THROW(pipelinePromise2.wait(ioContext.waitScope));
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
......@@ -216,7 +217,7 @@ TEST(TwoPartyNetwork, Pipelining) {
EXPECT_FALSE(drained);
}
drainedPromise.wait();
drainedPromise.wait(ioContext.waitScope);
}
} // namespace
......
......@@ -132,7 +132,7 @@ TEST_F(SerializeAsyncTest, ParseAsync) {
writeMessage(output, message);
});
auto received = readMessage(*input).wait();
auto received = readMessage(*input).wait(ioContext.waitScope);
checkTestMessage(received->getRoot<TestAllTypes>());
}
......@@ -150,7 +150,7 @@ TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) {
writeMessage(output, message);
});
auto received = readMessage(*input).wait();
auto received = readMessage(*input).wait(ioContext.waitScope);
checkTestMessage(received->getRoot<TestAllTypes>());
}
......@@ -168,7 +168,7 @@ TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) {
writeMessage(output, message);
});
auto received = readMessage(*input).wait();
auto received = readMessage(*input).wait(ioContext.waitScope);
checkTestMessage(received->getRoot<TestAllTypes>());
}
......@@ -193,7 +193,7 @@ TEST_F(SerializeAsyncTest, WriteAsync) {
}
});
writeMessage(*output, message).wait();
writeMessage(*output, message).wait(ioContext.waitScope);
}
TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
......@@ -216,7 +216,7 @@ TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
}
});
writeMessage(*output, message).wait();
writeMessage(*output, message).wait(ioContext.waitScope);
}
TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) {
......@@ -239,7 +239,7 @@ TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) {
}
});
writeMessage(*output, message).wait();
writeMessage(*output, message).wait(ioContext.waitScope);
}
} // namespace
......
......@@ -588,10 +588,10 @@ PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler
}
template <typename T>
T Promise<T>::wait() {
T Promise<T>::wait(WaitScope& waitScope) {
_::ExceptionOr<_::FixVoid<T>> result;
waitImpl(kj::mv(node), result);
waitImpl(kj::mv(node), result, waitScope);
KJ_IF_MAYBE(value, result.value) {
KJ_IF_MAYBE(exception, result.exception) {
......
......@@ -62,33 +62,34 @@ TEST(AsyncIo, SimpleNetwork) {
}).then([&](size_t n) {
EXPECT_EQ(3u, n);
return heapString(receiveBuffer, n);
}).wait();
}).wait(ioContext.waitScope);
EXPECT_EQ("foo", result);
}
String tryParse(Network& network, StringPtr text, uint portHint = 0) {
return network.parseAddress(text, portHint).wait()->toString();
String tryParse(WaitScope& waitScope, Network& network, StringPtr text, uint portHint = 0) {
return network.parseAddress(text, portHint).wait(waitScope)->toString();
}
TEST(AsyncIo, AddressParsing) {
auto ioContext = setupAsyncIo();
auto& w = ioContext.waitScope;
auto& network = ioContext.provider->getNetwork();
EXPECT_EQ("*:0", tryParse(network, "*"));
EXPECT_EQ("*:123", tryParse(network, "*:123"));
EXPECT_EQ("[::]:123", tryParse(network, "0::0", 123));
EXPECT_EQ("0.0.0.0:0", tryParse(network, "0.0.0.0"));
EXPECT_EQ("1.2.3.4:5678", tryParse(network, "1.2.3.4", 5678));
EXPECT_EQ("[12ab:cd::34]:321", tryParse(network, "[12ab:cd:0::0:34]:321", 432));
EXPECT_EQ("*:0", tryParse(w, network, "*"));
EXPECT_EQ("*:123", tryParse(w, network, "*:123"));
EXPECT_EQ("[::]:123", tryParse(w, network, "0::0", 123));
EXPECT_EQ("0.0.0.0:0", tryParse(w, network, "0.0.0.0"));
EXPECT_EQ("1.2.3.4:5678", tryParse(w, network, "1.2.3.4", 5678));
EXPECT_EQ("[12ab:cd::34]:321", tryParse(w, network, "[12ab:cd:0::0:34]:321", 432));
EXPECT_EQ("unix:foo/bar/baz", tryParse(network, "unix:foo/bar/baz"));
EXPECT_EQ("unix:foo/bar/baz", tryParse(w, network, "unix:foo/bar/baz"));
// We can parse services by name...
EXPECT_EQ("1.2.3.4:80", tryParse(network, "1.2.3.4:http", 5678));
EXPECT_EQ("[::]:80", tryParse(network, "[::]:http", 5678));
EXPECT_EQ("[12ab:cd::34]:80", tryParse(network, "[12ab:cd::34]:http", 5678));
EXPECT_EQ("*:80", tryParse(network, "*:http", 5678));
EXPECT_EQ("1.2.3.4:80", tryParse(w, network, "1.2.3.4:http", 5678));
EXPECT_EQ("[::]:80", tryParse(w, network, "[::]:http", 5678));
EXPECT_EQ("[12ab:cd::34]:80", tryParse(w, network, "[12ab:cd::34]:http", 5678));
EXPECT_EQ("*:80", tryParse(w, network, "*:http", 5678));
// It would be nice to test DNS lookup here but the test would not be very hermetic. Even
// localhost can map to different addresses depending on whether IPv6 is enabled. We do
......@@ -108,7 +109,7 @@ TEST(AsyncIo, OneWayPipe) {
kj::String result = pipe.in->tryRead(receiveBuffer, 3, 4).then([&](size_t n) {
EXPECT_EQ(3u, n);
return heapString(receiveBuffer, n);
}).wait();
}).wait(ioContext.waitScope);
EXPECT_EQ("foo", result);
}
......@@ -132,9 +133,9 @@ TEST(AsyncIo, TwoWayPipe) {
}).then([&](size_t n) {
EXPECT_EQ(3u, n);
return heapString(receiveBuffer2, n);
}).wait();
}).wait(ioContext.waitScope);
kj::String result2 = promise.wait();
kj::String result2 = promise.wait(ioContext.waitScope);
EXPECT_EQ("foo", result);
EXPECT_EQ("bar", result2);
......@@ -144,19 +145,19 @@ TEST(AsyncIo, PipeThread) {
auto ioContext = setupAsyncIo();
auto pipeThread = ioContext.provider->newPipeThread(
[](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
[](AsyncIoProvider& ioProvider, AsyncIoStream& stream, WaitScope& waitScope) {
char buf[4];
stream.write("foo", 3).wait();
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait());
stream.write("foo", 3).wait(waitScope);
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait(waitScope));
EXPECT_EQ("bar", heapString(buf, 3));
// Expect disconnect.
EXPECT_EQ(0, stream.tryRead(buf, 1, 1).wait());
EXPECT_EQ(0, stream.tryRead(buf, 1, 1).wait(waitScope));
});
char buf[4];
pipeThread.pipe->write("bar", 3).wait();
EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait());
pipeThread.pipe->write("bar", 3).wait(ioContext.waitScope);
EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait(ioContext.waitScope));
EXPECT_EQ("foo", heapString(buf, 3));
}
......@@ -166,21 +167,21 @@ TEST(AsyncIo, PipeThreadDisconnects) {
auto ioContext = setupAsyncIo();
auto pipeThread = ioContext.provider->newPipeThread(
[](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
[](AsyncIoProvider& ioProvider, AsyncIoStream& stream, WaitScope& waitScope) {
char buf[4];
stream.write("foo", 3).wait();
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait());
stream.write("foo", 3).wait(waitScope);
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait(waitScope));
EXPECT_EQ("bar", heapString(buf, 3));
});
char buf[4];
EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait());
EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait(ioContext.waitScope));
EXPECT_EQ("foo", heapString(buf, 3));
pipeThread.pipe->write("bar", 3).wait();
pipeThread.pipe->write("bar", 3).wait(ioContext.waitScope);
// Expect disconnect.
EXPECT_EQ(0, pipeThread.pipe->tryRead(buf, 1, 1).wait());
EXPECT_EQ(0, pipeThread.pipe->tryRead(buf, 1, 1).wait(ioContext.waitScope));
}
} // namespace
......
......@@ -716,7 +716,9 @@ public:
class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
public:
LowLevelAsyncIoProviderImpl(): eventLoop(eventPort) {}
LowLevelAsyncIoProviderImpl(): eventLoop(eventPort), waitScope(eventLoop) {}
inline WaitScope& getWaitScope() { return waitScope; }
Own<AsyncInputStream> wrapInputFd(int fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
......@@ -747,6 +749,7 @@ public:
private:
UnixEventPort eventPort;
EventLoop eventLoop;
WaitScope waitScope;
};
// =======================================================================================
......@@ -889,7 +892,7 @@ public:
}
PipeThread newPipeThread(
Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)> startFunc) override {
Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
int fds[2];
int type = SOCK_STREAM;
#if __linux__
......@@ -903,11 +906,11 @@ public:
auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
auto thread = heap<Thread>(kj::mvCapture(startFunc,
[threadFd](Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)>&& startFunc) {
[threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
LowLevelAsyncIoProviderImpl lowLevel;
auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
AsyncIoProviderImpl ioProvider(lowLevel);
startFunc(ioProvider, *stream);
startFunc(ioProvider, *stream, lowLevel.getWaitScope());
}));
return { kj::mv(thread), kj::mv(pipe) };
......@@ -931,7 +934,8 @@ Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
AsyncIoContext setupAsyncIo() {
auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
return { kj::mv(lowLevel), kj::mv(ioProvider) };
auto& waitScope = lowLevel->getWaitScope();
return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope };
}
} // namespace kj
......@@ -188,7 +188,7 @@ public:
};
virtual PipeThread newPipeThread(
Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)> startFunc) = 0;
Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) = 0;
// Create a new thread and set up a two-way pipe (socketpair) which can be used to communicate
// with it. One end of the pipe is passed to the thread's start function and the other end of
// the pipe is returned. The new thread also gets its own `AsyncIoProvider` instance and will
......@@ -278,6 +278,7 @@ Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel);
struct AsyncIoContext {
Own<LowLevelAsyncIoProvider> lowLevelProvider;
Own<AsyncIoProvider> provider;
WaitScope& waitScope;
};
AsyncIoContext setupAsyncIo();
......
......@@ -34,6 +34,7 @@ namespace kj {
class EventLoop;
template <typename T>
class Promise;
class WaitScope;
namespace _ { // private
......@@ -173,7 +174,7 @@ private:
};
void daemonize(kj::Promise<void>&& promise);
void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result);
void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope& waitScope);
Promise<void> yield();
class NeverDone {
......
......@@ -30,63 +30,69 @@ namespace {
TEST(Async, EvalVoid) {
EventLoop loop;
WaitScope waitScope(loop);
bool done = false;
Promise<void> promise = evalLater([&]() { done = true; });
EXPECT_FALSE(done);
promise.wait();
promise.wait(waitScope);
EXPECT_TRUE(done);
}
TEST(Async, EvalInt) {
EventLoop loop;
WaitScope waitScope(loop);
bool done = false;
Promise<int> promise = evalLater([&]() { done = true; return 123; });
EXPECT_FALSE(done);
EXPECT_EQ(123, promise.wait());
EXPECT_EQ(123, promise.wait(waitScope));
EXPECT_TRUE(done);
}
TEST(Async, There) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<int> a = 123;
bool done = false;
Promise<int> promise = a.then([&](int ai) { done = true; return ai + 321; });
EXPECT_FALSE(done);
EXPECT_EQ(444, promise.wait());
EXPECT_EQ(444, promise.wait(waitScope));
EXPECT_TRUE(done);
}
TEST(Async, ThereVoid) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<int> a = 123;
int value = 0;
Promise<void> promise = a.then([&](int ai) { value = ai; });
EXPECT_EQ(0, value);
promise.wait();
promise.wait(waitScope);
EXPECT_EQ(123, value);
}
TEST(Async, Exception) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
EXPECT_TRUE(kj::runCatchingExceptions([&]() {
// wait() only returns when compiling with -fno-exceptions.
EXPECT_EQ(123, promise.wait());
EXPECT_EQ(123, promise.wait(waitScope));
}) != nullptr);
}
TEST(Async, HandleException) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
......@@ -96,11 +102,12 @@ TEST(Async, HandleException) {
[](int i) { return i + 1; },
[&](Exception&& e) { EXPECT_EQ(line, e.getLine()); return 345; });
EXPECT_EQ(345, promise.wait());
EXPECT_EQ(345, promise.wait(waitScope));
}
TEST(Async, PropagateException) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
......@@ -112,11 +119,12 @@ TEST(Async, PropagateException) {
[](int i) { return i + 2; },
[&](Exception&& e) { EXPECT_EQ(line, e.getLine()); return 345; });
EXPECT_EQ(345, promise.wait());
EXPECT_EQ(345, promise.wait(waitScope));
}
TEST(Async, PropagateExceptionTypeChange) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater(
[&]() -> int { KJ_FAIL_ASSERT("foo") { return 123; } });
......@@ -128,11 +136,12 @@ TEST(Async, PropagateExceptionTypeChange) {
[](StringPtr s) -> StringPtr { return "bar"; },
[&](Exception&& e) -> StringPtr { EXPECT_EQ(line, e.getLine()); return "baz"; });
EXPECT_EQ("baz", promise2.wait());
EXPECT_EQ("baz", promise2.wait(waitScope));
}
TEST(Async, Then) {
EventLoop loop;
WaitScope waitScope(loop);
bool done = false;
......@@ -143,13 +152,14 @@ TEST(Async, Then) {
EXPECT_FALSE(done);
EXPECT_EQ(444, promise.wait());
EXPECT_EQ(444, promise.wait(waitScope));
EXPECT_TRUE(done);
}
TEST(Async, Chain) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater([&]() -> int { return 123; });
Promise<int> promise2 = evalLater([&]() -> int { return 321; });
......@@ -160,11 +170,12 @@ TEST(Async, Chain) {
});
});
EXPECT_EQ(444, promise3.wait());
EXPECT_EQ(444, promise3.wait(waitScope));
}
TEST(Async, SeparateFulfiller) {
EventLoop loop;
WaitScope waitScope(loop);
auto pair = newPromiseAndFulfiller<int>();
......@@ -172,11 +183,12 @@ TEST(Async, SeparateFulfiller) {
pair.fulfiller->fulfill(123);
EXPECT_FALSE(pair.fulfiller->isWaiting());
EXPECT_EQ(123, pair.promise.wait());
EXPECT_EQ(123, pair.promise.wait(waitScope));
}
TEST(Async, SeparateFulfillerVoid) {
EventLoop loop;
WaitScope waitScope(loop);
auto pair = newPromiseAndFulfiller<void>();
......@@ -184,7 +196,7 @@ TEST(Async, SeparateFulfillerVoid) {
pair.fulfiller->fulfill();
EXPECT_FALSE(pair.fulfiller->isWaiting());
pair.promise.wait();
pair.promise.wait(waitScope);
}
TEST(Async, SeparateFulfillerCanceled) {
......@@ -197,6 +209,7 @@ TEST(Async, SeparateFulfillerCanceled) {
TEST(Async, SeparateFulfillerChained) {
EventLoop loop;
WaitScope waitScope(loop);
auto pair = newPromiseAndFulfiller<Promise<int>>();
auto inner = newPromiseAndFulfiller<int>();
......@@ -207,7 +220,7 @@ TEST(Async, SeparateFulfillerChained) {
inner.fulfiller->fulfill(123);
EXPECT_EQ(123, pair.promise.wait());
EXPECT_EQ(123, pair.promise.wait(waitScope));
}
#if KJ_NO_EXCEPTIONS
......@@ -217,11 +230,12 @@ TEST(Async, SeparateFulfillerChained) {
TEST(Async, SeparateFulfillerDiscarded) {
EventLoop loop;
WaitScope waitScope(loop);
auto pair = newPromiseAndFulfiller<int>();
pair.fulfiller = nullptr;
EXPECT_ANY_THROW(pair.promise.wait());
EXPECT_ANY_THROW(pair.promise.wait(waitScope));
}
TEST(Async, SeparateFulfillerMemoryLeak) {
......@@ -231,6 +245,7 @@ TEST(Async, SeparateFulfillerMemoryLeak) {
TEST(Async, Ordering) {
EventLoop loop;
WaitScope waitScope(loop);
int counter = 0;
Promise<void> promises[6] = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
......@@ -285,7 +300,7 @@ TEST(Async, Ordering) {
promises[0].eagerlyEvaluate();
for (auto i: indices(promises)) {
kj::mv(promises[i]).wait();
kj::mv(promises[i]).wait(waitScope);
}
EXPECT_EQ(7, counter);
......@@ -293,6 +308,7 @@ TEST(Async, Ordering) {
TEST(Async, Fork) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater([&]() { return 123; });
......@@ -311,8 +327,8 @@ TEST(Async, Fork) {
auto releaseFork = kj::mv(fork);
}
EXPECT_EQ(456, branch1.wait());
EXPECT_EQ(789, branch2.wait());
EXPECT_EQ(456, branch1.wait(waitScope));
EXPECT_EQ(789, branch2.wait(waitScope));
}
struct RefcountedInt: public Refcounted {
......@@ -323,6 +339,7 @@ struct RefcountedInt: public Refcounted {
TEST(Async, ForkRef) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<Own<RefcountedInt>> promise = evalLater([&]() {
return refcounted<RefcountedInt>(123);
......@@ -343,46 +360,50 @@ TEST(Async, ForkRef) {
auto releaseFork = kj::mv(fork);
}
EXPECT_EQ(456, branch1.wait());
EXPECT_EQ(789, branch2.wait());
EXPECT_EQ(456, branch1.wait(waitScope));
EXPECT_EQ(789, branch2.wait(waitScope));
}
TEST(Async, ExclusiveJoin) {
{
EventLoop loop;
WaitScope waitScope(loop);
auto left = evalLater([&]() { return 123; });
auto right = newPromiseAndFulfiller<int>(); // never fulfilled
left.exclusiveJoin(kj::mv(right.promise));
EXPECT_EQ(123, left.wait());
EXPECT_EQ(123, left.wait(waitScope));
}
{
EventLoop loop;
WaitScope waitScope(loop);
auto left = newPromiseAndFulfiller<int>(); // never fulfilled
auto right = evalLater([&]() { return 123; });
left.promise.exclusiveJoin(kj::mv(right));
EXPECT_EQ(123, left.promise.wait());
EXPECT_EQ(123, left.promise.wait(waitScope));
}
{
EventLoop loop;
WaitScope waitScope(loop);
auto left = evalLater([&]() { return 123; });
auto right = evalLater([&]() { return 456; });
left.exclusiveJoin(kj::mv(right));
EXPECT_EQ(123, left.wait());
EXPECT_EQ(123, left.wait(waitScope));
}
{
EventLoop loop;
WaitScope waitScope(loop);
auto left = evalLater([&]() { return 123; });
auto right = evalLater([&]() { return 456; });
......@@ -391,7 +412,7 @@ TEST(Async, ExclusiveJoin) {
left.exclusiveJoin(kj::mv(right));
EXPECT_EQ(456, left.wait());
EXPECT_EQ(456, left.wait(waitScope));
}
}
......@@ -406,6 +427,7 @@ public:
TEST(Async, TaskSet) {
EventLoop loop;
WaitScope waitScope(loop);
ErrorHandlerImpl errorHandler;
TaskSet tasks(errorHandler);
......@@ -428,7 +450,7 @@ TEST(Async, TaskSet) {
evalLater([&]() {
EXPECT_EQ(3, counter++);
}).wait();
}).wait(waitScope);
EXPECT_EQ(4, counter);
EXPECT_EQ(1u, errorHandler.exceptionCount);
......@@ -447,6 +469,7 @@ TEST(Async, Attach) {
bool destroyed = false;
EventLoop loop;
WaitScope waitScope(loop);
Promise<int> promise = evalLater([&]() {
EXPECT_FALSE(destroyed);
......@@ -461,7 +484,7 @@ TEST(Async, Attach) {
});
EXPECT_FALSE(destroyed);
EXPECT_EQ(444, promise.wait());
EXPECT_EQ(444, promise.wait(waitScope));
EXPECT_TRUE(destroyed);
}
......@@ -469,23 +492,25 @@ TEST(Async, EagerlyEvaluate) {
bool called = false;
EventLoop loop;
WaitScope waitScope(loop);
Promise<void> promise = Promise<void>(READY_NOW).then([&]() {
called = true;
});
evalLater([]() {}).wait();
evalLater([]() {}).wait(waitScope);
EXPECT_FALSE(called);
promise.eagerlyEvaluate();
evalLater([]() {}).wait();
evalLater([]() {}).wait(waitScope);
EXPECT_TRUE(called);
}
TEST(Async, Daemonize) {
EventLoop loop;
WaitScope waitScope(loop);
bool ran1 = false;
bool ran2 = false;
......@@ -499,7 +524,7 @@ TEST(Async, Daemonize) {
EXPECT_FALSE(ran2);
EXPECT_FALSE(ran3);
evalLater([]() {}).wait();
evalLater([]() {}).wait(waitScope);
EXPECT_FALSE(ran1);
EXPECT_TRUE(ran2);
......@@ -522,6 +547,7 @@ public:
TEST(Async, SetRunnable) {
DummyEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
EXPECT_FALSE(port.runnable);
EXPECT_EQ(0, port.callCount);
......@@ -535,7 +561,7 @@ TEST(Async, SetRunnable) {
EXPECT_FALSE(port.runnable);
EXPECT_EQ(2, port.callCount);
promise.wait();
promise.wait(waitScope);
EXPECT_FALSE(port.runnable);
EXPECT_EQ(4, port.callCount);
}
......@@ -556,7 +582,7 @@ TEST(Async, SetRunnable) {
loop.run(10);
EXPECT_FALSE(port.runnable);
promise.wait();
promise.wait(waitScope);
EXPECT_FALSE(port.runnable);
EXPECT_EQ(8, port.callCount);
......
......@@ -53,10 +53,11 @@ public:
TEST_F(AsyncUnixTest, Signals) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
kill(getpid(), SIGUSR2);
siginfo_t info = port.onSignal(SIGUSR2).wait();
siginfo_t info = port.onSignal(SIGUSR2).wait(waitScope);
EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code);
}
......@@ -70,13 +71,14 @@ TEST_F(AsyncUnixTest, SignalWithValue) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
union sigval value;
memset(&value, 0, sizeof(value));
value.sival_int = 123;
sigqueue(getpid(), SIGUSR2, value);
siginfo_t info = port.onSignal(SIGUSR2).wait();
siginfo_t info = port.onSignal(SIGUSR2).wait(waitScope);
EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_QUEUE, info.si_code);
EXPECT_EQ(123, info.si_value.sival_int);
......@@ -86,6 +88,7 @@ TEST_F(AsyncUnixTest, SignalWithValue) {
TEST_F(AsyncUnixTest, SignalsMultiListen) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
port.onSignal(SIGIO).then([](siginfo_t&&) {
ADD_FAILURE() << "Received wrong signal.";
......@@ -95,7 +98,7 @@ TEST_F(AsyncUnixTest, SignalsMultiListen) {
kill(getpid(), SIGUSR2);
siginfo_t info = port.onSignal(SIGUSR2).wait();
siginfo_t info = port.onSignal(SIGUSR2).wait(waitScope);
EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code);
}
......@@ -103,15 +106,16 @@ TEST_F(AsyncUnixTest, SignalsMultiListen) {
TEST_F(AsyncUnixTest, SignalsMultiReceive) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
kill(getpid(), SIGUSR2);
kill(getpid(), SIGIO);
siginfo_t info = port.onSignal(SIGUSR2).wait();
siginfo_t info = port.onSignal(SIGUSR2).wait(waitScope);
EXPECT_EQ(SIGUSR2, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code);
info = port.onSignal(SIGIO).wait();
info = port.onSignal(SIGIO).wait(waitScope);
EXPECT_EQ(SIGIO, info.si_signo);
EXPECT_SI_CODE(SI_USER, info.si_code);
}
......@@ -119,6 +123,7 @@ TEST_F(AsyncUnixTest, SignalsMultiReceive) {
TEST_F(AsyncUnixTest, SignalsAsync) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
// Arrange for a signal to be sent from another thread.
pthread_t mainThread = pthread_self();
......@@ -127,7 +132,7 @@ TEST_F(AsyncUnixTest, SignalsAsync) {
pthread_kill(mainThread, SIGUSR2);
});
siginfo_t info = port.onSignal(SIGUSR2).wait();
siginfo_t info = port.onSignal(SIGUSR2).wait(waitScope);
EXPECT_EQ(SIGUSR2, info.si_signo);
#if __linux__
EXPECT_SI_CODE(SI_TKILL, info.si_code);
......@@ -139,6 +144,7 @@ TEST_F(AsyncUnixTest, SignalsNoWait) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
bool receivedSigusr2 = false;
bool receivedSigio = false;
......@@ -178,18 +184,20 @@ TEST_F(AsyncUnixTest, SignalsNoWait) {
TEST_F(AsyncUnixTest, Poll) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
int pipefds[2];
KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); });
KJ_SYSCALL(pipe(pipefds));
KJ_SYSCALL(write(pipefds[1], "foo", 3));
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait());
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait(waitScope));
}
TEST_F(AsyncUnixTest, PollMultiListen) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
int bogusPipefds[2];
KJ_SYSCALL(pipe(bogusPipefds));
......@@ -206,12 +214,13 @@ TEST_F(AsyncUnixTest, PollMultiListen) {
KJ_DEFER({ close(pipefds[1]); close(pipefds[0]); });
KJ_SYSCALL(write(pipefds[1], "foo", 3));
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait());
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait(waitScope));
}
TEST_F(AsyncUnixTest, PollMultiReceive) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
int pipefds[2];
KJ_SYSCALL(pipe(pipefds));
......@@ -223,13 +232,14 @@ TEST_F(AsyncUnixTest, PollMultiReceive) {
KJ_DEFER({ close(pipefds2[1]); close(pipefds2[0]); });
KJ_SYSCALL(write(pipefds2[1], "bar", 3));
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait());
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds2[0], POLLIN | POLLPRI).wait());
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait(waitScope));
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds2[0], POLLIN | POLLPRI).wait(waitScope));
}
TEST_F(AsyncUnixTest, PollAsync) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
// Make a pipe and wait on its read end while another thread writes to it.
int pipefds[2];
......@@ -241,7 +251,7 @@ TEST_F(AsyncUnixTest, PollAsync) {
});
// Wait for the event in this thread.
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait());
EXPECT_EQ(POLLIN, port.onFdEvent(pipefds[0], POLLIN | POLLPRI).wait(waitScope));
}
TEST_F(AsyncUnixTest, PollNoWait) {
......@@ -249,6 +259,7 @@ TEST_F(AsyncUnixTest, PollNoWait) {
UnixEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
int pipefds[2];
KJ_SYSCALL(pipe(pipefds));
......
......@@ -199,26 +199,13 @@ void EventPort::setRunnable(bool runnable) {}
EventLoop::EventLoop()
: port(_::NullEventPort::instance),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {
KJ_REQUIRE(threadLocalEventLoop == nullptr, "This thread already has an EventLoop.");
threadLocalEventLoop = this;
}
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {}
EventLoop::EventLoop(EventPort& port)
: port(port),
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {
KJ_REQUIRE(threadLocalEventLoop == nullptr, "This thread already has an EventLoop.");
threadLocalEventLoop = this;
}
daemons(kj::heap<_::TaskSetImpl>(_::LoggingErrorHandler::instance)) {}
EventLoop::~EventLoop() noexcept(false) {
KJ_REQUIRE(threadLocalEventLoop == this,
"EventLoop being destroyed in a different thread than it was created.") {
break;
}
KJ_DEFER(threadLocalEventLoop = nullptr);
// Destroy all "daemon" tasks, noting that their destructors might try to access the EventLoop
// some more.
daemons = nullptr;
......@@ -237,6 +224,12 @@ EventLoop::~EventLoop() noexcept(false) {
}
break;
}
KJ_REQUIRE(threadLocalEventLoop != this,
"EventLoop destroyed while still current for the thread.") {
threadLocalEventLoop = nullptr;
break;
}
}
void EventLoop::run(uint maxTurnCount) {
......@@ -291,10 +284,24 @@ void EventLoop::setRunnable(bool runnable) {
}
}
void EventLoop::enterScope() {
KJ_REQUIRE(threadLocalEventLoop == nullptr, "This thread already has an EventLoop.");
threadLocalEventLoop = this;
}
void EventLoop::leaveScope() {
KJ_REQUIRE(threadLocalEventLoop == this,
"WaitScope destroyed in a different thread than it was created in.") {
break;
}
threadLocalEventLoop = nullptr;
}
namespace _ { // private
void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result) {
EventLoop& loop = currentEventLoop();
void waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result, WaitScope& waitScope) {
EventLoop& loop = waitScope.loop;
KJ_REQUIRE(&loop == threadLocalEventLoop, "WaitScope not valid for this thread.");
KJ_REQUIRE(!loop.running, "wait() is not allowed from within event callbacks.");
BoolEvent doneEvent;
......
......@@ -32,6 +32,7 @@
namespace kj {
class EventLoop;
class WaitScope;
template <typename T>
class Promise;
......@@ -186,14 +187,10 @@ public:
// actual I/O. To solve this, use `kj::evalLater()` to yield control; this way, all other events
// in the queue will get a chance to run before your callback is executed.
T wait();
T wait(WaitScope& waitScope);
// Run the event loop until the promise is fulfilled, then return its result. If the promise
// is rejected, throw an exception.
//
// wait() cannot be called recursively -- that is, an event callback cannot call wait().
// Instead, callbacks that need to perform more async operations should return a promise and
// rely on promise chaining.
//
// wait() is primarily useful at the top level of a program -- typically, within the function
// that allocated the EventLoop. For example, a program that performs one or two RPCs and then
// exits would likely use wait() in its main() function to wait on each RPC. On the other hand,
......@@ -205,13 +202,27 @@ public:
// use `then()` to set an appropriate handler for the exception case, so that the promise you
// actually wait on never throws.
//
// `waitScope` is an object proving that the caller is in a scope where wait() is allowed. By
// convention, any function which might call wait(), or which might call another function which
// might call wait(), must take `WaitScope&` as one of its parameters. This is needed for two
// reasons:
// * `wait()` is not allowed during an event callback, because event callbacks are themselves
// called during some other `wait()`, and such recursive `wait()`s would only be able to
// complete in LIFO order, which might mean that the outer `wait()` ends up waiting longer
// than it is supposed to. To prevent this, a `WaitScope` cannot be constructed or used during
// an event callback.
// * Since `wait()` runs the event loop, unrelated event callbacks may execute before `wait()`
// returns. This means that anyone calling `wait()` must be reentrant -- state may change
// around them in arbitrary ways. Therefore, callers really need to know if a function they
// are calling might wait(), and the `WaitScope&` parameter makes this clear.
//
// Note that `wait()` consumes the promise on which it is called, in the sense of move semantics.
// After returning, the original promise is no longer valid.
//
// TODO(someday): Implement fibers, and let them call wait() even when they are handling an
// event.
ForkedPromise<T> fork();
ForkedPromise<T> fork() KJ_WARN_UNUSED_RESULT;
// Forks the promise, so that multiple different clients can independently wait on the result.
// `T` must be copy-constructable for this to work. Or, in the special case where `T` is
// `Own<U>`, `U` must have a method `Own<U> addRef()` which returns a new reference to the same
......@@ -578,10 +589,35 @@ private:
bool turn();
void setRunnable(bool runnable);
void enterScope();
void leaveScope();
friend void _::daemonize(kj::Promise<void>&& promise);
friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result);
friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result,
WaitScope& waitScope);
friend class _::Event;
friend class WaitScope;
};
class WaitScope {
// Represents a scope in which asynchronous programming can occur. A `WaitScope` should usually
// be allocated on the stack and serves two purposes:
// * While the `WaitScope` exists, its `EventLoop` is registered as the current loop for the
// thread. Most operations dealing with `Promise` (including all of its methods) do not work
// unless the thread has a current `EventLoop`.
// * `WaitScope` may be passed to `Promise::wait()` to synchronously wait for a particular
// promise to complete. See `Promise::wait()` for an extended discussion.
public:
inline explicit WaitScope(EventLoop& loop): loop(loop) { loop.enterScope(); }
inline ~WaitScope() { loop.leaveScope(); }
KJ_DISALLOW_COPY(WaitScope);
private:
EventLoop& loop;
friend class EventLoop;
friend void _::waitImpl(Own<_::PromiseNode>&& node, _::ExceptionOrValue& result,
WaitScope& waitScope);
};
} // namespace kj
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment