Commit 09aa4f79 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #558 from capnproto/capability-stream

Add AsyncCapabilityStream, an abstraction of unix sockets SCM_RIGHTS FD passing.
parents 0c482f58 dbb2e985
......@@ -192,6 +192,47 @@ TEST(AsyncIo, TwoWayPipe) {
EXPECT_EQ("bar", result2);
}
#if !_WIN32
TEST(AsyncIo, CapabilityPipe) {
auto ioContext = setupAsyncIo();
auto pipe = ioContext.provider->newCapabilityPipe();
auto pipe2 = ioContext.provider->newCapabilityPipe();
char receiveBuffer1[4];
char receiveBuffer2[4];
// Expect to receive a stream, then write "bar" to it, then receive "foo" from it.
Own<AsyncCapabilityStream> receivedStream;
auto promise = pipe2.ends[1]->receiveStream()
.then([&](Own<AsyncCapabilityStream> stream) {
receivedStream = kj::mv(stream);
return receivedStream->write("bar", 3);
}).then([&]() {
return receivedStream->tryRead(receiveBuffer2, 3, 4);
}).then([&](size_t n) {
EXPECT_EQ(3u, n);
return heapString(receiveBuffer2, n);
});
// Send a stream, then write "foo" to the other end of the sent stream, then receive "bar"
// from it.
kj::String result = pipe2.ends[0]->sendStream(kj::mv(pipe.ends[1]))
.then([&]() {
return pipe.ends[0]->write("foo", 3);
}).then([&]() {
return pipe.ends[0]->tryRead(receiveBuffer1, 3, 4);
}).then([&](size_t n) {
EXPECT_EQ(3u, n);
return heapString(receiveBuffer1, n);
}).wait(ioContext.waitScope);
kj::String result2 = promise.wait(ioContext.waitScope);
EXPECT_EQ("bar", result);
EXPECT_EQ("foo", result2);
}
#endif
TEST(AsyncIo, PipeThread) {
auto ioContext = setupAsyncIo();
......
......@@ -45,25 +45,35 @@
#include <set>
#include <poll.h>
#include <limits.h>
#include <sys/ioctl.h>
namespace kj {
namespace {
void setNonblocking(int fd) {
#ifdef FIONBIO
int opt = 1;
KJ_SYSCALL(ioctl(fd, FIONBIO, &opt));
#else
int flags;
KJ_SYSCALL(flags = fcntl(fd, F_GETFL));
if ((flags & O_NONBLOCK) == 0) {
KJ_SYSCALL(fcntl(fd, F_SETFL, flags | O_NONBLOCK));
}
#endif
}
void setCloseOnExec(int fd) {
#ifdef FIOCLEX
KJ_SYSCALL(ioctl(fd, FIOCLEX));
#else
int flags;
KJ_SYSCALL(flags = fcntl(fd, F_GETFD));
if ((flags & FD_CLOEXEC) == 0) {
KJ_SYSCALL(fcntl(fd, F_SETFD, flags | FD_CLOEXEC));
}
#endif
}
static constexpr uint NEW_FD_FLAGS =
......@@ -112,10 +122,11 @@ private:
// =======================================================================================
class AsyncStreamFd: public OwnedFileDescriptor, public AsyncIoStream {
class AsyncStreamFd: public OwnedFileDescriptor, public AsyncCapabilityStream {
public:
AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags)
: OwnedFileDescriptor(fd, flags),
eventPort(eventPort),
observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ_WRITE) {}
virtual ~AsyncStreamFd() noexcept(false) {}
......@@ -198,6 +209,57 @@ public:
*length = socklen;
}
kj::Promise<Maybe<Own<AsyncCapabilityStream>>> tryReceiveStream() override {
return tryReceiveFdImpl<Own<AsyncCapabilityStream>>();
}
kj::Promise<void> sendStream(Own<AsyncCapabilityStream> stream) override {
auto downcasted = stream.downcast<AsyncStreamFd>();
auto promise = sendFd(downcasted->fd);
return promise.attach(kj::mv(downcasted));
}
kj::Promise<kj::Maybe<AutoCloseFd>> tryReceiveFd() override {
return tryReceiveFdImpl<AutoCloseFd>();
}
kj::Promise<void> sendFd(int fdToSend) override {
struct msghdr msg;
struct iovec iov;
union {
struct cmsghdr cmsg;
char cmsgSpace[CMSG_LEN(sizeof(int))];
};
memset(&msg, 0, sizeof(msg));
memset(&iov, 0, sizeof(iov));
memset(cmsgSpace, 0, sizeof(cmsgSpace));
char c = 0;
iov.iov_base = &c;
iov.iov_len = 1;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = &cmsg;
msg.msg_controllen = sizeof(cmsgSpace);
cmsg.cmsg_len = sizeof(cmsgSpace);
cmsg.cmsg_level = SOL_SOCKET;
cmsg.cmsg_type = SCM_RIGHTS;
*reinterpret_cast<int*>(CMSG_DATA(&cmsg)) = fdToSend;
ssize_t n;
KJ_NONBLOCKING_SYSCALL(n = sendmsg(fd, &msg, 0));
if (n < 0) {
return observer.whenBecomesWritable().then([this,fdToSend]() {
return sendFd(fdToSend);
});
} else {
KJ_ASSERT(n == 1);
return kj::READY_NOW;
}
}
Promise<void> waitConnected() {
// Wait until initial connection has completed. This actually just waits until it is writable.
......@@ -222,6 +284,7 @@ public:
}
private:
UnixEventPort& eventPort;
UnixEventPort::FdObserver observer;
Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
......@@ -354,6 +417,71 @@ private:
}
}
}
template <typename T>
kj::Promise<kj::Maybe<T>> tryReceiveFdImpl() {
struct msghdr msg;
memset(&msg, 0, sizeof(msg));
struct iovec iov;
memset(&iov, 0, sizeof(iov));
char c;
iov.iov_base = &c;
iov.iov_len = 1;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
// Allocate space to receive a cmsg.
union {
struct cmsghdr cmsg;
char cmsgSpace[CMSG_SPACE(sizeof(int))];
};
msg.msg_control = &cmsg;
msg.msg_controllen = sizeof(cmsgSpace);
#ifdef MSG_CMSG_CLOEXEC
int recvmsgFlags = MSG_CMSG_CLOEXEC;
#else
int recvmsgFlags = 0;
#endif
ssize_t n;
KJ_NONBLOCKING_SYSCALL(n = recvmsg(fd, &msg, recvmsgFlags));
if (n < 0) {
return observer.whenBecomesReadable().then([this]() {
return tryReceiveFdImpl<T>();
});
} else if (n == 0) {
return kj::Maybe<T>(nullptr);
} else {
KJ_REQUIRE(msg.msg_controllen >= sizeof(cmsg),
"expected to receive FD over socket; received data instead");
// We expect an SCM_RIGHTS message with a single FD.
KJ_REQUIRE(cmsg.cmsg_level == SOL_SOCKET);
KJ_REQUIRE(cmsg.cmsg_type == SCM_RIGHTS);
KJ_REQUIRE(cmsg.cmsg_len == CMSG_LEN(sizeof(int)));
int receivedFd;
memcpy(&receivedFd, CMSG_DATA(&cmsg), sizeof(receivedFd));
return kj::Maybe<T>(wrapFd(receivedFd, (T*)nullptr));
}
}
AutoCloseFd wrapFd(int newFd, AutoCloseFd*) {
auto result = AutoCloseFd(newFd);
#ifndef MSG_CMSG_CLOEXEC
setCloseOnExec(result);
#endif
return result;
}
Own<AsyncCapabilityStream> wrapFd(int newFd, Own<AsyncCapabilityStream>*) {
return kj::heap<AsyncStreamFd>(eventPort, newFd,
#ifdef MSG_CMSG_CLOEXEC
LowLevelAsyncIoProvider::ALREADY_CLOEXEC |
#endif
LowLevelAsyncIoProvider::TAKE_OWNERSHIP);
}
};
// =======================================================================================
......@@ -951,6 +1079,9 @@ public:
Own<AsyncIoStream> wrapSocketFd(int fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Own<AsyncCapabilityStream> wrapUnixSocketFd(Fd fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
int fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override {
// Unfortunately connect() doesn't fit the mold of KJ_NONBLOCKING_SYSCALL, since it indicates
......@@ -1375,6 +1506,19 @@ public:
} };
}
CapabilityPipe newCapabilityPipe() override {
int fds[2];
int type = SOCK_STREAM;
#if __linux__ && !__BIONIC__
type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
#endif
KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
return CapabilityPipe { {
lowLevel.wrapUnixSocketFd(fds[0], NEW_FD_FLAGS),
lowLevel.wrapUnixSocketFd(fds[1], NEW_FD_FLAGS)
} };
}
Network& getNetwork() override {
return network;
}
......
......@@ -29,6 +29,7 @@
#include "async-io-internal.h"
#include "debug.h"
#include "vector.h"
#include "io.h"
#if _WIN32
#include <winsock2.h>
......@@ -181,6 +182,34 @@ Maybe<Promise<uint64_t>> AsyncOutputStream::tryPumpFrom(
return nullptr;
}
Promise<Own<AsyncCapabilityStream>> AsyncCapabilityStream::receiveStream() {
return tryReceiveStream()
.then([](Maybe<Own<AsyncCapabilityStream>>&& result)
-> Promise<Own<AsyncCapabilityStream>> {
KJ_IF_MAYBE(r, result) {
return kj::mv(*r);
} else {
return KJ_EXCEPTION(FAILED, "EOF when expecting to receive capability");
}
});
}
Promise<AutoCloseFd> AsyncCapabilityStream::receiveFd() {
return tryReceiveFd().then([](Maybe<AutoCloseFd>&& result) -> Promise<AutoCloseFd> {
KJ_IF_MAYBE(r, result) {
return kj::mv(*r);
} else {
return KJ_EXCEPTION(FAILED, "EOF when expecting to receive capability");
}
});
}
Promise<Maybe<AutoCloseFd>> AsyncCapabilityStream::tryReceiveFd() {
return KJ_EXCEPTION(UNIMPLEMENTED, "this stream cannot receive file descriptors");
}
Promise<void> AsyncCapabilityStream::sendFd(int fd) {
return KJ_EXCEPTION(UNIMPLEMENTED, "this stream cannot send file descriptors");
}
void AsyncIoStream::getsockopt(int level, int option, void* value, uint* length) {
KJ_UNIMPLEMENTED("Not a socket.");
}
......@@ -212,6 +241,43 @@ Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(
Fd fd, LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) {
KJ_UNIMPLEMENTED("Datagram sockets not implemented.");
}
CapabilityPipe AsyncIoProvider::newCapabilityPipe() {
KJ_UNIMPLEMENTED("Capability pipes not implemented.");
}
Own<AsyncInputStream> LowLevelAsyncIoProvider::wrapInputFd(OwnFd fd, uint flags) {
return wrapInputFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP);
}
Own<AsyncOutputStream> LowLevelAsyncIoProvider::wrapOutputFd(OwnFd fd, uint flags) {
return wrapOutputFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP);
}
Own<AsyncIoStream> LowLevelAsyncIoProvider::wrapSocketFd(OwnFd fd, uint flags) {
return wrapSocketFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP);
}
#if !_WIN32
Own<AsyncCapabilityStream> LowLevelAsyncIoProvider::wrapUnixSocketFd(OwnFd fd, uint flags) {
return wrapUnixSocketFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP);
}
#endif
Promise<Own<AsyncIoStream>> LowLevelAsyncIoProvider::wrapConnectingSocketFd(
OwnFd fd, const struct sockaddr* addr, uint addrlen, uint flags) {
return wrapConnectingSocketFd(reinterpret_cast<Fd>(fd.release()), addr, addrlen,
flags | TAKE_OWNERSHIP);
}
Own<ConnectionReceiver> LowLevelAsyncIoProvider::wrapListenSocketFd(
OwnFd fd, NetworkFilter& filter, uint flags) {
return wrapListenSocketFd(reinterpret_cast<Fd>(fd.release()), filter, flags | TAKE_OWNERSHIP);
}
Own<ConnectionReceiver> LowLevelAsyncIoProvider::wrapListenSocketFd(OwnFd fd, uint flags) {
return wrapListenSocketFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP);
}
Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(
OwnFd fd, NetworkFilter& filter, uint flags) {
return wrapDatagramSocketFd(reinterpret_cast<Fd>(fd.release()), filter, flags | TAKE_OWNERSHIP);
}
Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(OwnFd fd, uint flags) {
return wrapDatagramSocketFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP);
}
namespace {
......@@ -227,6 +293,39 @@ LowLevelAsyncIoProvider::NetworkFilter& LowLevelAsyncIoProvider::NetworkFilter::
return result;
}
// =======================================================================================
// Convenience adapters.
Promise<Own<AsyncIoStream>> CapabilityStreamConnectionReceiver::accept() {
return inner.receiveStream()
.then([](Own<AsyncCapabilityStream>&& stream) -> Own<AsyncIoStream> {
return kj::mv(stream);
});
}
uint CapabilityStreamConnectionReceiver::getPort() {
return 0;
}
Promise<Own<AsyncIoStream>> CapabilityStreamNetworkAddress::connect() {
auto pipe = provider.newCapabilityPipe();
auto result = kj::mv(pipe.ends[0]);
return inner.sendStream(kj::mv(pipe.ends[1]))
.then(kj::mvCapture(result, [](Own<AsyncIoStream>&& result) {
return kj::mv(result);
}));
}
Own<ConnectionReceiver> CapabilityStreamNetworkAddress::listen() {
return kj::heap<CapabilityStreamConnectionReceiver>(inner);
}
Own<NetworkAddress> CapabilityStreamNetworkAddress::clone() {
KJ_UNIMPLEMENTED("can't clone CapabilityStreamNetworkAddress");
}
String CapabilityStreamNetworkAddress::toString() {
return kj::str("<CapabilityStreamNetworkAddress>");
}
// =======================================================================================
namespace _ { // private
......
......@@ -37,12 +37,15 @@ namespace kj {
#if _WIN32
class Win32EventPort;
class AutoCloseHandle;
#else
class UnixEventPort;
#endif
class AutoCloseFd;
class NetworkAddress;
class AsyncOutputStream;
class AsyncIoStream;
// =======================================================================================
// Streaming I/O
......@@ -130,6 +133,42 @@ public:
// ephemeral addresses for a single connection.
};
class AsyncCapabilityStream: public AsyncIoStream {
// An AsyncIoStream that also allows sending and receiving new connections or other kinds of
// capabilities, in addition to simple data.
//
// For correct functioning, a protocol must be designed such that the receiver knows when to
// expect a capability transfer. The receiver must not read() when a capability is expected, and
// must not receiveStream() when data is expected -- if it does, an exception may be thrown or
// invalid data may be returned. This implies that data sent over an AsyncCapabilityStream must
// be framed such that the receiver knows exactly how many bytes to read before receiving a
// capability.
//
// On Unix, KJ provides an implementation based on Unix domain sockets and file descriptor
// passing via SCM_RIGHTS. Due to the nature of SCM_RIGHTS, if the application accidentally
// read()s when it should have called receiveStream(), it will observe a NUL byte in the data
// and the capability will be discarded. Of course, an application should not depend on this
// behavior; it should avoid read()ing through a capability.
//
// KJ does not provide any implementation of this type on Windows, as there's no obvious
// implementation there. Handle passing on Windows requires at least one of the processes
// involved to have permission to modify the other's handle table, which is effectively full
// control. Handle passing between mutually non-trusting processes would require a trusted
// broker process to facilitate. One could possibly implement this type in terms of such a
// broker, or in terms of direct handle passing if at least one process trusts the other.
public:
Promise<Own<AsyncCapabilityStream>> receiveStream();
virtual Promise<Maybe<Own<AsyncCapabilityStream>>> tryReceiveStream() = 0;
virtual Promise<void> sendStream(Own<AsyncCapabilityStream> stream) = 0;
// Transfer a stream.
Promise<AutoCloseFd> receiveFd();
virtual Promise<Maybe<AutoCloseFd>> tryReceiveFd();
virtual Promise<void> sendFd(int fd);
// Transfer a raw file descriptor. Default implementation throws UNIMPLEMENTED.
};
struct OneWayPipe {
// A data pipe with an input end and an output end. (Typically backed by pipe() system call.)
......@@ -144,6 +183,12 @@ struct TwoWayPipe {
Own<AsyncIoStream> ends[2];
};
struct CapabilityPipe {
// Like TwoWayPipe but allowing capability-passing.
Own<AsyncCapabilityStream> ends[2];
};
class ConnectionReceiver {
// Represents a server socket listening on a port.
......@@ -401,6 +446,13 @@ public:
// Creates two AsyncIoStreams representing the two ends of a two-way pipe (e.g. created with
// socketpair(2) system call). Data written to one end can be read from the other.
virtual CapabilityPipe newCapabilityPipe();
// Creates two AsyncCapabilityStreams representing the two ends of a two-way capability pipe.
//
// The default implementation throws an unimplemented exception. In particular this is not
// implemented by the default AsyncIoProvider on Windows, since Windows lacks any sane way to
// pass handles over a stream.
virtual Network& getNetwork() = 0;
// Creates a new `Network` instance representing the networks exposed by the operating system.
//
......@@ -461,16 +513,11 @@ class LowLevelAsyncIoProvider {
// Different implementations of this interface might work on top of different event handling
// primitives, such as poll vs. epoll vs. kqueue vs. some higher-level event library.
//
// On Windows, this interface can be used to import native HANDLEs into the async framework.
// On Windows, this interface can be used to import native SOCKETs into the async framework.
// Different implementations of this interface might work on top of different event handling
// primitives, such as I/O completion ports vs. completion routines.
//
// TODO(port): Actually implement Windows support.
public:
// ---------------------------------------------------------------------------
// Unix-specific stuff
enum Flags {
// Flags controlling how to wrap a file descriptor.
......@@ -501,11 +548,13 @@ public:
#if _WIN32
typedef uintptr_t Fd;
typedef AutoCloseHandle OwnFd;
// On Windows, the `fd` parameter to each of these methods must be a SOCKET, and must have the
// flag WSA_FLAG_OVERLAPPED (which socket() uses by default, but WSASocket() wants you to specify
// explicitly).
#else
typedef int Fd;
typedef AutoCloseFd OwnFd;
// On Unix, any arbitrary file descriptor is supported.
#endif
......@@ -524,6 +573,12 @@ public:
//
// `flags` is a bitwise-OR of the values of the `Flags` enum.
#if !_WIN32
virtual Own<AsyncCapabilityStream> wrapUnixSocketFd(Fd fd, uint flags = 0) = 0;
// Like wrapSocketFd() but also support capability passing via SCM_RIGHTS. The socket must be
// a Unix domain socket.
#endif
virtual Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
Fd fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a socket and initiate a connection to the given address.
......@@ -562,6 +617,22 @@ public:
//
// This timer is not affected by changes to the system date. It is unspecified whether the timer
// continues to count while the system is suspended.
Own<AsyncInputStream> wrapInputFd(OwnFd fd, uint flags = 0);
Own<AsyncOutputStream> wrapOutputFd(OwnFd fd, uint flags = 0);
Own<AsyncIoStream> wrapSocketFd(OwnFd fd, uint flags = 0);
#if !_WIN32
Own<AsyncCapabilityStream> wrapUnixSocketFd(OwnFd fd, uint flags = 0);
#endif
Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
OwnFd fd, const struct sockaddr* addr, uint addrlen, uint flags = 0);
Own<ConnectionReceiver> wrapListenSocketFd(
OwnFd fd, NetworkFilter& filter, uint flags = 0);
Own<ConnectionReceiver> wrapListenSocketFd(OwnFd fd, uint flags = 0);
Own<DatagramPort> wrapDatagramSocketFd(OwnFd fd, NetworkFilter& filter, uint flags = 0);
Own<DatagramPort> wrapDatagramSocketFd(OwnFd fd, uint flags = 0);
// Convenience wrappers which transfer ownership via AutoCloseFd (Unix) or AutoCloseHandle
// (Windows). TAKE_OWNERSHIP will be implicitly added to `flags`.
};
Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel);
......@@ -609,6 +680,50 @@ AsyncIoContext setupAsyncIo();
// note that this means that server processes which daemonize themselves at startup must wait
// until after daemonization to create an AsyncIoContext.
// =======================================================================================
// Convenience adapters.
class CapabilityStreamConnectionReceiver final: public ConnectionReceiver {
// Trivial wrapper which allows an AsyncCapabilityStream to act as a ConnectionReceiver. accept()
// calls receiveStream().
public:
CapabilityStreamConnectionReceiver(AsyncCapabilityStream& inner)
: inner(inner) {}
Promise<Own<AsyncIoStream>> accept() override;
uint getPort() override;
private:
AsyncCapabilityStream& inner;
};
class CapabilityStreamNetworkAddress final: public NetworkAddress {
// Trivial wrapper which allows an AsyncCapabilityStream to act as a NetworkAddress.
//
// connect() is implemented by calling provider.newCapabilityPipe(), sending one end over the
// original capability stream, and returning the other end.
//
// listen().accept() is implemented by receiving new streams over the original stream.
//
// Note that clone() dosen't work (due to ownership issues) and toString() returns a static
// string.
public:
CapabilityStreamNetworkAddress(AsyncIoProvider& provider, AsyncCapabilityStream& inner)
: provider(provider), inner(inner) {}
Promise<Own<AsyncIoStream>> connect() override;
Own<ConnectionReceiver> listen() override;
Own<NetworkAddress> clone() override;
String toString() override;
private:
AsyncIoProvider& provider;
AsyncCapabilityStream& inner;
};
// =======================================================================================
// inline implementation details
......
......@@ -283,6 +283,13 @@ public:
inline bool operator==(decltype(nullptr)) { return fd < 0; }
inline bool operator!=(decltype(nullptr)) { return fd >= 0; }
inline int release() {
// Release ownership of an FD. Not recommended.
int result = fd;
fd = -1;
return result;
}
private:
int fd;
UnwindDetector unwindDetector;
......@@ -376,6 +383,13 @@ public:
inline bool operator==(decltype(nullptr)) { return handle != (void*)-1; }
inline bool operator!=(decltype(nullptr)) { return handle == (void*)-1; }
inline void* release() {
// Release ownership of an FD. Not recommended.
void* result = handle;
handle = (void*)-1;
return result;
}
private:
void* handle; // -1 (aka INVALID_HANDLE_VALUE) if not valid.
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment