Unverified Commit eefc0dd3 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #788 from capnproto/cancel-http

Make HttpServer detect client disconnect and cancel request processing
parents bd54ed6e 1325f3c2
......@@ -2103,5 +2103,98 @@ KJ_TEST("Userland tee buffer size limit") {
}
}
KJ_TEST("Userspace OneWayPipe whenWriteDisconnected()") {
kj::EventLoop loop;
WaitScope ws(loop);
auto pipe = newOneWayPipe();
auto abortedPromise = pipe.out->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(ws));
pipe.in = nullptr;
KJ_ASSERT(abortedPromise.poll(ws));
abortedPromise.wait(ws);
}
KJ_TEST("Userspace TwoWayPipe whenWriteDisconnected()") {
kj::EventLoop loop;
WaitScope ws(loop);
auto pipe = newTwoWayPipe();
auto abortedPromise = pipe.ends[0]->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(ws));
pipe.ends[1] = nullptr;
KJ_ASSERT(abortedPromise.poll(ws));
abortedPromise.wait(ws);
}
#if !_WIN32 // We don't currently support detecting disconnect with IOCP.
#if !__CYGWIN__ // TODO(soon): Figure out why whenWriteDisconnected() doesn't work on Cygwin.
KJ_TEST("OS OneWayPipe whenWriteDisconnected()") {
auto io = setupAsyncIo();
auto pipe = io.provider->newOneWayPipe();
pipe.out->write("foo", 3).wait(io.waitScope);
auto abortedPromise = pipe.out->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(io.waitScope));
pipe.in = nullptr;
KJ_ASSERT(abortedPromise.poll(io.waitScope));
abortedPromise.wait(io.waitScope);
}
KJ_TEST("OS TwoWayPipe whenWriteDisconnected()") {
auto io = setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
pipe.ends[0]->write("foo", 3).wait(io.waitScope);
pipe.ends[1]->write("bar", 3).wait(io.waitScope);
auto abortedPromise = pipe.ends[0]->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(io.waitScope));
pipe.ends[1] = nullptr;
KJ_ASSERT(abortedPromise.poll(io.waitScope));
abortedPromise.wait(io.waitScope);
char buffer[4];
KJ_ASSERT(pipe.ends[0]->tryRead(&buffer, sizeof(buffer), sizeof(buffer)).wait(io.waitScope) == 3);
buffer[3] = '\0';
KJ_EXPECT(buffer == "bar"_kj);
}
KJ_TEST("import socket FD that's already broken") {
auto io = setupAsyncIo();
int fds[2];
KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
KJ_SYSCALL(write(fds[1], "foo", 3));
KJ_SYSCALL(close(fds[1]));
auto stream = io.lowLevelProvider->wrapSocketFd(fds[0], LowLevelAsyncIoProvider::TAKE_OWNERSHIP);
auto abortedPromise = stream->whenWriteDisconnected();
KJ_ASSERT(abortedPromise.poll(io.waitScope));
abortedPromise.wait(io.waitScope);
char buffer[4];
KJ_ASSERT(stream->tryRead(&buffer, sizeof(buffer), sizeof(buffer)).wait(io.waitScope) == 3);
buffer[3] = '\0';
KJ_EXPECT(buffer == "foo"_kj);
}
#endif // !__CYGWIN__
#endif // !_WIN32
} // namespace
} // namespace kj
......@@ -179,6 +179,17 @@ public:
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(p, writeDisconnectedPromise) {
return p->addBranch();
} else {
auto fork = observer.whenWriteDisconnected().fork();
auto result = fork.addBranch();
writeDisconnectedPromise = kj::mv(fork);
return kj::mv(result);
}
}
void shutdownWrite() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// UnixAsyncIoProvider interface.
......@@ -290,6 +301,7 @@ public:
private:
UnixEventPort& eventPort;
UnixEventPort::FdObserver observer;
Maybe<ForkedPromise<void>> writeDisconnectedPromise;
Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
size_t alreadyRead) {
......
......@@ -311,6 +311,23 @@ public:
});
}
Promise<void> whenWriteDisconnected() override {
// Windows IOCP does not provide a direct, documented way to detect when the socket disconnects
// without actually doing a read or write. However, there is an undocoumented-but-stable
// ioctl called IOCTL_AFD_POLL which can be used for this purpose. In fact, select() is
// implemented in terms of this ioctl -- performed synchronously -- but it's entirely possible
// to put only one socket into the list and perform the ioctl asynchronously. Here's the
// source code for select() in Windows 2000 (not sure how this became public...):
//
// https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655
//
// And here's an interesting discussion: https://github.com/python-trio/trio/issues/52
//
// TODO(soon): Implement this with IOCTL_AFD_POLL. For now I'm leaving it unimplemented because
// I added this method for a Linux-only use case.
return NEVER_DONE;
}
void shutdownWrite() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// Win32AsyncIoProvider interface.
......
......@@ -228,6 +228,12 @@ public:
} else {
ownState = kj::heap<AbortedRead>();
state = *ownState;
readAborted = true;
KJ_IF_MAYBE(f, readAbortFulfiller) {
f->get()->fulfill();
readAbortFulfiller = nullptr;
}
}
}
......@@ -268,6 +274,21 @@ public:
}
}
Promise<void> whenWriteDisconnected() override {
if (readAborted) {
return kj::READY_NOW;
} else KJ_IF_MAYBE(p, readAbortPromise) {
return p->addBranch();
} else {
auto paf = newPromiseAndFulfiller<void>();
readAbortFulfiller = kj::mv(paf.fulfiller);
auto fork = paf.promise.fork();
auto result = fork.addBranch();
readAbortPromise = kj::mv(fork);
return result;
}
}
void shutdownWrite() override {
KJ_IF_MAYBE(s, state) {
s->shutdownWrite();
......@@ -285,6 +306,10 @@ private:
kj::Own<AsyncIoStream> ownState;
bool readAborted = false;
Maybe<Own<PromiseFulfiller<void>>> readAbortFulfiller = nullptr;
Maybe<ForkedPromise<void>> readAbortPromise = nullptr;
void endState(AsyncIoStream& obj) {
KJ_IF_MAYBE(s, state) {
if (s == &obj) {
......@@ -443,6 +468,10 @@ private:
KJ_FAIL_REQUIRE("can't shutdownWrite() until previous write() completes");
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private:
PromiseFulfiller<void>& fulfiller;
AsyncPipe& pipe;
......@@ -562,6 +591,10 @@ private:
KJ_FAIL_REQUIRE("can't shutdownWrite() until previous tryPumpFrom() completes");
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private:
PromiseFulfiller<uint64_t>& fulfiller;
AsyncPipe& pipe;
......@@ -733,6 +766,10 @@ private:
pipe.shutdownWrite();
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private:
PromiseFulfiller<size_t>& fulfiller;
AsyncPipe& pipe;
......@@ -901,6 +938,10 @@ private:
pipe.shutdownWrite();
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private:
PromiseFulfiller<uint64_t>& fulfiller;
AsyncPipe& pipe;
......@@ -937,6 +978,9 @@ private:
// ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped,
// which is not an error even if reads have been aborted.
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
};
class ShutdownedWrite final: public AsyncIoStream {
......@@ -966,6 +1010,9 @@ private:
// ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped,
// so it will only be called once anyhow.
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
};
};
......@@ -1013,6 +1060,10 @@ public:
return pipe->tryPumpFrom(input, amount);
}
Promise<void> whenWriteDisconnected() override {
return pipe->whenWriteDisconnected();
}
private:
Own<AsyncPipe> pipe;
UnwindDetector unwind;
......@@ -1049,6 +1100,9 @@ public:
AsyncInputStream& input, uint64_t amount) override {
return out->tryPumpFrom(input, amount);
}
Promise<void> whenWriteDisconnected() override {
return out->whenWriteDisconnected();
}
void shutdownWrite() override {
out->shutdownWrite();
}
......
......@@ -106,6 +106,20 @@ public:
// output stream. If it finds one, it performs the pump. Otherwise, it returns null.
//
// The default implementation always returns null.
virtual Promise<void> whenWriteDisconnected() = 0;
// Returns a promise that resolves when the stream has become disconnected such that new write()s
// will fail with a DISCONNECTED exception. This is particularly useful, for example, to cancel
// work early when it is detected that no one will receive the result.
//
// Note that not all streams are able to detect this condition without actually performing a
// write(); such stream implementations may return a promise that never resolves. (In particular,
// as of this writing, whenWriteDisconnected() is not implemented on Windows. Also, for TCP
// streams, not all disconnects are detectable -- a power or network failure may lead the
// connection to hang forever, or until configured socket options lead to a timeout.)
//
// Unlike most other asynchronous stream methods, it is safe to call whenWriteDisconnected()
// multiple times without canceling the previous promises.
};
class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream {
......
......@@ -365,6 +365,13 @@ void UnixEventPort::FdObserver::fire(short events) {
}
}
if (events & (EPOLLHUP | EPOLLERR)) {
KJ_IF_MAYBE(f, hupFulfiller) {
f->get()->fulfill();
hupFulfiller = nullptr;
}
}
if (events & EPOLLPRI) {
KJ_IF_MAYBE(f, urgentFulfiller) {
f->get()->fulfill();
......@@ -398,6 +405,12 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() {
return kj::mv(paf.promise);
}
Promise<void> UnixEventPort::FdObserver::whenWriteDisconnected() {
auto paf = newPromiseAndFulfiller<void>();
hupFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
bool UnixEventPort::wait() {
return doEpollWait(
timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue))
......@@ -652,6 +665,13 @@ void UnixEventPort::FdObserver::fire(short events) {
}
}
if (events & (POLLHUP | POLLERR | POLLNVAL)) {
KJ_IF_MAYBE(f, hupFulfiller) {
f->get()->fulfill();
hupFulfiller = nullptr;
}
}
if (events & POLLPRI) {
KJ_IF_MAYBE(f, urgentFulfiller) {
f->get()->fulfill();
......@@ -675,7 +695,16 @@ void UnixEventPort::FdObserver::fire(short events) {
short UnixEventPort::FdObserver::getEventMask() {
return (readFulfiller == nullptr ? 0 : (POLLIN | POLLRDHUP)) |
(writeFulfiller == nullptr ? 0 : POLLOUT) |
(urgentFulfiller == nullptr ? 0 : POLLPRI);
(urgentFulfiller == nullptr ? 0 : POLLPRI) |
// The POSIX standard says POLLHUP and POLLERR will be reported even if not requested.
// But on MacOS, if `events` is 0, then POLLHUP apparently will not be reported:
// https://openradar.appspot.com/37537852
// It seems that by settingc any non-zero value -- even one documented as ignored -- we
// cause POLLHUP to be reported. Both POLLHUP and POLLERR are documented as being ignored.
// So, we'll go ahead and set them. This has no effect on non-broken OSs, causes MacOS to
// do the right thing, and sort of looks as if we're explicitly requesting notification of
// these two conditions, which we do after all want to know about.
POLLHUP | POLLERR;
}
Promise<void> UnixEventPort::FdObserver::whenBecomesReadable() {
......@@ -724,6 +753,19 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() {
return kj::mv(paf.promise);
}
Promise<void> UnixEventPort::FdObserver::whenWriteDisconnected() {
if (prev == nullptr) {
KJ_DASSERT(next == nullptr);
prev = eventPort.observersTail;
*prev = this;
eventPort.observersTail = &next;
}
auto paf = newPromiseAndFulfiller<void>();
hupFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
class UnixEventPort::PollContext {
public:
PollContext(FdObserver* ptr) {
......
......@@ -278,6 +278,9 @@ public:
// WARNING: This has some known weird behavior on macOS. See
// https://github.com/sandstorm-io/capnproto/issues/374.
Promise<void> whenWriteDisconnected();
// Resolves when poll() on the file descriptor reports POLLHUP or POLLERR.
private:
UnixEventPort& eventPort;
int fd;
......@@ -286,6 +289,7 @@ private:
kj::Maybe<Own<PromiseFulfiller<void>>> readFulfiller;
kj::Maybe<Own<PromiseFulfiller<void>>> writeFulfiller;
kj::Maybe<Own<PromiseFulfiller<void>>> urgentFulfiller;
kj::Maybe<Own<PromiseFulfiller<void>>> hupFulfiller;
// Replaced each time `whenBecomesReadable()` or `whenBecomesWritable()` is called. Reverted to
// null every time an event is fired.
......
......@@ -126,6 +126,8 @@ public:
}
return kj::READY_NOW;
}
Promise<void> whenWriteDisconnected() override { KJ_UNIMPLEMENTED("not used"); }
};
KJ_TEST("gzip decompression") {
......
......@@ -118,6 +118,8 @@ public:
Promise<void> write(const void* buffer, size_t size) override;
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override;
Promise<void> whenWriteDisconnected() override { return inner.whenWriteDisconnected(); }
inline Promise<void> flush() {
return pump(Z_SYNC_FLUSH);
}
......
// Copyright (c) 2019 Cloudflare, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
// Run http-test, but use real OS socketpairs to connect rather than using in-process pipes.
// This is essentially an integration test between KJ HTTP and KJ OS socket handling.
#define KJ_HTTP_TEST_USE_OS_PIPE 1
#include "http-test.c++"
......@@ -26,6 +26,22 @@
#include <kj/test.h>
#include <map>
#if KJ_HTTP_TEST_USE_OS_PIPE
// Run the test using OS-level socketpairs. (See http-socketpair-test.c++.)
#define KJ_HTTP_TEST_SETUP_IO \
auto io = kj::setupAsyncIo(); \
auto& waitScope = io.waitScope
#define KJ_HTTP_TEST_CREATE_2PIPE \
io.provider->newTwoWayPipe()
#else
// Run the test using in-process two-way pipes.
#define KJ_HTTP_TEST_SETUP_IO \
kj::EventLoop eventLoop; \
kj::WaitScope waitScope(eventLoop)
#define KJ_HTTP_TEST_CREATE_2PIPE \
kj::newTwoWayPipe()
#endif
namespace kj {
namespace {
......@@ -288,6 +304,10 @@ public:
return inner.tryPumpFrom(input, amount);
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
void shutdownWrite() override {
return inner.shutdownWrite();
}
......@@ -428,8 +448,8 @@ kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::ArrayPtr<const byte>
}));
}
void testHttpClientRequest(kj::WaitScope& waitScope, const HttpRequestTestCase& testCase) {
auto pipe = kj::newTwoWayPipe();
void testHttpClientRequest(kj::WaitScope& waitScope, const HttpRequestTestCase& testCase,
kj::TwoWayPipe pipe) {
auto serverTask = expectRead(*pipe.ends[1], testCase.raw).then([&]() {
static const char SIMPLE_RESPONSE[] =
......@@ -469,8 +489,7 @@ void testHttpClientRequest(kj::WaitScope& waitScope, const HttpRequestTestCase&
}
void testHttpClientResponse(kj::WaitScope& waitScope, const HttpResponseTestCase& testCase,
size_t readFragmentSize) {
auto pipe = kj::newTwoWayPipe();
size_t readFragmentSize, kj::TwoWayPipe pipe) {
ReadFragmenter fragmenter(*pipe.ends[0], readFragmentSize);
auto expectedReqText = testCase.method == HttpMethod::GET || testCase.method == HttpMethod::HEAD
......@@ -610,9 +629,8 @@ private:
void testHttpServerRequest(kj::WaitScope& waitScope, kj::Timer& timer,
const HttpRequestTestCase& requestCase,
const HttpResponseTestCase& responseCase) {
auto pipe = kj::newTwoWayPipe();
const HttpResponseTestCase& responseCase,
kj::TwoWayPipe pipe) {
HttpHeaderTable table;
TestHttpService service(requestCase, responseCase, table);
HttpServer server(timer, table, service);
......@@ -857,35 +875,32 @@ kj::ArrayPtr<const HttpResponseTestCase> responseTestCases() {
}
KJ_TEST("HttpClient requests") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
for (auto& testCase: requestTestCases()) {
if (testCase.side == SERVER_ONLY) continue;
KJ_CONTEXT(testCase.raw);
testHttpClientRequest(waitScope, testCase);
testHttpClientRequest(waitScope, testCase, KJ_HTTP_TEST_CREATE_2PIPE);
}
}
KJ_TEST("HttpClient responses") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
size_t FRAGMENT_SIZES[] = { 1, 2, 3, 4, 5, 6, 7, 8, 16, 31, kj::maxValue };
for (auto& testCase: responseTestCases()) {
if (testCase.side == SERVER_ONLY) continue;
for (size_t fragmentSize: FRAGMENT_SIZES) {
KJ_CONTEXT(testCase.raw, fragmentSize);
testHttpClientResponse(waitScope, testCase, fragmentSize);
testHttpClientResponse(waitScope, testCase, fragmentSize, KJ_HTTP_TEST_CREATE_2PIPE);
}
}
}
KJ_TEST("HttpClient canceled write") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto serverPromise = pipe.ends[1]->readAllText();
......@@ -918,10 +933,9 @@ KJ_TEST("HttpClient canceled write") {
}
KJ_TEST("HttpClient chunked body gather-write") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto serverPromise = pipe.ends[1]->readAllText();
......@@ -969,10 +983,9 @@ KJ_TEST("HttpClient chunked body pump from fixed length stream") {
kj::StringPtr body = "foo bar baz";
};
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto serverPromise = pipe.ends[1]->readAllText();
......@@ -1021,15 +1034,15 @@ KJ_TEST("HttpServer requests") {
3, {"foo"}
};
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
for (auto& testCase: requestTestCases()) {
if (testCase.side == CLIENT_ONLY) continue;
KJ_CONTEXT(testCase.raw);
testHttpServerRequest(waitScope, timer, testCase,
testCase.method == HttpMethod::HEAD ? HEAD_RESPONSE : RESPONSE);
testCase.method == HttpMethod::HEAD ? HEAD_RESPONSE : RESPONSE,
KJ_HTTP_TEST_CREATE_2PIPE);
}
}
......@@ -1054,15 +1067,15 @@ KJ_TEST("HttpServer responses") {
uint64_t(0), {},
};
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
for (auto& testCase: responseTestCases()) {
if (testCase.side == CLIENT_ONLY) continue;
KJ_CONTEXT(testCase.raw);
testHttpServerRequest(waitScope, timer,
testCase.method == HttpMethod::HEAD ? HEAD_REQUEST : REQUEST, testCase);
testCase.method == HttpMethod::HEAD ? HEAD_REQUEST : REQUEST, testCase,
KJ_HTTP_TEST_CREATE_2PIPE);
}
}
......@@ -1217,9 +1230,8 @@ kj::ArrayPtr<const HttpTestCase> pipelineTestCases() {
KJ_TEST("HttpClient pipeline") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
kj::Promise<void> writeResponsesPromise = kj::READY_NOW;
for (auto& testCase: PIPELINE_TESTS) {
......@@ -1247,9 +1259,8 @@ KJ_TEST("HttpClient pipeline") {
KJ_TEST("HttpClient parallel pipeline") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
kj::Promise<void> readRequestsPromise = kj::READY_NOW;
kj::Promise<void> writeResponsesPromise = kj::READY_NOW;
......@@ -1311,10 +1322,9 @@ KJ_TEST("HttpClient parallel pipeline") {
KJ_TEST("HttpServer pipeline") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
TestHttpService service(PIPELINE_TESTS, table);
......@@ -1340,10 +1350,9 @@ KJ_TEST("HttpServer pipeline") {
KJ_TEST("HttpServer parallel pipeline") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto allRequestText =
kj::strArray(KJ_MAP(testCase, PIPELINE_TESTS) { return testCase.request.raw; }, "");
......@@ -1370,10 +1379,9 @@ KJ_TEST("HttpServer parallel pipeline") {
KJ_TEST("HttpClient <-> HttpServer") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
TestHttpService service(PIPELINE_TESTS, table);
......@@ -1395,8 +1403,7 @@ KJ_TEST("HttpClient <-> HttpServer") {
// -----------------------------------------------------------------------------
KJ_TEST("HttpInputStream requests") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::HttpHeaderTable table;
......@@ -1434,8 +1441,7 @@ KJ_TEST("HttpInputStream requests") {
}
KJ_TEST("HttpInputStream responses") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::HttpHeaderTable table;
......@@ -1475,8 +1481,7 @@ KJ_TEST("HttpInputStream responses") {
}
KJ_TEST("HttpInputStream bare messages") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::HttpHeaderTable table;
......@@ -1530,9 +1535,8 @@ KJ_TEST("HttpInputStream bare messages") {
// -----------------------------------------------------------------------------
KJ_TEST("WebSocket core protocol") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto client = newWebSocket(kj::mv(pipe.ends[0]), nullptr);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
......@@ -1591,9 +1595,8 @@ KJ_TEST("WebSocket core protocol") {
}
KJ_TEST("WebSocket fragmented") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
......@@ -1629,9 +1632,8 @@ public:
};
KJ_TEST("WebSocket masked") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
FakeEntropySource maskGenerator;
auto client = kj::mv(pipe.ends[0]);
......@@ -1657,9 +1659,8 @@ KJ_TEST("WebSocket masked") {
}
KJ_TEST("WebSocket unsolicited pong") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
......@@ -1684,9 +1685,8 @@ KJ_TEST("WebSocket unsolicited pong") {
}
KJ_TEST("WebSocket ping") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
......@@ -1722,9 +1722,8 @@ KJ_TEST("WebSocket ping") {
}
KJ_TEST("WebSocket ping mid-send") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
......@@ -1791,6 +1790,10 @@ public:
return out->tryPumpFrom(input, amount);
}
Promise<void> whenWriteDisconnected() override {
return out->whenWriteDisconnected();
}
void shutdownWrite() override {
out = nullptr;
}
......@@ -1801,8 +1804,7 @@ private:
};
KJ_TEST("WebSocket double-ping mid-send") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
auto upPipe = newOneWayPipe();
auto downPipe = newOneWayPipe();
......@@ -1839,9 +1841,8 @@ KJ_TEST("WebSocket double-ping mid-send") {
}
KJ_TEST("WebSocket ping received during pong send") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
......@@ -1875,10 +1876,9 @@ KJ_TEST("WebSocket ping received during pong send") {
}
KJ_TEST("WebSocket pump disconnect on send") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe1 = kj::newTwoWayPipe();
auto pipe2 = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe1 = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe2 = KJ_HTTP_TEST_CREATE_2PIPE;
auto client1 = newWebSocket(kj::mv(pipe1.ends[0]), nullptr);
auto server1 = newWebSocket(kj::mv(pipe1.ends[1]), nullptr);
......@@ -1900,10 +1900,9 @@ KJ_TEST("WebSocket pump disconnect on send") {
}
KJ_TEST("WebSocket pump disconnect on receive") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe1 = kj::newTwoWayPipe();
auto pipe2 = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe1 = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe2 = KJ_HTTP_TEST_CREATE_2PIPE;
auto server1 = newWebSocket(kj::mv(pipe1.ends[1]), nullptr);
auto client2 = newWebSocket(kj::mv(pipe2.ends[0]), nullptr);
......@@ -2053,9 +2052,8 @@ void testWebSocketClient(kj::WaitScope& waitScope, HttpHeaderTable& headerTable,
}
KJ_TEST("HttpClient WebSocket handshake") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
......@@ -2084,9 +2082,8 @@ KJ_TEST("HttpClient WebSocket handshake") {
}
KJ_TEST("HttpClient WebSocket error") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
KJ_HTTP_TEST_SETUP_IO;
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
......@@ -2131,10 +2128,9 @@ KJ_TEST("HttpClient WebSocket error") {
}
KJ_TEST("HttpServer WebSocket handshake") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
......@@ -2158,10 +2154,9 @@ KJ_TEST("HttpServer WebSocket handshake") {
}
KJ_TEST("HttpServer WebSocket handshake error") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
......@@ -2189,10 +2184,9 @@ KJ_TEST("HttpServer WebSocket handshake error") {
KJ_TEST("HttpServer request timeout") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
TestHttpService service(PIPELINE_TESTS, table);
......@@ -2215,10 +2209,9 @@ KJ_TEST("HttpServer request timeout") {
KJ_TEST("HttpServer pipeline timeout") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
TestHttpService service(PIPELINE_TESTS, table);
......@@ -2269,10 +2262,9 @@ private:
KJ_TEST("HttpServer no response") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
BrokenHttpService service;
......@@ -2297,10 +2289,9 @@ KJ_TEST("HttpServer no response") {
KJ_TEST("HttpServer disconnected") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
BrokenHttpService service(KJ_EXCEPTION(DISCONNECTED, "disconnected"));
......@@ -2319,10 +2310,9 @@ KJ_TEST("HttpServer disconnected") {
KJ_TEST("HttpServer overloaded") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
BrokenHttpService service(KJ_EXCEPTION(OVERLOADED, "overloaded"));
......@@ -2341,10 +2331,9 @@ KJ_TEST("HttpServer overloaded") {
KJ_TEST("HttpServer unimplemented") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
BrokenHttpService service(KJ_EXCEPTION(UNIMPLEMENTED, "unimplemented"));
......@@ -2363,10 +2352,9 @@ KJ_TEST("HttpServer unimplemented") {
KJ_TEST("HttpServer threw exception") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
BrokenHttpService service(KJ_EXCEPTION(FAILED, "failed"));
......@@ -2407,10 +2395,9 @@ private:
KJ_TEST("HttpServer threw exception after starting response") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
PartialResponseService service;
......@@ -2455,10 +2442,9 @@ private:
KJ_TEST("HttpServer failed to write complete response but didn't throw") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
PartialResponseNoThrowService service;
......@@ -2526,10 +2512,9 @@ private:
KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
PumpResponseService service;
......@@ -2550,16 +2535,72 @@ KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") {
"Hello, World!", text);
}
class HangingHttpService final: public HttpService {
// HttpService that hangs forever.
public:
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& responseSender) override {
kj::Promise<void> result = kj::NEVER_DONE;
++inFlight;
return result.attach(kj::defer([this]() {
if (--inFlight == 0) {
KJ_IF_MAYBE(f, onCancelFulfiller) {
f->get()->fulfill();
}
}
}));
}
kj::Promise<void> onCancel() {
auto paf = kj::newPromiseAndFulfiller<void>();
onCancelFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
uint inFlight = 0;
private:
kj::Maybe<kj::Exception> exception;
kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> onCancelFulfiller;
};
KJ_TEST("HttpServer cancels request when client disconnects") {
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
HangingHttpService service;
HttpServer server(timer, table, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
KJ_EXPECT(service.inFlight == 0);
static constexpr auto request = "GET / HTTP/1.1\r\n\r\n"_kj;
pipe.ends[1]->write(request.begin(), request.size()).wait(waitScope);
auto cancelPromise = service.onCancel();
KJ_EXPECT(!cancelPromise.poll(waitScope));
KJ_EXPECT(service.inFlight == 1);
// Disconnect client and verify server cancels.
pipe.ends[1] = nullptr;
KJ_ASSERT(cancelPromise.poll(waitScope));
KJ_EXPECT(service.inFlight == 0);
cancelPromise.wait(waitScope);
}
// -----------------------------------------------------------------------------
KJ_TEST("newHttpService from HttpClient") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto frontPipe = kj::newTwoWayPipe();
auto backPipe = kj::newTwoWayPipe();
auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE;
kj::Promise<void> writeResponsesPromise = kj::READY_NOW;
for (auto& testCase: PIPELINE_TESTS) {
......@@ -2596,11 +2637,10 @@ KJ_TEST("newHttpService from HttpClient") {
}
KJ_TEST("newHttpService from HttpClient WebSockets") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto frontPipe = kj::newTwoWayPipe();
auto backPipe = kj::newTwoWayPipe();
auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
auto writeResponsesPromise = expectRead(*backPipe.ends[1], request)
......@@ -2639,11 +2679,10 @@ KJ_TEST("newHttpService from HttpClient WebSockets") {
}
KJ_TEST("newHttpService from HttpClient WebSockets disconnect") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto frontPipe = kj::newTwoWayPipe();
auto backPipe = kj::newTwoWayPipe();
auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
auto writeResponsesPromise = expectRead(*backPipe.ends[1], request)
......@@ -2683,8 +2722,7 @@ KJ_TEST("newHttpService from HttpClient WebSockets disconnect") {
KJ_TEST("newHttpClient from HttpService") {
auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
HttpHeaderTable table;
......@@ -2697,10 +2735,9 @@ KJ_TEST("newHttpClient from HttpService") {
}
KJ_TEST("newHttpClient from HttpService WebSockets") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe();
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
......@@ -2746,6 +2783,9 @@ public:
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
return inner->tryPumpFrom(input, amount);
}
Promise<void> whenWriteDisconnected() override {
return inner->whenWriteDisconnected();
}
void shutdownWrite() override {
return inner->shutdownWrite();
}
......
......@@ -1760,6 +1760,10 @@ public:
return fork.addBranch();
}
Promise<void> whenWriteDisconnected() {
return inner.whenWriteDisconnected();
}
private:
AsyncOutputStream& inner;
kj::Promise<void> writeQueue = kj::READY_NOW;
......@@ -1787,6 +1791,9 @@ public:
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()");
}
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
};
class HttpDiscardingEntityWriter final: public kj::AsyncOutputStream {
......@@ -1797,6 +1804,9 @@ public:
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return kj::READY_NOW;
}
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
};
class HttpFixedLengthEntityWriter final: public kj::AsyncOutputStream {
......@@ -1877,6 +1887,10 @@ public:
return kj::mv(promise);
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
private:
HttpOutputStream& inner;
uint64_t length;
......@@ -1960,6 +1974,10 @@ public:
}
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
private:
HttpOutputStream& inner;
};
......@@ -2022,6 +2040,18 @@ public:
return kj::READY_NOW;
}
void abort() override {
queuedPong = nullptr;
sendingPong = nullptr;
disconnected = true;
stream->abortRead();
stream->shutdownWrite();
}
kj::Promise<void> whenAborted() override {
return stream->whenWriteDisconnected();
}
kj::Promise<Message> receive() override {
size_t headerSize = Header::headerSize(recvData.begin(), recvData.size());
......@@ -2506,7 +2536,12 @@ kj::Promise<void> WebSocket::pumpTo(WebSocket& other) {
} else {
// Fall back to default implementation.
return kj::evalNow([&]() {
return pumpWebSocketLoop(*this, other);
auto cancelPromise = other.whenAborted().then([this]() -> kj::Promise<void> {
this->abort();
return KJ_EXCEPTION(DISCONNECTED,
"destination of WebSocket pump disconnected prematurely");
});
return pumpWebSocketLoop(*this, other).exclusiveJoin(kj::mv(cancelPromise));
});
}
}
......@@ -2517,12 +2552,7 @@ kj::Maybe<kj::Promise<void>> WebSocket::tryPumpFrom(WebSocket& other) {
namespace {
class AbortableWebSocket: public WebSocket {
public:
virtual void abort() = 0;
};
class WebSocketPipeImpl final: public AbortableWebSocket, public kj::Refcounted {
class WebSocketPipeImpl final: public WebSocket, public kj::Refcounted {
// Represents one direction of a WebSocket pipe.
//
// This class behaves as a "loopback" WebSocket: a message sent using send() is received using
......@@ -2548,6 +2578,12 @@ public:
} else {
ownState = heap<Aborted>();
state = *ownState;
aborted = true;
KJ_IF_MAYBE(f, abortedFulfiller) {
f->get()->fulfill();
abortedFulfiller = nullptr;
}
}
}
......@@ -2581,6 +2617,20 @@ public:
return kj::READY_NOW;
}
}
kj::Promise<void> whenAborted() override {
if (aborted) {
return kj::READY_NOW;
} else KJ_IF_MAYBE(p, abortedPromise) {
return p->addBranch();
} else {
auto paf = newPromiseAndFulfiller<void>();
abortedFulfiller = kj::mv(paf.fulfiller);
auto fork = paf.promise.fork();
auto result = fork.addBranch();
abortedPromise = kj::mv(fork);
return result;
}
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
KJ_IF_MAYBE(s, state) {
return s->tryPumpFrom(other);
......@@ -2605,12 +2655,16 @@ public:
}
private:
kj::Maybe<AbortableWebSocket&> state;
kj::Maybe<WebSocket&> state;
// Object-oriented state! If any method call is blocked waiting on activity from the other end,
// then `state` is non-null and method calls should be forwarded to it. If no calls are
// outstanding, `state` is null.
kj::Own<AbortableWebSocket> ownState;
kj::Own<WebSocket> ownState;
bool aborted = false;
Maybe<Own<PromiseFulfiller<void>>> abortedFulfiller = nullptr;
Maybe<ForkedPromise<void>> abortedPromise = nullptr;
void endState(WebSocket& obj) {
KJ_IF_MAYBE(s, state) {
......@@ -2626,7 +2680,7 @@ private:
};
typedef kj::OneOf<kj::ArrayPtr<const char>, kj::ArrayPtr<const byte>, ClosePtr> MessagePtr;
class BlockedSend final: public AbortableWebSocket {
class BlockedSend final: public WebSocket {
public:
BlockedSend(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, MessagePtr message)
: fulfiller(fulfiller), pipe(pipe), message(kj::mv(message)) {
......@@ -2643,6 +2697,9 @@ private:
pipe.endState(*this);
pipe.abort();
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_FAIL_ASSERT("another message send is already in progress");
......@@ -2713,7 +2770,7 @@ private:
Canceler canceler;
};
class BlockedPumpFrom final: public AbortableWebSocket {
class BlockedPumpFrom final: public WebSocket {
public:
BlockedPumpFrom(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe,
WebSocket& input)
......@@ -2731,6 +2788,9 @@ private:
pipe.endState(*this);
pipe.abort();
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_FAIL_ASSERT("another message send is already in progress");
......@@ -2788,7 +2848,7 @@ private:
Canceler canceler;
};
class BlockedReceive final: public AbortableWebSocket {
class BlockedReceive final: public WebSocket {
public:
BlockedReceive(kj::PromiseFulfiller<Message>& fulfiller, WebSocketPipeImpl& pipe)
: fulfiller(fulfiller), pipe(pipe) {
......@@ -2805,6 +2865,9 @@ private:
pipe.endState(*this);
pipe.abort();
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_REQUIRE(canceler.isEmpty(), "already pumping");
......@@ -2860,7 +2923,7 @@ private:
Canceler canceler;
};
class BlockedPumpTo final: public AbortableWebSocket {
class BlockedPumpTo final: public WebSocket {
public:
BlockedPumpTo(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, WebSocket& output)
: fulfiller(fulfiller), pipe(pipe), output(output) {
......@@ -2881,6 +2944,9 @@ private:
pipe.endState(*this);
pipe.abort();
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress");
......@@ -2926,11 +2992,14 @@ private:
Canceler canceler;
};
class Disconnected final: public AbortableWebSocket {
class Disconnected final: public WebSocket {
public:
void abort() override {
// can ignore
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_FAIL_REQUIRE("can't send() after disconnect()");
......@@ -2956,11 +3025,14 @@ private:
}
};
class Aborted final: public AbortableWebSocket {
class Aborted final: public WebSocket {
public:
void abort() override {
// can ignore
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed");
......@@ -3009,6 +3081,13 @@ public:
kj::Promise<void> disconnect() override {
return out->disconnect();
}
void abort() override {
in->abort();
out->abort();
}
kj::Promise<void> whenAborted() override {
return out->whenAborted();
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
return out->tryPumpFrom(other);
}
......@@ -3382,6 +3461,22 @@ public:
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
void shutdownWrite() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->shutdownWrite();
......@@ -3456,6 +3551,22 @@ public:
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
public:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncOutputStream>> stream;
......@@ -3865,6 +3976,9 @@ public:
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return kj::READY_NOW;
}
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
// We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method.
};
......@@ -4079,6 +4193,13 @@ private:
kj::Promise<void> disconnect() override {
return inner->disconnect();
}
void abort() override {
// Don't need to worry about completion task in this case -- cancelling it is reasonable.
inner->abort();
}
kj::Promise<void> whenAborted() override {
return inner->whenAborted();
}
kj::Promise<Message> receive() override {
return inner->receive().then([this](Message&& message) -> kj::Promise<Message> {
if (message.is<WebSocket::Close>()) {
......@@ -4712,6 +4833,12 @@ private:
kj::Promise<void> disconnect() override {
return kj::cp(exception);
}
void abort() override {
kj::throwRecoverableException(kj::cp(exception));
}
kj::Promise<void> whenAborted() override {
return kj::cp(exception);
}
kj::Promise<Message> receive() override {
return kj::cp(exception);
}
......@@ -4796,7 +4923,10 @@ kj::Promise<bool> HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection
}
}
auto promise = obj->loop(true);
// Start reading requests and responding to them, but immediately cancel processing if the client
// disconnects.
auto promise = obj->loop(true)
.exclusiveJoin(connection.whenWriteDisconnected().then([]() {return false;}));
// Eagerly evaluate so that we drop the connection when the promise resolves, even if the caller
// doesn't eagerly evaluate.
......
......@@ -489,6 +489,17 @@ public:
// shutdown, but is sometimes useful when you want the other end to trigger whatever behavior
// it normally triggers when a connection is dropped.
virtual void abort() = 0;
// Forcefully close this WebSocket, such that the remote end should get a DISCONNECTED error if
// it continues to write. This differs from disconnect(), which only closes the sending
// direction, but still allows receives.
virtual kj::Promise<void> whenAborted() = 0;
// Resolves when the remote side aborts the connection such that send() would throw DISCONNECTED,
// if this can be detected without actually writing a message. (If not, this promise never
// resolves, but send() or receive() will throw DISCONNECTED when appropriate. See also
// kj::AsyncOutputStream::whenWriteDisconnected().)
struct Close {
uint16_t code;
kj::String reason;
......@@ -629,6 +640,9 @@ public:
//
// `url` and `headers` are invalidated on the first read from `requestBody` or when the returned
// promise resolves, whichever comes first.
//
// Request processing can be canceled by dropping the returned promise. HttpServer may do so if
// the client disconnects prematurely.
virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host);
// Handles CONNECT requests. Only relevant for proxy services. Default implementation throws
......
......@@ -180,6 +180,10 @@ public:
return writeInternal(pieces[0], pieces.slice(1, pieces.size()));
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
void shutdownWrite() override {
KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment