Unverified Commit fd949ff9 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #608 from capnproto/revocable-membrane

Extend membrane framework to make revocation easy.
parents ff4cd3d7 4b28ee7f
...@@ -88,10 +88,19 @@ protected: ...@@ -88,10 +88,19 @@ protected:
context.getResults().setThing(context.getParams().getThing()); context.getResults().setThing(context.getParams().getThing());
return kj::READY_NOW; return kj::READY_NOW;
} }
kj::Promise<void> waitForever(WaitForeverContext context) override {
context.allowCancellation();
return kj::NEVER_DONE;
}
}; };
class MembranePolicyImpl: public MembranePolicy, public kj::Refcounted { class MembranePolicyImpl: public MembranePolicy, public kj::Refcounted {
public: public:
MembranePolicyImpl() = default;
MembranePolicyImpl(kj::Maybe<kj::Promise<void>> revokePromise)
: revokePromise(revokePromise.map([](kj::Promise<void>& p) { return p.fork(); })) {}
kj::Maybe<Capability::Client> inboundCall(uint64_t interfaceId, uint16_t methodId, kj::Maybe<Capability::Client> inboundCall(uint64_t interfaceId, uint16_t methodId,
Capability::Client target) override { Capability::Client target) override {
if (interfaceId == capnp::typeId<Thing>() && methodId == 1) { if (interfaceId == capnp::typeId<Thing>() && methodId == 1) {
...@@ -113,6 +122,15 @@ public: ...@@ -113,6 +122,15 @@ public:
kj::Own<MembranePolicy> addRef() override { kj::Own<MembranePolicy> addRef() override {
return kj::addRef(*this); return kj::addRef(*this);
} }
kj::Maybe<kj::Promise<void>> onRevoked() override {
return revokePromise.map([](kj::ForkedPromise<void>& fork) {
return fork.addBranch();
});
}
private:
kj::Maybe<kj::ForkedPromise<void>> revokePromise;
}; };
void testThingImpl(kj::WaitScope& waitScope, test::TestMembrane::Client membraned, void testThingImpl(kj::WaitScope& waitScope, test::TestMembrane::Client membraned,
...@@ -265,12 +283,13 @@ struct TestRpcEnv { ...@@ -265,12 +283,13 @@ struct TestRpcEnv {
TwoPartyClient server; TwoPartyClient server;
test::TestMembrane::Client membraned; test::TestMembrane::Client membraned;
TestRpcEnv() TestRpcEnv(kj::Maybe<kj::Promise<void>> revokePromise = nullptr)
: io(kj::setupAsyncIo()), : io(kj::setupAsyncIo()),
pipe(io.provider->newTwoWayPipe()), pipe(io.provider->newTwoWayPipe()),
client(*pipe.ends[0]), client(*pipe.ends[0]),
server(*pipe.ends[1], server(*pipe.ends[1],
membrane(kj::heap<TestMembraneImpl>(), kj::refcounted<MembranePolicyImpl>()), membrane(kj::heap<TestMembraneImpl>(),
kj::refcounted<MembranePolicyImpl>(kj::mv(revokePromise))),
rpc::twoparty::Side::SERVER), rpc::twoparty::Side::SERVER),
membraned(client.bootstrap().castAs<test::TestMembrane>()) {} membraned(client.bootstrap().castAs<test::TestMembrane>()) {}
...@@ -330,6 +349,29 @@ KJ_TEST("call remote promise pointing into membrane that eventually resolves to ...@@ -330,6 +349,29 @@ KJ_TEST("call remote promise pointing into membrane that eventually resolves to
}, "outside", "outside", "outside", "outbound"); }, "outside", "outside", "outside", "outbound");
} }
KJ_TEST("revoke membrane") {
auto paf = kj::newPromiseAndFulfiller<void>();
TestRpcEnv env(kj::mv(paf.promise));
auto thing = env.membraned.makeThingRequest().send().wait(env.io.waitScope).getThing();
auto callPromise = env.membraned.waitForeverRequest().send();
KJ_EXPECT(!callPromise.poll(env.io.waitScope));
paf.fulfiller->reject(KJ_EXCEPTION(DISCONNECTED, "foobar"));
KJ_ASSERT(callPromise.poll(env.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("foobar", callPromise.wait(env.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("foobar",
env.membraned.makeThingRequest().send().wait(env.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("foobar",
thing.passThroughRequest().send().wait(env.io.waitScope));
}
} // namespace } // namespace
} // namespace _ } // namespace _
} // namespace capnp } // namespace capnp
...@@ -198,6 +198,8 @@ public: ...@@ -198,6 +198,8 @@ public:
auto newPipeline = AnyPointer::Pipeline(kj::refcounted<MembranePipelineHook>( auto newPipeline = AnyPointer::Pipeline(kj::refcounted<MembranePipelineHook>(
PipelineHook::from(kj::mv(promise)), policy->addRef(), reverse)); PipelineHook::from(kj::mv(promise)), policy->addRef(), reverse));
auto onRevoked = policy->onRevoked();
bool reverse = this->reverse; // for capture bool reverse = this->reverse; // for capture
auto newPromise = promise.then(kj::mvCapture(policy, auto newPromise = promise.then(kj::mvCapture(policy,
[reverse](kj::Own<MembranePolicy>&& policy, Response<AnyPointer>&& response) { [reverse](kj::Own<MembranePolicy>&& policy, Response<AnyPointer>&& response) {
...@@ -208,6 +210,12 @@ public: ...@@ -208,6 +210,12 @@ public:
return Response<AnyPointer>(reader, kj::mv(newRespHook)); return Response<AnyPointer>(reader, kj::mv(newRespHook));
})); }));
KJ_IF_MAYBE(r, kj::mv(onRevoked)) {
newPromise = newPromise.exclusiveJoin(r->then([]() -> Response<AnyPointer> {
KJ_FAIL_REQUIRE("onRevoked() promise resolved; it should only reject");
}));
}
return RemotePromise<AnyPointer>(kj::mv(newPromise), kj::mv(newPipeline)); return RemotePromise<AnyPointer>(kj::mv(newPromise), kj::mv(newPipeline));
} }
...@@ -301,8 +309,14 @@ private: ...@@ -301,8 +309,14 @@ private:
class MembraneHook final: public ClientHook, public kj::Refcounted { class MembraneHook final: public ClientHook, public kj::Refcounted {
public: public:
MembraneHook(kj::Own<ClientHook>&& inner, kj::Own<MembranePolicy>&& policy, bool reverse) MembraneHook(kj::Own<ClientHook>&& inner, kj::Own<MembranePolicy>&& policyParam, bool reverse)
: inner(kj::mv(inner)), policy(kj::mv(policy)), reverse(reverse) {} : inner(kj::mv(inner)), policy(kj::mv(policyParam)), reverse(reverse) {
KJ_IF_MAYBE(r, policy->onRevoked()) {
revocationTask = r->eagerlyEvaluate([this](kj::Exception&& exception) {
this->inner = newBrokenCap(kj::mv(exception));
});
}
}
static kj::Own<ClientHook> wrap(ClientHook& cap, MembranePolicy& policy, bool reverse) { static kj::Own<ClientHook> wrap(ClientHook& cap, MembranePolicy& policy, bool reverse) {
if (cap.getBrand() == MEMBRANE_BRAND) { if (cap.getBrand() == MEMBRANE_BRAND) {
...@@ -381,6 +395,10 @@ public: ...@@ -381,6 +395,10 @@ public:
auto result = inner->call(interfaceId, methodId, auto result = inner->call(interfaceId, methodId,
kj::refcounted<MembraneCallContextHook>(kj::mv(context), policy->addRef(), !reverse)); kj::refcounted<MembraneCallContextHook>(kj::mv(context), policy->addRef(), !reverse));
KJ_IF_MAYBE(r, policy->onRevoked()) {
result.promise = result.promise.exclusiveJoin(kj::mv(*r));
}
return { return {
kj::mv(result.promise), kj::mv(result.promise),
kj::refcounted<MembranePipelineHook>(kj::mv(result.pipeline), policy->addRef(), reverse) kj::refcounted<MembranePipelineHook>(kj::mv(result.pipeline), policy->addRef(), reverse)
...@@ -409,6 +427,12 @@ public: ...@@ -409,6 +427,12 @@ public:
} }
KJ_IF_MAYBE(promise, inner->whenMoreResolved()) { KJ_IF_MAYBE(promise, inner->whenMoreResolved()) {
KJ_IF_MAYBE(r, policy->onRevoked()) {
*promise = promise->exclusiveJoin(r->then([]() -> kj::Own<ClientHook> {
KJ_FAIL_REQUIRE("onRevoked() promise resolved; it should only reject");
}));
}
return promise->then([this](kj::Own<ClientHook>&& newInner) { return promise->then([this](kj::Own<ClientHook>&& newInner) {
kj::Own<ClientHook> newResolved = wrap(*newInner, *policy, reverse); kj::Own<ClientHook> newResolved = wrap(*newInner, *policy, reverse);
if (resolved == nullptr) { if (resolved == nullptr) {
...@@ -434,6 +458,7 @@ private: ...@@ -434,6 +458,7 @@ private:
kj::Own<MembranePolicy> policy; kj::Own<MembranePolicy> policy;
bool reverse; bool reverse;
kj::Maybe<kj::Own<ClientHook>> resolved; kj::Maybe<kj::Own<ClientHook>> resolved;
kj::Promise<void> revocationTask = nullptr;
}; };
kj::Own<ClientHook> membrane(kj::Own<ClientHook> inner, MembranePolicy& policy, bool reverse) { kj::Own<ClientHook> membrane(kj::Own<ClientHook> inner, MembranePolicy& policy, bool reverse) {
......
...@@ -104,6 +104,16 @@ public: ...@@ -104,6 +104,16 @@ public:
// object actually to be the *same* membrane. This is relevant when an object passes into the // object actually to be the *same* membrane. This is relevant when an object passes into the
// membrane and then back out (or out and then back in): instead of double-wrapping the object, // membrane and then back out (or out and then back in): instead of double-wrapping the object,
// the wrapping will be removed. // the wrapping will be removed.
virtual kj::Maybe<kj::Promise<void>> onRevoked() { return nullptr; }
// If this returns non-null, then it is a promise that will reject (throw an exception) when the
// membrane should be revoked. On revocation, all capabilities pointing across the membrane will
// be dropped and all outstanding calls canceled. The exception thrown by the promise will be
// propagated to all these calls. It is an error for the promise to resolve without throwing.
//
// After the revocation promise has rejected, inboundCall() and outboundCall() will still be
// invoked for new calls, but the `target` passed to them will be a capability that always
// rethrows the revocation exception.
}; };
Capability::Client membrane(Capability::Client inner, kj::Own<MembranePolicy> policy); Capability::Client membrane(Capability::Client inner, kj::Own<MembranePolicy> policy);
......
...@@ -866,6 +866,8 @@ interface TestMembrane { ...@@ -866,6 +866,8 @@ interface TestMembrane {
callIntercept @2 (thing :Thing, tailCall :Bool) -> Result; callIntercept @2 (thing :Thing, tailCall :Bool) -> Result;
loopback @3 (thing :Thing) -> (thing :Thing); loopback @3 (thing :Thing) -> (thing :Thing);
waitForever @4 ();
interface Thing { interface Thing {
passThrough @0 () -> Result; passThrough @0 () -> Result;
intercept @1 () -> Result; intercept @1 () -> Result;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment