Commit e4a5344b authored by Kenton Varda's avatar Kenton Varda

Refactor QuestionRef.

parent d080e158
......@@ -239,22 +239,26 @@ public:
kj::Own<kj::PromiseFulfiller<void>>&& disconnectFulfiller)
: eventLoop(eventLoop), restorer(restorer), connection(kj::mv(connection)),
disconnectFulfiller(kj::mv(disconnectFulfiller)),
tasks(eventLoop, *this), exportDisposer(*this) {
tasks(eventLoop, *this) {
tasks.add(messageLoop());
}
kj::Own<const ClientHook> restore(ObjectPointer::Reader objectId) {
QuestionId questionId;
auto paf = kj::newPromiseAndFulfiller<kj::Own<RpcResponse>>(eventLoop);
kj::Own<QuestionRef> questionRef;
auto paf = kj::newPromiseAndFulfiller<kj::Own<const RpcResponse>>(eventLoop);
{
auto lock = tables.lockExclusive();
auto& question = lock->questions.next(questionId);
question.isStarted = true;
question.fulfiller = kj::mv(paf.fulfiller);
// We need a dummy paramCaps since null normally indicates that the question has completed.
question.paramCaps = kj::heap<CapInjectorImpl>(*this);
questionRef = kj::refcounted<QuestionRef>(*this, questionId, kj::mv(paf.fulfiller));
question.selfRef = *questionRef;
paf.promise.attach(kj::addRef(*questionRef));
}
{
......@@ -268,18 +272,8 @@ public:
message->send();
}
auto questionRef = kj::heap<QuestionRef>(*this, questionId);
auto promiseWithQuestionRef = eventLoop.there(kj::mv(paf.promise),
kj::mvCapture(questionRef,
[](kj::Own<QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response)
-> kj::Own<const RpcResponse> {
response->setQuestionRef(kj::mv(questionRef));
return kj::mv(response);
}));
auto pipeline = kj::refcounted<RpcPipeline>(
*this, questionId, eventLoop.fork(kj::mv(promiseWithQuestionRef)));
*this, kj::mv(questionRef), eventLoop.fork(kj::mv(paf.promise)));
return pipeline->getPipelinedCap(kj::Array<const PipelineOp>(nullptr));
}
......@@ -303,7 +297,13 @@ public:
// All current questions complete with exceptions.
lock->questions.forEach([&](QuestionId id, Question& question) {
question.fulfiller->reject(kj::cp(networkException));
KJ_IF_MAYBE(questionRef, question.selfRef) {
// QuestionRef still present. Make sure it's not in the midst of being destroyed, then
// reject it.
KJ_IF_MAYBE(ownRef, kj::tryAddRef(*questionRef)) {
questionRef->reject(kj::cp(networkException));
}
}
KJ_IF_MAYBE(pc, question.paramCaps) {
paramCapsToRelease.add(kj::mv(*pc));
}
......@@ -351,6 +351,7 @@ private:
class PromiseClient;
class CapInjectorImpl;
class CapExtractorImpl;
class QuestionRef;
class RpcPipeline;
class RpcCallContext;
class RpcResponse;
......@@ -364,23 +365,19 @@ private:
typedef uint32_t ExportId;
struct Question {
kj::Own<kj::PromiseFulfiller<kj::Own<RpcResponse>>> fulfiller;
// Fulfill with the response.
kj::Maybe<kj::Own<CapInjectorImpl>> paramCaps;
// CapInjector from the parameter struct. This will be released once the `Return` message is
// received and `retainedCaps` processed. (If this is non-null, then the call has not returned
// yet.)
bool isStarted = false;
// Is this Question ID currently in-use? (This is true until both `Return` has been received and
// `Finish` has been sent.)
bool isFinished = false;
// Has the `Finish` message been sent?
kj::Maybe<QuestionRef&> selfRef;
// The local QuestionRef, set to nullptr when it is destroyed, which is also when `Finish` is
// sent.
inline bool operator==(decltype(nullptr)) const { return !isStarted; }
inline bool operator!=(decltype(nullptr)) const { return isStarted; }
inline bool operator==(decltype(nullptr)) const {
return paramCaps == nullptr && selfRef == nullptr;
}
inline bool operator!=(decltype(nullptr)) const { return !operator==(nullptr); }
};
struct Answer {
......@@ -449,34 +446,6 @@ private:
kj::TaskSet tasks;
class ExportDisposer final: public kj::Disposer {
public:
inline ExportDisposer(const RpcConnectionState& connectionState)
: connectionState(connectionState) {}
protected:
void disposeImpl(void* pointer) const override {
auto lock = connectionState.tables.lockExclusive();
ExportId id = reinterpret_cast<intptr_t>(pointer);
KJ_IF_MAYBE(exp, lock->exports.find(id)) {
if (--exp->refcount == 0) {
KJ_ASSERT(lock->exports.erase(id)) {
break;
}
}
} else {
KJ_FAIL_REQUIRE("invalid export ID", id) { break; }
}
}
private:
const RpcConnectionState& connectionState;
};
// TODO(now): unused?
ExportDisposer exportDisposer;
// =====================================================================================
// ClientHook implementations
......@@ -652,20 +621,25 @@ private:
public:
PipelineClient(const RpcConnectionState& connectionState,
kj::Own<const RpcPipeline>&& pipeline,
kj::Own<const QuestionRef>&& questionRef,
kj::Array<PipelineOp>&& ops)
: RpcClient(connectionState), pipeline(kj::mv(pipeline)), ops(kj::mv(ops)) {}
: RpcClient(connectionState), questionRef(kj::mv(questionRef)), ops(kj::mv(ops)) {}
kj::Maybe<ExportId> writeDescriptor(
rpc::CapDescriptor::Builder descriptor, Tables& tables) const override {
return pipeline->writeDescriptor(descriptor, tables, ops);
auto promisedAnswer = descriptor.initReceiverAnswer();
promisedAnswer.setQuestionId(questionRef->getId());
promisedAnswer.adoptTransform(fromPipelineOps(
Orphanage::getForMessageContaining(descriptor), ops));
return nullptr;
}
kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target) const override {
// TODO(now): The pipeline may redirect to the resolution before PromiseClient has resolved.
// This could lead to a race condition if PromiseClient implements embargoes.
return pipeline->writeTarget(target, ops);
auto builder = target.initPromisedAnswer();
builder.setQuestionId(questionRef->getId());
builder.adoptTransform(fromPipelineOps(Orphanage::getForMessageContaining(builder), ops));
return nullptr;
}
// implements ClientHook -----------------------------------------
......@@ -688,7 +662,7 @@ private:
}
private:
kj::Own<const RpcPipeline> pipeline;
kj::Own<const QuestionRef> questionRef;
kj::Array<PipelineOp> ops;
};
......@@ -779,7 +753,6 @@ private:
exp.refcount = 1;
exp.clientHook = kj::mv(cap);
descriptor.setSenderHosted(exportId);
KJ_DBG(this, exportId);
return exportId;
}
}
......@@ -1001,7 +974,6 @@ private:
entry.second.cap->addRef(), entry.second.builder, tables);
KJ_IF_MAYBE(exportId, maybeExportId) {
KJ_ASSERT(tables.exports.find(*exportId) != nullptr);
KJ_DBG(&connectionState, *exportId);
exports.add(*exportId);
}
}
......@@ -1064,24 +1036,25 @@ private:
// =====================================================================================
// RequestHook/PipelineHook/ResponseHook implementations
class QuestionRef {
class QuestionRef: public kj::Refcounted {
// A reference to an entry on the question table. Used to detect when the `Finish` message
// can be sent.
public:
inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id)
: connectionState(kj::addRef(connectionState)), id(id) {}
inline QuestionRef(const RpcConnectionState& connectionState, QuestionId id,
kj::Own<kj::PromiseFulfiller<kj::Own<const RpcResponse>>> fulfiller)
: connectionState(kj::addRef(connectionState)), id(id), fulfiller(kj::mv(fulfiller)),
resultCaps(connectionState) {}
~QuestionRef() {
// Send the "Finish" message.
auto message = connectionState->connection->newOutgoingMessage(
messageSizeHint<rpc::Finish>() +
resultCaps.map([](CapExtractorImpl& ce) { return ce.retainedListSizeHint(true); })
.orDefault(0));
messageSizeHint<rpc::Finish>() + resultCaps.retainedListSizeHint(true));
auto builder = message->getBody().getAs<rpc::Message>().initFinish();
builder.setQuestionId(id);
KJ_IF_MAYBE(r, resultCaps) {
builder.adoptRetainedCaps(r->finalizeRetainedCaps(
builder.adoptRetainedCaps(resultCaps.finalizeRetainedCaps(
Orphanage::getForMessageContaining(builder)));
}
message->send();
......@@ -1096,21 +1069,27 @@ private:
// Call has already returned, so we can now remove it from the table.
KJ_ASSERT(lock->questions.erase(id));
} else {
question.isFinished = true;
question.selfRef = nullptr;
}
}
}
inline QuestionId getId() const { return id; }
inline CapExtractorImpl& getCapExtractor() { return resultCaps; }
void fulfill(kj::Own<const RpcResponse>&& response) {
fulfiller->fulfill(kj::mv(response));
}
void setResultCapExtractor(CapExtractorImpl& extractor) {
resultCaps = extractor;
void reject(kj::Exception&& exception) {
fulfiller->reject(kj::mv(exception));
}
private:
kj::Own<const RpcConnectionState> connectionState;
QuestionId id;
kj::Maybe<CapExtractorImpl&> resultCaps;
kj::Own<kj::PromiseFulfiller<kj::Own<const RpcResponse>>> fulfiller;
CapExtractorImpl resultCaps;
};
class RpcRequest final: public RequestHook {
......@@ -1135,7 +1114,8 @@ private:
RemotePromise<ObjectPointer> send() override {
QuestionId questionId;
kj::Promise<kj::Own<RpcResponse>> promise = nullptr;
kj::Own<QuestionRef> questionRef;
kj::Promise<kj::Own<const RpcResponse>> promise = nullptr;
{
auto lock = connectionState->tables.lockExclusive();
......@@ -1169,30 +1149,25 @@ private:
} else {
injector->finishDescriptors(*lock);
auto paf = kj::newPromiseAndFulfiller<kj::Own<RpcResponse>>(connectionState->eventLoop);
auto paf = kj::newPromiseAndFulfiller<kj::Own<const RpcResponse>>(
connectionState->eventLoop);
auto& question = lock->questions.next(questionId);
callBuilder.setQuestionId(questionId);
question.isStarted = true;
question.paramCaps = kj::mv(injector);
question.fulfiller = kj::mv(paf.fulfiller);
questionRef = kj::refcounted<QuestionRef>(
*connectionState, questionId, kj::mv(paf.fulfiller));
question.selfRef = *questionRef;
message->send();
promise = kj::mv(paf.promise);
promise.attach(kj::addRef(*questionRef));
}
}
auto questionRef = kj::heap<QuestionRef>(*connectionState, questionId);
auto promiseWithQuestionRef = promise.thenInAnyThread(kj::mvCapture(questionRef,
[](kj::Own<QuestionRef>&& questionRef, kj::Own<RpcResponse>&& response)
-> kj::Own<const RpcResponse> {
response->setQuestionRef(kj::mv(questionRef));
return kj::mv(response);
}));
auto forkedPromise = connectionState->eventLoop.fork(kj::mv(promiseWithQuestionRef));
auto forkedPromise = connectionState->eventLoop.fork(kj::mv(promise));
auto appPromise = forkedPromise.addBranch().thenInAnyThread(
[](kj::Own<const RpcResponse>&& response) {
......@@ -1201,7 +1176,7 @@ private:
});
auto pipeline = kj::refcounted<RpcPipeline>(
*connectionState, questionId, kj::mv(forkedPromise));
*connectionState, kj::mv(questionRef), kj::mv(forkedPromise));
return RemotePromise<ObjectPointer>(
kj::mv(appPromise),
......@@ -1221,7 +1196,7 @@ private:
class RpcPipeline final: public PipelineHook, public kj::Refcounted {
public:
RpcPipeline(const RpcConnectionState& connectionState, QuestionId questionId,
RpcPipeline(const RpcConnectionState& connectionState, kj::Own<const QuestionRef> questionRef,
kj::ForkedPromise<kj::Own<const RpcResponse>>&& redirectLaterParam)
: connectionState(kj::addRef(connectionState)),
redirectLater(kj::mv(redirectLaterParam)),
......@@ -1234,53 +1209,13 @@ private:
// Construct a new RpcPipeline.
resolveSelfPromise.eagerlyEvaluate(connectionState.eventLoop);
state.getWithoutLock().init<Waiting>(questionId);
state.getWithoutLock().init<Waiting>(kj::mv(questionRef));
}
kj::Promise<kj::Own<const RpcResponse>> onResponse() const {
return redirectLater.addBranch();
}
kj::Maybe<kj::Own<const ClientHook>> writeTarget(
rpc::Call::Target::Builder target, kj::ArrayPtr<const PipelineOp> ops) const {
// Initializes `target` to a PromisedAnswer for a pipelined capability, *or* returns the
// specified capability directly.
//
// The caller *should* have a lock on the connection state's tables while calling this, so
// that a Finish message cannot be sent before the caller manages to send its `Call`.
auto lock = state.lockExclusive();
if (lock->is<Waiting>()) {
auto builder = target.initPromisedAnswer();
builder.setQuestionId(lock->get<Waiting>());
builder.adoptTransform(fromPipelineOps(Orphanage::getForMessageContaining(builder), ops));
return nullptr;
} else if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getResults().getPipelinedCap(ops);
} else {
return newBrokenCap(kj::cp(lock->get<Broken>()));
}
}
kj::Maybe<ExportId> writeDescriptor(rpc::CapDescriptor::Builder descriptor, Tables& tables,
kj::ArrayPtr<const PipelineOp> ops) const {
auto lock = state.lockExclusive();
if (lock->is<Waiting>()) {
auto promisedAnswer = descriptor.initReceiverAnswer();
promisedAnswer.setQuestionId(lock->get<Waiting>());
promisedAnswer.adoptTransform(fromPipelineOps(
Orphanage::getForMessageContaining(descriptor), ops));
return nullptr;
} else if (lock->is<Resolved>()) {
return connectionState->writeDescriptor(
lock->get<Resolved>()->getResults().getPipelinedCap(ops),
descriptor, tables);
} else {
return connectionState->writeDescriptor(
newBrokenCap(kj::cp(lock->get<Broken>())), descriptor, tables);
}
}
// implements PipelineHook ---------------------------------------
kj::Own<const PipelineHook> addRef() const override {
......@@ -1300,7 +1235,7 @@ private:
if (lock->is<Waiting>()) {
// Wrap a PipelineClient in a PromiseClient.
auto pipelineClient = kj::refcounted<PipelineClient>(
*connectionState, kj::addRef(*this), kj::heapArray(ops.asPtr()));
*connectionState, kj::addRef(*lock->get<Waiting>()), kj::heapArray(ops.asPtr()));
auto resolutionPromise = connectionState->eventLoop.there(redirectLater.addBranch(),
kj::mvCapture(ops,
......@@ -1322,7 +1257,7 @@ private:
kj::Maybe<CapExtractorImpl&> capExtractor;
kj::ForkedPromise<kj::Own<const RpcResponse>> redirectLater;
typedef QuestionId Waiting;
typedef kj::Own<const QuestionRef> Waiting;
typedef kj::Own<const RpcResponse> Resolved;
typedef kj::Exception Broken;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, Broken>> state;
......@@ -1347,13 +1282,14 @@ private:
class RpcResponse final: public ResponseHook, public kj::Refcounted {
public:
RpcResponse(const RpcConnectionState& connectionState,
kj::Own<QuestionRef>&& questionRef,
kj::Own<IncomingRpcMessage>&& message,
ObjectPointer::Reader results)
: connectionState(kj::addRef(connectionState)),
message(kj::mv(message)),
extractor(connectionState),
context(extractor),
reader(context.imbue(results)) {}
context(questionRef->getCapExtractor()),
reader(context.imbue(results)),
questionRef(kj::mv(questionRef)) {}
ObjectPointer::Reader getResults() const {
return reader;
......@@ -1363,18 +1299,12 @@ private:
return kj::addRef(*this);
}
void setQuestionRef(kj::Own<QuestionRef>&& questionRef) {
this->questionRef = kj::mv(questionRef);
this->questionRef->setResultCapExtractor(extractor);
}
private:
kj::Own<const RpcConnectionState> connectionState;
kj::Own<IncomingRpcMessage> message;
CapExtractorImpl extractor;
CapReaderContext context;
ObjectPointer::Reader reader;
kj::Own<QuestionRef> questionRef;
kj::Own<const QuestionRef> questionRef;
};
// =====================================================================================
......@@ -1853,16 +1783,27 @@ private:
switch (ret.which()) {
case rpc::Return::RESULTS:
question->fulfiller->fulfill(
kj::refcounted<RpcResponse>(*this, kj::mv(message), ret.getResults()));
KJ_IF_MAYBE(questionRef, question->selfRef) {
// The questionRef still exists, but could be being deleted in another thread.
KJ_IF_MAYBE(ownRef, kj::tryAddRef(*questionRef)) {
// Not being deleted.
questionRef->fulfill(kj::refcounted<RpcResponse>(
*this, kj::mv(*ownRef), kj::mv(message), ret.getResults()));
}
}
break;
case rpc::Return::EXCEPTION:
question->fulfiller->reject(toException(ret.getException()));
KJ_IF_MAYBE(questionRef, question->selfRef) {
// The questionRef still exists, but could be being deleted in another thread.
KJ_IF_MAYBE(ownRef, kj::tryAddRef(*questionRef)) {
questionRef->reject(toException(ret.getException()));
}
}
break;
case rpc::Return::CANCELED:
KJ_REQUIRE(question->isFinished,
KJ_REQUIRE(question->selfRef == nullptr,
"Return message falsely claims call was canceled.") { return; }
// We don't bother fulfilling the result. If someone is somehow still waiting on it
// (shouldn't be possible), that's OK: they'll get an exception due to the fulfiller
......@@ -1873,7 +1814,7 @@ private:
KJ_FAIL_REQUIRE("Unknown return type (not answer, exception, or canceled).") { return; }
}
if (question->isFinished) {
if (question->selfRef == nullptr) {
lock->questions.erase(ret.getQuestionId());
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment