Commit b1e502f7 authored by Kenton Varda's avatar Kenton Varda

Test and fix promise resolution across RPC.

parent a30c4171
......@@ -37,6 +37,9 @@ namespace capnp {
Capability::Client::Client(decltype(nullptr))
: hook(newBrokenCap("Called null capability.")) {}
Capability::Client::Client(kj::Exception&& exception)
: hook(newBrokenCap(kj::mv(exception))) {}
kj::Promise<void> Capability::Server::internalUnimplemented(
const char* actualInterfaceName, uint64_t requestedTypeId) {
KJ_FAIL_REQUIRE("Requested interface not implemented.", actualInterfaceName, requestedTypeId) {
......
......@@ -108,22 +108,36 @@ class Capability::Client {
public:
explicit Client(decltype(nullptr));
explicit Client(kj::Own<const ClientHook>&& hook);
// If you need to declare a Client before you have anything to assign to it (perhaps because
// the assignment is going to occur in an if/else scope), you can start by initializing it to
// `nullptr`. The resulting client is not meant to be called and throws exceptions from all
// methods.
template <typename T, typename = kj::EnableIf<kj::canConvert<T*, Capability::Server*>()>>
Client(kj::Own<T>&& server, const kj::EventLoop& loop = kj::EventLoop::current());
// Make a client capability that wraps the given server capability. The server's methods will
// only be executed in the given EventLoop, regardless of what thread calls the client's methods.
template <typename T, typename = kj::EnableIf<kj::canConvert<T*, Client*>()>>
Client(kj::Promise<T>&& promise, const kj::EventLoop& loop = kj::EventLoop::current());
// Make a client from a promise for a future client. The resulting client queues calls until the
// promise resolves.
Client(kj::Exception&& exception);
// Make a broken client that throws the given exception from all calls.
Client(const Client& other);
Client& operator=(const Client& other);
// Copies by reference counting. Warning: Refcounting is slow due to atomic ops. Try to only
// use move instead.
// Copies by reference counting. Warning: Refcounting may be slow due to atomic ops. Try to
// only use move instead.
Client(Client&&) = default;
Client& operator=(Client&&) = default;
// Move is fast.
explicit Client(kj::Own<const ClientHook>&& hook);
// For use by the RPC implementation: Wrap a ClientHook.
template <typename T>
typename T::Client castAs() const;
// Reinterpret the capability as implementing the given interface. Note that no error will occur
......@@ -533,6 +547,11 @@ inline Capability::Client::Client(kj::Own<const ClientHook>&& hook): hook(kj::mv
template <typename T, typename>
inline Capability::Client::Client(kj::Own<T>&& server, const kj::EventLoop& loop)
: hook(makeLocalClient(kj::mv(server), loop)) {}
template <typename T, typename>
inline Capability::Client::Client(kj::Promise<T>&& promise, const kj::EventLoop& loop)
: hook(newLocalPromiseClient(
promise.thenInAnyThread([](T&& t) { return kj::mv(t.hook); }),
loop)) {}
inline Capability::Client::Client(const Client& other): hook(other.hook->addRef()) {}
inline Capability::Client& Capability::Client::operator=(const Client& other) {
hook = other.hook->addRef();
......
......@@ -1307,6 +1307,13 @@ private:
" inline Client(::kj::Own<T>&& server,\n"
" const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n"
" : ::capnp::Capability::Client(::kj::mv(server), loop) {}\n"
" template <typename T,\n"
" typename = ::kj::EnableIf< ::kj::canConvert<T*, Client*>()>>\n"
" inline Client(::kj::Promise<T>&& promise,\n"
" const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n"
" : ::capnp::Capability::Client(::kj::mv(promise), loop) {}\n"
" inline Client(::kj::Exception&& exception)\n"
" : ::capnp::Capability::Client(::kj::mv(exception)) {}\n"
"\n",
KJ_MAP(m, methods) { return kj::mv(m.clientDecls); },
"\n"
......
......@@ -245,6 +245,8 @@ public:
return kj::heap<TestTailCalleeImpl>(callCount);
case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER:
return kj::heap<TestTailCallerImpl>(callCount);
case test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF:
return kj::heap<TestMoreStuffImpl>(callCount);
}
KJ_UNREACHABLE;
}
......@@ -383,6 +385,45 @@ TEST_F(RpcTest, TailCall) {
EXPECT_EQ(1, restorer.callCount);
}
TEST_F(RpcTest, PromiseResolve) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
int chainedCallCount = 0;
auto request = client.callFooRequest();
auto request2 = client.callFooWhenResolvedRequest();
auto paf = kj::newPromiseAndFulfiller<test::TestInterface::Client>();
{
auto fork = loop.fork(kj::mv(paf.promise));
request.setCap(test::TestInterface::Client(fork.addBranch(), loop));
request2.setCap(test::TestInterface::Client(fork.addBranch(), loop));
}
auto promise = request.send();
auto promise2 = request2.send();
{
// Make sure getCap() has been called on the server side by sending another call and waiting
// for it.
EXPECT_EQ(2, loop.wait(client.getCallSequenceRequest().send()).getN());
EXPECT_EQ(3, restorer.callCount);
}
// OK, now fulfill the local promise.
paf.fulfiller->fulfill(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(chainedCallCount), loop));
// We should now be able to wait for getCap() to finish.
EXPECT_EQ("bar", loop.wait(kj::mv(promise)).getS());
EXPECT_EQ("bar", loop.wait(kj::mv(promise2)).getS());
EXPECT_EQ(3, restorer.callCount);
EXPECT_EQ(2, chainedCallCount);
}
} // namespace
} // namespace _ (private)
} // namespace capnp
......@@ -49,6 +49,8 @@ public:
return kj::heap<TestTailCalleeImpl>(callCount);
case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER:
return kj::heap<TestTailCallerImpl>(callCount);
case test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF:
return kj::heap<TestMoreStuffImpl>(callCount);
}
KJ_UNREACHABLE;
}
......
......@@ -1121,11 +1121,11 @@ private:
lock->exportsByCap.erase(exp.clientHook);
exp.clientHook = kj::mv(resolution);
if (resolution->getBrand() != this) {
if (exp.clientHook->getBrand() != this) {
// We're resolving to a local capability. If we're resolving to a promise, we might be
// able to reuse our export table entry and avoid sending a message.
KJ_IF_MAYBE(promise, resolution->whenMoreResolved()) {
KJ_IF_MAYBE(promise, exp.clientHook->whenMoreResolved()) {
// We're replacing a promise with another local promise. In this case, we might actually
// be able to just reuse the existing export table entry to represent the new promise --
// unless it already has an entry. Let's check.
......@@ -1147,7 +1147,7 @@ private:
messageSizeHint<rpc::Resolve>() + sizeInWords<rpc::CapDescriptor>() + 16);
auto resolve = message->getBody().initAs<rpc::Message>().initResolve();
resolve.setPromiseId(exportId);
writeDescriptor(*resolution, resolve.initCap(), *this->tables.lockExclusive());
writeDescriptor(*exp.clientHook, resolve.initCap(), *lock);
message->send();
return kj::READY_NOW;
......@@ -2670,9 +2670,9 @@ private:
}
// Extend the resolution chain.
auto oldTail = kj::mv(lock->resolutionChainTail);
lock->resolutionChainTail = oldTail->addResolve(resolve.getPromiseId(), kj::mv(replacement));
lock.release(); // in case oldTail is destroyed
oldResolutionChainTail = kj::mv(lock->resolutionChainTail);
lock->resolutionChainTail = oldResolutionChainTail->addResolve(
resolve.getPromiseId(), kj::mv(replacement));
// If the import is on the table, fulfill it.
KJ_IF_MAYBE(import, lock->imports.find(resolve.getPromiseId())) {
......
......@@ -971,5 +971,54 @@ kj::Promise<void> TestTailCalleeImpl::fooAdvanced(
return kj::READY_NOW;
}
TestMoreStuffImpl::TestMoreStuffImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestMoreStuffImpl::getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) {
result.setN(callCount++);
return kj::READY_NOW;
}
::kj::Promise<void> TestMoreStuffImpl::callFoo(
test::TestMoreStuff::CallFooParams::Reader params,
test::TestMoreStuff::CallFooResults::Builder result) {
++callCount;
auto cap = params.getCap();
auto request = cap.fooRequest();
request.setI(123);
request.setJ(true);
return request.send().then(
[result](capnp::Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
result.setS("bar");
});
}
kj::Promise<void> TestMoreStuffImpl::callFooWhenResolved(
test::TestMoreStuff::CallFooWhenResolvedParams::Reader params,
test::TestMoreStuff::CallFooWhenResolvedResults::Builder result) {
++callCount;
auto cap = params.getCap();
return cap.whenResolved().then([cap,result]() mutable {
auto request = cap.fooRequest();
request.setI(123);
request.setJ(true);
return request.send().then(
[result](capnp::Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
result.setS("bar");
});
});
}
} // namespace _ (private)
} // namespace capnp
......@@ -221,6 +221,26 @@ private:
int& callCount;
};
class TestMoreStuffImpl final: public test::TestMoreStuff::Server {
public:
TestMoreStuffImpl(int& callCount);
kj::Promise<void> getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) override;
::kj::Promise<void> callFoo(
test::TestMoreStuff::CallFooParams::Reader params,
test::TestMoreStuff::CallFooResults::Builder result) override;
kj::Promise<void> callFooWhenResolved(
test::TestMoreStuff::CallFooWhenResolvedParams::Reader params,
test::TestMoreStuff::CallFooWhenResolvedResults::Builder result) override;
private:
int& callCount;
};
} // namespace _ (private)
} // namespace capnp
......
......@@ -590,7 +590,7 @@ const derivedConstant :TestAllTypes = (
structList = TestConstants.structListConst);
interface TestInterface {
foo @0 (i :UInt32, j :Bool) -> (x: Text);
foo @0 (i :UInt32, j :Bool) -> (x :Text);
bar @1 () -> ();
baz @2 (s: TestAllTypes);
}
......@@ -629,6 +629,16 @@ interface TestTailCaller {
foo @0 (i :Int32, callee :TestTailCallee) -> TestTailCallee.TailResult;
}
interface TestMoreStuff extends(TestCallOrder) {
# Catch-all type that contains lots of testing methods.
callFoo @0 (cap :TestInterface) -> (s: Text);
# Call `cap.foo()`, check the result, and return "bar".
callFooWhenResolved @1 (cap :TestInterface) -> (s: Text);
# Like callFoo but waits for `cap` to resolve first.
}
struct TestSturdyRefHostId {
host @0 :Text;
}
......@@ -641,6 +651,7 @@ struct TestSturdyRefObjectId {
testPipeline @2;
testTailCallee @3;
testTailCaller @4;
testMoreStuff @5;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment