Unverified Commit b2fe9c54 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #911 from capnproto/fix-self-bootstrap

Fix RPC loopback bootstrap().
parents 7c35ceb9 77a57f8c
...@@ -198,7 +198,7 @@ typedef VatNetwork< ...@@ -198,7 +198,7 @@ typedef VatNetwork<
class TestNetworkAdapter final: public TestNetworkAdapterBase { class TestNetworkAdapter final: public TestNetworkAdapterBase {
public: public:
TestNetworkAdapter(TestNetwork& network): network(network) {} TestNetworkAdapter(TestNetwork& network, kj::StringPtr self): network(network), self(self) {}
~TestNetworkAdapter() { ~TestNetworkAdapter() {
kj::Exception exception = KJ_EXCEPTION(FAILED, "Network was destroyed."); kj::Exception exception = KJ_EXCEPTION(FAILED, "Network was destroyed.");
...@@ -362,6 +362,10 @@ public: ...@@ -362,6 +362,10 @@ public:
}; };
kj::Maybe<kj::Own<Connection>> connect(test::TestSturdyRefHostId::Reader hostId) override { kj::Maybe<kj::Own<Connection>> connect(test::TestSturdyRefHostId::Reader hostId) override {
if (hostId.getHost() == self) {
return nullptr;
}
TestNetworkAdapter& dst = KJ_REQUIRE_NONNULL(network.find(hostId.getHost())); TestNetworkAdapter& dst = KJ_REQUIRE_NONNULL(network.find(hostId.getHost()));
auto iter = connections.find(&dst); auto iter = connections.find(&dst);
...@@ -400,6 +404,7 @@ public: ...@@ -400,6 +404,7 @@ public:
private: private:
TestNetwork& network; TestNetwork& network;
kj::StringPtr self;
uint sent = 0; uint sent = 0;
uint received = 0; uint received = 0;
...@@ -411,7 +416,7 @@ private: ...@@ -411,7 +416,7 @@ private:
TestNetwork::~TestNetwork() noexcept(false) {} TestNetwork::~TestNetwork() noexcept(false) {}
TestNetworkAdapter& TestNetwork::add(kj::StringPtr name) { TestNetworkAdapter& TestNetwork::add(kj::StringPtr name) {
return *(map[name] = kj::heap<TestNetworkAdapter>(*this)); return *(map[name] = kj::heap<TestNetworkAdapter>(*this, name));
} }
// ======================================================================================= // =======================================================================================
...@@ -456,6 +461,12 @@ struct TestContext { ...@@ -456,6 +461,12 @@ struct TestContext {
serverNetwork(network.add("server")), serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork)), rpcClient(makeRpcClient(clientNetwork)),
rpcServer(makeRpcServer(serverNetwork, restorer)) {} rpcServer(makeRpcServer(serverNetwork, restorer)) {}
TestContext(Capability::Client bootstrap)
: waitScope(loop),
clientNetwork(network.add("client")),
serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork)),
rpcServer(makeRpcServer(serverNetwork, bootstrap)) {}
TestContext(Capability::Client bootstrap, TestContext(Capability::Client bootstrap,
RealmGateway<test::TestSturdyRef, Text>::Client gateway) RealmGateway<test::TestSturdyRef, Text>::Client gateway)
: waitScope(loop), : waitScope(loop),
...@@ -1264,6 +1275,26 @@ TEST(Rpc, RealmGatewayImportExport) { ...@@ -1264,6 +1275,26 @@ TEST(Rpc, RealmGatewayImportExport) {
EXPECT_EQ("foo", response.getSturdyRef()); EXPECT_EQ("foo", response.getSturdyRef());
} }
KJ_TEST("loopback bootstrap()") {
int callCount = 0;
test::TestInterface::Client bootstrap = kj::heap<TestInterfaceImpl>(callCount);
MallocMessageBuilder hostIdBuilder;
auto hostId = hostIdBuilder.getRoot<test::TestSturdyRefHostId>();
hostId.setHost("server");
TestContext context(bootstrap);
auto client = context.rpcServer.bootstrap(hostId).castAs<test::TestInterface>();
auto request = client.fooRequest();
request.setI(123);
request.setJ(true);
auto response = request.send().wait(context.waitScope);
KJ_EXPECT(response.getX() == "foo");
KJ_EXPECT(callCount == 1);
}
} // namespace } // namespace
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -2999,11 +2999,16 @@ public: ...@@ -2999,11 +2999,16 @@ public:
KJ_IF_MAYBE(connection, network.baseConnect(vatId)) { KJ_IF_MAYBE(connection, network.baseConnect(vatId)) {
auto& state = getConnectionState(kj::mv(*connection)); auto& state = getConnectionState(kj::mv(*connection));
return Capability::Client(state.restore(objectId)); return Capability::Client(state.restore(objectId));
} else if (objectId.isNull()) {
// Turns out `vatId` refers to ourselves, so we can also pass it as the client ID for
// baseCreateFor().
return bootstrapFactory.baseCreateFor(vatId);
} else KJ_IF_MAYBE(r, restorer) { } else KJ_IF_MAYBE(r, restorer) {
return r->baseRestore(objectId); return r->baseRestore(objectId);
} else { } else {
return Capability::Client(newBrokenCap( return Capability::Client(newBrokenCap(
"SturdyRef referred to a local object but there is no local SturdyRef restorer.")); "This vat only supports a bootstrap interface, not the old Cap'n-Proto-0.4-style "
"named exports."));
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment