Commit 6b8b8c71 authored by Kenton Varda's avatar Kenton Varda

Test and fix cancellation and release.

parent b1e502f7
...@@ -6,7 +6,7 @@ ifeq ($(CXX),clang++) ...@@ -6,7 +6,7 @@ ifeq ($(CXX),clang++)
# Clang's verbose diagnostics don't play nice with the Ekam Eclipse plugin's error parsing, # Clang's verbose diagnostics don't play nice with the Ekam Eclipse plugin's error parsing,
# so disable them. Also enable some useful Clang warnings (dunno if GCC supports them, and don't # so disable them. Also enable some useful Clang warnings (dunno if GCC supports them, and don't
# care). # care).
EXTRA_FLAG=-fno-caret-diagnostics -Wglobal-constructors -Wextra-semi EXTRA_FLAG=-fno-caret-diagnostics -Wglobal-constructors -Wextra-semi -Werror=return-type
# EXTRA_FLAG=-fno-caret-diagnostics -Weverything -Wno-c++98-compat -Wno-shadow -Wno-c++98-compat-pedantic -Wno-padded -Wno-weak-vtables -Wno-gnu -Wno-unused-parameter -Wno-sign-conversion -Wno-undef -Wno-shorten-64-to-32 -Wno-conversion -Wno-unreachable-code -Wno-non-virtual-dtor # EXTRA_FLAG=-fno-caret-diagnostics -Weverything -Wno-c++98-compat -Wno-shadow -Wno-c++98-compat-pedantic -Wno-padded -Wno-weak-vtables -Wno-gnu -Wno-unused-parameter -Wno-sign-conversion -Wno-undef -Wno-shorten-64-to-32 -Wno-conversion -Wno-unreachable-code -Wno-non-virtual-dtor
else else
EXTRA_FLAG= EXTRA_FLAG=
......
...@@ -131,7 +131,9 @@ public: ...@@ -131,7 +131,9 @@ public:
return kj::mv(paf.promise); return kj::mv(paf.promise);
} }
void allowAsyncCancellation() override { void allowAsyncCancellation() override {
// ignored for local calls releaseParams();
// TODO(soon): Implement.
} }
bool isCanceled() override { bool isCanceled() override {
return false; return false;
......
...@@ -107,7 +107,7 @@ class Capability::Client { ...@@ -107,7 +107,7 @@ class Capability::Client {
// Base type for capability clients. // Base type for capability clients.
public: public:
explicit Client(decltype(nullptr)); Client(decltype(nullptr));
// If you need to declare a Client before you have anything to assign to it (perhaps because // If you need to declare a Client before you have anything to assign to it (perhaps because
// the assignment is going to occur in an if/else scope), you can start by initializing it to // the assignment is going to occur in an if/else scope), you can start by initializing it to
// `nullptr`. The resulting client is not meant to be called and throws exceptions from all // `nullptr`. The resulting client is not meant to be called and throws exceptions from all
...@@ -249,6 +249,10 @@ public: ...@@ -249,6 +249,10 @@ public:
// executing on a local thread. The method must perform an asynchronous operation or call // executing on a local thread. The method must perform an asynchronous operation or call
// `EventLoop::current().runLater()` to yield control. // `EventLoop::current().runLater()` to yield control.
// //
// This method implies `releaseParams()` -- you cannot allow async cancellation while still
// holding the params. (This is because of a quirk of the current RPC implementation; in theory
// it could be fixed.)
//
// TODO(soon): This doesn't work for local calls, because there's no one to own the object // TODO(soon): This doesn't work for local calls, because there's no one to own the object
// in the meantime. What do we do about that? Is the security issue here actually a real // in the meantime. What do we do about that? Is the security issue here actually a real
// threat? Maybe we can just always enable cancellation. After all, you need to be fault // threat? Maybe we can just always enable cancellation. After all, you need to be fault
......
...@@ -1205,6 +1205,11 @@ private: ...@@ -1205,6 +1205,11 @@ private:
kj::String resultType = resultProto.getScopeId() == 0 ? kj::String resultType = resultProto.getScopeId() == 0 ?
kj::str(interfaceName, "::", titleCase, "Results") : cppFullName(resultSchema).flatten(); kj::str(interfaceName, "::", titleCase, "Results") : cppFullName(resultSchema).flatten();
kj::String shortParamType = paramProto.getScopeId() == 0 ?
kj::str(titleCase, "Params") : cppFullName(paramSchema).flatten();
kj::String shortResultType = resultProto.getScopeId() == 0 ?
kj::str(titleCase, "Results") : cppFullName(resultSchema).flatten();
auto interfaceProto = method.getContainingInterface().getProto(); auto interfaceProto = method.getContainingInterface().getProto();
uint64_t interfaceId = interfaceProto.getId(); uint64_t interfaceId = interfaceProto.getId();
auto interfaceIdHex = kj::hex(interfaceId); auto interfaceIdHex = kj::hex(interfaceId);
...@@ -1216,11 +1221,15 @@ private: ...@@ -1216,11 +1221,15 @@ private:
" unsigned int firstSegmentWordSize = 0) const;\n"), " unsigned int firstSegmentWordSize = 0) const;\n"),
kj::strTree( kj::strTree(
paramProto.getScopeId() != 0 ? kj::strTree() : kj::strTree(
" typedef ", paramType, " ", titleCase, "Params;\n"),
resultProto.getScopeId() != 0 ? kj::strTree() : kj::strTree(
" typedef ", resultType, " ", titleCase, "Results;\n"),
" virtual ::kj::Promise<void> ", name, "(\n" " virtual ::kj::Promise<void> ", name, "(\n"
" ", paramType, "::Reader params,\n" " ", shortParamType, "::Reader params,\n"
" ", resultType, "::Builder result);\n" " ", shortResultType, "::Builder result);\n"
" virtual ::kj::Promise<void> ", name, "Advanced(\n" " virtual ::kj::Promise<void> ", name, "Advanced(\n"
" ::capnp::CallContext<", paramType, ", ", resultType, "> context);\n"), " ::capnp::CallContext<", shortParamType, ", ", shortResultType, "> context);\n"),
kj::strTree(), kj::strTree(),
...@@ -1298,17 +1307,15 @@ private: ...@@ -1298,17 +1307,15 @@ private:
" typedef ", fullName, " Calls;\n" " typedef ", fullName, " Calls;\n"
" typedef ", fullName, " Reads;\n" " typedef ", fullName, " Reads;\n"
"\n" "\n"
" inline explicit Client(decltype(nullptr))\n" " inline Client(decltype(nullptr))\n"
" : ::capnp::Capability::Client(nullptr) {}\n" " : ::capnp::Capability::Client(nullptr) {}\n"
" inline explicit Client(::kj::Own<const ::capnp::ClientHook>&& hook)\n" " inline explicit Client(::kj::Own<const ::capnp::ClientHook>&& hook)\n"
" : ::capnp::Capability::Client(::kj::mv(hook)) {}\n" " : ::capnp::Capability::Client(::kj::mv(hook)) {}\n"
" template <typename T,\n" " template <typename T, typename = ::kj::EnableIf< ::kj::canConvert<T*, Server*>()>>\n"
" typename = ::kj::EnableIf< ::kj::canConvert<T*, Server*>()>>\n"
" inline Client(::kj::Own<T>&& server,\n" " inline Client(::kj::Own<T>&& server,\n"
" const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n" " const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n"
" : ::capnp::Capability::Client(::kj::mv(server), loop) {}\n" " : ::capnp::Capability::Client(::kj::mv(server), loop) {}\n"
" template <typename T,\n" " template <typename T, typename = ::kj::EnableIf< ::kj::canConvert<T*, Client*>()>>\n"
" typename = ::kj::EnableIf< ::kj::canConvert<T*, Client*>()>>\n"
" inline Client(::kj::Promise<T>&& promise,\n" " inline Client(::kj::Promise<T>&& promise,\n"
" const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n" " const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n"
" : ::capnp::Capability::Client(::kj::mv(promise), loop) {}\n" " : ::capnp::Capability::Client(::kj::mv(promise), loop) {}\n"
......
...@@ -42,10 +42,9 @@ public: ...@@ -42,10 +42,9 @@ public:
TestNetworkAdapter& add(kj::StringPtr name); TestNetworkAdapter& add(kj::StringPtr name);
kj::Maybe<const TestNetworkAdapter&> find(kj::StringPtr name) const { kj::Maybe<TestNetworkAdapter&> find(kj::StringPtr name) {
auto lock = map.lockShared(); auto iter = map.find(name);
auto iter = lock->find(name); if (iter == map.end()) {
if (iter == lock->end()) {
return nullptr; return nullptr;
} else { } else {
return *iter->second; return *iter->second;
...@@ -53,7 +52,7 @@ public: ...@@ -53,7 +52,7 @@ public:
} }
private: private:
kj::MutexGuarded<std::map<kj::StringPtr, kj::Own<TestNetworkAdapter>>> map; std::map<kj::StringPtr, kj::Own<TestNetworkAdapter>> map;
}; };
typedef VatNetwork< typedef VatNetwork<
...@@ -62,13 +61,16 @@ typedef VatNetwork< ...@@ -62,13 +61,16 @@ typedef VatNetwork<
class TestNetworkAdapter final: public TestNetworkAdapterBase { class TestNetworkAdapter final: public TestNetworkAdapterBase {
public: public:
TestNetworkAdapter(const TestNetwork& network): network(network) {} TestNetworkAdapter(TestNetwork& network): network(network) {}
uint getSentCount() { return sent; }
uint getReceivedCount() { return received; }
typedef TestNetworkAdapterBase::Connection Connection; typedef TestNetworkAdapterBase::Connection Connection;
class ConnectionImpl final: public Connection, public kj::Refcounted { class ConnectionImpl final: public Connection, public kj::Refcounted {
public: public:
ConnectionImpl(const char* name): name(name) {} ConnectionImpl(TestNetworkAdapter& network, const char* name): network(network), name(name) {}
void attach(ConnectionImpl& other) { void attach(ConnectionImpl& other) {
KJ_REQUIRE(partner == nullptr); KJ_REQUIRE(partner == nullptr);
...@@ -100,6 +102,8 @@ public: ...@@ -100,6 +102,8 @@ public:
return message->message.getRoot<ObjectPointer>(); return message->message.getRoot<ObjectPointer>();
} }
void send() override { void send() override {
++connection.network.sent;
kj::String msg = kj::str(connection.name, ": ", message->message.getRoot<rpc::Message>()); kj::String msg = kj::str(connection.name, ": ", message->message.getRoot<rpc::Message>());
//KJ_DBG(msg); //KJ_DBG(msg);
...@@ -108,6 +112,7 @@ public: ...@@ -108,6 +112,7 @@ public:
if (lock->fulfillers.empty()) { if (lock->fulfillers.empty()) {
lock->messages.push(kj::mv(message)); lock->messages.push(kj::mv(message));
} else { } else {
++connection.network.received;
lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message))); lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message)));
lock->fulfillers.pop(); lock->fulfillers.pop();
} }
...@@ -129,6 +134,7 @@ public: ...@@ -129,6 +134,7 @@ public:
lock->fulfillers.push(kj::mv(paf.fulfiller)); lock->fulfillers.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise); return kj::mv(paf.promise);
} else { } else {
++network.received;
auto result = kj::mv(lock->messages.front()); auto result = kj::mv(lock->messages.front());
lock->messages.pop(); lock->messages.pop();
return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result)); return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result));
...@@ -149,6 +155,7 @@ public: ...@@ -149,6 +155,7 @@ public:
} }
private: private:
TestNetworkAdapter& network;
const char* name; const char* name;
kj::Maybe<ConnectionImpl&> partner; kj::Maybe<ConnectionImpl&> partner;
...@@ -161,7 +168,7 @@ public: ...@@ -161,7 +168,7 @@ public:
kj::Maybe<kj::Own<Connection>> connectToRefHost( kj::Maybe<kj::Own<Connection>> connectToRefHost(
test::TestSturdyRefHostId::Reader hostId) override { test::TestSturdyRefHostId::Reader hostId) override {
const TestNetworkAdapter& dst = KJ_REQUIRE_NONNULL(network.find(hostId.getHost())); TestNetworkAdapter& dst = KJ_REQUIRE_NONNULL(network.find(hostId.getHost()));
kj::Locked<State> myLock; kj::Locked<State> myLock;
kj::Locked<State> dstLock; kj::Locked<State> dstLock;
...@@ -176,8 +183,8 @@ public: ...@@ -176,8 +183,8 @@ public:
auto iter = myLock->connections.find(&dst); auto iter = myLock->connections.find(&dst);
if (iter == myLock->connections.end()) { if (iter == myLock->connections.end()) {
auto local = kj::refcounted<ConnectionImpl>("client"); auto local = kj::refcounted<ConnectionImpl>(*this, "client");
auto remote = kj::refcounted<ConnectionImpl>("server"); auto remote = kj::refcounted<ConnectionImpl>(dst, "server");
local->attach(*remote); local->attach(*remote);
myLock->connections[&dst] = kj::addRef(*local); myLock->connections[&dst] = kj::addRef(*local);
...@@ -210,7 +217,9 @@ public: ...@@ -210,7 +217,9 @@ public:
} }
private: private:
const TestNetwork& network; TestNetwork& network;
uint sent = 0;
uint received = 0;
struct State { struct State {
std::map<const TestNetworkAdapter*, kj::Own<ConnectionImpl>> connections; std::map<const TestNetworkAdapter*, kj::Own<ConnectionImpl>> connections;
...@@ -223,8 +232,7 @@ private: ...@@ -223,8 +232,7 @@ private:
TestNetwork::~TestNetwork() noexcept(false) {} TestNetwork::~TestNetwork() noexcept(false) {}
TestNetworkAdapter& TestNetwork::add(kj::StringPtr name) { TestNetworkAdapter& TestNetwork::add(kj::StringPtr name) {
auto lock = map.lockExclusive(); return *(map[name] = kj::heap<TestNetworkAdapter>(*this));
return *((*lock)[name] = kj::heap<TestNetworkAdapter>(*this));
} }
// ======================================================================================= // =======================================================================================
...@@ -257,6 +265,8 @@ protected: ...@@ -257,6 +265,8 @@ protected:
TestNetwork network; TestNetwork network;
TestRestorer restorer; TestRestorer restorer;
kj::SimpleEventLoop loop; kj::SimpleEventLoop loop;
TestNetworkAdapter& clientNetwork;
TestNetworkAdapter& serverNetwork;
RpcSystem<test::TestSturdyRefHostId> rpcClient; RpcSystem<test::TestSturdyRefHostId> rpcClient;
RpcSystem<test::TestSturdyRefHostId> rpcServer; RpcSystem<test::TestSturdyRefHostId> rpcServer;
...@@ -271,8 +281,10 @@ protected: ...@@ -271,8 +281,10 @@ protected:
} }
RpcTest() RpcTest()
: rpcClient(makeRpcClient(network.add("client"), loop)), : clientNetwork(network.add("client")),
rpcServer(makeRpcServer(network.add("server"), restorer, loop)) {} serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork, loop)),
rpcServer(makeRpcServer(serverNetwork, restorer, loop)) {}
~RpcTest() noexcept {} ~RpcTest() noexcept {}
// Need to declare this with explicit noexcept otherwise it conflicts with testing::Test::~Test. // Need to declare this with explicit noexcept otherwise it conflicts with testing::Test::~Test.
...@@ -405,12 +417,10 @@ TEST_F(RpcTest, PromiseResolve) { ...@@ -405,12 +417,10 @@ TEST_F(RpcTest, PromiseResolve) {
auto promise = request.send(); auto promise = request.send();
auto promise2 = request2.send(); auto promise2 = request2.send();
{ // Make sure getCap() has been called on the server side by sending another call and waiting
// Make sure getCap() has been called on the server side by sending another call and waiting // for it.
// for it. EXPECT_EQ(2, loop.wait(client.getCallSequenceRequest().send()).getN());
EXPECT_EQ(2, loop.wait(client.getCallSequenceRequest().send()).getN()); EXPECT_EQ(3, restorer.callCount);
EXPECT_EQ(3, restorer.callCount);
}
// OK, now fulfill the local promise. // OK, now fulfill the local promise.
paf.fulfiller->fulfill(test::TestInterface::Client( paf.fulfiller->fulfill(test::TestInterface::Client(
...@@ -424,6 +434,159 @@ TEST_F(RpcTest, PromiseResolve) { ...@@ -424,6 +434,159 @@ TEST_F(RpcTest, PromiseResolve) {
EXPECT_EQ(2, chainedCallCount); EXPECT_EQ(2, chainedCallCount);
} }
class TestCapDestructor final: public test::TestInterface::Server {
public:
TestCapDestructor(kj::Own<kj::PromiseFulfiller<void>>&& fulfiller)
: fulfiller(kj::mv(fulfiller)), impl(dummy) {}
~TestCapDestructor() {
fulfiller->fulfill();
}
kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) {
return impl.foo(params, result);
}
private:
kj::Own<kj::PromiseFulfiller<void>> fulfiller;
int dummy = 0;
TestInterfaceImpl impl;
};
TEST_F(RpcTest, RetainAndRelease) {
auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false;
auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; });
destructionPromise.eagerlyEvaluate(loop);
{
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
{
auto request = client.holdRequest();
request.setCap(test::TestInterface::Client(
kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller)), loop));
loop.wait(request.send());
}
// Do some other call to add a round trip.
EXPECT_EQ(1, loop.wait(client.getCallSequenceRequest().send()).getN());
// Shouldn't be destroyed because it's being held by the server.
EXPECT_FALSE(destroyed);
// We can ask it to call the held capability.
EXPECT_EQ("bar", loop.wait(client.callHeldRequest().send()).getS());
{
// We can get the cap back from it.
auto capCopy = loop.wait(client.getHeldRequest().send()).getCap();
{
// And call it, without any network communications.
uint oldSentCount = clientNetwork.getSentCount();
auto request = capCopy.fooRequest();
request.setI(123);
request.setJ(true);
EXPECT_EQ("foo", loop.wait(request.send()).getX());
EXPECT_EQ(oldSentCount, clientNetwork.getSentCount());
}
{
// We can send another copy of the same cap to another method, and it works.
auto request = client.callFooRequest();
request.setCap(capCopy);
EXPECT_EQ("bar", loop.wait(request.send()).getS());
}
}
// Give some time to settle.
EXPECT_EQ(5, loop.wait(client.getCallSequenceRequest().send()).getN());
EXPECT_EQ(6, loop.wait(client.getCallSequenceRequest().send()).getN());
EXPECT_EQ(7, loop.wait(client.getCallSequenceRequest().send()).getN());
// Can't be destroyed, we haven't released it.
EXPECT_FALSE(destroyed);
}
// We released our client, which should cause the server to be released, which in turn will
// release the cap pointing back to us.
loop.wait(kj::mv(destructionPromise));
EXPECT_TRUE(destroyed);
}
TEST_F(RpcTest, Cancel) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false;
auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; });
destructionPromise.eagerlyEvaluate(loop);
{
auto request = client.neverReturnRequest();
request.setCap(test::TestInterface::Client(
kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller)), loop));
{
auto responsePromise = request.send();
// Allow some time to settle.
EXPECT_EQ(1, loop.wait(client.getCallSequenceRequest().send()).getN());
EXPECT_EQ(2, loop.wait(client.getCallSequenceRequest().send()).getN());
// The cap shouldn't have been destroyed yet because the call never returned.
EXPECT_FALSE(destroyed);
}
}
// Now the cap should be released.
loop.wait(kj::mv(destructionPromise));
EXPECT_TRUE(destroyed);
}
TEST_F(RpcTest, SendTwice) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false;
auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; });
destructionPromise.eagerlyEvaluate(loop);
auto cap = test::TestInterface::Client(kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller)), loop);
{
auto request = client.callFooRequest();
request.setCap(cap);
EXPECT_EQ("bar", loop.wait(request.send()).getS());
}
// Allow some time for the server to release `cap`.
EXPECT_EQ(1, loop.wait(client.getCallSequenceRequest().send()).getN());
{
// More requests with the same cap.
auto request = client.callFooRequest();
auto request2 = client.callFooRequest();
request.setCap(cap);
request2.setCap(kj::mv(cap));
auto promise = request.send();
auto promise2 = request2.send();
EXPECT_EQ("bar", loop.wait(kj::mv(promise)).getS());
EXPECT_EQ("bar", loop.wait(kj::mv(promise2)).getS());
}
// Now the cap should be released.
loop.wait(kj::mv(destructionPromise));
EXPECT_TRUE(destroyed);
}
} // namespace } // namespace
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -690,9 +690,6 @@ private: ...@@ -690,9 +690,6 @@ private:
kj::Own<CallContextHook>&& context) const override { kj::Own<CallContextHook>&& context) const override {
// Implement call() by copying params and results messages. // Implement call() by copying params and results messages.
// We can and should propagate cancellation.
context->allowAsyncCancellation();
auto params = context->getParams(); auto params = context->getParams();
size_t sizeHint = params.targetSizeInWords(); size_t sizeHint = params.targetSizeInWords();
...@@ -709,9 +706,12 @@ private: ...@@ -709,9 +706,12 @@ private:
auto request = newCall(interfaceId, methodId, sizeHint); auto request = newCall(interfaceId, methodId, sizeHint);
request.set(context->getParams()); request.set(params);
context->releaseParams(); context->releaseParams();
// We can and should propagate cancellation.
context->allowAsyncCancellation();
auto promise = request.send(); auto promise = request.send();
auto pipeline = promise.releasePipelineHook(); auto pipeline = promise.releasePipelineHook();
...@@ -738,7 +738,7 @@ private: ...@@ -738,7 +738,7 @@ private:
return kj::addRef(*this); return kj::addRef(*this);
} }
const void* getBrand() const override { const void* getBrand() const override {
return &connectionState; return connectionState.get();
} }
protected: protected:
...@@ -1183,6 +1183,10 @@ private: ...@@ -1183,6 +1183,10 @@ private:
} }
} }
void doneExtracting() {
resolutionChain = nullptr;
}
uint retainedListSizeHint(bool final) { uint retainedListSizeHint(bool final) {
// Get the expected size of the retained caps list, in words. If `final` is true, then it // Get the expected size of the retained caps list, in words. If `final` is true, then it
// is known that no more caps will be extracted after this point, so an exact value can be // is known that no more caps will be extracted after this point, so an exact value can be
...@@ -2029,6 +2033,7 @@ private: ...@@ -2029,6 +2033,7 @@ private:
} }
void releaseParams() override { void releaseParams() override {
request = nullptr; request = nullptr;
requestCapExtractor.doneExtracting();
} }
ObjectPointer::Builder getResults(uint firstSegmentWordSize) override { ObjectPointer::Builder getResults(uint firstSegmentWordSize) override {
KJ_IF_MAYBE(r, response) { KJ_IF_MAYBE(r, response) {
...@@ -2108,7 +2113,17 @@ private: ...@@ -2108,7 +2113,17 @@ private:
return kj::mv(paf.promise); return kj::mv(paf.promise);
} }
void allowAsyncCancellation() override { void allowAsyncCancellation() override {
if (threadAcceptingCancellation != nullptr) { if (threadAcceptingCancellation == nullptr) {
// TODO(cleanup): We need to drop the request because it is holding on to the resolution
// chain which in turn holds on to the pipeline which holds on to this object thus
// preventing cancellation from working. This is a bit silly because obviously our
// request couldn't contain PromisedAnswers referring to itself, but currently the chain
// is a linear list and we have no way to tell that a reference to the chain taken before
// a call started doesn't really need to hold the call open. To fix this we'd presumably
// need to make the answer table snapshot-able and have CapExtractorImpl take a snapshot
// at creation.
releaseParams();
threadAcceptingCancellation = &kj::EventLoop::current(); threadAcceptingCancellation = &kj::EventLoop::current();
if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) == if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) ==
...@@ -2156,10 +2171,6 @@ private: ...@@ -2156,10 +2171,6 @@ private:
// When both flags are set, the cancellation process will begin. Must be manipulated atomically // When both flags are set, the cancellation process will begin. Must be manipulated atomically
// as it may be accessed from multiple threads. // as it may be accessed from multiple threads.
mutable kj::Promise<void> deferredCancellation = nullptr;
// Cancellation operation scheduled by cancelLater(). Must only be scheduled once, from one
// thread.
kj::EventLoop* threadAcceptingCancellation = nullptr; kj::EventLoop* threadAcceptingCancellation = nullptr;
// EventLoop for the thread that first called allowAsyncCancellation(). We store this as an // EventLoop for the thread that first called allowAsyncCancellation(). We store this as an
// optimization: if the application thread is independent from the network thread, we'd rather // optimization: if the application thread is independent from the network thread, we'd rather
...@@ -2176,31 +2187,29 @@ private: ...@@ -2176,31 +2187,29 @@ private:
// this call, shortly. We have to do it asynchronously because the caller might hold // this call, shortly. We have to do it asynchronously because the caller might hold
// arbitrary locks or might in fact be part of the task being canceled. // arbitrary locks or might in fact be part of the task being canceled.
deferredCancellation = threadAcceptingCancellation->evalLater([this]() { connectionState->tasks.add(threadAcceptingCancellation->evalLater(
// Make sure we don't accidentally delete ourselves in the process of canceling, since the kj::mvCapture(kj::addRef(*this), [](kj::Own<const RpcCallContext>&& self) {
// last reference to the context may be owned by the asyncOp.
auto self = kj::addRef(*this);
// Extract from the answer table the promise representing the executing call. // Extract from the answer table the promise representing the executing call.
kj::Promise<void> asyncOp = nullptr; kj::Promise<void> asyncOp = nullptr;
{ {
auto lock = connectionState->tables.lockExclusive(); auto lock = self->connectionState->tables.lockExclusive();
asyncOp = kj::mv(lock->answers[questionId].asyncOp); asyncOp = kj::mv(lock->answers[self->questionId].asyncOp);
} }
// Delete the promise, thereby canceling the operation. Note that if a continuation is // When `asyncOp` goes out of scope, if it holds the last reference to the ongoing
// running in another thread, this line blocks waiting for it to complete. This is why // operation, that operation will be canceled. Note that if a continuation is
// we try to schedule doCancel() on the application thread, so that it won't need to block. // running in another thread, the destructor will block waiting for it to complete. This
asyncOp = nullptr; // is why we try to schedule doCancel() on the application thread, so that it won't need
// to block.
// The `Return` will be sent when the context is destroyed. That might be right now, when // The `Return` will be sent when the context is destroyed. That might be right now, when
// `self` goes out of scope. However, it is also possible that the pipeline is still in // `self` and `asyncOp` go out of scope. However, it is also possible that the pipeline
// use: although `Finish` removes the pipeline reference from the answer table, it might // is still in use: although `Finish` removes the pipeline reference from the answer
// be held by an outstanding pipelined call, or by a pipelined promise that was echoed back // table, it might be held by an outstanding pipelined call, or by a pipelined promise that
// to us later (as a `receiverAnswer` in a `CapDescriptor`), or it may be held in the // was echoed back to us later (as a `receiverAnswer` in a `CapDescriptor`), or it may be
// resolution chain. In all of these cases, the call will continue running until those // held in the resolution chain. In all of these cases, the call will continue running
// references are dropped or the call completes. // until those references are dropped or the call completes.
}); })));
} }
bool isFirstResponder() { bool isFirstResponder() {
...@@ -2423,10 +2432,12 @@ private: ...@@ -2423,10 +2432,12 @@ private:
answer.pipeline = kj::mv(promiseAndPipeline.pipeline); answer.pipeline = kj::mv(promiseAndPipeline.pipeline);
if (redirectResults) { if (redirectResults) {
answer.redirectedResults = promiseAndPipeline.promise.then( auto promise = promiseAndPipeline.promise.then(
kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) { kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) {
return context->consumeRedirectedResponse(); return context->consumeRedirectedResponse();
})); }));
promise.eagerlyEvaluate(eventLoop);
answer.redirectedResults = kj::mv(promise);
} else { } else {
// Hack: Both the success and error continuations need to use the context. We could // Hack: Both the success and error continuations need to use the context. We could
// refcount, but both will be destroyed at the same time anyway. // refcount, but both will be destroyed at the same time anyway.
...@@ -2688,7 +2699,8 @@ private: ...@@ -2688,7 +2699,8 @@ private:
} }
void handleRelease(const rpc::Release::Reader& release) { void handleRelease(const rpc::Release::Reader& release) {
releaseExport(*tables.lockExclusive(), release.getId(), release.getReferenceCount()); auto chainToRelease = releaseExport(
*tables.lockExclusive(), release.getId(), release.getReferenceCount());
} }
static kj::Own<ResolutionChain> releaseExport(Tables& lockedTables, ExportId id, uint refcount) { static kj::Own<ResolutionChain> releaseExport(Tables& lockedTables, ExportId id, uint refcount) {
......
...@@ -862,9 +862,7 @@ void checkDynamicTestMessageAllZero(DynamicStruct::Reader reader) { ...@@ -862,9 +862,7 @@ void checkDynamicTestMessageAllZero(DynamicStruct::Reader reader) {
TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {} TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {}
::kj::Promise<void> TestInterfaceImpl::foo( kj::Promise<void> TestInterfaceImpl::foo(FooParams::Reader params, FooResults::Builder result) {
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) {
++callCount; ++callCount;
EXPECT_EQ(123, params.getI()); EXPECT_EQ(123, params.getI());
EXPECT_TRUE(params.getJ()); EXPECT_TRUE(params.getJ());
...@@ -872,9 +870,7 @@ TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {} ...@@ -872,9 +870,7 @@ TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {}
return kj::READY_NOW; return kj::READY_NOW;
} }
::kj::Promise<void> TestInterfaceImpl::bazAdvanced( kj::Promise<void> TestInterfaceImpl::bazAdvanced(CallContext<BazParams, BazResults> context) {
::capnp::CallContext<test::TestInterface::BazParams,
test::TestInterface::BazResults> context) {
++callCount; ++callCount;
auto params = context.getParams(); auto params = context.getParams();
checkTestMessage(params.getS()); checkTestMessage(params.getS());
...@@ -886,9 +882,7 @@ TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {} ...@@ -886,9 +882,7 @@ TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {}
TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {} TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {}
::kj::Promise<void> TestExtendsImpl::foo( kj::Promise<void> TestExtendsImpl::foo(FooParams::Reader params, FooResults::Builder result) {
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) {
++callCount; ++callCount;
EXPECT_EQ(321, params.getI()); EXPECT_EQ(321, params.getI());
EXPECT_FALSE(params.getJ()); EXPECT_FALSE(params.getJ());
...@@ -896,8 +890,8 @@ TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {} ...@@ -896,8 +890,8 @@ TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {}
return kj::READY_NOW; return kj::READY_NOW;
} }
::kj::Promise<void> TestExtendsImpl::graultAdvanced( kj::Promise<void> TestExtendsImpl::graultAdvanced(
::capnp::CallContext<test::TestExtends::GraultParams, test::TestAllTypes> context) { CallContext<GraultParams, test::TestAllTypes> context) {
++callCount; ++callCount;
context.releaseParams(); context.releaseParams();
...@@ -908,9 +902,8 @@ TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {} ...@@ -908,9 +902,8 @@ TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {}
TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {} TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {}
::kj::Promise<void> TestPipelineImpl::getCapAdvanced( kj::Promise<void> TestPipelineImpl::getCapAdvanced(
capnp::CallContext<test::TestPipeline::GetCapParams, CallContext<GetCapParams, GetCapResults> context) {
test::TestPipeline::GetCapResults> context) {
++callCount; ++callCount;
auto params = context.getParams(); auto params = context.getParams();
...@@ -924,7 +917,7 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {} ...@@ -924,7 +917,7 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {}
request.setJ(true); request.setJ(true);
return request.send().then( return request.send().then(
[this,context](capnp::Response<test::TestInterface::FooResults>&& response) mutable { [this,context](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX()); EXPECT_EQ("foo", response.getX());
auto result = context.getResults(); auto result = context.getResults();
...@@ -934,8 +927,7 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {} ...@@ -934,8 +927,7 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {}
} }
kj::Promise<void> TestCallOrderImpl::getCallSequence( kj::Promise<void> TestCallOrderImpl::getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params, GetCallSequenceParams::Reader params, GetCallSequenceResults::Builder result) {
test::TestCallOrder::GetCallSequenceResults::Builder result) {
result.setN(count++); result.setN(count++);
return kj::READY_NOW; return kj::READY_NOW;
} }
...@@ -943,8 +935,7 @@ kj::Promise<void> TestCallOrderImpl::getCallSequence( ...@@ -943,8 +935,7 @@ kj::Promise<void> TestCallOrderImpl::getCallSequence(
TestTailCallerImpl::TestTailCallerImpl(int& callCount): callCount(callCount) {} TestTailCallerImpl::TestTailCallerImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestTailCallerImpl::fooAdvanced( kj::Promise<void> TestTailCallerImpl::fooAdvanced(
capnp::CallContext<test::TestTailCaller::FooParams, CallContext<FooParams, test::TestTailCallee::TailResult> context) {
test::TestTailCallee::TailResult> context) {
++callCount; ++callCount;
auto params = context.getParams(); auto params = context.getParams();
...@@ -957,8 +948,7 @@ kj::Promise<void> TestTailCallerImpl::fooAdvanced( ...@@ -957,8 +948,7 @@ kj::Promise<void> TestTailCallerImpl::fooAdvanced(
TestTailCalleeImpl::TestTailCalleeImpl(int& callCount): callCount(callCount) {} TestTailCalleeImpl::TestTailCalleeImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestTailCalleeImpl::fooAdvanced( kj::Promise<void> TestTailCalleeImpl::fooAdvanced(
capnp::CallContext<test::TestTailCallee::FooParams, CallContext<FooParams, test::TestTailCallee::TailResult> context) {
test::TestTailCallee::TailResult> context) {
++callCount; ++callCount;
auto params = context.getParams(); auto params = context.getParams();
...@@ -974,15 +964,13 @@ kj::Promise<void> TestTailCalleeImpl::fooAdvanced( ...@@ -974,15 +964,13 @@ kj::Promise<void> TestTailCalleeImpl::fooAdvanced(
TestMoreStuffImpl::TestMoreStuffImpl(int& callCount): callCount(callCount) {} TestMoreStuffImpl::TestMoreStuffImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestMoreStuffImpl::getCallSequence( kj::Promise<void> TestMoreStuffImpl::getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params, GetCallSequenceParams::Reader params, GetCallSequenceResults::Builder result) {
test::TestCallOrder::GetCallSequenceResults::Builder result) {
result.setN(callCount++); result.setN(callCount++);
return kj::READY_NOW; return kj::READY_NOW;
} }
::kj::Promise<void> TestMoreStuffImpl::callFoo( kj::Promise<void> TestMoreStuffImpl::callFoo(
test::TestMoreStuff::CallFooParams::Reader params, CallFooParams::Reader params, CallFooResults::Builder result) {
test::TestMoreStuff::CallFooResults::Builder result) {
++callCount; ++callCount;
auto cap = params.getCap(); auto cap = params.getCap();
...@@ -992,9 +980,8 @@ kj::Promise<void> TestMoreStuffImpl::getCallSequence( ...@@ -992,9 +980,8 @@ kj::Promise<void> TestMoreStuffImpl::getCallSequence(
request.setJ(true); request.setJ(true);
return request.send().then( return request.send().then(
[result](capnp::Response<test::TestInterface::FooResults>&& response) mutable { [result](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX()); EXPECT_EQ("foo", response.getX());
result.setS("bar"); result.setS("bar");
}); });
} }
...@@ -1012,13 +999,58 @@ kj::Promise<void> TestMoreStuffImpl::callFooWhenResolved( ...@@ -1012,13 +999,58 @@ kj::Promise<void> TestMoreStuffImpl::callFooWhenResolved(
request.setJ(true); request.setJ(true);
return request.send().then( return request.send().then(
[result](capnp::Response<test::TestInterface::FooResults>&& response) mutable { [result](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX()); EXPECT_EQ("foo", response.getX());
result.setS("bar"); result.setS("bar");
}); });
}); });
} }
kj::Promise<void> TestMoreStuffImpl::neverReturnAdvanced(
CallContext<NeverReturnParams, NeverReturnResults> context) {
++callCount;
auto paf = kj::newPromiseAndFulfiller<void>();
neverFulfill = kj::mv(paf.fulfiller);
// Attach `cap` to the promise to make sure it is released.
paf.promise.attach(context.getParams().getCap());
// Also attach `cap` to the result struct to make sure that is released.
context.getResults().setCapCopy(context.getParams().getCap());
context.allowAsyncCancellation();
return kj::mv(paf.promise);
}
kj::Promise<void> TestMoreStuffImpl::hold(HoldParams::Reader params, HoldResults::Builder result) {
++callCount;
clientToHold = params.getCap();
return kj::READY_NOW;
}
kj::Promise<void> TestMoreStuffImpl::callHeld(
CallHeldParams::Reader params, CallHeldResults::Builder result) {
++callCount;
auto request = clientToHold.fooRequest();
request.setI(123);
request.setJ(true);
return request.send().then(
[result](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
result.setS("bar");
});
}
kj::Promise<void> TestMoreStuffImpl::getHeld(
GetHeldParams::Reader params, GetHeldResults::Builder result) {
++callCount;
result.setCap(clientToHold);
return kj::READY_NOW;
}
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -148,13 +148,9 @@ class TestInterfaceImpl final: public test::TestInterface::Server { ...@@ -148,13 +148,9 @@ class TestInterfaceImpl final: public test::TestInterface::Server {
public: public:
TestInterfaceImpl(int& callCount); TestInterfaceImpl(int& callCount);
::kj::Promise<void> foo( kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) override;
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) override;
::kj::Promise<void> bazAdvanced( kj::Promise<void> bazAdvanced(CallContext<BazParams, BazResults> context) override;
::capnp::CallContext<test::TestInterface::BazParams,
test::TestInterface::BazResults> context) override;
private: private:
int& callCount; int& callCount;
...@@ -164,12 +160,9 @@ class TestExtendsImpl final: public test::TestExtends::Server { ...@@ -164,12 +160,9 @@ class TestExtendsImpl final: public test::TestExtends::Server {
public: public:
TestExtendsImpl(int& callCount); TestExtendsImpl(int& callCount);
::kj::Promise<void> foo( kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) override;
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) override;
::kj::Promise<void> graultAdvanced( kj::Promise<void> graultAdvanced(CallContext<GraultParams, test::TestAllTypes> context) override;
::capnp::CallContext<test::TestExtends::GraultParams, test::TestAllTypes> context) override;
private: private:
int& callCount; int& callCount;
...@@ -179,9 +172,7 @@ class TestPipelineImpl final: public test::TestPipeline::Server { ...@@ -179,9 +172,7 @@ class TestPipelineImpl final: public test::TestPipeline::Server {
public: public:
TestPipelineImpl(int& callCount); TestPipelineImpl(int& callCount);
::kj::Promise<void> getCapAdvanced( kj::Promise<void> getCapAdvanced(CallContext<GetCapParams, GetCapResults> context) override;
capnp::CallContext<test::TestPipeline::GetCapParams,
test::TestPipeline::GetCapResults> context) override;
private: private:
int& callCount; int& callCount;
...@@ -190,8 +181,8 @@ private: ...@@ -190,8 +181,8 @@ private:
class TestCallOrderImpl final: public test::TestCallOrder::Server { class TestCallOrderImpl final: public test::TestCallOrder::Server {
public: public:
kj::Promise<void> getCallSequence( kj::Promise<void> getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params, GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) override; GetCallSequenceResults::Builder result) override;
private: private:
uint count = 0; uint count = 0;
...@@ -202,8 +193,7 @@ public: ...@@ -202,8 +193,7 @@ public:
TestTailCallerImpl(int& callCount); TestTailCallerImpl(int& callCount);
kj::Promise<void> fooAdvanced( kj::Promise<void> fooAdvanced(
capnp::CallContext<test::TestTailCaller::FooParams, CallContext<FooParams, test::TestTailCallee::TailResult> context) override;
test::TestTailCallee::TailResult> context) override;
private: private:
int& callCount; int& callCount;
...@@ -214,8 +204,7 @@ public: ...@@ -214,8 +204,7 @@ public:
TestTailCalleeImpl(int& callCount); TestTailCalleeImpl(int& callCount);
kj::Promise<void> fooAdvanced( kj::Promise<void> fooAdvanced(
capnp::CallContext<test::TestTailCallee::FooParams, CallContext<FooParams, test::TestTailCallee::TailResult> context) override;
test::TestTailCallee::TailResult> context) override;
private: private:
int& callCount; int& callCount;
...@@ -226,19 +215,32 @@ public: ...@@ -226,19 +215,32 @@ public:
TestMoreStuffImpl(int& callCount); TestMoreStuffImpl(int& callCount);
kj::Promise<void> getCallSequence( kj::Promise<void> getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params, GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) override; GetCallSequenceResults::Builder result) override;
::kj::Promise<void> callFoo( kj::Promise<void> callFoo(
test::TestMoreStuff::CallFooParams::Reader params, CallFooParams::Reader params,
test::TestMoreStuff::CallFooResults::Builder result) override; CallFooResults::Builder result) override;
kj::Promise<void> callFooWhenResolved( kj::Promise<void> callFooWhenResolved(
test::TestMoreStuff::CallFooWhenResolvedParams::Reader params, CallFooWhenResolvedParams::Reader params,
test::TestMoreStuff::CallFooWhenResolvedResults::Builder result) override; CallFooWhenResolvedResults::Builder result) override;
kj::Promise<void> neverReturnAdvanced(
CallContext<NeverReturnParams, NeverReturnResults> context) override;
kj::Promise<void> hold(HoldParams::Reader params, HoldResults::Builder result) override;
kj::Promise<void> callHeld(CallHeldParams::Reader params,
CallHeldResults::Builder result) override;
kj::Promise<void> getHeld(GetHeldParams::Reader params,
GetHeldResults::Builder result) override;
private: private:
int& callCount; int& callCount;
kj::Own<kj::PromiseFulfiller<void>> neverFulfill;
test::TestInterface::Client clientToHold = nullptr;
}; };
} // namespace _ (private) } // namespace _ (private)
......
...@@ -637,6 +637,18 @@ interface TestMoreStuff extends(TestCallOrder) { ...@@ -637,6 +637,18 @@ interface TestMoreStuff extends(TestCallOrder) {
callFooWhenResolved @1 (cap :TestInterface) -> (s: Text); callFooWhenResolved @1 (cap :TestInterface) -> (s: Text);
# Like callFoo but waits for `cap` to resolve first. # Like callFoo but waits for `cap` to resolve first.
neverReturn @2 (cap :TestInterface) -> (capCopy :TestInterface);
# Doesn't return. You should cancel it.
hold @3 (cap :TestInterface) -> ();
# Returns immediately but holds on to the capability.
callHeld @4 () -> (s: Text);
# Calls the capability previously held using `hold` (and keeps holding it).
getHeld @5 () -> (cap :TestInterface);
# Returns the capability previously held using `hold` (and keeps holding it).
} }
struct TestSturdyRefHostId { struct TestSturdyRefHostId {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment