Commit 06fd35c4 authored by Kenton Varda's avatar Kenton Varda

Split AsyncIoProvider into high-level and low-level interfaces, so that…

Split AsyncIoProvider into high-level and low-level interfaces, so that alternate EventPort implementations can implement the low-level interface and reuse the higher-level stuff.
parent ed03e100
......@@ -34,7 +34,7 @@ static __thread EzRpcContext* threadEzContext = nullptr;
class EzRpcContext: public kj::Refcounted {
public:
EzRpcContext(): ioProvider(kj::setupIoEventLoop()) {
EzRpcContext(): ioContext(kj::setupAsyncIo()) {
threadEzContext = this;
}
......@@ -47,7 +47,11 @@ public:
}
kj::AsyncIoProvider& getIoProvider() {
return *ioProvider;
return *ioContext.provider;
}
kj::LowLevelAsyncIoProvider& getLowLevelIoProvider() {
return *ioContext.lowLevelProvider;
}
static kj::Own<EzRpcContext> getThreadLocal() {
......@@ -60,7 +64,7 @@ public:
}
private:
kj::Own<kj::AsyncIoProvider> ioProvider;
kj::AsyncIoContext ioContext;
};
// =======================================================================================
......@@ -116,7 +120,8 @@ struct EzRpcClient::Impl {
Impl(int socketFd)
: context(EzRpcContext::getThreadLocal()),
setupPromise(kj::Promise<void>(kj::READY_NOW).fork()),
clientContext(kj::heap<ClientContext>(context->getIoProvider().wrapSocketFd(socketFd))) {}
clientContext(kj::heap<ClientContext>(
context->getLowLevelIoProvider().wrapSocketFd(socketFd))) {}
};
EzRpcClient::EzRpcClient(kj::StringPtr serverAddress, uint defaultPort)
......@@ -145,6 +150,10 @@ kj::AsyncIoProvider& EzRpcClient::getIoProvider() {
return impl->context->getIoProvider();
}
kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
return impl->context->getLowLevelIoProvider();
}
// =======================================================================================
struct EzRpcServer::Impl final: public SturdyRefRestorer<Text>, public kj::TaskSet::ErrorHandler {
......@@ -209,7 +218,7 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<Text>, public kj::TaskS
: context(EzRpcContext::getThreadLocal()),
portPromise(kj::Promise<uint>(port).fork()),
tasks(*this) {
acceptLoop(context->getIoProvider().wrapListenSocketFd(socketFd));
acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd));
}
void acceptLoop(kj::Own<kj::ConnectionReceiver>&& listener) {
......@@ -268,4 +277,8 @@ kj::AsyncIoProvider& EzRpcServer::getIoProvider() {
return impl->context->getIoProvider();
}
kj::LowLevelAsyncIoProvider& EzRpcServer::getLowLevelIoProvider() {
return impl->context->getLowLevelIoProvider();
}
} // namespace capnp
......@@ -26,7 +26,7 @@
#include "rpc.h"
namespace kj { class AsyncIoProvider; }
namespace kj { class AsyncIoProvider; class LowLevelAsyncIoProvider; }
namespace capnp {
......@@ -109,6 +109,10 @@ public:
// Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want
// to do some non-RPC I/O in asynchronous fashion.
kj::LowLevelAsyncIoProvider& getLowLevelIoProvider();
// Get the underlying LowLevelAsyncIoProvider set up by the RPC system. This is useful if you
// want to do some non-RPC I/O in asynchronous fashion.
private:
struct Impl;
kj::Own<Impl> impl;
......@@ -159,6 +163,10 @@ public:
// Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want
// to do some non-RPC I/O in asynchronous fashion.
kj::LowLevelAsyncIoProvider& getLowLevelIoProvider();
// Get the underlying LowLevelAsyncIoProvider set up by the RPC system. This is useful if you
// want to do some non-RPC I/O in asynchronous fashion.
private:
struct Impl;
kj::Own<Impl> impl;
......
......@@ -59,7 +59,7 @@ private:
int& callCount;
};
kj::Own<kj::AsyncIoStream> runServer(kj::AsyncIoProvider& ioProvider, int& callCount) {
kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider, int& callCount) {
return ioProvider.newPipeThread(
[&callCount](kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream) {
TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
......@@ -86,11 +86,11 @@ Capability::Client getPersistentCap(RpcSystem<rpc::twoparty::SturdyRefHostId>& c
}
TEST(TwoPartyNetwork, Basic) {
auto ioProvider = kj::setupIoEventLoop();
auto ioContext = kj::setupAsyncIo();
int callCount = 0;
auto stream = runServer(*ioProvider, callCount);
TwoPartyVatNetwork network(*stream, rpc::twoparty::Side::CLIENT);
auto serverThread = runServer(*ioContext.provider, callCount);
TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network);
// Request the particular capability from the server.
......@@ -131,12 +131,12 @@ TEST(TwoPartyNetwork, Basic) {
}
TEST(TwoPartyNetwork, Pipelining) {
auto ioProvider = kj::setupIoEventLoop();
auto ioContext = kj::setupAsyncIo();
int callCount = 0;
int reverseCallCount = 0; // Calls back from server to client.
auto stream = runServer(*ioProvider, callCount);
TwoPartyVatNetwork network(*stream, rpc::twoparty::Side::CLIENT);
auto serverThread = runServer(*ioContext.provider, callCount);
TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network);
bool disconnected = false;
......@@ -184,7 +184,7 @@ TEST(TwoPartyNetwork, Pipelining) {
EXPECT_FALSE(drained);
// What if we disconnect?
stream->shutdownWrite();
serverThread.pipe->shutdownWrite();
// The other side should also disconnect.
disconnectPromise.wait();
......
......@@ -120,8 +120,8 @@ protected:
};
TEST_F(SerializeAsyncTest, ParseAsync) {
auto ioProvider = kj::setupIoEventLoop();
auto input = ioProvider->wrapInputFd(fds[0]);
auto ioContext = kj::setupAsyncIo();
auto input = ioContext.lowLevelProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput);
......@@ -138,8 +138,8 @@ TEST_F(SerializeAsyncTest, ParseAsync) {
}
TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) {
auto ioProvider = kj::setupIoEventLoop();
auto input = ioProvider->wrapInputFd(fds[0]);
auto ioContext = kj::setupAsyncIo();
auto input = ioContext.lowLevelProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput);
......@@ -156,8 +156,8 @@ TEST_F(SerializeAsyncTest, ParseAsyncOddSegmentCount) {
}
TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) {
auto ioProvider = kj::setupIoEventLoop();
auto input = ioProvider->wrapInputFd(fds[0]);
auto ioContext = kj::setupAsyncIo();
auto input = ioContext.lowLevelProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput);
......@@ -174,8 +174,8 @@ TEST_F(SerializeAsyncTest, ParseAsyncEvenSegmentCount) {
}
TEST_F(SerializeAsyncTest, WriteAsync) {
auto ioProvider = kj::setupIoEventLoop();
auto output = ioProvider->wrapOutputFd(fds[1]);
auto ioContext = kj::setupAsyncIo();
auto output = ioContext.lowLevelProvider->wrapOutputFd(fds[1]);
TestMessageBuilder message(1);
auto root = message.getRoot<TestAllTypes>();
......@@ -197,8 +197,8 @@ TEST_F(SerializeAsyncTest, WriteAsync) {
}
TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
auto ioProvider = kj::setupIoEventLoop();
auto output = ioProvider->wrapOutputFd(fds[1]);
auto ioContext = kj::setupAsyncIo();
auto output = ioContext.lowLevelProvider->wrapOutputFd(fds[1]);
TestMessageBuilder message(7);
auto root = message.getRoot<TestAllTypes>();
......@@ -220,8 +220,8 @@ TEST_F(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
}
TEST_F(SerializeAsyncTest, WriteAsyncEvenSegmentCount) {
auto ioProvider = kj::setupIoEventLoop();
auto output = ioProvider->wrapOutputFd(fds[1]);
auto ioContext = kj::setupAsyncIo();
auto output = ioContext.lowLevelProvider->wrapOutputFd(fds[1]);
TestMessageBuilder message(10);
auto root = message.getRoot<TestAllTypes>();
......
......@@ -30,8 +30,8 @@ namespace kj {
namespace {
TEST(AsyncIo, SimpleNetwork) {
auto ioProvider = setupIoEventLoop();
auto& network = ioProvider->getNetwork();
auto ioContext = setupAsyncIo();
auto& network = ioContext.provider->getNetwork();
Own<ConnectionReceiver> listener;
Own<AsyncIoStream> server;
......@@ -76,8 +76,8 @@ String tryParseRemote(Network& network, StringPtr text, uint portHint = 0) {
}
TEST(AsyncIo, AddressParsing) {
auto ioProvider = setupIoEventLoop();
auto& network = ioProvider->getNetwork();
auto ioContext = setupAsyncIo();
auto& network = ioContext.provider->getNetwork();
EXPECT_EQ("*:0", tryParseLocal(network, "*"));
EXPECT_EQ("*:123", tryParseLocal(network, "123"));
......@@ -92,9 +92,9 @@ TEST(AsyncIo, AddressParsing) {
}
TEST(AsyncIo, OneWayPipe) {
auto ioProvider = setupIoEventLoop();
auto ioContext = setupAsyncIo();
auto pipe = ioProvider->newOneWayPipe();
auto pipe = ioContext.provider->newOneWayPipe();
char receiveBuffer[4];
pipe.out->write("foo", 3).daemonize([](kj::Exception&& exception) {
......@@ -110,9 +110,9 @@ TEST(AsyncIo, OneWayPipe) {
}
TEST(AsyncIo, TwoWayPipe) {
auto ioProvider = setupIoEventLoop();
auto ioContext = setupAsyncIo();
auto pipe = ioProvider->newTwoWayPipe();
auto pipe = ioContext.provider->newTwoWayPipe();
char receiveBuffer1[4];
char receiveBuffer2[4];
......@@ -137,9 +137,10 @@ TEST(AsyncIo, TwoWayPipe) {
}
TEST(AsyncIo, PipeThread) {
auto ioProvider = setupIoEventLoop();
auto ioContext = setupAsyncIo();
auto stream = ioProvider->newPipeThread([](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
auto pipeThread = ioContext.provider->newPipeThread(
[](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
char buf[4];
stream.write("foo", 3).wait();
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait());
......@@ -150,17 +151,18 @@ TEST(AsyncIo, PipeThread) {
});
char buf[4];
stream->write("bar", 3).wait();
EXPECT_EQ(3u, stream->tryRead(buf, 3, 4).wait());
pipeThread.pipe->write("bar", 3).wait();
EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait());
EXPECT_EQ("foo", heapString(buf, 3));
}
TEST(AsyncIo, PipeThreadDisconnects) {
// Like above, but in this case we expect the main thread to detect the pipe thread disconnecting.
auto ioProvider = setupIoEventLoop();
auto ioContext = setupAsyncIo();
auto stream = ioProvider->newPipeThread([](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
auto pipeThread = ioContext.provider->newPipeThread(
[](AsyncIoProvider& ioProvider, AsyncIoStream& stream) {
char buf[4];
stream.write("foo", 3).wait();
EXPECT_EQ(3u, stream.tryRead(buf, 3, 4).wait());
......@@ -168,13 +170,13 @@ TEST(AsyncIo, PipeThreadDisconnects) {
});
char buf[4];
EXPECT_EQ(3u, stream->tryRead(buf, 3, 4).wait());
EXPECT_EQ(3u, pipeThread.pipe->tryRead(buf, 3, 4).wait());
EXPECT_EQ("foo", heapString(buf, 3));
stream->write("bar", 3).wait();
pipeThread.pipe->write("bar", 3).wait();
// Expect disconnect.
EXPECT_EQ(0, stream->tryRead(buf, 1, 1).wait());
EXPECT_EQ(0, pipeThread.pipe->tryRead(buf, 1, 1).wait());
}
} // namespace
......
......@@ -54,24 +54,36 @@ void setNonblocking(int fd) {
}
}
void setCloseOnExec(int fd) {
int flags;
KJ_SYSCALL(flags = fcntl(fd, F_GETFD));
if ((flags & FD_CLOEXEC) == 0) {
KJ_SYSCALL(fcntl(fd, F_SETFD, flags | FD_CLOEXEC));
}
}
class OwnedFileDescriptor {
public:
OwnedFileDescriptor(int fd): fd(fd) {
#if __linux__
// Linux has alternate APIs allowing these flags to be set at FD creation; make sure we always
// use them.
KJ_DREQUIRE(fcntl(fd, F_GETFD) & FD_CLOEXEC, "You forgot to set CLOEXEC.");
KJ_DREQUIRE(fcntl(fd, F_GETFL) & O_NONBLOCK, "You forgot to set NONBLOCK.");
#else
// On non-Linux, we have to set the flags non-atomically.
fcntl(fd, F_SETFD, fcntl(fd, F_GETFD) | FD_CLOEXEC);
fcntl(fd, F_SETFL, fcntl(fd, F_GETFL) | O_NONBLOCK);
#endif
OwnedFileDescriptor(int fd, uint flags): fd(fd), flags(flags) {
if (flags & LowLevelAsyncIoProvider::ALREADY_NONBLOCK) {
KJ_DREQUIRE(fcntl(fd, F_GETFL) & O_NONBLOCK, "You claimed you set NONBLOCK, but you didn't.");
} else {
setNonblocking(fd);
}
if (flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) {
if (flags & LowLevelAsyncIoProvider::ALREADY_CLOEXEC) {
KJ_DREQUIRE(fcntl(fd, F_GETFD) & FD_CLOEXEC,
"You claimed you set CLOEXEC, but you didn't.");
} else {
setCloseOnExec(fd);
}
}
}
~OwnedFileDescriptor() noexcept(false) {
// Don't use SYSCALL() here because close() should not be repeated on EINTR.
if (close(fd) < 0) {
if ((flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) && close(fd) < 0) {
KJ_FAIL_SYSCALL("close", errno, fd) {
// Recoverable exceptions are safe in destructors.
break;
......@@ -81,14 +93,17 @@ public:
protected:
const int fd;
private:
uint flags;
};
// =======================================================================================
class AsyncStreamFd: public AsyncIoStream {
class AsyncStreamFd: public OwnedFileDescriptor, public AsyncIoStream {
public:
AsyncStreamFd(UnixEventPort& eventPort, int readFd, int writeFd)
: eventPort(eventPort), readFd(readFd), writeFd(writeFd) {}
AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags)
: OwnedFileDescriptor(fd, flags), eventPort(eventPort) {}
virtual ~AsyncStreamFd() noexcept(false) {}
Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
......@@ -108,7 +123,7 @@ public:
Promise<void> write(const void* buffer, size_t size) override {
ssize_t writeResult;
KJ_NONBLOCKING_SYSCALL(writeResult = ::write(writeFd, buffer, size)) {
KJ_NONBLOCKING_SYSCALL(writeResult = ::write(fd, buffer, size)) {
return READY_NOW;
}
......@@ -122,7 +137,7 @@ public:
size -= n;
}
return eventPort.onFdEvent(writeFd, POLLOUT).then([=](short) {
return eventPort.onFdEvent(fd, POLLOUT).then([=](short) {
return write(buffer, size);
});
}
......@@ -138,14 +153,11 @@ public:
void shutdownWrite() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// UnixAsyncIoProvider interface.
KJ_REQUIRE(readFd == writeFd, "shutdownWrite() is only implemented on sockets.");
KJ_SYSCALL(shutdown(writeFd, SHUT_WR));
KJ_SYSCALL(shutdown(fd, SHUT_WR));
}
private:
UnixEventPort& eventPort;
int readFd;
int writeFd;
bool gotHup = false;
Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
......@@ -155,13 +167,13 @@ private:
// be included in the final return value.
ssize_t n;
KJ_NONBLOCKING_SYSCALL(n = ::read(readFd, buffer, maxBytes)) {
KJ_NONBLOCKING_SYSCALL(n = ::read(fd, buffer, maxBytes)) {
return alreadyRead;
}
if (n < 0) {
// Read would block.
return eventPort.onFdEvent(readFd, POLLIN | POLLRDHUP).then([=](short events) {
return eventPort.onFdEvent(fd, POLLIN | POLLRDHUP).then([=](short events) {
gotHup = events & (POLLHUP | POLLRDHUP);
return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead);
});
......@@ -186,7 +198,7 @@ private:
minBytes -= n;
maxBytes -= n;
alreadyRead += n;
return eventPort.onFdEvent(readFd, POLLIN | POLLRDHUP).then([=](short events) {
return eventPort.onFdEvent(fd, POLLIN | POLLRDHUP).then([=](short events) {
gotHup = events & (POLLHUP | POLLRDHUP);
return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead);
});
......@@ -210,7 +222,7 @@ private:
}
ssize_t writeResult;
KJ_NONBLOCKING_SYSCALL(writeResult = ::writev(writeFd, iov.begin(), iov.size())) {
KJ_NONBLOCKING_SYSCALL(writeResult = ::writev(fd, iov.begin(), iov.size())) {
// error
return READY_NOW;
}
......@@ -223,7 +235,7 @@ private:
if (n < firstPiece.size()) {
// Only part of the first piece was consumed. Wait for POLLOUT and then write again.
firstPiece = firstPiece.slice(n, firstPiece.size());
return eventPort.onFdEvent(writeFd, POLLOUT).then([=](short) {
return eventPort.onFdEvent(fd, POLLOUT).then([=](short) {
return writeInternal(firstPiece, morePieces);
});
} else if (morePieces.size() == 0) {
......@@ -240,23 +252,6 @@ private:
}
};
class Socket final: public OwnedFileDescriptor, public AsyncStreamFd {
public:
Socket(UnixEventPort& eventPort, int fd)
: OwnedFileDescriptor(fd), AsyncStreamFd(eventPort, fd, fd) {}
};
class ThreadSocket final: public Thread, public OwnedFileDescriptor, public AsyncStreamFd {
// Combination thread and socket. The thread must be joined strictly after the socket is closed.
public:
template <typename StartFunc>
ThreadSocket(UnixEventPort& eventPort, int fd, StartFunc&& startFunc)
: Thread(kj::fwd<StartFunc>(startFunc)),
OwnedFileDescriptor(fd),
AsyncStreamFd(eventPort, fd, fd) {}
};
// =======================================================================================
class SocketAddress {
......@@ -504,10 +499,16 @@ private:
// =======================================================================================
static constexpr uint NEW_FD_FLAGS =
#if __linux__
LowLevelAsyncIoProvider::ALREADY_CLOEXEC || LowLevelAsyncIoProvider::ALREADY_NONBLOCK ||
#endif
LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor {
public:
FdConnectionReceiver(UnixEventPort& eventPort, int fd)
: OwnedFileDescriptor(fd), eventPort(eventPort) {}
FdConnectionReceiver(UnixEventPort& eventPort, int fd, uint flags)
: OwnedFileDescriptor(fd, flags), eventPort(eventPort) {}
Promise<Own<AsyncIoStream>> accept() override {
int newFd;
......@@ -520,7 +521,7 @@ public:
#endif
if (newFd >= 0) {
return Own<AsyncIoStream>(heap<Socket>(eventPort, newFd));
return Own<AsyncIoStream>(heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS));
} else {
int error = errno;
......@@ -562,28 +563,66 @@ public:
UnixEventPort& eventPort;
};
class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
public:
LowLevelAsyncIoProviderImpl(): eventLoop(eventPort) {}
Own<AsyncInputStream> wrapInputFd(int fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Own<AsyncOutputStream> wrapOutputFd(int fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Own<AsyncIoStream> wrapSocketFd(int fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(int fd, uint flags = 0) override {
auto result = heap<AsyncStreamFd>(eventPort, fd, flags);
return eventPort.onFdEvent(fd, POLLOUT).then(kj::mvCapture(result,
[fd](Own<AsyncIoStream>&& stream, short events) {
int err;
socklen_t errlen = sizeof(err);
KJ_SYSCALL(getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &errlen));
if (err != 0) {
KJ_FAIL_SYSCALL("connect()", err) { break; }
}
return kj::mv(stream);
}));
}
Own<ConnectionReceiver> wrapListenSocketFd(int fd, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, flags);
}
private:
UnixEventPort eventPort;
EventLoop eventLoop;
};
// =======================================================================================
class LocalSocketAddress final: public LocalAddress {
public:
LocalSocketAddress(UnixEventPort& eventPort, SocketAddress addr)
: eventPort(eventPort), addr(addr) {}
LocalSocketAddress(LowLevelAsyncIoProvider& lowLevel, SocketAddress addr)
: lowLevel(lowLevel), addr(addr) {}
Own<ConnectionReceiver> listen() override {
int fd = addr.socket(SOCK_STREAM);
auto result = heap<FdConnectionReceiver>(eventPort, fd);
// We always enable SO_REUSEADDR because having to take your server down for five minutes
// before it can restart really sucks.
int optval = 1;
KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)));
{
KJ_ON_SCOPE_FAILURE(close(fd));
// We always enable SO_REUSEADDR because having to take your server down for five minutes
// before it can restart really sucks.
int optval = 1;
KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)));
addr.bind(fd);
addr.bind(fd);
// TODO(someday): Let queue size be specified explicitly in string addresses.
KJ_SYSCALL(::listen(fd, SOMAXCONN));
// TODO(someday): Let queue size be specified explicitly in string addresses.
KJ_SYSCALL(::listen(fd, SOMAXCONN));
}
return kj::mv(result);
return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS);
}
String toString() override {
......@@ -591,30 +630,24 @@ public:
}
private:
UnixEventPort& eventPort;
LowLevelAsyncIoProvider& lowLevel;
SocketAddress addr;
};
class RemoteSocketAddress final: public RemoteAddress {
public:
RemoteSocketAddress(UnixEventPort& eventPort, SocketAddress addr)
: eventPort(eventPort), addr(addr) {}
RemoteSocketAddress(LowLevelAsyncIoProvider& lowLevel, SocketAddress addr)
: lowLevel(lowLevel), addr(addr) {}
Promise<Own<AsyncIoStream>> connect() override {
int fd = addr.socket(SOCK_STREAM);
auto result = heap<Socket>(eventPort, fd);
addr.connect(fd);
return eventPort.onFdEvent(fd, POLLOUT).then(kj::mvCapture(result,
[fd](Own<AsyncIoStream>&& stream, short events) {
int err;
socklen_t errlen = sizeof(err);
KJ_SYSCALL(getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &errlen));
if (err != 0) {
KJ_FAIL_SYSCALL("connect()", err) { break; }
}
return kj::mv(stream);
}));
{
KJ_ON_SCOPE_FAILURE(close(fd));
addr.connect(fd);
}
return lowLevel.wrapConnectingSocketFd(fd, NEW_FD_FLAGS);
}
String toString() override {
......@@ -622,46 +655,46 @@ public:
}
private:
UnixEventPort& eventPort;
LowLevelAsyncIoProvider& lowLevel;
SocketAddress addr;
};
class SocketNetwork final: public Network {
public:
explicit SocketNetwork(UnixEventPort& eventPort): eventPort(eventPort) {}
explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
Promise<Own<LocalAddress>> parseLocalAddress(StringPtr addr, uint portHint = 0) override {
auto& eventPortCopy = eventPort;
auto& lowLevelCopy = lowLevel;
return evalLater(mvCapture(heapString(addr),
[&eventPortCopy,portHint](String&& addr) -> Own<LocalAddress> {
return heap<LocalSocketAddress>(eventPortCopy, SocketAddress::parseLocal(addr, portHint));
[&lowLevelCopy,portHint](String&& addr) -> Own<LocalAddress> {
return heap<LocalSocketAddress>(lowLevelCopy, SocketAddress::parseLocal(addr, portHint));
}));
}
Promise<Own<RemoteAddress>> parseRemoteAddress(StringPtr addr, uint portHint = 0) override {
auto& eventPortCopy = eventPort;
auto& lowLevelCopy = lowLevel;
return evalLater(mvCapture(heapString(addr),
[&eventPortCopy,portHint](String&& addr) -> Own<RemoteAddress> {
return heap<RemoteSocketAddress>(eventPortCopy, SocketAddress::parse(addr, portHint));
[&lowLevelCopy,portHint](String&& addr) -> Own<RemoteAddress> {
return heap<RemoteSocketAddress>(lowLevelCopy, SocketAddress::parse(addr, portHint));
}));
}
Own<LocalAddress> getLocalSockaddr(const void* sockaddr, uint len) override {
return Own<LocalAddress>(heap<LocalSocketAddress>(eventPort, SocketAddress(sockaddr, len)));
return Own<LocalAddress>(heap<LocalSocketAddress>(lowLevel, SocketAddress(sockaddr, len)));
}
Own<RemoteAddress> getRemoteSockaddr(const void* sockaddr, uint len) override {
return Own<RemoteAddress>(heap<RemoteSocketAddress>(eventPort, SocketAddress(sockaddr, len)));
return Own<RemoteAddress>(heap<RemoteSocketAddress>(lowLevel, SocketAddress(sockaddr, len)));
}
private:
UnixEventPort& eventPort;
LowLevelAsyncIoProvider& lowLevel;
};
// =======================================================================================
class UnixAsyncIoProvider final: public AsyncIoProvider {
class AsyncIoProviderImpl final: public AsyncIoProvider {
public:
UnixAsyncIoProvider()
: eventLoop(eventPort), network(eventPort) {}
AsyncIoProviderImpl(LowLevelAsyncIoProvider& lowLevel)
: lowLevel(lowLevel), network(lowLevel) {}
OneWayPipe newOneWayPipe() override {
int fds[2];
......@@ -670,7 +703,10 @@ public:
#else
KJ_SYSCALL(pipe(fds));
#endif
return OneWayPipe { heap<Socket>(eventPort, fds[0]), heap<Socket>(eventPort, fds[1]) };
return OneWayPipe {
lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS),
lowLevel.wrapOutputFd(fds[1], NEW_FD_FLAGS)
};
}
TwoWayPipe newTwoWayPipe() override {
......@@ -680,14 +716,17 @@ public:
type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
#endif
KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
return TwoWayPipe { { heap<Socket>(eventPort, fds[0]), heap<Socket>(eventPort, fds[1]) } };
return TwoWayPipe { {
lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS),
lowLevel.wrapSocketFd(fds[1], NEW_FD_FLAGS)
} };
}
Network& getNetwork() override {
return network;
}
Own<AsyncIoStream> newPipeThread(
PipeThread newPipeThread(
Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)> startFunc) override {
int fds[2];
int type = SOCK_STREAM;
......@@ -697,36 +736,23 @@ public:
KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
int threadFd = fds[1];
KJ_ON_SCOPE_FAILURE(close(threadFd));
return heap<ThreadSocket>(eventPort, fds[0], kj::mvCapture(startFunc,
auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
auto thread = heap<Thread>(kj::mvCapture(startFunc,
[threadFd](Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)>&& startFunc) {
KJ_DEFER(KJ_SYSCALL(close(threadFd)));
UnixAsyncIoProvider ioProvider;
auto stream = ioProvider.wrapSocketFd(threadFd);
LowLevelAsyncIoProviderImpl lowLevel;
auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
AsyncIoProviderImpl ioProvider(lowLevel);
startFunc(ioProvider, *stream);
}));
}
Own<AsyncInputStream> wrapInputFd(int fd) override {
setNonblocking(fd);
return heap<AsyncStreamFd>(eventPort, fd, -1);
}
Own<AsyncOutputStream> wrapOutputFd(int fd) override {
setNonblocking(fd);
return heap<AsyncStreamFd>(eventPort, -1, fd);
}
Own<AsyncIoStream> wrapSocketFd(int fd) override {
setNonblocking(fd);
return heap<AsyncStreamFd>(eventPort, fd, fd);
}
Own<ConnectionReceiver> wrapListenSocketFd(int fd) override {
setNonblocking(fd);
return heap<FdConnectionReceiver>(eventPort, fd);
return { kj::mv(thread), kj::mv(pipe) };
}
private:
UnixEventPort eventPort;
EventLoop eventLoop;
LowLevelAsyncIoProvider& lowLevel;
SocketNetwork network;
};
......@@ -736,8 +762,14 @@ Promise<void> AsyncInputStream::read(void* buffer, size_t bytes) {
return read(buffer, bytes, bytes).then([](size_t) {});
}
Own<AsyncIoProvider> setupIoEventLoop() {
return heap<UnixAsyncIoProvider>();
Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
return kj::heap<AsyncIoProviderImpl>(lowLevel);
}
AsyncIoContext setupAsyncIo() {
auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
return { kj::mv(lowLevel), kj::mv(ioProvider) };
}
} // namespace kj
......@@ -26,6 +26,7 @@
#include "async.h"
#include "function.h"
#include "thread.h"
namespace kj {
......@@ -177,61 +178,111 @@ public:
// - Local IP wildcard (local addresses only; covers both v4 and v6): "*", "*:80", ":80", "80"
// - Unix domain: "unix:/path/to/socket"
virtual Own<AsyncIoStream> newPipeThread(
struct PipeThread {
// A combination of a thread and a two-way pipe that communicates with that thread.
//
// The fields are intentionally ordered so that the pipe will be destroyed (and therefore
// disconnected) before the thread is destroyed (and therefore joined). Thus if the thread
// arranges to exit when it detects disconnect, destruction should be clean.
Own<Thread> thread;
Own<AsyncIoStream> pipe;
};
virtual PipeThread newPipeThread(
Function<void(AsyncIoProvider& ioProvider, AsyncIoStream& stream)> startFunc) = 0;
// Create a new thread and set up a two-way pipe (socketpair) which can be used to communicate
// with it. One end of the pipe is passed to the thread's starct function and the other end of
// with it. One end of the pipe is passed to the thread's start function and the other end of
// the pipe is returned. The new thread also gets its own `AsyncIoProvider` instance and will
// already have an active `EventLoop` when `startFunc` is called.
//
// The returned stream's destructor first closes its end of the pipe then waits for the thread to
// finish (joins it). The thread should therefore be designed to exit soon after receiving EOF
// on the input stream.
//
// TODO(someday): I'm not entirely comfortable with this interface. It seems to be doing too
// much at once but I'm not sure how to cleanly break it down.
};
// ---------------------------------------------------------------------------
// Unix-only methods
class LowLevelAsyncIoProvider {
// Similar to `AsyncIoProvider`, but represents a lower-level interface that may differ on
// different operating systems. You should prefer to use `AsyncIoProvider` over this interface
// whenever possible, as `AsyncIoProvider` is portable and friendlier to dependency-injection.
//
// On Unix, this interface can be used to import native file descriptors into the async framework.
// Different implementations of this interface might work on top of different event handling
// primitives, such as poll vs. epoll vs. kqueue vs. some higher-level event library.
//
// TODO(cleanup): Should these be in a subclass?
// On Windows, this interface can be used to import native HANDLEs into the async framework.
// Different implementations of this interface might work on top of different event handling
// primitives, such as I/O completion ports vs. completion routines.
//
// TODO(port): Actually implement Windows support.
virtual Own<AsyncInputStream> wrapInputFd(int fd) = 0;
public:
// ---------------------------------------------------------------------------
// Unix-specific stuff
enum Flags {
// Flags controlling how to wrap a file descriptor.
TAKE_OWNERSHIP = 1 << 0,
// The returned object should own the file descriptor, automatically closing it when destroyed.
// The close-on-exec flag will be set on the descriptor if it is not already.
//
// If this flag is not used, then the file descriptor is not automatically closed and the
// close-on-exec flag is not modified.
ALREADY_CLOEXEC = 1 << 1,
// Indicates that the close-on-exec flag is known already to be set, so need not be set again.
// Only relevant when combined with TAKE_OWNERSHIP.
//
// On Linux, all system calls which yield new file descriptors have flags or variants which
// set the close-on-exec flag immediately. Unfortunately, other OS's do not.
ALREADY_NONBLOCK = 1 << 2
// Indicates that the file descriptor is known already to be in non-blocking mode, so the flag
// need not be set again. Otherwise, all wrap*Fd() methods will enable non-blocking mode
// automatically.
//
// On Linux, all system calls which yield new file descriptors have flags or variants which
// enable non-blocking mode immediately. Unfortunately, other OS's do not.
};
virtual Own<AsyncInputStream> wrapInputFd(int fd, uint flags = 0) = 0;
// Create an AsyncInputStream wrapping a file descriptor.
//
// Does not take ownership of the descriptor.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<AsyncOutputStream> wrapOutputFd(int fd) = 0;
virtual Own<AsyncOutputStream> wrapOutputFd(int fd, uint flags = 0) = 0;
// Create an AsyncOutputStream wrapping a file descriptor.
//
// Does not take ownership of the descriptor.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<AsyncIoStream> wrapSocketFd(int fd) = 0;
virtual Own<AsyncIoStream> wrapSocketFd(int fd, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a socket file descriptor.
//
// Does not take ownership of the descriptor.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(int fd, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a socket that is in the process of connecting. The returned
// promise should not resolve until connection has completed -- traditionally indicated by the
// descriptor becoming writable.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<ConnectionReceiver> wrapListenSocketFd(int fd) = 0;
virtual Own<ConnectionReceiver> wrapListenSocketFd(int fd, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a listen socket file descriptor. This socket should already
// have had `bind()` and `listen()` called on it, so it's ready for `accept()`.
//
// Does not take ownership of the descriptor.
//
// This will set `fd` to non-blocking mode (i.e. set O_NONBLOCK) if it isn't set already.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
};
// ---------------------------------------------------------------------------
// Windows-only methods
Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel);
// Make a new AsyncIoProvider wrapping a `LowLevelAsyncIoProvider`.
// TODO(port): IOCP
struct AsyncIoContext {
Own<LowLevelAsyncIoProvider> lowLevelProvider;
Own<AsyncIoProvider> provider;
};
Own<AsyncIoProvider> setupIoEventLoop();
AsyncIoContext setupAsyncIo();
// Convenience method which sets up the current thread with everything it needs to do async I/O.
// The returned object contains an `EventLoop` which is wrapping an appropriate `EventPort` for
// doing I/O on the host system, so everything is ready for the thread to start making async calls
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment