Commit 15f599c8 authored by Kenton Varda's avatar Kenton Varda

Implement handling of 'Release' messages.

parent 98696a57
...@@ -281,6 +281,11 @@ public: ...@@ -281,6 +281,11 @@ public:
} }
void taskFailed(kj::Exception&& exception) override { void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, "Closing connection due to protocol error.", exception);
disconnect(kj::mv(exception));
}
void disconnect(kj::Exception&& exception) {
{ {
kj::Vector<kj::Own<const PipelineHook>> pipelinesToRelease; kj::Vector<kj::Own<const PipelineHook>> pipelinesToRelease;
kj::Vector<kj::Own<const ClientHook>> clientsToRelease; kj::Vector<kj::Own<const ClientHook>> clientsToRelease;
...@@ -468,8 +473,8 @@ private: ...@@ -468,8 +473,8 @@ private:
// ===================================================================================== // =====================================================================================
class ResolutionChain: public kj::Refcounted { class ResolutionChain: public kj::Refcounted {
// A chain of pending promise resolutions which may affect messages that are still being // A chain of pending promise resolutions and export releases which may affect messages that
// processed. // are still being processed.
// //
// When a `Resolve` message comes in, we can't just handle it and then release the original // When a `Resolve` message comes in, we can't just handle it and then release the original
// promise import all at once, because it's possible that the application is still processing // promise import all at once, because it's possible that the application is still processing
...@@ -477,12 +482,17 @@ private: ...@@ -477,12 +482,17 @@ private:
// the promise as it does. We need to hold off on the release until we've actually gotten // the promise as it does. We need to hold off on the release until we've actually gotten
// through all outstanding messages. // through all outstanding messages.
// //
// Similarly, when a `Release` message comes in that causes an export's refcount to hit zero,
// we can't actually release until all previous messages are consumed because any of them
// could contain a `CapDescriptor` with `receiverHosted`. Oh god, and I suppose
// `receiverAnswer` is affected too. So we have to delay cleanup of pipelines. Sigh.
//
// To that end, we have the resolution chain. Each time a `CapExtractorImpl` is created -- // To that end, we have the resolution chain. Each time a `CapExtractorImpl` is created --
// representing a message to be consumed by the application -- it takes a reference to the // representing a message to be consumed by the application -- it takes a reference to the
// current end of the chain. When a `Resolve` message arrives, it is added to the end of the // current end of the chain. When a `Resolve` message or whatever arrives, it is added to the
// chain, and thus all `CapExtractorImpl`s that exist at that point now hold a reference to it, // end of the chain, and thus all `CapExtractorImpl`s that exist at that point now hold a
// but new `CapExtractorImpl`s will not. Once all references are dropped, the original promise // reference to it, but new `CapExtractorImpl`s will not. Once all references are dropped,
// can be released. // the original promise can be released.
// //
// The connection state actually holds one instance of ResolutionChain which doesn't yet have // The connection state actually holds one instance of ResolutionChain which doesn't yet have
// a promise attached to it, representing the end of the chain. This is what allows a new // a promise attached to it, representing the end of the chain. This is what allows a new
...@@ -490,28 +500,79 @@ private: ...@@ -490,28 +500,79 @@ private:
// holding a reference to it. // holding a reference to it.
public: public:
kj::Own<ResolutionChain> add(ExportId importId, kj::Own<ResolutionChain> addResolve(ExportId importId,
kj::Own<const ClientHook>&& replacement) { kj::Own<const ClientHook>&& replacement) {
// Add the a new resolution to the chain. Returns the new end-of-chain. // Add the a new resolution to the chain. Returns the new end-of-chain.
value.init<ResolvedImport>(ResolvedImport { importId, kj::mv(replacement) });
auto result = kj::refcounted<ResolutionChain>();
next = kj::addRef(*result);
return kj::mv(result);
}
kj::Maybe<const ClientHook&> findImport(ExportId importId) const {
// Look for the given import ID in the resolution chain.
this->importId = importId; const ResolutionChain* ptr = this;
this->replacement = kj::mv(replacement);
while (ptr->value != nullptr) {
if (ptr->value.is<ResolvedImport>()) {
auto& ri = ptr->value.get<ResolvedImport>();
if (ri.importId == importId) {
return *ri.replacement;
}
}
ptr = ptr->next;
}
return nullptr;
}
kj::Own<ResolutionChain> addRelease(ExportId exportId, kj::Own<const ClientHook>&& cap) {
// Add the a new release to the chain. Returns the new end-of-chain.
value.init<DelayedRelease>(DelayedRelease { exportId, kj::mv(cap) });
auto result = kj::refcounted<ResolutionChain>(); auto result = kj::refcounted<ResolutionChain>();
next = kj::addRef(*result); next = kj::addRef(*result);
return kj::mv(result);
}
kj::Maybe<const ClientHook&> findExport(ExportId exportId) const {
// Look for the given export ID in the resolution chain.
const ResolutionChain* ptr = this;
filled = true; while (ptr->value != nullptr) {
if (ptr->value.is<DelayedRelease>()) {
auto& ri = ptr->value.get<DelayedRelease>();
if (ri.exportId == exportId) {
return *ri.cap;
}
}
ptr = ptr->next;
}
return nullptr;
}
kj::Own<ResolutionChain> addFinish(QuestionId answerId,
kj::Own<const PipelineHook>&& pipeline) {
// Add the a new finish to the chain. Returns the new end-of-chain.
value.init<DelayedFinish>(DelayedFinish { answerId, kj::mv(pipeline) });
auto result = kj::refcounted<ResolutionChain>();
next = kj::addRef(*result);
return kj::mv(result); return kj::mv(result);
} }
kj::Maybe<kj::Own<const ClientHook>> find(ExportId importId) const { kj::Maybe<const PipelineHook&> findPipeline(QuestionId answerId) const {
// Look for the given import ID in the resolution chain. // Look for the given answer ID in the resolution chain.
const ResolutionChain* ptr = this; const ResolutionChain* ptr = this;
while (ptr->filled) { while (ptr->value != nullptr) {
if (ptr->importId == importId) { if (ptr->value.is<DelayedFinish>()) {
return ptr->replacement->addRef(); auto& ri = ptr->value.get<DelayedFinish>();
if (ri.answerId == answerId) {
return *ri.pipeline;
}
} }
ptr = ptr->next; ptr = ptr->next;
} }
...@@ -522,10 +583,21 @@ private: ...@@ -522,10 +583,21 @@ private:
private: private:
kj::Own<const ResolutionChain> next; kj::Own<const ResolutionChain> next;
bool filled = false; struct ResolvedImport {
ExportId importId; ExportId importId;
kj::Own<const ClientHook> replacement; kj::Own<const ClientHook> replacement;
}; };
struct DelayedRelease {
ExportId exportId;
kj::Own<const ClientHook> cap;
};
struct DelayedFinish {
QuestionId answerId;
kj::Own<const PipelineHook> pipeline;
};
kj::OneOf<ResolvedImport, DelayedRelease, DelayedFinish> value;
};
// ===================================================================================== // =====================================================================================
// ClientHook implementations // ClientHook implementations
...@@ -641,6 +713,8 @@ private: ...@@ -641,6 +713,8 @@ private:
} }
// Send a message releasing our remote references. // Send a message releasing our remote references.
// TODO(perf): Is there any good reason to delay this until later? I can't remember why
// I did this.
if (remoteRefcount > 0) { if (remoteRefcount > 0) {
connectionState->sendReleaseLater(importId, remoteRefcount); connectionState->sendReleaseLater(importId, remoteRefcount);
} }
...@@ -1059,8 +1133,8 @@ private: ...@@ -1059,8 +1133,8 @@ private:
// message was received. In this case, the original import ID will already have been // message was received. In this case, the original import ID will already have been
// dropped and could even have been reused for another capability. Luckily, the // dropped and could even have been reused for another capability. Luckily, the
// resolution chain holds the capability we actually want. // resolution chain holds the capability we actually want.
KJ_IF_MAYBE(resolution, resolutionChain.find(importId)) { KJ_IF_MAYBE(resolution, resolutionChain.findImport(importId)) {
return kj::mv(*resolution); return resolution->addRef();
} }
// No recent resolutions. Check the import table then. // No recent resolutions. Check the import table then.
...@@ -1110,6 +1184,11 @@ private: ...@@ -1110,6 +1184,11 @@ private:
} }
case rpc::CapDescriptor::RECEIVER_HOSTED: { case rpc::CapDescriptor::RECEIVER_HOSTED: {
// First check to see if this export ID was recently released.
KJ_IF_MAYBE(cap, resolutionChain.findExport(descriptor.getReceiverHosted())) {
return cap->addRef();
}
KJ_IF_MAYBE(exp, tables.exports.find(descriptor.getReceiverHosted())) { KJ_IF_MAYBE(exp, tables.exports.find(descriptor.getReceiverHosted())) {
return exp->clientHook->addRef(); return exp->clientHook->addRef();
} }
...@@ -1118,14 +1197,17 @@ private: ...@@ -1118,14 +1197,17 @@ private:
case rpc::CapDescriptor::RECEIVER_ANSWER: { case rpc::CapDescriptor::RECEIVER_ANSWER: {
auto promisedAnswer = descriptor.getReceiverAnswer(); auto promisedAnswer = descriptor.getReceiverAnswer();
const PipelineHook* pipeline;
// First check to see if this question ID was recently finished.
KJ_IF_MAYBE(p, resolutionChain.findPipeline(promisedAnswer.getQuestionId())) {
pipeline = p;
} else {
KJ_IF_MAYBE(answer, tables.answers.find(promisedAnswer.getQuestionId())) { KJ_IF_MAYBE(answer, tables.answers.find(promisedAnswer.getQuestionId())) {
if (answer->active) { if (answer->active) {
KJ_IF_MAYBE(pipeline, answer->pipeline) { KJ_IF_MAYBE(p, answer->pipeline) {
KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) { pipeline = p->get();
return pipeline->get()->getPipelinedCap(*ops);
} else {
return newBrokenCap("unrecognized pipeline ops");
}
} }
} }
} }
...@@ -1133,6 +1215,13 @@ private: ...@@ -1133,6 +1215,13 @@ private:
return newBrokenCap("invalid 'receiverAnswer'"); return newBrokenCap("invalid 'receiverAnswer'");
} }
KJ_IF_MAYBE(ops, toPipelineOps(promisedAnswer.getTransform())) {
return pipeline->getPipelinedCap(*ops);
} else {
return newBrokenCap("unrecognized pipeline ops");
}
}
case rpc::CapDescriptor::THIRD_PARTY_HOSTED: case rpc::CapDescriptor::THIRD_PARTY_HOSTED:
return newBrokenCap("three-way introductions not implemented"); return newBrokenCap("three-way introductions not implemented");
...@@ -1151,18 +1240,13 @@ private: ...@@ -1151,18 +1240,13 @@ private:
CapInjectorImpl(const RpcConnectionState& connectionState) CapInjectorImpl(const RpcConnectionState& connectionState)
: connectionState(connectionState) {} : connectionState(connectionState) {}
~CapInjectorImpl() noexcept(false) { ~CapInjectorImpl() noexcept(false) {
kj::Vector<kj::Own<const ClientHook>> clientsToRelease(exports.size()); kj::Vector<kj::Own<const ResolutionChain>> thingsToRelease(exports.size());
auto lock = connectionState.tables.lockExclusive(); auto lock = connectionState.tables.lockExclusive();
if (lock->networkException == nullptr) { if (lock->networkException == nullptr) {
for (auto exportId: exports) { for (auto exportId: exports) {
auto& exp = KJ_ASSERT_NONNULL(lock->exports.find(exportId)); thingsToRelease.add(releaseExport(*lock, exportId, 1));
if (--exp.refcount == 0) {
lock->exportsByCap.erase(exp.clientHook);
clientsToRelease.add(kj::mv(exp.clientHook));
lock->exports.erase(exportId);
}
} }
} }
} }
...@@ -1851,7 +1935,7 @@ private: ...@@ -1851,7 +1935,7 @@ private:
break; break;
case rpc::Message::RELEASE: case rpc::Message::RELEASE:
// TODO(now) handleRelease(reader.getRelease());
break; break;
case rpc::Message::DISEMBARGO: case rpc::Message::DISEMBARGO:
...@@ -1874,9 +1958,25 @@ private: ...@@ -1874,9 +1958,25 @@ private:
void handleUnimplemented(const rpc::Message::Reader& message) { void handleUnimplemented(const rpc::Message::Reader& message) {
switch (message.which()) { switch (message.which()) {
case rpc::Message::RESOLVE: case rpc::Message::RESOLVE: {
// TODO(now): Release the resolution. auto cap = message.getResolve().getCap();
switch (cap.which()) {
case rpc::CapDescriptor::SENDER_HOSTED:
releaseExport(*tables.lockExclusive(), cap.getSenderHosted(), 1);
break; break;
case rpc::CapDescriptor::SENDER_PROMISE:
releaseExport(*tables.lockExclusive(), cap.getSenderPromise(), 1);
break;
case rpc::CapDescriptor::RECEIVER_ANSWER:
case rpc::CapDescriptor::RECEIVER_HOSTED:
// Nothing to do.
break;
case rpc::CapDescriptor::THIRD_PARTY_HOSTED:
releaseExport(*tables.lockExclusive(), cap.getThirdPartyHosted().getVineId(), 1);
break;
}
break;
}
default: default:
KJ_FAIL_ASSERT("Peer did not implement required RPC message type.", (uint)message.which()); KJ_FAIL_ASSERT("Peer did not implement required RPC message type.", (uint)message.which());
...@@ -2058,7 +2158,7 @@ private: ...@@ -2058,7 +2158,7 @@ private:
} }
void handleFinish(const rpc::Finish::Reader& finish) { void handleFinish(const rpc::Finish::Reader& finish) {
kj::Maybe<kj::Own<const PipelineHook>> pipelineToRelease; kj::Own<ResolutionChain> chainToRelease;
auto lock = tables.lockExclusive(); auto lock = tables.lockExclusive();
...@@ -2070,16 +2170,27 @@ private: ...@@ -2070,16 +2170,27 @@ private:
} }
} }
auto& answer = lock->answers[finish.getQuestionId()]; KJ_IF_MAYBE(answer, lock->answers.find(finish.getQuestionId())) {
KJ_REQUIRE(answer->active, "'Finish' for invalid question ID.") { return; }
// `Finish` indicates that no further pipeline requests will be made. // `Finish` indicates that no further pipeline requests will be made.
pipelineToRelease = kj::mv(answer.pipeline); // However, previously-sent messages that are still being processed could still refer to this
// pipeline, so we have to move it into the resolution chain.
KJ_IF_MAYBE(p, kj::mv(answer->pipeline)) {
chainToRelease = kj::mv(lock->resolutionChainTail);
lock->resolutionChainTail = chainToRelease->addFinish(finish.getQuestionId(), kj::mv(*p));
}
KJ_IF_MAYBE(context, answer.callContext) { // If the call isn't actually done yet, cancel it. Otherwise, we can go ahead and erase the
// question from the table.
KJ_IF_MAYBE(context, answer->callContext) {
context->requestCancel(); context->requestCancel();
} else { } else {
lock->answers.erase(finish.getQuestionId()); lock->answers.erase(finish.getQuestionId());
} }
} else {
KJ_REQUIRE(answer->active, "'Finish' for invalid question ID.") { return; }
}
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
...@@ -2114,7 +2225,7 @@ private: ...@@ -2114,7 +2225,7 @@ private:
// Extend the resolution chain. // Extend the resolution chain.
auto oldTail = kj::mv(lock->resolutionChainTail); auto oldTail = kj::mv(lock->resolutionChainTail);
lock->resolutionChainTail = oldTail->add(resolve.getPromiseId(), kj::mv(replacement)); lock->resolutionChainTail = oldTail->addResolve(resolve.getPromiseId(), kj::mv(replacement));
lock.release(); // in case oldTail is destroyed lock.release(); // in case oldTail is destroyed
// If the import is on the table, fulfill it. // If the import is on the table, fulfill it.
...@@ -2130,6 +2241,35 @@ private: ...@@ -2130,6 +2241,35 @@ private:
} }
} }
void handleRelease(const rpc::Release::Reader& release) {
releaseExport(*tables.lockExclusive(), release.getId(), release.getReferenceCount());
}
static kj::Own<ResolutionChain> releaseExport(Tables& lockedTables, ExportId id, uint refcount) {
kj::Own<ResolutionChain> result;
KJ_IF_MAYBE(exp, lockedTables.exports.find(id)) {
KJ_REQUIRE(refcount <= exp->refcount, "Tried to drop export's refcount below zero.") {
return result;
}
exp->refcount -= refcount;
if (exp->refcount == 0) {
lockedTables.exportsByCap.erase(exp->clientHook);
result = kj::mv(lockedTables.resolutionChainTail);
lockedTables.resolutionChainTail = result->addRelease(id, kj::mv(exp->clientHook));
lockedTables.exports.erase(id);
return result;
} else {
return result;
}
} else {
KJ_FAIL_REQUIRE("Tried to release invalid export ID.") {
return result;
}
}
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Level 2 // Level 2
...@@ -2242,7 +2382,7 @@ public: ...@@ -2242,7 +2382,7 @@ public:
kj::Exception::Nature::LOCAL_BUG, kj::Exception::Durability::PERMANENT, kj::Exception::Nature::LOCAL_BUG, kj::Exception::Durability::PERMANENT,
__FILE__, __LINE__, kj::str("RpcSystem was destroyed.")); __FILE__, __LINE__, kj::str("RpcSystem was destroyed."));
for (auto& entry: connectionMap) { for (auto& entry: connectionMap) {
entry.second->taskFailed(kj::cp(shutdownException)); entry.second->disconnect(kj::cp(shutdownException));
deleteMe.add(kj::mv(entry.second)); deleteMe.add(kj::mv(entry.second));
} }
} }
......
...@@ -52,6 +52,9 @@ public: ...@@ -52,6 +52,9 @@ public:
OneOf& operator=(const OneOf& other) { if (tag != 0) destroy(); copyFrom(other); return *this; } OneOf& operator=(const OneOf& other) { if (tag != 0) destroy(); copyFrom(other); return *this; }
OneOf& operator=(OneOf&& other) { if (tag != 0) destroy(); moveFrom(other); return *this; } OneOf& operator=(OneOf&& other) { if (tag != 0) destroy(); moveFrom(other); return *this; }
inline bool operator==(decltype(nullptr)) const { return tag == 0; }
inline bool operator!=(decltype(nullptr)) const { return tag != 0; }
template <typename T> template <typename T>
bool is() const { bool is() const {
return tag == typeIndex<T>(); return tag == typeIndex<T>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment