Commit 0b9f906e authored by Kenton Varda's avatar Kenton Varda

Implement whenWriteDisconnected() everywhere.

I decided to make the new method pure-virtual as I wanted to make sure that all wrapper streams properly delegate to the inner stream. We wouldn't want e.g. proactive HTTP cancellation to unexpectedly not work when running over TLS.
parent d32b388c
......@@ -126,6 +126,8 @@ public:
}
return kj::READY_NOW;
}
Promise<void> whenWriteDisconnected() override { KJ_UNIMPLEMENTED("not used"); }
};
KJ_TEST("gzip decompression") {
......
......@@ -118,6 +118,8 @@ public:
Promise<void> write(const void* buffer, size_t size) override;
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override;
Promise<void> whenWriteDisconnected() override { return inner.whenWriteDisconnected(); }
inline Promise<void> flush() {
return pump(Z_SYNC_FLUSH);
}
......
......@@ -288,6 +288,10 @@ public:
return inner.tryPumpFrom(input, amount);
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
void shutdownWrite() override {
return inner.shutdownWrite();
}
......@@ -1791,6 +1795,10 @@ public:
return out->tryPumpFrom(input, amount);
}
Promise<void> whenWriteDisconnected() override {
return out->whenWriteDisconnected();
}
void shutdownWrite() override {
out = nullptr;
}
......@@ -2746,6 +2754,9 @@ public:
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
return inner->tryPumpFrom(input, amount);
}
Promise<void> whenWriteDisconnected() override {
return inner->whenWriteDisconnected();
}
void shutdownWrite() override {
return inner->shutdownWrite();
}
......
......@@ -1760,6 +1760,10 @@ public:
return fork.addBranch();
}
Promise<void> whenWriteDisconnected() {
return inner.whenWriteDisconnected();
}
private:
AsyncOutputStream& inner;
kj::Promise<void> writeQueue = kj::READY_NOW;
......@@ -1787,6 +1791,9 @@ public:
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()");
}
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
};
class HttpDiscardingEntityWriter final: public kj::AsyncOutputStream {
......@@ -1797,6 +1804,9 @@ public:
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return kj::READY_NOW;
}
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
};
class HttpFixedLengthEntityWriter final: public kj::AsyncOutputStream {
......@@ -1877,6 +1887,10 @@ public:
return kj::mv(promise);
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
private:
HttpOutputStream& inner;
uint64_t length;
......@@ -1960,6 +1974,10 @@ public:
}
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
private:
HttpOutputStream& inner;
};
......@@ -3382,6 +3400,22 @@ public:
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
void shutdownWrite() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->shutdownWrite();
......@@ -3456,6 +3490,22 @@ public:
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
public:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncOutputStream>> stream;
......@@ -3865,6 +3915,9 @@ public:
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return kj::READY_NOW;
}
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
// We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method.
};
......
......@@ -180,6 +180,10 @@ public:
return writeInternal(pieces[0], pieces.slice(1, pieces.size()));
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
void shutdownWrite() override {
KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment