Commit 72af96df authored by Kenton Varda's avatar Kenton Varda

Implement opt-in cancellation for local calls. (Previously, cancellation would…

Implement opt-in cancellation for local calls.  (Previously, cancellation would happen whether or not the callee had opted in.)
parent 55d661f7
......@@ -185,6 +185,90 @@ TEST(Capability, TailCall) {
EXPECT_EQ(1, callerCallCount);
}
TEST(Capability, AsyncCancelation) {
// Tests allowAsyncCancellation().
kj::SimpleEventLoop loop;
auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false;
auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; });
destructionPromise.eagerlyEvaluate(loop);
int callCount = 0;
test::TestMoreStuff::Client client(kj::heap<TestMoreStuffImpl>(callCount), loop);
kj::Promise<void> promise = nullptr;
bool returned = false;
{
auto request = client.expectAsyncCancelRequest();
request.setCap(test::TestInterface::Client(
kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller)), loop));
promise = loop.there(request.send(),
[&](Response<test::TestMoreStuff::ExpectAsyncCancelResults>&& response) {
returned = true;
});
promise.eagerlyEvaluate(loop);
}
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
// We can detect that the method was canceled because it will drop the cap.
EXPECT_FALSE(destroyed);
EXPECT_FALSE(returned);
promise = nullptr; // request cancellation
loop.wait(kj::mv(destructionPromise));
EXPECT_TRUE(destroyed);
EXPECT_FALSE(returned);
}
TEST(Capability, SyncCancelation) {
// Tests isCanceled() without allowAsyncCancellation().
kj::SimpleEventLoop loop;
int callCount = 0;
int innerCallCount = 0;
test::TestMoreStuff::Client client(kj::heap<TestMoreStuffImpl>(callCount), loop);
kj::Promise<void> promise = nullptr;
bool returned = false;
{
auto request = client.expectSyncCancelRequest();
request.setCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(innerCallCount), loop));
promise = loop.there(request.send(),
[&](Response<test::TestMoreStuff::ExpectSyncCancelResults>&& response) {
returned = true;
});
promise.eagerlyEvaluate(loop);
}
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
// expectSyncCancel() will make a call to the TestInterfaceImpl only once it noticed isCanceled()
// is true.
EXPECT_EQ(0, innerCallCount);
EXPECT_FALSE(returned);
promise = nullptr; // request cancellation
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
EXPECT_EQ(1, innerCallCount);
EXPECT_FALSE(returned);
}
// =======================================================================================
TEST(Capability, DynamicClient) {
......
......@@ -88,8 +88,10 @@ public:
class LocalCallContext final: public CallContextHook, public kj::Refcounted {
public:
LocalCallContext(kj::Own<LocalMessage>&& request, kj::Own<const ClientHook> clientRef)
: request(kj::mv(request)), clientRef(kj::mv(clientRef)) {}
LocalCallContext(kj::Own<LocalMessage>&& request, kj::Own<const ClientHook> clientRef,
kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller)
: request(kj::mv(request)), clientRef(kj::mv(clientRef)),
cancelAllowedFulfiller(kj::mv(cancelAllowedFulfiller)) {}
ObjectPointer::Reader getParams() override {
KJ_IF_MAYBE(r, request) {
......@@ -134,12 +136,11 @@ public:
return kj::mv(paf.promise);
}
void allowAsyncCancellation() override {
releaseParams();
// TODO(soon): Implement.
KJ_REQUIRE(request == nullptr, "Must call releaseParams() before allowAsyncCancellation().");
cancelAllowedFulfiller->fulfill();
}
bool isCanceled() override {
return false;
return cancelRequested;
}
kj::Own<CallContextHook> addRef() override {
return kj::addRef(*this);
......@@ -150,6 +151,21 @@ public:
ObjectPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null
kj::Own<const ClientHook> clientRef;
kj::Maybe<kj::Own<kj::PromiseFulfiller<ObjectPointer::Pipeline>>> tailCallPipelineFulfiller;
kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller;
bool cancelRequested = false;
class Canceler {
public:
Canceler(kj::Own<LocalCallContext>&& context): context(kj::mv(context)) {}
Canceler(Canceler&&) = default;
~Canceler() {
if (context) context->cancelRequested = true;
}
private:
kj::Own<LocalCallContext> context;
};
};
class LocalRequest final: public RequestHook {
......@@ -168,18 +184,37 @@ public:
uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId;
auto context = kj::refcounted<LocalCallContext>(kj::mv(message), client->addRef());
auto cancelPaf = kj::newPromiseAndFulfiller<void>();
auto context = kj::refcounted<LocalCallContext>(
kj::mv(message), client->addRef(), kj::mv(cancelPaf.fulfiller));
auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context));
auto promise = loop.there(kj::mv(promiseAndPipeline.promise),
kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
KJ_IF_MAYBE(r, context->response) {
return kj::mv(*r);
} else {
KJ_FAIL_ASSERT("Method implementation failed to fill in results.");
}
}));
// We have to make sure the call is not canceled unless permitted. We need to fork the promise
// so that if the client drops their copy, the promise isn't necessarily canceled.
auto forked = loop.fork(kj::mv(promiseAndPipeline.promise));
// We daemonize one branch, but only after joining it with the promise that fires if
// cancellation is allowed.
auto daemonPromise = forked.addBranch();
daemonPromise.attach(kj::addRef(*context));
daemonPromise = loop.exclusiveJoin(kj::mv(cancelPaf.promise), kj::mv(daemonPromise));
// Ignore exceptions.
daemonPromise = loop.there(kj::mv(daemonPromise), []() {}, [](kj::Exception&&) {});
loop.daemonize(kj::mv(daemonPromise));
// Now the other branch returns the response from the context.
auto contextPtr = context.get();
auto promise = loop.there(forked.addBranch(), [contextPtr]() {
contextPtr->getResults(1); // force response allocation
return kj::mv(KJ_ASSERT_NONNULL(contextPtr->response));
});
// We also want to notify the context that cancellation was requested in this branch is
// destroyed.
promise.attach(LocalCallContext::Canceler(kj::mv(context)));
// We return the other branch.
return RemotePromise<ObjectPointer>(
kj::mv(promise), ObjectPointer::Pipeline(kj::mv(promiseAndPipeline.pipeline)));
}
......
......@@ -250,15 +250,9 @@ public:
// executing on a local thread. The method must perform an asynchronous operation or call
// `EventLoop::current().runLater()` to yield control.
//
// This method implies `releaseParams()` -- you cannot allow async cancellation while still
// holding the params. (This is because of a quirk of the current RPC implementation; in theory
// it could be fixed.)
//
// TODO(soon): This doesn't work for local calls, because there's no one to own the object
// in the meantime. What do we do about that? Is the security issue here actually a real
// threat? Maybe we can just always enable cancellation. After all, you need to be fault
// tolerant and exception-safe, and those are pretty similar to being cancel-tolerant, though
// with less direct control by the attacker...
// Currently, you must call `releaseParams()` before `allowAsyncCancellation()`, otherwise the
// latter will throw an exception. This is a limitation of the current RPC implementation, but
// this requirement could be lifted in the future.
bool isCanceled();
// As an alternative to `allowAsyncCancellation()`, a server can call this to check for
......@@ -358,6 +352,10 @@ public:
// pipelined calls are waiting for it; the call is only truly done when the CallContextHook is
// destroyed.
//
// Since the caller of this method chooses the CallContext implementation, it is the caller's
// responsibility to ensure that the returned promise is not canceled unless allowed via
// the context's `allowAsyncCancellation()`.
//
// The call must not begin synchronously, as the caller may hold arbitrary mutexes.
virtual kj::Maybe<const ClientHook&> getResolved() const = 0;
......
......@@ -314,7 +314,7 @@ public:
private:
TestNetworkAdapter& network;
RpcDumper::Sender sender;
RpcDumper::Sender sender KJ_UNUSED_MEMBER;
kj::Maybe<ConnectionImpl&> partner;
struct Queues {
......@@ -559,6 +559,91 @@ TEST_F(RpcTest, TailCall) {
EXPECT_EQ(1, restorer.callCount);
}
TEST_F(RpcTest, AsyncCancelation) {
// Tests allowAsyncCancellation().
auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false;
auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; });
destructionPromise.eagerlyEvaluate(loop);
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
kj::Promise<void> promise = nullptr;
bool returned = false;
{
auto request = client.expectAsyncCancelRequest();
request.setCap(test::TestInterface::Client(
kj::heap<TestCapDestructor>(kj::mv(paf.fulfiller)), loop));
promise = loop.there(request.send(),
[&](Response<test::TestMoreStuff::ExpectAsyncCancelResults>&& response) {
returned = true;
});
promise.eagerlyEvaluate(loop);
}
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
// We can detect that the method was canceled because it will drop the cap.
EXPECT_FALSE(destroyed);
EXPECT_FALSE(returned);
promise = nullptr; // request cancellation
loop.wait(kj::mv(destructionPromise));
EXPECT_TRUE(destroyed);
EXPECT_FALSE(returned);
}
TEST_F(RpcTest, SyncCancelation) {
// Tests isCanceled() without allowAsyncCancellation().
int innerCallCount = 0;
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
kj::Promise<void> promise = nullptr;
bool returned = false;
{
auto request = client.expectSyncCancelRequest();
request.setCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(innerCallCount), loop));
promise = loop.there(request.send(),
[&](Response<test::TestMoreStuff::ExpectSyncCancelResults>&& response) {
returned = true;
});
promise.eagerlyEvaluate(loop);
}
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
// expectSyncCancel() will make a call to the TestInterfaceImpl only once it noticed isCanceled()
// is true.
EXPECT_EQ(0, innerCallCount);
EXPECT_FALSE(returned);
promise = nullptr; // request cancellation
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
loop.wait(loop.evalLater([]() {}));
EXPECT_EQ(1, innerCallCount);
EXPECT_FALSE(returned);
}
TEST_F(RpcTest, PromiseResolve) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
......@@ -596,25 +681,6 @@ TEST_F(RpcTest, PromiseResolve) {
EXPECT_EQ(2, chainedCallCount);
}
class TestCapDestructor final: public test::TestInterface::Server {
public:
TestCapDestructor(kj::Own<kj::PromiseFulfiller<void>>&& fulfiller)
: fulfiller(kj::mv(fulfiller)), impl(dummy) {}
~TestCapDestructor() {
fulfiller->fulfill();
}
kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) {
return impl.foo(params, result);
}
private:
kj::Own<kj::PromiseFulfiller<void>> fulfiller;
int dummy = 0;
TestInterfaceImpl impl;
};
TEST_F(RpcTest, RetainAndRelease) {
auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false;
......
......@@ -2113,7 +2113,7 @@ private:
// a call started doesn't really need to hold the call open. To fix this we'd presumably
// need to make the answer table snapshot-able and have CapExtractorImpl take a snapshot
// at creation.
releaseParams();
KJ_REQUIRE(request == nullptr, "Must call releaseParams() before allowAsyncCancellation().");
if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) ==
CANCEL_REQUESTED) {
......
......@@ -1058,5 +1058,53 @@ kj::Promise<void> TestMoreStuffImpl::echo(EchoParams::Reader params, EchoResults
return kj::READY_NOW;
}
kj::Promise<void> TestMoreStuffImpl::expectAsyncCancelAdvanced(
CallContext<ExpectAsyncCancelParams, ExpectAsyncCancelResults> context) {
auto cap = context.getParams().getCap();
context.releaseParams();
context.allowAsyncCancellation();
return loop(0, cap, context);
}
kj::Promise<void> TestMoreStuffImpl::loop(uint depth, test::TestInterface::Client cap,
CallContext<ExpectAsyncCancelParams, ExpectAsyncCancelResults> context) {
if (depth > 100) {
ADD_FAILURE() << "Looped too long, giving up.";
return kj::READY_NOW;
} else {
return kj::EventLoop::current().evalLater([=]() mutable {
return loop(depth + 1, cap, context);
});
}
}
kj::Promise<void> TestMoreStuffImpl::expectSyncCancelAdvanced(
CallContext<ExpectSyncCancelParams, ExpectSyncCancelResults> context) {
auto cap = context.getParams().getCap();
context.releaseParams();
return loop(0, cap, context);
}
kj::Promise<void> TestMoreStuffImpl::loop(uint depth, test::TestInterface::Client cap,
CallContext<ExpectSyncCancelParams, ExpectSyncCancelResults> context) {
if (depth > 100) {
ADD_FAILURE() << "Looped too long, giving up.";
return kj::READY_NOW;
} else if (context.isCanceled()) {
auto request = cap.fooRequest();
request.setI(123);
request.setJ(true);
return request.send().then(
[](Response<test::TestInterface::FooResults>&& response) mutable {
EXPECT_EQ("foo", response.getX());
});
} else {
return kj::EventLoop::current().evalLater([=]() mutable {
return loop(depth + 1, cap, context);
});
}
}
} // namespace _ (private)
} // namespace capnp
......@@ -239,10 +239,42 @@ public:
kj::Promise<void> echo(EchoParams::Reader params, EchoResults::Builder result) override;
kj::Promise<void> expectAsyncCancelAdvanced(
CallContext<ExpectAsyncCancelParams, ExpectAsyncCancelResults> context) override;
kj::Promise<void> expectSyncCancelAdvanced(
CallContext<ExpectSyncCancelParams, ExpectSyncCancelResults> context) override;
private:
int& callCount;
kj::Own<kj::PromiseFulfiller<void>> neverFulfill;
test::TestInterface::Client clientToHold = nullptr;
kj::Promise<void> loop(uint depth, test::TestInterface::Client cap,
CallContext<ExpectAsyncCancelParams, ExpectAsyncCancelResults> context);
kj::Promise<void> loop(uint depth, test::TestInterface::Client cap,
CallContext<ExpectSyncCancelParams, ExpectSyncCancelResults> context);
};
class TestCapDestructor final: public test::TestInterface::Server {
// Implementation of TestInterface that notifies when it is destroyed.
public:
TestCapDestructor(kj::Own<kj::PromiseFulfiller<void>>&& fulfiller)
: fulfiller(kj::mv(fulfiller)), impl(dummy) {}
~TestCapDestructor() {
fulfiller->fulfill();
}
kj::Promise<void> foo(FooParams::Reader params, FooResults::Builder result) {
return impl.foo(params, result);
}
private:
kj::Own<kj::PromiseFulfiller<void>> fulfiller;
int dummy = 0;
TestInterfaceImpl impl;
};
} // namespace _ (private)
......
......@@ -654,6 +654,13 @@ interface TestMoreStuff extends(TestCallOrder) {
echo @6 (cap :TestCallOrder) -> (cap :TestCallOrder);
# Just returns the input cap.
expectAsyncCancel @7 (cap :TestInterface) -> ();
# evalLater()-loops forever, holding `cap`. Must be canceled.
expectSyncCancel @8 (cap :TestInterface) -> ();
# evalLater()-loops until context.isCanceled() returns true, then makes a call to `cap` before
# returning.
}
struct TestSturdyRefHostId {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment