Commit 72af96df authored by Kenton Varda's avatar Kenton Varda

Implement opt-in cancellation for local calls. (Previously, cancellation would…

Implement opt-in cancellation for local calls.  (Previously, cancellation would happen whether or not the callee had opted in.)
parent 55d661f7
...@@ -185,6 +185,90 @@ TEST(Capability, TailCall) { ...@@ -185,6 +185,90 @@ TEST(Capability, TailCall) {
EXPECT_EQ(1, callerCallCount); EXPECT_EQ(1, callerCallCount);
} }
TEST(Capability, AsyncCancelation) {
// Tests allowAsyncCancellation().
kj::SimpleEventLoop loop;
auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false;
auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; });
destructionPromise.eagerlyEvaluate(loop);
int callCount = 0;
test::TestMoreStuff::Client client(kj::heap<TestMoreStuffImpl>(callCount), loop);
kj::Promise<void> promise = nullptr;
bool returned = false;
{
auto request = client.expectAsyncCancelRequest();
request.setCap(test::TestInterface::Client(
kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller)), loop));
promise = loop.there(request.send(),
[&](Response<test::TestMoreStuff::ExpectAsyncCancelResults>&& response) {
returned = true;
});
promise.eagerlyEvaluate(loop);
}
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
// We can detect that the method was canceled because it will drop the cap.
EXPECT_FALSE(destroyed);
EXPECT_FALSE(returned);
promise = nullptr; // request cancellation
loop.wait(kj::mv(destructionPromise));
EXPECT_TRUE(destroyed);
EXPECT_FALSE(returned);
}
TEST(Capability, SyncCancelation) {
// Tests isCanceled() without allowAsyncCancellation().
kj::SimpleEventLoop loop;
int callCount = 0;
int innerCallCount = 0;
test::TestMoreStuff::Client client(kj::heap<TestMoreStuffImpl>(callCount), loop);
kj::Promise<void> promise = nullptr;
bool returned = false;
{
auto request = client.expectSyncCancelRequest();
request.setCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(innerCallCount), loop));
promise = loop.there(request.send(),
[&](Response<test::TestMoreStuff::ExpectSyncCancelResults>&& response) {
returned = true;
});
promise.eagerlyEvaluate(loop);
}
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
// expectSyncCancel() will make a call to the TestInterfaceImpl only once it noticed isCanceled()
// is true.
EXPECT_EQ(0, innerCallCount);
EXPECT_FALSE(returned);
promise = nullptr; // request cancellation
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
EXPECT_EQ(1, innerCallCount);
EXPECT_FALSE(returned);
}
// ======================================================================================= // =======================================================================================
TEST(Capability, DynamicClient) { TEST(Capability, DynamicClient) {
......
...@@ -88,8 +88,10 @@ public: ...@@ -88,8 +88,10 @@ public:
class LocalCallContext final: public CallContextHook, public kj::Refcounted { class LocalCallContext final: public CallContextHook, public kj::Refcounted {
public: public:
LocalCallContext(kj::Own<LocalMessage>&& request, kj::Own<const ClientHook> clientRef) LocalCallContext(kj::Own<LocalMessage>&& request, kj::Own<const ClientHook> clientRef,
: request(kj::mv(request)), clientRef(kj::mv(clientRef)) {} kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller)
: request(kj::mv(request)), clientRef(kj::mv(clientRef)),
cancelAllowedFulfiller(kj::mv(cancelAllowedFulfiller)) {}
ObjectPointer::Reader getParams() override { ObjectPointer::Reader getParams() override {
KJ_IF_MAYBE(r, request) { KJ_IF_MAYBE(r, request) {
...@@ -134,12 +136,11 @@ public: ...@@ -134,12 +136,11 @@ public:
return kj::mv(paf.promise); return kj::mv(paf.promise);
} }
void allowAsyncCancellation() override { void allowAsyncCancellation() override {
releaseParams(); KJ_REQUIRE(request == nullptr, "Must call releaseParams() before allowAsyncCancellation().");
cancelAllowedFulfiller->fulfill();
// TODO(soon): Implement.
} }
bool isCanceled() override { bool isCanceled() override {
return false; return cancelRequested;
} }
kj::Own<CallContextHook> addRef() override { kj::Own<CallContextHook> addRef() override {
return kj::addRef(*this); return kj::addRef(*this);
...@@ -150,6 +151,21 @@ public: ...@@ -150,6 +151,21 @@ public:
ObjectPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null ObjectPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null
kj::Own<const ClientHook> clientRef; kj::Own<const ClientHook> clientRef;
kj::Maybe<kj::Own<kj::PromiseFulfiller<ObjectPointer::Pipeline>>> tailCallPipelineFulfiller; kj::Maybe<kj::Own<kj::PromiseFulfiller<ObjectPointer::Pipeline>>> tailCallPipelineFulfiller;
kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller;
bool cancelRequested = false;
class Canceler {
public:
Canceler(kj::Own<LocalCallContext>&& context): context(kj::mv(context)) {}
Canceler(Canceler&&) = default;
~Canceler() {
if (context) context->cancelRequested = true;
}
private:
kj::Own<LocalCallContext> context;
};
}; };
class LocalRequest final: public RequestHook { class LocalRequest final: public RequestHook {
...@@ -168,18 +184,37 @@ public: ...@@ -168,18 +184,37 @@ public:
uint64_t interfaceId = this->interfaceId; uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId; uint16_t methodId = this->methodId;
auto context = kj::refcounted<LocalCallContext>(kj::mv(message), client->addRef()); auto cancelPaf = kj::newPromiseAndFulfiller<void>();
auto context = kj::refcounted<LocalCallContext>(
kj::mv(message), client->addRef(), kj::mv(cancelPaf.fulfiller));
auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context)); auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context));
auto promise = loop.there(kj::mv(promiseAndPipeline.promise), // We have to make sure the call is not canceled unless permitted. We need to fork the promise
kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) { // so that if the client drops their copy, the promise isn't necessarily canceled.
KJ_IF_MAYBE(r, context->response) { auto forked = loop.fork(kj::mv(promiseAndPipeline.promise));
return kj::mv(*r);
} else { // We daemonize one branch, but only after joining it with the promise that fires if
KJ_FAIL_ASSERT("Method implementation failed to fill in results."); // cancellation is allowed.
} auto daemonPromise = forked.addBranch();
})); daemonPromise.attach(kj::addRef(*context));
daemonPromise = loop.exclusiveJoin(kj::mv(cancelPaf.promise), kj::mv(daemonPromise));
// Ignore exceptions.
daemonPromise = loop.there(kj::mv(daemonPromise), []() {}, [](kj::Exception&&) {});
loop.daemonize(kj::mv(daemonPromise));
// Now the other branch returns the response from the context.
auto contextPtr = context.get();
auto promise = loop.there(forked.addBranch(), [contextPtr]() {
contextPtr->getResults(1); // force response allocation
return kj::mv(KJ_ASSERT_NONNULL(contextPtr->response));
});
// We also want to notify the context that cancellation was requested in this branch is
// destroyed.
promise.attach(LocalCallContext::Canceler(kj::mv(context)));
// We return the other branch.
return RemotePromise<ObjectPointer>( return RemotePromise<ObjectPointer>(
kj::mv(promise), ObjectPointer::Pipeline(kj::mv(promiseAndPipeline.pipeline))); kj::mv(promise), ObjectPointer::Pipeline(kj::mv(promiseAndPipeline.pipeline)));
} }
......
...@@ -250,15 +250,9 @@ public: ...@@ -250,15 +250,9 @@ public:
// executing on a local thread. The method must perform an asynchronous operation or call // executing on a local thread. The method must perform an asynchronous operation or call
// `EventLoop::current().runLater()` to yield control. // `EventLoop::current().runLater()` to yield control.
// //
// This method implies `releaseParams()` -- you cannot allow async cancellation while still // Currently, you must call `releaseParams()` before `allowAsyncCancellation()`, otherwise the
// holding the params. (This is because of a quirk of the current RPC implementation; in theory // latter will throw an exception. This is a limitation of the current RPC implementation, but
// it could be fixed.) // this requirement could be lifted in the future.
//
// TODO(soon): This doesn't work for local calls, because there's no one to own the object
// in the meantime. What do we do about that? Is the security issue here actually a real
// threat? Maybe we can just always enable cancellation. After all, you need to be fault
// tolerant and exception-safe, and those are pretty similar to being cancel-tolerant, though
// with less direct control by the attacker...
bool isCanceled(); bool isCanceled();
// As an alternative to `allowAsyncCancellation()`, a server can call this to check for // As an alternative to `allowAsyncCancellation()`, a server can call this to check for
...@@ -358,6 +352,10 @@ public: ...@@ -358,6 +352,10 @@ public:
// pipelined calls are waiting for it; the call is only truly done when the CallContextHook is // pipelined calls are waiting for it; the call is only truly done when the CallContextHook is
// destroyed. // destroyed.
// //
// Since the caller of this method chooses the CallContext implementation, it is the caller's
// responsibility to ensure that the returned promise is not canceled unless allowed via
// the context's `allowAsyncCancellation()`.
//
// The call must not begin synchronously, as the caller may hold arbitrary mutexes. // The call must not begin synchronously, as the caller may hold arbitrary mutexes.
virtual kj::Maybe<const ClientHook&> getResolved() const = 0; virtual kj::Maybe<const ClientHook&> getResolved() const = 0;
......
...@@ -314,7 +314,7 @@ public: ...@@ -314,7 +314,7 @@ public:
private: private:
TestNetworkAdapter& network; TestNetworkAdapter& network;
RpcDumper::Sender sender; RpcDumper::Sender sender KJ_UNUSED_MEMBER;
kj::Maybe<ConnectionImpl&> partner; kj::Maybe<ConnectionImpl&> partner;
struct Queues { struct Queues {
...@@ -559,6 +559,91 @@ TEST_F(RpcTest, TailCall) { ...@@ -559,6 +559,91 @@ TEST_F(RpcTest, TailCall) {
EXPECT_EQ(1, restorer.callCount); EXPECT_EQ(1, restorer.callCount);
} }
TEST_F(RpcTest, AsyncCancelation) {
// Tests allowAsyncCancellation().
auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false;
auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; });
destructionPromise.eagerlyEvaluate(loop);
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
kj::Promise<void> promise = nullptr;
bool returned = false;
{
auto request = client.expectAsyncCancelRequest();
request.setCap(test::TestInterface::Client(
kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller)), loop));
promise = loop.there(request.send(),
[&](Response<test::TestMoreStuff::ExpectAsyncCancelResults>&& response) {
returned = true;
});
promise.eagerlyEvaluate(loop);
}
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
// We can detect that the method was canceled because it will drop the cap.
EXPECT_FALSE(destroyed);
EXPECT_FALSE(returned);
promise = nullptr; // request cancellation
loop.wait(kj::mv(destructionPromise));
EXPECT_TRUE(destroyed);
EXPECT_FALSE(returned);
}
TEST_F(RpcTest, SyncCancelation) {
// Tests isCanceled() without allowAsyncCancellation().
int innerCallCount = 0;
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
kj::Promise<void> promise = nullptr;
bool returned = false;
{
auto request = client.expectSyncCancelRequest();
request.setCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(innerCallCount), loop));
promise = loop.there(request.send(),
[&](Response<test::TestMoreStuff::ExpectSyncCancelResults>&& response) {
returned = true;
});
promise.eagerlyEvaluate(loop);
}
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
// expectSyncCancel() will make a call to the TestInterfaceImpl only once it noticed isCanceled()
// is true.
EXPECT_EQ(0, innerCallCount);
EXPECT_FALSE(returned);
promise = nullptr; // request cancellation
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
EXPECT_EQ(1, innerCallCount);
EXPECT_FALSE(returned);
}
TEST_F(RpcTest, PromiseResolve) { TEST_F(RpcTest, PromiseResolve) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>(); .castAs<test::TestMoreStuff>();
...@@ -596,25 +681,6 @@ TEST_F(RpcTest, PromiseResolve) { ...@@ -596,25 +681,6 @@ TEST_F(RpcTest, PromiseResolve) {
EXPECT_EQ(2, chainedCallCount); EXPECT_EQ(2, chainedCallCount);
} }
class TestCapDestructor final: public test::TestInterface::Server {
public:
TestCapDestructor(kj::Own<kj::PromiseFulfiller<void>>&& fulfiller)
: fulfiller(kj::mv(fulfiller)), impl(dummy) {}
~TestCapDestructor() {
fulfiller->fulfill();
}
kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) {
return impl.foo(params, result);
}
private:
kj::Own<kj::PromiseFulfiller<void>> fulfiller;
int dummy = 0;
TestInterfaceImpl impl;
};
TEST_F(RpcTest, RetainAndRelease) { TEST_F(RpcTest, RetainAndRelease) {
auto paf = kj::newPromiseAndFulfiller<void>(); auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false; bool destroyed = false;
......
...@@ -2113,7 +2113,7 @@ private: ...@@ -2113,7 +2113,7 @@ private:
// a call started doesn't really need to hold the call open. To fix this we'd presumably // a call started doesn't really need to hold the call open. To fix this we'd presumably
// need to make the answer table snapshot-able and have CapExtractorImpl take a snapshot // need to make the answer table snapshot-able and have CapExtractorImpl take a snapshot
// at creation. // at creation.
releaseParams(); KJ_REQUIRE(request == nullptr, "Must call releaseParams() before allowAsyncCancellation().");
if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) == if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) ==
CANCEL_REQUESTED) { CANCEL_REQUESTED) {
......
...@@ -1058,5 +1058,53 @@ kj::Promise<void> TestMoreStuffImpl::echo(EchoParams::Reader params, EchoResults ...@@ -1058,5 +1058,53 @@ kj::Promise<void> TestMoreStuffImpl::echo(EchoParams::Reader params, EchoResults
return kj::READY_NOW; return kj::READY_NOW;
} }
kj::Promise<void> TestMoreStuffImpl::expectAsyncCancelAdvanced(
CallContext<ExpectAsyncCancelParams, ExpectAsyncCancelResults> context) {
auto cap = context.getParams().getCap();
context.releaseParams();
context.allowAsyncCancellation();
return loop(0, cap, context);
}
kj::Promise<void> TestMoreStuffImpl::loop(uint depth, test::TestInterface::Client cap,
CallContext<ExpectAsyncCancelParams, ExpectAsyncCancelResults> context) {
if (depth > 100) {
ADD_FAILURE() << "Looped too long, giving up.";
return kj::READY_NOW;
} else {
return kj::EventLoop::current().evalLater([=]() mutable {
return loop(depth + 1, cap, context);
});
}
}
kj::Promise<void> TestMoreStuffImpl::expectSyncCancelAdvanced(
CallContext<ExpectSyncCancelParams, ExpectSyncCancelResults> context) {
auto cap = context.getParams().getCap();
context.releaseParams();
return loop(0, cap, context);
}
kj::Promise<void> TestMoreStuffImpl::loop(uint depth, test::TestInterface::Client cap,
CallContext<ExpectSyncCancelParams, ExpectSyncCancelResults> context) {
if (depth > 100) {
ADD_FAILURE() << "Looped too long, giving up.";
return kj::READY_NOW;
} else if (context.isCanceled()) {
auto request = cap.fooRequest();
request.setI(123);
request.setJ(true);
return request.send().then(
[](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
});
} else {
return kj::EventLoop::current().evalLater([=]() mutable {
return loop(depth + 1, cap, context);
});
}
}
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -239,10 +239,42 @@ public: ...@@ -239,10 +239,42 @@ public:
kj::Promise<void> echo(EchoParams::Reader params, EchoResults::Builder result) override; kj::Promise<void> echo(EchoParams::Reader params, EchoResults::Builder result) override;
kj::Promise<void> expectAsyncCancelAdvanced(
CallContext<ExpectAsyncCancelParams, ExpectAsyncCancelResults> context) override;
kj::Promise<void> expectSyncCancelAdvanced(
CallContext<ExpectSyncCancelParams, ExpectSyncCancelResults> context) override;
private: private:
int& callCount; int& callCount;
kj::Own<kj::PromiseFulfiller<void>> neverFulfill; kj::Own<kj::PromiseFulfiller<void>> neverFulfill;
test::TestInterface::Client clientToHold = nullptr; test::TestInterface::Client clientToHold = nullptr;
kj::Promise<void> loop(uint depth, test::TestInterface::Client cap,
CallContext<ExpectAsyncCancelParams, ExpectAsyncCancelResults> context);
kj::Promise<void> loop(uint depth, test::TestInterface::Client cap,
CallContext<ExpectSyncCancelParams, ExpectSyncCancelResults> context);
};
class TestCapDestructor final: public test::TestInterface::Server {
// Implementation of TestInterface that notifies when it is destroyed.
public:
TestCapDestructor(kj::Own<kj::PromiseFulfiller<void>>&& fulfiller)
: fulfiller(kj::mv(fulfiller)), impl(dummy) {}
~TestCapDestructor() {
fulfiller->fulfill();
}
kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) {
return impl.foo(params, result);
}
private:
kj::Own<kj::PromiseFulfiller<void>> fulfiller;
int dummy = 0;
TestInterfaceImpl impl;
}; };
} // namespace _ (private) } // namespace _ (private)
......
...@@ -654,6 +654,13 @@ interface TestMoreStuff extends(TestCallOrder) { ...@@ -654,6 +654,13 @@ interface TestMoreStuff extends(TestCallOrder) {
echo @6 (cap :TestCallOrder) -> (cap :TestCallOrder); echo @6 (cap :TestCallOrder) -> (cap :TestCallOrder);
# Just returns the input cap. # Just returns the input cap.
expectAsyncCancel @7 (cap :TestInterface) -> ();
# evalLater()-loops forever, holding `cap`. Must be canceled.
expectSyncCancel @8 (cap :TestInterface) -> ();
# evalLater()-loops until context.isCanceled() returns true, then makes a call to `cap` before
# returning.
} }
struct TestSturdyRefHostId { struct TestSturdyRefHostId {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment