Commit c3cfe9e5 authored by Kenton Varda's avatar Kenton Varda

Implement server side of streaming.

There are two things that every capability server must implement:

* When a streaming method is delivered, it blocks subsequent calls on the same capability. Although not strictly needed to achieve flow control, this simplifies the implementation of streaming servers -- many would otherwise need to implement such serialization manually.
* When a streaming method throws, all subsequent calls also throw the same exception. This is important because exceptions thrown by a streaming call might not actually be delivered to a client, since the client doesn't necessarily wait for the results before making the next call. Again, a streaming server could implement this manually, but almost all streaming servers will likely need it, and this makes things easier.
parent a784f2f7
......@@ -1101,6 +1101,198 @@ KJ_TEST("clone() with caps") {
KJ_EXPECT(ClientHook::from(copy[2]).get() != ClientHook::from(root[0]).get());
}
KJ_TEST("Streaming calls block subsequent calls") {
kj::EventLoop loop;
kj::WaitScope waitScope(loop);
auto ownServer = kj::heap<TestStreamingImpl>();
auto& server = *ownServer;
test::TestStreaming::Client cap = kj::mv(ownServer);
kj::Promise<void> promise1 = nullptr, promise2 = nullptr, promise3 = nullptr;
{
auto req = cap.doStreamIRequest();
req.setI(123);
promise1 = req.send().ignoreResult();
}
{
auto req = cap.doStreamJRequest();
req.setJ(321);
promise2 = req.send().ignoreResult();
}
{
auto req = cap.doStreamIRequest();
req.setI(456);
promise3 = req.send().ignoreResult();
}
auto promise4 = cap.finishStreamRequest().send();
KJ_EXPECT(server.iSum == 0);
KJ_EXPECT(server.jSum == 0);
KJ_EXPECT(!promise1.poll(waitScope));
KJ_EXPECT(!promise2.poll(waitScope));
KJ_EXPECT(!promise3.poll(waitScope));
KJ_EXPECT(!promise4.poll(waitScope));
KJ_EXPECT(server.iSum == 123);
KJ_EXPECT(server.jSum == 0);
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
KJ_EXPECT(promise1.poll(waitScope));
KJ_EXPECT(!promise2.poll(waitScope));
KJ_EXPECT(!promise3.poll(waitScope));
KJ_EXPECT(!promise4.poll(waitScope));
KJ_EXPECT(server.iSum == 123);
KJ_EXPECT(server.jSum == 321);
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
KJ_EXPECT(promise1.poll(waitScope));
KJ_EXPECT(promise2.poll(waitScope));
KJ_EXPECT(!promise3.poll(waitScope));
KJ_EXPECT(!promise4.poll(waitScope));
KJ_EXPECT(server.iSum == 579);
KJ_EXPECT(server.jSum == 321);
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
KJ_EXPECT(promise1.poll(waitScope));
KJ_EXPECT(promise2.poll(waitScope));
KJ_EXPECT(promise3.poll(waitScope));
KJ_EXPECT(promise4.poll(waitScope));
auto result = promise4.wait(waitScope);
KJ_EXPECT(result.getTotalI() == 579);
KJ_EXPECT(result.getTotalJ() == 321);
}
KJ_TEST("Streaming calls can be canceled") {
kj::EventLoop loop;
kj::WaitScope waitScope(loop);
auto ownServer = kj::heap<TestStreamingImpl>();
auto& server = *ownServer;
test::TestStreaming::Client cap = kj::mv(ownServer);
kj::Promise<void> promise1 = nullptr, promise2 = nullptr, promise3 = nullptr;
{
auto req = cap.doStreamIRequest();
req.setI(123);
promise1 = req.send().ignoreResult();
}
{
auto req = cap.doStreamJRequest();
req.setJ(321);
promise2 = req.send().ignoreResult();
}
{
auto req = cap.doStreamIRequest();
req.setI(456);
promise3 = req.send().ignoreResult();
}
auto promise4 = cap.finishStreamRequest().send();
// Cancel the streaming calls.
promise1 = nullptr;
promise2 = nullptr;
promise3 = nullptr;
KJ_EXPECT(server.iSum == 0);
KJ_EXPECT(server.jSum == 0);
KJ_EXPECT(!promise4.poll(waitScope));
KJ_EXPECT(server.iSum == 123);
KJ_EXPECT(server.jSum == 0);
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
KJ_EXPECT(!promise4.poll(waitScope));
// The call to doStreamJ() opted into cancellation so the next call to doStreamI() happens
// immediately.
KJ_EXPECT(server.iSum == 579);
KJ_EXPECT(server.jSum == 321);
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
KJ_EXPECT(promise4.poll(waitScope));
auto result = promise4.wait(waitScope);
KJ_EXPECT(result.getTotalI() == 579);
KJ_EXPECT(result.getTotalJ() == 321);
}
KJ_TEST("Streaming call throwing cascades to following calls") {
kj::EventLoop loop;
kj::WaitScope waitScope(loop);
auto ownServer = kj::heap<TestStreamingImpl>();
auto& server = *ownServer;
test::TestStreaming::Client cap = kj::mv(ownServer);
server.jShouldThrow = true;
kj::Promise<void> promise1 = nullptr, promise2 = nullptr, promise3 = nullptr;
{
auto req = cap.doStreamIRequest();
req.setI(123);
promise1 = req.send().ignoreResult();
}
{
auto req = cap.doStreamJRequest();
req.setJ(321);
promise2 = req.send().ignoreResult();
}
{
auto req = cap.doStreamIRequest();
req.setI(456);
promise3 = req.send().ignoreResult();
}
auto promise4 = cap.finishStreamRequest().send();
KJ_EXPECT(server.iSum == 0);
KJ_EXPECT(server.jSum == 0);
KJ_EXPECT(!promise1.poll(waitScope));
KJ_EXPECT(!promise2.poll(waitScope));
KJ_EXPECT(!promise3.poll(waitScope));
KJ_EXPECT(!promise4.poll(waitScope));
KJ_EXPECT(server.iSum == 123);
KJ_EXPECT(server.jSum == 0);
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
KJ_EXPECT(promise1.poll(waitScope));
KJ_EXPECT(promise2.poll(waitScope));
KJ_EXPECT(promise3.poll(waitScope));
KJ_EXPECT(promise4.poll(waitScope));
KJ_EXPECT(server.iSum == 123);
KJ_EXPECT(server.jSum == 321);
KJ_EXPECT_THROW_MESSAGE("throw requested", promise2.wait(waitScope));
KJ_EXPECT_THROW_MESSAGE("throw requested", promise3.wait(waitScope));
KJ_EXPECT_THROW_MESSAGE("throw requested", promise4.wait(waitScope));
}
} // namespace
} // namespace _
} // namespace capnp
......@@ -82,15 +82,21 @@ kj::Promise<kj::Maybe<int>> Capability::Client::getFd() {
}
}
kj::Promise<void> Capability::Server::internalUnimplemented(
Capability::Server::DispatchCallResult Capability::Server::internalUnimplemented(
const char* actualInterfaceName, uint64_t requestedTypeId) {
return KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.",
actualInterfaceName, requestedTypeId);
return {
KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.",
actualInterfaceName, requestedTypeId),
false
};
}
kj::Promise<void> Capability::Server::internalUnimplemented(
Capability::Server::DispatchCallResult Capability::Server::internalUnimplemented(
const char* interfaceName, uint64_t typeId, uint16_t methodId) {
return KJ_EXCEPTION(UNIMPLEMENTED, "Method not implemented.", interfaceName, typeId, methodId);
return {
KJ_EXCEPTION(UNIMPLEMENTED, "Method not implemented.", interfaceName, typeId, methodId),
false
};
}
kj::Promise<void> Capability::Server::internalUnimplemented(
......@@ -495,8 +501,12 @@ public:
// Note also that QueuedClient depends on this evalLater() to ensure that pipelined calls don't
// complete before 'whenMoreResolved()' promises resolve.
auto promise = kj::evalLater([this,interfaceId,methodId,contextPtr]() {
return server->dispatchCall(interfaceId, methodId,
CallContext<AnyPointer, AnyPointer>(*contextPtr));
if (blocked) {
return kj::newAdaptedPromise<kj::Promise<void>, BlockedCall>(
*this, interfaceId, methodId, *contextPtr);
} else {
return callInternal(interfaceId, methodId, *contextPtr);
}
}).attach(kj::addRef(*this));
// We have to fork this promise for the pipeline to receive a copy of the answer.
......@@ -553,6 +563,106 @@ private:
kj::Own<Capability::Server> server;
_::CapabilityServerSetBase* capServerSet = nullptr;
void* ptr = nullptr;
class BlockedCall {
public:
BlockedCall(kj::PromiseFulfiller<kj::Promise<void>>& fulfiller, LocalClient& client,
uint64_t interfaceId, uint16_t methodId, CallContextHook& context)
: fulfiller(fulfiller), client(client),
interfaceId(interfaceId), methodId(methodId), context(context),
prev(client.blockedCallsEnd) {
*prev = *this;
client.blockedCallsEnd = &next;
}
~BlockedCall() noexcept(false) {
unlink();
}
void unblock() {
unlink();
fulfiller.fulfill(kj::evalNow([this]() {
return client.callInternal(interfaceId, methodId, context);
}));
}
private:
kj::PromiseFulfiller<kj::Promise<void>>& fulfiller;
LocalClient& client;
uint64_t interfaceId;
uint16_t methodId;
CallContextHook& context;
kj::Maybe<BlockedCall&> next;
kj::Maybe<BlockedCall&>* prev;
void unlink() {
if (prev != nullptr) {
*prev = next;
KJ_IF_MAYBE(n, next) {
n->prev = prev;
} else {
client.blockedCallsEnd = prev;
}
prev = nullptr;
}
}
};
class BlockingScope {
public:
BlockingScope(LocalClient& client): client(client) { client.blocked = true; }
BlockingScope(): client(nullptr) {}
BlockingScope(BlockingScope&& other): client(other.client) { other.client = nullptr; }
KJ_DISALLOW_COPY(BlockingScope);
~BlockingScope() noexcept(false) {
KJ_IF_MAYBE(c, client) {
c->unblock();
}
}
private:
kj::Maybe<LocalClient&> client;
};
bool blocked = false;
kj::Maybe<kj::Exception> brokenException;
kj::Maybe<BlockedCall&> blockedCalls;
kj::Maybe<BlockedCall&>* blockedCallsEnd = &blockedCalls;
void unblock() {
blocked = false;
while (!blocked) {
KJ_IF_MAYBE(t, blockedCalls) {
t->unblock();
} else {
break;
}
}
}
kj::Promise<void> callInternal(uint64_t interfaceId, uint16_t methodId,
CallContextHook& context) {
KJ_ASSERT(!blocked);
KJ_IF_MAYBE(e, brokenException) {
// Previous streaming call threw, so everything fails from now on.
return kj::cp(*e);
}
auto result = server->dispatchCall(interfaceId, methodId,
CallContext<AnyPointer, AnyPointer>(context));
if (result.isStreaming) {
return result.promise
.catch_([this](kj::Exception&& e) {
brokenException = kj::cp(e);
kj::throwRecoverableException(kj::mv(e));
}).attach(BlockingScope(*this));
} else {
return kj::mv(result.promise);
}
}
};
kj::Own<ClientHook> Capability::Client::makeLocalClient(kj::Own<Capability::Server>&& server) {
......
......@@ -338,8 +338,18 @@ class Capability::Server {
public:
typedef Capability Serves;
virtual kj::Promise<void> dispatchCall(uint64_t interfaceId, uint16_t methodId,
CallContext<AnyPointer, AnyPointer> context) = 0;
struct DispatchCallResult {
kj::Promise<void> promise;
// Promise for completion of the call.
bool isStreaming;
// If true, this method was declared as `-> stream;`. No other calls should be permitted until
// this call finishes, and if this call throws an exception, all future calls will throw the
// same exception.
};
virtual DispatchCallResult dispatchCall(uint64_t interfaceId, uint16_t methodId,
CallContext<AnyPointer, AnyPointer> context) = 0;
// Call the given method. `params` is the input struct, and should be released as soon as it
// is no longer needed. `context` may be used to allocate the output struct and deal with
// cancellation.
......@@ -367,10 +377,10 @@ protected:
template <typename Params, typename Results>
CallContext<Params, Results> internalGetTypedContext(
CallContext<AnyPointer, AnyPointer> typeless);
kj::Promise<void> internalUnimplemented(const char* actualInterfaceName,
uint64_t requestedTypeId);
kj::Promise<void> internalUnimplemented(const char* interfaceName,
uint64_t typeId, uint16_t methodId);
DispatchCallResult internalUnimplemented(const char* actualInterfaceName,
uint64_t requestedTypeId);
DispatchCallResult internalUnimplemented(const char* interfaceName,
uint64_t typeId, uint16_t methodId);
kj::Promise<void> internalUnimplemented(const char* interfaceName, const char* methodName,
uint64_t typeId, uint16_t methodId);
......
......@@ -37,6 +37,7 @@
#include <set>
#include <kj/main.h>
#include <algorithm>
#include <capnp/stream.capnp.h>
#if _WIN32
#define WIN32_LEAN_AND_MEAN // ::eyeroll::
......@@ -2139,6 +2140,8 @@ private:
auto paramProto = paramSchema.getProto();
auto resultProto = resultSchema.getProto();
bool isStreaming = resultProto.getId() == typeId<StreamResult>();
auto implicitParamsReader = proto.getImplicitParameters();
auto implicitParamsBuilder = kj::heapArrayBuilder<CppTypeName>(implicitParamsReader.size());
for (auto param: implicitParamsReader) {
......@@ -2252,9 +2255,29 @@ private:
"}\n"),
kj::strTree(
" case ", methodId, ":\n"
" return ", identifierName, "(::capnp::Capability::Server::internalGetTypedContext<\n"
" ", genericParamType, ", ", genericResultType, ">(context));\n")
" case ", methodId, ":\n",
isStreaming
? kj::strTree(
// For streaming calls, we need to add an evalNow() here so that exceptions thrown
// directly from the call can propagate to later calls. If we don't capture the
// exception properly then the caller will never find out that this is a streaming
// call (indicated by the boolean in the return value) so won't know to propagate
// the exception.
" return {\n"
" kj::evalNow([&]() {\n"
" return ", identifierName, "(::capnp::Capability::Server::internalGetTypedContext<\n"
" ", genericParamType, ", ", genericResultType, ">(context));\n"
" }),\n"
" true\n"
" };\n")
: kj::strTree(
// For non-streaming calls we let exceptions just flow through for a little more
// efficiency.
" return {\n"
" ", identifierName, "(::capnp::Capability::Server::internalGetTypedContext<\n"
" ", genericParamType, ", ", genericResultType, ">(context)),\n"
" false\n"
" };\n"))
};
}
......@@ -2403,7 +2426,8 @@ private:
"public:\n",
" typedef ", name, " Serves;\n"
"\n"
" ::kj::Promise<void> dispatchCall(uint64_t interfaceId, uint16_t methodId,\n"
" ::capnp::Capability::Server::DispatchCallResult dispatchCall(\n"
" uint64_t interfaceId, uint16_t methodId,\n"
" ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context)\n"
" override;\n"
"\n"
......@@ -2415,7 +2439,8 @@ private:
" .template castAs<", typeName, ">();\n"
" }\n"
"\n"
" ::kj::Promise<void> dispatchCallInternal(uint16_t methodId,\n"
" ::capnp::Capability::Server::DispatchCallResult dispatchCallInternal(\n"
" uint16_t methodId,\n"
" ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context);\n"
"};\n"
"#endif // !CAPNP_LITE\n"
......@@ -2459,7 +2484,7 @@ private:
"#if !CAPNP_LITE\n",
KJ_MAP(m, methods) { return kj::mv(m.sourceDefs); },
templateContext.allDecls(),
"::kj::Promise<void> ", fullName, "::Server::dispatchCall(\n"
"::capnp::Capability::Server::DispatchCallResult ", fullName, "::Server::dispatchCall(\n"
" uint64_t interfaceId, uint16_t methodId,\n"
" ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context) {\n"
" switch (interfaceId) {\n"
......@@ -2476,7 +2501,7 @@ private:
" }\n"
"}\n",
templateContext.allDecls(),
"::kj::Promise<void> ", fullName, "::Server::dispatchCallInternal(\n"
"::capnp::Capability::Server::DispatchCallResult ", fullName, "::Server::dispatchCallInternal(\n"
" uint16_t methodId,\n"
" ::capnp::CallContext< ::capnp::AnyPointer, ::capnp::AnyPointer> context) {\n"
" switch (methodId) {\n",
......
......@@ -52,15 +52,19 @@ Request<DynamicStruct, DynamicStruct> DynamicCapability::Client::newRequest(
return newRequest(schema.getMethodByName(methodName), sizeHint);
}
kj::Promise<void> DynamicCapability::Server::dispatchCall(
Capability::Server::DispatchCallResult DynamicCapability::Server::dispatchCall(
uint64_t interfaceId, uint16_t methodId,
CallContext<AnyPointer, AnyPointer> context) {
KJ_IF_MAYBE(interface, schema.findSuperclass(interfaceId)) {
auto methods = interface->getMethods();
if (methodId < methods.size()) {
auto method = methods[methodId];
return call(method, CallContext<DynamicStruct, DynamicStruct>(*context.hook,
method.getParamType(), method.getResultType()));
auto resultType = method.getResultType();
return {
call(method, CallContext<DynamicStruct, DynamicStruct>(*context.hook,
method.getParamType(), resultType)),
resultType.isStreamResult()
};
} else {
return internalUnimplemented(
interface->getProto().getDisplayName().cStr(), interfaceId, methodId);
......
......@@ -531,8 +531,8 @@ public:
virtual kj::Promise<void> call(InterfaceSchema::Method method,
CallContext<DynamicStruct, DynamicStruct> context) = 0;
kj::Promise<void> dispatchCall(uint64_t interfaceId, uint16_t methodId,
CallContext<AnyPointer, AnyPointer> context) override final;
DispatchCallResult dispatchCall(uint64_t interfaceId, uint16_t methodId,
CallContext<AnyPointer, AnyPointer> context) override final;
inline InterfaceSchema getSchema() const { return schema; }
......
......@@ -318,6 +318,42 @@ private:
kj::AutoCloseFd fd;
};
class TestStreamingImpl final: public test::TestStreaming::Server {
public:
uint iSum = 0;
uint jSum = 0;
kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> fulfiller;
bool jShouldThrow = false;
kj::Promise<void> doStreamI(DoStreamIContext context) override {
iSum += context.getParams().getI();
auto paf = kj::newPromiseAndFulfiller<void>();
fulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
kj::Promise<void> doStreamJ(DoStreamJContext context) override {
context.allowCancellation();
jSum += context.getParams().getJ();
if (jShouldThrow) {
KJ_FAIL_ASSERT("throw requested") { break; }
return kj::READY_NOW;
}
auto paf = kj::newPromiseAndFulfiller<void>();
fulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
kj::Promise<void> finishStream(FinishStreamContext context) override {
auto results = context.getResults();
results.setTotalI(iSum);
results.setTotalJ(jSum);
return kj::READY_NOW;
}
};
#endif // !CAPNP_LITE
} // namespace _ (private)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment