// Copyright (c) 2015 Sandstorm Development Group, Inc. and contributors // Licensed under the MIT License: // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. #include "membrane.h" #include <kj/debug.h> namespace capnp { namespace { static const char DUMMY = 0; static constexpr const void* MEMBRANE_BRAND = &DUMMY; kj::Own<ClientHook> membrane(kj::Own<ClientHook> inner, MembranePolicy& policy, bool reverse); class MembraneCapTableReader final: public _::CapTableReader { public: MembraneCapTableReader(MembranePolicy& policy, bool reverse) : policy(policy), reverse(reverse) {} AnyPointer::Reader imbue(AnyPointer::Reader reader) { return AnyPointer::Reader(imbue( _::PointerHelpers<AnyPointer>::getInternalReader(kj::mv(reader)))); } _::PointerReader imbue(_::PointerReader reader) { KJ_REQUIRE(inner == nullptr, "can only call this once"); inner = reader.getCapTable(); return reader.imbue(this); } _::StructReader imbue(_::StructReader reader) { KJ_REQUIRE(inner == nullptr, "can only call this once"); inner = reader.getCapTable(); return reader.imbue(this); } _::ListReader imbue(_::ListReader reader) { KJ_REQUIRE(inner == nullptr, "can only call this once"); inner = reader.getCapTable(); return reader.imbue(this); } kj::Maybe<kj::Own<ClientHook>> extractCap(uint index) override { // The underlying message is inside the membrane, and we're pulling a cap out of it. Therefore, // we want to wrap the extracted capability in the membrane. return inner->extractCap(index).map([this](kj::Own<ClientHook>&& cap) { return membrane(kj::mv(cap), policy, reverse); }); } private: _::CapTableReader* inner = nullptr; MembranePolicy& policy; bool reverse; }; class MembraneCapTableBuilder final: public _::CapTableBuilder { public: MembraneCapTableBuilder(MembranePolicy& policy, bool reverse) : policy(policy), reverse(reverse) {} AnyPointer::Builder imbue(AnyPointer::Builder builder) { KJ_REQUIRE(inner == nullptr, "can only call this once"); auto pointerBuilder = _::PointerHelpers<AnyPointer>::getInternalBuilder(kj::mv(builder)); inner = pointerBuilder.getCapTable(); return AnyPointer::Builder(pointerBuilder.imbue(this)); } AnyPointer::Builder unimbue(AnyPointer::Builder builder) { auto pointerBuilder = _::PointerHelpers<AnyPointer>::getInternalBuilder(kj::mv(builder)); KJ_REQUIRE(pointerBuilder.getCapTable() == this); return AnyPointer::Builder(pointerBuilder.imbue(inner)); } kj::Maybe<kj::Own<ClientHook>> extractCap(uint index) override { // The underlying message is inside the membrane, and we're pulling a cap out of it. Therefore, // we want to wrap the extracted capability in the membrane. return inner->extractCap(index).map([this](kj::Own<ClientHook>&& cap) { return membrane(kj::mv(cap), policy, reverse); }); } uint injectCap(kj::Own<ClientHook>&& cap) override { // The underlying message is inside the membrane, and we're inserting a cap from outside into // it. Therefore we want to add a reverse membrane. return inner->injectCap(membrane(kj::mv(cap), policy, !reverse)); } void dropCap(uint index) override { inner->dropCap(index); } private: _::CapTableBuilder* inner = nullptr; MembranePolicy& policy; bool reverse; }; class MembranePipelineHook final: public PipelineHook, public kj::Refcounted { public: MembranePipelineHook( kj::Own<PipelineHook>&& inner, kj::Own<MembranePolicy>&& policy, bool reverse) : inner(kj::mv(inner)), policy(kj::mv(policy)), reverse(reverse) {} kj::Own<PipelineHook> addRef() override { return kj::addRef(*this); } kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) override { return membrane(inner->getPipelinedCap(ops), *policy, reverse); } kj::Own<ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) override { return membrane(inner->getPipelinedCap(kj::mv(ops)), *policy, reverse); } private: kj::Own<PipelineHook> inner; kj::Own<MembranePolicy> policy; bool reverse; }; class MembraneResponseHook final: public ResponseHook { public: MembraneResponseHook( kj::Own<ResponseHook>&& inner, kj::Own<MembranePolicy>&& policy, bool reverse) : inner(kj::mv(inner)), policy(kj::mv(policy)), capTable(*this->policy, reverse) {} AnyPointer::Reader imbue(AnyPointer::Reader reader) { return capTable.imbue(reader); } private: kj::Own<ResponseHook> inner; kj::Own<MembranePolicy> policy; MembraneCapTableReader capTable; }; class MembraneRequestHook final: public RequestHook { public: MembraneRequestHook(kj::Own<RequestHook>&& inner, kj::Own<MembranePolicy>&& policy, bool reverse) : inner(kj::mv(inner)), policy(kj::mv(policy)), reverse(reverse), capTable(*this->policy, reverse) {} static Request<AnyPointer, AnyPointer> wrap( Request<AnyPointer, AnyPointer>&& inner, MembranePolicy& policy, bool reverse) { AnyPointer::Builder builder = inner; auto innerHook = RequestHook::from(kj::mv(inner)); if (innerHook->getBrand() == MEMBRANE_BRAND) { auto& otherMembrane = kj::downcast<MembraneRequestHook>(*innerHook); if (otherMembrane.policy.get() == &policy && otherMembrane.reverse == !reverse) { // Request that passed across the membrane one way is now passing back the other way. // Unwrap it rather than double-wrap it. builder = otherMembrane.capTable.unimbue(builder); return { builder, kj::mv(otherMembrane.inner) }; } } auto newHook = kj::heap<MembraneRequestHook>(kj::mv(innerHook), policy.addRef(), reverse); builder = newHook->capTable.imbue(builder); return { builder, kj::mv(newHook) }; } static kj::Own<RequestHook> wrap( kj::Own<RequestHook>&& inner, MembranePolicy& policy, bool reverse) { if (inner->getBrand() == MEMBRANE_BRAND) { auto& otherMembrane = kj::downcast<MembraneRequestHook>(*inner); if (otherMembrane.policy.get() == &policy && otherMembrane.reverse == !reverse) { // Request that passed across the membrane one way is now passing back the other way. // Unwrap it rather than double-wrap it. return kj::mv(otherMembrane.inner); } } return kj::heap<MembraneRequestHook>(kj::mv(inner), policy.addRef(), reverse); } RemotePromise<AnyPointer> send() override { auto promise = inner->send(); auto newPipeline = AnyPointer::Pipeline(kj::refcounted<MembranePipelineHook>( PipelineHook::from(kj::mv(promise)), policy->addRef(), reverse)); bool reverse = this->reverse; // for capture auto newPromise = promise.then(kj::mvCapture(policy, [reverse](kj::Own<MembranePolicy>&& policy, Response<AnyPointer>&& response) { AnyPointer::Reader reader = response; auto newRespHook = kj::heap<MembraneResponseHook>( ResponseHook::from(kj::mv(response)), policy->addRef(), reverse); reader = newRespHook->imbue(reader); return Response<AnyPointer>(reader, kj::mv(newRespHook)); })); return RemotePromise<AnyPointer>(kj::mv(newPromise), kj::mv(newPipeline)); } const void* getBrand() override { return MEMBRANE_BRAND; } private: kj::Own<RequestHook> inner; kj::Own<MembranePolicy> policy; bool reverse; MembraneCapTableBuilder capTable; }; class MembraneCallContextHook final: public CallContextHook, public kj::Refcounted { public: MembraneCallContextHook(kj::Own<CallContextHook>&& inner, kj::Own<MembranePolicy>&& policy, bool reverse) : inner(kj::mv(inner)), policy(kj::mv(policy)), reverse(reverse), paramsCapTable(*this->policy, reverse), resultsCapTable(*this->policy, reverse) {} AnyPointer::Reader getParams() override { KJ_REQUIRE(!releasedParams); KJ_IF_MAYBE(p, params) { return *p; } else { auto result = paramsCapTable.imbue(inner->getParams()); params = result; return result; } } void releaseParams() override { KJ_REQUIRE(!releasedParams); releasedParams = true; inner->releaseParams(); } AnyPointer::Builder getResults(kj::Maybe<MessageSize> sizeHint) override { KJ_IF_MAYBE(r, results) { return *r; } else { auto result = resultsCapTable.imbue(inner->getResults(sizeHint)); results = result; return result; } } kj::Promise<void> tailCall(kj::Own<RequestHook>&& request) override { return inner->tailCall(MembraneRequestHook::wrap(kj::mv(request), *policy, !reverse)); } void allowCancellation() override { inner->allowCancellation(); } kj::Promise<AnyPointer::Pipeline> onTailCall() override { return inner->onTailCall().then([this](AnyPointer::Pipeline&& innerPipeline) { return AnyPointer::Pipeline(kj::refcounted<MembranePipelineHook>( PipelineHook::from(kj::mv(innerPipeline)), policy->addRef(), reverse)); }); } ClientHook::VoidPromiseAndPipeline directTailCall(kj::Own<RequestHook>&& request) override { auto pair = inner->directTailCall( MembraneRequestHook::wrap(kj::mv(request), *policy, !reverse)); return { kj::mv(pair.promise), kj::refcounted<MembranePipelineHook>(kj::mv(pair.pipeline), policy->addRef(), reverse) }; } kj::Own<CallContextHook> addRef() override { return kj::addRef(*this); } private: kj::Own<CallContextHook> inner; kj::Own<MembranePolicy> policy; bool reverse; MembraneCapTableReader paramsCapTable; kj::Maybe<AnyPointer::Reader> params; bool releasedParams = false; MembraneCapTableBuilder resultsCapTable; kj::Maybe<AnyPointer::Builder> results; }; class MembraneHook final: public ClientHook, public kj::Refcounted { public: MembraneHook(kj::Own<ClientHook>&& inner, kj::Own<MembranePolicy>&& policy, bool reverse) : inner(kj::mv(inner)), policy(kj::mv(policy)), reverse(reverse) {} static kj::Own<ClientHook> wrap(ClientHook& cap, MembranePolicy& policy, bool reverse) { if (cap.getBrand() == MEMBRANE_BRAND) { auto& otherMembrane = kj::downcast<MembraneHook>(cap); if (otherMembrane.policy.get() == &policy && otherMembrane.reverse == !reverse) { // Capability that passed across the membrane one way is now passing back the other way. // Unwrap it rather than double-wrap it. return otherMembrane.inner->addRef(); } } return kj::refcounted<MembraneHook>(cap.addRef(), policy.addRef(), reverse); } static kj::Own<ClientHook> wrap(kj::Own<ClientHook> cap, MembranePolicy& policy, bool reverse) { if (cap->getBrand() == MEMBRANE_BRAND) { auto& otherMembrane = kj::downcast<MembraneHook>(*cap); if (otherMembrane.policy.get() == &policy && otherMembrane.reverse == !reverse) { // Capability that passed across the membrane one way is now passing back the other way. // Unwrap it rather than double-wrap it. return otherMembrane.inner->addRef(); } } return kj::refcounted<MembraneHook>(kj::mv(cap), policy.addRef(), reverse); } Request<AnyPointer, AnyPointer> newCall( uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { KJ_IF_MAYBE(r, resolved) { return r->get()->newCall(interfaceId, methodId, sizeHint); } auto redirect = reverse ? policy->outboundCall(interfaceId, methodId, Capability::Client(inner->addRef())) : policy->inboundCall(interfaceId, methodId, Capability::Client(inner->addRef())); KJ_IF_MAYBE(r, redirect) { // The policy says that *if* this capability points into the membrane, then we want to // redirect the call. However, if this capability is a promise, then it could resolve to // something outside the membrane later. We have to wait before we actually redirect, // otherwise behavior will differ depending on whether the promise is resolved. KJ_IF_MAYBE(p, whenMoreResolved()) { return newLocalPromiseClient(kj::mv(*p))->newCall(interfaceId, methodId, sizeHint); } return ClientHook::from(kj::mv(*r))->newCall(interfaceId, methodId, sizeHint); } else { // For pass-through calls, we don't worry about promises, because if the capability resolves // to something outside the membrane, then the call will pass back out of the membrane too. return MembraneRequestHook::wrap( inner->newCall(interfaceId, methodId, sizeHint), *policy, reverse); } } VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, kj::Own<CallContextHook>&& context) override { KJ_IF_MAYBE(r, resolved) { return r->get()->call(interfaceId, methodId, kj::mv(context)); } auto redirect = reverse ? policy->outboundCall(interfaceId, methodId, Capability::Client(inner->addRef())) : policy->inboundCall(interfaceId, methodId, Capability::Client(inner->addRef())); KJ_IF_MAYBE(r, redirect) { // The policy says that *if* this capability points into the membrane, then we want to // redirect the call. However, if this capability is a promise, then it could resolve to // something outside the membrane later. We have to wait before we actually redirect, // otherwise behavior will differ depending on whether the promise is resolved. KJ_IF_MAYBE(p, whenMoreResolved()) { return newLocalPromiseClient(kj::mv(*p))->call(interfaceId, methodId, kj::mv(context)); } return ClientHook::from(kj::mv(*r))->call(interfaceId, methodId, kj::mv(context)); } else { // !reverse because calls to the CallContext go in the opposite direction. auto result = inner->call(interfaceId, methodId, kj::refcounted<MembraneCallContextHook>(kj::mv(context), policy->addRef(), !reverse)); return { kj::mv(result.promise), kj::refcounted<MembranePipelineHook>(kj::mv(result.pipeline), policy->addRef(), reverse) }; } } kj::Maybe<ClientHook&> getResolved() override { KJ_IF_MAYBE(r, resolved) { return **r; } KJ_IF_MAYBE(newInner, inner->getResolved()) { kj::Own<ClientHook> newResolved = wrap(*newInner, *policy, reverse); ClientHook& result = *newResolved; resolved = kj::mv(newResolved); return result; } else { return nullptr; } } kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override { KJ_IF_MAYBE(r, resolved) { return kj::Promise<kj::Own<ClientHook>>(r->get()->addRef()); } KJ_IF_MAYBE(promise, inner->whenMoreResolved()) { return promise->then([this](kj::Own<ClientHook>&& newInner) { kj::Own<ClientHook> newResolved = wrap(*newInner, *policy, reverse); if (resolved == nullptr) { resolved = newResolved->addRef(); } return newResolved; }); } else { return nullptr; } } kj::Own<ClientHook> addRef() override { return kj::addRef(*this); } const void* getBrand() override { return MEMBRANE_BRAND; } private: kj::Own<ClientHook> inner; kj::Own<MembranePolicy> policy; bool reverse; kj::Maybe<kj::Own<ClientHook>> resolved; }; kj::Own<ClientHook> membrane(kj::Own<ClientHook> inner, MembranePolicy& policy, bool reverse) { return MembraneHook::wrap(kj::mv(inner), policy, reverse); } } // namespace Capability::Client membrane(Capability::Client inner, kj::Own<MembranePolicy> policy) { return Capability::Client(membrane( ClientHook::from(kj::mv(inner)), *policy, false)); } Capability::Client reverseMembrane(Capability::Client inner, kj::Own<MembranePolicy> policy) { return Capability::Client(membrane( ClientHook::from(kj::mv(inner)), *policy, true)); } namespace _ { // private _::OrphanBuilder copyOutOfMembrane(PointerReader from, Orphanage to, kj::Own<MembranePolicy> policy, bool reverse) { MembraneCapTableReader capTable(*policy, reverse); return _::OrphanBuilder::copy( OrphanageInternal::getArena(to), OrphanageInternal::getCapTable(to), capTable.imbue(from)); } _::OrphanBuilder copyOutOfMembrane(StructReader from, Orphanage to, kj::Own<MembranePolicy> policy, bool reverse) { MembraneCapTableReader capTable(*policy, reverse); return _::OrphanBuilder::copy( OrphanageInternal::getArena(to), OrphanageInternal::getCapTable(to), capTable.imbue(from)); } _::OrphanBuilder copyOutOfMembrane(ListReader from, Orphanage to, kj::Own<MembranePolicy> policy, bool reverse) { MembraneCapTableReader capTable(*policy, reverse); return _::OrphanBuilder::copy( OrphanageInternal::getArena(to), OrphanageInternal::getCapTable(to), capTable.imbue(from)); } } // namespace _ (private) } // namespace capnp