Commit 2aa37a8a authored by Kenton Varda's avatar Kenton Varda

Make RPC system handle disconnect.

parent 38bbbd32
...@@ -215,7 +215,7 @@ private: ...@@ -215,7 +215,7 @@ private:
kj::Own<const ClientHook> BrokenPipeline::getPipelinedCap( kj::Own<const ClientHook> BrokenPipeline::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) const { kj::ArrayPtr<const PipelineOp> ops) const {
return kj::heap<BrokenClient>(exception); return kj::refcounted<BrokenClient>(exception);
} }
} // namespace } // namespace
...@@ -228,4 +228,8 @@ kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason) { ...@@ -228,4 +228,8 @@ kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason) {
return kj::refcounted<BrokenClient>(kj::mv(reason)); return kj::refcounted<BrokenClient>(kj::mv(reason));
} }
kj::Own<const PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
return kj::refcounted<BrokenPipeline>(kj::mv(reason));
}
} // namespace capnp } // namespace capnp
...@@ -211,6 +211,9 @@ kj::Own<const ClientHook> newBrokenCap(kj::StringPtr reason); ...@@ -211,6 +211,9 @@ kj::Own<const ClientHook> newBrokenCap(kj::StringPtr reason);
kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason); kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called. // Helper function that creates a capability which simply throws exceptions when called.
kj::Own<const PipelineHook> newBrokenPipeline(kj::Exception&& reason);
// Helper function that creates a pipeline which simply throws exceptions when called.
// ======================================================================================= // =======================================================================================
// inline implementation details // inline implementation details
......
...@@ -105,7 +105,7 @@ public: ...@@ -105,7 +105,7 @@ public:
if (lock->fulfillers.empty()) { if (lock->fulfillers.empty()) {
lock->messages.push(kj::mv(message)); lock->messages.push(kj::mv(message));
} else { } else {
lock->fulfillers.front()->fulfill(kj::mv(message)); lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message)));
lock->fulfillers.pop(); lock->fulfillers.pop();
} }
} }
...@@ -119,16 +119,16 @@ public: ...@@ -119,16 +119,16 @@ public:
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override { kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override {
return kj::heap<OutgoingRpcMessageImpl>(*this, firstSegmentWordSize); return kj::heap<OutgoingRpcMessageImpl>(*this, firstSegmentWordSize);
} }
kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() override { kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override {
auto lock = queues.lockExclusive(); auto lock = queues.lockExclusive();
if (lock->messages.empty()) { if (lock->messages.empty()) {
auto paf = kj::newPromiseAndFulfiller<kj::Own<IncomingRpcMessage>>(); auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>();
lock->fulfillers.push(kj::mv(paf.fulfiller)); lock->fulfillers.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise); return kj::mv(paf.promise);
} else { } else {
auto result = kj::mv(lock->messages.front()); auto result = kj::mv(lock->messages.front());
lock->messages.pop(); lock->messages.pop();
return kj::mv(result); return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result));
} }
} }
void introduceTo(Connection& recipient, void introduceTo(Connection& recipient,
...@@ -149,7 +149,7 @@ public: ...@@ -149,7 +149,7 @@ public:
kj::Maybe<ConnectionImpl&> partner; kj::Maybe<ConnectionImpl&> partner;
struct Queues { struct Queues {
std::queue<kj::Own<kj::PromiseFulfiller<kj::Own<IncomingRpcMessage>>>> fulfillers; std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers;
std::queue<kj::Own<IncomingRpcMessage>> messages; std::queue<kj::Own<IncomingRpcMessage>> messages;
}; };
kj::MutexGuarded<Queues> queues; kj::MutexGuarded<Queues> queues;
......
...@@ -144,7 +144,7 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -144,7 +144,7 @@ TEST(TwoPartyNetwork, Pipelining) {
// Start up server in another thread. // Start up server in another thread.
auto quitter = kj::newPromiseAndFulfiller<void>(); auto quitter = kj::newPromiseAndFulfiller<void>();
kj::Thread thread([&]() { auto thread = kj::heap<kj::Thread>([&]() {
runServer(kj::mv(quitter.promise), kj::mv(pipe.ends[1]), callCount); runServer(kj::mv(quitter.promise), kj::mv(pipe.ends[1]), callCount);
}); });
KJ_DEFER(quitter.fulfiller->fulfill()); // Stop the server loop before destroying the thread. KJ_DEFER(quitter.fulfiller->fulfill()); // Stop the server loop before destroying the thread.
...@@ -154,38 +154,88 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -154,38 +154,88 @@ TEST(TwoPartyNetwork, Pipelining) {
TwoPartyVatNetwork network(loop, *pipe.ends[0], rpc::twoparty::Side::CLIENT); TwoPartyVatNetwork network(loop, *pipe.ends[0], rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network, loop); auto rpcClient = makeRpcClient(network, loop);
// Request the particular capability from the server. bool disconnected = false;
auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER, bool drained = false;
test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>(); kj::Promise<void> disconnectPromise = loop.there(network.onDisconnect(),
[&]() { disconnected = true; });
kj::Promise<void> drainedPromise = loop.there(network.onDrained(),
[&]() { drained = true; });
// Use the capability. {
auto request = client.getCapRequest(); // Request the particular capability from the server.
request.setN(234); auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
request.setInCap(test::TestInterface::Client( test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>();
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
auto promise = request.send(); {
// Use the capability.
auto request = client.getCapRequest();
request.setN(234);
request.setInCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
auto pipelineRequest = promise.getOutBox().getCap().fooRequest(); auto promise = request.send();
pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
auto pipelineRequest2 = promise.getOutBox().getCap().castAs<test::TestExtends>().graultRequest(); auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
auto pipelinePromise2 = pipelineRequest2.send(); pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
promise = nullptr; // Just to be annoying, drop the original promise. auto pipelineRequest2 = promise.getOutBox().getCap()
.castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
EXPECT_EQ(0, callCount); promise = nullptr; // Just to be annoying, drop the original promise.
EXPECT_EQ(0, reverseCallCount);
EXPECT_EQ(0, callCount);
EXPECT_EQ(0, reverseCallCount);
auto response = loop.wait(kj::mv(pipelinePromise));
EXPECT_EQ("bar", response.getX());
auto response2 = loop.wait(kj::mv(pipelinePromise2));
checkTestMessage(response2);
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
}
auto response = loop.wait(kj::mv(pipelinePromise)); EXPECT_FALSE(disconnected);
EXPECT_EQ("bar", response.getX()); EXPECT_FALSE(drained);
auto response2 = loop.wait(kj::mv(pipelinePromise2)); // What if the other side disconnects?
checkTestMessage(response2); quitter.fulfiller->fulfill();
thread = nullptr;
loop.wait(kj::mv(disconnectPromise));
EXPECT_FALSE(drained);
{
// Use the now-broken capability.
auto request = client.getCapRequest();
request.setN(234);
request.setInCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
auto promise = request.send();
auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
auto pipelineRequest2 = promise.getOutBox().getCap()
.castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
EXPECT_ANY_THROW(loop.wait(kj::mv(pipelinePromise)));
EXPECT_ANY_THROW(loop.wait(kj::mv(pipelinePromise2)));
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
}
EXPECT_FALSE(drained);
}
EXPECT_EQ(3, callCount); loop.wait(kj::mv(drainedPromise));
EXPECT_EQ(1, reverseCallCount);
} }
} // namespace } // namespace
......
...@@ -31,15 +31,25 @@ TwoPartyVatNetwork::TwoPartyVatNetwork( ...@@ -31,15 +31,25 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(
const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream, rpc::twoparty::Side side, const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions) ReaderOptions receiveOptions)
: eventLoop(eventLoop), stream(stream), side(side), receiveOptions(receiveOptions), : eventLoop(eventLoop), stream(stream), side(side), receiveOptions(receiveOptions),
previousWrite(kj::READY_NOW) {} previousWrite(kj::READY_NOW) {
{
auto paf = kj::newPromiseAndFulfiller<void>();
disconnectPromise = eventLoop.fork(kj::mv(paf.promise));
disconnectFulfiller.getWithoutLock() = kj::mv(paf.fulfiller);
}
{
auto paf = kj::newPromiseAndFulfiller<void>();
drainedPromise = eventLoop.fork(kj::mv(paf.promise));
drainedFulfiller.fulfiller.getWithoutLock() = kj::mv(paf.fulfiller);
}
}
kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connectToRefHost( kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connectToRefHost(
rpc::twoparty::SturdyRefHostId::Reader ref) { rpc::twoparty::SturdyRefHostId::Reader ref) {
if (ref.getSide() == side) { if (ref.getSide() == side) {
return nullptr; return nullptr;
} else { } else {
return kj::Own<TwoPartyVatNetworkBase::Connection>(this, return kj::Own<TwoPartyVatNetworkBase::Connection>(this, drainedFulfiller);
kj::DestructorOnlyDisposer<TwoPartyVatNetworkBase::Connection>::instance);
} }
} }
...@@ -70,13 +80,21 @@ public: ...@@ -70,13 +80,21 @@ public:
void send() override { void send() override {
auto lock = network.previousWrite.lockExclusive(); auto lock = network.previousWrite.lockExclusive();
*lock = network.eventLoop.there(network.eventLoop.there(kj::mv(*lock), [this]() { *lock = network.eventLoop.there(kj::mv(*lock),
return writeMessage(network.stream, message); kj::mvCapture(kj::addRef(*this), [&](kj::Own<OutgoingMessageImpl>&& self) {
}), kj::mvCapture(kj::addRef(*this), return writeMessage(network.stream, message)
[](kj::Own<OutgoingMessageImpl>&& self) -> kj::Promise<void> { .then(kj::mvCapture(kj::mv(self),
// Hack to force this continuation to run (thus allowing `self` to be released) even if [](kj::Own<OutgoingMessageImpl>&& self) -> kj::Promise<void> {
// no one is waiting on the promise. // Just here to hold a reference to `self` until the write completes.
return kj::READY_NOW;
// Hack to force this continuation to run (thus allowing `self` to be released) even if
// no one is waiting on the promise.
return kj::READY_NOW;
}), [&](kj::Exception&& exception) -> kj::Promise<void> {
// Exception during write!
network.disconnectFulfiller.lockExclusive()->get()->fulfill();
return kj::READY_NOW;
});
})); }));
} }
...@@ -102,11 +120,21 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage( ...@@ -102,11 +120,21 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(
return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize); return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize);
} }
kj::Promise<kj::Own<IncomingRpcMessage>> TwoPartyVatNetwork::receiveIncomingMessage() { kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
return eventLoop.evalLater([&]() { return eventLoop.evalLater([&]() {
return readMessage(stream, receiveOptions) return tryReadMessage(stream, receiveOptions)
.then([](kj::Own<MessageReader>&& message) -> kj::Own<IncomingRpcMessage> { .then([&](kj::Maybe<kj::Own<MessageReader>>&& message)
return kj::heap<IncomingMessageImpl>(kj::mv(message)); -> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
} else {
disconnectFulfiller.lockExclusive()->get()->fulfill();
return nullptr;
}
}, [&](kj::Exception&& exception) {
disconnectFulfiller.lockExclusive()->get()->fulfill();
kj::throwRecoverableException(kj::mv(exception));
return nullptr;
}); });
}); });
} }
......
...@@ -41,6 +41,15 @@ public: ...@@ -41,6 +41,15 @@ public:
TwoPartyVatNetwork(const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream, TwoPartyVatNetwork(const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream,
rpc::twoparty::Side side, ReaderOptions receiveOptions = ReaderOptions()); rpc::twoparty::Side side, ReaderOptions receiveOptions = ReaderOptions());
kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); }
// Returns a promise that resolves when the peer disconnects.
kj::Promise<void> onDrained() { return drainedPromise.addBranch(); }
// Returns a promise that resolves once the peer has disconnected *and* all local objects
// referencing this connection have been destroyed. A caller might use this to decide when it
// is safe to destroy the RpcSystem, if it isn't able to reliably destroy all objects using it
// directly.
// implements VatNetwork ----------------------------------------------------- // implements VatNetwork -----------------------------------------------------
kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> connectToRefHost( kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> connectToRefHost(
...@@ -64,10 +73,22 @@ private: ...@@ -64,10 +73,22 @@ private:
// Fulfiller for the promise returned by acceptConnectionAsRefHost() on the client side, or the // Fulfiller for the promise returned by acceptConnectionAsRefHost() on the client side, or the
// second call on the server side. Never fulfilled, because there is only one connection. // second call on the server side. Never fulfilled, because there is only one connection.
kj::ForkedPromise<void> disconnectPromise = nullptr;
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> disconnectFulfiller;
kj::ForkedPromise<void> drainedPromise = nullptr;
class FulfillerDisposer: public kj::Disposer {
public:
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> fulfiller;
void disposeImpl(void* pointer) const override { fulfiller.lockExclusive()->get()->fulfill(); }
};
FulfillerDisposer drainedFulfiller;
// implements Connection ----------------------------------------------------- // implements Connection -----------------------------------------------------
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override; kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override;
kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() override; kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override;
void introduceTo(TwoPartyVatNetworkBase::Connection& recipient, void introduceTo(TwoPartyVatNetworkBase::Connection& recipient,
rpc::twoparty::ThirdPartyCapId::Builder sendToRecipient, rpc::twoparty::ThirdPartyCapId::Builder sendToRecipient,
rpc::twoparty::RecipientId::Builder sendToTarget) override; rpc::twoparty::RecipientId::Builder sendToTarget) override;
......
...@@ -167,6 +167,15 @@ public: ...@@ -167,6 +167,15 @@ public:
} }
} }
template <typename Func>
void forEach(Func&& func) {
for (Id i = 0; i < slots.size(); i++) {
if (slots[i] != nullptr) {
func(i, slots[i]);
}
}
}
private: private:
kj::Vector<T> slots; kj::Vector<T> slots;
std::priority_queue<Id, std::vector<Id>, std::greater<Id>> freeIds; std::priority_queue<Id, std::vector<Id>, std::greater<Id>> freeIds;
...@@ -206,6 +215,16 @@ public: ...@@ -206,6 +215,16 @@ public:
} }
} }
template <typename Func>
void forEach(Func&& func) {
for (Id i: kj::indices(low)) {
func(i, low[i]);
}
for (auto& entry: high) {
func(entry.first, entry.second);
}
}
private: private:
T low[16]; T low[16];
std::unordered_map<Id, T> high; std::unordered_map<Id, T> high;
...@@ -268,14 +287,16 @@ struct Import { ...@@ -268,14 +287,16 @@ struct Import {
// ======================================================================================= // =======================================================================================
class RpcConnectionState final: public kj::TaskSet::ErrorHandler { class RpcConnectionState final: public kj::TaskSet::ErrorHandler, public kj::Refcounted {
class PromisedAnswerClient; class PromisedAnswerClient;
public: public:
RpcConnectionState(const kj::EventLoop& eventLoop, RpcConnectionState(const kj::EventLoop& eventLoop,
kj::Maybe<SturdyRefRestorerBase&> restorer, kj::Maybe<SturdyRefRestorerBase&> restorer,
kj::Own<VatNetworkBase::Connection>&& connection) kj::Own<VatNetworkBase::Connection>&& connection,
kj::Own<kj::PromiseFulfiller<void>>&& disconnectFulfiller)
: eventLoop(eventLoop), restorer(restorer), connection(kj::mv(connection)), : eventLoop(eventLoop), restorer(restorer), connection(kj::mv(connection)),
disconnectFulfiller(kj::mv(disconnectFulfiller)),
tasks(eventLoop, *this), exportDisposer(*this) { tasks(eventLoop, *this), exportDisposer(*this) {
tasks.add(messageLoop()); tasks.add(messageLoop());
} }
...@@ -305,11 +326,11 @@ public: ...@@ -305,11 +326,11 @@ public:
message->send(); message->send();
} }
auto questionRef = kj::refcounted<QuestionRef>(*this, questionId); auto questionRef = kj::heap<QuestionRef>(*this, questionId);
auto promiseWithQuestionRef = eventLoop.there(kj::mv(paf.promise), auto promiseWithQuestionRef = eventLoop.there(kj::mv(paf.promise),
kj::mvCapture(kj::addRef(*questionRef), kj::mvCapture(questionRef,
[](kj::Own<const QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response) [](kj::Own<QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response)
-> kj::Own<const RpcResponse> { -> kj::Own<const RpcResponse> {
response->setQuestionRef(kj::mv(questionRef)); response->setQuestionRef(kj::mv(questionRef));
return kj::mv(response); return kj::mv(response);
...@@ -322,21 +343,68 @@ public: ...@@ -322,21 +343,68 @@ public:
} }
void taskFailed(kj::Exception&& exception) override { void taskFailed(kj::Exception&& exception) override {
// TODO(now): Kill the connection. {
// - All present and future questions must complete with exceptions. kj::Exception networkException(
// - All answers should be canceled (if they allow cancellation). kj::Exception::Nature::NETWORK_FAILURE, kj::Exception::Durability::PERMANENT,
// - All exports are dropped. "", 0, kj::str("Disconnected: ", exception.getDescription()));
// - All imported promises resolve to exceptions.
// - Send abort message.
// - Remove from connection map.
kj::throwRecoverableException(kj::mv(exception)); kj::Vector<kj::Own<const PipelineHook>> pipelinesToRelease;
kj::Vector<kj::Own<const ClientHook>> clientsToRelease;
kj::Vector<kj::Own<CapInjectorImpl>> paramCapsToRelease;
auto lock = tables.lockExclusive();
// All current questions complete with exceptions.
lock->questions.forEach([&](QuestionId id,
Question<CapInjectorImpl, RpcPipeline, RpcResponse>& question) {
question.fulfiller->reject(kj::cp(networkException));
KJ_IF_MAYBE(pc, question.paramCaps) {
paramCapsToRelease.add(kj::mv(*pc));
}
});
lock->answers.forEach([&](QuestionId id, Answer<RpcCallContext>& answer) {
KJ_IF_MAYBE(p, answer.pipeline) {
pipelinesToRelease.add(kj::mv(*p));
}
KJ_IF_MAYBE(context, answer.callContext) {
context->requestCancel();
}
});
lock->exports.forEach([&](ExportId id, Export& exp) {
clientsToRelease.add(kj::mv(exp.clientHook));
exp = Export();
});
lock->imports.forEach([&](ExportId id, Import<ImportClient>& import) {
if (import.client != nullptr) {
import.client->disconnect(kj::cp(networkException));
}
});
lock->networkException = kj::mv(networkException);
}
{
// Send an abort message.
auto message = connection->newOutgoingMessage(
messageSizeHint<rpc::Exception>() +
(exception.getDescription().size() + 7) / sizeof(word));
fromException(exception, message->getBody().getAs<rpc::Message>().initAbort());
message->send();
}
// Indicate disconnect.
disconnectFulfiller->fulfill();
} }
private: private:
const kj::EventLoop& eventLoop; const kj::EventLoop& eventLoop;
kj::Maybe<SturdyRefRestorerBase&> restorer; kj::Maybe<SturdyRefRestorerBase&> restorer;
kj::Own<VatNetworkBase::Connection> connection; kj::Own<VatNetworkBase::Connection> connection;
kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller;
class ImportClient; class ImportClient;
class CapInjectorImpl; class CapInjectorImpl;
...@@ -353,6 +421,10 @@ private: ...@@ -353,6 +421,10 @@ private:
std::unordered_map<const ClientHook*, ExportId> exportsByCap; std::unordered_map<const ClientHook*, ExportId> exportsByCap;
// Maps already-exported ClientHook objects to their ID in the export table. // Maps already-exported ClientHook objects to their ID in the export table.
kj::Maybe<kj::Exception> networkException;
// If the connection has failed, this is the exception describing the failure. All future
// calls should throw this exception.
}; };
kj::MutexGuarded<Tables> tables; kj::MutexGuarded<Tables> tables;
...@@ -391,7 +463,7 @@ private: ...@@ -391,7 +463,7 @@ private:
class RpcClient: public ClientHook, public kj::Refcounted { class RpcClient: public ClientHook, public kj::Refcounted {
public: public:
RpcClient(const RpcConnectionState& connectionState) RpcClient(const RpcConnectionState& connectionState)
: connectionState(connectionState) {} : connectionState(kj::addRef(connectionState)) {}
virtual kj::Maybe<ExportId> writeDescriptor( virtual kj::Maybe<ExportId> writeDescriptor(
rpc::CapDescriptor::Builder descriptor, Tables& tables) const = 0; rpc::CapDescriptor::Builder descriptor, Tables& tables) const = 0;
...@@ -417,6 +489,11 @@ private: ...@@ -417,6 +489,11 @@ private:
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) const override { kj::Own<CallContextHook>&& context) const override {
// Implement call() by copying params and results messages.
// We can and should propagate cancellation.
context->allowAsyncCancellation();
auto params = context->getParams(); auto params = context->getParams();
size_t sizeHint = params.targetSizeInWords(); size_t sizeHint = params.targetSizeInWords();
...@@ -466,7 +543,7 @@ private: ...@@ -466,7 +543,7 @@ private:
} }
protected: protected:
const RpcConnectionState& connectionState; kj::Own<const RpcConnectionState> connectionState;
}; };
class ImportClient: public RpcClient { class ImportClient: public RpcClient {
...@@ -481,7 +558,7 @@ private: ...@@ -481,7 +558,7 @@ private:
// that another thread attempted to obtain this import just as the destructor started, in // that another thread attempted to obtain this import just as the destructor started, in
// which case that other thread will have constructed a new ImportClient and placed it in // which case that other thread will have constructed a new ImportClient and placed it in
// the import table.) // the import table.)
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState->tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports.find(importId)) { KJ_IF_MAYBE(import, lock->imports.find(importId)) {
if (import->client == this) { if (import->client == this) {
lock->imports.erase(importId); lock->imports.erase(importId);
...@@ -491,7 +568,7 @@ private: ...@@ -491,7 +568,7 @@ private:
// Send a message releasing our remote references. // Send a message releasing our remote references.
if (remoteRefcount > 0) { if (remoteRefcount > 0) {
connectionState.sendReleaseLater(importId, remoteRefcount); connectionState->sendReleaseLater(importId, remoteRefcount);
} }
} }
...@@ -499,6 +576,9 @@ private: ...@@ -499,6 +576,9 @@ private:
// Replace the PromiseImportClient with its resolution. Returns false if this is not a promise // Replace the PromiseImportClient with its resolution. Returns false if this is not a promise
// (i.e. it is a SettledImportClient). // (i.e. it is a SettledImportClient).
virtual void disconnect(kj::Exception&& exception) = 0;
// Cause whenMoreResolved() to fail.
kj::Maybe<kj::Own<ImportClient>> tryAddRemoteRef() { kj::Maybe<kj::Own<ImportClient>> tryAddRemoteRef() {
// Add a new RemoteRef and return a new ref to this client representing it. Returns null // Add a new RemoteRef and return a new ref to this client representing it. Returns null
// if this client is being deleted in another thread, in which case the caller should // if this client is being deleted in another thread, in which case the caller should
...@@ -528,7 +608,8 @@ private: ...@@ -528,7 +608,8 @@ private:
Request<ObjectPointer, ObjectPointer> newCall( Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override { uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto request = kj::heap<RpcRequest>(connectionState, firstSegmentWordSize, kj::addRef(*this)); auto request = kj::heap<RpcRequest>(
*connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall(); auto callBuilder = request->getCall();
callBuilder.getTarget().setExportedCap(importId); callBuilder.getTarget().setExportedCap(importId);
...@@ -555,6 +636,10 @@ private: ...@@ -555,6 +636,10 @@ private:
return false; return false;
} }
void disconnect(kj::Exception&& exception) override {
// nothing
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override { kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return nullptr; return nullptr;
} }
...@@ -575,6 +660,10 @@ private: ...@@ -575,6 +660,10 @@ private:
return true; return true;
} }
void disconnect(kj::Exception&& exception) override {
fulfiller->reject(kj::mv(exception));
}
// TODO(now): Override writeDescriptor() and writeTarget() to redirect once the promise // TODO(now): Override writeDescriptor() and writeTarget() to redirect once the promise
// resolves. // resolves.
...@@ -616,10 +705,11 @@ private: ...@@ -616,10 +705,11 @@ private:
if (lock->is<Waiting>()) { if (lock->is<Waiting>()) {
return lock->get<Waiting>()->writeDescriptor(descriptor, tables, ops); return lock->get<Waiting>()->writeDescriptor(descriptor, tables, ops);
} else if (lock->is<Resolved>()) { } else if (lock->is<Resolved>()) {
return connectionState.writeDescriptor(lock->get<Resolved>()->addRef(), descriptor, tables); return connectionState->writeDescriptor(
lock->get<Resolved>()->addRef(), descriptor, tables);
} else { } else {
return connectionState.writeDescriptor(newBrokenCap(kj::cp(lock->get<Broken>())), return connectionState->writeDescriptor(
descriptor, tables); newBrokenCap(kj::cp(lock->get<Broken>())), descriptor, tables);
} }
} }
...@@ -630,7 +720,7 @@ private: ...@@ -630,7 +720,7 @@ private:
if (lock->is<Waiting>()) { if (lock->is<Waiting>()) {
return lock->get<Waiting>()->writeTarget(target, ops); return lock->get<Waiting>()->writeTarget(target, ops);
} else if (lock->is<Resolved>()) { } else if (lock->is<Resolved>()) {
return connectionState.writeTarget(*lock->get<Resolved>(), target); return connectionState->writeTarget(*lock->get<Resolved>(), target);
} else { } else {
return newBrokenCap(kj::cp(lock->get<Broken>())); return newBrokenCap(kj::cp(lock->get<Broken>()));
} }
...@@ -644,7 +734,7 @@ private: ...@@ -644,7 +734,7 @@ private:
if (lock->is<Waiting>()) { if (lock->is<Waiting>()) {
auto request = kj::heap<RpcRequest>( auto request = kj::heap<RpcRequest>(
connectionState, firstSegmentWordSize, kj::addRef(*this)); *connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall(); auto callBuilder = request->getCall();
callBuilder.setInterfaceId(interfaceId); callBuilder.setInterfaceId(interfaceId);
...@@ -900,11 +990,14 @@ private: ...@@ -900,11 +990,14 @@ private:
kj::Vector<kj::Own<const ClientHook>> clientsToRelease(exports.size()); kj::Vector<kj::Own<const ClientHook>> clientsToRelease(exports.size());
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState.tables.lockExclusive();
for (auto exportId: exports) {
auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId)); if (lock->networkException == nullptr) {
if (--exp.refcount == 0) { for (auto exportId: exports) {
clientsToRelease.add(kj::mv(exp.clientHook)); auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId));
lock->exports.erase(exportId); if (--exp.refcount == 0) {
clientsToRelease.add(kj::mv(exp.clientHook));
lock->exports.erase(exportId);
}
} }
} }
} }
...@@ -990,19 +1083,24 @@ private: ...@@ -990,19 +1083,24 @@ private:
// ===================================================================================== // =====================================================================================
// RequestHook/PipelineHook/ResponseHook implementations // RequestHook/PipelineHook/ResponseHook implementations
class QuestionRef: public kj::Refcounted { class QuestionRef {
public: public:
inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id) inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id)
: connectionState(connectionState), id(id), resultCaps(connectionState) {} : connectionState(kj::addRef(connectionState)), id(id) {}
~QuestionRef() { ~QuestionRef() {
// Send the "Finish" message. // Send the "Finish" message.
auto message = connectionState.connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Finish>() + resultCaps.retainedListSizeHint(true)); messageSizeHint<rpc::Finish>() +
resultCaps.map([](CapExtractorImpl& ce) { return ce.retainedListSizeHint(true); })
.orDefault(0));
auto builder = message->getBody().getAs<rpc::Message>().initFinish(); auto builder = message->getBody().getAs<rpc::Message>().initFinish();
builder.setQuestionId(id); builder.setQuestionId(id);
builder.adoptRetainedCaps(resultCaps.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder))); KJ_IF_MAYBE(r, resultCaps) {
builder.adoptRetainedCaps(r->finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder)));
}
message->send(); message->send();
...@@ -1010,7 +1108,7 @@ private: ...@@ -1010,7 +1108,7 @@ private:
// Remove question ID from the table. Must do this *after* sending `Finish` to ensure that // Remove question ID from the table. Must do this *after* sending `Finish` to ensure that
// the ID is not re-allocated before the `Finish` message can be sent. // the ID is not re-allocated before the `Finish` message can be sent.
{ {
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState->tables.lockExclusive();
auto& question = KJ_ASSERT_NONNULL( auto& question = KJ_ASSERT_NONNULL(
lock->questions.find(id), "Question ID no longer on table?"); lock->questions.find(id), "Question ID no longer on table?");
if (question.paramCaps == nullptr) { if (question.paramCaps == nullptr) {
...@@ -1024,19 +1122,21 @@ private: ...@@ -1024,19 +1122,21 @@ private:
inline QuestionId getId() const { return id; } inline QuestionId getId() const { return id; }
CapExtractorImpl& getResultCapExtractor() { return resultCaps; } void setResultCapExtractor(CapExtractorImpl& extractor) {
resultCaps = extractor;
}
private: private:
const RpcConnectionState& connectionState; kj::Own<const RpcConnectionState> connectionState;
QuestionId id; QuestionId id;
CapExtractorImpl resultCaps; kj::Maybe<CapExtractorImpl&> resultCaps;
}; };
class RpcRequest final: public RequestHook { class RpcRequest final: public RequestHook {
public: public:
RpcRequest(const RpcConnectionState& connectionState, uint firstSegmentWordSize, RpcRequest(const RpcConnectionState& connectionState, uint firstSegmentWordSize,
kj::Own<const RpcClient>&& target) kj::Own<const RpcClient>&& target)
: connectionState(connectionState), : connectionState(kj::addRef(connectionState)),
target(kj::mv(target)), target(kj::mv(target)),
message(connectionState.connection->newOutgoingMessage( message(connectionState.connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 : firstSegmentWordSize + messageSizeHint<rpc::Call>())), firstSegmentWordSize == 0 ? 0 : firstSegmentWordSize + messageSizeHint<rpc::Call>())),
...@@ -1057,7 +1157,13 @@ private: ...@@ -1057,7 +1157,13 @@ private:
kj::Promise<kj::Own<RpcResponse>> promise = nullptr; kj::Promise<kj::Own<RpcResponse>> promise = nullptr;
{ {
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState->tables.lockExclusive();
KJ_IF_MAYBE(e, lock->networkException) {
return RemotePromise<ObjectPointer>(
kj::Promise<Response<ObjectPointer>>(kj::cp(*e)),
ObjectPointer::Pipeline(newBrokenPipeline(kj::cp(*e))));
}
KJ_IF_MAYBE(redirect, target->writeTarget(callBuilder.getTarget())) { KJ_IF_MAYBE(redirect, target->writeTarget(callBuilder.getTarget())) {
// Whoops, this capability has been redirected while we were building the request! // Whoops, this capability has been redirected while we were building the request!
...@@ -1082,7 +1188,7 @@ private: ...@@ -1082,7 +1188,7 @@ private:
} else { } else {
injector->finishDescriptors(*lock); injector->finishDescriptors(*lock);
auto paf = kj::newPromiseAndFulfiller<kj::Own<RpcResponse>>(connectionState.eventLoop); auto paf = kj::newPromiseAndFulfiller<kj::Own<RpcResponse>>(connectionState->eventLoop);
auto& question = lock->questions.next(questionId); auto& question = lock->questions.next(questionId);
callBuilder.setQuestionId(questionId); callBuilder.setQuestionId(questionId);
...@@ -1096,16 +1202,16 @@ private: ...@@ -1096,16 +1202,16 @@ private:
} }
} }
auto questionRef = kj::refcounted<QuestionRef>(connectionState, questionId); auto questionRef = kj::heap<QuestionRef>(*connectionState, questionId);
auto promiseWithQuestionRef = promise.thenInAnyThread(kj::mvCapture(kj::addRef(*questionRef), auto promiseWithQuestionRef = promise.thenInAnyThread(kj::mvCapture(questionRef,
[](kj::Own<const QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response) [](kj::Own<QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response)
-> kj::Own<const RpcResponse> { -> kj::Own<const RpcResponse> {
response->setQuestionRef(kj::mv(questionRef)); response->setQuestionRef(kj::mv(questionRef));
return kj::mv(response); return kj::mv(response);
})); }));
auto forkedPromise = connectionState.eventLoop.fork(kj::mv(promiseWithQuestionRef)); auto forkedPromise = connectionState->eventLoop.fork(kj::mv(promiseWithQuestionRef));
auto appPromise = forkedPromise.addBranch().thenInAnyThread( auto appPromise = forkedPromise.addBranch().thenInAnyThread(
[](kj::Own<const RpcResponse>&& response) { [](kj::Own<const RpcResponse>&& response) {
...@@ -1114,7 +1220,7 @@ private: ...@@ -1114,7 +1220,7 @@ private:
}); });
auto pipeline = kj::refcounted<RpcPipeline>( auto pipeline = kj::refcounted<RpcPipeline>(
connectionState, questionId, kj::mv(forkedPromise)); *connectionState, questionId, kj::mv(forkedPromise));
return RemotePromise<ObjectPointer>( return RemotePromise<ObjectPointer>(
kj::mv(appPromise), kj::mv(appPromise),
...@@ -1122,7 +1228,7 @@ private: ...@@ -1122,7 +1228,7 @@ private:
} }
private: private:
const RpcConnectionState& connectionState; kj::Own<const RpcConnectionState> connectionState;
kj::Own<const RpcClient> target; kj::Own<const RpcClient> target;
kj::Own<OutgoingRpcMessage> message; kj::Own<OutgoingRpcMessage> message;
...@@ -1136,7 +1242,7 @@ private: ...@@ -1136,7 +1242,7 @@ private:
public: public:
RpcPipeline(const RpcConnectionState& connectionState, QuestionId questionId, RpcPipeline(const RpcConnectionState& connectionState, QuestionId questionId,
kj::ForkedPromise<kj::Own<const RpcResponse>>&& redirectLaterParam) kj::ForkedPromise<kj::Own<const RpcResponse>>&& redirectLaterParam)
: connectionState(connectionState), : connectionState(kj::addRef(connectionState)),
redirectLater(kj::mv(redirectLaterParam)), redirectLater(kj::mv(redirectLaterParam)),
resolveSelfPromise(connectionState.eventLoop.there(redirectLater.addBranch(), resolveSelfPromise(connectionState.eventLoop.there(redirectLater.addBranch(),
[this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> { [this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> {
...@@ -1186,11 +1292,11 @@ private: ...@@ -1186,11 +1292,11 @@ private:
Orphanage::getForMessageContaining(descriptor), ops)); Orphanage::getForMessageContaining(descriptor), ops));
return nullptr; return nullptr;
} else if (lock->is<Resolved>()) { } else if (lock->is<Resolved>()) {
return connectionState.writeDescriptor( return connectionState->writeDescriptor(
lock->get<Resolved>()->getResults().getPipelinedCap(ops), lock->get<Resolved>()->getResults().getPipelinedCap(ops),
descriptor, tables); descriptor, tables);
} else { } else {
return connectionState.writeDescriptor( return connectionState->writeDescriptor(
newBrokenCap(kj::cp(lock->get<Broken>())), descriptor, tables); newBrokenCap(kj::cp(lock->get<Broken>())), descriptor, tables);
} }
} }
...@@ -1213,7 +1319,7 @@ private: ...@@ -1213,7 +1319,7 @@ private:
auto lock = state.lockExclusive(); auto lock = state.lockExclusive();
if (lock->is<Waiting>()) { if (lock->is<Waiting>()) {
return kj::refcounted<PromisedAnswerClient>( return kj::refcounted<PromisedAnswerClient>(
connectionState, kj::addRef(*this), kj::mv(ops)); *connectionState, kj::addRef(*this), kj::mv(ops));
} else if (lock->is<Resolved>()) { } else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops); return lock->get<Resolved>()->getResults().getPipelinedCap(ops);
} else { } else {
...@@ -1222,7 +1328,7 @@ private: ...@@ -1222,7 +1328,7 @@ private:
} }
private: private:
const RpcConnectionState& connectionState; kj::Own<const RpcConnectionState> connectionState;
kj::Maybe<CapExtractorImpl&> capExtractor; kj::Maybe<CapExtractorImpl&> capExtractor;
kj::ForkedPromise<kj::Own<const RpcResponse>> redirectLater; kj::ForkedPromise<kj::Own<const RpcResponse>> redirectLater;
...@@ -1253,7 +1359,8 @@ private: ...@@ -1253,7 +1359,8 @@ private:
RpcResponse(const RpcConnectionState& connectionState, RpcResponse(const RpcConnectionState& connectionState,
kj::Own<IncomingRpcMessage>&& message, kj::Own<IncomingRpcMessage>&& message,
ObjectPointer::Reader results) ObjectPointer::Reader results)
: message(kj::mv(message)), : connectionState(kj::addRef(connectionState)),
message(kj::mv(message)),
extractor(connectionState), extractor(connectionState),
context(extractor), context(extractor),
reader(context.imbue(results)) {} reader(context.imbue(results)) {}
...@@ -1266,16 +1373,18 @@ private: ...@@ -1266,16 +1373,18 @@ private:
return kj::addRef(*this); return kj::addRef(*this);
} }
void setQuestionRef(kj::Own<const QuestionRef>&& questionRef) { void setQuestionRef(kj::Own<QuestionRef>&& questionRef) {
this->questionRef = kj::mv(questionRef); this->questionRef = kj::mv(questionRef);
this->questionRef->setResultCapExtractor(extractor);
} }
private: private:
kj::Own<const RpcConnectionState> connectionState;
kj::Own<IncomingRpcMessage> message; kj::Own<IncomingRpcMessage> message;
CapExtractorImpl extractor; CapExtractorImpl extractor;
CapReaderContext context; CapReaderContext context;
ObjectPointer::Reader reader; ObjectPointer::Reader reader;
kj::Own<const QuestionRef> questionRef; kj::Own<QuestionRef> questionRef;
}; };
// ===================================================================================== // =====================================================================================
...@@ -1316,9 +1425,9 @@ private: ...@@ -1316,9 +1425,9 @@ private:
class RpcCallContext final: public CallContextHook, public kj::Refcounted { class RpcCallContext final: public CallContextHook, public kj::Refcounted {
public: public:
RpcCallContext(RpcConnectionState& connectionState, QuestionId questionId, RpcCallContext(const RpcConnectionState& connectionState, QuestionId questionId,
kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params) kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params)
: connectionState(connectionState), : connectionState(kj::addRef(connectionState)),
questionId(questionId), questionId(questionId),
request(kj::mv(request)), request(kj::mv(request)),
requestCapExtractor(connectionState), requestCapExtractor(connectionState),
...@@ -1339,7 +1448,7 @@ private: ...@@ -1339,7 +1448,7 @@ private:
} }
void sendErrorReturn(kj::Exception&& exception) { void sendErrorReturn(kj::Exception&& exception) {
if (isFirstResponder()) { if (isFirstResponder()) {
auto message = connectionState.connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Return>() + sizeInWords<rpc::Exception>() + messageSizeHint<rpc::Return>() + sizeInWords<rpc::Exception>() +
exception.getDescription().size() / sizeof(word) + 1); exception.getDescription().size() / sizeof(word) + 1);
auto builder = message->getBody().initAs<rpc::Message>().initReturn(); auto builder = message->getBody().initAs<rpc::Message>().initReturn();
...@@ -1354,7 +1463,7 @@ private: ...@@ -1354,7 +1463,7 @@ private:
} }
void sendCancel() { void sendCancel() {
if (isFirstResponder()) { if (isFirstResponder()) {
auto message = connectionState.connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Return>()); messageSizeHint<rpc::Return>());
auto builder = message->getBody().initAs<rpc::Message>().initReturn(); auto builder = message->getBody().initAs<rpc::Message>().initReturn();
...@@ -1377,7 +1486,7 @@ private: ...@@ -1377,7 +1486,7 @@ private:
// Verify that we're holding the tables mutex. This is important because we're handing off // Verify that we're holding the tables mutex. This is important because we're handing off
// responsibility for deleting the answer. Moreover, the callContext pointer in the answer // responsibility for deleting the answer. Moreover, the callContext pointer in the answer
// table should not be null as this would indicate that we've already returned a result. // table should not be null as this would indicate that we've already returned a result.
KJ_DASSERT(connectionState.tables.getAlreadyLockedExclusive() KJ_DASSERT(connectionState->tables.getAlreadyLockedExclusive()
.answers[questionId].callContext != nullptr); .answers[questionId].callContext != nullptr);
if (__atomic_fetch_or(&cancellationFlags, CANCEL_REQUESTED, __ATOMIC_RELAXED) == if (__atomic_fetch_or(&cancellationFlags, CANCEL_REQUESTED, __ATOMIC_RELAXED) ==
...@@ -1401,13 +1510,13 @@ private: ...@@ -1401,13 +1510,13 @@ private:
KJ_IF_MAYBE(r, response) { KJ_IF_MAYBE(r, response) {
return r->get()->getResults(); return r->get()->getResults();
} else { } else {
auto message = connectionState.connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 : firstSegmentWordSize == 0 ? 0 :
firstSegmentWordSize + messageSizeHint<rpc::Return>() + firstSegmentWordSize + messageSizeHint<rpc::Return>() +
requestCapExtractor.retainedListSizeHint(request == nullptr)); requestCapExtractor.retainedListSizeHint(request == nullptr));
returnMessage = message->getBody().initAs<rpc::Message>().initReturn(); returnMessage = message->getBody().initAs<rpc::Message>().initReturn();
auto response = kj::heap<RpcServerResponse>( auto response = kj::heap<RpcServerResponse>(
connectionState, kj::mv(message), returnMessage.getResults()); *connectionState, kj::mv(message), returnMessage.getResults());
auto results = response->getResults(); auto results = response->getResults();
this->response = kj::mv(response); this->response = kj::mv(response);
return results; return results;
...@@ -1433,7 +1542,7 @@ private: ...@@ -1433,7 +1542,7 @@ private:
} }
private: private:
RpcConnectionState& connectionState; kj::Own<const RpcConnectionState> connectionState;
QuestionId questionId; QuestionId questionId;
// Request --------------------------------------------- // Request ---------------------------------------------
...@@ -1486,7 +1595,7 @@ private: ...@@ -1486,7 +1595,7 @@ private:
// Extract from the answer table the promise representing the executing call. // Extract from the answer table the promise representing the executing call.
kj::Promise<void> asyncOp = nullptr; kj::Promise<void> asyncOp = nullptr;
{ {
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState->tables.lockExclusive();
asyncOp = kj::mv(lock->answers[questionId].asyncOp); asyncOp = kj::mv(lock->answers[questionId].asyncOp);
} }
...@@ -1515,7 +1624,7 @@ private: ...@@ -1515,7 +1624,7 @@ private:
// We need to remove the `callContext` pointer -- which points back to us -- from the // We need to remove the `callContext` pointer -- which points back to us -- from the
// answer table. Or we might even be responsible for removing the entire answer table // answer table. Or we might even be responsible for removing the entire answer table
// entry. // entry.
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState->tables.lockExclusive();
if (__atomic_load_n(&cancellationFlags, __ATOMIC_RELAXED) & CANCEL_REQUESTED) { if (__atomic_load_n(&cancellationFlags, __ATOMIC_RELAXED) & CANCEL_REQUESTED) {
// We are responsible for deleting the answer table entry. Awkwardly, however, the // We are responsible for deleting the answer table entry. Awkwardly, however, the
...@@ -1524,7 +1633,7 @@ private: ...@@ -1524,7 +1633,7 @@ private:
// actual deletion asynchronously. But we have to remove it from the table *now*, while // actual deletion asynchronously. But we have to remove it from the table *now*, while
// we still hold the lock, because once we send the return message the answer ID is free // we still hold the lock, because once we send the return message the answer ID is free
// for reuse. // for reuse.
connectionState.tasks.add(connectionState.eventLoop.evalLater( connectionState->tasks.add(connectionState->eventLoop.evalLater(
kj::mvCapture(lock->answers[questionId], kj::mvCapture(lock->answers[questionId],
[](Answer<RpcCallContext>&& answer) { [](Answer<RpcCallContext>&& answer) {
// Just let the answer be deleted. // Just let the answer be deleted.
...@@ -1558,8 +1667,12 @@ private: ...@@ -1558,8 +1667,12 @@ private:
kj::Promise<void> messageLoop() { kj::Promise<void> messageLoop() {
auto receive = eventLoop.there(connection->receiveIncomingMessage(), auto receive = eventLoop.there(connection->receiveIncomingMessage(),
[this](kj::Own<IncomingRpcMessage>&& message) { [this](kj::Maybe<kj::Own<IncomingRpcMessage>>&& message) {
handleMessage(kj::mv(message)); KJ_IF_MAYBE(m, message) {
handleMessage(kj::mv(*m));
} else {
KJ_FAIL_REQUIRE("Peer disconnected.") { break; }
}
}); });
return eventLoop.there(kj::mv(receive), return eventLoop.there(kj::mv(receive),
[this]() { [this]() {
...@@ -1817,7 +1930,9 @@ private: ...@@ -1817,7 +1930,9 @@ private:
class SingleCapPipeline: public PipelineHook, public kj::Refcounted { class SingleCapPipeline: public PipelineHook, public kj::Refcounted {
public: public:
SingleCapPipeline(kj::Own<const ClientHook>&& cap): cap(kj::mv(cap)) {} SingleCapPipeline(kj::Own<const ClientHook>&& cap,
kj::Own<CapInjectorImpl>&& capInjector)
: cap(kj::mv(cap)), capInjector(kj::mv(capInjector)) {}
kj::Own<const PipelineHook> addRef() const override { kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this); return kj::addRef(*this);
...@@ -1833,6 +1948,7 @@ private: ...@@ -1833,6 +1948,7 @@ private:
private: private:
kj::Own<const ClientHook> cap; kj::Own<const ClientHook> cap;
kj::Own<CapInjectorImpl> capInjector;
}; };
void handleRestore(kj::Own<IncomingRpcMessage>&& message, const rpc::Restore::Reader& restore) { void handleRestore(kj::Own<IncomingRpcMessage>&& message, const rpc::Restore::Reader& restore) {
...@@ -1844,8 +1960,8 @@ private: ...@@ -1844,8 +1960,8 @@ private:
rpc::Return::Builder ret = response->getBody().getAs<rpc::Message>().initReturn(); rpc::Return::Builder ret = response->getBody().getAs<rpc::Message>().initReturn();
ret.setQuestionId(questionId); ret.setQuestionId(questionId);
CapInjectorImpl injector(*this); auto injector = kj::heap<CapInjectorImpl>(*this);
CapBuilderContext context(injector); CapBuilderContext context(*injector);
kj::Own<const ClientHook> capHook; kj::Own<const ClientHook> capHook;
...@@ -1878,11 +1994,12 @@ private: ...@@ -1878,11 +1994,12 @@ private:
return; return;
} }
injector->finishDescriptors(*lock);
answer.active = true; answer.active = true;
answer.pipeline = kj::Own<const PipelineHook>( answer.pipeline = kj::Own<const PipelineHook>(
kj::refcounted<SingleCapPipeline>(kj::mv(capHook))); kj::refcounted<SingleCapPipeline>(kj::mv(capHook), kj::mv(injector)));
injector.finishDescriptors(*lock);
response->send(); response->send();
} }
} }
...@@ -1956,7 +2073,12 @@ private: ...@@ -1956,7 +2073,12 @@ private:
auto iter = lockedMap.find(connection); auto iter = lockedMap.find(connection);
if (iter == lockedMap.end()) { if (iter == lockedMap.end()) {
VatNetworkBase::Connection* connectionPtr = connection; VatNetworkBase::Connection* connectionPtr = connection;
auto newState = kj::heap<RpcConnectionState>(eventLoop, restorer, kj::mv(connection)); auto onDisconnect = kj::newPromiseAndFulfiller<void>();
tasks.add(eventLoop.there(kj::mv(onDisconnect.promise), [this,connectionPtr]() {
connections.lockExclusive()->erase(connectionPtr);
}));
auto newState = kj::refcounted<RpcConnectionState>(
eventLoop, restorer, kj::mv(connection), kj::mv(onDisconnect.fulfiller));
RpcConnectionState& result = *newState; RpcConnectionState& result = *newState;
lockedMap.insert(std::make_pair(connectionPtr, kj::mv(newState))); lockedMap.insert(std::make_pair(connectionPtr, kj::mv(newState)));
return result; return result;
......
...@@ -31,7 +31,7 @@ namespace capnp { ...@@ -31,7 +31,7 @@ namespace capnp {
// ======================================================================================= // =======================================================================================
// *************************************************************************************** // ***************************************************************************************
// This section contains various internal stuff that needs to be declared upfront. // This section contains various internal stuff that needs to be declared upfront.
// Scroll down to `class EventLoop` or `class Promise` for the public interfaces. // Scroll down to `class VatNetwork` or `class RpcSystem` for the public interfaces.
// *************************************************************************************** // ***************************************************************************************
// ======================================================================================= // =======================================================================================
...@@ -60,7 +60,7 @@ public: ...@@ -60,7 +60,7 @@ public:
class Connection { class Connection {
public: public:
virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const = 0; virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const = 0;
virtual kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() = 0; virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
virtual void baseIntroduceTo(Connection& recipient, virtual void baseIntroduceTo(Connection& recipient,
ObjectPointer::Builder sendToRecipient, ObjectPointer::Builder sendToRecipient,
ObjectPointer::Builder sendToTarget) = 0; ObjectPointer::Builder sendToTarget) = 0;
...@@ -163,10 +163,9 @@ public: ...@@ -163,10 +163,9 @@ public:
// //
// Notice that this may be called from any thread. // Notice that this may be called from any thread.
virtual kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() = 0; virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
// Wait for a message to be received and return it. If the connection fails before a message // Wait for a message to be received and return it. If the read stream cleanly terminates,
// is received, the promise will be broken -- this is the only way to tell if a connection has // return null. If any other problem occurs, throw an exception.
// died.
// Level 3 features ---------------------------------------------- // Level 3 features ----------------------------------------------
......
...@@ -30,10 +30,12 @@ namespace { ...@@ -30,10 +30,12 @@ namespace {
class AsyncMessageReader: public MessageReader { class AsyncMessageReader: public MessageReader {
public: public:
inline AsyncMessageReader(ReaderOptions options): MessageReader(options) {} inline AsyncMessageReader(ReaderOptions options): MessageReader(options) {
memset(firstWord, 0, sizeof(firstWord));
}
~AsyncMessageReader() noexcept(false) {} ~AsyncMessageReader() noexcept(false) {}
kj::Promise<void> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace); kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
// implements MessageReader ---------------------------------------- // implements MessageReader ----------------------------------------
...@@ -56,68 +58,93 @@ private: ...@@ -56,68 +58,93 @@ private:
inline uint segmentCount() { return firstWord[0].get() + 1; } inline uint segmentCount() { return firstWord[0].get() + 1; }
inline uint segment0Size() { return firstWord[1].get(); } inline uint segment0Size() { return firstWord[1].get(); }
kj::Promise<void> readAfterFirstWord(
kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
kj::Promise<void> readSegments(
kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
}; };
kj::Promise<void> AsyncMessageReader::read(kj::AsyncInputStream& inputStream, kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) { kj::ArrayPtr<word> scratchSpace) {
return inputStream.read(firstWord, sizeof(firstWord)) return inputStream.tryRead(firstWord, sizeof(firstWord), sizeof(firstWord))
.then([this,&inputStream]() -> kj::Promise<void> { .then([this,&inputStream,scratchSpace](size_t n) mutable -> kj::Promise<bool> {
if (segmentCount() == 0) { if (n == 0) {
firstWord[1].set(0); return false;
} else if (n < sizeof(firstWord)) {
// EOF in first word.
KJ_FAIL_REQUIRE("Premature EOF.") {
return false;
}
} }
// Reject messages with too many segments for security reasons. return readAfterFirstWord(inputStream, scratchSpace).then([]() { return true; });
KJ_REQUIRE(segmentCount() < 512, "Message has too many segments.") { });
return kj::READY_NOW; // exception will be propagated }
}
if (segmentCount() > 1) { kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream,
// Read sizes for all segments except the first. Include padding if necessary. kj::ArrayPtr<word> scratchSpace) {
moreSizes = kj::heapArray<_::WireValue<uint32_t>>(segmentCount() & ~1); if (segmentCount() == 0) {
return inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0])); firstWord[1].set(0);
} else { }
return kj::READY_NOW;
}
}).then([this,&inputStream,scratchSpace]() mutable -> kj::Promise<void> {
size_t totalWords = segment0Size();
if (segmentCount() > 1) { // Reject messages with too many segments for security reasons.
for (uint i = 0; i < segmentCount() - 1; i++) { KJ_REQUIRE(segmentCount() < 512, "Message has too many segments.") {
totalWords += moreSizes[i].get(); return kj::READY_NOW; // exception will be propagated
} }
}
// Don't accept a message which the receiver couldn't possibly traverse without hitting the if (segmentCount() > 1) {
// traversal limit. Without this check, a malicious client could transmit a very large segment // Read sizes for all segments except the first. Include padding if necessary.
// size to make the receiver allocate excessive space and possibly crash. moreSizes = kj::heapArray<_::WireValue<uint32_t>>(segmentCount() & ~1);
KJ_REQUIRE(totalWords <= getOptions().traversalLimitInWords, return inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0]))
"Message is too large. To increase the limit on the receiving end, see " .then([this,&inputStream,scratchSpace]() mutable {
"capnp::ReaderOptions.") { return readSegments(inputStream, scratchSpace);
return kj::READY_NOW; // exception will be propagated });
} } else {
return readSegments(inputStream, scratchSpace);
}
}
if (scratchSpace.size() < totalWords) { kj::Promise<void> AsyncMessageReader::readSegments(kj::AsyncInputStream& inputStream,
// TODO(perf): Consider allocating each segment as a separate chunk to reduce memory kj::ArrayPtr<word> scratchSpace) {
// fragmentation. size_t totalWords = segment0Size();
ownedSpace = kj::heapArray<word>(totalWords);
scratchSpace = ownedSpace; if (segmentCount() > 1) {
for (uint i = 0; i < segmentCount() - 1; i++) {
totalWords += moreSizes[i].get();
} }
}
segmentStarts = kj::heapArray<const word*>(segmentCount()); // Don't accept a message which the receiver couldn't possibly traverse without hitting the
// traversal limit. Without this check, a malicious client could transmit a very large segment
// size to make the receiver allocate excessive space and possibly crash.
KJ_REQUIRE(totalWords <= getOptions().traversalLimitInWords,
"Message is too large. To increase the limit on the receiving end, see "
"capnp::ReaderOptions.") {
return kj::READY_NOW; // exception will be propagated
}
if (scratchSpace.size() < totalWords) {
// TODO(perf): Consider allocating each segment as a separate chunk to reduce memory
// fragmentation.
ownedSpace = kj::heapArray<word>(totalWords);
scratchSpace = ownedSpace;
}
segmentStarts[0] = scratchSpace.begin(); segmentStarts = kj::heapArray<const word*>(segmentCount());
if (segmentCount() > 1) { segmentStarts[0] = scratchSpace.begin();
size_t offset = segment0Size();
for (uint i = 1; i < segmentCount(); i++) { if (segmentCount() > 1) {
segmentStarts[i] = scratchSpace.begin() + offset; size_t offset = segment0Size();
offset += moreSizes[i-1].get();
} for (uint i = 1; i < segmentCount(); i++) {
segmentStarts[i] = scratchSpace.begin() + offset;
offset += moreSizes[i-1].get();
} }
}
return inputStream.read(scratchSpace.begin(), totalWords * sizeof(word)); return inputStream.read(scratchSpace.begin(), totalWords * sizeof(word));
});
} }
...@@ -127,11 +154,26 @@ kj::Promise<kj::Own<MessageReader>> readMessage( ...@@ -127,11 +154,26 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options); auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace); auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader) { return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader, bool success) {
KJ_REQUIRE(success, "Premature EOF.") { break; }
return kj::mv(reader); return kj::mv(reader);
})); }));
} }
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader,
[](kj::Own<MessageReader>&& reader, bool success) -> kj::Maybe<kj::Own<MessageReader>> {
if (success) {
return kj::mv(reader);
} else {
return nullptr;
}
}));
}
// ======================================================================================= // =======================================================================================
namespace { namespace {
......
...@@ -38,6 +38,11 @@ kj::Promise<kj::Own<MessageReader>> readMessage( ...@@ -38,6 +38,11 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
// //
// `scratchSpace`, if provided, must remain valid until the returned MessageReader is destroyed. // `scratchSpace`, if provided, must remain valid until the returned MessageReader is destroyed.
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options = ReaderOptions(),
kj::ArrayPtr<word> scratchSpace = nullptr);
// Like `readMessage` but returns null on EOF.
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) kj::ArrayPtr<const kj::ArrayPtr<const word>> segments)
KJ_WARN_UNUSED_RESULT; KJ_WARN_UNUSED_RESULT;
......
...@@ -64,6 +64,9 @@ void registerSignalHandler(int signum) { ...@@ -64,6 +64,9 @@ void registerSignalHandler(int signum) {
void registerSigusr1() { void registerSigusr1() {
registerSignalHandler(SIGUSR1); registerSignalHandler(SIGUSR1);
// We also disable SIGPIPE because users of UnixEventLoop almost certainly don't want it.
signal(SIGPIPE, SIG_IGN);
} }
pthread_once_t registerSigusr1Once = PTHREAD_ONCE_INIT; pthread_once_t registerSigusr1Once = PTHREAD_ONCE_INIT;
......
...@@ -25,4 +25,6 @@ ...@@ -25,4 +25,6 @@
namespace kj { namespace kj {
const NullDisposer NullDisposer::instance = NullDisposer();
} // namespace kj } // namespace kj
...@@ -81,6 +81,15 @@ public: ...@@ -81,6 +81,15 @@ public:
template <typename T> template <typename T>
const DestructorOnlyDisposer<T> DestructorOnlyDisposer<T>::instance = DestructorOnlyDisposer<T>(); const DestructorOnlyDisposer<T> DestructorOnlyDisposer<T>::instance = DestructorOnlyDisposer<T>();
class NullDisposer: public Disposer {
// A disposer that does nothing.
public:
static const NullDisposer instance;
void disposeImpl(void* pointer) const override {}
};
// ======================================================================================= // =======================================================================================
// Own<T> -- An owned pointer. // Own<T> -- An owned pointer.
......
...@@ -142,5 +142,23 @@ TEST(Mutex, Lazy) { ...@@ -142,5 +142,23 @@ TEST(Mutex, Lazy) {
EXPECT_EQ(123u, lazy.get([](SpaceFor<uint>& space) { return space.construct(789); })); EXPECT_EQ(123u, lazy.get([](SpaceFor<uint>& space) { return space.construct(789); }));
} }
TEST(Mutex, LazyException) {
Lazy<uint> lazy;
auto exception = kj::runCatchingExceptions([&]() {
lazy.get([&](SpaceFor<uint>& space) -> Own<uint> {
KJ_FAIL_ASSERT("foo") { break; }
return space.construct(123);
});
});
EXPECT_TRUE(exception != nullptr);
uint i = lazy.get([&](SpaceFor<uint>& space) -> Own<uint> {
return space.construct(456);
});
EXPECT_EQ(456, i);
}
} // namespace } // namespace
} // namespace kj } // namespace kj
...@@ -137,11 +137,23 @@ void Mutex::assertLockedByCaller(Exclusivity exclusivity) { ...@@ -137,11 +137,23 @@ void Mutex::assertLockedByCaller(Exclusivity exclusivity) {
} }
void Once::runOnce(Initializer& init) { void Once::runOnce(Initializer& init) {
startOver:
uint state = UNINITIALIZED; uint state = UNINITIALIZED;
if (__atomic_compare_exchange_n(&futex, &state, INITIALIZING, false, if (__atomic_compare_exchange_n(&futex, &state, INITIALIZING, false,
__ATOMIC_RELAXED, __ATOMIC_RELAXED)) { __ATOMIC_RELAXED, __ATOMIC_RELAXED)) {
// It's our job to initialize! // It's our job to initialize!
init.run(); {
KJ_ON_SCOPE_FAILURE({
// An exception was thrown by the initializer. We have to revert.
if (__atomic_exchange_n(&futex, UNINITIALIZED, __ATOMIC_RELEASE) ==
INITIALIZING_WITH_WAITERS) {
// Someone was waiting for us to finish.
syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, NULL, NULL, 0);
}
});
init.run();
}
if (__atomic_exchange_n(&futex, INITIALIZED, __ATOMIC_RELEASE) == if (__atomic_exchange_n(&futex, INITIALIZED, __ATOMIC_RELEASE) ==
INITIALIZING_WITH_WAITERS) { INITIALIZING_WITH_WAITERS) {
// Someone was waiting for us to finish. // Someone was waiting for us to finish.
...@@ -165,6 +177,12 @@ void Once::runOnce(Initializer& init) { ...@@ -165,6 +177,12 @@ void Once::runOnce(Initializer& init) {
// Wait for initialization. // Wait for initialization.
syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0); syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0);
state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE); state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE);
if (state == UNINITIALIZED) {
// Oh hey, apparently whoever was trying to initialize gave up. Let's take it from the
// top.
goto startOver;
}
} }
} }
} }
...@@ -209,7 +227,7 @@ void Once::disable() noexcept { ...@@ -209,7 +227,7 @@ void Once::disable() noexcept {
// Wait for initialization. // Wait for initialization.
syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0); syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0);
state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE); state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE);
break; continue;
} }
} }
} }
......
...@@ -28,22 +28,26 @@ ...@@ -28,22 +28,26 @@
namespace kj { namespace kj {
Thread::Thread(void* (*run)(void*), void (*deleteArg)(void*), void* arg) { Thread::Thread(Function<void()> func): func(kj::mv(func)) {
static_assert(sizeof(threadId) >= sizeof(pthread_t), static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port."); "pthread_t is larger than a long long on your platform. Please port.");
int pthreadResult = pthread_create(reinterpret_cast<pthread_t*>(&threadId), nullptr, run, arg); int pthreadResult = pthread_create(reinterpret_cast<pthread_t*>(&threadId),
nullptr, &runThread, this);
if (pthreadResult != 0) { if (pthreadResult != 0) {
deleteArg(arg);
KJ_FAIL_SYSCALL("pthread_create", pthreadResult); KJ_FAIL_SYSCALL("pthread_create", pthreadResult);
} }
} }
Thread::~Thread() { Thread::~Thread() noexcept(false) {
int pthreadResult = pthread_join(*reinterpret_cast<pthread_t*>(&threadId), nullptr); int pthreadResult = pthread_join(*reinterpret_cast<pthread_t*>(&threadId), nullptr);
if (pthreadResult != 0) { if (pthreadResult != 0) {
KJ_FAIL_SYSCALL("pthread_join", pthreadResult) { break; } KJ_FAIL_SYSCALL("pthread_join", pthreadResult) { break; }
} }
KJ_IF_MAYBE(e, exception) {
kj::throwRecoverableException(kj::mv(*e));
}
} }
void Thread::sendSignal(int signo) { void Thread::sendSignal(int signo) {
...@@ -53,4 +57,14 @@ void Thread::sendSignal(int signo) { ...@@ -53,4 +57,14 @@ void Thread::sendSignal(int signo) {
} }
} }
void* Thread::runThread(void* ptr) {
Thread* thread = reinterpret_cast<Thread*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
thread->func();
})) {
thread->exception = kj::mv(*exception);
}
return nullptr;
}
} // namespace kj } // namespace kj
...@@ -25,43 +25,30 @@ ...@@ -25,43 +25,30 @@
#define KJ_THREAD_H_ #define KJ_THREAD_H_
#include "common.h" #include "common.h"
#include "function.h"
#include "exception.h"
namespace kj { namespace kj {
class Thread { class Thread {
// A thread! Pass a lambda to the constructor. The destructor joins the thread. // A thread! Pass a lambda to the constructor, and it runs in the thread. The destructor joins
// the thread. If the function throws an exception, it is rethrown from the thread's destructor
// (if not unwinding from another exception).
public: public:
template <typename Func> explicit Thread(Function<void()> func);
explicit Thread(Func&& func)
: Thread(&runThread<Decay<Func>>,
&deleteArg<Decay<Func>>,
new Decay<Func>(kj::fwd<Func>(func))) {}
~Thread(); ~Thread() noexcept(false);
void sendSignal(int signo); void sendSignal(int signo);
// Send a Unix signal to the given thread, using pthread_kill or an equivalent. // Send a Unix signal to the given thread, using pthread_kill or an equivalent.
private: private:
Function<void()> func;
unsigned long long threadId; // actually pthread_t unsigned long long threadId; // actually pthread_t
kj::Maybe<kj::Exception> exception;
Thread(void* (*run)(void*), void (*deleteArg)(void*), void* arg); static void* runThread(void* ptr);
template <typename Func>
static void* runThread(void* ptr) {
// TODO(someday): Catch exceptions and propagate to the joiner.
Func* func = reinterpret_cast<Func*>(ptr);
KJ_DEFER(delete func);
(*func)();
return nullptr;
}
template <typename Func>
static void deleteArg(void* ptr) {
delete reinterpret_cast<Func*>(ptr);
}
}; };
} // namespace kj } // namespace kj
......
linux-gcc-4.7 1735 ./super-test.sh tmpdir capnp-gcc-4.7 quick linux-gcc-4.7 1737 ./super-test.sh tmpdir capnp-gcc-4.7 quick
linux-gcc-4.8 1738 ./super-test.sh tmpdir capnp-gcc-4.8 quick gcc-4.8 linux-gcc-4.8 1740 ./super-test.sh tmpdir capnp-gcc-4.8 quick gcc-4.8
linux-clang 1758 ./super-test.sh tmpdir capnp-clang quick clang linux-clang 1760 ./super-test.sh tmpdir capnp-clang quick clang
mac 807 ./super-test.sh remote beat caffeinate quick mac 805 ./super-test.sh remote beat caffeinate quick
cygwin 810 ./super-test.sh remote Kenton@flashman quick cygwin 810 ./super-test.sh remote Kenton@flashman quick
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment