Commit 8b8f8a2f authored by Kenton Varda's avatar Kenton Varda

Implement CapabilityServerSet which provides a way to unwrap loopback…

Implement CapabilityServerSet which provides a way to unwrap loopback capabilities to get the underlying server object, and also implements weak references.
parent dd662130
...@@ -823,6 +823,69 @@ TEST(Capability, ImplicitParams) { ...@@ -823,6 +823,69 @@ TEST(Capability, ImplicitParams) {
promise.wait(waitScope); promise.wait(waitScope);
} }
TEST(Capability, CapabilityServerSet) {
kj::EventLoop loop;
kj::WaitScope waitScope(loop);
CapabilityServerSet<test::TestInterface> set1, set2;
int callCount = 0;
test::TestInterface::Client clientStandalone(kj::heap<TestInterfaceImpl>(callCount));
test::TestInterface::Client clientNull = nullptr;
auto ownServer1 = kj::heap<TestInterfaceImpl>(callCount);
auto& server1 = *ownServer1;
test::TestInterface::Client client1 = set1.add(kj::mv(ownServer1));
auto ownServer2 = kj::heap<TestInterfaceImpl>(callCount);
auto& server2 = *ownServer2;
auto client2AndWeak = set2.addWeak(kj::mv(ownServer2));
test::TestInterface::Client client2 = kj::mv(client2AndWeak.client);
kj::Own<WeakCapability<test::TestInterface>> client2Weak = kj::mv(client2AndWeak.weak);
// Getting the local server using the correct set works.
EXPECT_EQ(&server1, &KJ_ASSERT_NONNULL(set1.getLocalServer(client1).wait(waitScope)));
EXPECT_EQ(&server2, &KJ_ASSERT_NONNULL(set2.getLocalServer(client2).wait(waitScope)));
// Getting the local server using the wrong set doesn't work.
EXPECT_TRUE(set1.getLocalServer(client2).wait(waitScope) == nullptr);
EXPECT_TRUE(set2.getLocalServer(client1).wait(waitScope) == nullptr);
EXPECT_TRUE(set1.getLocalServer(clientStandalone).wait(waitScope) == nullptr);
EXPECT_TRUE(set1.getLocalServer(clientNull).wait(waitScope) == nullptr);
// A promise client waits to be resolved.
auto paf = kj::newPromiseAndFulfiller<test::TestInterface::Client>();
test::TestInterface::Client clientPromise = kj::mv(paf.promise);
bool resolved1 = false, resolved2 = false;
auto promise1 = set1.getLocalServer(clientPromise)
.then([&](kj::Maybe<test::TestInterface::Server&> server) {
resolved1 = true;
EXPECT_EQ(&server1, &KJ_ASSERT_NONNULL(server));
});
auto promise2 = set2.getLocalServer(clientPromise)
.then([&](kj::Maybe<test::TestInterface::Server&> server) {
resolved2 = true;
EXPECT_TRUE(server == nullptr);
});
kj::evalLater([](){}).wait(waitScope);
kj::evalLater([](){}).wait(waitScope);
kj::evalLater([](){}).wait(waitScope);
kj::evalLater([](){}).wait(waitScope);
EXPECT_FALSE(resolved1);
EXPECT_FALSE(resolved2);
paf.fulfiller->fulfill(kj::cp(client1));
promise1.wait(waitScope);
promise2.wait(waitScope);
EXPECT_TRUE(resolved1);
EXPECT_TRUE(resolved2);
}
} // namespace } // namespace
} // namespace _ } // namespace _
} // namespace capnp } // namespace capnp
...@@ -59,6 +59,10 @@ ClientHook::ClientHook() { ...@@ -59,6 +59,10 @@ ClientHook::ClientHook() {
setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory); setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
} }
void* ClientHook::getLocalServer(_::CapabilityServerSetBase& capServerSet) {
return nullptr;
}
void MessageReader::initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTable) { void MessageReader::initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTable) {
setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory); setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
arena()->initCapTable(kj::mv(capTable)); arena()->initCapTable(kj::mv(capTable));
...@@ -457,8 +461,22 @@ private: ...@@ -457,8 +461,22 @@ private:
class LocalClient final: public ClientHook, public kj::Refcounted { class LocalClient final: public ClientHook, public kj::Refcounted {
public: public:
LocalClient(kj::Own<Capability::Server>&& server) LocalClient(kj::Own<Capability::Server>&& server): server(kj::mv(server)) {}
: server(kj::mv(server)) {} LocalClient(kj::Own<Capability::Server>&& server,
_::CapabilityServerSetBase& capServerSet, void* ptr)
: server(kj::mv(server)), capServerSet(&capServerSet), ptr(ptr) {}
~LocalClient() noexcept(false) {
KJ_IF_MAYBE(w, weak) {
w->client = nullptr;
}
}
void setWeak(_::WeakCapabilityBase& weak) {
KJ_REQUIRE(this->weak == nullptr && weak.client == nullptr);
weak.client = *this;
this->weak = weak;
}
Request<AnyPointer, AnyPointer> newCall( Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
...@@ -523,8 +541,20 @@ public: ...@@ -523,8 +541,20 @@ public:
return nullptr; return nullptr;
} }
void* getLocalServer(_::CapabilityServerSetBase& capServerSet) override {
if (this->capServerSet == &capServerSet) {
return ptr;
} else {
return nullptr;
}
}
private: private:
kj::Own<Capability::Server> server; kj::Own<Capability::Server> server;
_::CapabilityServerSetBase* capServerSet = nullptr;
void* ptr = nullptr;
kj::Maybe<_::WeakCapabilityBase&> weak;
friend class _::WeakCapabilityBase;
}; };
kj::Own<ClientHook> Capability::Client::makeLocalClient(kj::Own<Capability::Server>&& server) { kj::Own<ClientHook> Capability::Client::makeLocalClient(kj::Own<Capability::Server>&& server) {
...@@ -632,4 +662,58 @@ Request<AnyPointer, AnyPointer> newBrokenRequest( ...@@ -632,4 +662,58 @@ Request<AnyPointer, AnyPointer> newBrokenRequest(
return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
} }
// =======================================================================================
// CapabilityServerSet
namespace _ { // private
WeakCapabilityBase::~WeakCapabilityBase() noexcept(false) {
KJ_IF_MAYBE(c, client) {
c->weak = nullptr;
}
}
kj::Maybe<Capability::Client> WeakCapabilityBase::getInternal() {
return client.map([](LocalClient& client) {
return Capability::Client(client.addRef());
});
}
Capability::Client CapabilityServerSetBase::addInternal(
kj::Own<Capability::Server>&& server, void* ptr) {
return Capability::Client(kj::refcounted<LocalClient>(kj::mv(server), *this, ptr));
}
Capability::Client CapabilityServerSetBase::addWeakInternal(
kj::Own<Capability::Server>&& server, _::WeakCapabilityBase& weak, void* ptr) {
auto result = kj::refcounted<LocalClient>(kj::mv(server), *this, ptr);
result->setWeak(weak);
return Capability::Client(kj::mv(result));
}
kj::Promise<void*> CapabilityServerSetBase::getLocalServerInternal(Capability::Client& client) {
ClientHook* hook = client.hook.get();
// Get the most-resolved-so-far version of the hook.
KJ_IF_MAYBE(h, hook->getResolved()) {
hook = h;
};
KJ_IF_MAYBE(p, hook->whenMoreResolved()) {
// This hook is an unresolved promise. We need to wait for it.
return p->attach(hook->addRef())
.then([this](kj::Own<ClientHook>&& resolved) {
Capability::Client client(kj::mv(resolved));
return getLocalServerInternal(client);
}, [](kj::Exception&&) -> void* {
// A broken promise is simply not a local capability.
return nullptr;
});
} else {
return hook->getLocalServer(*this);
}
}
} // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -61,10 +61,12 @@ public: ...@@ -61,10 +61,12 @@ public:
RemotePromise& operator=(RemotePromise&& other) = default; RemotePromise& operator=(RemotePromise&& other) = default;
}; };
class LocalClient;
namespace _ { // private namespace _ { // private
struct RawSchema; struct RawSchema;
struct RawBrandedSchema; struct RawBrandedSchema;
extern const RawSchema NULL_INTERFACE_SCHEMA; // defined in schema.c++ extern const RawSchema NULL_INTERFACE_SCHEMA; // defined in schema.c++
class CapabilityServerSetBase;
} // namespace _ (private) } // namespace _ (private)
struct Capability { struct Capability {
...@@ -141,6 +143,9 @@ class Capability::Client { ...@@ -141,6 +143,9 @@ class Capability::Client {
// Base type for capability clients. // Base type for capability clients.
public: public:
typedef Capability Reads;
typedef Capability Calls;
Client(decltype(nullptr)); Client(decltype(nullptr));
// If you need to declare a Client before you have anything to assign to it (perhaps because // If you need to declare a Client before you have anything to assign to it (perhaps because
// the assignment is going to occur in an if/else scope), you can start by initializing it to // the assignment is going to occur in an if/else scope), you can start by initializing it to
...@@ -220,6 +225,7 @@ private: ...@@ -220,6 +225,7 @@ private:
friend struct DynamicList; friend struct DynamicList;
template <typename, Kind> template <typename, Kind>
friend struct List; friend struct List;
friend class _::CapabilityServerSetBase;
}; };
// ======================================================================================= // =======================================================================================
...@@ -313,6 +319,8 @@ class Capability::Server { ...@@ -313,6 +319,8 @@ class Capability::Server {
// dispatchCall(). // dispatchCall().
public: public:
typedef Capability Serves;
virtual kj::Promise<void> dispatchCall(uint64_t interfaceId, uint16_t methodId, virtual kj::Promise<void> dispatchCall(uint64_t interfaceId, uint16_t methodId,
CallContext<AnyPointer, AnyPointer> context) = 0; CallContext<AnyPointer, AnyPointer> context) = 0;
// Call the given method. `params` is the input struct, and should be released as soon as it // Call the given method. `params` is the input struct, and should be released as soon as it
...@@ -334,6 +342,83 @@ protected: ...@@ -334,6 +342,83 @@ protected:
uint64_t typeId, uint16_t methodId); uint64_t typeId, uint16_t methodId);
}; };
// =======================================================================================
namespace _ { // private
class WeakCapabilityBase {
public:
~WeakCapabilityBase() noexcept(false);
kj::Maybe<Capability::Client> getInternal();
private:
kj::Maybe<LocalClient&> client;
friend class capnp::LocalClient;
};
class CapabilityServerSetBase {
public:
Capability::Client addInternal(kj::Own<Capability::Server>&& server, void* ptr);
Capability::Client addWeakInternal(kj::Own<Capability::Server>&& server,
_::WeakCapabilityBase& weak, void* ptr);
kj::Promise<void*> getLocalServerInternal(Capability::Client& client);
};
} // namespace _ (private)
template <typename T>
class WeakCapability: private _::WeakCapabilityBase {
public:
kj::Maybe<typename T::Client> get();
// If the server is still alive, get a live client to it.
private:
template <typename>
friend class CapabilityServerSet;
};
template <typename T>
class CapabilityServerSet: private _::CapabilityServerSetBase {
// Allows a server to:
// 1) Recognize its own capabilities when passed back to it, and obtain the underlying Server
// objects associated with them.
// 2) Obtain "weak" versions of these capabilities, which do not prevent the underlying Server
// from being destroyed but can be upgraded to normal Clients as long as the Server is still
// alive.
//
// All objects in the set must have the same interface type T. The objects may implement various
// interfaces derived from T (and in fact T can be `capnp::Capability` to accept all objects),
// but note that if you compile with RTTI disabled then you will not be able to down-cast through
// virtual inheritance, and all inheritance between server interfaces is virtual. So, with RTTI
// disabled, you will likely need to set T to be the most-derived Cap'n Proto interface type,
// and you server class will need to be directly derived from that, so that you can use
// static_cast (or kj::downcast) to cast to it after calling getLocalServer(). (If you compile
// with RTTI, then you can freely dynamic_cast and ignore this issue!)
public:
CapabilityServerSet() = default;
KJ_DISALLOW_COPY(CapabilityServerSet);
typename T::Client add(kj::Own<typename T::Server>&& server);
// Create a new capability Client for the given Server and also add this server to the set.
struct ClientAndWeak {
typename T::Client client;
kj::Own<WeakCapability<T>> weak;
};
ClientAndWeak addWeak(kj::Own<typename T::Server>&& server);
// Like add() but also creates a weak reference.
kj::Promise<kj::Maybe<typename T::Server&>> getLocalServer(typename T::Client& client);
// Given a Client pointing to a server previously passed to add(), return the corresponding
// Server. This returns a promise because if the input client is itself a promise, this must
// wait for it to resolve. Keep in mind that the server will be deleted when all clients are
// gone, so the caller should make sure to keep the client alive (hence why this method only
// accepts an lvalue input).
};
// ======================================================================================= // =======================================================================================
// Hook interfaces which must be implemented by the RPC system. Applications never call these // Hook interfaces which must be implemented by the RPC system. Applications never call these
// directly; the RPC system implements them and the types defined earlier in this file wrap them. // directly; the RPC system implements them and the types defined earlier in this file wrap them.
...@@ -421,6 +506,11 @@ public: ...@@ -421,6 +506,11 @@ public:
// Returns a void* that identifies who made this client. This can be used by an RPC adapter to // Returns a void* that identifies who made this client. This can be used by an RPC adapter to
// discover when a capability it needs to marshal is one that it created in the first place, and // discover when a capability it needs to marshal is one that it created in the first place, and
// therefore it can transfer the capability without proxying. // therefore it can transfer the capability without proxying.
virtual void* getLocalServer(_::CapabilityServerSetBase& capServerSet);
// If this is a local capability created through `capServerSet`, return the underlying Server.
// Otherwise, return nullptr. Default implementation (which everyone except LocalClient should
// use) always returns nullptr.
}; };
class CallContextHook { class CallContextHook {
...@@ -689,6 +779,45 @@ CallContext<Params, Results> Capability::Server::internalGetTypedContext( ...@@ -689,6 +779,45 @@ CallContext<Params, Results> Capability::Server::internalGetTypedContext(
return CallContext<Params, Results>(*typeless.hook); return CallContext<Params, Results>(*typeless.hook);
} }
template <typename T>
kj::Maybe<typename T::Client> WeakCapability<T>::get() {
return getInternal().map([](Capability::Client&& client) {
return client.castAs<T>();
});
}
template <typename T>
typename T::Client CapabilityServerSet<T>::add(kj::Own<typename T::Server>&& server) {
void* ptr = reinterpret_cast<void*>(server.get());
// Clang insists that `castAs` is a template-dependent member and therefore we need the
// `template` keyword here, but AFAICT this is wrong: addImpl() is not a template.
return addInternal(kj::mv(server), ptr).template castAs<T>();
}
template <typename T>
typename CapabilityServerSet<T>::ClientAndWeak CapabilityServerSet<T>::addWeak(
kj::Own<typename T::Server>&& server) {
void* ptr = reinterpret_cast<void*>(server.get());
auto weak = kj::heap<WeakCapability<T>>();
// Clang insists that `castAs` is a template-dependent member and therefore we need the
// `template` keyword here, but AFAICT this is wrong: addWeakImpl() is not a template.
auto client = addWeakInternal(kj::mv(server), *weak, ptr).template castAs<T>();
return { kj::mv(client), kj::mv(weak) };
}
template <typename T>
kj::Promise<kj::Maybe<typename T::Server&>> CapabilityServerSet<T>::getLocalServer(
typename T::Client& client) {
return getLocalServerInternal(client)
.then([](void* server) -> kj::Maybe<typename T::Server&> {
if (server == nullptr) {
return nullptr;
} else {
return *reinterpret_cast<typename T::Server*>(server);
}
});
}
} // namespace capnp } // namespace capnp
#endif // CAPNP_CAPABILITY_H_ #endif // CAPNP_CAPABILITY_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment