Commit 2aa37a8a authored by Kenton Varda's avatar Kenton Varda

Make RPC system handle disconnect.

parent 38bbbd32
......@@ -215,7 +215,7 @@ private:
kj::Own<const ClientHook> BrokenPipeline::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) const {
return kj::heap<BrokenClient>(exception);
return kj::refcounted<BrokenClient>(exception);
}
} // namespace
......@@ -228,4 +228,8 @@ kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason) {
return kj::refcounted<BrokenClient>(kj::mv(reason));
}
kj::Own<const PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
return kj::refcounted<BrokenPipeline>(kj::mv(reason));
}
} // namespace capnp
......@@ -211,6 +211,9 @@ kj::Own<const ClientHook> newBrokenCap(kj::StringPtr reason);
kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called.
kj::Own<const PipelineHook> newBrokenPipeline(kj::Exception&& reason);
// Helper function that creates a pipeline which simply throws exceptions when called.
// =======================================================================================
// inline implementation details
......
......@@ -105,7 +105,7 @@ public:
if (lock->fulfillers.empty()) {
lock->messages.push(kj::mv(message));
} else {
lock->fulfillers.front()->fulfill(kj::mv(message));
lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message)));
lock->fulfillers.pop();
}
}
......@@ -119,16 +119,16 @@ public:
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override {
return kj::heap<OutgoingRpcMessageImpl>(*this, firstSegmentWordSize);
}
kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() override {
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override {
auto lock = queues.lockExclusive();
if (lock->messages.empty()) {
auto paf = kj::newPromiseAndFulfiller<kj::Own<IncomingRpcMessage>>();
auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>();
lock->fulfillers.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise);
} else {
auto result = kj::mv(lock->messages.front());
lock->messages.pop();
return kj::mv(result);
return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result));
}
}
void introduceTo(Connection& recipient,
......@@ -149,7 +149,7 @@ public:
kj::Maybe<ConnectionImpl&> partner;
struct Queues {
std::queue<kj::Own<kj::PromiseFulfiller<kj::Own<IncomingRpcMessage>>>> fulfillers;
std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers;
std::queue<kj::Own<IncomingRpcMessage>> messages;
};
kj::MutexGuarded<Queues> queues;
......
......@@ -144,7 +144,7 @@ TEST(TwoPartyNetwork, Pipelining) {
// Start up server in another thread.
auto quitter = kj::newPromiseAndFulfiller<void>();
kj::Thread thread([&]() {
auto thread = kj::heap<kj::Thread>([&]() {
runServer(kj::mv(quitter.promise), kj::mv(pipe.ends[1]), callCount);
});
KJ_DEFER(quitter.fulfiller->fulfill()); // Stop the server loop before destroying the thread.
......@@ -154,38 +154,88 @@ TEST(TwoPartyNetwork, Pipelining) {
TwoPartyVatNetwork network(loop, *pipe.ends[0], rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network, loop);
// Request the particular capability from the server.
auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>();
bool disconnected = false;
bool drained = false;
kj::Promise<void> disconnectPromise = loop.there(network.onDisconnect(),
[&]() { disconnected = true; });
kj::Promise<void> drainedPromise = loop.there(network.onDrained(),
[&]() { drained = true; });
// Use the capability.
auto request = client.getCapRequest();
request.setN(234);
request.setInCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
{
// Request the particular capability from the server.
auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>();
auto promise = request.send();
{
// Use the capability.
auto request = client.getCapRequest();
request.setN(234);
request.setInCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
auto promise = request.send();
auto pipelineRequest2 = promise.getOutBox().getCap().castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
promise = nullptr; // Just to be annoying, drop the original promise.
auto pipelineRequest2 = promise.getOutBox().getCap()
.castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
EXPECT_EQ(0, callCount);
EXPECT_EQ(0, reverseCallCount);
promise = nullptr; // Just to be annoying, drop the original promise.
EXPECT_EQ(0, callCount);
EXPECT_EQ(0, reverseCallCount);
auto response = loop.wait(kj::mv(pipelinePromise));
EXPECT_EQ("bar", response.getX());
auto response2 = loop.wait(kj::mv(pipelinePromise2));
checkTestMessage(response2);
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
}
auto response = loop.wait(kj::mv(pipelinePromise));
EXPECT_EQ("bar", response.getX());
EXPECT_FALSE(disconnected);
EXPECT_FALSE(drained);
auto response2 = loop.wait(kj::mv(pipelinePromise2));
checkTestMessage(response2);
// What if the other side disconnects?
quitter.fulfiller->fulfill();
thread = nullptr;
loop.wait(kj::mv(disconnectPromise));
EXPECT_FALSE(drained);
{
// Use the now-broken capability.
auto request = client.getCapRequest();
request.setN(234);
request.setInCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
auto promise = request.send();
auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
auto pipelineRequest2 = promise.getOutBox().getCap()
.castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
EXPECT_ANY_THROW(loop.wait(kj::mv(pipelinePromise)));
EXPECT_ANY_THROW(loop.wait(kj::mv(pipelinePromise2)));
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
}
EXPECT_FALSE(drained);
}
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
loop.wait(kj::mv(drainedPromise));
}
} // namespace
......
......@@ -31,15 +31,25 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(
const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions)
: eventLoop(eventLoop), stream(stream), side(side), receiveOptions(receiveOptions),
previousWrite(kj::READY_NOW) {}
previousWrite(kj::READY_NOW) {
{
auto paf = kj::newPromiseAndFulfiller<void>();
disconnectPromise = eventLoop.fork(kj::mv(paf.promise));
disconnectFulfiller.getWithoutLock() = kj::mv(paf.fulfiller);
}
{
auto paf = kj::newPromiseAndFulfiller<void>();
drainedPromise = eventLoop.fork(kj::mv(paf.promise));
drainedFulfiller.fulfiller.getWithoutLock() = kj::mv(paf.fulfiller);
}
}
kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connectToRefHost(
rpc::twoparty::SturdyRefHostId::Reader ref) {
if (ref.getSide() == side) {
return nullptr;
} else {
return kj::Own<TwoPartyVatNetworkBase::Connection>(this,
kj::DestructorOnlyDisposer<TwoPartyVatNetworkBase::Connection>::instance);
return kj::Own<TwoPartyVatNetworkBase::Connection>(this, drainedFulfiller);
}
}
......@@ -70,13 +80,21 @@ public:
void send() override {
auto lock = network.previousWrite.lockExclusive();
*lock = network.eventLoop.there(network.eventLoop.there(kj::mv(*lock), [this]() {
return writeMessage(network.stream, message);
}), kj::mvCapture(kj::addRef(*this),
[](kj::Own<OutgoingMessageImpl>&& self) -> kj::Promise<void> {
// Hack to force this continuation to run (thus allowing `self` to be released) even if
// no one is waiting on the promise.
return kj::READY_NOW;
*lock = network.eventLoop.there(kj::mv(*lock),
kj::mvCapture(kj::addRef(*this), [&](kj::Own<OutgoingMessageImpl>&& self) {
return writeMessage(network.stream, message)
.then(kj::mvCapture(kj::mv(self),
[](kj::Own<OutgoingMessageImpl>&& self) -> kj::Promise<void> {
// Just here to hold a reference to `self` until the write completes.
// Hack to force this continuation to run (thus allowing `self` to be released) even if
// no one is waiting on the promise.
return kj::READY_NOW;
}), [&](kj::Exception&& exception) -> kj::Promise<void> {
// Exception during write!
network.disconnectFulfiller.lockExclusive()->get()->fulfill();
return kj::READY_NOW;
});
}));
}
......@@ -102,11 +120,21 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(
return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize);
}
kj::Promise<kj::Own<IncomingRpcMessage>> TwoPartyVatNetwork::receiveIncomingMessage() {
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
return eventLoop.evalLater([&]() {
return readMessage(stream, receiveOptions)
.then([](kj::Own<MessageReader>&& message) -> kj::Own<IncomingRpcMessage> {
return kj::heap<IncomingMessageImpl>(kj::mv(message));
return tryReadMessage(stream, receiveOptions)
.then([&](kj::Maybe<kj::Own<MessageReader>>&& message)
-> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
} else {
disconnectFulfiller.lockExclusive()->get()->fulfill();
return nullptr;
}
}, [&](kj::Exception&& exception) {
disconnectFulfiller.lockExclusive()->get()->fulfill();
kj::throwRecoverableException(kj::mv(exception));
return nullptr;
});
});
}
......
......@@ -41,6 +41,15 @@ public:
TwoPartyVatNetwork(const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream,
rpc::twoparty::Side side, ReaderOptions receiveOptions = ReaderOptions());
kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); }
// Returns a promise that resolves when the peer disconnects.
kj::Promise<void> onDrained() { return drainedPromise.addBranch(); }
// Returns a promise that resolves once the peer has disconnected *and* all local objects
// referencing this connection have been destroyed. A caller might use this to decide when it
// is safe to destroy the RpcSystem, if it isn't able to reliably destroy all objects using it
// directly.
// implements VatNetwork -----------------------------------------------------
kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> connectToRefHost(
......@@ -64,10 +73,22 @@ private:
// Fulfiller for the promise returned by acceptConnectionAsRefHost() on the client side, or the
// second call on the server side. Never fulfilled, because there is only one connection.
kj::ForkedPromise<void> disconnectPromise = nullptr;
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> disconnectFulfiller;
kj::ForkedPromise<void> drainedPromise = nullptr;
class FulfillerDisposer: public kj::Disposer {
public:
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> fulfiller;
void disposeImpl(void* pointer) const override { fulfiller.lockExclusive()->get()->fulfill(); }
};
FulfillerDisposer drainedFulfiller;
// implements Connection -----------------------------------------------------
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override;
kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() override;
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override;
void introduceTo(TwoPartyVatNetworkBase::Connection& recipient,
rpc::twoparty::ThirdPartyCapId::Builder sendToRecipient,
rpc::twoparty::RecipientId::Builder sendToTarget) override;
......
......@@ -167,6 +167,15 @@ public:
}
}
template <typename Func>
void forEach(Func&& func) {
for (Id i = 0; i < slots.size(); i++) {
if (slots[i] != nullptr) {
func(i, slots[i]);
}
}
}
private:
kj::Vector<T> slots;
std::priority_queue<Id, std::vector<Id>, std::greater<Id>> freeIds;
......@@ -206,6 +215,16 @@ public:
}
}
template <typename Func>
void forEach(Func&& func) {
for (Id i: kj::indices(low)) {
func(i, low[i]);
}
for (auto& entry: high) {
func(entry.first, entry.second);
}
}
private:
T low[16];
std::unordered_map<Id, T> high;
......@@ -268,14 +287,16 @@ struct Import {
// =======================================================================================
class RpcConnectionState final: public kj::TaskSet::ErrorHandler {
class RpcConnectionState final: public kj::TaskSet::ErrorHandler, public kj::Refcounted {
class PromisedAnswerClient;
public:
RpcConnectionState(const kj::EventLoop& eventLoop,
kj::Maybe<SturdyRefRestorerBase&> restorer,
kj::Own<VatNetworkBase::Connection>&& connection)
kj::Own<VatNetworkBase::Connection>&& connection,
kj::Own<kj::PromiseFulfiller<void>>&& disconnectFulfiller)
: eventLoop(eventLoop), restorer(restorer), connection(kj::mv(connection)),
disconnectFulfiller(kj::mv(disconnectFulfiller)),
tasks(eventLoop, *this), exportDisposer(*this) {
tasks.add(messageLoop());
}
......@@ -305,11 +326,11 @@ public:
message->send();
}
auto questionRef = kj::refcounted<QuestionRef>(*this, questionId);
auto questionRef = kj::heap<QuestionRef>(*this, questionId);
auto promiseWithQuestionRef = eventLoop.there(kj::mv(paf.promise),
kj::mvCapture(kj::addRef(*questionRef),
[](kj::Own<const QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response)
kj::mvCapture(questionRef,
[](kj::Own<QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response)
-> kj::Own<const RpcResponse> {
response->setQuestionRef(kj::mv(questionRef));
return kj::mv(response);
......@@ -322,21 +343,68 @@ public:
}
void taskFailed(kj::Exception&& exception) override {
// TODO(now): Kill the connection.
// - All present and future questions must complete with exceptions.
// - All answers should be canceled (if they allow cancellation).
// - All exports are dropped.
// - All imported promises resolve to exceptions.
// - Send abort message.
// - Remove from connection map.
{
kj::Exception networkException(
kj::Exception::Nature::NETWORK_FAILURE, kj::Exception::Durability::PERMANENT,
"", 0, kj::str("Disconnected: ", exception.getDescription()));
kj::throwRecoverableException(kj::mv(exception));
kj::Vector<kj::Own<const PipelineHook>> pipelinesToRelease;
kj::Vector<kj::Own<const ClientHook>> clientsToRelease;
kj::Vector<kj::Own<CapInjectorImpl>> paramCapsToRelease;
auto lock = tables.lockExclusive();
// All current questions complete with exceptions.
lock->questions.forEach([&](QuestionId id,
Question<CapInjectorImpl, RpcPipeline, RpcResponse>& question) {
question.fulfiller->reject(kj::cp(networkException));
KJ_IF_MAYBE(pc, question.paramCaps) {
paramCapsToRelease.add(kj::mv(*pc));
}
});
lock->answers.forEach([&](QuestionId id, Answer<RpcCallContext>& answer) {
KJ_IF_MAYBE(p, answer.pipeline) {
pipelinesToRelease.add(kj::mv(*p));
}
KJ_IF_MAYBE(context, answer.callContext) {
context->requestCancel();
}
});
lock->exports.forEach([&](ExportId id, Export& exp) {
clientsToRelease.add(kj::mv(exp.clientHook));
exp = Export();
});
lock->imports.forEach([&](ExportId id, Import<ImportClient>& import) {
if (import.client != nullptr) {
import.client->disconnect(kj::cp(networkException));
}
});
lock->networkException = kj::mv(networkException);
}
{
// Send an abort message.
auto message = connection->newOutgoingMessage(
messageSizeHint<rpc::Exception>() +
(exception.getDescription().size() + 7) / sizeof(word));
fromException(exception, message->getBody().getAs<rpc::Message>().initAbort());
message->send();
}
// Indicate disconnect.
disconnectFulfiller->fulfill();
}
private:
const kj::EventLoop& eventLoop;
kj::Maybe<SturdyRefRestorerBase&> restorer;
kj::Own<VatNetworkBase::Connection> connection;
kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller;
class ImportClient;
class CapInjectorImpl;
......@@ -353,6 +421,10 @@ private:
std::unordered_map<const ClientHook*, ExportId> exportsByCap;
// Maps already-exported ClientHook objects to their ID in the export table.
kj::Maybe<kj::Exception> networkException;
// If the connection has failed, this is the exception describing the failure. All future
// calls should throw this exception.
};
kj::MutexGuarded<Tables> tables;
......@@ -391,7 +463,7 @@ private:
class RpcClient: public ClientHook, public kj::Refcounted {
public:
RpcClient(const RpcConnectionState& connectionState)
: connectionState(connectionState) {}
: connectionState(kj::addRef(connectionState)) {}
virtual kj::Maybe<ExportId> writeDescriptor(
rpc::CapDescriptor::Builder descriptor, Tables& tables) const = 0;
......@@ -417,6 +489,11 @@ private:
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) const override {
// Implement call() by copying params and results messages.
// We can and should propagate cancellation.
context->allowAsyncCancellation();
auto params = context->getParams();
size_t sizeHint = params.targetSizeInWords();
......@@ -466,7 +543,7 @@ private:
}
protected:
const RpcConnectionState& connectionState;
kj::Own<const RpcConnectionState> connectionState;
};
class ImportClient: public RpcClient {
......@@ -481,7 +558,7 @@ private:
// that another thread attempted to obtain this import just as the destructor started, in
// which case that other thread will have constructed a new ImportClient and placed it in
// the import table.)
auto lock = connectionState.tables.lockExclusive();
auto lock = connectionState->tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports.find(importId)) {
if (import->client == this) {
lock->imports.erase(importId);
......@@ -491,7 +568,7 @@ private:
// Send a message releasing our remote references.
if (remoteRefcount > 0) {
connectionState.sendReleaseLater(importId, remoteRefcount);
connectionState->sendReleaseLater(importId, remoteRefcount);
}
}
......@@ -499,6 +576,9 @@ private:
// Replace the PromiseImportClient with its resolution. Returns false if this is not a promise
// (i.e. it is a SettledImportClient).
virtual void disconnect(kj::Exception&& exception) = 0;
// Cause whenMoreResolved() to fail.
kj::Maybe<kj::Own<ImportClient>> tryAddRemoteRef() {
// Add a new RemoteRef and return a new ref to this client representing it. Returns null
// if this client is being deleted in another thread, in which case the caller should
......@@ -528,7 +608,8 @@ private:
Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto request = kj::heap<RpcRequest>(connectionState, firstSegmentWordSize, kj::addRef(*this));
auto request = kj::heap<RpcRequest>(
*connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall();
callBuilder.getTarget().setExportedCap(importId);
......@@ -555,6 +636,10 @@ private:
return false;
}
void disconnect(kj::Exception&& exception) override {
// nothing
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return nullptr;
}
......@@ -575,6 +660,10 @@ private:
return true;
}
void disconnect(kj::Exception&& exception) override {
fulfiller->reject(kj::mv(exception));
}
// TODO(now): Override writeDescriptor() and writeTarget() to redirect once the promise
// resolves.
......@@ -616,10 +705,11 @@ private:
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->writeDescriptor(descriptor, tables, ops);
} else if (lock->is<Resolved>()) {
return connectionState.writeDescriptor(lock->get<Resolved>()->addRef(), descriptor, tables);
return connectionState->writeDescriptor(
lock->get<Resolved>()->addRef(), descriptor, tables);
} else {
return connectionState.writeDescriptor(newBrokenCap(kj::cp(lock->get<Broken>())),
descriptor, tables);
return connectionState->writeDescriptor(
newBrokenCap(kj::cp(lock->get<Broken>())), descriptor, tables);
}
}
......@@ -630,7 +720,7 @@ private:
if (lock->is<Waiting>()) {
return lock->get<Waiting>()->writeTarget(target, ops);
} else if (lock->is<Resolved>()) {
return connectionState.writeTarget(*lock->get<Resolved>(), target);
return connectionState->writeTarget(*lock->get<Resolved>(), target);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()));
}
......@@ -644,7 +734,7 @@ private:
if (lock->is<Waiting>()) {
auto request = kj::heap<RpcRequest>(
connectionState, firstSegmentWordSize, kj::addRef(*this));
*connectionState, firstSegmentWordSize, kj::addRef(*this));
auto callBuilder = request->getCall();
callBuilder.setInterfaceId(interfaceId);
......@@ -900,11 +990,14 @@ private:
kj::Vector<kj::Own<const ClientHook>> clientsToRelease(exports.size());
auto lock = connectionState.tables.lockExclusive();
for (auto exportId: exports) {
auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId));
if (--exp.refcount == 0) {
clientsToRelease.add(kj::mv(exp.clientHook));
lock->exports.erase(exportId);
if (lock->networkException == nullptr) {
for (auto exportId: exports) {
auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId));
if (--exp.refcount == 0) {
clientsToRelease.add(kj::mv(exp.clientHook));
lock->exports.erase(exportId);
}
}
}
}
......@@ -990,19 +1083,24 @@ private:
// =====================================================================================
// RequestHook/PipelineHook/ResponseHook implementations
class QuestionRef: public kj::Refcounted {
class QuestionRef {
public:
inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id)
: connectionState(connectionState), id(id), resultCaps(connectionState) {}
: connectionState(kj::addRef(connectionState)), id(id) {}
~QuestionRef() {
// Send the "Finish" message.
auto message = connectionState.connection->newOutgoingMessage(
messageSizeHint<rpc::Finish>() + resultCaps.retainedListSizeHint(true));
auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Finish>() +
resultCaps.map([](CapExtractorImpl& ce) { return ce.retainedListSizeHint(true); })
.orDefault(0));
auto builder = message->getBody().getAs<rpc::Message>().initFinish();
builder.setQuestionId(id);
builder.adoptRetainedCaps(resultCaps.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder)));
KJ_IF_MAYBE(r, resultCaps) {
builder.adoptRetainedCaps(r->finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder)));
}
message->send();
......@@ -1010,7 +1108,7 @@ private:
// Remove question ID from the table. Must do this *after* sending `Finish` to ensure that
// the ID is not re-allocated before the `Finish` message can be sent.
{
auto lock = connectionState.tables.lockExclusive();
auto lock = connectionState->tables.lockExclusive();
auto& question = KJ_ASSERT_NONNULL(
lock->questions.find(id), "Question ID no longer on table?");
if (question.paramCaps == nullptr) {
......@@ -1024,19 +1122,21 @@ private:
inline QuestionId getId() const { return id; }
CapExtractorImpl& getResultCapExtractor() { return resultCaps; }
void setResultCapExtractor(CapExtractorImpl& extractor) {
resultCaps = extractor;
}
private:
const RpcConnectionState& connectionState;
kj::Own<const RpcConnectionState> connectionState;
QuestionId id;
CapExtractorImpl resultCaps;
kj::Maybe<CapExtractorImpl&> resultCaps;
};
class RpcRequest final: public RequestHook {
public:
RpcRequest(const RpcConnectionState& connectionState, uint firstSegmentWordSize,
kj::Own<const RpcClient>&& target)
: connectionState(connectionState),
: connectionState(kj::addRef(connectionState)),
target(kj::mv(target)),
message(connectionState.connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 : firstSegmentWordSize + messageSizeHint<rpc::Call>())),
......@@ -1057,7 +1157,13 @@ private:
kj::Promise<kj::Own<RpcResponse>> promise = nullptr;
{
auto lock = connectionState.tables.lockExclusive();
auto lock = connectionState->tables.lockExclusive();
KJ_IF_MAYBE(e, lock->networkException) {
return RemotePromise<ObjectPointer>(
kj::Promise<Response<ObjectPointer>>(kj::cp(*e)),
ObjectPointer::Pipeline(newBrokenPipeline(kj::cp(*e))));
}
KJ_IF_MAYBE(redirect, target->writeTarget(callBuilder.getTarget())) {
// Whoops, this capability has been redirected while we were building the request!
......@@ -1082,7 +1188,7 @@ private:
} else {
injector->finishDescriptors(*lock);
auto paf = kj::newPromiseAndFulfiller<kj::Own<RpcResponse>>(connectionState.eventLoop);
auto paf = kj::newPromiseAndFulfiller<kj::Own<RpcResponse>>(connectionState->eventLoop);
auto& question = lock->questions.next(questionId);
callBuilder.setQuestionId(questionId);
......@@ -1096,16 +1202,16 @@ private:
}
}
auto questionRef = kj::refcounted<QuestionRef>(connectionState, questionId);
auto questionRef = kj::heap<QuestionRef>(*connectionState, questionId);
auto promiseWithQuestionRef = promise.thenInAnyThread(kj::mvCapture(kj::addRef(*questionRef),
[](kj::Own<const QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response)
auto promiseWithQuestionRef = promise.thenInAnyThread(kj::mvCapture(questionRef,
[](kj::Own<QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response)
-> kj::Own<const RpcResponse> {
response->setQuestionRef(kj::mv(questionRef));
return kj::mv(response);
}));
auto forkedPromise = connectionState.eventLoop.fork(kj::mv(promiseWithQuestionRef));
auto forkedPromise = connectionState->eventLoop.fork(kj::mv(promiseWithQuestionRef));
auto appPromise = forkedPromise.addBranch().thenInAnyThread(
[](kj::Own<const RpcResponse>&& response) {
......@@ -1114,7 +1220,7 @@ private:
});
auto pipeline = kj::refcounted<RpcPipeline>(
connectionState, questionId, kj::mv(forkedPromise));
*connectionState, questionId, kj::mv(forkedPromise));
return RemotePromise<ObjectPointer>(
kj::mv(appPromise),
......@@ -1122,7 +1228,7 @@ private:
}
private:
const RpcConnectionState& connectionState;
kj::Own<const RpcConnectionState> connectionState;
kj::Own<const RpcClient> target;
kj::Own<OutgoingRpcMessage> message;
......@@ -1136,7 +1242,7 @@ private:
public:
RpcPipeline(const RpcConnectionState& connectionState, QuestionId questionId,
kj::ForkedPromise<kj::Own<const RpcResponse>>&& redirectLaterParam)
: connectionState(connectionState),
: connectionState(kj::addRef(connectionState)),
redirectLater(kj::mv(redirectLaterParam)),
resolveSelfPromise(connectionState.eventLoop.there(redirectLater.addBranch(),
[this](kj::Own<const RpcResponse>&& response) -> kj::Promise<void> {
......@@ -1186,11 +1292,11 @@ private:
Orphanage::getForMessageContaining(descriptor), ops));
return nullptr;
} else if (lock->is<Resolved>()) {
return connectionState.writeDescriptor(
return connectionState->writeDescriptor(
lock->get<Resolved>()->getResults().getPipelinedCap(ops),
descriptor, tables);
} else {
return connectionState.writeDescriptor(
return connectionState->writeDescriptor(
newBrokenCap(kj::cp(lock->get<Broken>())), descriptor, tables);
}
}
......@@ -1213,7 +1319,7 @@ private:
auto lock = state.lockExclusive();
if (lock->is<Waiting>()) {
return kj::refcounted<PromisedAnswerClient>(
connectionState, kj::addRef(*this), kj::mv(ops));
*connectionState, kj::addRef(*this), kj::mv(ops));
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops);
} else {
......@@ -1222,7 +1328,7 @@ private:
}
private:
const RpcConnectionState& connectionState;
kj::Own<const RpcConnectionState> connectionState;
kj::Maybe<CapExtractorImpl&> capExtractor;
kj::ForkedPromise<kj::Own<const RpcResponse>> redirectLater;
......@@ -1253,7 +1359,8 @@ private:
RpcResponse(const RpcConnectionState& connectionState,
kj::Own<IncomingRpcMessage>&& message,
ObjectPointer::Reader results)
: message(kj::mv(message)),
: connectionState(kj::addRef(connectionState)),
message(kj::mv(message)),
extractor(connectionState),
context(extractor),
reader(context.imbue(results)) {}
......@@ -1266,16 +1373,18 @@ private:
return kj::addRef(*this);
}
void setQuestionRef(kj::Own<const QuestionRef>&& questionRef) {
void setQuestionRef(kj::Own<QuestionRef>&& questionRef) {
this->questionRef = kj::mv(questionRef);
this->questionRef->setResultCapExtractor(extractor);
}
private:
kj::Own<const RpcConnectionState> connectionState;
kj::Own<IncomingRpcMessage> message;
CapExtractorImpl extractor;
CapReaderContext context;
ObjectPointer::Reader reader;
kj::Own<const QuestionRef> questionRef;
kj::Own<QuestionRef> questionRef;
};
// =====================================================================================
......@@ -1316,9 +1425,9 @@ private:
class RpcCallContext final: public CallContextHook, public kj::Refcounted {
public:
RpcCallContext(RpcConnectionState& connectionState, QuestionId questionId,
RpcCallContext(const RpcConnectionState& connectionState, QuestionId questionId,
kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params)
: connectionState(connectionState),
: connectionState(kj::addRef(connectionState)),
questionId(questionId),
request(kj::mv(request)),
requestCapExtractor(connectionState),
......@@ -1339,7 +1448,7 @@ private:
}
void sendErrorReturn(kj::Exception&& exception) {
if (isFirstResponder()) {
auto message = connectionState.connection->newOutgoingMessage(
auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Return>() + sizeInWords<rpc::Exception>() +
exception.getDescription().size() / sizeof(word) + 1);
auto builder = message->getBody().initAs<rpc::Message>().initReturn();
......@@ -1354,7 +1463,7 @@ private:
}
void sendCancel() {
if (isFirstResponder()) {
auto message = connectionState.connection->newOutgoingMessage(
auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Return>());
auto builder = message->getBody().initAs<rpc::Message>().initReturn();
......@@ -1377,7 +1486,7 @@ private:
// Verify that we're holding the tables mutex. This is important because we're handing off
// responsibility for deleting the answer. Moreover, the callContext pointer in the answer
// table should not be null as this would indicate that we've already returned a result.
KJ_DASSERT(connectionState.tables.getAlreadyLockedExclusive()
KJ_DASSERT(connectionState->tables.getAlreadyLockedExclusive()
.answers[questionId].callContext != nullptr);
if (__atomic_fetch_or(&cancellationFlags, CANCEL_REQUESTED, __ATOMIC_RELAXED) ==
......@@ -1401,13 +1510,13 @@ private:
KJ_IF_MAYBE(r, response) {
return r->get()->getResults();
} else {
auto message = connectionState.connection->newOutgoingMessage(
auto message = connectionState->connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 :
firstSegmentWordSize + messageSizeHint<rpc::Return>() +
requestCapExtractor.retainedListSizeHint(request == nullptr));
returnMessage = message->getBody().initAs<rpc::Message>().initReturn();
auto response = kj::heap<RpcServerResponse>(
connectionState, kj::mv(message), returnMessage.getResults());
*connectionState, kj::mv(message), returnMessage.getResults());
auto results = response->getResults();
this->response = kj::mv(response);
return results;
......@@ -1433,7 +1542,7 @@ private:
}
private:
RpcConnectionState& connectionState;
kj::Own<const RpcConnectionState> connectionState;
QuestionId questionId;
// Request ---------------------------------------------
......@@ -1486,7 +1595,7 @@ private:
// Extract from the answer table the promise representing the executing call.
kj::Promise<void> asyncOp = nullptr;
{
auto lock = connectionState.tables.lockExclusive();
auto lock = connectionState->tables.lockExclusive();
asyncOp = kj::mv(lock->answers[questionId].asyncOp);
}
......@@ -1515,7 +1624,7 @@ private:
// We need to remove the `callContext` pointer -- which points back to us -- from the
// answer table. Or we might even be responsible for removing the entire answer table
// entry.
auto lock = connectionState.tables.lockExclusive();
auto lock = connectionState->tables.lockExclusive();
if (__atomic_load_n(&cancellationFlags, __ATOMIC_RELAXED) & CANCEL_REQUESTED) {
// We are responsible for deleting the answer table entry. Awkwardly, however, the
......@@ -1524,7 +1633,7 @@ private:
// actual deletion asynchronously. But we have to remove it from the table *now*, while
// we still hold the lock, because once we send the return message the answer ID is free
// for reuse.
connectionState.tasks.add(connectionState.eventLoop.evalLater(
connectionState->tasks.add(connectionState->eventLoop.evalLater(
kj::mvCapture(lock->answers[questionId],
[](Answer<RpcCallContext>&& answer) {
// Just let the answer be deleted.
......@@ -1558,8 +1667,12 @@ private:
kj::Promise<void> messageLoop() {
auto receive = eventLoop.there(connection->receiveIncomingMessage(),
[this](kj::Own<IncomingRpcMessage>&& message) {
handleMessage(kj::mv(message));
[this](kj::Maybe<kj::Own<IncomingRpcMessage>>&& message) {
KJ_IF_MAYBE(m, message) {
handleMessage(kj::mv(*m));
} else {
KJ_FAIL_REQUIRE("Peer disconnected.") { break; }
}
});
return eventLoop.there(kj::mv(receive),
[this]() {
......@@ -1817,7 +1930,9 @@ private:
class SingleCapPipeline: public PipelineHook, public kj::Refcounted {
public:
SingleCapPipeline(kj::Own<const ClientHook>&& cap): cap(kj::mv(cap)) {}
SingleCapPipeline(kj::Own<const ClientHook>&& cap,
kj::Own<CapInjectorImpl>&& capInjector)
: cap(kj::mv(cap)), capInjector(kj::mv(capInjector)) {}
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
......@@ -1833,6 +1948,7 @@ private:
private:
kj::Own<const ClientHook> cap;
kj::Own<CapInjectorImpl> capInjector;
};
void handleRestore(kj::Own<IncomingRpcMessage>&& message, const rpc::Restore::Reader& restore) {
......@@ -1844,8 +1960,8 @@ private:
rpc::Return::Builder ret = response->getBody().getAs<rpc::Message>().initReturn();
ret.setQuestionId(questionId);
CapInjectorImpl injector(*this);
CapBuilderContext context(injector);
auto injector = kj::heap<CapInjectorImpl>(*this);
CapBuilderContext context(*injector);
kj::Own<const ClientHook> capHook;
......@@ -1878,11 +1994,12 @@ private:
return;
}
injector->finishDescriptors(*lock);
answer.active = true;
answer.pipeline = kj::Own<const PipelineHook>(
kj::refcounted<SingleCapPipeline>(kj::mv(capHook)));
kj::refcounted<SingleCapPipeline>(kj::mv(capHook), kj::mv(injector)));
injector.finishDescriptors(*lock);
response->send();
}
}
......@@ -1956,7 +2073,12 @@ private:
auto iter = lockedMap.find(connection);
if (iter == lockedMap.end()) {
VatNetworkBase::Connection* connectionPtr = connection;
auto newState = kj::heap<RpcConnectionState>(eventLoop, restorer, kj::mv(connection));
auto onDisconnect = kj::newPromiseAndFulfiller<void>();
tasks.add(eventLoop.there(kj::mv(onDisconnect.promise), [this,connectionPtr]() {
connections.lockExclusive()->erase(connectionPtr);
}));
auto newState = kj::refcounted<RpcConnectionState>(
eventLoop, restorer, kj::mv(connection), kj::mv(onDisconnect.fulfiller));
RpcConnectionState& result = *newState;
lockedMap.insert(std::make_pair(connectionPtr, kj::mv(newState)));
return result;
......
......@@ -31,7 +31,7 @@ namespace capnp {
// =======================================================================================
// ***************************************************************************************
// This section contains various internal stuff that needs to be declared upfront.
// Scroll down to `class EventLoop` or `class Promise` for the public interfaces.
// Scroll down to `class VatNetwork` or `class RpcSystem` for the public interfaces.
// ***************************************************************************************
// =======================================================================================
......@@ -60,7 +60,7 @@ public:
class Connection {
public:
virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const = 0;
virtual kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() = 0;
virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
virtual void baseIntroduceTo(Connection& recipient,
ObjectPointer::Builder sendToRecipient,
ObjectPointer::Builder sendToTarget) = 0;
......@@ -163,10 +163,9 @@ public:
//
// Notice that this may be called from any thread.
virtual kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() = 0;
// Wait for a message to be received and return it. If the connection fails before a message
// is received, the promise will be broken -- this is the only way to tell if a connection has
// died.
virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
// Wait for a message to be received and return it. If the read stream cleanly terminates,
// return null. If any other problem occurs, throw an exception.
// Level 3 features ----------------------------------------------
......
......@@ -30,10 +30,12 @@ namespace {
class AsyncMessageReader: public MessageReader {
public:
inline AsyncMessageReader(ReaderOptions options): MessageReader(options) {}
inline AsyncMessageReader(ReaderOptions options): MessageReader(options) {
memset(firstWord, 0, sizeof(firstWord));
}
~AsyncMessageReader() noexcept(false) {}
kj::Promise<void> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
// implements MessageReader ----------------------------------------
......@@ -56,68 +58,93 @@ private:
inline uint segmentCount() { return firstWord[0].get() + 1; }
inline uint segment0Size() { return firstWord[1].get(); }
kj::Promise<void> readAfterFirstWord(
kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
kj::Promise<void> readSegments(
kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
};
kj::Promise<void> AsyncMessageReader::read(kj::AsyncInputStream& inputStream,
kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) {
return inputStream.read(firstWord, sizeof(firstWord))
.then([this,&inputStream]() -> kj::Promise<void> {
if (segmentCount() == 0) {
firstWord[1].set(0);
return inputStream.tryRead(firstWord, sizeof(firstWord), sizeof(firstWord))
.then([this,&inputStream,scratchSpace](size_t n) mutable -> kj::Promise<bool> {
if (n == 0) {
return false;
} else if (n < sizeof(firstWord)) {
// EOF in first word.
KJ_FAIL_REQUIRE("Premature EOF.") {
return false;
}
}
// Reject messages with too many segments for security reasons.
KJ_REQUIRE(segmentCount() < 512, "Message has too many segments.") {
return kj::READY_NOW; // exception will be propagated
}
return readAfterFirstWord(inputStream, scratchSpace).then([]() { return true; });
});
}
if (segmentCount() > 1) {
// Read sizes for all segments except the first. Include padding if necessary.
moreSizes = kj::heapArray<_::WireValue<uint32_t>>(segmentCount() & ~1);
return inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0]));
} else {
return kj::READY_NOW;
}
}).then([this,&inputStream,scratchSpace]() mutable -> kj::Promise<void> {
size_t totalWords = segment0Size();
kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) {
if (segmentCount() == 0) {
firstWord[1].set(0);
}
if (segmentCount() > 1) {
for (uint i = 0; i < segmentCount() - 1; i++) {
totalWords += moreSizes[i].get();
}
}
// Reject messages with too many segments for security reasons.
KJ_REQUIRE(segmentCount() < 512, "Message has too many segments.") {
return kj::READY_NOW; // exception will be propagated
}
// Don't accept a message which the receiver couldn't possibly traverse without hitting the
// traversal limit. Without this check, a malicious client could transmit a very large segment
// size to make the receiver allocate excessive space and possibly crash.
KJ_REQUIRE(totalWords <= getOptions().traversalLimitInWords,
"Message is too large. To increase the limit on the receiving end, see "
"capnp::ReaderOptions.") {
return kj::READY_NOW; // exception will be propagated
}
if (segmentCount() > 1) {
// Read sizes for all segments except the first. Include padding if necessary.
moreSizes = kj::heapArray<_::WireValue<uint32_t>>(segmentCount() & ~1);
return inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0]))
.then([this,&inputStream,scratchSpace]() mutable {
return readSegments(inputStream, scratchSpace);
});
} else {
return readSegments(inputStream, scratchSpace);
}
}
if (scratchSpace.size() < totalWords) {
// TODO(perf): Consider allocating each segment as a separate chunk to reduce memory
// fragmentation.
ownedSpace = kj::heapArray<word>(totalWords);
scratchSpace = ownedSpace;
kj::Promise<void> AsyncMessageReader::readSegments(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) {
size_t totalWords = segment0Size();
if (segmentCount() > 1) {
for (uint i = 0; i < segmentCount() - 1; i++) {
totalWords += moreSizes[i].get();
}
}
segmentStarts = kj::heapArray<const word*>(segmentCount());
// Don't accept a message which the receiver couldn't possibly traverse without hitting the
// traversal limit. Without this check, a malicious client could transmit a very large segment
// size to make the receiver allocate excessive space and possibly crash.
KJ_REQUIRE(totalWords <= getOptions().traversalLimitInWords,
"Message is too large. To increase the limit on the receiving end, see "
"capnp::ReaderOptions.") {
return kj::READY_NOW; // exception will be propagated
}
if (scratchSpace.size() < totalWords) {
// TODO(perf): Consider allocating each segment as a separate chunk to reduce memory
// fragmentation.
ownedSpace = kj::heapArray<word>(totalWords);
scratchSpace = ownedSpace;
}
segmentStarts[0] = scratchSpace.begin();
segmentStarts = kj::heapArray<const word*>(segmentCount());
if (segmentCount() > 1) {
size_t offset = segment0Size();
segmentStarts[0] = scratchSpace.begin();
for (uint i = 1; i < segmentCount(); i++) {
segmentStarts[i] = scratchSpace.begin() + offset;
offset += moreSizes[i-1].get();
}
if (segmentCount() > 1) {
size_t offset = segment0Size();
for (uint i = 1; i < segmentCount(); i++) {
segmentStarts[i] = scratchSpace.begin() + offset;
offset += moreSizes[i-1].get();
}
}
return inputStream.read(scratchSpace.begin(), totalWords * sizeof(word));
});
return inputStream.read(scratchSpace.begin(), totalWords * sizeof(word));
}
......@@ -127,11 +154,26 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader) {
return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader, bool success) {
KJ_REQUIRE(success, "Premature EOF.") { break; }
return kj::mv(reader);
}));
}
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader,
[](kj::Own<MessageReader>&& reader, bool success) -> kj::Maybe<kj::Own<MessageReader>> {
if (success) {
return kj::mv(reader);
} else {
return nullptr;
}
}));
}
// =======================================================================================
namespace {
......
......@@ -38,6 +38,11 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
//
// `scratchSpace`, if provided, must remain valid until the returned MessageReader is destroyed.
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options = ReaderOptions(),
kj::ArrayPtr<word> scratchSpace = nullptr);
// Like `readMessage` but returns null on EOF.
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments)
KJ_WARN_UNUSED_RESULT;
......
......@@ -64,6 +64,9 @@ void registerSignalHandler(int signum) {
void registerSigusr1() {
registerSignalHandler(SIGUSR1);
// We also disable SIGPIPE because users of UnixEventLoop almost certainly don't want it.
signal(SIGPIPE, SIG_IGN);
}
pthread_once_t registerSigusr1Once = PTHREAD_ONCE_INIT;
......
......@@ -25,4 +25,6 @@
namespace kj {
const NullDisposer NullDisposer::instance = NullDisposer();
} // namespace kj
......@@ -81,6 +81,15 @@ public:
template <typename T>
const DestructorOnlyDisposer<T> DestructorOnlyDisposer<T>::instance = DestructorOnlyDisposer<T>();
class NullDisposer: public Disposer {
// A disposer that does nothing.
public:
static const NullDisposer instance;
void disposeImpl(void* pointer) const override {}
};
// =======================================================================================
// Own<T> -- An owned pointer.
......
......@@ -142,5 +142,23 @@ TEST(Mutex, Lazy) {
EXPECT_EQ(123u, lazy.get([](SpaceFor<uint>& space) { return space.construct(789); }));
}
TEST(Mutex, LazyException) {
Lazy<uint> lazy;
auto exception = kj::runCatchingExceptions([&]() {
lazy.get([&](SpaceFor<uint>& space) -> Own<uint> {
KJ_FAIL_ASSERT("foo") { break; }
return space.construct(123);
});
});
EXPECT_TRUE(exception != nullptr);
uint i = lazy.get([&](SpaceFor<uint>& space) -> Own<uint> {
return space.construct(456);
});
EXPECT_EQ(456, i);
}
} // namespace
} // namespace kj
......@@ -137,11 +137,23 @@ void Mutex::assertLockedByCaller(Exclusivity exclusivity) {
}
void Once::runOnce(Initializer& init) {
startOver:
uint state = UNINITIALIZED;
if (__atomic_compare_exchange_n(&futex, &state, INITIALIZING, false,
__ATOMIC_RELAXED, __ATOMIC_RELAXED)) {
// It's our job to initialize!
init.run();
{
KJ_ON_SCOPE_FAILURE({
// An exception was thrown by the initializer. We have to revert.
if (__atomic_exchange_n(&futex, UNINITIALIZED, __ATOMIC_RELEASE) ==
INITIALIZING_WITH_WAITERS) {
// Someone was waiting for us to finish.
syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, NULL, NULL, 0);
}
});
init.run();
}
if (__atomic_exchange_n(&futex, INITIALIZED, __ATOMIC_RELEASE) ==
INITIALIZING_WITH_WAITERS) {
// Someone was waiting for us to finish.
......@@ -165,6 +177,12 @@ void Once::runOnce(Initializer& init) {
// Wait for initialization.
syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0);
state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE);
if (state == UNINITIALIZED) {
// Oh hey, apparently whoever was trying to initialize gave up. Let's take it from the
// top.
goto startOver;
}
}
}
}
......@@ -209,7 +227,7 @@ void Once::disable() noexcept {
// Wait for initialization.
syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0);
state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE);
break;
continue;
}
}
}
......
......@@ -28,22 +28,26 @@
namespace kj {
Thread::Thread(void* (*run)(void*), void (*deleteArg)(void*), void* arg) {
Thread::Thread(Function<void()> func): func(kj::mv(func)) {
static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port.");
int pthreadResult = pthread_create(reinterpret_cast<pthread_t*>(&threadId), nullptr, run, arg);
int pthreadResult = pthread_create(reinterpret_cast<pthread_t*>(&threadId),
nullptr, &runThread, this);
if (pthreadResult != 0) {
deleteArg(arg);
KJ_FAIL_SYSCALL("pthread_create", pthreadResult);
}
}
Thread::~Thread() {
Thread::~Thread() noexcept(false) {
int pthreadResult = pthread_join(*reinterpret_cast<pthread_t*>(&threadId), nullptr);
if (pthreadResult != 0) {
KJ_FAIL_SYSCALL("pthread_join", pthreadResult) { break; }
}
KJ_IF_MAYBE(e, exception) {
kj::throwRecoverableException(kj::mv(*e));
}
}
void Thread::sendSignal(int signo) {
......@@ -53,4 +57,14 @@ void Thread::sendSignal(int signo) {
}
}
void* Thread::runThread(void* ptr) {
Thread* thread = reinterpret_cast<Thread*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
thread->func();
})) {
thread->exception = kj::mv(*exception);
}
return nullptr;
}
} // namespace kj
......@@ -25,43 +25,30 @@
#define KJ_THREAD_H_
#include "common.h"
#include "function.h"
#include "exception.h"
namespace kj {
class Thread {
// A thread! Pass a lambda to the constructor. The destructor joins the thread.
// A thread! Pass a lambda to the constructor, and it runs in the thread. The destructor joins
// the thread. If the function throws an exception, it is rethrown from the thread's destructor
// (if not unwinding from another exception).
public:
template <typename Func>
explicit Thread(Func&& func)
: Thread(&runThread<Decay<Func>>,
&deleteArg<Decay<Func>>,
new Decay<Func>(kj::fwd<Func>(func))) {}
explicit Thread(Function<void()> func);
~Thread();
~Thread() noexcept(false);
void sendSignal(int signo);
// Send a Unix signal to the given thread, using pthread_kill or an equivalent.
private:
Function<void()> func;
unsigned long long threadId; // actually pthread_t
kj::Maybe<kj::Exception> exception;
Thread(void* (*run)(void*), void (*deleteArg)(void*), void* arg);
template <typename Func>
static void* runThread(void* ptr) {
// TODO(someday): Catch exceptions and propagate to the joiner.
Func* func = reinterpret_cast<Func*>(ptr);
KJ_DEFER(delete func);
(*func)();
return nullptr;
}
template <typename Func>
static void deleteArg(void* ptr) {
delete reinterpret_cast<Func*>(ptr);
}
static void* runThread(void* ptr);
};
} // namespace kj
......
linux-gcc-4.7 1735 ./super-test.sh tmpdir capnp-gcc-4.7 quick
linux-gcc-4.8 1738 ./super-test.sh tmpdir capnp-gcc-4.8 quick gcc-4.8
linux-clang 1758 ./super-test.sh tmpdir capnp-clang quick clang
mac 807 ./super-test.sh remote beat caffeinate quick
linux-gcc-4.7 1737 ./super-test.sh tmpdir capnp-gcc-4.7 quick
linux-gcc-4.8 1740 ./super-test.sh tmpdir capnp-gcc-4.8 quick gcc-4.8
linux-clang 1760 ./super-test.sh tmpdir capnp-clang quick clang
mac 805 ./super-test.sh remote beat caffeinate quick
cygwin 810 ./super-test.sh remote Kenton@flashman quick
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment