Commit 6b8b8c71 authored by Kenton Varda's avatar Kenton Varda

Test and fix cancellation and release.

parent b1e502f7
...@@ -6,7 +6,7 @@ ifeq ($(CXX),clang++) ...@@ -6,7 +6,7 @@ ifeq ($(CXX),clang++)
# Clang's verbose diagnostics don't play nice with the Ekam Eclipse plugin's error parsing, # Clang's verbose diagnostics don't play nice with the Ekam Eclipse plugin's error parsing,
# so disable them. Also enable some useful Clang warnings (dunno if GCC supports them, and don't # so disable them. Also enable some useful Clang warnings (dunno if GCC supports them, and don't
# care). # care).
EXTRA_FLAG=-fno-caret-diagnostics -Wglobal-constructors -Wextra-semi EXTRA_FLAG=-fno-caret-diagnostics -Wglobal-constructors -Wextra-semi -Werror=return-type
# EXTRA_FLAG=-fno-caret-diagnostics -Weverything -Wno-c++98-compat -Wno-shadow -Wno-c++98-compat-pedantic -Wno-padded -Wno-weak-vtables -Wno-gnu -Wno-unused-parameter -Wno-sign-conversion -Wno-undef -Wno-shorten-64-to-32 -Wno-conversion -Wno-unreachable-code -Wno-non-virtual-dtor # EXTRA_FLAG=-fno-caret-diagnostics -Weverything -Wno-c++98-compat -Wno-shadow -Wno-c++98-compat-pedantic -Wno-padded -Wno-weak-vtables -Wno-gnu -Wno-unused-parameter -Wno-sign-conversion -Wno-undef -Wno-shorten-64-to-32 -Wno-conversion -Wno-unreachable-code -Wno-non-virtual-dtor
else else
EXTRA_FLAG= EXTRA_FLAG=
......
...@@ -131,7 +131,9 @@ public: ...@@ -131,7 +131,9 @@ public:
return kj::mv(paf.promise); return kj::mv(paf.promise);
} }
void allowAsyncCancellation() override { void allowAsyncCancellation() override {
// ignored for local calls releaseParams();
// TODO(soon): Implement.
} }
bool isCanceled() override { bool isCanceled() override {
return false; return false;
......
...@@ -107,7 +107,7 @@ class Capability::Client { ...@@ -107,7 +107,7 @@ class Capability::Client {
// Base type for capability clients. // Base type for capability clients.
public: public:
explicit Client(decltype(nullptr)); Client(decltype(nullptr));
// If you need to declare a Client before you have anything to assign to it (perhaps because // If you need to declare a Client before you have anything to assign to it (perhaps because
// the assignment is going to occur in an if/else scope), you can start by initializing it to // the assignment is going to occur in an if/else scope), you can start by initializing it to
// `nullptr`. The resulting client is not meant to be called and throws exceptions from all // `nullptr`. The resulting client is not meant to be called and throws exceptions from all
...@@ -249,6 +249,10 @@ public: ...@@ -249,6 +249,10 @@ public:
// executing on a local thread. The method must perform an asynchronous operation or call // executing on a local thread. The method must perform an asynchronous operation or call
// `EventLoop::current().runLater()` to yield control. // `EventLoop::current().runLater()` to yield control.
// //
// This method implies `releaseParams()` -- you cannot allow async cancellation while still
// holding the params. (This is because of a quirk of the current RPC implementation; in theory
// it could be fixed.)
//
// TODO(soon): This doesn't work for local calls, because there's no one to own the object // TODO(soon): This doesn't work for local calls, because there's no one to own the object
// in the meantime. What do we do about that? Is the security issue here actually a real // in the meantime. What do we do about that? Is the security issue here actually a real
// threat? Maybe we can just always enable cancellation. After all, you need to be fault // threat? Maybe we can just always enable cancellation. After all, you need to be fault
......
...@@ -1205,6 +1205,11 @@ private: ...@@ -1205,6 +1205,11 @@ private:
kj::String resultType = resultProto.getScopeId() == 0 ? kj::String resultType = resultProto.getScopeId() == 0 ?
kj::str(interfaceName, "::", titleCase, "Results") : cppFullName(resultSchema).flatten(); kj::str(interfaceName, "::", titleCase, "Results") : cppFullName(resultSchema).flatten();
kj::String shortParamType = paramProto.getScopeId() == 0 ?
kj::str(titleCase, "Params") : cppFullName(paramSchema).flatten();
kj::String shortResultType = resultProto.getScopeId() == 0 ?
kj::str(titleCase, "Results") : cppFullName(resultSchema).flatten();
auto interfaceProto = method.getContainingInterface().getProto(); auto interfaceProto = method.getContainingInterface().getProto();
uint64_t interfaceId = interfaceProto.getId(); uint64_t interfaceId = interfaceProto.getId();
auto interfaceIdHex = kj::hex(interfaceId); auto interfaceIdHex = kj::hex(interfaceId);
...@@ -1216,11 +1221,15 @@ private: ...@@ -1216,11 +1221,15 @@ private:
" unsigned int firstSegmentWordSize = 0) const;\n"), " unsigned int firstSegmentWordSize = 0) const;\n"),
kj::strTree( kj::strTree(
paramProto.getScopeId() != 0 ? kj::strTree() : kj::strTree(
" typedef ", paramType, " ", titleCase, "Params;\n"),
resultProto.getScopeId() != 0 ? kj::strTree() : kj::strTree(
" typedef ", resultType, " ", titleCase, "Results;\n"),
" virtual ::kj::Promise<void> ", name, "(\n" " virtual ::kj::Promise<void> ", name, "(\n"
" ", paramType, "::Reader params,\n" " ", shortParamType, "::Reader params,\n"
" ", resultType, "::Builder result);\n" " ", shortResultType, "::Builder result);\n"
" virtual ::kj::Promise<void> ", name, "Advanced(\n" " virtual ::kj::Promise<void> ", name, "Advanced(\n"
" ::capnp::CallContext<", paramType, ", ", resultType, "> context);\n"), " ::capnp::CallContext<", shortParamType, ", ", shortResultType, "> context);\n"),
kj::strTree(), kj::strTree(),
...@@ -1298,17 +1307,15 @@ private: ...@@ -1298,17 +1307,15 @@ private:
" typedef ", fullName, " Calls;\n" " typedef ", fullName, " Calls;\n"
" typedef ", fullName, " Reads;\n" " typedef ", fullName, " Reads;\n"
"\n" "\n"
" inline explicit Client(decltype(nullptr))\n" " inline Client(decltype(nullptr))\n"
" : ::capnp::Capability::Client(nullptr) {}\n" " : ::capnp::Capability::Client(nullptr) {}\n"
" inline explicit Client(::kj::Own<const ::capnp::ClientHook>&& hook)\n" " inline explicit Client(::kj::Own<const ::capnp::ClientHook>&& hook)\n"
" : ::capnp::Capability::Client(::kj::mv(hook)) {}\n" " : ::capnp::Capability::Client(::kj::mv(hook)) {}\n"
" template <typename T,\n" " template <typename T, typename = ::kj::EnableIf< ::kj::canConvert<T*, Server*>()>>\n"
" typename = ::kj::EnableIf< ::kj::canConvert<T*, Server*>()>>\n"
" inline Client(::kj::Own<T>&& server,\n" " inline Client(::kj::Own<T>&& server,\n"
" const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n" " const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n"
" : ::capnp::Capability::Client(::kj::mv(server), loop) {}\n" " : ::capnp::Capability::Client(::kj::mv(server), loop) {}\n"
" template <typename T,\n" " template <typename T, typename = ::kj::EnableIf< ::kj::canConvert<T*, Client*>()>>\n"
" typename = ::kj::EnableIf< ::kj::canConvert<T*, Client*>()>>\n"
" inline Client(::kj::Promise<T>&& promise,\n" " inline Client(::kj::Promise<T>&& promise,\n"
" const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n" " const ::kj::EventLoop& loop = ::kj::EventLoop::current())\n"
" : ::capnp::Capability::Client(::kj::mv(promise), loop) {}\n" " : ::capnp::Capability::Client(::kj::mv(promise), loop) {}\n"
......
This diff is collapsed.
...@@ -690,9 +690,6 @@ private: ...@@ -690,9 +690,6 @@ private:
kj::Own<CallContextHook>&& context) const override { kj::Own<CallContextHook>&& context) const override {
// Implement call() by copying params and results messages. // Implement call() by copying params and results messages.
// We can and should propagate cancellation.
context->allowAsyncCancellation();
auto params = context->getParams(); auto params = context->getParams();
size_t sizeHint = params.targetSizeInWords(); size_t sizeHint = params.targetSizeInWords();
...@@ -709,9 +706,12 @@ private: ...@@ -709,9 +706,12 @@ private:
auto request = newCall(interfaceId, methodId, sizeHint); auto request = newCall(interfaceId, methodId, sizeHint);
request.set(context->getParams()); request.set(params);
context->releaseParams(); context->releaseParams();
// We can and should propagate cancellation.
context->allowAsyncCancellation();
auto promise = request.send(); auto promise = request.send();
auto pipeline = promise.releasePipelineHook(); auto pipeline = promise.releasePipelineHook();
...@@ -738,7 +738,7 @@ private: ...@@ -738,7 +738,7 @@ private:
return kj::addRef(*this); return kj::addRef(*this);
} }
const void* getBrand() const override { const void* getBrand() const override {
return &connectionState; return connectionState.get();
} }
protected: protected:
...@@ -1183,6 +1183,10 @@ private: ...@@ -1183,6 +1183,10 @@ private:
} }
} }
void doneExtracting() {
resolutionChain = nullptr;
}
uint retainedListSizeHint(bool final) { uint retainedListSizeHint(bool final) {
// Get the expected size of the retained caps list, in words. If `final` is true, then it // Get the expected size of the retained caps list, in words. If `final` is true, then it
// is known that no more caps will be extracted after this point, so an exact value can be // is known that no more caps will be extracted after this point, so an exact value can be
...@@ -2029,6 +2033,7 @@ private: ...@@ -2029,6 +2033,7 @@ private:
} }
void releaseParams() override { void releaseParams() override {
request = nullptr; request = nullptr;
requestCapExtractor.doneExtracting();
} }
ObjectPointer::Builder getResults(uint firstSegmentWordSize) override { ObjectPointer::Builder getResults(uint firstSegmentWordSize) override {
KJ_IF_MAYBE(r, response) { KJ_IF_MAYBE(r, response) {
...@@ -2108,7 +2113,17 @@ private: ...@@ -2108,7 +2113,17 @@ private:
return kj::mv(paf.promise); return kj::mv(paf.promise);
} }
void allowAsyncCancellation() override { void allowAsyncCancellation() override {
if (threadAcceptingCancellation != nullptr) { if (threadAcceptingCancellation == nullptr) {
// TODO(cleanup): We need to drop the request because it is holding on to the resolution
// chain which in turn holds on to the pipeline which holds on to this object thus
// preventing cancellation from working. This is a bit silly because obviously our
// request couldn't contain PromisedAnswers referring to itself, but currently the chain
// is a linear list and we have no way to tell that a reference to the chain taken before
// a call started doesn't really need to hold the call open. To fix this we'd presumably
// need to make the answer table snapshot-able and have CapExtractorImpl take a snapshot
// at creation.
releaseParams();
threadAcceptingCancellation = &kj::EventLoop::current(); threadAcceptingCancellation = &kj::EventLoop::current();
if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) == if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) ==
...@@ -2156,10 +2171,6 @@ private: ...@@ -2156,10 +2171,6 @@ private:
// When both flags are set, the cancellation process will begin. Must be manipulated atomically // When both flags are set, the cancellation process will begin. Must be manipulated atomically
// as it may be accessed from multiple threads. // as it may be accessed from multiple threads.
mutable kj::Promise<void> deferredCancellation = nullptr;
// Cancellation operation scheduled by cancelLater(). Must only be scheduled once, from one
// thread.
kj::EventLoop* threadAcceptingCancellation = nullptr; kj::EventLoop* threadAcceptingCancellation = nullptr;
// EventLoop for the thread that first called allowAsyncCancellation(). We store this as an // EventLoop for the thread that first called allowAsyncCancellation(). We store this as an
// optimization: if the application thread is independent from the network thread, we'd rather // optimization: if the application thread is independent from the network thread, we'd rather
...@@ -2176,31 +2187,29 @@ private: ...@@ -2176,31 +2187,29 @@ private:
// this call, shortly. We have to do it asynchronously because the caller might hold // this call, shortly. We have to do it asynchronously because the caller might hold
// arbitrary locks or might in fact be part of the task being canceled. // arbitrary locks or might in fact be part of the task being canceled.
deferredCancellation = threadAcceptingCancellation->evalLater([this]() { connectionState->tasks.add(threadAcceptingCancellation->evalLater(
// Make sure we don't accidentally delete ourselves in the process of canceling, since the kj::mvCapture(kj::addRef(*this), [](kj::Own<const RpcCallContext>&& self) {
// last reference to the context may be owned by the asyncOp.
auto self = kj::addRef(*this);
// Extract from the answer table the promise representing the executing call. // Extract from the answer table the promise representing the executing call.
kj::Promise<void> asyncOp = nullptr; kj::Promise<void> asyncOp = nullptr;
{ {
auto lock = connectionState->tables.lockExclusive(); auto lock = self->connectionState->tables.lockExclusive();
asyncOp = kj::mv(lock->answers[questionId].asyncOp); asyncOp = kj::mv(lock->answers[self->questionId].asyncOp);
} }
// Delete the promise, thereby canceling the operation. Note that if a continuation is // When `asyncOp` goes out of scope, if it holds the last reference to the ongoing
// running in another thread, this line blocks waiting for it to complete. This is why // operation, that operation will be canceled. Note that if a continuation is
// we try to schedule doCancel() on the application thread, so that it won't need to block. // running in another thread, the destructor will block waiting for it to complete. This
asyncOp = nullptr; // is why we try to schedule doCancel() on the application thread, so that it won't need
// to block.
// The `Return` will be sent when the context is destroyed. That might be right now, when // The `Return` will be sent when the context is destroyed. That might be right now, when
// `self` goes out of scope. However, it is also possible that the pipeline is still in // `self` and `asyncOp` go out of scope. However, it is also possible that the pipeline
// use: although `Finish` removes the pipeline reference from the answer table, it might // is still in use: although `Finish` removes the pipeline reference from the answer
// be held by an outstanding pipelined call, or by a pipelined promise that was echoed back // table, it might be held by an outstanding pipelined call, or by a pipelined promise that
// to us later (as a `receiverAnswer` in a `CapDescriptor`), or it may be held in the // was echoed back to us later (as a `receiverAnswer` in a `CapDescriptor`), or it may be
// resolution chain. In all of these cases, the call will continue running until those // held in the resolution chain. In all of these cases, the call will continue running
// references are dropped or the call completes. // until those references are dropped or the call completes.
}); })));
} }
bool isFirstResponder() { bool isFirstResponder() {
...@@ -2423,10 +2432,12 @@ private: ...@@ -2423,10 +2432,12 @@ private:
answer.pipeline = kj::mv(promiseAndPipeline.pipeline); answer.pipeline = kj::mv(promiseAndPipeline.pipeline);
if (redirectResults) { if (redirectResults) {
answer.redirectedResults = promiseAndPipeline.promise.then( auto promise = promiseAndPipeline.promise.then(
kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) { kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) {
return context->consumeRedirectedResponse(); return context->consumeRedirectedResponse();
})); }));
promise.eagerlyEvaluate(eventLoop);
answer.redirectedResults = kj::mv(promise);
} else { } else {
// Hack: Both the success and error continuations need to use the context. We could // Hack: Both the success and error continuations need to use the context. We could
// refcount, but both will be destroyed at the same time anyway. // refcount, but both will be destroyed at the same time anyway.
...@@ -2688,7 +2699,8 @@ private: ...@@ -2688,7 +2699,8 @@ private:
} }
void handleRelease(const rpc::Release::Reader& release) { void handleRelease(const rpc::Release::Reader& release) {
releaseExport(*tables.lockExclusive(), release.getId(), release.getReferenceCount()); auto chainToRelease = releaseExport(
*tables.lockExclusive(), release.getId(), release.getReferenceCount());
} }
static kj::Own<ResolutionChain> releaseExport(Tables& lockedTables, ExportId id, uint refcount) { static kj::Own<ResolutionChain> releaseExport(Tables& lockedTables, ExportId id, uint refcount) {
......
...@@ -862,9 +862,7 @@ void checkDynamicTestMessageAllZero(DynamicStruct::Reader reader) { ...@@ -862,9 +862,7 @@ void checkDynamicTestMessageAllZero(DynamicStruct::Reader reader) {
TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {} TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {}
::kj::Promise<void> TestInterfaceImpl::foo( kj::Promise<void> TestInterfaceImpl::foo(FooParams::Reader params, FooResults::Builder result) {
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) {
++callCount; ++callCount;
EXPECT_EQ(123, params.getI()); EXPECT_EQ(123, params.getI());
EXPECT_TRUE(params.getJ()); EXPECT_TRUE(params.getJ());
...@@ -872,9 +870,7 @@ TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {} ...@@ -872,9 +870,7 @@ TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {}
return kj::READY_NOW; return kj::READY_NOW;
} }
::kj::Promise<void> TestInterfaceImpl::bazAdvanced( kj::Promise<void> TestInterfaceImpl::bazAdvanced(CallContext<BazParams, BazResults> context) {
::capnp::CallContext<test::TestInterface::BazParams,
test::TestInterface::BazResults> context) {
++callCount; ++callCount;
auto params = context.getParams(); auto params = context.getParams();
checkTestMessage(params.getS()); checkTestMessage(params.getS());
...@@ -886,9 +882,7 @@ TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {} ...@@ -886,9 +882,7 @@ TestInterfaceImpl::TestInterfaceImpl(int& callCount): callCount(callCount) {}
TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {} TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {}
::kj::Promise<void> TestExtendsImpl::foo( kj::Promise<void> TestExtendsImpl::foo(FooParams::Reader params, FooResults::Builder result) {
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) {
++callCount; ++callCount;
EXPECT_EQ(321, params.getI()); EXPECT_EQ(321, params.getI());
EXPECT_FALSE(params.getJ()); EXPECT_FALSE(params.getJ());
...@@ -896,8 +890,8 @@ TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {} ...@@ -896,8 +890,8 @@ TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {}
return kj::READY_NOW; return kj::READY_NOW;
} }
::kj::Promise<void> TestExtendsImpl::graultAdvanced( kj::Promise<void> TestExtendsImpl::graultAdvanced(
::capnp::CallContext<test::TestExtends::GraultParams, test::TestAllTypes> context) { CallContext<GraultParams, test::TestAllTypes> context) {
++callCount; ++callCount;
context.releaseParams(); context.releaseParams();
...@@ -908,9 +902,8 @@ TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {} ...@@ -908,9 +902,8 @@ TestExtendsImpl::TestExtendsImpl(int& callCount): callCount(callCount) {}
TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {} TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {}
::kj::Promise<void> TestPipelineImpl::getCapAdvanced( kj::Promise<void> TestPipelineImpl::getCapAdvanced(
capnp::CallContext<test::TestPipeline::GetCapParams, CallContext<GetCapParams, GetCapResults> context) {
test::TestPipeline::GetCapResults> context) {
++callCount; ++callCount;
auto params = context.getParams(); auto params = context.getParams();
...@@ -924,7 +917,7 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {} ...@@ -924,7 +917,7 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {}
request.setJ(true); request.setJ(true);
return request.send().then( return request.send().then(
[this,context](capnp::Response<test::TestInterface::FooResults>&& response) mutable { [this,context](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX()); EXPECT_EQ("foo", response.getX());
auto result = context.getResults(); auto result = context.getResults();
...@@ -934,8 +927,7 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {} ...@@ -934,8 +927,7 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {}
} }
kj::Promise<void> TestCallOrderImpl::getCallSequence( kj::Promise<void> TestCallOrderImpl::getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params, GetCallSequenceParams::Reader params, GetCallSequenceResults::Builder result) {
test::TestCallOrder::GetCallSequenceResults::Builder result) {
result.setN(count++); result.setN(count++);
return kj::READY_NOW; return kj::READY_NOW;
} }
...@@ -943,8 +935,7 @@ kj::Promise<void> TestCallOrderImpl::getCallSequence( ...@@ -943,8 +935,7 @@ kj::Promise<void> TestCallOrderImpl::getCallSequence(
TestTailCallerImpl::TestTailCallerImpl(int& callCount): callCount(callCount) {} TestTailCallerImpl::TestTailCallerImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestTailCallerImpl::fooAdvanced( kj::Promise<void> TestTailCallerImpl::fooAdvanced(
capnp::CallContext<test::TestTailCaller::FooParams, CallContext<FooParams, test::TestTailCallee::TailResult> context) {
test::TestTailCallee::TailResult> context) {
++callCount; ++callCount;
auto params = context.getParams(); auto params = context.getParams();
...@@ -957,8 +948,7 @@ kj::Promise<void> TestTailCallerImpl::fooAdvanced( ...@@ -957,8 +948,7 @@ kj::Promise<void> TestTailCallerImpl::fooAdvanced(
TestTailCalleeImpl::TestTailCalleeImpl(int& callCount): callCount(callCount) {} TestTailCalleeImpl::TestTailCalleeImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestTailCalleeImpl::fooAdvanced( kj::Promise<void> TestTailCalleeImpl::fooAdvanced(
capnp::CallContext<test::TestTailCallee::FooParams, CallContext<FooParams, test::TestTailCallee::TailResult> context) {
test::TestTailCallee::TailResult> context) {
++callCount; ++callCount;
auto params = context.getParams(); auto params = context.getParams();
...@@ -974,15 +964,13 @@ kj::Promise<void> TestTailCalleeImpl::fooAdvanced( ...@@ -974,15 +964,13 @@ kj::Promise<void> TestTailCalleeImpl::fooAdvanced(
TestMoreStuffImpl::TestMoreStuffImpl(int& callCount): callCount(callCount) {} TestMoreStuffImpl::TestMoreStuffImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestMoreStuffImpl::getCallSequence( kj::Promise<void> TestMoreStuffImpl::getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params, GetCallSequenceParams::Reader params, GetCallSequenceResults::Builder result) {
test::TestCallOrder::GetCallSequenceResults::Builder result) {
result.setN(callCount++); result.setN(callCount++);
return kj::READY_NOW; return kj::READY_NOW;
} }
::kj::Promise<void> TestMoreStuffImpl::callFoo( kj::Promise<void> TestMoreStuffImpl::callFoo(
test::TestMoreStuff::CallFooParams::Reader params, CallFooParams::Reader params, CallFooResults::Builder result) {
test::TestMoreStuff::CallFooResults::Builder result) {
++callCount; ++callCount;
auto cap = params.getCap(); auto cap = params.getCap();
...@@ -992,9 +980,8 @@ kj::Promise<void> TestMoreStuffImpl::getCallSequence( ...@@ -992,9 +980,8 @@ kj::Promise<void> TestMoreStuffImpl::getCallSequence(
request.setJ(true); request.setJ(true);
return request.send().then( return request.send().then(
[result](capnp::Response<test::TestInterface::FooResults>&& response) mutable { [result](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX()); EXPECT_EQ("foo", response.getX());
result.setS("bar"); result.setS("bar");
}); });
} }
...@@ -1012,13 +999,58 @@ kj::Promise<void> TestMoreStuffImpl::callFooWhenResolved( ...@@ -1012,13 +999,58 @@ kj::Promise<void> TestMoreStuffImpl::callFooWhenResolved(
request.setJ(true); request.setJ(true);
return request.send().then( return request.send().then(
[result](capnp::Response<test::TestInterface::FooResults>&& response) mutable { [result](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX()); EXPECT_EQ("foo", response.getX());
result.setS("bar"); result.setS("bar");
}); });
}); });
} }
kj::Promise<void> TestMoreStuffImpl::neverReturnAdvanced(
CallContext<NeverReturnParams, NeverReturnResults> context) {
++callCount;
auto paf = kj::newPromiseAndFulfiller<void>();
neverFulfill = kj::mv(paf.fulfiller);
// Attach `cap` to the promise to make sure it is released.
paf.promise.attach(context.getParams().getCap());
// Also attach `cap` to the result struct to make sure that is released.
context.getResults().setCapCopy(context.getParams().getCap());
context.allowAsyncCancellation();
return kj::mv(paf.promise);
}
kj::Promise<void> TestMoreStuffImpl::hold(HoldParams::Reader params, HoldResults::Builder result) {
++callCount;
clientToHold = params.getCap();
return kj::READY_NOW;
}
kj::Promise<void> TestMoreStuffImpl::callHeld(
CallHeldParams::Reader params, CallHeldResults::Builder result) {
++callCount;
auto request = clientToHold.fooRequest();
request.setI(123);
request.setJ(true);
return request.send().then(
[result](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
result.setS("bar");
});
}
kj::Promise<void> TestMoreStuffImpl::getHeld(
GetHeldParams::Reader params, GetHeldResults::Builder result) {
++callCount;
result.setCap(clientToHold);
return kj::READY_NOW;
}
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -148,13 +148,9 @@ class TestInterfaceImpl final: public test::TestInterface::Server { ...@@ -148,13 +148,9 @@ class TestInterfaceImpl final: public test::TestInterface::Server {
public: public:
TestInterfaceImpl(int& callCount); TestInterfaceImpl(int& callCount);
::kj::Promise<void> foo( kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) override;
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) override;
::kj::Promise<void> bazAdvanced( kj::Promise<void> bazAdvanced(CallContext<BazParams, BazResults> context) override;
::capnp::CallContext<test::TestInterface::BazParams,
test::TestInterface::BazResults> context) override;
private: private:
int& callCount; int& callCount;
...@@ -164,12 +160,9 @@ class TestExtendsImpl final: public test::TestExtends::Server { ...@@ -164,12 +160,9 @@ class TestExtendsImpl final: public test::TestExtends::Server {
public: public:
TestExtendsImpl(int& callCount); TestExtendsImpl(int& callCount);
::kj::Promise<void> foo( kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) override;
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) override;
::kj::Promise<void> graultAdvanced( kj::Promise<void> graultAdvanced(CallContext<GraultParams, test::TestAllTypes> context) override;
::capnp::CallContext<test::TestExtends::GraultParams, test::TestAllTypes> context) override;
private: private:
int& callCount; int& callCount;
...@@ -179,9 +172,7 @@ class TestPipelineImpl final: public test::TestPipeline::Server { ...@@ -179,9 +172,7 @@ class TestPipelineImpl final: public test::TestPipeline::Server {
public: public:
TestPipelineImpl(int& callCount); TestPipelineImpl(int& callCount);
::kj::Promise<void> getCapAdvanced( kj::Promise<void> getCapAdvanced(CallContext<GetCapParams, GetCapResults> context) override;
capnp::CallContext<test::TestPipeline::GetCapParams,
test::TestPipeline::GetCapResults> context) override;
private: private:
int& callCount; int& callCount;
...@@ -190,8 +181,8 @@ private: ...@@ -190,8 +181,8 @@ private:
class TestCallOrderImpl final: public test::TestCallOrder::Server { class TestCallOrderImpl final: public test::TestCallOrder::Server {
public: public:
kj::Promise<void> getCallSequence( kj::Promise<void> getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params, GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) override; GetCallSequenceResults::Builder result) override;
private: private:
uint count = 0; uint count = 0;
...@@ -202,8 +193,7 @@ public: ...@@ -202,8 +193,7 @@ public:
TestTailCallerImpl(int& callCount); TestTailCallerImpl(int& callCount);
kj::Promise<void> fooAdvanced( kj::Promise<void> fooAdvanced(
capnp::CallContext<test::TestTailCaller::FooParams, CallContext<FooParams, test::TestTailCallee::TailResult> context) override;
test::TestTailCallee::TailResult> context) override;
private: private:
int& callCount; int& callCount;
...@@ -214,8 +204,7 @@ public: ...@@ -214,8 +204,7 @@ public:
TestTailCalleeImpl(int& callCount); TestTailCalleeImpl(int& callCount);
kj::Promise<void> fooAdvanced( kj::Promise<void> fooAdvanced(
capnp::CallContext<test::TestTailCallee::FooParams, CallContext<FooParams, test::TestTailCallee::TailResult> context) override;
test::TestTailCallee::TailResult> context) override;
private: private:
int& callCount; int& callCount;
...@@ -226,19 +215,32 @@ public: ...@@ -226,19 +215,32 @@ public:
TestMoreStuffImpl(int& callCount); TestMoreStuffImpl(int& callCount);
kj::Promise<void> getCallSequence( kj::Promise<void> getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params, GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) override; GetCallSequenceResults::Builder result) override;
::kj::Promise<void> callFoo( kj::Promise<void> callFoo(
test::TestMoreStuff::CallFooParams::Reader params, CallFooParams::Reader params,
test::TestMoreStuff::CallFooResults::Builder result) override; CallFooResults::Builder result) override;
kj::Promise<void> callFooWhenResolved( kj::Promise<void> callFooWhenResolved(
test::TestMoreStuff::CallFooWhenResolvedParams::Reader params, CallFooWhenResolvedParams::Reader params,
test::TestMoreStuff::CallFooWhenResolvedResults::Builder result) override; CallFooWhenResolvedResults::Builder result) override;
kj::Promise<void> neverReturnAdvanced(
CallContext<NeverReturnParams, NeverReturnResults> context) override;
kj::Promise<void> hold(HoldParams::Reader params, HoldResults::Builder result) override;
kj::Promise<void> callHeld(CallHeldParams::Reader params,
CallHeldResults::Builder result) override;
kj::Promise<void> getHeld(GetHeldParams::Reader params,
GetHeldResults::Builder result) override;
private: private:
int& callCount; int& callCount;
kj::Own<kj::PromiseFulfiller<void>> neverFulfill;
test::TestInterface::Client clientToHold = nullptr;
}; };
} // namespace _ (private) } // namespace _ (private)
......
...@@ -637,6 +637,18 @@ interface TestMoreStuff extends(TestCallOrder) { ...@@ -637,6 +637,18 @@ interface TestMoreStuff extends(TestCallOrder) {
callFooWhenResolved @1 (cap :TestInterface) -> (s: Text); callFooWhenResolved @1 (cap :TestInterface) -> (s: Text);
# Like callFoo but waits for `cap` to resolve first. # Like callFoo but waits for `cap` to resolve first.
neverReturn @2 (cap :TestInterface) -> (capCopy :TestInterface);
# Doesn't return. You should cancel it.
hold @3 (cap :TestInterface) -> ();
# Returns immediately but holds on to the capability.
callHeld @4 () -> (s: Text);
# Calls the capability previously held using `hold` (and keeps holding it).
getHeld @5 () -> (cap :TestInterface);
# Returns the capability previously held using `hold` (and keeps holding it).
} }
struct TestSturdyRefHostId { struct TestSturdyRefHostId {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment