Commit e9c46ac1 authored by Kenton Varda's avatar Kenton Varda

Implement FD passing in Cap'n Proto.

An endpoint (subclass of `Capability::Server`) may override `kj::Maybe<int> getFd()` to expose an underlying file descriptor.

A remote client may use `Capability::Client::getFd()` on the endpoint's capability to get that FD.

The client and server must explicitly opt into FD passing by passing a max-FDs-per-message limit to TwoPartyVatNetwork and using Unix sockets as the transport. Nothing other than that is needed.
parent 84c7ba81
...@@ -69,6 +69,19 @@ Capability::Client::Client(decltype(nullptr)) ...@@ -69,6 +69,19 @@ Capability::Client::Client(decltype(nullptr))
Capability::Client::Client(kj::Exception&& exception) Capability::Client::Client(kj::Exception&& exception)
: hook(newBrokenCap(kj::mv(exception))) {} : hook(newBrokenCap(kj::mv(exception))) {}
kj::Promise<kj::Maybe<int>> Capability::Client::getFd() {
auto fd = hook->getFd();
if (fd != nullptr) {
return fd;
} else KJ_IF_MAYBE(promise, hook->whenMoreResolved()) {
return promise->attach(hook->addRef()).then([](kj::Own<ClientHook> newHook) {
return Client(kj::mv(newHook)).getFd();
});
} else {
return kj::Maybe<int>(nullptr);
}
}
kj::Promise<void> Capability::Server::internalUnimplemented( kj::Promise<void> Capability::Server::internalUnimplemented(
const char* actualInterfaceName, uint64_t requestedTypeId) { const char* actualInterfaceName, uint64_t requestedTypeId) {
return KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.", return KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.",
...@@ -374,6 +387,14 @@ public: ...@@ -374,6 +387,14 @@ public:
return nullptr; return nullptr;
} }
kj::Maybe<int> getFd() override {
KJ_IF_MAYBE(r, redirect) {
return r->get()->getFd();
} else {
return nullptr;
}
}
private: private:
typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork; typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork;
...@@ -524,6 +545,10 @@ public: ...@@ -524,6 +545,10 @@ public:
} }
} }
kj::Maybe<int> getFd() override {
return server->getFd();
}
private: private:
kj::Own<Capability::Server> server; kj::Own<Capability::Server> server;
_::CapabilityServerSetBase* capServerSet = nullptr; _::CapabilityServerSetBase* capServerSet = nullptr;
...@@ -616,6 +641,10 @@ public: ...@@ -616,6 +641,10 @@ public:
return brand; return brand;
} }
kj::Maybe<int> getFd() override {
return nullptr;
}
private: private:
kj::Exception exception; kj::Exception exception;
bool resolved; bool resolved;
......
...@@ -206,6 +206,19 @@ public: ...@@ -206,6 +206,19 @@ public:
// Make a request without knowing the types of the params or results. You specify the type ID // Make a request without knowing the types of the params or results. You specify the type ID
// and method number manually. // and method number manually.
kj::Promise<kj::Maybe<int>> getFd();
// If the capability's server implemented Capability::Server::getFd() returning non-null, and all
// RPC links between the client and server support FD passing, returns a file descriptor pointing
// to the same undelying file description as the server did. Returns null if the server provided
// no FD or if FD passing was unavailable at some intervening link.
//
// This returns a Promise to handle the case of an unresolved promise capability, e.g. a
// pipelined capability. The promise resolves no later than when the capability settles, i.e.
// the same time `whenResolved()` would complete.
//
// The file descriptor will remain open at least as long as the Capability::Client remains alive.
// If you need it to last longer, you will need to `dup()` it.
// TODO(someday): method(s) for Join // TODO(someday): method(s) for Join
protected: protected:
...@@ -331,6 +344,11 @@ public: ...@@ -331,6 +344,11 @@ public:
// is no longer needed. `context` may be used to allocate the output struct and deal with // is no longer needed. `context` may be used to allocate the output struct and deal with
// cancellation. // cancellation.
virtual kj::Maybe<int> getFd() { return nullptr; }
// If this capability is backed by a file descriptor that is safe to directly expose to clients,
// returns that FD. When FD passing has been enabled in the RPC layer, this FD may be sent to
// other processes along with the capability.
// TODO(someday): Method which can optionally be overridden to implement Join when the object is // TODO(someday): Method which can optionally be overridden to implement Join when the object is
// a proxy. // a proxy.
...@@ -563,6 +581,10 @@ public: ...@@ -563,6 +581,10 @@ public:
// Otherwise, return nullptr. Default implementation (which everyone except LocalClient should // Otherwise, return nullptr. Default implementation (which everyone except LocalClient should
// use) always returns nullptr. // use) always returns nullptr.
virtual kj::Maybe<int> getFd() = 0;
// Implements Capability::Client::getFd(). If this returns null but whenMoreResolved() returns
// non-null, then Capability::Client::getFd() waits for resolution and tries again.
static kj::Own<ClientHook> from(Capability::Client client) { return kj::mv(client.hook); } static kj::Own<ClientHook> from(Capability::Client client) { return kj::mv(client.hook); }
}; };
......
...@@ -469,6 +469,13 @@ public: ...@@ -469,6 +469,13 @@ public:
return MEMBRANE_BRAND; return MEMBRANE_BRAND;
} }
kj::Maybe<int> getFd() override {
// We can't let FDs pass over membranes because we have no way to enforce the membrane policy
// on them. If the MembranePolicy wishes to explicitly permit certain FDs to pass, it can
// always do so by overriding the appropriate policy methods.
return nullptr;
}
private: private:
kj::Own<ClientHook> inner; kj::Own<ClientHook> inner;
kj::Own<MembranePolicy> policy; kj::Own<MembranePolicy> policy;
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/thread.h> #include <kj/thread.h>
#include <kj/compat/gtest.h> #include <kj/compat/gtest.h>
#include <kj/miniposix.h>
// TODO(cleanup): Auto-generate stringification functions for union discriminants. // TODO(cleanup): Auto-generate stringification functions for union discriminants.
namespace capnp { namespace capnp {
...@@ -419,6 +420,104 @@ TEST(TwoPartyNetwork, BootstrapFactory) { ...@@ -419,6 +420,104 @@ TEST(TwoPartyNetwork, BootstrapFactory) {
EXPECT_TRUE(bootstrapFactory.called); EXPECT_TRUE(bootstrapFactory.called);
} }
// =======================================================================================
#if !_WIN32
KJ_TEST("send FD over RPC") {
auto io = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
TwoPartyServer server(kj::heap<TestMoreStuffImpl>(callCount, handleCount));
auto pipe = io.provider->newCapabilityPipe();
server.accept(kj::mv(pipe.ends[0]), 2);
TwoPartyClient client(*pipe.ends[1], 2);
auto cap = client.bootstrap().castAs<test::TestMoreStuff>();
int pipeFds[2];
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in1(pipeFds[0]);
kj::AutoCloseFd out1(pipeFds[1]);
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in2(pipeFds[0]);
kj::AutoCloseFd out2(pipeFds[1]);
capnp::RemotePromise<test::TestMoreStuff::WriteToFdResults> promise = nullptr;
{
auto req = cap.writeToFdRequest();
// Order reversal intentional, just trying to mix things up.
req.setFdCap1(kj::heap<TestFdCap>(kj::mv(out2)));
req.setFdCap2(kj::heap<TestFdCap>(kj::mv(out1)));
promise = req.send();
}
int in3 = KJ_ASSERT_NONNULL(promise.getFdCap3().getFd().wait(io.waitScope));
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in3))->readAllText().wait(io.waitScope)
== "baz");
{
auto promise2 = kj::mv(promise); // make sure the PipelineHook also goes out of scope
auto response = promise2.wait(io.waitScope);
KJ_EXPECT(response.getSecondFdPresent());
}
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in1))->readAllText().wait(io.waitScope)
== "bar");
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in2))->readAllText().wait(io.waitScope)
== "foo");
}
KJ_TEST("FD per message limit") {
auto io = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
TwoPartyServer server(kj::heap<TestMoreStuffImpl>(callCount, handleCount));
auto pipe = io.provider->newCapabilityPipe();
server.accept(kj::mv(pipe.ends[0]), 1);
TwoPartyClient client(*pipe.ends[1], 1);
auto cap = client.bootstrap().castAs<test::TestMoreStuff>();
int pipeFds[2];
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in1(pipeFds[0]);
kj::AutoCloseFd out1(pipeFds[1]);
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in2(pipeFds[0]);
kj::AutoCloseFd out2(pipeFds[1]);
capnp::RemotePromise<test::TestMoreStuff::WriteToFdResults> promise = nullptr;
{
auto req = cap.writeToFdRequest();
// Order reversal intentional, just trying to mix things up.
req.setFdCap1(kj::heap<TestFdCap>(kj::mv(out2)));
req.setFdCap2(kj::heap<TestFdCap>(kj::mv(out1)));
promise = req.send();
}
int in3 = KJ_ASSERT_NONNULL(promise.getFdCap3().getFd().wait(io.waitScope));
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in3))->readAllText().wait(io.waitScope)
== "baz");
{
auto promise2 = kj::mv(promise); // make sure the PipelineHook also goes out of scope
auto response = promise2.wait(io.waitScope);
KJ_EXPECT(!response.getSecondFdPresent());
}
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in1))->readAllText().wait(io.waitScope)
== "");
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in2))->readAllText().wait(io.waitScope)
== "foo");
}
#endif // !_WIN32
} // namespace } // namespace
} // namespace _ } // namespace _
} // namespace capnp } // namespace capnp
...@@ -22,12 +22,13 @@ ...@@ -22,12 +22,13 @@
#include "rpc-twoparty.h" #include "rpc-twoparty.h"
#include "serialize-async.h" #include "serialize-async.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/io.h>
namespace capnp { namespace capnp {
TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions) ReaderOptions receiveOptions)
: stream(stream), side(side), peerVatId(4), : stream(&stream), maxFdsPerMesage(0), side(side), peerVatId(4),
receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) { receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) {
peerVatId.initRoot<rpc::twoparty::VatId>().setSide( peerVatId.initRoot<rpc::twoparty::VatId>().setSide(
side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER
...@@ -38,6 +39,13 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty: ...@@ -38,6 +39,13 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty:
disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller); disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller);
} }
TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMesage,
rpc::twoparty::Side side, ReaderOptions receiveOptions)
: TwoPartyVatNetwork(stream, side, receiveOptions) {
this->stream = &stream;
this->maxFdsPerMesage = maxFdsPerMesage;
}
void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const { void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const {
if (--refcount == 0) { if (--refcount == 0) {
fulfiller->fulfill(); fulfiller->fulfill();
...@@ -81,6 +89,12 @@ public: ...@@ -81,6 +89,12 @@ public:
return message.getRoot<AnyPointer>(); return message.getRoot<AnyPointer>();
} }
void setFds(kj::Array<int> fds) override {
if (network.stream.is<kj::AsyncCapabilityStream*>()) {
this->fds = kj::mv(fds);
}
}
void send() override { void send() override {
size_t size = 0; size_t size = 0;
for (auto& segment: message.getSegmentsForOutput()) { for (auto& segment: message.getSegmentsForOutput()) {
...@@ -98,7 +112,15 @@ public: ...@@ -98,7 +112,15 @@ public:
// Note that if the write fails, all further writes will be skipped due to the exception. // Note that if the write fails, all further writes will be skipped due to the exception.
// We never actually handle this exception because we assume the read end will fail as well // We never actually handle this exception because we assume the read end will fail as well
// and it's cleaner to handle the failure there. // and it's cleaner to handle the failure there.
return writeMessage(network.stream, message); KJ_SWITCH_ONEOF(network.stream) {
KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
return writeMessage(*ioStream, message);
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
return writeMessage(*capStream, fds, message);
}
}
KJ_UNREACHABLE;
}).attach(kj::addRef(*this)) }).attach(kj::addRef(*this))
// Note that it's important that the eagerlyEvaluate() come *after* the attach() because // Note that it's important that the eagerlyEvaluate() come *after* the attach() because
// otherwise the message (and any capabilities in it) will not be released until a new // otherwise the message (and any capabilities in it) will not be released until a new
...@@ -109,18 +131,32 @@ public: ...@@ -109,18 +131,32 @@ public:
private: private:
TwoPartyVatNetwork& network; TwoPartyVatNetwork& network;
MallocMessageBuilder message; MallocMessageBuilder message;
kj::Array<int> fds;
}; };
class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage { class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage {
public: public:
IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {} IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {}
IncomingMessageImpl(MessageReaderAndFds init, kj::Array<kj::AutoCloseFd> fdSpace)
: message(kj::mv(init.reader)),
fdSpace(kj::mv(fdSpace)),
fds(init.fds) {
KJ_DASSERT(this->fds.begin() == this->fdSpace.begin());
}
AnyPointer::Reader getBody() override { AnyPointer::Reader getBody() override {
return message->getRoot<AnyPointer>(); return message->getRoot<AnyPointer>();
} }
kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() override {
return fds;
}
private: private:
kj::Own<MessageReader> message; kj::Own<MessageReader> message;
kj::Array<kj::AutoCloseFd> fdSpace;
kj::ArrayPtr<kj::AutoCloseFd> fds;
}; };
rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() { rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() {
...@@ -132,9 +168,11 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSeg ...@@ -132,9 +168,11 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSeg
} }
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() { kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
return kj::evalLater([&]() { return kj::evalLater([this]() {
return tryReadMessage(stream, receiveOptions) KJ_SWITCH_ONEOF(stream) {
.then([&](kj::Maybe<kj::Own<MessageReader>>&& message) KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
return tryReadMessage(*ioStream, receiveOptions)
.then([](kj::Maybe<kj::Own<MessageReader>>&& message)
-> kj::Maybe<kj::Own<IncomingRpcMessage>> { -> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, message) { KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m))); return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
...@@ -142,12 +180,40 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI ...@@ -142,12 +180,40 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI
return nullptr; return nullptr;
} }
}); });
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
auto fdSpace = kj::heapArray<kj::AutoCloseFd>(maxFdsPerMesage);
auto promise = tryReadMessage(*capStream, fdSpace, receiveOptions);
return promise.then([fdSpace = kj::mv(fdSpace)]
(kj::Maybe<MessageReaderAndFds>&& messageAndFds) mutable
-> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, messageAndFds) {
if (m->fds.size() > 0) {
return kj::Own<IncomingRpcMessage>(
kj::heap<IncomingMessageImpl>(kj::mv(*m), kj::mv(fdSpace)));
} else {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(m->reader)));
}
} else {
return nullptr;
}
});
}
}
KJ_UNREACHABLE;
}); });
} }
kj::Promise<void> TwoPartyVatNetwork::shutdown() { kj::Promise<void> TwoPartyVatNetwork::shutdown() {
kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() { kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() {
stream.shutdownWrite(); KJ_SWITCH_ONEOF(stream) {
KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
ioStream->shutdownWrite();
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
capStream->shutdownWrite();
}
}
}); });
previousWrite = nullptr; previousWrite = nullptr;
return kj::mv(result); return kj::mv(result);
...@@ -168,6 +234,14 @@ struct TwoPartyServer::AcceptedConnection { ...@@ -168,6 +234,14 @@ struct TwoPartyServer::AcceptedConnection {
: connection(kj::mv(connectionParam)), : connection(kj::mv(connectionParam)),
network(*connection, rpc::twoparty::Side::SERVER), network(*connection, rpc::twoparty::Side::SERVER),
rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {} rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
explicit AcceptedConnection(Capability::Client bootstrapInterface,
kj::Own<kj::AsyncCapabilityStream>&& connectionParam,
uint maxFdsPerMesage)
: connection(kj::mv(connectionParam)),
network(kj::downcast<kj::AsyncCapabilityStream>(*connection),
maxFdsPerMesage, rpc::twoparty::Side::SERVER),
rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
}; };
void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) { void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
...@@ -178,6 +252,15 @@ void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) { ...@@ -178,6 +252,15 @@ void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
tasks.add(promise.attach(kj::mv(connectionState))); tasks.add(promise.attach(kj::mv(connectionState)));
} }
void TwoPartyServer::accept(kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMesage) {
auto connectionState = kj::heap<AcceptedConnection>(
bootstrapInterface, kj::mv(connection), maxFdsPerMesage);
// Run the connection until disconnect.
auto promise = connectionState->network.onDisconnect();
tasks.add(promise.attach(kj::mv(connectionState)));
}
kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) { kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
return listener.accept() return listener.accept()
.then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable { .then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable {
...@@ -186,6 +269,15 @@ kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) { ...@@ -186,6 +269,15 @@ kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
}); });
} }
kj::Promise<void> TwoPartyServer::listenCapStreamReceiver(
kj::ConnectionReceiver& listener, uint maxFdsPerMesage) {
return listener.accept()
.then([this,&listener,maxFdsPerMesage](kj::Own<kj::AsyncIoStream>&& connection) mutable {
accept(connection.downcast<kj::AsyncCapabilityStream>(), maxFdsPerMesage);
return listenCapStreamReceiver(listener, maxFdsPerMesage);
});
}
void TwoPartyServer::taskFailed(kj::Exception&& exception) { void TwoPartyServer::taskFailed(kj::Exception&& exception) {
KJ_LOG(ERROR, exception); KJ_LOG(ERROR, exception);
} }
...@@ -195,12 +287,22 @@ TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection) ...@@ -195,12 +287,22 @@ TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection)
rpcSystem(makeRpcClient(network)) {} rpcSystem(makeRpcClient(network)) {}
TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage)
: network(connection, maxFdsPerMesage, rpc::twoparty::Side::CLIENT),
rpcSystem(makeRpcClient(network)) {}
TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection, TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection,
Capability::Client bootstrapInterface, Capability::Client bootstrapInterface,
rpc::twoparty::Side side) rpc::twoparty::Side side)
: network(connection, side), : network(connection, side),
rpcSystem(network, bootstrapInterface) {} rpcSystem(network, bootstrapInterface) {}
TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage,
Capability::Client bootstrapInterface,
rpc::twoparty::Side side)
: network(connection, maxFdsPerMesage, side),
rpcSystem(network, bootstrapInterface) {}
Capability::Client TwoPartyClient::bootstrap() { Capability::Client TwoPartyClient::bootstrap() {
MallocMessageBuilder message(4); MallocMessageBuilder message(4);
auto vatId = message.getRoot<rpc::twoparty::VatId>(); auto vatId = message.getRoot<rpc::twoparty::VatId>();
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "message.h" #include "message.h"
#include <kj/async-io.h> #include <kj/async-io.h>
#include <capnp/rpc-twoparty.capnp.h> #include <capnp/rpc-twoparty.capnp.h>
#include <kj/one-of.h>
namespace capnp { namespace capnp {
...@@ -53,6 +54,19 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, ...@@ -53,6 +54,19 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase,
public: public:
TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions = ReaderOptions()); ReaderOptions receiveOptions = ReaderOptions());
TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMesage,
rpc::twoparty::Side side, ReaderOptions receiveOptions = ReaderOptions());
// To support FD passing, pass an AsyncCapabilityStream and `maxFdsPerMesage`, which specifies
// the maximum number of file descriptors to accept from the peer in any one RPC message. It is
// important to keep maxFdsPerMesage low in order to stop DoS attacks that fill up your FD table.
//
// Note that this limit applies only to incoming messages; outgoing messages are allowed to have
// more FDs. Sometimes it makes sense to enforce a limit of zero in one direction while having
// a non-zero limit in the other. For example, in a supervisor/sandbox scenario, typically there
// are many use cases for passing FDs from supervisor to sandbox but no use case for vice versa.
// The supervisor may be configured not to accept any FDs from the sandbox in order to reduce
// risk of DoS attacks.
KJ_DISALLOW_COPY(TwoPartyVatNetwork); KJ_DISALLOW_COPY(TwoPartyVatNetwork);
kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); } kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); }
...@@ -70,7 +84,8 @@ private: ...@@ -70,7 +84,8 @@ private:
class OutgoingMessageImpl; class OutgoingMessageImpl;
class IncomingMessageImpl; class IncomingMessageImpl;
kj::AsyncIoStream& stream; kj::OneOf<kj::AsyncIoStream*, kj::AsyncCapabilityStream*> stream;
uint maxFdsPerMesage;
rpc::twoparty::Side side; rpc::twoparty::Side side;
MallocMessageBuilder peerVatId; MallocMessageBuilder peerVatId;
ReaderOptions receiveOptions; ReaderOptions receiveOptions;
...@@ -120,6 +135,7 @@ public: ...@@ -120,6 +135,7 @@ public:
explicit TwoPartyServer(Capability::Client bootstrapInterface); explicit TwoPartyServer(Capability::Client bootstrapInterface);
void accept(kj::Own<kj::AsyncIoStream>&& connection); void accept(kj::Own<kj::AsyncIoStream>&& connection);
void accept(kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMesage);
// Accepts the connection for servicing. // Accepts the connection for servicing.
kj::Promise<void> listen(kj::ConnectionReceiver& listener); kj::Promise<void> listen(kj::ConnectionReceiver& listener);
...@@ -127,6 +143,10 @@ public: ...@@ -127,6 +143,10 @@ public:
// exception is thrown while trying to accept. You may discard the returned promise to cancel // exception is thrown while trying to accept. You may discard the returned promise to cancel
// listening. // listening.
kj::Promise<void> listenCapStreamReceiver(kj::ConnectionReceiver& listener, uint maxFdsPerMesage);
// Listen with support for FD transfers. `listener.accept()` must return instances of
// AsyncCapabilityStream, otherwise this will crash.
private: private:
Capability::Client bootstrapInterface; Capability::Client bootstrapInterface;
kj::TaskSet tasks; kj::TaskSet tasks;
...@@ -141,8 +161,12 @@ class TwoPartyClient { ...@@ -141,8 +161,12 @@ class TwoPartyClient {
public: public:
explicit TwoPartyClient(kj::AsyncIoStream& connection); explicit TwoPartyClient(kj::AsyncIoStream& connection);
explicit TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage);
TwoPartyClient(kj::AsyncIoStream& connection, Capability::Client bootstrapInterface, TwoPartyClient(kj::AsyncIoStream& connection, Capability::Client bootstrapInterface,
rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT); rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT);
TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMesage,
Capability::Client bootstrapInterface,
rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT);
Capability::Client bootstrap(); Capability::Client bootstrap();
// Get the server's bootstrap interface. // Get the server's bootstrap interface.
......
This diff is collapsed.
...@@ -28,6 +28,8 @@ ...@@ -28,6 +28,8 @@
#include "capability.h" #include "capability.h"
#include "rpc-prelude.h" #include "rpc-prelude.h"
namespace kj { class AutoCloseFd; }
namespace capnp { namespace capnp {
template <typename VatId, typename ProvisionId, typename RecipientId, template <typename VatId, typename ProvisionId, typename RecipientId,
...@@ -305,6 +307,10 @@ public: ...@@ -305,6 +307,10 @@ public:
// Get the message body, which the caller may fill in any way it wants. (The standard RPC // Get the message body, which the caller may fill in any way it wants. (The standard RPC
// implementation initializes it as a Message as defined in rpc.capnp.) // implementation initializes it as a Message as defined in rpc.capnp.)
virtual void setFds(kj::Array<int> fds) {}
// Set the list of file descriptors to send along with this message, if FD passing is supported.
// An implementation may ignore this.
virtual void send() = 0; virtual void send() = 0;
// Send the message, or at least put it in a queue to be sent later. Note that the builder // Send the message, or at least put it in a queue to be sent later. Note that the builder
// returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed. // returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed.
...@@ -317,6 +323,14 @@ public: ...@@ -317,6 +323,14 @@ public:
virtual AnyPointer::Reader getBody() = 0; virtual AnyPointer::Reader getBody() = 0;
// Get the message body, to be interpreted by the caller. (The standard RPC implementation // Get the message body, to be interpreted by the caller. (The standard RPC implementation
// interprets it as a Message as defined in rpc.capnp.) // interprets it as a Message as defined in rpc.capnp.)
virtual kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() { return nullptr; }
// If the transport supports attached file descriptors and some were attached to this message,
// returns them. Otherwise returns an empty array. It is intended that the caller will move the
// FDs out of this table when they are consumed, possibly leaving behind a null slot. Callers
// should be careful to check if an FD was already consumed by comparing the slot with `nullptr`.
// (We don't use Maybe here because moving from a Maybe doesn't make it null, so it would only
// add confusion. Moving from an AutoCloseFd does in fact make it null.)
}; };
template <typename VatId, typename ProvisionId, typename RecipientId, template <typename VatId, typename ProvisionId, typename RecipientId,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "serialize-async.h" #include "serialize-async.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/io.h>
namespace capnp { namespace capnp {
...@@ -35,6 +36,10 @@ public: ...@@ -35,6 +36,10 @@ public:
kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace); kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
kj::Promise<kj::Maybe<size_t>> readWithFds(
kj::AsyncCapabilityStream& inputStream,
kj::ArrayPtr<kj::AutoCloseFd> fds, kj::ArrayPtr<word> scratchSpace);
// implements MessageReader ---------------------------------------- // implements MessageReader ----------------------------------------
kj::ArrayPtr<const word> getSegment(uint id) override { kj::ArrayPtr<const word> getSegment(uint id) override {
...@@ -79,6 +84,27 @@ kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream, ...@@ -79,6 +84,27 @@ kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream,
}); });
} }
kj::Promise<kj::Maybe<size_t>> AsyncMessageReader::readWithFds(
kj::AsyncCapabilityStream& inputStream, kj::ArrayPtr<kj::AutoCloseFd> fds,
kj::ArrayPtr<word> scratchSpace) {
return inputStream.tryReadWithFds(firstWord, sizeof(firstWord), sizeof(firstWord),
fds.begin(), fds.size())
.then([this,&inputStream,KJ_CPCAP(scratchSpace)]
(kj::AsyncCapabilityStream::ReadResult result) mutable
-> kj::Promise<kj::Maybe<size_t>> {
if (result.byteCount == 0) {
return kj::Maybe<size_t>(nullptr);
} else if (result.byteCount < sizeof(firstWord)) {
// EOF in first word.
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
return kj::Maybe<size_t>(nullptr);
}
return readAfterFirstWord(inputStream, scratchSpace)
.then([result]() -> kj::Maybe<size_t> { return result.capCount; });
});
}
kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream, kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) { kj::ArrayPtr<word> scratchSpace) {
if (segmentCount() == 0) { if (segmentCount() == 0) {
...@@ -151,26 +177,57 @@ kj::Promise<kj::Own<MessageReader>> readMessage( ...@@ -151,26 +177,57 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options); auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace); auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader, bool success) { return promise.then([reader = kj::mv(reader)](bool success) mutable -> kj::Own<MessageReader> {
if (!success) { if (!success) {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
} }
return kj::mv(reader); return kj::mv(reader);
})); });
} }
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage( kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options); auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace); auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, return promise.then([reader = kj::mv(reader)](bool success) mutable
[](kj::Own<MessageReader>&& reader, bool success) -> kj::Maybe<kj::Own<MessageReader>> { -> kj::Maybe<kj::Own<MessageReader>> {
if (success) { if (success) {
return kj::mv(reader); return kj::mv(reader);
} else { } else {
return nullptr; return nullptr;
} }
})); });
}
kj::Promise<MessageReaderAndFds> readMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->readWithFds(input, fdSpace, scratchSpace);
return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe<size_t> nfds) mutable
-> MessageReaderAndFds {
KJ_IF_MAYBE(n, nfds) {
return { kj::mv(reader), fdSpace.slice(0, *n) };
} else {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
return { kj::mv(reader), nullptr };
}
});
}
kj::Promise<kj::Maybe<MessageReaderAndFds>> tryReadMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->readWithFds(input, fdSpace, scratchSpace);
return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe<size_t> nfds) mutable
-> kj::Maybe<MessageReaderAndFds> {
KJ_IF_MAYBE(n, nfds) {
return MessageReaderAndFds { kj::mv(reader), fdSpace.slice(0, *n) };
} else {
return nullptr;
}
});
} }
// ======================================================================================= // =======================================================================================
...@@ -184,10 +241,9 @@ struct WriteArrays { ...@@ -184,10 +241,9 @@ struct WriteArrays {
kj::Array<kj::ArrayPtr<const byte>> pieces; kj::Array<kj::ArrayPtr<const byte>> pieces;
}; };
} // namespace template <typename WriteFunc>
kj::Promise<void> writeMessageImpl(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments,
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, WriteFunc&& writeFunc) {
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message."); KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message.");
WriteArrays arrays; WriteArrays arrays;
...@@ -212,10 +268,28 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, ...@@ -212,10 +268,28 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
arrays.pieces[i + 1] = segments[i].asBytes(); arrays.pieces[i + 1] = segments[i].asBytes();
} }
auto promise = output.write(arrays.pieces); auto promise = writeFunc(arrays.pieces);
// Make sure the arrays aren't freed until the write completes. // Make sure the arrays aren't freed until the write completes.
return promise.then(kj::mvCapture(arrays, [](WriteArrays&&) {})); return promise.then(kj::mvCapture(arrays, [](WriteArrays&&) {}));
} }
} // namespace
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
return writeMessageImpl(segments,
[&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
return output.write(pieces);
});
}
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
return writeMessageImpl(segments,
[&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
return output.writeWithFds(pieces[0], pieces.slice(1, pieces.size()), fds);
});
}
} // namespace capnp } // namespace capnp
...@@ -51,11 +51,42 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& bu ...@@ -51,11 +51,42 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& bu
KJ_WARN_UNUSED_RESULT; KJ_WARN_UNUSED_RESULT;
// Write asynchronously. The parameters must remain valid until the returned promise resolves. // Write asynchronously. The parameters must remain valid until the returned promise resolves.
// -----------------------------------------------------------------------------
// Versions that support FD passing.
struct MessageReaderAndFds {
kj::Own<MessageReader> reader;
kj::ArrayPtr<kj::AutoCloseFd> fds;
};
kj::Promise<MessageReaderAndFds> readMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options = ReaderOptions(), kj::ArrayPtr<word> scratchSpace = nullptr);
// Read a message that may also have file descriptors attached, e.g. from a Unix socket with
// SCM_RIGHTS.
kj::Promise<kj::Maybe<MessageReaderAndFds>> tryReadMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options = ReaderOptions(), kj::ArrayPtr<word> scratchSpace = nullptr);
// Like `readMessage` but returns null on EOF.
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments)
KJ_WARN_UNUSED_RESULT;
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
MessageBuilder& builder)
KJ_WARN_UNUSED_RESULT;
// Write a message with FDs attached, e.g. to a Unix socket with SCM_RIGHTS.
// ======================================================================================= // =======================================================================================
// inline implementation details // inline implementation details
inline kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& builder) { inline kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& builder) {
return writeMessage(output, builder.getSegmentsForOutput()); return writeMessage(output, builder.getSegmentsForOutput());
} }
inline kj::Promise<void> writeMessage(
kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds, MessageBuilder& builder) {
return writeMessage(output, fds, builder.getSegmentsForOutput());
}
} // namespace capnp } // namespace capnp
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "test-util.h" #include "test-util.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/compat/gtest.h> #include <kj/compat/gtest.h>
#include <kj/io.h>
#include <kj/miniposix.h>
namespace capnp { namespace capnp {
namespace _ { // private namespace _ { // private
...@@ -1144,6 +1146,34 @@ kj::Promise<void> TestMoreStuffImpl::getEnormousString(GetEnormousStringContext ...@@ -1144,6 +1146,34 @@ kj::Promise<void> TestMoreStuffImpl::getEnormousString(GetEnormousStringContext
return kj::READY_NOW; return kj::READY_NOW;
} }
kj::Promise<void> TestMoreStuffImpl::writeToFd(WriteToFdContext context) {
auto params = context.getParams();
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(params.getFdCap1().getFd()
.then([](kj::Maybe<int> fd) {
kj::FdOutputStream(KJ_ASSERT_NONNULL(fd)).write("foo", 3);
}));
promises.add(params.getFdCap2().getFd()
.then([context](kj::Maybe<int> fd) mutable {
context.getResults().setSecondFdPresent(fd != nullptr);
KJ_IF_MAYBE(f, fd) {
kj::FdOutputStream(*f).write("bar", 3);
}
}));
int pair[2];
KJ_SYSCALL(kj::miniposix::pipe(pair));
kj::AutoCloseFd in(pair[0]);
kj::AutoCloseFd out(pair[1]);
kj::FdOutputStream(kj::mv(out)).write("baz", 3);
context.getResults().setFdCap3(kj::heap<TestFdCap>(kj::mv(in)));
return kj::joinPromises(promises.finish());
}
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
} // namespace _ (private) } // namespace _ (private)
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#if !CAPNP_LITE #if !CAPNP_LITE
#include "dynamic.h" #include "dynamic.h"
#include <kj/io.h>
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
// TODO(cleanup): Auto-generate stringification functions for union discriminants. // TODO(cleanup): Auto-generate stringification functions for union discriminants.
...@@ -274,6 +275,8 @@ public: ...@@ -274,6 +275,8 @@ public:
kj::Promise<void> getEnormousString(GetEnormousStringContext context) override; kj::Promise<void> getEnormousString(GetEnormousStringContext context) override;
kj::Promise<void> writeToFd(WriteToFdContext context) override;
private: private:
int& callCount; int& callCount;
int& handleCount; int& handleCount;
...@@ -303,6 +306,18 @@ private: ...@@ -303,6 +306,18 @@ private:
TestInterfaceImpl impl; TestInterfaceImpl impl;
}; };
class TestFdCap final: public test::TestInterface::Server {
// Implementation of TestInterface that wraps a file descriptor.
public:
TestFdCap(kj::AutoCloseFd fd): fd(kj::mv(fd)) {}
kj::Maybe<int> getFd() override { return fd.get(); }
private:
kj::AutoCloseFd fd;
};
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
} // namespace _ (private) } // namespace _ (private)
......
...@@ -860,6 +860,11 @@ interface TestMoreStuff extends(TestCallOrder) { ...@@ -860,6 +860,11 @@ interface TestMoreStuff extends(TestCallOrder) {
getEnormousString @11 () -> (str :Text); getEnormousString @11 () -> (str :Text);
# Attempts to return an 100MB string. Should always fail. # Attempts to return an 100MB string. Should always fail.
writeToFd @13 (fdCap1 :TestInterface, fdCap2 :TestInterface)
-> (fdCap3 :TestInterface, secondFdPresent :Bool);
# Expects fdCap1 and fdCap2 wrap socket file descriptors. Writes "foo" to the first and "bar" to
# the second. Also creates a socketpair, writes "baz" to one end, and returns the other end.
} }
interface TestMembrane { interface TestMembrane {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment