Commit 41feb767 authored by Kenton Varda's avatar Kenton Varda

More RPC protocol WIP.

parent 966d25a2
...@@ -110,7 +110,7 @@ public: ...@@ -110,7 +110,7 @@ public:
} }
return response->message.getRoot(); return response->message.getRoot();
} }
void allowAsyncCancellation(bool allow) override { void allowAsyncCancellation() override {
// ignored for local calls // ignored for local calls
} }
bool isCanceled() override { bool isCanceled() override {
......
...@@ -178,6 +178,8 @@ class CallContext: public kj::DisallowConstCopy { ...@@ -178,6 +178,8 @@ class CallContext: public kj::DisallowConstCopy {
// //
// Methods of this class may only be called from within the server's event loop, not from other // Methods of this class may only be called from within the server's event loop, not from other
// threads. // threads.
//
// The CallContext becomes invalid as soon as the call reports completion.
public: public:
explicit CallContext(CallContextHook& hook); explicit CallContext(CallContextHook& hook);
...@@ -205,7 +207,7 @@ public: ...@@ -205,7 +207,7 @@ public:
// `firstSegmentWordSize` indicates the suggested size of the message's first segment. This // `firstSegmentWordSize` indicates the suggested size of the message's first segment. This
// is a hint only. If not specified, the system will decide on its own. // is a hint only. If not specified, the system will decide on its own.
void allowAsyncCancellation(bool allow = true); void allowAsyncCancellation();
// Indicate that it is OK for the RPC system to discard its Promise for this call's result if // Indicate that it is OK for the RPC system to discard its Promise for this call's result if
// the caller cancels the call, thereby transitively canceling any asynchronous operations the // the caller cancels the call, thereby transitively canceling any asynchronous operations the
// call implementation was performing. This is not done by default because it could represent a // call implementation was performing. This is not done by default because it could represent a
...@@ -213,8 +215,6 @@ public: ...@@ -213,8 +215,6 @@ public:
// a bad state if an operation is canceled at an arbitrary point. However, for long-running // a bad state if an operation is canceled at an arbitrary point. However, for long-running
// method calls that hold significant resources, prompt cancellation is often useful. // method calls that hold significant resources, prompt cancellation is often useful.
// //
// You can also switch back to disallowing cancellation by passing `false` as the argument.
//
// Keep in mind that asynchronous cancellation cannot occur while the method is synchronously // Keep in mind that asynchronous cancellation cannot occur while the method is synchronously
// executing on a local thread. The method must perform an asynchronous operation or call // executing on a local thread. The method must perform an asynchronous operation or call
// `EventLoop::current().runLater()` to yield control. // `EventLoop::current().runLater()` to yield control.
...@@ -340,7 +340,7 @@ public: ...@@ -340,7 +340,7 @@ public:
virtual ObjectPointer::Reader getParams() = 0; virtual ObjectPointer::Reader getParams() = 0;
virtual void releaseParams() = 0; virtual void releaseParams() = 0;
virtual ObjectPointer::Builder getResults(uint firstSegmentWordSize) = 0; virtual ObjectPointer::Builder getResults(uint firstSegmentWordSize) = 0;
virtual void allowAsyncCancellation(bool allow) = 0; virtual void allowAsyncCancellation() = 0;
virtual bool isCanceled() = 0; virtual bool isCanceled() = 0;
virtual kj::Own<CallContextHook> addRef() = 0; virtual kj::Own<CallContextHook> addRef() = 0;
...@@ -551,8 +551,8 @@ inline Orphanage CallContext<Params, Results>::getResultsOrphanage(uint firstSeg ...@@ -551,8 +551,8 @@ inline Orphanage CallContext<Params, Results>::getResultsOrphanage(uint firstSeg
return Orphanage::getForMessageContaining(hook->getResults(firstSegmentWordSize)); return Orphanage::getForMessageContaining(hook->getResults(firstSegmentWordSize));
} }
template <typename Params, typename Results> template <typename Params, typename Results>
inline void CallContext<Params, Results>::allowAsyncCancellation(bool allow) { inline void CallContext<Params, Results>::allowAsyncCancellation() {
hook->allowAsyncCancellation(allow); hook->allowAsyncCancellation();
} }
template <typename Params, typename Results> template <typename Params, typename Results>
inline bool CallContext<Params, Results>::isCanceled() { inline bool CallContext<Params, Results>::isCanceled() {
......
...@@ -534,7 +534,7 @@ public: ...@@ -534,7 +534,7 @@ public:
void setResults(DynamicStruct::Reader value); void setResults(DynamicStruct::Reader value);
void adoptResults(Orphan<DynamicStruct>&& value); void adoptResults(Orphan<DynamicStruct>&& value);
Orphanage getResultsOrphanage(uint firstSegmentWordSize = 0); Orphanage getResultsOrphanage(uint firstSegmentWordSize = 0);
void allowAsyncCancellation(bool allow = true); void allowAsyncCancellation();
bool isCanceled(); bool isCanceled();
private: private:
...@@ -1513,8 +1513,8 @@ inline Orphanage CallContext<DynamicStruct, DynamicStruct>::getResultsOrphanage( ...@@ -1513,8 +1513,8 @@ inline Orphanage CallContext<DynamicStruct, DynamicStruct>::getResultsOrphanage(
uint firstSegmentWordSize) { uint firstSegmentWordSize) {
return Orphanage::getForMessageContaining(hook->getResults(firstSegmentWordSize)); return Orphanage::getForMessageContaining(hook->getResults(firstSegmentWordSize));
} }
inline void CallContext<DynamicStruct, DynamicStruct>::allowAsyncCancellation(bool allow) { inline void CallContext<DynamicStruct, DynamicStruct>::allowAsyncCancellation() {
hook->allowAsyncCancellation(allow); hook->allowAsyncCancellation();
} }
inline bool CallContext<DynamicStruct, DynamicStruct>::isCanceled() { inline bool CallContext<DynamicStruct, DynamicStruct>::isCanceled() {
return hook->isCanceled(); return hook->isCanceled();
......
...@@ -232,6 +232,9 @@ struct ObjectPointer { ...@@ -232,6 +232,9 @@ struct ObjectPointer {
kj::Own<const ClientHook> asCap() const; kj::Own<const ClientHook> asCap() const;
// Expect that the result is a capability and construct a pipelined version of it now. // Expect that the result is a capability and construct a pipelined version of it now.
inline kj::Own<const PipelineHook> releasePipelineHook() { return kj::mv(hook); }
// For use by RPC implementations.
private: private:
kj::Own<const PipelineHook> hook; kj::Own<const PipelineHook> hook;
kj::Array<PipelineOp> ops; kj::Array<PipelineOp> ops;
......
...@@ -85,6 +85,46 @@ Orphan<List<rpc::PromisedAnswer::Op>> fromPipelineOps( ...@@ -85,6 +85,46 @@ Orphan<List<rpc::PromisedAnswer::Op>> fromPipelineOps(
return result; return result;
} }
kj::Exception toException(const rpc::Exception::Reader& exception) {
kj::Exception::Nature nature =
exception.getIsCallersFault()
? kj::Exception::Nature::PRECONDITION
: kj::Exception::Nature::LOCAL_BUG;
kj::Exception::Durability durability;
switch (exception.getDurability()) {
default:
case rpc::Exception::Durability::PERMANENT:
durability = kj::Exception::Durability::PERMANENT;
break;
case rpc::Exception::Durability::TEMPORARY:
durability = kj::Exception::Durability::TEMPORARY;
break;
case rpc::Exception::Durability::OVERLOADED:
durability = kj::Exception::Durability::OVERLOADED;
break;
}
return kj::Exception(nature, durability, "(remote)", 0, kj::heapString(exception.getReason()));
}
void fromException(const kj::Exception& exception, rpc::Exception::Builder builder) {
builder.setReason(exception.getDescription());
builder.setIsCallersFault(exception.getNature() == kj::Exception::Nature::PRECONDITION);
switch (exception.getDurability()) {
case kj::Exception::Durability::PERMANENT:
builder.setDurability(rpc::Exception::Durability::PERMANENT);
break;
case kj::Exception::Durability::TEMPORARY:
builder.setDurability(rpc::Exception::Durability::TEMPORARY);
break;
case kj::Exception::Durability::OVERLOADED:
builder.setDurability(rpc::Exception::Durability::OVERLOADED);
break;
}
}
// =======================================================================================
typedef uint32_t QuestionId; typedef uint32_t QuestionId;
typedef uint32_t ExportId; typedef uint32_t ExportId;
...@@ -141,6 +181,19 @@ public: ...@@ -141,6 +181,19 @@ public:
} }
} }
kj::Maybe<T&> find(Id id) {
if (id < kj::size(low)) {
return low[id];
} else {
auto iter = high.find(id);
if (iter == nullptr) {
return nullptr;
} else {
return iter->second;
}
}
}
void erase(Id id) { void erase(Id id) {
if (id < kj::size(low)) { if (id < kj::size(low)) {
low[id] = T(); low[id] = T();
...@@ -181,14 +234,21 @@ struct Question { ...@@ -181,14 +234,21 @@ struct Question {
inline bool operator!=(decltype(nullptr)) const { return isStarted; } inline bool operator!=(decltype(nullptr)) const { return isStarted; }
}; };
template <typename CallContext>
struct Answer { struct Answer {
bool active = false; bool active = false;
// True from the point when the Call message is received to the point when both the `Finish`
// message has been received and the `Return` has been sent.
kj::Own<const PipelineHook> pipeline; kj::Maybe<kj::Own<const PipelineHook>> pipeline;
// Send pipelined calls here. // Send pipelined calls here. Becomes null as soon as a `Finish` is received.
kj::Promise<void> asyncOp = nullptr; kj::Promise<void> asyncOp = nullptr;
// Delete this promise to cancel the call. // Delete this promise to cancel the call.
kj::Maybe<const CallContext&> callContext;
// The call context, if it's still active. Becomes null when the `Return` message is sent. This
// object, if non-null, is owned by `asyncOp`.
}; };
struct Export { struct Export {
...@@ -201,6 +261,8 @@ struct Export { ...@@ -201,6 +261,8 @@ struct Export {
inline bool operator!=(decltype(nullptr)) const { return refcount != 0; } inline bool operator!=(decltype(nullptr)) const { return refcount != 0; }
}; };
// =======================================================================================
class RpcConnectionState: public kj::TaskSet::ErrorHandler { class RpcConnectionState: public kj::TaskSet::ErrorHandler {
public: public:
RpcConnectionState(const kj::EventLoop& eventLoop, RpcConnectionState(const kj::EventLoop& eventLoop,
...@@ -222,10 +284,11 @@ private: ...@@ -222,10 +284,11 @@ private:
class CapInjectorImpl; class CapInjectorImpl;
class CapExtractorImpl; class CapExtractorImpl;
class RpcPipeline; class RpcPipeline;
class RpcCallContext;
struct Tables { struct Tables {
ExportTable<QuestionId, Question<CapInjectorImpl, RpcPipeline>> questions; ExportTable<QuestionId, Question<CapInjectorImpl, RpcPipeline>> questions;
ImportTable<QuestionId, Answer> answers; ImportTable<QuestionId, Answer<RpcCallContext>> answers;
ExportTable<ExportId, Export> exports; ExportTable<ExportId, Export> exports;
ImportTable<ExportId, kj::Maybe<ImportClient&>> imports; ImportTable<ExportId, kj::Maybe<ImportClient&>> imports;
}; };
...@@ -282,9 +345,43 @@ private: ...@@ -282,9 +345,43 @@ private:
kj::Own<CallContextHook>&& context) const override { kj::Own<CallContextHook>&& context) const override {
auto params = context->getParams(); auto params = context->getParams();
newOutgoingMessage size_t sizeHint = params.targetSizeInWords();
// TODO(perf): Extend targetSizeInWords() to include a capability count? Here we increase
// the size by 1/16 to deal with cap descriptors possibly expanding. See also below, when
// handling the response.
sizeHint += sizeHint / 16;
newCall(interfaceId, methodId, params.targetSizeInWords() + CALL_MESSAGE_SIZE); // Don't overflow.
if (uint(sizeHint) != sizeHint) {
sizeHint = ~uint(0);
}
auto request = newCall(interfaceId, methodId, sizeHint);
request.set(context->getParams());
context->releaseParams();
auto promise = request.send();
auto pipeline = promise.releasePipelineHook();
auto voidPromise = promise.then(kj::mvCapture(context,
[](kj::Own<CallContextHook>&& context, Response<ObjectPointer> response) {
size_t sizeHint = response.targetSizeInWords();
// See above TODO.
sizeHint += sizeHint / 16;
// Don't overflow.
if (uint(sizeHint) != sizeHint) {
sizeHint = ~uint(0);
}
context->getResults(sizeHint).set(response);
}));
return { kj::mv(voidPromise), kj::mv(pipeline) };
} }
kj::Own<const ClientHook> addRef() const override { kj::Own<const ClientHook> addRef() const override {
...@@ -298,10 +395,12 @@ private: ...@@ -298,10 +395,12 @@ private:
const RpcConnectionState& connectionState; const RpcConnectionState& connectionState;
}; };
class ImportClient final: public RpcClient { class ImportClient: public RpcClient {
protected:
ImportClient(const RpcConnectionState& connectionState, ExportId importId)
: RpcClient(connectionState), importId(importId) {}
public: public:
ImportClient(const RpcConnectionState& connectionState, ExportId importId, bool isPromise)
: RpcClient(connectionState), importId(importId), isPromise(isPromise) {}
~ImportClient() noexcept(false) { ~ImportClient() noexcept(false) {
{ {
// Remove self from the import table, if the table is still pointing at us. (It's possible // Remove self from the import table, if the table is still pointing at us. (It's possible
...@@ -322,6 +421,10 @@ private: ...@@ -322,6 +421,10 @@ private:
} }
} }
virtual bool settle(kj::Own<const ClientHook> replacement) = 0;
// Replace the PromiseImportClient with its resolution. Returns false if this is not a promise
// (i.e. it is a SettledImportClient).
kj::Maybe<kj::Own<ImportClient>> tryAddRemoteRef() { kj::Maybe<kj::Own<ImportClient>> tryAddRemoteRef() {
// Add a new RemoteRef and return a new ref to this client representing it. Returns null // Add a new RemoteRef and return a new ref to this client representing it. Returns null
// if this client is being deleted in another thread, in which case the caller should // if this client is being deleted in another thread, in which case the caller should
...@@ -343,17 +446,69 @@ private: ...@@ -343,17 +446,69 @@ private:
// implements ClientHook ----------------------------------------- // implements ClientHook -----------------------------------------
Request<ObjectPointer, ObjectPointer> newCall( Request<ObjectPointer, ObjectPointer> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override; uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override; auto request = kj::heap<RpcRequest>(connectionState, firstSegmentWordSize);
auto callBuilder = request->getCall();
callBuilder.getTarget().setExportedCap(importId);
callBuilder.setInterfaceId(interfaceId);
callBuilder.setMethodId(methodId);
request->holdRef(writeTarget(callBuilder.getTarget()));
auto root = request->getRoot();
return Request<ObjectPointer, ObjectPointer>(root, kj::mv(request));
}
private: private:
ExportId importId; ExportId importId;
bool isPromise;
uint remoteRefcount = 0; uint remoteRefcount = 0;
// Number of times we've received this import from the peer. // Number of times we've received this import from the peer.
}; };
class SettledImportClient final: public ImportClient {
public:
inline SettledImportClient(const RpcConnectionState& connectionState, ExportId importId)
: ImportClient(connectionState, importId) {}
bool settle(kj::Own<const ClientHook> replacement) override {
return false;
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return nullptr;
}
};
class PromiseImportClient final: public ImportClient {
public:
PromiseImportClient(const RpcConnectionState& connectionState, ExportId importId)
: ImportClient(connectionState, importId),
fork(nullptr) {
auto paf = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>(connectionState.eventLoop);
fulfiller = kj::mv(paf.fulfiller);
fork = paf.promise.fork();
}
bool settle(kj::Own<const ClientHook> replacement) override {
fulfiller->fulfill(kj::mv(replacement));
return true;
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
// We need the returned promise to hold a reference back to this object, so that it doesn't
// disappear while the promise is still outstanding.
return fork.addBranch().thenInAnyThread(kj::mvCapture(kj::addRef(*this),
[](kj::Own<const PromiseImportClient>&&, kj::Own<const ClientHook>&& replacement) {
return kj::mv(replacement);
}));
}
private:
kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>> fulfiller;
kj::ForkedPromise<kj::Own<const ClientHook>> fork;
};
class PromisedAnswerClient final: public RpcClient { class PromisedAnswerClient final: public RpcClient {
public: public:
PromisedAnswerClient(const RpcConnectionState& connectionState, QuestionId questionId, PromisedAnswerClient(const RpcConnectionState& connectionState, QuestionId questionId,
...@@ -470,7 +625,8 @@ private: ...@@ -470,7 +625,8 @@ private:
kj::Own<const ClientHook> extractCap(rpc::CapDescriptor::Reader descriptor) const override { kj::Own<const ClientHook> extractCap(rpc::CapDescriptor::Reader descriptor) const override {
switch (descriptor.which()) { switch (descriptor.which()) {
case rpc::CapDescriptor::SENDER_HOSTED: { case rpc::CapDescriptor::SENDER_HOSTED:
case rpc::CapDescriptor::SENDER_PROMISE: {
ExportId importId = descriptor.getSenderHosted(); ExportId importId = descriptor.getSenderHosted();
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState.tables.lockExclusive();
...@@ -488,7 +644,14 @@ private: ...@@ -488,7 +644,14 @@ private:
} }
// No import for this ID exists currently, so create one. // No import for this ID exists currently, so create one.
auto result = kj::refcounted<ImportClient>(connectionState, importId); kj::Own<ImportClient> result;
if (descriptor.which() == rpc::CapDescriptor::SENDER_PROMISE) {
// TODO(now): Check for pending `Resolve` messages replacing this import ID, and if
// one exists, use that client instead.
kj::refcounted<PromiseImportClient>(connectionState, importId);
} else {
kj::refcounted<SettledImportClient>(connectionState, importId);
}
lock->imports[importId] = *result; lock->imports[importId] = *result;
// Note that we need to retain this import later if it still exists. // Note that we need to retain this import later if it still exists.
...@@ -497,10 +660,6 @@ private: ...@@ -497,10 +660,6 @@ private:
return kj::mv(result); return kj::mv(result);
} }
case rpc::CapDescriptor::SENDER_PROMISE:
// TODO(now): Implement this or remove `senderPromise`.
return newBrokenCap("senderPromise not implemented");
case rpc::CapDescriptor::RECEIVER_HOSTED: { case rpc::CapDescriptor::RECEIVER_HOSTED: {
auto lock = connectionState.tables.lockExclusive(); // TODO(perf): shared? auto lock = connectionState.tables.lockExclusive(); // TODO(perf): shared?
KJ_IF_MAYBE(exp, lock->exports.find(descriptor.getReceiverHosted())) { KJ_IF_MAYBE(exp, lock->exports.find(descriptor.getReceiverHosted())) {
...@@ -613,7 +772,7 @@ private: ...@@ -613,7 +772,7 @@ private:
// ===================================================================================== // =====================================================================================
// RequestHook/PipelineHook/ResponseHook implementations // RequestHook/PipelineHook/ResponseHook implementations
class RpcRequest: public RequestHook { class RpcRequest final: public RequestHook {
public: public:
RpcRequest(const RpcConnectionState& connectionState, uint firstSegmentWordSize) RpcRequest(const RpcConnectionState& connectionState, uint firstSegmentWordSize)
: connectionState(connectionState), : connectionState(connectionState),
...@@ -627,6 +786,9 @@ private: ...@@ -627,6 +786,9 @@ private:
inline ObjectPointer::Builder getRoot() { inline ObjectPointer::Builder getRoot() {
return paramsBuilder; return paramsBuilder;
} }
inline rpc::Call::Builder getCall() {
return callBuilder;
}
RemotePromise<ObjectPointer> send() override { RemotePromise<ObjectPointer> send() override {
auto paf = kj::newPromiseAndFulfiller<Response<ObjectPointer>>(connectionState.eventLoop); auto paf = kj::newPromiseAndFulfiller<Response<ObjectPointer>>(connectionState.eventLoop);
...@@ -731,9 +893,33 @@ private: ...@@ -731,9 +893,33 @@ private:
kj::Maybe<CapExtractorImpl&> capExtractor; kj::Maybe<CapExtractorImpl&> capExtractor;
}; };
class RpcResponse { class RpcResponse: public ResponseHook {
public: public:
RpcResponse(RpcConnectionState& connectionState, RpcResponse(const RpcConnectionState& connectionState,
kj::Own<IncomingRpcMessage>&& message,
ObjectPointer::Reader results)
: message(kj::mv(message)),
extractor(connectionState),
context(extractor),
reader(context.imbue(results)) {}
ObjectPointer::Reader getResults() {
return reader;
}
private:
kj::Own<IncomingRpcMessage> message;
CapExtractorImpl extractor;
CapReaderContext context;
ObjectPointer::Reader reader;
};
// =====================================================================================
// CallContextHook implementation
class RpcServerResponse {
public:
RpcServerResponse(const RpcConnectionState& connectionState,
kj::Own<OutgoingRpcMessage>&& message, kj::Own<OutgoingRpcMessage>&& message,
ObjectPointer::Builder results) ObjectPointer::Builder results)
: message(kj::mv(message)), : message(kj::mv(message)),
...@@ -756,9 +942,6 @@ private: ...@@ -756,9 +942,6 @@ private:
ObjectPointer::Builder builder; ObjectPointer::Builder builder;
}; };
// =====================================================================================
// CallContextHook implementation
class RpcCallContext final: public CallContextHook, public kj::Refcounted { class RpcCallContext final: public CallContextHook, public kj::Refcounted {
public: public:
RpcCallContext(RpcConnectionState& connectionState, QuestionId questionId, RpcCallContext(RpcConnectionState& connectionState, QuestionId questionId,
...@@ -772,6 +955,7 @@ private: ...@@ -772,6 +955,7 @@ private:
returnMessage(nullptr) {} returnMessage(nullptr) {}
void sendReturn() { void sendReturn() {
if (isFirstResponder()) {
if (response == nullptr) getResults(1); // force initialization of response if (response == nullptr) getResults(1); // force initialization of response
returnMessage.setQuestionId(questionId); returnMessage.setQuestionId(questionId);
...@@ -780,7 +964,57 @@ private: ...@@ -780,7 +964,57 @@ private:
KJ_ASSERT_NONNULL(response)->send(); KJ_ASSERT_NONNULL(response)->send();
} }
void sendErrorReturn(kj::Exception&& exception); }
void sendErrorReturn(kj::Exception&& exception) {
if (isFirstResponder()) {
auto message = connectionState.connection->newOutgoingMessage(
messageSizeHint<rpc::Return>() + sizeInWords<rpc::Exception>() +
exception.getDescription().size() / sizeof(word) + 1);
auto builder = message->getBody().initAs<rpc::Message>().initReturn();
builder.setQuestionId(questionId);
builder.adoptRetainedCaps(requestCapExtractor.finalizeRetainedCaps(
Orphanage::getForMessageContaining(returnMessage)));
fromException(exception, builder.initException());
message->send();
}
}
void sendCancel() {
if (isFirstResponder()) {
auto message = connectionState.connection->newOutgoingMessage(
messageSizeHint<rpc::Return>());
auto builder = message->getBody().initAs<rpc::Message>().initReturn();
builder.setQuestionId(questionId);
builder.adoptRetainedCaps(requestCapExtractor.finalizeRetainedCaps(
Orphanage::getForMessageContaining(returnMessage)));
builder.setCanceled();
message->send();
}
}
void requestCancel() const {
// Hints that the caller wishes to cancel this call. At the next time when cancellation is
// deemed safe, the RpcCallContext shall send a canceled Return -- or if it never becomes
// safe, the RpcCallContext will send a normal return when the call completes. Either way
// the RpcCallContext is now responsible for cleaning up the entry in the answer table, since
// a Finish message was already received.
// Verify that we're holding the tables mutex. This is important because we're handing off
// responsibility for deleting the answer. Moreover, the callContext pointer in the answer
// table should not be null as this would indicate that we've already returned a result.
KJ_DASSERT(connectionState.tables.getAlreadyLockedExclusive()
.answers[questionId].callContext != nullptr);
if (__atomic_fetch_or(&cancellationFlags, CANCEL_REQUESTED, __ATOMIC_RELAXED) ==
CANCEL_ALLOWED) {
// We just set CANCEL_REQUESTED, and CANCEL_ALLOWED was already set previously. Schedule
// the cancellation.
scheduleCancel();
}
}
// implements CallContextHook ------------------------------------ // implements CallContextHook ------------------------------------
...@@ -800,20 +1034,27 @@ private: ...@@ -800,20 +1034,27 @@ private:
firstSegmentWordSize + messageSizeHint<rpc::Return>() + firstSegmentWordSize + messageSizeHint<rpc::Return>() +
requestCapExtractor.retainedListSizeHint(request == nullptr)); requestCapExtractor.retainedListSizeHint(request == nullptr));
returnMessage = message->getBody().initAs<rpc::Message>().initReturn(); returnMessage = message->getBody().initAs<rpc::Message>().initReturn();
auto response = kj::heap<RpcResponse>(connectionState, kj::mv(message), auto response = kj::heap<RpcServerResponse>(
returnMessage.getAnswer()); connectionState, kj::mv(message), returnMessage.getAnswer());
auto results = response->getResults(); auto results = response->getResults();
this->response = kj::mv(response); this->response = kj::mv(response);
return results; return results;
} }
} }
void allowAsyncCancellation(bool allow) override { void allowAsyncCancellation() override {
// TODO(soon): Do we want this or not? if (threadAcceptingCancellation != nullptr) {
KJ_FAIL_REQUIRE("not implemented"); threadAcceptingCancellation = &kj::EventLoop::current();
if (__atomic_fetch_or(&cancellationFlags, CANCEL_ALLOWED, __ATOMIC_RELAXED) ==
CANCEL_REQUESTED) {
// We just set CANCEL_ALLOWED, and CANCEL_REQUESTED was already set previously. Schedule
// the cancellation.
scheduleCancel();
}
}
} }
bool isCanceled() override { bool isCanceled() override {
// TODO(soon): Do we want this or not? return __atomic_load_n(&cancellationFlags, __ATOMIC_RELAXED) & CANCEL_REQUESTED;
KJ_FAIL_REQUIRE("not implemented");
} }
kj::Own<CallContextHook> addRef() override { kj::Own<CallContextHook> addRef() override {
return kj::addRef(*this); return kj::addRef(*this);
...@@ -822,14 +1063,109 @@ private: ...@@ -822,14 +1063,109 @@ private:
private: private:
RpcConnectionState& connectionState; RpcConnectionState& connectionState;
QuestionId questionId; QuestionId questionId;
kj::Maybe<kj::Own<IncomingRpcMessage>> request;
// Request ---------------------------------------------
kj::Maybe<kj::Own<IncomingRpcMessage>> request;
CapExtractorImpl requestCapExtractor; CapExtractorImpl requestCapExtractor;
CapReaderContext requestCapContext; CapReaderContext requestCapContext;
ObjectPointer::Reader params; ObjectPointer::Reader params;
kj::Maybe<kj::Own<RpcResponse>> response; // Response --------------------------------------------
kj::Maybe<kj::Own<RpcServerResponse>> response;
rpc::Return::Builder returnMessage; rpc::Return::Builder returnMessage;
bool responseSent = false;
// Cancellation state ----------------------------------
enum CancellationFlags {
CANCEL_REQUESTED = 1,
CANCEL_ALLOWED = 2
};
mutable uint8_t cancellationFlags = 0;
// When both flags are set, the cancellation process will begin. Must be manipulated atomically
// as it may be accessed from multiple threads.
mutable kj::Promise<void> deferredCancellation = nullptr;
// Cancellation operation scheduled by cancelLater(). Must only be scheduled once, from one
// thread.
kj::EventLoop* threadAcceptingCancellation = nullptr;
// EventLoop for the thread that first called allowAsyncCancellation(). We store this as an
// optimization: if the application thread is independent from the network thread, we'd rather
// perform the cancellation in the application thread, because otherwise we might block waiting
// on an application promise continuation callback to finish executing, which could take
// arbitrary time.
// -----------------------------------------------------
void scheduleCancel() const {
// Arranges for the answer's asyncOp to be deleted, thus canceling all processing related to
// this call, shortly. We have to do it asynchronously because the caller might hold
// arbitrary locks or might in fact be part of the task being canceled.
deferredCancellation = threadAcceptingCancellation->evalLater([this]() {
// Make sure we don't accidentally delete ourselves in the process of canceling, since the
// last reference to the context may be owned by the asyncOp.
auto self = kj::addRef(*this);
// Extract from the answer table the promise representing the executing call.
kj::Promise<void> asyncOp = nullptr;
{
auto lock = connectionState.tables.lockExclusive();
asyncOp = kj::mv(lock->answers[questionId].asyncOp);
}
// Delete the promise, thereby canceling the operation. Note that if a continuation is
// running in another thread, this line blocks waiting for it to complete. This is why
// we try to schedule doCancel() on the application thread, so that it won't need to block.
asyncOp = nullptr;
// OK, now that we know the call isn't running in another thread, we can drop our thread
// safety and send a return message.
const_cast<RpcCallContext*>(this)->sendCancel();
});
}
bool isFirstResponder() {
// The first time it is called, removes self from the answer table and returns true.
// On subsequent calls, returns false.
if (responseSent) {
return false;
} else {
responseSent = true;
// We need to remove the `callContext` pointer -- which points back to us -- from the
// answer table. Or we might even be responsible for removing the entire answer table
// entry.
auto lock = connectionState.tables.lockExclusive();
if (__atomic_load_n(&cancellationFlags, __ATOMIC_RELAXED) & CANCEL_REQUESTED) {
// We are responsible for deleting the answer table entry. Awkwardly, however, the
// answer table may be the only thing holding a reference to the context, and we may even
// be called from the continuation represented by answer.asyncOp. So we have to do the
// actual deletion asynchronously. But we have to remove it from the table *now*, while
// we still hold the lock, because once we send the return message the answer ID is free
// for reuse.
connectionState.tasks.add(connectionState.eventLoop.evalLater(
kj::mvCapture(lock->answers[questionId],
[](Answer<RpcCallContext>&& answer) {
// Just let the answer be deleted.
})));
// Erase from the table.
lock->answers.erase(questionId);
} else {
// We just have to null out callContext.
lock->answers[questionId].callContext = nullptr;
}
return true;
}
}
}; };
// ===================================================================================== // =====================================================================================
...@@ -863,6 +1199,14 @@ private: ...@@ -863,6 +1199,14 @@ private:
doCall(kj::mv(message), reader.getCall()); doCall(kj::mv(message), reader.getCall());
break; break;
case rpc::Message::RETURN:
doReturn(kj::mv(message), reader.getReturn());
break;
case rpc::Message::FINISH:
doFinish(reader.getFinish());
break;
default: { default: {
auto message = connection->newOutgoingMessage( auto message = connection->newOutgoingMessage(
reader.totalSizeInWords() + messageSizeHint<void>()); reader.totalSizeInWords() + messageSizeHint<void>());
...@@ -904,11 +1248,17 @@ private: ...@@ -904,11 +1248,17 @@ private:
{ {
auto lock = tables.lockExclusive(); // TODO(perf): shared? auto lock = tables.lockExclusive(); // TODO(perf): shared?
const Answer& base = lock->answers[promisedAnswer.getQuestionId()]; auto& base = lock->answers[promisedAnswer.getQuestionId()];
KJ_REQUIRE(base.active, "PromisedAnswer.questionId is not a current question.") { KJ_REQUIRE(base.active, "PromisedAnswer.questionId is not a current question.") {
return; return;
} }
pipeline = base.pipeline->addRef(); KJ_IF_MAYBE(p, base.pipeline) {
pipeline = p->get()->addRef();
} else {
KJ_FAIL_REQUIRE("PromisedAnswer.questionId is already finished.") {
return;
}
}
} }
KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) { KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) {
...@@ -927,8 +1277,6 @@ private: ...@@ -927,8 +1277,6 @@ private:
} }
} }
// TODO(now): Imbue the message!
QuestionId questionId = call.getQuestionId(); QuestionId questionId = call.getQuestionId();
auto context = kj::refcounted<RpcCallContext>( auto context = kj::refcounted<RpcCallContext>(
*this, questionId, kj::mv(message), call.getRequest()); *this, questionId, kj::mv(message), call.getRequest());
...@@ -940,7 +1288,7 @@ private: ...@@ -940,7 +1288,7 @@ private:
{ {
auto lock = tables.lockExclusive(); auto lock = tables.lockExclusive();
Answer& answer = lock->answers[questionId]; auto& answer = lock->answers[questionId];
// We don't want to overwrite an active question because the destructors for the promise and // We don't want to overwrite an active question because the destructors for the promise and
// pipeline could try to lock our mutex. Of course, we did already fire off the new call // pipeline could try to lock our mutex. Of course, we did already fire off the new call
...@@ -951,6 +1299,7 @@ private: ...@@ -951,6 +1299,7 @@ private:
} }
answer.active = true; answer.active = true;
answer.callContext = *context;
answer.pipeline = kj::mv(promiseAndPipeline.pipeline); answer.pipeline = kj::mv(promiseAndPipeline.pipeline);
// Hack: Both the success and error continuations need to use the context. We could // Hack: Both the success and error continuations need to use the context. We could
...@@ -971,13 +1320,67 @@ private: ...@@ -971,13 +1320,67 @@ private:
} }
} }
void doReturn(kj::Own<IncomingRpcMessage>&& message, const rpc::Return::Reader& ret) {
kj::Own<CapInjectorImpl> paramCapsToRelease;
auto lock = tables.lockExclusive();
KJ_IF_MAYBE(question, lock->questions.find(ret.getQuestionId())) {
KJ_REQUIRE(!question->isReturned, "Duplicate Return.") { return; }
question->isReturned = true;
for (ExportId retained: ret.getRetainedCaps()) {
KJ_IF_MAYBE(exp, lock->exports.find(retained)) {
++exp->refcount;
} else {
KJ_FAIL_REQUIRE("Invalid export ID in Return.retainedCaps list.") { return; }
}
}
// We can now release the caps in the parameter message. Well... we can release them once
// we release the lock, at least.
paramCapsToRelease = kj::mv(question->paramCaps);
// TODO(now): Handle exception/cancel response
auto response = kj::heap<RpcResponse>(*this, kj::mv(message), ret.getAnswer());
auto imbuedResults = response->getResults();
question->fulfiller->fulfill(Response<ObjectPointer>(imbuedResults, kj::mv(response)));
kj::Exception toException(const rpc::Exception::Reader& exception) { if (question->pipeline == nullptr) {
// TODO(now) lock->questions.erase(ret.getQuestionId());
}
} else {
KJ_FAIL_REQUIRE("Invalid question ID in Return message.") { return; }
}
} }
void doFinish(const rpc::Finish::Reader& finish) {
kj::Maybe<kj::Own<const PipelineHook>> pipelineToRelease;
auto lock = tables.lockExclusive();
for (ExportId retained: finish.getRetainedCaps()) {
KJ_IF_MAYBE(exp, lock->exports.find(retained)) {
++exp->refcount;
} else {
KJ_FAIL_REQUIRE("Invalid export ID in Return.retainedCaps list.") { return; }
}
}
auto& answer = lock->answers[finish.getQuestionId()];
// `Finish` indicates that no further pipeline requests will be made.
pipelineToRelease = kj::mv(answer.pipeline);
KJ_IF_MAYBE(context, answer.callContext) {
context->requestCancel();
} else {
lock->answers.erase(finish.getQuestionId());
}
}
// =====================================================================================
void sendReleaseLater(ExportId importId, uint remoteRefcount) const { void sendReleaseLater(ExportId importId, uint remoteRefcount) const {
tasks.add(eventLoop.evalLater([this,importId,remoteRefcount]() { tasks.add(eventLoop.evalLater([this,importId,remoteRefcount]() {
auto message = connection->newOutgoingMessage(messageSizeHint<rpc::Release>()); auto message = connection->newOutgoingMessage(messageSizeHint<rpc::Release>());
......
...@@ -746,16 +746,19 @@ struct Exception { ...@@ -746,16 +746,19 @@ struct Exception {
# automated bug report were to be generated for this error, should it be initially filed on the # automated bug report were to be generated for this error, should it be initially filed on the
# caller's code or the callee's? This is a guess. Generally guesses should err towards blaming # caller's code or the callee's? This is a guess. Generally guesses should err towards blaming
# the callee -- at the very least, the callee should be on the hook for improving their error # the callee -- at the very least, the callee should be on the hook for improving their error
# handling to be more confident. # handling to be more confident in assigning blame.
isPermanent @2 :Bool; durability @2 :Durability;
# In the best estimate of the error source, is this error likely to repeat if the same call is # In the best estimate of the error source, is this error likely to repeat if the same call is
# executed again? Callers might use this to decide when to retry a request. # executed again? Callers might use this to decide when to retry a request.
isOverloaded @3 :Bool; enum Durability {
# In the best estimate of the error source, is it likely this error was caused by the system permanent @0; # Retrying the exact same operation will fail in the same way.
# being overloaded? If so, the caller probably should not retry the request now, but may temporary @1; # Retrying the exact same operation might succeed.
# consider retrying it later. overloaded @2; # The error may be due to the system being overloaded. Retrying may work
# later on, but for now the caller should not retry right away as this will
# likely exacerbate the problem.
}
} }
# ======================================================================================== # ========================================================================================
......
...@@ -639,7 +639,7 @@ class ForkedPromise { ...@@ -639,7 +639,7 @@ class ForkedPromise {
// Like `Promise<T>`, this is a pass-by-move type. // Like `Promise<T>`, this is a pass-by-move type.
public: public:
inline ForkedPromise(decltype(nullptr)): hub(nullptr) {} inline ForkedPromise(decltype(nullptr)) {}
Promise<_::Forked<T>> addBranch() const; Promise<_::Forked<T>> addBranch() const;
// Add a new branch to the fork. The branch is equivalent to the original promise, except // Add a new branch to the fork. The branch is equivalent to the original promise, except
......
...@@ -131,8 +131,9 @@ ArrayPtr<const char> KJ_STRINGIFY(Exception::Nature nature) { ...@@ -131,8 +131,9 @@ ArrayPtr<const char> KJ_STRINGIFY(Exception::Nature nature) {
ArrayPtr<const char> KJ_STRINGIFY(Exception::Durability durability) { ArrayPtr<const char> KJ_STRINGIFY(Exception::Durability durability) {
static const char* DURABILITY_STRINGS[] = { static const char* DURABILITY_STRINGS[] = {
"permanent",
"temporary", "temporary",
"permanent" "overloaded"
}; };
const char* s = DURABILITY_STRINGS[static_cast<uint>(durability)]; const char* s = DURABILITY_STRINGS[static_cast<uint>(durability)];
......
...@@ -61,8 +61,11 @@ public: ...@@ -61,8 +61,11 @@ public:
}; };
enum class Durability { enum class Durability {
PERMANENT, // Retrying the exact same operation will fail in exactly the same way.
TEMPORARY, // Retrying the exact same operation might succeed. TEMPORARY, // Retrying the exact same operation might succeed.
PERMANENT // Retrying the exact same operation will fail in exactly the same way. OVERLOADED // The error was possibly caused by the system being overloaded. Retrying the
// operation might work at a later point in time, but the caller should NOT retry
// immediately as this will probably exacerbate the problem.
// Make sure to update the stringifier if you add a new durability. // Make sure to update the stringifier if you add a new durability.
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment