Commit e1fdba61 authored by Kenton Varda's avatar Kenton Varda

Remove more mutexes.

parent 540b1887
......@@ -253,7 +253,7 @@ public:
: promise(promiseParam.fork()),
selfResolutionOp(promise.addBranch().then(
[this](kj::Own<PipelineHook>&& inner) {
*redirect.lockExclusive() = kj::mv(inner);
redirect = kj::mv(inner);
})) {
selfResolutionOp.eagerlyEvaluate();
}
......@@ -275,7 +275,7 @@ public:
private:
kj::ForkedPromise<kj::Own<PipelineHook>> promise;
kj::MutexGuarded<kj::Maybe<kj::Own<PipelineHook>>> redirect;
kj::Maybe<kj::Own<PipelineHook>> redirect;
// Once the promise resolves, this will become non-null and point to the underlying object.
kj::Promise<void> selfResolutionOp;
......@@ -291,8 +291,10 @@ public:
: promise(promiseParam.fork()),
selfResolutionOp(promise.addBranch().then(
[this](kj::Own<ClientHook>&& inner) {
*redirect.lockExclusive() = kj::mv(inner);
})) {
redirect = kj::mv(inner);
})),
promiseForCallForwarding(promise.addBranch().fork()),
promiseForClientResolution(promise.addBranch().fork()) {
selfResolutionOp.eagerlyEvaluate();
}
......@@ -334,7 +336,7 @@ public:
// Create a promise for the call initiation.
kj::ForkedPromise<kj::Own<CallResultHolder>> callResultPromise =
getPromiseForCallForwarding().addBranch().then(kj::mvCapture(context,
promiseForCallForwarding.addBranch().then(kj::mvCapture(context,
[=](kj::Own<CallContextHook>&& context, kj::Own<ClientHook>&& client){
return kj::refcounted<CallResultHolder>(
client->call(interfaceId, methodId, kj::mv(context)));
......@@ -359,7 +361,7 @@ public:
}
kj::Maybe<ClientHook&> getResolved() override {
KJ_IF_MAYBE(inner, *redirect.lockExclusive()) {
KJ_IF_MAYBE(inner, redirect) {
return **inner;
} else {
return nullptr;
......@@ -367,7 +369,7 @@ public:
}
kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
return getPromiseForClientResolution().addBranch();
return promiseForClientResolution.addBranch();
}
kj::Own<ClientHook> addRef() override {
......@@ -381,50 +383,35 @@ public:
private:
typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork;
kj::Maybe<kj::Own<ClientHook>> redirect;
// Once the promise resolves, this will become non-null and point to the underlying object.
ClientHookPromiseFork promise;
// Promise that resolves when we have a new ClientHook to forward to.
//
// This fork shall only have two branches: `promiseForCallForwarding` and
// This fork shall only have three branches: `selfResolutionOp`, `promiseForCallForwarding`, and
// `promiseForClientResolution`, in that order.
kj::Lazy<ClientHookPromiseFork> promiseForCallForwarding;
kj::Promise<void> selfResolutionOp;
// Represents the operation which will set `redirect` when possible.
ClientHookPromiseFork promiseForCallForwarding;
// When this promise resolves, each queued call will be forwarded to the real client. This needs
// to occur *before* any 'whenMoreResolved()' promises resolve, because we want to make sure
// previously-queued calls are delivered before any new calls made in response to the resolution.
kj::Lazy<ClientHookPromiseFork> promiseForClientResolution;
ClientHookPromiseFork promiseForClientResolution;
// whenMoreResolved() returns forks of this promise. These must resolve *after* queued calls
// have been initiated (so that any calls made in the whenMoreResolved() handler are correctly
// delivered after calls made earlier), but *before* any queued calls return (because it might
// confuse the application if a queued call returns before the capability on which it was made
// resolves). Luckily, we know that queued calls will involve, at the very least, an
// eventLoop.evalLater.
kj::MutexGuarded<kj::Maybe<kj::Own<ClientHook>>> redirect;
// Once the promise resolves, this will become non-null and point to the underlying object.
kj::Promise<void> selfResolutionOp;
// Represents the operation which will set `redirect` when possible.
ClientHookPromiseFork& getPromiseForCallForwarding() {
return promiseForCallForwarding.get([this](kj::SpaceFor<ClientHookPromiseFork>& space) {
return space.construct(promise.addBranch().fork());
});
}
kj::ForkedPromise<kj::Own<ClientHook>>& getPromiseForClientResolution() {
return promiseForClientResolution.get([this](kj::SpaceFor<ClientHookPromiseFork>& space) {
getPromiseForCallForwarding(); // must be initialized first.
return space.construct(promise.addBranch().fork());
});
}
};
kj::Own<ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) {
auto lock = redirect.lockExclusive();
KJ_IF_MAYBE(redirect, *lock) {
return redirect->get()->getPipelinedCap(kj::mv(ops));
KJ_IF_MAYBE(r, redirect) {
return r->get()->getPipelinedCap(kj::mv(ops));
} else {
auto clientPromise = promise.addBranch().then(kj::mvCapture(ops,
[](kj::Array<PipelineOp>&& ops, kj::Own<PipelineHook> pipeline) {
......
......@@ -211,7 +211,7 @@ public:
kj::Exception exception(
kj::Exception::Nature::PRECONDITION, kj::Exception::Durability::PERMANENT,
__FILE__, __LINE__, kj::str("Network was destroyed."));
for (auto& entry: state.getWithoutLock().connections) {
for (auto& entry: connections) {
entry.second->disconnect(kj::cp(exception));
}
}
......@@ -235,11 +235,9 @@ public:
}
void disconnect(kj::Exception&& exception) {
auto lock = queues.lockExclusive();
while (!lock->fulfillers.empty()) {
lock->fulfillers.front()->reject(kj::cp(exception));
lock->fulfillers.pop();
while (!fulfillers.empty()) {
fulfillers.front()->reject(kj::cp(exception));
fulfillers.pop();
}
networkException = kj::mv(exception);
......@@ -285,13 +283,13 @@ public:
connection.tasks->add(kj::evalLater(kj::mvCapture(kj::addRef(*message),
[connectionPtr](kj::Own<IncomingRpcMessageImpl>&& message) {
KJ_IF_MAYBE(p, connectionPtr->partner) {
auto lock = p->queues.lockExclusive();
if (lock->fulfillers.empty()) {
lock->messages.push(kj::mv(message));
if (p->fulfillers.empty()) {
p->messages.push(kj::mv(message));
} else {
++connectionPtr->network.received;
lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message)));
lock->fulfillers.pop();
++p->network.received;
p->fulfillers.front()->fulfill(
kj::Own<IncomingRpcMessage>(kj::mv(message)));
p->fulfillers.pop();
}
}
})));
......@@ -310,15 +308,14 @@ public:
return kj::cp(*e);
}
auto lock = queues.lockExclusive();
if (lock->messages.empty()) {
if (messages.empty()) {
auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>();
lock->fulfillers.push(kj::mv(paf.fulfiller));
fulfillers.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise);
} else {
++network.received;
auto result = kj::mv(lock->messages.front());
lock->messages.pop();
auto result = kj::mv(messages.front());
messages.pop();
return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result));
}
}
......@@ -347,11 +344,8 @@ public:
kj::Maybe<kj::Exception> networkException;
struct Queues {
std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers;
std::queue<kj::Own<IncomingRpcMessage>> messages;
};
kj::MutexGuarded<Queues> queues;
std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers;
std::queue<kj::Own<IncomingRpcMessage>> messages;
kj::Own<kj::TaskSet> tasks;
};
......@@ -360,31 +354,20 @@ public:
test::TestSturdyRefHostId::Reader hostId) override {
TestNetworkAdapter& dst = KJ_REQUIRE_NONNULL(network.find(hostId.getHost()));
kj::Locked<State> myLock;
kj::Locked<State> dstLock;
if (&dst < this) {
dstLock = dst.state.lockExclusive();
myLock = state.lockExclusive();
} else {
myLock = state.lockExclusive();
dstLock = dst.state.lockExclusive();
}
auto iter = myLock->connections.find(&dst);
if (iter == myLock->connections.end()) {
auto iter = connections.find(&dst);
if (iter == connections.end()) {
auto local = kj::refcounted<ConnectionImpl>(*this, RpcDumper::CLIENT);
auto remote = kj::refcounted<ConnectionImpl>(dst, RpcDumper::SERVER);
local->attach(*remote);
myLock->connections[&dst] = kj::addRef(*local);
dstLock->connections[this] = kj::addRef(*remote);
connections[&dst] = kj::addRef(*local);
dst.connections[this] = kj::addRef(*remote);
if (dstLock->fulfillerQueue.empty()) {
dstLock->connectionQueue.push(kj::mv(remote));
if (dst.fulfillerQueue.empty()) {
dst.connectionQueue.push(kj::mv(remote));
} else {
dstLock->fulfillerQueue.front()->fulfill(kj::mv(remote));
dstLock->fulfillerQueue.pop();
dst.fulfillerQueue.front()->fulfill(kj::mv(remote));
dst.fulfillerQueue.pop();
}
return kj::Own<Connection>(kj::mv(local));
......@@ -394,14 +377,13 @@ public:
}
kj::Promise<kj::Own<Connection>> acceptConnectionAsRefHost() override {
auto lock = state.lockExclusive();
if (lock->connectionQueue.empty()) {
if (connectionQueue.empty()) {
auto paf = kj::newPromiseAndFulfiller<kj::Own<Connection>>();
lock->fulfillerQueue.push(kj::mv(paf.fulfiller));
fulfillerQueue.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise);
} else {
auto result = kj::mv(lock->connectionQueue.front());
lock->connectionQueue.pop();
auto result = kj::mv(connectionQueue.front());
connectionQueue.pop();
return kj::mv(result);
}
}
......@@ -411,12 +393,9 @@ private:
uint sent = 0;
uint received = 0;
struct State {
std::map<const TestNetworkAdapter*, kj::Own<ConnectionImpl>> connections;
std::queue<kj::Own<kj::PromiseFulfiller<kj::Own<Connection>>>> fulfillerQueue;
std::queue<kj::Own<Connection>> connectionQueue;
};
kj::MutexGuarded<State> state;
std::map<const TestNetworkAdapter*, kj::Own<ConnectionImpl>> connections;
std::queue<kj::Own<kj::PromiseFulfiller<kj::Own<Connection>>>> fulfillerQueue;
std::queue<kj::Own<Connection>> connectionQueue;
};
TestNetwork::~TestNetwork() noexcept(false) {}
......
......@@ -33,12 +33,12 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty:
{
auto paf = kj::newPromiseAndFulfiller<void>();
disconnectPromise = paf.promise.fork();
disconnectFulfiller.getWithoutLock() = kj::mv(paf.fulfiller);
disconnectFulfiller = kj::mv(paf.fulfiller);
}
{
auto paf = kj::newPromiseAndFulfiller<void>();
drainedPromise = paf.promise.fork();
drainedFulfiller.fulfiller.getWithoutLock() = kj::mv(paf.fulfiller);
drainedFulfiller.fulfiller = kj::mv(paf.fulfiller);
}
}
......@@ -68,7 +68,7 @@ kj::Promise<kj::Own<TwoPartyVatNetworkBase::Connection>>
class TwoPartyVatNetwork::OutgoingMessageImpl final
: public OutgoingRpcMessage, public kj::Refcounted {
public:
OutgoingMessageImpl(const TwoPartyVatNetwork& network, uint firstSegmentWordSize)
OutgoingMessageImpl(TwoPartyVatNetwork& network, uint firstSegmentWordSize)
: network(network),
message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize) {}
......@@ -77,22 +77,21 @@ public:
}
void send() override {
auto lock = network.previousWrite.lockExclusive();
*lock = lock->then([&]() {
network.previousWrite = network.previousWrite.then([&]() {
auto promise = writeMessage(network.stream, message).then([]() {
// success; do nothing
}, [&](kj::Exception&& exception) {
// Exception during write!
network.disconnectFulfiller.lockExclusive()->get()->fulfill();
network.disconnectFulfiller->fulfill();
});
promise.eagerlyEvaluate();
return kj::mv(promise);
});
lock->attach(kj::addRef(*this));
network.previousWrite.attach(kj::addRef(*this));
}
private:
const TwoPartyVatNetwork& network;
TwoPartyVatNetwork& network;
MallocMessageBuilder message;
};
......@@ -120,11 +119,11 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI
KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
} else {
disconnectFulfiller.lockExclusive()->get()->fulfill();
disconnectFulfiller->fulfill();
return nullptr;
}
}, [&](kj::Exception&& exception) {
disconnectFulfiller.lockExclusive()->get()->fulfill();
disconnectFulfiller->fulfill();
kj::throwRecoverableException(kj::mv(exception));
return nullptr;
});
......
......@@ -65,7 +65,7 @@ private:
ReaderOptions receiveOptions;
bool accepted = false;
kj::MutexGuarded<kj::Promise<void>> previousWrite;
kj::Promise<void> previousWrite;
// Resolves when the previous write completes. This effectively serves as the write queue.
kj::Own<kj::PromiseFulfiller<kj::Own<TwoPartyVatNetworkBase::Connection>>> acceptFulfiller;
......@@ -73,14 +73,14 @@ private:
// second call on the server side. Never fulfilled, because there is only one connection.
kj::ForkedPromise<void> disconnectPromise = nullptr;
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> disconnectFulfiller;
kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller;
kj::ForkedPromise<void> drainedPromise = nullptr;
class FulfillerDisposer: public kj::Disposer {
public:
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> fulfiller;
mutable kj::Own<kj::PromiseFulfiller<void>> fulfiller;
void disposeImpl(void* pointer) const override { fulfiller.lockExclusive()->get()->fulfill(); }
void disposeImpl(void* pointer) const override { fulfiller->fulfill(); }
};
FulfillerDisposer drainedFulfiller;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment