Commit 684a7f34 authored by Kenton Varda's avatar Kenton Varda

Pass the inner capability into the membrane filter callbacks.

Also introduce a way to copy a struct or list which applies membranes to all embedded capabilities, since this seems like it will be needed in conjuction with the above.
parent 8f614cc4
......@@ -305,6 +305,9 @@ public:
Orphan() = default;
KJ_DISALLOW_COPY(Orphan);
Orphan(Orphan&&) = default;
inline Orphan(_::OrphanBuilder&& builder)
: builder(kj::mv(builder)) {}
Orphan& operator=(Orphan&&) = default;
template <typename T>
......@@ -350,9 +353,6 @@ public:
private:
_::OrphanBuilder builder;
inline Orphan(_::OrphanBuilder&& builder)
: builder(kj::mv(builder)) {}
template <typename, Kind>
friend struct _::PointerHelpers;
friend class Orphanage;
......
......@@ -1762,7 +1762,8 @@ struct WireHelpers {
#if !CAPNP_LITE
KJ_IF_MAYBE(cap, srcCapTable->extractCap(src->capRef.index.get())) {
setCapabilityPointer(dstSegment, dstCapTable, dst, kj::mv(*cap));
return { dstSegment, nullptr };
// Return dummy non-null pointer so OrphanBuilder doesn't end up null.
return { dstSegment, reinterpret_cast<word*>(1) };
} else {
#endif // !CAPNP_LITE
KJ_FAIL_REQUIRE("Message contained invalid capability pointer.") {
......@@ -1810,7 +1811,7 @@ struct WireHelpers {
location = nullptr;
} else if (ref->kind() == WirePointer::OTHER) {
KJ_REQUIRE(ref->isCapability(), "Unknown pointer type.") { break; }
location = reinterpret_cast<word*>(ref); // dummy so that it is non-null
location = reinterpret_cast<word*>(1); // dummy so that it is non-null
} else {
WirePointer* refCopy = ref;
location = followFarsNoWritableCheck(refCopy, ref->target(), segment);
......@@ -2581,6 +2582,10 @@ MessageSizeCounts StructReader::totalSize() const {
return result;
}
CapTableReader* StructReader::getCapTable() {
return capTable;
}
StructReader StructReader::imbue(CapTableReader* capTable) const {
auto result = *this;
result.capTable = capTable;
......@@ -2716,6 +2721,10 @@ StructReader ListReader::getStructElement(ElementCount index) const {
nestingLimit - 1);
}
CapTableReader* ListReader::getCapTable() {
return capTable;
}
ListReader ListReader::imbue(CapTableReader* capTable) const {
auto result = *this;
result.capTable = capTable;
......
......@@ -567,6 +567,9 @@ public:
// use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns.
CapTableReader* getCapTable();
// Gets the capability context in which this object is operating.
StructReader imbue(CapTableReader* capTable) const;
// Return a copy of this reader except using the given capability context.
......@@ -727,6 +730,9 @@ public:
StructReader getStructElement(ElementCount index) const;
CapTableReader* getCapTable();
// Gets the capability context in which this object is operating.
ListReader imbue(CapTableReader* capTable) const;
// Return a copy of this reader except using the given capability context.
......@@ -860,7 +866,7 @@ private:
word* location;
// Pointer to the object, or nullptr if the pointer is null. For capabilities, we make this
// point at `tag` just so that it is non-null for operator==, but it is never used.
// 0x1 just so that it is non-null for operator==, but it is never used.
inline OrphanBuilder(const void* tagPtr, SegmentBuilder* segment,
CapTableBuilder* capTable, word* location)
......
......@@ -90,7 +90,8 @@ protected:
class MembranePolicyImpl: public MembranePolicy, public kj::Refcounted {
public:
kj::Maybe<Capability::Client> inboundCall(uint64_t interfaceId, uint16_t methodId) override {
kj::Maybe<Capability::Client> inboundCall(uint64_t interfaceId, uint16_t methodId,
Capability::Client target) override {
if (interfaceId == capnp::typeId<Thing>() && methodId == 1) {
return Capability::Client(kj::heap<ThingImpl>("inbound"));
} else {
......@@ -98,7 +99,8 @@ public:
}
}
kj::Maybe<Capability::Client> outboundCall(uint64_t interfaceId, uint16_t methodId) override {
kj::Maybe<Capability::Client> outboundCall(uint64_t interfaceId, uint16_t methodId,
Capability::Client target) override {
if (interfaceId == capnp::typeId<Thing>() && methodId == 1) {
return Capability::Client(kj::heap<ThingImpl>("outbound"));
} else {
......@@ -147,11 +149,13 @@ void testThingImpl(kj::WaitScope& waitScope, test::TestMembrane::Client membrane
struct TestEnv {
kj::EventLoop loop;
kj::WaitScope waitScope;
kj::Own<MembranePolicyImpl> policy;
test::TestMembrane::Client membraned;
TestEnv()
: waitScope(loop),
membraned(membrane(kj::heap<TestMembraneImpl>(), kj::refcounted<MembranePolicyImpl>())) {}
policy(kj::refcounted<MembranePolicyImpl>()),
membraned(membrane(kj::heap<TestMembraneImpl>(), policy->addRef())) {}
void testThing(kj::Function<Thing::Client()> makeThing,
kj::StringPtr localPassThrough, kj::StringPtr localIntercept,
......@@ -209,6 +213,49 @@ KJ_TEST("call local promise pointing into membrane that eventually resolves to o
}, "outside", "outside", "outside", "outbound");
}
KJ_TEST("apply membrane using copyOutOfMembrane() on struct") {
TestEnv env;
env.testThing([&]() {
MallocMessageBuilder outsideBuilder;
auto root = outsideBuilder.initRoot<test::TestContainMembrane>();
root.setCap(kj::heap<ThingImpl>("inside"));
MallocMessageBuilder insideBuilder;
insideBuilder.adoptRoot(copyOutOfMembrane(
root.asReader(), insideBuilder.getOrphanage(), env.policy->addRef()));
return insideBuilder.getRoot<test::TestContainMembrane>().getCap();
}, "inside", "inbound", "inside", "inside");
}
KJ_TEST("apply membrane using copyOutOfMembrane() on list") {
TestEnv env;
env.testThing([&]() {
MallocMessageBuilder outsideBuilder;
auto list = outsideBuilder.initRoot<test::TestContainMembrane>().initList(1);
list.set(0, kj::heap<ThingImpl>("inside"));
MallocMessageBuilder insideBuilder;
insideBuilder.initRoot<test::TestContainMembrane>().adoptList(copyOutOfMembrane(
list.asReader(), insideBuilder.getOrphanage(), env.policy->addRef()));
return insideBuilder.getRoot<test::TestContainMembrane>().getList()[0];
}, "inside", "inbound", "inside", "inside");
}
KJ_TEST("apply membrane using copyOutOfMembrane() on AnyPointer") {
TestEnv env;
env.testThing([&]() {
MallocMessageBuilder outsideBuilder;
auto ptr = outsideBuilder.initRoot<test::TestAnyPointer>().getAnyPointerField();
ptr.setAs<test::TestMembrane::Thing>(kj::heap<ThingImpl>("inside"));
MallocMessageBuilder insideBuilder;
insideBuilder.initRoot<test::TestAnyPointer>().getAnyPointerField().adopt(copyOutOfMembrane(
ptr.asReader(), insideBuilder.getOrphanage(), env.policy->addRef()));
return insideBuilder.getRoot<test::TestAnyPointer>().getAnyPointerField()
.getAs<test::TestMembrane::Thing>();
}, "inside", "inbound", "inside", "inside");
}
struct TestRpcEnv {
kj::AsyncIoContext io;
kj::TwoWayPipe pipe;
......
......@@ -37,10 +37,26 @@ public:
: policy(policy), reverse(reverse) {}
AnyPointer::Reader imbue(AnyPointer::Reader reader) {
return AnyPointer::Reader(imbue(
_::PointerHelpers<AnyPointer>::getInternalReader(kj::mv(reader))));
}
_::PointerReader imbue(_::PointerReader reader) {
KJ_REQUIRE(inner == nullptr, "can only call this once");
inner = reader.getCapTable();
return reader.imbue(this);
}
_::StructReader imbue(_::StructReader reader) {
KJ_REQUIRE(inner == nullptr, "can only call this once");
inner = reader.getCapTable();
return reader.imbue(this);
}
_::ListReader imbue(_::ListReader reader) {
KJ_REQUIRE(inner == nullptr, "can only call this once");
auto pointerReader = _::PointerHelpers<AnyPointer>::getInternalReader(kj::mv(reader));
inner = pointerReader.getCapTable();
return AnyPointer::Reader(pointerReader.imbue(this));
inner = reader.getCapTable();
return reader.imbue(this);
}
kj::Maybe<kj::Own<ClientHook>> extractCap(uint index) override {
......@@ -320,8 +336,9 @@ public:
return r->get()->newCall(interfaceId, methodId, sizeHint);
}
auto redirect = reverse ? policy->outboundCall(interfaceId, methodId)
: policy->inboundCall(interfaceId, methodId);
auto redirect = reverse
? policy->outboundCall(interfaceId, methodId, Capability::Client(inner->addRef()))
: policy->inboundCall(interfaceId, methodId, Capability::Client(inner->addRef()));
KJ_IF_MAYBE(r, redirect) {
// The policy says that *if* this capability points into the membrane, then we want to
// redirect the call. However, if this capability is a promise, then it could resolve to
......@@ -346,8 +363,9 @@ public:
return r->get()->call(interfaceId, methodId, kj::mv(context));
}
auto redirect = reverse ? policy->outboundCall(interfaceId, methodId)
: policy->inboundCall(interfaceId, methodId);
auto redirect = reverse
? policy->outboundCall(interfaceId, methodId, Capability::Client(inner->addRef()))
: policy->inboundCall(interfaceId, methodId, Capability::Client(inner->addRef()));
KJ_IF_MAYBE(r, redirect) {
// The policy says that *if* this capability points into the membrane, then we want to
// redirect the call. However, if this capability is a promise, then it could resolve to
......@@ -434,5 +452,36 @@ Capability::Client reverseMembrane(Capability::Client inner, kj::Own<MembranePol
ClientHook::from(kj::mv(inner)), *policy, true));
}
} // namespace capnp
namespace _ { // private
_::OrphanBuilder copyOutOfMembrane(PointerReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse) {
MembraneCapTableReader capTable(*policy, reverse);
return _::OrphanBuilder::copy(
OrphanageInternal::getArena(to),
OrphanageInternal::getCapTable(to),
capTable.imbue(from));
}
_::OrphanBuilder copyOutOfMembrane(StructReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse) {
MembraneCapTableReader capTable(*policy, reverse);
return _::OrphanBuilder::copy(
OrphanageInternal::getArena(to),
OrphanageInternal::getCapTable(to),
capTable.imbue(from));
}
_::OrphanBuilder copyOutOfMembrane(ListReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse) {
MembraneCapTableReader capTable(*policy, reverse);
return _::OrphanBuilder::copy(
OrphanageInternal::getArena(to),
OrphanageInternal::getCapTable(to),
capTable.imbue(from));
}
} // namespace _ (private)
} // namespace capnp
......@@ -57,7 +57,8 @@ class MembranePolicy {
// calls crossing the membrane to be blocked or redirected.
public:
virtual kj::Maybe<Capability::Client> inboundCall(uint64_t interfaceId, uint16_t methodId) = 0;
virtual kj::Maybe<Capability::Client> inboundCall(
uint64_t interfaceId, uint16_t methodId, Capability::Client target) = 0;
// Given an inbound call (a call originating "outside" the membrane destined for an object
// "inside" the membrane), decides what to do with it. The policy may:
//
......@@ -68,8 +69,20 @@ public:
// auto-wrapped; however, the callee can easily wrap the returned capability in the membrane
// itself before returning to achieve this effect.
// - Throw an exception to cause the call to fail with that exception.
//
// `target` is the underlying capability (*inside* the membrane) for which the call is destined.
// Generally, the only way you should use `target` is to wrap it in some capbaility which you
// return as a redirect. The redirect capability may modify the call in some way and send it to
// `target`. Be careful to use `copyIntoMembrane()` and `copyOutOfMembrane()` as appropriate when
// copying parameters or results across the membrane.
//
// Note that since `target` is inside the capability, if you were to directly return it (rather
// than return null), the effect would be that the membrane would be broken: the call would
// proceed directly and any new capabilities introduced through it would not be membraned. You
// generally should not do that.
virtual kj::Maybe<Capability::Client> outboundCall(uint64_t interfaceId, uint16_t methodId) = 0;
virtual kj::Maybe<Capability::Client> outboundCall(
uint64_t interfaceId, uint16_t methodId, Capability::Client target) = 0;
// Like `inboundCall()`, but applies to calls originating *inside* the membrane and terminating
// outside.
//
......@@ -119,6 +132,17 @@ typename ServerType::Serves::Client reverseMembrane(
// Convenience templates which input a capability server type and return the appropriate client
// type.
template <typename Reader>
Orphan<typename kj::Decay<Reader>::Reads> copyIntoMembrane(
Reader&& from, Orphanage to, kj::Own<MembranePolicy> policy);
// Copy a Cap'n Proto object (e.g. struct or list), adding the given membrane to any capabilities
// found within it. `from` is interpreted as "outside" the membrane while `to` is "inside".
template <typename Reader>
Orphan<typename kj::Decay<Reader>::Reads> copyOutOfMembrane(
Reader&& from, Orphanage to, kj::Own<MembranePolicy> policy);
// Like copyIntoMembrane() except that `from` is "inside" the membrane and `to` is "outside".
// =======================================================================================
// inline implementation details
......@@ -146,6 +170,33 @@ typename ServerType::Serves::Client reverseMembrane(
.castAs<typename ServerType::Serves::Client>();
}
namespace _ { // private
OrphanBuilder copyOutOfMembrane(PointerReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse);
OrphanBuilder copyOutOfMembrane(StructReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse);
OrphanBuilder copyOutOfMembrane(ListReader from, Orphanage to,
kj::Own<MembranePolicy> policy, bool reverse);
} // namespace _ (private)
template <typename Reader>
Orphan<typename kj::Decay<Reader>::Reads> copyIntoMembrane(
Reader&& from, Orphanage to, kj::Own<MembranePolicy> policy) {
return _::copyOutOfMembrane(
_::PointerHelpers<typename kj::Decay<Reader>::Reads>::getInternalReader(from),
to, kj::mv(policy), true);
}
template <typename Reader>
Orphan<typename kj::Decay<Reader>::Reads> copyOutOfMembrane(
Reader&& from, Orphanage to, kj::Own<MembranePolicy> policy) {
return _::copyOutOfMembrane(
_::PointerHelpers<typename kj::Decay<Reader>::Reads>::getInternalReader(from),
to, kj::mv(policy), false);
}
} // namespace capnp
#endif // CAPNP_MEMBRANE_H_
......@@ -34,6 +34,7 @@ class StructSchema;
class ListSchema;
struct DynamicStruct;
struct DynamicList;
namespace _ { struct OrphanageInternal; }
template <typename T>
class Orphan {
......@@ -53,6 +54,7 @@ public:
KJ_DISALLOW_COPY(Orphan);
Orphan(Orphan&&) = default;
Orphan& operator=(Orphan&&) = default;
inline Orphan(_::OrphanBuilder&& builder): builder(kj::mv(builder)) {}
inline BuilderFor<T> get();
// Get the underlying builder. If the orphan is null, this will allocate and return a default
......@@ -89,8 +91,6 @@ public:
private:
_::OrphanBuilder builder;
inline Orphan(_::OrphanBuilder&& builder): builder(kj::mv(builder)) {}
template <typename, Kind>
friend struct _::PointerHelpers;
template <typename, Kind>
......@@ -179,6 +179,7 @@ private:
struct NewOrphanListImpl;
friend class MessageBuilder;
friend struct _::OrphanageInternal;
};
// =======================================================================================
......@@ -276,6 +277,11 @@ struct OrphanGetImpl<Data, Kind::BLOB> {
}
};
struct OrphanageInternal {
static inline _::BuilderArena* getArena(Orphanage orphanage) { return orphanage.arena; }
static inline _::CapTableBuilder* getCapTable(Orphanage orphanage) { return orphanage.capTable; }
};
} // namespace _ (private)
template <typename T>
......
......@@ -819,6 +819,11 @@ interface TestMembrane {
}
}
struct TestContainMembrane {
cap @0 :TestMembrane.Thing;
list @1 :List(TestMembrane.Thing);
}
struct TestTransferCap {
list @0 :List(Element);
struct Element {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment