Commit e2905da7 authored by Kenton Varda's avatar Kenton Varda

Test and fix embargoes.

parent 6b8b8c71
...@@ -207,7 +207,13 @@ class QueuedPipeline final: public PipelineHook, public kj::Refcounted { ...@@ -207,7 +207,13 @@ class QueuedPipeline final: public PipelineHook, public kj::Refcounted {
public: public:
QueuedPipeline(const kj::EventLoop& loop, kj::Promise<kj::Own<const PipelineHook>>&& promise) QueuedPipeline(const kj::EventLoop& loop, kj::Promise<kj::Own<const PipelineHook>>&& promise)
: loop(loop), : loop(loop),
promise(loop.fork(kj::mv(promise))) {} promise(loop.fork(kj::mv(promise))),
selfResolutionOp(loop.there(this->promise.addBranch(),
[this](kj::Own<const PipelineHook>&& inner) {
*redirect.lockExclusive() = kj::mv(inner);
})) {
selfResolutionOp.eagerlyEvaluate(loop);
}
kj::Own<const PipelineHook> addRef() const override { kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this); return kj::addRef(*this);
...@@ -226,6 +232,12 @@ public: ...@@ -226,6 +232,12 @@ public:
private: private:
const kj::EventLoop& loop; const kj::EventLoop& loop;
kj::ForkedPromise<kj::Own<const PipelineHook>> promise; kj::ForkedPromise<kj::Own<const PipelineHook>> promise;
kj::MutexGuarded<kj::Maybe<kj::Own<const PipelineHook>>> redirect;
// Once the promise resolves, this will become non-null and point to the underlying object.
kj::Promise<void> selfResolutionOp;
// Represents the operation which will set `redirect` when possible.
}; };
class QueuedClient final: public ClientHook, public kj::Refcounted { class QueuedClient final: public ClientHook, public kj::Refcounted {
...@@ -371,12 +383,18 @@ private: ...@@ -371,12 +383,18 @@ private:
}; };
kj::Own<const ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) const { kj::Own<const ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) const {
auto clientPromise = loop.there(promise.addBranch(), kj::mvCapture(ops, auto lock = redirect.lockShared();
[](kj::Array<PipelineOp>&& ops, kj::Own<const PipelineHook> pipeline) {
return pipeline->getPipelinedCap(kj::mv(ops));
}));
return kj::refcounted<QueuedClient>(loop, kj::mv(clientPromise)); KJ_IF_MAYBE(redirect, *lock) {
return redirect->get()->getPipelinedCap(kj::mv(ops));
} else {
auto clientPromise = loop.there(promise.addBranch(), kj::mvCapture(ops,
[](kj::Array<PipelineOp>&& ops, kj::Own<const PipelineHook> pipeline) {
return pipeline->getPipelinedCap(kj::mv(ops));
}));
return kj::refcounted<QueuedClient>(loop, kj::mv(clientPromise));
}
} }
// ======================================================================================= // =======================================================================================
......
...@@ -24,7 +24,9 @@ ...@@ -24,7 +24,9 @@
#include "rpc.h" #include "rpc.h"
#include "capability-context.h" #include "capability-context.h"
#include "test-util.h" #include "test-util.h"
#include "schema.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/string-tree.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <capnp/rpc.capnp.h> #include <capnp/rpc.capnp.h>
#include <map> #include <map>
...@@ -34,10 +36,148 @@ namespace capnp { ...@@ -34,10 +36,148 @@ namespace capnp {
namespace _ { // private namespace _ { // private
namespace { namespace {
class RpcDumper {
// Class which stringifies RPC messages for debugging purposes, including decoding params and
// results based on the call's interface and method IDs and extracting cap descriptors.
public:
void addSchema(InterfaceSchema schema) {
schemas[schema.getProto().getId()] = schema;
}
enum Sender {
CLIENT,
SERVER
};
kj::String dump(rpc::Message::Reader message, Sender sender) {
const char* senderName = sender == CLIENT ? "client" : "server";
switch (message.which()) {
case rpc::Message::CALL: {
auto call = message.getCall();
auto iter = schemas.find(call.getInterfaceId());
if (iter == schemas.end()) {
break;
}
InterfaceSchema schema = iter->second;
auto methods = schema.getMethods();
if (call.getMethodId() >= methods.size()) {
break;
}
InterfaceSchema::Method method = methods[call.getMethodId()];
auto schemaProto = schema.getProto();
auto interfaceName =
schemaProto.getDisplayName().slice(schemaProto.getDisplayNamePrefixLength());
auto methodProto = method.getProto();
auto paramType = schema.getDependency(methodProto.getParamStructType()).asStruct();
auto resultType = schema.getDependency(methodProto.getResultStructType()).asStruct();
returnTypes[std::make_pair(sender, call.getQuestionId())] = resultType;
CapExtractorImpl extractor;
CapReaderContext context(extractor);
auto params = kj::str(context.imbue(call.getParams()).getAs<DynamicStruct>(paramType));
auto sendResultsTo = call.getSendResultsTo();
return kj::str(senderName, "(", call.getQuestionId(), "): call ",
call.getTarget(), " <- ", interfaceName, ".",
methodProto.getName(), params,
" caps:[", extractor.printCaps(), "]",
sendResultsTo.isCaller() ? kj::str()
: kj::str(" sendResultsTo:", sendResultsTo));
}
case rpc::Message::RETURN: {
auto ret = message.getReturn();
auto iter = returnTypes.find(
std::make_pair(sender == CLIENT ? SERVER : CLIENT, ret.getQuestionId()));
if (iter == returnTypes.end()) {
break;
}
auto schema = iter->second;
returnTypes.erase(iter);
if (ret.which() != rpc::Return::RESULTS) {
// Oops, no results returned. We don't check this earlier because we want to make sure
// returnTypes.erase() gets a chance to happen.
break;
}
CapExtractorImpl extractor;
CapReaderContext context(extractor);
auto imbued = context.imbue(ret.getResults());
if (schema.getProto().isStruct()) {
auto results = kj::str(imbued.getAs<DynamicStruct>(schema.asStruct()));
return kj::str(senderName, "(", ret.getQuestionId(), "): return ", results,
" caps:[", extractor.printCaps(), "]");
} else if (schema.getProto().isInterface()) {
imbued.getAs<DynamicCapability>(schema.asInterface());
return kj::str(senderName, "(", ret.getQuestionId(), "): return cap ",
extractor.printCaps());
} else {
break;
}
}
case rpc::Message::RESTORE: {
auto restore = message.getRestore();
returnTypes[std::make_pair(sender, restore.getQuestionId())] = InterfaceSchema();
return kj::str(senderName, "(", restore.getQuestionId(), "): restore ",
restore.getObjectId().getAs<test::TestSturdyRefObjectId>());
}
default:
break;
}
return kj::str(senderName, ": ", message);
}
private:
std::map<uint64_t, InterfaceSchema> schemas;
std::map<std::pair<Sender, uint32_t>, Schema> returnTypes;
class CapExtractorImpl: public CapExtractor<rpc::CapDescriptor> {
public:
kj::Own<const ClientHook> extractCap(rpc::CapDescriptor::Reader descriptor) const {
caps.add(kj::str(descriptor));
return newBrokenCap("fake cap");
}
kj::String printCaps() {
return kj::strArray(caps, ", ");
}
private:
mutable kj::Vector<kj::String> caps;
};
};
// =======================================================================================
class TestNetworkAdapter; class TestNetworkAdapter;
class TestNetwork { class TestNetwork {
public: public:
TestNetwork(kj::EventLoop& loop): loop(loop) {
dumper.addSchema(Schema::from<test::TestInterface>());
dumper.addSchema(Schema::from<test::TestExtends>());
dumper.addSchema(Schema::from<test::TestPipeline>());
dumper.addSchema(Schema::from<test::TestCallOrder>());
dumper.addSchema(Schema::from<test::TestTailCallee>());
dumper.addSchema(Schema::from<test::TestTailCaller>());
dumper.addSchema(Schema::from<test::TestMoreStuff>());
}
~TestNetwork() noexcept(false); ~TestNetwork() noexcept(false);
TestNetworkAdapter& add(kj::StringPtr name); TestNetworkAdapter& add(kj::StringPtr name);
...@@ -51,7 +191,10 @@ public: ...@@ -51,7 +191,10 @@ public:
} }
} }
RpcDumper dumper;
private: private:
kj::EventLoop& loop;
std::map<kj::StringPtr, kj::Own<TestNetworkAdapter>> map; std::map<kj::StringPtr, kj::Own<TestNetworkAdapter>> map;
}; };
...@@ -61,16 +204,18 @@ typedef VatNetwork< ...@@ -61,16 +204,18 @@ typedef VatNetwork<
class TestNetworkAdapter final: public TestNetworkAdapterBase { class TestNetworkAdapter final: public TestNetworkAdapterBase {
public: public:
TestNetworkAdapter(TestNetwork& network): network(network) {} TestNetworkAdapter(kj::EventLoop& loop, TestNetwork& network): loop(loop), network(network) {}
uint getSentCount() { return sent; } uint getSentCount() { return sent; }
uint getReceivedCount() { return received; } uint getReceivedCount() { return received; }
typedef TestNetworkAdapterBase::Connection Connection; typedef TestNetworkAdapterBase::Connection Connection;
class ConnectionImpl final: public Connection, public kj::Refcounted { class ConnectionImpl final
: public Connection, public kj::Refcounted, public kj::TaskSet::ErrorHandler {
public: public:
ConnectionImpl(TestNetworkAdapter& network, const char* name): network(network), name(name) {} ConnectionImpl(TestNetworkAdapter& network, RpcDumper::Sender sender)
: network(network), sender(sender), tasks(network.loop, *this) {}
void attach(ConnectionImpl& other) { void attach(ConnectionImpl& other) {
KJ_REQUIRE(partner == nullptr); KJ_REQUIRE(partner == nullptr);
...@@ -79,7 +224,7 @@ public: ...@@ -79,7 +224,7 @@ public:
other.partner = *this; other.partner = *this;
} }
class IncomingRpcMessageImpl final: public IncomingRpcMessage { class IncomingRpcMessageImpl final: public IncomingRpcMessage, public kj::Refcounted {
public: public:
IncomingRpcMessageImpl(uint firstSegmentWordSize) IncomingRpcMessageImpl(uint firstSegmentWordSize)
: message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS
...@@ -96,7 +241,7 @@ public: ...@@ -96,7 +241,7 @@ public:
public: public:
OutgoingRpcMessageImpl(const ConnectionImpl& connection, uint firstSegmentWordSize) OutgoingRpcMessageImpl(const ConnectionImpl& connection, uint firstSegmentWordSize)
: connection(connection), : connection(connection),
message(kj::heap<IncomingRpcMessageImpl>(firstSegmentWordSize)) {} message(kj::refcounted<IncomingRpcMessageImpl>(firstSegmentWordSize)) {}
ObjectPointer::Builder getBody() override { ObjectPointer::Builder getBody() override {
return message->message.getRoot<ObjectPointer>(); return message->message.getRoot<ObjectPointer>();
...@@ -104,19 +249,26 @@ public: ...@@ -104,19 +249,26 @@ public:
void send() override { void send() override {
++connection.network.sent; ++connection.network.sent;
kj::String msg = kj::str(connection.name, ": ", message->message.getRoot<rpc::Message>()); // Uncomment to get a debug dump.
//KJ_DBG(msg); // kj::String msg = connection.network.network.dumper.dump(
// message->message.getRoot<rpc::Message>(), connection.sender);
KJ_IF_MAYBE(p, connection.partner) { // KJ_ DBG(msg);
auto lock = p->queues.lockExclusive();
if (lock->fulfillers.empty()) { auto connectionPtr = &connection;
lock->messages.push(kj::mv(message)); connection.tasks.add(connection.network.loop.evalLater(
} else { kj::mvCapture(kj::addRef(*message),
++connection.network.received; [connectionPtr](kj::Own<IncomingRpcMessageImpl>&& message) {
lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message))); KJ_IF_MAYBE(p, connectionPtr->partner) {
lock->fulfillers.pop(); auto lock = p->queues.lockExclusive();
if (lock->fulfillers.empty()) {
lock->messages.push(kj::mv(message));
} else {
++connectionPtr->network.received;
lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message)));
lock->fulfillers.pop();
}
} }
} })));
} }
private: private:
...@@ -154,9 +306,13 @@ public: ...@@ -154,9 +306,13 @@ public:
KJ_FAIL_ASSERT("not implemented"); KJ_FAIL_ASSERT("not implemented");
} }
void taskFailed(kj::Exception&& exception) override {
ADD_FAILURE() << kj::str(exception).cStr();
}
private: private:
TestNetworkAdapter& network; TestNetworkAdapter& network;
const char* name; RpcDumper::Sender sender;
kj::Maybe<ConnectionImpl&> partner; kj::Maybe<ConnectionImpl&> partner;
struct Queues { struct Queues {
...@@ -164,6 +320,8 @@ public: ...@@ -164,6 +320,8 @@ public:
std::queue<kj::Own<IncomingRpcMessage>> messages; std::queue<kj::Own<IncomingRpcMessage>> messages;
}; };
kj::MutexGuarded<Queues> queues; kj::MutexGuarded<Queues> queues;
kj::TaskSet tasks;
}; };
kj::Maybe<kj::Own<Connection>> connectToRefHost( kj::Maybe<kj::Own<Connection>> connectToRefHost(
...@@ -183,8 +341,8 @@ public: ...@@ -183,8 +341,8 @@ public:
auto iter = myLock->connections.find(&dst); auto iter = myLock->connections.find(&dst);
if (iter == myLock->connections.end()) { if (iter == myLock->connections.end()) {
auto local = kj::refcounted<ConnectionImpl>(*this, "client"); auto local = kj::refcounted<ConnectionImpl>(*this, RpcDumper::CLIENT);
auto remote = kj::refcounted<ConnectionImpl>(dst, "server"); auto remote = kj::refcounted<ConnectionImpl>(dst, RpcDumper::SERVER);
local->attach(*remote); local->attach(*remote);
myLock->connections[&dst] = kj::addRef(*local); myLock->connections[&dst] = kj::addRef(*local);
...@@ -217,6 +375,7 @@ public: ...@@ -217,6 +375,7 @@ public:
} }
private: private:
kj::EventLoop& loop;
TestNetwork& network; TestNetwork& network;
uint sent = 0; uint sent = 0;
uint received = 0; uint received = 0;
...@@ -232,7 +391,7 @@ private: ...@@ -232,7 +391,7 @@ private:
TestNetwork::~TestNetwork() noexcept(false) {} TestNetwork::~TestNetwork() noexcept(false) {}
TestNetworkAdapter& TestNetwork::add(kj::StringPtr name) { TestNetworkAdapter& TestNetwork::add(kj::StringPtr name) {
return *(map[name] = kj::heap<TestNetworkAdapter>(*this)); return *(map[name] = kj::heap<TestNetworkAdapter>(loop, *this));
} }
// ======================================================================================= // =======================================================================================
...@@ -262,9 +421,9 @@ public: ...@@ -262,9 +421,9 @@ public:
class RpcTest: public testing::Test { class RpcTest: public testing::Test {
protected: protected:
kj::SimpleEventLoop loop;
TestNetwork network; TestNetwork network;
TestRestorer restorer; TestRestorer restorer;
kj::SimpleEventLoop loop;
TestNetworkAdapter& clientNetwork; TestNetworkAdapter& clientNetwork;
TestNetworkAdapter& serverNetwork; TestNetworkAdapter& serverNetwork;
RpcSystem<test::TestSturdyRefHostId> rpcClient; RpcSystem<test::TestSturdyRefHostId> rpcClient;
...@@ -281,7 +440,8 @@ protected: ...@@ -281,7 +440,8 @@ protected:
} }
RpcTest() RpcTest()
: clientNetwork(network.add("client")), : network(loop),
clientNetwork(network.add("client")),
serverNetwork(network.add("server")), serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork, loop)), rpcClient(makeRpcClient(clientNetwork, loop)),
rpcServer(makeRpcServer(serverNetwork, restorer, loop)) {} rpcServer(makeRpcServer(serverNetwork, restorer, loop)) {}
...@@ -587,6 +747,48 @@ TEST_F(RpcTest, SendTwice) { ...@@ -587,6 +747,48 @@ TEST_F(RpcTest, SendTwice) {
EXPECT_TRUE(destroyed); EXPECT_TRUE(destroyed);
} }
RemotePromise<test::TestCallOrder::GetCallSequenceResults> getCallSequence(
const test::TestCallOrder::Client& client, uint expected) {
auto req = client.getCallSequenceRequest();
req.setExpected(expected);
return req.send();
}
TEST_F(RpcTest, Embargo) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
auto cap = test::TestCallOrder::Client(kj::heap<TestCallOrderImpl>(), loop);
auto earlyCall = client.getCallSequenceRequest().send();
auto echoRequest = client.echoRequest();
echoRequest.setCap(cap);
auto echo = echoRequest.send();
auto pipeline = echo.getCap();
auto call0 = getCallSequence(pipeline, 0);
auto call1 = getCallSequence(pipeline, 1);
loop.wait(kj::mv(earlyCall));
auto call2 = getCallSequence(pipeline, 2);
auto resolved = loop.wait(kj::mv(echo)).getCap();
auto call3 = getCallSequence(pipeline, 3);
auto call4 = getCallSequence(pipeline, 4);
auto call5 = getCallSequence(pipeline, 5);
EXPECT_EQ(0, loop.wait(kj::mv(call0)).getN());
EXPECT_EQ(1, loop.wait(kj::mv(call1)).getN());
EXPECT_EQ(2, loop.wait(kj::mv(call2)).getN());
EXPECT_EQ(3, loop.wait(kj::mv(call3)).getN());
EXPECT_EQ(4, loop.wait(kj::mv(call4)).getN());
EXPECT_EQ(5, loop.wait(kj::mv(call5)).getN());
}
} // namespace } // namespace
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -298,6 +298,7 @@ public: ...@@ -298,6 +298,7 @@ public:
kj::Vector<kj::Own<const PipelineHook>> pipelinesToRelease; kj::Vector<kj::Own<const PipelineHook>> pipelinesToRelease;
kj::Vector<kj::Own<const ClientHook>> clientsToRelease; kj::Vector<kj::Own<const ClientHook>> clientsToRelease;
kj::Vector<kj::Own<CapInjectorImpl>> paramCapsToRelease; kj::Vector<kj::Own<CapInjectorImpl>> paramCapsToRelease;
kj::Vector<kj::Promise<kj::Own<const RpcResponse>>> promisesToRelease;
auto lock = tables.lockExclusive(); auto lock = tables.lockExclusive();
...@@ -330,9 +331,7 @@ public: ...@@ -330,9 +331,7 @@ public:
} }
KJ_IF_MAYBE(promise, answer.redirectedResults) { KJ_IF_MAYBE(promise, answer.redirectedResults) {
// Answer contains a result redirection that hasn't been picked up yet. Make the call promisesToRelease.add(kj::mv(*promise));
// properly cancelable by transforming the redirect promise into a regular asyncOp.
answer.asyncOp = promise->thenInAnyThread([](kj::Own<const RpcResponse>&& response) {});
} }
KJ_IF_MAYBE(context, answer.callContext) { KJ_IF_MAYBE(context, answer.callContext) {
...@@ -413,9 +412,6 @@ private: ...@@ -413,9 +412,6 @@ private:
kj::Maybe<kj::Own<const PipelineHook>> pipeline; kj::Maybe<kj::Own<const PipelineHook>> pipeline;
// Send pipelined calls here. Becomes null as soon as a `Finish` is received. // Send pipelined calls here. Becomes null as soon as a `Finish` is received.
kj::Promise<void> asyncOp = kj::Promise<void>(nullptr);
// Delete this promise to cancel the call. For redirected calls, this is null.
kj::Maybe<kj::Promise<kj::Own<const RpcResponse>>> redirectedResults; kj::Maybe<kj::Promise<kj::Own<const RpcResponse>>> redirectedResults;
// For locally-redirected calls (Call.sendResultsTo.yourself), this is a promise for the call // For locally-redirected calls (Call.sendResultsTo.yourself), this is a promise for the call
// result, to be picked up by a subsequent `Return`. // result, to be picked up by a subsequent `Return`.
...@@ -885,6 +881,10 @@ private: ...@@ -885,6 +881,10 @@ private:
// the `PromiseClient` is destroyed; `eventual` must therefore make sure to hold references to // the `PromiseClient` is destroyed; `eventual` must therefore make sure to hold references to
// anything that needs to stay alive in order to resolve it correctly (such as making sure the // anything that needs to stay alive in order to resolve it correctly (such as making sure the
// import ID is not released). // import ID is not released).
resolveSelfPromise = connectionState.eventLoop.there(kj::mv(resolveSelfPromise),
[]() {}, [&](kj::Exception&& e) { connectionState.tasks.add(kj::mv(e)); });
resolveSelfPromise.eagerlyEvaluate(connectionState.eventLoop); resolveSelfPromise.eagerlyEvaluate(connectionState.eventLoop);
} }
...@@ -907,6 +907,7 @@ private: ...@@ -907,6 +907,7 @@ private:
kj::Maybe<ExportId> writeDescriptor( kj::Maybe<ExportId> writeDescriptor(
rpc::CapDescriptor::Builder descriptor, Tables& tables) const override { rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return connectionState->writeDescriptor(*inner.lockExclusive()->cap, descriptor, tables); return connectionState->writeDescriptor(*inner.lockExclusive()->cap, descriptor, tables);
} }
...@@ -917,6 +918,7 @@ private: ...@@ -917,6 +918,7 @@ private:
} }
kj::Own<const ClientHook> getInnermostClient() const override { kj::Own<const ClientHook> getInnermostClient() const override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return connectionState->getInnermostClient(*inner.lockExclusive()->cap); return connectionState->getInnermostClient(*inner.lockExclusive()->cap);
} }
...@@ -924,11 +926,13 @@ private: ...@@ -924,11 +926,13 @@ private:
Request<ObjectPointer, ObjectPointer> newCall( Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override { uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return inner.lockExclusive()->cap->newCall(interfaceId, methodId, firstSegmentWordSize); return inner.lockExclusive()->cap->newCall(interfaceId, methodId, firstSegmentWordSize);
} }
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) const override { kj::Own<CallContextHook>&& context) const override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return inner.lockExclusive()->cap->call(interfaceId, methodId, kj::mv(context)); return inner.lockExclusive()->cap->call(interfaceId, methodId, kj::mv(context));
} }
...@@ -962,7 +966,7 @@ private: ...@@ -962,7 +966,7 @@ private:
mutable bool receivedCall = false; mutable bool receivedCall = false;
void resolve(kj::Own<const ClientHook> replacement) { void resolve(kj::Own<const ClientHook> replacement) {
if (replacement->getBrand() != this && if (replacement->getBrand() != connectionState.get() &&
__atomic_load_n(&receivedCall, __ATOMIC_RELAXED)) { __atomic_load_n(&receivedCall, __ATOMIC_RELAXED)) {
// The new capability is hosted locally, not on the remote machine. And, we had made calls // The new capability is hosted locally, not on the remote machine. And, we had made calls
// to the promise. We need to make sure those calls echo back to us before we allow new // to the promise. We need to make sure those calls echo back to us before we allow new
...@@ -972,7 +976,7 @@ private: ...@@ -972,7 +976,7 @@ private:
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Disembargo>() + MESSAGE_TARGET_SIZE_HINT); messageSizeHint<rpc::Disembargo>() + MESSAGE_TARGET_SIZE_HINT);
auto disembargo = message->getBody().initAs<rpc::Message>().getDisembargo(); auto disembargo = message->getBody().initAs<rpc::Message>().initDisembargo();
{ {
auto redirect = connectionState->writeTarget( auto redirect = connectionState->writeTarget(
...@@ -1511,6 +1515,10 @@ private: ...@@ -1511,6 +1515,10 @@ private:
resultCaps(connectionState, kj::mv(resolutionChain)) {} resultCaps(connectionState, kj::mv(resolutionChain)) {}
~QuestionRef() { ~QuestionRef() {
if (connectionState->tables.lockShared()->networkException != nullptr) {
return;
}
// Send the "Finish" message. // Send the "Finish" message.
{ {
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
...@@ -1922,7 +1930,8 @@ private: ...@@ -1922,7 +1930,8 @@ private:
public: public:
RpcCallContext(const RpcConnectionState& connectionState, QuestionId questionId, RpcCallContext(const RpcConnectionState& connectionState, QuestionId questionId,
kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params, kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params,
kj::Own<const ResolutionChain> resolutionChain, bool redirectResults) kj::Own<const ResolutionChain> resolutionChain, bool redirectResults,
kj::Own<kj::PromiseFulfiller<void>>&& cancelFulfiller)
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
questionId(questionId), questionId(questionId),
request(kj::mv(request)), request(kj::mv(request)),
...@@ -1930,7 +1939,8 @@ private: ...@@ -1930,7 +1939,8 @@ private:
requestCapContext(requestCapExtractor), requestCapContext(requestCapExtractor),
params(requestCapContext.imbue(params)), params(requestCapContext.imbue(params)),
returnMessage(nullptr), returnMessage(nullptr),
redirectResults(redirectResults) {} redirectResults(redirectResults),
cancelFulfiller(kj::mv(cancelFulfiller)) {}
~RpcCallContext() noexcept(false) { ~RpcCallContext() noexcept(false) {
if (isFirstResponder()) { if (isFirstResponder()) {
...@@ -2019,9 +2029,9 @@ private: ...@@ -2019,9 +2029,9 @@ private:
if (__atomic_fetch_or(&cancellationFlags, CANCEL_REQUESTED, __ATOMIC_RELAXED) == if (__atomic_fetch_or(&cancellationFlags, CANCEL_REQUESTED, __ATOMIC_RELAXED) ==
CANCEL_ALLOWED) { CANCEL_ALLOWED) {
// We just set CANCEL_REQUESTED, and CANCEL_ALLOWED was already set previously. Schedule // We just set CANCEL_REQUESTED, and CANCEL_ALLOWED was already set previously. Initiate
// the cancellation. // the cancellation.
scheduleCancel(); cancelFulfiller->fulfill();
} }
} }
...@@ -2113,25 +2123,21 @@ private: ...@@ -2113,25 +2123,21 @@ private:
return kj::mv(paf.promise); return kj::mv(paf.promise);
} }
void allowAsyncCancellation() override { void allowAsyncCancellation() override {
if (threadAcceptingCancellation == nullptr) { // TODO(cleanup): We need to drop the request because it is holding on to the resolution
// TODO(cleanup): We need to drop the request because it is holding on to the resolution // chain which in turn holds on to the pipeline which holds on to this object thus
// chain which in turn holds on to the pipeline which holds on to this object thus // preventing cancellation from working. This is a bit silly because obviously our
// preventing cancellation from working. This is a bit silly because obviously our // request couldn't contain PromisedAnswers referring to itself, but currently the chain
// request couldn't contain PromisedAnswers referring to itself, but currently the chain // is a linear list and we have no way to tell that a reference to the chain taken before
// is a linear list and we have no way to tell that a reference to the chain taken before // a call started doesn't really need to hold the call open. To fix this we'd presumably
// a call started doesn't really need to hold the call open. To fix this we'd presumably // need to make the answer table snapshot-able and have CapExtractorImpl take a snapshot
// need to make the answer table snapshot-able and have CapExtractorImpl take a snapshot // at creation.
// at creation. releaseParams();
releaseParams();
if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) ==
threadAcceptingCancellation = &kj::EventLoop::current(); CANCEL_REQUESTED) {
// We just set CANCEL_ALLOWED, and CANCEL_REQUESTED was already set previously. Initiate
if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) == // the cancellation.
CANCEL_REQUESTED) { cancelFulfiller->fulfill();
// We just set CANCEL_ALLOWED, and CANCEL_REQUESTED was already set previously. Schedule
// the cancellation.
scheduleCancel();
}
} }
} }
bool isCanceled() override { bool isCanceled() override {
...@@ -2171,47 +2177,15 @@ private: ...@@ -2171,47 +2177,15 @@ private:
// When both flags are set, the cancellation process will begin. Must be manipulated atomically // When both flags are set, the cancellation process will begin. Must be manipulated atomically
// as it may be accessed from multiple threads. // as it may be accessed from multiple threads.
kj::EventLoop* threadAcceptingCancellation = nullptr; mutable kj::Own<kj::PromiseFulfiller<void>> cancelFulfiller;
// EventLoop for the thread that first called allowAsyncCancellation(). We store this as an // Fulfilled when cancellation has been both requested and permitted. The fulfilled promise is
// optimization: if the application thread is independent from the network thread, we'd rather // exclusive-joined with the outermost promise waiting on the call return, so fulfilling it
// perform the cancellation in the application thread, because otherwise we might block waiting // cancels that promise.
// on an application promise continuation callback to finish executing, which could take
// arbitrary time.
kj::UnwindDetector unwindDetector; kj::UnwindDetector unwindDetector;
// ----------------------------------------------------- // -----------------------------------------------------
void scheduleCancel() const {
// Arranges for the answer's asyncOp to be deleted, thus canceling all processing related to
// this call, shortly. We have to do it asynchronously because the caller might hold
// arbitrary locks or might in fact be part of the task being canceled.
connectionState->tasks.add(threadAcceptingCancellation->evalLater(
kj::mvCapture(kj::addRef(*this), [](kj::Own<const RpcCallContext>&& self) {
// Extract from the answer table the promise representing the executing call.
kj::Promise<void> asyncOp = nullptr;
{
auto lock = self->connectionState->tables.lockExclusive();
asyncOp = kj::mv(lock->answers[self->questionId].asyncOp);
}
// When `asyncOp` goes out of scope, if it holds the last reference to the ongoing
// operation, that operation will be canceled. Note that if a continuation is
// running in another thread, the destructor will block waiting for it to complete. This
// is why we try to schedule doCancel() on the application thread, so that it won't need
// to block.
// The `Return` will be sent when the context is destroyed. That might be right now, when
// `self` and `asyncOp` go out of scope. However, it is also possible that the pipeline
// is still in use: although `Finish` removes the pipeline reference from the answer
// table, it might be held by an outstanding pipelined call, or by a pipelined promise that
// was echoed back to us later (as a `receiverAnswer` in a `CapDescriptor`), or it may be
// held in the resolution chain. In all of these cases, the call will continue running
// until those references are dropped or the call completes.
})));
}
bool isFirstResponder() { bool isFirstResponder() {
if (responseSent) { if (responseSent) {
return false; return false;
...@@ -2402,13 +2376,15 @@ private: ...@@ -2402,13 +2376,15 @@ private:
KJ_FAIL_REQUIRE("Unsupported `Call.sendResultsTo`.") { return; } KJ_FAIL_REQUIRE("Unsupported `Call.sendResultsTo`.") { return; }
} }
auto cancelPaf = kj::newPromiseAndFulfiller<void>();
QuestionId questionId = call.getQuestionId(); QuestionId questionId = call.getQuestionId();
// Note: resolutionChainTail couldn't possibly be changing here because we only handle one // Note: resolutionChainTail couldn't possibly be changing here because we only handle one
// message at a time, so we can hold off locking the tables for a bit longer. // message at a time, so we can hold off locking the tables for a bit longer.
auto context = kj::refcounted<RpcCallContext>( auto context = kj::refcounted<RpcCallContext>(
*this, questionId, kj::mv(message), call.getParams(), *this, questionId, kj::mv(message), call.getParams(),
kj::addRef(*tables.getWithoutLock().resolutionChainTail), kj::addRef(*tables.getWithoutLock().resolutionChainTail),
redirectResults); redirectResults, kj::mv(cancelPaf.fulfiller));
auto promiseAndPipeline = capability->call( auto promiseAndPipeline = capability->call(
call.getInterfaceId(), call.getMethodId(), context->addRef()); call.getInterfaceId(), call.getMethodId(), context->addRef());
...@@ -2432,18 +2408,27 @@ private: ...@@ -2432,18 +2408,27 @@ private:
answer.pipeline = kj::mv(promiseAndPipeline.pipeline); answer.pipeline = kj::mv(promiseAndPipeline.pipeline);
if (redirectResults) { if (redirectResults) {
auto promise = promiseAndPipeline.promise.then( auto resultsPromise = promiseAndPipeline.promise.then(
kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) { kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) {
return context->consumeRedirectedResponse(); return context->consumeRedirectedResponse();
})); }));
promise.eagerlyEvaluate(eventLoop);
answer.redirectedResults = kj::mv(promise); // If the call that later picks up `redirectedResults` decides to discard it, we need to
// make sure our call is not itself canceled unless it has called allowAsyncCancellation().
// So we fork the promise and join one branch with the cancellation promise, in order to
// hold on to it.
auto forked = eventLoop.fork(kj::mv(resultsPromise));
answer.redirectedResults = forked.addBranch();
auto promise = kj::mv(cancelPaf.promise);
promise.exclusiveJoin(forked.addBranch().then([](kj::Own<const RpcResponse>&&){}));
eventLoop.daemonize(kj::mv(promise));
} else { } else {
// Hack: Both the success and error continuations need to use the context. We could // Hack: Both the success and error continuations need to use the context. We could
// refcount, but both will be destroyed at the same time anyway. // refcount, but both will be destroyed at the same time anyway.
RpcCallContext* contextPtr = context; RpcCallContext* contextPtr = context;
answer.asyncOp = promiseAndPipeline.promise.then( auto promise = promiseAndPipeline.promise.then(
[contextPtr]() { [contextPtr]() {
contextPtr->sendReturn(); contextPtr->sendReturn();
}, [contextPtr](kj::Exception&& exception) { }, [contextPtr](kj::Exception&& exception) {
...@@ -2452,8 +2437,9 @@ private: ...@@ -2452,8 +2437,9 @@ private:
// Handle exceptions that occur in sendReturn()/sendErrorReturn(). // Handle exceptions that occur in sendReturn()/sendErrorReturn().
taskFailed(kj::mv(exception)); taskFailed(kj::mv(exception));
}); });
answer.asyncOp.attach(kj::mv(context)); promise.attach(kj::mv(context));
answer.asyncOp.eagerlyEvaluate(eventLoop); promise.exclusiveJoin(kj::mv(cancelPaf.promise));
eventLoop.daemonize(kj::mv(promise));
} }
} }
} }
...@@ -2729,8 +2715,6 @@ private: ...@@ -2729,8 +2715,6 @@ private:
} }
void handleDisembargo(const rpc::Disembargo::Reader& disembargo) { void handleDisembargo(const rpc::Disembargo::Reader& disembargo) {
auto lock = tables.lockExclusive();
auto context = disembargo.getContext(); auto context = disembargo.getContext();
switch (context.which()) { switch (context.which()) {
case rpc::Disembargo::Context::SENDER_LOOPBACK: { case rpc::Disembargo::Context::SENDER_LOOPBACK: {
...@@ -2757,34 +2741,42 @@ private: ...@@ -2757,34 +2741,42 @@ private:
return; return;
} }
const RpcClient& downcasted = kj::downcast<const RpcClient>(*target); EmbargoId embargoId = context.getSenderLoopback();
auto message = connection->newOutgoingMessage( // We need to insert an evalLater() here to make sure that any pending calls towards this
messageSizeHint<rpc::Disembargo>() + MESSAGE_TARGET_SIZE_HINT); // cap have had time to find their way through the event loop.
auto builder = message->getBody().initAs<rpc::Message>().initDisembargo(); tasks.add(eventLoop.evalLater(kj::mvCapture(
target, [this,embargoId](kj::Own<const ClientHook>&& target) {
{ const RpcClient& downcasted = kj::downcast<const RpcClient>(*target);
auto redirect = downcasted.writeTarget(builder.initTarget());
auto message = connection->newOutgoingMessage(
// Disembargoes should only be sent to capabilities that were previously the object of messageSizeHint<rpc::Disembargo>() + MESSAGE_TARGET_SIZE_HINT);
// a `Resolve` message. But `writeTarget` only ever returns non-null when called on auto builder = message->getBody().initAs<rpc::Message>().initDisembargo();
// a PromiseClient. The code which sends `Resolve` should have replaced any promise
// with a direct node in order to solve the Tribble 4-way race condition. {
KJ_REQUIRE(redirect == nullptr, auto redirect = downcasted.writeTarget(builder.initTarget());
"'Disembargo' of type 'senderLoopback' sent to an object that does not appear "
"to have been the object of a previous 'Resolve' message.") { // Disembargoes should only be sent to capabilities that were previously the object of
return; // a `Resolve` message. But `writeTarget` only ever returns non-null when called on
// a PromiseClient. The code which sends `Resolve` should have replaced any promise
// with a direct node in order to solve the Tribble 4-way race condition.
KJ_REQUIRE(redirect == nullptr,
"'Disembargo' of type 'senderLoopback' sent to an object that does not "
"appear to have been the object of a previous 'Resolve' message.") {
return;
}
} }
}
builder.getContext().setReceiverLoopback(context.getSenderLoopback()); builder.getContext().setReceiverLoopback(embargoId);
message->send(); message->send();
})));
break; break;
} }
case rpc::Disembargo::Context::RECEIVER_LOOPBACK: case rpc::Disembargo::Context::RECEIVER_LOOPBACK: {
auto lock = tables.lockExclusive();
KJ_IF_MAYBE(embargo, lock->embargoes.find(context.getReceiverLoopback())) { KJ_IF_MAYBE(embargo, lock->embargoes.find(context.getReceiverLoopback())) {
KJ_ASSERT_NONNULL(embargo->fulfiller)->fulfill(); KJ_ASSERT_NONNULL(embargo->fulfiller)->fulfill();
lock->embargoes.erase(context.getReceiverLoopback()); lock->embargoes.erase(context.getReceiverLoopback());
...@@ -2794,6 +2786,7 @@ private: ...@@ -2794,6 +2786,7 @@ private:
} }
} }
break; break;
}
default: default:
KJ_FAIL_REQUIRE("Unimplemented Disembargo type.", disembargo) { return; } KJ_FAIL_REQUIRE("Unimplemented Disembargo type.", disembargo) { return; }
......
...@@ -1052,5 +1052,11 @@ kj::Promise<void> TestMoreStuffImpl::getHeld( ...@@ -1052,5 +1052,11 @@ kj::Promise<void> TestMoreStuffImpl::getHeld(
return kj::READY_NOW; return kj::READY_NOW;
} }
kj::Promise<void> TestMoreStuffImpl::echo(EchoParams::Reader params, EchoResults::Builder result) {
++callCount;
result.setCap(params.getCap());
return kj::READY_NOW;
}
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -237,6 +237,8 @@ public: ...@@ -237,6 +237,8 @@ public:
kj::Promise<void> getHeld(GetHeldParams::Reader params, kj::Promise<void> getHeld(GetHeldParams::Reader params,
GetHeldResults::Builder result) override; GetHeldResults::Builder result) override;
kj::Promise<void> echo(EchoParams::Reader params, EchoResults::Builder result) override;
private: private:
int& callCount; int& callCount;
kj::Own<kj::PromiseFulfiller<void>> neverFulfill; kj::Own<kj::PromiseFulfiller<void>> neverFulfill;
......
...@@ -611,8 +611,10 @@ interface TestPipeline { ...@@ -611,8 +611,10 @@ interface TestPipeline {
} }
interface TestCallOrder { interface TestCallOrder {
getCallSequence @0 () -> (n: UInt32); getCallSequence @0 (expected: UInt32) -> (n: UInt32);
# First call returns 0, next returns 1, ... # First call returns 0, next returns 1, ...
#
# The input `expected` is ignored but useful for disambiguating debug logs.
} }
interface TestTailCallee { interface TestTailCallee {
...@@ -649,6 +651,9 @@ interface TestMoreStuff extends(TestCallOrder) { ...@@ -649,6 +651,9 @@ interface TestMoreStuff extends(TestCallOrder) {
getHeld @5 () -> (cap :TestInterface); getHeld @5 () -> (cap :TestInterface);
# Returns the capability previously held using `hold` (and keeps holding it). # Returns the capability previously held using `hold` (and keeps holding it).
echo @6 (cap :TestCallOrder) -> (cap :TestCallOrder);
# Just returns the input cap.
} }
struct TestSturdyRefHostId { struct TestSturdyRefHostId {
......
...@@ -69,6 +69,91 @@ public: ...@@ -69,6 +69,91 @@ public:
} // namespace } // namespace
namespace _ { // private
class TaskSetImpl {
public:
inline TaskSetImpl(const EventLoop& loop, TaskSet::ErrorHandler& errorHandler)
: loop(loop), errorHandler(errorHandler) {}
~TaskSetImpl() noexcept(false) {
// std::map doesn't like it when elements' destructors throw, so carefully disassemble it.
auto& taskMap = tasks.getWithoutLock();
if (!taskMap.empty()) {
Vector<Own<Task>> deleteMe(taskMap.size());
for (auto& entry: taskMap) {
deleteMe.add(kj::mv(entry.second));
}
}
}
class Task final: public EventLoop::Event {
public:
Task(const TaskSetImpl& taskSet, Own<_::PromiseNode>&& nodeParam)
: EventLoop::Event(taskSet.loop), taskSet(taskSet), node(kj::mv(nodeParam)) {
if (node->onReady(*this)) {
arm();
}
}
~Task() {
disarm();
}
protected:
void fire() override {
// Get the result.
_::ExceptionOr<_::Void> result;
node->get(result);
// Delete the node, catching any exceptions.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
node = nullptr;
})) {
result.addException(kj::mv(*exception));
}
// Call the error handler if there was an exception.
KJ_IF_MAYBE(e, result.exception) {
taskSet.errorHandler.taskFailed(kj::mv(*e));
}
}
private:
const TaskSetImpl& taskSet;
kj::Own<_::PromiseNode> node;
};
void add(Promise<void>&& promise) const {
auto task = heap<Task>(*this, _::makeSafeForLoop<_::Void>(kj::mv(promise.node), loop));
Task* ptr = task;
tasks.lockExclusive()->insert(std::make_pair(ptr, kj::mv(task)));
}
private:
const EventLoop& loop;
TaskSet::ErrorHandler& errorHandler;
// TODO(soon): Use a linked list instead. We should factor out the intrusive linked list code
// that appears in EventLoop and ForkHub.
MutexGuarded<std::map<Task*, Own<Task>>> tasks;
};
class LoggingErrorHandler: public TaskSet::ErrorHandler {
public:
static LoggingErrorHandler instance;
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, "Uncaught exception in daemonized task.", exception);
}
};
LoggingErrorHandler LoggingErrorHandler::instance = LoggingErrorHandler();
} // namespace _ (private)
// =======================================================================================
EventLoop& EventLoop::current() { EventLoop& EventLoop::current() {
EventLoop* result = threadLocalEventLoop; EventLoop* result = threadLocalEventLoop;
KJ_REQUIRE(result != nullptr, "No event loop is running on this thread."); KJ_REQUIRE(result != nullptr, "No event loop is running on this thread.");
...@@ -79,7 +164,10 @@ bool EventLoop::isCurrent() const { ...@@ -79,7 +164,10 @@ bool EventLoop::isCurrent() const {
return threadLocalEventLoop == this; return threadLocalEventLoop == this;
} }
EventLoop::EventLoop() {} EventLoop::EventLoop()
: daemons(kj::heap<_::TaskSetImpl>(*this, _::LoggingErrorHandler::instance)) {}
EventLoop::~EventLoop() noexcept(false) {}
void EventLoop::waitImpl(Own<_::PromiseNode> node, _::ExceptionOrValue& result) { void EventLoop::waitImpl(Own<_::PromiseNode> node, _::ExceptionOrValue& result) {
EventLoop* oldEventLoop = threadLocalEventLoop; EventLoop* oldEventLoop = threadLocalEventLoop;
...@@ -119,6 +207,10 @@ void EventLoop::receivedNewJob() const { ...@@ -119,6 +207,10 @@ void EventLoop::receivedNewJob() const {
wake(); wake();
} }
void EventLoop::daemonize(kj::Promise<void>&& promise) const {
daemons->add(kj::mv(promise));
}
EventLoop::Event::Event(const EventLoop& loop) EventLoop::Event::Event(const EventLoop& loop)
: loop(loop), : loop(loop),
jobs { loop.queue.createJob(*this), loop.queue.createJob(*this) } {} jobs { loop.queue.createJob(*this), loop.queue.createJob(*this) } {}
...@@ -227,76 +319,8 @@ void PromiseBase::absolve() { ...@@ -227,76 +319,8 @@ void PromiseBase::absolve() {
runCatchingExceptions([this]() { node = nullptr; }); runCatchingExceptions([this]() { node = nullptr; });
} }
class TaskSet::Impl {
public:
inline Impl(const EventLoop& loop, ErrorHandler& errorHandler)
: loop(loop), errorHandler(errorHandler) {}
~Impl() noexcept(false) {
// std::map doesn't like it when elements' destructors throw, so carefully disassemble it.
auto& taskMap = tasks.getWithoutLock();
if (!taskMap.empty()) {
Vector<Own<Task>> deleteMe(taskMap.size());
for (auto& entry: taskMap) {
deleteMe.add(kj::mv(entry.second));
}
}
}
class Task final: public EventLoop::Event {
public:
Task(const Impl& taskSet, Own<_::PromiseNode>&& nodeParam)
: EventLoop::Event(taskSet.loop), taskSet(taskSet), node(kj::mv(nodeParam)) {
if (node->onReady(*this)) {
arm();
}
}
~Task() {
disarm();
}
protected:
void fire() override {
// Get the result.
_::ExceptionOr<_::Void> result;
node->get(result);
// Delete the node, catching any exceptions.
KJ_IF_MAYBE(exception, runCatchingExceptions([this]() {
node = nullptr;
})) {
result.addException(kj::mv(*exception));
}
// Call the error handler if there was an exception.
KJ_IF_MAYBE(e, result.exception) {
taskSet.errorHandler.taskFailed(kj::mv(*e));
}
}
private:
const Impl& taskSet;
kj::Own<_::PromiseNode> node;
};
void add(Promise<void>&& promise) const {
auto task = heap<Task>(*this, _::makeSafeForLoop<_::Void>(kj::mv(promise.node), loop));
Task* ptr = task;
tasks.lockExclusive()->insert(std::make_pair(ptr, kj::mv(task)));
}
private:
const EventLoop& loop;
ErrorHandler& errorHandler;
// TODO(soon): Use a linked list instead. We should factor out the intrusive linked list code
// that appears in EventLoop and ForkHub.
MutexGuarded<std::map<Task*, Own<Task>>> tasks;
};
TaskSet::TaskSet(const EventLoop& loop, ErrorHandler& errorHandler) TaskSet::TaskSet(const EventLoop& loop, ErrorHandler& errorHandler)
: impl(heap<Impl>(loop, errorHandler)) {} : impl(heap<_::TaskSetImpl>(loop, errorHandler)) {}
TaskSet::~TaskSet() noexcept(false) {} TaskSet::~TaskSet() noexcept(false) {}
......
...@@ -182,6 +182,8 @@ class ChainPromiseNode; ...@@ -182,6 +182,8 @@ class ChainPromiseNode;
template <typename T> template <typename T>
class ForkHub; class ForkHub;
class TaskSetImpl;
} // namespace _ (private) } // namespace _ (private)
// ======================================================================================= // =======================================================================================
...@@ -228,6 +230,7 @@ class EventLoop: private _::NewJobCallback { ...@@ -228,6 +230,7 @@ class EventLoop: private _::NewJobCallback {
public: public:
EventLoop(); EventLoop();
~EventLoop() noexcept(false);
static EventLoop& current(); static EventLoop& current();
// Get the event loop for the current thread. Throws an exception if no event loop is active. // Get the event loop for the current thread. Throws an exception if no event loop is active.
...@@ -298,6 +301,15 @@ public: ...@@ -298,6 +301,15 @@ public:
Promise<T> exclusiveJoin(Promise<T>&& promise1, Promise<T>&& promise2) const; Promise<T> exclusiveJoin(Promise<T>&& promise1, Promise<T>&& promise2) const;
// Like `promise1.exclusiveJoin(promise2)`, returning the joined promise. // Like `promise1.exclusiveJoin(promise2)`, returning the joined promise.
void daemonize(kj::Promise<void>&& promise) const;
// Allows the given promise to continue running in the background until it completes or the
// `EventLoop` is destroyed. Be careful when using this: you need to make sure that the promise
// owns all the objects it touches or make sure those objects outlive the EventLoop. Also, be
// careful about error handling: exceptions will merely be logged with KJ_LOG(ERROR, ...).
//
// This method exists mainly to implement the Cap'n Proto requirement that RPC calls cannot be
// canceled unless the callee explicitly permits it.
// ----------------------------------------------------------------- // -----------------------------------------------------------------
// Low-level interface. // Low-level interface.
...@@ -387,6 +399,8 @@ private: ...@@ -387,6 +399,8 @@ private:
Maybe<_::WorkQueue<EventJob>::JobWrapper&> insertionPoint; Maybe<_::WorkQueue<EventJob>::JobWrapper&> insertionPoint;
// Where to insert preemptively-scheduled events into the queue. // Where to insert preemptively-scheduled events into the queue.
Own<_::TaskSetImpl> daemons;
template <typename T, typename Func, typename ErrorFunc> template <typename T, typename Func, typename ErrorFunc>
Own<_::PromiseNode> thereImpl(Promise<T>&& promise, Func&& func, ErrorFunc&& errorHandler) const; Own<_::PromiseNode> thereImpl(Promise<T>&& promise, Func&& func, ErrorFunc&& errorHandler) const;
// Shared implementation of there() and Promise::then(). // Shared implementation of there() and Promise::then().
...@@ -456,7 +470,7 @@ private: ...@@ -456,7 +470,7 @@ private:
friend class _::ChainPromiseNode; friend class _::ChainPromiseNode;
template <typename> template <typename>
friend class Promise; friend class Promise;
friend class TaskSet; friend class _::TaskSetImpl;
}; };
template <typename T> template <typename T>
...@@ -763,8 +777,7 @@ public: ...@@ -763,8 +777,7 @@ public:
void add(Promise<void>&& promise) const; void add(Promise<void>&& promise) const;
private: private:
class Impl; Own<_::TaskSetImpl> impl;
Own<Impl> impl;
}; };
constexpr _::Void READY_NOW = _::Void(); constexpr _::Void READY_NOW = _::Void();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment