Commit f740a60f authored by Kenton Varda's avatar Kenton Varda

Cleaner disconnect handling. Better fix for issue #71, and also simplifies the…

Cleaner disconnect handling.  Better fix for issue #71, and also simplifies the interface and improves robustness.
parent 1fddf5a6
...@@ -578,9 +578,7 @@ public: ...@@ -578,9 +578,7 @@ public:
Request<AnyPointer, AnyPointer> newCall( Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
auto hook = kj::heap<BrokenRequest>(exception, sizeHint); return newBrokenRequest(kj::cp(exception), sizeHint);
auto root = hook->message.getRoot<AnyPointer>();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
} }
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
...@@ -626,4 +624,11 @@ kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason) { ...@@ -626,4 +624,11 @@ kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
return kj::refcounted<BrokenPipeline>(kj::mv(reason)); return kj::refcounted<BrokenPipeline>(kj::mv(reason));
} }
Request<AnyPointer, AnyPointer> newBrokenRequest(
kj::Exception&& reason, kj::Maybe<MessageSize> sizeHint) {
auto hook = kj::heap<BrokenRequest>(kj::mv(reason), sizeHint);
auto root = hook->message.getRoot<AnyPointer>();
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
}
} // namespace capnp } // namespace capnp
...@@ -419,6 +419,10 @@ kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason); ...@@ -419,6 +419,10 @@ kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason);
kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason); kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason);
// Helper function that creates a pipeline which simply throws exceptions when called. // Helper function that creates a pipeline which simply throws exceptions when called.
Request<AnyPointer, AnyPointer> newBrokenRequest(
kj::Exception&& reason, kj::Maybe<MessageSize> sizeHint);
// Helper function that creates a Request object that simply throws exceptions when sent.
// ======================================================================================= // =======================================================================================
// Extend PointerHelpers for interfaces // Extend PointerHelpers for interfaces
......
...@@ -240,7 +240,7 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<Text>, public kj::TaskS ...@@ -240,7 +240,7 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<Text>, public kj::TaskS
// Arrange to destroy the server context when all references are gone, or when the // Arrange to destroy the server context when all references are gone, or when the
// EzRpcServer is destroyed (which will destroy the TaskSet). // EzRpcServer is destroyed (which will destroy the TaskSet).
tasks.add(server->network.onDrained().attach(kj::mv(server))); tasks.add(server->network.onDisconnect().attach(kj::mv(server)));
}))); })));
} }
......
...@@ -66,7 +66,6 @@ kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider, int& ...@@ -66,7 +66,6 @@ kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider, int&
TestRestorer restorer(callCount); TestRestorer restorer(callCount);
auto server = makeRpcServer(network, restorer); auto server = makeRpcServer(network, restorer);
network.onDisconnect().wait(waitScope); network.onDisconnect().wait(waitScope);
network.onDrained().wait(waitScope);
}); });
} }
...@@ -141,9 +140,7 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -141,9 +140,7 @@ TEST(TwoPartyNetwork, Pipelining) {
auto rpcClient = makeRpcClient(network); auto rpcClient = makeRpcClient(network);
bool disconnected = false; bool disconnected = false;
bool drained = false;
kj::Promise<void> disconnectPromise = network.onDisconnect().then([&]() { disconnected = true; }); kj::Promise<void> disconnectPromise = network.onDisconnect().then([&]() { disconnected = true; });
kj::Promise<void> drainedPromise = network.onDrained().then([&]() { drained = true; });
{ {
// Request the particular capability from the server. // Request the particular capability from the server.
...@@ -182,14 +179,12 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -182,14 +179,12 @@ TEST(TwoPartyNetwork, Pipelining) {
} }
EXPECT_FALSE(disconnected); EXPECT_FALSE(disconnected);
EXPECT_FALSE(drained);
// What if we disconnect? // What if we disconnect?
serverThread.pipe->shutdownWrite(); serverThread.pipe->shutdownWrite();
// The other side should also disconnect. // The other side should also disconnect.
disconnectPromise.wait(ioContext.waitScope); disconnectPromise.wait(ioContext.waitScope);
EXPECT_FALSE(drained);
{ {
// Use the now-broken capability. // Use the now-broken capability.
...@@ -213,11 +208,7 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -213,11 +208,7 @@ TEST(TwoPartyNetwork, Pipelining) {
EXPECT_EQ(3, callCount); EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount); EXPECT_EQ(1, reverseCallCount);
} }
EXPECT_FALSE(drained);
} }
drainedPromise.wait(ioContext.waitScope);
} }
} // namespace } // namespace
......
...@@ -30,37 +30,20 @@ namespace capnp { ...@@ -30,37 +30,20 @@ namespace capnp {
TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions) ReaderOptions receiveOptions)
: stream(stream), side(side), receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) { : stream(stream), side(side), receiveOptions(receiveOptions), previousWrite(kj::READY_NOW) {
{
auto paf = kj::newPromiseAndFulfiller<void>(); auto paf = kj::newPromiseAndFulfiller<void>();
drainedPromise = paf.promise.fork(); disconnectPromise = paf.promise.fork();
drainedFulfiller.fulfiller = kj::mv(paf.fulfiller); disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller);
}
{
auto paf = kj::newPromiseAndFulfiller<void>();
// If the RPC system on this side drops the connection, thus firing onDrained() before
// onDisconnected(), we also want to consider ourselves disconnected. Otherwise, we might
// not detect actual disconnect because the RPC system won't attempt to send or receive any
// more messages on the connection. So, we exclusive-join the disconnect promise with the
// first branch of drainedPromise.
disconnectPromise = paf.promise.exclusiveJoin(drainedPromise.addBranch()).fork();
disconnectFulfiller = kj::mv(paf.fulfiller);
}
} }
void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const { void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const {
KJ_DBG("deref", this, refcount);
if (--refcount == 0) { if (--refcount == 0) {
fulfiller->fulfill(); fulfiller->fulfill();
} }
} }
kj::Own<TwoPartyVatNetworkBase::Connection> TwoPartyVatNetwork::asConnection() { kj::Own<TwoPartyVatNetworkBase::Connection> TwoPartyVatNetwork::asConnection() {
KJ_DBG("ref", &drainedFulfiller, drainedFulfiller.refcount); ++disconnectFulfiller.refcount;
++drainedFulfiller.refcount; return kj::Own<TwoPartyVatNetworkBase::Connection>(this, disconnectFulfiller);
return kj::Own<TwoPartyVatNetworkBase::Connection>(this, drainedFulfiller);
} }
kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connectToRefHost( kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connectToRefHost(
...@@ -102,12 +85,10 @@ public: ...@@ -102,12 +85,10 @@ public:
void send() override { void send() override {
network.previousWrite = network.previousWrite.then([&]() { network.previousWrite = network.previousWrite.then([&]() {
auto promise = writeMessage(network.stream, message).then([]() { // Note that if the write fails, all further writes will be skipped due to the exception.
// success; do nothing // We never actually handle this exception because we assume the read end will fail as well
}, [&](kj::Exception&& exception) { // and it's cleaner to handle the failure there.
// Exception during write! auto promise = writeMessage(network.stream, message).eagerlyEvaluate(nullptr);
network.disconnectFulfiller->fulfill();
}).eagerlyEvaluate(nullptr);
return kj::mv(promise); return kj::mv(promise);
}).attach(kj::addRef(*this)); }).attach(kj::addRef(*this));
} }
...@@ -145,13 +126,8 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI ...@@ -145,13 +126,8 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI
KJ_IF_MAYBE(m, message) { KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m))); return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
} else { } else {
disconnectFulfiller->fulfill();
return nullptr; return nullptr;
} }
}, [&](kj::Exception&& exception) {
disconnectFulfiller->fulfill();
kj::throwRecoverableException(kj::mv(exception));
return nullptr;
}); });
}); });
} }
......
...@@ -49,19 +49,6 @@ public: ...@@ -49,19 +49,6 @@ public:
kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); } kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); }
// Returns a promise that resolves when the peer disconnects. // Returns a promise that resolves when the peer disconnects.
//
// TODO(soon): Currently this fires when the underlying physical connection breaks. It should
// fire after the RPC system has detected EOF itself and dropped its connection reference, so
// that it has a chance to reply to connections ended cleanly.
kj::Promise<void> onDrained() { return drainedPromise.addBranch(); }
// Returns a promise that resolves once the peer has disconnected *and* all local objects
// referencing this connection have been destroyed. A caller might use this to decide when it
// is safe to destroy the RpcSystem, if it isn't able to reliably destroy all objects using it
// directly.
//
// TODO(soon): This is not quite designed right. Those local objects should simply be disabled.
// Their existence should not prevent the RpcSystem from being destroyed.
// implements VatNetwork ----------------------------------------------------- // implements VatNetwork -----------------------------------------------------
...@@ -86,14 +73,12 @@ private: ...@@ -86,14 +73,12 @@ private:
// second call on the server side. Never fulfilled, because there is only one connection. // second call on the server side. Never fulfilled, because there is only one connection.
kj::ForkedPromise<void> disconnectPromise = nullptr; kj::ForkedPromise<void> disconnectPromise = nullptr;
kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller;
kj::ForkedPromise<void> drainedPromise = nullptr;
class FulfillerDisposer: public kj::Disposer { class FulfillerDisposer: public kj::Disposer {
// Hack: TwoPartyVatNetwork is both a VatNetwork and a VatNetwork::Connection. When all // Hack: TwoPartyVatNetwork is both a VatNetwork and a VatNetwork::Connection. Whet the RPC
// references to the Connection have been dropped, then we want onDrained() to fire. So we // system detects (or initiates) a disconnection, it drops its reference to the Connection.
// hand out Own<Connection>s with this disposer attached, so that we can detect when they are // When all references have been dropped, then we want onDrained() to fire. So we hand out
// dropped. // Own<Connection>s with this disposer attached, so that we can detect when they are dropped.
public: public:
mutable kj::Own<kj::PromiseFulfiller<void>> fulfiller; mutable kj::Own<kj::PromiseFulfiller<void>> fulfiller;
...@@ -101,7 +86,7 @@ private: ...@@ -101,7 +86,7 @@ private:
void disposeImpl(void* pointer) const override; void disposeImpl(void* pointer) const override;
}; };
FulfillerDisposer drainedFulfiller; FulfillerDisposer disconnectFulfiller;
kj::Own<TwoPartyVatNetworkBase::Connection> asConnection(); kj::Own<TwoPartyVatNetworkBase::Connection> asConnection();
// Returns a pointer to this with the disposer set to drainedFulfiller. // Returns a pointer to this with the disposer set to drainedFulfiller.
......
...@@ -269,15 +269,18 @@ class RpcConnectionState final: public kj::TaskSet::ErrorHandler, public kj::Ref ...@@ -269,15 +269,18 @@ class RpcConnectionState final: public kj::TaskSet::ErrorHandler, public kj::Ref
public: public:
RpcConnectionState(kj::Maybe<SturdyRefRestorerBase&> restorer, RpcConnectionState(kj::Maybe<SturdyRefRestorerBase&> restorer,
kj::Own<VatNetworkBase::Connection>&& connection, kj::Own<VatNetworkBase::Connection>&& connectionParam,
kj::Own<kj::PromiseFulfiller<void>>&& disconnectFulfiller) kj::Own<kj::PromiseFulfiller<void>>&& disconnectFulfiller)
: restorer(restorer), connection(kj::mv(connection)), : restorer(restorer), disconnectFulfiller(kj::mv(disconnectFulfiller)), tasks(*this) {
disconnectFulfiller(kj::mv(disconnectFulfiller)), connection.init<Connected>(kj::mv(connectionParam));
tasks(*this) {
tasks.add(messageLoop()); tasks.add(messageLoop());
} }
kj::Own<ClientHook> restore(AnyPointer::Reader objectId) { kj::Own<ClientHook> restore(AnyPointer::Reader objectId) {
if (connection.is<Disconnected>()) {
return newBrokenCap(kj::cp(connection.get<Disconnected>()));
}
QuestionId questionId; QuestionId questionId;
auto& question = questions.next(questionId); auto& question = questions.next(questionId);
...@@ -291,7 +294,7 @@ public: ...@@ -291,7 +294,7 @@ public:
paf.promise = paf.promise.attach(kj::addRef(*questionRef)); paf.promise = paf.promise.attach(kj::addRef(*questionRef));
{ {
auto message = connection->newOutgoingMessage( auto message = connection.get<Connected>()->newOutgoingMessage(
objectId.targetSize().wordCount + messageSizeHint<rpc::Restore>()); objectId.targetSize().wordCount + messageSizeHint<rpc::Restore>());
auto builder = message->getBody().initAs<rpc::Message>().initRestore(); auto builder = message->getBody().initAs<rpc::Message>().initRestore();
...@@ -311,16 +314,8 @@ public: ...@@ -311,16 +314,8 @@ public:
} }
void disconnect(kj::Exception&& exception) { void disconnect(kj::Exception&& exception) {
{ if (!connection.is<Connected>()) {
// Carefully pull all the objects out of the tables prior to releasing them because their // Already disconnected.
// destructors could come back and mess with the tables.
kj::Vector<kj::Own<PipelineHook>> pipelinesToRelease;
kj::Vector<kj::Own<ClientHook>> clientsToRelease;
kj::Vector<kj::Promise<kj::Own<RpcResponse>>> tailCallsToRelease;
kj::Vector<kj::Promise<void>> resolveOpsToRelease;
if (networkException != nullptr) {
// Oops, already disconnected.
return; return;
} }
...@@ -328,6 +323,14 @@ public: ...@@ -328,6 +323,14 @@ public:
kj::Exception::Nature::NETWORK_FAILURE, kj::Exception::Durability::PERMANENT, kj::Exception::Nature::NETWORK_FAILURE, kj::Exception::Durability::PERMANENT,
__FILE__, __LINE__, kj::str("Disconnected: ", exception.getDescription())); __FILE__, __LINE__, kj::str("Disconnected: ", exception.getDescription()));
KJ_IF_MAYBE(newException, kj::runCatchingExceptions([&]() {
// Carefully pull all the objects out of the tables prior to releasing them because their
// destructors could come back and mess with the tables.
kj::Vector<kj::Own<PipelineHook>> pipelinesToRelease;
kj::Vector<kj::Own<ClientHook>> clientsToRelease;
kj::Vector<kj::Promise<kj::Own<RpcResponse>>> tailCallsToRelease;
kj::Vector<kj::Promise<void>> resolveOpsToRelease;
// All current questions complete with exceptions. // All current questions complete with exceptions.
questions.forEach([&](QuestionId id, Question& question) { questions.forEach([&](QuestionId id, Question& question) {
KJ_IF_MAYBE(questionRef, question.selfRef) { KJ_IF_MAYBE(questionRef, question.selfRef) {
...@@ -367,20 +370,24 @@ public: ...@@ -367,20 +370,24 @@ public:
f->get()->reject(kj::cp(networkException)); f->get()->reject(kj::cp(networkException));
} }
}); });
})) {
this->networkException = kj::mv(networkException); // Some destructor must have thrown an exception. There is no appropriate place to report
// these errors.
KJ_LOG(ERROR, "Uncaught exception when destroying capabilities dropped by disconnect.",
*newException);
} }
{ // Send an abort message, but ignore failure.
// Send an abort message. kj::runCatchingExceptions([&]() {
auto message = connection->newOutgoingMessage( auto message = connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<void>() + exceptionSizeHint(exception)); messageSizeHint<void>() + exceptionSizeHint(exception));
fromException(exception, message->getBody().getAs<rpc::Message>().initAbort()); fromException(exception, message->getBody().getAs<rpc::Message>().initAbort());
message->send(); message->send();
} });
// Indicate disconnect. // Indicate disconnect.
disconnectFulfiller->fulfill(); disconnectFulfiller->fulfill();
connection.init<Disconnected>(kj::mv(networkException));
} }
private: private:
...@@ -508,7 +515,13 @@ private: ...@@ -508,7 +515,13 @@ private:
// OK, now we can define RpcConnectionState's member data. // OK, now we can define RpcConnectionState's member data.
kj::Maybe<SturdyRefRestorerBase&> restorer; kj::Maybe<SturdyRefRestorerBase&> restorer;
kj::Own<VatNetworkBase::Connection> connection;
typedef kj::Own<VatNetworkBase::Connection> Connected;
typedef kj::Exception Disconnected;
kj::OneOf<Connected, Disconnected> connection;
// Once the connection has failed, we drop it and replace it with an exception, which will be
// thrown from all further calls.
kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller; kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller;
ExportTable<ExportId, Export> exports; ExportTable<ExportId, Export> exports;
...@@ -521,10 +534,6 @@ private: ...@@ -521,10 +534,6 @@ private:
std::unordered_map<ClientHook*, ExportId> exportsByCap; std::unordered_map<ClientHook*, ExportId> exportsByCap;
// Maps already-exported ClientHook objects to their ID in the export table. // Maps already-exported ClientHook objects to their ID in the export table.
kj::Maybe<kj::Exception> networkException;
// If the connection has failed, this is the exception describing the failure. All future
// calls should throw this exception.
ExportTable<EmbargoId, Embargo> embargoes; ExportTable<EmbargoId, Embargo> embargoes;
// There are only four tables. This definitely isn't a fifth table. I don't know what you're // There are only four tables. This definitely isn't a fifth table. I don't know what you're
// talking about. // talking about.
...@@ -566,8 +575,13 @@ private: ...@@ -566,8 +575,13 @@ private:
Request<AnyPointer, AnyPointer> newCall( Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
if (!connectionState->connection.is<Connected>()) {
return newBrokenRequest(kj::cp(connectionState->connection.get<Disconnected>()), sizeHint);
}
auto request = kj::heap<RpcRequest>( auto request = kj::heap<RpcRequest>(
*connectionState, sizeHint, kj::addRef(*this)); *connectionState, *connectionState->connection.get<Connected>(),
sizeHint, kj::addRef(*this));
auto callBuilder = request->getCall(); auto callBuilder = request->getCall();
callBuilder.setInterfaceId(interfaceId); callBuilder.setInterfaceId(interfaceId);
...@@ -623,8 +637,8 @@ private: ...@@ -623,8 +637,8 @@ private:
} }
// Send a message releasing our remote references. // Send a message releasing our remote references.
if (remoteRefcount > 0) { if (remoteRefcount > 0 && connectionState->connection.is<Connected>()) {
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<rpc::Release>()); messageSizeHint<rpc::Release>());
rpc::Release::Builder builder = message->getBody().initAs<rpc::Message>().initRelease(); rpc::Release::Builder builder = message->getBody().initAs<rpc::Message>().initRelease();
builder.setId(importId); builder.setId(importId);
...@@ -821,13 +835,14 @@ private: ...@@ -821,13 +835,14 @@ private:
bool receivedCall = false; bool receivedCall = false;
void resolve(kj::Own<ClientHook> replacement, bool isError) { void resolve(kj::Own<ClientHook> replacement, bool isError) {
if (replacement->getBrand() != connectionState.get() && receivedCall && !isError) { if (replacement->getBrand() != connectionState.get() && receivedCall && !isError &&
connectionState->connection.is<Connected>()) {
// The new capability is hosted locally, not on the remote machine. And, we had made calls // The new capability is hosted locally, not on the remote machine. And, we had made calls
// to the promise. We need to make sure those calls echo back to us before we allow new // to the promise. We need to make sure those calls echo back to us before we allow new
// calls to go directly to the local capability, so we need to set a local embargo and send // calls to go directly to the local capability, so we need to set a local embargo and send
// a `Disembargo` to echo through the peer. // a `Disembargo` to echo through the peer.
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<rpc::Disembargo>() + MESSAGE_TARGET_SIZE_HINT); messageSizeHint<rpc::Disembargo>() + MESSAGE_TARGET_SIZE_HINT);
auto disembargo = message->getBody().initAs<rpc::Message>().initDisembargo(); auto disembargo = message->getBody().initAs<rpc::Message>().initDisembargo();
...@@ -970,6 +985,11 @@ private: ...@@ -970,6 +985,11 @@ private:
[this,exportId](kj::Own<ClientHook>&& resolution) -> kj::Promise<void> { [this,exportId](kj::Own<ClientHook>&& resolution) -> kj::Promise<void> {
// Successful resolution. // Successful resolution.
KJ_ASSERT(connection.is<Connected>(),
"Resolving export should have been canceled on disconnect.") {
return kj::READY_NOW;
}
// Get the innermost ClientHook backing the resolved client. This includes traversing // Get the innermost ClientHook backing the resolved client. This includes traversing
// PromiseClients that haven't resolved yet to their underlying ImportClient or // PromiseClients that haven't resolved yet to their underlying ImportClient or
// PipelineClient, so that we get a remote promise that might resolve later. This is // PipelineClient, so that we get a remote promise that might resolve later. This is
...@@ -1006,7 +1026,7 @@ private: ...@@ -1006,7 +1026,7 @@ private:
} }
// OK, we have to send a `Resolve` message. // OK, we have to send a `Resolve` message.
auto message = connection->newOutgoingMessage( auto message = connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<rpc::Resolve>() + sizeInWords<rpc::CapDescriptor>() + 16); messageSizeHint<rpc::Resolve>() + sizeInWords<rpc::CapDescriptor>() + 16);
auto resolve = message->getBody().initAs<rpc::Message>().initResolve(); auto resolve = message->getBody().initAs<rpc::Message>().initResolve();
resolve.setPromiseId(exportId); resolve.setPromiseId(exportId);
...@@ -1016,7 +1036,7 @@ private: ...@@ -1016,7 +1036,7 @@ private:
return kj::READY_NOW; return kj::READY_NOW;
}, [this,exportId](kj::Exception&& exception) { }, [this,exportId](kj::Exception&& exception) {
// send error resolution // send error resolution
auto message = connection->newOutgoingMessage( auto message = connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<rpc::Resolve>() + exceptionSizeHint(exception) + 8); messageSizeHint<rpc::Resolve>() + exceptionSizeHint(exception) + 8);
auto resolve = message->getBody().initAs<rpc::Message>().initResolve(); auto resolve = message->getBody().initAs<rpc::Message>().initResolve();
resolve.setPromiseId(exportId); resolve.setPromiseId(exportId);
...@@ -1142,8 +1162,8 @@ private: ...@@ -1142,8 +1162,8 @@ private:
~QuestionRef() { ~QuestionRef() {
unwindDetector.catchExceptionsIfUnwinding([&]() { unwindDetector.catchExceptionsIfUnwinding([&]() {
// Send the "Finish" message (if the connection is not already broken). // Send the "Finish" message (if the connection is not already broken).
if (connectionState->networkException == nullptr) { if (connectionState->connection.is<Connected>()) {
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<rpc::Finish>()); messageSizeHint<rpc::Finish>());
auto builder = message->getBody().getAs<rpc::Message>().initFinish(); auto builder = message->getBody().getAs<rpc::Message>().initFinish();
builder.setQuestionId(id); builder.setQuestionId(id);
...@@ -1189,11 +1209,11 @@ private: ...@@ -1189,11 +1209,11 @@ private:
class RpcRequest final: public RequestHook { class RpcRequest final: public RequestHook {
public: public:
RpcRequest(RpcConnectionState& connectionState, kj::Maybe<MessageSize> sizeHint, RpcRequest(RpcConnectionState& connectionState, VatNetworkBase::Connection& connection,
kj::Own<RpcClient>&& target) kj::Maybe<MessageSize> sizeHint, kj::Own<RpcClient>&& target)
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
target(kj::mv(target)), target(kj::mv(target)),
message(connectionState.connection->newOutgoingMessage( message(connection.newOutgoingMessage(
firstSegmentSize(sizeHint, messageSizeHint<rpc::Call>() + firstSegmentSize(sizeHint, messageSizeHint<rpc::Call>() +
sizeInWords<rpc::Payload>() + MESSAGE_TARGET_SIZE_HINT))), sizeInWords<rpc::Payload>() + MESSAGE_TARGET_SIZE_HINT))),
callBuilder(message->getBody().getAs<rpc::Message>().initCall()), callBuilder(message->getBody().getAs<rpc::Message>().initCall()),
...@@ -1207,11 +1227,12 @@ private: ...@@ -1207,11 +1227,12 @@ private:
} }
RemotePromise<AnyPointer> send() override { RemotePromise<AnyPointer> send() override {
KJ_IF_MAYBE(e, connectionState->networkException) { if (!connectionState->connection.is<Connected>()) {
// Connection is broken. // Connection is broken.
const kj::Exception& e = connectionState->connection.get<Disconnected>();
return RemotePromise<AnyPointer>( return RemotePromise<AnyPointer>(
kj::Promise<Response<AnyPointer>>(kj::cp(*e)), kj::Promise<Response<AnyPointer>>(kj::cp(e)),
AnyPointer::Pipeline(newBrokenPipeline(kj::cp(*e)))); AnyPointer::Pipeline(newBrokenPipeline(kj::cp(e))));
} }
KJ_IF_MAYBE(redirect, target->writeTarget(callBuilder.getTarget())) { KJ_IF_MAYBE(redirect, target->writeTarget(callBuilder.getTarget())) {
...@@ -1257,7 +1278,7 @@ private: ...@@ -1257,7 +1278,7 @@ private:
SendInternalResult sendResult; SendInternalResult sendResult;
if (connectionState->networkException != nullptr) { if (!connectionState->connection.is<Connected>()) {
// Disconnected; fall back to a regular send() which will fail appropriately. // Disconnected; fall back to a regular send() which will fail appropriately.
return nullptr; return nullptr;
} }
...@@ -1539,8 +1560,8 @@ private: ...@@ -1539,8 +1560,8 @@ private:
// We haven't sent a return yet, so we must have been canceled. Send a cancellation return. // We haven't sent a return yet, so we must have been canceled. Send a cancellation return.
unwindDetector.catchExceptionsIfUnwinding([&]() { unwindDetector.catchExceptionsIfUnwinding([&]() {
// Don't send anything if the connection is broken. // Don't send anything if the connection is broken.
if (connectionState->networkException == nullptr) { if (connectionState->connection.is<Connected>()) {
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<rpc::Return>() + sizeInWords<rpc::Payload>()); messageSizeHint<rpc::Return>() + sizeInWords<rpc::Payload>());
auto builder = message->getBody().initAs<rpc::Message>().initReturn(); auto builder = message->getBody().initAs<rpc::Message>().initReturn();
...@@ -1579,6 +1600,11 @@ private: ...@@ -1579,6 +1600,11 @@ private:
// Avoid sending results if canceled so that we don't have to figure out whether or not // Avoid sending results if canceled so that we don't have to figure out whether or not
// `releaseResultCaps` was set in the already-received `Finish`. // `releaseResultCaps` was set in the already-received `Finish`.
if (!(cancellationFlags & CANCEL_REQUESTED) && isFirstResponder()) { if (!(cancellationFlags & CANCEL_REQUESTED) && isFirstResponder()) {
KJ_ASSERT(connectionState->connection.is<Connected>(),
"Cancellation should have been requested on disconnect.") {
return;
}
if (response == nullptr) getResults(MessageSize{0, 0}); // force initialization of response if (response == nullptr) getResults(MessageSize{0, 0}); // force initialization of response
returnMessage.setAnswerId(answerId); returnMessage.setAnswerId(answerId);
...@@ -1597,7 +1623,8 @@ private: ...@@ -1597,7 +1623,8 @@ private:
void sendErrorReturn(kj::Exception&& exception) { void sendErrorReturn(kj::Exception&& exception) {
KJ_ASSERT(!redirectResults); KJ_ASSERT(!redirectResults);
if (isFirstResponder()) { if (isFirstResponder()) {
auto message = connectionState->connection->newOutgoingMessage( if (connectionState->connection.is<Connected>()) {
auto message = connectionState->connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<rpc::Return>() + exceptionSizeHint(exception)); messageSizeHint<rpc::Return>() + exceptionSizeHint(exception));
auto builder = message->getBody().initAs<rpc::Message>().initReturn(); auto builder = message->getBody().initAs<rpc::Message>().initReturn();
...@@ -1606,6 +1633,7 @@ private: ...@@ -1606,6 +1633,7 @@ private:
fromException(exception, builder.initException()); fromException(exception, builder.initException());
message->send(); message->send();
}
// Do not allow releasing the pipeline because we want pipelined calls to propagate the // Do not allow releasing the pipeline because we want pipelined calls to propagate the
// exception rather than fail with a "no such field" exception. // exception rather than fail with a "no such field" exception.
...@@ -1645,10 +1673,10 @@ private: ...@@ -1645,10 +1673,10 @@ private:
} else { } else {
kj::Own<RpcServerResponse> response; kj::Own<RpcServerResponse> response;
if (redirectResults) { if (redirectResults || !connectionState->connection.is<Connected>()) {
response = kj::refcounted<LocallyRedirectedRpcResponse>(sizeHint); response = kj::refcounted<LocallyRedirectedRpcResponse>(sizeHint);
} else { } else {
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection.get<Connected>()->newOutgoingMessage(
firstSegmentSize(sizeHint, messageSizeHint<rpc::Return>() + firstSegmentSize(sizeHint, messageSizeHint<rpc::Return>() +
sizeInWords<rpc::Payload>())); sizeInWords<rpc::Payload>()));
returnMessage = message->getBody().initAs<rpc::Message>().initReturn(); returnMessage = message->getBody().initAs<rpc::Message>().initReturn();
...@@ -1678,7 +1706,8 @@ private: ...@@ -1678,7 +1706,8 @@ private:
KJ_IF_MAYBE(tailInfo, kj::downcast<RpcRequest>(*request).tailSend()) { KJ_IF_MAYBE(tailInfo, kj::downcast<RpcRequest>(*request).tailSend()) {
if (isFirstResponder()) { if (isFirstResponder()) {
auto message = connectionState->connection->newOutgoingMessage( if (connectionState->connection.is<Connected>()) {
auto message = connectionState->connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<rpc::Return>()); messageSizeHint<rpc::Return>());
auto builder = message->getBody().initAs<rpc::Message>().initReturn(); auto builder = message->getBody().initAs<rpc::Message>().initReturn();
...@@ -1687,6 +1716,7 @@ private: ...@@ -1687,6 +1716,7 @@ private:
builder.setTakeFromOtherQuestion(tailInfo->questionId); builder.setTakeFromOtherQuestion(tailInfo->questionId);
message->send(); message->send();
}
// There are no caps in our return message, but of course the tail results could have // There are no caps in our return message, but of course the tail results could have
// caps, so we must continue to honor pipeline calls (and just bounce them back). // caps, so we must continue to honor pipeline calls (and just bounce them back).
...@@ -1803,7 +1833,11 @@ private: ...@@ -1803,7 +1833,11 @@ private:
// Message handling // Message handling
kj::Promise<void> messageLoop() { kj::Promise<void> messageLoop() {
return connection->receiveIncomingMessage().then( if (!connection.is<Connected>()) {
return kj::READY_NOW;
}
return connection.get<Connected>()->receiveIncomingMessage().then(
[this](kj::Maybe<kj::Own<IncomingRpcMessage>>&& message) { [this](kj::Maybe<kj::Own<IncomingRpcMessage>>&& message) {
KJ_IF_MAYBE(m, message) { KJ_IF_MAYBE(m, message) {
handleMessage(kj::mv(*m)); handleMessage(kj::mv(*m));
...@@ -1864,10 +1898,12 @@ private: ...@@ -1864,10 +1898,12 @@ private:
break; break;
default: { default: {
auto message = connection->newOutgoingMessage( if (connection.is<Connected>()) {
auto message = connection.get<Connected>()->newOutgoingMessage(
firstSegmentSize(reader.totalSize(), messageSizeHint<void>())); firstSegmentSize(reader.totalSize(), messageSizeHint<void>()));
message->getBody().initAs<rpc::Message>().setUnimplemented(reader); message->getBody().initAs<rpc::Message>().setUnimplemented(reader);
message->send(); message->send();
}
break; break;
} }
} }
...@@ -2269,9 +2305,13 @@ private: ...@@ -2269,9 +2305,13 @@ private:
// cap have had time to find their way through the event loop. // cap have had time to find their way through the event loop.
tasks.add(kj::evalLater(kj::mvCapture( tasks.add(kj::evalLater(kj::mvCapture(
target, [this,embargoId](kj::Own<ClientHook>&& target) { target, [this,embargoId](kj::Own<ClientHook>&& target) {
if (!connection.is<Connected>()) {
return;
}
RpcClient& downcasted = kj::downcast<RpcClient>(*target); RpcClient& downcasted = kj::downcast<RpcClient>(*target);
auto message = connection->newOutgoingMessage( auto message = connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<rpc::Disembargo>() + MESSAGE_TARGET_SIZE_HINT); messageSizeHint<rpc::Disembargo>() + MESSAGE_TARGET_SIZE_HINT);
auto builder = message->getBody().initAs<rpc::Message>().initDisembargo(); auto builder = message->getBody().initAs<rpc::Message>().initDisembargo();
...@@ -2341,7 +2381,12 @@ private: ...@@ -2341,7 +2381,12 @@ private:
void handleRestore(kj::Own<IncomingRpcMessage>&& message, const rpc::Restore::Reader& restore) { void handleRestore(kj::Own<IncomingRpcMessage>&& message, const rpc::Restore::Reader& restore) {
AnswerId answerId = restore.getQuestionId(); AnswerId answerId = restore.getQuestionId();
auto response = connection->newOutgoingMessage( if (!connection.is<Connected>()) {
// Disconnected; ignore.
return;
}
auto response = connection.get<Connected>()->newOutgoingMessage(
messageSizeHint<rpc::Return>() + sizeInWords<rpc::CapDescriptor>() + 32); messageSizeHint<rpc::Return>() + sizeInWords<rpc::CapDescriptor>() + 32);
rpc::Return::Builder ret = response->getBody().getAs<rpc::Message>().initReturn(); rpc::Return::Builder ret = response->getBody().getAs<rpc::Message>().initReturn();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment