Commit e9c46ac1 authored by Kenton Varda's avatar Kenton Varda

Implement FD passing in Cap'n Proto.

An endpoint (subclass of `Capability::Server`) may override `kj::Maybe<int> getFd()` to expose an underlying file descriptor.

A remote client may use `Capability::Client::getFd()` on the endpoint's capability to get that FD.

The client and server must explicitly opt into FD passing by passing a max-FDs-per-message limit to TwoPartyVatNetwork and using Unix sockets as the transport. Nothing other than that is needed.
parent 84c7ba81
......@@ -69,6 +69,19 @@ Capability::Client::Client(decltype(nullptr))
Capability::Client::Client(kj::Exception&& exception)
: hook(newBrokenCap(kj::mv(exception))) {}
kj::Promise<kj::Maybe<int>> Capability::Client::getFd() {
auto fd = hook->getFd();
if (fd != nullptr) {
return fd;
} else KJ_IF_MAYBE(promise, hook->whenMoreResolved()) {
return promise->attach(hook->addRef()).then([](kj::Own<ClientHook> newHook) {
return Client(kj::mv(newHook)).getFd();
});
} else {
return kj::Maybe<int>(nullptr);
}
}
kj::Promise<void> Capability::Server::internalUnimplemented(
const char* actualInterfaceName, uint64_t requestedTypeId) {
return KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.",
......@@ -374,6 +387,14 @@ public:
return nullptr;
}
kj::Maybe<int> getFd() override {
KJ_IF_MAYBE(r, redirect) {
return r->get()->getFd();
} else {
return nullptr;
}
}
private:
typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork;
......@@ -524,6 +545,10 @@ public:
}
}
kj::Maybe<int> getFd() override {
return server->getFd();
}
private:
kj::Own<Capability::Server> server;
_::CapabilityServerSetBase* capServerSet = nullptr;
......@@ -616,6 +641,10 @@ public:
return brand;
}
kj::Maybe<int> getFd() override {
return nullptr;
}
private:
kj::Exception exception;
bool resolved;
......
......@@ -206,6 +206,19 @@ public:
// Make a request without knowing the types of the params or results. You specify the type ID
// and method number manually.
kj::Promise<kj::Maybe<int>> getFd();
// If the capability's server implemented Capability::Server::getFd() returning non-null, and all
// RPC links between the client and server support FD passing, returns a file descriptor pointing
// to the same undelying file description as the server did. Returns null if the server provided
// no FD or if FD passing was unavailable at some intervening link.
//
// This returns a Promise to handle the case of an unresolved promise capability, e.g. a
// pipelined capability. The promise resolves no later than when the capability settles, i.e.
// the same time `whenResolved()` would complete.
//
// The file descriptor will remain open at least as long as the Capability::Client remains alive.
// If you need it to last longer, you will need to `dup()` it.
// TODO(someday): method(s) for Join
protected:
......@@ -331,6 +344,11 @@ public:
// is no longer needed. `context` may be used to allocate the output struct and deal with
// cancellation.
virtual kj::Maybe<int> getFd() { return nullptr; }
// If this capability is backed by a file descriptor that is safe to directly expose to clients,
// returns that FD. When FD passing has been enabled in the RPC layer, this FD may be sent to
// other processes along with the capability.
// TODO(someday): Method which can optionally be overridden to implement Join when the object is
// a proxy.
......@@ -563,6 +581,10 @@ public:
// Otherwise, return nullptr. Default implementation (which everyone except LocalClient should
// use) always returns nullptr.
virtual kj::Maybe<int> getFd() = 0;
// Implements Capability::Client::getFd(). If this returns null but whenMoreResolved() returns
// non-null, then Capability::Client::getFd() waits for resolution and tries again.
static kj::Own<ClientHook> from(Capability::Client client) { return kj::mv(client.hook); }
};
......
......@@ -469,6 +469,13 @@ public:
return MEMBRANE_BRAND;
}
kj::Maybe<int> getFd() override {
// We can't let FDs pass over membranes because we have no way to enforce the membrane policy
// on them. If the MembranePolicy wishes to explicitly permit certain FDs to pass, it can
// always do so by overriding the appropriate policy methods.
return nullptr;
}
private:
kj::Own<ClientHook> inner;
kj::Own<MembranePolicy> policy;
......
......@@ -27,6 +27,7 @@
#include <kj/debug.h>
#include <kj/thread.h>
#include <kj/compat/gtest.h>
#include <kj/miniposix.h>
// TODO(cleanup): Auto-generate stringification functions for union discriminants.
namespace capnp {
......@@ -419,6 +420,104 @@ TEST(TwoPartyNetwork, BootstrapFactory) {
EXPECT_TRUE(bootstrapFactory.called);
}
// =======================================================================================
#if !_WIN32
KJ_TEST("send FD over RPC") {
auto io = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
TwoPartyServer server(kj::heap<TestMoreStuffImpl>(callCount, handleCount));
auto pipe = io.provider->newCapabilityPipe();
server.accept(kj::mv(pipe.ends[0]), 2);
TwoPartyClient client(*pipe.ends[1], 2);
auto cap = client.bootstrap().castAs<test::TestMoreStuff>();
int pipeFds[2];
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in1(pipeFds[0]);
kj::AutoCloseFd out1(pipeFds[1]);
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in2(pipeFds[0]);
kj::AutoCloseFd out2(pipeFds[1]);
capnp::RemotePromise<test::TestMoreStuff::WriteToFdResults> promise = nullptr;
{
auto req = cap.writeToFdRequest();
// Order reversal intentional, just trying to mix things up.
req.setFdCap1(kj::heap<TestFdCap>(kj::mv(out2)));
req.setFdCap2(kj::heap<TestFdCap>(kj::mv(out1)));
promise = req.send();
}
int in3 = KJ_ASSERT_NONNULL(promise.getFdCap3().getFd().wait(io.waitScope));
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in3))->readAllText().wait(io.waitScope)
== "baz");
{
auto promise2 = kj::mv(promise); // make sure the PipelineHook also goes out of scope
auto response = promise2.wait(io.waitScope);
KJ_EXPECT(response.getSecondFdPresent());
}
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in1))->readAllText().wait(io.waitScope)
== "bar");
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in2))->readAllText().wait(io.waitScope)
== "foo");
}
KJ_TEST("FD per message limit") {
auto io = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
TwoPartyServer server(kj::heap<TestMoreStuffImpl>(callCount, handleCount));
auto pipe = io.provider->newCapabilityPipe();
server.accept(kj::mv(pipe.ends[0]), 1);
TwoPartyClient client(*pipe.ends[1], 1);
auto cap = client.bootstrap().castAs<test::TestMoreStuff>();
int pipeFds[2];
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in1(pipeFds[0]);
kj::AutoCloseFd out1(pipeFds[1]);
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in2(pipeFds[0]);
kj::AutoCloseFd out2(pipeFds[1]);
capnp::RemotePromise<test::TestMoreStuff::WriteToFdResults> promise = nullptr;
{
auto req = cap.writeToFdRequest();
// Order reversal intentional, just trying to mix things up.
req.setFdCap1(kj::heap<TestFdCap>(kj::mv(out2)));
req.setFdCap2(kj::heap<TestFdCap>(kj::mv(out1)));
promise = req.send();
}
int in3 = KJ_ASSERT_NONNULL(promise.getFdCap3().getFd().wait(io.waitScope));
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in3))->readAllText().wait(io.waitScope)
== "baz");
{
auto promise2 = kj::mv(promise); // make sure the PipelineHook also goes out of scope
auto response = promise2.wait(io.waitScope);
KJ_EXPECT(!response.getSecondFdPresent());
}
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in1))->readAllText().wait(io.waitScope)
== "");
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in2))->readAllText().wait(io.waitScope)
== "foo");
}
#endif // !_WIN32
} // namespace
} // namespace _
} // namespace capnp
......@@ -22,12 +22,13 @@
#include "rpc-twoparty.h"
#include "serialize-async.h"
#include <kj/debug.h>
#include <kj/io.h>
namespace capnp {
TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions)
: stream(stream), side(side), peerVatId(4),
: stream(&stream), maxFdsPerMesage(0), side(side), peerVatId(4),
receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) {
peerVatId.initRoot<rpc::twoparty::VatId>().setSide(
side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER
......@@ -38,6 +39,13 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty:
disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller);
}
TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMesage,
rpc::twoparty::Side side, ReaderOptions receiveOptions)
: TwoPartyVatNetwork(stream, side, receiveOptions) {
this->stream = &stream;
this->maxFdsPerMesage = maxFdsPerMesage;
}
void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const {
if (--refcount == 0) {
fulfiller->fulfill();
......@@ -81,6 +89,12 @@ public:
return message.getRoot<AnyPointer>();
}
void setFds(kj::Array<int> fds) override {
if (network.stream.is<kj::AsyncCapabilityStream*>()) {
this->fds = kj::mv(fds);
}
}
void send() override {
size_t size = 0;
for (auto& segment: message.getSegmentsForOutput()) {
......@@ -98,7 +112,15 @@ public:
// Note that if the write fails, all further writes will be skipped due to the exception.
// We never actually handle this exception because we assume the read end will fail as well
// and it's cleaner to handle the failure there.
return writeMessage(network.stream, message);
KJ_SWITCH_ONEOF(network.stream) {
KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
return writeMessage(*ioStream, message);
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
return writeMessage(*capStream, fds, message);
}
}
KJ_UNREACHABLE;
}).attach(kj::addRef(*this))
// Note that it's important that the eagerlyEvaluate() come *after* the attach() because
// otherwise the message (and any capabilities in it) will not be released until a new
......@@ -109,18 +131,32 @@ public:
private:
TwoPartyVatNetwork& network;
MallocMessageBuilder message;
kj::Array<int> fds;
};
class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage {
public:
IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {}
IncomingMessageImpl(MessageReaderAndFds init, kj::Array<kj::AutoCloseFd> fdSpace)
: message(kj::mv(init.reader)),
fdSpace(kj::mv(fdSpace)),
fds(init.fds) {
KJ_DASSERT(this->fds.begin() == this->fdSpace.begin());
}
AnyPointer::Reader getBody() override {
return message->getRoot<AnyPointer>();
}
kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() override {
return fds;
}
private:
kj::Own<MessageReader> message;
kj::Array<kj::AutoCloseFd> fdSpace;
kj::ArrayPtr<kj::AutoCloseFd> fds;
};
rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() {
......@@ -132,22 +168,52 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSeg
}
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
return kj::evalLater([&]() {
return tryReadMessage(stream, receiveOptions)
.then([&](kj::Maybe<kj::Own<MessageReader>>&& message)
-> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
} else {
return nullptr;
return kj::evalLater([this]() {
KJ_SWITCH_ONEOF(stream) {
KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
return tryReadMessage(*ioStream, receiveOptions)
.then([](kj::Maybe<kj::Own<MessageReader>>&& message)
-> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
} else {
return nullptr;
}
});
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
auto fdSpace = kj::heapArray<kj::AutoCloseFd>(maxFdsPerMesage);
auto promise = tryReadMessage(*capStream, fdSpace, receiveOptions);
return promise.then([fdSpace = kj::mv(fdSpace)]
(kj::Maybe<MessageReaderAndFds>&& messageAndFds) mutable
-> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, messageAndFds) {
if (m->fds.size() > 0) {
return kj::Own<IncomingRpcMessage>(
kj::heap<IncomingMessageImpl>(kj::mv(*m), kj::mv(fdSpace)));
} else {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(m->reader)));
}
} else {
return nullptr;
}
});
}
});
}
KJ_UNREACHABLE;
});
}
kj::Promise<void> TwoPartyVatNetwork::shutdown() {
kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() {
stream.shutdownWrite();
KJ_SWITCH_ONEOF(stream) {
KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
ioStream->shutdownWrite();
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
capStream->shutdownWrite();
}
}
});
previousWrite = nullptr;
return kj::mv(result);
......@@ -168,6 +234,14 @@ struct TwoPartyServer::AcceptedConnection {
: connection(kj::mv(connectionParam)),
network(*connection, rpc::twoparty::Side::SERVER),
rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
explicit AcceptedConnection(Capability::Client bootstrapInterface,
kj::Own<kj::AsyncCapabilityStream>&& connectionParam,
uint maxFdsPerMesage)
: connection(kj::mv(connectionParam)),
network(kj::downcast<kj::AsyncCapabilityStream>(*connection),
maxFdsPerMesage, rpc::twoparty::Side::SERVER),
rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
};
void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
......@@ -178,6 +252,15 @@ void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
tasks.add(promise.attach(kj::mv(connectionState)));
}
void TwoPartyServer::accept(kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMesage) {
auto connectionState = kj::heap<AcceptedConnection>(
bootstrapInterface, kj::mv(connection), maxFdsPerMesage);
// Run the connection until disconnect.
auto promise = connectionState->network.onDisconnect();
tasks.add(promise.attach(kj::mv(connectionState)));
}
kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
return listener.accept()
.then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable {
......@@ -186,6 +269,15 @@ kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
});
}
kj::Promise<void> TwoPartyServer::listenCapStreamReceiver(
kj::ConnectionReceiver& listener, uint maxFdsPerMesage) {
return listener.accept()
.then([this,&listener,maxFdsPerMesage](kj::Own<kj::AsyncIoStream>&& connection) mutable {
accept(connection.downcast<kj::AsyncCapabilityStream>(), maxFdsPerMesage);
return listenCapStreamReceiver(listener, maxFdsPerMesage);
});
}
void TwoPartyServer::taskFailed(kj::Exception&& exception) {
KJ_LOG(ERROR, exception);
}
......@@ -195,12 +287,22 @@ TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection)
rpcSystem(makeRpcClient(network)) {}
TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage)
: network(connection, maxFdsPerMesage, rpc::twoparty::Side::CLIENT),
rpcSystem(makeRpcClient(network)) {}
TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection,
Capability::Client bootstrapInterface,
rpc::twoparty::Side side)
: network(connection, side),
rpcSystem(network, bootstrapInterface) {}
TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage,
Capability::Client bootstrapInterface,
rpc::twoparty::Side side)
: network(connection, maxFdsPerMesage, side),
rpcSystem(network, bootstrapInterface) {}
Capability::Client TwoPartyClient::bootstrap() {
MallocMessageBuilder message(4);
auto vatId = message.getRoot<rpc::twoparty::VatId>();
......
......@@ -29,6 +29,7 @@
#include "message.h"
#include <kj/async-io.h>
#include <capnp/rpc-twoparty.capnp.h>
#include <kj/one-of.h>
namespace capnp {
......@@ -53,6 +54,19 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase,
public:
TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions = ReaderOptions());
TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMesage,
rpc::twoparty::Side side, ReaderOptions receiveOptions = ReaderOptions());
// To support FD passing, pass an AsyncCapabilityStream and `maxFdsPerMesage`, which specifies
// the maximum number of file descriptors to accept from the peer in any one RPC message. It is
// important to keep maxFdsPerMesage low in order to stop DoS attacks that fill up your FD table.
//
// Note that this limit applies only to incoming messages; outgoing messages are allowed to have
// more FDs. Sometimes it makes sense to enforce a limit of zero in one direction while having
// a non-zero limit in the other. For example, in a supervisor/sandbox scenario, typically there
// are many use cases for passing FDs from supervisor to sandbox but no use case for vice versa.
// The supervisor may be configured not to accept any FDs from the sandbox in order to reduce
// risk of DoS attacks.
KJ_DISALLOW_COPY(TwoPartyVatNetwork);
kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); }
......@@ -70,7 +84,8 @@ private:
class OutgoingMessageImpl;
class IncomingMessageImpl;
kj::AsyncIoStream& stream;
kj::OneOf<kj::AsyncIoStream*, kj::AsyncCapabilityStream*> stream;
uint maxFdsPerMesage;
rpc::twoparty::Side side;
MallocMessageBuilder peerVatId;
ReaderOptions receiveOptions;
......@@ -120,6 +135,7 @@ public:
explicit TwoPartyServer(Capability::Client bootstrapInterface);
void accept(kj::Own<kj::AsyncIoStream>&& connection);
void accept(kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMesage);
// Accepts the connection for servicing.
kj::Promise<void> listen(kj::ConnectionReceiver& listener);
......@@ -127,6 +143,10 @@ public:
// exception is thrown while trying to accept. You may discard the returned promise to cancel
// listening.
kj::Promise<void> listenCapStreamReceiver(kj::ConnectionReceiver& listener, uint maxFdsPerMesage);
// Listen with support for FD transfers. `listener.accept()` must return instances of
// AsyncCapabilityStream, otherwise this will crash.
private:
Capability::Client bootstrapInterface;
kj::TaskSet tasks;
......@@ -141,8 +161,12 @@ class TwoPartyClient {
public:
explicit TwoPartyClient(kj::AsyncIoStream& connection);
explicit TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage);
TwoPartyClient(kj::AsyncIoStream& connection, Capability::Client bootstrapInterface,
rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT);
TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage,
Capability::Client bootstrapInterface,
rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT);
Capability::Client bootstrap();
// Get the server's bootstrap interface.
......
This diff is collapsed.
......@@ -28,6 +28,8 @@
#include "capability.h"
#include "rpc-prelude.h"
namespace kj { class AutoCloseFd; }
namespace capnp {
template <typename VatId, typename ProvisionId, typename RecipientId,
......@@ -305,6 +307,10 @@ public:
// Get the message body, which the caller may fill in any way it wants. (The standard RPC
// implementation initializes it as a Message as defined in rpc.capnp.)
virtual void setFds(kj::Array<int> fds) {}
// Set the list of file descriptors to send along with this message, if FD passing is supported.
// An implementation may ignore this.
virtual void send() = 0;
// Send the message, or at least put it in a queue to be sent later. Note that the builder
// returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed.
......@@ -317,6 +323,14 @@ public:
virtual AnyPointer::Reader getBody() = 0;
// Get the message body, to be interpreted by the caller. (The standard RPC implementation
// interprets it as a Message as defined in rpc.capnp.)
virtual kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() { return nullptr; }
// If the transport supports attached file descriptors and some were attached to this message,
// returns them. Otherwise returns an empty array. It is intended that the caller will move the
// FDs out of this table when they are consumed, possibly leaving behind a null slot. Callers
// should be careful to check if an FD was already consumed by comparing the slot with `nullptr`.
// (We don't use Maybe here because moving from a Maybe doesn't make it null, so it would only
// add confusion. Moving from an AutoCloseFd does in fact make it null.)
};
template <typename VatId, typename ProvisionId, typename RecipientId,
......
......@@ -21,6 +21,7 @@
#include "serialize-async.h"
#include <kj/debug.h>
#include <kj/io.h>
namespace capnp {
......@@ -35,6 +36,10 @@ public:
kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
kj::Promise<kj::Maybe<size_t>> readWithFds(
kj::AsyncCapabilityStream& inputStream,
kj::ArrayPtr<kj::AutoCloseFd> fds, kj::ArrayPtr<word> scratchSpace);
// implements MessageReader ----------------------------------------
kj::ArrayPtr<const word> getSegment(uint id) override {
......@@ -79,6 +84,27 @@ kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream,
});
}
kj::Promise<kj::Maybe<size_t>> AsyncMessageReader::readWithFds(
kj::AsyncCapabilityStream& inputStream, kj::ArrayPtr<kj::AutoCloseFd> fds,
kj::ArrayPtr<word> scratchSpace) {
return inputStream.tryReadWithFds(firstWord, sizeof(firstWord), sizeof(firstWord),
fds.begin(), fds.size())
.then([this,&inputStream,KJ_CPCAP(scratchSpace)]
(kj::AsyncCapabilityStream::ReadResult result) mutable
-> kj::Promise<kj::Maybe<size_t>> {
if (result.byteCount == 0) {
return kj::Maybe<size_t>(nullptr);
} else if (result.byteCount < sizeof(firstWord)) {
// EOF in first word.
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
return kj::Maybe<size_t>(nullptr);
}
return readAfterFirstWord(inputStream, scratchSpace)
.then([result]() -> kj::Maybe<size_t> { return result.capCount; });
});
}
kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) {
if (segmentCount() == 0) {
......@@ -151,26 +177,57 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader, bool success) {
return promise.then([reader = kj::mv(reader)](bool success) mutable -> kj::Own<MessageReader> {
if (!success) {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
}
return kj::mv(reader);
}));
});
}
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader,
[](kj::Own<MessageReader>&& reader, bool success) -> kj::Maybe<kj::Own<MessageReader>> {
return promise.then([reader = kj::mv(reader)](bool success) mutable
-> kj::Maybe<kj::Own<MessageReader>> {
if (success) {
return kj::mv(reader);
} else {
return nullptr;
}
}));
});
}
kj::Promise<MessageReaderAndFds> readMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->readWithFds(input, fdSpace, scratchSpace);
return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe<size_t> nfds) mutable
-> MessageReaderAndFds {
KJ_IF_MAYBE(n, nfds) {
return { kj::mv(reader), fdSpace.slice(0, *n) };
} else {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
return { kj::mv(reader), nullptr };
}
});
}
kj::Promise<kj::Maybe<MessageReaderAndFds>> tryReadMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->readWithFds(input, fdSpace, scratchSpace);
return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe<size_t> nfds) mutable
-> kj::Maybe<MessageReaderAndFds> {
KJ_IF_MAYBE(n, nfds) {
return MessageReaderAndFds { kj::mv(reader), fdSpace.slice(0, *n) };
} else {
return nullptr;
}
});
}
// =======================================================================================
......@@ -184,10 +241,9 @@ struct WriteArrays {
kj::Array<kj::ArrayPtr<const byte>> pieces;
};
} // namespace
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
template <typename WriteFunc>
kj::Promise<void> writeMessageImpl(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments,
WriteFunc&& writeFunc) {
KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message.");
WriteArrays arrays;
......@@ -212,10 +268,28 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
arrays.pieces[i + 1] = segments[i].asBytes();
}
auto promise = output.write(arrays.pieces);
auto promise = writeFunc(arrays.pieces);
// Make sure the arrays aren't freed until the write completes.
return promise.then(kj::mvCapture(arrays, [](WriteArrays&&) {}));
}
} // namespace
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
return writeMessageImpl(segments,
[&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
return output.write(pieces);
});
}
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
return writeMessageImpl(segments,
[&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
return output.writeWithFds(pieces[0], pieces.slice(1, pieces.size()), fds);
});
}
} // namespace capnp
......@@ -51,11 +51,42 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& bu
KJ_WARN_UNUSED_RESULT;
// Write asynchronously. The parameters must remain valid until the returned promise resolves.
// -----------------------------------------------------------------------------
// Versions that support FD passing.
struct MessageReaderAndFds {
kj::Own<MessageReader> reader;
kj::ArrayPtr<kj::AutoCloseFd> fds;
};
kj::Promise<MessageReaderAndFds> readMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options = ReaderOptions(), kj::ArrayPtr<word> scratchSpace = nullptr);
// Read a message that may also have file descriptors attached, e.g. from a Unix socket with
// SCM_RIGHTS.
kj::Promise<kj::Maybe<MessageReaderAndFds>> tryReadMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options = ReaderOptions(), kj::ArrayPtr<word> scratchSpace = nullptr);
// Like `readMessage` but returns null on EOF.
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments)
KJ_WARN_UNUSED_RESULT;
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
MessageBuilder& builder)
KJ_WARN_UNUSED_RESULT;
// Write a message with FDs attached, e.g. to a Unix socket with SCM_RIGHTS.
// =======================================================================================
// inline implementation details
inline kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& builder) {
return writeMessage(output, builder.getSegmentsForOutput());
}
inline kj::Promise<void> writeMessage(
kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds, MessageBuilder& builder) {
return writeMessage(output, fds, builder.getSegmentsForOutput());
}
} // namespace capnp
......@@ -22,6 +22,8 @@
#include "test-util.h"
#include <kj/debug.h>
#include <kj/compat/gtest.h>
#include <kj/io.h>
#include <kj/miniposix.h>
namespace capnp {
namespace _ { // private
......@@ -1144,6 +1146,34 @@ kj::Promise<void> TestMoreStuffImpl::getEnormousString(GetEnormousStringContext
return kj::READY_NOW;
}
kj::Promise<void> TestMoreStuffImpl::writeToFd(WriteToFdContext context) {
auto params = context.getParams();
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(params.getFdCap1().getFd()
.then([](kj::Maybe<int> fd) {
kj::FdOutputStream(KJ_ASSERT_NONNULL(fd)).write("foo", 3);
}));
promises.add(params.getFdCap2().getFd()
.then([context](kj::Maybe<int> fd) mutable {
context.getResults().setSecondFdPresent(fd != nullptr);
KJ_IF_MAYBE(f, fd) {
kj::FdOutputStream(*f).write("bar", 3);
}
}));
int pair[2];
KJ_SYSCALL(kj::miniposix::pipe(pair));
kj::AutoCloseFd in(pair[0]);
kj::AutoCloseFd out(pair[1]);
kj::FdOutputStream(kj::mv(out)).write("baz", 3);
context.getResults().setFdCap3(kj::heap<TestFdCap>(kj::mv(in)));
return kj::joinPromises(promises.finish());
}
#endif // !CAPNP_LITE
} // namespace _ (private)
......
......@@ -32,6 +32,7 @@
#if !CAPNP_LITE
#include "dynamic.h"
#include <kj/io.h>
#endif // !CAPNP_LITE
// TODO(cleanup): Auto-generate stringification functions for union discriminants.
......@@ -274,6 +275,8 @@ public:
kj::Promise<void> getEnormousString(GetEnormousStringContext context) override;
kj::Promise<void> writeToFd(WriteToFdContext context) override;
private:
int& callCount;
int& handleCount;
......@@ -303,6 +306,18 @@ private:
TestInterfaceImpl impl;
};
class TestFdCap final: public test::TestInterface::Server {
// Implementation of TestInterface that wraps a file descriptor.
public:
TestFdCap(kj::AutoCloseFd fd): fd(kj::mv(fd)) {}
kj::Maybe<int> getFd() override { return fd.get(); }
private:
kj::AutoCloseFd fd;
};
#endif // !CAPNP_LITE
} // namespace _ (private)
......
......@@ -860,6 +860,11 @@ interface TestMoreStuff extends(TestCallOrder) {
getEnormousString @11 () -> (str :Text);
# Attempts to return an 100MB string. Should always fail.
writeToFd @13 (fdCap1 :TestInterface, fdCap2 :TestInterface)
-> (fdCap3 :TestInterface, secondFdPresent :Bool);
# Expects fdCap1 and fdCap2 wrap socket file descriptors. Writes "foo" to the first and "bar" to
# the second. Also creates a socketpair, writes "baz" to one end, and returns the other end.
}
interface TestMembrane {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment