Commit 2aa37a8a authored by Kenton Varda's avatar Kenton Varda

Make RPC system handle disconnect.

parent 38bbbd32
......@@ -215,7 +215,7 @@ private:
kj::Own<const ClientHook> BrokenPipeline::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) const {
return kj::heap<BrokenClient>(exception);
return kj::refcounted<BrokenClient>(exception);
}
} // namespace
......@@ -228,4 +228,8 @@ kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason) {
return kj::refcounted<BrokenClient>(kj::mv(reason));
}
kj::Own<const PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
return kj::refcounted<BrokenPipeline>(kj::mv(reason));
}
} // namespace capnp
......@@ -211,6 +211,9 @@ kj::Own<const ClientHook> newBrokenCap(kj::StringPtr reason);
kj::Own<const ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called.
kj::Own<const PipelineHook> newBrokenPipeline(kj::Exception&& reason);
// Helper function that creates a pipeline which simply throws exceptions when called.
// =======================================================================================
// inline implementation details
......
......@@ -105,7 +105,7 @@ public:
if (lock->fulfillers.empty()) {
lock->messages.push(kj::mv(message));
} else {
lock->fulfillers.front()->fulfill(kj::mv(message));
lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message)));
lock->fulfillers.pop();
}
}
......@@ -119,16 +119,16 @@ public:
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override {
return kj::heap<OutgoingRpcMessageImpl>(*this, firstSegmentWordSize);
}
kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() override {
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override {
auto lock = queues.lockExclusive();
if (lock->messages.empty()) {
auto paf = kj::newPromiseAndFulfiller<kj::Own<IncomingRpcMessage>>();
auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>();
lock->fulfillers.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise);
} else {
auto result = kj::mv(lock->messages.front());
lock->messages.pop();
return kj::mv(result);
return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result));
}
}
void introduceTo(Connection& recipient,
......@@ -149,7 +149,7 @@ public:
kj::Maybe<ConnectionImpl&> partner;
struct Queues {
std::queue<kj::Own<kj::PromiseFulfiller<kj::Own<IncomingRpcMessage>>>> fulfillers;
std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers;
std::queue<kj::Own<IncomingRpcMessage>> messages;
};
kj::MutexGuarded<Queues> queues;
......
......@@ -144,7 +144,7 @@ TEST(TwoPartyNetwork, Pipelining) {
// Start up server in another thread.
auto quitter = kj::newPromiseAndFulfiller<void>();
kj::Thread thread([&]() {
auto thread = kj::heap<kj::Thread>([&]() {
runServer(kj::mv(quitter.promise), kj::mv(pipe.ends[1]), callCount);
});
KJ_DEFER(quitter.fulfiller->fulfill()); // Stop the server loop before destroying the thread.
......@@ -154,38 +154,88 @@ TEST(TwoPartyNetwork, Pipelining) {
TwoPartyVatNetwork network(loop, *pipe.ends[0], rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network, loop);
// Request the particular capability from the server.
auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>();
bool disconnected = false;
bool drained = false;
kj::Promise<void> disconnectPromise = loop.there(network.onDisconnect(),
[&]() { disconnected = true; });
kj::Promise<void> drainedPromise = loop.there(network.onDrained(),
[&]() { drained = true; });
// Use the capability.
auto request = client.getCapRequest();
request.setN(234);
request.setInCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
{
// Request the particular capability from the server.
auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>();
auto promise = request.send();
{
// Use the capability.
auto request = client.getCapRequest();
request.setN(234);
request.setInCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
auto promise = request.send();
auto pipelineRequest2 = promise.getOutBox().getCap().castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
promise = nullptr; // Just to be annoying, drop the original promise.
auto pipelineRequest2 = promise.getOutBox().getCap()
.castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
EXPECT_EQ(0, callCount);
EXPECT_EQ(0, reverseCallCount);
promise = nullptr; // Just to be annoying, drop the original promise.
EXPECT_EQ(0, callCount);
EXPECT_EQ(0, reverseCallCount);
auto response = loop.wait(kj::mv(pipelinePromise));
EXPECT_EQ("bar", response.getX());
auto response2 = loop.wait(kj::mv(pipelinePromise2));
checkTestMessage(response2);
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
}
auto response = loop.wait(kj::mv(pipelinePromise));
EXPECT_EQ("bar", response.getX());
EXPECT_FALSE(disconnected);
EXPECT_FALSE(drained);
auto response2 = loop.wait(kj::mv(pipelinePromise2));
checkTestMessage(response2);
// What if the other side disconnects?
quitter.fulfiller->fulfill();
thread = nullptr;
loop.wait(kj::mv(disconnectPromise));
EXPECT_FALSE(drained);
{
// Use the now-broken capability.
auto request = client.getCapRequest();
request.setN(234);
request.setInCap(test::TestInterface::Client(
kj::heap<TestInterfaceImpl>(reverseCallCount), loop));
auto promise = request.send();
auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
pipelineRequest.setI(321);
auto pipelinePromise = pipelineRequest.send();
auto pipelineRequest2 = promise.getOutBox().getCap()
.castAs<test::TestExtends>().graultRequest();
auto pipelinePromise2 = pipelineRequest2.send();
EXPECT_ANY_THROW(loop.wait(kj::mv(pipelinePromise)));
EXPECT_ANY_THROW(loop.wait(kj::mv(pipelinePromise2)));
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
}
EXPECT_FALSE(drained);
}
EXPECT_EQ(3, callCount);
EXPECT_EQ(1, reverseCallCount);
loop.wait(kj::mv(drainedPromise));
}
} // namespace
......
......@@ -31,15 +31,25 @@ TwoPartyVatNetwork::TwoPartyVatNetwork(
const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream, rpc::twoparty::Side side,
ReaderOptions receiveOptions)
: eventLoop(eventLoop), stream(stream), side(side), receiveOptions(receiveOptions),
previousWrite(kj::READY_NOW) {}
previousWrite(kj::READY_NOW) {
{
auto paf = kj::newPromiseAndFulfiller<void>();
disconnectPromise = eventLoop.fork(kj::mv(paf.promise));
disconnectFulfiller.getWithoutLock() = kj::mv(paf.fulfiller);
}
{
auto paf = kj::newPromiseAndFulfiller<void>();
drainedPromise = eventLoop.fork(kj::mv(paf.promise));
drainedFulfiller.fulfiller.getWithoutLock() = kj::mv(paf.fulfiller);
}
}
kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connectToRefHost(
rpc::twoparty::SturdyRefHostId::Reader ref) {
if (ref.getSide() == side) {
return nullptr;
} else {
return kj::Own<TwoPartyVatNetworkBase::Connection>(this,
kj::DestructorOnlyDisposer<TwoPartyVatNetworkBase::Connection>::instance);
return kj::Own<TwoPartyVatNetworkBase::Connection>(this, drainedFulfiller);
}
}
......@@ -70,13 +80,21 @@ public:
void send() override {
auto lock = network.previousWrite.lockExclusive();
*lock = network.eventLoop.there(network.eventLoop.there(kj::mv(*lock), [this]() {
return writeMessage(network.stream, message);
}), kj::mvCapture(kj::addRef(*this),
[](kj::Own<OutgoingMessageImpl>&& self) -> kj::Promise<void> {
// Hack to force this continuation to run (thus allowing `self` to be released) even if
// no one is waiting on the promise.
return kj::READY_NOW;
*lock = network.eventLoop.there(kj::mv(*lock),
kj::mvCapture(kj::addRef(*this), [&](kj::Own<OutgoingMessageImpl>&& self) {
return writeMessage(network.stream, message)
.then(kj::mvCapture(kj::mv(self),
[](kj::Own<OutgoingMessageImpl>&& self) -> kj::Promise<void> {
// Just here to hold a reference to `self` until the write completes.
// Hack to force this continuation to run (thus allowing `self` to be released) even if
// no one is waiting on the promise.
return kj::READY_NOW;
}), [&](kj::Exception&& exception) -> kj::Promise<void> {
// Exception during write!
network.disconnectFulfiller.lockExclusive()->get()->fulfill();
return kj::READY_NOW;
});
}));
}
......@@ -102,11 +120,21 @@ kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(
return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize);
}
kj::Promise<kj::Own<IncomingRpcMessage>> TwoPartyVatNetwork::receiveIncomingMessage() {
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
return eventLoop.evalLater([&]() {
return readMessage(stream, receiveOptions)
.then([](kj::Own<MessageReader>&& message) -> kj::Own<IncomingRpcMessage> {
return kj::heap<IncomingMessageImpl>(kj::mv(message));
return tryReadMessage(stream, receiveOptions)
.then([&](kj::Maybe<kj::Own<MessageReader>>&& message)
-> kj::Maybe<kj::Own<IncomingRpcMessage>> {
KJ_IF_MAYBE(m, message) {
return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(*m)));
} else {
disconnectFulfiller.lockExclusive()->get()->fulfill();
return nullptr;
}
}, [&](kj::Exception&& exception) {
disconnectFulfiller.lockExclusive()->get()->fulfill();
kj::throwRecoverableException(kj::mv(exception));
return nullptr;
});
});
}
......
......@@ -41,6 +41,15 @@ public:
TwoPartyVatNetwork(const kj::EventLoop& eventLoop, kj::AsyncIoStream& stream,
rpc::twoparty::Side side, ReaderOptions receiveOptions = ReaderOptions());
kj::Promise<void> onDisconnect() { return disconnectPromise.addBranch(); }
// Returns a promise that resolves when the peer disconnects.
kj::Promise<void> onDrained() { return drainedPromise.addBranch(); }
// Returns a promise that resolves once the peer has disconnected *and* all local objects
// referencing this connection have been destroyed. A caller might use this to decide when it
// is safe to destroy the RpcSystem, if it isn't able to reliably destroy all objects using it
// directly.
// implements VatNetwork -----------------------------------------------------
kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> connectToRefHost(
......@@ -64,10 +73,22 @@ private:
// Fulfiller for the promise returned by acceptConnectionAsRefHost() on the client side, or the
// second call on the server side. Never fulfilled, because there is only one connection.
kj::ForkedPromise<void> disconnectPromise = nullptr;
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> disconnectFulfiller;
kj::ForkedPromise<void> drainedPromise = nullptr;
class FulfillerDisposer: public kj::Disposer {
public:
kj::MutexGuarded<kj::Own<kj::PromiseFulfiller<void>>> fulfiller;
void disposeImpl(void* pointer) const override { fulfiller.lockExclusive()->get()->fulfill(); }
};
FulfillerDisposer drainedFulfiller;
// implements Connection -----------------------------------------------------
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const override;
kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() override;
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override;
void introduceTo(TwoPartyVatNetworkBase::Connection& recipient,
rpc::twoparty::ThirdPartyCapId::Builder sendToRecipient,
rpc::twoparty::RecipientId::Builder sendToTarget) override;
......
This diff is collapsed.
......@@ -31,7 +31,7 @@ namespace capnp {
// =======================================================================================
// ***************************************************************************************
// This section contains various internal stuff that needs to be declared upfront.
// Scroll down to `class EventLoop` or `class Promise` for the public interfaces.
// Scroll down to `class VatNetwork` or `class RpcSystem` for the public interfaces.
// ***************************************************************************************
// =======================================================================================
......@@ -60,7 +60,7 @@ public:
class Connection {
public:
virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) const = 0;
virtual kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() = 0;
virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
virtual void baseIntroduceTo(Connection& recipient,
ObjectPointer::Builder sendToRecipient,
ObjectPointer::Builder sendToTarget) = 0;
......@@ -163,10 +163,9 @@ public:
//
// Notice that this may be called from any thread.
virtual kj::Promise<kj::Own<IncomingRpcMessage>> receiveIncomingMessage() = 0;
// Wait for a message to be received and return it. If the connection fails before a message
// is received, the promise will be broken -- this is the only way to tell if a connection has
// died.
virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
// Wait for a message to be received and return it. If the read stream cleanly terminates,
// return null. If any other problem occurs, throw an exception.
// Level 3 features ----------------------------------------------
......
......@@ -30,10 +30,12 @@ namespace {
class AsyncMessageReader: public MessageReader {
public:
inline AsyncMessageReader(ReaderOptions options): MessageReader(options) {}
inline AsyncMessageReader(ReaderOptions options): MessageReader(options) {
memset(firstWord, 0, sizeof(firstWord));
}
~AsyncMessageReader() noexcept(false) {}
kj::Promise<void> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
// implements MessageReader ----------------------------------------
......@@ -56,68 +58,93 @@ private:
inline uint segmentCount() { return firstWord[0].get() + 1; }
inline uint segment0Size() { return firstWord[1].get(); }
kj::Promise<void> readAfterFirstWord(
kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
kj::Promise<void> readSegments(
kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace);
};
kj::Promise<void> AsyncMessageReader::read(kj::AsyncInputStream& inputStream,
kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) {
return inputStream.read(firstWord, sizeof(firstWord))
.then([this,&inputStream]() -> kj::Promise<void> {
if (segmentCount() == 0) {
firstWord[1].set(0);
return inputStream.tryRead(firstWord, sizeof(firstWord), sizeof(firstWord))
.then([this,&inputStream,scratchSpace](size_t n) mutable -> kj::Promise<bool> {
if (n == 0) {
return false;
} else if (n < sizeof(firstWord)) {
// EOF in first word.
KJ_FAIL_REQUIRE("Premature EOF.") {
return false;
}
}
// Reject messages with too many segments for security reasons.
KJ_REQUIRE(segmentCount() < 512, "Message has too many segments.") {
return kj::READY_NOW; // exception will be propagated
}
return readAfterFirstWord(inputStream, scratchSpace).then([]() { return true; });
});
}
if (segmentCount() > 1) {
// Read sizes for all segments except the first. Include padding if necessary.
moreSizes = kj::heapArray<_::WireValue<uint32_t>>(segmentCount() & ~1);
return inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0]));
} else {
return kj::READY_NOW;
}
}).then([this,&inputStream,scratchSpace]() mutable -> kj::Promise<void> {
size_t totalWords = segment0Size();
kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) {
if (segmentCount() == 0) {
firstWord[1].set(0);
}
if (segmentCount() > 1) {
for (uint i = 0; i < segmentCount() - 1; i++) {
totalWords += moreSizes[i].get();
}
}
// Reject messages with too many segments for security reasons.
KJ_REQUIRE(segmentCount() < 512, "Message has too many segments.") {
return kj::READY_NOW; // exception will be propagated
}
// Don't accept a message which the receiver couldn't possibly traverse without hitting the
// traversal limit. Without this check, a malicious client could transmit a very large segment
// size to make the receiver allocate excessive space and possibly crash.
KJ_REQUIRE(totalWords <= getOptions().traversalLimitInWords,
"Message is too large. To increase the limit on the receiving end, see "
"capnp::ReaderOptions.") {
return kj::READY_NOW; // exception will be propagated
}
if (segmentCount() > 1) {
// Read sizes for all segments except the first. Include padding if necessary.
moreSizes = kj::heapArray<_::WireValue<uint32_t>>(segmentCount() & ~1);
return inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0]))
.then([this,&inputStream,scratchSpace]() mutable {
return readSegments(inputStream, scratchSpace);
});
} else {
return readSegments(inputStream, scratchSpace);
}
}
if (scratchSpace.size() < totalWords) {
// TODO(perf): Consider allocating each segment as a separate chunk to reduce memory
// fragmentation.
ownedSpace = kj::heapArray<word>(totalWords);
scratchSpace = ownedSpace;
kj::Promise<void> AsyncMessageReader::readSegments(kj::AsyncInputStream& inputStream,
kj::ArrayPtr<word> scratchSpace) {
size_t totalWords = segment0Size();
if (segmentCount() > 1) {
for (uint i = 0; i < segmentCount() - 1; i++) {
totalWords += moreSizes[i].get();
}
}
segmentStarts = kj::heapArray<const word*>(segmentCount());
// Don't accept a message which the receiver couldn't possibly traverse without hitting the
// traversal limit. Without this check, a malicious client could transmit a very large segment
// size to make the receiver allocate excessive space and possibly crash.
KJ_REQUIRE(totalWords <= getOptions().traversalLimitInWords,
"Message is too large. To increase the limit on the receiving end, see "
"capnp::ReaderOptions.") {
return kj::READY_NOW; // exception will be propagated
}
if (scratchSpace.size() < totalWords) {
// TODO(perf): Consider allocating each segment as a separate chunk to reduce memory
// fragmentation.
ownedSpace = kj::heapArray<word>(totalWords);
scratchSpace = ownedSpace;
}
segmentStarts[0] = scratchSpace.begin();
segmentStarts = kj::heapArray<const word*>(segmentCount());
if (segmentCount() > 1) {
size_t offset = segment0Size();
segmentStarts[0] = scratchSpace.begin();
for (uint i = 1; i < segmentCount(); i++) {
segmentStarts[i] = scratchSpace.begin() + offset;
offset += moreSizes[i-1].get();
}
if (segmentCount() > 1) {
size_t offset = segment0Size();
for (uint i = 1; i < segmentCount(); i++) {
segmentStarts[i] = scratchSpace.begin() + offset;
offset += moreSizes[i-1].get();
}
}
return inputStream.read(scratchSpace.begin(), totalWords * sizeof(word));
});
return inputStream.read(scratchSpace.begin(), totalWords * sizeof(word));
}
......@@ -127,11 +154,26 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader) {
return promise.then(kj::mvCapture(reader, [](kj::Own<MessageReader>&& reader, bool success) {
KJ_REQUIRE(success, "Premature EOF.") { break; }
return kj::mv(reader);
}));
}
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
auto reader = kj::heap<AsyncMessageReader>(options);
auto promise = reader->read(input, scratchSpace);
return promise.then(kj::mvCapture(reader,
[](kj::Own<MessageReader>&& reader, bool success) -> kj::Maybe<kj::Own<MessageReader>> {
if (success) {
return kj::mv(reader);
} else {
return nullptr;
}
}));
}
// =======================================================================================
namespace {
......
......@@ -38,6 +38,11 @@ kj::Promise<kj::Own<MessageReader>> readMessage(
//
// `scratchSpace`, if provided, must remain valid until the returned MessageReader is destroyed.
kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage(
kj::AsyncInputStream& input, ReaderOptions options = ReaderOptions(),
kj::ArrayPtr<word> scratchSpace = nullptr);
// Like `readMessage` but returns null on EOF.
kj::Promise<void> writeMessage(kj::AsyncOutputStream& output,
kj::ArrayPtr<const kj::ArrayPtr<const word>> segments)
KJ_WARN_UNUSED_RESULT;
......
......@@ -64,6 +64,9 @@ void registerSignalHandler(int signum) {
void registerSigusr1() {
registerSignalHandler(SIGUSR1);
// We also disable SIGPIPE because users of UnixEventLoop almost certainly don't want it.
signal(SIGPIPE, SIG_IGN);
}
pthread_once_t registerSigusr1Once = PTHREAD_ONCE_INIT;
......
......@@ -25,4 +25,6 @@
namespace kj {
const NullDisposer NullDisposer::instance = NullDisposer();
} // namespace kj
......@@ -81,6 +81,15 @@ public:
template <typename T>
const DestructorOnlyDisposer<T> DestructorOnlyDisposer<T>::instance = DestructorOnlyDisposer<T>();
class NullDisposer: public Disposer {
// A disposer that does nothing.
public:
static const NullDisposer instance;
void disposeImpl(void* pointer) const override {}
};
// =======================================================================================
// Own<T> -- An owned pointer.
......
......@@ -142,5 +142,23 @@ TEST(Mutex, Lazy) {
EXPECT_EQ(123u, lazy.get([](SpaceFor<uint>& space) { return space.construct(789); }));
}
TEST(Mutex, LazyException) {
Lazy<uint> lazy;
auto exception = kj::runCatchingExceptions([&]() {
lazy.get([&](SpaceFor<uint>& space) -> Own<uint> {
KJ_FAIL_ASSERT("foo") { break; }
return space.construct(123);
});
});
EXPECT_TRUE(exception != nullptr);
uint i = lazy.get([&](SpaceFor<uint>& space) -> Own<uint> {
return space.construct(456);
});
EXPECT_EQ(456, i);
}
} // namespace
} // namespace kj
......@@ -137,11 +137,23 @@ void Mutex::assertLockedByCaller(Exclusivity exclusivity) {
}
void Once::runOnce(Initializer& init) {
startOver:
uint state = UNINITIALIZED;
if (__atomic_compare_exchange_n(&futex, &state, INITIALIZING, false,
__ATOMIC_RELAXED, __ATOMIC_RELAXED)) {
// It's our job to initialize!
init.run();
{
KJ_ON_SCOPE_FAILURE({
// An exception was thrown by the initializer. We have to revert.
if (__atomic_exchange_n(&futex, UNINITIALIZED, __ATOMIC_RELEASE) ==
INITIALIZING_WITH_WAITERS) {
// Someone was waiting for us to finish.
syscall(SYS_futex, &futex, FUTEX_WAKE_PRIVATE, INT_MAX, NULL, NULL, 0);
}
});
init.run();
}
if (__atomic_exchange_n(&futex, INITIALIZED, __ATOMIC_RELEASE) ==
INITIALIZING_WITH_WAITERS) {
// Someone was waiting for us to finish.
......@@ -165,6 +177,12 @@ void Once::runOnce(Initializer& init) {
// Wait for initialization.
syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0);
state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE);
if (state == UNINITIALIZED) {
// Oh hey, apparently whoever was trying to initialize gave up. Let's take it from the
// top.
goto startOver;
}
}
}
}
......@@ -209,7 +227,7 @@ void Once::disable() noexcept {
// Wait for initialization.
syscall(SYS_futex, &futex, FUTEX_WAIT_PRIVATE, INITIALIZING_WITH_WAITERS, NULL, NULL, 0);
state = __atomic_load_n(&futex, __ATOMIC_ACQUIRE);
break;
continue;
}
}
}
......
......@@ -28,22 +28,26 @@
namespace kj {
Thread::Thread(void* (*run)(void*), void (*deleteArg)(void*), void* arg) {
Thread::Thread(Function<void()> func): func(kj::mv(func)) {
static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port.");
int pthreadResult = pthread_create(reinterpret_cast<pthread_t*>(&threadId), nullptr, run, arg);
int pthreadResult = pthread_create(reinterpret_cast<pthread_t*>(&threadId),
nullptr, &runThread, this);
if (pthreadResult != 0) {
deleteArg(arg);
KJ_FAIL_SYSCALL("pthread_create", pthreadResult);
}
}
Thread::~Thread() {
Thread::~Thread() noexcept(false) {
int pthreadResult = pthread_join(*reinterpret_cast<pthread_t*>(&threadId), nullptr);
if (pthreadResult != 0) {
KJ_FAIL_SYSCALL("pthread_join", pthreadResult) { break; }
}
KJ_IF_MAYBE(e, exception) {
kj::throwRecoverableException(kj::mv(*e));
}
}
void Thread::sendSignal(int signo) {
......@@ -53,4 +57,14 @@ void Thread::sendSignal(int signo) {
}
}
void* Thread::runThread(void* ptr) {
Thread* thread = reinterpret_cast<Thread*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
thread->func();
})) {
thread->exception = kj::mv(*exception);
}
return nullptr;
}
} // namespace kj
......@@ -25,43 +25,30 @@
#define KJ_THREAD_H_
#include "common.h"
#include "function.h"
#include "exception.h"
namespace kj {
class Thread {
// A thread! Pass a lambda to the constructor. The destructor joins the thread.
// A thread! Pass a lambda to the constructor, and it runs in the thread. The destructor joins
// the thread. If the function throws an exception, it is rethrown from the thread's destructor
// (if not unwinding from another exception).
public:
template <typename Func>
explicit Thread(Func&& func)
: Thread(&runThread<Decay<Func>>,
&deleteArg<Decay<Func>>,
new Decay<Func>(kj::fwd<Func>(func))) {}
explicit Thread(Function<void()> func);
~Thread();
~Thread() noexcept(false);
void sendSignal(int signo);
// Send a Unix signal to the given thread, using pthread_kill or an equivalent.
private:
Function<void()> func;
unsigned long long threadId; // actually pthread_t
kj::Maybe<kj::Exception> exception;
Thread(void* (*run)(void*), void (*deleteArg)(void*), void* arg);
template <typename Func>
static void* runThread(void* ptr) {
// TODO(someday): Catch exceptions and propagate to the joiner.
Func* func = reinterpret_cast<Func*>(ptr);
KJ_DEFER(delete func);
(*func)();
return nullptr;
}
template <typename Func>
static void deleteArg(void* ptr) {
delete reinterpret_cast<Func*>(ptr);
}
static void* runThread(void* ptr);
};
} // namespace kj
......
linux-gcc-4.7 1735 ./super-test.sh tmpdir capnp-gcc-4.7 quick
linux-gcc-4.8 1738 ./super-test.sh tmpdir capnp-gcc-4.8 quick gcc-4.8
linux-clang 1758 ./super-test.sh tmpdir capnp-clang quick clang
mac 807 ./super-test.sh remote beat caffeinate quick
linux-gcc-4.7 1737 ./super-test.sh tmpdir capnp-gcc-4.7 quick
linux-gcc-4.8 1740 ./super-test.sh tmpdir capnp-gcc-4.8 quick gcc-4.8
linux-clang 1760 ./super-test.sh tmpdir capnp-clang quick clang
mac 805 ./super-test.sh remote beat caffeinate quick
cygwin 810 ./super-test.sh remote Kenton@flashman quick
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment