Commit 1e518b92 authored by Kenton Varda's avatar Kenton Varda

Tail calls half-implemented, but I think it's time to delete all the…

Tail calls half-implemented, but I think it's time to delete all the cancellation code before continuing because it will simplify things.
parent 08143efb
......@@ -174,6 +174,10 @@ public:
ObjectPointer::Pipeline(kj::refcounted<BrokenPipeline>(exception)));
}
const void* getBrand() const {
return nullptr;
}
kj::Exception exception;
LocalMessage message;
};
......
......@@ -152,6 +152,39 @@ TEST(Capability, Pipelining) {
EXPECT_EQ(1, chainedCallCount);
}
TEST(Capability, TailCall) {
kj::SimpleEventLoop loop;
int calleeCallCount = 0;
int callerCallCount = 0;
test::TestTailCallee::Client callee(kj::heap<TestTailCalleeImpl>(calleeCallCount), loop);
test::TestTailCaller::Client caller(kj::heap<TestTailCallerImpl>(callerCallCount), loop);
auto request = caller.fooRequest();
request.setI(456);
request.setCallee(callee);
auto promise = request.send();
auto dependentCall0 = promise.getC().getCallSequenceRequest().send();
auto response = loop.wait(kj::mv(promise));
EXPECT_EQ(456, response.getI());
EXPECT_EQ(456, response.getI());
auto dependentCall1 = promise.getC().getCallSequenceRequest().send();
auto dependentCall2 = response.getC().getCallSequenceRequest().send();
EXPECT_EQ(0, loop.wait(kj::mv(dependentCall0)).getN());
EXPECT_EQ(1, loop.wait(kj::mv(dependentCall1)).getN());
EXPECT_EQ(2, loop.wait(kj::mv(dependentCall2)).getN());
EXPECT_EQ(1, calleeCallCount);
EXPECT_EQ(1, callerCallCount);
}
// =======================================================================================
TEST(Capability, DynamicClient) {
......
......@@ -75,8 +75,6 @@ kj::Promise<void> ClientHook::whenResolved() const {
// =======================================================================================
namespace {
class LocalResponse final: public ResponseHook, public kj::Refcounted {
public:
LocalResponse(uint sizeHint)
......@@ -101,10 +99,33 @@ public:
request = nullptr;
}
ObjectPointer::Builder getResults(uint firstSegmentWordSize) override {
if (!response) {
response = kj::refcounted<LocalResponse>(firstSegmentWordSize);
if (response == nullptr) {
auto localResponse = kj::refcounted<LocalResponse>(firstSegmentWordSize);
responseBuilder = localResponse->message.getRoot();
response = Response<ObjectPointer>(responseBuilder.asReader(), kj::mv(localResponse));
}
return response->message.getRoot();
return responseBuilder;
}
kj::Promise<void> tailCall(kj::Own<RequestHook> request) override {
KJ_REQUIRE(response == nullptr, "Can't call tailCall() after initializing the results struct.");
releaseParams();
auto promise = request->send();
// Link pipelines.
KJ_IF_MAYBE(f, tailCallPipelineFulfiller) {
f->get()->fulfill(kj::mv(kj::implicitCast<ObjectPointer::Pipeline&>(promise)));
}
// Wait for response.
return promise.then([this](Response<ObjectPointer>&& tailResponse) {
response = kj::mv(tailResponse);
});
}
kj::Promise<ObjectPointer::Pipeline> onTailCall() override {
auto paf = kj::newPromiseAndFulfiller<ObjectPointer::Pipeline>();
tailCallPipelineFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
void allowAsyncCancellation() override {
// ignored for local calls
......@@ -117,8 +138,10 @@ public:
}
kj::Maybe<kj::Own<LocalMessage>> request;
kj::Own<LocalResponse> response;
kj::Maybe<Response<ObjectPointer>> response;
ObjectPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null
kj::Own<const ClientHook> clientRef;
kj::Maybe<kj::Own<kj::PromiseFulfiller<ObjectPointer::Pipeline>>> tailCallPipelineFulfiller;
};
class LocalRequest final: public RequestHook {
......@@ -131,6 +154,8 @@ public:
interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {}
RemotePromise<ObjectPointer> send() override {
KJ_REQUIRE(message.get() != nullptr, "Already called send() on this request.");
// For the lambda capture.
uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId;
......@@ -140,15 +165,21 @@ public:
auto promise = loop.there(kj::mv(promiseAndPipeline.promise),
kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
// Do not inline `reader` -- kj::mv on next line may occur first.
auto reader = context->getResults(1).asReader();
return Response<ObjectPointer>(reader, kj::mv(context->response));
KJ_IF_MAYBE(r, context->response) {
return kj::mv(*r);
} else {
KJ_FAIL_ASSERT("Method implementation failed to fill in results.");
}
}));
return RemotePromise<ObjectPointer>(
kj::mv(promise), ObjectPointer::Pipeline(kj::mv(promiseAndPipeline.pipeline)));
}
const void* getBrand() const {
return nullptr;
}
kj::Own<LocalMessage> message;
private:
......@@ -398,8 +429,7 @@ public:
});
// Make sure that this client cannot be destroyed until the promise completes.
promise = promise.thenInAnyThread(kj::mvCapture(kj::addRef(*this),
[](kj::Own<const LocalClient>&& ref) {}));
promise.attach(kj::addRef(*this));
// We have to fork this promise for the pipeline to receive a copy of the answer.
auto forked = server.getEventLoop().fork(kj::mv(promise));
......@@ -410,14 +440,16 @@ public:
return kj::refcounted<LocalPipeline>(kj::mv(context));
}));
auto completionPromise = forked.addBranch().thenInAnyThread(kj::mvCapture(context,
[=](kj::Own<CallContextHook>&& context) {
// Nothing to do here. We just wanted to make sure to hold on to a reference to the
// context even if the pipeline was discarded.
//
// TODO(someday): We could probably make this less ugly if we had the ability to
// convert Promise<Tuple<T, U>> -> Tuple<Promise<T>, Promise<U>>...
}));
auto tailPipelinePromise = context->onTailCall().thenInAnyThread(
[](ObjectPointer::Pipeline&& pipeline) {
return kj::mv(pipeline.hook);
});
pipelinePromise = server.getEventLoop().exclusiveJoin(
kj::mv(pipelinePromise), kj::mv(tailPipelinePromise));
auto completionPromise = forked.addBranch();
completionPromise.attach(kj::mv(context));
return VoidPromiseAndPipeline { kj::mv(completionPromise),
kj::refcounted<QueuedPipeline>(server.getEventLoop(), kj::mv(pipelinePromise)) };
......@@ -444,8 +476,6 @@ private:
kj::EventLoopGuarded<kj::Own<Capability::Server>> server;
};
} // namespace
kj::Own<const ClientHook> Capability::Client::makeLocalClient(
kj::Own<Capability::Server>&& server, const kj::EventLoop& eventLoop) {
return kj::refcounted<LocalClient>(eventLoop, kj::mv(server));
......
......@@ -83,6 +83,8 @@ private:
friend class Capability::Client;
friend struct DynamicCapability;
template <typename, typename>
friend class CallContext;
};
template <typename Results>
......@@ -207,6 +209,20 @@ public:
// `firstSegmentWordSize` indicates the suggested size of the message's first segment. This
// is a hint only. If not specified, the system will decide on its own.
template <typename SubParams>
kj::Promise<void> tailCall(Request<SubParams, Results>&& tailRequest);
// Resolve the call by making a tail call. `tailRequest` is a request that has been filled in
// but not yet sent. The context will send the call, then fill in the results with the result
// of the call. If tailCall() is used, {get,init,set,adopt}Results (above) *must not* be called.
//
// The RPC implementation may be able to optimize a tail call to another machine such that the
// results never actually pass through this machine. Even if no such optimization is possible,
// `tailCall()` may allow pipelined calls to be forwarded optimistically to the new call site.
//
// `tailCall()` implies a call to `releaseParams()`, to simplify certain implementations.
// In general, this should be the last thing a method implementation calls, and the promise
// returned from `tailCall()` should then be returned by the method implementation.
void allowAsyncCancellation();
// Indicate that it is OK for the RPC system to discard its Promise for this call's result if
// the caller cancels the call, thereby transitively canceling any asynchronous operations the
......@@ -276,6 +292,11 @@ class RequestHook {
public:
virtual RemotePromise<ObjectPointer> send() = 0;
// Send the call and return a promise for the result.
virtual const void* getBrand() const = 0;
// Returns a void* that identifies who made this request. This can be used by an RPC adapter to
// discover when tail call is going to be sent over its own connection and therefore can be
// optimized into a remote tail call.
};
class ResponseHook {
......@@ -347,9 +368,14 @@ public:
virtual ObjectPointer::Reader getParams() = 0;
virtual void releaseParams() = 0;
virtual ObjectPointer::Builder getResults(uint firstSegmentWordSize) = 0;
virtual kj::Promise<void> tailCall(kj::Own<RequestHook> request) = 0;
virtual void allowAsyncCancellation() = 0;
virtual bool isCanceled() = 0;
virtual kj::Promise<ObjectPointer::Pipeline> onTailCall() = 0;
// If `tailCall()` is called, resolves to the PipelineHook from the tail call. An
// implementation of `ClientHook::call()` is allowed to call this at most once.
virtual kj::Own<CallContextHook> addRef() = 0;
};
......@@ -561,6 +587,12 @@ inline Orphanage CallContext<Params, Results>::getResultsOrphanage(uint firstSeg
return Orphanage::getForMessageContaining(hook->getResults(firstSegmentWordSize));
}
template <typename Params, typename Results>
template <typename SubParams>
inline kj::Promise<void> CallContext<Params, Results>::tailCall(
Request<SubParams, Results>&& tailRequest) {
return hook->tailCall(kj::mv(tailRequest.hook));
}
template <typename Params, typename Results>
inline void CallContext<Params, Results>::allowAsyncCancellation() {
hook->allowAsyncCancellation();
}
......
......@@ -247,6 +247,8 @@ struct ObjectPointer {
inline Pipeline(kj::Own<const PipelineHook>&& hook, kj::Array<PipelineOp>&& ops)
: hook(kj::mv(hook)), ops(kj::mv(ops)) {}
friend class LocalClient;
};
};
......
......@@ -68,7 +68,7 @@ public:
class ConnectionImpl final: public Connection, public kj::Refcounted {
public:
ConnectionImpl() {}
ConnectionImpl(const char* name): name(name) {}
void attach(ConnectionImpl& other) {
KJ_REQUIRE(partner == nullptr);
......@@ -100,6 +100,9 @@ public:
return message->message.getRoot<ObjectPointer>();
}
void send() override {
//kj::String msg = kj::str(connection.name, ": ", message->message.getRoot<rpc::Message>());
//KJ_DBG(msg);
KJ_IF_MAYBE(p, connection.partner) {
auto lock = p->queues.lockExclusive();
if (lock->fulfillers.empty()) {
......@@ -146,6 +149,7 @@ public:
}
private:
const char* name;
kj::Maybe<ConnectionImpl&> partner;
struct Queues {
......@@ -172,8 +176,8 @@ public:
auto iter = myLock->connections.find(&dst);
if (iter == myLock->connections.end()) {
auto local = kj::refcounted<ConnectionImpl>();
auto remote = kj::refcounted<ConnectionImpl>();
auto local = kj::refcounted<ConnectionImpl>("client");
auto remote = kj::refcounted<ConnectionImpl>("server");
local->attach(*remote);
myLock->connections[&dst] = kj::addRef(*local);
......@@ -237,6 +241,10 @@ public:
return Capability::Client(newBrokenCap("No TestExtends implemented."));
case test::TestSturdyRefObjectId::Tag::TEST_PIPELINE:
return kj::heap<TestPipelineImpl>(callCount);
case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLEE:
return kj::heap<TestTailCalleeImpl>(callCount);
case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER:
return kj::heap<TestTailCallerImpl>(callCount);
}
KJ_UNREACHABLE;
}
......@@ -343,6 +351,38 @@ TEST_F(RpcTest, Pipelining) {
EXPECT_EQ(1, chainedCallCount);
}
TEST_F(RpcTest, TailCall) {
auto caller = connect(test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER)
.castAs<test::TestTailCaller>();
int calleeCallCount = 0;
test::TestTailCallee::Client callee(kj::heap<TestTailCalleeImpl>(calleeCallCount), loop);
auto request = caller.fooRequest();
request.setI(456);
request.setCallee(callee);
auto promise = request.send();
auto dependentCall0 = promise.getC().getCallSequenceRequest().send();
auto response = loop.wait(kj::mv(promise));
EXPECT_EQ(456, response.getI());
EXPECT_EQ(456, response.getI());
auto dependentCall1 = promise.getC().getCallSequenceRequest().send();
auto dependentCall2 = response.getC().getCallSequenceRequest().send();
EXPECT_EQ(0, loop.wait(kj::mv(dependentCall0)).getN());
EXPECT_EQ(1, loop.wait(kj::mv(dependentCall1)).getN());
EXPECT_EQ(2, loop.wait(kj::mv(dependentCall2)).getN());
EXPECT_EQ(1, calleeCallCount);
EXPECT_EQ(1, restorer.callCount);
}
} // namespace
} // namespace _ (private)
} // namespace capnp
This diff is collapsed.
......@@ -275,18 +275,31 @@ struct Call {
# The params may contain capabilities. These capabilities are automatically released when the
# call returns *unless* the Return message explicitly indicates that they are being retained.
sendReturnTo :union {
sendResultsTo :union {
# Where should the return message be sent?
caller @5 :Void;
# Send the return message back to the caller (the usual).
yourself @6 :QuestionId;
yourself @6 :Void;
# **(level 1)**
#
# This is actually an echo of a call originally made by the receiver, with the given question
# ID. The result of this call should directly resolve the original call, without ever sending
# a `Return` over the wire.
# Don't actually return the results to the sender. Instead, hold on to them and await
# instructions from the sender regarding what to do with them. In particular, the sender
# may subsequently send a `Return` for some other call (which the receiver had previously made
# to the sender) with `takeFromOtherAnswer` set. The results from this call are then used
# as the results of the other call.
#
# When `yourself` is used, the receiver must still send a `Return` for the call, but sets the
# field `resultsSentElsewhere` in that `Return` rather than including the results.
#
# This feature can be used to implement tail calls in which a call from Vat A to Vat B ends up
# returning the result of a call from Vat B back to Vat A.
#
# In particular, the most common use case for this feature is when Vat A makes a call to a
# promise in Vat B, and then that promise ends up resolving to a capability back in Vat A.
# Vat B must forward all the queued calls on that promise back to Vat A, but can set `yourself`
# in the calls so that the results need not pass back through Vat B.
#
# For example:
# - Alice, in Vat A, call foo() on Bob in Vat B.
......@@ -294,14 +307,18 @@ struct Call {
# - Later on, Bob resolves the promise from foo() to point at Carol, who lives in Vat A (next
# to Alice).
# - Vat B dutifully forwards the bar() call to Carol. Let us call this forwarded call bar'().
# - The `Call` for bar'() has `sendReturnTo` set to `yourself`, with the value being the
# Notice that bar() and bar'() are travelling in opposite directions on the same network
# link.
# - The `Call` for bar'() has `sendResultsTo` set to `yourself`, with the value being the
# question ID originally assigned to the bar() call.
# - Vat A receives bar'() and delivers it to Carol.
# - When bar'() returns, Vat A does *not* send a `Return` message to Vat B. Instead, it
# directly returns the result to Alice.
# - Vat A then sends a `Finish` message for bar().
# - Vat B, on receiving the `Finish`, sends a corresponding `Finish` for bar'().
# - Neither bar() nor bar'() ever see a `Return` message sent over the wire.
# - When bar'() returns, Vat A immediately takes the results and returns them from bar().
# - Meanwhile, Vat A sends a `Return` for bar'() to Vat B, with `resultsSentElsewhere` set in
# place of results.
# - Vat A sends a `Finish` for that call to Vat B.
# - Vat B receives the `Return` for bar'() and sends a `Return` for bar(), with
# `receivedFromYourself` set in place of the results.
# - Vat B receives the `Finish` for bar() and sends a `Finish` to bar'().
thirdParty @7 :RecipientId;
# **(level 3)**
......@@ -312,10 +329,9 @@ struct Call {
#
# This operates much like `yourself`, above, except that Carol is in a separate Vat C. `Call`
# messages are sent from Vat A -> Vat B and Vat B -> Vat C. A `Return` message is sent from
# Vat B -> Vat A that contains a `redirect` to Vat C. When Vat A sends an `Accept` to Vat C,
# it receives back a `Return` containing the call's actual result. Vat C never sends a `Return`
# to Vat B, although `Finish` messages must still be sent corresponding to every `Call` as well
# as the `Accept`.
# Vat B -> Vat A that contains `acceptFromThirdParty` in place of results. When Vat A sends
# an `Accept` to Vat C, it receives back a `Return` containing the call's actual result. Vat C
# also sends a `Return` to Vat B with `resultsSentElsewhere`.
}
}
......@@ -352,12 +368,20 @@ struct Return {
# Indicates that the call was canceled due to the caller sending a Finish message
# before the call had completed.
redirect @5 :ThirdPartyCapId;
resultsSentElsewhere @5 :Void;
# This is set when returning from a `Call` which had `sendResultsTo` set to something other
# than `caller`.
takeFromOtherAnswer @6 :QuestionId;
# The sender has also sent (before this message) a `Call` with the given question ID and with
# `sendResultsTo.yourself` set, and the results of that other call should be used as the
# results here.
acceptFromThirdParty @7 :ThirdPartyCapId;
# **(level 3)**
#
# The call has been redirected to another vat, and the result should be obtained by connecting
# to that vat directly. An `Accept` message sent to the vat will return the result. See
# `Call.sendReturnTo.thirdParty`.
# The caller should contact a third-party vat to pick up the results. An `Accept` message
# sent to the vat will return the result. This pairs with `Call.sendResultsTo.thirdParty`.
}
}
......
This diff is collapsed.
This diff is collapsed.
......@@ -22,6 +22,7 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "test-util.h"
#include <kj/debug.h>
#include <gtest/gtest.h>
namespace capnp {
......@@ -932,5 +933,43 @@ TestPipelineImpl::TestPipelineImpl(int& callCount): callCount(callCount) {}
});
}
kj::Promise<void> TestCallOrderImpl::getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) {
result.setN(count++);
return kj::READY_NOW;
}
TestTailCallerImpl::TestTailCallerImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestTailCallerImpl::fooAdvanced(
capnp::CallContext<test::TestTailCaller::FooParams,
test::TestTailCallee::TailResult> context) {
++callCount;
auto params = context.getParams();
auto tailRequest = params.getCallee().fooRequest();
tailRequest.setI(params.getI());
tailRequest.setT("from TestTailCaller");
return context.tailCall(kj::mv(tailRequest));
}
TestTailCalleeImpl::TestTailCalleeImpl(int& callCount): callCount(callCount) {}
kj::Promise<void> TestTailCalleeImpl::fooAdvanced(
capnp::CallContext<test::TestTailCallee::FooParams,
test::TestTailCallee::TailResult> context) {
++callCount;
auto params = context.getParams();
auto results = context.getResults();
results.setI(params.getI());
results.setT(params.getT());
results.setC(kj::heap<TestCallOrderImpl>());
return kj::READY_NOW;
}
} // namespace _ (private)
} // namespace capnp
......@@ -187,6 +187,40 @@ private:
int& callCount;
};
class TestCallOrderImpl final: public test::TestCallOrder::Server {
public:
kj::Promise<void> getCallSequence(
test::TestCallOrder::GetCallSequenceParams::Reader params,
test::TestCallOrder::GetCallSequenceResults::Builder result) override;
private:
uint count = 0;
};
class TestTailCallerImpl final: public test::TestTailCaller::Server {
public:
TestTailCallerImpl(int& callCount);
kj::Promise<void> fooAdvanced(
capnp::CallContext<test::TestTailCaller::FooParams,
test::TestTailCallee::TailResult> context) override;
private:
int& callCount;
};
class TestTailCalleeImpl final: public test::TestTailCallee::Server {
public:
TestTailCalleeImpl(int& callCount);
kj::Promise<void> fooAdvanced(
capnp::CallContext<test::TestTailCallee::FooParams,
test::TestTailCallee::TailResult> context) override;
private:
int& callCount;
};
} // namespace _ (private)
} // namespace capnp
......
......@@ -610,6 +610,25 @@ interface TestPipeline {
}
}
interface TestCallOrder {
getCallSequence @0 () -> (n: UInt32);
# First call returns 0, next returns 1, ...
}
interface TestTailCallee {
struct TailResult {
i @0 :UInt32;
t @1 :Text;
c @2 :TestCallOrder;
}
foo @0 (i :Int32, t :Text) -> TailResult;
}
interface TestTailCaller {
foo @0 (i :Int32, callee :TestTailCallee) -> TestTailCallee.TailResult;
}
struct TestSturdyRefHostId {
host @0 :Text;
}
......@@ -620,6 +639,8 @@ struct TestSturdyRefObjectId {
testInterface @0;
testExtends @1;
testPipeline @2;
testTailCallee @3;
testTailCaller @4;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment