Commit 2d1c9a53 authored by Kenton Varda's avatar Kenton Varda

Remove a bunch of mutexes.

parent bf7af0b6
...@@ -161,7 +161,7 @@ private: ...@@ -161,7 +161,7 @@ private:
} }
private: private:
mutable kj::Vector<kj::String> caps; kj::Vector<kj::String> caps;
}; };
}; };
......
...@@ -249,26 +249,22 @@ public: ...@@ -249,26 +249,22 @@ public:
disconnectFulfiller(kj::mv(disconnectFulfiller)), disconnectFulfiller(kj::mv(disconnectFulfiller)),
tasks(*this) { tasks(*this) {
tasks.add(messageLoop()); tasks.add(messageLoop());
tables.getWithoutLock().resolutionChainTail = kj::refcounted<ResolutionChain>(); resolutionChainTail = kj::refcounted<ResolutionChain>();
} }
kj::Own<ClientHook> restore(ObjectPointer::Reader objectId) { kj::Own<ClientHook> restore(ObjectPointer::Reader objectId) {
QuestionId questionId; QuestionId questionId;
kj::Own<QuestionRef> questionRef; auto& question = questions.next(questionId);
auto paf = kj::newPromiseAndFulfiller<kj::Promise<kj::Own<RpcResponse>>>();
{
auto lock = tables.lockExclusive();
auto& question = lock->questions.next(questionId);
// We need a dummy paramCaps since null normally indicates that the question has completed. // We need a dummy paramCaps since null normally indicates that the question has completed.
question.paramCaps = kj::heap<CapInjectorImpl>(*this); question.paramCaps = kj::heap<CapInjectorImpl>(*this);
questionRef = kj::refcounted<QuestionRef>(*this, questionId, kj::mv(paf.fulfiller)); auto paf = kj::newPromiseAndFulfiller<kj::Promise<kj::Own<RpcResponse>>>();
auto questionRef = kj::refcounted<QuestionRef>(*this, questionId, kj::mv(paf.fulfiller));
question.selfRef = *questionRef; question.selfRef = *questionRef;
paf.promise.attach(kj::addRef(*questionRef)); paf.promise.attach(kj::addRef(*questionRef));
}
{ {
auto message = connection->newOutgoingMessage( auto message = connection->newOutgoingMessage(
...@@ -293,15 +289,15 @@ public: ...@@ -293,15 +289,15 @@ public:
void disconnect(kj::Exception&& exception) { void disconnect(kj::Exception&& exception) {
{ {
// Carefully pull all the objects out of the tables prior to releasing them because their
// destructors could come back and mess with the tables.
kj::Vector<kj::Own<PipelineHook>> pipelinesToRelease; kj::Vector<kj::Own<PipelineHook>> pipelinesToRelease;
kj::Vector<kj::Own<ClientHook>> clientsToRelease; kj::Vector<kj::Own<ClientHook>> clientsToRelease;
kj::Vector<kj::Own<CapInjectorImpl>> capInjectorsToRelease; kj::Vector<kj::Own<CapInjectorImpl>> capInjectorsToRelease;
kj::Vector<kj::Promise<kj::Own<RpcResponse>>> tailCallsToRelease; kj::Vector<kj::Promise<kj::Own<RpcResponse>>> tailCallsToRelease;
kj::Vector<kj::Promise<void>> resolveOpsToRelease; kj::Vector<kj::Promise<void>> resolveOpsToRelease;
auto lock = tables.lockExclusive(); if (networkException != nullptr) {
if (lock->networkException != nullptr) {
// Oops, already disconnected. // Oops, already disconnected.
return; return;
} }
...@@ -311,7 +307,7 @@ public: ...@@ -311,7 +307,7 @@ public:
__FILE__, __LINE__, kj::str("Disconnected: ", exception.getDescription())); __FILE__, __LINE__, kj::str("Disconnected: ", exception.getDescription()));
// All current questions complete with exceptions. // All current questions complete with exceptions.
lock->questions.forEach([&](QuestionId id, Question& question) { questions.forEach([&](QuestionId id, Question& question) {
KJ_IF_MAYBE(questionRef, question.selfRef) { KJ_IF_MAYBE(questionRef, question.selfRef) {
// QuestionRef still present. Make sure it's not in the midst of being destroyed, then // QuestionRef still present. Make sure it's not in the midst of being destroyed, then
// reject it. // reject it.
...@@ -324,7 +320,7 @@ public: ...@@ -324,7 +320,7 @@ public:
} }
}); });
lock->answers.forEach([&](QuestionId id, Answer& answer) { answers.forEach([&](QuestionId id, Answer& answer) {
KJ_IF_MAYBE(p, answer.pipeline) { KJ_IF_MAYBE(p, answer.pipeline) {
pipelinesToRelease.add(kj::mv(*p)); pipelinesToRelease.add(kj::mv(*p));
} }
...@@ -342,25 +338,25 @@ public: ...@@ -342,25 +338,25 @@ public:
} }
}); });
lock->exports.forEach([&](ExportId id, Export& exp) { exports.forEach([&](ExportId id, Export& exp) {
clientsToRelease.add(kj::mv(exp.clientHook)); clientsToRelease.add(kj::mv(exp.clientHook));
resolveOpsToRelease.add(kj::mv(exp.resolveOp)); resolveOpsToRelease.add(kj::mv(exp.resolveOp));
exp = Export(); exp = Export();
}); });
lock->imports.forEach([&](ExportId id, Import& import) { imports.forEach([&](ExportId id, Import& import) {
KJ_IF_MAYBE(f, import.promiseFulfiller) { KJ_IF_MAYBE(f, import.promiseFulfiller) {
f->get()->reject(kj::cp(networkException)); f->get()->reject(kj::cp(networkException));
} }
}); });
lock->embargoes.forEach([&](EmbargoId id, Embargo& embargo) { embargoes.forEach([&](EmbargoId id, Embargo& embargo) {
KJ_IF_MAYBE(f, embargo.fulfiller) { KJ_IF_MAYBE(f, embargo.fulfiller) {
f->get()->reject(kj::cp(networkException)); f->get()->reject(kj::cp(networkException));
} }
}); });
lock->networkException = kj::mv(networkException); this->networkException = kj::mv(networkException);
} }
{ {
...@@ -492,11 +488,11 @@ private: ...@@ -492,11 +488,11 @@ private:
kj::Own<VatNetworkBase::Connection> connection; kj::Own<VatNetworkBase::Connection> connection;
kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller; kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller;
struct Tables {
ExportTable<ExportId, Export> exports; ExportTable<ExportId, Export> exports;
ExportTable<QuestionId, Question> questions; ExportTable<QuestionId, Question> questions;
ImportTable<QuestionId, Answer> answers; ImportTable<QuestionId, Answer> answers;
ImportTable<ExportId, Import> imports; ImportTable<ExportId, Import> imports;
// The Four Tables!
// The order of the tables is important for correct destruction. // The order of the tables is important for correct destruction.
std::unordered_map<ClientHook*, ExportId> exportsByCap; std::unordered_map<ClientHook*, ExportId> exportsByCap;
...@@ -515,8 +511,6 @@ private: ...@@ -515,8 +511,6 @@ private:
ExportTable<EmbargoId, Embargo> embargoes; ExportTable<EmbargoId, Embargo> embargoes;
// There are only four tables. This definitely isn't a fifth table. I don't know what you're // There are only four tables. This definitely isn't a fifth table. I don't know what you're
// talking about. // talking about.
};
kj::MutexGuarded<Tables> tables;
kj::TaskSet tasks; kj::TaskSet tasks;
...@@ -655,12 +649,10 @@ private: ...@@ -655,12 +649,10 @@ private:
RpcClient(RpcConnectionState& connectionState) RpcClient(RpcConnectionState& connectionState)
: connectionState(kj::addRef(connectionState)) {} : connectionState(kj::addRef(connectionState)) {}
virtual kj::Maybe<ExportId> writeDescriptor( virtual kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) = 0;
rpc::CapDescriptor::Builder descriptor, Tables& tables) = 0; // Writes a CapDescriptor referencing this client. The CapDescriptor must be sent as part of
// Writes a CapDescriptor referencing this client. Must be called with the // the very next message sent on the connection, as it may become invalid if other things
// RpcConnectionState's table locked -- a reference to them is passed as the second argument. // happen.
// The CapDescriptor must be sent before unlocking the tables, as it may become invalid at
// any time once the tables are unlocked.
// //
// If writing the descriptor adds a new export to the export table, or increments the refcount // If writing the descriptor adds a new export to the export table, or increments the refcount
// on an existing one, then the ID is returned and the caller is responsible for removing it // on an existing one, then the ID is returned and the caller is responsible for removing it
...@@ -743,18 +735,15 @@ private: ...@@ -743,18 +735,15 @@ private:
: RpcClient(connectionState), importId(importId) {} : RpcClient(connectionState), importId(importId) {}
~ImportClient() noexcept(false) { ~ImportClient() noexcept(false) {
{
// Remove self from the import table, if the table is still pointing at us. (It's possible // Remove self from the import table, if the table is still pointing at us. (It's possible
// that another thread attempted to obtain this import just as the destructor started, in // that another thread attempted to obtain this import just as the destructor started, in
// which case that other thread will have constructed a new ImportClient and placed it in // which case that other thread will have constructed a new ImportClient and placed it in
// the import table. Therefore, we must actually verify that the import table points at // the import table. Therefore, we must actually verify that the import table points at
// this object.) // this object.)
auto lock = connectionState->tables.lockExclusive(); KJ_IF_MAYBE(import, connectionState->imports.find(importId)) {
KJ_IF_MAYBE(import, lock->imports.find(importId)) {
KJ_IF_MAYBE(i, import->importClient) { KJ_IF_MAYBE(i, import->importClient) {
if (i == this) { if (i == this) {
lock->imports.erase(importId); connectionState->imports.erase(importId);
}
} }
} }
} }
...@@ -775,8 +764,7 @@ private: ...@@ -775,8 +764,7 @@ private:
++remoteRefcount; ++remoteRefcount;
} }
kj::Maybe<ExportId> writeDescriptor( kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) override {
rpc::CapDescriptor::Builder descriptor, Tables& tables) override {
descriptor.setReceiverHosted(importId); descriptor.setReceiverHosted(importId);
return nullptr; return nullptr;
} }
...@@ -817,8 +805,7 @@ private: ...@@ -817,8 +805,7 @@ private:
kj::Array<PipelineOp>&& ops) kj::Array<PipelineOp>&& ops)
: RpcClient(connectionState), questionRef(kj::mv(questionRef)), ops(kj::mv(ops)) {} : RpcClient(connectionState), questionRef(kj::mv(questionRef)), ops(kj::mv(ops)) {}
kj::Maybe<ExportId> writeDescriptor( kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) override {
rpc::CapDescriptor::Builder descriptor, Tables& tables) override {
auto promisedAnswer = descriptor.initReceiverAnswer(); auto promisedAnswer = descriptor.initReceiverAnswer();
promisedAnswer.setQuestionId(questionRef->getId()); promisedAnswer.setQuestionId(questionRef->getId());
promisedAnswer.adoptTransform(fromPipelineOps( promisedAnswer.adoptTransform(fromPipelineOps(
...@@ -863,7 +850,8 @@ private: ...@@ -863,7 +850,8 @@ private:
kj::Promise<kj::Own<ClientHook>> eventual, kj::Promise<kj::Own<ClientHook>> eventual,
kj::Maybe<ExportId> importId) kj::Maybe<ExportId> importId)
: RpcClient(connectionState), : RpcClient(connectionState),
inner(Inner {false, kj::mv(initial)}), isResolved(false),
cap(kj::mv(initial)),
importId(importId), importId(importId),
fork(eventual.fork()), fork(eventual.fork()),
resolveSelfPromise(fork.addBranch().then( resolveSelfPromise(fork.addBranch().then(
...@@ -893,8 +881,7 @@ private: ...@@ -893,8 +881,7 @@ private:
// contain a pointer back to it. Remove that pointer. Note that we have to verify that // contain a pointer back to it. Remove that pointer. Note that we have to verify that
// the import still exists and the pointer still points back to this object because this // the import still exists and the pointer still points back to this object because this
// object may actually outlive the import. // object may actually outlive the import.
auto lock = connectionState->tables.lockExclusive(); KJ_IF_MAYBE(import, connectionState->imports.find(*id)) {
KJ_IF_MAYBE(import, lock->imports.find(*id)) {
KJ_IF_MAYBE(c, import->appClient) { KJ_IF_MAYBE(c, import->appClient) {
if (c == this) { if (c == this) {
import->appClient = nullptr; import->appClient = nullptr;
...@@ -904,21 +891,20 @@ private: ...@@ -904,21 +891,20 @@ private:
} }
} }
kj::Maybe<ExportId> writeDescriptor( kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor) override {
rpc::CapDescriptor::Builder descriptor, Tables& tables) override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED); __atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return connectionState->writeDescriptor(*inner.lockExclusive()->cap, descriptor, tables); return connectionState->writeDescriptor(*cap, descriptor);
} }
kj::Maybe<kj::Own<ClientHook>> writeTarget( kj::Maybe<kj::Own<ClientHook>> writeTarget(
rpc::MessageTarget::Builder target) override { rpc::MessageTarget::Builder target) override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED); __atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return connectionState->writeTarget(*inner.lockExclusive()->cap, target); return connectionState->writeTarget(*cap, target);
} }
kj::Own<ClientHook> getInnermostClient() override { kj::Own<ClientHook> getInnermostClient() override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED); __atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return connectionState->getInnermostClient(*inner.lockExclusive()->cap); return connectionState->getInnermostClient(*cap);
} }
// implements ClientHook ----------------------------------------- // implements ClientHook -----------------------------------------
...@@ -926,19 +912,18 @@ private: ...@@ -926,19 +912,18 @@ private:
Request<ObjectPointer, ObjectPointer> newCall( Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) override { uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED); __atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return inner.lockExclusive()->cap->newCall(interfaceId, methodId, firstSegmentWordSize); return cap->newCall(interfaceId, methodId, firstSegmentWordSize);
} }
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) override { kj::Own<CallContextHook>&& context) override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED); __atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return inner.lockExclusive()->cap->call(interfaceId, methodId, kj::mv(context)); return cap->call(interfaceId, methodId, kj::mv(context));
} }
kj::Maybe<ClientHook&> getResolved() override { kj::Maybe<ClientHook&> getResolved() override {
auto lock = inner.lockExclusive(); if (isResolved) {
if (lock->isResolved) { return *cap;
return *lock->cap;
} else { } else {
return nullptr; return nullptr;
} }
...@@ -949,12 +934,9 @@ private: ...@@ -949,12 +934,9 @@ private:
} }
private: private:
struct Inner {
bool isResolved; bool isResolved;
kj::Own<ClientHook> cap; kj::Own<ClientHook> cap;
};
kj::MutexGuarded<Inner> inner;
kj::Maybe<ExportId> importId; kj::Maybe<ExportId> importId;
kj::ForkedPromise<kj::Own<ClientHook>> fork; kj::ForkedPromise<kj::Own<ClientHook>> fork;
...@@ -962,7 +944,7 @@ private: ...@@ -962,7 +944,7 @@ private:
// ensure the continuation is not still running. // ensure the continuation is not still running.
kj::Promise<void> resolveSelfPromise; kj::Promise<void> resolveSelfPromise;
mutable bool receivedCall = false; bool receivedCall = false;
void resolve(kj::Own<ClientHook> replacement) { void resolve(kj::Own<ClientHook> replacement) {
if (replacement->getBrand() != connectionState.get() && if (replacement->getBrand() != connectionState.get() &&
...@@ -978,15 +960,13 @@ private: ...@@ -978,15 +960,13 @@ private:
auto disembargo = message->getBody().initAs<rpc::Message>().initDisembargo(); auto disembargo = message->getBody().initAs<rpc::Message>().initDisembargo();
{ {
auto redirect = connectionState->writeTarget( auto redirect = connectionState->writeTarget(*cap, disembargo.initTarget());
*inner.lockExclusive()->cap, disembargo.initTarget());
KJ_ASSERT(redirect == nullptr, KJ_ASSERT(redirect == nullptr,
"Original promise target should always be from this RPC connection."); "Original promise target should always be from this RPC connection.");
} }
EmbargoId embargoId; EmbargoId embargoId;
auto lock = connectionState->tables.lockExclusive(); Embargo& embargo = connectionState->embargoes.next(embargoId);
Embargo& embargo = lock->embargoes.next(embargoId);
disembargo.getContext().setSenderLoopback(embargoId); disembargo.getContext().setSenderLoopback(embargoId);
...@@ -1007,18 +987,12 @@ private: ...@@ -1007,18 +987,12 @@ private:
message->send(); message->send();
} }
// Careful to make sure the old client is not destroyed until we release the lock. cap = replacement->addRef();
kj::Own<ClientHook> old; isResolved = true;
auto lock = inner.lockExclusive();
old = kj::mv(lock->cap);
lock->cap = replacement->addRef();
lock->isResolved = true;
} }
}; };
kj::Maybe<ExportId> writeDescriptor( kj::Maybe<ExportId> writeDescriptor(ClientHook& cap, rpc::CapDescriptor::Builder descriptor) {
ClientHook& cap, rpc::CapDescriptor::Builder descriptor, Tables& tables) {
// Write a descriptor for the given capability. The tables must be locked by the caller and // Write a descriptor for the given capability. The tables must be locked by the caller and
// passed in as a parameter. // passed in as a parameter.
...@@ -1033,20 +1007,20 @@ private: ...@@ -1033,20 +1007,20 @@ private:
} }
if (inner->getBrand() == this) { if (inner->getBrand() == this) {
return kj::downcast<RpcClient>(*inner).writeDescriptor(descriptor, tables); return kj::downcast<RpcClient>(*inner).writeDescriptor(descriptor);
} else { } else {
auto iter = tables.exportsByCap.find(inner); auto iter = exportsByCap.find(inner);
if (iter != tables.exportsByCap.end()) { if (iter != exportsByCap.end()) {
// We've already seen and exported this capability before. Just up the refcount. // We've already seen and exported this capability before. Just up the refcount.
auto& exp = KJ_ASSERT_NONNULL(tables.exports.find(iter->second)); auto& exp = KJ_ASSERT_NONNULL(exports.find(iter->second));
++exp.refcount; ++exp.refcount;
descriptor.setSenderHosted(iter->second); descriptor.setSenderHosted(iter->second);
return iter->second; return iter->second;
} else { } else {
// This is the first time we've seen this capability. // This is the first time we've seen this capability.
ExportId exportId; ExportId exportId;
auto& exp = tables.exports.next(exportId); auto& exp = exports.next(exportId);
tables.exportsByCap[inner] = exportId; exportsByCap[inner] = exportId;
exp.refcount = 1; exp.refcount = 1;
exp.clientHook = inner->addRef(); exp.clientHook = inner->addRef();
descriptor.setSenderHosted(exportId); descriptor.setSenderHosted(exportId);
...@@ -1117,9 +1091,8 @@ private: ...@@ -1117,9 +1091,8 @@ private:
// Update the export table to point at this object instead. We know that our entry in the // Update the export table to point at this object instead. We know that our entry in the
// export table is still live because when it is destroyed the asynchronous resolution task // export table is still live because when it is destroyed the asynchronous resolution task
// (i.e. this code) is canceled. // (i.e. this code) is canceled.
auto lock = tables.lockExclusive(); auto& exp = KJ_ASSERT_NONNULL(exports.find(exportId));
auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId)); exportsByCap.erase(exp.clientHook);
lock->exportsByCap.erase(exp.clientHook);
exp.clientHook = kj::mv(resolution); exp.clientHook = kj::mv(resolution);
if (exp.clientHook->getBrand() != this) { if (exp.clientHook->getBrand() != this) {
...@@ -1131,8 +1104,7 @@ private: ...@@ -1131,8 +1104,7 @@ private:
// be able to just reuse the existing export table entry to represent the new promise -- // be able to just reuse the existing export table entry to represent the new promise --
// unless it already has an entry. Let's check. // unless it already has an entry. Let's check.
auto insertResult = lock->exportsByCap.insert( auto insertResult = exportsByCap.insert(std::make_pair(exp.clientHook.get(), exportId));
std::make_pair(exp.clientHook.get(), exportId));
if (insertResult.second) { if (insertResult.second) {
// The new promise was not already in the table, therefore the existing export table // The new promise was not already in the table, therefore the existing export table
...@@ -1148,7 +1120,7 @@ private: ...@@ -1148,7 +1120,7 @@ private:
messageSizeHint<rpc::Resolve>() + sizeInWords<rpc::CapDescriptor>() + 16); messageSizeHint<rpc::Resolve>() + sizeInWords<rpc::CapDescriptor>() + 16);
auto resolve = message->getBody().initAs<rpc::Message>().initResolve(); auto resolve = message->getBody().initAs<rpc::Message>().initResolve();
resolve.setPromiseId(exportId); resolve.setPromiseId(exportId);
writeDescriptor(*exp.clientHook, resolve.initCap(), *lock); writeDescriptor(*exp.clientHook, resolve.initCap());
message->send(); message->send();
return kj::READY_NOW; return kj::READY_NOW;
...@@ -1178,8 +1150,7 @@ private: ...@@ -1178,8 +1150,7 @@ private:
resolutionChain(kj::mv(resolutionChain)) {} resolutionChain(kj::mv(resolutionChain)) {}
~CapExtractorImpl() noexcept(false) { ~CapExtractorImpl() noexcept(false) {
KJ_ASSERT(retainedCaps.getWithoutLock().size() == 0 || KJ_ASSERT(retainedCaps.size() == 0 || connectionState.networkException != nullptr,
connectionState.tables.lockShared()->networkException != nullptr,
"CapExtractorImpl destroyed without getting a chance to retain the caps!") { "CapExtractorImpl destroyed without getting a chance to retain the caps!") {
break; break;
} }
...@@ -1197,7 +1168,7 @@ private: ...@@ -1197,7 +1168,7 @@ private:
// If `final` is true then there's no need to lock. If it is false, then asynchronous // If `final` is true then there's no need to lock. If it is false, then asynchronous
// access is possible. It's probably not worth taking the lock to look; we'll just return // access is possible. It's probably not worth taking the lock to look; we'll just return
// a silly estimate. // a silly estimate.
uint count = final ? retainedCaps.getWithoutLock().size() : 32; uint count = final ? retainedCaps.size() : 32;
return (count * sizeof(ExportId) + (sizeof(word) - 1)) / sizeof(word); return (count * sizeof(ExportId) + (sizeof(word) - 1)) / sizeof(word);
} }
...@@ -1218,15 +1189,13 @@ private: ...@@ -1218,15 +1189,13 @@ private:
// message's release. // message's release.
// Called on finalization, when all extractions have ceased, so we can skip the lock. // Called on finalization, when all extractions have ceased, so we can skip the lock.
kj::Vector<ExportId> retainedCaps = kj::mv(this->retainedCaps.getWithoutLock()); kj::Vector<ExportId> retainedCaps = kj::mv(this->retainedCaps);
kj::Vector<kj::Own<ClientHook>> refs(retainedCaps.size()); kj::Vector<kj::Own<ClientHook>> refs(retainedCaps.size());
auto lock = connectionState.tables.lockExclusive();
auto actualRetained = retainedCaps.begin(); auto actualRetained = retainedCaps.begin();
for (ExportId importId: retainedCaps) { for (ExportId importId: retainedCaps) {
// Check if the import still exists under this ID. // Check if the import still exists under this ID.
KJ_IF_MAYBE(import, lock->imports.find(importId)) { KJ_IF_MAYBE(import, connectionState.imports.find(importId)) {
KJ_IF_MAYBE(ic, import->importClient) { KJ_IF_MAYBE(ic, import->importClient) {
KJ_IF_MAYBE(ref, kj::tryAddRef(*ic)) { KJ_IF_MAYBE(ref, kj::tryAddRef(*ic)) {
// Import indeed still exists! We'll return it in the retained caps, which means it // Import indeed still exists! We'll return it in the retained caps, which means it
...@@ -1253,23 +1222,21 @@ private: ...@@ -1253,23 +1222,21 @@ private:
} }
static kj::Own<ClientHook> extractCapAndAddRef( static kj::Own<ClientHook> extractCapAndAddRef(
RpcConnectionState& connectionState, Tables& lockedTables, RpcConnectionState& connectionState, rpc::CapDescriptor::Reader descriptor) {
rpc::CapDescriptor::Reader descriptor) {
// Interpret the given capability descriptor and, if it is an import, immediately give it // Interpret the given capability descriptor and, if it is an import, immediately give it
// a remote ref. This is called when interpreting messages that have a CapabilityDescriptor // a remote ref. This is called when interpreting messages that have a CapabilityDescriptor
// but do not have a corresponding response message where a list of retained caps is given. // but do not have a corresponding response message where a list of retained caps is given.
// In these cases, the cap is always assumed retained, and must be explicitly released. // In these cases, the cap is always assumed retained, and must be explicitly released.
// For example, the 'Resolve' message contains a capability which is presumed to be retained. // For example, the 'Resolve' message contains a capability which is presumed to be retained.
return extractCapImpl(connectionState, lockedTables, descriptor, return extractCapImpl(connectionState, descriptor,
*lockedTables.resolutionChainTail, nullptr); *connectionState.resolutionChainTail, nullptr);
} }
// implements CapDescriptor ------------------------------------------------ // implements CapDescriptor ------------------------------------------------
kj::Own<ClientHook> extractCap(rpc::CapDescriptor::Reader descriptor) override { kj::Own<ClientHook> extractCap(rpc::CapDescriptor::Reader descriptor) override {
return extractCapImpl(connectionState, *connectionState.tables.lockExclusive(), descriptor, return extractCapImpl(connectionState, descriptor, *resolutionChain, retainedCaps);
*resolutionChain, retainedCaps);
} }
private: private:
...@@ -1280,15 +1247,15 @@ private: ...@@ -1280,15 +1247,15 @@ private:
// this message from being invalidated by `Resolve` messages before extraction is finished. // this message from being invalidated by `Resolve` messages before extraction is finished.
// Simply holding on to the chain keeps the import table entries valid. // Simply holding on to the chain keeps the import table entries valid.
kj::MutexGuarded<kj::Vector<ExportId>> retainedCaps; kj::Vector<ExportId> retainedCaps;
// Imports which we are responsible for retaining, should they still exist at the time that // Imports which we are responsible for retaining, should they still exist at the time that
// this message is released. // this message is released.
static kj::Own<ClientHook> extractCapImpl( static kj::Own<ClientHook> extractCapImpl(
RpcConnectionState& connectionState, Tables& tables, RpcConnectionState& connectionState,
rpc::CapDescriptor::Reader descriptor, rpc::CapDescriptor::Reader descriptor,
ResolutionChain& resolutionChain, ResolutionChain& resolutionChain,
kj::Maybe<const kj::MutexGuarded<kj::Vector<ExportId>>&> retainedCaps) { kj::Maybe<kj::Vector<ExportId>&> retainedCaps) {
switch (descriptor.which()) { switch (descriptor.which()) {
case rpc::CapDescriptor::SENDER_HOSTED: case rpc::CapDescriptor::SENDER_HOSTED:
case rpc::CapDescriptor::SENDER_PROMISE: { case rpc::CapDescriptor::SENDER_PROMISE: {
...@@ -1303,7 +1270,7 @@ private: ...@@ -1303,7 +1270,7 @@ private:
} }
// No recent resolutions. Check the import table then. // No recent resolutions. Check the import table then.
auto& import = tables.imports[importId]; auto& import = connectionState.imports[importId];
KJ_IF_MAYBE(c, import.appClient) { KJ_IF_MAYBE(c, import.appClient) {
// The import is already on the table, but it could be being deleted in another // The import is already on the table, but it could be being deleted in another
// thread. // thread.
...@@ -1323,7 +1290,7 @@ private: ...@@ -1323,7 +1290,7 @@ private:
KJ_IF_MAYBE(rc, retainedCaps) { KJ_IF_MAYBE(rc, retainedCaps) {
// We need to retain this import later if it still exists. // We need to retain this import later if it still exists.
rc->lockExclusive()->add(importId); rc->add(importId);
} else { } else {
// Automatically increment the refcount. // Automatically increment the refcount.
importClient->addRemoteRef(); importClient->addRemoteRef();
...@@ -1354,7 +1321,7 @@ private: ...@@ -1354,7 +1321,7 @@ private:
return cap->addRef(); return cap->addRef();
} }
KJ_IF_MAYBE(exp, tables.exports.find(descriptor.getReceiverHosted())) { KJ_IF_MAYBE(exp, connectionState.exports.find(descriptor.getReceiverHosted())) {
return exp->clientHook->addRef(); return exp->clientHook->addRef();
} }
return newBrokenCap("invalid 'receiverHosted' export ID"); return newBrokenCap("invalid 'receiverHosted' export ID");
...@@ -1369,7 +1336,7 @@ private: ...@@ -1369,7 +1336,7 @@ private:
KJ_IF_MAYBE(p, resolutionChain.findPipeline(promisedAnswer.getQuestionId())) { KJ_IF_MAYBE(p, resolutionChain.findPipeline(promisedAnswer.getQuestionId())) {
pipeline = p; pipeline = p;
} else { } else {
KJ_IF_MAYBE(answer, tables.answers.find(promisedAnswer.getQuestionId())) { KJ_IF_MAYBE(answer, connectionState.answers.find(promisedAnswer.getQuestionId())) {
if (answer->active) { if (answer->active) {
KJ_IF_MAYBE(p, answer->pipeline) { KJ_IF_MAYBE(p, answer->pipeline) {
pipeline = p->get(); pipeline = p->get();
...@@ -1407,11 +1374,9 @@ private: ...@@ -1407,11 +1374,9 @@ private:
~CapInjectorImpl() noexcept(false) { ~CapInjectorImpl() noexcept(false) {
kj::Vector<kj::Own<ResolutionChain>> thingsToRelease(exports.size()); kj::Vector<kj::Own<ResolutionChain>> thingsToRelease(exports.size());
auto lock = connectionState.tables.lockExclusive(); if (connectionState.networkException == nullptr) {
if (lock->networkException == nullptr) {
for (auto exportId: exports) { for (auto exportId: exports) {
thingsToRelease.add(releaseExport(*lock, exportId, 1)); thingsToRelease.add(connectionState.releaseExport(exportId, 1));
} }
} }
} }
...@@ -1420,22 +1385,22 @@ private: ...@@ -1420,22 +1385,22 @@ private:
// Return true if the message contains any capabilities. (If not, it may be possible to // Return true if the message contains any capabilities. (If not, it may be possible to
// release earlier.) // release earlier.)
return !caps.getWithoutLock().empty(); return !caps.empty();
} }
void finishDescriptors(Tables& tables) { void finishDescriptors() {
// Finish writing all of the CapDescriptors. Must be called with the tables locked, and the // Finish writing all of the CapDescriptors. Must be called with the tables locked, and the
// message must be sent before the tables are unlocked. // message must be sent before the tables are unlocked.
exports = kj::Vector<ExportId>(caps.getWithoutLock().size()); exports = kj::Vector<ExportId>(caps.size());
for (auto& entry: caps.getWithoutLock()) { for (auto& entry: caps) {
// If maybeExportId is inlined, GCC 4.7 reports a spurious "may be used uninitialized" // If maybeExportId is inlined, GCC 4.7 reports a spurious "may be used uninitialized"
// error (GCC 4.8 and Clang do not complain). // error (GCC 4.8 and Clang do not complain).
auto maybeExportId = connectionState.writeDescriptor( auto maybeExportId = connectionState.writeDescriptor(
*entry.second.cap, entry.second.builder, tables); *entry.second.cap, entry.second.builder);
KJ_IF_MAYBE(exportId, maybeExportId) { KJ_IF_MAYBE(exportId, maybeExportId) {
KJ_ASSERT(tables.exports.find(*exportId) != nullptr); KJ_ASSERT(connectionState.exports.find(*exportId) != nullptr);
exports.add(*exportId); exports.add(*exportId);
} }
} }
...@@ -1444,21 +1409,19 @@ private: ...@@ -1444,21 +1409,19 @@ private:
// implements CapInjector ---------------------------------------- // implements CapInjector ----------------------------------------
void injectCap(rpc::CapDescriptor::Builder descriptor, kj::Own<ClientHook>&& cap) override { void injectCap(rpc::CapDescriptor::Builder descriptor, kj::Own<ClientHook>&& cap) override {
auto lock = caps.lockExclusive(); auto result = caps.insert(std::make_pair(
auto result = lock->insert(std::make_pair(
identity(descriptor), CapInfo(descriptor, kj::mv(cap)))); identity(descriptor), CapInfo(descriptor, kj::mv(cap))));
KJ_REQUIRE(result.second, "A cap has already been injected at this location.") { KJ_REQUIRE(result.second, "A cap has already been injected at this location.") {
break; break;
} }
} }
kj::Own<ClientHook> getInjectedCap(rpc::CapDescriptor::Reader descriptor) override { kj::Own<ClientHook> getInjectedCap(rpc::CapDescriptor::Reader descriptor) override {
auto lock = caps.lockExclusive(); auto iter = caps.find(identity(descriptor));
auto iter = lock->find(identity(descriptor)); KJ_REQUIRE(iter != caps.end(), "getInjectedCap() called on descriptor I didn't write.");
KJ_REQUIRE(iter != lock->end(), "getInjectedCap() called on descriptor I didn't write.");
return iter->second.cap->addRef(); return iter->second.cap->addRef();
} }
void dropCap(rpc::CapDescriptor::Reader descriptor) override { void dropCap(rpc::CapDescriptor::Reader descriptor) override {
caps.lockExclusive()->erase(identity(descriptor)); caps.erase(identity(descriptor));
} }
private: private:
...@@ -1481,7 +1444,7 @@ private: ...@@ -1481,7 +1444,7 @@ private:
CapInfo(CapInfo&& other) = default; CapInfo(CapInfo&& other) = default;
}; };
kj::MutexGuarded<std::map<const void*, CapInfo>> caps; std::map<const void*, CapInfo> caps;
// Maps CapDescriptor locations to embedded caps. The descriptors aren't actually filled in // Maps CapDescriptor locations to embedded caps. The descriptors aren't actually filled in
// until just before the message is sent. // until just before the message is sent.
...@@ -1508,7 +1471,7 @@ private: ...@@ -1508,7 +1471,7 @@ private:
: connectionState(kj::addRef(connectionState)), id(id), fulfiller(kj::mv(fulfiller)) {} : connectionState(kj::addRef(connectionState)), id(id), fulfiller(kj::mv(fulfiller)) {}
~QuestionRef() { ~QuestionRef() {
if (connectionState->tables.lockShared()->networkException != nullptr) { if (connectionState->networkException != nullptr) {
return; return;
} }
...@@ -1534,18 +1497,15 @@ private: ...@@ -1534,18 +1497,15 @@ private:
// Check if the question has returned and, if so, remove it from the table. // Check if the question has returned and, if so, remove it from the table.
// Remove question ID from the table. Must do this *after* sending `Finish` to ensure that // Remove question ID from the table. Must do this *after* sending `Finish` to ensure that
// the ID is not re-allocated before the `Finish` message can be sent. // the ID is not re-allocated before the `Finish` message can be sent.
{
auto lock = connectionState->tables.lockExclusive();
auto& question = KJ_ASSERT_NONNULL( auto& question = KJ_ASSERT_NONNULL(
lock->questions.find(id), "Question ID no longer on table?"); connectionState->questions.find(id), "Question ID no longer on table?");
if (question.paramCaps == nullptr) { if (question.paramCaps == nullptr) {
// Call has already returned, so we can now remove it from the table. // Call has already returned, so we can now remove it from the table.
KJ_ASSERT(lock->questions.erase(id)); KJ_ASSERT(connectionState->questions.erase(id));
} else { } else {
question.selfRef = nullptr; question.selfRef = nullptr;
} }
} }
}
inline QuestionId getId() const { return id; } inline QuestionId getId() const { return id; }
...@@ -1599,10 +1559,7 @@ private: ...@@ -1599,10 +1559,7 @@ private:
RemotePromise<ObjectPointer> send() override { RemotePromise<ObjectPointer> send() override {
SendInternalResult sendResult; SendInternalResult sendResult;
{ KJ_IF_MAYBE(e, connectionState->networkException) {
auto lock = connectionState->tables.lockExclusive();
KJ_IF_MAYBE(e, lock->networkException) {
return RemotePromise<ObjectPointer>( return RemotePromise<ObjectPointer>(
kj::Promise<Response<ObjectPointer>>(kj::cp(*e)), kj::Promise<Response<ObjectPointer>>(kj::cp(*e)),
ObjectPointer::Pipeline(newBrokenPipeline(kj::cp(*e)))); ObjectPointer::Pipeline(newBrokenPipeline(kj::cp(*e))));
...@@ -1612,8 +1569,6 @@ private: ...@@ -1612,8 +1569,6 @@ private:
// Whoops, this capability has been redirected while we were building the request! // Whoops, this capability has been redirected while we were building the request!
// We'll have to make a new request and do a copy. Ick. // We'll have to make a new request and do a copy. Ick.
lock.release();
size_t sizeHint = paramsBuilder.targetSizeInWords(); size_t sizeHint = paramsBuilder.targetSizeInWords();
// TODO(perf): See TODO in RpcClient::call() about why we need to inflate the size a bit. // TODO(perf): See TODO in RpcClient::call() about why we need to inflate the size a bit.
...@@ -1629,8 +1584,7 @@ private: ...@@ -1629,8 +1584,7 @@ private:
replacement.set(paramsBuilder); replacement.set(paramsBuilder);
return replacement.send(); return replacement.send();
} else { } else {
sendResult = sendInternal(false, *lock); sendResult = sendInternal(false);
}
} }
auto forkedPromise = sendResult.promise.fork(); auto forkedPromise = sendResult.promise.fork();
...@@ -1664,10 +1618,7 @@ private: ...@@ -1664,10 +1618,7 @@ private:
SendInternalResult sendResult; SendInternalResult sendResult;
{ if (connectionState->networkException != nullptr) {
auto lock = connectionState->tables.lockExclusive();
if (lock->networkException != nullptr) {
// Disconnected; fall back to a regular send() which will fail appropriately. // Disconnected; fall back to a regular send() which will fail appropriately.
return nullptr; return nullptr;
} }
...@@ -1677,8 +1628,7 @@ private: ...@@ -1677,8 +1628,7 @@ private:
// Fall back to regular send(). // Fall back to regular send().
return nullptr; return nullptr;
} else { } else {
sendResult = sendInternal(true, *lock); sendResult = sendInternal(true);
}
} }
auto promise = sendResult.promise.then([](kj::Own<RpcResponse>&& response) { auto promise = sendResult.promise.then([](kj::Own<RpcResponse>&& response) {
...@@ -1712,12 +1662,12 @@ private: ...@@ -1712,12 +1662,12 @@ private:
kj::Promise<kj::Own<RpcResponse>> promise = nullptr; kj::Promise<kj::Own<RpcResponse>> promise = nullptr;
}; };
SendInternalResult sendInternal(bool isTailCall, Tables& lockedTables) { SendInternalResult sendInternal(bool isTailCall) {
injector->finishDescriptors(lockedTables); injector->finishDescriptors();
auto paf = kj::newPromiseAndFulfiller<kj::Promise<kj::Own<RpcResponse>>>(); auto paf = kj::newPromiseAndFulfiller<kj::Promise<kj::Own<RpcResponse>>>();
QuestionId questionId; QuestionId questionId;
auto& question = lockedTables.questions.next(questionId); auto& question = connectionState->questions.next(questionId);
callBuilder.setQuestionId(questionId); callBuilder.setQuestionId(questionId);
if (isTailCall) { if (isTailCall) {
...@@ -1757,7 +1707,7 @@ private: ...@@ -1757,7 +1707,7 @@ private:
// Construct a new RpcPipeline. // Construct a new RpcPipeline.
resolveSelfPromise.eagerlyEvaluate(); resolveSelfPromise.eagerlyEvaluate();
state.getWithoutLock().init<Waiting>(kj::mv(questionRef)); state.init<Waiting>(kj::mv(questionRef));
} }
RpcPipeline(RpcConnectionState& connectionState, kj::Own<QuestionRef>&& questionRef) RpcPipeline(RpcConnectionState& connectionState, kj::Own<QuestionRef>&& questionRef)
...@@ -1765,7 +1715,7 @@ private: ...@@ -1765,7 +1715,7 @@ private:
resolveSelfPromise(nullptr) { resolveSelfPromise(nullptr) {
// Construct a new RpcPipeline that is never expected to resolve. // Construct a new RpcPipeline that is never expected to resolve.
state.getWithoutLock().init<Waiting>(kj::mv(questionRef)); state.init<Waiting>(kj::mv(questionRef));
} }
// implements PipelineHook --------------------------------------- // implements PipelineHook ---------------------------------------
...@@ -1783,11 +1733,10 @@ private: ...@@ -1783,11 +1733,10 @@ private:
} }
kj::Own<ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) override { kj::Own<ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) override {
auto lock = state.lockExclusive(); if (state.is<Waiting>()) {
if (lock->is<Waiting>()) {
// Wrap a PipelineClient in a PromiseClient. // Wrap a PipelineClient in a PromiseClient.
auto pipelineClient = kj::refcounted<PipelineClient>( auto pipelineClient = kj::refcounted<PipelineClient>(
*connectionState, kj::addRef(*lock->get<Waiting>()), kj::heapArray(ops.asPtr())); *connectionState, kj::addRef(*state.get<Waiting>()), kj::heapArray(ops.asPtr()));
KJ_IF_MAYBE(r, redirectLater) { KJ_IF_MAYBE(r, redirectLater) {
auto resolutionPromise = r->addBranch().then(kj::mvCapture(ops, auto resolutionPromise = r->addBranch().then(kj::mvCapture(ops,
...@@ -1801,10 +1750,10 @@ private: ...@@ -1801,10 +1750,10 @@ private:
// Oh, this pipeline will never get redirected, so just return the PipelineClient. // Oh, this pipeline will never get redirected, so just return the PipelineClient.
return kj::mv(pipelineClient); return kj::mv(pipelineClient);
} }
} else if (lock->is<Resolved>()) { } else if (state.is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops); return state.get<Resolved>()->getResults().getPipelinedCap(ops);
} else { } else {
return newBrokenCap(kj::cp(lock->get<Broken>())); return newBrokenCap(kj::cp(state.get<Broken>()));
} }
} }
...@@ -1816,22 +1765,20 @@ private: ...@@ -1816,22 +1765,20 @@ private:
typedef kj::Own<QuestionRef> Waiting; typedef kj::Own<QuestionRef> Waiting;
typedef kj::Own<RpcResponse> Resolved; typedef kj::Own<RpcResponse> Resolved;
typedef kj::Exception Broken; typedef kj::Exception Broken;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, Broken>> state; kj::OneOf<Waiting, Resolved, Broken> state;
// Keep this last, because the continuation uses *this, so it should be destroyed first to // Keep this last, because the continuation uses *this, so it should be destroyed first to
// ensure the continuation is not still running. // ensure the continuation is not still running.
kj::Promise<void> resolveSelfPromise; kj::Promise<void> resolveSelfPromise;
void resolve(kj::Own<RpcResponse>&& response) { void resolve(kj::Own<RpcResponse>&& response) {
auto lock = state.lockExclusive(); KJ_ASSERT(state.is<Waiting>(), "Already resolved?");
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?"); state.init<Resolved>(kj::mv(response));
lock->init<Resolved>(kj::mv(response));
} }
void resolve(const kj::Exception&& exception) { void resolve(const kj::Exception&& exception) {
auto lock = state.lockExclusive(); KJ_ASSERT(state.is<Waiting>(), "Already resolved?");
KJ_ASSERT(lock->is<Waiting>(), "Already resolved?"); state.init<Broken>(kj::mv(exception));
lock->init<Broken>(kj::mv(exception));
} }
}; };
...@@ -1892,8 +1839,8 @@ private: ...@@ -1892,8 +1839,8 @@ private:
return builder; return builder;
} }
kj::Own<CapInjectorImpl> send(Tables& lockedTables) { kj::Own<CapInjectorImpl> send() {
injector->finishDescriptors(lockedTables); injector->finishDescriptors();
message->send(); message->send();
return kj::mv(injector); return kj::mv(injector);
} }
...@@ -1949,7 +1896,7 @@ private: ...@@ -1949,7 +1896,7 @@ private:
// We haven't sent a return yet, so we must have been canceled. Send a cancellation return. // We haven't sent a return yet, so we must have been canceled. Send a cancellation return.
unwindDetector.catchExceptionsIfUnwinding([&]() { unwindDetector.catchExceptionsIfUnwinding([&]() {
// Don't send anything if the connection is broken. // Don't send anything if the connection is broken.
if (connectionState->tables.lockShared()->networkException == nullptr) { if (connectionState->networkException == nullptr) {
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
requestCapExtractor.retainedListSizeHint(true) + messageSizeHint<rpc::Return>()); requestCapExtractor.retainedListSizeHint(true) + messageSizeHint<rpc::Return>());
auto builder = message->getBody().initAs<rpc::Message>().initReturn(); auto builder = message->getBody().initAs<rpc::Message>().initReturn();
...@@ -1970,7 +1917,7 @@ private: ...@@ -1970,7 +1917,7 @@ private:
message->send(); message->send();
} }
cleanupAnswerTable(connectionState->tables.lockExclusive(), nullptr, true); cleanupAnswerTable(nullptr, true);
}); });
} }
} }
...@@ -1996,10 +1943,8 @@ private: ...@@ -1996,10 +1943,8 @@ private:
returnMessage.adoptRetainedCaps(kj::mv(retainedCaps.exportList)); returnMessage.adoptRetainedCaps(kj::mv(retainedCaps.exportList));
kj::Own<PipelineHook> pipelineToRelease; kj::Own<PipelineHook> pipelineToRelease;
auto lock = connectionState->tables.lockExclusive(); cleanupAnswerTable(
auto& tables = *lock; kj::downcast<RpcServerResponseImpl>(*KJ_ASSERT_NONNULL(response)).send(), true);
cleanupAnswerTable(kj::mv(lock),
kj::downcast<RpcServerResponseImpl>(*KJ_ASSERT_NONNULL(response)).send(tables), true);
} }
} }
void sendErrorReturn(kj::Exception&& exception) { void sendErrorReturn(kj::Exception&& exception) {
...@@ -2017,7 +1962,7 @@ private: ...@@ -2017,7 +1962,7 @@ private:
fromException(exception, builder.initException()); fromException(exception, builder.initException());
message->send(); message->send();
cleanupAnswerTable(connectionState->tables.lockExclusive(), nullptr, true); cleanupAnswerTable(nullptr, true);
} }
} }
...@@ -2028,12 +1973,6 @@ private: ...@@ -2028,12 +1973,6 @@ private:
// the RpcCallContext is now responsible for cleaning up the entry in the answer table, since // the RpcCallContext is now responsible for cleaning up the entry in the answer table, since
// a Finish message was already received. // a Finish message was already received.
// Verify that we're holding the tables mutex. This is important because we're handing off
// responsibility for deleting the answer. Moreover, the callContext pointer in the answer
// table should not be null as this would indicate that we've already returned a result.
KJ_DASSERT(connectionState->tables.getAlreadyLockedExclusive()
.answers[questionId].callContext != nullptr);
if (__atomic_fetch_or(&cancellationFlags, CANCEL_REQUESTED, __ATOMIC_RELAXED) == if (__atomic_fetch_or(&cancellationFlags, CANCEL_REQUESTED, __ATOMIC_RELAXED) ==
CANCEL_ALLOWED) { CANCEL_ALLOWED) {
// We just set CANCEL_REQUESTED, and CANCEL_ALLOWED was already set previously. Initiate // We just set CANCEL_REQUESTED, and CANCEL_ALLOWED was already set previously. Initiate
...@@ -2105,7 +2044,7 @@ private: ...@@ -2105,7 +2044,7 @@ private:
builder.setTakeFromOtherAnswer(tailInfo->questionId); builder.setTakeFromOtherAnswer(tailInfo->questionId);
message->send(); message->send();
cleanupAnswerTable(connectionState->tables.lockExclusive(), nullptr, false); cleanupAnswerTable(nullptr, false);
} }
return { kj::mv(tailInfo->promise), kj::mv(tailInfo->pipeline) }; return { kj::mv(tailInfo->promise), kj::mv(tailInfo->pipeline) };
} }
...@@ -2182,11 +2121,11 @@ private: ...@@ -2182,11 +2121,11 @@ private:
CANCEL_ALLOWED = 2 CANCEL_ALLOWED = 2
}; };
mutable uint8_t cancellationFlags = 0; uint8_t cancellationFlags = 0;
// When both flags are set, the cancellation process will begin. Must be manipulated atomically // When both flags are set, the cancellation process will begin. Must be manipulated atomically
// as it may be accessed from multiple threads. // as it may be accessed from multiple threads.
mutable kj::Own<kj::PromiseFulfiller<void>> cancelFulfiller; kj::Own<kj::PromiseFulfiller<void>> cancelFulfiller;
// Fulfilled when cancellation has been both requested and permitted. The fulfilled promise is // Fulfilled when cancellation has been both requested and permitted. The fulfilled promise is
// exclusive-joined with the outermost promise waiting on the call return, so fulfilling it // exclusive-joined with the outermost promise waiting on the call return, so fulfilling it
// cancels that promise. // cancels that promise.
...@@ -2204,8 +2143,7 @@ private: ...@@ -2204,8 +2143,7 @@ private:
} }
} }
void cleanupAnswerTable(kj::Locked<Tables>&& lock, void cleanupAnswerTable(kj::Maybe<kj::Own<CapInjectorImpl>> resultCaps,
kj::Maybe<kj::Own<CapInjectorImpl>> resultCaps,
bool freePipelineIfNoCaps) { bool freePipelineIfNoCaps) {
// We need to remove the `callContext` pointer -- which points back to us -- from the // We need to remove the `callContext` pointer -- which points back to us -- from the
// answer table. Or we might even be responsible for removing the entire answer table // answer table. Or we might even be responsible for removing the entire answer table
...@@ -2218,18 +2156,14 @@ private: ...@@ -2218,18 +2156,14 @@ private:
kj::Own<PipelineHook> pipelineToRelease; kj::Own<PipelineHook> pipelineToRelease;
Answer answerToDelete; Answer answerToDelete;
// Release lock later so that pipelineToRelease and resultCaps can be deleted without
// deadlock.
KJ_DEFER(lock.release());
if (__atomic_load_n(&cancellationFlags, __ATOMIC_RELAXED) & CANCEL_REQUESTED) { if (__atomic_load_n(&cancellationFlags, __ATOMIC_RELAXED) & CANCEL_REQUESTED) {
Answer answerToDelete = kj::mv(lock->answers[questionId]); answerToDelete = kj::mv(connectionState->answers[questionId]);
// Erase from the table. // Erase from the table.
lock->answers.erase(questionId); connectionState->answers.erase(questionId);
} else { } else {
// We just have to null out callContext. // We just have to null out callContext.
auto& answer = lock->answers[questionId]; auto& answer = connectionState->answers[questionId];
answer.callContext = nullptr; answer.callContext = nullptr;
// If the response has capabilities, we need to arrange to keep the CapInjector around // If the response has capabilities, we need to arrange to keep the CapInjector around
...@@ -2327,17 +2261,17 @@ private: ...@@ -2327,17 +2261,17 @@ private:
auto cap = message.getResolve().getCap(); auto cap = message.getResolve().getCap();
switch (cap.which()) { switch (cap.which()) {
case rpc::CapDescriptor::SENDER_HOSTED: case rpc::CapDescriptor::SENDER_HOSTED:
releaseExport(*tables.lockExclusive(), cap.getSenderHosted(), 1); releaseExport(cap.getSenderHosted(), 1);
break; break;
case rpc::CapDescriptor::SENDER_PROMISE: case rpc::CapDescriptor::SENDER_PROMISE:
releaseExport(*tables.lockExclusive(), cap.getSenderPromise(), 1); releaseExport(cap.getSenderPromise(), 1);
break; break;
case rpc::CapDescriptor::RECEIVER_ANSWER: case rpc::CapDescriptor::RECEIVER_ANSWER:
case rpc::CapDescriptor::RECEIVER_HOSTED: case rpc::CapDescriptor::RECEIVER_HOSTED:
// Nothing to do. // Nothing to do.
break; break;
case rpc::CapDescriptor::THIRD_PARTY_HOSTED: case rpc::CapDescriptor::THIRD_PARTY_HOSTED:
releaseExport(*tables.lockExclusive(), cap.getThirdPartyHosted().getVineId(), 1); releaseExport(cap.getThirdPartyHosted().getVineId(), 1);
break; break;
} }
break; break;
...@@ -2386,14 +2320,13 @@ private: ...@@ -2386,14 +2320,13 @@ private:
// message at a time, so we can hold off locking the tables for a bit longer. // message at a time, so we can hold off locking the tables for a bit longer.
auto context = kj::refcounted<RpcCallContext>( auto context = kj::refcounted<RpcCallContext>(
*this, questionId, kj::mv(message), call.getParams(), *this, questionId, kj::mv(message), call.getParams(),
kj::addRef(*tables.getWithoutLock().resolutionChainTail), kj::addRef(*resolutionChainTail),
redirectResults, kj::mv(cancelPaf.fulfiller)); redirectResults, kj::mv(cancelPaf.fulfiller));
// No more using `call` after this point! // No more using `call` after this point!
{ {
auto lock = tables.lockExclusive(); auto& answer = answers[questionId];
auto& answer = lock->answers[questionId];
KJ_REQUIRE(!answer.active, "questionId is already in use") { KJ_REQUIRE(!answer.active, "questionId is already in use") {
return; return;
...@@ -2410,8 +2343,7 @@ private: ...@@ -2410,8 +2343,7 @@ private:
// context->directTailCall(). // context->directTailCall().
{ {
auto lock = tables.lockExclusive(); auto& answer = answers[questionId];
auto& answer = lock->answers[questionId];
answer.pipeline = kj::mv(promiseAndPipeline.pipeline); answer.pipeline = kj::mv(promiseAndPipeline.pipeline);
...@@ -2455,8 +2387,7 @@ private: ...@@ -2455,8 +2387,7 @@ private:
kj::Maybe<kj::Own<ClientHook>> getMessageTarget(const rpc::MessageTarget::Reader& target) { kj::Maybe<kj::Own<ClientHook>> getMessageTarget(const rpc::MessageTarget::Reader& target) {
switch (target.which()) { switch (target.which()) {
case rpc::MessageTarget::EXPORTED_CAP: { case rpc::MessageTarget::EXPORTED_CAP: {
auto lock = tables.lockExclusive(); // TODO(perf): shared? KJ_IF_MAYBE(exp, exports.find(target.getExportedCap())) {
KJ_IF_MAYBE(exp, lock->exports.find(target.getExportedCap())) {
return exp->clientHook->addRef(); return exp->clientHook->addRef();
} else { } else {
KJ_FAIL_REQUIRE("Message target is not a current export ID.") { KJ_FAIL_REQUIRE("Message target is not a current export ID.") {
...@@ -2470,9 +2401,7 @@ private: ...@@ -2470,9 +2401,7 @@ private:
auto promisedAnswer = target.getPromisedAnswer(); auto promisedAnswer = target.getPromisedAnswer();
kj::Own<PipelineHook> pipeline; kj::Own<PipelineHook> pipeline;
{ auto& base = answers[promisedAnswer.getQuestionId()];
auto lock = tables.lockExclusive(); // TODO(perf): shared?
auto& base = lock->answers[promisedAnswer.getQuestionId()];
KJ_REQUIRE(base.active, "PromisedAnswer.questionId is not a current question.") { KJ_REQUIRE(base.active, "PromisedAnswer.questionId is not a current question.") {
return nullptr; return nullptr;
} }
...@@ -2484,7 +2413,6 @@ private: ...@@ -2484,7 +2413,6 @@ private:
return nullptr; return nullptr;
} }
} }
}
KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) { KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) {
return pipeline->getPipelinedCap(*ops); return pipeline->getPipelinedCap(*ops);
...@@ -2507,8 +2435,7 @@ private: ...@@ -2507,8 +2435,7 @@ private:
kj::Own<CapInjectorImpl> paramCapsToRelease; kj::Own<CapInjectorImpl> paramCapsToRelease;
kj::Promise<kj::Own<RpcResponse>> promiseToRelease = nullptr; kj::Promise<kj::Own<RpcResponse>> promiseToRelease = nullptr;
auto lock = tables.lockExclusive(); KJ_IF_MAYBE(question, questions.find(ret.getQuestionId())) {
KJ_IF_MAYBE(question, lock->questions.find(ret.getQuestionId())) {
KJ_REQUIRE(question->paramCaps != nullptr, "Duplicate Return.") { return; } KJ_REQUIRE(question->paramCaps != nullptr, "Duplicate Return.") { return; }
KJ_IF_MAYBE(pc, question->paramCaps) { KJ_IF_MAYBE(pc, question->paramCaps) {
...@@ -2519,7 +2446,7 @@ private: ...@@ -2519,7 +2446,7 @@ private:
} }
for (ExportId retained: ret.getRetainedCaps()) { for (ExportId retained: ret.getRetainedCaps()) {
KJ_IF_MAYBE(exp, lock->exports.find(retained)) { KJ_IF_MAYBE(exp, exports.find(retained)) {
++exp->refcount; ++exp->refcount;
} else { } else {
KJ_FAIL_REQUIRE("Invalid export ID in Return.retainedCaps list.") { return; } KJ_FAIL_REQUIRE("Invalid export ID in Return.retainedCaps list.") { return; }
...@@ -2537,7 +2464,7 @@ private: ...@@ -2537,7 +2464,7 @@ private:
// Not being deleted. // Not being deleted.
questionRef->fulfill(kj::refcounted<RpcResponseImpl>( questionRef->fulfill(kj::refcounted<RpcResponseImpl>(
*this, kj::mv(*ownRef), kj::mv(message), ret.getResults(), *this, kj::mv(*ownRef), kj::mv(message), ret.getResults(),
kj::addRef(*lock->resolutionChainTail))); kj::addRef(*resolutionChainTail)));
} }
} }
break; break;
...@@ -2579,7 +2506,7 @@ private: ...@@ -2579,7 +2506,7 @@ private:
break; break;
case rpc::Return::TAKE_FROM_OTHER_ANSWER: case rpc::Return::TAKE_FROM_OTHER_ANSWER:
KJ_IF_MAYBE(answer, lock->answers.find(ret.getTakeFromOtherAnswer())) { KJ_IF_MAYBE(answer, answers.find(ret.getTakeFromOtherAnswer())) {
KJ_IF_MAYBE(response, answer->redirectedResults) { KJ_IF_MAYBE(response, answer->redirectedResults) {
// If we don't manage to fill in a questionRef here, we will want to release the // If we don't manage to fill in a questionRef here, we will want to release the
// promise. // promise.
...@@ -2607,7 +2534,7 @@ private: ...@@ -2607,7 +2534,7 @@ private:
} }
if (question->selfRef == nullptr) { if (question->selfRef == nullptr) {
lock->questions.erase(ret.getQuestionId()); questions.erase(ret.getQuestionId());
} }
} else { } else {
...@@ -2619,25 +2546,23 @@ private: ...@@ -2619,25 +2546,23 @@ private:
kj::Own<ResolutionChain> chainToRelease; kj::Own<ResolutionChain> chainToRelease;
Answer answerToRelease; Answer answerToRelease;
auto lock = tables.lockExclusive();
for (ExportId retained: finish.getRetainedCaps()) { for (ExportId retained: finish.getRetainedCaps()) {
KJ_IF_MAYBE(exp, lock->exports.find(retained)) { KJ_IF_MAYBE(exp, exports.find(retained)) {
++exp->refcount; ++exp->refcount;
} else { } else {
KJ_FAIL_REQUIRE("Invalid export ID in Return.retainedCaps list.") { return; } KJ_FAIL_REQUIRE("Invalid export ID in Return.retainedCaps list.") { return; }
} }
} }
KJ_IF_MAYBE(answer, lock->answers.find(finish.getQuestionId())) { KJ_IF_MAYBE(answer, answers.find(finish.getQuestionId())) {
KJ_REQUIRE(answer->active, "'Finish' for invalid question ID.") { return; } KJ_REQUIRE(answer->active, "'Finish' for invalid question ID.") { return; }
// `Finish` indicates that no further pipeline requests will be made. // `Finish` indicates that no further pipeline requests will be made.
// However, previously-sent messages that are still being processed could still refer to this // However, previously-sent messages that are still being processed could still refer to this
// pipeline, so we have to move it into the resolution chain. // pipeline, so we have to move it into the resolution chain.
KJ_IF_MAYBE(p, kj::mv(answer->pipeline)) { KJ_IF_MAYBE(p, kj::mv(answer->pipeline)) {
chainToRelease = kj::mv(lock->resolutionChainTail); chainToRelease = kj::mv(resolutionChainTail);
lock->resolutionChainTail = chainToRelease->addFinish(finish.getQuestionId(), kj::mv(*p)); resolutionChainTail = chainToRelease->addFinish(finish.getQuestionId(), kj::mv(*p));
} }
// If the call isn't actually done yet, cancel it. Otherwise, we can go ahead and erase the // If the call isn't actually done yet, cancel it. Otherwise, we can go ahead and erase the
...@@ -2646,7 +2571,7 @@ private: ...@@ -2646,7 +2571,7 @@ private:
context->requestCancel(); context->requestCancel();
} else { } else {
answerToRelease = kj::mv(*answer); answerToRelease = kj::mv(*answer);
lock->answers.erase(finish.getQuestionId()); answers.erase(finish.getQuestionId());
} }
} else { } else {
KJ_REQUIRE(answer->active, "'Finish' for invalid question ID.") { return; } KJ_REQUIRE(answer->active, "'Finish' for invalid question ID.") { return; }
...@@ -2659,14 +2584,12 @@ private: ...@@ -2659,14 +2584,12 @@ private:
void handleResolve(const rpc::Resolve::Reader& resolve) { void handleResolve(const rpc::Resolve::Reader& resolve) {
kj::Own<ResolutionChain> oldResolutionChainTail; // must be freed outside of lock kj::Own<ResolutionChain> oldResolutionChainTail; // must be freed outside of lock
auto lock = tables.lockExclusive();
kj::Own<ClientHook> replacement; kj::Own<ClientHook> replacement;
// Extract the replacement capability. // Extract the replacement capability.
switch (resolve.which()) { switch (resolve.which()) {
case rpc::Resolve::CAP: case rpc::Resolve::CAP:
replacement = CapExtractorImpl::extractCapAndAddRef(*this, *lock, resolve.getCap()); replacement = CapExtractorImpl::extractCapAndAddRef(*this, resolve.getCap());
break; break;
case rpc::Resolve::EXCEPTION: case rpc::Resolve::EXCEPTION:
...@@ -2678,12 +2601,12 @@ private: ...@@ -2678,12 +2601,12 @@ private:
} }
// Extend the resolution chain. // Extend the resolution chain.
oldResolutionChainTail = kj::mv(lock->resolutionChainTail); oldResolutionChainTail = kj::mv(resolutionChainTail);
lock->resolutionChainTail = oldResolutionChainTail->addResolve( resolutionChainTail = oldResolutionChainTail->addResolve(
resolve.getPromiseId(), kj::mv(replacement)); resolve.getPromiseId(), kj::mv(replacement));
// If the import is on the table, fulfill it. // If the import is on the table, fulfill it.
KJ_IF_MAYBE(import, lock->imports.find(resolve.getPromiseId())) { KJ_IF_MAYBE(import, imports.find(resolve.getPromiseId())) {
KJ_IF_MAYBE(fulfiller, import->promiseFulfiller) { KJ_IF_MAYBE(fulfiller, import->promiseFulfiller) {
// OK, this is in fact an unfulfilled promise! // OK, this is in fact an unfulfilled promise!
fulfiller->get()->fulfill(kj::mv(replacement)); fulfiller->get()->fulfill(kj::mv(replacement));
...@@ -2696,24 +2619,23 @@ private: ...@@ -2696,24 +2619,23 @@ private:
} }
void handleRelease(const rpc::Release::Reader& release) { void handleRelease(const rpc::Release::Reader& release) {
auto chainToRelease = releaseExport( auto chainToRelease = releaseExport(release.getId(), release.getReferenceCount());
*tables.lockExclusive(), release.getId(), release.getReferenceCount());
} }
static kj::Own<ResolutionChain> releaseExport(Tables& lockedTables, ExportId id, uint refcount) { kj::Own<ResolutionChain> releaseExport(ExportId id, uint refcount) {
kj::Own<ResolutionChain> result; kj::Own<ResolutionChain> result;
KJ_IF_MAYBE(exp, lockedTables.exports.find(id)) { KJ_IF_MAYBE(exp, exports.find(id)) {
KJ_REQUIRE(refcount <= exp->refcount, "Tried to drop export's refcount below zero.") { KJ_REQUIRE(refcount <= exp->refcount, "Tried to drop export's refcount below zero.") {
return result; return result;
} }
exp->refcount -= refcount; exp->refcount -= refcount;
if (exp->refcount == 0) { if (exp->refcount == 0) {
lockedTables.exportsByCap.erase(exp->clientHook); exportsByCap.erase(exp->clientHook);
result = kj::mv(lockedTables.resolutionChainTail); result = kj::mv(resolutionChainTail);
lockedTables.resolutionChainTail = result->addRelease(id, kj::mv(exp->clientHook)); resolutionChainTail = result->addRelease(id, kj::mv(exp->clientHook));
lockedTables.exports.erase(id); exports.erase(id);
return result; return result;
} else { } else {
return result; return result;
...@@ -2787,10 +2709,9 @@ private: ...@@ -2787,10 +2709,9 @@ private:
} }
case rpc::Disembargo::Context::RECEIVER_LOOPBACK: { case rpc::Disembargo::Context::RECEIVER_LOOPBACK: {
auto lock = tables.lockExclusive(); KJ_IF_MAYBE(embargo, embargoes.find(context.getReceiverLoopback())) {
KJ_IF_MAYBE(embargo, lock->embargoes.find(context.getReceiverLoopback())) {
KJ_ASSERT_NONNULL(embargo->fulfiller)->fulfill(); KJ_ASSERT_NONNULL(embargo->fulfiller)->fulfill();
lock->embargoes.erase(context.getReceiverLoopback()); embargoes.erase(context.getReceiverLoopback());
} else { } else {
KJ_FAIL_REQUIRE("Invalid embargo ID in 'Disembargo.context.receiverLoopback'.") { KJ_FAIL_REQUIRE("Invalid embargo ID in 'Disembargo.context.receiverLoopback'.") {
return; return;
...@@ -2865,15 +2786,12 @@ private: ...@@ -2865,15 +2786,12 @@ private:
message = nullptr; message = nullptr;
// Add the answer to the answer table for pipelining and send the response. // Add the answer to the answer table for pipelining and send the response.
{ auto& answer = answers[questionId];
auto lock = tables.lockExclusive();
auto& answer = lock->answers[questionId];
KJ_REQUIRE(!answer.active, "questionId is already in use") { KJ_REQUIRE(!answer.active, "questionId is already in use") {
return; return;
} }
injector->finishDescriptors(*lock); injector->finishDescriptors();
answer.active = true; answer.active = true;
answer.pipeline = kj::Own<PipelineHook>( answer.pipeline = kj::Own<PipelineHook>(
...@@ -2881,7 +2799,6 @@ private: ...@@ -2881,7 +2799,6 @@ private:
response->send(); response->send();
} }
}
}; };
} // namespace } // namespace
...@@ -2896,13 +2813,12 @@ public: ...@@ -2896,13 +2813,12 @@ public:
~Impl() noexcept(false) { ~Impl() noexcept(false) {
// std::unordered_map doesn't like it when elements' destructors throw, so carefully // std::unordered_map doesn't like it when elements' destructors throw, so carefully
// disassemble it. // disassemble it.
auto& connectionMap = connections.getWithoutLock(); if (!connections.empty()) {
if (!connectionMap.empty()) { kj::Vector<kj::Own<RpcConnectionState>> deleteMe(connections.size());
kj::Vector<kj::Own<RpcConnectionState>> deleteMe(connectionMap.size());
kj::Exception shutdownException( kj::Exception shutdownException(
kj::Exception::Nature::LOCAL_BUG, kj::Exception::Durability::PERMANENT, kj::Exception::Nature::LOCAL_BUG, kj::Exception::Durability::PERMANENT,
__FILE__, __LINE__, kj::str("RpcSystem was destroyed.")); __FILE__, __LINE__, kj::str("RpcSystem was destroyed."));
for (auto& entry: connectionMap) { for (auto& entry: connections) {
entry.second->disconnect(kj::cp(shutdownException)); entry.second->disconnect(kj::cp(shutdownException));
deleteMe.add(kj::mv(entry.second)); deleteMe.add(kj::mv(entry.second));
} }
...@@ -2911,8 +2827,7 @@ public: ...@@ -2911,8 +2827,7 @@ public:
Capability::Client restore(_::StructReader hostId, ObjectPointer::Reader objectId) { Capability::Client restore(_::StructReader hostId, ObjectPointer::Reader objectId) {
KJ_IF_MAYBE(connection, network.baseConnectToRefHost(hostId)) { KJ_IF_MAYBE(connection, network.baseConnectToRefHost(hostId)) {
auto lock = connections.lockExclusive(); auto& state = getConnectionState(kj::mv(*connection));
auto& state = getConnectionState(kj::mv(*connection), *lock);
return Capability::Client(state.restore(objectId)); return Capability::Client(state.restore(objectId));
} else KJ_IF_MAYBE(r, restorer) { } else KJ_IF_MAYBE(r, restorer) {
return r->baseRestore(objectId); return r->baseRestore(objectId);
...@@ -2933,21 +2848,20 @@ private: ...@@ -2933,21 +2848,20 @@ private:
typedef std::unordered_map<VatNetworkBase::Connection*, kj::Own<RpcConnectionState>> typedef std::unordered_map<VatNetworkBase::Connection*, kj::Own<RpcConnectionState>>
ConnectionMap; ConnectionMap;
kj::MutexGuarded<ConnectionMap> connections; ConnectionMap connections;
RpcConnectionState& getConnectionState(kj::Own<VatNetworkBase::Connection>&& connection, RpcConnectionState& getConnectionState(kj::Own<VatNetworkBase::Connection>&& connection) {
ConnectionMap& lockedMap) { auto iter = connections.find(connection);
auto iter = lockedMap.find(connection); if (iter == connections.end()) {
if (iter == lockedMap.end()) {
VatNetworkBase::Connection* connectionPtr = connection; VatNetworkBase::Connection* connectionPtr = connection;
auto onDisconnect = kj::newPromiseAndFulfiller<void>(); auto onDisconnect = kj::newPromiseAndFulfiller<void>();
tasks.add(onDisconnect.promise.then([this,connectionPtr]() { tasks.add(onDisconnect.promise.then([this,connectionPtr]() {
connections.lockExclusive()->erase(connectionPtr); connections.erase(connectionPtr);
})); }));
auto newState = kj::refcounted<RpcConnectionState>( auto newState = kj::refcounted<RpcConnectionState>(
restorer, kj::mv(connection), kj::mv(onDisconnect.fulfiller)); restorer, kj::mv(connection), kj::mv(onDisconnect.fulfiller));
RpcConnectionState& result = *newState; RpcConnectionState& result = *newState;
lockedMap.insert(std::make_pair(connectionPtr, kj::mv(newState))); connections.insert(std::make_pair(connectionPtr, kj::mv(newState)));
return result; return result;
} else { } else {
return *iter->second; return *iter->second;
...@@ -2957,8 +2871,7 @@ private: ...@@ -2957,8 +2871,7 @@ private:
kj::Promise<void> acceptLoop() { kj::Promise<void> acceptLoop() {
auto receive = network.baseAcceptConnectionAsRefHost().then( auto receive = network.baseAcceptConnectionAsRefHost().then(
[this](kj::Own<VatNetworkBase::Connection>&& connection) { [this](kj::Own<VatNetworkBase::Connection>&& connection) {
auto lock = connections.lockExclusive(); getConnectionState(kj::mv(connection));
getConnectionState(kj::mv(connection), *lock);
}); });
return receive.then([this]() { return receive.then([this]() {
// No exceptions; continue loop. // No exceptions; continue loop.
......
...@@ -560,25 +560,22 @@ void TransformPromiseNodeBase::getDepResult(ExceptionOrValue& output) { ...@@ -560,25 +560,22 @@ void TransformPromiseNodeBase::getDepResult(ExceptionOrValue& output) {
// ------------------------------------------------------------------- // -------------------------------------------------------------------
ForkBranchBase::ForkBranchBase(Own<ForkHubBase>&& hubParam): hub(kj::mv(hubParam)) { ForkBranchBase::ForkBranchBase(Own<ForkHubBase>&& hubParam): hub(kj::mv(hubParam)) {
auto lock = hub->branchList.lockExclusive(); if (hub->tailBranch == nullptr) {
if (lock->lastPtr == nullptr) {
onReadyEvent.arm(); onReadyEvent.arm();
} else { } else {
// Insert into hub's linked list of branches. // Insert into hub's linked list of branches.
prevPtr = lock->lastPtr; prevPtr = hub->tailBranch;
*prevPtr = this; *prevPtr = this;
next = nullptr; next = nullptr;
lock->lastPtr = &next; hub->tailBranch = &next;
} }
} }
ForkBranchBase::~ForkBranchBase() noexcept(false) { ForkBranchBase::~ForkBranchBase() noexcept(false) {
if (prevPtr != nullptr) { if (prevPtr != nullptr) {
// Remove from hub's linked list of branches. // Remove from hub's linked list of branches.
auto lock = hub->branchList.lockExclusive();
*prevPtr = next; *prevPtr = next;
(next == nullptr ? lock->lastPtr : next->prevPtr) = prevPtr; (next == nullptr ? hub->tailBranch : next->prevPtr) = prevPtr;
} }
} }
...@@ -618,16 +615,15 @@ Maybe<Own<EventLoop::Event>> ForkHubBase::fire() { ...@@ -618,16 +615,15 @@ Maybe<Own<EventLoop::Event>> ForkHubBase::fire() {
resultRef.addException(kj::mv(*exception)); resultRef.addException(kj::mv(*exception));
} }
auto lock = branchList.lockExclusive(); for (auto branch = headBranch; branch != nullptr; branch = branch->next) {
for (auto branch = lock->first; branch != nullptr; branch = branch->next) {
branch->hubReady(); branch->hubReady();
*branch->prevPtr = nullptr; *branch->prevPtr = nullptr;
branch->prevPtr = nullptr; branch->prevPtr = nullptr;
} }
*lock->lastPtr = nullptr; *tailBranch = nullptr;
// Indicate that the list is no longer active. // Indicate that the list is no longer active.
lock->lastPtr = nullptr; tailBranch = nullptr;
return nullptr; return nullptr;
} }
......
...@@ -1108,16 +1108,12 @@ public: ...@@ -1108,16 +1108,12 @@ public:
inline ExceptionOrValue& getResultRef() { return resultRef; } inline ExceptionOrValue& getResultRef() { return resultRef; }
private: private:
struct BranchList {
ForkBranchBase* first = nullptr;
ForkBranchBase** lastPtr = &first;
};
Own<PromiseNode> inner; Own<PromiseNode> inner;
ExceptionOrValue& resultRef; ExceptionOrValue& resultRef;
MutexGuarded<BranchList> branchList; ForkBranchBase* headBranch = nullptr;
// Becomes null once the inner promise is ready and all branches have been notified. ForkBranchBase** tailBranch = &headBranch;
// Tail becomes null once the inner promise is ready and all branches have been notified.
Maybe<Own<Event>> fire() override; Maybe<Own<Event>> fire() override;
_::PromiseNode* getInnerForTrace() override; _::PromiseNode* getInnerForTrace() override;
...@@ -1451,61 +1447,54 @@ public: ...@@ -1451,61 +1447,54 @@ public:
} }
void fulfill(FixVoid<T>&& value) override { void fulfill(FixVoid<T>&& value) override {
auto lock = inner.lockExclusive(); if (inner != nullptr) {
if (*lock != nullptr) { inner->fulfill(kj::mv(value));
(*lock)->fulfill(kj::mv(value));
} }
} }
void reject(Exception&& exception) override { void reject(Exception&& exception) override {
auto lock = inner.lockExclusive(); if (inner != nullptr) {
if (*lock != nullptr) { inner->reject(kj::mv(exception));
(*lock)->reject(kj::mv(exception));
} }
} }
bool isWaiting() override { bool isWaiting() override {
auto lock = inner.lockExclusive(); return inner != nullptr && inner->isWaiting();
return *lock != nullptr && (*lock)->isWaiting();
} }
void attach(PromiseFulfiller<T>& newInner) { void attach(PromiseFulfiller<T>& newInner) {
inner.getWithoutLock() = &newInner; inner = &newInner;
} }
void detach(PromiseFulfiller<T>& from) { void detach(PromiseFulfiller<T>& from) {
auto lock = inner.lockExclusive(); if (inner == nullptr) {
if (*lock == nullptr) {
// Already disposed. // Already disposed.
lock.release();
delete this; delete this;
} else { } else {
KJ_IREQUIRE(*lock == &from); KJ_IREQUIRE(inner == &from);
*lock = nullptr; inner = nullptr;
} }
} }
private: private:
MutexGuarded<PromiseFulfiller<T>*> inner; mutable PromiseFulfiller<T>* inner;
WeakFulfiller(): inner(nullptr) {} WeakFulfiller(): inner(nullptr) {}
void disposeImpl(void* pointer) const override { void disposeImpl(void* pointer) const override {
// TODO(perf): Factor some of this out so it isn't regenerated for every fulfiller type? // TODO(perf): Factor some of this out so it isn't regenerated for every fulfiller type?
auto lock = inner.lockExclusive(); if (inner == nullptr) {
if (*lock == nullptr) {
// Already detached. // Already detached.
lock.release();
delete this; delete this;
} else { } else {
if ((*lock)->isWaiting()) { if (inner->isWaiting()) {
(*lock)->reject(kj::Exception( inner->reject(kj::Exception(
kj::Exception::Nature::LOCAL_BUG, kj::Exception::Durability::PERMANENT, kj::Exception::Nature::LOCAL_BUG, kj::Exception::Durability::PERMANENT,
__FILE__, __LINE__, __FILE__, __LINE__,
kj::heapString("PromiseFulfiller was destroyed without fulfilling the promise."))); kj::heapString("PromiseFulfiller was destroyed without fulfilling the promise.")));
} }
*lock = nullptr; inner = nullptr;
} }
} }
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment