Commit a181ec36 authored by Kenton Varda's avatar Kenton Varda

Add tests for RealmGateway, fix bugs.

parent 89b25093
......@@ -57,7 +57,9 @@ interface Persistent@0xc8cb212fcd9f5691(SturdyRef) {
# capabilities can be saved -- application interfaces should define which capabilities support
# this and which do not.
struct SaveParams {}
struct SaveParams {
# Nothing for now.
}
struct SaveResults {
sturdyRef @0 :SturdyRef;
}
......
......@@ -85,6 +85,26 @@ private:
friend class capnp::RpcSystem;
};
template <typename T> struct InternalRefFromRealmGateway_;
template <typename InternalRef, typename ExternalRef>
struct InternalRefFromRealmGateway_<RealmGateway<InternalRef, ExternalRef>> {
typedef InternalRef Type;
};
template <typename T>
using InternalRefFromRealmGateway = typename InternalRefFromRealmGateway_<T>::Type;
template <typename T>
using InternalRefFromRealmGatewayClient = InternalRefFromRealmGateway<typename T::Calls>;
template <typename T> struct ExternalRefFromRealmGateway_;
template <typename InternalRef, typename ExternalRef>
struct ExternalRefFromRealmGateway_<RealmGateway<InternalRef, ExternalRef>> {
typedef ExternalRef Type;
};
template <typename T>
using ExternalRefFromRealmGateway = typename ExternalRefFromRealmGateway_<T>::Type;
template <typename T>
using ExternalRefFromRealmGatewayClient = ExternalRefFromRealmGateway<typename T::Calls>;
} // namespace _ (private)
} // namespace capnp
......
......@@ -427,6 +427,21 @@ struct TestContext {
serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork)),
rpcServer(makeRpcServer(serverNetwork, restorer)) {}
TestContext(Capability::Client bootstrap,
RealmGateway<test::TestSturdyRef, Text>::Client gateway)
: waitScope(loop),
clientNetwork(network.add("client")),
serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork, gateway)),
rpcServer(makeRpcServer(serverNetwork, bootstrap)) {}
TestContext(Capability::Client bootstrap,
RealmGateway<test::TestSturdyRef, Text>::Client gateway,
bool)
: waitScope(loop),
clientNetwork(network.add("client")),
serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork)),
rpcServer(makeRpcServer(serverNetwork, bootstrap, gateway)) {}
Capability::Client connect(test::TestSturdyRefObjectId::Tag tag) {
MallocMessageBuilder refMessage(128);
......@@ -962,6 +977,91 @@ TEST(Rpc, CallBrokenPromise) {
getCallSequence(client, 1).wait(context.waitScope);
}
// =======================================================================================
typedef RealmGateway<test::TestSturdyRef, Text> TestRealmGateway;
class TestGateway: public TestRealmGateway::Server {
public:
kj::Promise<void> import(ImportContext context) override {
auto cap = context.getParams().getCap();
context.releaseParams();
return cap.saveRequest().send()
.then([context](Response<Persistent<Text>::SaveResults> response) mutable {
context.getResults().initSturdyRef().getObjectId().setAs<Text>(
kj::str("imported-", response.getSturdyRef()));
});
}
kj::Promise<void> export_(ExportContext context) override {
auto cap = context.getParams().getCap();
context.releaseParams();
return cap.saveRequest().send()
.then([context](Response<Persistent<test::TestSturdyRef>::SaveResults> response) mutable {
context.getResults().setSturdyRef(kj::str("exported-",
response.getSturdyRef().getObjectId().getAs<Text>()));
});
}
};
class TestPersistent: public Persistent<test::TestSturdyRef>::Server {
public:
TestPersistent(kj::StringPtr name): name(name) {}
kj::Promise<void> save(SaveContext context) override {
context.initResults().initSturdyRef().getObjectId().setAs<Text>(name);
return kj::READY_NOW;
}
private:
kj::StringPtr name;
};
class TestPersistentText: public Persistent<Text>::Server {
public:
TestPersistentText(kj::StringPtr name): name(name) {}
kj::Promise<void> save(SaveContext context) override {
context.initResults().setSturdyRef(name);
return kj::READY_NOW;
}
private:
kj::StringPtr name;
};
TEST(Rpc, RealmGatewayImport) {
TestRealmGateway::Client gateway = kj::heap<TestGateway>();
Persistent<Text>::Client bootstrap = kj::heap<TestPersistentText>("foo");
MallocMessageBuilder hostIdBuilder;
auto hostId = hostIdBuilder.getRoot<test::TestSturdyRefHostId>();
hostId.setHost("server");
TestContext context(bootstrap, gateway);
auto client = context.rpcClient.bootstrap(hostId).castAs<Persistent<test::TestSturdyRef>>();
auto response = client.saveRequest().send().wait(context.waitScope);
EXPECT_EQ("imported-foo", response.getSturdyRef().getObjectId().getAs<Text>());
}
TEST(Rpc, RealmGatewayExport) {
TestRealmGateway::Client gateway = kj::heap<TestGateway>();
Persistent<test::TestSturdyRef>::Client bootstrap = kj::heap<TestPersistent>("foo");
MallocMessageBuilder hostIdBuilder;
auto hostId = hostIdBuilder.getRoot<test::TestSturdyRefHostId>();
hostId.setHost("server");
TestContext context(bootstrap, gateway, true);
auto client = context.rpcClient.bootstrap(hostId).castAs<Persistent<Text>>();
auto response = client.saveRequest().send().wait(context.waitScope);
EXPECT_EQ("exported-foo", response.getSturdyRef());
}
} // namespace
} // namespace _ (private)
} // namespace capnp
......@@ -592,7 +592,7 @@ private:
});
auto request = g->importRequest(sizeHint);
request.setCap(Persistent<>::Client(addRef()));
request.setCap(Persistent<>::Client(kj::refcounted<NoInterceptClient>(*this)));
// Awkwardly, request.initParams() would return a SaveParams struct, but to construct
// the Request<AnyPointer, AnyPointer> to return we need an AnyPointer::Builder, and you
......@@ -608,6 +608,11 @@ private:
}
}
return newCallNoIntercept(interfaceId, methodId, sizeHint);
}
Request<AnyPointer, AnyPointer> newCallNoIntercept(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) {
if (!connectionState->connection.is<Connected>()) {
return newBrokenRequest(kj::cp(connectionState->connection.get<Disconnected>()), sizeHint);
}
......@@ -626,8 +631,6 @@ private:
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) override {
// Implement call() by copying params and results messages.
if (interfaceId == typeId<Persistent<>>() && methodId == 0) {
KJ_IF_MAYBE(g, connectionState->gateway) {
// Wait, this is a call to Persistent.save() and we need to translate it through our
......@@ -639,7 +642,7 @@ private:
requestSize.wordCount += sizeInWords<RealmGateway<>::ImportParams>();
auto request = g->importRequest(requestSize);
request.setCap(Persistent<>::Client(addRef()));
request.setCap(Persistent<>::Client(kj::refcounted<NoInterceptClient>(*this)));
request.setParams(params);
context->allowCancellation();
......@@ -648,6 +651,13 @@ private:
}
}
return callNoIntercept(interfaceId, methodId, kj::mv(context));
}
VoidPromiseAndPipeline callNoIntercept(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) {
// Implement call() by copying params and results messages.
auto params = context->getParams();
auto request = newCall(interfaceId, methodId, params.targetSize());
......@@ -667,7 +677,6 @@ private:
return connectionState.get();
}
protected:
kj::Own<RpcConnectionState> connectionState;
};
......@@ -933,6 +942,55 @@ private:
}
};
class NoInterceptClient final: public RpcClient {
// A wrapper around an RpcClient which bypasses special handling of "save" requests. When we
// intercept a "save" request and invoke a RealmGateway, we give it a version of the capability
// with intercepting disabled, since usually the first thing the RealmGateway will do is turn
// around and call save() again.
//
// This is admittedly sort of backwards: the interception of "save" ought to be the part
// implemented by a wrapper. However, that would require placing a wrapper around every
// RpcClient we create whereas NoInterceptClient only needs to be injected after a save()
// request occurs and is intercepted.
public:
NoInterceptClient(RpcClient& inner)
: RpcClient(*inner.connectionState),
inner(kj::addRef(inner)) {}
kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) override {
return inner->writeDescriptor(descriptor);
}
kj::Maybe<kj::Own<ClientHook>> writeTarget(rpc::MessageTarget::Builder target) override {
return inner->writeTarget(target);
}
kj::Own<ClientHook> getInnermostClient() override {
return inner->getInnermostClient();
}
Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
return inner->newCallNoIntercept(interfaceId, methodId, sizeHint);
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) override {
return inner->callNoIntercept(interfaceId, methodId, kj::mv(context));
}
kj::Maybe<ClientHook&> getResolved() override {
return nullptr;
}
kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
return nullptr;
}
private:
kj::Own<RpcClient> inner;
};
kj::Maybe<ExportId> writeDescriptor(ClientHook& cap, rpc::CapDescriptor::Builder descriptor) {
// Write a descriptor for the given capability.
......
......@@ -101,12 +101,12 @@ RpcSystem<VatId> makeRpcServer(
// client-server RPC connection.
template <typename VatId, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult,
typename InternalRef, typename ExternalRef>
typename ThirdPartyCapId, typename JoinResult, typename RealmGatewayClient,
typename InternalRef = _::InternalRefFromRealmGatewayClient<RealmGatewayClient>,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
Capability::Client bootstrapInterface,
typename RealmGateway<InternalRef, ExternalRef>::Client gateway);
Capability::Client bootstrapInterface, RealmGatewayClient gateway);
// Make an RPC server for a VatNetwork that resides in a different realm from the application.
// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format
// and the network's ("external") format.
......@@ -147,11 +147,12 @@ RpcSystem<VatId> makeRpcClient(
// client-server RPC connection.
template <typename VatId, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult,
typename InternalRef, typename ExternalRef>
typename ThirdPartyCapId, typename JoinResult, typename RealmGatewayClient,
typename InternalRef = _::InternalRefFromRealmGatewayClient<RealmGatewayClient>,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcClient(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
typename RealmGateway<InternalRef, ExternalRef>::Client gateway);
RealmGatewayClient gateway);
// Make an RPC client for a VatNetwork that resides in a different realm from the application.
// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format
// and the network's ("external") format.
......@@ -366,11 +367,10 @@ RpcSystem<VatId> makeRpcServer(
template <typename VatId, typename ProvisionId, typename RecipientId,
typename ThirdPartyCapId, typename JoinResult,
typename InternalRef, typename ExternalRef>
typename RealmGatewayClient, typename InternalRef, typename ExternalRef>
RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
Capability::Client bootstrapInterface,
typename RealmGateway<InternalRef, ExternalRef>::Client gateway) {
Capability::Client bootstrapInterface, RealmGatewayClient gateway) {
return RpcSystem<VatId>(network, kj::mv(bootstrapInterface),
gateway.template castAs<RealmGateway<>>());
}
......@@ -392,10 +392,10 @@ RpcSystem<VatId> makeRpcClient(
template <typename VatId, typename ProvisionId,
typename RecipientId, typename ThirdPartyCapId, typename JoinResult,
typename InternalRef, typename ExternalRef>
typename RealmGatewayClient, typename InternalRef, typename ExternalRef>
RpcSystem<VatId> makeRpcClient(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
typename RealmGateway<InternalRef, ExternalRef>::Client gateway) {
RealmGatewayClient gateway) {
return RpcSystem<VatId>(network, nullptr, gateway.template castAs<RealmGateway<>>());
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment