Commit 4a4fe65c authored by Kenton Varda's avatar Kenton Varda

CapabilityServerSet::getLocalServer() must wait for stream queue.

Consider a capnp streaming type that wraps a kj::AsyncOutputStream.

KJ streams require the caller to avoid doing multiple writes at once. Capnp streaming conveniently guarantees only one streaming call will be delivered at a time. This is great because it means the app does not have to do its own queuing of writes.

However, the app may want to use a CapabilityServerSet to unwrap the capability and get at the underlying KJ stream to optimize by writing to it directly. However, before it can issue a direct write, it has to wait for all RPC writes to complete. These RPC writes were probably issued by the same caller, before it realized it was talking to a local cap. Unfortunately, it can't just wait for those calls it issued to complete, because streaming flow control may have made them appear to complete long ago, when they're actually still in the server's queue. How does the app make sure that the directly-issued writes don't overlap with RPC writes?

We can solve this by making CapabilityServerSet::getLocalServer() delay until all in-flight stream calls are complete before unwrapping.

Now, the app can simply make sure that any requests it issued over RPC in the past completed before it starts issuing direct requests.
parent 7a0e0fd0
...@@ -544,16 +544,46 @@ public: ...@@ -544,16 +544,46 @@ public:
return kj::addRef(*this); return kj::addRef(*this);
} }
static const uint BRAND;
// Value is irrelevant; used for pointer.
const void* getBrand() override { const void* getBrand() override {
// We have no need to detect local objects. return &BRAND;
return nullptr;
} }
void* getLocalServer(_::CapabilityServerSetBase& capServerSet) override { kj::Promise<void*> getLocalServer(_::CapabilityServerSetBase& capServerSet) {
// If this is a local capability created through `capServerSet`, return the underlying Server.
// Otherwise, return nullptr. Default implementation (which everyone except LocalClient should
// use) always returns nullptr.
if (this->capServerSet == &capServerSet) { if (this->capServerSet == &capServerSet) {
return ptr; if (blocked) {
// If streaming calls are in-flight, it could be the case that they were originally sent
// over RPC and reflected back, before the capability had resolved to a local object. In
// that case, the client may already perceive these calls as "done" because the RPC
// implementation caused the client promise to resolve early. However, the capability is
// now local, and the app is trying to break through the LocalClient wrapper and access
// the server directly, bypassing the stream queue. Since the app thinks that all
// previous calls already completed, it may then try to queue a new call directly on the
// server, jumping the queue.
//
// We can solve this by delaying getLocalServer() until all current streaming calls have
// finished. Note that if a new streaming call is started *after* this point, we need not
// worry about that, because in this case it is presumably a local call and the caller
// won't be informed of completion until the call actually does complete. Thus the caller
// is well-aware that this call is still in-flight.
//
// However, the app still cannot assume that there aren't multiple clients, perhaps even
// a malicious client that tries to send stream requests that overlap with the app's
// direct use of the server... so it's up to the app to check for and guard against
// concurrent calls after using getLocalServer().
return kj::newAdaptedPromise<kj::Promise<void>, BlockedCall>(*this)
.then([this]() { return ptr; });
} else {
return ptr;
}
} else { } else {
return nullptr; return (void*)nullptr;
} }
} }
...@@ -577,15 +607,26 @@ private: ...@@ -577,15 +607,26 @@ private:
client.blockedCallsEnd = &next; client.blockedCallsEnd = &next;
} }
BlockedCall(kj::PromiseFulfiller<kj::Promise<void>>& fulfiller, LocalClient& client)
: fulfiller(fulfiller), client(client), prev(client.blockedCallsEnd) {
*prev = *this;
client.blockedCallsEnd = &next;
}
~BlockedCall() noexcept(false) { ~BlockedCall() noexcept(false) {
unlink(); unlink();
} }
void unblock() { void unblock() {
unlink(); unlink();
fulfiller.fulfill(kj::evalNow([this]() { KJ_IF_MAYBE(c, context) {
return client.callInternal(interfaceId, methodId, context); fulfiller.fulfill(kj::evalNow([&]() {
})); return client.callInternal(interfaceId, methodId, *c);
}));
} else {
// This is just a barrier.
fulfiller.fulfill(kj::READY_NOW);
}
} }
private: private:
...@@ -593,7 +634,7 @@ private: ...@@ -593,7 +634,7 @@ private:
LocalClient& client; LocalClient& client;
uint64_t interfaceId; uint64_t interfaceId;
uint16_t methodId; uint16_t methodId;
CallContextHook& context; kj::Maybe<CallContextHook&> context;
kj::Maybe<BlockedCall&> next; kj::Maybe<BlockedCall&> next;
kj::Maybe<BlockedCall&>* prev; kj::Maybe<BlockedCall&>* prev;
...@@ -667,6 +708,8 @@ private: ...@@ -667,6 +708,8 @@ private:
} }
}; };
const uint LocalClient::BRAND = 0;
kj::Own<ClientHook> Capability::Client::makeLocalClient(kj::Own<Capability::Server>&& server) { kj::Own<ClientHook> Capability::Client::makeLocalClient(kj::Own<Capability::Server>&& server) {
return kj::refcounted<LocalClient>(kj::mv(server)); return kj::refcounted<LocalClient>(kj::mv(server));
} }
...@@ -864,8 +907,10 @@ kj::Promise<void*> CapabilityServerSetBase::getLocalServerInternal(Capability::C ...@@ -864,8 +907,10 @@ kj::Promise<void*> CapabilityServerSetBase::getLocalServerInternal(Capability::C
Capability::Client client(kj::mv(resolved)); Capability::Client client(kj::mv(resolved));
return getLocalServerInternal(client); return getLocalServerInternal(client);
}); });
} else if (hook->getBrand() == &LocalClient::BRAND) {
return kj::downcast<LocalClient>(*hook).getLocalServer(*this);
} else { } else {
return hook->getLocalServer(*this); return (void*)nullptr;
} }
} }
......
...@@ -641,11 +641,6 @@ public: ...@@ -641,11 +641,6 @@ public:
// Returns true if the capability was created as a result of assigning a Client to null or by // Returns true if the capability was created as a result of assigning a Client to null or by
// reading a null pointer out of a Cap'n Proto message. // reading a null pointer out of a Cap'n Proto message.
virtual void* getLocalServer(_::CapabilityServerSetBase& capServerSet);
// If this is a local capability created through `capServerSet`, return the underlying Server.
// Otherwise, return nullptr. Default implementation (which everyone except LocalClient should
// use) always returns nullptr.
virtual kj::Maybe<int> getFd() = 0; virtual kj::Maybe<int> getFd() = 0;
// Implements Capability::Client::getFd(). If this returns null but whenMoreResolved() returns // Implements Capability::Client::getFd(). If this returns null but whenMoreResolved() returns
// non-null, then Capability::Client::getFd() waits for resolution and tries again. // non-null, then Capability::Client::getFd() waits for resolution and tries again.
......
...@@ -55,11 +55,6 @@ void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory) { ...@@ -55,11 +55,6 @@ void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory) {
const uint ClientHook::NULL_CAPABILITY_BRAND = 0; const uint ClientHook::NULL_CAPABILITY_BRAND = 0;
// Defined here rather than capability.c++ so that we can safely call isNull() in this file. // Defined here rather than capability.c++ so that we can safely call isNull() in this file.
void* ClientHook::getLocalServer(_::CapabilityServerSetBase& capServerSet) {
// Defined here rather than capability.c++ because otherwise building with -fsanitize=vptr fails.
return nullptr;
}
namespace _ { // private namespace _ { // private
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
......
...@@ -639,6 +639,99 @@ KJ_TEST("Streaming over RPC") { ...@@ -639,6 +639,99 @@ KJ_TEST("Streaming over RPC") {
} }
} }
KJ_TEST("Streaming over RPC then unwrap with CapabilitySet") {
kj::EventLoop loop;
kj::WaitScope waitScope(loop);
auto pipe = kj::newTwoWayPipe();
CapabilityServerSet<test::TestStreaming> capSet;
auto ownServer = kj::heap<TestStreamingImpl>();
auto& server = *ownServer;
auto serverCap = capSet.add(kj::mv(ownServer));
auto paf = kj::newPromiseAndFulfiller<test::TestStreaming::Client>();
TwoPartyClient tpClient(*pipe.ends[0], serverCap);
TwoPartyClient tpServer(*pipe.ends[1], kj::mv(paf.promise), rpc::twoparty::Side::SERVER);
auto clientCap = tpClient.bootstrap().castAs<test::TestStreaming>();
// Send stream requests until we can't anymore.
kj::Promise<void> promise = kj::READY_NOW;
uint count = 0;
while (promise.poll(waitScope)) {
promise.wait(waitScope);
auto req = clientCap.doStreamIRequest();
req.setI(++count);
promise = req.send();
}
// We should have sent... several.
KJ_EXPECT(count > 10);
// Now try to unwrap.
auto unwrapPromise = capSet.getLocalServer(clientCap);
// It won't work yet, obviously, because we haven't resolved the promise.
KJ_EXPECT(!unwrapPromise.poll(waitScope));
// So do that.
paf.fulfiller->fulfill(tpServer.bootstrap().castAs<test::TestStreaming>());
clientCap.whenResolved().wait(waitScope);
// But the unwrap still doesn't resolve because streaming requests are queued up.
KJ_EXPECT(!unwrapPromise.poll(waitScope));
// OK, let's resolve a streaming request.
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
// All of our call promises have now completed from the client's perspective.
promise.wait(waitScope);
// But we still can't unwrap, because calls are queued server-side.
KJ_EXPECT(!unwrapPromise.poll(waitScope));
// Let's even make one more call now. But this is actually a local call since the promise
// resolved.
{
auto req = clientCap.doStreamIRequest();
req.setI(++count);
promise = req.send();
}
// Because it's a local call, it doesn't resolve early. The window is no longer in effect.
KJ_EXPECT(!promise.poll(waitScope));
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
KJ_EXPECT(!promise.poll(waitScope));
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
KJ_EXPECT(!promise.poll(waitScope));
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
KJ_EXPECT(!promise.poll(waitScope));
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
KJ_EXPECT(!promise.poll(waitScope));
// Our unwrap promise is also still not resolved.
KJ_EXPECT(!unwrapPromise.poll(waitScope));
// Close out stream calls until it does resolve!
while (!unwrapPromise.poll(waitScope)) {
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
}
// Now we can unwrap!
KJ_EXPECT(&KJ_ASSERT_NONNULL(unwrapPromise.wait(waitScope)) == &server);
// But our last stream call still isn't done.
KJ_EXPECT(!promise.poll(waitScope));
// Finish it.
KJ_ASSERT_NONNULL(server.fulfiller)->fulfill();
promise.wait(waitScope);
}
} // namespace } // namespace
} // namespace _ } // namespace _
} // namespace capnp } // namespace capnp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment