Commit a30c4171 authored by Kenton Varda's avatar Kenton Varda

Finish other end of RPC tail call implementation.

parent 85a2ec20
......@@ -188,6 +188,7 @@ public:
AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY);
inline ObjectPointer::Builder getRoot() { return root; }
inline ObjectPointer::Reader getRootReader() const { return root.asReader(); }
private:
MallocMessageBuilder message;
......
......@@ -256,7 +256,7 @@ public:
kj::Own<const ClientHook> restore(ObjectPointer::Reader objectId) {
QuestionId questionId;
kj::Own<QuestionRef> questionRef;
auto paf = kj::newPromiseAndFulfiller<kj::Own<const RpcResponse>>(eventLoop);
auto paf = kj::newPromiseAndFulfiller<kj::Promise<kj::Own<const RpcResponse>>>(eventLoop);
{
auto lock = tables.lockExclusive();
......@@ -329,6 +329,12 @@ public:
pipelinesToRelease.add(kj::mv(*p));
}
KJ_IF_MAYBE(promise, answer.redirectedResults) {
// Answer contains a result redirection that hasn't been picked up yet. Make the call
// properly cancelable by transforming the redirect promise into a regular asyncOp.
answer.asyncOp = promise->thenInAnyThread([](kj::Own<const RpcResponse>&& response) {});
}
KJ_IF_MAYBE(context, answer.callContext) {
context->requestCancel();
}
......@@ -408,7 +414,11 @@ private:
// Send pipelined calls here. Becomes null as soon as a `Finish` is received.
kj::Promise<void> asyncOp = kj::Promise<void>(nullptr);
// Delete this promise to cancel the call.
// Delete this promise to cancel the call. For redirected calls, this is null.
kj::Maybe<kj::Promise<kj::Own<const RpcResponse>>> redirectedResults;
// For locally-redirected calls (Call.sendResultsTo.yourself), this is a promise for the call
// result, to be picked up by a subsequent `Return`.
kj::Maybe<const RpcCallContext&> callContext;
// The call context, if it's still active. Becomes null when the `Return` message is sent.
......@@ -1489,9 +1499,10 @@ private:
// can be sent.
public:
inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id,
kj::Own<kj::PromiseFulfiller<kj::Own<const RpcResponse>>> fulfiller,
kj::Own<const ResolutionChain> resolutionChain)
inline QuestionRef(
const RpcConnectionState& connectionState, QuestionId id,
kj::Own<kj::PromiseFulfiller<kj::Promise<kj::Own<const RpcResponse>>>> fulfiller,
kj::Own<const ResolutionChain> resolutionChain)
: connectionState(kj::addRef(connectionState)), id(id), fulfiller(kj::mv(fulfiller)),
resultCaps(connectionState, kj::mv(resolutionChain)) {}
......@@ -1532,6 +1543,10 @@ private:
fulfiller->fulfill(kj::mv(response));
}
void fulfill(kj::Promise<kj::Own<const RpcResponse>>&& promise) {
fulfiller->fulfill(kj::mv(promise));
}
void reject(kj::Exception&& exception) {
fulfiller->reject(kj::mv(exception));
}
......@@ -1539,7 +1554,7 @@ private:
private:
kj::Own<const RpcConnectionState> connectionState;
QuestionId id;
kj::Own<kj::PromiseFulfiller<kj::Own<const RpcResponse>>> fulfiller;
kj::Own<kj::PromiseFulfiller<kj::Promise<kj::Own<const RpcResponse>>>> fulfiller;
CapExtractorImpl resultCaps;
};
......@@ -1683,7 +1698,8 @@ private:
SendInternalResult sendInternal(bool isTailCall, Tables& lockedTables) {
injector->finishDescriptors(lockedTables);
auto paf = kj::newPromiseAndFulfiller<kj::Own<const RpcResponse>>();
auto paf = kj::newPromiseAndFulfiller<kj::Promise<kj::Own<const RpcResponse>>>(
connectionState->eventLoop);
QuestionId questionId;
auto& question = lockedTables.questions.next(questionId);
......@@ -1806,23 +1822,29 @@ private:
}
};
class RpcResponse final: public ResponseHook, public kj::Refcounted {
class RpcResponse: public ResponseHook {
public:
virtual ObjectPointer::Reader getResults() const = 0;
virtual kj::Own<const RpcResponse> addRef() const = 0;
};
class RpcResponseImpl final: public RpcResponse, public kj::Refcounted {
public:
RpcResponse(const RpcConnectionState& connectionState,
kj::Own<QuestionRef>&& questionRef,
kj::Own<IncomingRpcMessage>&& message,
ObjectPointer::Reader results)
RpcResponseImpl(const RpcConnectionState& connectionState,
kj::Own<QuestionRef>&& questionRef,
kj::Own<IncomingRpcMessage>&& message,
ObjectPointer::Reader results)
: connectionState(kj::addRef(connectionState)),
message(kj::mv(message)),
context(questionRef->getCapExtractor()),
reader(context.imbue(results)),
questionRef(kj::mv(questionRef)) {}
ObjectPointer::Reader getResults() const {
ObjectPointer::Reader getResults() const override {
return reader;
}
kj::Own<const RpcResponse> addRef() const {
kj::Own<const RpcResponse> addRef() const override {
return kj::addRef(*this);
}
......@@ -1839,15 +1861,20 @@ private:
class RpcServerResponse {
public:
RpcServerResponse(const RpcConnectionState& connectionState,
kj::Own<OutgoingRpcMessage>&& message,
ObjectPointer::Builder results)
virtual ObjectPointer::Builder getResultsBuilder() = 0;
};
class RpcServerResponseImpl final: public RpcServerResponse {
public:
RpcServerResponseImpl(const RpcConnectionState& connectionState,
kj::Own<OutgoingRpcMessage>&& message,
ObjectPointer::Builder results)
: message(kj::mv(message)),
injector(kj::heap<CapInjectorImpl>(connectionState)),
context(*injector),
builder(context.imbue(results)) {}
ObjectPointer::Builder getResults() {
ObjectPointer::Builder getResultsBuilder() override {
return builder;
}
......@@ -1864,18 +1891,42 @@ private:
ObjectPointer::Builder builder;
};
class LocallyRedirectedRpcResponse final
: public RpcResponse, public RpcServerResponse, public kj::Refcounted{
public:
LocallyRedirectedRpcResponse(uint firstSegmentWordSize)
: message(firstSegmentWordSize == 0 ?
SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize + 1) {}
ObjectPointer::Builder getResultsBuilder() override {
return message.getRoot();
}
ObjectPointer::Reader getResults() const override {
return message.getRootReader();
}
kj::Own<const RpcResponse> addRef() const override {
return kj::addRef(*this);
}
private:
LocalMessage message;
};
class RpcCallContext final: public CallContextHook, public kj::Refcounted {
public:
RpcCallContext(const RpcConnectionState& connectionState, QuestionId questionId,
kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params,
kj::Own<const ResolutionChain> resolutionChain)
kj::Own<const ResolutionChain> resolutionChain, bool redirectResults)
: connectionState(kj::addRef(connectionState)),
questionId(questionId),
request(kj::mv(request)),
requestCapExtractor(connectionState, kj::mv(resolutionChain)),
requestCapContext(requestCapExtractor),
params(requestCapContext.imbue(params)),
returnMessage(nullptr) {}
returnMessage(nullptr),
redirectResults(redirectResults) {}
~RpcCallContext() noexcept(false) {
if (isFirstResponder()) {
......@@ -1889,7 +1940,13 @@ private:
auto retainedCaps = requestCapExtractor.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder));
builder.adoptRetainedCaps(kj::mv(retainedCaps.exportList));
builder.setCanceled();
if (redirectResults) {
// The reason we haven't sent a return is because the results were sent somewhere else.
builder.setResultsSentElsewhere();
} else {
builder.setCanceled();
}
message->send();
cleanupAnswerTable(connectionState->tables.lockExclusive(), nullptr);
......@@ -1897,7 +1954,18 @@ private:
}
}
kj::Own<const RpcResponse> consumeRedirectedResponse() {
KJ_ASSERT(redirectResults);
if (response == nullptr) getResults(1); // force initialization of response
// Note that the context needs to keep its own reference to the response so that it doesn't
// get GC'd until the PipelineHook drops its reference to the context.
return kj::downcast<LocallyRedirectedRpcResponse>(*KJ_ASSERT_NONNULL(response)).addRef();
}
void sendReturn() {
KJ_ASSERT(!redirectResults);
if (isFirstResponder()) {
if (response == nullptr) getResults(1); // force initialization of response
......@@ -1909,10 +1977,12 @@ private:
kj::Own<const PipelineHook> pipelineToRelease;
auto lock = connectionState->tables.lockExclusive();
auto& tables = *lock;
cleanupAnswerTable(kj::mv(lock), KJ_ASSERT_NONNULL(response)->send(tables));
cleanupAnswerTable(kj::mv(lock),
kj::downcast<RpcServerResponseImpl>(*KJ_ASSERT_NONNULL(response)).send(tables));
}
}
void sendErrorReturn(kj::Exception&& exception) {
KJ_ASSERT(!redirectResults);
if (isFirstResponder()) {
auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Return>() + requestCapExtractor.retainedListSizeHint(true) +
......@@ -1962,16 +2032,23 @@ private:
}
ObjectPointer::Builder getResults(uint firstSegmentWordSize) override {
KJ_IF_MAYBE(r, response) {
return r->get()->getResults();
return r->get()->getResultsBuilder();
} else {
auto message = connectionState->connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 :
firstSegmentWordSize + messageSizeHint<rpc::Return>() +
requestCapExtractor.retainedListSizeHint(request == nullptr));
returnMessage = message->getBody().initAs<rpc::Message>().initReturn();
auto response = kj::heap<RpcServerResponse>(
*connectionState, kj::mv(message), returnMessage.getResults());
auto results = response->getResults();
kj::Own<RpcServerResponse> response;
if (redirectResults) {
response = kj::refcounted<LocallyRedirectedRpcResponse>(firstSegmentWordSize);
} else {
auto message = connectionState->connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 :
firstSegmentWordSize + messageSizeHint<rpc::Return>() +
requestCapExtractor.retainedListSizeHint(request == nullptr));
returnMessage = message->getBody().initAs<rpc::Message>().initReturn();
response = kj::heap<RpcServerResponseImpl>(
*connectionState, kj::mv(message), returnMessage.getResults());
}
auto results = response->getResultsBuilder();
this->response = kj::mv(response);
return results;
}
......@@ -2064,6 +2141,7 @@ private:
kj::Maybe<kj::Own<RpcServerResponse>> response;
rpc::Return::Builder returnMessage;
bool redirectResults = false;
bool responseSent = false;
kj::Maybe<kj::Own<kj::PromiseFulfiller<ObjectPointer::Pipeline>>> tailCallPipelineFulfiller;
......@@ -2303,12 +2381,25 @@ private:
return;
}
bool redirectResults;
switch (call.getSendResultsTo().which()) {
case rpc::Call::SendResultsTo::CALLER:
redirectResults = false;
break;
case rpc::Call::SendResultsTo::YOURSELF:
redirectResults = true;
break;
default:
KJ_FAIL_REQUIRE("Unsupported `Call.sendResultsTo`.") { return; }
}
QuestionId questionId = call.getQuestionId();
// Note: resolutionChainTail couldn't possibly be changing here because we only handle one
// message at a time, so we can hold off locking the tables for a bit longer.
auto context = kj::refcounted<RpcCallContext>(
*this, questionId, kj::mv(message), call.getParams(),
kj::addRef(*tables.getWithoutLock().resolutionChainTail));
kj::addRef(*tables.getWithoutLock().resolutionChainTail),
redirectResults);
auto promiseAndPipeline = capability->call(
call.getInterfaceId(), call.getMethodId(), context->addRef());
......@@ -2331,24 +2422,28 @@ private:
answer.callContext = *context;
answer.pipeline = kj::mv(promiseAndPipeline.pipeline);
// Hack: Both the success and error continuations need to use the context. We could
// refcount, but both will be destroyed at the same time anyway.
RpcCallContext* contextPtr = context;
answer.asyncOp = promiseAndPipeline.promise.then(
[contextPtr]() {
contextPtr->sendReturn();
}, [contextPtr](kj::Exception&& exception) {
contextPtr->sendErrorReturn(kj::mv(exception));
}).then([]() {
// Success.
}, [&](kj::Exception&& exception) {
// We never actually wait on `asyncOp` so we need to manually report exceptions.
// TODO(cleanup): Perhaps there should be a better, more-automated approach to this?
taskFailed(kj::mv(exception));
});
answer.asyncOp.attach(kj::mv(context));
answer.asyncOp.eagerlyEvaluate(eventLoop);
if (redirectResults) {
answer.redirectedResults = promiseAndPipeline.promise.then(
kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) {
return context->consumeRedirectedResponse();
}));
} else {
// Hack: Both the success and error continuations need to use the context. We could
// refcount, but both will be destroyed at the same time anyway.
RpcCallContext* contextPtr = context;
answer.asyncOp = promiseAndPipeline.promise.then(
[contextPtr]() {
contextPtr->sendReturn();
}, [contextPtr](kj::Exception&& exception) {
contextPtr->sendErrorReturn(kj::mv(exception));
}).then([]() {}, [&](kj::Exception&& exception) {
// Handle exceptions that occur in sendReturn()/sendErrorReturn().
taskFailed(kj::mv(exception));
});
answer.asyncOp.attach(kj::mv(context));
answer.asyncOp.eagerlyEvaluate(eventLoop);
}
}
}
......@@ -2403,6 +2498,7 @@ private:
void handleReturn(kj::Own<IncomingRpcMessage>&& message, const rpc::Return::Reader& ret) {
kj::Own<CapInjectorImpl> paramCapsToRelease;
kj::Promise<kj::Own<const RpcResponse>> promiseToRelease = nullptr;
auto lock = tables.lockExclusive();
KJ_IF_MAYBE(question, lock->questions.find(ret.getQuestionId())) {
......@@ -2432,7 +2528,7 @@ private:
// The questionRef still exists, but could be being deleted in another thread.
KJ_IF_MAYBE(ownRef, kj::tryAddRef(*questionRef)) {
// Not being deleted.
questionRef->fulfill(kj::refcounted<RpcResponse>(
questionRef->fulfill(kj::refcounted<RpcResponseImpl>(
*this, kj::mv(*ownRef), kj::mv(message), ret.getResults()));
}
}
......@@ -2474,6 +2570,30 @@ private:
}
break;
case rpc::Return::TAKE_FROM_OTHER_ANSWER:
KJ_IF_MAYBE(answer, lock->answers.find(ret.getTakeFromOtherAnswer())) {
KJ_IF_MAYBE(response, answer->redirectedResults) {
// If we don't manage to fill in a questionRef here, we will want to release the
// promise.
promiseToRelease = kj::mv(*response);
KJ_IF_MAYBE(questionRef, question->selfRef) {
// The questionRef still exists, but could be being deleted in another thread.
KJ_IF_MAYBE(ownRef, kj::tryAddRef(*questionRef)) {
// Not being deleted.
questionRef->fulfill(kj::mv(promiseToRelease));
}
}
} else {
KJ_FAIL_REQUIRE("`Return.takeFromOtherAnswer` referenced a call that did not "
"use `sendResultsTo.yourself`.") { return; }
}
} else {
KJ_FAIL_REQUIRE("`Return.takeFromOtherAnswer` had invalid answer ID.") { return; }
}
break;
default:
KJ_FAIL_REQUIRE("Unknown 'Return' type.") { return; }
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment