Commit 56493100 authored by Kenton Varda's avatar Kenton Varda

Add client-side streaming hooks.

Also, push harder on the code generator such that `StreamResult` doesn't show up in generated code at all.

So now we have `StreamingRequest<Params>` which is like `Request<Params, Results>`, and we have `StreamingCallContext<Params>` which is like `CallContext<Params, Results>`.
parent 34481c85
......@@ -1114,19 +1114,19 @@ KJ_TEST("Streaming calls block subsequent calls") {
{
auto req = cap.doStreamIRequest();
req.setI(123);
promise1 = req.send().ignoreResult();
promise1 = req.send();
}
{
auto req = cap.doStreamJRequest();
req.setJ(321);
promise2 = req.send().ignoreResult();
promise2 = req.send();
}
{
auto req = cap.doStreamIRequest();
req.setI(456);
promise3 = req.send().ignoreResult();
promise3 = req.send();
}
auto promise4 = cap.finishStreamRequest().send();
......@@ -1187,19 +1187,19 @@ KJ_TEST("Streaming calls can be canceled") {
{
auto req = cap.doStreamIRequest();
req.setI(123);
promise1 = req.send().ignoreResult();
promise1 = req.send();
}
{
auto req = cap.doStreamJRequest();
req.setJ(321);
promise2 = req.send().ignoreResult();
promise2 = req.send();
}
{
auto req = cap.doStreamIRequest();
req.setI(456);
promise3 = req.send().ignoreResult();
promise3 = req.send();
}
auto promise4 = cap.finishStreamRequest().send();
......@@ -1250,19 +1250,19 @@ KJ_TEST("Streaming call throwing cascades to following calls") {
{
auto req = cap.doStreamIRequest();
req.setI(123);
promise1 = req.send().ignoreResult();
promise1 = req.send();
}
{
auto req = cap.doStreamJRequest();
req.setJ(321);
promise2 = req.send().ignoreResult();
promise2 = req.send();
}
{
auto req = cap.doStreamIRequest();
req.setI(456);
promise3 = req.send().ignoreResult();
promise3 = req.send();
}
auto promise4 = cap.finishStreamRequest().send();
......
......@@ -208,10 +208,6 @@ public:
RemotePromise<AnyPointer> send() override {
KJ_REQUIRE(message.get() != nullptr, "Already called send() on this request.");
// For the lambda capture.
uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId;
auto cancelPaf = kj::newPromiseAndFulfiller<void>();
auto context = kj::refcounted<LocalCallContext>(
......@@ -241,6 +237,12 @@ public:
kj::mv(promise), AnyPointer::Pipeline(kj::mv(promiseAndPipeline.pipeline)));
}
kj::Promise<void> sendStreaming() override {
// We don't do any special handling of streaming in RequestHook for local requests, because
// there is no latency to compensate for between the client and server in this case.
return send().ignoreResult();
}
const void* getBrand() override {
return nullptr;
}
......@@ -705,6 +707,10 @@ public:
AnyPointer::Pipeline(kj::refcounted<BrokenPipeline>(exception)));
}
kj::Promise<void> sendStreaming() override {
return kj::cp(exception);
}
const void* getBrand() override {
return nullptr;
}
......
......@@ -97,6 +97,7 @@ class RequestHook;
class ResponseHook;
class PipelineHook;
class ClientHook;
class StreamResult;
template <typename Params, typename Results>
class Request: public Params::Builder {
......@@ -125,6 +126,27 @@ private:
friend class RequestHook;
};
template <typename Params>
class StreamingRequest: public Params::Builder {
// Like `Request` but for streaming requests.
public:
inline StreamingRequest(typename Params::Builder builder, kj::Own<RequestHook>&& hook)
: Params::Builder(builder), hook(kj::mv(hook)) {}
inline StreamingRequest(decltype(nullptr)): Params::Builder(nullptr) {}
kj::Promise<void> send() KJ_WARN_UNUSED_RESULT;
private:
kj::Own<RequestHook> hook;
friend class Capability::Client;
friend struct DynamicCapability;
template <typename, typename>
friend class CallContext;
friend class RequestHook;
};
template <typename Results>
class Response: public Results::Reader {
// A completed call. This class extends a Reader for the call's answer structure. The Response
......@@ -227,6 +249,9 @@ protected:
template <typename Params, typename Results>
Request<Params, Results> newCall(uint64_t interfaceId, uint16_t methodId,
kj::Maybe<MessageSize> sizeHint);
template <typename Params>
StreamingRequest<Params> newStreamingCall(uint64_t interfaceId, uint16_t methodId,
kj::Maybe<MessageSize> sizeHint);
private:
kj::Own<ClientHook> hook;
......@@ -330,6 +355,30 @@ private:
friend struct DynamicCapability;
};
template <typename Params>
class StreamingCallContext: public kj::DisallowConstCopy {
// Like CallContext but for streaming calls.
public:
explicit StreamingCallContext(CallContextHook& hook);
typename Params::Reader getParams();
void releaseParams();
// Note: tailCall() is not supported because:
// - It would significantly complicate the implementation of streaming.
// - It wouldn't be particularly useful since streaming calls don't return anything, and they
// already compensate for latency.
void allowCancellation();
private:
CallContextHook* hook;
friend class Capability::Server;
friend struct DynamicCapability;
};
class Capability::Server {
// Objects implementing a Cap'n Proto interface must subclass this. Typically, such objects
// will instead subclass a typed Server interface which will take care of implementing
......@@ -377,6 +426,9 @@ protected:
template <typename Params, typename Results>
CallContext<Params, Results> internalGetTypedContext(
CallContext<AnyPointer, AnyPointer> typeless);
template <typename Params>
StreamingCallContext<Params> internalGetTypedStreamingContext(
CallContext<AnyPointer, AnyPointer> typeless);
DispatchCallResult internalUnimplemented(const char* actualInterfaceName,
uint64_t requestedTypeId);
DispatchCallResult internalUnimplemented(const char* interfaceName,
......@@ -497,6 +549,9 @@ public:
virtual RemotePromise<AnyPointer> send() = 0;
// Send the call and return a promise for the result.
virtual kj::Promise<void> sendStreaming() = 0;
// Send a streaming call.
virtual const void* getBrand() = 0;
// Returns a void* that identifies who made this request. This can be used by an RPC adapter to
// discover when tail call is going to be sent over its own connection and therefore can be
......@@ -809,6 +864,13 @@ RemotePromise<Results> Request<Params, Results>::send() {
return RemotePromise<Results>(kj::mv(typedPromise), kj::mv(typedPipeline));
}
template <typename Params>
kj::Promise<void> StreamingRequest<Params>::send() {
auto promise = hook->sendStreaming();
hook = nullptr; // prevent reuse
return promise;
}
inline Capability::Client::Client(kj::Own<ClientHook>&& hook): hook(kj::mv(hook)) {}
template <typename T, typename>
inline Capability::Client::Client(kj::Own<T>&& server)
......@@ -839,17 +901,33 @@ inline Request<Params, Results> Capability::Client::newCall(
auto typeless = hook->newCall(interfaceId, methodId, sizeHint);
return Request<Params, Results>(typeless.template getAs<Params>(), kj::mv(typeless.hook));
}
template <typename Params>
inline StreamingRequest<Params> Capability::Client::newStreamingCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) {
auto typeless = hook->newCall(interfaceId, methodId, sizeHint);
return StreamingRequest<Params>(typeless.template getAs<Params>(), kj::mv(typeless.hook));
}
template <typename Params, typename Results>
inline CallContext<Params, Results>::CallContext(CallContextHook& hook): hook(&hook) {}
template <typename Params>
inline StreamingCallContext<Params>::StreamingCallContext(CallContextHook& hook): hook(&hook) {}
template <typename Params, typename Results>
inline typename Params::Reader CallContext<Params, Results>::getParams() {
return hook->getParams().template getAs<Params>();
}
template <typename Params>
inline typename Params::Reader StreamingCallContext<Params>::getParams() {
return hook->getParams().template getAs<Params>();
}
template <typename Params, typename Results>
inline void CallContext<Params, Results>::releaseParams() {
hook->releaseParams();
}
template <typename Params>
inline void StreamingCallContext<Params>::releaseParams() {
hook->releaseParams();
}
template <typename Params, typename Results>
inline typename Results::Builder CallContext<Params, Results>::getResults(
kj::Maybe<MessageSize> sizeHint) {
......@@ -885,6 +963,10 @@ template <typename Params, typename Results>
inline void CallContext<Params, Results>::allowCancellation() {
hook->allowCancellation();
}
template <typename Params>
inline void StreamingCallContext<Params>::allowCancellation() {
hook->allowCancellation();
}
template <typename Params, typename Results>
CallContext<Params, Results> Capability::Server::internalGetTypedContext(
......@@ -892,6 +974,12 @@ CallContext<Params, Results> Capability::Server::internalGetTypedContext(
return CallContext<Params, Results>(*typeless.hook);
}
template <typename Params>
StreamingCallContext<Params> Capability::Server::internalGetTypedStreamingContext(
CallContext<AnyPointer, AnyPointer> typeless) {
return StreamingCallContext<Params>(*typeless.hook);
}
Capability::Client Capability::Server::thisCap() {
return Client(thisHook->addRef());
}
......
......@@ -830,7 +830,14 @@ private:
}
kj::Maybe<kj::StringTree> makeBrandDepInitializer(Schema type) {
return makeBrandDepInitializer(type, cppFullName(type, nullptr));
// Be careful not to invoke cppFullName() if it would just be thrown away, as doing so will
// add the type's declaring file to `usedImports`. In particular, this causes `stream.capnp.h`
// to be #included unnecessarily.
if (type.isBranded()) {
return makeBrandDepInitializer(type, cppFullName(type, nullptr));
} else {
return nullptr;
}
}
kj::Maybe<kj::StringTree> makeBrandDepInitializer(
......@@ -2140,7 +2147,7 @@ private:
auto paramProto = paramSchema.getProto();
auto resultProto = resultSchema.getProto();
bool isStreaming = resultProto.getId() == typeId<StreamResult>();
bool isStreaming = method.isStreaming();
auto implicitParamsReader = proto.getImplicitParameters();
auto implicitParamsBuilder = kj::heapArrayBuilder<CppTypeName>(implicitParamsReader.size());
......@@ -2178,7 +2185,10 @@ private:
}
CppTypeName resultType;
CppTypeName genericResultType;
if (resultProto.getScopeId() == 0) {
if (isStreaming) {
// We don't use resultType or genericResultType in this case. We want to avoid computing them
// at all so that we don't end up marking stream.capnp.h in usedImports.
} else if (resultProto.getScopeId() == 0) {
resultType = interfaceTypeName;
if (implicitParams.size() == 0) {
resultType.addMemberType(kj::str(titleCase, "Results"));
......@@ -2196,7 +2206,7 @@ private:
kj::String shortParamType = paramProto.getScopeId() == 0 ?
kj::str(titleCase, "Params") : kj::str(genericParamType);
kj::String shortResultType = resultProto.getScopeId() == 0 ?
kj::String shortResultType = resultProto.getScopeId() == 0 || isStreaming ?
kj::str(titleCase, "Results") : kj::str(genericResultType);
auto interfaceProto = method.getContainingInterface().getProto();
......@@ -2218,10 +2228,13 @@ private:
templateContext.allDecls(),
implicitParamsTemplateDecl,
templateContext.isGeneric() ? "CAPNP_AUTO_IF_MSVC(" : "",
"::capnp::Request<", paramType, ", ", resultType, ">",
isStreaming ? kj::strTree("::capnp::StreamingRequest<", paramType, ">")
: kj::strTree("::capnp::Request<", paramType, ", ", resultType, ">"),
templateContext.isGeneric() ? ")\n" : "\n",
interfaceName, "::Client::", name, "Request(::kj::Maybe< ::capnp::MessageSize> sizeHint) {\n"
" return newCall<", paramType, ", ", resultType, ">(\n"
interfaceName, "::Client::", name, "Request(::kj::Maybe< ::capnp::MessageSize> sizeHint) {\n",
isStreaming
? kj::strTree(" return newStreamingCall<", paramType, ">(\n")
: kj::strTree(" return newCall<", paramType, ", ", resultType, ">(\n"),
" 0x", interfaceIdHex, "ull, ", methodId, ", sizeHint);\n"
"}\n");
......@@ -2229,7 +2242,8 @@ private:
kj::strTree(
implicitParamsTemplateDecl.size() == 0 ? "" : " ", implicitParamsTemplateDecl,
templateContext.isGeneric() ? " CAPNP_AUTO_IF_MSVC(" : " ",
"::capnp::Request<", paramType, ", ", resultType, ">",
isStreaming ? kj::strTree("::capnp::StreamingRequest<", paramType, ">")
: kj::strTree("::capnp::Request<", paramType, ", ", resultType, ">"),
templateContext.isGeneric() ? ")" : "",
" ", name, "Request(\n"
" ::kj::Maybe< ::capnp::MessageSize> sizeHint = nullptr);\n"),
......@@ -2239,8 +2253,11 @@ private:
" typedef ", genericParamType, " ", titleCase, "Params;\n"),
resultProto.getScopeId() != 0 ? kj::strTree() : kj::strTree(
" typedef ", genericResultType, " ", titleCase, "Results;\n"),
" typedef ::capnp::CallContext<", shortParamType, ", ", shortResultType, "> ",
titleCase, "Context;\n"
isStreaming
? kj::strTree(" typedef ::capnp::StreamingCallContext<", shortParamType, "> ")
: kj::strTree(
" typedef ::capnp::CallContext<", shortParamType, ", ", shortResultType, "> "),
titleCase, "Context;\n"
" virtual ::kj::Promise<void> ", identifierName, "(", titleCase, "Context context);\n"),
implicitParams.size() == 0 ? kj::strTree() : kj::mv(requestMethodImpl),
......@@ -2265,8 +2282,8 @@ private:
// the exception.
" return {\n"
" kj::evalNow([&]() {\n"
" return ", identifierName, "(::capnp::Capability::Server::internalGetTypedContext<\n"
" ", genericParamType, ", ", genericResultType, ">(context));\n"
" return ", identifierName, "(::capnp::Capability::Server::internalGetTypedStreamingContext<\n"
" ", genericParamType, ">(context));\n"
" }),\n"
" true\n"
" };\n")
......
......@@ -76,6 +76,7 @@ Capability::Server::DispatchCallResult DynamicCapability::Server::dispatchCall(
RemotePromise<DynamicStruct> Request<DynamicStruct, DynamicStruct>::send() {
auto typelessPromise = hook->send();
hook = nullptr; // prevent reuse
auto resultSchemaCopy = resultSchema;
// Convert the Promise to return the correct response type.
......@@ -94,4 +95,12 @@ RemotePromise<DynamicStruct> Request<DynamicStruct, DynamicStruct>::send() {
return RemotePromise<DynamicStruct>(kj::mv(typedPromise), kj::mv(typedPipeline));
}
kj::Promise<void> Request<DynamicStruct, DynamicStruct>::sendStreaming() {
KJ_REQUIRE(resultSchema.isStreamResult());
auto promise = hook->sendStreaming();
hook = nullptr; // prevent reuse
return promise;
}
} // namespace capnp
......@@ -552,6 +552,10 @@ public:
RemotePromise<DynamicStruct> send();
// Send the call and return a promise for the results.
kj::Promise<void> sendStreaming();
// Use when the caller is aware that the response type is StreamResult and wants to invoke
// streaming behavior. It is an error to call this if the response type is not StreamResult.
private:
kj::Own<RequestHook> hook;
StructSchema resultSchema;
......
......@@ -403,3 +403,11 @@ inline constexpr uint sizeInWords() {
static constexpr ::capnp::_::RawSchema const* schema = &::capnp::schemas::s_##id;
#endif // CAPNP_LITE, else
namespace capnp {
namespace schemas {
CAPNP_DECLARE_SCHEMA(995f9a3377c0b16e);
// HACK: Forward-declare the RawSchema for StreamResult, from stream.capnp. This allows capnp
// files which declare steraming methods to avoid including stream.capnp.h.
}
}
......@@ -219,6 +219,18 @@ public:
return RemotePromise<AnyPointer>(kj::mv(newPromise), kj::mv(newPipeline));
}
kj::Promise<void> sendStreaming() override {
auto promise = inner->sendStreaming();
KJ_IF_MAYBE(r, policy->onRevoked()) {
promise = promise.exclusiveJoin(r->then([]() {
KJ_FAIL_REQUIRE("onRevoked() promise resolved; it should only reject");
}));
}
return promise;
}
const void* getBrand() override {
return MEMBRANE_BRAND;
}
......
......@@ -1510,6 +1510,10 @@ private:
}
}
kj::Promise<void> sendStreaming() override {
KJ_UNIMPLEMENTED("TODO(now)");
}
struct TailInfo {
QuestionId questionId;
kj::Promise<void> promise;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment