Commit 684a7f34 authored by Kenton Varda's avatar Kenton Varda

Pass the inner capability into the membrane filter callbacks.

Also introduce a way to copy a struct or list which applies membranes to all embedded capabilities, since this seems like it will be needed in conjuction with the above.
parent 8f614cc4
...@@ -305,6 +305,9 @@ public: ...@@ -305,6 +305,9 @@ public:
Orphan() = default; Orphan() = default;
KJ_DISALLOW_COPY(Orphan); KJ_DISALLOW_COPY(Orphan);
Orphan(Orphan&&) = default; Orphan(Orphan&&) = default;
inline Orphan(_::OrphanBuilder&& builder)
: builder(kj::mv(builder)) {}
Orphan& operator=(Orphan&&) = default; Orphan& operator=(Orphan&&) = default;
template <typename T> template <typename T>
...@@ -350,9 +353,6 @@ public: ...@@ -350,9 +353,6 @@ public:
private: private:
_::OrphanBuilder builder; _::OrphanBuilder builder;
inline Orphan(_::OrphanBuilder&& builder)
: builder(kj::mv(builder)) {}
template <typename, Kind> template <typename, Kind>
friend struct _::PointerHelpers; friend struct _::PointerHelpers;
friend class Orphanage; friend class Orphanage;
......
...@@ -1762,7 +1762,8 @@ struct WireHelpers { ...@@ -1762,7 +1762,8 @@ struct WireHelpers {
#if !CAPNP_LITE #if !CAPNP_LITE
KJ_IF_MAYBE(cap, srcCapTable->extractCap(src->capRef.index.get())) { KJ_IF_MAYBE(cap, srcCapTable->extractCap(src->capRef.index.get())) {
setCapabilityPointer(dstSegment, dstCapTable, dst, kj::mv(*cap)); setCapabilityPointer(dstSegment, dstCapTable, dst, kj::mv(*cap));
return { dstSegment, nullptr }; // Return dummy non-null pointer so OrphanBuilder doesn't end up null.
return { dstSegment, reinterpret_cast<word*>(1) };
} else { } else {
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
KJ_FAIL_REQUIRE("Message contained invalid capability pointer.") { KJ_FAIL_REQUIRE("Message contained invalid capability pointer.") {
...@@ -1810,7 +1811,7 @@ struct WireHelpers { ...@@ -1810,7 +1811,7 @@ struct WireHelpers {
location = nullptr; location = nullptr;
} else if (ref->kind() == WirePointer::OTHER) { } else if (ref->kind() == WirePointer::OTHER) {
KJ_REQUIRE(ref->isCapability(), "Unknown pointer type.") { break; } KJ_REQUIRE(ref->isCapability(), "Unknown pointer type.") { break; }
location = reinterpret_cast<word*>(ref); // dummy so that it is non-null location = reinterpret_cast<word*>(1); // dummy so that it is non-null
} else { } else {
WirePointer* refCopy = ref; WirePointer* refCopy = ref;
location = followFarsNoWritableCheck(refCopy, ref->target(), segment); location = followFarsNoWritableCheck(refCopy, ref->target(), segment);
...@@ -2581,6 +2582,10 @@ MessageSizeCounts StructReader::totalSize() const { ...@@ -2581,6 +2582,10 @@ MessageSizeCounts StructReader::totalSize() const {
return result; return result;
} }
CapTableReader* StructReader::getCapTable() {
return capTable;
}
StructReader StructReader::imbue(CapTableReader* capTable) const { StructReader StructReader::imbue(CapTableReader* capTable) const {
auto result = *this; auto result = *this;
result.capTable = capTable; result.capTable = capTable;
...@@ -2716,6 +2721,10 @@ StructReader ListReader::getStructElement(ElementCount index) const { ...@@ -2716,6 +2721,10 @@ StructReader ListReader::getStructElement(ElementCount index) const {
nestingLimit - 1); nestingLimit - 1);
} }
CapTableReader* ListReader::getCapTable() {
return capTable;
}
ListReader ListReader::imbue(CapTableReader* capTable) const { ListReader ListReader::imbue(CapTableReader* capTable) const {
auto result = *this; auto result = *this;
result.capTable = capTable; result.capTable = capTable;
......
...@@ -567,6 +567,9 @@ public: ...@@ -567,6 +567,9 @@ public:
// use the result as a hint for allocating the first segment, do the copy, and then throw an // use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns. // exception if it overruns.
CapTableReader* getCapTable();
// Gets the capability context in which this object is operating.
StructReader imbue(CapTableReader* capTable) const; StructReader imbue(CapTableReader* capTable) const;
// Return a copy of this reader except using the given capability context. // Return a copy of this reader except using the given capability context.
...@@ -727,6 +730,9 @@ public: ...@@ -727,6 +730,9 @@ public:
StructReader getStructElement(ElementCount index) const; StructReader getStructElement(ElementCount index) const;
CapTableReader* getCapTable();
// Gets the capability context in which this object is operating.
ListReader imbue(CapTableReader* capTable) const; ListReader imbue(CapTableReader* capTable) const;
// Return a copy of this reader except using the given capability context. // Return a copy of this reader except using the given capability context.
...@@ -860,7 +866,7 @@ private: ...@@ -860,7 +866,7 @@ private:
word* location; word* location;
// Pointer to the object, or nullptr if the pointer is null. For capabilities, we make this // Pointer to the object, or nullptr if the pointer is null. For capabilities, we make this
// point at `tag` just so that it is non-null for operator==, but it is never used. // 0x1 just so that it is non-null for operator==, but it is never used.
inline OrphanBuilder(const void* tagPtr, SegmentBuilder* segment, inline OrphanBuilder(const void* tagPtr, SegmentBuilder* segment,
CapTableBuilder* capTable, word* location) CapTableBuilder* capTable, word* location)
......
...@@ -90,7 +90,8 @@ protected: ...@@ -90,7 +90,8 @@ protected:
class MembranePolicyImpl: public MembranePolicy, public kj::Refcounted { class MembranePolicyImpl: public MembranePolicy, public kj::Refcounted {
public: public:
kj::Maybe<Capability::Client> inboundCall(uint64_t interfaceId, uint16_t methodId) override { kj::Maybe<Capability::Client> inboundCall(uint64_t interfaceId, uint16_t methodId,
Capability::Client target) override {
if (interfaceId == capnp::typeId<Thing>() && methodId == 1) { if (interfaceId == capnp::typeId<Thing>() && methodId == 1) {
return Capability::Client(kj::heap<ThingImpl>("inbound")); return Capability::Client(kj::heap<ThingImpl>("inbound"));
} else { } else {
...@@ -98,7 +99,8 @@ public: ...@@ -98,7 +99,8 @@ public:
} }
} }
kj::Maybe<Capability::Client> outboundCall(uint64_t interfaceId, uint16_t methodId) override { kj::Maybe<Capability::Client> outboundCall(uint64_t interfaceId, uint16_t methodId,
Capability::Client target) override {
if (interfaceId == capnp::typeId<Thing>() && methodId == 1) { if (interfaceId == capnp::typeId<Thing>() && methodId == 1) {
return Capability::Client(kj::heap<ThingImpl>("outbound")); return Capability::Client(kj::heap<ThingImpl>("outbound"));
} else { } else {
...@@ -147,11 +149,13 @@ void testThingImpl(kj::WaitScope& waitScope, test::TestMembrane::Client membrane ...@@ -147,11 +149,13 @@ void testThingImpl(kj::WaitScope& waitScope, test::TestMembrane::Client membrane
struct TestEnv { struct TestEnv {
kj::EventLoop loop; kj::EventLoop loop;
kj::WaitScope waitScope; kj::WaitScope waitScope;
kj::Own<MembranePolicyImpl> policy;
test::TestMembrane::Client membraned; test::TestMembrane::Client membraned;
TestEnv() TestEnv()
: waitScope(loop), : waitScope(loop),
membraned(membrane(kj::heap<TestMembraneImpl>(), kj::refcounted<MembranePolicyImpl>())) {} policy(kj::refcounted<MembranePolicyImpl>()),
membraned(membrane(kj::heap<TestMembraneImpl>(), policy->addRef())) {}
void testThing(kj::Function<Thing::Client()> makeThing, void testThing(kj::Function<Thing::Client()> makeThing,
kj::StringPtr localPassThrough, kj::StringPtr localIntercept, kj::StringPtr localPassThrough, kj::StringPtr localIntercept,
...@@ -209,6 +213,49 @@ KJ_TEST("call local promise pointing into membrane that eventually resolves to o ...@@ -209,6 +213,49 @@ KJ_TEST("call local promise pointing into membrane that eventually resolves to o
}, "outside", "outside", "outside", "outbound"); }, "outside", "outside", "outside", "outbound");
} }
KJ_TEST("apply membrane using copyOutOfMembrane() on struct") {
TestEnv env;
env.testThing([&]() {
MallocMessageBuilder outsideBuilder;
auto root = outsideBuilder.initRoot<test::TestContainMembrane>();
root.setCap(kj::heap<ThingImpl>("inside"));
MallocMessageBuilder insideBuilder;
insideBuilder.adoptRoot(copyOutOfMembrane(
root.asReader(), insideBuilder.getOrphanage(), env.policy->addRef()));
return insideBuilder.getRoot<test::TestContainMembrane>().getCap();
}, "inside", "inbound", "inside", "inside");
}
KJ_TEST("apply membrane using copyOutOfMembrane() on list") {
TestEnv env;
env.testThing([&]() {
MallocMessageBuilder outsideBuilder;
auto list = outsideBuilder.initRoot<test::TestContainMembrane>().initList(1);
list.set(0, kj::heap<ThingImpl>("inside"));
MallocMessageBuilder insideBuilder;
insideBuilder.initRoot<test::TestContainMembrane>().adoptList(copyOutOfMembrane(
list.asReader(), insideBuilder.getOrphanage(), env.policy->addRef()));
return insideBuilder.getRoot<test::TestContainMembrane>().getList()[0];
}, "inside", "inbound", "inside", "inside");
}
KJ_TEST("apply membrane using copyOutOfMembrane() on AnyPointer") {
TestEnv env;
env.testThing([&]() {
MallocMessageBuilder outsideBuilder;
auto ptr = outsideBuilder.initRoot<test::TestAnyPointer>().getAnyPointerField();
ptr.setAs<test::TestMembrane::Thing>(kj::heap<ThingImpl>("inside"));
MallocMessageBuilder insideBuilder;
insideBuilder.initRoot<test::TestAnyPointer>().getAnyPointerField().adopt(copyOutOfMembrane(
ptr.asReader(), insideBuilder.getOrphanage(), env.policy->addRef()));
return insideBuilder.getRoot<test::TestAnyPointer>().getAnyPointerField()
.getAs<test::TestMembrane::Thing>();
}, "inside", "inbound", "inside", "inside");
}
struct TestRpcEnv { struct TestRpcEnv {
kj::AsyncIoContext io; kj::AsyncIoContext io;
kj::TwoWayPipe pipe; kj::TwoWayPipe pipe;
......
...@@ -37,10 +37,26 @@ public: ...@@ -37,10 +37,26 @@ public:
: policy(policy), reverse(reverse) {} : policy(policy), reverse(reverse) {}
AnyPointer::Reader imbue(AnyPointer::Reader reader) { AnyPointer::Reader imbue(AnyPointer::Reader reader) {
return AnyPointer::Reader(imbue(
_::PointerHelpers<AnyPointer>::getInternalReader(kj::mv(reader))));
}
_::PointerReader imbue(_::PointerReader reader) {
KJ_REQUIRE(inner == nullptr, "can only call this once");
inner = reader.getCapTable();
return reader.imbue(this);
}
_::StructReader imbue(_::StructReader reader) {
KJ_REQUIRE(inner == nullptr, "can only call this once");
inner = reader.getCapTable();
return reader.imbue(this);
}
_::ListReader imbue(_::ListReader reader) {
KJ_REQUIRE(inner == nullptr, "can only call this once"); KJ_REQUIRE(inner == nullptr, "can only call this once");
auto pointerReader = _::PointerHelpers<AnyPointer>::getInternalReader(kj::mv(reader)); inner = reader.getCapTable();
inner = pointerReader.getCapTable(); return reader.imbue(this);
return AnyPointer::Reader(pointerReader.imbue(this));
} }
kj::Maybe<kj::Own<ClientHook>> extractCap(uint index) override { kj::Maybe<kj::Own<ClientHook>> extractCap(uint index) override {
...@@ -320,8 +336,9 @@ public: ...@@ -320,8 +336,9 @@ public:
return r->get()->newCall(interfaceId, methodId, sizeHint); return r->get()->newCall(interfaceId, methodId, sizeHint);
} }
auto redirect = reverse ? policy->outboundCall(interfaceId, methodId) auto redirect = reverse
: policy->inboundCall(interfaceId, methodId); ? policy->outboundCall(interfaceId, methodId, Capability::Client(inner->addRef()))
: policy->inboundCall(interfaceId, methodId, Capability::Client(inner->addRef()));
KJ_IF_MAYBE(r, redirect) { KJ_IF_MAYBE(r, redirect) {
// The policy says that *if* this capability points into the membrane, then we want to // The policy says that *if* this capability points into the membrane, then we want to
// redirect the call. However, if this capability is a promise, then it could resolve to // redirect the call. However, if this capability is a promise, then it could resolve to
...@@ -346,8 +363,9 @@ public: ...@@ -346,8 +363,9 @@ public:
return r->get()->call(interfaceId, methodId, kj::mv(context)); return r->get()->call(interfaceId, methodId, kj::mv(context));
} }
auto redirect = reverse ? policy->outboundCall(interfaceId, methodId) auto redirect = reverse
: policy->inboundCall(interfaceId, methodId); ? policy->outboundCall(interfaceId, methodId, Capability::Client(inner->addRef()))
: policy->inboundCall(interfaceId, methodId, Capability::Client(inner->addRef()));
KJ_IF_MAYBE(r, redirect) { KJ_IF_MAYBE(r, redirect) {
// The policy says that *if* this capability points into the membrane, then we want to // The policy says that *if* this capability points into the membrane, then we want to
// redirect the call. However, if this capability is a promise, then it could resolve to // redirect the call. However, if this capability is a promise, then it could resolve to
...@@ -434,5 +452,36 @@ Capability::Client reverseMembrane(Capability::Client inner, kj::Own<MembranePol ...@@ -434,5 +452,36 @@ Capability::Client reverseMembrane(Capability::Client inner, kj::Own<MembranePol
ClientHook::from(kj::mv(inner)), *policy, true)); ClientHook::from(kj::mv(inner)), *policy, true));
} }
namespace _ { // private
_::OrphanBuilder copyOutOfMembrane(PointerReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse) {
MembraneCapTableReader capTable(*policy, reverse);
return _::OrphanBuilder::copy(
OrphanageInternal::getArena(to),
OrphanageInternal::getCapTable(to),
capTable.imbue(from));
}
_::OrphanBuilder copyOutOfMembrane(StructReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse) {
MembraneCapTableReader capTable(*policy, reverse);
return _::OrphanBuilder::copy(
OrphanageInternal::getArena(to),
OrphanageInternal::getCapTable(to),
capTable.imbue(from));
}
_::OrphanBuilder copyOutOfMembrane(ListReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse) {
MembraneCapTableReader capTable(*policy, reverse);
return _::OrphanBuilder::copy(
OrphanageInternal::getArena(to),
OrphanageInternal::getCapTable(to),
capTable.imbue(from));
}
} // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -57,7 +57,8 @@ class MembranePolicy { ...@@ -57,7 +57,8 @@ class MembranePolicy {
// calls crossing the membrane to be blocked or redirected. // calls crossing the membrane to be blocked or redirected.
public: public:
virtual kj::Maybe<Capability::Client> inboundCall(uint64_t interfaceId, uint16_t methodId) = 0; virtual kj::Maybe<Capability::Client> inboundCall(
uint64_t interfaceId, uint16_t methodId, Capability::Client target) = 0;
// Given an inbound call (a call originating "outside" the membrane destined for an object // Given an inbound call (a call originating "outside" the membrane destined for an object
// "inside" the membrane), decides what to do with it. The policy may: // "inside" the membrane), decides what to do with it. The policy may:
// //
...@@ -68,8 +69,20 @@ public: ...@@ -68,8 +69,20 @@ public:
// auto-wrapped; however, the callee can easily wrap the returned capability in the membrane // auto-wrapped; however, the callee can easily wrap the returned capability in the membrane
// itself before returning to achieve this effect. // itself before returning to achieve this effect.
// - Throw an exception to cause the call to fail with that exception. // - Throw an exception to cause the call to fail with that exception.
//
// `target` is the underlying capability (*inside* the membrane) for which the call is destined.
// Generally, the only way you should use `target` is to wrap it in some capbaility which you
// return as a redirect. The redirect capability may modify the call in some way and send it to
// `target`. Be careful to use `copyIntoMembrane()` and `copyOutOfMembrane()` as appropriate when
// copying parameters or results across the membrane.
//
// Note that since `target` is inside the capability, if you were to directly return it (rather
// than return null), the effect would be that the membrane would be broken: the call would
// proceed directly and any new capabilities introduced through it would not be membraned. You
// generally should not do that.
virtual kj::Maybe<Capability::Client> outboundCall(uint64_t interfaceId, uint16_t methodId) = 0; virtual kj::Maybe<Capability::Client> outboundCall(
uint64_t interfaceId, uint16_t methodId, Capability::Client target) = 0;
// Like `inboundCall()`, but applies to calls originating *inside* the membrane and terminating // Like `inboundCall()`, but applies to calls originating *inside* the membrane and terminating
// outside. // outside.
// //
...@@ -119,6 +132,17 @@ typename ServerType::Serves::Client reverseMembrane( ...@@ -119,6 +132,17 @@ typename ServerType::Serves::Client reverseMembrane(
// Convenience templates which input a capability server type and return the appropriate client // Convenience templates which input a capability server type and return the appropriate client
// type. // type.
template <typename Reader>
Orphan<typename kj::Decay<Reader>::Reads> copyIntoMembrane(
Reader&& from, Orphanage to, kj::Own<MembranePolicy> policy);
// Copy a Cap'n Proto object (e.g. struct or list), adding the given membrane to any capabilities
// found within it. `from` is interpreted as "outside" the membrane while `to` is "inside".
template <typename Reader>
Orphan<typename kj::Decay<Reader>::Reads> copyOutOfMembrane(
Reader&& from, Orphanage to, kj::Own<MembranePolicy> policy);
// Like copyIntoMembrane() except that `from` is "inside" the membrane and `to` is "outside".
// ======================================================================================= // =======================================================================================
// inline implementation details // inline implementation details
...@@ -146,6 +170,33 @@ typename ServerType::Serves::Client reverseMembrane( ...@@ -146,6 +170,33 @@ typename ServerType::Serves::Client reverseMembrane(
.castAs<typename ServerType::Serves::Client>(); .castAs<typename ServerType::Serves::Client>();
} }
namespace _ { // private
OrphanBuilder copyOutOfMembrane(PointerReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse);
OrphanBuilder copyOutOfMembrane(StructReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse);
OrphanBuilder copyOutOfMembrane(ListReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse);
} // namespace _ (private)
template <typename Reader>
Orphan<typename kj::Decay<Reader>::Reads> copyIntoMembrane(
Reader&& from, Orphanage to, kj::Own<MembranePolicy> policy) {
return _::copyOutOfMembrane(
_::PointerHelpers<typename kj::Decay<Reader>::Reads>::getInternalReader(from),
to, kj::mv(policy), true);
}
template <typename Reader>
Orphan<typename kj::Decay<Reader>::Reads> copyOutOfMembrane(
Reader&& from, Orphanage to, kj::Own<MembranePolicy> policy) {
return _::copyOutOfMembrane(
_::PointerHelpers<typename kj::Decay<Reader>::Reads>::getInternalReader(from),
to, kj::mv(policy), false);
}
} // namespace capnp } // namespace capnp
#endif // CAPNP_MEMBRANE_H_ #endif // CAPNP_MEMBRANE_H_
...@@ -34,6 +34,7 @@ class StructSchema; ...@@ -34,6 +34,7 @@ class StructSchema;
class ListSchema; class ListSchema;
struct DynamicStruct; struct DynamicStruct;
struct DynamicList; struct DynamicList;
namespace _ { struct OrphanageInternal; }
template <typename T> template <typename T>
class Orphan { class Orphan {
...@@ -53,6 +54,7 @@ public: ...@@ -53,6 +54,7 @@ public:
KJ_DISALLOW_COPY(Orphan); KJ_DISALLOW_COPY(Orphan);
Orphan(Orphan&&) = default; Orphan(Orphan&&) = default;
Orphan& operator=(Orphan&&) = default; Orphan& operator=(Orphan&&) = default;
inline Orphan(_::OrphanBuilder&& builder): builder(kj::mv(builder)) {}
inline BuilderFor<T> get(); inline BuilderFor<T> get();
// Get the underlying builder. If the orphan is null, this will allocate and return a default // Get the underlying builder. If the orphan is null, this will allocate and return a default
...@@ -89,8 +91,6 @@ public: ...@@ -89,8 +91,6 @@ public:
private: private:
_::OrphanBuilder builder; _::OrphanBuilder builder;
inline Orphan(_::OrphanBuilder&& builder): builder(kj::mv(builder)) {}
template <typename, Kind> template <typename, Kind>
friend struct _::PointerHelpers; friend struct _::PointerHelpers;
template <typename, Kind> template <typename, Kind>
...@@ -179,6 +179,7 @@ private: ...@@ -179,6 +179,7 @@ private:
struct NewOrphanListImpl; struct NewOrphanListImpl;
friend class MessageBuilder; friend class MessageBuilder;
friend struct _::OrphanageInternal;
}; };
// ======================================================================================= // =======================================================================================
...@@ -276,6 +277,11 @@ struct OrphanGetImpl<Data, Kind::BLOB> { ...@@ -276,6 +277,11 @@ struct OrphanGetImpl<Data, Kind::BLOB> {
} }
}; };
struct OrphanageInternal {
static inline _::BuilderArena* getArena(Orphanage orphanage) { return orphanage.arena; }
static inline _::CapTableBuilder* getCapTable(Orphanage orphanage) { return orphanage.capTable; }
};
} // namespace _ (private) } // namespace _ (private)
template <typename T> template <typename T>
......
...@@ -819,6 +819,11 @@ interface TestMembrane { ...@@ -819,6 +819,11 @@ interface TestMembrane {
} }
} }
struct TestContainMembrane {
cap @0 :TestMembrane.Thing;
list @1 :List(TestMembrane.Thing);
}
struct TestTransferCap { struct TestTransferCap {
list @0 :List(Element); list @0 :List(Element);
struct Element { struct Element {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment