Commit b8431197 authored by Kenton Varda's avatar Kenton Varda

Add ability to construct a new bootstrap capability for each connecting client…

Add ability to construct a new bootstrap capability for each connecting client based on their authenticated VatId.
parent 1262c429
......@@ -43,7 +43,7 @@ class RpcSystem;
namespace _ { // private
class VatNetworkBase {
// Non-template version of VatNetwork. Ignore this class; see VatNetwork, below.
// Non-template version of VatNetwork. Ignore this class; see VatNetwork in rpc.h.
public:
class Connection;
......@@ -59,6 +59,7 @@ public:
virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) = 0;
virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
virtual kj::Promise<void> shutdown() = 0;
virtual AnyStruct::Reader baseGetPeerVatId() = 0;
};
virtual kj::Maybe<kj::Own<Connection>> baseConnect(_::StructReader vatId) = 0;
virtual kj::Promise<kj::Own<Connection>> baseAccept() = 0;
......@@ -69,10 +70,20 @@ public:
virtual Capability::Client baseRestore(AnyPointer::Reader ref) = 0;
};
class BootstrapFactoryBase {
// Non-template version of BootstrapFactory. Ignore this class; see BootstrapFactory in rpc.h.
public:
virtual Capability::Client baseCreateFor(AnyStruct::Reader clientId) = 0;
};
class RpcSystemBase {
// Non-template version of RpcSystem. Ignore this class; see RpcSystem in rpc.h.
public:
RpcSystemBase(VatNetworkBase& network, kj::Maybe<Capability::Client> bootstrapInterface,
kj::Maybe<RealmGateway<>::Client> gateway);
RpcSystemBase(VatNetworkBase& network, BootstrapFactoryBase& bootstrapFactory,
kj::Maybe<RealmGateway<>::Client> gateway);
RpcSystemBase(VatNetworkBase& network, SturdyRefRestorerBase& restorer);
RpcSystemBase(RpcSystemBase&& other) noexcept;
~RpcSystemBase() noexcept(false);
......
......@@ -302,6 +302,11 @@ public:
MallocMessageBuilder message;
};
test::TestSturdyRefHostId::Reader getPeerVatId() override {
// Not actually implemented for the purpose of this test.
return test::TestSturdyRefHostId::Reader();
}
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) override {
return kj::heap<OutgoingRpcMessageImpl>(*this, firstSegmentWordSize);
}
......
......@@ -336,6 +336,55 @@ TEST(TwoPartyNetwork, ConvenienceClasses) {
EXPECT_EQ(1, callCount);
}
class TestAuthenticatedBootstrapImpl
: public test::TestAuthenticatedBootstrap<rpc::twoparty::VatId>::Server {
public:
TestAuthenticatedBootstrapImpl(rpc::twoparty::VatId::Reader clientId) {
this->clientId.setRoot(clientId);
}
protected:
kj::Promise<void> getCallerId(GetCallerIdContext context) override {
context.getResults().setCaller(clientId.getRoot<rpc::twoparty::VatId>());
return kj::READY_NOW;
}
private:
MallocMessageBuilder clientId;
};
class TestBootstrapFactory: public BootstrapFactory<rpc::twoparty::VatId> {
public:
Capability::Client createFor(rpc::twoparty::VatId::Reader clientId) {
called = true;
EXPECT_EQ(rpc::twoparty::Side::CLIENT, clientId.getSide());
return kj::heap<TestAuthenticatedBootstrapImpl>(clientId);
}
bool called = false;
};
kj::AsyncIoProvider::PipeThread runAuthenticatingServer(
kj::AsyncIoProvider& ioProvider, BootstrapFactory<rpc::twoparty::VatId>& bootstrapFactory) {
return ioProvider.newPipeThread([&bootstrapFactory](
kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream, kj::WaitScope& waitScope) {
TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
auto server = makeRpcServer(network, bootstrapFactory);
network.onDisconnect().wait(waitScope);
});
}
TEST(TwoPartyNetwork, BootstrapFactory) {
auto ioContext = kj::setupAsyncIo();
TestBootstrapFactory bootstrapFactory;
auto serverThread = runAuthenticatingServer(*ioContext.provider, bootstrapFactory);
TwoPartyClient client(*serverThread.pipe);
auto resp = client.bootstrap().castAs<test::TestAuthenticatedBootstrap<rpc::twoparty::VatId>>()
.getCallerIdRequest().send().wait(ioContext.waitScope);
EXPECT_EQ(rpc::twoparty::Side::CLIENT, resp.getCaller().getSide());
EXPECT_TRUE(bootstrapFactory.called);
}
} // namespace
} // namespace _
} // namespace capnp
......@@ -27,7 +27,12 @@ namespace capnp {
TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions)
: stream(stream), side(side), receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) {
: stream(stream), side(side), peerVatId(4),
receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) {
peerVatId.initRoot<rpc::twoparty::VatId>().setSide(
side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER
: rpc::twoparty::Side::CLIENT);
auto paf = kj::newPromiseAndFulfiller<void>();
disconnectPromise = paf.promise.fork();
disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller);
......@@ -115,6 +120,10 @@ private:
kj::Own<MessageReader> message;
};
rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() {
return peerVatId.getRoot<rpc::twoparty::VatId>();
}
kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSegmentWordSize) {
return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize);
}
......
......@@ -71,6 +71,7 @@ private:
kj::AsyncIoStream& stream;
rpc::twoparty::Side side;
MallocMessageBuilder peerVatId;
ReaderOptions receiveOptions;
bool accepted = false;
......@@ -103,6 +104,7 @@ private:
// implements Connection -----------------------------------------------------
rpc::twoparty::VatId::Reader getPeerVatId() override;
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) override;
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override;
kj::Promise<void> shutdown() override;
......
......@@ -244,12 +244,12 @@ public:
// Task which is working on sending an abort message and cleanly ending the connection.
};
RpcConnectionState(kj::Maybe<Capability::Client> bootstrapInterface,
RpcConnectionState(BootstrapFactoryBase& bootstrapFactory,
kj::Maybe<RealmGateway<>::Client> gateway,
kj::Maybe<SturdyRefRestorerBase&> restorer,
kj::Own<VatNetworkBase::Connection>&& connectionParam,
kj::Own<kj::PromiseFulfiller<DisconnectInfo>>&& disconnectFulfiller)
: bootstrapInterface(kj::mv(bootstrapInterface)), gateway(kj::mv(gateway)),
: bootstrapFactory(bootstrapFactory), gateway(kj::mv(gateway)),
restorer(restorer), disconnectFulfiller(kj::mv(disconnectFulfiller)), tasks(*this) {
connection.init<Connected>(kj::mv(connectionParam));
tasks.add(messageLoop());
......@@ -502,7 +502,7 @@ private:
// =======================================================================================
// OK, now we can define RpcConnectionState's member data.
kj::Maybe<Capability::Client> bootstrapInterface;
BootstrapFactoryBase& bootstrapFactory;
kj::Maybe<RealmGateway<>::Client> gateway;
kj::Maybe<SturdyRefRestorerBase&> restorer;
......@@ -2091,7 +2091,8 @@ private:
return;
}
auto response = connection.get<Connected>()->newOutgoingMessage(
VatNetworkBase::Connection& conn = *connection.get<Connected>();
auto response = conn.newOutgoingMessage(
messageSizeHint<rpc::Return>() + sizeInWords<rpc::CapDescriptor>() + 32);
rpc::Return::Builder ret = response->getBody().getAs<rpc::Message>().initReturn();
......@@ -2104,17 +2105,16 @@ private:
// Call the restorer and initialize the answer.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
Capability::Client cap = nullptr;
KJ_IF_MAYBE(r, restorer) {
cap = r->baseRestore(bootstrap.getDeprecatedObjectId());
} else KJ_IF_MAYBE(b, bootstrapInterface) {
if (bootstrap.hasDeprecatedObjectId()) {
if (bootstrap.hasDeprecatedObjectId()) {
KJ_IF_MAYBE(r, restorer) {
cap = r->baseRestore(bootstrap.getDeprecatedObjectId());
} else {
KJ_FAIL_REQUIRE("This vat only supports a bootstrap interface, not the old "
"Cap'n-Proto-0.4-style named exports.") { return; }
} else {
cap = *b;
}
} else {
KJ_FAIL_REQUIRE("This vat does not expose any public/bootstrap interfaces.") { return; }
cap = bootstrapFactory.baseCreateFor(conn.baseGetPeerVatId());
}
auto payload = ret.initResults();
......@@ -2594,16 +2594,22 @@ private:
} // namespace
class RpcSystemBase::Impl final: public kj::TaskSet::ErrorHandler {
class RpcSystemBase::Impl final: private BootstrapFactoryBase, private kj::TaskSet::ErrorHandler {
public:
Impl(VatNetworkBase& network, kj::Maybe<Capability::Client> bootstrapInterface,
kj::Maybe<RealmGateway<>::Client> gateway)
: network(network), bootstrapInterface(kj::mv(bootstrapInterface)),
bootstrapFactory(*this), gateway(kj::mv(gateway)), tasks(*this) {
tasks.add(acceptLoop());
}
Impl(VatNetworkBase& network, BootstrapFactoryBase& bootstrapFactory,
kj::Maybe<RealmGateway<>::Client> gateway)
: network(network), bootstrapFactory(bootstrapFactory),
gateway(kj::mv(gateway)), tasks(*this) {
tasks.add(acceptLoop());
}
Impl(VatNetworkBase& network, SturdyRefRestorerBase& restorer)
: network(network), restorer(restorer), tasks(*this) {
: network(network), bootstrapFactory(*this), restorer(restorer), tasks(*this) {
tasks.add(acceptLoop());
}
......@@ -2640,13 +2646,10 @@ public:
}
}
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, exception);
}
private:
VatNetworkBase& network;
kj::Maybe<Capability::Client> bootstrapInterface;
BootstrapFactoryBase& bootstrapFactory;
kj::Maybe<RealmGateway<>::Client> gateway;
kj::Maybe<SturdyRefRestorerBase&> restorer;
kj::TaskSet tasks;
......@@ -2668,7 +2671,7 @@ private:
tasks.add(kj::mv(info.shutdownPromise));
}));
auto newState = kj::refcounted<RpcConnectionState>(
bootstrapInterface, gateway, restorer, kj::mv(connection),
bootstrapFactory, gateway, restorer, kj::mv(connection),
kj::mv(onDisconnect.fulfiller));
RpcConnectionState& result = *newState;
connections.insert(std::make_pair(connectionPtr, kj::mv(newState)));
......@@ -2691,12 +2694,33 @@ private:
tasks.add(acceptLoop());
});
}
Capability::Client baseCreateFor(AnyStruct::Reader clientId) override {
// Implements BootstrapFactory::baseCreateFor() in terms of `bootstrapInterface` or `restorer`,
// for use when we were given one of those instead of an actual `bootstrapFactory`.
KJ_IF_MAYBE(cap, bootstrapInterface) {
return *cap;
} else KJ_IF_MAYBE(r, restorer) {
return r->baseRestore(AnyPointer::Reader());
} else {
return KJ_EXCEPTION(FAILED, "This vat does not expose any public/bootstrap interfaces.");
}
}
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, exception);
}
};
RpcSystemBase::RpcSystemBase(VatNetworkBase& network,
kj::Maybe<Capability::Client> bootstrapInterface,
kj::Maybe<RealmGateway<>::Client> gateway)
: impl(kj::heap<Impl>(network, kj::mv(bootstrapInterface), kj::mv(gateway))) {}
RpcSystemBase::RpcSystemBase(VatNetworkBase& network,
BootstrapFactoryBase& bootstrapFactory,
kj::Maybe<RealmGateway<>::Client> gateway)
: impl(kj::heap<Impl>(network, bootstrapFactory, kj::mv(gateway))) {}
RpcSystemBase::RpcSystemBase(VatNetworkBase& network, SturdyRefRestorerBase& restorer)
: impl(kj::heap<Impl>(network, restorer)) {}
RpcSystemBase::RpcSystemBase(RpcSystemBase&& other) noexcept = default;
......
......@@ -37,6 +37,25 @@ class VatNetwork;
template <typename SturdyRefObjectId>
class SturdyRefRestorer;
template <typename VatId>
class BootstrapFactory: public _::BootstrapFactoryBase {
// Interface that constructs per-client bootstrap interfaces. Use this if you want each client
// who connects to see a different bootstrap interface based on their (authenticated) VatId.
// This allows an application to bootstrap off of the authentication performed at the VatNetwork
// level. (Typically VatId is some sort of public key.)
//
// This is only useful for multi-party networks. For TwoPartyVatNetwork, there's no reason to
// use a BootstrapFactory; just specify a single bootstrap capability in this case.
public:
virtual Capability::Client createFor(typename VatId::Reader clientId) = 0;
// Create a bootstrap capability appropriate for exposing to the given client. VatNetwork will
// have authenticated the client VatId before this is called.
private:
Capability::Client baseCreateFor(AnyStruct::Reader clientId) override;
};
template <typename VatId>
class RpcSystem: public _::RpcSystemBase {
// Represents the RPC system, which is the portal to objects available on the network.
......@@ -60,6 +79,13 @@ public:
kj::Maybe<Capability::Client> bootstrapInterface,
kj::Maybe<RealmGateway<>::Client> gateway = nullptr);
template <typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult>
RpcSystem(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
BootstrapFactory<VatId>& bootstrapFactory,
kj::Maybe<RealmGateway<>::Client> gateway = nullptr);
template <typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult,
typename LocalSturdyRefObjectId>
......@@ -115,6 +141,25 @@ RpcSystem<VatId> makeRpcServer(
// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format
// and the network's ("external") format.
template <typename VatId, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult>
RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
BootstrapFactory<VatId>& bootstrapFactory);
// Make an RPC server that can serve different bootstrap interfaces to different clients via a
// BootstrapInterface.
template <typename VatId, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult, typename RealmGatewayClient,
typename InternalRef = _::InternalRefFromRealmGatewayClient<RealmGatewayClient>,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
BootstrapFactory<VatId>& bootstrapFactory, RealmGatewayClient gateway);
// Make an RPC server that can serve different bootstrap interfaces to different clients via a
// BootstrapInterface and communicates with a different realm than the application is in via a
// RealmGateway.
template <typename VatId, typename LocalSturdyRefObjectId,
typename ProvisionId, typename RecipientId, typename ThirdPartyCapId, typename JoinResult>
RpcSystem<VatId> makeRpcServer(
......@@ -264,6 +309,11 @@ public:
public:
// Level 0 features ----------------------------------------------
virtual typename VatId::Reader getPeerVatId() = 0;
// Returns the connected vat's authenticated VatId. It is the VatNetwork's responsibility to
// authenticate this, so that the caller can be assured that they are really talking to the
// identified vat and not an imposter.
virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) = 0;
// Allocate a new message to be sent on this connection.
//
......@@ -278,6 +328,9 @@ public:
virtual kj::Promise<void> shutdown() KJ_WARN_UNUSED_RESULT = 0;
// Waits until all outgoing messages have been sent, then shuts down the outgoing stream. The
// returned promise resolves after shutdown is complete.
private:
AnyStruct::Reader baseGetPeerVatId() override;
};
// Level 0 features ------------------------------------------------
......@@ -310,6 +363,11 @@ private:
// ***************************************************************************************
// =======================================================================================
template <typename VatId>
Capability::Client BootstrapFactory<VatId>::baseCreateFor(AnyStruct::Reader clientId) {
return createFor(clientId.as<VatId>());
}
template <typename SturdyRef, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult>
kj::Maybe<kj::Own<_::VatNetworkBase::Connection>>
......@@ -331,6 +389,14 @@ kj::Promise<kj::Own<_::VatNetworkBase::Connection>>
});
}
template <typename SturdyRef, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult>
AnyStruct::Reader VatNetwork<
SturdyRef, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>::
Connection::baseGetPeerVatId() {
return getPeerVatId();
}
template <typename SturdyRef>
Capability::Client SturdyRefRestorer<SturdyRef>::baseRestore(AnyPointer::Reader ref) {
#pragma GCC diagnostic push
......@@ -348,6 +414,15 @@ RpcSystem<VatId>::RpcSystem(
kj::Maybe<RealmGateway<>::Client> gateway)
: _::RpcSystemBase(network, kj::mv(bootstrap), kj::mv(gateway)) {}
template <typename VatId>
template <typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult>
RpcSystem<VatId>::RpcSystem(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
BootstrapFactory<VatId>& bootstrapFactory,
kj::Maybe<RealmGateway<>::Client> gateway)
: _::RpcSystemBase(network, bootstrapFactory, kj::mv(gateway)) {}
template <typename VatId>
template <typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult,
......@@ -386,6 +461,23 @@ RpcSystem<VatId> makeRpcServer(
gateway.template castAs<RealmGateway<>>());
}
template <typename VatId, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult>
RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
BootstrapFactory<VatId>& bootstrapFactory) {
return RpcSystem<VatId>(network, bootstrapFactory);
}
template <typename VatId, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult,
typename RealmGatewayClient, typename InternalRef, typename ExternalRef>
RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
BootstrapFactory<VatId>& bootstrapFactory, RealmGatewayClient gateway) {
return RpcSystem<VatId>(network, bootstrapFactory, gateway.template castAs<RealmGateway<>>());
}
template <typename VatId, typename LocalSturdyRefObjectId,
typename ProvisionId, typename RecipientId, typename ThirdPartyCapId, typename JoinResult>
RpcSystem<VatId> makeRpcServer(
......
......@@ -793,6 +793,10 @@ interface TestKeywordMethods {
return @3 ();
}
interface TestAuthenticatedBootstrap(VatId) {
getCallerId @0 () -> (caller :VatId);
}
struct TestSturdyRef {
hostId @0 :TestSturdyRefHostId;
objectId @1 :AnyPointer;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment