Unverified Commit cf34b937 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #821 from capnproto/fd-passing

Implement FD passing in Cap'n Proto. 
parents 0f368d57 d690ae52
...@@ -69,6 +69,19 @@ Capability::Client::Client(decltype(nullptr)) ...@@ -69,6 +69,19 @@ Capability::Client::Client(decltype(nullptr))
Capability::Client::Client(kj::Exception&& exception) Capability::Client::Client(kj::Exception&& exception)
: hook(newBrokenCap(kj::mv(exception))) {} : hook(newBrokenCap(kj::mv(exception))) {}
kj::Promise<kj::Maybe<int>> Capability::Client::getFd() {
auto fd = hook->getFd();
if (fd != nullptr) {
return fd;
} else KJ_IF_MAYBE(promise, hook->whenMoreResolved()) {
return promise->attach(hook->addRef()).then([](kj::Own<ClientHook> newHook) {
return Client(kj::mv(newHook)).getFd();
});
} else {
return kj::Maybe<int>(nullptr);
}
}
kj::Promise<void> Capability::Server::internalUnimplemented( kj::Promise<void> Capability::Server::internalUnimplemented(
const char* actualInterfaceName, uint64_t requestedTypeId) { const char* actualInterfaceName, uint64_t requestedTypeId) {
return KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.", return KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.",
...@@ -374,6 +387,14 @@ public: ...@@ -374,6 +387,14 @@ public:
return nullptr; return nullptr;
} }
kj::Maybe<int> getFd() override {
KJ_IF_MAYBE(r, redirect) {
return r->get()->getFd();
} else {
return nullptr;
}
}
private: private:
typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork; typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork;
...@@ -524,6 +545,10 @@ public: ...@@ -524,6 +545,10 @@ public:
} }
} }
kj::Maybe<int> getFd() override {
return server->getFd();
}
private: private:
kj::Own<Capability::Server> server; kj::Own<Capability::Server> server;
_::CapabilityServerSetBase* capServerSet = nullptr; _::CapabilityServerSetBase* capServerSet = nullptr;
...@@ -616,6 +641,10 @@ public: ...@@ -616,6 +641,10 @@ public:
return brand; return brand;
} }
kj::Maybe<int> getFd() override {
return nullptr;
}
private: private:
kj::Exception exception; kj::Exception exception;
bool resolved; bool resolved;
......
...@@ -206,6 +206,19 @@ public: ...@@ -206,6 +206,19 @@ public:
// Make a request without knowing the types of the params or results. You specify the type ID // Make a request without knowing the types of the params or results. You specify the type ID
// and method number manually. // and method number manually.
kj::Promise<kj::Maybe<int>> getFd();
// If the capability's server implemented Capability::Server::getFd() returning non-null, and all
// RPC links between the client and server support FD passing, returns a file descriptor pointing
// to the same underlying file description as the server did. Returns null if the server provided
// no FD or if FD passing was unavailable at some intervening link.
//
// This returns a Promise to handle the case of an unresolved promise capability, e.g. a
// pipelined capability. The promise resolves no later than when the capability settles, i.e.
// the same time `whenResolved()` would complete.
//
// The file descriptor will remain open at least as long as the Capability::Client remains alive.
// If you need it to last longer, you will need to `dup()` it.
// TODO(someday): method(s) for Join // TODO(someday): method(s) for Join
protected: protected:
...@@ -331,6 +344,11 @@ public: ...@@ -331,6 +344,11 @@ public:
// is no longer needed. `context` may be used to allocate the output struct and deal with // is no longer needed. `context` may be used to allocate the output struct and deal with
// cancellation. // cancellation.
virtual kj::Maybe<int> getFd() { return nullptr; }
// If this capability is backed by a file descriptor that is safe to directly expose to clients,
// returns that FD. When FD passing has been enabled in the RPC layer, this FD may be sent to
// other processes along with the capability.
// TODO(someday): Method which can optionally be overridden to implement Join when the object is // TODO(someday): Method which can optionally be overridden to implement Join when the object is
// a proxy. // a proxy.
...@@ -563,6 +581,10 @@ public: ...@@ -563,6 +581,10 @@ public:
// Otherwise, return nullptr. Default implementation (which everyone except LocalClient should // Otherwise, return nullptr. Default implementation (which everyone except LocalClient should
// use) always returns nullptr. // use) always returns nullptr.
virtual kj::Maybe<int> getFd() = 0;
// Implements Capability::Client::getFd(). If this returns null but whenMoreResolved() returns
// non-null, then Capability::Client::getFd() waits for resolution and tries again.
static kj::Own<ClientHook> from(Capability::Client client) { return kj::mv(client.hook); } static kj::Own<ClientHook> from(Capability::Client client) { return kj::mv(client.hook); }
}; };
......
...@@ -469,6 +469,13 @@ public: ...@@ -469,6 +469,13 @@ public:
return MEMBRANE_BRAND; return MEMBRANE_BRAND;
} }
kj::Maybe<int> getFd() override {
// We can't let FDs pass over membranes because we have no way to enforce the membrane policy
// on them. If the MembranePolicy wishes to explicitly permit certain FDs to pass, it can
// always do so by overriding the appropriate policy methods.
return nullptr;
}
private: private:
kj::Own<ClientHook> inner; kj::Own<ClientHook> inner;
kj::Own<MembranePolicy> policy; kj::Own<MembranePolicy> policy;
......
...@@ -21,12 +21,17 @@ ...@@ -21,12 +21,17 @@
#define CAPNP_TESTING_CAPNP 1 #define CAPNP_TESTING_CAPNP 1
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "rpc-twoparty.h" #include "rpc-twoparty.h"
#include "test-util.h" #include "test-util.h"
#include <capnp/rpc.capnp.h> #include <capnp/rpc.capnp.h>
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/thread.h> #include <kj/thread.h>
#include <kj/compat/gtest.h> #include <kj/compat/gtest.h>
#include <kj/miniposix.h>
// TODO(cleanup): Auto-generate stringification functions for union discriminants. // TODO(cleanup): Auto-generate stringification functions for union discriminants.
namespace capnp { namespace capnp {
...@@ -419,6 +424,104 @@ TEST(TwoPartyNetwork, BootstrapFactory) { ...@@ -419,6 +424,104 @@ TEST(TwoPartyNetwork, BootstrapFactory) {
EXPECT_TRUE(bootstrapFactory.called); EXPECT_TRUE(bootstrapFactory.called);
} }
// =======================================================================================
#if !_WIN32 && !__CYGWIN__ // Windows and Cygwin don't support SCM_RIGHTS.
KJ_TEST("send FD over RPC") {
auto io = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
TwoPartyServer server(kj::heap<TestMoreStuffImpl>(callCount, handleCount));
auto pipe = io.provider->newCapabilityPipe();
server.accept(kj::mv(pipe.ends[0]), 2);
TwoPartyClient client(*pipe.ends[1], 2);
auto cap = client.bootstrap().castAs<test::TestMoreStuff>();
int pipeFds[2];
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in1(pipeFds[0]);
kj::AutoCloseFd out1(pipeFds[1]);
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in2(pipeFds[0]);
kj::AutoCloseFd out2(pipeFds[1]);
capnp::RemotePromise<test::TestMoreStuff::WriteToFdResults> promise = nullptr;
{
auto req = cap.writeToFdRequest();
// Order reversal intentional, just trying to mix things up.
req.setFdCap1(kj::heap<TestFdCap>(kj::mv(out2)));
req.setFdCap2(kj::heap<TestFdCap>(kj::mv(out1)));
promise = req.send();
}
int in3 = KJ_ASSERT_NONNULL(promise.getFdCap3().getFd().wait(io.waitScope));
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in3))->readAllText().wait(io.waitScope)
== "baz");
{
auto promise2 = kj::mv(promise); // make sure the PipelineHook also goes out of scope
auto response = promise2.wait(io.waitScope);
KJ_EXPECT(response.getSecondFdPresent());
}
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in1))->readAllText().wait(io.waitScope)
== "bar");
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in2))->readAllText().wait(io.waitScope)
== "foo");
}
KJ_TEST("FD per message limit") {
auto io = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
TwoPartyServer server(kj::heap<TestMoreStuffImpl>(callCount, handleCount));
auto pipe = io.provider->newCapabilityPipe();
server.accept(kj::mv(pipe.ends[0]), 1);
TwoPartyClient client(*pipe.ends[1], 1);
auto cap = client.bootstrap().castAs<test::TestMoreStuff>();
int pipeFds[2];
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in1(pipeFds[0]);
kj::AutoCloseFd out1(pipeFds[1]);
KJ_SYSCALL(kj::miniposix::pipe(pipeFds));
kj::AutoCloseFd in2(pipeFds[0]);
kj::AutoCloseFd out2(pipeFds[1]);
capnp::RemotePromise<test::TestMoreStuff::WriteToFdResults> promise = nullptr;
{
auto req = cap.writeToFdRequest();
// Order reversal intentional, just trying to mix things up.
req.setFdCap1(kj::heap<TestFdCap>(kj::mv(out2)));
req.setFdCap2(kj::heap<TestFdCap>(kj::mv(out1)));
promise = req.send();
}
int in3 = KJ_ASSERT_NONNULL(promise.getFdCap3().getFd().wait(io.waitScope));
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in3))->readAllText().wait(io.waitScope)
== "baz");
{
auto promise2 = kj::mv(promise); // make sure the PipelineHook also goes out of scope
auto response = promise2.wait(io.waitScope);
KJ_EXPECT(!response.getSecondFdPresent());
}
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in1))->readAllText().wait(io.waitScope)
== "");
KJ_EXPECT(io.lowLevelProvider->wrapInputFd(kj::mv(in2))->readAllText().wait(io.waitScope)
== "foo");
}
#endif // !_WIN32 && !__CYGWIN__
} // namespace } // namespace
} // namespace _ } // namespace _
} // namespace capnp } // namespace capnp
...@@ -22,12 +22,13 @@ ...@@ -22,12 +22,13 @@
#include "rpc-twoparty.h" #include "rpc-twoparty.h"
#include "serialize-async.h" #include "serialize-async.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/io.h>
namespace capnp { namespace capnp {
TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions) ReaderOptions receiveOptions)
: stream(stream), side(side), peerVatId(4), : stream(&stream), maxFdsPerMessage(0), side(side), peerVatId(4),
receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) { receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) {
peerVatId.initRoot<rpc::twoparty::VatId>().setSide( peerVatId.initRoot<rpc::twoparty::VatId>().setSide(
side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER
...@@ -38,6 +39,13 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty: ...@@ -38,6 +39,13 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty:
disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller); disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller);
} }
TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMessage,
rpc::twoparty::Side side, ReaderOptions receiveOptions)
: TwoPartyVatNetwork(stream, side, receiveOptions) {
this->stream = &stream;
this->maxFdsPerMessage = maxFdsPerMessage;
}
void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const { void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const {
if (--refcount == 0) { if (--refcount == 0) {
fulfiller->fulfill(); fulfiller->fulfill();
...@@ -81,6 +89,12 @@ public: ...@@ -81,6 +89,12 @@ public:
return message.getRoot<AnyPointer>(); return message.getRoot<AnyPointer>();
} }
void setFds(kj::Array<int> fds) override {
if (network.stream.is<kj::AsyncCapabilityStream*>()) {
this->fds = kj::mv(fds);
}
}
void send() override { void send() override {
size_t size = 0; size_t size = 0;
for (auto& segment: message.getSegmentsForOutput()) { for (auto& segment: message.getSegmentsForOutput()) {
...@@ -98,7 +112,15 @@ public: ...@@ -98,7 +112,15 @@ public:
// Note that if the write fails, all further writes will be skipped due to the exception. // Note that if the write fails, all further writes will be skipped due to the exception.
// We never actually handle this exception because we assume the read end will fail as well // We never actually handle this exception because we assume the read end will fail as well
// and it's cleaner to handle the failure there. // and it's cleaner to handle the failure there.
return writeMessage(network.stream, message); KJ_SWITCH_ONEOF(network.stream) {
KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
return writeMessage(*ioStream, message);
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
return writeMessage(*capStream, fds, message);
}
}
KJ_UNREACHABLE;
}).attach(kj::addRef(*this)) }).attach(kj::addRef(*this))
// Note that it's important that the eagerlyEvaluate() come *after* the attach() because // Note that it's important that the eagerlyEvaluate() come *after* the attach() because
// otherwise the message (and any capabilities in it) will not be released until a new // otherwise the message (and any capabilities in it) will not be released until a new
...@@ -109,18 +131,32 @@ public: ...@@ -109,18 +131,32 @@ public:
private: private:
TwoPartyVatNetwork& network; TwoPartyVatNetwork& network;
MallocMessageBuilder message; MallocMessageBuilder message;
kj::Array<int> fds;
}; };
class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage { class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage {
public: public:
IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {} IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {}
IncomingMessageImpl(MessageReaderAndFds init, kj::Array<kj::AutoCloseFd> fdSpace)
: message(kj::mv(init.reader)),
fdSpace(kj::mv(fdSpace)),
fds(init.fds) {
KJ_DASSERT(this->fds.begin() == this->fdSpace.begin());
}
AnyPointer::Reader getBody() override { AnyPointer::Reader getBody() override {
return message->getRoot<AnyPointer>(); return message->getRoot<AnyPointer>();
} }
kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() override {
return fds;
}
private: private:
kj::Own<MessageReader> message; kj::Own<MessageReader> message;
kj::Array<kj::AutoCloseFd> fdSpace;
kj::ArrayPtr<kj::AutoCloseFd> fds;
}; };
rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() { rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() {
...@@ -132,22 +168,52 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSeg ...@@ -132,22 +168,52 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSeg
} }
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() { kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
return kj::evalLater([&]() { return kj::evalLater([this]() {
return tryReadMessage(stream, receiveOptions) KJ_SWITCH_ONEOF(stream) {
.then([&](kj::Maybe<kj::Own<MessageReader>>&& message) KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
-> kj::Maybe<kj::Own<IncomingRpcMessage>> { return tryReadMessage(*ioStream, receiveOptions)
KJ_IF_MAYBE(m, message) { .then([](kj::Maybe<kj::Own<MessageReader>>&& message)
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m))); -> kj::Maybe<kj::Own<IncomingRpcMessage>> {
} else { KJ_IF_MAYBE(m, message) {
return nullptr; return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
} else {
return nullptr;
}
});
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
auto fdSpace = kj::heapArray<kj::AutoCloseFd>(maxFdsPerMessage);
auto promise = tryReadMessage(*capStream, fdSpace, receiveOptions);
return promise.then([fdSpace = kj::mv(fdSpace)]
(kj::Maybe<MessageReaderAndFds>&& messageAndFds) mutable
-> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, messageAndFds) {
if (m->fds.size() > 0) {
return kj::Own<IncomingRpcMessage>(
kj::heap<IncomingMessageImpl>(kj::mv(*m), kj::mv(fdSpace)));
} else {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(m->reader)));
}
} else {
return nullptr;
}
});
} }
}); }
KJ_UNREACHABLE;
}); });
} }
kj::Promise<void> TwoPartyVatNetwork::shutdown() { kj::Promise<void> TwoPartyVatNetwork::shutdown() {
kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() { kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() {
stream.shutdownWrite(); KJ_SWITCH_ONEOF(stream) {
KJ_CASE_ONEOF(ioStream, kj::AsyncIoStream*) {
ioStream->shutdownWrite();
}
KJ_CASE_ONEOF(capStream, kj::AsyncCapabilityStream*) {
capStream->shutdownWrite();
}
}
}); });
previousWrite = nullptr; previousWrite = nullptr;
return kj::mv(result); return kj::mv(result);
...@@ -168,6 +234,14 @@ struct TwoPartyServer::AcceptedConnection { ...@@ -168,6 +234,14 @@ struct TwoPartyServer::AcceptedConnection {
: connection(kj::mv(connectionParam)), : connection(kj::mv(connectionParam)),
network(*connection, rpc::twoparty::Side::SERVER), network(*connection, rpc::twoparty::Side::SERVER),
rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {} rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
explicit AcceptedConnection(Capability::Client bootstrapInterface,
kj::Own<kj::AsyncCapabilityStream>&& connectionParam,
uint maxFdsPerMessage)
: connection(kj::mv(connectionParam)),
network(kj::downcast<kj::AsyncCapabilityStream>(*connection),
maxFdsPerMessage, rpc::twoparty::Side::SERVER),
rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
}; };
void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) { void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
...@@ -178,6 +252,16 @@ void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) { ...@@ -178,6 +252,16 @@ void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
tasks.add(promise.attach(kj::mv(connectionState))); tasks.add(promise.attach(kj::mv(connectionState)));
} }
void TwoPartyServer::accept(
kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMessage) {
auto connectionState = kj::heap<AcceptedConnection>(
bootstrapInterface, kj::mv(connection), maxFdsPerMessage);
// Run the connection until disconnect.
auto promise = connectionState->network.onDisconnect();
tasks.add(promise.attach(kj::mv(connectionState)));
}
kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) { kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
return listener.accept() return listener.accept()
.then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable { .then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable {
...@@ -186,6 +270,15 @@ kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) { ...@@ -186,6 +270,15 @@ kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
}); });
} }
kj::Promise<void> TwoPartyServer::listenCapStreamReceiver(
kj::ConnectionReceiver& listener, uint maxFdsPerMessage) {
return listener.accept()
.then([this,&listener,maxFdsPerMessage](kj::Own<kj::AsyncIoStream>&& connection) mutable {
accept(connection.downcast<kj::AsyncCapabilityStream>(), maxFdsPerMessage);
return listenCapStreamReceiver(listener, maxFdsPerMessage);
});
}
void TwoPartyServer::taskFailed(kj::Exception&& exception) { void TwoPartyServer::taskFailed(kj::Exception&& exception) {
KJ_LOG(ERROR, exception); KJ_LOG(ERROR, exception);
} }
...@@ -195,12 +288,22 @@ TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection) ...@@ -195,12 +288,22 @@ TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection)
rpcSystem(makeRpcClient(network)) {} rpcSystem(makeRpcClient(network)) {}
TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage)
: network(connection, maxFdsPerMessage, rpc::twoparty::Side::CLIENT),
rpcSystem(makeRpcClient(network)) {}
TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection, TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection,
Capability::Client bootstrapInterface, Capability::Client bootstrapInterface,
rpc::twoparty::Side side) rpc::twoparty::Side side)
: network(connection, side), : network(connection, side),
rpcSystem(network, bootstrapInterface) {} rpcSystem(network, bootstrapInterface) {}
TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage,
Capability::Client bootstrapInterface,
rpc::twoparty::Side side)
: network(connection, maxFdsPerMessage, side),
rpcSystem(network, bootstrapInterface) {}
Capability::Client TwoPartyClient::bootstrap() { Capability::Client TwoPartyClient::bootstrap() {
MallocMessageBuilder message(4); MallocMessageBuilder message(4);
auto vatId = message.getRoot<rpc::twoparty::VatId>(); auto vatId = message.getRoot<rpc::twoparty::VatId>();
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "message.h" #include "message.h"
#include <kj/async-io.h> #include <kj/async-io.h>
#include <capnp/rpc-twoparty.capnp.h> #include <capnp/rpc-twoparty.capnp.h>
#include <kj/one-of.h>
namespace capnp { namespace capnp {
...@@ -53,6 +54,19 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, ...@@ -53,6 +54,19 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase,
public: public:
TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions = ReaderOptions()); ReaderOptions receiveOptions = ReaderOptions());
TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMessage,
rpc::twoparty::Side side, ReaderOptions receiveOptions = ReaderOptions());
// To support FD passing, pass an AsyncCapabilityStream and `maxFdsPerMessage`, which specifies
// the maximum number of file descriptors to accept from the peer in any one RPC message. It is
// important to keep maxFdsPerMessage low in order to stop DoS attacks that fill up your FD table.
//
// Note that this limit applies only to incoming messages; outgoing messages are allowed to have
// more FDs. Sometimes it makes sense to enforce a limit of zero in one direction while having
// a non-zero limit in the other. For example, in a supervisor/sandbox scenario, typically there
// are many use cases for passing FDs from supervisor to sandbox but no use case for vice versa.
// The supervisor may be configured not to accept any FDs from the sandbox in order to reduce
// risk of DoS attacks.
KJ_DISALLOW_COPY(TwoPartyVatNetwork); KJ_DISALLOW_COPY(TwoPartyVatNetwork);
kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); } kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); }
...@@ -70,7 +84,8 @@ private: ...@@ -70,7 +84,8 @@ private:
class OutgoingMessageImpl; class OutgoingMessageImpl;
class IncomingMessageImpl; class IncomingMessageImpl;
kj::AsyncIoStream& stream; kj::OneOf<kj::AsyncIoStream*, kj::AsyncCapabilityStream*> stream;
uint maxFdsPerMessage;
rpc::twoparty::Side side; rpc::twoparty::Side side;
MallocMessageBuilder peerVatId; MallocMessageBuilder peerVatId;
ReaderOptions receiveOptions; ReaderOptions receiveOptions;
...@@ -120,6 +135,7 @@ public: ...@@ -120,6 +135,7 @@ public:
explicit TwoPartyServer(Capability::Client bootstrapInterface); explicit TwoPartyServer(Capability::Client bootstrapInterface);
void accept(kj::Own<kj::AsyncIoStream>&& connection); void accept(kj::Own<kj::AsyncIoStream>&& connection);
void accept(kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMessage);
// Accepts the connection for servicing. // Accepts the connection for servicing.
kj::Promise<void> listen(kj::ConnectionReceiver& listener); kj::Promise<void> listen(kj::ConnectionReceiver& listener);
...@@ -127,6 +143,11 @@ public: ...@@ -127,6 +143,11 @@ public:
// exception is thrown while trying to accept. You may discard the returned promise to cancel // exception is thrown while trying to accept. You may discard the returned promise to cancel
// listening. // listening.
kj::Promise<void> listenCapStreamReceiver(
kj::ConnectionReceiver& listener, uint maxFdsPerMessage);
// Listen with support for FD transfers. `listener.accept()` must return instances of
// AsyncCapabilityStream, otherwise this will crash.
private: private:
Capability::Client bootstrapInterface; Capability::Client bootstrapInterface;
kj::TaskSet tasks; kj::TaskSet tasks;
...@@ -141,8 +162,12 @@ class TwoPartyClient { ...@@ -141,8 +162,12 @@ class TwoPartyClient {
public: public:
explicit TwoPartyClient(kj::AsyncIoStream& connection); explicit TwoPartyClient(kj::AsyncIoStream& connection);
explicit TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage);
TwoPartyClient(kj::AsyncIoStream& connection, Capability::Client bootstrapInterface, TwoPartyClient(kj::AsyncIoStream& connection, Capability::Client bootstrapInterface,
rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT); rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT);
TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage,
Capability::Client bootstrapInterface,
rpc::twoparty::Side side = rpc::twoparty::Side::CLIENT);
Capability::Client bootstrap(); Capability::Client bootstrap();
// Get the server's bootstrap interface. // Get the server's bootstrap interface.
......
This diff is collapsed.
...@@ -988,6 +988,63 @@ struct CapDescriptor { ...@@ -988,6 +988,63 @@ struct CapDescriptor {
# Level 1 and 2 implementations that receive a `thirdPartyHosted` may simply send calls to its # Level 1 and 2 implementations that receive a `thirdPartyHosted` may simply send calls to its
# `vine` instead. # `vine` instead.
} }
attachedFd @6 :UInt8 = 0xff;
# If the RPC message in which this CapDescriptor was delivered also had file descriptors
# attached, and `fd` is a valid index into the list of attached file descriptors, then
# that file descriptor should be attached to this capability. If `attachedFd` is out-of-bounds
# for said list, then no FD is attached.
#
# For example, if the RPC message arrived over a Unix socket, then file descriptors may be
# attached by sending an SCM_RIGHTS ancillary message attached to the data bytes making up the
# raw message. Receivers who wish to opt into FD passing should arrange to receive SCM_RIGHTS
# whenever receiving an RPC message. Senders who wish to send FDs need not verify whether the
# receiver knows how to receive them, because the operating system will automatically discard
# ancillary messages like SCM_RIGHTS if the receiver doesn't ask to receive them, including
# automatically closing any FDs.
#
# It is up to the application protocol to define what capabilities are expected to have file
# descriptors attached, and what those FDs mean. But, for example, an application could use this
# to open a file on disk and then transmit the open file descriptor to a sandboxed process that
# does not otherwise have permission to access the filesystem directly. This is usually an
# optimization: the sending process could instead provide an RPC interface supporting all the
# operations needed (such as reading and writing a file), but by passing the file descriptor
# directly, the recipient can often perform operations much more efficiently. Application
# designers are encouraged to provide such RPC interfaces and automatically fall back to them
# when FD passing is not available, so that the application can still work when the parties are
# remote over a network.
#
# An attached FD is most often associated with a `senderHosted` descriptor. It could also make
# sense in the case of `thirdPartyHosted`: in this case, the sender is forwarding the FD that
# they received from the third party, so that the receiver can start using it without first
# interacting with the third party. This is an optional optimization -- the middleman may choose
# not to forward capabilities, in which case the receiver will need to complete the handshake
# with the third party directly before receiving the FD. If an implementation receives a second
# attached FD after having already received one previously (e.g. both in a `thirdPartyHosted`
# CapDescriptor and then later again when receiving the final capability directly from the
# third party), the implementation should discard the later FD and stick with the original. At
# present, there is no known reason why other capability types (e.g. `receiverHosted`) would want
# to carry an attached FD, but we reserve the right to define a meaning for this in the future.
#
# Each file descriptor attached to the message must be used in no more than one CapDescriptor,
# so that the receiver does not need to use dup() or refcounting to handle the possibility of
# multiple capabilities using the same descriptor. If multiple CapDescriptors do point to the
# same FD index, then the receiver can arbitrarily choose which capability ends up having the
# FD attached.
#
# To mitigate DoS attacks, RPC implementations should limit the number of FDs they are willing to
# receive in a single message to a small value. If a message happens to contain more than that,
# the list is truncated. Moreover, in some cases, FD passing needs to be blocked entirely for
# security or implementation reasons, in which case the list may be truncated to zero. Hence,
# `attachedFd` might point past the end of the list, which the implementation should treat as if
# no FD was attached at all.
#
# The type of this field was chosen to be UInt8 because Linux supports sending only a maximum
# of 253 file descriptors in an SCM_RIGHTS message anyway, and CapDescriptor had two bytes of
# padding left -- so after adding this, there is still one byte for a future feature.
# Conveniently, this also means we're able to use 0xff as the default value, which will always
# be out-of-range (of course, the implementation should explicitly enforce that 255 descriptors
# cannot be sent at once, rather than relying on Linux to do so).
} }
struct PromisedAnswer { struct PromisedAnswer {
...@@ -1256,6 +1313,11 @@ using RecipientId = AnyPointer; ...@@ -1256,6 +1313,11 @@ using RecipientId = AnyPointer;
# #
# In a network where each vat has a public/private key pair, this could simply be the public key # In a network where each vat has a public/private key pair, this could simply be the public key
# fingerprint of the recipient along with a nonce matching the one in the `ProvisionId`. # fingerprint of the recipient along with a nonce matching the one in the `ProvisionId`.
#
# As another example, when communicating between processes on the same machine over Unix sockets,
# RecipientId could simply refer to a file descriptor attached to the message via SCM_RIGHTS.
# This file descriptor would be one end of a newly-created socketpair, with the other end having
# been sent to the capability's recipient in ThirdPartyCapId.
using ThirdPartyCapId = AnyPointer; using ThirdPartyCapId = AnyPointer;
# **(level 3)** # **(level 3)**
...@@ -1266,6 +1328,11 @@ using ThirdPartyCapId = AnyPointer; ...@@ -1266,6 +1328,11 @@ using ThirdPartyCapId = AnyPointer;
# third party's public key fingerprint, hints on how to connect to the third party (e.g. an IP # third party's public key fingerprint, hints on how to connect to the third party (e.g. an IP
# address), and the nonce used in the corresponding `Provide` message's `RecipientId` as sent # address), and the nonce used in the corresponding `Provide` message's `RecipientId` as sent
# to that third party (used to identify which capability to pick up). # to that third party (used to identify which capability to pick up).
#
# As another example, when communicating between processes on the same machine over Unix sockets,
# ThirdPartyCapId could simply refer to a file descriptor attached to the message via SCM_RIGHTS.
# This file descriptor would be one end of a newly-created socketpair, with the other end having
# been sent to the process hosting the capability in RecipientId.
using JoinKeyPart = AnyPointer; using JoinKeyPart = AnyPointer;
# **(level 4)** # **(level 4)**
......
...@@ -1413,7 +1413,7 @@ const ::capnp::_::RawSchema s_9a0e61223d96743b = { ...@@ -1413,7 +1413,7 @@ const ::capnp::_::RawSchema s_9a0e61223d96743b = {
1, 2, i_9a0e61223d96743b, nullptr, nullptr, { &s_9a0e61223d96743b, nullptr, nullptr, 0, 0, nullptr } 1, 2, i_9a0e61223d96743b, nullptr, nullptr, { &s_9a0e61223d96743b, nullptr, nullptr, 0, 0, nullptr }
}; };
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = { static const ::capnp::_::AlignedData<130> b_8523ddc40b86b8b0 = {
{ 0, 0, 0, 0, 5, 0, 6, 0, { 0, 0, 0, 0, 5, 0, 6, 0,
176, 184, 134, 11, 196, 221, 35, 133, 176, 184, 134, 11, 196, 221, 35, 133,
16, 0, 0, 0, 1, 0, 1, 0, 16, 0, 0, 0, 1, 0, 1, 0,
...@@ -1423,7 +1423,7 @@ static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = { ...@@ -1423,7 +1423,7 @@ static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = {
21, 0, 0, 0, 242, 0, 0, 0, 21, 0, 0, 0, 242, 0, 0, 0,
33, 0, 0, 0, 7, 0, 0, 0, 33, 0, 0, 0, 7, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
29, 0, 0, 0, 87, 1, 0, 0, 29, 0, 0, 0, 143, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
99, 97, 112, 110, 112, 47, 114, 112, 99, 97, 112, 110, 112, 47, 114, 112,
...@@ -1431,49 +1431,56 @@ static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = { ...@@ -1431,49 +1431,56 @@ static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = {
67, 97, 112, 68, 101, 115, 99, 114, 67, 97, 112, 68, 101, 115, 99, 114,
105, 112, 116, 111, 114, 0, 0, 0, 105, 112, 116, 111, 114, 0, 0, 0,
0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0,
24, 0, 0, 0, 3, 0, 4, 0, 28, 0, 0, 0, 3, 0, 4, 0,
0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 255, 255, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
153, 0, 0, 0, 42, 0, 0, 0, 181, 0, 0, 0, 42, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
148, 0, 0, 0, 3, 0, 1, 0, 176, 0, 0, 0, 3, 0, 1, 0,
160, 0, 0, 0, 2, 0, 1, 0, 188, 0, 0, 0, 2, 0, 1, 0,
1, 0, 254, 255, 1, 0, 0, 0, 1, 0, 254, 255, 1, 0, 0, 0,
0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
157, 0, 0, 0, 106, 0, 0, 0, 185, 0, 0, 0, 106, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
156, 0, 0, 0, 3, 0, 1, 0, 184, 0, 0, 0, 3, 0, 1, 0,
168, 0, 0, 0, 2, 0, 1, 0, 196, 0, 0, 0, 2, 0, 1, 0,
2, 0, 253, 255, 1, 0, 0, 0, 2, 0, 253, 255, 1, 0, 0, 0,
0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
165, 0, 0, 0, 114, 0, 0, 0, 193, 0, 0, 0, 114, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
164, 0, 0, 0, 3, 0, 1, 0, 192, 0, 0, 0, 3, 0, 1, 0,
176, 0, 0, 0, 2, 0, 1, 0, 204, 0, 0, 0, 2, 0, 1, 0,
3, 0, 252, 255, 1, 0, 0, 0, 3, 0, 252, 255, 1, 0, 0, 0,
0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
173, 0, 0, 0, 122, 0, 0, 0, 201, 0, 0, 0, 122, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
172, 0, 0, 0, 3, 0, 1, 0, 200, 0, 0, 0, 3, 0, 1, 0,
184, 0, 0, 0, 2, 0, 1, 0, 212, 0, 0, 0, 2, 0, 1, 0,
4, 0, 251, 255, 0, 0, 0, 0, 4, 0, 251, 255, 0, 0, 0, 0,
0, 0, 1, 0, 4, 0, 0, 0, 0, 0, 1, 0, 4, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
181, 0, 0, 0, 122, 0, 0, 0, 209, 0, 0, 0, 122, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
180, 0, 0, 0, 3, 0, 1, 0, 208, 0, 0, 0, 3, 0, 1, 0,
192, 0, 0, 0, 2, 0, 1, 0, 220, 0, 0, 0, 2, 0, 1, 0,
5, 0, 250, 255, 0, 0, 0, 0, 5, 0, 250, 255, 0, 0, 0, 0,
0, 0, 1, 0, 5, 0, 0, 0, 0, 0, 1, 0, 5, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
189, 0, 0, 0, 138, 0, 0, 0, 217, 0, 0, 0, 138, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
192, 0, 0, 0, 3, 0, 1, 0, 220, 0, 0, 0, 3, 0, 1, 0,
204, 0, 0, 0, 2, 0, 1, 0, 232, 0, 0, 0, 2, 0, 1, 0,
6, 0, 0, 0, 2, 0, 0, 0,
0, 0, 1, 0, 6, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0,
229, 0, 0, 0, 90, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
228, 0, 0, 0, 3, 0, 1, 0,
240, 0, 0, 0, 2, 0, 1, 0,
110, 111, 110, 101, 0, 0, 0, 0, 110, 111, 110, 101, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
...@@ -1526,6 +1533,15 @@ static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = { ...@@ -1526,6 +1533,15 @@ static const ::capnp::_::AlignedData<114> b_8523ddc40b86b8b0 = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
16, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
97, 116, 116, 97, 99, 104, 101, 100,
70, 100, 0, 0, 0, 0, 0, 0,
6, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
6, 0, 255, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, } 0, 0, 0, 0, 0, 0, 0, 0, }
}; };
...@@ -1535,11 +1551,11 @@ static const ::capnp::_::RawSchema* const d_8523ddc40b86b8b0[] = { ...@@ -1535,11 +1551,11 @@ static const ::capnp::_::RawSchema* const d_8523ddc40b86b8b0[] = {
&s_d37007fde1f0027d, &s_d37007fde1f0027d,
&s_d800b1d6cd6f1ca0, &s_d800b1d6cd6f1ca0,
}; };
static const uint16_t m_8523ddc40b86b8b0[] = {0, 4, 3, 1, 2, 5}; static const uint16_t m_8523ddc40b86b8b0[] = {6, 0, 4, 3, 1, 2, 5};
static const uint16_t i_8523ddc40b86b8b0[] = {0, 1, 2, 3, 4, 5}; static const uint16_t i_8523ddc40b86b8b0[] = {0, 1, 2, 3, 4, 5, 6};
const ::capnp::_::RawSchema s_8523ddc40b86b8b0 = { const ::capnp::_::RawSchema s_8523ddc40b86b8b0 = {
0x8523ddc40b86b8b0, b_8523ddc40b86b8b0.words, 114, d_8523ddc40b86b8b0, m_8523ddc40b86b8b0, 0x8523ddc40b86b8b0, b_8523ddc40b86b8b0.words, 130, d_8523ddc40b86b8b0, m_8523ddc40b86b8b0,
2, 6, i_8523ddc40b86b8b0, nullptr, nullptr, { &s_8523ddc40b86b8b0, nullptr, nullptr, 0, 0, nullptr } 2, 7, i_8523ddc40b86b8b0, nullptr, nullptr, { &s_8523ddc40b86b8b0, nullptr, nullptr, 0, 0, nullptr }
}; };
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
static const ::capnp::_::AlignedData<57> b_d800b1d6cd6f1ca0 = { static const ::capnp::_::AlignedData<57> b_d800b1d6cd6f1ca0 = {
......
...@@ -2028,6 +2028,8 @@ public: ...@@ -2028,6 +2028,8 @@ public:
inline bool hasThirdPartyHosted() const; inline bool hasThirdPartyHosted() const;
inline ::capnp::rpc::ThirdPartyCapDescriptor::Reader getThirdPartyHosted() const; inline ::capnp::rpc::ThirdPartyCapDescriptor::Reader getThirdPartyHosted() const;
inline ::uint8_t getAttachedFd() const;
private: private:
::capnp::_::StructReader _reader; ::capnp::_::StructReader _reader;
template <typename, ::capnp::Kind> template <typename, ::capnp::Kind>
...@@ -2089,6 +2091,9 @@ public: ...@@ -2089,6 +2091,9 @@ public:
inline void adoptThirdPartyHosted(::capnp::Orphan< ::capnp::rpc::ThirdPartyCapDescriptor>&& value); inline void adoptThirdPartyHosted(::capnp::Orphan< ::capnp::rpc::ThirdPartyCapDescriptor>&& value);
inline ::capnp::Orphan< ::capnp::rpc::ThirdPartyCapDescriptor> disownThirdPartyHosted(); inline ::capnp::Orphan< ::capnp::rpc::ThirdPartyCapDescriptor> disownThirdPartyHosted();
inline ::uint8_t getAttachedFd();
inline void setAttachedFd( ::uint8_t value);
private: private:
::capnp::_::StructBuilder _builder; ::capnp::_::StructBuilder _builder;
template <typename, ::capnp::Kind> template <typename, ::capnp::Kind>
...@@ -4670,6 +4675,20 @@ inline ::capnp::Orphan< ::capnp::rpc::ThirdPartyCapDescriptor> CapDescriptor::Bu ...@@ -4670,6 +4675,20 @@ inline ::capnp::Orphan< ::capnp::rpc::ThirdPartyCapDescriptor> CapDescriptor::Bu
::capnp::bounded<0>() * ::capnp::POINTERS)); ::capnp::bounded<0>() * ::capnp::POINTERS));
} }
inline ::uint8_t CapDescriptor::Reader::getAttachedFd() const {
return _reader.getDataField< ::uint8_t>(
::capnp::bounded<2>() * ::capnp::ELEMENTS, 255u);
}
inline ::uint8_t CapDescriptor::Builder::getAttachedFd() {
return _builder.getDataField< ::uint8_t>(
::capnp::bounded<2>() * ::capnp::ELEMENTS, 255u);
}
inline void CapDescriptor::Builder::setAttachedFd( ::uint8_t value) {
_builder.setDataField< ::uint8_t>(
::capnp::bounded<2>() * ::capnp::ELEMENTS, value, 255u);
}
inline ::uint32_t PromisedAnswer::Reader::getQuestionId() const { inline ::uint32_t PromisedAnswer::Reader::getQuestionId() const {
return _reader.getDataField< ::uint32_t>( return _reader.getDataField< ::uint32_t>(
::capnp::bounded<0>() * ::capnp::ELEMENTS); ::capnp::bounded<0>() * ::capnp::ELEMENTS);
......
...@@ -28,6 +28,8 @@ ...@@ -28,6 +28,8 @@
#include "capability.h" #include "capability.h"
#include "rpc-prelude.h" #include "rpc-prelude.h"
namespace kj { class AutoCloseFd; }
namespace capnp { namespace capnp {
template <typename VatId, typename ProvisionId, typename RecipientId, template <typename VatId, typename ProvisionId, typename RecipientId,
...@@ -305,6 +307,10 @@ public: ...@@ -305,6 +307,10 @@ public:
// Get the message body, which the caller may fill in any way it wants. (The standard RPC // Get the message body, which the caller may fill in any way it wants. (The standard RPC
// implementation initializes it as a Message as defined in rpc.capnp.) // implementation initializes it as a Message as defined in rpc.capnp.)
virtual void setFds(kj::Array<int> fds) {}
// Set the list of file descriptors to send along with this message, if FD passing is supported.
// An implementation may ignore this.
virtual void send() = 0; virtual void send() = 0;
// Send the message, or at least put it in a queue to be sent later. Note that the builder // Send the message, or at least put it in a queue to be sent later. Note that the builder
// returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed. // returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed.
...@@ -317,6 +323,14 @@ public: ...@@ -317,6 +323,14 @@ public:
virtual AnyPointer::Reader getBody() = 0; virtual AnyPointer::Reader getBody() = 0;
// Get the message body, to be interpreted by the caller. (The standard RPC implementation // Get the message body, to be interpreted by the caller. (The standard RPC implementation
// interprets it as a Message as defined in rpc.capnp.) // interprets it as a Message as defined in rpc.capnp.)
virtual kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() { return nullptr; }
// If the transport supports attached file descriptors and some were attached to this message,
// returns them. Otherwise returns an empty array. It is intended that the caller will move the
// FDs out of this table when they are consumed, possibly leaving behind a null slot. Callers
// should be careful to check if an FD was already consumed by comparing the slot with `nullptr`.
// (We don't use Maybe here because moving from a Maybe doesn't make it null, so it would only
// add confusion. Moving from an AutoCloseFd does in fact make it null.)
}; };
template <typename VatId, typename ProvisionId, typename RecipientId, template <typename VatId, typename ProvisionId, typename RecipientId,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "serialize-async.h" #include "serialize-async.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/io.h>
namespace capnp { namespace capnp {
...@@ -35,6 +36,10 @@ public: ...@@ -35,6 +36,10 @@ public:
kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace); kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
kj::Promise<kj::Maybe<size_t>> readWithFds(
kj::AsyncCapabilityStream& inputStream,
kj::ArrayPtr<kj::AutoCloseFd> fds, kj::ArrayPtr<word> scratchSpace);
// implements MessageReader ---------------------------------------- // implements MessageReader ----------------------------------------
kj::ArrayPtr<const word> getSegment(uint id) override { kj::ArrayPtr<const word> getSegment(uint id) override {
...@@ -79,6 +84,27 @@ kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream, ...@@ -79,6 +84,27 @@ kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream,
}); });
} }
kj::Promise<kj::Maybe<size_t>> AsyncMessageReader::readWithFds(
kj::AsyncCapabilityStream& inputStream, kj::ArrayPtr<kj::AutoCloseFd> fds,
kj::ArrayPtr<word> scratchSpace) {
return inputStream.tryReadWithFds(firstWord, sizeof(firstWord), sizeof(firstWord),
fds.begin(), fds.size())
.then([this,&inputStream,KJ_CPCAP(scratchSpace)]
(kj::AsyncCapabilityStream::ReadResult result) mutable
-> kj::Promise<kj::Maybe<size_t>> {
if (result.byteCount == 0) {
return kj::Maybe<size_t>(nullptr);
} else if (result.byteCount < sizeof(firstWord)) {
// EOF in first word.
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
return kj::Maybe<size_t>(nullptr);
}
return readAfterFirstWord(inputStream, scratchSpace)
.then([result]() -> kj::Maybe<size_t> { return result.capCount; });
});
}
kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream, kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) { kj::ArrayPtr<word> scratchSpace) {
if (segmentCount() == 0) { if (segmentCount() == 0) {
...@@ -151,26 +177,57 @@ kj::Promise<kj::Own<MessageReader>> readMessage( ...@@ -151,26 +177,57 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options); auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace); auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader, bool success) { return promise.then([reader = kj::mv(reader)](bool success) mutable -> kj::Own<MessageReader> {
if (!success) { if (!success) {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
} }
return kj::mv(reader); return kj::mv(reader);
})); });
} }
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage( kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options); auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace); auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, return promise.then([reader = kj::mv(reader)](bool success) mutable
[](kj::Own<MessageReader>&& reader, bool success) -> kj::Maybe<kj::Own<MessageReader>> { -> kj::Maybe<kj::Own<MessageReader>> {
if (success) { if (success) {
return kj::mv(reader); return kj::mv(reader);
} else { } else {
return nullptr; return nullptr;
} }
})); });
}
kj::Promise<MessageReaderAndFds> readMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->readWithFds(input, fdSpace, scratchSpace);
return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe<size_t> nfds) mutable
-> MessageReaderAndFds {
KJ_IF_MAYBE(n, nfds) {
return { kj::mv(reader), fdSpace.slice(0, *n) };
} else {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF."));
return { kj::mv(reader), nullptr };
}
});
}
kj::Promise<kj::Maybe<MessageReaderAndFds>> tryReadMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->readWithFds(input, fdSpace, scratchSpace);
return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe<size_t> nfds) mutable
-> kj::Maybe<MessageReaderAndFds> {
KJ_IF_MAYBE(n, nfds) {
return MessageReaderAndFds { kj::mv(reader), fdSpace.slice(0, *n) };
} else {
return nullptr;
}
});
} }
// ======================================================================================= // =======================================================================================
...@@ -184,10 +241,9 @@ struct WriteArrays { ...@@ -184,10 +241,9 @@ struct WriteArrays {
kj::Array<kj::ArrayPtr<const byte>> pieces; kj::Array<kj::ArrayPtr<const byte>> pieces;
}; };
} // namespace template <typename WriteFunc>
kj::Promise<void> writeMessageImpl(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments,
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, WriteFunc&& writeFunc) {
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message."); KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message.");
WriteArrays arrays; WriteArrays arrays;
...@@ -212,10 +268,28 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, ...@@ -212,10 +268,28 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
arrays.pieces[i + 1] = segments[i].asBytes(); arrays.pieces[i + 1] = segments[i].asBytes();
} }
auto promise = output.write(arrays.pieces); auto promise = writeFunc(arrays.pieces);
// Make sure the arrays aren't freed until the write completes. // Make sure the arrays aren't freed until the write completes.
return promise.then(kj::mvCapture(arrays, [](WriteArrays&&) {})); return promise.then(kj::mvCapture(arrays, [](WriteArrays&&) {}));
} }
} // namespace
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
return writeMessageImpl(segments,
[&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
return output.write(pieces);
});
}
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
return writeMessageImpl(segments,
[&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
return output.writeWithFds(pieces[0], pieces.slice(1, pieces.size()), fds);
});
}
} // namespace capnp } // namespace capnp
...@@ -51,11 +51,42 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& bu ...@@ -51,11 +51,42 @@ kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& bu
KJ_WARN_UNUSED_RESULT; KJ_WARN_UNUSED_RESULT;
// Write asynchronously. The parameters must remain valid until the returned promise resolves. // Write asynchronously. The parameters must remain valid until the returned promise resolves.
// -----------------------------------------------------------------------------
// Versions that support FD passing.
struct MessageReaderAndFds {
kj::Own<MessageReader> reader;
kj::ArrayPtr<kj::AutoCloseFd> fds;
};
kj::Promise<MessageReaderAndFds> readMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options = ReaderOptions(), kj::ArrayPtr<word> scratchSpace = nullptr);
// Read a message that may also have file descriptors attached, e.g. from a Unix socket with
// SCM_RIGHTS.
kj::Promise<kj::Maybe<MessageReaderAndFds>> tryReadMessage(
kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace,
ReaderOptions options = ReaderOptions(), kj::ArrayPtr<word> scratchSpace = nullptr);
// Like `readMessage` but returns null on EOF.
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments)
KJ_WARN_UNUSED_RESULT;
kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds,
MessageBuilder& builder)
KJ_WARN_UNUSED_RESULT;
// Write a message with FDs attached, e.g. to a Unix socket with SCM_RIGHTS.
// ======================================================================================= // =======================================================================================
// inline implementation details // inline implementation details
inline kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& builder) { inline kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, MessageBuilder& builder) {
return writeMessage(output, builder.getSegmentsForOutput()); return writeMessage(output, builder.getSegmentsForOutput());
} }
inline kj::Promise<void> writeMessage(
kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds, MessageBuilder& builder) {
return writeMessage(output, fds, builder.getSegmentsForOutput());
}
} // namespace capnp } // namespace capnp
...@@ -19,9 +19,15 @@ ...@@ -19,9 +19,15 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE. // THE SOFTWARE.
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "test-util.h" #include "test-util.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/compat/gtest.h> #include <kj/compat/gtest.h>
#include <kj/io.h>
#include <kj/miniposix.h>
namespace capnp { namespace capnp {
namespace _ { // private namespace _ { // private
...@@ -1144,6 +1150,34 @@ kj::Promise<void> TestMoreStuffImpl::getEnormousString(GetEnormousStringContext ...@@ -1144,6 +1150,34 @@ kj::Promise<void> TestMoreStuffImpl::getEnormousString(GetEnormousStringContext
return kj::READY_NOW; return kj::READY_NOW;
} }
kj::Promise<void> TestMoreStuffImpl::writeToFd(WriteToFdContext context) {
auto params = context.getParams();
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(params.getFdCap1().getFd()
.then([](kj::Maybe<int> fd) {
kj::FdOutputStream(KJ_ASSERT_NONNULL(fd)).write("foo", 3);
}));
promises.add(params.getFdCap2().getFd()
.then([context](kj::Maybe<int> fd) mutable {
context.getResults().setSecondFdPresent(fd != nullptr);
KJ_IF_MAYBE(f, fd) {
kj::FdOutputStream(*f).write("bar", 3);
}
}));
int pair[2];
KJ_SYSCALL(kj::miniposix::pipe(pair));
kj::AutoCloseFd in(pair[0]);
kj::AutoCloseFd out(pair[1]);
kj::FdOutputStream(kj::mv(out)).write("baz", 3);
context.getResults().setFdCap3(kj::heap<TestFdCap>(kj::mv(in)));
return kj::joinPromises(promises.finish());
}
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
} // namespace _ (private) } // namespace _ (private)
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#if !CAPNP_LITE #if !CAPNP_LITE
#include "dynamic.h" #include "dynamic.h"
#include <kj/io.h>
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
// TODO(cleanup): Auto-generate stringification functions for union discriminants. // TODO(cleanup): Auto-generate stringification functions for union discriminants.
...@@ -274,6 +275,8 @@ public: ...@@ -274,6 +275,8 @@ public:
kj::Promise<void> getEnormousString(GetEnormousStringContext context) override; kj::Promise<void> getEnormousString(GetEnormousStringContext context) override;
kj::Promise<void> writeToFd(WriteToFdContext context) override;
private: private:
int& callCount; int& callCount;
int& handleCount; int& handleCount;
...@@ -303,6 +306,18 @@ private: ...@@ -303,6 +306,18 @@ private:
TestInterfaceImpl impl; TestInterfaceImpl impl;
}; };
class TestFdCap final: public test::TestInterface::Server {
// Implementation of TestInterface that wraps a file descriptor.
public:
TestFdCap(kj::AutoCloseFd fd): fd(kj::mv(fd)) {}
kj::Maybe<int> getFd() override { return fd.get(); }
private:
kj::AutoCloseFd fd;
};
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
} // namespace _ (private) } // namespace _ (private)
......
...@@ -860,6 +860,11 @@ interface TestMoreStuff extends(TestCallOrder) { ...@@ -860,6 +860,11 @@ interface TestMoreStuff extends(TestCallOrder) {
getEnormousString @11 () -> (str :Text); getEnormousString @11 () -> (str :Text);
# Attempts to return an 100MB string. Should always fail. # Attempts to return an 100MB string. Should always fail.
writeToFd @13 (fdCap1 :TestInterface, fdCap2 :TestInterface)
-> (fdCap3 :TestInterface, secondFdPresent :Bool);
# Expects fdCap1 and fdCap2 wrap socket file descriptors. Writes "foo" to the first and "bar" to
# the second. Also creates a socketpair, writes "baz" to one end, and returns the other end.
} }
interface TestMembrane { interface TestMembrane {
......
...@@ -23,11 +23,15 @@ ...@@ -23,11 +23,15 @@
// Request Vista-level APIs. // Request Vista-level APIs.
#define WINVER 0x0600 #define WINVER 0x0600
#define _WIN32_WINNT 0x0600 #define _WIN32_WINNT 0x0600
#elif !defined(_GNU_SOURCE)
#define _GNU_SOURCE
#endif #endif
#include "async-io.h" #include "async-io.h"
#include "async-io-internal.h" #include "async-io-internal.h"
#include "debug.h" #include "debug.h"
#include "io.h"
#include "miniposix.h"
#include <kj/compat/gtest.h> #include <kj/compat/gtest.h>
#include <sys/types.h> #include <sys/types.h>
#if _WIN32 #if _WIN32
...@@ -233,6 +237,190 @@ TEST(AsyncIo, CapabilityPipe) { ...@@ -233,6 +237,190 @@ TEST(AsyncIo, CapabilityPipe) {
EXPECT_EQ("bar", result); EXPECT_EQ("bar", result);
EXPECT_EQ("foo", result2); EXPECT_EQ("foo", result2);
} }
TEST(AsyncIo, CapabilityPipeMultiStreamMessage) {
auto ioContext = setupAsyncIo();
auto pipe = ioContext.provider->newCapabilityPipe();
auto pipe2 = ioContext.provider->newCapabilityPipe();
auto pipe3 = ioContext.provider->newCapabilityPipe();
auto streams = heapArrayBuilder<Own<AsyncCapabilityStream>>(2);
streams.add(kj::mv(pipe2.ends[0]));
streams.add(kj::mv(pipe3.ends[0]));
ArrayPtr<const byte> secondBuf = "bar"_kj.asBytes();
pipe.ends[0]->writeWithStreams("foo"_kj.asBytes(), arrayPtr(&secondBuf, 1), streams.finish())
.wait(ioContext.waitScope);
char receiveBuffer[7];
Own<AsyncCapabilityStream> receiveStreams[3];
auto result = pipe.ends[1]->tryReadWithStreams(receiveBuffer, 6, 7, receiveStreams, 3)
.wait(ioContext.waitScope);
KJ_EXPECT(result.byteCount == 6);
receiveBuffer[6] = '\0';
KJ_EXPECT(kj::StringPtr(receiveBuffer) == "foobar");
KJ_ASSERT(result.capCount == 2);
receiveStreams[0]->write("baz", 3).wait(ioContext.waitScope);
receiveStreams[0] = nullptr;
KJ_EXPECT(pipe2.ends[1]->readAllText().wait(ioContext.waitScope) == "baz");
pipe3.ends[1]->write("qux", 3).wait(ioContext.waitScope);
pipe3.ends[1] = nullptr;
KJ_EXPECT(receiveStreams[1]->readAllText().wait(ioContext.waitScope) == "qux");
}
TEST(AsyncIo, ScmRightsTruncatedOdd) {
// Test that if we send two FDs over a unix socket, but the receiving end only receives one, we
// don't leak the other FD.
auto io = setupAsyncIo();
auto capPipe = io.provider->newCapabilityPipe();
int pipeFds[2];
KJ_SYSCALL(miniposix::pipe(pipeFds));
kj::AutoCloseFd in1(pipeFds[0]);
kj::AutoCloseFd out1(pipeFds[1]);
KJ_SYSCALL(miniposix::pipe(pipeFds));
kj::AutoCloseFd in2(pipeFds[0]);
kj::AutoCloseFd out2(pipeFds[1]);
{
AutoCloseFd sendFds[2] = { kj::mv(out1), kj::mv(out2) };
capPipe.ends[0]->writeWithFds("foo"_kj.asBytes(), nullptr, sendFds).wait(io.waitScope);
}
{
char buffer[4];
AutoCloseFd fdBuffer[1];
auto result = capPipe.ends[1]->tryReadWithFds(buffer, 3, 3, fdBuffer, 1).wait(io.waitScope);
KJ_ASSERT(result.capCount == 1);
kj::FdOutputStream(fdBuffer[0].get()).write("bar", 3);
}
// We want to carefully verify that out1 and out2 were closed, without deadlocking if they
// weren't. So we manually set nonblocking mode and then issue read()s.
KJ_SYSCALL(fcntl(in1, F_SETFL, O_NONBLOCK));
KJ_SYSCALL(fcntl(in2, F_SETFL, O_NONBLOCK));
char buffer[4];
ssize_t n;
// First we read "bar" from in1.
KJ_NONBLOCKING_SYSCALL(n = read(in1, buffer, 4));
KJ_ASSERT(n == 3);
buffer[3] = '\0';
KJ_ASSERT(kj::StringPtr(buffer) == "bar");
// Now it should be EOF.
KJ_NONBLOCKING_SYSCALL(n = read(in1, buffer, 4));
if (n < 0) {
KJ_FAIL_ASSERT("out1 was not closed");
}
KJ_ASSERT(n == 0);
// Second pipe should have been closed implicitly because we didn't provide space to receive it.
KJ_NONBLOCKING_SYSCALL(n = read(in2, buffer, 4));
if (n < 0) {
KJ_FAIL_ASSERT("out2 was not closed. This could indicate that your operating system kernel is "
"buggy and leaks file descriptors when an SCM_RIGHTS message is truncated. FreeBSD was "
"known to do this until late 2018, while MacOS still has this bug as of this writing in "
"2019. However, KJ works around the problem on those platforms. You need to enable the "
"same work-around for your OS -- search for 'SCM_RIGHTS' in src/kj/async-io-unix.c++.");
}
KJ_ASSERT(n == 0);
}
TEST(AsyncIo, ScmRightsTruncatedEven) {
// Test that if we send three FDs over a unix socket, but the receiving end only receives two, we
// don't leak the third FD. This is different from the send-two-receive-one case in that
// CMSG_SPACE() on many systems rounds up such that there is always space for an even number of
// FDs. In that case the other test only verifies that our userspace code to close unwanted FDs
// is correct, whereas *this* test really verifies that the *kernel* properly closes truncated
// FDs.
auto io = setupAsyncIo();
auto capPipe = io.provider->newCapabilityPipe();
int pipeFds[2];
KJ_SYSCALL(miniposix::pipe(pipeFds));
kj::AutoCloseFd in1(pipeFds[0]);
kj::AutoCloseFd out1(pipeFds[1]);
KJ_SYSCALL(miniposix::pipe(pipeFds));
kj::AutoCloseFd in2(pipeFds[0]);
kj::AutoCloseFd out2(pipeFds[1]);
KJ_SYSCALL(miniposix::pipe(pipeFds));
kj::AutoCloseFd in3(pipeFds[0]);
kj::AutoCloseFd out3(pipeFds[1]);
{
AutoCloseFd sendFds[3] = { kj::mv(out1), kj::mv(out2), kj::mv(out3) };
capPipe.ends[0]->writeWithFds("foo"_kj.asBytes(), nullptr, sendFds).wait(io.waitScope);
}
{
char buffer[4];
AutoCloseFd fdBuffer[2];
auto result = capPipe.ends[1]->tryReadWithFds(buffer, 3, 3, fdBuffer, 2).wait(io.waitScope);
KJ_ASSERT(result.capCount == 2);
kj::FdOutputStream(fdBuffer[0].get()).write("bar", 3);
kj::FdOutputStream(fdBuffer[1].get()).write("baz", 3);
}
// We want to carefully verify that out1, out2, and out3 were closed, without deadlocking if they
// weren't. So we manually set nonblocking mode and then issue read()s.
KJ_SYSCALL(fcntl(in1, F_SETFL, O_NONBLOCK));
KJ_SYSCALL(fcntl(in2, F_SETFL, O_NONBLOCK));
KJ_SYSCALL(fcntl(in3, F_SETFL, O_NONBLOCK));
char buffer[4];
ssize_t n;
// First we read "bar" from in1.
KJ_NONBLOCKING_SYSCALL(n = read(in1, buffer, 4));
KJ_ASSERT(n == 3);
buffer[3] = '\0';
KJ_ASSERT(kj::StringPtr(buffer) == "bar");
// Now it should be EOF.
KJ_NONBLOCKING_SYSCALL(n = read(in1, buffer, 4));
if (n < 0) {
KJ_FAIL_ASSERT("out1 was not closed");
}
KJ_ASSERT(n == 0);
// Next we read "baz" from in2.
KJ_NONBLOCKING_SYSCALL(n = read(in2, buffer, 4));
KJ_ASSERT(n == 3);
buffer[3] = '\0';
KJ_ASSERT(kj::StringPtr(buffer) == "baz");
// Now it should be EOF.
KJ_NONBLOCKING_SYSCALL(n = read(in2, buffer, 4));
if (n < 0) {
KJ_FAIL_ASSERT("out2 was not closed");
}
KJ_ASSERT(n == 0);
// Third pipe should have been closed implicitly because we didn't provide space to receive it.
KJ_NONBLOCKING_SYSCALL(n = read(in3, buffer, 4));
if (n < 0) {
KJ_FAIL_ASSERT("out3 was not closed. This could indicate that your operating system kernel is "
"buggy and leaks file descriptors when an SCM_RIGHTS message is truncated. FreeBSD was "
"known to do this until late 2018, while MacOS still has this bug as of this writing in "
"2019. However, KJ works around the problem on those platforms. You need to enable the "
"same work-around for your OS -- search for 'SCM_RIGHTS' in src/kj/async-io-unix.c++.");
}
KJ_ASSERT(n == 0);
}
#endif #endif
TEST(AsyncIo, PipeThread) { TEST(AsyncIo, PipeThread) {
...@@ -2239,9 +2427,11 @@ KJ_TEST("OS TwoWayPipe whenWriteDisconnected()") { ...@@ -2239,9 +2427,11 @@ KJ_TEST("OS TwoWayPipe whenWriteDisconnected()") {
abortedPromise.wait(io.waitScope); abortedPromise.wait(io.waitScope);
char buffer[4]; char buffer[4];
KJ_ASSERT(pipe.ends[0]->tryRead(&buffer, sizeof(buffer), sizeof(buffer)).wait(io.waitScope) == 3); KJ_ASSERT(pipe.ends[0]->tryRead(&buffer, 3, 3).wait(io.waitScope) == 3);
buffer[3] = '\0'; buffer[3] = '\0';
KJ_EXPECT(buffer == "bar"_kj); KJ_EXPECT(buffer == "bar"_kj);
// Note: Reading any further in pipe.ends[0] would throw "connection reset".
} }
KJ_TEST("import socket FD that's already broken") { KJ_TEST("import socket FD that's already broken") {
......
This diff is collapsed.
...@@ -1818,6 +1818,25 @@ Tee newTee(Own<AsyncInputStream> input, uint64_t limit) { ...@@ -1818,6 +1818,25 @@ Tee newTee(Own<AsyncInputStream> input, uint64_t limit) {
return { { mv(branch1), mv(branch2) } }; return { { mv(branch1), mv(branch2) } };
} }
Promise<void> AsyncCapabilityStream::writeWithFds(
ArrayPtr<const byte> data, ArrayPtr<const ArrayPtr<const byte>> moreData,
ArrayPtr<const AutoCloseFd> fds) {
// HACK: AutoCloseFd actually contains an `int` under the hood. We can reinterpret_cast to avoid
// unnecessary memory allocation.
static_assert(sizeof(AutoCloseFd) == sizeof(int), "this optimization won't work");
auto intArray = arrayPtr(reinterpret_cast<const int*>(fds.begin()), fds.size());
// Be extra-paranoid about aliasing rules by injecting a compiler barrier here. Probably
// not necessary but also probably doesn't hurt.
#if _MSC_VER
_ReadWriteBarrier();
#else
__asm__ __volatile__("": : :"memory");
#endif
return writeWithFds(data, moreData, intArray);
}
Promise<Own<AsyncCapabilityStream>> AsyncCapabilityStream::receiveStream() { Promise<Own<AsyncCapabilityStream>> AsyncCapabilityStream::receiveStream() {
return tryReceiveStream() return tryReceiveStream()
.then([](Maybe<Own<AsyncCapabilityStream>>&& result) .then([](Maybe<Own<AsyncCapabilityStream>>&& result)
...@@ -1830,6 +1849,35 @@ Promise<Own<AsyncCapabilityStream>> AsyncCapabilityStream::receiveStream() { ...@@ -1830,6 +1849,35 @@ Promise<Own<AsyncCapabilityStream>> AsyncCapabilityStream::receiveStream() {
}); });
} }
kj::Promise<Maybe<Own<AsyncCapabilityStream>>> AsyncCapabilityStream::tryReceiveStream() {
struct ResultHolder {
byte b;
Own<AsyncCapabilityStream> stream;
};
auto result = kj::heap<ResultHolder>();
auto promise = tryReadWithStreams(&result->b, 1, 1, &result->stream, 1);
return promise.then([result = kj::mv(result)](ReadResult actual) mutable
-> Maybe<Own<AsyncCapabilityStream>> {
if (actual.byteCount == 0) {
return nullptr;
}
KJ_REQUIRE(actual.capCount == 1,
"expected to receive a capability (e.g. file descirptor via SCM_RIGHTS), but didn't") {
return nullptr;
}
return kj::mv(result->stream);
});
}
Promise<void> AsyncCapabilityStream::sendStream(Own<AsyncCapabilityStream> stream) {
static constexpr byte b = 0;
auto streams = kj::heapArray<Own<AsyncCapabilityStream>>(1);
streams[0] = kj::mv(stream);
return writeWithStreams(arrayPtr(&b, 1), nullptr, kj::mv(streams));
}
Promise<AutoCloseFd> AsyncCapabilityStream::receiveFd() { Promise<AutoCloseFd> AsyncCapabilityStream::receiveFd() {
return tryReceiveFd().then([](Maybe<AutoCloseFd>&& result) -> Promise<AutoCloseFd> { return tryReceiveFd().then([](Maybe<AutoCloseFd>&& result) -> Promise<AutoCloseFd> {
KJ_IF_MAYBE(r, result) { KJ_IF_MAYBE(r, result) {
...@@ -1839,11 +1887,35 @@ Promise<AutoCloseFd> AsyncCapabilityStream::receiveFd() { ...@@ -1839,11 +1887,35 @@ Promise<AutoCloseFd> AsyncCapabilityStream::receiveFd() {
} }
}); });
} }
Promise<Maybe<AutoCloseFd>> AsyncCapabilityStream::tryReceiveFd() {
return KJ_EXCEPTION(UNIMPLEMENTED, "this stream cannot receive file descriptors"); kj::Promise<kj::Maybe<AutoCloseFd>> AsyncCapabilityStream::tryReceiveFd() {
struct ResultHolder {
byte b;
AutoCloseFd fd;
};
auto result = kj::heap<ResultHolder>();
auto promise = tryReadWithFds(&result->b, 1, 1, &result->fd, 1);
return promise.then([result = kj::mv(result)](ReadResult actual) mutable
-> Maybe<AutoCloseFd> {
if (actual.byteCount == 0) {
return nullptr;
}
KJ_REQUIRE(actual.capCount == 1,
"expected to receive a file descriptor (e.g. via SCM_RIGHTS), but didn't") {
return nullptr;
}
return kj::mv(result->fd);
});
} }
Promise<void> AsyncCapabilityStream::sendFd(int fd) { Promise<void> AsyncCapabilityStream::sendFd(int fd) {
return KJ_EXCEPTION(UNIMPLEMENTED, "this stream cannot send file descriptors"); static constexpr byte b = 0;
auto fds = kj::heapArray<int>(1);
fds[0] = fd;
auto promise = writeWithFds(arrayPtr(&b, 1), nullptr, fds);
return promise.attach(kj::mv(fds));
} }
void AsyncIoStream::getsockopt(int level, int option, void* value, uint* length) { void AsyncIoStream::getsockopt(int level, int option, void* value, uint* length) {
......
...@@ -175,15 +175,57 @@ class AsyncCapabilityStream: public AsyncIoStream { ...@@ -175,15 +175,57 @@ class AsyncCapabilityStream: public AsyncIoStream {
// broker, or in terms of direct handle passing if at least one process trusts the other. // broker, or in terms of direct handle passing if at least one process trusts the other.
public: public:
virtual Promise<void> writeWithFds(ArrayPtr<const byte> data,
ArrayPtr<const ArrayPtr<const byte>> moreData,
ArrayPtr<const int> fds) = 0;
Promise<void> writeWithFds(ArrayPtr<const byte> data,
ArrayPtr<const ArrayPtr<const byte>> moreData,
ArrayPtr<const AutoCloseFd> fds);
// Write some data to the stream with some file descriptors attached to it.
//
// The maximum number of FDs that can be sent at a time is usually subject to an OS-imposed
// limit. On Linux, this is 253. In practice, sending more than a handful of FDs at once is
// probably a bad idea.
struct ReadResult {
size_t byteCount;
size_t capCount;
};
virtual Promise<ReadResult> tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes,
AutoCloseFd* fdBuffer, size_t maxFds) = 0;
// Read data from the stream that may have file descriptors attached. Any attached descriptors
// will be placed in `fdBuffer`. If multiple bundles of FDs are encountered in the course of
// reading the amount of data requested by minBytes/maxBytes, then they will be concatenated. If
// more FDs are received than fit in the buffer, then the excess will be discarded and closed --
// this behavior, while ugly, is important to defend against denial-of-service attacks that may
// fill up the FD table with garbage. Applications must think carefully about how many FDs they
// really need to receive at once and set a well-defined limit.
virtual Promise<void> writeWithStreams(ArrayPtr<const byte> data,
ArrayPtr<const ArrayPtr<const byte>> moreData,
Array<Own<AsyncCapabilityStream>> streams) = 0;
virtual Promise<ReadResult> tryReadWithStreams(
void* buffer, size_t minBytes, size_t maxBytes,
Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) = 0;
// Like above, but passes AsyncCapabilityStream objects. The stream implementations must be from
// the same AsyncIoProvider.
// ---------------------------------------------------------------------------
// Helpers for sending individual capabilities.
//
// These are equivalent to the above methods with the constraint that only one FD is
// sent/received at a time and the corresponding data is a single zero-valued byte.
Promise<Own<AsyncCapabilityStream>> receiveStream(); Promise<Own<AsyncCapabilityStream>> receiveStream();
virtual Promise<Maybe<Own<AsyncCapabilityStream>>> tryReceiveStream() = 0; Promise<Maybe<Own<AsyncCapabilityStream>>> tryReceiveStream();
virtual Promise<void> sendStream(Own<AsyncCapabilityStream> stream) = 0; Promise<void> sendStream(Own<AsyncCapabilityStream> stream);
// Transfer a stream. // Transfer a single stream.
Promise<AutoCloseFd> receiveFd(); Promise<AutoCloseFd> receiveFd();
virtual Promise<Maybe<AutoCloseFd>> tryReceiveFd(); Promise<Maybe<AutoCloseFd>> tryReceiveFd();
virtual Promise<void> sendFd(int fd); Promise<void> sendFd(int fd);
// Transfer a raw file descriptor. Default implementation throws UNIMPLEMENTED. // Transfer a single raw file descriptor.
}; };
struct OneWayPipe { struct OneWayPipe {
......
...@@ -326,14 +326,13 @@ void VectorOutputStream::grow(size_t minSize) { ...@@ -326,14 +326,13 @@ void VectorOutputStream::grow(size_t minSize) {
AutoCloseFd::~AutoCloseFd() noexcept(false) { AutoCloseFd::~AutoCloseFd() noexcept(false) {
if (fd >= 0) { if (fd >= 0) {
unwindDetector.catchExceptionsIfUnwinding([&]() { // Don't use SYSCALL() here because close() should not be repeated on EINTR.
// Don't use SYSCALL() here because close() should not be repeated on EINTR. if (miniposix::close(fd) < 0) {
if (miniposix::close(fd) < 0) { KJ_FAIL_SYSCALL("close", errno, fd) {
KJ_FAIL_SYSCALL("close", errno, fd) { // This ensures we don't throw an exception if unwinding.
break; break;
}
} }
}); }
} }
} }
......
...@@ -300,7 +300,6 @@ public: ...@@ -300,7 +300,6 @@ public:
private: private:
int fd; int fd;
UnwindDetector unwindDetector;
}; };
inline auto KJ_STRINGIFY(const AutoCloseFd& fd) inline auto KJ_STRINGIFY(const AutoCloseFd& fd)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment