Commit e1fdba61 authored by Kenton Varda's avatar Kenton Varda

Remove more mutexes.

parent 540b1887
...@@ -253,7 +253,7 @@ public: ...@@ -253,7 +253,7 @@ public:
: promise(promiseParam.fork()), : promise(promiseParam.fork()),
selfResolutionOp(promise.addBranch().then( selfResolutionOp(promise.addBranch().then(
[this](kj::Own<PipelineHook>&& inner) { [this](kj::Own<PipelineHook>&& inner) {
*redirect.lockExclusive() = kj::mv(inner); redirect = kj::mv(inner);
})) { })) {
selfResolutionOp.eagerlyEvaluate(); selfResolutionOp.eagerlyEvaluate();
} }
...@@ -275,7 +275,7 @@ public: ...@@ -275,7 +275,7 @@ public:
private: private:
kj::ForkedPromise<kj::Own<PipelineHook>> promise; kj::ForkedPromise<kj::Own<PipelineHook>> promise;
kj::MutexGuarded<kj::Maybe<kj::Own<PipelineHook>>> redirect; kj::Maybe<kj::Own<PipelineHook>> redirect;
// Once the promise resolves, this will become non-null and point to the underlying object. // Once the promise resolves, this will become non-null and point to the underlying object.
kj::Promise<void> selfResolutionOp; kj::Promise<void> selfResolutionOp;
...@@ -291,8 +291,10 @@ public: ...@@ -291,8 +291,10 @@ public:
: promise(promiseParam.fork()), : promise(promiseParam.fork()),
selfResolutionOp(promise.addBranch().then( selfResolutionOp(promise.addBranch().then(
[this](kj::Own<ClientHook>&& inner) { [this](kj::Own<ClientHook>&& inner) {
*redirect.lockExclusive() = kj::mv(inner); redirect = kj::mv(inner);
})) { })),
promiseForCallForwarding(promise.addBranch().fork()),
promiseForClientResolution(promise.addBranch().fork()) {
selfResolutionOp.eagerlyEvaluate(); selfResolutionOp.eagerlyEvaluate();
} }
...@@ -334,7 +336,7 @@ public: ...@@ -334,7 +336,7 @@ public:
// Create a promise for the call initiation. // Create a promise for the call initiation.
kj::ForkedPromise<kj::Own<CallResultHolder>> callResultPromise = kj::ForkedPromise<kj::Own<CallResultHolder>> callResultPromise =
getPromiseForCallForwarding().addBranch().then(kj::mvCapture(context, promiseForCallForwarding.addBranch().then(kj::mvCapture(context,
[=](kj::Own<CallContextHook>&& context, kj::Own<ClientHook>&& client){ [=](kj::Own<CallContextHook>&& context, kj::Own<ClientHook>&& client){
return kj::refcounted<CallResultHolder>( return kj::refcounted<CallResultHolder>(
client->call(interfaceId, methodId, kj::mv(context))); client->call(interfaceId, methodId, kj::mv(context)));
...@@ -359,7 +361,7 @@ public: ...@@ -359,7 +361,7 @@ public:
} }
kj::Maybe<ClientHook&> getResolved() override { kj::Maybe<ClientHook&> getResolved() override {
KJ_IF_MAYBE(inner, *redirect.lockExclusive()) { KJ_IF_MAYBE(inner, redirect) {
return **inner; return **inner;
} else { } else {
return nullptr; return nullptr;
...@@ -367,7 +369,7 @@ public: ...@@ -367,7 +369,7 @@ public:
} }
kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override { kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
return getPromiseForClientResolution().addBranch(); return promiseForClientResolution.addBranch();
} }
kj::Own<ClientHook> addRef() override { kj::Own<ClientHook> addRef() override {
...@@ -381,50 +383,35 @@ public: ...@@ -381,50 +383,35 @@ public:
private: private:
typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork; typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork;
kj::Maybe<kj::Own<ClientHook>> redirect;
// Once the promise resolves, this will become non-null and point to the underlying object.
ClientHookPromiseFork promise; ClientHookPromiseFork promise;
// Promise that resolves when we have a new ClientHook to forward to. // Promise that resolves when we have a new ClientHook to forward to.
// //
// This fork shall only have two branches: `promiseForCallForwarding` and // This fork shall only have three branches: `selfResolutionOp`, `promiseForCallForwarding`, and
// `promiseForClientResolution`, in that order. // `promiseForClientResolution`, in that order.
kj::Lazy<ClientHookPromiseFork> promiseForCallForwarding; kj::Promise<void> selfResolutionOp;
// Represents the operation which will set `redirect` when possible.
ClientHookPromiseFork promiseForCallForwarding;
// When this promise resolves, each queued call will be forwarded to the real client. This needs // When this promise resolves, each queued call will be forwarded to the real client. This needs
// to occur *before* any 'whenMoreResolved()' promises resolve, because we want to make sure // to occur *before* any 'whenMoreResolved()' promises resolve, because we want to make sure
// previously-queued calls are delivered before any new calls made in response to the resolution. // previously-queued calls are delivered before any new calls made in response to the resolution.
kj::Lazy<ClientHookPromiseFork> promiseForClientResolution; ClientHookPromiseFork promiseForClientResolution;
// whenMoreResolved() returns forks of this promise. These must resolve *after* queued calls // whenMoreResolved() returns forks of this promise. These must resolve *after* queued calls
// have been initiated (so that any calls made in the whenMoreResolved() handler are correctly // have been initiated (so that any calls made in the whenMoreResolved() handler are correctly
// delivered after calls made earlier), but *before* any queued calls return (because it might // delivered after calls made earlier), but *before* any queued calls return (because it might
// confuse the application if a queued call returns before the capability on which it was made // confuse the application if a queued call returns before the capability on which it was made
// resolves). Luckily, we know that queued calls will involve, at the very least, an // resolves). Luckily, we know that queued calls will involve, at the very least, an
// eventLoop.evalLater. // eventLoop.evalLater.
kj::MutexGuarded<kj::Maybe<kj::Own<ClientHook>>> redirect;
// Once the promise resolves, this will become non-null and point to the underlying object.
kj::Promise<void> selfResolutionOp;
// Represents the operation which will set `redirect` when possible.
ClientHookPromiseFork& getPromiseForCallForwarding() {
return promiseForCallForwarding.get([this](kj::SpaceFor<ClientHookPromiseFork>& space) {
return space.construct(promise.addBranch().fork());
});
}
kj::ForkedPromise<kj::Own<ClientHook>>& getPromiseForClientResolution() {
return promiseForClientResolution.get([this](kj::SpaceFor<ClientHookPromiseFork>& space) {
getPromiseForCallForwarding(); // must be initialized first.
return space.construct(promise.addBranch().fork());
});
}
}; };
kj::Own<ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) { kj::Own<ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) {
auto lock = redirect.lockExclusive(); KJ_IF_MAYBE(r, redirect) {
return r->get()->getPipelinedCap(kj::mv(ops));
KJ_IF_MAYBE(redirect, *lock) {
return redirect->get()->getPipelinedCap(kj::mv(ops));
} else { } else {
auto clientPromise = promise.addBranch().then(kj::mvCapture(ops, auto clientPromise = promise.addBranch().then(kj::mvCapture(ops,
[](kj::Array<PipelineOp>&& ops, kj::Own<PipelineHook> pipeline) { [](kj::Array<PipelineOp>&& ops, kj::Own<PipelineHook> pipeline) {
......
...@@ -211,7 +211,7 @@ public: ...@@ -211,7 +211,7 @@ public:
kj::Exception exception( kj::Exception exception(
kj::Exception::Nature::PRECONDITION, kj::Exception::Durability::PERMANENT, kj::Exception::Nature::PRECONDITION, kj::Exception::Durability::PERMANENT,
__FILE__, __LINE__, kj::str("Network was destroyed.")); __FILE__, __LINE__, kj::str("Network was destroyed."));
for (auto& entry: state.getWithoutLock().connections) { for (auto& entry: connections) {
entry.second->disconnect(kj::cp(exception)); entry.second->disconnect(kj::cp(exception));
} }
} }
...@@ -235,11 +235,9 @@ public: ...@@ -235,11 +235,9 @@ public:
} }
void disconnect(kj::Exception&& exception) { void disconnect(kj::Exception&& exception) {
auto lock = queues.lockExclusive(); while (!fulfillers.empty()) {
fulfillers.front()->reject(kj::cp(exception));
while (!lock->fulfillers.empty()) { fulfillers.pop();
lock->fulfillers.front()->reject(kj::cp(exception));
lock->fulfillers.pop();
} }
networkException = kj::mv(exception); networkException = kj::mv(exception);
...@@ -285,13 +283,13 @@ public: ...@@ -285,13 +283,13 @@ public:
connection.tasks->add(kj::evalLater(kj::mvCapture(kj::addRef(*message), connection.tasks->add(kj::evalLater(kj::mvCapture(kj::addRef(*message),
[connectionPtr](kj::Own<IncomingRpcMessageImpl>&& message) { [connectionPtr](kj::Own<IncomingRpcMessageImpl>&& message) {
KJ_IF_MAYBE(p, connectionPtr->partner) { KJ_IF_MAYBE(p, connectionPtr->partner) {
auto lock = p->queues.lockExclusive(); if (p->fulfillers.empty()) {
if (lock->fulfillers.empty()) { p->messages.push(kj::mv(message));
lock->messages.push(kj::mv(message));
} else { } else {
++connectionPtr->network.received; ++p->network.received;
lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message))); p->fulfillers.front()->fulfill(
lock->fulfillers.pop(); kj::Own<IncomingRpcMessage>(kj::mv(message)));
p->fulfillers.pop();
} }
} }
}))); })));
...@@ -310,15 +308,14 @@ public: ...@@ -310,15 +308,14 @@ public:
return kj::cp(*e); return kj::cp(*e);
} }
auto lock = queues.lockExclusive(); if (messages.empty()) {
if (lock->messages.empty()) {
auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>(); auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>();
lock->fulfillers.push(kj::mv(paf.fulfiller)); fulfillers.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise); return kj::mv(paf.promise);
} else { } else {
++network.received; ++network.received;
auto result = kj::mv(lock->messages.front()); auto result = kj::mv(messages.front());
lock->messages.pop(); messages.pop();
return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result)); return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result));
} }
} }
...@@ -347,11 +344,8 @@ public: ...@@ -347,11 +344,8 @@ public:
kj::Maybe<kj::Exception> networkException; kj::Maybe<kj::Exception> networkException;
struct Queues { std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers;
std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers; std::queue<kj::Own<IncomingRpcMessage>> messages;
std::queue<kj::Own<IncomingRpcMessage>> messages;
};
kj::MutexGuarded<Queues> queues;
kj::Own<kj::TaskSet> tasks; kj::Own<kj::TaskSet> tasks;
}; };
...@@ -360,31 +354,20 @@ public: ...@@ -360,31 +354,20 @@ public:
test::TestSturdyRefHostId::Reader hostId) override { test::TestSturdyRefHostId::Reader hostId) override {
TestNetworkAdapter& dst = KJ_REQUIRE_NONNULL(network.find(hostId.getHost())); TestNetworkAdapter& dst = KJ_REQUIRE_NONNULL(network.find(hostId.getHost()));
kj::Locked<State> myLock; auto iter = connections.find(&dst);
kj::Locked<State> dstLock; if (iter == connections.end()) {
if (&dst < this) {
dstLock = dst.state.lockExclusive();
myLock = state.lockExclusive();
} else {
myLock = state.lockExclusive();
dstLock = dst.state.lockExclusive();
}
auto iter = myLock->connections.find(&dst);
if (iter == myLock->connections.end()) {
auto local = kj::refcounted<ConnectionImpl>(*this, RpcDumper::CLIENT); auto local = kj::refcounted<ConnectionImpl>(*this, RpcDumper::CLIENT);
auto remote = kj::refcounted<ConnectionImpl>(dst, RpcDumper::SERVER); auto remote = kj::refcounted<ConnectionImpl>(dst, RpcDumper::SERVER);
local->attach(*remote); local->attach(*remote);
myLock->connections[&dst] = kj::addRef(*local); connections[&dst] = kj::addRef(*local);
dstLock->connections[this] = kj::addRef(*remote); dst.connections[this] = kj::addRef(*remote);
if (dstLock->fulfillerQueue.empty()) { if (dst.fulfillerQueue.empty()) {
dstLock->connectionQueue.push(kj::mv(remote)); dst.connectionQueue.push(kj::mv(remote));
} else { } else {
dstLock->fulfillerQueue.front()->fulfill(kj::mv(remote)); dst.fulfillerQueue.front()->fulfill(kj::mv(remote));
dstLock->fulfillerQueue.pop(); dst.fulfillerQueue.pop();
} }
return kj::Own<Connection>(kj::mv(local)); return kj::Own<Connection>(kj::mv(local));
...@@ -394,14 +377,13 @@ public: ...@@ -394,14 +377,13 @@ public:
} }
kj::Promise<kj::Own<Connection>> acceptConnectionAsRefHost() override { kj::Promise<kj::Own<Connection>> acceptConnectionAsRefHost() override {
auto lock = state.lockExclusive(); if (connectionQueue.empty()) {
if (lock->connectionQueue.empty()) {
auto paf = kj::newPromiseAndFulfiller<kj::Own<Connection>>(); auto paf = kj::newPromiseAndFulfiller<kj::Own<Connection>>();
lock->fulfillerQueue.push(kj::mv(paf.fulfiller)); fulfillerQueue.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise); return kj::mv(paf.promise);
} else { } else {
auto result = kj::mv(lock->connectionQueue.front()); auto result = kj::mv(connectionQueue.front());
lock->connectionQueue.pop(); connectionQueue.pop();
return kj::mv(result); return kj::mv(result);
} }
} }
...@@ -411,12 +393,9 @@ private: ...@@ -411,12 +393,9 @@ private:
uint sent = 0; uint sent = 0;
uint received = 0; uint received = 0;
struct State { std::map<const TestNetworkAdapter*, kj::Own<ConnectionImpl>> connections;
std::map<const TestNetworkAdapter*, kj::Own<ConnectionImpl>> connections; std::queue<kj::Own<kj::PromiseFulfiller<kj::Own<Connection>>>> fulfillerQueue;
std::queue<kj::Own<kj::PromiseFulfiller<kj::Own<Connection>>>> fulfillerQueue; std::queue<kj::Own<Connection>> connectionQueue;
std::queue<kj::Own<Connection>> connectionQueue;
};
kj::MutexGuarded<State> state;
}; };
TestNetwork::~TestNetwork() noexcept(false) {} TestNetwork::~TestNetwork() noexcept(false) {}
......
...@@ -33,12 +33,12 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty: ...@@ -33,12 +33,12 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty:
{ {
auto paf = kj::newPromiseAndFulfiller<void>(); auto paf = kj::newPromiseAndFulfiller<void>();
disconnectPromise = paf.promise.fork(); disconnectPromise = paf.promise.fork();
disconnectFulfiller.getWithoutLock() = kj::mv(paf.fulfiller); disconnectFulfiller = kj::mv(paf.fulfiller);
} }
{ {
auto paf = kj::newPromiseAndFulfiller<void>(); auto paf = kj::newPromiseAndFulfiller<void>();
drainedPromise = paf.promise.fork(); drainedPromise = paf.promise.fork();
drainedFulfiller.fulfiller.getWithoutLock() = kj::mv(paf.fulfiller); drainedFulfiller.fulfiller = kj::mv(paf.fulfiller);
} }
} }
...@@ -68,7 +68,7 @@ kj::Promise<kj::Own<TwoPartyVatNetworkBase::Connection>> ...@@ -68,7 +68,7 @@ kj::Promise<kj::Own<TwoPartyVatNetworkBase::Connection>>
class TwoPartyVatNetwork::OutgoingMessageImpl final class TwoPartyVatNetwork::OutgoingMessageImpl final
: public OutgoingRpcMessage, public kj::Refcounted { : public OutgoingRpcMessage, public kj::Refcounted {
public: public:
OutgoingMessageImpl(const TwoPartyVatNetwork& network, uint firstSegmentWordSize) OutgoingMessageImpl(TwoPartyVatNetwork& network, uint firstSegmentWordSize)
: network(network), : network(network),
message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize) {} message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize) {}
...@@ -77,22 +77,21 @@ public: ...@@ -77,22 +77,21 @@ public:
} }
void send() override { void send() override {
auto lock = network.previousWrite.lockExclusive(); network.previousWrite = network.previousWrite.then([&]() {
*lock = lock->then([&]() {
auto promise = writeMessage(network.stream, message).then([]() { auto promise = writeMessage(network.stream, message).then([]() {
// success; do nothing // success; do nothing
}, [&](kj::Exception&& exception) { }, [&](kj::Exception&& exception) {
// Exception during write! // Exception during write!
network.disconnectFulfiller.lockExclusive()->get()->fulfill(); network.disconnectFulfiller->fulfill();
}); });
promise.eagerlyEvaluate(); promise.eagerlyEvaluate();
return kj::mv(promise); return kj::mv(promise);
}); });
lock->attach(kj::addRef(*this)); network.previousWrite.attach(kj::addRef(*this));
} }
private: private:
const TwoPartyVatNetwork& network; TwoPartyVatNetwork& network;
MallocMessageBuilder message; MallocMessageBuilder message;
}; };
...@@ -120,11 +119,11 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI ...@@ -120,11 +119,11 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI
KJ_IF_MAYBE(m, message) { KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m))); return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
} else { } else {
disconnectFulfiller.lockExclusive()->get()->fulfill(); disconnectFulfiller->fulfill();
return nullptr; return nullptr;
} }
}, [&](kj::Exception&& exception) { }, [&](kj::Exception&& exception) {
disconnectFulfiller.lockExclusive()->get()->fulfill(); disconnectFulfiller->fulfill();
kj::throwRecoverableException(kj::mv(exception)); kj::throwRecoverableException(kj::mv(exception));
return nullptr; return nullptr;
}); });
......
...@@ -65,7 +65,7 @@ private: ...@@ -65,7 +65,7 @@ private:
ReaderOptions receiveOptions; ReaderOptions receiveOptions;
bool accepted = false; bool accepted = false;
kj::MutexGuarded<kj::Promise<void>> previousWrite; kj::Promise<void> previousWrite;
// Resolves when the previous write completes. This effectively serves as the write queue. // Resolves when the previous write completes. This effectively serves as the write queue.
kj::Own<kj::PromiseFulfiller<kj::Own<TwoPartyVatNetworkBase::Connection>>> acceptFulfiller; kj::Own<kj::PromiseFulfiller<kj::Own<TwoPartyVatNetworkBase::Connection>>> acceptFulfiller;
...@@ -73,14 +73,14 @@ private: ...@@ -73,14 +73,14 @@ private:
// second call on the server side. Never fulfilled, because there is only one connection. // second call on the server side. Never fulfilled, because there is only one connection.
kj::ForkedPromise<void> disconnectPromise = nullptr; kj::ForkedPromise<void> disconnectPromise = nullptr;
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> disconnectFulfiller; kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller;
kj::ForkedPromise<void> drainedPromise = nullptr; kj::ForkedPromise<void> drainedPromise = nullptr;
class FulfillerDisposer: public kj::Disposer { class FulfillerDisposer: public kj::Disposer {
public: public:
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> fulfiller; mutable kj::Own<kj::PromiseFulfiller<void>> fulfiller;
void disposeImpl(void* pointer) const override { fulfiller.lockExclusive()->get()->fulfill(); } void disposeImpl(void* pointer) const override { fulfiller->fulfill(); }
}; };
FulfillerDisposer drainedFulfiller; FulfillerDisposer drainedFulfiller;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment