Commit a181ec36 authored by Kenton Varda's avatar Kenton Varda

Add tests for RealmGateway, fix bugs.

parent 89b25093
...@@ -57,7 +57,9 @@ interface Persistent@0xc8cb212fcd9f5691(SturdyRef) { ...@@ -57,7 +57,9 @@ interface Persistent@0xc8cb212fcd9f5691(SturdyRef) {
# capabilities can be saved -- application interfaces should define which capabilities support # capabilities can be saved -- application interfaces should define which capabilities support
# this and which do not. # this and which do not.
struct SaveParams {} struct SaveParams {
# Nothing for now.
}
struct SaveResults { struct SaveResults {
sturdyRef @0 :SturdyRef; sturdyRef @0 :SturdyRef;
} }
......
...@@ -85,6 +85,26 @@ private: ...@@ -85,6 +85,26 @@ private:
friend class capnp::RpcSystem; friend class capnp::RpcSystem;
}; };
template <typename T> struct InternalRefFromRealmGateway_;
template <typename InternalRef, typename ExternalRef>
struct InternalRefFromRealmGateway_<RealmGateway<InternalRef, ExternalRef>> {
typedef InternalRef Type;
};
template <typename T>
using InternalRefFromRealmGateway = typename InternalRefFromRealmGateway_<T>::Type;
template <typename T>
using InternalRefFromRealmGatewayClient = InternalRefFromRealmGateway<typename T::Calls>;
template <typename T> struct ExternalRefFromRealmGateway_;
template <typename InternalRef, typename ExternalRef>
struct ExternalRefFromRealmGateway_<RealmGateway<InternalRef, ExternalRef>> {
typedef ExternalRef Type;
};
template <typename T>
using ExternalRefFromRealmGateway = typename ExternalRefFromRealmGateway_<T>::Type;
template <typename T>
using ExternalRefFromRealmGatewayClient = ExternalRefFromRealmGateway<typename T::Calls>;
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
......
...@@ -427,6 +427,21 @@ struct TestContext { ...@@ -427,6 +427,21 @@ struct TestContext {
serverNetwork(network.add("server")), serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork)), rpcClient(makeRpcClient(clientNetwork)),
rpcServer(makeRpcServer(serverNetwork, restorer)) {} rpcServer(makeRpcServer(serverNetwork, restorer)) {}
TestContext(Capability::Client bootstrap,
RealmGateway<test::TestSturdyRef, Text>::Client gateway)
: waitScope(loop),
clientNetwork(network.add("client")),
serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork, gateway)),
rpcServer(makeRpcServer(serverNetwork, bootstrap)) {}
TestContext(Capability::Client bootstrap,
RealmGateway<test::TestSturdyRef, Text>::Client gateway,
bool)
: waitScope(loop),
clientNetwork(network.add("client")),
serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork)),
rpcServer(makeRpcServer(serverNetwork, bootstrap, gateway)) {}
Capability::Client connect(test::TestSturdyRefObjectId::Tag tag) { Capability::Client connect(test::TestSturdyRefObjectId::Tag tag) {
MallocMessageBuilder refMessage(128); MallocMessageBuilder refMessage(128);
...@@ -962,6 +977,91 @@ TEST(Rpc, CallBrokenPromise) { ...@@ -962,6 +977,91 @@ TEST(Rpc, CallBrokenPromise) {
getCallSequence(client, 1).wait(context.waitScope); getCallSequence(client, 1).wait(context.waitScope);
} }
// =======================================================================================
typedef RealmGateway<test::TestSturdyRef, Text> TestRealmGateway;
class TestGateway: public TestRealmGateway::Server {
public:
kj::Promise<void> import(ImportContext context) override {
auto cap = context.getParams().getCap();
context.releaseParams();
return cap.saveRequest().send()
.then([context](Response<Persistent<Text>::SaveResults> response) mutable {
context.getResults().initSturdyRef().getObjectId().setAs<Text>(
kj::str("imported-", response.getSturdyRef()));
});
}
kj::Promise<void> export_(ExportContext context) override {
auto cap = context.getParams().getCap();
context.releaseParams();
return cap.saveRequest().send()
.then([context](Response<Persistent<test::TestSturdyRef>::SaveResults> response) mutable {
context.getResults().setSturdyRef(kj::str("exported-",
response.getSturdyRef().getObjectId().getAs<Text>()));
});
}
};
class TestPersistent: public Persistent<test::TestSturdyRef>::Server {
public:
TestPersistent(kj::StringPtr name): name(name) {}
kj::Promise<void> save(SaveContext context) override {
context.initResults().initSturdyRef().getObjectId().setAs<Text>(name);
return kj::READY_NOW;
}
private:
kj::StringPtr name;
};
class TestPersistentText: public Persistent<Text>::Server {
public:
TestPersistentText(kj::StringPtr name): name(name) {}
kj::Promise<void> save(SaveContext context) override {
context.initResults().setSturdyRef(name);
return kj::READY_NOW;
}
private:
kj::StringPtr name;
};
TEST(Rpc, RealmGatewayImport) {
TestRealmGateway::Client gateway = kj::heap<TestGateway>();
Persistent<Text>::Client bootstrap = kj::heap<TestPersistentText>("foo");
MallocMessageBuilder hostIdBuilder;
auto hostId = hostIdBuilder.getRoot<test::TestSturdyRefHostId>();
hostId.setHost("server");
TestContext context(bootstrap, gateway);
auto client = context.rpcClient.bootstrap(hostId).castAs<Persistent<test::TestSturdyRef>>();
auto response = client.saveRequest().send().wait(context.waitScope);
EXPECT_EQ("imported-foo", response.getSturdyRef().getObjectId().getAs<Text>());
}
TEST(Rpc, RealmGatewayExport) {
TestRealmGateway::Client gateway = kj::heap<TestGateway>();
Persistent<test::TestSturdyRef>::Client bootstrap = kj::heap<TestPersistent>("foo");
MallocMessageBuilder hostIdBuilder;
auto hostId = hostIdBuilder.getRoot<test::TestSturdyRefHostId>();
hostId.setHost("server");
TestContext context(bootstrap, gateway, true);
auto client = context.rpcClient.bootstrap(hostId).castAs<Persistent<Text>>();
auto response = client.saveRequest().send().wait(context.waitScope);
EXPECT_EQ("exported-foo", response.getSturdyRef());
}
} // namespace } // namespace
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -592,7 +592,7 @@ private: ...@@ -592,7 +592,7 @@ private:
}); });
auto request = g->importRequest(sizeHint); auto request = g->importRequest(sizeHint);
request.setCap(Persistent<>::Client(addRef())); request.setCap(Persistent<>::Client(kj::refcounted<NoInterceptClient>(*this)));
// Awkwardly, request.initParams() would return a SaveParams struct, but to construct // Awkwardly, request.initParams() would return a SaveParams struct, but to construct
// the Request<AnyPointer, AnyPointer> to return we need an AnyPointer::Builder, and you // the Request<AnyPointer, AnyPointer> to return we need an AnyPointer::Builder, and you
...@@ -608,6 +608,11 @@ private: ...@@ -608,6 +608,11 @@ private:
} }
} }
return newCallNoIntercept(interfaceId, methodId, sizeHint);
}
Request<AnyPointer, AnyPointer> newCallNoIntercept(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) {
if (!connectionState->connection.is<Connected>()) { if (!connectionState->connection.is<Connected>()) {
return newBrokenRequest(kj::cp(connectionState->connection.get<Disconnected>()), sizeHint); return newBrokenRequest(kj::cp(connectionState->connection.get<Disconnected>()), sizeHint);
} }
...@@ -626,8 +631,6 @@ private: ...@@ -626,8 +631,6 @@ private:
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) override { kj::Own<CallContextHook>&& context) override {
// Implement call() by copying params and results messages.
if (interfaceId == typeId<Persistent<>>() && methodId == 0) { if (interfaceId == typeId<Persistent<>>() && methodId == 0) {
KJ_IF_MAYBE(g, connectionState->gateway) { KJ_IF_MAYBE(g, connectionState->gateway) {
// Wait, this is a call to Persistent.save() and we need to translate it through our // Wait, this is a call to Persistent.save() and we need to translate it through our
...@@ -639,7 +642,7 @@ private: ...@@ -639,7 +642,7 @@ private:
requestSize.wordCount += sizeInWords<RealmGateway<>::ImportParams>(); requestSize.wordCount += sizeInWords<RealmGateway<>::ImportParams>();
auto request = g->importRequest(requestSize); auto request = g->importRequest(requestSize);
request.setCap(Persistent<>::Client(addRef())); request.setCap(Persistent<>::Client(kj::refcounted<NoInterceptClient>(*this)));
request.setParams(params); request.setParams(params);
context->allowCancellation(); context->allowCancellation();
...@@ -648,6 +651,13 @@ private: ...@@ -648,6 +651,13 @@ private:
} }
} }
return callNoIntercept(interfaceId, methodId, kj::mv(context));
}
VoidPromiseAndPipeline callNoIntercept(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) {
// Implement call() by copying params and results messages.
auto params = context->getParams(); auto params = context->getParams();
auto request = newCall(interfaceId, methodId, params.targetSize()); auto request = newCall(interfaceId, methodId, params.targetSize());
...@@ -667,7 +677,6 @@ private: ...@@ -667,7 +677,6 @@ private:
return connectionState.get(); return connectionState.get();
} }
protected:
kj::Own<RpcConnectionState> connectionState; kj::Own<RpcConnectionState> connectionState;
}; };
...@@ -933,6 +942,55 @@ private: ...@@ -933,6 +942,55 @@ private:
} }
}; };
class NoInterceptClient final: public RpcClient {
// A wrapper around an RpcClient which bypasses special handling of "save" requests. When we
// intercept a "save" request and invoke a RealmGateway, we give it a version of the capability
// with intercepting disabled, since usually the first thing the RealmGateway will do is turn
// around and call save() again.
//
// This is admittedly sort of backwards: the interception of "save" ought to be the part
// implemented by a wrapper. However, that would require placing a wrapper around every
// RpcClient we create whereas NoInterceptClient only needs to be injected after a save()
// request occurs and is intercepted.
public:
NoInterceptClient(RpcClient& inner)
: RpcClient(*inner.connectionState),
inner(kj::addRef(inner)) {}
kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) override {
return inner->writeDescriptor(descriptor);
}
kj::Maybe<kj::Own<ClientHook>> writeTarget(rpc::MessageTarget::Builder target) override {
return inner->writeTarget(target);
}
kj::Own<ClientHook> getInnermostClient() override {
return inner->getInnermostClient();
}
Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
return inner->newCallNoIntercept(interfaceId, methodId, sizeHint);
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) override {
return inner->callNoIntercept(interfaceId, methodId, kj::mv(context));
}
kj::Maybe<ClientHook&> getResolved() override {
return nullptr;
}
kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
return nullptr;
}
private:
kj::Own<RpcClient> inner;
};
kj::Maybe<ExportId> writeDescriptor(ClientHook& cap, rpc::CapDescriptor::Builder descriptor) { kj::Maybe<ExportId> writeDescriptor(ClientHook& cap, rpc::CapDescriptor::Builder descriptor) {
// Write a descriptor for the given capability. // Write a descriptor for the given capability.
......
...@@ -101,12 +101,12 @@ RpcSystem<VatId> makeRpcServer( ...@@ -101,12 +101,12 @@ RpcSystem<VatId> makeRpcServer(
// client-server RPC connection. // client-server RPC connection.
template <typename VatId, typename ProvisionId, typename RecipientId, template <typename VatId, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult, typename ThirdPartyCapId, typename JoinResult, typename RealmGatewayClient,
typename InternalRef, typename ExternalRef> typename InternalRef = _::InternalRefFromRealmGatewayClient<RealmGatewayClient>,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcServer( RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network, VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
Capability::Client bootstrapInterface, Capability::Client bootstrapInterface, RealmGatewayClient gateway);
typename RealmGateway<InternalRef, ExternalRef>::Client gateway);
// Make an RPC server for a VatNetwork that resides in a different realm from the application. // Make an RPC server for a VatNetwork that resides in a different realm from the application.
// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format // The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format
// and the network's ("external") format. // and the network's ("external") format.
...@@ -147,11 +147,12 @@ RpcSystem<VatId> makeRpcClient( ...@@ -147,11 +147,12 @@ RpcSystem<VatId> makeRpcClient(
// client-server RPC connection. // client-server RPC connection.
template <typename VatId, typename ProvisionId, typename RecipientId, template <typename VatId, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult, typename ThirdPartyCapId, typename JoinResult, typename RealmGatewayClient,
typename InternalRef, typename ExternalRef> typename InternalRef = _::InternalRefFromRealmGatewayClient<RealmGatewayClient>,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcClient( RpcSystem<VatId> makeRpcClient(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network, VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
typename RealmGateway<InternalRef, ExternalRef>::Client gateway); RealmGatewayClient gateway);
// Make an RPC client for a VatNetwork that resides in a different realm from the application. // Make an RPC client for a VatNetwork that resides in a different realm from the application.
// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format // The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format
// and the network's ("external") format. // and the network's ("external") format.
...@@ -366,11 +367,10 @@ RpcSystem<VatId> makeRpcServer( ...@@ -366,11 +367,10 @@ RpcSystem<VatId> makeRpcServer(
template <typename VatId, typename ProvisionId, typename RecipientId, template <typename VatId, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult, typename ThirdPartyCapId, typename JoinResult,
typename InternalRef, typename ExternalRef> typename RealmGatewayClient, typename InternalRef, typename ExternalRef>
RpcSystem<VatId> makeRpcServer( RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network, VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
Capability::Client bootstrapInterface, Capability::Client bootstrapInterface, RealmGatewayClient gateway) {
typename RealmGateway<InternalRef, ExternalRef>::Client gateway) {
return RpcSystem<VatId>(network, kj::mv(bootstrapInterface), return RpcSystem<VatId>(network, kj::mv(bootstrapInterface),
gateway.template castAs<RealmGateway<>>()); gateway.template castAs<RealmGateway<>>());
} }
...@@ -392,10 +392,10 @@ RpcSystem<VatId> makeRpcClient( ...@@ -392,10 +392,10 @@ RpcSystem<VatId> makeRpcClient(
template <typename VatId, typename ProvisionId, template <typename VatId, typename ProvisionId,
typename RecipientId, typename ThirdPartyCapId, typename JoinResult, typename RecipientId, typename ThirdPartyCapId, typename JoinResult,
typename InternalRef, typename ExternalRef> typename RealmGatewayClient, typename InternalRef, typename ExternalRef>
RpcSystem<VatId> makeRpcClient( RpcSystem<VatId> makeRpcClient(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network, VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
typename RealmGateway<InternalRef, ExternalRef>::Client gateway) { RealmGatewayClient gateway) {
return RpcSystem<VatId>(network, nullptr, gateway.template castAs<RealmGateway<>>()); return RpcSystem<VatId>(network, nullptr, gateway.template castAs<RealmGateway<>>());
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment