Unverified Commit bd54ed6e authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #787 from capnproto/fix-http-client-adapter-cancelation

Fix HttpClient-from-HttpService wrapper prematurely cancelling service promise.
parents 59267d01 3fcd0f46
......@@ -1459,7 +1459,7 @@ KJ_TEST("HttpInputStream responses") {
KJ_CONTEXT(testCase.raw);
KJ_ASSERT(input->awaitNextMessage().wait(waitScope));
auto resp = input->readResponse(testCase.method).wait(waitScope);
KJ_EXPECT(resp.statusCode == testCase.statusCode);
KJ_EXPECT(resp.statusText == testCase.statusText);
......@@ -2610,13 +2610,6 @@ KJ_TEST("newHttpService from HttpClient WebSockets") {
.then([&]() { return backPipe.ends[1]->write({WEBSOCKET_REPLY_MESSAGE}); })
.then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_CLOSE); })
.then([&]() { return backPipe.ends[1]->write({WEBSOCKET_REPLY_CLOSE}); })
// expect EOF
.then([&]() { return backPipe.ends[1]->readAllBytes(); })
.then([&](kj::ArrayPtr<byte> content) {
KJ_EXPECT(content.size() == 0);
// Send EOF.
backPipe.ends[1]->shutdownWrite();
})
.eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
{
......
......@@ -2484,9 +2484,9 @@ static kj::Promise<void> pumpWebSocketLoop(WebSocket& from, WebSocket& to) {
.then([&from,&to]() { return pumpWebSocketLoop(from, to); });
}
KJ_CASE_ONEOF(close, WebSocket::Close) {
// Once a close has passed through, the pump is complete.
return to.close(close.code, close.reason)
.attach(kj::mv(close))
.then([&from,&to]() { return pumpWebSocketLoop(from, to); });
.attach(kj::mv(close));
}
}
KJ_UNREACHABLE;
......@@ -3884,9 +3884,16 @@ public:
auto pipe = newOneWayPipe(expectedBodySize);
auto paf = kj::newPromiseAndFulfiller<Response>();
auto responder = kj::refcounted<ResponseImpl>(method, kj::mv(paf.fulfiller));
auto promise = service.request(method, urlCopy, *headersCopy, *pipe.in, *responder);
responder->setPromise(promise.attach(kj::mv(pipe.in), kj::mv(urlCopy), kj::mv(headersCopy)));
auto responder = kj::heap<ResponseImpl>(method, kj::mv(paf.fulfiller));
auto promise = kj::evalLater([this, method,
urlCopy = kj::mv(urlCopy),
headersCopy = kj::mv(headersCopy),
pipeIn = kj::mv(pipe.in),
&responder = *responder]() mutable {
auto promise = service.request(method, urlCopy, *headersCopy, *pipeIn, responder);
return promise.attach(kj::mv(pipeIn), kj::mv(urlCopy), kj::mv(headersCopy));
});
responder->setPromise(kj::mv(promise));
return {
kj::mv(pipe.out),
......@@ -3906,10 +3913,16 @@ public:
KJ_DASSERT(headersCopy->isWebSocket());
auto paf = kj::newPromiseAndFulfiller<WebSocketResponse>();
auto responder = kj::refcounted<WebSocketResponseImpl>(kj::mv(paf.fulfiller));
auto in = kj::heap<NullInputStream>();
auto promise = service.request(HttpMethod::GET, urlCopy, *headersCopy, *in, *responder);
responder->setPromise(promise.attach(kj::mv(in), kj::mv(urlCopy), kj::mv(headersCopy)));
auto responder = kj::heap<WebSocketResponseImpl>(kj::mv(paf.fulfiller));
auto promise = kj::evalLater([this,
urlCopy = kj::mv(urlCopy),
headersCopy = kj::mv(headersCopy),
&responder = *responder]() mutable {
auto in = kj::heap<NullInputStream>();
auto promise = service.request(HttpMethod::GET, urlCopy, *headersCopy, *in, responder);
return promise.attach(kj::mv(in), kj::mv(urlCopy), kj::mv(headersCopy));
});
responder->setPromise(kj::mv(promise));
return paf.promise.attach(kj::mv(responder));
}
......@@ -3921,7 +3934,64 @@ public:
private:
HttpService& service;
class ResponseImpl final: public HttpService::Response, public kj::Refcounted {
class DelayedEofInputStream final: public kj::AsyncInputStream {
// An AsyncInputStream wrapper that, when it reaches EOF, delays the final read until some
// promise completes.
public:
DelayedEofInputStream(kj::Own<kj::AsyncInputStream> inner, kj::Promise<void> completionTask)
: inner(kj::mv(inner)), completionTask(kj::mv(completionTask)) {}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return inner->tryRead(buffer, minBytes, maxBytes)
.then([this, minBytes](size_t amount) -> kj::Promise<size_t> {
if (amount < minBytes) {
// Must have reached EOF.
KJ_IF_MAYBE(t, completionTask) {
// Delay until completion.
auto result = t->then([amount]() { return amount; });
completionTask = nullptr;
return result;
} else {
// Must have called tryRead() again after we already signaled EOF. Fine.
return amount;
}
} else {
return amount;
}
});
}
kj::Maybe<uint64_t> tryGetLength() override {
return inner->tryGetLength();
}
kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
return inner->pumpTo(output, amount)
.then([this,amount](uint64_t actual) -> kj::Promise<uint64_t> {
if (actual < amount) {
// Must have reached EOF.
KJ_IF_MAYBE(t, completionTask) {
// Delay until completion.
auto result = t->then([amount]() { return amount; });
completionTask = nullptr;
return result;
} else {
// Must have called tryRead() again after we already signaled EOF. Fine.
return amount;
}
} else {
return amount;
}
});
}
private:
kj::Own<kj::AsyncInputStream> inner;
kj::Maybe<kj::Promise<void>> completionTask;
};
class ResponseImpl final: public HttpService::Response {
public:
ResponseImpl(kj::HttpMethod method,
kj::Own<kj::PromiseFulfiller<HttpClient::Response>> fulfiller)
......@@ -3947,18 +4017,30 @@ private:
auto statusTextCopy = kj::str(statusText);
auto headersCopy = kj::heap(headers.clone());
if (method == kj::HttpMethod::HEAD) {
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
kj::heap<NullInputStream>(expectedBodySize)
.attach(kj::addRef(*this), kj::mv(statusTextCopy), kj::mv(headersCopy))
});
if (method == kj::HttpMethod::HEAD || expectedBodySize.orDefault(1) == 0) {
// We're not expecting any body. We need to delay reporting completion to the client until
// the server side has actually returned from the service method, otherwise we may
// prematurely cancel it.
task = task.then([this,statusCode,statusTextCopy=kj::mv(statusTextCopy),
headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable {
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
kj::heap<NullInputStream>(expectedBodySize)
.attach(kj::mv(statusTextCopy), kj::mv(headersCopy))
});
}).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
return kj::heap<NullOutputStream>();
} else {
auto pipe = newOneWayPipe(expectedBodySize);
// Wrap the stream in a wrapper that delays the last read (the one that signals EOF) until
// the service's request promise has finished.
auto wrapper = kj::heap<DelayedEofInputStream>(kj::mv(pipe.in), kj::mv(task));
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
pipe.in.attach(kj::addRef(*this), kj::mv(statusTextCopy), kj::mv(headersCopy))
wrapper.attach(kj::mv(statusTextCopy), kj::mv(headersCopy))
});
return kj::mv(pipe.out);
}
......@@ -3974,7 +4056,82 @@ private:
kj::Promise<void> task = nullptr;
};
class WebSocketResponseImpl final: public HttpService::Response, public kj::Refcounted {
class DelayedCloseWebSocket final: public WebSocket {
// A WebSocket wrapper that, when it reaches Close (in both directions), delays the final close
// operation until some promise completes.
public:
DelayedCloseWebSocket(kj::Own<kj::WebSocket> inner, kj::Promise<void> completionTask)
: inner(kj::mv(inner)), completionTask(kj::mv(completionTask)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return inner->send(message);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return inner->send(message);
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
return inner->close(code, reason)
.then([this]() {
return afterSendClosed();
});
}
kj::Promise<void> disconnect() override {
return inner->disconnect();
}
kj::Promise<Message> receive() override {
return inner->receive().then([this](Message&& message) -> kj::Promise<Message> {
if (message.is<WebSocket::Close>()) {
return afterReceiveClosed()
.then([message = kj::mv(message)]() mutable { return kj::mv(message); });
}
return kj::mv(message);
});
}
kj::Promise<void> pumpTo(WebSocket& other) override {
return inner->pumpTo(other).then([this]() {
return afterReceiveClosed();
});
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
return other.pumpTo(*inner).then([this]() {
return afterSendClosed();
});
}
private:
kj::Own<kj::WebSocket> inner;
kj::Maybe<kj::Promise<void>> completionTask;
bool sentClose = false;
bool receivedClose = false;
kj::Promise<void> afterSendClosed() {
sentClose = true;
if (receivedClose) {
KJ_IF_MAYBE(t, completionTask) {
auto result = kj::mv(*t);
completionTask = nullptr;
return result;
}
}
return kj::READY_NOW;
}
kj::Promise<void> afterReceiveClosed() {
receivedClose = true;
if (sentClose) {
KJ_IF_MAYBE(t, completionTask) {
auto result = kj::mv(*t);
completionTask = nullptr;
return result;
}
}
return kj::READY_NOW;
}
};
class WebSocketResponseImpl final: public HttpService::Response {
public:
WebSocketResponseImpl(kj::Own<kj::PromiseFulfiller<HttpClient::WebSocketResponse>> fulfiller)
: fulfiller(kj::mv(fulfiller)) {}
......@@ -3999,12 +4156,34 @@ private:
auto statusTextCopy = kj::str(statusText);
auto headersCopy = kj::heap(headers.clone());
auto pipe = newOneWayPipe(expectedBodySize);
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
pipe.in.attach(kj::addRef(*this), kj::mv(statusTextCopy), kj::mv(headersCopy))
});
return kj::mv(pipe.out);
if (expectedBodySize.orDefault(1) == 0) {
// We're not expecting any body. We need to delay reporting completion to the client until
// the server side has actually returned from the service method, otherwise we may
// prematurely cancel it.
task = task.then([this,statusCode,statusTextCopy=kj::mv(statusTextCopy),
headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable {
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
kj::Own<AsyncInputStream>(kj::heap<NullInputStream>(expectedBodySize)
.attach(kj::mv(statusTextCopy), kj::mv(headersCopy)))
});
}).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
return kj::heap<NullOutputStream>();
} else {
auto pipe = newOneWayPipe(expectedBodySize);
// Wrap the stream in a wrapper that delays the last read (the one that signals EOF) until
// the service's request promise has finished.
kj::Own<AsyncInputStream> wrapper =
kj::heap<DelayedEofInputStream>(kj::mv(pipe.in), kj::mv(task));
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
wrapper.attach(kj::mv(statusTextCopy), kj::mv(headersCopy))
});
return kj::mv(pipe.out);
}
}
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
......@@ -4014,9 +4193,14 @@ private:
auto headersCopy = kj::heap(headers.clone());
auto pipe = newWebSocketPipe();
// Wrap the client-side WebSocket in a wrapper that delays clean close of the WebSocket until
// the service's request promise has finished.
kj::Own<WebSocket> wrapper =
kj::heap<DelayedCloseWebSocket>(kj::mv(pipe.ends[0]), kj::mv(task));
fulfiller->fulfill({
101, "Switching Protocols", headersCopy.get(),
pipe.ends[0].attach(kj::addRef(*this), kj::mv(headersCopy))
wrapper.attach(kj::mv(headersCopy))
});
return kj::mv(pipe.ends[1]);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment