Unverified Commit 22c0fdf6 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #917 from capnproto/harris/move-promise-streams-to-async-io

Make newPromisedStream() generally available
parents 000e2a99 598eb978
......@@ -1821,6 +1821,210 @@ Tee newTee(Own<AsyncInputStream> input, uint64_t limit) {
return { { mv(branch1), mv(branch2) } };
}
namespace {
class PromisedAsyncIoStream final: public kj::AsyncIoStream, private kj::TaskSet::ErrorHandler {
// An AsyncIoStream which waits for a promise to resolve then forwards all calls to the promised
// stream.
public:
PromisedAsyncIoStream(kj::Promise<kj::Own<AsyncIoStream>> promise)
: promise(promise.then([this](kj::Own<AsyncIoStream> result) {
stream = kj::mv(result);
}).fork()),
tasks(*this) {}
kj::Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->read(buffer, minBytes, maxBytes);
} else {
return promise.addBranch().then([this,buffer,minBytes,maxBytes]() {
return KJ_ASSERT_NONNULL(stream)->read(buffer, minBytes, maxBytes);
});
}
}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryRead(buffer, minBytes, maxBytes);
} else {
return promise.addBranch().then([this,buffer,minBytes,maxBytes]() {
return KJ_ASSERT_NONNULL(stream)->tryRead(buffer, minBytes, maxBytes);
});
}
}
kj::Maybe<uint64_t> tryGetLength() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryGetLength();
} else {
return nullptr;
}
}
kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->pumpTo(output, amount);
} else {
return promise.addBranch().then([this,&output,amount]() {
return KJ_ASSERT_NONNULL(stream)->pumpTo(output, amount);
});
}
}
kj::Promise<void> write(const void* buffer, size_t size) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(buffer, size);
} else {
return promise.addBranch().then([this,buffer,size]() {
return KJ_ASSERT_NONNULL(stream)->write(buffer, size);
});
}
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(pieces);
} else {
return promise.addBranch().then([this,pieces]() {
return KJ_ASSERT_NONNULL(stream)->write(pieces);
});
}
}
kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
KJ_IF_MAYBE(s, stream) {
// Call input.pumpTo() on the resolved stream instead, so that if it does some dynamic_casts
// or whatnot to detect stream types it can retry those on the inner stream.
return input.pumpTo(**s, amount);
} else {
return promise.addBranch().then([this,&input,amount]() {
// Here we actually have no choice but to call input.pumpTo() because if we called
// tryPumpFrom(input, amount) and it returned nullptr, what would we do? It's too late for
// us to return nullptr. But the thing about dynamic_cast also applies.
return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount);
});
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
void shutdownWrite() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->shutdownWrite();
} else {
tasks.add(promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->shutdownWrite();
}));
}
}
void abortRead() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->abortRead();
} else {
tasks.add(promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->abortRead();
}));
}
}
private:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncIoStream>> stream;
kj::TaskSet tasks;
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, exception);
}
};
class PromisedAsyncOutputStream final: public kj::AsyncOutputStream {
// An AsyncOutputStream which waits for a promise to resolve then forwards all calls to the
// promised stream.
//
// TODO(cleanup): Can this share implementation with PromiseIoStream? Seems hard.
public:
PromisedAsyncOutputStream(kj::Promise<kj::Own<AsyncOutputStream>> promise)
: promise(promise.then([this](kj::Own<AsyncOutputStream> result) {
stream = kj::mv(result);
}).fork()) {}
kj::Promise<void> write(const void* buffer, size_t size) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(buffer, size);
} else {
return promise.addBranch().then([this,buffer,size]() {
return KJ_ASSERT_NONNULL(stream)->write(buffer, size);
});
}
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(pieces);
} else {
return promise.addBranch().then([this,pieces]() {
return KJ_ASSERT_NONNULL(stream)->write(pieces);
});
}
}
kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryPumpFrom(input, amount);
} else {
return promise.addBranch().then([this,&input,amount]() {
// Call input.pumpTo() on the resolved stream instead.
return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount);
});
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
private:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncOutputStream>> stream;
};
} // namespace
Own<AsyncOutputStream> newPromisedStream(Promise<Own<AsyncOutputStream>> promise) {
return heap<PromisedAsyncOutputStream>(kj::mv(promise));
}
Own<AsyncIoStream> newPromisedStream(Promise<Own<AsyncIoStream>> promise) {
return heap<PromisedAsyncIoStream>(kj::mv(promise));
}
Promise<void> AsyncCapabilityStream::writeWithFds(
ArrayPtr<const byte> data, ArrayPtr<const ArrayPtr<const byte>> moreData,
ArrayPtr<const AutoCloseFd> fds) {
......
......@@ -281,6 +281,11 @@ Tee newTee(Own<AsyncInputStream> input, uint64_t limit = kj::maxValue);
//
// It is recommended that you use a more conservative value for `limit` than the default.
Own<AsyncOutputStream> newPromisedStream(Promise<Own<AsyncOutputStream>> promise);
Own<AsyncIoStream> newPromisedStream(Promise<Own<AsyncIoStream>> promise);
// Constructs an Async*Stream which waits for a promise to resolve, then forwards all calls to the
// promised stream.
class ConnectionReceiver {
// Represents a server socket listening on a port.
......
......@@ -3504,202 +3504,6 @@ kj::Own<HttpClient> newHttpClient(
namespace {
class PromiseIoStream final: public kj::AsyncIoStream, private kj::TaskSet::ErrorHandler {
// An AsyncIoStream which waits for a promise to resolve then forwards all calls to the promised
// stream.
//
// TODO(cleanup): Make this more broadly available.
public:
PromiseIoStream(kj::Promise<kj::Own<AsyncIoStream>> promise)
: promise(promise.then([this](kj::Own<AsyncIoStream> result) {
stream = kj::mv(result);
}).fork()),
tasks(*this) {}
kj::Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->read(buffer, minBytes, maxBytes);
} else {
return promise.addBranch().then([this,buffer,minBytes,maxBytes]() {
return KJ_ASSERT_NONNULL(stream)->read(buffer, minBytes, maxBytes);
});
}
}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryRead(buffer, minBytes, maxBytes);
} else {
return promise.addBranch().then([this,buffer,minBytes,maxBytes]() {
return KJ_ASSERT_NONNULL(stream)->tryRead(buffer, minBytes, maxBytes);
});
}
}
kj::Maybe<uint64_t> tryGetLength() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryGetLength();
} else {
return nullptr;
}
}
kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->pumpTo(output, amount);
} else {
return promise.addBranch().then([this,&output,amount]() {
return KJ_ASSERT_NONNULL(stream)->pumpTo(output, amount);
});
}
}
kj::Promise<void> write(const void* buffer, size_t size) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(buffer, size);
} else {
return promise.addBranch().then([this,buffer,size]() {
return KJ_ASSERT_NONNULL(stream)->write(buffer, size);
});
}
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(pieces);
} else {
return promise.addBranch().then([this,pieces]() {
return KJ_ASSERT_NONNULL(stream)->write(pieces);
});
}
}
kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
KJ_IF_MAYBE(s, stream) {
// Call input.pumpTo() on the resolved stream instead, so that if it does some dynamic_casts
// or whatnot to detect stream types it can retry those on the inner stream.
return input.pumpTo(**s, amount);
} else {
return promise.addBranch().then([this,&input,amount]() {
// Here we actually have no choice but to call input.pumpTo() because if we called
// tryPumpFrom(input, amount) and it returned nullptr, what would we do? It's too late for
// us to return nullptr. But the thing about dynamic_cast also applies.
return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount);
});
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
void shutdownWrite() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->shutdownWrite();
} else {
tasks.add(promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->shutdownWrite();
}));
}
}
void abortRead() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->abortRead();
} else {
tasks.add(promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->abortRead();
}));
}
}
public:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncIoStream>> stream;
kj::TaskSet tasks;
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, exception);
}
};
class PromiseOutputStream final: public kj::AsyncOutputStream {
// An AsyncOutputStream which waits for a promise to resolve then forwards all calls to the
// promised stream.
//
// TODO(cleanup): Make this more broadly available.
// TODO(cleanup): Can this share implementation with PromiseIoStream? Seems hard.
public:
PromiseOutputStream(kj::Promise<kj::Own<AsyncOutputStream>> promise)
: promise(promise.then([this](kj::Own<AsyncOutputStream> result) {
stream = kj::mv(result);
}).fork()) {}
kj::Promise<void> write(const void* buffer, size_t size) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(buffer, size);
} else {
return promise.addBranch().then([this,buffer,size]() {
return KJ_ASSERT_NONNULL(stream)->write(buffer, size);
});
}
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->write(pieces);
} else {
return promise.addBranch().then([this,pieces]() {
return KJ_ASSERT_NONNULL(stream)->write(pieces);
});
}
}
kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
KJ_IF_MAYBE(s, stream) {
return s->get()->tryPumpFrom(input, amount);
} else {
return promise.addBranch().then([this,&input,amount]() {
// Call input.pumpTo() on the resolved stream instead.
return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount);
});
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
public:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncOutputStream>> stream;
};
class NetworkAddressHttpClient final: public HttpClient {
public:
NetworkAddressHttpClient(kj::Timer& timer, HttpHeaderTable& responseHeaderTable,
......@@ -3797,7 +3601,7 @@ private:
kj::Own<RefcountedClient> getClient() {
for (;;) {
if (availableClients.empty()) {
auto stream = kj::heap<PromiseIoStream>(address->connect());
auto stream = newPromisedStream(address->connect());
return kj::refcounted<RefcountedClient>(*this,
kj::heap<HttpClientImpl>(responseHeaderTable, kj::mv(stream), settings));
} else {
......@@ -3898,7 +3702,7 @@ public:
auto split = combined.split();
return {
kj::heap<PromiseOutputStream>(kj::mv(kj::get<0>(split))),
newPromisedStream(kj::mv(kj::get<0>(split))),
kj::mv(kj::get<1>(split))
};
}
......@@ -4110,7 +3914,7 @@ public:
auto split = combined.split();
pendingRequests.push(kj::mv(paf.fulfiller));
fireCountChanged();
return { kj::heap<PromiseOutputStream>(kj::mv(kj::get<0>(split))), kj::mv(kj::get<1>(split)) };
return { newPromisedStream(kj::mv(kj::get<0>(split))), kj::mv(kj::get<1>(split)) };
}
kj::Promise<WebSocketResponse> openWebSocket(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment