// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors // Licensed under the MIT License: // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. #include "rpc.h" #include "test-util.h" #include "schema.h" #include "serialize.h" #include <kj/debug.h> #include <kj/string-tree.h> #include <kj/compat/gtest.h> #include <capnp/rpc.capnp.h> #include <map> #include <queue> // TODO(cleanup): Auto-generate stringification functions for union discriminants. namespace capnp { namespace rpc { inline kj::String KJ_STRINGIFY(Message::Which which) { return kj::str(static_cast<uint16_t>(which)); } } // namespace rpc } // namespace capnp namespace capnp { namespace _ { // private namespace { class RpcDumper { // Class which stringifies RPC messages for debugging purposes, including decoding params and // results based on the call's interface and method IDs and extracting cap descriptors. // // TODO(cleanup): Clean this code up and move it to someplace reusable, so it can be used as // a packet inspector / debugging tool for Cap'n Proto network traffic. public: void addSchema(InterfaceSchema schema) { schemas[schema.getProto().getId()] = schema; } enum Sender { CLIENT, SERVER }; kj::String dump(rpc::Message::Reader message, Sender sender) { const char* senderName = sender == CLIENT ? "client" : "server"; switch (message.which()) { case rpc::Message::CALL: { auto call = message.getCall(); auto iter = schemas.find(call.getInterfaceId()); if (iter == schemas.end()) { break; } InterfaceSchema schema = iter->second; auto methods = schema.getMethods(); if (call.getMethodId() >= methods.size()) { break; } InterfaceSchema::Method method = methods[call.getMethodId()]; auto schemaProto = schema.getProto(); auto interfaceName = schemaProto.getDisplayName().slice(schemaProto.getDisplayNamePrefixLength()); auto methodProto = method.getProto(); auto paramType = method.getParamType(); auto resultType = method.getResultType(); if (call.getSendResultsTo().isCaller()) { returnTypes[std::make_pair(sender, call.getQuestionId())] = resultType; } auto payload = call.getParams(); auto params = kj::str(payload.getContent().getAs<DynamicStruct>(paramType)); auto sendResultsTo = call.getSendResultsTo(); return kj::str(senderName, "(", call.getQuestionId(), "): call ", call.getTarget(), " <- ", interfaceName, ".", methodProto.getName(), params, " caps:[", kj::strArray(payload.getCapTable(), ", "), "]", sendResultsTo.isCaller() ? kj::str() : kj::str(" sendResultsTo:", sendResultsTo)); } case rpc::Message::RETURN: { auto ret = message.getReturn(); auto iter = returnTypes.find( std::make_pair(sender == CLIENT ? SERVER : CLIENT, ret.getAnswerId())); if (iter == returnTypes.end()) { break; } auto schema = iter->second; returnTypes.erase(iter); if (ret.which() != rpc::Return::RESULTS) { // Oops, no results returned. We don't check this earlier because we want to make sure // returnTypes.erase() gets a chance to happen. break; } auto payload = ret.getResults(); if (schema.getProto().isStruct()) { auto results = kj::str(payload.getContent().getAs<DynamicStruct>(schema.asStruct())); return kj::str(senderName, "(", ret.getAnswerId(), "): return ", results, " caps:[", kj::strArray(payload.getCapTable(), ", "), "]"); } else if (schema.getProto().isInterface()) { payload.getContent().getAs<DynamicCapability>(schema.asInterface()); return kj::str(senderName, "(", ret.getAnswerId(), "): return cap ", kj::strArray(payload.getCapTable(), ", ")); } else { break; } } case rpc::Message::BOOTSTRAP: { auto restore = message.getBootstrap(); returnTypes[std::make_pair(sender, restore.getQuestionId())] = InterfaceSchema(); return kj::str(senderName, "(", restore.getQuestionId(), "): bootstrap ", restore.getDeprecatedObjectId().getAs<test::TestSturdyRefObjectId>()); } default: break; } return kj::str(senderName, ": ", message); } private: std::map<uint64_t, InterfaceSchema> schemas; std::map<std::pair<Sender, uint32_t>, Schema> returnTypes; }; // ======================================================================================= class TestNetworkAdapter; class TestNetwork { public: TestNetwork() { dumper.addSchema(Schema::from<test::TestInterface>()); dumper.addSchema(Schema::from<test::TestExtends>()); dumper.addSchema(Schema::from<test::TestPipeline>()); dumper.addSchema(Schema::from<test::TestCallOrder>()); dumper.addSchema(Schema::from<test::TestTailCallee>()); dumper.addSchema(Schema::from<test::TestTailCaller>()); dumper.addSchema(Schema::from<test::TestMoreStuff>()); } ~TestNetwork() noexcept(false); TestNetworkAdapter& add(kj::StringPtr name); kj::Maybe<TestNetworkAdapter&> find(kj::StringPtr name) { auto iter = map.find(name); if (iter == map.end()) { return nullptr; } else { return *iter->second; } } RpcDumper dumper; private: std::map<kj::StringPtr, kj::Own<TestNetworkAdapter>> map; }; typedef VatNetwork< test::TestSturdyRefHostId, test::TestProvisionId, test::TestRecipientId, test::TestThirdPartyCapId, test::TestJoinResult> TestNetworkAdapterBase; class TestNetworkAdapter final: public TestNetworkAdapterBase { public: TestNetworkAdapter(TestNetwork& network): network(network) {} ~TestNetworkAdapter() { kj::Exception exception = KJ_EXCEPTION(FAILED, "Network was destroyed."); for (auto& entry: connections) { entry.second->disconnect(kj::cp(exception)); } } uint getSentCount() { return sent; } uint getReceivedCount() { return received; } typedef TestNetworkAdapterBase::Connection Connection; class ConnectionImpl final : public Connection, public kj::Refcounted, public kj::TaskSet::ErrorHandler { public: ConnectionImpl(TestNetworkAdapter& network, RpcDumper::Sender sender) : network(network), sender(sender), tasks(kj::heap<kj::TaskSet>(*this)) {} void attach(ConnectionImpl& other) { KJ_REQUIRE(partner == nullptr); KJ_REQUIRE(other.partner == nullptr); partner = other; other.partner = *this; } void disconnect(kj::Exception&& exception) { while (!fulfillers.empty()) { fulfillers.front()->reject(kj::cp(exception)); fulfillers.pop(); } networkException = kj::mv(exception); tasks = nullptr; } class IncomingRpcMessageImpl final: public IncomingRpcMessage, public kj::Refcounted { public: IncomingRpcMessageImpl(kj::Array<word> data) : data(kj::mv(data)), message(this->data) {} AnyPointer::Reader getBody() override { return message.getRoot<AnyPointer>(); } void initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>>&& capTable) override { message.initCapTable(kj::mv(capTable)); } kj::Array<word> data; FlatArrayMessageReader message; }; class OutgoingRpcMessageImpl final: public OutgoingRpcMessage { public: OutgoingRpcMessageImpl(ConnectionImpl& connection, uint firstSegmentWordSize) : connection(connection), message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize) {} AnyPointer::Builder getBody() override { return message.getRoot<AnyPointer>(); } kj::ArrayPtr<kj::Maybe<kj::Own<ClientHook>>> getCapTable() override { return message.getCapTable(); } void send() override { if (connection.networkException != nullptr) { return; } ++connection.network.sent; // Uncomment to get a debug dump. // kj::String msg = connection.network.network.dumper.dump( // message.getRoot<rpc::Message>(), connection.sender); // KJ_ DBG(msg); auto incomingMessage = kj::heap<IncomingRpcMessageImpl>(messageToFlatArray(message)); auto connectionPtr = &connection; connection.tasks->add(kj::evalLater(kj::mvCapture(incomingMessage, [connectionPtr](kj::Own<IncomingRpcMessageImpl>&& message) { KJ_IF_MAYBE(p, connectionPtr->partner) { if (p->fulfillers.empty()) { p->messages.push(kj::mv(message)); } else { ++p->network.received; p->fulfillers.front()->fulfill( kj::Own<IncomingRpcMessage>(kj::mv(message))); p->fulfillers.pop(); } } }))); } private: ConnectionImpl& connection; MallocMessageBuilder message; }; kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) override { return kj::heap<OutgoingRpcMessageImpl>(*this, firstSegmentWordSize); } kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override { KJ_IF_MAYBE(e, networkException) { return kj::cp(*e); } if (messages.empty()) { KJ_IF_MAYBE(f, fulfillOnEnd) { f->get()->fulfill(); return kj::Maybe<kj::Own<IncomingRpcMessage>>(nullptr); } else { auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>(); fulfillers.push(kj::mv(paf.fulfiller)); return kj::mv(paf.promise); } } else { ++network.received; auto result = kj::mv(messages.front()); messages.pop(); return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result)); } } kj::Promise<void> shutdown() override { KJ_IF_MAYBE(p, partner) { auto paf = kj::newPromiseAndFulfiller<void>(); p->fulfillOnEnd = kj::mv(paf.fulfiller); return kj::mv(paf.promise); } else { return kj::READY_NOW; } } void taskFailed(kj::Exception&& exception) override { ADD_FAILURE() << kj::str(exception).cStr(); } private: TestNetworkAdapter& network; RpcDumper::Sender sender KJ_UNUSED_MEMBER; kj::Maybe<ConnectionImpl&> partner; kj::Maybe<kj::Exception> networkException; std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers; std::queue<kj::Own<IncomingRpcMessage>> messages; kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> fulfillOnEnd; kj::Own<kj::TaskSet> tasks; }; kj::Maybe<kj::Own<Connection>> connect(test::TestSturdyRefHostId::Reader hostId) override { TestNetworkAdapter& dst = KJ_REQUIRE_NONNULL(network.find(hostId.getHost())); auto iter = connections.find(&dst); if (iter == connections.end()) { auto local = kj::refcounted<ConnectionImpl>(*this, RpcDumper::CLIENT); auto remote = kj::refcounted<ConnectionImpl>(dst, RpcDumper::SERVER); local->attach(*remote); connections[&dst] = kj::addRef(*local); dst.connections[this] = kj::addRef(*remote); if (dst.fulfillerQueue.empty()) { dst.connectionQueue.push(kj::mv(remote)); } else { dst.fulfillerQueue.front()->fulfill(kj::mv(remote)); dst.fulfillerQueue.pop(); } return kj::Own<Connection>(kj::mv(local)); } else { return kj::Own<Connection>(kj::addRef(*iter->second)); } } kj::Promise<kj::Own<Connection>> accept() override { if (connectionQueue.empty()) { auto paf = kj::newPromiseAndFulfiller<kj::Own<Connection>>(); fulfillerQueue.push(kj::mv(paf.fulfiller)); return kj::mv(paf.promise); } else { auto result = kj::mv(connectionQueue.front()); connectionQueue.pop(); return kj::mv(result); } } private: TestNetwork& network; uint sent = 0; uint received = 0; std::map<const TestNetworkAdapter*, kj::Own<ConnectionImpl>> connections; std::queue<kj::Own<kj::PromiseFulfiller<kj::Own<Connection>>>> fulfillerQueue; std::queue<kj::Own<Connection>> connectionQueue; }; TestNetwork::~TestNetwork() noexcept(false) {} TestNetworkAdapter& TestNetwork::add(kj::StringPtr name) { return *(map[name] = kj::heap<TestNetworkAdapter>(*this)); } // ======================================================================================= class TestRestorer final: public SturdyRefRestorer<test::TestSturdyRefObjectId> { public: int callCount = 0; int handleCount = 0; Capability::Client restore(test::TestSturdyRefObjectId::Reader objectId) override { switch (objectId.getTag()) { case test::TestSturdyRefObjectId::Tag::TEST_INTERFACE: return kj::heap<TestInterfaceImpl>(callCount); case test::TestSturdyRefObjectId::Tag::TEST_EXTENDS: return Capability::Client(newBrokenCap("No TestExtends implemented.")); case test::TestSturdyRefObjectId::Tag::TEST_PIPELINE: return kj::heap<TestPipelineImpl>(callCount); case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLEE: return kj::heap<TestTailCalleeImpl>(callCount); case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER: return kj::heap<TestTailCallerImpl>(callCount); case test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF: return kj::heap<TestMoreStuffImpl>(callCount, handleCount); } KJ_UNREACHABLE; } }; struct TestContext { kj::EventLoop loop; kj::WaitScope waitScope; TestNetwork network; TestRestorer restorer; TestNetworkAdapter& clientNetwork; TestNetworkAdapter& serverNetwork; RpcSystem<test::TestSturdyRefHostId> rpcClient; RpcSystem<test::TestSturdyRefHostId> rpcServer; TestContext() : waitScope(loop), clientNetwork(network.add("client")), serverNetwork(network.add("server")), rpcClient(makeRpcClient(clientNetwork)), rpcServer(makeRpcServer(serverNetwork, restorer)) {} TestContext(Capability::Client bootstrap, RealmGateway<test::TestSturdyRef, Text>::Client gateway) : waitScope(loop), clientNetwork(network.add("client")), serverNetwork(network.add("server")), rpcClient(makeRpcClient(clientNetwork, gateway)), rpcServer(makeRpcServer(serverNetwork, bootstrap)) {} TestContext(Capability::Client bootstrap, RealmGateway<test::TestSturdyRef, Text>::Client gateway, bool) : waitScope(loop), clientNetwork(network.add("client")), serverNetwork(network.add("server")), rpcClient(makeRpcClient(clientNetwork)), rpcServer(makeRpcServer(serverNetwork, bootstrap, gateway)) {} Capability::Client connect(test::TestSturdyRefObjectId::Tag tag) { MallocMessageBuilder refMessage(128); auto ref = refMessage.initRoot<test::TestSturdyRef>(); auto hostId = ref.initHostId(); hostId.setHost("server"); ref.getObjectId().initAs<test::TestSturdyRefObjectId>().setTag(tag); return rpcClient.restore(hostId, ref.getObjectId()); } }; TEST(Rpc, Basic) { TestContext context; auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_INTERFACE) .castAs<test::TestInterface>(); auto request1 = client.fooRequest(); request1.setI(123); request1.setJ(true); auto promise1 = request1.send(); // We used to call bar() after baz(), hence the numbering, but this masked the case where the // RPC system actually disconnected on bar() (thus returning an exception, which we decided // was expected). bool barFailed = false; auto request3 = client.barRequest(); auto promise3 = request3.send().then( [](Response<test::TestInterface::BarResults>&& response) { ADD_FAILURE() << "Expected bar() call to fail."; }, [&](kj::Exception&& e) { barFailed = true; }); auto request2 = client.bazRequest(); initTestMessage(request2.initS()); auto promise2 = request2.send(); EXPECT_EQ(0, context.restorer.callCount); auto response1 = promise1.wait(context.waitScope); EXPECT_EQ("foo", response1.getX()); auto response2 = promise2.wait(context.waitScope); promise3.wait(context.waitScope); EXPECT_EQ(2, context.restorer.callCount); EXPECT_TRUE(barFailed); } TEST(Rpc, Pipelining) { TestContext context; auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_PIPELINE) .castAs<test::TestPipeline>(); int chainedCallCount = 0; auto request = client.getCapRequest(); request.setN(234); request.setInCap(kj::heap<TestInterfaceImpl>(chainedCallCount)); auto promise = request.send(); auto pipelineRequest = promise.getOutBox().getCap().fooRequest(); pipelineRequest.setI(321); auto pipelinePromise = pipelineRequest.send(); auto pipelineRequest2 = promise.getOutBox().getCap().castAs<test::TestExtends>().graultRequest(); auto pipelinePromise2 = pipelineRequest2.send(); promise = nullptr; // Just to be annoying, drop the original promise. EXPECT_EQ(0, context.restorer.callCount); EXPECT_EQ(0, chainedCallCount); auto response = pipelinePromise.wait(context.waitScope); EXPECT_EQ("bar", response.getX()); auto response2 = pipelinePromise2.wait(context.waitScope); checkTestMessage(response2); EXPECT_EQ(3, context.restorer.callCount); EXPECT_EQ(1, chainedCallCount); } TEST(Rpc, Release) { TestContext context; auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) .castAs<test::TestMoreStuff>(); auto handle1 = client.getHandleRequest().send().wait(context.waitScope).getHandle(); auto promise = client.getHandleRequest().send(); auto handle2 = promise.wait(context.waitScope).getHandle(); EXPECT_EQ(2, context.restorer.handleCount); handle1 = nullptr; for (uint i = 0; i < 16; i++) kj::evalLater([]() {}).wait(context.waitScope); EXPECT_EQ(1, context.restorer.handleCount); handle2 = nullptr; for (uint i = 0; i < 16; i++) kj::evalLater([]() {}).wait(context.waitScope); EXPECT_EQ(1, context.restorer.handleCount); promise = nullptr; for (uint i = 0; i < 16; i++) kj::evalLater([]() {}).wait(context.waitScope); EXPECT_EQ(0, context.restorer.handleCount); } TEST(Rpc, ReleaseOnCancel) { // At one time, there was a bug where if a Return contained capabilities, but the client had // canceled the request and already send a Finish (which presumably didn't reach the server before // the Return), then we'd leak those caps. Test for that. TestContext context; auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) .castAs<test::TestMoreStuff>(); client.whenResolved().wait(context.waitScope); { auto promise = client.getHandleRequest().send(); // If the server receives cancellation too early, it won't even return a capability in the // results, it will just return "canceled". We want to emulate the case where the return message // and the cancel (finish) message cross paths. It turns out that exactly two evalLater()s get // us there. // // TODO(cleanup): This is fragile, but I'm not sure how else to write it without a ton // of scaffolding. kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); } for (uint i = 0; i < 16; i++) kj::evalLater([]() {}).wait(context.waitScope); EXPECT_EQ(0, context.restorer.handleCount); } TEST(Rpc, TailCall) { TestContext context; auto caller = context.connect(test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER) .castAs<test::TestTailCaller>(); int calleeCallCount = 0; test::TestTailCallee::Client callee(kj::heap<TestTailCalleeImpl>(calleeCallCount)); auto request = caller.fooRequest(); request.setI(456); request.setCallee(callee); auto promise = request.send(); auto dependentCall0 = promise.getC().getCallSequenceRequest().send(); auto response = promise.wait(context.waitScope); EXPECT_EQ(456, response.getI()); EXPECT_EQ("from TestTailCaller", response.getT()); auto dependentCall1 = promise.getC().getCallSequenceRequest().send(); auto dependentCall2 = response.getC().getCallSequenceRequest().send(); EXPECT_EQ(0, dependentCall0.wait(context.waitScope).getN()); EXPECT_EQ(1, dependentCall1.wait(context.waitScope).getN()); EXPECT_EQ(2, dependentCall2.wait(context.waitScope).getN()); EXPECT_EQ(1, calleeCallCount); EXPECT_EQ(1, context.restorer.callCount); } TEST(Rpc, Cancelation) { // Tests allowCancellation(). TestContext context; auto paf = kj::newPromiseAndFulfiller<void>(); bool destroyed = false; auto destructionPromise = paf.promise.then([&]() { destroyed = true; }).eagerlyEvaluate(nullptr); auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) .castAs<test::TestMoreStuff>(); kj::Promise<void> promise = nullptr; bool returned = false; { auto request = client.expectCancelRequest(); request.setCap(kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller))); promise = request.send().then( [&](Response<test::TestMoreStuff::ExpectCancelResults>&& response) { returned = true; }).eagerlyEvaluate(nullptr); } kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); // We can detect that the method was canceled because it will drop the cap. EXPECT_FALSE(destroyed); EXPECT_FALSE(returned); promise = nullptr; // request cancellation destructionPromise.wait(context.waitScope); EXPECT_TRUE(destroyed); EXPECT_FALSE(returned); } TEST(Rpc, PromiseResolve) { TestContext context; auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) .castAs<test::TestMoreStuff>(); int chainedCallCount = 0; auto request = client.callFooRequest(); auto request2 = client.callFooWhenResolvedRequest(); auto paf = kj::newPromiseAndFulfiller<test::TestInterface::Client>(); { auto fork = paf.promise.fork(); request.setCap(fork.addBranch()); request2.setCap(fork.addBranch()); } auto promise = request.send(); auto promise2 = request2.send(); // Make sure getCap() has been called on the server side by sending another call and waiting // for it. EXPECT_EQ(2, client.getCallSequenceRequest().send().wait(context.waitScope).getN()); EXPECT_EQ(3, context.restorer.callCount); // OK, now fulfill the local promise. paf.fulfiller->fulfill(kj::heap<TestInterfaceImpl>(chainedCallCount)); // We should now be able to wait for getCap() to finish. EXPECT_EQ("bar", promise.wait(context.waitScope).getS()); EXPECT_EQ("bar", promise2.wait(context.waitScope).getS()); EXPECT_EQ(3, context.restorer.callCount); EXPECT_EQ(2, chainedCallCount); } TEST(Rpc, RetainAndRelease) { TestContext context; auto paf = kj::newPromiseAndFulfiller<void>(); bool destroyed = false; auto destructionPromise = paf.promise.then([&]() { destroyed = true; }).eagerlyEvaluate(nullptr); { auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) .castAs<test::TestMoreStuff>(); { auto request = client.holdRequest(); request.setCap(kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller))); request.send().wait(context.waitScope); } // Do some other call to add a round trip. EXPECT_EQ(1, client.getCallSequenceRequest().send().wait(context.waitScope).getN()); // Shouldn't be destroyed because it's being held by the server. EXPECT_FALSE(destroyed); // We can ask it to call the held capability. EXPECT_EQ("bar", client.callHeldRequest().send().wait(context.waitScope).getS()); { // We can get the cap back from it. auto capCopy = client.getHeldRequest().send().wait(context.waitScope).getCap(); { // And call it, without any network communications. uint oldSentCount = context.clientNetwork.getSentCount(); auto request = capCopy.fooRequest(); request.setI(123); request.setJ(true); EXPECT_EQ("foo", request.send().wait(context.waitScope).getX()); EXPECT_EQ(oldSentCount, context.clientNetwork.getSentCount()); } { // We can send another copy of the same cap to another method, and it works. auto request = client.callFooRequest(); request.setCap(capCopy); EXPECT_EQ("bar", request.send().wait(context.waitScope).getS()); } } // Give some time to settle. EXPECT_EQ(5, client.getCallSequenceRequest().send().wait(context.waitScope).getN()); EXPECT_EQ(6, client.getCallSequenceRequest().send().wait(context.waitScope).getN()); EXPECT_EQ(7, client.getCallSequenceRequest().send().wait(context.waitScope).getN()); // Can't be destroyed, we haven't released it. EXPECT_FALSE(destroyed); } // We released our client, which should cause the server to be released, which in turn will // release the cap pointing back to us. destructionPromise.wait(context.waitScope); EXPECT_TRUE(destroyed); } TEST(Rpc, Cancel) { TestContext context; auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) .castAs<test::TestMoreStuff>(); auto paf = kj::newPromiseAndFulfiller<void>(); bool destroyed = false; auto destructionPromise = paf.promise.then([&]() { destroyed = true; }).eagerlyEvaluate(nullptr); { auto request = client.neverReturnRequest(); request.setCap(kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller))); { auto responsePromise = request.send(); // Allow some time to settle. EXPECT_EQ(1, client.getCallSequenceRequest().send().wait(context.waitScope).getN()); EXPECT_EQ(2, client.getCallSequenceRequest().send().wait(context.waitScope).getN()); // The cap shouldn't have been destroyed yet because the call never returned. EXPECT_FALSE(destroyed); } } // Now the cap should be released. destructionPromise.wait(context.waitScope); EXPECT_TRUE(destroyed); } TEST(Rpc, SendTwice) { TestContext context; auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) .castAs<test::TestMoreStuff>(); auto paf = kj::newPromiseAndFulfiller<void>(); bool destroyed = false; auto destructionPromise = paf.promise.then([&]() { destroyed = true; }).eagerlyEvaluate(nullptr); auto cap = test::TestInterface::Client(kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller))); { auto request = client.callFooRequest(); request.setCap(cap); EXPECT_EQ("bar", request.send().wait(context.waitScope).getS()); } // Allow some time for the server to release `cap`. EXPECT_EQ(1, client.getCallSequenceRequest().send().wait(context.waitScope).getN()); { // More requests with the same cap. auto request = client.callFooRequest(); auto request2 = client.callFooRequest(); request.setCap(cap); request2.setCap(kj::mv(cap)); auto promise = request.send(); auto promise2 = request2.send(); EXPECT_EQ("bar", promise.wait(context.waitScope).getS()); EXPECT_EQ("bar", promise2.wait(context.waitScope).getS()); } // Now the cap should be released. destructionPromise.wait(context.waitScope); EXPECT_TRUE(destroyed); } RemotePromise<test::TestCallOrder::GetCallSequenceResults> getCallSequence( test::TestCallOrder::Client& client, uint expected) { auto req = client.getCallSequenceRequest(); req.setExpected(expected); return req.send(); } TEST(Rpc, Embargo) { TestContext context; auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) .castAs<test::TestMoreStuff>(); auto cap = test::TestCallOrder::Client(kj::heap<TestCallOrderImpl>()); auto earlyCall = client.getCallSequenceRequest().send(); auto echoRequest = client.echoRequest(); echoRequest.setCap(cap); auto echo = echoRequest.send(); auto pipeline = echo.getCap(); auto call0 = getCallSequence(pipeline, 0); auto call1 = getCallSequence(pipeline, 1); earlyCall.wait(context.waitScope); auto call2 = getCallSequence(pipeline, 2); auto resolved = echo.wait(context.waitScope).getCap(); auto call3 = getCallSequence(pipeline, 3); auto call4 = getCallSequence(pipeline, 4); auto call5 = getCallSequence(pipeline, 5); EXPECT_EQ(0, call0.wait(context.waitScope).getN()); EXPECT_EQ(1, call1.wait(context.waitScope).getN()); EXPECT_EQ(2, call2.wait(context.waitScope).getN()); EXPECT_EQ(3, call3.wait(context.waitScope).getN()); EXPECT_EQ(4, call4.wait(context.waitScope).getN()); EXPECT_EQ(5, call5.wait(context.waitScope).getN()); } template <typename T> void expectPromiseThrows(kj::Promise<T>&& promise, kj::WaitScope& waitScope) { EXPECT_TRUE(promise.then([](T&&) { return false; }, [](kj::Exception&&) { return true; }) .wait(waitScope)); } template <> void expectPromiseThrows(kj::Promise<void>&& promise, kj::WaitScope& waitScope) { EXPECT_TRUE(promise.then([]() { return false; }, [](kj::Exception&&) { return true; }) .wait(waitScope)); } TEST(Rpc, EmbargoError) { TestContext context; auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) .castAs<test::TestMoreStuff>(); auto paf = kj::newPromiseAndFulfiller<test::TestCallOrder::Client>(); auto cap = test::TestCallOrder::Client(kj::mv(paf.promise)); auto earlyCall = client.getCallSequenceRequest().send(); auto echoRequest = client.echoRequest(); echoRequest.setCap(cap); auto echo = echoRequest.send(); auto pipeline = echo.getCap(); auto call0 = getCallSequence(pipeline, 0); auto call1 = getCallSequence(pipeline, 1); earlyCall.wait(context.waitScope); auto call2 = getCallSequence(pipeline, 2); auto resolved = echo.wait(context.waitScope).getCap(); auto call3 = getCallSequence(pipeline, 3); auto call4 = getCallSequence(pipeline, 4); auto call5 = getCallSequence(pipeline, 5); paf.fulfiller->rejectIfThrows([]() { KJ_FAIL_ASSERT("foo") { break; } }); expectPromiseThrows(kj::mv(call0), context.waitScope); expectPromiseThrows(kj::mv(call1), context.waitScope); expectPromiseThrows(kj::mv(call2), context.waitScope); expectPromiseThrows(kj::mv(call3), context.waitScope); expectPromiseThrows(kj::mv(call4), context.waitScope); expectPromiseThrows(kj::mv(call5), context.waitScope); // Verify that we're still connected (there were no protocol errors). getCallSequence(client, 1).wait(context.waitScope); } TEST(Rpc, CallBrokenPromise) { // Tell the server to call back to a promise client, then resolve the promise to an error. TestContext context; auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) .castAs<test::TestMoreStuff>(); auto paf = kj::newPromiseAndFulfiller<test::TestInterface::Client>(); { auto req = client.holdRequest(); req.setCap(kj::mv(paf.promise)); req.send().wait(context.waitScope); } bool returned = false; auto req = client.callHeldRequest().send() .then([&](capnp::Response<test::TestMoreStuff::CallHeldResults>&&) { returned = true; }, [&](kj::Exception&& e) { returned = true; kj::throwRecoverableException(kj::mv(e)); }).eagerlyEvaluate(nullptr); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); EXPECT_FALSE(returned); paf.fulfiller->rejectIfThrows([]() { KJ_FAIL_ASSERT("foo") { break; } }); expectPromiseThrows(kj::mv(req), context.waitScope); EXPECT_TRUE(returned); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); kj::evalLater([]() {}).wait(context.waitScope); // Verify that we're still connected (there were no protocol errors). getCallSequence(client, 1).wait(context.waitScope); } TEST(Rpc, Abort) { // Verify that aborts are received. TestContext context; MallocMessageBuilder refMessage(128); auto hostId = refMessage.initRoot<test::TestSturdyRefHostId>(); hostId.setHost("server"); auto conn = KJ_ASSERT_NONNULL(context.clientNetwork.connect(hostId)); { // Send an invalid message (Return to non-existent question). auto msg = conn->newOutgoingMessage(128); auto body = msg->getBody().initAs<rpc::Message>().initReturn(); body.setAnswerId(1234); body.setCanceled(); msg->send(); } auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(context.waitScope)); EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs<rpc::Message>().which()); EXPECT_TRUE(conn->receiveIncomingMessage().wait(context.waitScope) == nullptr); } // ======================================================================================= typedef RealmGateway<test::TestSturdyRef, Text> TestRealmGateway; class TestGateway final: public TestRealmGateway::Server { public: kj::Promise<void> import(ImportContext context) override { auto cap = context.getParams().getCap(); context.releaseParams(); return cap.saveRequest().send() .then([context](Response<Persistent<Text>::SaveResults> response) mutable { context.getResults().initSturdyRef().getObjectId().setAs<Text>( kj::str("imported-", response.getSturdyRef())); }); } kj::Promise<void> export_(ExportContext context) override { auto cap = context.getParams().getCap(); context.releaseParams(); return cap.saveRequest().send() .then([context](Response<Persistent<test::TestSturdyRef>::SaveResults> response) mutable { context.getResults().setSturdyRef(kj::str("exported-", response.getSturdyRef().getObjectId().getAs<Text>())); }); } }; class TestPersistent final: public Persistent<test::TestSturdyRef>::Server { public: TestPersistent(kj::StringPtr name): name(name) {} kj::Promise<void> save(SaveContext context) override { context.initResults().initSturdyRef().getObjectId().setAs<Text>(name); return kj::READY_NOW; } private: kj::StringPtr name; }; class TestPersistentText final: public Persistent<Text>::Server { public: TestPersistentText(kj::StringPtr name): name(name) {} kj::Promise<void> save(SaveContext context) override { context.initResults().setSturdyRef(name); return kj::READY_NOW; } private: kj::StringPtr name; }; TEST(Rpc, RealmGatewayImport) { TestRealmGateway::Client gateway = kj::heap<TestGateway>(); Persistent<Text>::Client bootstrap = kj::heap<TestPersistentText>("foo"); MallocMessageBuilder hostIdBuilder; auto hostId = hostIdBuilder.getRoot<test::TestSturdyRefHostId>(); hostId.setHost("server"); TestContext context(bootstrap, gateway); auto client = context.rpcClient.bootstrap(hostId).castAs<Persistent<test::TestSturdyRef>>(); auto response = client.saveRequest().send().wait(context.waitScope); EXPECT_EQ("imported-foo", response.getSturdyRef().getObjectId().getAs<Text>()); } TEST(Rpc, RealmGatewayExport) { TestRealmGateway::Client gateway = kj::heap<TestGateway>(); Persistent<test::TestSturdyRef>::Client bootstrap = kj::heap<TestPersistent>("foo"); MallocMessageBuilder hostIdBuilder; auto hostId = hostIdBuilder.getRoot<test::TestSturdyRefHostId>(); hostId.setHost("server"); TestContext context(bootstrap, gateway, true); auto client = context.rpcClient.bootstrap(hostId).castAs<Persistent<Text>>(); auto response = client.saveRequest().send().wait(context.waitScope); EXPECT_EQ("exported-foo", response.getSturdyRef()); } } // namespace } // namespace _ (private) } // namespace capnp