Commit a30c4171 authored by Kenton Varda's avatar Kenton Varda

Finish other end of RPC tail call implementation.

parent 85a2ec20
...@@ -188,6 +188,7 @@ public: ...@@ -188,6 +188,7 @@ public:
AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY); AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY);
inline ObjectPointer::Builder getRoot() { return root; } inline ObjectPointer::Builder getRoot() { return root; }
inline ObjectPointer::Reader getRootReader() const { return root.asReader(); }
private: private:
MallocMessageBuilder message; MallocMessageBuilder message;
......
...@@ -256,7 +256,7 @@ public: ...@@ -256,7 +256,7 @@ public:
kj::Own<const ClientHook> restore(ObjectPointer::Reader objectId) { kj::Own<const ClientHook> restore(ObjectPointer::Reader objectId) {
QuestionId questionId; QuestionId questionId;
kj::Own<QuestionRef> questionRef; kj::Own<QuestionRef> questionRef;
auto paf = kj::newPromiseAndFulfiller<kj::Own<const RpcResponse>>(eventLoop); auto paf = kj::newPromiseAndFulfiller<kj::Promise<kj::Own<const RpcResponse>>>(eventLoop);
{ {
auto lock = tables.lockExclusive(); auto lock = tables.lockExclusive();
...@@ -329,6 +329,12 @@ public: ...@@ -329,6 +329,12 @@ public:
pipelinesToRelease.add(kj::mv(*p)); pipelinesToRelease.add(kj::mv(*p));
} }
KJ_IF_MAYBE(promise, answer.redirectedResults) {
// Answer contains a result redirection that hasn't been picked up yet. Make the call
// properly cancelable by transforming the redirect promise into a regular asyncOp.
answer.asyncOp = promise->thenInAnyThread([](kj::Own<const RpcResponse>&& response) {});
}
KJ_IF_MAYBE(context, answer.callContext) { KJ_IF_MAYBE(context, answer.callContext) {
context->requestCancel(); context->requestCancel();
} }
...@@ -408,7 +414,11 @@ private: ...@@ -408,7 +414,11 @@ private:
// Send pipelined calls here. Becomes null as soon as a `Finish` is received. // Send pipelined calls here. Becomes null as soon as a `Finish` is received.
kj::Promise<void> asyncOp = kj::Promise<void>(nullptr); kj::Promise<void> asyncOp = kj::Promise<void>(nullptr);
// Delete this promise to cancel the call. // Delete this promise to cancel the call. For redirected calls, this is null.
kj::Maybe<kj::Promise<kj::Own<const RpcResponse>>> redirectedResults;
// For locally-redirected calls (Call.sendResultsTo.yourself), this is a promise for the call
// result, to be picked up by a subsequent `Return`.
kj::Maybe<const RpcCallContext&> callContext; kj::Maybe<const RpcCallContext&> callContext;
// The call context, if it's still active. Becomes null when the `Return` message is sent. // The call context, if it's still active. Becomes null when the `Return` message is sent.
...@@ -1489,8 +1499,9 @@ private: ...@@ -1489,8 +1499,9 @@ private:
// can be sent. // can be sent.
public: public:
inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id, inline QuestionRef(
kj::Own<kj::PromiseFulfiller<kj::Own<const RpcResponse>>> fulfiller, const RpcConnectionState& connectionState, QuestionId id,
kj::Own<kj::PromiseFulfiller<kj::Promise<kj::Own<const RpcResponse>>>> fulfiller,
kj::Own<const ResolutionChain> resolutionChain) kj::Own<const ResolutionChain> resolutionChain)
: connectionState(kj::addRef(connectionState)), id(id), fulfiller(kj::mv(fulfiller)), : connectionState(kj::addRef(connectionState)), id(id), fulfiller(kj::mv(fulfiller)),
resultCaps(connectionState, kj::mv(resolutionChain)) {} resultCaps(connectionState, kj::mv(resolutionChain)) {}
...@@ -1532,6 +1543,10 @@ private: ...@@ -1532,6 +1543,10 @@ private:
fulfiller->fulfill(kj::mv(response)); fulfiller->fulfill(kj::mv(response));
} }
void fulfill(kj::Promise<kj::Own<const RpcResponse>>&& promise) {
fulfiller->fulfill(kj::mv(promise));
}
void reject(kj::Exception&& exception) { void reject(kj::Exception&& exception) {
fulfiller->reject(kj::mv(exception)); fulfiller->reject(kj::mv(exception));
} }
...@@ -1539,7 +1554,7 @@ private: ...@@ -1539,7 +1554,7 @@ private:
private: private:
kj::Own<const RpcConnectionState> connectionState; kj::Own<const RpcConnectionState> connectionState;
QuestionId id; QuestionId id;
kj::Own<kj::PromiseFulfiller<kj::Own<const RpcResponse>>> fulfiller; kj::Own<kj::PromiseFulfiller<kj::Promise<kj::Own<const RpcResponse>>>> fulfiller;
CapExtractorImpl resultCaps; CapExtractorImpl resultCaps;
}; };
...@@ -1683,7 +1698,8 @@ private: ...@@ -1683,7 +1698,8 @@ private:
SendInternalResult sendInternal(bool isTailCall, Tables& lockedTables) { SendInternalResult sendInternal(bool isTailCall, Tables& lockedTables) {
injector->finishDescriptors(lockedTables); injector->finishDescriptors(lockedTables);
auto paf = kj::newPromiseAndFulfiller<kj::Own<const RpcResponse>>(); auto paf = kj::newPromiseAndFulfiller<kj::Promise<kj::Own<const RpcResponse>>>(
connectionState->eventLoop);
QuestionId questionId; QuestionId questionId;
auto& question = lockedTables.questions.next(questionId); auto& question = lockedTables.questions.next(questionId);
...@@ -1806,9 +1822,15 @@ private: ...@@ -1806,9 +1822,15 @@ private:
} }
}; };
class RpcResponse final: public ResponseHook, public kj::Refcounted { class RpcResponse: public ResponseHook {
public:
virtual ObjectPointer::Reader getResults() const = 0;
virtual kj::Own<const RpcResponse> addRef() const = 0;
};
class RpcResponseImpl final: public RpcResponse, public kj::Refcounted {
public: public:
RpcResponse(const RpcConnectionState& connectionState, RpcResponseImpl(const RpcConnectionState& connectionState,
kj::Own<QuestionRef>&& questionRef, kj::Own<QuestionRef>&& questionRef,
kj::Own<IncomingRpcMessage>&& message, kj::Own<IncomingRpcMessage>&& message,
ObjectPointer::Reader results) ObjectPointer::Reader results)
...@@ -1818,11 +1840,11 @@ private: ...@@ -1818,11 +1840,11 @@ private:
reader(context.imbue(results)), reader(context.imbue(results)),
questionRef(kj::mv(questionRef)) {} questionRef(kj::mv(questionRef)) {}
ObjectPointer::Reader getResults() const { ObjectPointer::Reader getResults() const override {
return reader; return reader;
} }
kj::Own<const RpcResponse> addRef() const { kj::Own<const RpcResponse> addRef() const override {
return kj::addRef(*this); return kj::addRef(*this);
} }
...@@ -1839,7 +1861,12 @@ private: ...@@ -1839,7 +1861,12 @@ private:
class RpcServerResponse { class RpcServerResponse {
public: public:
RpcServerResponse(const RpcConnectionState& connectionState, virtual ObjectPointer::Builder getResultsBuilder() = 0;
};
class RpcServerResponseImpl final: public RpcServerResponse {
public:
RpcServerResponseImpl(const RpcConnectionState& connectionState,
kj::Own<OutgoingRpcMessage>&& message, kj::Own<OutgoingRpcMessage>&& message,
ObjectPointer::Builder results) ObjectPointer::Builder results)
: message(kj::mv(message)), : message(kj::mv(message)),
...@@ -1847,7 +1874,7 @@ private: ...@@ -1847,7 +1874,7 @@ private:
context(*injector), context(*injector),
builder(context.imbue(results)) {} builder(context.imbue(results)) {}
ObjectPointer::Builder getResults() { ObjectPointer::Builder getResultsBuilder() override {
return builder; return builder;
} }
...@@ -1864,18 +1891,42 @@ private: ...@@ -1864,18 +1891,42 @@ private:
ObjectPointer::Builder builder; ObjectPointer::Builder builder;
}; };
class LocallyRedirectedRpcResponse final
: public RpcResponse, public RpcServerResponse, public kj::Refcounted{
public:
LocallyRedirectedRpcResponse(uint firstSegmentWordSize)
: message(firstSegmentWordSize == 0 ?
SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize + 1) {}
ObjectPointer::Builder getResultsBuilder() override {
return message.getRoot();
}
ObjectPointer::Reader getResults() const override {
return message.getRootReader();
}
kj::Own<const RpcResponse> addRef() const override {
return kj::addRef(*this);
}
private:
LocalMessage message;
};
class RpcCallContext final: public CallContextHook, public kj::Refcounted { class RpcCallContext final: public CallContextHook, public kj::Refcounted {
public: public:
RpcCallContext(const RpcConnectionState& connectionState, QuestionId questionId, RpcCallContext(const RpcConnectionState& connectionState, QuestionId questionId,
kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params, kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params,
kj::Own<const ResolutionChain> resolutionChain) kj::Own<const ResolutionChain> resolutionChain, bool redirectResults)
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
questionId(questionId), questionId(questionId),
request(kj::mv(request)), request(kj::mv(request)),
requestCapExtractor(connectionState, kj::mv(resolutionChain)), requestCapExtractor(connectionState, kj::mv(resolutionChain)),
requestCapContext(requestCapExtractor), requestCapContext(requestCapExtractor),
params(requestCapContext.imbue(params)), params(requestCapContext.imbue(params)),
returnMessage(nullptr) {} returnMessage(nullptr),
redirectResults(redirectResults) {}
~RpcCallContext() noexcept(false) { ~RpcCallContext() noexcept(false) {
if (isFirstResponder()) { if (isFirstResponder()) {
...@@ -1889,7 +1940,13 @@ private: ...@@ -1889,7 +1940,13 @@ private:
auto retainedCaps = requestCapExtractor.finalizeRetainedCaps( auto retainedCaps = requestCapExtractor.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder)); Orphanage::getForMessageContaining(builder));
builder.adoptRetainedCaps(kj::mv(retainedCaps.exportList)); builder.adoptRetainedCaps(kj::mv(retainedCaps.exportList));
if (redirectResults) {
// The reason we haven't sent a return is because the results were sent somewhere else.
builder.setResultsSentElsewhere();
} else {
builder.setCanceled(); builder.setCanceled();
}
message->send(); message->send();
cleanupAnswerTable(connectionState->tables.lockExclusive(), nullptr); cleanupAnswerTable(connectionState->tables.lockExclusive(), nullptr);
...@@ -1897,7 +1954,18 @@ private: ...@@ -1897,7 +1954,18 @@ private:
} }
} }
kj::Own<const RpcResponse> consumeRedirectedResponse() {
KJ_ASSERT(redirectResults);
if (response == nullptr) getResults(1); // force initialization of response
// Note that the context needs to keep its own reference to the response so that it doesn't
// get GC'd until the PipelineHook drops its reference to the context.
return kj::downcast<LocallyRedirectedRpcResponse>(*KJ_ASSERT_NONNULL(response)).addRef();
}
void sendReturn() { void sendReturn() {
KJ_ASSERT(!redirectResults);
if (isFirstResponder()) { if (isFirstResponder()) {
if (response == nullptr) getResults(1); // force initialization of response if (response == nullptr) getResults(1); // force initialization of response
...@@ -1909,10 +1977,12 @@ private: ...@@ -1909,10 +1977,12 @@ private:
kj::Own<const PipelineHook> pipelineToRelease; kj::Own<const PipelineHook> pipelineToRelease;
auto lock = connectionState->tables.lockExclusive(); auto lock = connectionState->tables.lockExclusive();
auto& tables = *lock; auto& tables = *lock;
cleanupAnswerTable(kj::mv(lock), KJ_ASSERT_NONNULL(response)->send(tables)); cleanupAnswerTable(kj::mv(lock),
kj::downcast<RpcServerResponseImpl>(*KJ_ASSERT_NONNULL(response)).send(tables));
} }
} }
void sendErrorReturn(kj::Exception&& exception) { void sendErrorReturn(kj::Exception&& exception) {
KJ_ASSERT(!redirectResults);
if (isFirstResponder()) { if (isFirstResponder()) {
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Return>() + requestCapExtractor.retainedListSizeHint(true) + messageSizeHint<rpc::Return>() + requestCapExtractor.retainedListSizeHint(true) +
...@@ -1962,16 +2032,23 @@ private: ...@@ -1962,16 +2032,23 @@ private:
} }
ObjectPointer::Builder getResults(uint firstSegmentWordSize) override { ObjectPointer::Builder getResults(uint firstSegmentWordSize) override {
KJ_IF_MAYBE(r, response) { KJ_IF_MAYBE(r, response) {
return r->get()->getResults(); return r->get()->getResultsBuilder();
} else {
kj::Own<RpcServerResponse> response;
if (redirectResults) {
response = kj::refcounted<LocallyRedirectedRpcResponse>(firstSegmentWordSize);
} else { } else {
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 : firstSegmentWordSize == 0 ? 0 :
firstSegmentWordSize + messageSizeHint<rpc::Return>() + firstSegmentWordSize + messageSizeHint<rpc::Return>() +
requestCapExtractor.retainedListSizeHint(request == nullptr)); requestCapExtractor.retainedListSizeHint(request == nullptr));
returnMessage = message->getBody().initAs<rpc::Message>().initReturn(); returnMessage = message->getBody().initAs<rpc::Message>().initReturn();
auto response = kj::heap<RpcServerResponse>( response = kj::heap<RpcServerResponseImpl>(
*connectionState, kj::mv(message), returnMessage.getResults()); *connectionState, kj::mv(message), returnMessage.getResults());
auto results = response->getResults(); }
auto results = response->getResultsBuilder();
this->response = kj::mv(response); this->response = kj::mv(response);
return results; return results;
} }
...@@ -2064,6 +2141,7 @@ private: ...@@ -2064,6 +2141,7 @@ private:
kj::Maybe<kj::Own<RpcServerResponse>> response; kj::Maybe<kj::Own<RpcServerResponse>> response;
rpc::Return::Builder returnMessage; rpc::Return::Builder returnMessage;
bool redirectResults = false;
bool responseSent = false; bool responseSent = false;
kj::Maybe<kj::Own<kj::PromiseFulfiller<ObjectPointer::Pipeline>>> tailCallPipelineFulfiller; kj::Maybe<kj::Own<kj::PromiseFulfiller<ObjectPointer::Pipeline>>> tailCallPipelineFulfiller;
...@@ -2303,12 +2381,25 @@ private: ...@@ -2303,12 +2381,25 @@ private:
return; return;
} }
bool redirectResults;
switch (call.getSendResultsTo().which()) {
case rpc::Call::SendResultsTo::CALLER:
redirectResults = false;
break;
case rpc::Call::SendResultsTo::YOURSELF:
redirectResults = true;
break;
default:
KJ_FAIL_REQUIRE("Unsupported `Call.sendResultsTo`.") { return; }
}
QuestionId questionId = call.getQuestionId(); QuestionId questionId = call.getQuestionId();
// Note: resolutionChainTail couldn't possibly be changing here because we only handle one // Note: resolutionChainTail couldn't possibly be changing here because we only handle one
// message at a time, so we can hold off locking the tables for a bit longer. // message at a time, so we can hold off locking the tables for a bit longer.
auto context = kj::refcounted<RpcCallContext>( auto context = kj::refcounted<RpcCallContext>(
*this, questionId, kj::mv(message), call.getParams(), *this, questionId, kj::mv(message), call.getParams(),
kj::addRef(*tables.getWithoutLock().resolutionChainTail)); kj::addRef(*tables.getWithoutLock().resolutionChainTail),
redirectResults);
auto promiseAndPipeline = capability->call( auto promiseAndPipeline = capability->call(
call.getInterfaceId(), call.getMethodId(), context->addRef()); call.getInterfaceId(), call.getMethodId(), context->addRef());
...@@ -2331,6 +2422,12 @@ private: ...@@ -2331,6 +2422,12 @@ private:
answer.callContext = *context; answer.callContext = *context;
answer.pipeline = kj::mv(promiseAndPipeline.pipeline); answer.pipeline = kj::mv(promiseAndPipeline.pipeline);
if (redirectResults) {
answer.redirectedResults = promiseAndPipeline.promise.then(
kj::mvCapture(context, [](kj::Own<RpcCallContext>&& context) {
return context->consumeRedirectedResponse();
}));
} else {
// Hack: Both the success and error continuations need to use the context. We could // Hack: Both the success and error continuations need to use the context. We could
// refcount, but both will be destroyed at the same time anyway. // refcount, but both will be destroyed at the same time anyway.
RpcCallContext* contextPtr = context; RpcCallContext* contextPtr = context;
...@@ -2340,17 +2437,15 @@ private: ...@@ -2340,17 +2437,15 @@ private:
contextPtr->sendReturn(); contextPtr->sendReturn();
}, [contextPtr](kj::Exception&& exception) { }, [contextPtr](kj::Exception&& exception) {
contextPtr->sendErrorReturn(kj::mv(exception)); contextPtr->sendErrorReturn(kj::mv(exception));
}).then([]() { }).then([]() {}, [&](kj::Exception&& exception) {
// Success. // Handle exceptions that occur in sendReturn()/sendErrorReturn().
}, [&](kj::Exception&& exception) {
// We never actually wait on `asyncOp` so we need to manually report exceptions.
// TODO(cleanup): Perhaps there should be a better, more-automated approach to this?
taskFailed(kj::mv(exception)); taskFailed(kj::mv(exception));
}); });
answer.asyncOp.attach(kj::mv(context)); answer.asyncOp.attach(kj::mv(context));
answer.asyncOp.eagerlyEvaluate(eventLoop); answer.asyncOp.eagerlyEvaluate(eventLoop);
} }
} }
}
kj::Maybe<kj::Own<const ClientHook>> getMessageTarget(const rpc::MessageTarget::Reader& target) { kj::Maybe<kj::Own<const ClientHook>> getMessageTarget(const rpc::MessageTarget::Reader& target) {
switch (target.which()) { switch (target.which()) {
...@@ -2403,6 +2498,7 @@ private: ...@@ -2403,6 +2498,7 @@ private:
void handleReturn(kj::Own<IncomingRpcMessage>&& message, const rpc::Return::Reader& ret) { void handleReturn(kj::Own<IncomingRpcMessage>&& message, const rpc::Return::Reader& ret) {
kj::Own<CapInjectorImpl> paramCapsToRelease; kj::Own<CapInjectorImpl> paramCapsToRelease;
kj::Promise<kj::Own<const RpcResponse>> promiseToRelease = nullptr;
auto lock = tables.lockExclusive(); auto lock = tables.lockExclusive();
KJ_IF_MAYBE(question, lock->questions.find(ret.getQuestionId())) { KJ_IF_MAYBE(question, lock->questions.find(ret.getQuestionId())) {
...@@ -2432,7 +2528,7 @@ private: ...@@ -2432,7 +2528,7 @@ private:
// The questionRef still exists, but could be being deleted in another thread. // The questionRef still exists, but could be being deleted in another thread.
KJ_IF_MAYBE(ownRef, kj::tryAddRef(*questionRef)) { KJ_IF_MAYBE(ownRef, kj::tryAddRef(*questionRef)) {
// Not being deleted. // Not being deleted.
questionRef->fulfill(kj::refcounted<RpcResponse>( questionRef->fulfill(kj::refcounted<RpcResponseImpl>(
*this, kj::mv(*ownRef), kj::mv(message), ret.getResults())); *this, kj::mv(*ownRef), kj::mv(message), ret.getResults()));
} }
} }
...@@ -2474,6 +2570,30 @@ private: ...@@ -2474,6 +2570,30 @@ private:
} }
break; break;
case rpc::Return::TAKE_FROM_OTHER_ANSWER:
KJ_IF_MAYBE(answer, lock->answers.find(ret.getTakeFromOtherAnswer())) {
KJ_IF_MAYBE(response, answer->redirectedResults) {
// If we don't manage to fill in a questionRef here, we will want to release the
// promise.
promiseToRelease = kj::mv(*response);
KJ_IF_MAYBE(questionRef, question->selfRef) {
// The questionRef still exists, but could be being deleted in another thread.
KJ_IF_MAYBE(ownRef, kj::tryAddRef(*questionRef)) {
// Not being deleted.
questionRef->fulfill(kj::mv(promiseToRelease));
}
}
} else {
KJ_FAIL_REQUIRE("`Return.takeFromOtherAnswer` referenced a call that did not "
"use `sendResultsTo.yourself`.") { return; }
}
} else {
KJ_FAIL_REQUIRE("`Return.takeFromOtherAnswer` had invalid answer ID.") { return; }
}
break;
default: default:
KJ_FAIL_REQUIRE("Unknown 'Return' type.") { return; } KJ_FAIL_REQUIRE("Unknown 'Return' type.") { return; }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment