Commit 90d48343 authored by Kenton Varda's avatar Kenton Varda

Fix HttpClient-from-HttpService wrapper prematurely cancelling service promise.

The client app will typically discard the returned response body upon reading EOF. However, the server app may not actually be "done" with the service callback yet at this point. Usually it completes very soon after, but it may need another turn or two of the event loop. If the client discards the response body stream, the server-side promise is discarded, cancelling whatever was left. This is awkward, so we should instead delay the client from seeing EOF until the server has actually finished up.
parent 59267d01
......@@ -3884,9 +3884,16 @@ public:
auto pipe = newOneWayPipe(expectedBodySize);
auto paf = kj::newPromiseAndFulfiller<Response>();
auto responder = kj::refcounted<ResponseImpl>(method, kj::mv(paf.fulfiller));
auto promise = service.request(method, urlCopy, *headersCopy, *pipe.in, *responder);
responder->setPromise(promise.attach(kj::mv(pipe.in), kj::mv(urlCopy), kj::mv(headersCopy)));
auto responder = kj::heap<ResponseImpl>(method, kj::mv(paf.fulfiller));
auto promise = kj::evalLater([this, method,
urlCopy = kj::mv(urlCopy),
headersCopy = kj::mv(headersCopy),
pipeIn = kj::mv(pipe.in),
&responder = *responder]() mutable {
auto promise = service.request(method, urlCopy, *headersCopy, *pipeIn, responder);
return promise.attach(kj::mv(pipeIn), kj::mv(urlCopy), kj::mv(headersCopy));
});
responder->setPromise(kj::mv(promise));
return {
kj::mv(pipe.out),
......@@ -3906,10 +3913,16 @@ public:
KJ_DASSERT(headersCopy->isWebSocket());
auto paf = kj::newPromiseAndFulfiller<WebSocketResponse>();
auto responder = kj::refcounted<WebSocketResponseImpl>(kj::mv(paf.fulfiller));
auto responder = kj::heap<WebSocketResponseImpl>(kj::mv(paf.fulfiller));
auto promise = kj::evalLater([this,
urlCopy = kj::mv(urlCopy),
headersCopy = kj::mv(headersCopy),
&responder = *responder]() mutable {
auto in = kj::heap<NullInputStream>();
auto promise = service.request(HttpMethod::GET, urlCopy, *headersCopy, *in, *responder);
responder->setPromise(promise.attach(kj::mv(in), kj::mv(urlCopy), kj::mv(headersCopy)));
auto promise = service.request(HttpMethod::GET, urlCopy, *headersCopy, *in, responder);
return promise.attach(kj::mv(in), kj::mv(urlCopy), kj::mv(headersCopy));
});
responder->setPromise(kj::mv(promise));
return paf.promise.attach(kj::mv(responder));
}
......@@ -3921,7 +3934,64 @@ public:
private:
HttpService& service;
class ResponseImpl final: public HttpService::Response, public kj::Refcounted {
class DelayedEofInputStream final: public kj::AsyncInputStream {
// An AsyncInputStream wrapper that, when it reaches EOF, delays the final read until some
// promise completes.
public:
DelayedEofInputStream(kj::Own<kj::AsyncInputStream> inner, kj::Promise<void> completionTask)
: inner(kj::mv(inner)), completionTask(kj::mv(completionTask)) {}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return inner->tryRead(buffer, minBytes, maxBytes)
.then([this, minBytes](size_t amount) -> kj::Promise<size_t> {
if (amount < minBytes) {
// Must have reached EOF.
KJ_IF_MAYBE(t, completionTask) {
// Delay until completion.
auto result = t->then([amount]() { return amount; });
completionTask = nullptr;
return result;
} else {
// Must have called tryRead() again after we already signaled EOF. Fine.
return amount;
}
} else {
return amount;
}
});
}
kj::Maybe<uint64_t> tryGetLength() override {
return inner->tryGetLength();
}
kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
return inner->pumpTo(output, amount)
.then([this,amount](uint64_t actual) -> kj::Promise<uint64_t> {
if (actual < amount) {
// Must have reached EOF.
KJ_IF_MAYBE(t, completionTask) {
// Delay until completion.
auto result = t->then([amount]() { return amount; });
completionTask = nullptr;
return result;
} else {
// Must have called tryRead() again after we already signaled EOF. Fine.
return amount;
}
} else {
return amount;
}
});
}
private:
kj::Own<kj::AsyncInputStream> inner;
kj::Maybe<kj::Promise<void>> completionTask;
};
class ResponseImpl final: public HttpService::Response {
public:
ResponseImpl(kj::HttpMethod method,
kj::Own<kj::PromiseFulfiller<HttpClient::Response>> fulfiller)
......@@ -3947,18 +4017,30 @@ private:
auto statusTextCopy = kj::str(statusText);
auto headersCopy = kj::heap(headers.clone());
if (method == kj::HttpMethod::HEAD) {
if (method == kj::HttpMethod::HEAD || expectedBodySize.orDefault(1) == 0) {
// We're not expecting any body. We need to delay reporting completion to the client until
// the server side has actually returned from the service method, otherwise we may
// prematurely cancel it.
task = task.then([this,statusCode,statusTextCopy=kj::mv(statusTextCopy),
headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable {
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
kj::heap<NullInputStream>(expectedBodySize)
.attach(kj::addRef(*this), kj::mv(statusTextCopy), kj::mv(headersCopy))
.attach(kj::mv(statusTextCopy), kj::mv(headersCopy))
});
}).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
return kj::heap<NullOutputStream>();
} else {
auto pipe = newOneWayPipe(expectedBodySize);
// Wrap the stream in a wrapper that delays the last read (the one that signals EOF) until
// the service's request promise has finished.
auto wrapper = kj::heap<DelayedEofInputStream>(kj::mv(pipe.in), kj::mv(task));
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
pipe.in.attach(kj::addRef(*this), kj::mv(statusTextCopy), kj::mv(headersCopy))
wrapper.attach(kj::mv(statusTextCopy), kj::mv(headersCopy))
});
return kj::mv(pipe.out);
}
......@@ -3974,7 +4056,82 @@ private:
kj::Promise<void> task = nullptr;
};
class WebSocketResponseImpl final: public HttpService::Response, public kj::Refcounted {
class DelayedCloseWebSocket final: public WebSocket {
// A WebSocket wrapper that, when it reaches Close (in both directions), delays the final close
// operation until some promise completes.
public:
DelayedCloseWebSocket(kj::Own<kj::WebSocket> inner, kj::Promise<void> completionTask)
: inner(kj::mv(inner)), completionTask(kj::mv(completionTask)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return inner->send(message);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return inner->send(message);
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
return inner->close(code, reason)
.then([this]() {
return afterSendClosed();
});
}
kj::Promise<void> disconnect() override {
return inner->disconnect();
}
kj::Promise<Message> receive() override {
return inner->receive().then([this](Message&& message) -> kj::Promise<Message> {
if (message.is<WebSocket::Close>()) {
return afterReceiveClosed()
.then([message = kj::mv(message)]() mutable { return kj::mv(message); });
}
return kj::mv(message);
});
}
kj::Promise<void> pumpTo(WebSocket& other) override {
return inner->pumpTo(other).then([this]() {
return afterReceiveClosed();
});
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
return other.pumpTo(*inner).then([this]() {
return afterSendClosed();
});
}
private:
kj::Own<kj::WebSocket> inner;
kj::Maybe<kj::Promise<void>> completionTask;
bool sentClose = false;
bool receivedClose = false;
kj::Promise<void> afterSendClosed() {
sentClose = true;
if (receivedClose) {
KJ_IF_MAYBE(t, completionTask) {
auto result = kj::mv(*t);
completionTask = nullptr;
return result;
}
}
return kj::READY_NOW;
}
kj::Promise<void> afterReceiveClosed() {
receivedClose = true;
if (sentClose) {
KJ_IF_MAYBE(t, completionTask) {
auto result = kj::mv(*t);
completionTask = nullptr;
return result;
}
}
return kj::READY_NOW;
}
};
class WebSocketResponseImpl final: public HttpService::Response {
public:
WebSocketResponseImpl(kj::Own<kj::PromiseFulfiller<HttpClient::WebSocketResponse>> fulfiller)
: fulfiller(kj::mv(fulfiller)) {}
......@@ -3999,13 +4156,35 @@ private:
auto statusTextCopy = kj::str(statusText);
auto headersCopy = kj::heap(headers.clone());
if (expectedBodySize.orDefault(1) == 0) {
// We're not expecting any body. We need to delay reporting completion to the client until
// the server side has actually returned from the service method, otherwise we may
// prematurely cancel it.
task = task.then([this,statusCode,statusTextCopy=kj::mv(statusTextCopy),
headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable {
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
kj::Own<AsyncInputStream>(kj::heap<NullInputStream>(expectedBodySize)
.attach(kj::mv(statusTextCopy), kj::mv(headersCopy)))
});
}).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
return kj::heap<NullOutputStream>();
} else {
auto pipe = newOneWayPipe(expectedBodySize);
// Wrap the stream in a wrapper that delays the last read (the one that signals EOF) until
// the service's request promise has finished.
kj::Own<AsyncInputStream> wrapper =
kj::heap<DelayedEofInputStream>(kj::mv(pipe.in), kj::mv(task));
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
pipe.in.attach(kj::addRef(*this), kj::mv(statusTextCopy), kj::mv(headersCopy))
wrapper.attach(kj::mv(statusTextCopy), kj::mv(headersCopy))
});
return kj::mv(pipe.out);
}
}
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
// The caller of HttpClient is allowed to assume that the headers remain valid until the body
......@@ -4014,9 +4193,14 @@ private:
auto headersCopy = kj::heap(headers.clone());
auto pipe = newWebSocketPipe();
// Wrap the client-side WebSocket in a wrapper that delays clean close of the WebSocket until
// the service's request promise has finished.
kj::Own<WebSocket> wrapper =
kj::heap<DelayedCloseWebSocket>(kj::mv(pipe.ends[0]), kj::mv(task));
fulfiller->fulfill({
101, "Switching Protocols", headersCopy.get(),
pipe.ends[0].attach(kj::addRef(*this), kj::mv(headersCopy))
wrapper.attach(kj::mv(headersCopy))
});
return kj::mv(pipe.ends[1]);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment