Commit 4878663f authored by Kenton Varda's avatar Kenton Varda

Unimbue cap descriptors before extraction/injection.

parent 5de44367
......@@ -114,8 +114,7 @@ kj::Maybe<kj::Own<const ClientHook>> BasicReaderArena::newBrokenCap(kj::StringPt
// =======================================================================================
ImbuedReaderArena::ImbuedReaderArena(Arena* base, CapExtractorBase* capExtractor)
: base(base), capExtractor(capExtractor),
segment0(nullptr, SegmentId(0), nullptr, nullptr) {}
: base(base), capExtractor(capExtractor), segment0(nullptr) {}
ImbuedReaderArena::~ImbuedReaderArena() noexcept(false) {}
SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) {
......@@ -124,7 +123,7 @@ SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) {
if (baseSegment->getSegmentId() == SegmentId(0)) {
if (segment0.getArena() == nullptr) {
kj::dtor(segment0);
kj::ctor(segment0, this, *baseSegment);
kj::ctor(segment0, this, baseSegment);
}
KJ_DASSERT(segment0.getArray().begin() == baseSegment->getArray().begin());
return &segment0;
......@@ -146,7 +145,7 @@ SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) {
*lock = kj::mv(newMap);
}
auto newSegment = kj::heap<SegmentReader>(this, *baseSegment);
auto newSegment = kj::heap<ImbuedSegmentReader>(this, baseSegment);
SegmentReader* result = newSegment;
segments->insert(std::make_pair(baseSegment, mv(newSegment)));
return result;
......@@ -162,7 +161,9 @@ void ImbuedReaderArena::reportReadLimitReached() {
kj::Maybe<kj::Own<const ClientHook>> ImbuedReaderArena::extractCap(
const _::StructReader& capDescriptor) {
return capExtractor->extractCapInternal(capDescriptor);
_::StructReader copy = capDescriptor;
copy.unimbue();
return capExtractor->extractCapInternal(copy);
}
kj::Maybe<kj::Own<const ClientHook>> ImbuedReaderArena::newBrokenCap(kj::StringPtr description) {
......@@ -382,7 +383,9 @@ void ImbuedBuilderArena::reportReadLimitReached() {
kj::Maybe<kj::Own<const ClientHook>> ImbuedBuilderArena::extractCap(
const _::StructReader& capDescriptor) {
return capInjector->getInjectedCapInternal(capDescriptor);
_::StructReader copy = capDescriptor;
copy.unimbue();
return capInjector->getInjectedCapInternal(copy);
}
kj::Maybe<kj::Own<const ClientHook>> ImbuedBuilderArena::newBrokenCap(kj::StringPtr description) {
......@@ -404,7 +407,9 @@ OrphanBuilder ImbuedBuilderArena::injectCap(kj::Own<const ClientHook>&& cap) {
}
void ImbuedBuilderArena::dropCap(const StructReader& capDescriptor) {
capInjector->dropCapInternal(capDescriptor);
_::StructReader copy = capDescriptor;
copy.unimbue();
capInjector->dropCapInternal(copy);
}
} // namespace _ (private)
......
......@@ -102,7 +102,6 @@ class SegmentReader {
public:
inline SegmentReader(Arena* arena, SegmentId id, kj::ArrayPtr<const word> ptr,
ReadLimiter* readLimiter);
inline SegmentReader(Arena* arena, const SegmentReader& base);
KJ_ALWAYS_INLINE(bool containsInterval(const void* from, const void* to));
......@@ -128,6 +127,18 @@ private:
friend class SegmentBuilder;
friend class ImbuedSegmentBuilder;
friend class ImbuedSegmentReader;
};
class ImbuedSegmentReader: public SegmentReader {
public:
inline ImbuedSegmentReader(Arena* arena, SegmentReader* base);
inline ImbuedSegmentReader(decltype(nullptr));
inline SegmentReader* unimbue();
private:
SegmentReader* base;
};
class SegmentBuilder: public SegmentReader {
......@@ -172,6 +183,11 @@ public:
inline ImbuedSegmentBuilder(decltype(nullptr));
KJ_DISALLOW_COPY(ImbuedSegmentBuilder);
inline SegmentBuilder* unimbue();
private:
SegmentBuilder* base;
};
class Arena {
......@@ -237,9 +253,9 @@ private:
CapExtractorBase* capExtractor;
// Optimize for single-segment messages so that small messages are handled quickly.
SegmentReader segment0;
ImbuedSegmentReader segment0;
typedef std::unordered_map<SegmentReader*, kj::Own<SegmentReader>> SegmentMap;
typedef std::unordered_map<SegmentReader*, kj::Own<ImbuedSegmentReader>> SegmentMap;
kj::MutexGuarded<kj::Maybe<kj::Own<SegmentMap>>> moreSegments;
};
......@@ -374,9 +390,6 @@ inline SegmentReader::SegmentReader(Arena* arena, SegmentId id, kj::ArrayPtr<con
ReadLimiter* readLimiter)
: arena(arena), id(id), ptr(ptr), readLimiter(readLimiter) {}
inline SegmentReader::SegmentReader(Arena* arena, const SegmentReader& base)
: arena(arena), id(base.id), ptr(base.ptr), readLimiter(base.readLimiter) {}
inline bool SegmentReader::containsInterval(const void* from, const void* to) {
return from >= this->ptr.begin() && to <= this->ptr.end() &&
readLimiter->canRead(
......@@ -395,6 +408,14 @@ inline WordCount SegmentReader::getSize() { return ptr.size() * WORDS; }
inline kj::ArrayPtr<const word> SegmentReader::getArray() { return ptr; }
inline void SegmentReader::unread(WordCount64 amount) { readLimiter->unread(amount); }
inline ImbuedSegmentReader::ImbuedSegmentReader(Arena* arena, SegmentReader* base)
: SegmentReader(arena, base->id, base->ptr, base->readLimiter), base(base) {}
inline ImbuedSegmentReader::ImbuedSegmentReader(decltype(nullptr))
: SegmentReader(nullptr, SegmentId(0), nullptr, nullptr), base(nullptr) {}
inline SegmentReader* ImbuedSegmentReader::unimbue() {
return base;
}
// -------------------------------------------------------------------
inline SegmentBuilder::SegmentBuilder(
......@@ -452,9 +473,13 @@ inline BasicSegmentBuilder::BasicSegmentBuilder(
inline ImbuedSegmentBuilder::ImbuedSegmentBuilder(ImbuedBuilderArena* arena, SegmentBuilder* base)
: SegmentBuilder(arena, base->id,
kj::arrayPtr(const_cast<word*>(base->ptr.begin()), base->ptr.size()),
base->readLimiter, base->pos) {}
base->readLimiter, base->pos),
base(base) {}
inline ImbuedSegmentBuilder::ImbuedSegmentBuilder(decltype(nullptr))
: SegmentBuilder(nullptr, SegmentId(0), nullptr, nullptr, nullptr) {}
: SegmentBuilder(nullptr, SegmentId(0), nullptr, nullptr, nullptr),
base(nullptr) {}
inline SegmentBuilder* ImbuedSegmentBuilder::unimbue() { return base; }
} // namespace _ (private)
} // namespace capnp
......
......@@ -2380,6 +2380,10 @@ BuilderArena* StructBuilder::getArena() {
return segment->getArena();
}
void StructBuilder::unimbue() {
segment = static_cast<ImbuedSegmentBuilder*>(segment)->unimbue();
}
// =======================================================================================
// StructReader
......@@ -2399,6 +2403,10 @@ WordCount64 StructReader::totalSize() const {
return result;
}
void StructReader::unimbue() {
segment = static_cast<ImbuedSegmentReader*>(segment)->unimbue();
}
// =======================================================================================
// ListBuilder
......
......@@ -462,6 +462,10 @@ public:
BuilderArena* getArena();
// Gets the arena in which this object is allocated.
void unimbue();
// Removes the capability context from the builder. This means replacing the segment pointer --
// which is assumed to point to an ImbuedSegmentBuilder -- with the non-imbued base segment.
private:
SegmentBuilder* segment; // Memory segment in which the struct resides.
void* data; // Pointer to the encoded data.
......@@ -527,6 +531,10 @@ public:
// use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns.
void unimbue();
// Removes the capability context from the reader. This means replacing the segment pointer --
// which is assumed to point to an ImbuedSegmentReader -- with the non-imbued base segment.
private:
SegmentReader* segment; // Memory segment in which the struct resides.
......
......@@ -36,8 +36,8 @@ struct SetTrueInDestructor: public Refcounted {
TEST(Refcount, Basic) {
bool b = false;
Own<SetTrueInDestructor> ref1 = kj::refcounted<SetTrueInDestructor>(&b);
Own<SetTrueInDestructor> ref2 = kj::addRef(*ref1);
Own<SetTrueInDestructor> ref3 = kj::addRef(*ref2);
Own<const SetTrueInDestructor> ref2 = kj::addRef(*ref1);
Own<const SetTrueInDestructor> ref3 = kj::addRef(*ref2);
EXPECT_FALSE(b);
ref1 = Own<SetTrueInDestructor>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment