Commit 966d25a2 authored by Kenton Varda's avatar Kenton Varda

More RPC protocol WIP.

parent c1e51108
......@@ -93,7 +93,7 @@ public:
return kj::addRef(*this);
}
void* getBrand() const override {
const void* getBrand() const override {
return nullptr;
}
......
......@@ -276,7 +276,7 @@ public:
return kj::addRef(*this);
}
void* getBrand() const override {
const void* getBrand() const override {
return nullptr;
}
......@@ -414,7 +414,7 @@ public:
return kj::addRef(*this);
}
void* getBrand() const override {
const void* getBrand() const override {
// We have no need to detect local objects.
return nullptr;
}
......
......@@ -326,7 +326,7 @@ public:
virtual kj::Own<const ClientHook> addRef() const = 0;
// Return a new reference to the same capability.
virtual void* getBrand() const = 0;
virtual const void* getBrand() const = 0;
// Returns a void* that identifies who made this client. This can be used by an RPC adapter to
// discover when a capability it needs to marshal is one that it created in the first place, and
// therefore it can transfer the capability without proxying.
......
......@@ -219,6 +219,15 @@ inline constexpr uint64_t typeId() { return _::TypeId_<T>::typeId; }
// typeId<MyType>() returns the type ID as defined in the schema. Works with structs, enums, and
// interfaces.
template <typename T>
inline constexpr uint sizeInWords() {
// Return the size, in words, of a Struct type, if allocated free-standing (not in a list).
// May be useful for pre-computing space needed in order to precisely allocate messages.
return (WordCount32(_::structSize<T>().data) +
_::structSize<T>().pointers * WORDS_PER_POINTER) / WORDS;
}
} // namespace capnp
#define CAPNP_DECLARE_ENUM(type, id) \
......
......@@ -2247,6 +2247,10 @@ const word* PointerReader::getUnchecked() const {
return reinterpret_cast<const word*>(pointer);
}
WordCount64 PointerReader::targetSize() const {
return WireHelpers::totalSize(segment, pointer, nestingLimit);
}
bool PointerReader::isNull() const {
return pointer == nullptr || pointer->isNull();
}
......
......@@ -355,6 +355,13 @@ public:
static inline PointerReader getRootUnchecked(const word* location);
// Get a PointerReader for an unchecked message.
WordCount64 targetSize() const;
// Return the total size of the target object and everything to which it points. Does not count
// far pointer overhead. This is useful for deciding how much space is needed to copy the object
// into a flat array. However, the caller is advised NOT to treat this value as secure. Instead,
// use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns.
bool isNull() const;
StructReader getStruct(const word* defaultValue) const;
......
......@@ -89,6 +89,9 @@ struct ObjectPointer {
Reader() = default;
inline Reader(_::PointerReader reader): reader(reader) {}
inline size_t targetSizeInWords() const;
// Get the total size, in words, of the target object and all its children.
inline bool isNull() const;
template <typename T>
......@@ -126,6 +129,9 @@ struct ObjectPointer {
inline Builder(decltype(nullptr)) {}
inline Builder(_::PointerBuilder builder): builder(builder) {}
inline size_t targetSizeInWords() const;
// Get the total size, in words, of the target object and all its children.
inline bool isNull();
inline void clear();
......@@ -296,6 +302,10 @@ private:
// =======================================================================================
// Inline implementation details
inline size_t ObjectPointer::Reader::targetSizeInWords() const {
return reader.targetSize() / WORDS;
}
inline bool ObjectPointer::Reader::isNull() const {
return reader.isNull();
}
......@@ -305,6 +315,10 @@ inline ReaderFor<T> ObjectPointer::Reader::getAs() const {
return _::PointerHelpers<T>::get(reader);
}
inline size_t ObjectPointer::Builder::targetSizeInWords() const {
return asReader().targetSizeInWords();
}
inline bool ObjectPointer::Builder::isNull() {
return builder.isNull();
}
......
......@@ -26,6 +26,7 @@
#include <kj/debug.h>
#include <kj/vector.h>
#include <kj/async.h>
#include <kj/one-of.h>
#include <unordered_map>
#include <queue>
#include <capnp/rpc.capnp.h>
......@@ -35,6 +36,56 @@ namespace _ { // private
namespace {
template <typename T>
inline constexpr uint messageSizeHint() {
return 1 + sizeInWords<rpc::Message>() + sizeInWords<T>();
}
template <>
inline constexpr uint messageSizeHint<void>() {
return 1 + sizeInWords<rpc::Message>();
}
kj::Maybe<kj::Array<PipelineOp>> toPipelineOps(List<rpc::PromisedAnswer::Op>::Reader ops) {
auto result = kj::heapArrayBuilder<PipelineOp>(ops.size());
for (auto opReader: ops) {
PipelineOp op;
switch (opReader.which()) {
case rpc::PromisedAnswer::Op::NOOP:
op.type = PipelineOp::NOOP;
break;
case rpc::PromisedAnswer::Op::GET_POINTER_FIELD:
op.type = PipelineOp::GET_POINTER_FIELD;
op.pointerIndex = opReader.getGetPointerField();
break;
default:
// TODO(soon): Handle better?
KJ_FAIL_REQUIRE("Unsupported pipeline op.", (uint)opReader.which()) {
return nullptr;
}
}
}
return result.finish();
}
Orphan<List<rpc::PromisedAnswer::Op>> fromPipelineOps(
Orphanage orphanage, const kj::Array<PipelineOp>& ops) {
auto result = orphanage.newOrphan<List<rpc::PromisedAnswer::Op>>(ops.size());
auto builder = result.get();
for (uint i: kj::indices(ops)) {
rpc::PromisedAnswer::Op::Builder opBuilder = builder[i];
switch (ops[i].type) {
case PipelineOp::NOOP:
opBuilder.setNoop();
break;
case PipelineOp::GET_POINTER_FIELD:
opBuilder.setGetPointerField(ops[i].pointerIndex);
break;
}
}
return result;
}
typedef uint32_t QuestionId;
typedef uint32_t ExportId;
......@@ -103,24 +154,29 @@ private:
std::unordered_map<Id, T> high;
};
template <typename ParamCaps, typename RpcPipeline>
struct Question {
kj::Array<ExportId> exportsInParams;
// Exports embedded in the call message which should be implicitly released on return (unless
// they are in the retain list).
kj::Own<ParamCaps> paramCaps;
// A handle representing the capabilities in the parameter struct. This will be dropped as soon
// as the call returns.
kj::Own<kj::PromiseFulfiller<Response<ObjectPointer>>> fulfiller;
// Fulfill with the response.
kj::Maybe<RpcPipeline&> pipeline;
// The local pipeline object. The RpcPipeline's own destructor sets this value to null and then
// sends the Finish message.
//
// TODO(cleanup): We only have this pointer here because CapInjectorImpl::getInjectedCap() needs
// it, but perhaps CapInjectorImpl should instead hold on to the ClientHook it got in the first
// place.
bool isStarted = false;
// Is this Question currently in-use?
bool isReturned = false;
// Has the call returned?
bool isReleased = false;
// Has the call been released locally, and the ReleaseAnswer message sent? Note that this could
// occur *before* the call returns.
inline bool operator==(decltype(nullptr)) const { return !isStarted; }
inline bool operator!=(decltype(nullptr)) const { return isStarted; }
};
......@@ -145,23 +201,12 @@ struct Export {
inline bool operator!=(decltype(nullptr)) const { return refcount != 0; }
};
struct Import {
uint remoteRefcount = 0;
// Number of times we've received this import from the peer.
uint localRefcount = 0;
// Number of local proxies that currently exist wrapping this import. Once this reaches zero,
// a Release message should be sent for `remoteRefcount` references and the import should be
// removed from the table. (It would be nice to construct only one proxy object and use its
// own reference count, but it would be hard to prevent it from being destroyed in another thread
// at exactly the moment that we call addRef() on it.)
};
class RpcConnectionState: public kj::TaskSet::ErrorHandler {
public:
RpcConnectionState(const kj::EventLoop& eventLoop,
kj::Own<VatNetworkBase::Connection>&& connection)
: eventLoop(eventLoop), connection(kj::mv(connection)), tasks(eventLoop, *this) {
: eventLoop(eventLoop), connection(kj::mv(connection)), tasks(eventLoop, *this),
exportDisposer(*this) {
tasks.add(messageLoop());
}
......@@ -173,95 +218,249 @@ private:
const kj::EventLoop& eventLoop;
kj::Own<VatNetworkBase::Connection> connection;
class ImportClient;
class CapInjectorImpl;
class CapExtractorImpl;
class RpcPipeline;
struct Tables {
ExportTable<QuestionId, Question> questions;
ExportTable<QuestionId, Question<CapInjectorImpl, RpcPipeline>> questions;
ImportTable<QuestionId, Answer> answers;
ExportTable<ExportId, Export> exports;
ImportTable<ExportId, Import> imports;
ImportTable<ExportId, kj::Maybe<ImportClient&>> imports;
};
kj::MutexGuarded<Tables> tables;
kj::TaskSet tasks;
class ClientHookImpl final: public ClientHook {
class ExportDisposer final: public kj::Disposer {
public:
ClientHookImpl(RpcConnectionState& connectionState, ExportId importId);
inline ExportDisposer(const RpcConnectionState& connectionState)
: connectionState(connectionState) {}
protected:
void disposeImpl(void* pointer) const override {
auto lock = connectionState.tables.lockExclusive();
ExportId id = reinterpret_cast<intptr_t>(pointer);
KJ_IF_MAYBE(exp, lock->exports.find(id)) {
if (--exp->refcount == 0) {
KJ_ASSERT(lock->exports.erase(id)) {
break;
}
}
} else {
KJ_FAIL_REQUIRE("invalid export ID", id) { break; }
}
}
private:
const RpcConnectionState& connectionState;
};
ExportDisposer exportDisposer;
// =====================================================================================
// ClientHook implementations
class CapExtractorImpl final: public CapExtractor<rpc::CapDescriptor> {
class RpcClient: public ClientHook, public kj::Refcounted {
public:
CapExtractorImpl(const RpcConnectionState& connectionState)
RpcClient(const RpcConnectionState& connectionState)
: connectionState(connectionState) {}
~CapExtractorImpl() {
if (retainedCaps.getWithoutLock().size() > 0) {
// Oops, we were deleted without finalizeRetainedCaps. We really need to make sure that
// the references we kept get unreferenced.
virtual kj::Own<const kj::Refcounted> writeDescriptor(
rpc::CapDescriptor::Builder descriptor) const = 0;
// Writes a CapDescriptor referencing this client. Returns a reference to some object which
// must be held at least until the message containing `descriptor` has been sent.
//
// TODO(cleanup): Specialize Own<void> so that we can return it here instead of
// Own<Refcounted>.
// implements ClientHook -----------------------------------------
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) const override {
auto params = context->getParams();
newOutgoingMessage
newCall(interfaceId, methodId, params.targetSizeInWords() + CALL_MESSAGE_SIZE);
}
kj::Own<const ClientHook> addRef() const override {
return kj::addRef(*this);
}
const void* getBrand() const override {
return &connectionState;
}
protected:
const RpcConnectionState& connectionState;
};
class ImportClient final: public RpcClient {
public:
ImportClient(const RpcConnectionState& connectionState, ExportId importId, bool isPromise)
: RpcClient(connectionState), importId(importId), isPromise(isPromise) {}
~ImportClient() noexcept(false) {
{
// Remove self from the import table, if the table is still pointing at us. (It's possible
// that another thread attempted to obtain this import just as the destructor started, in
// which case that other thread will have constructed a new ImportClient and placed it in
// the import table.)
auto lock = connectionState.tables.lockExclusive();
for (auto importId: retainedCaps.getWithoutLock()) {
Import& import = lock->imports[importId];
if (--lock->imports[importId].localRefcount == 0) {
if (import.remoteRefcount != 0) {
connectionState.sendReleaseLater(importId, import.remoteRefcount);
import.remoteRefcount = 0;
}
KJ_IF_MAYBE(ptr, lock->imports[importId]) {
if (ptr == this) {
lock->imports[importId] = nullptr;
}
}
}
// Send a message releasing our remote references.
if (remoteRefcount > 0) {
connectionState.sendReleaseLater(importId, remoteRefcount);
}
}
Orphan<List<ExportId>> finalizeRetainedCaps(Orphanage orphanage) {
// TODO(now): Go back through the caps and decrement their localRefcounts. Then go through
// them again, and for each whose refcount is now zero, remove the import from the table
// and don't retain it after all (if remoteRefcount is non-zero, arrange for a release
// message to be sent). Otherwise, retain it and increment the remoteRefcount.
kj::Maybe<kj::Own<ImportClient>> tryAddRemoteRef() {
// Add a new RemoteRef and return a new ref to this client representing it. Returns null
// if this client is being deleted in another thread, in which case the caller should
// construct a new one.
kj::Vector<ExportId> retainedCaps = kj::mv(*this->retainedCaps.lockExclusive());
KJ_IF_MAYBE(ref, kj::tryAddRef(*this)) {
++remoteRefcount;
return kj::mv(*ref);
} else {
return nullptr;
}
}
auto lock = connectionState.tables.lockExclusive();
kj::Own<const kj::Refcounted> writeDescriptor(
rpc::CapDescriptor::Builder descriptor) const override {
descriptor.setReceiverHosted(importId);
return kj::addRef(*this);
}
// implements ClientHook -----------------------------------------
Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override;
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override;
private:
ExportId importId;
bool isPromise;
uint remoteRefcount = 0;
// Number of times we've received this import from the peer.
};
class PromisedAnswerClient final: public RpcClient {
public:
PromisedAnswerClient(const RpcConnectionState& connectionState, QuestionId questionId,
kj::Array<PipelineOp>&& ops, kj::Own<const RpcPipeline> pipeline);
kj::Own<const kj::Refcounted> writeDescriptor(rpc::CapDescriptor::Builder descriptor) const override {
auto lock = state.lockShared();
if (lock->is<Waiting>()) {
auto promisedAnswer = descriptor.initReceiverAnswer();
promisedAnswer.setQuestionId(questionId);
promisedAnswer.adoptTransform(fromPipelineOps(
Orphanage::getForMessageContaining(descriptor), ops));
// Return a ref to the RpcPipeline to ensure that we don't send a Finish message for this
// call before the message containing this CapDescriptor is sent.
return kj::addRef(*lock->get<Waiting>());
} else {
// TODO(now): Problem: This won't necessarily be a remote cap!
return connectionState.writeDescriptor(
lock->get<Resolved>().getPipelinedCap(ops), descriptor);
}
}
// implements ClientHook -----------------------------------------
Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override;
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override;
private:
const RpcConnectionState& connectionState;
QuestionId questionId;
kj::Array<PipelineOp> ops;
typedef kj::Own<const RpcPipeline> Waiting;
typedef Response<ObjectPointer> Resolved;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved>> state;
};
// Remove the extra refcount we kept on each retained cap.
for (auto importId: retainedCaps) {
--lock->imports[importId].localRefcount;
kj::Own<const kj::Refcounted> writeDescriptor(
kj::Own<const ClientHook>&& cap, rpc::CapDescriptor::Builder descriptor) const {
// Write a descriptor for the given capability. Returns a reference to something which must
// be held at least until the message containing the descriptor is sent.
if (cap->getBrand() == this) {
return kj::downcast<const RpcClient>(*cap).writeDescriptor(descriptor);
} else {
// TODO(now): We have to figure out if the client is already in our table.
// TODO(now): We have to add a refcount to the export, and return an object that decrements
// that refcount later.
}
}
// =====================================================================================
// CapExtractor / CapInjector implementations
class CapExtractorImpl final: public CapExtractor<rpc::CapDescriptor> {
// Reads CapDescriptors from a received message.
public:
CapExtractorImpl(const RpcConnectionState& connectionState)
: connectionState(connectionState) {}
~CapExtractorImpl() noexcept(false) {
KJ_ASSERT(retainedCaps.getWithoutLock().size() > 0,
"CapExtractorImpl destroyed without getting a chance to retain the caps!") {
break;
}
}
uint retainedListSizeHint(bool final) {
// Get the expected size of the retained caps list, in words. If `final` is true, then it
// is known that no more caps will be extracted after this point, so an exact value can be
// returned. Otherwise, the returned size includes room for error.
// If `final` is true then there's no need to lock. If it is false, then asynchronous
// access is possible. It's probably not worth taking the lock to look; we'll just return
// a silly estimate.
uint count = final ? retainedCaps.getWithoutLock().size() : 32;
return (count * sizeof(ExportId) + (sizeof(ExportId) - 1)) / sizeof(word);
}
// Un-retain all of the ones that now have a refcount of zero.
uint count = 0;
for (auto importId: retainedCaps) {
Import& import = lock->imports[importId];
if (import.localRefcount == 0) {
if (import.remoteRefcount != 0) {
// localRefcount reached zero but remoteRefcount is not zero. `extractCap()` only
// adds the cap to `retainedCaps` at all if it provided the first local reference.
// So, the only way to get here is the following sequence:
// - extractCap() extracts the first instance of this import ID. `remoteRefcount`
// is zero at that time.
// - Some parallel message introduces the same import ID again, incrementing its
// `remoteRefcount`.
// - Both references are discarded by the application.
//
// In any case, since we're dropping the last local reference but the import has a
// non-zero remote refcount, we have to arrange for a `Release` message to be sent.
connectionState.sendReleaseLater(importId, import.remoteRefcount);
import.remoteRefcount = 0;
Orphan<List<ExportId>> finalizeRetainedCaps(Orphanage orphanage) {
// Called on finalization, when the lock is no longer needed.
kj::Vector<ExportId> retainedCaps = kj::mv(this->retainedCaps.getWithoutLock());
auto lock = connectionState.tables.lockExclusive();
auto actualRetained = retainedCaps.begin();
for (ExportId importId: retainedCaps) {
// Check if the import still exists under this ID.
KJ_IF_MAYBE(import, lock->imports[importId]) {
if (import->tryAddRemoteRef() != nullptr) {
// Import indeed still exists! We are responsible for retaining it.
*actualRetained++ = importId;
}
} else {
++count;
}
}
// Finally, build the retain list out of the imports that had non-zero refcounts.
uint count = actualRetained - retainedCaps.begin();
// Build the retain list out of the imports that had non-zero refcounts.
auto result = orphanage.newOrphan<List<ExportId>>(count);
auto resultBuilder = result.get();
count = 0;
for (auto importId: retainedCaps) {
Import& import = lock->imports[importId];
if (import.localRefcount != 0) {
resultBuilder.set(count++, importId);
}
for (auto iter = retainedCaps.begin(); iter < actualRetained; ++iter) {
resultBuilder.set(count++, *iter);
}
return kj::mv(result);
......@@ -273,22 +472,29 @@ private:
switch (descriptor.which()) {
case rpc::CapDescriptor::SENDER_HOSTED: {
ExportId importId = descriptor.getSenderHosted();
{
auto lock = connectionState.tables.lockExclusive();
Import& import = lock->imports[importId];
if (import.localRefcount == 0) {
// We haven't seen this import before, so we'll need to flag it as retained. For
// now, increment its local refcount so that it can't possibly be released before
// we get to `finalizeRetainedCaps()`.
retainedCaps.lockExclusive()->add(importId);
import.localRefcount = 1;
}
// Increment the local refcount for the ClientHook that we're about to return.
++import.localRefcount;
auto lock = connectionState.tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports[importId]) {
// The import is already on the table, but it could be being deleted in another
// thread.
KJ_IF_MAYBE(ref, kj::tryAddRef(*import)) {
// We successfully grabbed a reference to the import without it being deleted in
// another thread. Since this import already exists, we don't have to take
// responsibility for retaining it. We can just return the existing object and
// be done with it.
return kj::mv(*ref);
}
}
return kj::refcounted<ClientHookImpl>(importId);
// No import for this ID exists currently, so create one.
auto result = kj::refcounted<ImportClient>(connectionState, importId);
lock->imports[importId] = *result;
// Note that we need to retain this import later if it still exists.
retainedCaps.lockExclusive()->add(importId);
return kj::mv(result);
}
case rpc::CapDescriptor::SENDER_PROMISE:
......@@ -318,88 +524,262 @@ private:
private:
const RpcConnectionState& connectionState;
kj::MutexGuarded<kj::Vector<ExportId>> retainedCaps;
// Imports which we are responsible for retaining, should they still exist at the time that
// this message is released.
};
// =====================================================================================
// -----------------------------------------------------------------
class CapInjectorImpl final: public CapInjector<rpc::CapDescriptor> {
// Write CapDescriptors into a message as it is being built, before sending it.
public:
CapInjectorImpl(const RpcConnectionState& connectionState)
: connectionState(connectionState) {}
~CapInjectorImpl() {}
// implements CapInjector ----------------------------------------
void injectCap(rpc::CapDescriptor::Builder descriptor,
kj::Own<const ClientHook>&& cap) const override {
if (cap->getBrand() == &connectionState) {
kj::downcast<const ClientHookImpl&>(*cap).writeDescriptor(descriptor);
state.lockExclusive()->receiverHosted.add(kj::mv(cap));
} else {
// TODO(now): We have to figure out if the client is already in our table.
}
auto ref = connectionState.writeDescriptor(kj::mv(cap), descriptor);
refs.lockExclusive()->add(kj::mv(ref));
}
kj::Own<const ClientHook> getInjectedCap(rpc::CapDescriptor::Reader descriptor) const override {
}
void dropCap(rpc::CapDescriptor::Reader descriptor) const override {
switch (descriptor.which()) {
case rpc::CapDescriptor::SENDER_HOSTED: {
state.lockExclusive()->dropped.add(descriptor.getSenderHosted());
auto lock = connectionState.tables.lockExclusive();
KJ_IF_MAYBE(exp, lock->exports.find(descriptor.getSenderHosted())) {
if (--exp->refcount == 0) {
exp->clientHook = nullptr;
}
return exp->clientHook->addRef();
} else {
KJ_FAIL_REQUIRE("Dropped descriptor had invalid 'senderHosted'.") { break; }
KJ_FAIL_REQUIRE("Dropped descriptor had invalid 'senderHosted'.") {
return newBrokenCap("Calling invalid CapDescriptor found in builder.");
}
}
break;
}
case rpc::CapDescriptor::RECEIVER_HOSTED:
case rpc::CapDescriptor::RECEIVER_ANSWER:
// No big deal if we hold on to the ClientHooks a little longer until this message
// is sent.
break;
case rpc::CapDescriptor::RECEIVER_HOSTED: {
auto lock = connectionState.tables.lockExclusive();
KJ_IF_MAYBE(import, lock->imports[descriptor.getReceiverHosted()]) {
KJ_IF_MAYBE(ref, kj::tryAddRef(*import)) {
return kj::mv(*ref);
}
}
// If we wrote this CapDescriptor then we should hold a reference to the import in
// our `receiverHosted` table, yet it seems that the import ID is not valid. Something
// is wrong.
return newBrokenCap("CapDescriptor in builder had invalid 'receiverHosted'.");
}
case rpc::CapDescriptor::RECEIVER_ANSWER: {
auto promisedAnswer = descriptor.getReceiverAnswer();
auto lock = connectionState.tables.lockExclusive();
KJ_IF_MAYBE(question, lock->questions.find(promisedAnswer.getQuestionId())) {
KJ_IF_MAYBE(pipeline, question->pipeline) {
KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) {
return kj::refcounted<PromisedAnswerClient>(
connectionState, promisedAnswer.getQuestionId(),
kj::mv(*ops), kj::addRef(*pipeline));
}
}
}
return newBrokenCap("CapDescriptor in builder had invalid PromisedAnswer.");
}
default:
KJ_FAIL_REQUIRE("I don't think I wrote this descriptor.") { break; }
KJ_FAIL_REQUIRE("I don't think I wrote this descriptor.") {
return newBrokenCap("CapDescriptor in builder was invalid.");
}
break;
}
}
void dropCap(rpc::CapDescriptor::Reader descriptor) const override {
// TODO(someday): We could implement this by maintaining a map from CapDescriptors to
// the corresponding refs, but is it worth it?
}
private:
const RpcConnectionState& connectionState;
struct State {
kj::Vector<ExportId> senderHosted;
// Local capabilities that are being exported with this message. These have had their
// refcounts in the exports table increased by one while the CapInjector exists, but those
// refs will be released in the destructor, so the receiver will have to explicitly retain
// them before that point to keep them live.
kj::Vector<ExportId> dropped;
// Exports that were injected but then subsequently dropped. Each ID in this list also
// appears in senderHosted -- the instance in `dropped` essentially negates its existence in
// `senderHosted`.
kj::Vector<kj::Own<const ClientHook>> receiverHosted;
// Capabilities (exports and promised-answers) hosted by the receiver which have been injected
// into this message. This vector exists only to hold references to these caps to prevent
// them from being prematurely released before the message can be sent.
};
kj::MutexGuarded<State> state;
kj::MutexGuarded<kj::Vector<kj::Own<const kj::Refcounted>>> refs;
// List of references that need to be held until the message is destroyed.
};
// =====================================================================================
// RequestHook/PipelineHook/ResponseHook implementations
class RpcCallContext final: public CallContextHook,
public CapExtractor<rpc::CapDescriptor>,
public CapInjector<rpc::CapDescriptor>,
public kj::Refcounted {
class RpcRequest: public RequestHook {
public:
RpcCallContext(RpcConnectionState& connectionState, QuestionId questionId,
kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params);
RpcRequest(const RpcConnectionState& connectionState, uint firstSegmentWordSize)
: connectionState(connectionState),
message(connectionState.connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 : firstSegmentWordSize + messageSizeHint<rpc::Call>())),
injector(kj::heap<CapInjectorImpl>(connectionState)),
context(*injector),
callBuilder(message->getBody().getAs<rpc::Message>().initCall()),
paramsBuilder(context.imbue(callBuilder.getRequest())) {}
inline ObjectPointer::Builder getRoot() {
return paramsBuilder;
}
RemotePromise<ObjectPointer> send() override {
auto paf = kj::newPromiseAndFulfiller<Response<ObjectPointer>>(connectionState.eventLoop);
QuestionId questionId;
void sendReturn();
{
auto lock = connectionState.tables.lockExclusive();
auto& question = lock->questions.next(questionId);
callBuilder.setQuestionId(questionId);
question.isStarted = true;
question.paramCaps = kj::mv(injector);
question.fulfiller = kj::mv(paf.fulfiller);
}
auto pipeline = kj::refcounted<RpcPipeline>(connectionState, questionId);
// If the caller discards the pipeline without discarding the promise, we need the pipeline
// to stay alive so that we don't cancel the call altogether.
auto promiseWithPipelineRef = paf.promise.then(kj::mvCapture(pipeline->addRef(),
[](kj::Own<const PipelineHook>&&, Response<ObjectPointer>&& response)
-> Response<ObjectPointer> {
return kj::mv(response);
}));
message->send();
return RemotePromise<ObjectPointer>(
kj::mv(promiseWithPipelineRef),
ObjectPointer::Pipeline(kj::mv(pipeline)));
}
private:
const RpcConnectionState& connectionState;
kj::Own<OutgoingRpcMessage> message;
kj::Own<CapInjectorImpl> injector;
CapBuilderContext context;
rpc::Call::Builder callBuilder;
ObjectPointer::Builder paramsBuilder;
};
class RpcPipeline: public PipelineHook, public kj::Refcounted {
public:
RpcPipeline(const RpcConnectionState& connectionState, QuestionId questionId)
: connectionState(connectionState), questionId(questionId) {}
~RpcPipeline() noexcept(false) {
uint sizeHint = messageSizeHint<rpc::Finish>();
KJ_IF_MAYBE(ce, capExtractor) {
sizeHint += ce->retainedListSizeHint(true);
}
auto finishMessage = connectionState.connection->newOutgoingMessage(sizeHint);
rpc::Finish::Builder builder = finishMessage->getBody().initAs<rpc::Message>().initFinish();
builder.setQuestionId(questionId);
KJ_IF_MAYBE(ce, capExtractor) {
builder.adoptRetainedCaps(ce->finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder)));
}
finishMessage->send();
{
auto lock = connectionState.tables.lockExclusive();
auto& question = KJ_ASSERT_NONNULL(lock->questions.find(questionId),
"RpcPipeline had invalid questionId?");
question.pipeline = nullptr;
if (question.isReturned) {
KJ_ASSERT(lock->questions.erase(questionId));
}
}
}
kj::Promise<Response<ObjectPointer>> getResponse();
// implements PipelineHook ---------------------------------------
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
}
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const override {
auto copy = kj::heapArrayBuilder<PipelineOp>(ops.size());
for (auto& op: ops) {
copy.add(op);
}
return getPipelinedCap(copy.finish());
}
kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const override {
return kj::refcounted<PromisedAnswerClient>(
connectionState, questionId, kj::mv(ops), kj::addRef(*this));
}
private:
const RpcConnectionState& connectionState;
QuestionId questionId;
kj::Maybe<CapExtractorImpl&> capExtractor;
};
class RpcResponse {
public:
RpcResponse(RpcConnectionState& connectionState,
kj::Own<OutgoingRpcMessage>&& message,
ObjectPointer::Builder results)
: message(kj::mv(message)),
injector(connectionState),
context(injector),
builder(context.imbue(results)) {}
ObjectPointer::Builder getResults() {
return builder;
}
void send() {
message->send();
}
private:
kj::Own<OutgoingRpcMessage> message;
CapInjectorImpl injector;
CapBuilderContext context;
ObjectPointer::Builder builder;
};
// =====================================================================================
// CallContextHook implementation
class RpcCallContext final: public CallContextHook, public kj::Refcounted {
public:
RpcCallContext(RpcConnectionState& connectionState, QuestionId questionId,
kj::Own<IncomingRpcMessage>&& request, const ObjectPointer::Reader& params)
: connectionState(connectionState),
questionId(questionId),
request(kj::mv(request)),
requestCapExtractor(connectionState),
requestCapContext(requestCapExtractor),
params(requestCapContext.imbue(params)),
returnMessage(nullptr) {}
void sendReturn() {
if (response == nullptr) getResults(1); // force initialization of response
returnMessage.setQuestionId(questionId);
returnMessage.adoptRetainedCaps(requestCapExtractor.finalizeRetainedCaps(
Orphanage::getForMessageContaining(returnMessage)));
KJ_ASSERT_NONNULL(response)->send();
}
void sendErrorReturn(kj::Exception&& exception);
// implements CallContextHook ------------------------------------
......@@ -413,13 +793,18 @@ private:
}
ObjectPointer::Builder getResults(uint firstSegmentWordSize) override {
KJ_IF_MAYBE(r, response) {
return r->get()->getBody().getAs<rpc::Message>().getReturn().getAnswer();
return r->get()->getResults();
} else {
auto message = connectionState.connection->newOutgoingMessage(
firstSegmentWordSize == 0 ? 0 : firstSegmentWordSize + 10);
auto result = message->getBody().initAs<rpc::Message>().initReturn().getAnswer();
response = kj::mv(message);
return result;
firstSegmentWordSize == 0 ? 0 :
firstSegmentWordSize + messageSizeHint<rpc::Return>() +
requestCapExtractor.retainedListSizeHint(request == nullptr));
returnMessage = message->getBody().initAs<rpc::Message>().initReturn();
auto response = kj::heap<RpcResponse>(connectionState, kj::mv(message),
returnMessage.getAnswer());
auto results = response->getResults();
this->response = kj::mv(response);
return results;
}
}
void allowAsyncCancellation(bool allow) override {
......@@ -441,13 +826,15 @@ private:
CapExtractorImpl requestCapExtractor;
CapReaderContext requestCapContext;
CapBuilderContext responseContext;
ObjectPointer::Reader params;
kj::Maybe<kj::Own<OutgoingRpcMessage>> response;
kj::Maybe<kj::Own<RpcResponse>> response;
rpc::Return::Builder returnMessage;
};
// =====================================================================================
// Message handling
kj::Promise<void> messageLoop() {
return connection->receiveIncomingMessage().then(
[this](kj::Own<IncomingRpcMessage>&& message) {
......@@ -477,7 +864,8 @@ private:
break;
default: {
auto message = connection->newOutgoingMessage(reader.totalSizeInWords() + 6);
auto message = connection->newOutgoingMessage(
reader.totalSizeInWords() + messageSizeHint<void>());
message->getBody().initAs<rpc::Message>().setUnimplemented(reader);
message->send();
break;
......@@ -523,27 +911,12 @@ private:
pipeline = base.pipeline->addRef();
}
auto opsReader = promisedAnswer.getTransform();
auto ops = kj::heapArrayBuilder<PipelineOp>(opsReader.size());
for (auto opReader: opsReader) {
PipelineOp op;
switch (opReader.which()) {
case rpc::PromisedAnswer::Op::NOOP:
op.type = PipelineOp::NOOP;
break;
case rpc::PromisedAnswer::Op::GET_POINTER_FIELD:
op.type = PipelineOp::GET_POINTER_FIELD;
op.pointerIndex = opReader.getGetPointerField();
break;
default:
// TODO(soon): Handle better.
KJ_FAIL_REQUIRE("Unsupported pipeline op.", (uint)opReader.which()) {
return;
}
}
ops.add(op);
KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) {
capability = pipeline->getPipelinedCap(*ops);
} else {
// Exception already thrown.
return;
}
capability = pipeline->getPipelinedCap(ops.finish());
break;
}
......@@ -607,7 +980,7 @@ private:
void sendReleaseLater(ExportId importId, uint remoteRefcount) const {
tasks.add(eventLoop.evalLater([this,importId,remoteRefcount]() {
auto message = connection->newOutgoingMessage(8);
auto message = connection->newOutgoingMessage(messageSizeHint<rpc::Release>());
rpc::Release::Builder builder = message->getBody().initAs<rpc::Message>().initRelease();
builder.setId(importId);
builder.setReferenceCount(remoteRefcount);
......
......@@ -336,6 +336,10 @@ struct Finish {
# 3) If the answer has not returned yet, the caller no longer cares about the answer, so the
# callee may wish to immediately cancel the operation and send back a Return message with
# "canceled" set.
#
# TODO(soon): Should we separate (1) and (2)? It would be possible and useful to notify the
# server that it doesn't need to keep around the response to service pipeline requests even
# though the caller hasn't yet finished processing the response.
questionId @0 :QuestionId;
# ID of the question whose answer is to be released.
......
......@@ -54,4 +54,22 @@ TEST(Refcount, Basic) {
#endif
}
TEST(Refcount, Weak) {
{
bool b = false;
SetTrueInDestructor obj(&b);
EXPECT_TRUE(tryAddRef(obj) == nullptr);
}
{
bool b = false;
Own<SetTrueInDestructor> ref = kj::refcounted<SetTrueInDestructor>(&b);
KJ_IF_MAYBE(ref2, tryAddRef(*ref)) {
EXPECT_EQ(ref.get(), ref2->get());
} else {
ADD_FAILURE() << "tryAddRef() failed.";
}
}
}
} // namespace kj
......@@ -22,24 +22,39 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "refcount.h"
#include "debug.h"
#include <memory>
namespace kj {
Refcounted::~Refcounted() noexcept(false) {}
Refcounted::~Refcounted() noexcept(false) {
KJ_ASSERT(refcount == 0, "Refcounted object deleted with non-zero refcount.");
}
void Refcounted::disposeImpl(void* pointer) const {
// The load is a fast-path for the common case where this is the last reference. An acquire-load
// is just a regular load on x86. If there is more than one reference, then we need to do a full
// atomic decrement with full memory barrier, because:
// - If this is the final decrement then we need to acquire the object state in order to destroy
// it.
// - If this is not the final decrement then we need to release the object state so that another
// thread may destroy it.
if (__atomic_load_n(&refcount, __ATOMIC_ACQUIRE) == 1 ||
__atomic_sub_fetch(&refcount, 1, __ATOMIC_ACQ_REL) == 0) {
// Need to do a "release" decrement in order to release the object's state to any other thread
// which seeks to destroy it.
if (__atomic_sub_fetch(&refcount, 1, __ATOMIC_RELEASE) == 0) {
// This was the last reference. Acquire the memory so that we can destroy it.
__atomic_thread_fence(__ATOMIC_ACQUIRE);
delete this;
}
}
bool Refcounted::tryAddRefInternal() const {
// We want to increment the refcount, but only if it is non-zero. We have to use a cmpxchg for
// this.
uint old = __atomic_load_n(&refcount, __ATOMIC_RELAXED);
for (;;) {
if (old == 0) {
return false;
}
if (__atomic_compare_exchange_n(&refcount, &old, old + 1, true,
__ATOMIC_RELAXED, __ATOMIC_RELAXED)) {
return true;
}
}
}
} // namespace kj
......@@ -59,8 +59,12 @@ private:
template <typename T>
static Own<T> addRefInternal(T* object);
bool tryAddRefInternal() const;
template <typename T>
friend Own<T> addRef(T& object);
template <typename T>
friend Maybe<Own<T>> tryAddRef(T& object);
template <typename T, typename... Params>
friend Own<T> refcounted(Params&&... params);
};
......@@ -83,6 +87,23 @@ Own<T> addRef(T& object) {
return Refcounted::addRefInternal(&object);
}
template <typename T>
Maybe<Own<T>> tryAddRef(T& object) {
// Like `addRef`, but if the object's refcount is already zero or if the object was not allocated
// with `refcounted`, returns nullptr. This can be used to implement weak references in a
// thread-safe way: store a (regular, non-owned) pointer to the object, and have the object's
// destructor null out that pointer. To convert the pointer to a full reference, use tryAddRef().
// If it fails, the object is already being destroyed. Be sure to also use some sort of mutex
// locking to synchronize access to the raw pointer, since you'll want the object's destructor
// to block if another thread is currently trying to restore the ref.
if (object.Refcounted::tryAddRefInternal()) {
return Own<T>(&object, kj::implicitCast<const Refcounted&>(object));
} else {
return nullptr;
}
}
template <typename T>
Own<T> Refcounted::addRefInternal(T* object) {
const Refcounted* refcounted = object;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment