Commit 06fd35c4 authored by Kenton Varda's avatar Kenton Varda

Split AsyncIoProvider into high-level and low-level interfaces, so that…

Split AsyncIoProvider into high-level and low-level interfaces, so that alternate EventPort implementations can implement the low-level interface and reuse the higher-level stuff.
parent ed03e100
......@@ -34,7 +34,7 @@ static __thread EzRpcContext* threadEzContext = nullptr;
class EzRpcContext: public kj::Refcounted {
public:
EzRpcContext(): ioProvider(kj::setupIoEventLoop()) {
EzRpcContext(): ioContext(kj::setupAsyncIo()) {
threadEzContext = this;
}
......@@ -47,7 +47,11 @@ public:
}
kj::AsyncIoProvider& getIoProvider() {
return *ioProvider;
return *ioContext.provider;
}
kj::LowLevelAsyncIoProvider& getLowLevelIoProvider() {
return *ioContext.lowLevelProvider;
}
static kj::Own<EzRpcContext> getThreadLocal() {
......@@ -60,7 +64,7 @@ public:
}
private:
kj::Own<kj::AsyncIoProvider> ioProvider;
kj::AsyncIoContext ioContext;
};
// =======================================================================================
......@@ -116,7 +120,8 @@ struct EzRpcClient::Impl {
Impl(int socketFd)
: context(EzRpcContext::getThreadLocal()),
setupPromise(kj::Promise<void>(kj::READY_NOW).fork()),
clientContext(kj::heap<ClientContext>(context->getIoProvider().wrapSocketFd(socketFd))) {}
clientContext(kj::heap<ClientContext>(
context->getLowLevelIoProvider().wrapSocketFd(socketFd))) {}
};
EzRpcClient::EzRpcClient(kj::StringPtr serverAddress, uint defaultPort)
......@@ -145,6 +150,10 @@ kj::AsyncIoProvider& EzRpcClient::getIoProvider() {
return impl->context->getIoProvider();
}
kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
return impl->context->getLowLevelIoProvider();
}
// =======================================================================================
struct EzRpcServer::Impl final: public SturdyRefRestorer<Text>, public kj::TaskSet::ErrorHandler {
......@@ -209,7 +218,7 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<Text>, public kj::TaskS
: context(EzRpcContext::getThreadLocal()),
portPromise(kj::Promise<uint>(port).fork()),
tasks(*this) {
acceptLoop(context->getIoProvider().wrapListenSocketFd(socketFd));
acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd));
}
void acceptLoop(kj::Own<kj::ConnectionReceiver>&& listener) {
......@@ -268,4 +277,8 @@ kj::AsyncIoProvider& EzRpcServer::getIoProvider() {
return impl->context->getIoProvider();
}
kj::LowLevelAsyncIoProvider& EzRpcServer::getLowLevelIoProvider() {
return impl->context->getLowLevelIoProvider();
}
} // namespace capnp
......@@ -26,7 +26,7 @@
#include "rpc.h"
namespace kj { class AsyncIoProvider; }
namespace kj { class AsyncIoProvider; class LowLevelAsyncIoProvider; }
namespace capnp {
......@@ -109,6 +109,10 @@ public:
// Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want
// to do some non-RPC I/O in asynchronous fashion.
kj::LowLevelAsyncIoProvider& getLowLevelIoProvider();
// Get the underlying LowLevelAsyncIoProvider set up by the RPC system. This is useful if you
// want to do some non-RPC I/O in asynchronous fashion.
private:
struct Impl;
kj::Own<Impl> impl;
......@@ -159,6 +163,10 @@ public:
// Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want
// to do some non-RPC I/O in asynchronous fashion.
kj::LowLevelAsyncIoProvider& getLowLevelIoProvider();
// Get the underlying LowLevelAsyncIoProvider set up by the RPC system. This is useful if you
// want to do some non-RPC I/O in asynchronous fashion.
private:
struct Impl;
kj::Own<Impl> impl;
......
......@@ -59,7 +59,7 @@ private:
int& callCount;
};
kj::Own<kj::AsyncIoStream> runServer(kj::AsyncIoProvider& ioProvider, int& callCount) {
kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider, int& callCount) {
return ioProvider.newPipeThread(
[&callCount](kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream) {
TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
......@@ -86,11 +86,11 @@ Capability::Client getPersistentCap(RpcSystem<rpc::twoparty::SturdyRefHostId>& c
}
TEST(TwoPartyNetwork, Basic) {
auto ioProvider = kj::setupIoEventLoop();
auto ioContext = kj::setupAsyncIo();
int callCount = 0;
auto stream = runServer(*ioProvider, callCount);
TwoPartyVatNetwork network(*stream, rpc::twoparty::Side::CLIENT);
auto serverThread = runServer(*ioContext.provider, callCount);
TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network);
// Request the particular capability from the server.
......@@ -131,12 +131,12 @@ TEST(TwoPartyNetwork, Basic) {
}
TEST(TwoPartyNetwork, Pipelining) {
auto ioProvider = kj::setupIoEventLoop();
auto ioContext = kj::setupAsyncIo();
int callCount = 0;
int reverseCallCount = 0; // Calls back from server to client.
auto stream = runServer(*ioProvider, callCount);
TwoPartyVatNetwork network(*stream, rpc::twoparty::Side::CLIENT);
auto serverThread = runServer(*ioContext.provider, callCount);
TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network);
bool disconnected = false;
......@@ -184,7 +184,7 @@ TEST(TwoPartyNetwork, Pipelining) {
EXPECT_FALSE(drained);
// What if we disconnect?
stream->shutdownWrite();
serverThread.pipe->shutdownWrite();
// The other side should also disconnect.
disconnectPromise.wait();
......
......@@ -120,8 +120,8 @@ protected:
};
TEST_F(SerializeAsyncTest, ParseAsync) {
auto ioProvider = kj::setupIoEventLoop();
auto input = ioProvider->wrapInputFd(fds[0]);
auto ioContext = kj::setupAsyncIo();
auto input = ioContext.lowLevelProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput);
......@@ -138,8 +138,8 @@ TEST_F(SerializeAsyncTest, ParseAsync) {
}
TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) {
auto ioProvider = kj::setupIoEventLoop();
auto input = ioProvider->wrapInputFd(fds[0]);
auto ioContext = kj::setupAsyncIo();
auto input = ioContext.lowLevelProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput);
......@@ -156,8 +156,8 @@ TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) {
}
TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) {
auto ioProvider = kj::setupIoEventLoop();
auto input = ioProvider->wrapInputFd(fds[0]);
auto ioContext = kj::setupAsyncIo();
auto input = ioContext.lowLevelProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput);
......@@ -174,8 +174,8 @@ TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) {
}
TEST_F(SerializeAsyncTest, WriteAsync) {
auto ioProvider = kj::setupIoEventLoop();
auto output = ioProvider->wrapOutputFd(fds[1]);
auto ioContext = kj::setupAsyncIo();
auto output = ioContext.lowLevelProvider->wrapOutputFd(fds[1]);
TestMessageBuilder message(1);
auto root = message.getRoot<TestAllTypes>();
......@@ -197,8 +197,8 @@ TEST_F(SerializeAsyncTest, WriteAsync) {
}
TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
auto ioProvider = kj::setupIoEventLoop();
auto output = ioProvider->wrapOutputFd(fds[1]);
auto ioContext = kj::setupAsyncIo();
auto output = ioContext.lowLevelProvider->wrapOutputFd(fds[1]);
TestMessageBuilder message(7);
auto root = message.getRoot<TestAllTypes>();
......@@ -220,8 +220,8 @@ TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
}
TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) {
auto ioProvider = kj::setupIoEventLoop();
auto output = ioProvider->wrapOutputFd(fds[1]);
auto ioContext = kj::setupAsyncIo();
auto output = ioContext.lowLevelProvider->wrapOutputFd(fds[1]);
TestMessageBuilder message(10);
auto root = message.getRoot<TestAllTypes>();
......
......@@ -30,8 +30,8 @@ namespace kj {
namespace {
TEST(AsyncIo, SimpleNetwork) {
auto ioProvider = setupIoEventLoop();
auto& network = ioProvider->getNetwork();
auto ioContext = setupAsyncIo();
auto& network = ioContext.provider->getNetwork();
Own<ConnectionReceiver> listener;
Own<AsyncIoStream> server;
......@@ -76,8 +76,8 @@ String tryParseRemote(Network& network, StringPtr text, uint portHint = 0) {
}
TEST(AsyncIo, AddressParsing) {
auto ioProvider = setupIoEventLoop();
auto& network = ioProvider->getNetwork();
auto ioContext = setupAsyncIo();
auto& network = ioContext.provider->getNetwork();
EXPECT_EQ("*:0", tryParseLocal(network, "*"));
EXPECT_EQ("*:123", tryParseLocal(network, "123"));
......@@ -92,9 +92,9 @@ TEST(AsyncIo, AddressParsing) {
}
TEST(AsyncIo, OneWayPipe) {
auto ioProvider = setupIoEventLoop();
auto ioContext = setupAsyncIo();
auto pipe = ioProvider->newOneWayPipe();
auto pipe = ioContext.provider->newOneWayPipe();
char receiveBuffer[4];
pipe.out->write("foo", 3).daemonize([](kj::Exception&& exception) {
......@@ -110,9 +110,9 @@ TEST(AsyncIo, OneWayPipe) {
}
TEST(AsyncIo, TwoWayPipe) {
auto ioProvider = setupIoEventLoop();
auto ioContext = setupAsyncIo();
auto pipe = ioProvider->newTwoWayPipe();
auto pipe = ioContext.provider->newTwoWayPipe();
char receiveBuffer1[4];
char receiveBuffer2[4];
......@@ -137,9 +137,10 @@ TEST(AsyncIo, TwoWayPipe) {
}
TEST(AsyncIo, PipeThread) {
auto ioProvider = setupIoEventLoop();
auto ioContext = setupAsyncIo();
auto stream = ioProvider->newPipeThread([](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
auto pipeThread = ioContext.provider->newPipeThread(
[](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
char buf[4];
stream.write("foo", 3).wait();
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait());
......@@ -150,17 +151,18 @@ TEST(AsyncIo, PipeThread) {
});
char buf[4];
stream->write("bar", 3).wait();
EXPECT_EQ(3u, stream->tryRead(buf, 3, 4).wait());
pipeThread.pipe->write("bar", 3).wait();
EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait());
EXPECT_EQ("foo", heapString(buf, 3));
}
TEST(AsyncIo, PipeThreadDisconnects) {
// Like above, but in this case we expect the main thread to detect the pipe thread disconnecting.
auto ioProvider = setupIoEventLoop();
auto ioContext = setupAsyncIo();
auto stream = ioProvider->newPipeThread([](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
auto pipeThread = ioContext.provider->newPipeThread(
[](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
char buf[4];
stream.write("foo", 3).wait();
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait());
......@@ -168,13 +170,13 @@ TEST(AsyncIo, PipeThreadDisconnects) {
});
char buf[4];
EXPECT_EQ(3u, stream->tryRead(buf, 3, 4).wait());
EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait());
EXPECT_EQ("foo", heapString(buf, 3));
stream->write("bar", 3).wait();
pipeThread.pipe->write("bar", 3).wait();
// Expect disconnect.
EXPECT_EQ(0, stream->tryRead(buf, 1, 1).wait());
EXPECT_EQ(0, pipeThread.pipe->tryRead(buf, 1, 1).wait());
}
} // namespace
......
This diff is collapsed.
......@@ -26,6 +26,7 @@
#include "async.h"
#include "function.h"
#include "thread.h"
namespace kj {
......@@ -177,61 +178,111 @@ public:
// - Local IP wildcard (local addresses only; covers both v4 and v6): "*", "*:80", ":80", "80"
// - Unix domain: "unix:/path/to/socket"
virtual Own<AsyncIoStream> newPipeThread(
struct PipeThread {
// A combination of a thread and a two-way pipe that communicates with that thread.
//
// The fields are intentionally ordered so that the pipe will be destroyed (and therefore
// disconnected) before the thread is destroyed (and therefore joined). Thus if the thread
// arranges to exit when it detects disconnect, destruction should be clean.
Own<Thread> thread;
Own<AsyncIoStream> pipe;
};
virtual PipeThread newPipeThread(
Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)> startFunc) = 0;
// Create a new thread and set up a two-way pipe (socketpair) which can be used to communicate
// with it. One end of the pipe is passed to the thread's starct function and the other end of
// with it. One end of the pipe is passed to the thread's start function and the other end of
// the pipe is returned. The new thread also gets its own `AsyncIoProvider` instance and will
// already have an active `EventLoop` when `startFunc` is called.
//
// The returned stream's destructor first closes its end of the pipe then waits for the thread to
// finish (joins it). The thread should therefore be designed to exit soon after receiving EOF
// on the input stream.
//
// TODO(someday): I'm not entirely comfortable with this interface. It seems to be doing too
// much at once but I'm not sure how to cleanly break it down.
};
// ---------------------------------------------------------------------------
// Unix-only methods
class LowLevelAsyncIoProvider {
// Similar to `AsyncIoProvider`, but represents a lower-level interface that may differ on
// different operating systems. You should prefer to use `AsyncIoProvider` over this interface
// whenever possible, as `AsyncIoProvider` is portable and friendlier to dependency-injection.
//
// On Unix, this interface can be used to import native file descriptors into the async framework.
// Different implementations of this interface might work on top of different event handling
// primitives, such as poll vs. epoll vs. kqueue vs. some higher-level event library.
//
// TODO(cleanup): Should these be in a subclass?
// On Windows, this interface can be used to import native HANDLEs into the async framework.
// Different implementations of this interface might work on top of different event handling
// primitives, such as I/O completion ports vs. completion routines.
//
// TODO(port): Actually implement Windows support.
virtual Own<AsyncInputStream> wrapInputFd(int fd) = 0;
public:
// ---------------------------------------------------------------------------
// Unix-specific stuff
enum Flags {
// Flags controlling how to wrap a file descriptor.
TAKE_OWNERSHIP = 1 << 0,
// The returned object should own the file descriptor, automatically closing it when destroyed.
// The close-on-exec flag will be set on the descriptor if it is not already.
//
// If this flag is not used, then the file descriptor is not automatically closed and the
// close-on-exec flag is not modified.
ALREADY_CLOEXEC = 1 << 1,
// Indicates that the close-on-exec flag is known already to be set, so need not be set again.
// Only relevant when combined with TAKE_OWNERSHIP.
//
// On Linux, all system calls which yield new file descriptors have flags or variants which
// set the close-on-exec flag immediately. Unfortunately, other OS's do not.
ALREADY_NONBLOCK = 1 << 2
// Indicates that the file descriptor is known already to be in non-blocking mode, so the flag
// need not be set again. Otherwise, all wrap*Fd() methods will enable non-blocking mode
// automatically.
//
// On Linux, all system calls which yield new file descriptors have flags or variants which
// enable non-blocking mode immediately. Unfortunately, other OS's do not.
};
virtual Own<AsyncInputStream> wrapInputFd(int fd, uint flags = 0) = 0;
// Create an AsyncInputStream wrapping a file descriptor.
//
// Does not take ownership of the descriptor.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<AsyncOutputStream> wrapOutputFd(int fd) = 0;
virtual Own<AsyncOutputStream> wrapOutputFd(int fd, uint flags = 0) = 0;
// Create an AsyncOutputStream wrapping a file descriptor.
//
// Does not take ownership of the descriptor.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<AsyncIoStream> wrapSocketFd(int fd) = 0;
virtual Own<AsyncIoStream> wrapSocketFd(int fd, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a socket file descriptor.
//
// Does not take ownership of the descriptor.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(int fd, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a socket that is in the process of connecting. The returned
// promise should not resolve until connection has completed -- traditionally indicated by the
// descriptor becoming writable.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<ConnectionReceiver> wrapListenSocketFd(int fd) = 0;
virtual Own<ConnectionReceiver> wrapListenSocketFd(int fd, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a listen socket file descriptor. This socket should already
// have had `bind()` and `listen()` called on it, so it's ready for `accept()`.
//
// Does not take ownership of the descriptor.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
};
// ---------------------------------------------------------------------------
// Windows-only methods
Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel);
// Make a new AsyncIoProvider wrapping a `LowLevelAsyncIoProvider`.
// TODO(port): IOCP
struct AsyncIoContext {
Own<LowLevelAsyncIoProvider> lowLevelProvider;
Own<AsyncIoProvider> provider;
};
Own<AsyncIoProvider> setupIoEventLoop();
AsyncIoContext setupAsyncIo();
// Convenience method which sets up the current thread with everything it needs to do async I/O.
// The returned object contains an `EventLoop` which is wrapping an appropriate `EventPort` for
// doing I/O on the host system, so everything is ready for the thread to start making async calls
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment