Commit e9c46ac1 authored by Kenton Varda's avatar Kenton Varda

Implement FD passing in Cap'n Proto.

An endpoint (subclass of `Capability::Server`) may override `kj::Maybe<int> getFd()` to expose an underlying file descriptor.

A remote client may use `Capability::Client::getFd()` on the endpoint's capability to get that FD.

The client and server must explicitly opt into FD passing by passing a max-FDs-per-message limit to TwoPartyVatNetwork and using Unix sockets as the transport. Nothing other than that is needed.
parent 84c7ba81
...@@ -69,6 +69,19 @@ Capability::Client::Client(decltype(nullptr)) ...@@ -69,6 +69,19 @@ Capability::Client::Client(decltype(nullptr))
Capability::Client::Client(kj::Exception&& exception) Capability::Client::Client(kj::Exception&& exception)
: hook(newBrokenCap(kj::mv(exception))) {} : hook(newBrokenCap(kj::mv(exception))) {}
kj::Promise<kj::Maybe<int>> Capability::Client::getFd() {
auto fd = hook->getFd();
if (fd != nullptr) {
return fd;
} else KJ_IF_MAYBE(promise, hook->whenMoreResolved()) {
return promise->attach(hook->addRef()).then([](kj::Own<ClientHook> newHook) {
return Client(kj::mv(newHook)).getFd();
});
} else {
return kj::Maybe<int>(nullptr);
}
}
kj::Promise<void> Capability::Server::internalUnimplemented( kj::Promise<void> Capability::Server::internalUnimplemented(
const char* actualInterfaceName, uint64_t requestedTypeId) { const char* actualInterfaceName, uint64_t requestedTypeId) {
return KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.", return KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.",
...@@ -374,6 +387,14 @@ public: ...@@ -374,6 +387,14 @@ public:
return nullptr; return nullptr;
} }
kj::Maybe<int> getFd() override {
KJ_IF_MAYBE(r, redirect) {
return r->get()->getFd();
} else {
return nullptr;
}
}
private: private:
typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork; typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork;
...@@ -524,6 +545,10 @@ public: ...@@ -524,6 +545,10 @@ public:
} }
} }
kj::Maybe<int> getFd() override {
return server->getFd();
}
private: private:
kj::Own<Capability::Server> server; kj::Own<Capability::Server> server;
_::CapabilityServerSetBase* capServerSet = nullptr; _::CapabilityServerSetBase* capServerSet = nullptr;
...@@ -616,6 +641,10 @@ public: ...@@ -616,6 +641,10 @@ public:
return brand; return brand;
} }
kj::Maybe<int> getFd() override {
return nullptr;
}
private: private:
kj::Exception exception; kj::Exception exception;
bool resolved; bool resolved;
......
...@@ -206,6 +206,19 @@ public: ...@@ -206,6 +206,19 @@ public:
// Make a request without knowing the types of the params or results. You specify the type ID // Make a request without knowing the types of the params or results. You specify the type ID
// and method number manually. // and method number manually.
kj::Promise<kj::Maybe<int>> getFd();
// If the capability's server implemented Capability::Server::getFd() returning non-null, and all
// RPC links between the client and server support FD passing, returns a file descriptor pointing
// to the same undelying file description as the server did. Returns null if the server provided
// no FD or if FD passing was unavailable at some intervening link.
//
// This returns a Promise to handle the case of an unresolved promise capability, e.g. a
// pipelined capability. The promise resolves no later than when the capability settles, i.e.
// the same time `whenResolved()` would complete.
//
// The file descriptor will remain open at least as long as the Capability::Client remains alive.
// If you need it to last longer, you will need to `dup()` it.
// TODO(someday): method(s) for Join // TODO(someday): method(s) for Join
protected: protected:
...@@ -331,6 +344,11 @@ public: ...@@ -331,6 +344,11 @@ public:
// is no longer needed. `context` may be used to allocate the output struct and deal with // is no longer needed. `context` may be used to allocate the output struct and deal with
// cancellation. // cancellation.
virtual kj::Maybe<int> getFd() { return nullptr; }
// If this capability is backed by a file descriptor that is safe to directly expose to clients,
// returns that FD. When FD passing has been enabled in the RPC layer, this FD may be sent to
// other processes along with the capability.
// TODO(someday): Method which can optionally be overridden to implement Join when the object is // TODO(someday): Method which can optionally be overridden to implement Join when the object is
// a proxy. // a proxy.
...@@ -563,6 +581,10 @@ public: ...@@ -563,6 +581,10 @@ public:
// Otherwise, return nullptr. Default implementation (which everyone except LocalClient should // Otherwise, return nullptr. Default implementation (which everyone except LocalClient should
// use) always returns nullptr. // use) always returns nullptr.
virtual kj::Maybe<int> getFd() = 0;
// Implements Capability::Client::getFd(). If this returns null but whenMoreResolved() returns
// non-null, then Capability::Client::getFd() waits for resolution and tries again.
static kj::Own<ClientHook> from(Capability::Client client) { return kj::mv(client.hook); } static kj::Own<ClientHook> from(Capability::Client client) { return kj::mv(client.hook); }
}; };
......
...@@ -469,6 +469,13 @@ public: ...@@ -469,6 +469,13 @@ public:
return MEMBRANE_BRAND; return MEMBRANE_BRAND;
} }
kj::Maybe<int> getFd() override {
// We can't let FDs pass over membranes because we have no way to enforce the membrane policy
// on them. If the MembranePolicy wishes to explicitly permit certain FDs to pass, it can
// always do so by overriding the appropriate policy methods.
return nullptr;
}
private: private:
kj::Own<ClientHook> inner; kj::Own<ClientHook> inner;
kj::Own<MembranePolicy> policy; kj::Own<MembranePolicy> policy;
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/thread.h> #include <kj/thread.h>
#include <kj/compat/gtest.h> #include <kj/compat/gtest.h>
#include <kj/miniposix.h>
// TODO(cleanup): Auto-generate stringification functions for union discriminants. // TODO(cleanup): Auto-generate stringification functions for union discriminants.
namespace capnp { namespace capnp {
...@@ -419,6 +420,104 @@ TEST(TwoPartyNetwork, BootstrapFactory) { ...@@ -419,6 +420,104 @@ TEST(TwoPartyNetwork, BootstrapFactory) {
EXPECT_TRUE(bootstrapFactory.called); EXPECT_TRUE(bootstrapFactory.called);
} }
// =======================================================================================
#if !_WIN32
KJ_TEST("send FD over RPC") {
auto io = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
TwoPartyServer server(kj::heap<TestMoreStuffImpl>(callCount, handleCount));
auto pipe = io.provider->newCapabilityPipe();
server.accept(kj::mv(pipe.ends[0]), 2);
TwoPartyClient client(*pipe.ends[1], 2);
auto cap = client.bootstrap().castAs<test::TestMoreStuff>();
int pipeFds[2];
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in1(pipeFds[0]);
kj::AutoCloseFd out1(pipeFds[1]);
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in2(pipeFds[0]);
kj::AutoCloseFd out2(pipeFds[1]);
capnp::RemotePromise<test::TestMoreStuff::WriteToFdResults> promise = nullptr;
{
auto req = cap.writeToFdRequest();
// Order reversal intentional, just trying to mix things up.
req.setFdCap1(kj::heap<TestFdCap>(kj::mv(out2)));
req.setFdCap2(kj::heap<TestFdCap>(kj::mv(out1)));
promise = req.send();
}
int in3 = KJ_ASSERT_NONNULL(promise.getFdCap3().getFd().wait(io.waitScope));
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in3))->readAllText().wait(io.waitScope)
== "baz");
{
auto promise2 = kj::mv(promise); // make sure the PipelineHook also goes out of scope
auto response = promise2.wait(io.waitScope);
KJ_EXPECT(response.getSecondFdPresent());
}
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in1))->readAllText().wait(io.waitScope)
== "bar");
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in2))->readAllText().wait(io.waitScope)
== "foo");
}
KJ_TEST("FD per message limit") {
auto io = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
TwoPartyServer server(kj::heap<TestMoreStuffImpl>(callCount, handleCount));
auto pipe = io.provider->newCapabilityPipe();
server.accept(kj::mv(pipe.ends[0]), 1);
TwoPartyClient client(*pipe.ends[1], 1);
auto cap = client.bootstrap().castAs<test::TestMoreStuff>();
int pipeFds[2];
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in1(pipeFds[0]);
kj::AutoCloseFd out1(pipeFds[1]);
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in2(pipeFds[0]);
kj::AutoCloseFd out2(pipeFds[1]);
capnp::RemotePromise<test::TestMoreStuff::WriteToFdResults> promise = nullptr;
{
auto req = cap.writeToFdRequest();
// Order reversal intentional, just trying to mix things up.
req.setFdCap1(kj::heap<TestFdCap>(kj::mv(out2)));
req.setFdCap2(kj::heap<TestFdCap>(kj::mv(out1)));
promise = req.send();
}
int in3 = KJ_ASSERT_NONNULL(promise.getFdCap3().getFd().wait(io.waitScope));
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in3))->readAllText().wait(io.waitScope)
== "baz");
{
auto promise2 = kj::mv(promise); // make sure the PipelineHook also goes out of scope
auto response = promise2.wait(io.waitScope);
KJ_EXPECT(!response.getSecondFdPresent());
}
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in1))->readAllText().wait(io.waitScope)
== "");
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in2))->readAllText().wait(io.waitScope)
== "foo");
}
#endif // !_WIN32
} // namespace } // namespace
} // namespace _ } // namespace _
} // namespace capnp } // namespace capnp
...@@ -22,12 +22,13 @@ ...@@ -22,12 +22,13 @@
#include "rpc-twoparty.h" #include "rpc-twoparty.h"
#include "serialize-async.h" #include "serialize-async.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/io.h>
namespace capnp { namespace capnp {
TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions) ReaderOptions receiveOptions)
: stream(stream), side(side), peerVatId(4), : stream(&stream), maxFdsPerMesage(0), side(side), peerVatId(4),
receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) { receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) {
peerVatId.initRoot<rpc::twoparty::VatId>().setSide( peerVatId.initRoot<rpc::twoparty::VatId>().setSide(
side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER
...@@ -38,6 +39,13 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty: ...@@ -38,6 +39,13 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty:
disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller); disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller);
} }
TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMesage,
rpc::twoparty::Side side, ReaderOptions receiveOptions)
: TwoPartyVatNetwork(stream, side, receiveOptions) {
this->stream = &stream;
this->maxFdsPerMesage = maxFdsPerMesage;
}
void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const { void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const {
if (--refcount == 0) { if (--refcount == 0) {
fulfiller->fulfill(); fulfiller->fulfill();
...@@ -81,6 +89,12 @@ public: ...@@ -81,6 +89,12 @@ public:
return message.getRoot<AnyPointer>(); return message.getRoot<AnyPointer>();
} }
void setFds(kj::Array<int> fds) override {
if (network.stream.is<kj::AsyncCapabilityStream*>()) {
this->fds = kj::mv(fds);
}
}
void send() override { void send() override {
size_t size = 0; size_t size = 0;
for (auto& segment: message.getSegmentsForOutput()) { for (auto& segment: message.getSegmentsForOutput()) {
...@@ -98,7 +112,15 @@ public: ...@@ -98,7 +112,15 @@ public:
// Note that if the write fails, all further writes will be skipped due to the exception. // Note that if the write fails, all further writes will be skipped due to the exception.
// We never actually handle this exception because we assume the read end will fail as well // We never actually handle this exception because we assume the read end will fail as well
// and it's cleaner to handle the failure there. // and it's cleaner to handle the failure there.
return writeMessage(network.stream, message); KJ_SWITCH_ONEOF(network.stream) {
KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
return writeMessage(*ioStream, message);
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
return writeMessage(*capStream, fds, message);
}
}
KJ_UNREACHABLE;
}).attach(kj::addRef(*this)) }).attach(kj::addRef(*this))
// Note that it's important that the eagerlyEvaluate() come *after* the attach() because // Note that it's important that the eagerlyEvaluate() come *after* the attach() because
// otherwise the message (and any capabilities in it) will not be released until a new // otherwise the message (and any capabilities in it) will not be released until a new
...@@ -109,18 +131,32 @@ public: ...@@ -109,18 +131,32 @@ public:
private: private:
TwoPartyVatNetwork& network; TwoPartyVatNetwork& network;
MallocMessageBuilder message; MallocMessageBuilder message;
kj::Array<int> fds;
}; };
class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage { class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage {
public: public:
IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {} IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {}
IncomingMessageImpl(MessageReaderAndFds init, kj::Array<kj::AutoCloseFd> fdSpace)
: message(kj::mv(init.reader)),
fdSpace(kj::mv(fdSpace)),
fds(init.fds) {
KJ_DASSERT(this->fds.begin() == this->fdSpace.begin());
}
AnyPointer::Reader getBody() override { AnyPointer::Reader getBody() override {
return message->getRoot<AnyPointer>(); return message->getRoot<AnyPointer>();
} }
kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() override {
return fds;
}
private: private:
kj::Own<MessageReader> message; kj::Own<MessageReader> message;
kj::Array<kj::AutoCloseFd> fdSpace;
kj::ArrayPtr<kj::AutoCloseFd> fds;
}; };
rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() { rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() {
...@@ -132,9 +168,11 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSeg ...@@ -132,9 +168,11 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSeg
} }
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() { kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
return kj::evalLater([&]() { return kj::evalLater([this]() {
return tryReadMessage(stream, receiveOptions) KJ_SWITCH_ONEOF(stream) {
.then([&](kj::Maybe<kj::Own<MessageReader>>&& message) KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
return tryReadMessage(*ioStream, receiveOptions)
.then([](kj::Maybe<kj::Own<MessageReader>>&& message)
-> kj::Maybe<kj::Own<IncomingRpcMessage>> { -> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, message) { KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m))); return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
...@@ -142,12 +180,40 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI ...@@ -142,12 +180,40 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI
return nullptr; return nullptr;
} }
}); });
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
auto fdSpace = kj::heapArray<kj::AutoCloseFd>(maxFdsPerMesage);
auto promise = tryReadMessage(*capStream, fdSpace, receiveOptions);
return promise.then([fdSpace = kj::mv(fdSpace)]
(kj::Maybe<MessageReaderAndFds>&& messageAndFds) mutable
-> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, messageAndFds) {
if (m->fds.size() > 0) {
return kj::Own<IncomingRpcMessage>(
kj::heap<IncomingMessageImpl>(kj::mv(*m), kj::mv(fdSpace)));
} else {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(m->reader)));
}
} else {
return nullptr;
}
});
}
}
KJ_UNREACHABLE;
}); });
} }
kj::Promise<void> TwoPartyVatNetwork::shutdown() { kj::Promise<void> TwoPartyVatNetwork::shutdown() {
kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() { kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() {
stream.shutdownWrite(); KJ_SWITCH_ONEOF(stream) {
KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
ioStream->shutdownWrite();
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
capStream->shutdownWrite();
}
}
}); });
previousWrite = nullptr; previousWrite = nullptr;
return kj::mv(result); return kj::mv(result);
...@@ -168,6 +234,14 @@ struct TwoPartyServer::AcceptedConnection { ...@@ -168,6 +234,14 @@ struct TwoPartyServer::AcceptedConnection {
: connection(kj::mv(connectionParam)), : connection(kj::mv(connectionParam)),
network(*connection, rpc::twoparty::Side::SERVER), network(*connection, rpc::twoparty::Side::SERVER),
rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {} rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
explicit AcceptedConnection(Capability::Client bootstrapInterface,
kj::Own<kj::AsyncCapabilityStream>&& connectionParam,
uint maxFdsPerMesage)
: connection(kj::mv(connectionParam)),
network(kj::downcast<kj::AsyncCapabilityStream>(*connection),
maxFdsPerMesage, rpc::twoparty::Side::SERVER),
rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
}; };
void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) { void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
...@@ -178,6 +252,15 @@ void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) { ...@@ -178,6 +252,15 @@ void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
tasks.add(promise.attach(kj::mv(connectionState))); tasks.add(promise.attach(kj::mv(connectionState)));
} }
void TwoPartyServer::accept(kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMesage) {
auto connectionState = kj::heap<AcceptedConnection>(
bootstrapInterface, kj::mv(connection), maxFdsPerMesage);
// Run the connection until disconnect.
auto promise = connectionState->network.onDisconnect();
tasks.add(promise.attach(kj::mv(connectionState)));
}
kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) { kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
return listener.accept() return listener.accept()
.then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable { .then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable {
...@@ -186,6 +269,15 @@ kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) { ...@@ -186,6 +269,15 @@ kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
}); });
} }
kj::Promise<void> TwoPartyServer::listenCapStreamReceiver(
kj::ConnectionReceiver& listener, uint maxFdsPerMesage) {
return listener.accept()
.then([this,&listener,maxFdsPerMesage](kj::Own<kj::AsyncIoStream>&& connection) mutable {
accept(connection.downcast<kj::AsyncCapabilityStream>(), maxFdsPerMesage);
return listenCapStreamReceiver(listener, maxFdsPerMesage);
});
}
void TwoPartyServer::taskFailed(kj::Exception&& exception) { void TwoPartyServer::taskFailed(kj::Exception&& exception) {
KJ_LOG(ERROR, exception); KJ_LOG(ERROR, exception);
} }
...@@ -195,12 +287,22 @@ TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection) ...@@ -195,12 +287,22 @@ TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection)
rpcSystem(makeRpcClient(network)) {} rpcSystem(makeRpcClient(network)) {}
TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage)
: network(connection, maxFdsPerMesage, rpc::twoparty::Side::CLIENT),
rpcSystem(makeRpcClient(network)) {}
TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection, TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection,
Capability::Client bootstrapInterface, Capability::Client bootstrapInterface,
rpc::twoparty::Side side) rpc::twoparty::Side side)
: network(connection, side), : network(connection, side),
rpcSystem(network, bootstrapInterface) {} rpcSystem(network, bootstrapInterface) {}
TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage,
Capability::Client bootstrapInterface,
rpc::twoparty::Side side)
: network(connection, maxFdsPerMesage, side),
rpcSystem(network, bootstrapInterface) {}
Capability::Client TwoPartyClient::bootstrap() { Capability::Client TwoPartyClient::bootstrap() {
MallocMessageBuilder message(4); MallocMessageBuilder message(4);
auto vatId = message.getRoot<rpc::twoparty::VatId>(); auto vatId = message.getRoot<rpc::twoparty::VatId>();
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "message.h" #include "message.h"
#include <kj/async-io.h> #include <kj/async-io.h>
#include <capnp/rpc-twoparty.capnp.h> #include <capnp/rpc-twoparty.capnp.h>
#include <kj/one-of.h>
namespace capnp { namespace capnp {
...@@ -53,6 +54,19 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, ...@@ -53,6 +54,19 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase,
public: public:
TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions = ReaderOptions()); ReaderOptions receiveOptions = ReaderOptions());
TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMesage,
rpc::twoparty::Side side, ReaderOptions receiveOptions = ReaderOptions());
// To support FD passing, pass an AsyncCapabilityStream and `maxFdsPerMesage`, which specifies
// the maximum number of file descriptors to accept from the peer in any one RPC message. It is
// important to keep maxFdsPerMesage low in order to stop DoS attacks that fill up your FD table.
//
// Note that this limit applies only to incoming messages; outgoing messages are allowed to have
// more FDs. Sometimes it makes sense to enforce a limit of zero in one direction while having
// a non-zero limit in the other. For example, in a supervisor/sandbox scenario, typically there
// are many use cases for passing FDs from supervisor to sandbox but no use case for vice versa.
// The supervisor may be configured not to accept any FDs from the sandbox in order to reduce
// risk of DoS attacks.
KJ_DISALLOW_COPY(TwoPartyVatNetwork); KJ_DISALLOW_COPY(TwoPartyVatNetwork);
kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); } kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); }
...@@ -70,7 +84,8 @@ private: ...@@ -70,7 +84,8 @@ private:
class OutgoingMessageImpl; class OutgoingMessageImpl;
class IncomingMessageImpl; class IncomingMessageImpl;
kj::AsyncIoStream& stream; kj::OneOf<kj::AsyncIoStream*, kj::AsyncCapabilityStream*> stream;
uint maxFdsPerMesage;
rpc::twoparty::Side side; rpc::twoparty::Side side;
MallocMessageBuilder peerVatId; MallocMessageBuilder peerVatId;
ReaderOptions receiveOptions; ReaderOptions receiveOptions;
...@@ -120,6 +135,7 @@ public: ...@@ -120,6 +135,7 @@ public:
explicit TwoPartyServer(Capability::Client bootstrapInterface); explicit TwoPartyServer(Capability::Client bootstrapInterface);
void accept(kj::Own<kj::AsyncIoStream>&& connection); void accept(kj::Own<kj::AsyncIoStream>&& connection);
void accept(kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMesage);
// Accepts the connection for servicing. // Accepts the connection for servicing.
kj::Promise<void> listen(kj::ConnectionReceiver& listener); kj::Promise<void> listen(kj::ConnectionReceiver& listener);
...@@ -127,6 +143,10 @@ public: ...@@ -127,6 +143,10 @@ public:
// exception is thrown while trying to accept. You may discard the returned promise to cancel // exception is thrown while trying to accept. You may discard the returned promise to cancel
// listening. // listening.
kj::Promise<void> listenCapStreamReceiver(kj::ConnectionReceiver& listener, uint maxFdsPerMesage);
// Listen with support for FD transfers. `listener.accept()` must return instances of
// AsyncCapabilityStream, otherwise this will crash.
private: private:
Capability::Client bootstrapInterface; Capability::Client bootstrapInterface;
kj::TaskSet tasks; kj::TaskSet tasks;
...@@ -141,8 +161,12 @@ class TwoPartyClient { ...@@ -141,8 +161,12 @@ class TwoPartyClient {
public: public:
explicit TwoPartyClient(kj::AsyncIoStream& connection); explicit TwoPartyClient(kj::AsyncIoStream& connection);
explicit TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage);
TwoPartyClient(kj::AsyncIoStream& connection, Capability::Client bootstrapInterface, TwoPartyClient(kj::AsyncIoStream& connection, Capability::Client bootstrapInterface,
rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT); rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT);
TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage,
Capability::Client bootstrapInterface,
rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT);
Capability::Client bootstrap(); Capability::Client bootstrap();
// Get the server's bootstrap interface. // Get the server's bootstrap interface.
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <map> #include <map>
#include <queue> #include <queue>
#include <capnp/rpc.capnp.h> #include <capnp/rpc.capnp.h>
#include <kj/io.h>
namespace capnp { namespace capnp {
namespace _ { // private namespace _ { // private
...@@ -575,7 +576,8 @@ private: ...@@ -575,7 +576,8 @@ private:
RpcClient(RpcConnectionState& connectionState) RpcClient(RpcConnectionState& connectionState)
: connectionState(kj::addRef(connectionState)) {} : connectionState(kj::addRef(connectionState)) {}
virtual kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) = 0; virtual kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor,
kj::Vector<int>& fds) = 0;
// Writes a CapDescriptor referencing this client. The CapDescriptor must be sent as part of // Writes a CapDescriptor referencing this client. The CapDescriptor must be sent as part of
// the very next message sent on the connection, as it may become invalid if other things // the very next message sent on the connection, as it may become invalid if other things
// happen. // happen.
...@@ -710,8 +712,9 @@ private: ...@@ -710,8 +712,9 @@ private:
// A ClientHook that wraps an entry in the import table. // A ClientHook that wraps an entry in the import table.
public: public:
ImportClient(RpcConnectionState& connectionState, ImportId importId) ImportClient(RpcConnectionState& connectionState, ImportId importId,
: RpcClient(connectionState), importId(importId) {} kj::Maybe<kj::AutoCloseFd> fd)
: RpcClient(connectionState), importId(importId), fd(kj::mv(fd)) {}
~ImportClient() noexcept(false) { ~ImportClient() noexcept(false) {
unwindDetector.catchExceptionsIfUnwinding([&]() { unwindDetector.catchExceptionsIfUnwinding([&]() {
...@@ -736,12 +739,19 @@ private: ...@@ -736,12 +739,19 @@ private:
}); });
} }
void setFdIfMissing(kj::Maybe<kj::AutoCloseFd> newFd) {
if (fd == nullptr) {
fd = kj::mv(newFd);
}
}
void addRemoteRef() { void addRemoteRef() {
// Add a new RemoteRef and return a new ref to this client representing it. // Add a new RemoteRef and return a new ref to this client representing it.
++remoteRefcount; ++remoteRefcount;
} }
kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) override { kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor,
kj::Vector<int>& fds) override {
descriptor.setReceiverHosted(importId); descriptor.setReceiverHosted(importId);
return nullptr; return nullptr;
} }
...@@ -766,8 +776,13 @@ private: ...@@ -766,8 +776,13 @@ private:
return nullptr; return nullptr;
} }
kj::Maybe<int> getFd() override {
return fd.map([](auto& f) { return f.get(); });
}
private: private:
ImportId importId; ImportId importId;
kj::Maybe<kj::AutoCloseFd> fd;
uint remoteRefcount = 0; uint remoteRefcount = 0;
// Number of times we've received this import from the peer. // Number of times we've received this import from the peer.
...@@ -784,7 +799,8 @@ private: ...@@ -784,7 +799,8 @@ private:
kj::Array<PipelineOp>&& ops) kj::Array<PipelineOp>&& ops)
: RpcClient(connectionState), questionRef(kj::mv(questionRef)), ops(kj::mv(ops)) {} : RpcClient(connectionState), questionRef(kj::mv(questionRef)), ops(kj::mv(ops)) {}
kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) override { kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor,
kj::Vector<int>& fds) override {
auto promisedAnswer = descriptor.initReceiverAnswer(); auto promisedAnswer = descriptor.initReceiverAnswer();
promisedAnswer.setQuestionId(questionRef->getId()); promisedAnswer.setQuestionId(questionRef->getId());
promisedAnswer.adoptTransform(fromPipelineOps( promisedAnswer.adoptTransform(fromPipelineOps(
...@@ -814,6 +830,10 @@ private: ...@@ -814,6 +830,10 @@ private:
return nullptr; return nullptr;
} }
kj::Maybe<int> getFd() override {
return nullptr;
}
private: private:
kj::Own<QuestionRef> questionRef; kj::Own<QuestionRef> questionRef;
kj::Array<PipelineOp> ops; kj::Array<PipelineOp> ops;
...@@ -867,9 +887,10 @@ private: ...@@ -867,9 +887,10 @@ private:
} }
} }
kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) override { kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor,
kj::Vector<int>& fds) override {
receivedCall = true; receivedCall = true;
return connectionState->writeDescriptor(*cap, descriptor); return connectionState->writeDescriptor(*cap, descriptor, fds);
} }
kj::Maybe<kj::Own<ClientHook>> writeTarget( kj::Maybe<kj::Own<ClientHook>> writeTarget(
...@@ -939,6 +960,20 @@ private: ...@@ -939,6 +960,20 @@ private:
return fork.addBranch(); return fork.addBranch();
} }
kj::Maybe<int> getFd() override {
if (isResolved) {
return cap->getFd();
} else {
// In theory, before resolution, the ImportClient for the promise could have an FD
// attached, if the promise itself was presented with an attached FD. However, we can't
// really return that one here because it may be closed when we get the Resolve message
// later. In theory we could have the PromiseClient itself take ownership of an FD that
// arrived attached to a promise cap, but the use case for that is questionable. I'm
// keeping it simple for now.
return nullptr;
}
}
private: private:
bool isResolved; bool isResolved;
kj::Own<ClientHook> cap; kj::Own<ClientHook> cap;
...@@ -1016,8 +1051,9 @@ private: ...@@ -1016,8 +1051,9 @@ private:
: RpcClient(*inner.connectionState), : RpcClient(*inner.connectionState),
inner(kj::addRef(inner)) {} inner(kj::addRef(inner)) {}
kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) override { kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor,
return inner->writeDescriptor(descriptor); kj::Vector<int>& fds) override {
return inner->writeDescriptor(descriptor, fds);
} }
kj::Maybe<kj::Own<ClientHook>> writeTarget(rpc::MessageTarget::Builder target) override { kj::Maybe<kj::Own<ClientHook>> writeTarget(rpc::MessageTarget::Builder target) override {
...@@ -1045,11 +1081,16 @@ private: ...@@ -1045,11 +1081,16 @@ private:
return nullptr; return nullptr;
} }
kj::Maybe<int> getFd() override {
return nullptr;
}
private: private:
kj::Own<RpcClient> inner; kj::Own<RpcClient> inner;
}; };
kj::Maybe<ExportId> writeDescriptor(ClientHook& cap, rpc::CapDescriptor::Builder descriptor) { kj::Maybe<ExportId> writeDescriptor(ClientHook& cap, rpc::CapDescriptor::Builder descriptor,
kj::Vector<int>& fds) {
// Write a descriptor for the given capability. // Write a descriptor for the given capability.
// Find the innermost wrapped capability. // Find the innermost wrapped capability.
...@@ -1062,8 +1103,13 @@ private: ...@@ -1062,8 +1103,13 @@ private:
} }
} }
KJ_IF_MAYBE(fd, inner->getFd()) {
descriptor.setAttachedFd(fds.size());
fds.add(kj::mv(*fd));
}
if (inner->getBrand() == this) { if (inner->getBrand() == this) {
return kj::downcast<RpcClient>(*inner).writeDescriptor(descriptor); return kj::downcast<RpcClient>(*inner).writeDescriptor(descriptor, fds);
} else { } else {
auto iter = exportsByCap.find(inner); auto iter = exportsByCap.find(inner);
if (iter != exportsByCap.end()) { if (iter != exportsByCap.end()) {
...@@ -1094,12 +1140,12 @@ private: ...@@ -1094,12 +1140,12 @@ private:
} }
kj::Array<ExportId> writeDescriptors(kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> capTable, kj::Array<ExportId> writeDescriptors(kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> capTable,
rpc::Payload::Builder payload) { rpc::Payload::Builder payload, kj::Vector<int>& fds) {
auto capTableBuilder = payload.initCapTable(capTable.size()); auto capTableBuilder = payload.initCapTable(capTable.size());
kj::Vector<ExportId> exports(capTable.size()); kj::Vector<ExportId> exports(capTable.size());
for (uint i: kj::indices(capTable)) { for (uint i: kj::indices(capTable)) {
KJ_IF_MAYBE(cap, capTable[i]) { KJ_IF_MAYBE(cap, capTable[i]) {
KJ_IF_MAYBE(exportId, writeDescriptor(**cap, capTableBuilder[i])) { KJ_IF_MAYBE(exportId, writeDescriptor(**cap, capTableBuilder[i], fds)) {
exports.add(*exportId); exports.add(*exportId);
} }
} else { } else {
...@@ -1199,7 +1245,9 @@ private: ...@@ -1199,7 +1245,9 @@ private:
messageSizeHint<rpc::Resolve>() + sizeInWords<rpc::CapDescriptor>() + 16); messageSizeHint<rpc::Resolve>() + sizeInWords<rpc::CapDescriptor>() + 16);
auto resolve = message->getBody().initAs<rpc::Message>().initResolve(); auto resolve = message->getBody().initAs<rpc::Message>().initResolve();
resolve.setPromiseId(exportId); resolve.setPromiseId(exportId);
writeDescriptor(*exp.clientHook, resolve.initCap()); kj::Vector<int> fds;
writeDescriptor(*exp.clientHook, resolve.initCap(), fds);
message->setFds(fds.releaseAsArray());
message->send(); message->send();
return kj::READY_NOW; return kj::READY_NOW;
...@@ -1220,7 +1268,7 @@ private: ...@@ -1220,7 +1268,7 @@ private:
// ===================================================================================== // =====================================================================================
// Interpreting CapDescriptor // Interpreting CapDescriptor
kj::Own<ClientHook> import(ImportId importId, bool isPromise) { kj::Own<ClientHook> import(ImportId importId, bool isPromise, kj::Maybe<kj::AutoCloseFd> fd) {
// Receive a new import. // Receive a new import.
auto& import = imports[importId]; auto& import = imports[importId];
...@@ -1229,8 +1277,17 @@ private: ...@@ -1229,8 +1277,17 @@ private:
// Create the ImportClient, or if one already exists, use it. // Create the ImportClient, or if one already exists, use it.
KJ_IF_MAYBE(c, import.importClient) { KJ_IF_MAYBE(c, import.importClient) {
importClient = kj::addRef(*c); importClient = kj::addRef(*c);
// If the same import is introduced multiple times, and it is missing an FD the first time,
// but it has one on a later attempt, we want to attach the later one. This could happen
// because the first introduction was part of a message that had too many other FDs and went
// over the per-message limit. Perhaps the protocol design is such that this other message
// doesn't really care if the FDs are transferred or not, but the later message really does
// care; it would be bad if the previous message blocked later messages from delivering the
// FD just because it happened to reference the same capability.
importClient->setFdIfMissing(kj::mv(fd));
} else { } else {
importClient = kj::refcounted<ImportClient>(*this, importId); importClient = kj::refcounted<ImportClient>(*this, importId, kj::mv(fd));
import.importClient = *importClient; import.importClient = *importClient;
} }
...@@ -1262,15 +1319,22 @@ private: ...@@ -1262,15 +1319,22 @@ private:
} }
} }
kj::Maybe<kj::Own<ClientHook>> receiveCap(rpc::CapDescriptor::Reader descriptor) { kj::Maybe<kj::Own<ClientHook>> receiveCap(rpc::CapDescriptor::Reader descriptor,
kj::ArrayPtr<kj::AutoCloseFd> fds) {
uint fdIndex = descriptor.getAttachedFd();
kj::Maybe<kj::AutoCloseFd> fd;
if (fdIndex < fds.size() && fds[fdIndex] != nullptr) {
fd = kj::mv(fds[fdIndex]);
}
switch (descriptor.which()) { switch (descriptor.which()) {
case rpc::CapDescriptor::NONE: case rpc::CapDescriptor::NONE:
return nullptr; return nullptr;
case rpc::CapDescriptor::SENDER_HOSTED: case rpc::CapDescriptor::SENDER_HOSTED:
return import(descriptor.getSenderHosted(), false); return import(descriptor.getSenderHosted(), false, kj::mv(fd));
case rpc::CapDescriptor::SENDER_PROMISE: case rpc::CapDescriptor::SENDER_PROMISE:
return import(descriptor.getSenderPromise(), true); return import(descriptor.getSenderPromise(), true, kj::mv(fd));
case rpc::CapDescriptor::RECEIVER_HOSTED: case rpc::CapDescriptor::RECEIVER_HOSTED:
KJ_IF_MAYBE(exp, exports.find(descriptor.getReceiverHosted())) { KJ_IF_MAYBE(exp, exports.find(descriptor.getReceiverHosted())) {
...@@ -1299,7 +1363,7 @@ private: ...@@ -1299,7 +1363,7 @@ private:
case rpc::CapDescriptor::THIRD_PARTY_HOSTED: case rpc::CapDescriptor::THIRD_PARTY_HOSTED:
// We don't support third-party caps, so use the vine instead. // We don't support third-party caps, so use the vine instead.
return import(descriptor.getThirdPartyHosted().getVineId(), false); return import(descriptor.getThirdPartyHosted().getVineId(), false, kj::mv(fd));
default: default:
KJ_FAIL_REQUIRE("unknown CapDescriptor type") { break; } KJ_FAIL_REQUIRE("unknown CapDescriptor type") { break; }
...@@ -1307,10 +1371,11 @@ private: ...@@ -1307,10 +1371,11 @@ private:
} }
} }
kj::Array<kj::Maybe<kj::Own<ClientHook>>> receiveCaps(List<rpc::CapDescriptor>::Reader capTable) { kj::Array<kj::Maybe<kj::Own<ClientHook>>> receiveCaps(List<rpc::CapDescriptor>::Reader capTable,
kj::ArrayPtr<kj::AutoCloseFd> fds) {
auto result = kj::heapArrayBuilder<kj::Maybe<kj::Own<ClientHook>>>(capTable.size()); auto result = kj::heapArrayBuilder<kj::Maybe<kj::Own<ClientHook>>>(capTable.size());
for (auto cap: capTable) { for (auto cap: capTable) {
result.add(receiveCap(cap)); result.add(receiveCap(cap, fds));
} }
return result.finish(); return result.finish();
} }
...@@ -1497,8 +1562,10 @@ private: ...@@ -1497,8 +1562,10 @@ private:
SendInternalResult sendInternal(bool isTailCall) { SendInternalResult sendInternal(bool isTailCall) {
// Build the cap table. // Build the cap table.
kj::Vector<int> fds;
auto exports = connectionState->writeDescriptors( auto exports = connectionState->writeDescriptors(
capTable.getTable(), callBuilder.getParams()); capTable.getTable(), callBuilder.getParams(), fds);
message->setFds(fds.releaseAsArray());
// Init the question table. Do this after writing descriptors to avoid interference. // Init the question table. Do this after writing descriptors to avoid interference.
QuestionId questionId; QuestionId questionId;
...@@ -1691,7 +1758,9 @@ private: ...@@ -1691,7 +1758,9 @@ private:
// Build the cap table. // Build the cap table.
auto capTable = this->capTable.getTable(); auto capTable = this->capTable.getTable();
auto exports = connectionState.writeDescriptors(capTable, payload); kj::Vector<int> fds;
auto exports = connectionState.writeDescriptors(capTable, payload, fds);
message->setFds(fds.releaseAsArray());
// Capabilities that we are returning are subject to embargos. See `Disembargo` in rpc.capnp. // Capabilities that we are returning are subject to embargos. See `Disembargo` in rpc.capnp.
// As explained there, in order to deal with the Tribble 4-way race condition, we need to // As explained there, in order to deal with the Tribble 4-way race condition, we need to
...@@ -2130,7 +2199,7 @@ private: ...@@ -2130,7 +2199,7 @@ private:
break; break;
case rpc::Message::RESOLVE: case rpc::Message::RESOLVE:
handleResolve(reader.getResolve()); handleResolve(kj::mv(message), reader.getResolve());
break; break;
case rpc::Message::RELEASE: case rpc::Message::RELEASE:
...@@ -2262,7 +2331,9 @@ private: ...@@ -2262,7 +2331,9 @@ private:
auto capTableArray = capTable.getTable(); auto capTableArray = capTable.getTable();
KJ_DASSERT(capTableArray.size() == 1); KJ_DASSERT(capTableArray.size() == 1);
resultExports = writeDescriptors(capTableArray, payload); kj::Vector<int> fds;
resultExports = writeDescriptors(capTableArray, payload, fds);
response->setFds(fds.releaseAsArray());
capHook = KJ_ASSERT_NONNULL(capTableArray[0])->addRef(); capHook = KJ_ASSERT_NONNULL(capTableArray[0])->addRef();
})) { })) {
fromException(*exception, ret.initException()); fromException(*exception, ret.initException());
...@@ -2307,7 +2378,7 @@ private: ...@@ -2307,7 +2378,7 @@ private:
} }
auto payload = call.getParams(); auto payload = call.getParams();
auto capTableArray = receiveCaps(payload.getCapTable()); auto capTableArray = receiveCaps(payload.getCapTable(), message->getAttachedFds());
auto cancelPaf = kj::newPromiseAndFulfiller<void>(); auto cancelPaf = kj::newPromiseAndFulfiller<void>();
AnswerId answerId = call.getQuestionId(); AnswerId answerId = call.getQuestionId();
...@@ -2500,7 +2571,7 @@ private: ...@@ -2500,7 +2571,7 @@ private:
} }
auto payload = ret.getResults(); auto payload = ret.getResults();
auto capTableArray = receiveCaps(payload.getCapTable()); auto capTableArray = receiveCaps(payload.getCapTable(), message->getAttachedFds());
questionRef->fulfill(kj::refcounted<RpcResponseImpl>( questionRef->fulfill(kj::refcounted<RpcResponseImpl>(
*this, kj::addRef(*questionRef), kj::mv(message), *this, kj::addRef(*questionRef), kj::mv(message),
kj::mv(capTableArray), payload.getContent())); kj::mv(capTableArray), payload.getContent()));
...@@ -2600,14 +2671,14 @@ private: ...@@ -2600,14 +2671,14 @@ private:
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Level 1 // Level 1
void handleResolve(const rpc::Resolve::Reader& resolve) { void handleResolve(kj::Own<IncomingRpcMessage>&& message, const rpc::Resolve::Reader& resolve) {
kj::Own<ClientHook> replacement; kj::Own<ClientHook> replacement;
kj::Maybe<kj::Exception> exception; kj::Maybe<kj::Exception> exception;
// Extract the replacement capability. // Extract the replacement capability.
switch (resolve.which()) { switch (resolve.which()) {
case rpc::Resolve::CAP: case rpc::Resolve::CAP:
KJ_IF_MAYBE(cap, receiveCap(resolve.getCap())) { KJ_IF_MAYBE(cap, receiveCap(resolve.getCap(), message->getAttachedFds())) {
replacement = kj::mv(*cap); replacement = kj::mv(*cap);
} else { } else {
KJ_FAIL_REQUIRE("'Resolve' contained 'CapDescriptor.none'.") { return; } KJ_FAIL_REQUIRE("'Resolve' contained 'CapDescriptor.none'.") { return; }
......
...@@ -28,6 +28,8 @@ ...@@ -28,6 +28,8 @@
#include "capability.h" #include "capability.h"
#include "rpc-prelude.h" #include "rpc-prelude.h"
namespace kj { class AutoCloseFd; }
namespace capnp { namespace capnp {
template <typename VatId, typename ProvisionId, typename RecipientId, template <typename VatId, typename ProvisionId, typename RecipientId,
...@@ -305,6 +307,10 @@ public: ...@@ -305,6 +307,10 @@ public:
// Get the message body, which the caller may fill in any way it wants. (The standard RPC // Get the message body, which the caller may fill in any way it wants. (The standard RPC
// implementation initializes it as a Message as defined in rpc.capnp.) // implementation initializes it as a Message as defined in rpc.capnp.)
virtual void setFds(kj::Array<int> fds) {}
// Set the list of file descriptors to send along with this message, if FD passing is supported.
// An implementation may ignore this.
virtual void send() = 0; virtual void send() = 0;
// Send the message, or at least put it in a queue to be sent later. Note that the builder // Send the message, or at least put it in a queue to be sent later. Note that the builder
// returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed. // returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed.
...@@ -317,6 +323,14 @@ public: ...@@ -317,6 +323,14 @@ public:
virtual AnyPointer::Reader getBody() = 0; virtual AnyPointer::Reader getBody() = 0;
// Get the message body, to be interpreted by the caller. (The standard RPC implementation // Get the message body, to be interpreted by the caller. (The standard RPC implementation
// interprets it as a Message as defined in rpc.capnp.) // interprets it as a Message as defined in rpc.capnp.)
virtual kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() { return nullptr; }
// If the transport supports attached file descriptors and some were attached to this message,
// returns them. Otherwise returns an empty array. It is intended that the caller will move the
// FDs out of this table when they are consumed, possibly leaving behind a null slot. Callers
// should be careful to check if an FD was already consumed by comparing the slot with `nullptr`.
// (We don't use Maybe here because moving from a Maybe doesn't make it null, so it would only
// add confusion. Moving from an AutoCloseFd does in fact make it null.)
}; };
template <typename VatId, typename ProvisionId, typename RecipientId, template <typename VatId, typename ProvisionId, typename RecipientId,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "serialize-async.h" #include "serialize-async.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/io.h>
namespace capnp { namespace capnp {
...@@ -35,6 +36,10 @@ public: ...@@ -35,6 +36,10 @@ public:
kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace); kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
kj::Promise<kj::Maybe<size_t>> readWithFds(
kj::AsyncCapabilityStream& inputStream,
kj::ArrayPtr<kj::AutoCloseFd> fds, kj::ArrayPtr<word> scratchSpace);
// implements MessageReader ---------------------------------------- // implements MessageReader ----------------------------------------
kj::ArrayPtr<const word> getSegment(uint id) override { kj::ArrayPtr<const word> getSegment(uint id) override {
...@@ -79,6 +84,27 @@ kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream, ...@@ -79,6 +84,27 @@ kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream,
}); });
} }
kj::Promise<kj::Maybe<size_t>> AsyncMessageReader::readWithFds(
kj::AsyncCapabilityStream& inputStream, kj::ArrayPtr<kj::AutoCloseFd> fds,
kj::ArrayPtr<word> scratchSpace) {
return inputStream.tryReadWithFds(firstWord, sizeof(firstWord), sizeof(firstWord),
fds.begin(), fds.size())
.then([this,&inputStream,KJ_CPCAP(scratchSpace)]
(kj::AsyncCapabilityStream::ReadResult result) mutable
-> kj::Promise<kj::Maybe<size_t>> {
if (result.byteCount == 0) {
return kj::Maybe<size_t>(nullptr);
} else if (result.byteCount < sizeof(firstWord)) {
// EOF in first word.
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
return kj::Maybe<size_t>(nullptr);
}
return readAfterFirstWord(inputStream, scratchSpace)
.then([result]() -> kj::Maybe<size_t> { return result.capCount; });
});
}
kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream, kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) { kj::ArrayPtr<word> scratchSpace) {
if (segmentCount() == 0) { if (segmentCount() == 0) {
...@@ -151,26 +177,57 @@ kj::Promise<kj::Own<MessageReader>> readMessage( ...@@ -151,26 +177,57 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options); auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace); auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader, bool success) { return promise.then([reader = kj::mv(reader)](bool success) mutable -> kj::Own<MessageReader> {
if (!success) { if (!success) {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
} }
return kj::mv(reader); return kj::mv(reader);
})); });
} }
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage( kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options); auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace); auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, return promise.then([reader = kj::mv(reader)](bool success) mutable
[](kj::Own<MessageReader>&& reader, bool success) -> kj::Maybe<kj::Own<MessageReader>> { -> kj::Maybe<kj::Own<MessageReader>> {
if (success) { if (success) {
return kj::mv(reader); return kj::mv(reader);
} else { } else {
return nullptr; return nullptr;
} }
})); });
}
kj::Promise<MessageReaderAndFds> readMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->readWithFds(input, fdSpace, scratchSpace);
return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe<size_t> nfds) mutable
-> MessageReaderAndFds {
KJ_IF_MAYBE(n, nfds) {
return { kj::mv(reader), fdSpace.slice(0, *n) };
} else {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
return { kj::mv(reader), nullptr };
}
});
}
kj::Promise<kj::Maybe<MessageReaderAndFds>> tryReadMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->readWithFds(input, fdSpace, scratchSpace);
return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe<size_t> nfds) mutable
-> kj::Maybe<MessageReaderAndFds> {
KJ_IF_MAYBE(n, nfds) {
return MessageReaderAndFds { kj::mv(reader), fdSpace.slice(0, *n) };
} else {
return nullptr;
}
});
} }
// ======================================================================================= // =======================================================================================
...@@ -184,10 +241,9 @@ struct WriteArrays { ...@@ -184,10 +241,9 @@ struct WriteArrays {
kj::Array<kj::ArrayPtr<const byte>> pieces; kj::Array<kj::ArrayPtr<const byte>> pieces;
}; };
} // namespace template <typename WriteFunc>
kj::Promise<void> writeMessageImpl(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments,
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, WriteFunc&& writeFunc) {
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message."); KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message.");
WriteArrays arrays; WriteArrays arrays;
...@@ -212,10 +268,28 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, ...@@ -212,10 +268,28 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
arrays.pieces[i + 1] = segments[i].asBytes(); arrays.pieces[i + 1] = segments[i].asBytes();
} }
auto promise = output.write(arrays.pieces); auto promise = writeFunc(arrays.pieces);
// Make sure the arrays aren't freed until the write completes. // Make sure the arrays aren't freed until the write completes.
return promise.then(kj::mvCapture(arrays, [](WriteArrays&&) {})); return promise.then(kj::mvCapture(arrays, [](WriteArrays&&) {}));
} }
} // namespace
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
return writeMessageImpl(segments,
[&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
return output.write(pieces);
});
}
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
return writeMessageImpl(segments,
[&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
return output.writeWithFds(pieces[0], pieces.slice(1, pieces.size()), fds);
});
}
} // namespace capnp } // namespace capnp
...@@ -51,11 +51,42 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& bu ...@@ -51,11 +51,42 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& bu
KJ_WARN_UNUSED_RESULT; KJ_WARN_UNUSED_RESULT;
// Write asynchronously. The parameters must remain valid until the returned promise resolves. // Write asynchronously. The parameters must remain valid until the returned promise resolves.
// -----------------------------------------------------------------------------
// Versions that support FD passing.
struct MessageReaderAndFds {
kj::Own<MessageReader> reader;
kj::ArrayPtr<kj::AutoCloseFd> fds;
};
kj::Promise<MessageReaderAndFds> readMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options = ReaderOptions(), kj::ArrayPtr<word> scratchSpace = nullptr);
// Read a message that may also have file descriptors attached, e.g. from a Unix socket with
// SCM_RIGHTS.
kj::Promise<kj::Maybe<MessageReaderAndFds>> tryReadMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options = ReaderOptions(), kj::ArrayPtr<word> scratchSpace = nullptr);
// Like `readMessage` but returns null on EOF.
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments)
KJ_WARN_UNUSED_RESULT;
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
MessageBuilder& builder)
KJ_WARN_UNUSED_RESULT;
// Write a message with FDs attached, e.g. to a Unix socket with SCM_RIGHTS.
// ======================================================================================= // =======================================================================================
// inline implementation details // inline implementation details
inline kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& builder) { inline kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& builder) {
return writeMessage(output, builder.getSegmentsForOutput()); return writeMessage(output, builder.getSegmentsForOutput());
} }
inline kj::Promise<void> writeMessage(
kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds, MessageBuilder& builder) {
return writeMessage(output, fds, builder.getSegmentsForOutput());
}
} // namespace capnp } // namespace capnp
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "test-util.h" #include "test-util.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/compat/gtest.h> #include <kj/compat/gtest.h>
#include <kj/io.h>
#include <kj/miniposix.h>
namespace capnp { namespace capnp {
namespace _ { // private namespace _ { // private
...@@ -1144,6 +1146,34 @@ kj::Promise<void> TestMoreStuffImpl::getEnormousString(GetEnormousStringContext ...@@ -1144,6 +1146,34 @@ kj::Promise<void> TestMoreStuffImpl::getEnormousString(GetEnormousStringContext
return kj::READY_NOW; return kj::READY_NOW;
} }
kj::Promise<void> TestMoreStuffImpl::writeToFd(WriteToFdContext context) {
auto params = context.getParams();
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(params.getFdCap1().getFd()
.then([](kj::Maybe<int> fd) {
kj::FdOutputStream(KJ_ASSERT_NONNULL(fd)).write("foo", 3);
}));
promises.add(params.getFdCap2().getFd()
.then([context](kj::Maybe<int> fd) mutable {
context.getResults().setSecondFdPresent(fd != nullptr);
KJ_IF_MAYBE(f, fd) {
kj::FdOutputStream(*f).write("bar", 3);
}
}));
int pair[2];
KJ_SYSCALL(kj::miniposix::pipe(pair));
kj::AutoCloseFd in(pair[0]);
kj::AutoCloseFd out(pair[1]);
kj::FdOutputStream(kj::mv(out)).write("baz", 3);
context.getResults().setFdCap3(kj::heap<TestFdCap>(kj::mv(in)));
return kj::joinPromises(promises.finish());
}
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
} // namespace _ (private) } // namespace _ (private)
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#if !CAPNP_LITE #if !CAPNP_LITE
#include "dynamic.h" #include "dynamic.h"
#include <kj/io.h>
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
// TODO(cleanup): Auto-generate stringification functions for union discriminants. // TODO(cleanup): Auto-generate stringification functions for union discriminants.
...@@ -274,6 +275,8 @@ public: ...@@ -274,6 +275,8 @@ public:
kj::Promise<void> getEnormousString(GetEnormousStringContext context) override; kj::Promise<void> getEnormousString(GetEnormousStringContext context) override;
kj::Promise<void> writeToFd(WriteToFdContext context) override;
private: private:
int& callCount; int& callCount;
int& handleCount; int& handleCount;
...@@ -303,6 +306,18 @@ private: ...@@ -303,6 +306,18 @@ private:
TestInterfaceImpl impl; TestInterfaceImpl impl;
}; };
class TestFdCap final: public test::TestInterface::Server {
// Implementation of TestInterface that wraps a file descriptor.
public:
TestFdCap(kj::AutoCloseFd fd): fd(kj::mv(fd)) {}
kj::Maybe<int> getFd() override { return fd.get(); }
private:
kj::AutoCloseFd fd;
};
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
} // namespace _ (private) } // namespace _ (private)
......
...@@ -860,6 +860,11 @@ interface TestMoreStuff extends(TestCallOrder) { ...@@ -860,6 +860,11 @@ interface TestMoreStuff extends(TestCallOrder) {
getEnormousString @11 () -> (str :Text); getEnormousString @11 () -> (str :Text);
# Attempts to return an 100MB string. Should always fail. # Attempts to return an 100MB string. Should always fail.
writeToFd @13 (fdCap1 :TestInterface, fdCap2 :TestInterface)
-> (fdCap3 :TestInterface, secondFdPresent :Bool);
# Expects fdCap1 and fdCap2 wrap socket file descriptors. Writes "foo" to the first and "bar" to
# the second. Also creates a socketpair, writes "baz" to one end, and returns the other end.
} }
interface TestMembrane { interface TestMembrane {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment