Commit 626a8c8a authored by Kenton Varda's avatar Kenton Varda

Add AsyncCapabilityStream, an abstraction of unix sockets SCM_RIGHTS FD passing.

This allows a lot of nice design patterns, and might later the the basis for capnp 3-party handoff within a machine.
parent 73a01874
...@@ -192,6 +192,47 @@ TEST(AsyncIo, TwoWayPipe) { ...@@ -192,6 +192,47 @@ TEST(AsyncIo, TwoWayPipe) {
EXPECT_EQ("bar", result2); EXPECT_EQ("bar", result2);
} }
#if !_WIN32
TEST(AsyncIo, CapabilityPipe) {
auto ioContext = setupAsyncIo();
auto pipe = ioContext.provider->newCapabilityPipe();
auto pipe2 = ioContext.provider->newCapabilityPipe();
char receiveBuffer1[4];
char receiveBuffer2[4];
// Expect to receive a stream, then write "bar" to it, then receive "foo" from it.
Own<AsyncCapabilityStream> receivedStream;
auto promise = pipe2.ends[1]->receiveStream()
.then([&](Own<AsyncCapabilityStream> stream) {
receivedStream = kj::mv(stream);
return receivedStream->write("bar", 3);
}).then([&]() {
return receivedStream->tryRead(receiveBuffer2, 3, 4);
}).then([&](size_t n) {
EXPECT_EQ(3u, n);
return heapString(receiveBuffer2, n);
});
// Send a stream, then write "foo" to the other end of the sent stream, then receive "bar"
// from it.
kj::String result = pipe2.ends[0]->sendStream(kj::mv(pipe.ends[1]))
.then([&]() {
return pipe.ends[0]->write("foo", 3);
}).then([&]() {
return pipe.ends[0]->tryRead(receiveBuffer1, 3, 4);
}).then([&](size_t n) {
EXPECT_EQ(3u, n);
return heapString(receiveBuffer1, n);
}).wait(ioContext.waitScope);
kj::String result2 = promise.wait(ioContext.waitScope);
EXPECT_EQ("bar", result);
EXPECT_EQ("foo", result2);
}
#endif
TEST(AsyncIo, PipeThread) { TEST(AsyncIo, PipeThread) {
auto ioContext = setupAsyncIo(); auto ioContext = setupAsyncIo();
......
...@@ -112,10 +112,11 @@ private: ...@@ -112,10 +112,11 @@ private:
// ======================================================================================= // =======================================================================================
class AsyncStreamFd: public OwnedFileDescriptor, public AsyncIoStream { class AsyncStreamFd: public OwnedFileDescriptor, public AsyncCapabilityStream {
public: public:
AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags) AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags)
: OwnedFileDescriptor(fd, flags), : OwnedFileDescriptor(fd, flags),
eventPort(eventPort),
observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ_WRITE) {} observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ_WRITE) {}
virtual ~AsyncStreamFd() noexcept(false) {} virtual ~AsyncStreamFd() noexcept(false) {}
...@@ -198,6 +199,57 @@ public: ...@@ -198,6 +199,57 @@ public:
*length = socklen; *length = socklen;
} }
kj::Promise<Maybe<Own<AsyncCapabilityStream>>> tryReceiveStream() override {
return tryReceiveFdImpl<Own<AsyncCapabilityStream>>();
}
kj::Promise<void> sendStream(Own<AsyncCapabilityStream> stream) override {
auto downcasted = stream.downcast<AsyncStreamFd>();
auto promise = sendFd(downcasted->fd);
return promise.attach(kj::mv(downcasted));
}
kj::Promise<kj::Maybe<AutoCloseFd>> tryReceiveFd() override {
return tryReceiveFdImpl<AutoCloseFd>();
}
kj::Promise<void> sendFd(int fdToSend) override {
struct msghdr msg;
struct iovec iov;
union {
struct cmsghdr cmsg;
char cmsgSpace[CMSG_LEN(sizeof(int))];
};
memset(&msg, 0, sizeof(msg));
memset(&iov, 0, sizeof(iov));
memset(cmsgSpace, 0, sizeof(cmsgSpace));
char c = 0;
iov.iov_base = &c;
iov.iov_len = 1;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = &cmsg;
msg.msg_controllen = sizeof(cmsgSpace);
cmsg.cmsg_len = sizeof(cmsgSpace);
cmsg.cmsg_level = SOL_SOCKET;
cmsg.cmsg_type = SCM_RIGHTS;
*reinterpret_cast<int*>(CMSG_DATA(&cmsg)) = fdToSend;
ssize_t n;
KJ_NONBLOCKING_SYSCALL(n = sendmsg(fd, &msg, 0));
if (n < 0) {
return observer.whenBecomesWritable().then([this,fdToSend]() {
return sendFd(fdToSend);
});
} else {
KJ_ASSERT(n == 1);
return kj::READY_NOW;
}
}
Promise<void> waitConnected() { Promise<void> waitConnected() {
// Wait until initial connection has completed. This actually just waits until it is writable. // Wait until initial connection has completed. This actually just waits until it is writable.
...@@ -222,6 +274,7 @@ public: ...@@ -222,6 +274,7 @@ public:
} }
private: private:
UnixEventPort& eventPort;
UnixEventPort::FdObserver observer; UnixEventPort::FdObserver observer;
Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes, Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
...@@ -354,6 +407,59 @@ private: ...@@ -354,6 +407,59 @@ private:
} }
} }
} }
template <typename T>
kj::Promise<kj::Maybe<T>> tryReceiveFdImpl() {
struct msghdr msg;
memset(&msg, 0, sizeof(msg));
struct iovec iov;
memset(&iov, 0, sizeof(iov));
char c;
iov.iov_base = &c;
iov.iov_len = 1;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
// Allocate space to receive a cmsg.
union {
struct cmsghdr cmsg;
char cmsgSpace[CMSG_SPACE(sizeof(int))];
};
msg.msg_control = &cmsg;
msg.msg_controllen = sizeof(cmsgSpace);
ssize_t n;
KJ_NONBLOCKING_SYSCALL(n = recvmsg(fd, &msg, MSG_CMSG_CLOEXEC));
if (n < 0) {
return observer.whenBecomesReadable().then([this]() {
return tryReceiveFdImpl<T>();
});
} else if (n == 0) {
return kj::Maybe<T>(nullptr);
} else {
KJ_REQUIRE(msg.msg_controllen >= sizeof(cmsg),
"expected to receive FD over socket; received data instead");
// We expect an SCM_RIGHTS message with a single FD.
KJ_REQUIRE(cmsg.cmsg_level == SOL_SOCKET);
KJ_REQUIRE(cmsg.cmsg_type == SCM_RIGHTS);
KJ_REQUIRE(cmsg.cmsg_len == CMSG_LEN(sizeof(int)));
int receivedFd;
memcpy(&receivedFd, CMSG_DATA(&cmsg), sizeof(receivedFd));
return kj::Maybe<T>(wrapFd(receivedFd, (T*)nullptr));
}
}
AutoCloseFd wrapFd(int newFd, AutoCloseFd*) {
auto result = AutoCloseFd(newFd);
setCloseOnExec(result);
return result;
}
Own<AsyncCapabilityStream> wrapFd(int newFd, Own<AsyncCapabilityStream>*) {
return kj::heap<AsyncStreamFd>(eventPort, newFd, LowLevelAsyncIoProvider::TAKE_OWNERSHIP);
}
}; };
// ======================================================================================= // =======================================================================================
...@@ -951,6 +1057,9 @@ public: ...@@ -951,6 +1057,9 @@ public:
Own<AsyncIoStream> wrapSocketFd(int fd, uint flags = 0) override { Own<AsyncIoStream> wrapSocketFd(int fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags); return heap<AsyncStreamFd>(eventPort, fd, flags);
} }
Own<AsyncCapabilityStream> wrapUnixSocketFd(Fd fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Promise<Own<AsyncIoStream>> wrapConnectingSocketFd( Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
int fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override { int fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override {
// Unfortunately connect() doesn't fit the mold of KJ_NONBLOCKING_SYSCALL, since it indicates // Unfortunately connect() doesn't fit the mold of KJ_NONBLOCKING_SYSCALL, since it indicates
...@@ -1375,6 +1484,19 @@ public: ...@@ -1375,6 +1484,19 @@ public:
} }; } };
} }
CapabilityPipe newCapabilityPipe() override {
int fds[2];
int type = SOCK_STREAM;
#if __linux__ && !__BIONIC__
type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
#endif
KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
return CapabilityPipe { {
lowLevel.wrapUnixSocketFd(fds[0], NEW_FD_FLAGS),
lowLevel.wrapUnixSocketFd(fds[1], NEW_FD_FLAGS)
} };
}
Network& getNetwork() override { Network& getNetwork() override {
return network; return network;
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "async-io-internal.h" #include "async-io-internal.h"
#include "debug.h" #include "debug.h"
#include "vector.h" #include "vector.h"
#include "io.h"
#if _WIN32 #if _WIN32
#include <winsock2.h> #include <winsock2.h>
...@@ -181,6 +182,34 @@ Maybe<Promise<uint64_t>> AsyncOutputStream::tryPumpFrom( ...@@ -181,6 +182,34 @@ Maybe<Promise<uint64_t>> AsyncOutputStream::tryPumpFrom(
return nullptr; return nullptr;
} }
Promise<Own<AsyncCapabilityStream>> AsyncCapabilityStream::receiveStream() {
return tryReceiveStream()
.then([](Maybe<Own<AsyncCapabilityStream>>&& result)
-> Promise<Own<AsyncCapabilityStream>> {
KJ_IF_MAYBE(r, result) {
return kj::mv(*r);
} else {
return KJ_EXCEPTION(FAILED, "EOF when expecting to receive capability");
}
});
}
Promise<AutoCloseFd> AsyncCapabilityStream::receiveFd() {
return tryReceiveFd().then([](Maybe<AutoCloseFd>&& result) -> Promise<AutoCloseFd> {
KJ_IF_MAYBE(r, result) {
return kj::mv(*r);
} else {
return KJ_EXCEPTION(FAILED, "EOF when expecting to receive capability");
}
});
}
Promise<Maybe<AutoCloseFd>> AsyncCapabilityStream::tryReceiveFd() {
return KJ_EXCEPTION(UNIMPLEMENTED, "this stream cannot receive file descriptors");
}
Promise<void> AsyncCapabilityStream::sendFd(int fd) {
return KJ_EXCEPTION(UNIMPLEMENTED, "this stream cannot send file descriptors");
}
void AsyncIoStream::getsockopt(int level, int option, void* value, uint* length) { void AsyncIoStream::getsockopt(int level, int option, void* value, uint* length) {
KJ_UNIMPLEMENTED("Not a socket."); KJ_UNIMPLEMENTED("Not a socket.");
} }
...@@ -212,6 +241,9 @@ Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd( ...@@ -212,6 +241,9 @@ Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(
Fd fd, LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) { Fd fd, LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) {
KJ_UNIMPLEMENTED("Datagram sockets not implemented."); KJ_UNIMPLEMENTED("Datagram sockets not implemented.");
} }
CapabilityPipe AsyncIoProvider::newCapabilityPipe() {
KJ_UNIMPLEMENTED("Capability pipes not implemented.");
}
namespace { namespace {
...@@ -227,6 +259,39 @@ LowLevelAsyncIoProvider::NetworkFilter& LowLevelAsyncIoProvider::NetworkFilter:: ...@@ -227,6 +259,39 @@ LowLevelAsyncIoProvider::NetworkFilter& LowLevelAsyncIoProvider::NetworkFilter::
return result; return result;
} }
// =======================================================================================
// Convenience adapters.
Promise<Own<AsyncIoStream>> CapabilityStreamConnectionReceiver::accept() {
return inner.receiveStream()
.then([](Own<AsyncCapabilityStream>&& stream) -> Own<AsyncIoStream> {
return kj::mv(stream);
});
}
uint CapabilityStreamConnectionReceiver::getPort() {
return 0;
}
Promise<Own<AsyncIoStream>> CapabilityStreamNetworkAddress::connect() {
auto pipe = provider.newCapabilityPipe();
auto result = kj::mv(pipe.ends[0]);
return inner.sendStream(kj::mv(pipe.ends[1]))
.then(kj::mvCapture(result, [](Own<AsyncIoStream>&& result) {
return kj::mv(result);
}));
}
Own<ConnectionReceiver> CapabilityStreamNetworkAddress::listen() {
return kj::heap<CapabilityStreamConnectionReceiver>(inner);
}
Own<NetworkAddress> CapabilityStreamNetworkAddress::clone() {
KJ_UNIMPLEMENTED("can't clone CapabilityStreamNetworkAddress");
}
String CapabilityStreamNetworkAddress::toString() {
return kj::str("<CapabilityStreamNetworkAddress>");
}
// ======================================================================================= // =======================================================================================
namespace _ { // private namespace _ { // private
......
...@@ -39,10 +39,12 @@ namespace kj { ...@@ -39,10 +39,12 @@ namespace kj {
class Win32EventPort; class Win32EventPort;
#else #else
class UnixEventPort; class UnixEventPort;
class AutoCloseFd;
#endif #endif
class NetworkAddress; class NetworkAddress;
class AsyncOutputStream; class AsyncOutputStream;
class AsyncIoStream;
// ======================================================================================= // =======================================================================================
// Streaming I/O // Streaming I/O
...@@ -130,6 +132,42 @@ public: ...@@ -130,6 +132,42 @@ public:
// ephemeral addresses for a single connection. // ephemeral addresses for a single connection.
}; };
class AsyncCapabilityStream: public AsyncIoStream {
// An AsyncIoStream that also allows sending and receiving new connections or other kinds of
// capabilities, in addition to simple data.
//
// For correct functioning, a protocol must be designed such that the receiver knows when to
// expect a capability transfer. The receiver must not read() when a capability is expected, and
// must not receiveStream() when data is expected -- if it does, an exception may be thrown or
// invalid data may be returned. This implies that data sent over an AsyncCapabilityStream must
// be framed such that the receiver knows exactly how many bytes to read before receiving a
// capability.
//
// On Unix, KJ provides an implementation based on Unix domain sockets and file descriptor
// passing via SCM_RIGHTS. Due to the nature of SCM_RIGHTS, if the application accidentally
// read()s when it should have called receiveStream(), it will observe a NUL byte in the data
// and the capability will be discarded. Of course, an application should not depend on this
// behavior; it should avoid read()ing through a capability.
//
// KJ does not provide any implementation of this type on Windows, as there's no obvious
// implementation there. Handle passing on Windows requires at least one of the processes
// involved to have permission to modify the other's handle table, which is effectively full
// control. Handle passing between mutually non-trusting processes would require a trusted
// broker process to facilitate. One could possibly implement this type in terms of such a
// broker, or in terms of direct handle passing if at least one process trusts the other.
public:
Promise<Own<AsyncCapabilityStream>> receiveStream();
virtual Promise<Maybe<Own<AsyncCapabilityStream>>> tryReceiveStream() = 0;
virtual Promise<void> sendStream(Own<AsyncCapabilityStream> stream) = 0;
// Transfer a stream.
Promise<AutoCloseFd> receiveFd();
virtual Promise<Maybe<AutoCloseFd>> tryReceiveFd();
virtual Promise<void> sendFd(int fd);
// Transfer a raw file descriptor. Default implementation throws UNIMPLEMENTED.
};
struct OneWayPipe { struct OneWayPipe {
// A data pipe with an input end and an output end. (Typically backed by pipe() system call.) // A data pipe with an input end and an output end. (Typically backed by pipe() system call.)
...@@ -144,6 +182,12 @@ struct TwoWayPipe { ...@@ -144,6 +182,12 @@ struct TwoWayPipe {
Own<AsyncIoStream> ends[2]; Own<AsyncIoStream> ends[2];
}; };
struct CapabilityPipe {
// Like TwoWayPipe but allowing capability-passing.
Own<AsyncCapabilityStream> ends[2];
};
class ConnectionReceiver { class ConnectionReceiver {
// Represents a server socket listening on a port. // Represents a server socket listening on a port.
...@@ -401,6 +445,13 @@ public: ...@@ -401,6 +445,13 @@ public:
// Creates two AsyncIoStreams representing the two ends of a two-way pipe (e.g. created with // Creates two AsyncIoStreams representing the two ends of a two-way pipe (e.g. created with
// socketpair(2) system call). Data written to one end can be read from the other. // socketpair(2) system call). Data written to one end can be read from the other.
virtual CapabilityPipe newCapabilityPipe();
// Creates two AsyncCapabilityStreams representing the two ends of a two-way capability pipe.
//
// The default implementation throws an unimplemented exception. In particular this is not
// implemented by the default AsyncIoProvider on Windows, since Windows lacks any sane way to
// pass handles over a stream.
virtual Network& getNetwork() = 0; virtual Network& getNetwork() = 0;
// Creates a new `Network` instance representing the networks exposed by the operating system. // Creates a new `Network` instance representing the networks exposed by the operating system.
// //
...@@ -524,6 +575,12 @@ public: ...@@ -524,6 +575,12 @@ public:
// //
// `flags` is a bitwise-OR of the values of the `Flags` enum. // `flags` is a bitwise-OR of the values of the `Flags` enum.
#if !_WIN32
virtual Own<AsyncCapabilityStream> wrapUnixSocketFd(Fd fd, uint flags = 0) = 0;
// Like wrapSocketFd() but also support capability passing via SCM_RIGHTS. The socket must be
// a Unix domain socket.
#endif
virtual Promise<Own<AsyncIoStream>> wrapConnectingSocketFd( virtual Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
Fd fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) = 0; Fd fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a socket and initiate a connection to the given address. // Create an AsyncIoStream wrapping a socket and initiate a connection to the given address.
...@@ -609,6 +666,50 @@ AsyncIoContext setupAsyncIo(); ...@@ -609,6 +666,50 @@ AsyncIoContext setupAsyncIo();
// note that this means that server processes which daemonize themselves at startup must wait // note that this means that server processes which daemonize themselves at startup must wait
// until after daemonization to create an AsyncIoContext. // until after daemonization to create an AsyncIoContext.
// =======================================================================================
// Convenience adapters.
class CapabilityStreamConnectionReceiver: public ConnectionReceiver {
// Trivial wrapper which allows an AsyncCapabilityStream to act as a ConnectionReceiver. accept()
// calls receiveStream().
public:
CapabilityStreamConnectionReceiver(AsyncCapabilityStream& inner)
: inner(inner) {}
Promise<Own<AsyncIoStream>> accept() override;
uint getPort() override;
private:
AsyncCapabilityStream& inner;
};
class CapabilityStreamNetworkAddress: public NetworkAddress {
// Trivial wrapper which allows an AsyncCapabilityStream to act as a NetworkAddress.
//
// connect() is implemented by calling provider.newCapabilityPipe(), sending one end over the
// original capability stream, and returning the other end.
//
// listen().accept() is implemented by receiving new streams over the original stream.
//
// Note that clone() dosen't work (due to ownership issues) and toString() returns a static
// string.
public:
CapabilityStreamNetworkAddress(AsyncIoProvider& provider, AsyncCapabilityStream& inner)
: provider(provider), inner(inner) {}
Promise<Own<AsyncIoStream>> connect() override;
Own<ConnectionReceiver> listen() override;
Own<NetworkAddress> clone() override;
String toString() override;
private:
AsyncIoProvider& provider;
AsyncCapabilityStream& inner;
};
// ======================================================================================= // =======================================================================================
// inline implementation details // inline implementation details
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment