Commit 598eb978 authored by Harris Hancock's avatar Harris Hancock

Make newPromisedStream() generally available

This moves Promise{Io,Output}Stream from http.c++ to PromisedAsync{Io,Output}Stream in async-io.c++ and exposes them via newPromisedStream() so other code can use them.

No substantive code changes, though I changed a couple of the `public` access specifiers to `private`.
parent 000e2a99
...@@ -1821,6 +1821,210 @@ Tee newTee(Own<AsyncInputStream> input, uint64_t limit) { ...@@ -1821,6 +1821,210 @@ Tee newTee(Own<AsyncInputStream> input, uint64_t limit) {
return { { mv(branch1), mv(branch2) } }; return { { mv(branch1), mv(branch2) } };
} }
namespace {
class PromisedAsyncIoStream final: public kj::AsyncIoStream, private kj::TaskSet::ErrorHandler {
// An AsyncIoStream which waits for a promise to resolve then forwards all calls to the promised
// stream.
public:
PromisedAsyncIoStream(kj::Promise<kj::Own<AsyncIoStream>> promise)
: promise(promise.then([this](kj::Own<AsyncIoStream> result) {
stream = kj::mv(result);
}).fork()),
tasks(*this) {}
kj::Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->read(buffer, minBytes, maxBytes);
} else {
return promise.addBranch().then([this,buffer,minBytes,maxBytes]() {
return KJ_ASSERT_NONNULL(stream)->read(buffer, minBytes, maxBytes);
});
}
}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryRead(buffer, minBytes, maxBytes);
} else {
return promise.addBranch().then([this,buffer,minBytes,maxBytes]() {
return KJ_ASSERT_NONNULL(stream)->tryRead(buffer, minBytes, maxBytes);
});
}
}
kj::Maybe<uint64_t> tryGetLength() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryGetLength();
} else {
return nullptr;
}
}
kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->pumpTo(output, amount);
} else {
return promise.addBranch().then([this,&output,amount]() {
return KJ_ASSERT_NONNULL(stream)->pumpTo(output, amount);
});
}
}
kj::Promise<void> write(const void* buffer, size_t size) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(buffer, size);
} else {
return promise.addBranch().then([this,buffer,size]() {
return KJ_ASSERT_NONNULL(stream)->write(buffer, size);
});
}
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(pieces);
} else {
return promise.addBranch().then([this,pieces]() {
return KJ_ASSERT_NONNULL(stream)->write(pieces);
});
}
}
kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
KJ_IF_MAYBE(s, stream) {
// Call input.pumpTo() on the resolved stream instead, so that if it does some dynamic_casts
// or whatnot to detect stream types it can retry those on the inner stream.
return input.pumpTo(**s, amount);
} else {
return promise.addBranch().then([this,&input,amount]() {
// Here we actually have no choice but to call input.pumpTo() because if we called
// tryPumpFrom(input, amount) and it returned nullptr, what would we do? It's too late for
// us to return nullptr. But the thing about dynamic_cast also applies.
return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount);
});
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
void shutdownWrite() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->shutdownWrite();
} else {
tasks.add(promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->shutdownWrite();
}));
}
}
void abortRead() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->abortRead();
} else {
tasks.add(promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->abortRead();
}));
}
}
private:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncIoStream>> stream;
kj::TaskSet tasks;
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, exception);
}
};
class PromisedAsyncOutputStream final: public kj::AsyncOutputStream {
// An AsyncOutputStream which waits for a promise to resolve then forwards all calls to the
// promised stream.
//
// TODO(cleanup): Can this share implementation with PromiseIoStream? Seems hard.
public:
PromisedAsyncOutputStream(kj::Promise<kj::Own<AsyncOutputStream>> promise)
: promise(promise.then([this](kj::Own<AsyncOutputStream> result) {
stream = kj::mv(result);
}).fork()) {}
kj::Promise<void> write(const void* buffer, size_t size) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(buffer, size);
} else {
return promise.addBranch().then([this,buffer,size]() {
return KJ_ASSERT_NONNULL(stream)->write(buffer, size);
});
}
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(pieces);
} else {
return promise.addBranch().then([this,pieces]() {
return KJ_ASSERT_NONNULL(stream)->write(pieces);
});
}
}
kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryPumpFrom(input, amount);
} else {
return promise.addBranch().then([this,&input,amount]() {
// Call input.pumpTo() on the resolved stream instead.
return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount);
});
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
private:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncOutputStream>> stream;
};
} // namespace
Own<AsyncOutputStream> newPromisedStream(Promise<Own<AsyncOutputStream>> promise) {
return heap<PromisedAsyncOutputStream>(kj::mv(promise));
}
Own<AsyncIoStream> newPromisedStream(Promise<Own<AsyncIoStream>> promise) {
return heap<PromisedAsyncIoStream>(kj::mv(promise));
}
Promise<void> AsyncCapabilityStream::writeWithFds( Promise<void> AsyncCapabilityStream::writeWithFds(
ArrayPtr<const byte> data, ArrayPtr<const ArrayPtr<const byte>> moreData, ArrayPtr<const byte> data, ArrayPtr<const ArrayPtr<const byte>> moreData,
ArrayPtr<const AutoCloseFd> fds) { ArrayPtr<const AutoCloseFd> fds) {
......
...@@ -281,6 +281,11 @@ Tee newTee(Own<AsyncInputStream> input, uint64_t limit = kj::maxValue); ...@@ -281,6 +281,11 @@ Tee newTee(Own<AsyncInputStream> input, uint64_t limit = kj::maxValue);
// //
// It is recommended that you use a more conservative value for `limit` than the default. // It is recommended that you use a more conservative value for `limit` than the default.
Own<AsyncOutputStream> newPromisedStream(Promise<Own<AsyncOutputStream>> promise);
Own<AsyncIoStream> newPromisedStream(Promise<Own<AsyncIoStream>> promise);
// Constructs an Async*Stream which waits for a promise to resolve, then forwards all calls to the
// promised stream.
class ConnectionReceiver { class ConnectionReceiver {
// Represents a server socket listening on a port. // Represents a server socket listening on a port.
......
...@@ -3504,202 +3504,6 @@ kj::Own<HttpClient> newHttpClient( ...@@ -3504,202 +3504,6 @@ kj::Own<HttpClient> newHttpClient(
namespace { namespace {
class PromiseIoStream final: public kj::AsyncIoStream, private kj::TaskSet::ErrorHandler {
// An AsyncIoStream which waits for a promise to resolve then forwards all calls to the promised
// stream.
//
// TODO(cleanup): Make this more broadly available.
public:
PromiseIoStream(kj::Promise<kj::Own<AsyncIoStream>> promise)
: promise(promise.then([this](kj::Own<AsyncIoStream> result) {
stream = kj::mv(result);
}).fork()),
tasks(*this) {}
kj::Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->read(buffer, minBytes, maxBytes);
} else {
return promise.addBranch().then([this,buffer,minBytes,maxBytes]() {
return KJ_ASSERT_NONNULL(stream)->read(buffer, minBytes, maxBytes);
});
}
}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryRead(buffer, minBytes, maxBytes);
} else {
return promise.addBranch().then([this,buffer,minBytes,maxBytes]() {
return KJ_ASSERT_NONNULL(stream)->tryRead(buffer, minBytes, maxBytes);
});
}
}
kj::Maybe<uint64_t> tryGetLength() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryGetLength();
} else {
return nullptr;
}
}
kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->pumpTo(output, amount);
} else {
return promise.addBranch().then([this,&output,amount]() {
return KJ_ASSERT_NONNULL(stream)->pumpTo(output, amount);
});
}
}
kj::Promise<void> write(const void* buffer, size_t size) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(buffer, size);
} else {
return promise.addBranch().then([this,buffer,size]() {
return KJ_ASSERT_NONNULL(stream)->write(buffer, size);
});
}
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(pieces);
} else {
return promise.addBranch().then([this,pieces]() {
return KJ_ASSERT_NONNULL(stream)->write(pieces);
});
}
}
kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
KJ_IF_MAYBE(s, stream) {
// Call input.pumpTo() on the resolved stream instead, so that if it does some dynamic_casts
// or whatnot to detect stream types it can retry those on the inner stream.
return input.pumpTo(**s, amount);
} else {
return promise.addBranch().then([this,&input,amount]() {
// Here we actually have no choice but to call input.pumpTo() because if we called
// tryPumpFrom(input, amount) and it returned nullptr, what would we do? It's too late for
// us to return nullptr. But the thing about dynamic_cast also applies.
return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount);
});
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
void shutdownWrite() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->shutdownWrite();
} else {
tasks.add(promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->shutdownWrite();
}));
}
}
void abortRead() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->abortRead();
} else {
tasks.add(promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->abortRead();
}));
}
}
public:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncIoStream>> stream;
kj::TaskSet tasks;
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, exception);
}
};
class PromiseOutputStream final: public kj::AsyncOutputStream {
// An AsyncOutputStream which waits for a promise to resolve then forwards all calls to the
// promised stream.
//
// TODO(cleanup): Make this more broadly available.
// TODO(cleanup): Can this share implementation with PromiseIoStream? Seems hard.
public:
PromiseOutputStream(kj::Promise<kj::Own<AsyncOutputStream>> promise)
: promise(promise.then([this](kj::Own<AsyncOutputStream> result) {
stream = kj::mv(result);
}).fork()) {}
kj::Promise<void> write(const void* buffer, size_t size) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(buffer, size);
} else {
return promise.addBranch().then([this,buffer,size]() {
return KJ_ASSERT_NONNULL(stream)->write(buffer, size);
});
}
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(pieces);
} else {
return promise.addBranch().then([this,pieces]() {
return KJ_ASSERT_NONNULL(stream)->write(pieces);
});
}
}
kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryPumpFrom(input, amount);
} else {
return promise.addBranch().then([this,&input,amount]() {
// Call input.pumpTo() on the resolved stream instead.
return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount);
});
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
public:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncOutputStream>> stream;
};
class NetworkAddressHttpClient final: public HttpClient { class NetworkAddressHttpClient final: public HttpClient {
public: public:
NetworkAddressHttpClient(kj::Timer& timer, HttpHeaderTable& responseHeaderTable, NetworkAddressHttpClient(kj::Timer& timer, HttpHeaderTable& responseHeaderTable,
...@@ -3797,7 +3601,7 @@ private: ...@@ -3797,7 +3601,7 @@ private:
kj::Own<RefcountedClient> getClient() { kj::Own<RefcountedClient> getClient() {
for (;;) { for (;;) {
if (availableClients.empty()) { if (availableClients.empty()) {
auto stream = kj::heap<PromiseIoStream>(address->connect()); auto stream = newPromisedStream(address->connect());
return kj::refcounted<RefcountedClient>(*this, return kj::refcounted<RefcountedClient>(*this,
kj::heap<HttpClientImpl>(responseHeaderTable, kj::mv(stream), settings)); kj::heap<HttpClientImpl>(responseHeaderTable, kj::mv(stream), settings));
} else { } else {
...@@ -3898,7 +3702,7 @@ public: ...@@ -3898,7 +3702,7 @@ public:
auto split = combined.split(); auto split = combined.split();
return { return {
kj::heap<PromiseOutputStream>(kj::mv(kj::get<0>(split))), newPromisedStream(kj::mv(kj::get<0>(split))),
kj::mv(kj::get<1>(split)) kj::mv(kj::get<1>(split))
}; };
} }
...@@ -4110,7 +3914,7 @@ public: ...@@ -4110,7 +3914,7 @@ public:
auto split = combined.split(); auto split = combined.split();
pendingRequests.push(kj::mv(paf.fulfiller)); pendingRequests.push(kj::mv(paf.fulfiller));
fireCountChanged(); fireCountChanged();
return { kj::heap<PromiseOutputStream>(kj::mv(kj::get<0>(split))), kj::mv(kj::get<1>(split)) }; return { newPromisedStream(kj::mv(kj::get<0>(split))), kj::mv(kj::get<1>(split)) };
} }
kj::Promise<WebSocketResponse> openWebSocket( kj::Promise<WebSocketResponse> openWebSocket(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment