Commit b1e502f7 authored by Kenton Varda's avatar Kenton Varda

Test and fix promise resolution across RPC.

parent a30c4171
...@@ -37,6 +37,9 @@ namespace capnp { ...@@ -37,6 +37,9 @@ namespace capnp {
Capability::Client::Client(decltype(nullptr)) Capability::Client::Client(decltype(nullptr))
: hook(newBrokenCap("Called null capability.")) {} : hook(newBrokenCap("Called null capability.")) {}
Capability::Client::Client(kj::Exception&& exception)
: hook(newBrokenCap(kj::mv(exception))) {}
kj::Promise<void> Capability::Server::internalUnimplemented( kj::Promise<void> Capability::Server::internalUnimplemented(
const char* actualInterfaceName, uint64_t requestedTypeId) { const char* actualInterfaceName, uint64_t requestedTypeId) {
KJ_FAIL_REQUIRE("Requested interface not implemented.", actualInterfaceName, requestedTypeId) { KJ_FAIL_REQUIRE("Requested interface not implemented.", actualInterfaceName, requestedTypeId) {
......
...@@ -108,22 +108,36 @@ class Capability::Client { ...@@ -108,22 +108,36 @@ class Capability::Client {
public: public:
explicit Client(decltype(nullptr)); explicit Client(decltype(nullptr));
explicit Client(kj::Own<const ClientHook>&& hook); // If you need to declare a Client before you have anything to assign to it (perhaps because
// the assignment is going to occur in an if/else scope), you can start by initializing it to
// `nullptr`. The resulting client is not meant to be called and throws exceptions from all
// methods.
template <typename T, typename = kj::EnableIf<kj::canConvert<T*, Capability::Server*>()>> template <typename T, typename = kj::EnableIf<kj::canConvert<T*, Capability::Server*>()>>
Client(kj::Own<T>&& server, const kj::EventLoop& loop = kj::EventLoop::current()); Client(kj::Own<T>&& server, const kj::EventLoop& loop = kj::EventLoop::current());
// Make a client capability that wraps the given server capability. The server's methods will // Make a client capability that wraps the given server capability. The server's methods will
// only be executed in the given EventLoop, regardless of what thread calls the client's methods. // only be executed in the given EventLoop, regardless of what thread calls the client's methods.
template <typename T, typename = kj::EnableIf<kj::canConvert<T*, Client*>()>>
Client(kj::Promise<T>&& promise, const kj::EventLoop& loop = kj::EventLoop::current());
// Make a client from a promise for a future client. The resulting client queues calls until the
// promise resolves.
Client(kj::Exception&& exception);
// Make a broken client that throws the given exception from all calls.
Client(const Client& other); Client(const Client& other);
Client& operator=(const Client& other); Client& operator=(const Client& other);
// Copies by reference counting. Warning: Refcounting is slow due to atomic ops. Try to only // Copies by reference counting. Warning: Refcounting may be slow due to atomic ops. Try to
// use move instead. // only use move instead.
Client(Client&&) = default; Client(Client&&) = default;
Client& operator=(Client&&) = default; Client& operator=(Client&&) = default;
// Move is fast. // Move is fast.
explicit Client(kj::Own<const ClientHook>&& hook);
// For use by the RPC implementation: Wrap a ClientHook.
template <typename T> template <typename T>
typename T::Client castAs() const; typename T::Client castAs() const;
// Reinterpret the capability as implementing the given interface. Note that no error will occur // Reinterpret the capability as implementing the given interface. Note that no error will occur
...@@ -533,6 +547,11 @@ inline Capability::Client::Client(kj::Own<const ClientHook>&& hook): hook(kj::mv ...@@ -533,6 +547,11 @@ inline Capability::Client::Client(kj::Own<const ClientHook>&& hook): hook(kj::mv
template <typename T, typename> template <typename T, typename>
inline Capability::Client::Client(kj::Own<T>&& server, const kj::EventLoop& loop) inline Capability::Client::Client(kj::Own<T>&& server, const kj::EventLoop& loop)
: hook(makeLocalClient(kj::mv(server), loop)) {} : hook(makeLocalClient(kj::mv(server), loop)) {}
template <typename T, typename>
inline Capability::Client::Client(kj::Promise<T>&& promise, const kj::EventLoop& loop)
: hook(newLocalPromiseClient(
promise.thenInAnyThread([](T&& t) { return kj::mv(t.hook); }),
loop)) {}
inline Capability::Client::Client(const Client& other): hook(other.hook->addRef()) {} inline Capability::Client::Client(const Client& other): hook(other.hook->addRef()) {}
inline Capability::Client& Capability::Client::operator=(const Client& other) { inline Capability::Client& Capability::Client::operator=(const Client& other) {
hook = other.hook->addRef(); hook = other.hook->addRef();
......
...@@ -1307,6 +1307,13 @@ private: ...@@ -1307,6 +1307,13 @@ private:
" inline Client(::kj::Own<T>&& server,\n" " inline Client(::kj::Own<T>&& server,\n"
" const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n" " const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n"
" : ::capnp::Capability::Client(::kj::mv(server), loop) {}\n" " : ::capnp::Capability::Client(::kj::mv(server), loop) {}\n"
" template <typename T,\n"
" typename = ::kj::EnableIf< ::kj::canConvert<T*, Client*>()>>\n"
" inline Client(::kj::Promise<T>&& promise,\n"
" const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n"
" : ::capnp::Capability::Client(::kj::mv(promise), loop) {}\n"
" inline Client(::kj::Exception&& exception)\n"
" : ::capnp::Capability::Client(::kj::mv(exception)) {}\n"
"\n", "\n",
KJ_MAP(m, methods) { return kj::mv(m.clientDecls); }, KJ_MAP(m, methods) { return kj::mv(m.clientDecls); },
"\n" "\n"
......
...@@ -245,6 +245,8 @@ public: ...@@ -245,6 +245,8 @@ public:
return kj::heap<TestTailCalleeImpl>(callCount); return kj::heap<TestTailCalleeImpl>(callCount);
case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER: case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER:
return kj::heap<TestTailCallerImpl>(callCount); return kj::heap<TestTailCallerImpl>(callCount);
case test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF:
return kj::heap<TestMoreStuffImpl>(callCount);
} }
KJ_UNREACHABLE; KJ_UNREACHABLE;
} }
...@@ -383,6 +385,45 @@ TEST_F(RpcTest, TailCall) { ...@@ -383,6 +385,45 @@ TEST_F(RpcTest, TailCall) {
EXPECT_EQ(1, restorer.callCount); EXPECT_EQ(1, restorer.callCount);
} }
TEST_F(RpcTest, PromiseResolve) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
int chainedCallCount = 0;
auto request = client.callFooRequest();
auto request2 = client.callFooWhenResolvedRequest();
auto paf = kj::newPromiseAndFulfiller<test::TestInterface::Client>();
{
auto fork = loop.fork(kj::mv(paf.promise));
request.setCap(test::TestInterface::Client(fork.addBranch(), loop));
request2.setCap(test::TestInterface::Client(fork.addBranch(), loop));
}
auto promise = request.send();
auto promise2 = request2.send();
{
// Make sure getCap() has been called on the server side by sending another call and waiting
// for it.
EXPECT_EQ(2, loop.wait(client.getCallSequenceRequest().send()).getN());
EXPECT_EQ(3, restorer.callCount);
}
// OK, now fulfill the local promise.
paf.fulfiller->fulfill(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(chainedCallCount), loop));
// We should now be able to wait for getCap() to finish.
EXPECT_EQ("bar", loop.wait(kj::mv(promise)).getS());
EXPECT_EQ("bar", loop.wait(kj::mv(promise2)).getS());
EXPECT_EQ(3, restorer.callCount);
EXPECT_EQ(2, chainedCallCount);
}
} // namespace } // namespace
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -49,6 +49,8 @@ public: ...@@ -49,6 +49,8 @@ public:
return kj::heap<TestTailCalleeImpl>(callCount); return kj::heap<TestTailCalleeImpl>(callCount);
case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER: case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER:
return kj::heap<TestTailCallerImpl>(callCount); return kj::heap<TestTailCallerImpl>(callCount);
case test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF:
return kj::heap<TestMoreStuffImpl>(callCount);
} }
KJ_UNREACHABLE; KJ_UNREACHABLE;
} }
......
...@@ -1121,11 +1121,11 @@ private: ...@@ -1121,11 +1121,11 @@ private:
lock->exportsByCap.erase(exp.clientHook); lock->exportsByCap.erase(exp.clientHook);
exp.clientHook = kj::mv(resolution); exp.clientHook = kj::mv(resolution);
if (resolution->getBrand() != this) { if (exp.clientHook->getBrand() != this) {
// We're resolving to a local capability. If we're resolving to a promise, we might be // We're resolving to a local capability. If we're resolving to a promise, we might be
// able to reuse our export table entry and avoid sending a message. // able to reuse our export table entry and avoid sending a message.
KJ_IF_MAYBE(promise, resolution->whenMoreResolved()) { KJ_IF_MAYBE(promise, exp.clientHook->whenMoreResolved()) {
// We're replacing a promise with another local promise. In this case, we might actually // We're replacing a promise with another local promise. In this case, we might actually
// be able to just reuse the existing export table entry to represent the new promise -- // be able to just reuse the existing export table entry to represent the new promise --
// unless it already has an entry. Let's check. // unless it already has an entry. Let's check.
...@@ -1147,7 +1147,7 @@ private: ...@@ -1147,7 +1147,7 @@ private:
messageSizeHint<rpc::Resolve>() + sizeInWords<rpc::CapDescriptor>() + 16); messageSizeHint<rpc::Resolve>() + sizeInWords<rpc::CapDescriptor>() + 16);
auto resolve = message->getBody().initAs<rpc::Message>().initResolve(); auto resolve = message->getBody().initAs<rpc::Message>().initResolve();
resolve.setPromiseId(exportId); resolve.setPromiseId(exportId);
writeDescriptor(*resolution, resolve.initCap(), *this->tables.lockExclusive()); writeDescriptor(*exp.clientHook, resolve.initCap(), *lock);
message->send(); message->send();
return kj::READY_NOW; return kj::READY_NOW;
...@@ -2670,9 +2670,9 @@ private: ...@@ -2670,9 +2670,9 @@ private:
} }
// Extend the resolution chain. // Extend the resolution chain.
auto oldTail = kj::mv(lock->resolutionChainTail); oldResolutionChainTail = kj::mv(lock->resolutionChainTail);
lock->resolutionChainTail = oldTail->addResolve(resolve.getPromiseId(), kj::mv(replacement)); lock->resolutionChainTail = oldResolutionChainTail->addResolve(
lock.release(); // in case oldTail is destroyed resolve.getPromiseId(), kj::mv(replacement));
// If the import is on the table, fulfill it. // If the import is on the table, fulfill it.
KJ_IF_MAYBE(import, lock->imports.find(resolve.getPromiseId())) { KJ_IF_MAYBE(import, lock->imports.find(resolve.getPromiseId())) {
......
...@@ -971,5 +971,54 @@ kj::Promise<void> TestTailCalleeImpl::fooAdvanced( ...@@ -971,5 +971,54 @@ kj::Promise<void> TestTailCalleeImpl::fooAdvanced(
return kj::READY_NOW; return kj::READY_NOW;
} }
TestMoreStuffImpl::TestMoreStuffImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestMoreStuffImpl::getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) {
result.setN(callCount++);
return kj::READY_NOW;
}
::kj::Promise<void> TestMoreStuffImpl::callFoo(
test::TestMoreStuff::CallFooParams::Reader params,
test::TestMoreStuff::CallFooResults::Builder result) {
++callCount;
auto cap = params.getCap();
auto request = cap.fooRequest();
request.setI(123);
request.setJ(true);
return request.send().then(
[result](capnp::Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
result.setS("bar");
});
}
kj::Promise<void> TestMoreStuffImpl::callFooWhenResolved(
test::TestMoreStuff::CallFooWhenResolvedParams::Reader params,
test::TestMoreStuff::CallFooWhenResolvedResults::Builder result) {
++callCount;
auto cap = params.getCap();
return cap.whenResolved().then([cap,result]() mutable {
auto request = cap.fooRequest();
request.setI(123);
request.setJ(true);
return request.send().then(
[result](capnp::Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
result.setS("bar");
});
});
}
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -221,6 +221,26 @@ private: ...@@ -221,6 +221,26 @@ private:
int& callCount; int& callCount;
}; };
class TestMoreStuffImpl final: public test::TestMoreStuff::Server {
public:
TestMoreStuffImpl(int& callCount);
kj::Promise<void> getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) override;
::kj::Promise<void> callFoo(
test::TestMoreStuff::CallFooParams::Reader params,
test::TestMoreStuff::CallFooResults::Builder result) override;
kj::Promise<void> callFooWhenResolved(
test::TestMoreStuff::CallFooWhenResolvedParams::Reader params,
test::TestMoreStuff::CallFooWhenResolvedResults::Builder result) override;
private:
int& callCount;
};
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
......
...@@ -590,7 +590,7 @@ const derivedConstant :TestAllTypes = ( ...@@ -590,7 +590,7 @@ const derivedConstant :TestAllTypes = (
structList = TestConstants.structListConst); structList = TestConstants.structListConst);
interface TestInterface { interface TestInterface {
foo @0 (i :UInt32, j :Bool) -> (x: Text); foo @0 (i :UInt32, j :Bool) -> (x :Text);
bar @1 () -> (); bar @1 () -> ();
baz @2 (s: TestAllTypes); baz @2 (s: TestAllTypes);
} }
...@@ -629,6 +629,16 @@ interface TestTailCaller { ...@@ -629,6 +629,16 @@ interface TestTailCaller {
foo @0 (i :Int32, callee :TestTailCallee) -> TestTailCallee.TailResult; foo @0 (i :Int32, callee :TestTailCallee) -> TestTailCallee.TailResult;
} }
interface TestMoreStuff extends(TestCallOrder) {
# Catch-all type that contains lots of testing methods.
callFoo @0 (cap :TestInterface) -> (s: Text);
# Call `cap.foo()`, check the result, and return "bar".
callFooWhenResolved @1 (cap :TestInterface) -> (s: Text);
# Like callFoo but waits for `cap` to resolve first.
}
struct TestSturdyRefHostId { struct TestSturdyRefHostId {
host @0 :Text; host @0 :Text;
} }
...@@ -641,6 +651,7 @@ struct TestSturdyRefObjectId { ...@@ -641,6 +651,7 @@ struct TestSturdyRefObjectId {
testPipeline @2; testPipeline @2;
testTailCallee @3; testTailCallee @3;
testTailCaller @4; testTailCaller @4;
testMoreStuff @5;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment