Commit 6b8b8c71 authored by Kenton Varda's avatar Kenton Varda

Test and fix cancellation and release.

parent b1e502f7
......@@ -6,7 +6,7 @@ ifeq ($(CXX),clang++)
# Clang's verbose diagnostics don't play nice with the Ekam Eclipse plugin's error parsing,
# so disable them. Also enable some useful Clang warnings (dunno if GCC supports them, and don't
# care).
EXTRA_FLAG=-fno-caret-diagnostics -Wglobal-constructors -Wextra-semi
EXTRA_FLAG=-fno-caret-diagnostics -Wglobal-constructors -Wextra-semi -Werror=return-type
# EXTRA_FLAG=-fno-caret-diagnostics -Weverything -Wno-c++98-compat -Wno-shadow -Wno-c++98-compat-pedantic -Wno-padded -Wno-weak-vtables -Wno-gnu -Wno-unused-parameter -Wno-sign-conversion -Wno-undef -Wno-shorten-64-to-32 -Wno-conversion -Wno-unreachable-code -Wno-non-virtual-dtor
else
EXTRA_FLAG=
......
......@@ -131,7 +131,9 @@ public:
return kj::mv(paf.promise);
}
void allowAsyncCancellation() override {
// ignored for local calls
releaseParams();
// TODO(soon): Implement.
}
bool isCanceled() override {
return false;
......
......@@ -107,7 +107,7 @@ class Capability::Client {
// Base type for capability clients.
public:
explicit Client(decltype(nullptr));
Client(decltype(nullptr));
// If you need to declare a Client before you have anything to assign to it (perhaps because
// the assignment is going to occur in an if/else scope), you can start by initializing it to
// `nullptr`. The resulting client is not meant to be called and throws exceptions from all
......@@ -249,6 +249,10 @@ public:
// executing on a local thread. The method must perform an asynchronous operation or call
// `EventLoop::current().runLater()` to yield control.
//
// This method implies `releaseParams()` -- you cannot allow async cancellation while still
// holding the params. (This is because of a quirk of the current RPC implementation; in theory
// it could be fixed.)
//
// TODO(soon): This doesn't work for local calls, because there's no one to own the object
// in the meantime. What do we do about that? Is the security issue here actually a real
// threat? Maybe we can just always enable cancellation. After all, you need to be fault
......
......@@ -1205,6 +1205,11 @@ private:
kj::String resultType = resultProto.getScopeId() == 0 ?
kj::str(interfaceName, "::", titleCase, "Results") : cppFullName(resultSchema).flatten();
kj::String shortParamType = paramProto.getScopeId() == 0 ?
kj::str(titleCase, "Params") : cppFullName(paramSchema).flatten();
kj::String shortResultType = resultProto.getScopeId() == 0 ?
kj::str(titleCase, "Results") : cppFullName(resultSchema).flatten();
auto interfaceProto = method.getContainingInterface().getProto();
uint64_t interfaceId = interfaceProto.getId();
auto interfaceIdHex = kj::hex(interfaceId);
......@@ -1216,11 +1221,15 @@ private:
" unsigned int firstSegmentWordSize = 0) const;\n"),
kj::strTree(
paramProto.getScopeId() != 0 ? kj::strTree() : kj::strTree(
" typedef ", paramType, " ", titleCase, "Params;\n"),
resultProto.getScopeId() != 0 ? kj::strTree() : kj::strTree(
" typedef ", resultType, " ", titleCase, "Results;\n"),
" virtual ::kj::Promise<void> ", name, "(\n"
" ", paramType, "::Reader params,\n"
" ", resultType, "::Builder result);\n"
" ", shortParamType, "::Reader params,\n"
" ", shortResultType, "::Builder result);\n"
" virtual ::kj::Promise<void> ", name, "Advanced(\n"
" ::capnp::CallContext<", paramType, ", ", resultType, "> context);\n"),
" ::capnp::CallContext<", shortParamType, ", ", shortResultType, "> context);\n"),
kj::strTree(),
......@@ -1298,17 +1307,15 @@ private:
" typedef ", fullName, " Calls;\n"
" typedef ", fullName, " Reads;\n"
"\n"
" inline explicit Client(decltype(nullptr))\n"
" inline Client(decltype(nullptr))\n"
" : ::capnp::Capability::Client(nullptr) {}\n"
" inline explicit Client(::kj::Own<const ::capnp::ClientHook>&& hook)\n"
" : ::capnp::Capability::Client(::kj::mv(hook)) {}\n"
" template <typename T,\n"
" typename = ::kj::EnableIf< ::kj::canConvert<T*, Server*>()>>\n"
" template <typename T, typename = ::kj::EnableIf< ::kj::canConvert<T*, Server*>()>>\n"
" inline Client(::kj::Own<T>&& server,\n"
" const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n"
" : ::capnp::Capability::Client(::kj::mv(server), loop) {}\n"
" template <typename T,\n"
" typename = ::kj::EnableIf< ::kj::canConvert<T*, Client*>()>>\n"
" template <typename T, typename = ::kj::EnableIf< ::kj::canConvert<T*, Client*>()>>\n"
" inline Client(::kj::Promise<T>&& promise,\n"
" const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n"
" : ::capnp::Capability::Client(::kj::mv(promise), loop) {}\n"
......
This diff is collapsed.
......@@ -690,9 +690,6 @@ private:
kj::Own<CallContextHook>&& context) const override {
// Implement call() by copying params and results messages.
// We can and should propagate cancellation.
context->allowAsyncCancellation();
auto params = context->getParams();
size_t sizeHint = params.targetSizeInWords();
......@@ -709,9 +706,12 @@ private:
auto request = newCall(interfaceId, methodId, sizeHint);
request.set(context->getParams());
request.set(params);
context->releaseParams();
// We can and should propagate cancellation.
context->allowAsyncCancellation();
auto promise = request.send();
auto pipeline = promise.releasePipelineHook();
......@@ -738,7 +738,7 @@ private:
return kj::addRef(*this);
}
const void* getBrand() const override {
return &connectionState;
return connectionState.get();
}
protected:
......@@ -1183,6 +1183,10 @@ private:
}
}
void doneExtracting() {
resolutionChain = nullptr;
}
uint retainedListSizeHint(bool final) {
// Get the expected size of the retained caps list, in words. If `final` is true, then it
// is known that no more caps will be extracted after this point, so an exact value can be
......@@ -2029,6 +2033,7 @@ private:
}
void releaseParams() override {
request = nullptr;
requestCapExtractor.doneExtracting();
}
ObjectPointer::Builder getResults(uint firstSegmentWordSize) override {
KJ_IF_MAYBE(r, response) {
......@@ -2108,7 +2113,17 @@ private:
return kj::mv(paf.promise);
}
void allowAsyncCancellation() override {
if (threadAcceptingCancellation != nullptr) {
if (threadAcceptingCancellation == nullptr) {
// TODO(cleanup): We need to drop the request because it is holding on to the resolution
// chain which in turn holds on to the pipeline which holds on to this object thus
// preventing cancellation from working. This is a bit silly because obviously our
// request couldn't contain PromisedAnswers referring to itself, but currently the chain
// is a linear list and we have no way to tell that a reference to the chain taken before
// a call started doesn't really need to hold the call open. To fix this we'd presumably
// need to make the answer table snapshot-able and have CapExtractorImpl take a snapshot
// at creation.
releaseParams();
threadAcceptingCancellation = &kj::EventLoop::current();
if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) ==
......@@ -2156,10 +2171,6 @@ private:
// When both flags are set, the cancellation process will begin. Must be manipulated atomically
// as it may be accessed from multiple threads.
mutable kj::Promise<void> deferredCancellation = nullptr;
// Cancellation operation scheduled by cancelLater(). Must only be scheduled once, from one
// thread.
kj::EventLoop* threadAcceptingCancellation = nullptr;
// EventLoop for the thread that first called allowAsyncCancellation(). We store this as an
// optimization: if the application thread is independent from the network thread, we'd rather
......@@ -2176,31 +2187,29 @@ private:
// this call, shortly. We have to do it asynchronously because the caller might hold
// arbitrary locks or might in fact be part of the task being canceled.
deferredCancellation = threadAcceptingCancellation->evalLater([this]() {
// Make sure we don't accidentally delete ourselves in the process of canceling, since the
// last reference to the context may be owned by the asyncOp.
auto self = kj::addRef(*this);
connectionState->tasks.add(threadAcceptingCancellation->evalLater(
kj::mvCapture(kj::addRef(*this), [](kj::Own<const RpcCallContext>&& self) {
// Extract from the answer table the promise representing the executing call.
kj::Promise<void> asyncOp = nullptr;
{
auto lock = connectionState->tables.lockExclusive();
asyncOp = kj::mv(lock->answers[questionId].asyncOp);
auto lock = self->connectionState->tables.lockExclusive();
asyncOp = kj::mv(lock->answers[self->questionId].asyncOp);
}
// Delete the promise, thereby canceling the operation. Note that if a continuation is
// running in another thread, this line blocks waiting for it to complete. This is why
// we try to schedule doCancel() on the application thread, so that it won't need to block.
asyncOp = nullptr;
// When `asyncOp` goes out of scope, if it holds the last reference to the ongoing
// operation, that operation will be canceled. Note that if a continuation is
// running in another thread, the destructor will block waiting for it to complete. This
// is why we try to schedule doCancel() on the application thread, so that it won't need
// to block.
// The `Return` will be sent when the context is destroyed. That might be right now, when
// `self` goes out of scope. However, it is also possible that the pipeline is still in
// use: although `Finish` removes the pipeline reference from the answer table, it might
// be held by an outstanding pipelined call, or by a pipelined promise that was echoed back
// to us later (as a `receiverAnswer` in a `CapDescriptor`), or it may be held in the
// resolution chain. In all of these cases, the call will continue running until those
// references are dropped or the call completes.
});
// `self` and `asyncOp` go out of scope. However, it is also possible that the pipeline
// is still in use: although `Finish` removes the pipeline reference from the answer
// table, it might be held by an outstanding pipelined call, or by a pipelined promise that
// was echoed back to us later (as a `receiverAnswer` in a `CapDescriptor`), or it may be
// held in the resolution chain. In all of these cases, the call will continue running
// until those references are dropped or the call completes.
})));
}
bool isFirstResponder() {
......@@ -2423,10 +2432,12 @@ private:
answer.pipeline = kj::mv(promiseAndPipeline.pipeline);
if (redirectResults) {
answer.redirectedResults = promiseAndPipeline.promise.then(
auto promise = promiseAndPipeline.promise.then(
kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) {
return context->consumeRedirectedResponse();
}));
promise.eagerlyEvaluate(eventLoop);
answer.redirectedResults = kj::mv(promise);
} else {
// Hack: Both the success and error continuations need to use the context. We could
// refcount, but both will be destroyed at the same time anyway.
......@@ -2688,7 +2699,8 @@ private:
}
void handleRelease(const rpc::Release::Reader& release) {
releaseExport(*tables.lockExclusive(), release.getId(), release.getReferenceCount());
auto chainToRelease = releaseExport(
*tables.lockExclusive(), release.getId(), release.getReferenceCount());
}
static kj::Own<ResolutionChain> releaseExport(Tables& lockedTables, ExportId id, uint refcount) {
......
......@@ -862,9 +862,7 @@ void checkDynamicTestMessageAllZero(DynamicStruct::Reader reader) {
TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {}
::kj::Promise<void> TestInterfaceImpl::foo(
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) {
kj::Promise<void> TestInterfaceImpl::foo(FooParams::Reader params, FooResults::Builder result) {
++callCount;
EXPECT_EQ(123, params.getI());
EXPECT_TRUE(params.getJ());
......@@ -872,9 +870,7 @@ TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {}
return kj::READY_NOW;
}
::kj::Promise<void> TestInterfaceImpl::bazAdvanced(
::capnp::CallContext<test::TestInterface::BazParams,
test::TestInterface::BazResults> context) {
kj::Promise<void> TestInterfaceImpl::bazAdvanced(CallContext<BazParams, BazResults> context) {
++callCount;
auto params = context.getParams();
checkTestMessage(params.getS());
......@@ -886,9 +882,7 @@ TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {}
TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {}
::kj::Promise<void> TestExtendsImpl::foo(
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) {
kj::Promise<void> TestExtendsImpl::foo(FooParams::Reader params, FooResults::Builder result) {
++callCount;
EXPECT_EQ(321, params.getI());
EXPECT_FALSE(params.getJ());
......@@ -896,8 +890,8 @@ TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {}
return kj::READY_NOW;
}
::kj::Promise<void> TestExtendsImpl::graultAdvanced(
::capnp::CallContext<test::TestExtends::GraultParams, test::TestAllTypes> context) {
kj::Promise<void> TestExtendsImpl::graultAdvanced(
CallContext<GraultParams, test::TestAllTypes> context) {
++callCount;
context.releaseParams();
......@@ -908,9 +902,8 @@ TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {}
TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {}
::kj::Promise<void> TestPipelineImpl::getCapAdvanced(
capnp::CallContext<test::TestPipeline::GetCapParams,
test::TestPipeline::GetCapResults> context) {
kj::Promise<void> TestPipelineImpl::getCapAdvanced(
CallContext<GetCapParams, GetCapResults> context) {
++callCount;
auto params = context.getParams();
......@@ -924,7 +917,7 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {}
request.setJ(true);
return request.send().then(
[this,context](capnp::Response<test::TestInterface::FooResults>&& response) mutable {
[this,context](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
auto result = context.getResults();
......@@ -934,8 +927,7 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {}
}
kj::Promise<void> TestCallOrderImpl::getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) {
GetCallSequenceParams::Reader params, GetCallSequenceResults::Builder result) {
result.setN(count++);
return kj::READY_NOW;
}
......@@ -943,8 +935,7 @@ kj::Promise<void> TestCallOrderImpl::getCallSequence(
TestTailCallerImpl::TestTailCallerImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestTailCallerImpl::fooAdvanced(
capnp::CallContext<test::TestTailCaller::FooParams,
test::TestTailCallee::TailResult> context) {
CallContext<FooParams, test::TestTailCallee::TailResult> context) {
++callCount;
auto params = context.getParams();
......@@ -957,8 +948,7 @@ kj::Promise<void> TestTailCallerImpl::fooAdvanced(
TestTailCalleeImpl::TestTailCalleeImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestTailCalleeImpl::fooAdvanced(
capnp::CallContext<test::TestTailCallee::FooParams,
test::TestTailCallee::TailResult> context) {
CallContext<FooParams, test::TestTailCallee::TailResult> context) {
++callCount;
auto params = context.getParams();
......@@ -974,15 +964,13 @@ kj::Promise<void> TestTailCalleeImpl::fooAdvanced(
TestMoreStuffImpl::TestMoreStuffImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestMoreStuffImpl::getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) {
GetCallSequenceParams::Reader params, GetCallSequenceResults::Builder result) {
result.setN(callCount++);
return kj::READY_NOW;
}
::kj::Promise<void> TestMoreStuffImpl::callFoo(
test::TestMoreStuff::CallFooParams::Reader params,
test::TestMoreStuff::CallFooResults::Builder result) {
kj::Promise<void> TestMoreStuffImpl::callFoo(
CallFooParams::Reader params, CallFooResults::Builder result) {
++callCount;
auto cap = params.getCap();
......@@ -992,9 +980,8 @@ kj::Promise<void> TestMoreStuffImpl::getCallSequence(
request.setJ(true);
return request.send().then(
[result](capnp::Response<test::TestInterface::FooResults>&& response) mutable {
[result](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
result.setS("bar");
});
}
......@@ -1012,13 +999,58 @@ kj::Promise<void> TestMoreStuffImpl::callFooWhenResolved(
request.setJ(true);
return request.send().then(
[result](capnp::Response<test::TestInterface::FooResults>&& response) mutable {
[result](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
result.setS("bar");
});
});
}
kj::Promise<void> TestMoreStuffImpl::neverReturnAdvanced(
CallContext<NeverReturnParams, NeverReturnResults> context) {
++callCount;
auto paf = kj::newPromiseAndFulfiller<void>();
neverFulfill = kj::mv(paf.fulfiller);
// Attach `cap` to the promise to make sure it is released.
paf.promise.attach(context.getParams().getCap());
// Also attach `cap` to the result struct to make sure that is released.
context.getResults().setCapCopy(context.getParams().getCap());
context.allowAsyncCancellation();
return kj::mv(paf.promise);
}
kj::Promise<void> TestMoreStuffImpl::hold(HoldParams::Reader params, HoldResults::Builder result) {
++callCount;
clientToHold = params.getCap();
return kj::READY_NOW;
}
kj::Promise<void> TestMoreStuffImpl::callHeld(
CallHeldParams::Reader params, CallHeldResults::Builder result) {
++callCount;
auto request = clientToHold.fooRequest();
request.setI(123);
request.setJ(true);
return request.send().then(
[result](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
result.setS("bar");
});
}
kj::Promise<void> TestMoreStuffImpl::getHeld(
GetHeldParams::Reader params, GetHeldResults::Builder result) {
++callCount;
result.setCap(clientToHold);
return kj::READY_NOW;
}
} // namespace _ (private)
} // namespace capnp
......@@ -148,13 +148,9 @@ class TestInterfaceImpl final: public test::TestInterface::Server {
public:
TestInterfaceImpl(int& callCount);
::kj::Promise<void> foo(
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) override;
kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) override;
::kj::Promise<void> bazAdvanced(
::capnp::CallContext<test::TestInterface::BazParams,
test::TestInterface::BazResults> context) override;
kj::Promise<void> bazAdvanced(CallContext<BazParams, BazResults> context) override;
private:
int& callCount;
......@@ -164,12 +160,9 @@ class TestExtendsImpl final: public test::TestExtends::Server {
public:
TestExtendsImpl(int& callCount);
::kj::Promise<void> foo(
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) override;
kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) override;
::kj::Promise<void> graultAdvanced(
::capnp::CallContext<test::TestExtends::GraultParams, test::TestAllTypes> context) override;
kj::Promise<void> graultAdvanced(CallContext<GraultParams, test::TestAllTypes> context) override;
private:
int& callCount;
......@@ -179,9 +172,7 @@ class TestPipelineImpl final: public test::TestPipeline::Server {
public:
TestPipelineImpl(int& callCount);
::kj::Promise<void> getCapAdvanced(
capnp::CallContext<test::TestPipeline::GetCapParams,
test::TestPipeline::GetCapResults> context) override;
kj::Promise<void> getCapAdvanced(CallContext<GetCapParams, GetCapResults> context) override;
private:
int& callCount;
......@@ -190,8 +181,8 @@ private:
class TestCallOrderImpl final: public test::TestCallOrder::Server {
public:
kj::Promise<void> getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) override;
GetCallSequenceParams::Reader params,
GetCallSequenceResults::Builder result) override;
private:
uint count = 0;
......@@ -202,8 +193,7 @@ public:
TestTailCallerImpl(int& callCount);
kj::Promise<void> fooAdvanced(
capnp::CallContext<test::TestTailCaller::FooParams,
test::TestTailCallee::TailResult> context) override;
CallContext<FooParams, test::TestTailCallee::TailResult> context) override;
private:
int& callCount;
......@@ -214,8 +204,7 @@ public:
TestTailCalleeImpl(int& callCount);
kj::Promise<void> fooAdvanced(
capnp::CallContext<test::TestTailCallee::FooParams,
test::TestTailCallee::TailResult> context) override;
CallContext<FooParams, test::TestTailCallee::TailResult> context) override;
private:
int& callCount;
......@@ -226,19 +215,32 @@ public:
TestMoreStuffImpl(int& callCount);
kj::Promise<void> getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) override;
GetCallSequenceParams::Reader params,
GetCallSequenceResults::Builder result) override;
::kj::Promise<void> callFoo(
test::TestMoreStuff::CallFooParams::Reader params,
test::TestMoreStuff::CallFooResults::Builder result) override;
kj::Promise<void> callFoo(
CallFooParams::Reader params,
CallFooResults::Builder result) override;
kj::Promise<void> callFooWhenResolved(
test::TestMoreStuff::CallFooWhenResolvedParams::Reader params,
test::TestMoreStuff::CallFooWhenResolvedResults::Builder result) override;
CallFooWhenResolvedParams::Reader params,
CallFooWhenResolvedResults::Builder result) override;
kj::Promise<void> neverReturnAdvanced(
CallContext<NeverReturnParams, NeverReturnResults> context) override;
kj::Promise<void> hold(HoldParams::Reader params, HoldResults::Builder result) override;
kj::Promise<void> callHeld(CallHeldParams::Reader params,
CallHeldResults::Builder result) override;
kj::Promise<void> getHeld(GetHeldParams::Reader params,
GetHeldResults::Builder result) override;
private:
int& callCount;
kj::Own<kj::PromiseFulfiller<void>> neverFulfill;
test::TestInterface::Client clientToHold = nullptr;
};
} // namespace _ (private)
......
......@@ -637,6 +637,18 @@ interface TestMoreStuff extends(TestCallOrder) {
callFooWhenResolved @1 (cap :TestInterface) -> (s: Text);
# Like callFoo but waits for `cap` to resolve first.
neverReturn @2 (cap :TestInterface) -> (capCopy :TestInterface);
# Doesn't return. You should cancel it.
hold @3 (cap :TestInterface) -> ();
# Returns immediately but holds on to the capability.
callHeld @4 () -> (s: Text);
# Calls the capability previously held using `hold` (and keeps holding it).
getHeld @5 () -> (cap :TestInterface);
# Returns the capability previously held using `hold` (and keeps holding it).
}
struct TestSturdyRefHostId {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment