Commit 4878663f authored by Kenton Varda's avatar Kenton Varda

Unimbue cap descriptors before extraction/injection.

parent 5de44367
...@@ -114,8 +114,7 @@ kj::Maybe<kj::Own<const ClientHook>> BasicReaderArena::newBrokenCap(kj::StringPt ...@@ -114,8 +114,7 @@ kj::Maybe<kj::Own<const ClientHook>> BasicReaderArena::newBrokenCap(kj::StringPt
// ======================================================================================= // =======================================================================================
ImbuedReaderArena::ImbuedReaderArena(Arena* base, CapExtractorBase* capExtractor) ImbuedReaderArena::ImbuedReaderArena(Arena* base, CapExtractorBase* capExtractor)
: base(base), capExtractor(capExtractor), : base(base), capExtractor(capExtractor), segment0(nullptr) {}
segment0(nullptr, SegmentId(0), nullptr, nullptr) {}
ImbuedReaderArena::~ImbuedReaderArena() noexcept(false) {} ImbuedReaderArena::~ImbuedReaderArena() noexcept(false) {}
SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) { SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) {
...@@ -124,7 +123,7 @@ SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) { ...@@ -124,7 +123,7 @@ SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) {
if (baseSegment->getSegmentId() == SegmentId(0)) { if (baseSegment->getSegmentId() == SegmentId(0)) {
if (segment0.getArena() == nullptr) { if (segment0.getArena() == nullptr) {
kj::dtor(segment0); kj::dtor(segment0);
kj::ctor(segment0, this, *baseSegment); kj::ctor(segment0, this, baseSegment);
} }
KJ_DASSERT(segment0.getArray().begin() == baseSegment->getArray().begin()); KJ_DASSERT(segment0.getArray().begin() == baseSegment->getArray().begin());
return &segment0; return &segment0;
...@@ -146,7 +145,7 @@ SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) { ...@@ -146,7 +145,7 @@ SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) {
*lock = kj::mv(newMap); *lock = kj::mv(newMap);
} }
auto newSegment = kj::heap<SegmentReader>(this, *baseSegment); auto newSegment = kj::heap<ImbuedSegmentReader>(this, baseSegment);
SegmentReader* result = newSegment; SegmentReader* result = newSegment;
segments->insert(std::make_pair(baseSegment, mv(newSegment))); segments->insert(std::make_pair(baseSegment, mv(newSegment)));
return result; return result;
...@@ -162,7 +161,9 @@ void ImbuedReaderArena::reportReadLimitReached() { ...@@ -162,7 +161,9 @@ void ImbuedReaderArena::reportReadLimitReached() {
kj::Maybe<kj::Own<const ClientHook>> ImbuedReaderArena::extractCap( kj::Maybe<kj::Own<const ClientHook>> ImbuedReaderArena::extractCap(
const _::StructReader& capDescriptor) { const _::StructReader& capDescriptor) {
return capExtractor->extractCapInternal(capDescriptor); _::StructReader copy = capDescriptor;
copy.unimbue();
return capExtractor->extractCapInternal(copy);
} }
kj::Maybe<kj::Own<const ClientHook>> ImbuedReaderArena::newBrokenCap(kj::StringPtr description) { kj::Maybe<kj::Own<const ClientHook>> ImbuedReaderArena::newBrokenCap(kj::StringPtr description) {
...@@ -382,7 +383,9 @@ void ImbuedBuilderArena::reportReadLimitReached() { ...@@ -382,7 +383,9 @@ void ImbuedBuilderArena::reportReadLimitReached() {
kj::Maybe<kj::Own<const ClientHook>> ImbuedBuilderArena::extractCap( kj::Maybe<kj::Own<const ClientHook>> ImbuedBuilderArena::extractCap(
const _::StructReader& capDescriptor) { const _::StructReader& capDescriptor) {
return capInjector->getInjectedCapInternal(capDescriptor); _::StructReader copy = capDescriptor;
copy.unimbue();
return capInjector->getInjectedCapInternal(copy);
} }
kj::Maybe<kj::Own<const ClientHook>> ImbuedBuilderArena::newBrokenCap(kj::StringPtr description) { kj::Maybe<kj::Own<const ClientHook>> ImbuedBuilderArena::newBrokenCap(kj::StringPtr description) {
...@@ -404,7 +407,9 @@ OrphanBuilder ImbuedBuilderArena::injectCap(kj::Own<const ClientHook>&& cap) { ...@@ -404,7 +407,9 @@ OrphanBuilder ImbuedBuilderArena::injectCap(kj::Own<const ClientHook>&& cap) {
} }
void ImbuedBuilderArena::dropCap(const StructReader& capDescriptor) { void ImbuedBuilderArena::dropCap(const StructReader& capDescriptor) {
capInjector->dropCapInternal(capDescriptor); _::StructReader copy = capDescriptor;
copy.unimbue();
capInjector->dropCapInternal(copy);
} }
} // namespace _ (private) } // namespace _ (private)
......
...@@ -102,7 +102,6 @@ class SegmentReader { ...@@ -102,7 +102,6 @@ class SegmentReader {
public: public:
inline SegmentReader(Arena* arena, SegmentId id, kj::ArrayPtr<const word> ptr, inline SegmentReader(Arena* arena, SegmentId id, kj::ArrayPtr<const word> ptr,
ReadLimiter* readLimiter); ReadLimiter* readLimiter);
inline SegmentReader(Arena* arena, const SegmentReader& base);
KJ_ALWAYS_INLINE(bool containsInterval(const void* from, const void* to)); KJ_ALWAYS_INLINE(bool containsInterval(const void* from, const void* to));
...@@ -128,6 +127,18 @@ private: ...@@ -128,6 +127,18 @@ private:
friend class SegmentBuilder; friend class SegmentBuilder;
friend class ImbuedSegmentBuilder; friend class ImbuedSegmentBuilder;
friend class ImbuedSegmentReader;
};
class ImbuedSegmentReader: public SegmentReader {
public:
inline ImbuedSegmentReader(Arena* arena, SegmentReader* base);
inline ImbuedSegmentReader(decltype(nullptr));
inline SegmentReader* unimbue();
private:
SegmentReader* base;
}; };
class SegmentBuilder: public SegmentReader { class SegmentBuilder: public SegmentReader {
...@@ -172,6 +183,11 @@ public: ...@@ -172,6 +183,11 @@ public:
inline ImbuedSegmentBuilder(decltype(nullptr)); inline ImbuedSegmentBuilder(decltype(nullptr));
KJ_DISALLOW_COPY(ImbuedSegmentBuilder); KJ_DISALLOW_COPY(ImbuedSegmentBuilder);
inline SegmentBuilder* unimbue();
private:
SegmentBuilder* base;
}; };
class Arena { class Arena {
...@@ -237,9 +253,9 @@ private: ...@@ -237,9 +253,9 @@ private:
CapExtractorBase* capExtractor; CapExtractorBase* capExtractor;
// Optimize for single-segment messages so that small messages are handled quickly. // Optimize for single-segment messages so that small messages are handled quickly.
SegmentReader segment0; ImbuedSegmentReader segment0;
typedef std::unordered_map<SegmentReader*, kj::Own<SegmentReader>> SegmentMap; typedef std::unordered_map<SegmentReader*, kj::Own<ImbuedSegmentReader>> SegmentMap;
kj::MutexGuarded<kj::Maybe<kj::Own<SegmentMap>>> moreSegments; kj::MutexGuarded<kj::Maybe<kj::Own<SegmentMap>>> moreSegments;
}; };
...@@ -374,9 +390,6 @@ inline SegmentReader::SegmentReader(Arena* arena, SegmentId id, kj::ArrayPtr<con ...@@ -374,9 +390,6 @@ inline SegmentReader::SegmentReader(Arena* arena, SegmentId id, kj::ArrayPtr<con
ReadLimiter* readLimiter) ReadLimiter* readLimiter)
: arena(arena), id(id), ptr(ptr), readLimiter(readLimiter) {} : arena(arena), id(id), ptr(ptr), readLimiter(readLimiter) {}
inline SegmentReader::SegmentReader(Arena* arena, const SegmentReader& base)
: arena(arena), id(base.id), ptr(base.ptr), readLimiter(base.readLimiter) {}
inline bool SegmentReader::containsInterval(const void* from, const void* to) { inline bool SegmentReader::containsInterval(const void* from, const void* to) {
return from >= this->ptr.begin() && to <= this->ptr.end() && return from >= this->ptr.begin() && to <= this->ptr.end() &&
readLimiter->canRead( readLimiter->canRead(
...@@ -395,6 +408,14 @@ inline WordCount SegmentReader::getSize() { return ptr.size() * WORDS; } ...@@ -395,6 +408,14 @@ inline WordCount SegmentReader::getSize() { return ptr.size() * WORDS; }
inline kj::ArrayPtr<const word> SegmentReader::getArray() { return ptr; } inline kj::ArrayPtr<const word> SegmentReader::getArray() { return ptr; }
inline void SegmentReader::unread(WordCount64 amount) { readLimiter->unread(amount); } inline void SegmentReader::unread(WordCount64 amount) { readLimiter->unread(amount); }
inline ImbuedSegmentReader::ImbuedSegmentReader(Arena* arena, SegmentReader* base)
: SegmentReader(arena, base->id, base->ptr, base->readLimiter), base(base) {}
inline ImbuedSegmentReader::ImbuedSegmentReader(decltype(nullptr))
: SegmentReader(nullptr, SegmentId(0), nullptr, nullptr), base(nullptr) {}
inline SegmentReader* ImbuedSegmentReader::unimbue() {
return base;
}
// ------------------------------------------------------------------- // -------------------------------------------------------------------
inline SegmentBuilder::SegmentBuilder( inline SegmentBuilder::SegmentBuilder(
...@@ -452,9 +473,13 @@ inline BasicSegmentBuilder::BasicSegmentBuilder( ...@@ -452,9 +473,13 @@ inline BasicSegmentBuilder::BasicSegmentBuilder(
inline ImbuedSegmentBuilder::ImbuedSegmentBuilder(ImbuedBuilderArena* arena, SegmentBuilder* base) inline ImbuedSegmentBuilder::ImbuedSegmentBuilder(ImbuedBuilderArena* arena, SegmentBuilder* base)
: SegmentBuilder(arena, base->id, : SegmentBuilder(arena, base->id,
kj::arrayPtr(const_cast<word*>(base->ptr.begin()), base->ptr.size()), kj::arrayPtr(const_cast<word*>(base->ptr.begin()), base->ptr.size()),
base->readLimiter, base->pos) {} base->readLimiter, base->pos),
base(base) {}
inline ImbuedSegmentBuilder::ImbuedSegmentBuilder(decltype(nullptr)) inline ImbuedSegmentBuilder::ImbuedSegmentBuilder(decltype(nullptr))
: SegmentBuilder(nullptr, SegmentId(0), nullptr, nullptr, nullptr) {} : SegmentBuilder(nullptr, SegmentId(0), nullptr, nullptr, nullptr),
base(nullptr) {}
inline SegmentBuilder* ImbuedSegmentBuilder::unimbue() { return base; }
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
......
...@@ -2380,6 +2380,10 @@ BuilderArena* StructBuilder::getArena() { ...@@ -2380,6 +2380,10 @@ BuilderArena* StructBuilder::getArena() {
return segment->getArena(); return segment->getArena();
} }
void StructBuilder::unimbue() {
segment = static_cast<ImbuedSegmentBuilder*>(segment)->unimbue();
}
// ======================================================================================= // =======================================================================================
// StructReader // StructReader
...@@ -2399,6 +2403,10 @@ WordCount64 StructReader::totalSize() const { ...@@ -2399,6 +2403,10 @@ WordCount64 StructReader::totalSize() const {
return result; return result;
} }
void StructReader::unimbue() {
segment = static_cast<ImbuedSegmentReader*>(segment)->unimbue();
}
// ======================================================================================= // =======================================================================================
// ListBuilder // ListBuilder
......
...@@ -462,6 +462,10 @@ public: ...@@ -462,6 +462,10 @@ public:
BuilderArena* getArena(); BuilderArena* getArena();
// Gets the arena in which this object is allocated. // Gets the arena in which this object is allocated.
void unimbue();
// Removes the capability context from the builder. This means replacing the segment pointer --
// which is assumed to point to an ImbuedSegmentBuilder -- with the non-imbued base segment.
private: private:
SegmentBuilder* segment; // Memory segment in which the struct resides. SegmentBuilder* segment; // Memory segment in which the struct resides.
void* data; // Pointer to the encoded data. void* data; // Pointer to the encoded data.
...@@ -527,6 +531,10 @@ public: ...@@ -527,6 +531,10 @@ public:
// use the result as a hint for allocating the first segment, do the copy, and then throw an // use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns. // exception if it overruns.
void unimbue();
// Removes the capability context from the reader. This means replacing the segment pointer --
// which is assumed to point to an ImbuedSegmentReader -- with the non-imbued base segment.
private: private:
SegmentReader* segment; // Memory segment in which the struct resides. SegmentReader* segment; // Memory segment in which the struct resides.
......
...@@ -36,8 +36,8 @@ struct SetTrueInDestructor: public Refcounted { ...@@ -36,8 +36,8 @@ struct SetTrueInDestructor: public Refcounted {
TEST(Refcount, Basic) { TEST(Refcount, Basic) {
bool b = false; bool b = false;
Own<SetTrueInDestructor> ref1 = kj::refcounted<SetTrueInDestructor>(&b); Own<SetTrueInDestructor> ref1 = kj::refcounted<SetTrueInDestructor>(&b);
Own<SetTrueInDestructor> ref2 = kj::addRef(*ref1); Own<const SetTrueInDestructor> ref2 = kj::addRef(*ref1);
Own<SetTrueInDestructor> ref3 = kj::addRef(*ref2); Own<const SetTrueInDestructor> ref3 = kj::addRef(*ref2);
EXPECT_FALSE(b); EXPECT_FALSE(b);
ref1 = Own<SetTrueInDestructor>(); ref1 = Own<SetTrueInDestructor>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment