Commit e2905da7 authored by Kenton Varda's avatar Kenton Varda

Test and fix embargoes.

parent 6b8b8c71
......@@ -207,7 +207,13 @@ class QueuedPipeline final: public PipelineHook, public kj::Refcounted {
public:
QueuedPipeline(const kj::EventLoop& loop, kj::Promise<kj::Own<const PipelineHook>>&& promise)
: loop(loop),
promise(loop.fork(kj::mv(promise))) {}
promise(loop.fork(kj::mv(promise))),
selfResolutionOp(loop.there(this->promise.addBranch(),
[this](kj::Own<const PipelineHook>&& inner) {
*redirect.lockExclusive() = kj::mv(inner);
})) {
selfResolutionOp.eagerlyEvaluate(loop);
}
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
......@@ -226,6 +232,12 @@ public:
private:
const kj::EventLoop& loop;
kj::ForkedPromise<kj::Own<const PipelineHook>> promise;
kj::MutexGuarded<kj::Maybe<kj::Own<const PipelineHook>>> redirect;
// Once the promise resolves, this will become non-null and point to the underlying object.
kj::Promise<void> selfResolutionOp;
// Represents the operation which will set `redirect` when possible.
};
class QueuedClient final: public ClientHook, public kj::Refcounted {
......@@ -371,12 +383,18 @@ private:
};
kj::Own<const ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) const {
auto lock = redirect.lockShared();
KJ_IF_MAYBE(redirect, *lock) {
return redirect->get()->getPipelinedCap(kj::mv(ops));
} else {
auto clientPromise = loop.there(promise.addBranch(), kj::mvCapture(ops,
[](kj::Array<PipelineOp>&& ops, kj::Own<const PipelineHook> pipeline) {
return pipeline->getPipelinedCap(kj::mv(ops));
}));
return kj::refcounted<QueuedClient>(loop, kj::mv(clientPromise));
}
}
// =======================================================================================
......
......@@ -24,7 +24,9 @@
#include "rpc.h"
#include "capability-context.h"
#include "test-util.h"
#include "schema.h"
#include <kj/debug.h>
#include <kj/string-tree.h>
#include <gtest/gtest.h>
#include <capnp/rpc.capnp.h>
#include <map>
......@@ -34,10 +36,148 @@ namespace capnp {
namespace _ { // private
namespace {
class RpcDumper {
// Class which stringifies RPC messages for debugging purposes, including decoding params and
// results based on the call's interface and method IDs and extracting cap descriptors.
public:
void addSchema(InterfaceSchema schema) {
schemas[schema.getProto().getId()] = schema;
}
enum Sender {
CLIENT,
SERVER
};
kj::String dump(rpc::Message::Reader message, Sender sender) {
const char* senderName = sender == CLIENT ? "client" : "server";
switch (message.which()) {
case rpc::Message::CALL: {
auto call = message.getCall();
auto iter = schemas.find(call.getInterfaceId());
if (iter == schemas.end()) {
break;
}
InterfaceSchema schema = iter->second;
auto methods = schema.getMethods();
if (call.getMethodId() >= methods.size()) {
break;
}
InterfaceSchema::Method method = methods[call.getMethodId()];
auto schemaProto = schema.getProto();
auto interfaceName =
schemaProto.getDisplayName().slice(schemaProto.getDisplayNamePrefixLength());
auto methodProto = method.getProto();
auto paramType = schema.getDependency(methodProto.getParamStructType()).asStruct();
auto resultType = schema.getDependency(methodProto.getResultStructType()).asStruct();
returnTypes[std::make_pair(sender, call.getQuestionId())] = resultType;
CapExtractorImpl extractor;
CapReaderContext context(extractor);
auto params = kj::str(context.imbue(call.getParams()).getAs<DynamicStruct>(paramType));
auto sendResultsTo = call.getSendResultsTo();
return kj::str(senderName, "(", call.getQuestionId(), "): call ",
call.getTarget(), " <- ", interfaceName, ".",
methodProto.getName(), params,
" caps:[", extractor.printCaps(), "]",
sendResultsTo.isCaller() ? kj::str()
: kj::str(" sendResultsTo:", sendResultsTo));
}
case rpc::Message::RETURN: {
auto ret = message.getReturn();
auto iter = returnTypes.find(
std::make_pair(sender == CLIENT ? SERVER : CLIENT, ret.getQuestionId()));
if (iter == returnTypes.end()) {
break;
}
auto schema = iter->second;
returnTypes.erase(iter);
if (ret.which() != rpc::Return::RESULTS) {
// Oops, no results returned. We don't check this earlier because we want to make sure
// returnTypes.erase() gets a chance to happen.
break;
}
CapExtractorImpl extractor;
CapReaderContext context(extractor);
auto imbued = context.imbue(ret.getResults());
if (schema.getProto().isStruct()) {
auto results = kj::str(imbued.getAs<DynamicStruct>(schema.asStruct()));
return kj::str(senderName, "(", ret.getQuestionId(), "): return ", results,
" caps:[", extractor.printCaps(), "]");
} else if (schema.getProto().isInterface()) {
imbued.getAs<DynamicCapability>(schema.asInterface());
return kj::str(senderName, "(", ret.getQuestionId(), "): return cap ",
extractor.printCaps());
} else {
break;
}
}
case rpc::Message::RESTORE: {
auto restore = message.getRestore();
returnTypes[std::make_pair(sender, restore.getQuestionId())] = InterfaceSchema();
return kj::str(senderName, "(", restore.getQuestionId(), "): restore ",
restore.getObjectId().getAs<test::TestSturdyRefObjectId>());
}
default:
break;
}
return kj::str(senderName, ": ", message);
}
private:
std::map<uint64_t, InterfaceSchema> schemas;
std::map<std::pair<Sender, uint32_t>, Schema> returnTypes;
class CapExtractorImpl: public CapExtractor<rpc::CapDescriptor> {
public:
kj::Own<const ClientHook> extractCap(rpc::CapDescriptor::Reader descriptor) const {
caps.add(kj::str(descriptor));
return newBrokenCap("fake cap");
}
kj::String printCaps() {
return kj::strArray(caps, ", ");
}
private:
mutable kj::Vector<kj::String> caps;
};
};
// =======================================================================================
class TestNetworkAdapter;
class TestNetwork {
public:
TestNetwork(kj::EventLoop& loop): loop(loop) {
dumper.addSchema(Schema::from<test::TestInterface>());
dumper.addSchema(Schema::from<test::TestExtends>());
dumper.addSchema(Schema::from<test::TestPipeline>());
dumper.addSchema(Schema::from<test::TestCallOrder>());
dumper.addSchema(Schema::from<test::TestTailCallee>());
dumper.addSchema(Schema::from<test::TestTailCaller>());
dumper.addSchema(Schema::from<test::TestMoreStuff>());
}
~TestNetwork() noexcept(false);
TestNetworkAdapter& add(kj::StringPtr name);
......@@ -51,7 +191,10 @@ public:
}
}
RpcDumper dumper;
private:
kj::EventLoop& loop;
std::map<kj::StringPtr, kj::Own<TestNetworkAdapter>> map;
};
......@@ -61,16 +204,18 @@ typedef VatNetwork<
class TestNetworkAdapter final: public TestNetworkAdapterBase {
public:
TestNetworkAdapter(TestNetwork& network): network(network) {}
TestNetworkAdapter(kj::EventLoop& loop, TestNetwork& network): loop(loop), network(network) {}
uint getSentCount() { return sent; }
uint getReceivedCount() { return received; }
typedef TestNetworkAdapterBase::Connection Connection;
class ConnectionImpl final: public Connection, public kj::Refcounted {
class ConnectionImpl final
: public Connection, public kj::Refcounted, public kj::TaskSet::ErrorHandler {
public:
ConnectionImpl(TestNetworkAdapter& network, const char* name): network(network), name(name) {}
ConnectionImpl(TestNetworkAdapter& network, RpcDumper::Sender sender)
: network(network), sender(sender), tasks(network.loop, *this) {}
void attach(ConnectionImpl& other) {
KJ_REQUIRE(partner == nullptr);
......@@ -79,7 +224,7 @@ public:
other.partner = *this;
}
class IncomingRpcMessageImpl final: public IncomingRpcMessage {
class IncomingRpcMessageImpl final: public IncomingRpcMessage, public kj::Refcounted {
public:
IncomingRpcMessageImpl(uint firstSegmentWordSize)
: message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS
......@@ -96,7 +241,7 @@ public:
public:
OutgoingRpcMessageImpl(const ConnectionImpl& connection, uint firstSegmentWordSize)
: connection(connection),
message(kj::heap<IncomingRpcMessageImpl>(firstSegmentWordSize)) {}
message(kj::refcounted<IncomingRpcMessageImpl>(firstSegmentWordSize)) {}
ObjectPointer::Builder getBody() override {
return message->message.getRoot<ObjectPointer>();
......@@ -104,19 +249,26 @@ public:
void send() override {
++connection.network.sent;
kj::String msg = kj::str(connection.name, ": ", message->message.getRoot<rpc::Message>());
//KJ_DBG(msg);
// Uncomment to get a debug dump.
// kj::String msg = connection.network.network.dumper.dump(
// message->message.getRoot<rpc::Message>(), connection.sender);
// KJ_ DBG(msg);
KJ_IF_MAYBE(p, connection.partner) {
auto connectionPtr = &connection;
connection.tasks.add(connection.network.loop.evalLater(
kj::mvCapture(kj::addRef(*message),
[connectionPtr](kj::Own<IncomingRpcMessageImpl>&& message) {
KJ_IF_MAYBE(p, connectionPtr->partner) {
auto lock = p->queues.lockExclusive();
if (lock->fulfillers.empty()) {
lock->messages.push(kj::mv(message));
} else {
++connection.network.received;
++connectionPtr->network.received;
lock->fulfillers.front()->fulfill(kj::Own<IncomingRpcMessage>(kj::mv(message)));
lock->fulfillers.pop();
}
}
})));
}
private:
......@@ -154,9 +306,13 @@ public:
KJ_FAIL_ASSERT("not implemented");
}
void taskFailed(kj::Exception&& exception) override {
ADD_FAILURE() << kj::str(exception).cStr();
}
private:
TestNetworkAdapter& network;
const char* name;
RpcDumper::Sender sender;
kj::Maybe<ConnectionImpl&> partner;
struct Queues {
......@@ -164,6 +320,8 @@ public:
std::queue<kj::Own<IncomingRpcMessage>> messages;
};
kj::MutexGuarded<Queues> queues;
kj::TaskSet tasks;
};
kj::Maybe<kj::Own<Connection>> connectToRefHost(
......@@ -183,8 +341,8 @@ public:
auto iter = myLock->connections.find(&dst);
if (iter == myLock->connections.end()) {
auto local = kj::refcounted<ConnectionImpl>(*this, "client");
auto remote = kj::refcounted<ConnectionImpl>(dst, "server");
auto local = kj::refcounted<ConnectionImpl>(*this, RpcDumper::CLIENT);
auto remote = kj::refcounted<ConnectionImpl>(dst, RpcDumper::SERVER);
local->attach(*remote);
myLock->connections[&dst] = kj::addRef(*local);
......@@ -217,6 +375,7 @@ public:
}
private:
kj::EventLoop& loop;
TestNetwork& network;
uint sent = 0;
uint received = 0;
......@@ -232,7 +391,7 @@ private:
TestNetwork::~TestNetwork() noexcept(false) {}
TestNetworkAdapter& TestNetwork::add(kj::StringPtr name) {
return *(map[name] = kj::heap<TestNetworkAdapter>(*this));
return *(map[name] = kj::heap<TestNetworkAdapter>(loop, *this));
}
// =======================================================================================
......@@ -262,9 +421,9 @@ public:
class RpcTest: public testing::Test {
protected:
kj::SimpleEventLoop loop;
TestNetwork network;
TestRestorer restorer;
kj::SimpleEventLoop loop;
TestNetworkAdapter& clientNetwork;
TestNetworkAdapter& serverNetwork;
RpcSystem<test::TestSturdyRefHostId> rpcClient;
......@@ -281,7 +440,8 @@ protected:
}
RpcTest()
: clientNetwork(network.add("client")),
: network(loop),
clientNetwork(network.add("client")),
serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork, loop)),
rpcServer(makeRpcServer(serverNetwork, restorer, loop)) {}
......@@ -587,6 +747,48 @@ TEST_F(RpcTest, SendTwice) {
EXPECT_TRUE(destroyed);
}
RemotePromise<test::TestCallOrder::GetCallSequenceResults> getCallSequence(
const test::TestCallOrder::Client& client, uint expected) {
auto req = client.getCallSequenceRequest();
req.setExpected(expected);
return req.send();
}
TEST_F(RpcTest, Embargo) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>();
auto cap = test::TestCallOrder::Client(kj::heap<TestCallOrderImpl>(), loop);
auto earlyCall = client.getCallSequenceRequest().send();
auto echoRequest = client.echoRequest();
echoRequest.setCap(cap);
auto echo = echoRequest.send();
auto pipeline = echo.getCap();
auto call0 = getCallSequence(pipeline, 0);
auto call1 = getCallSequence(pipeline, 1);
loop.wait(kj::mv(earlyCall));
auto call2 = getCallSequence(pipeline, 2);
auto resolved = loop.wait(kj::mv(echo)).getCap();
auto call3 = getCallSequence(pipeline, 3);
auto call4 = getCallSequence(pipeline, 4);
auto call5 = getCallSequence(pipeline, 5);
EXPECT_EQ(0, loop.wait(kj::mv(call0)).getN());
EXPECT_EQ(1, loop.wait(kj::mv(call1)).getN());
EXPECT_EQ(2, loop.wait(kj::mv(call2)).getN());
EXPECT_EQ(3, loop.wait(kj::mv(call3)).getN());
EXPECT_EQ(4, loop.wait(kj::mv(call4)).getN());
EXPECT_EQ(5, loop.wait(kj::mv(call5)).getN());
}
} // namespace
} // namespace _ (private)
} // namespace capnp
......@@ -298,6 +298,7 @@ public:
kj::Vector<kj::Own<const PipelineHook>> pipelinesToRelease;
kj::Vector<kj::Own<const ClientHook>> clientsToRelease;
kj::Vector<kj::Own<CapInjectorImpl>> paramCapsToRelease;
kj::Vector<kj::Promise<kj::Own<const RpcResponse>>> promisesToRelease;
auto lock = tables.lockExclusive();
......@@ -330,9 +331,7 @@ public:
}
KJ_IF_MAYBE(promise, answer.redirectedResults) {
// Answer contains a result redirection that hasn't been picked up yet. Make the call
// properly cancelable by transforming the redirect promise into a regular asyncOp.
answer.asyncOp = promise->thenInAnyThread([](kj::Own<const RpcResponse>&& response) {});
promisesToRelease.add(kj::mv(*promise));
}
KJ_IF_MAYBE(context, answer.callContext) {
......@@ -413,9 +412,6 @@ private:
kj::Maybe<kj::Own<const PipelineHook>> pipeline;
// Send pipelined calls here. Becomes null as soon as a `Finish` is received.
kj::Promise<void> asyncOp = kj::Promise<void>(nullptr);
// Delete this promise to cancel the call. For redirected calls, this is null.
kj::Maybe<kj::Promise<kj::Own<const RpcResponse>>> redirectedResults;
// For locally-redirected calls (Call.sendResultsTo.yourself), this is a promise for the call
// result, to be picked up by a subsequent `Return`.
......@@ -885,6 +881,10 @@ private:
// the `PromiseClient` is destroyed; `eventual` must therefore make sure to hold references to
// anything that needs to stay alive in order to resolve it correctly (such as making sure the
// import ID is not released).
resolveSelfPromise = connectionState.eventLoop.there(kj::mv(resolveSelfPromise),
[]() {}, [&](kj::Exception&& e) { connectionState.tasks.add(kj::mv(e)); });
resolveSelfPromise.eagerlyEvaluate(connectionState.eventLoop);
}
......@@ -907,6 +907,7 @@ private:
kj::Maybe<ExportId> writeDescriptor(
rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return connectionState->writeDescriptor(*inner.lockExclusive()->cap, descriptor, tables);
}
......@@ -917,6 +918,7 @@ private:
}
kj::Own<const ClientHook> getInnermostClient() const override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return connectionState->getInnermostClient(*inner.lockExclusive()->cap);
}
......@@ -924,11 +926,13 @@ private:
Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return inner.lockExclusive()->cap->newCall(interfaceId, methodId, firstSegmentWordSize);
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) const override {
__atomic_store_n(&receivedCall, true, __ATOMIC_RELAXED);
return inner.lockExclusive()->cap->call(interfaceId, methodId, kj::mv(context));
}
......@@ -962,7 +966,7 @@ private:
mutable bool receivedCall = false;
void resolve(kj::Own<const ClientHook> replacement) {
if (replacement->getBrand() != this &&
if (replacement->getBrand() != connectionState.get() &&
__atomic_load_n(&receivedCall, __ATOMIC_RELAXED)) {
// The new capability is hosted locally, not on the remote machine. And, we had made calls
// to the promise. We need to make sure those calls echo back to us before we allow new
......@@ -972,7 +976,7 @@ private:
auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Disembargo>() + MESSAGE_TARGET_SIZE_HINT);
auto disembargo = message->getBody().initAs<rpc::Message>().getDisembargo();
auto disembargo = message->getBody().initAs<rpc::Message>().initDisembargo();
{
auto redirect = connectionState->writeTarget(
......@@ -1511,6 +1515,10 @@ private:
resultCaps(connectionState, kj::mv(resolutionChain)) {}
~QuestionRef() {
if (connectionState->tables.lockShared()->networkException != nullptr) {
return;
}
// Send the "Finish" message.
{
auto message = connectionState->connection->newOutgoingMessage(
......@@ -1922,7 +1930,8 @@ private:
public:
RpcCallContext(const RpcConnectionState& connectionState, QuestionId questionId,
kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params,
kj::Own<const ResolutionChain> resolutionChain, bool redirectResults)
kj::Own<const ResolutionChain> resolutionChain, bool redirectResults,
kj::Own<kj::PromiseFulfiller<void>>&& cancelFulfiller)
: connectionState(kj::addRef(connectionState)),
questionId(questionId),
request(kj::mv(request)),
......@@ -1930,7 +1939,8 @@ private:
requestCapContext(requestCapExtractor),
params(requestCapContext.imbue(params)),
returnMessage(nullptr),
redirectResults(redirectResults) {}
redirectResults(redirectResults),
cancelFulfiller(kj::mv(cancelFulfiller)) {}
~RpcCallContext() noexcept(false) {
if (isFirstResponder()) {
......@@ -2019,9 +2029,9 @@ private:
if (__atomic_fetch_or(&cancellationFlags, CANCEL_REQUESTED, __ATOMIC_RELAXED) ==
CANCEL_ALLOWED) {
// We just set CANCEL_REQUESTED, and CANCEL_ALLOWED was already set previously. Schedule
// We just set CANCEL_REQUESTED, and CANCEL_ALLOWED was already set previously. Initiate
// the cancellation.
scheduleCancel();
cancelFulfiller->fulfill();
}
}
......@@ -2113,7 +2123,6 @@ private:
return kj::mv(paf.promise);
}
void allowAsyncCancellation() override {
if (threadAcceptingCancellation == nullptr) {
// TODO(cleanup): We need to drop the request because it is holding on to the resolution
// chain which in turn holds on to the pipeline which holds on to this object thus
// preventing cancellation from working. This is a bit silly because obviously our
......@@ -2124,14 +2133,11 @@ private:
// at creation.
releaseParams();
threadAcceptingCancellation = &kj::EventLoop::current();
if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) ==
CANCEL_REQUESTED) {
// We just set CANCEL_ALLOWED, and CANCEL_REQUESTED was already set previously. Schedule
// We just set CANCEL_ALLOWED, and CANCEL_REQUESTED was already set previously. Initiate
// the cancellation.
scheduleCancel();
}
cancelFulfiller->fulfill();
}
}
bool isCanceled() override {
......@@ -2171,47 +2177,15 @@ private:
// When both flags are set, the cancellation process will begin. Must be manipulated atomically
// as it may be accessed from multiple threads.
kj::EventLoop* threadAcceptingCancellation = nullptr;
// EventLoop for the thread that first called allowAsyncCancellation(). We store this as an
// optimization: if the application thread is independent from the network thread, we'd rather
// perform the cancellation in the application thread, because otherwise we might block waiting
// on an application promise continuation callback to finish executing, which could take
// arbitrary time.
mutable kj::Own<kj::PromiseFulfiller<void>> cancelFulfiller;
// Fulfilled when cancellation has been both requested and permitted. The fulfilled promise is
// exclusive-joined with the outermost promise waiting on the call return, so fulfilling it
// cancels that promise.
kj::UnwindDetector unwindDetector;
// -----------------------------------------------------
void scheduleCancel() const {
// Arranges for the answer's asyncOp to be deleted, thus canceling all processing related to
// this call, shortly. We have to do it asynchronously because the caller might hold
// arbitrary locks or might in fact be part of the task being canceled.
connectionState->tasks.add(threadAcceptingCancellation->evalLater(
kj::mvCapture(kj::addRef(*this), [](kj::Own<const RpcCallContext>&& self) {
// Extract from the answer table the promise representing the executing call.
kj::Promise<void> asyncOp = nullptr;
{
auto lock = self->connectionState->tables.lockExclusive();
asyncOp = kj::mv(lock->answers[self->questionId].asyncOp);
}
// When `asyncOp` goes out of scope, if it holds the last reference to the ongoing
// operation, that operation will be canceled. Note that if a continuation is
// running in another thread, the destructor will block waiting for it to complete. This
// is why we try to schedule doCancel() on the application thread, so that it won't need
// to block.
// The `Return` will be sent when the context is destroyed. That might be right now, when
// `self` and `asyncOp` go out of scope. However, it is also possible that the pipeline
// is still in use: although `Finish` removes the pipeline reference from the answer
// table, it might be held by an outstanding pipelined call, or by a pipelined promise that
// was echoed back to us later (as a `receiverAnswer` in a `CapDescriptor`), or it may be
// held in the resolution chain. In all of these cases, the call will continue running
// until those references are dropped or the call completes.
})));
}
bool isFirstResponder() {
if (responseSent) {
return false;
......@@ -2402,13 +2376,15 @@ private:
KJ_FAIL_REQUIRE("Unsupported `Call.sendResultsTo`.") { return; }
}
auto cancelPaf = kj::newPromiseAndFulfiller<void>();
QuestionId questionId = call.getQuestionId();
// Note: resolutionChainTail couldn't possibly be changing here because we only handle one
// message at a time, so we can hold off locking the tables for a bit longer.
auto context = kj::refcounted<RpcCallContext>(
*this, questionId, kj::mv(message), call.getParams(),
kj::addRef(*tables.getWithoutLock().resolutionChainTail),
redirectResults);
redirectResults, kj::mv(cancelPaf.fulfiller));
auto promiseAndPipeline = capability->call(
call.getInterfaceId(), call.getMethodId(), context->addRef());
......@@ -2432,18 +2408,27 @@ private:
answer.pipeline = kj::mv(promiseAndPipeline.pipeline);
if (redirectResults) {
auto promise = promiseAndPipeline.promise.then(
auto resultsPromise = promiseAndPipeline.promise.then(
kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) {
return context->consumeRedirectedResponse();
}));
promise.eagerlyEvaluate(eventLoop);
answer.redirectedResults = kj::mv(promise);
// If the call that later picks up `redirectedResults` decides to discard it, we need to
// make sure our call is not itself canceled unless it has called allowAsyncCancellation().
// So we fork the promise and join one branch with the cancellation promise, in order to
// hold on to it.
auto forked = eventLoop.fork(kj::mv(resultsPromise));
answer.redirectedResults = forked.addBranch();
auto promise = kj::mv(cancelPaf.promise);
promise.exclusiveJoin(forked.addBranch().then([](kj::Own<const RpcResponse>&&){}));
eventLoop.daemonize(kj::mv(promise));
} else {
// Hack: Both the success and error continuations need to use the context. We could
// refcount, but both will be destroyed at the same time anyway.
RpcCallContext* contextPtr = context;
answer.asyncOp = promiseAndPipeline.promise.then(
auto promise = promiseAndPipeline.promise.then(
[contextPtr]() {
contextPtr->sendReturn();
}, [contextPtr](kj::Exception&& exception) {
......@@ -2452,8 +2437,9 @@ private:
// Handle exceptions that occur in sendReturn()/sendErrorReturn().
taskFailed(kj::mv(exception));
});
answer.asyncOp.attach(kj::mv(context));
answer.asyncOp.eagerlyEvaluate(eventLoop);
promise.attach(kj::mv(context));
promise.exclusiveJoin(kj::mv(cancelPaf.promise));
eventLoop.daemonize(kj::mv(promise));
}
}
}
......@@ -2729,8 +2715,6 @@ private:
}
void handleDisembargo(const rpc::Disembargo::Reader& disembargo) {
auto lock = tables.lockExclusive();
auto context = disembargo.getContext();
switch (context.which()) {
case rpc::Disembargo::Context::SENDER_LOOPBACK: {
......@@ -2757,6 +2741,12 @@ private:
return;
}
EmbargoId embargoId = context.getSenderLoopback();
// We need to insert an evalLater() here to make sure that any pending calls towards this
// cap have had time to find their way through the event loop.
tasks.add(eventLoop.evalLater(kj::mvCapture(
target, [this,embargoId](kj::Own<const ClientHook>&& target) {
const RpcClient& downcasted = kj::downcast<const RpcClient>(*target);
auto message = connection->newOutgoingMessage(
......@@ -2771,20 +2761,22 @@ private:
// a PromiseClient. The code which sends `Resolve` should have replaced any promise
// with a direct node in order to solve the Tribble 4-way race condition.
KJ_REQUIRE(redirect == nullptr,
"'Disembargo' of type 'senderLoopback' sent to an object that does not appear "
"to have been the object of a previous 'Resolve' message.") {
"'Disembargo' of type 'senderLoopback' sent to an object that does not "
"appear to have been the object of a previous 'Resolve' message.") {
return;
}
}
builder.getContext().setReceiverLoopback(context.getSenderLoopback());
builder.getContext().setReceiverLoopback(embargoId);
message->send();
})));
break;
}
case rpc::Disembargo::Context::RECEIVER_LOOPBACK:
case rpc::Disembargo::Context::RECEIVER_LOOPBACK: {
auto lock = tables.lockExclusive();
KJ_IF_MAYBE(embargo, lock->embargoes.find(context.getReceiverLoopback())) {
KJ_ASSERT_NONNULL(embargo->fulfiller)->fulfill();
lock->embargoes.erase(context.getReceiverLoopback());
......@@ -2794,6 +2786,7 @@ private:
}
}
break;
}
default:
KJ_FAIL_REQUIRE("Unimplemented Disembargo type.", disembargo) { return; }
......
......@@ -1052,5 +1052,11 @@ kj::Promise<void> TestMoreStuffImpl::getHeld(
return kj::READY_NOW;
}
kj::Promise<void> TestMoreStuffImpl::echo(EchoParams::Reader params, EchoResults::Builder result) {
++callCount;
result.setCap(params.getCap());
return kj::READY_NOW;
}
} // namespace _ (private)
} // namespace capnp
......@@ -237,6 +237,8 @@ public:
kj::Promise<void> getHeld(GetHeldParams::Reader params,
GetHeldResults::Builder result) override;
kj::Promise<void> echo(EchoParams::Reader params, EchoResults::Builder result) override;
private:
int& callCount;
kj::Own<kj::PromiseFulfiller<void>> neverFulfill;
......
......@@ -611,8 +611,10 @@ interface TestPipeline {
}
interface TestCallOrder {
getCallSequence @0 () -> (n: UInt32);
getCallSequence @0 (expected: UInt32) -> (n: UInt32);
# First call returns 0, next returns 1, ...
#
# The input `expected` is ignored but useful for disambiguating debug logs.
}
interface TestTailCallee {
......@@ -649,6 +651,9 @@ interface TestMoreStuff extends(TestCallOrder) {
getHeld @5 () -> (cap :TestInterface);
# Returns the capability previously held using `hold` (and keeps holding it).
echo @6 (cap :TestCallOrder) -> (cap :TestCallOrder);
# Just returns the input cap.
}
struct TestSturdyRefHostId {
......
......@@ -69,6 +69,91 @@ public:
} // namespace
namespace _ { // private
class TaskSetImpl {
public:
inline TaskSetImpl(const EventLoop& loop, TaskSet::ErrorHandler& errorHandler)
: loop(loop), errorHandler(errorHandler) {}
~TaskSetImpl() noexcept(false) {
// std::map doesn't like it when elements' destructors throw, so carefully disassemble it.
auto& taskMap = tasks.getWithoutLock();
if (!taskMap.empty()) {
Vector<Own<Task>> deleteMe(taskMap.size());
for (auto& entry: taskMap) {
deleteMe.add(kj::mv(entry.second));
}
}
}
class Task final: public EventLoop::Event {
public:
Task(const TaskSetImpl& taskSet, Own<_::PromiseNode>&& nodeParam)
: EventLoop::Event(taskSet.loop), taskSet(taskSet), node(kj::mv(nodeParam)) {
if (node->onReady(*this)) {
arm();
}
}
~Task() {
disarm();
}
protected:
void fire() override {
// Get the result.
_::ExceptionOr<_::Void> result;
node->get(result);
// Delete the node, catching any exceptions.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([this]() {
node = nullptr;
})) {
result.addException(kj::mv(*exception));
}
// Call the error handler if there was an exception.
KJ_IF_MAYBE(e, result.exception) {
taskSet.errorHandler.taskFailed(kj::mv(*e));
}
}
private:
const TaskSetImpl& taskSet;
kj::Own<_::PromiseNode> node;
};
void add(Promise<void>&& promise) const {
auto task = heap<Task>(*this, _::makeSafeForLoop<_::Void>(kj::mv(promise.node), loop));
Task* ptr = task;
tasks.lockExclusive()->insert(std::make_pair(ptr, kj::mv(task)));
}
private:
const EventLoop& loop;
TaskSet::ErrorHandler& errorHandler;
// TODO(soon): Use a linked list instead. We should factor out the intrusive linked list code
// that appears in EventLoop and ForkHub.
MutexGuarded<std::map<Task*, Own<Task>>> tasks;
};
class LoggingErrorHandler: public TaskSet::ErrorHandler {
public:
static LoggingErrorHandler instance;
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, "Uncaught exception in daemonized task.", exception);
}
};
LoggingErrorHandler LoggingErrorHandler::instance = LoggingErrorHandler();
} // namespace _ (private)
// =======================================================================================
EventLoop& EventLoop::current() {
EventLoop* result = threadLocalEventLoop;
KJ_REQUIRE(result != nullptr, "No event loop is running on this thread.");
......@@ -79,7 +164,10 @@ bool EventLoop::isCurrent() const {
return threadLocalEventLoop == this;
}
EventLoop::EventLoop() {}
EventLoop::EventLoop()
: daemons(kj::heap<_::TaskSetImpl>(*this, _::LoggingErrorHandler::instance)) {}
EventLoop::~EventLoop() noexcept(false) {}
void EventLoop::waitImpl(Own<_::PromiseNode> node, _::ExceptionOrValue& result) {
EventLoop* oldEventLoop = threadLocalEventLoop;
......@@ -119,6 +207,10 @@ void EventLoop::receivedNewJob() const {
wake();
}
void EventLoop::daemonize(kj::Promise<void>&& promise) const {
daemons->add(kj::mv(promise));
}
EventLoop::Event::Event(const EventLoop& loop)
: loop(loop),
jobs { loop.queue.createJob(*this), loop.queue.createJob(*this) } {}
......@@ -227,76 +319,8 @@ void PromiseBase::absolve() {
runCatchingExceptions([this]() { node = nullptr; });
}
class TaskSet::Impl {
public:
inline Impl(const EventLoop& loop, ErrorHandler& errorHandler)
: loop(loop), errorHandler(errorHandler) {}
~Impl() noexcept(false) {
// std::map doesn't like it when elements' destructors throw, so carefully disassemble it.
auto& taskMap = tasks.getWithoutLock();
if (!taskMap.empty()) {
Vector<Own<Task>> deleteMe(taskMap.size());
for (auto& entry: taskMap) {
deleteMe.add(kj::mv(entry.second));
}
}
}
class Task final: public EventLoop::Event {
public:
Task(const Impl& taskSet, Own<_::PromiseNode>&& nodeParam)
: EventLoop::Event(taskSet.loop), taskSet(taskSet), node(kj::mv(nodeParam)) {
if (node->onReady(*this)) {
arm();
}
}
~Task() {
disarm();
}
protected:
void fire() override {
// Get the result.
_::ExceptionOr<_::Void> result;
node->get(result);
// Delete the node, catching any exceptions.
KJ_IF_MAYBE(exception, runCatchingExceptions([this]() {
node = nullptr;
})) {
result.addException(kj::mv(*exception));
}
// Call the error handler if there was an exception.
KJ_IF_MAYBE(e, result.exception) {
taskSet.errorHandler.taskFailed(kj::mv(*e));
}
}
private:
const Impl& taskSet;
kj::Own<_::PromiseNode> node;
};
void add(Promise<void>&& promise) const {
auto task = heap<Task>(*this, _::makeSafeForLoop<_::Void>(kj::mv(promise.node), loop));
Task* ptr = task;
tasks.lockExclusive()->insert(std::make_pair(ptr, kj::mv(task)));
}
private:
const EventLoop& loop;
ErrorHandler& errorHandler;
// TODO(soon): Use a linked list instead. We should factor out the intrusive linked list code
// that appears in EventLoop and ForkHub.
MutexGuarded<std::map<Task*, Own<Task>>> tasks;
};
TaskSet::TaskSet(const EventLoop& loop, ErrorHandler& errorHandler)
: impl(heap<Impl>(loop, errorHandler)) {}
: impl(heap<_::TaskSetImpl>(loop, errorHandler)) {}
TaskSet::~TaskSet() noexcept(false) {}
......
......@@ -182,6 +182,8 @@ class ChainPromiseNode;
template <typename T>
class ForkHub;
class TaskSetImpl;
} // namespace _ (private)
// =======================================================================================
......@@ -228,6 +230,7 @@ class EventLoop: private _::NewJobCallback {
public:
EventLoop();
~EventLoop() noexcept(false);
static EventLoop& current();
// Get the event loop for the current thread. Throws an exception if no event loop is active.
......@@ -298,6 +301,15 @@ public:
Promise<T> exclusiveJoin(Promise<T>&& promise1, Promise<T>&& promise2) const;
// Like `promise1.exclusiveJoin(promise2)`, returning the joined promise.
void daemonize(kj::Promise<void>&& promise) const;
// Allows the given promise to continue running in the background until it completes or the
// `EventLoop` is destroyed. Be careful when using this: you need to make sure that the promise
// owns all the objects it touches or make sure those objects outlive the EventLoop. Also, be
// careful about error handling: exceptions will merely be logged with KJ_LOG(ERROR, ...).
//
// This method exists mainly to implement the Cap'n Proto requirement that RPC calls cannot be
// canceled unless the callee explicitly permits it.
// -----------------------------------------------------------------
// Low-level interface.
......@@ -387,6 +399,8 @@ private:
Maybe<_::WorkQueue<EventJob>::JobWrapper&> insertionPoint;
// Where to insert preemptively-scheduled events into the queue.
Own<_::TaskSetImpl> daemons;
template <typename T, typename Func, typename ErrorFunc>
Own<_::PromiseNode> thereImpl(Promise<T>&& promise, Func&& func, ErrorFunc&& errorHandler) const;
// Shared implementation of there() and Promise::then().
......@@ -456,7 +470,7 @@ private:
friend class _::ChainPromiseNode;
template <typename>
friend class Promise;
friend class TaskSet;
friend class _::TaskSetImpl;
};
template <typename T>
......@@ -763,8 +777,7 @@ public:
void add(Promise<void>&& promise) const;
private:
class Impl;
Own<Impl> impl;
Own<_::TaskSetImpl> impl;
};
constexpr _::Void READY_NOW = _::Void();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment