Commit 2aa37a8a authored by Kenton Varda's avatar Kenton Varda

Make RPC system handle disconnect.

parent 38bbbd32
...@@ -215,7 +215,7 @@ private: ...@@ -215,7 +215,7 @@ private:
kj::Own<const ClientHook> BrokenPipeline::getPipelinedCap( kj::Own<const ClientHook> BrokenPipeline::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) const { kj::ArrayPtr<const PipelineOp> ops) const {
return kj::heap<BrokenClient>(exception); return kj::refcounted<BrokenClient>(exception);
} }
} // namespace } // namespace
...@@ -228,4 +228,8 @@ kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason) { ...@@ -228,4 +228,8 @@ kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason) {
return kj::refcounted<BrokenClient>(kj::mv(reason)); return kj::refcounted<BrokenClient>(kj::mv(reason));
} }
kj::Own<const PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
return kj::refcounted<BrokenPipeline>(kj::mv(reason));
}
} // namespace capnp } // namespace capnp
...@@ -211,6 +211,9 @@ kj::Own<const ClientHook> newBrokenCap(kj::StringPtr reason); ...@@ -211,6 +211,9 @@ kj::Own<const ClientHook> newBrokenCap(kj::StringPtr reason);
kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason); kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called. // Helper function that creates a capability which simply throws exceptions when called.
kj::Own<const PipelineHook> newBrokenPipeline(kj::Exception&& reason);
// Helper function that creates a pipeline which simply throws exceptions when called.
// ======================================================================================= // =======================================================================================
// inline implementation details // inline implementation details
......
...@@ -105,7 +105,7 @@ public: ...@@ -105,7 +105,7 @@ public:
if (lock->fulfillers.empty()) { if (lock->fulfillers.empty()) {
lock->messages.push(kj::mv(message)); lock->messages.push(kj::mv(message));
} else { } else {
lock->fulfillers.front()->fulfill(kj::mv(message)); lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message)));
lock->fulfillers.pop(); lock->fulfillers.pop();
} }
} }
...@@ -119,16 +119,16 @@ public: ...@@ -119,16 +119,16 @@ public:
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override { kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override {
return kj::heap<OutgoingRpcMessageImpl>(*this, firstSegmentWordSize); return kj::heap<OutgoingRpcMessageImpl>(*this, firstSegmentWordSize);
} }
kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() override { kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override {
auto lock = queues.lockExclusive(); auto lock = queues.lockExclusive();
if (lock->messages.empty()) { if (lock->messages.empty()) {
auto paf = kj::newPromiseAndFulfiller<kj::Own<IncomingRpcMessage>>(); auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>();
lock->fulfillers.push(kj::mv(paf.fulfiller)); lock->fulfillers.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise); return kj::mv(paf.promise);
} else { } else {
auto result = kj::mv(lock->messages.front()); auto result = kj::mv(lock->messages.front());
lock->messages.pop(); lock->messages.pop();
return kj::mv(result); return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result));
} }
} }
void introduceTo(Connection& recipient, void introduceTo(Connection& recipient,
...@@ -149,7 +149,7 @@ public: ...@@ -149,7 +149,7 @@ public:
kj::Maybe<ConnectionImpl&> partner; kj::Maybe<ConnectionImpl&> partner;
struct Queues { struct Queues {
std::queue<kj::Own<kj::PromiseFulfiller<kj::Own<IncomingRpcMessage>>>> fulfillers; std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers;
std::queue<kj::Own<IncomingRpcMessage>> messages; std::queue<kj::Own<IncomingRpcMessage>> messages;
}; };
kj::MutexGuarded<Queues> queues; kj::MutexGuarded<Queues> queues;
......
...@@ -144,7 +144,7 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -144,7 +144,7 @@ TEST(TwoPartyNetwork, Pipelining) {
// Start up server in another thread. // Start up server in another thread.
auto quitter = kj::newPromiseAndFulfiller<void>(); auto quitter = kj::newPromiseAndFulfiller<void>();
kj::Thread thread([&]() { auto thread = kj::heap<kj::Thread>([&]() {
runServer(kj::mv(quitter.promise), kj::mv(pipe.ends[1]), callCount); runServer(kj::mv(quitter.promise), kj::mv(pipe.ends[1]), callCount);
}); });
KJ_DEFER(quitter.fulfiller->fulfill()); // Stop the server loop before destroying the thread. KJ_DEFER(quitter.fulfiller->fulfill()); // Stop the server loop before destroying the thread.
...@@ -154,38 +154,88 @@ TEST(TwoPartyNetwork, Pipelining) { ...@@ -154,38 +154,88 @@ TEST(TwoPartyNetwork, Pipelining) {
TwoPartyVatNetwork network(loop, *pipe.ends[0], rpc::twoparty::Side::CLIENT); TwoPartyVatNetwork network(loop, *pipe.ends[0], rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network, loop); auto rpcClient = makeRpcClient(network, loop);
// Request the particular capability from the server. bool disconnected = false;
auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER, bool drained = false;
test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>(); kj::Promise<void> disconnectPromise = loop.there(network.onDisconnect(),
[&]() { disconnected = true; });
kj::Promise<void> drainedPromise = loop.there(network.onDrained(),
[&]() { drained = true; });
// Use the capability. {
auto request = client.getCapRequest(); // Request the particular capability from the server.
request.setN(234); auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
request.setInCap(test::TestInterface::Client( test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>();
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
auto promise = request.send(); {
// Use the capability.
auto request = client.getCapRequest();
request.setN(234);
request.setInCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
auto pipelineRequest = promise.getOutBox().getCap().fooRequest(); auto promise = request.send();
pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
auto pipelineRequest2 = promise.getOutBox().getCap().castAs<test::TestExtends>().graultRequest(); auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
auto pipelinePromise2 = pipelineRequest2.send(); pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
promise = nullptr; // Just to be annoying, drop the original promise. auto pipelineRequest2 = promise.getOutBox().getCap()
.castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
EXPECT_EQ(0, callCount); promise = nullptr; // Just to be annoying, drop the original promise.
EXPECT_EQ(0, reverseCallCount);
EXPECT_EQ(0, callCount);
EXPECT_EQ(0, reverseCallCount);
auto response = loop.wait(kj::mv(pipelinePromise));
EXPECT_EQ("bar", response.getX());
auto response2 = loop.wait(kj::mv(pipelinePromise2));
checkTestMessage(response2);
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
}
auto response = loop.wait(kj::mv(pipelinePromise)); EXPECT_FALSE(disconnected);
EXPECT_EQ("bar", response.getX()); EXPECT_FALSE(drained);
auto response2 = loop.wait(kj::mv(pipelinePromise2)); // What if the other side disconnects?
checkTestMessage(response2); quitter.fulfiller->fulfill();
thread = nullptr;
loop.wait(kj::mv(disconnectPromise));
EXPECT_FALSE(drained);
{
// Use the now-broken capability.
auto request = client.getCapRequest();
request.setN(234);
request.setInCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
auto promise = request.send();
auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
auto pipelineRequest2 = promise.getOutBox().getCap()
.castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
EXPECT_ANY_THROW(loop.wait(kj::mv(pipelinePromise)));
EXPECT_ANY_THROW(loop.wait(kj::mv(pipelinePromise2)));
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
}
EXPECT_FALSE(drained);
}
EXPECT_EQ(3, callCount); loop.wait(kj::mv(drainedPromise));
EXPECT_EQ(1, reverseCallCount);
} }
} // namespace } // namespace
......
...@@ -31,15 +31,25 @@ TwoPartyVatNetwork::TwoPartyVatNetwork( ...@@ -31,15 +31,25 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(
const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream, rpc::twoparty::Side side, const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions) ReaderOptions receiveOptions)
: eventLoop(eventLoop), stream(stream), side(side), receiveOptions(receiveOptions), : eventLoop(eventLoop), stream(stream), side(side), receiveOptions(receiveOptions),
previousWrite(kj::READY_NOW) {} previousWrite(kj::READY_NOW) {
{
auto paf = kj::newPromiseAndFulfiller<void>();
disconnectPromise = eventLoop.fork(kj::mv(paf.promise));
disconnectFulfiller.getWithoutLock() = kj::mv(paf.fulfiller);
}
{
auto paf = kj::newPromiseAndFulfiller<void>();
drainedPromise = eventLoop.fork(kj::mv(paf.promise));
drainedFulfiller.fulfiller.getWithoutLock() = kj::mv(paf.fulfiller);
}
}
kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connectToRefHost( kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connectToRefHost(
rpc::twoparty::SturdyRefHostId::Reader ref) { rpc::twoparty::SturdyRefHostId::Reader ref) {
if (ref.getSide() == side) { if (ref.getSide() == side) {
return nullptr; return nullptr;
} else { } else {
return kj::Own<TwoPartyVatNetworkBase::Connection>(this, return kj::Own<TwoPartyVatNetworkBase::Connection>(this, drainedFulfiller);
kj::DestructorOnlyDisposer<TwoPartyVatNetworkBase::Connection>::instance);
} }
} }
...@@ -70,13 +80,21 @@ public: ...@@ -70,13 +80,21 @@ public:
void send() override { void send() override {
auto lock = network.previousWrite.lockExclusive(); auto lock = network.previousWrite.lockExclusive();
*lock = network.eventLoop.there(network.eventLoop.there(kj::mv(*lock), [this]() { *lock = network.eventLoop.there(kj::mv(*lock),
return writeMessage(network.stream, message); kj::mvCapture(kj::addRef(*this), [&](kj::Own<OutgoingMessageImpl>&& self) {
}), kj::mvCapture(kj::addRef(*this), return writeMessage(network.stream, message)
[](kj::Own<OutgoingMessageImpl>&& self) -> kj::Promise<void> { .then(kj::mvCapture(kj::mv(self),
// Hack to force this continuation to run (thus allowing `self` to be released) even if [](kj::Own<OutgoingMessageImpl>&& self) -> kj::Promise<void> {
// no one is waiting on the promise. // Just here to hold a reference to `self` until the write completes.
return kj::READY_NOW;
// Hack to force this continuation to run (thus allowing `self` to be released) even if
// no one is waiting on the promise.
return kj::READY_NOW;
}), [&](kj::Exception&& exception) -> kj::Promise<void> {
// Exception during write!
network.disconnectFulfiller.lockExclusive()->get()->fulfill();
return kj::READY_NOW;
});
})); }));
} }
...@@ -102,11 +120,21 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage( ...@@ -102,11 +120,21 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(
return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize); return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize);
} }
kj::Promise<kj::Own<IncomingRpcMessage>> TwoPartyVatNetwork::receiveIncomingMessage() { kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
return eventLoop.evalLater([&]() { return eventLoop.evalLater([&]() {
return readMessage(stream, receiveOptions) return tryReadMessage(stream, receiveOptions)
.then([](kj::Own<MessageReader>&& message) -> kj::Own<IncomingRpcMessage> { .then([&](kj::Maybe<kj::Own<MessageReader>>&& message)
return kj::heap<IncomingMessageImpl>(kj::mv(message)); -> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
} else {
disconnectFulfiller.lockExclusive()->get()->fulfill();
return nullptr;
}
}, [&](kj::Exception&& exception) {
disconnectFulfiller.lockExclusive()->get()->fulfill();
kj::throwRecoverableException(kj::mv(exception));
return nullptr;
}); });
}); });
} }
......
...@@ -41,6 +41,15 @@ public: ...@@ -41,6 +41,15 @@ public:
TwoPartyVatNetwork(const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream, TwoPartyVatNetwork(const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream,
rpc::twoparty::Side side, ReaderOptions receiveOptions = ReaderOptions()); rpc::twoparty::Side side, ReaderOptions receiveOptions = ReaderOptions());
kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); }
// Returns a promise that resolves when the peer disconnects.
kj::Promise<void> onDrained() { return drainedPromise.addBranch(); }
// Returns a promise that resolves once the peer has disconnected *and* all local objects
// referencing this connection have been destroyed. A caller might use this to decide when it
// is safe to destroy the RpcSystem, if it isn't able to reliably destroy all objects using it
// directly.
// implements VatNetwork ----------------------------------------------------- // implements VatNetwork -----------------------------------------------------
kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> connectToRefHost( kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> connectToRefHost(
...@@ -64,10 +73,22 @@ private: ...@@ -64,10 +73,22 @@ private:
// Fulfiller for the promise returned by acceptConnectionAsRefHost() on the client side, or the // Fulfiller for the promise returned by acceptConnectionAsRefHost() on the client side, or the
// second call on the server side. Never fulfilled, because there is only one connection. // second call on the server side. Never fulfilled, because there is only one connection.
kj::ForkedPromise<void> disconnectPromise = nullptr;
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> disconnectFulfiller;
kj::ForkedPromise<void> drainedPromise = nullptr;
class FulfillerDisposer: public kj::Disposer {
public:
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> fulfiller;
void disposeImpl(void* pointer) const override { fulfiller.lockExclusive()->get()->fulfill(); }
};
FulfillerDisposer drainedFulfiller;
// implements Connection ----------------------------------------------------- // implements Connection -----------------------------------------------------
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override; kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override;
kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() override; kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override;
void introduceTo(TwoPartyVatNetworkBase::Connection& recipient, void introduceTo(TwoPartyVatNetworkBase::Connection& recipient,
rpc::twoparty::ThirdPartyCapId::Builder sendToRecipient, rpc::twoparty::ThirdPartyCapId::Builder sendToRecipient,
rpc::twoparty::RecipientId::Builder sendToTarget) override; rpc::twoparty::RecipientId::Builder sendToTarget) override;
......
This diff is collapsed.
...@@ -31,7 +31,7 @@ namespace capnp { ...@@ -31,7 +31,7 @@ namespace capnp {
// ======================================================================================= // =======================================================================================
// *************************************************************************************** // ***************************************************************************************
// This section contains various internal stuff that needs to be declared upfront. // This section contains various internal stuff that needs to be declared upfront.
// Scroll down to `class EventLoop` or `class Promise` for the public interfaces. // Scroll down to `class VatNetwork` or `class RpcSystem` for the public interfaces.
// *************************************************************************************** // ***************************************************************************************
// ======================================================================================= // =======================================================================================
...@@ -60,7 +60,7 @@ public: ...@@ -60,7 +60,7 @@ public:
class Connection { class Connection {
public: public:
virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const = 0; virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const = 0;
virtual kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() = 0; virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
virtual void baseIntroduceTo(Connection& recipient, virtual void baseIntroduceTo(Connection& recipient,
ObjectPointer::Builder sendToRecipient, ObjectPointer::Builder sendToRecipient,
ObjectPointer::Builder sendToTarget) = 0; ObjectPointer::Builder sendToTarget) = 0;
...@@ -163,10 +163,9 @@ public: ...@@ -163,10 +163,9 @@ public:
// //
// Notice that this may be called from any thread. // Notice that this may be called from any thread.
virtual kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() = 0; virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
// Wait for a message to be received and return it. If the connection fails before a message // Wait for a message to be received and return it. If the read stream cleanly terminates,
// is received, the promise will be broken -- this is the only way to tell if a connection has // return null. If any other problem occurs, throw an exception.
// died.
// Level 3 features ---------------------------------------------- // Level 3 features ----------------------------------------------
......
...@@ -30,10 +30,12 @@ namespace { ...@@ -30,10 +30,12 @@ namespace {
class AsyncMessageReader: public MessageReader { class AsyncMessageReader: public MessageReader {
public: public:
inline AsyncMessageReader(ReaderOptions options): MessageReader(options) {} inline AsyncMessageReader(ReaderOptions options): MessageReader(options) {
memset(firstWord, 0, sizeof(firstWord));
}
~AsyncMessageReader() noexcept(false) {} ~AsyncMessageReader() noexcept(false) {}
kj::Promise<void> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace); kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
// implements MessageReader ---------------------------------------- // implements MessageReader ----------------------------------------
...@@ -56,68 +58,93 @@ private: ...@@ -56,68 +58,93 @@ private:
inline uint segmentCount() { return firstWord[0].get() + 1; } inline uint segmentCount() { return firstWord[0].get() + 1; }
inline uint segment0Size() { return firstWord[1].get(); } inline uint segment0Size() { return firstWord[1].get(); }
kj::Promise<void> readAfterFirstWord(
kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
kj::Promise<void> readSegments(
kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
}; };
kj::Promise<void> AsyncMessageReader::read(kj::AsyncInputStream& inputStream, kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) { kj::ArrayPtr<word> scratchSpace) {
return inputStream.read(firstWord, sizeof(firstWord)) return inputStream.tryRead(firstWord, sizeof(firstWord), sizeof(firstWord))
.then([this,&inputStream]() -> kj::Promise<void> { .then([this,&inputStream,scratchSpace](size_t n) mutable -> kj::Promise<bool> {
if (segmentCount() == 0) { if (n == 0) {
firstWord[1].set(0); return false;
} else if (n < sizeof(firstWord)) {
// EOF in first word.
KJ_FAIL_REQUIRE("Premature EOF.") {
return false;
}
} }
// Reject messages with too many segments for security reasons. return readAfterFirstWord(inputStream, scratchSpace).then([]() { return true; });
KJ_REQUIRE(segmentCount() < 512, "Message has too many segments.") { });
return kj::READY_NOW; // exception will be propagated }
}
if (segmentCount() > 1) { kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream,
// Read sizes for all segments except the first. Include padding if necessary. kj::ArrayPtr<word> scratchSpace) {
moreSizes = kj::heapArray<_::WireValue<uint32_t>>(segmentCount() & ~1); if (segmentCount() == 0) {
return inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0])); firstWord[1].set(0);
} else { }
return kj::READY_NOW;
}
}).then([this,&inputStream,scratchSpace]() mutable -> kj::Promise<void> {
size_t totalWords = segment0Size();
if (segmentCount() > 1) { // Reject messages with too many segments for security reasons.
for (uint i = 0; i < segmentCount() - 1; i++) { KJ_REQUIRE(segmentCount() < 512, "Message has too many segments.") {
totalWords += moreSizes[i].get(); return kj::READY_NOW; // exception will be propagated
} }
}
// Don't accept a message which the receiver couldn't possibly traverse without hitting the if (segmentCount() > 1) {
// traversal limit. Without this check, a malicious client could transmit a very large segment // Read sizes for all segments except the first. Include padding if necessary.
// size to make the receiver allocate excessive space and possibly crash. moreSizes = kj::heapArray<_::WireValue<uint32_t>>(segmentCount() & ~1);
KJ_REQUIRE(totalWords <= getOptions().traversalLimitInWords, return inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0]))
"Message is too large. To increase the limit on the receiving end, see " .then([this,&inputStream,scratchSpace]() mutable {
"capnp::ReaderOptions.") { return readSegments(inputStream, scratchSpace);
return kj::READY_NOW; // exception will be propagated });
} } else {
return readSegments(inputStream, scratchSpace);
}
}
if (scratchSpace.size() < totalWords) { kj::Promise<void> AsyncMessageReader::readSegments(kj::AsyncInputStream& inputStream,
// TODO(perf): Consider allocating each segment as a separate chunk to reduce memory kj::ArrayPtr<word> scratchSpace) {
// fragmentation. size_t totalWords = segment0Size();
ownedSpace = kj::heapArray<word>(totalWords);
scratchSpace = ownedSpace; if (segmentCount() > 1) {
for (uint i = 0; i < segmentCount() - 1; i++) {
totalWords += moreSizes[i].get();
} }
}
segmentStarts = kj::heapArray<const word*>(segmentCount()); // Don't accept a message which the receiver couldn't possibly traverse without hitting the
// traversal limit. Without this check, a malicious client could transmit a very large segment
// size to make the receiver allocate excessive space and possibly crash.
KJ_REQUIRE(totalWords <= getOptions().traversalLimitInWords,
"Message is too large. To increase the limit on the receiving end, see "
"capnp::ReaderOptions.") {
return kj::READY_NOW; // exception will be propagated
}
if (scratchSpace.size() < totalWords) {
// TODO(perf): Consider allocating each segment as a separate chunk to reduce memory
// fragmentation.
ownedSpace = kj::heapArray<word>(totalWords);
scratchSpace = ownedSpace;
}
segmentStarts[0] = scratchSpace.begin(); segmentStarts = kj::heapArray<const word*>(segmentCount());
if (segmentCount() > 1) { segmentStarts[0] = scratchSpace.begin();
size_t offset = segment0Size();
for (uint i = 1; i < segmentCount(); i++) { if (segmentCount() > 1) {
segmentStarts[i] = scratchSpace.begin() + offset; size_t offset = segment0Size();
offset += moreSizes[i-1].get();
} for (uint i = 1; i < segmentCount(); i++) {
segmentStarts[i] = scratchSpace.begin() + offset;
offset += moreSizes[i-1].get();
} }
}
return inputStream.read(scratchSpace.begin(), totalWords * sizeof(word)); return inputStream.read(scratchSpace.begin(), totalWords * sizeof(word));
});
} }
...@@ -127,11 +154,26 @@ kj::Promise<kj::Own<MessageReader>> readMessage( ...@@ -127,11 +154,26 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options); auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace); auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader) { return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader, bool success) {
KJ_REQUIRE(success, "Premature EOF.") { break; }
return kj::mv(reader); return kj::mv(reader);
})); }));
} }
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader,
[](kj::Own<MessageReader>&& reader, bool success) -> kj::Maybe<kj::Own<MessageReader>> {
if (success) {
return kj::mv(reader);
} else {
return nullptr;
}
}));
}
// ======================================================================================= // =======================================================================================
namespace { namespace {
......
...@@ -38,6 +38,11 @@ kj::Promise<kj::Own<MessageReader>> readMessage( ...@@ -38,6 +38,11 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
// //
// `scratchSpace`, if provided, must remain valid until the returned MessageReader is destroyed. // `scratchSpace`, if provided, must remain valid until the returned MessageReader is destroyed.
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options = ReaderOptions(),
kj::ArrayPtr<word> scratchSpace = nullptr);
// Like `readMessage` but returns null on EOF.
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) kj::ArrayPtr<const kj::ArrayPtr<const word>> segments)
KJ_WARN_UNUSED_RESULT; KJ_WARN_UNUSED_RESULT;
......
...@@ -64,6 +64,9 @@ void registerSignalHandler(int signum) { ...@@ -64,6 +64,9 @@ void registerSignalHandler(int signum) {
void registerSigusr1() { void registerSigusr1() {
registerSignalHandler(SIGUSR1); registerSignalHandler(SIGUSR1);
// We also disable SIGPIPE because users of UnixEventLoop almost certainly don't want it.
signal(SIGPIPE, SIG_IGN);
} }
pthread_once_t registerSigusr1Once = PTHREAD_ONCE_INIT; pthread_once_t registerSigusr1Once = PTHREAD_ONCE_INIT;
......
...@@ -25,4 +25,6 @@ ...@@ -25,4 +25,6 @@
namespace kj { namespace kj {
const NullDisposer NullDisposer::instance = NullDisposer();
} // namespace kj } // namespace kj
...@@ -81,6 +81,15 @@ public: ...@@ -81,6 +81,15 @@ public:
template <typename T> template <typename T>
const DestructorOnlyDisposer<T> DestructorOnlyDisposer<T>::instance = DestructorOnlyDisposer<T>(); const DestructorOnlyDisposer<T> DestructorOnlyDisposer<T>::instance = DestructorOnlyDisposer<T>();
class NullDisposer: public Disposer {
// A disposer that does nothing.
public:
static const NullDisposer instance;
void disposeImpl(void* pointer) const override {}
};
// ======================================================================================= // =======================================================================================
// Own<T> -- An owned pointer. // Own<T> -- An owned pointer.
......
...@@ -142,5 +142,23 @@ TEST(Mutex, Lazy) { ...@@ -142,5 +142,23 @@ TEST(Mutex, Lazy) {
EXPECT_EQ(123u, lazy.get([](SpaceFor<uint>& space) { return space.construct(789); })); EXPECT_EQ(123u, lazy.get([](SpaceFor<uint>& space) { return space.construct(789); }));
} }
TEST(Mutex, LazyException) {
Lazy<uint> lazy;
auto exception = kj::runCatchingExceptions([&]() {
lazy.get([&](SpaceFor<uint>& space) -> Own<uint> {
KJ_FAIL_ASSERT("foo") { break; }
return space.construct(123);
});
});
EXPECT_TRUE(exception != nullptr);
uint i = lazy.get([&](SpaceFor<uint>& space) -> Own<uint> {
return space.construct(456);
});
EXPECT_EQ(456, i);
}
} // namespace } // namespace
} // namespace kj } // namespace kj
...@@ -137,11 +137,23 @@ void Mutex::assertLockedByCaller(Exclusivity exclusivity) { ...@@ -137,11 +137,23 @@ void Mutex::assertLockedByCaller(Exclusivity exclusivity) {
} }
void Once::runOnce(Initializer& init) { void Once::runOnce(Initializer& init) {
startOver:
uint state = UNINITIALIZED; uint state = UNINITIALIZED;
if (__atomic_compare_exchange_n(&futex, &state, INITIALIZING, false, if (__atomic_compare_exchange_n(&futex, &state, INITIALIZING, false,
__ATOMIC_RELAXED, __ATOMIC_RELAXED)) { __ATOMIC_RELAXED, __ATOMIC_RELAXED)) {
// It's our job to initialize! // It's our job to initialize!
init.run(); {
KJ_ON_SCOPE_FAILURE({
// An exception was thrown by the initializer. We have to revert.
if (__atomic_exchange_n(&futex, UNINITIALIZED, __ATOMIC_RELEASE) ==
INITIALIZING_WITH_WAITERS) {
// Someone was waiting for us to finish.
syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, NULL, NULL, 0);
}
});
init.run();
}
if (__atomic_exchange_n(&futex, INITIALIZED, __ATOMIC_RELEASE) == if (__atomic_exchange_n(&futex, INITIALIZED, __ATOMIC_RELEASE) ==
INITIALIZING_WITH_WAITERS) { INITIALIZING_WITH_WAITERS) {
// Someone was waiting for us to finish. // Someone was waiting for us to finish.
...@@ -165,6 +177,12 @@ void Once::runOnce(Initializer& init) { ...@@ -165,6 +177,12 @@ void Once::runOnce(Initializer& init) {
// Wait for initialization. // Wait for initialization.
syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0); syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0);
state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE); state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE);
if (state == UNINITIALIZED) {
// Oh hey, apparently whoever was trying to initialize gave up. Let's take it from the
// top.
goto startOver;
}
} }
} }
} }
...@@ -209,7 +227,7 @@ void Once::disable() noexcept { ...@@ -209,7 +227,7 @@ void Once::disable() noexcept {
// Wait for initialization. // Wait for initialization.
syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0); syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0);
state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE); state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE);
break; continue;
} }
} }
} }
......
...@@ -28,22 +28,26 @@ ...@@ -28,22 +28,26 @@
namespace kj { namespace kj {
Thread::Thread(void* (*run)(void*), void (*deleteArg)(void*), void* arg) { Thread::Thread(Function<void()> func): func(kj::mv(func)) {
static_assert(sizeof(threadId) >= sizeof(pthread_t), static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port."); "pthread_t is larger than a long long on your platform. Please port.");
int pthreadResult = pthread_create(reinterpret_cast<pthread_t*>(&threadId), nullptr, run, arg); int pthreadResult = pthread_create(reinterpret_cast<pthread_t*>(&threadId),
nullptr, &runThread, this);
if (pthreadResult != 0) { if (pthreadResult != 0) {
deleteArg(arg);
KJ_FAIL_SYSCALL("pthread_create", pthreadResult); KJ_FAIL_SYSCALL("pthread_create", pthreadResult);
} }
} }
Thread::~Thread() { Thread::~Thread() noexcept(false) {
int pthreadResult = pthread_join(*reinterpret_cast<pthread_t*>(&threadId), nullptr); int pthreadResult = pthread_join(*reinterpret_cast<pthread_t*>(&threadId), nullptr);
if (pthreadResult != 0) { if (pthreadResult != 0) {
KJ_FAIL_SYSCALL("pthread_join", pthreadResult) { break; } KJ_FAIL_SYSCALL("pthread_join", pthreadResult) { break; }
} }
KJ_IF_MAYBE(e, exception) {
kj::throwRecoverableException(kj::mv(*e));
}
} }
void Thread::sendSignal(int signo) { void Thread::sendSignal(int signo) {
...@@ -53,4 +57,14 @@ void Thread::sendSignal(int signo) { ...@@ -53,4 +57,14 @@ void Thread::sendSignal(int signo) {
} }
} }
void* Thread::runThread(void* ptr) {
Thread* thread = reinterpret_cast<Thread*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
thread->func();
})) {
thread->exception = kj::mv(*exception);
}
return nullptr;
}
} // namespace kj } // namespace kj
...@@ -25,43 +25,30 @@ ...@@ -25,43 +25,30 @@
#define KJ_THREAD_H_ #define KJ_THREAD_H_
#include "common.h" #include "common.h"
#include "function.h"
#include "exception.h"
namespace kj { namespace kj {
class Thread { class Thread {
// A thread! Pass a lambda to the constructor. The destructor joins the thread. // A thread! Pass a lambda to the constructor, and it runs in the thread. The destructor joins
// the thread. If the function throws an exception, it is rethrown from the thread's destructor
// (if not unwinding from another exception).
public: public:
template <typename Func> explicit Thread(Function<void()> func);
explicit Thread(Func&& func)
: Thread(&runThread<Decay<Func>>,
&deleteArg<Decay<Func>>,
new Decay<Func>(kj::fwd<Func>(func))) {}
~Thread(); ~Thread() noexcept(false);
void sendSignal(int signo); void sendSignal(int signo);
// Send a Unix signal to the given thread, using pthread_kill or an equivalent. // Send a Unix signal to the given thread, using pthread_kill or an equivalent.
private: private:
Function<void()> func;
unsigned long long threadId; // actually pthread_t unsigned long long threadId; // actually pthread_t
kj::Maybe<kj::Exception> exception;
Thread(void* (*run)(void*), void (*deleteArg)(void*), void* arg); static void* runThread(void* ptr);
template <typename Func>
static void* runThread(void* ptr) {
// TODO(someday): Catch exceptions and propagate to the joiner.
Func* func = reinterpret_cast<Func*>(ptr);
KJ_DEFER(delete func);
(*func)();
return nullptr;
}
template <typename Func>
static void deleteArg(void* ptr) {
delete reinterpret_cast<Func*>(ptr);
}
}; };
} // namespace kj } // namespace kj
......
linux-gcc-4.7 1735 ./super-test.sh tmpdir capnp-gcc-4.7 quick linux-gcc-4.7 1737 ./super-test.sh tmpdir capnp-gcc-4.7 quick
linux-gcc-4.8 1738 ./super-test.sh tmpdir capnp-gcc-4.8 quick gcc-4.8 linux-gcc-4.8 1740 ./super-test.sh tmpdir capnp-gcc-4.8 quick gcc-4.8
linux-clang 1758 ./super-test.sh tmpdir capnp-clang quick clang linux-clang 1760 ./super-test.sh tmpdir capnp-clang quick clang
mac 807 ./super-test.sh remote beat caffeinate quick mac 805 ./super-test.sh remote beat caffeinate quick
cygwin 810 ./super-test.sh remote Kenton@flashman quick cygwin 810 ./super-test.sh remote Kenton@flashman quick
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment