Unverified Commit eefc0dd3 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #788 from capnproto/cancel-http

Make HttpServer detect client disconnect and cancel request processing
parents bd54ed6e 1325f3c2
...@@ -2103,5 +2103,98 @@ KJ_TEST("Userland tee buffer size limit") { ...@@ -2103,5 +2103,98 @@ KJ_TEST("Userland tee buffer size limit") {
} }
} }
KJ_TEST("Userspace OneWayPipe whenWriteDisconnected()") {
kj::EventLoop loop;
WaitScope ws(loop);
auto pipe = newOneWayPipe();
auto abortedPromise = pipe.out->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(ws));
pipe.in = nullptr;
KJ_ASSERT(abortedPromise.poll(ws));
abortedPromise.wait(ws);
}
KJ_TEST("Userspace TwoWayPipe whenWriteDisconnected()") {
kj::EventLoop loop;
WaitScope ws(loop);
auto pipe = newTwoWayPipe();
auto abortedPromise = pipe.ends[0]->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(ws));
pipe.ends[1] = nullptr;
KJ_ASSERT(abortedPromise.poll(ws));
abortedPromise.wait(ws);
}
#if !_WIN32 // We don't currently support detecting disconnect with IOCP.
#if !__CYGWIN__ // TODO(soon): Figure out why whenWriteDisconnected() doesn't work on Cygwin.
KJ_TEST("OS OneWayPipe whenWriteDisconnected()") {
auto io = setupAsyncIo();
auto pipe = io.provider->newOneWayPipe();
pipe.out->write("foo", 3).wait(io.waitScope);
auto abortedPromise = pipe.out->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(io.waitScope));
pipe.in = nullptr;
KJ_ASSERT(abortedPromise.poll(io.waitScope));
abortedPromise.wait(io.waitScope);
}
KJ_TEST("OS TwoWayPipe whenWriteDisconnected()") {
auto io = setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
pipe.ends[0]->write("foo", 3).wait(io.waitScope);
pipe.ends[1]->write("bar", 3).wait(io.waitScope);
auto abortedPromise = pipe.ends[0]->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(io.waitScope));
pipe.ends[1] = nullptr;
KJ_ASSERT(abortedPromise.poll(io.waitScope));
abortedPromise.wait(io.waitScope);
char buffer[4];
KJ_ASSERT(pipe.ends[0]->tryRead(&buffer, sizeof(buffer), sizeof(buffer)).wait(io.waitScope) == 3);
buffer[3] = '\0';
KJ_EXPECT(buffer == "bar"_kj);
}
KJ_TEST("import socket FD that's already broken") {
auto io = setupAsyncIo();
int fds[2];
KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
KJ_SYSCALL(write(fds[1], "foo", 3));
KJ_SYSCALL(close(fds[1]));
auto stream = io.lowLevelProvider->wrapSocketFd(fds[0], LowLevelAsyncIoProvider::TAKE_OWNERSHIP);
auto abortedPromise = stream->whenWriteDisconnected();
KJ_ASSERT(abortedPromise.poll(io.waitScope));
abortedPromise.wait(io.waitScope);
char buffer[4];
KJ_ASSERT(stream->tryRead(&buffer, sizeof(buffer), sizeof(buffer)).wait(io.waitScope) == 3);
buffer[3] = '\0';
KJ_EXPECT(buffer == "foo"_kj);
}
#endif // !__CYGWIN__
#endif // !_WIN32
} // namespace } // namespace
} // namespace kj } // namespace kj
...@@ -179,6 +179,17 @@ public: ...@@ -179,6 +179,17 @@ public:
} }
} }
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(p, writeDisconnectedPromise) {
return p->addBranch();
} else {
auto fork = observer.whenWriteDisconnected().fork();
auto result = fork.addBranch();
writeDisconnectedPromise = kj::mv(fork);
return kj::mv(result);
}
}
void shutdownWrite() override { void shutdownWrite() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// UnixAsyncIoProvider interface. // UnixAsyncIoProvider interface.
...@@ -290,6 +301,7 @@ public: ...@@ -290,6 +301,7 @@ public:
private: private:
UnixEventPort& eventPort; UnixEventPort& eventPort;
UnixEventPort::FdObserver observer; UnixEventPort::FdObserver observer;
Maybe<ForkedPromise<void>> writeDisconnectedPromise;
Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes, Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
size_t alreadyRead) { size_t alreadyRead) {
......
...@@ -311,6 +311,23 @@ public: ...@@ -311,6 +311,23 @@ public:
}); });
} }
Promise<void> whenWriteDisconnected() override {
// Windows IOCP does not provide a direct, documented way to detect when the socket disconnects
// without actually doing a read or write. However, there is an undocoumented-but-stable
// ioctl called IOCTL_AFD_POLL which can be used for this purpose. In fact, select() is
// implemented in terms of this ioctl -- performed synchronously -- but it's entirely possible
// to put only one socket into the list and perform the ioctl asynchronously. Here's the
// source code for select() in Windows 2000 (not sure how this became public...):
//
// https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655
//
// And here's an interesting discussion: https://github.com/python-trio/trio/issues/52
//
// TODO(soon): Implement this with IOCTL_AFD_POLL. For now I'm leaving it unimplemented because
// I added this method for a Linux-only use case.
return NEVER_DONE;
}
void shutdownWrite() override { void shutdownWrite() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// Win32AsyncIoProvider interface. // Win32AsyncIoProvider interface.
......
...@@ -228,6 +228,12 @@ public: ...@@ -228,6 +228,12 @@ public:
} else { } else {
ownState = kj::heap<AbortedRead>(); ownState = kj::heap<AbortedRead>();
state = *ownState; state = *ownState;
readAborted = true;
KJ_IF_MAYBE(f, readAbortFulfiller) {
f->get()->fulfill();
readAbortFulfiller = nullptr;
}
} }
} }
...@@ -268,6 +274,21 @@ public: ...@@ -268,6 +274,21 @@ public:
} }
} }
Promise<void> whenWriteDisconnected() override {
if (readAborted) {
return kj::READY_NOW;
} else KJ_IF_MAYBE(p, readAbortPromise) {
return p->addBranch();
} else {
auto paf = newPromiseAndFulfiller<void>();
readAbortFulfiller = kj::mv(paf.fulfiller);
auto fork = paf.promise.fork();
auto result = fork.addBranch();
readAbortPromise = kj::mv(fork);
return result;
}
}
void shutdownWrite() override { void shutdownWrite() override {
KJ_IF_MAYBE(s, state) { KJ_IF_MAYBE(s, state) {
s->shutdownWrite(); s->shutdownWrite();
...@@ -285,6 +306,10 @@ private: ...@@ -285,6 +306,10 @@ private:
kj::Own<AsyncIoStream> ownState; kj::Own<AsyncIoStream> ownState;
bool readAborted = false;
Maybe<Own<PromiseFulfiller<void>>> readAbortFulfiller = nullptr;
Maybe<ForkedPromise<void>> readAbortPromise = nullptr;
void endState(AsyncIoStream& obj) { void endState(AsyncIoStream& obj) {
KJ_IF_MAYBE(s, state) { KJ_IF_MAYBE(s, state) {
if (s == &obj) { if (s == &obj) {
...@@ -443,6 +468,10 @@ private: ...@@ -443,6 +468,10 @@ private:
KJ_FAIL_REQUIRE("can't shutdownWrite() until previous write() completes"); KJ_FAIL_REQUIRE("can't shutdownWrite() until previous write() completes");
} }
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private: private:
PromiseFulfiller<void>& fulfiller; PromiseFulfiller<void>& fulfiller;
AsyncPipe& pipe; AsyncPipe& pipe;
...@@ -562,6 +591,10 @@ private: ...@@ -562,6 +591,10 @@ private:
KJ_FAIL_REQUIRE("can't shutdownWrite() until previous tryPumpFrom() completes"); KJ_FAIL_REQUIRE("can't shutdownWrite() until previous tryPumpFrom() completes");
} }
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private: private:
PromiseFulfiller<uint64_t>& fulfiller; PromiseFulfiller<uint64_t>& fulfiller;
AsyncPipe& pipe; AsyncPipe& pipe;
...@@ -733,6 +766,10 @@ private: ...@@ -733,6 +766,10 @@ private:
pipe.shutdownWrite(); pipe.shutdownWrite();
} }
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private: private:
PromiseFulfiller<size_t>& fulfiller; PromiseFulfiller<size_t>& fulfiller;
AsyncPipe& pipe; AsyncPipe& pipe;
...@@ -901,6 +938,10 @@ private: ...@@ -901,6 +938,10 @@ private:
pipe.shutdownWrite(); pipe.shutdownWrite();
} }
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private: private:
PromiseFulfiller<uint64_t>& fulfiller; PromiseFulfiller<uint64_t>& fulfiller;
AsyncPipe& pipe; AsyncPipe& pipe;
...@@ -937,6 +978,9 @@ private: ...@@ -937,6 +978,9 @@ private:
// ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped, // ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped,
// which is not an error even if reads have been aborted. // which is not an error even if reads have been aborted.
} }
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
}; };
class ShutdownedWrite final: public AsyncIoStream { class ShutdownedWrite final: public AsyncIoStream {
...@@ -966,6 +1010,9 @@ private: ...@@ -966,6 +1010,9 @@ private:
// ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped, // ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped,
// so it will only be called once anyhow. // so it will only be called once anyhow.
} }
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
}; };
}; };
...@@ -1013,6 +1060,10 @@ public: ...@@ -1013,6 +1060,10 @@ public:
return pipe->tryPumpFrom(input, amount); return pipe->tryPumpFrom(input, amount);
} }
Promise<void> whenWriteDisconnected() override {
return pipe->whenWriteDisconnected();
}
private: private:
Own<AsyncPipe> pipe; Own<AsyncPipe> pipe;
UnwindDetector unwind; UnwindDetector unwind;
...@@ -1049,6 +1100,9 @@ public: ...@@ -1049,6 +1100,9 @@ public:
AsyncInputStream& input, uint64_t amount) override { AsyncInputStream& input, uint64_t amount) override {
return out->tryPumpFrom(input, amount); return out->tryPumpFrom(input, amount);
} }
Promise<void> whenWriteDisconnected() override {
return out->whenWriteDisconnected();
}
void shutdownWrite() override { void shutdownWrite() override {
out->shutdownWrite(); out->shutdownWrite();
} }
......
...@@ -106,6 +106,20 @@ public: ...@@ -106,6 +106,20 @@ public:
// output stream. If it finds one, it performs the pump. Otherwise, it returns null. // output stream. If it finds one, it performs the pump. Otherwise, it returns null.
// //
// The default implementation always returns null. // The default implementation always returns null.
virtual Promise<void> whenWriteDisconnected() = 0;
// Returns a promise that resolves when the stream has become disconnected such that new write()s
// will fail with a DISCONNECTED exception. This is particularly useful, for example, to cancel
// work early when it is detected that no one will receive the result.
//
// Note that not all streams are able to detect this condition without actually performing a
// write(); such stream implementations may return a promise that never resolves. (In particular,
// as of this writing, whenWriteDisconnected() is not implemented on Windows. Also, for TCP
// streams, not all disconnects are detectable -- a power or network failure may lead the
// connection to hang forever, or until configured socket options lead to a timeout.)
//
// Unlike most other asynchronous stream methods, it is safe to call whenWriteDisconnected()
// multiple times without canceling the previous promises.
}; };
class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream { class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream {
......
...@@ -365,6 +365,13 @@ void UnixEventPort::FdObserver::fire(short events) { ...@@ -365,6 +365,13 @@ void UnixEventPort::FdObserver::fire(short events) {
} }
} }
if (events & (EPOLLHUP | EPOLLERR)) {
KJ_IF_MAYBE(f, hupFulfiller) {
f->get()->fulfill();
hupFulfiller = nullptr;
}
}
if (events & EPOLLPRI) { if (events & EPOLLPRI) {
KJ_IF_MAYBE(f, urgentFulfiller) { KJ_IF_MAYBE(f, urgentFulfiller) {
f->get()->fulfill(); f->get()->fulfill();
...@@ -398,6 +405,12 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() { ...@@ -398,6 +405,12 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() {
return kj::mv(paf.promise); return kj::mv(paf.promise);
} }
Promise<void> UnixEventPort::FdObserver::whenWriteDisconnected() {
auto paf = newPromiseAndFulfiller<void>();
hupFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
bool UnixEventPort::wait() { bool UnixEventPort::wait() {
return doEpollWait( return doEpollWait(
timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue)) timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue))
...@@ -652,6 +665,13 @@ void UnixEventPort::FdObserver::fire(short events) { ...@@ -652,6 +665,13 @@ void UnixEventPort::FdObserver::fire(short events) {
} }
} }
if (events & (POLLHUP | POLLERR | POLLNVAL)) {
KJ_IF_MAYBE(f, hupFulfiller) {
f->get()->fulfill();
hupFulfiller = nullptr;
}
}
if (events & POLLPRI) { if (events & POLLPRI) {
KJ_IF_MAYBE(f, urgentFulfiller) { KJ_IF_MAYBE(f, urgentFulfiller) {
f->get()->fulfill(); f->get()->fulfill();
...@@ -675,7 +695,16 @@ void UnixEventPort::FdObserver::fire(short events) { ...@@ -675,7 +695,16 @@ void UnixEventPort::FdObserver::fire(short events) {
short UnixEventPort::FdObserver::getEventMask() { short UnixEventPort::FdObserver::getEventMask() {
return (readFulfiller == nullptr ? 0 : (POLLIN | POLLRDHUP)) | return (readFulfiller == nullptr ? 0 : (POLLIN | POLLRDHUP)) |
(writeFulfiller == nullptr ? 0 : POLLOUT) | (writeFulfiller == nullptr ? 0 : POLLOUT) |
(urgentFulfiller == nullptr ? 0 : POLLPRI); (urgentFulfiller == nullptr ? 0 : POLLPRI) |
// The POSIX standard says POLLHUP and POLLERR will be reported even if not requested.
// But on MacOS, if `events` is 0, then POLLHUP apparently will not be reported:
// https://openradar.appspot.com/37537852
// It seems that by settingc any non-zero value -- even one documented as ignored -- we
// cause POLLHUP to be reported. Both POLLHUP and POLLERR are documented as being ignored.
// So, we'll go ahead and set them. This has no effect on non-broken OSs, causes MacOS to
// do the right thing, and sort of looks as if we're explicitly requesting notification of
// these two conditions, which we do after all want to know about.
POLLHUP | POLLERR;
} }
Promise<void> UnixEventPort::FdObserver::whenBecomesReadable() { Promise<void> UnixEventPort::FdObserver::whenBecomesReadable() {
...@@ -724,6 +753,19 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() { ...@@ -724,6 +753,19 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() {
return kj::mv(paf.promise); return kj::mv(paf.promise);
} }
Promise<void> UnixEventPort::FdObserver::whenWriteDisconnected() {
if (prev == nullptr) {
KJ_DASSERT(next == nullptr);
prev = eventPort.observersTail;
*prev = this;
eventPort.observersTail = &next;
}
auto paf = newPromiseAndFulfiller<void>();
hupFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
class UnixEventPort::PollContext { class UnixEventPort::PollContext {
public: public:
PollContext(FdObserver* ptr) { PollContext(FdObserver* ptr) {
......
...@@ -278,6 +278,9 @@ public: ...@@ -278,6 +278,9 @@ public:
// WARNING: This has some known weird behavior on macOS. See // WARNING: This has some known weird behavior on macOS. See
// https://github.com/sandstorm-io/capnproto/issues/374. // https://github.com/sandstorm-io/capnproto/issues/374.
Promise<void> whenWriteDisconnected();
// Resolves when poll() on the file descriptor reports POLLHUP or POLLERR.
private: private:
UnixEventPort& eventPort; UnixEventPort& eventPort;
int fd; int fd;
...@@ -286,6 +289,7 @@ private: ...@@ -286,6 +289,7 @@ private:
kj::Maybe<Own<PromiseFulfiller<void>>> readFulfiller; kj::Maybe<Own<PromiseFulfiller<void>>> readFulfiller;
kj::Maybe<Own<PromiseFulfiller<void>>> writeFulfiller; kj::Maybe<Own<PromiseFulfiller<void>>> writeFulfiller;
kj::Maybe<Own<PromiseFulfiller<void>>> urgentFulfiller; kj::Maybe<Own<PromiseFulfiller<void>>> urgentFulfiller;
kj::Maybe<Own<PromiseFulfiller<void>>> hupFulfiller;
// Replaced each time `whenBecomesReadable()` or `whenBecomesWritable()` is called. Reverted to // Replaced each time `whenBecomesReadable()` or `whenBecomesWritable()` is called. Reverted to
// null every time an event is fired. // null every time an event is fired.
......
...@@ -126,6 +126,8 @@ public: ...@@ -126,6 +126,8 @@ public:
} }
return kj::READY_NOW; return kj::READY_NOW;
} }
Promise<void> whenWriteDisconnected() override { KJ_UNIMPLEMENTED("not used"); }
}; };
KJ_TEST("gzip decompression") { KJ_TEST("gzip decompression") {
......
...@@ -118,6 +118,8 @@ public: ...@@ -118,6 +118,8 @@ public:
Promise<void> write(const void* buffer, size_t size) override; Promise<void> write(const void* buffer, size_t size) override;
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override; Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override;
Promise<void> whenWriteDisconnected() override { return inner.whenWriteDisconnected(); }
inline Promise<void> flush() { inline Promise<void> flush() {
return pump(Z_SYNC_FLUSH); return pump(Z_SYNC_FLUSH);
} }
......
// Copyright (c) 2019 Cloudflare, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
// Run http-test, but use real OS socketpairs to connect rather than using in-process pipes.
// This is essentially an integration test between KJ HTTP and KJ OS socket handling.
#define KJ_HTTP_TEST_USE_OS_PIPE 1
#include "http-test.c++"
...@@ -26,6 +26,22 @@ ...@@ -26,6 +26,22 @@
#include <kj/test.h> #include <kj/test.h>
#include <map> #include <map>
#if KJ_HTTP_TEST_USE_OS_PIPE
// Run the test using OS-level socketpairs. (See http-socketpair-test.c++.)
#define KJ_HTTP_TEST_SETUP_IO \
auto io = kj::setupAsyncIo(); \
auto& waitScope = io.waitScope
#define KJ_HTTP_TEST_CREATE_2PIPE \
io.provider->newTwoWayPipe()
#else
// Run the test using in-process two-way pipes.
#define KJ_HTTP_TEST_SETUP_IO \
kj::EventLoop eventLoop; \
kj::WaitScope waitScope(eventLoop)
#define KJ_HTTP_TEST_CREATE_2PIPE \
kj::newTwoWayPipe()
#endif
namespace kj { namespace kj {
namespace { namespace {
...@@ -288,6 +304,10 @@ public: ...@@ -288,6 +304,10 @@ public:
return inner.tryPumpFrom(input, amount); return inner.tryPumpFrom(input, amount);
} }
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
void shutdownWrite() override { void shutdownWrite() override {
return inner.shutdownWrite(); return inner.shutdownWrite();
} }
...@@ -428,8 +448,8 @@ kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::ArrayPtr<const byte> ...@@ -428,8 +448,8 @@ kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::ArrayPtr<const byte>
})); }));
} }
void testHttpClientRequest(kj::WaitScope& waitScope, const HttpRequestTestCase& testCase) { void testHttpClientRequest(kj::WaitScope& waitScope, const HttpRequestTestCase& testCase,
auto pipe = kj::newTwoWayPipe(); kj::TwoWayPipe pipe) {
auto serverTask = expectRead(*pipe.ends[1], testCase.raw).then([&]() { auto serverTask = expectRead(*pipe.ends[1], testCase.raw).then([&]() {
static const char SIMPLE_RESPONSE[] = static const char SIMPLE_RESPONSE[] =
...@@ -469,8 +489,7 @@ void testHttpClientRequest(kj::WaitScope& waitScope, const HttpRequestTestCase& ...@@ -469,8 +489,7 @@ void testHttpClientRequest(kj::WaitScope& waitScope, const HttpRequestTestCase&
} }
void testHttpClientResponse(kj::WaitScope& waitScope, const HttpResponseTestCase& testCase, void testHttpClientResponse(kj::WaitScope& waitScope, const HttpResponseTestCase& testCase,
size_t readFragmentSize) { size_t readFragmentSize, kj::TwoWayPipe pipe) {
auto pipe = kj::newTwoWayPipe();
ReadFragmenter fragmenter(*pipe.ends[0], readFragmentSize); ReadFragmenter fragmenter(*pipe.ends[0], readFragmentSize);
auto expectedReqText = testCase.method == HttpMethod::GET || testCase.method == HttpMethod::HEAD auto expectedReqText = testCase.method == HttpMethod::GET || testCase.method == HttpMethod::HEAD
...@@ -610,9 +629,8 @@ private: ...@@ -610,9 +629,8 @@ private:
void testHttpServerRequest(kj::WaitScope& waitScope, kj::Timer& timer, void testHttpServerRequest(kj::WaitScope& waitScope, kj::Timer& timer,
const HttpRequestTestCase& requestCase, const HttpRequestTestCase& requestCase,
const HttpResponseTestCase& responseCase) { const HttpResponseTestCase& responseCase,
auto pipe = kj::newTwoWayPipe(); kj::TwoWayPipe pipe) {
HttpHeaderTable table; HttpHeaderTable table;
TestHttpService service(requestCase, responseCase, table); TestHttpService service(requestCase, responseCase, table);
HttpServer server(timer, table, service); HttpServer server(timer, table, service);
...@@ -857,35 +875,32 @@ kj::ArrayPtr<const HttpResponseTestCase> responseTestCases() { ...@@ -857,35 +875,32 @@ kj::ArrayPtr<const HttpResponseTestCase> responseTestCases() {
} }
KJ_TEST("HttpClient requests") { KJ_TEST("HttpClient requests") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
for (auto& testCase: requestTestCases()) { for (auto& testCase: requestTestCases()) {
if (testCase.side == SERVER_ONLY) continue; if (testCase.side == SERVER_ONLY) continue;
KJ_CONTEXT(testCase.raw); KJ_CONTEXT(testCase.raw);
testHttpClientRequest(waitScope, testCase); testHttpClientRequest(waitScope, testCase, KJ_HTTP_TEST_CREATE_2PIPE);
} }
} }
KJ_TEST("HttpClient responses") { KJ_TEST("HttpClient responses") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
size_t FRAGMENT_SIZES[] = { 1, 2, 3, 4, 5, 6, 7, 8, 16, 31, kj::maxValue }; size_t FRAGMENT_SIZES[] = { 1, 2, 3, 4, 5, 6, 7, 8, 16, 31, kj::maxValue };
for (auto& testCase: responseTestCases()) { for (auto& testCase: responseTestCases()) {
if (testCase.side == SERVER_ONLY) continue; if (testCase.side == SERVER_ONLY) continue;
for (size_t fragmentSize: FRAGMENT_SIZES) { for (size_t fragmentSize: FRAGMENT_SIZES) {
KJ_CONTEXT(testCase.raw, fragmentSize); KJ_CONTEXT(testCase.raw, fragmentSize);
testHttpClientResponse(waitScope, testCase, fragmentSize); testHttpClientResponse(waitScope, testCase, fragmentSize, KJ_HTTP_TEST_CREATE_2PIPE);
} }
} }
} }
KJ_TEST("HttpClient canceled write") { KJ_TEST("HttpClient canceled write") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto serverPromise = pipe.ends[1]->readAllText(); auto serverPromise = pipe.ends[1]->readAllText();
...@@ -918,10 +933,9 @@ KJ_TEST("HttpClient canceled write") { ...@@ -918,10 +933,9 @@ KJ_TEST("HttpClient canceled write") {
} }
KJ_TEST("HttpClient chunked body gather-write") { KJ_TEST("HttpClient chunked body gather-write") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto serverPromise = pipe.ends[1]->readAllText(); auto serverPromise = pipe.ends[1]->readAllText();
...@@ -969,10 +983,9 @@ KJ_TEST("HttpClient chunked body pump from fixed length stream") { ...@@ -969,10 +983,9 @@ KJ_TEST("HttpClient chunked body pump from fixed length stream") {
kj::StringPtr body = "foo bar baz"; kj::StringPtr body = "foo bar baz";
}; };
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto serverPromise = pipe.ends[1]->readAllText(); auto serverPromise = pipe.ends[1]->readAllText();
...@@ -1021,15 +1034,15 @@ KJ_TEST("HttpServer requests") { ...@@ -1021,15 +1034,15 @@ KJ_TEST("HttpServer requests") {
3, {"foo"} 3, {"foo"}
}; };
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
for (auto& testCase: requestTestCases()) { for (auto& testCase: requestTestCases()) {
if (testCase.side == CLIENT_ONLY) continue; if (testCase.side == CLIENT_ONLY) continue;
KJ_CONTEXT(testCase.raw); KJ_CONTEXT(testCase.raw);
testHttpServerRequest(waitScope, timer, testCase, testHttpServerRequest(waitScope, timer, testCase,
testCase.method == HttpMethod::HEAD ? HEAD_RESPONSE : RESPONSE); testCase.method == HttpMethod::HEAD ? HEAD_RESPONSE : RESPONSE,
KJ_HTTP_TEST_CREATE_2PIPE);
} }
} }
...@@ -1054,15 +1067,15 @@ KJ_TEST("HttpServer responses") { ...@@ -1054,15 +1067,15 @@ KJ_TEST("HttpServer responses") {
uint64_t(0), {}, uint64_t(0), {},
}; };
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
for (auto& testCase: responseTestCases()) { for (auto& testCase: responseTestCases()) {
if (testCase.side == CLIENT_ONLY) continue; if (testCase.side == CLIENT_ONLY) continue;
KJ_CONTEXT(testCase.raw); KJ_CONTEXT(testCase.raw);
testHttpServerRequest(waitScope, timer, testHttpServerRequest(waitScope, timer,
testCase.method == HttpMethod::HEAD ? HEAD_REQUEST : REQUEST, testCase); testCase.method == HttpMethod::HEAD ? HEAD_REQUEST : REQUEST, testCase,
KJ_HTTP_TEST_CREATE_2PIPE);
} }
} }
...@@ -1217,9 +1230,8 @@ kj::ArrayPtr<const HttpTestCase> pipelineTestCases() { ...@@ -1217,9 +1230,8 @@ kj::ArrayPtr<const HttpTestCase> pipelineTestCases() {
KJ_TEST("HttpClient pipeline") { KJ_TEST("HttpClient pipeline") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
kj::Promise<void> writeResponsesPromise = kj::READY_NOW; kj::Promise<void> writeResponsesPromise = kj::READY_NOW;
for (auto& testCase: PIPELINE_TESTS) { for (auto& testCase: PIPELINE_TESTS) {
...@@ -1247,9 +1259,8 @@ KJ_TEST("HttpClient pipeline") { ...@@ -1247,9 +1259,8 @@ KJ_TEST("HttpClient pipeline") {
KJ_TEST("HttpClient parallel pipeline") { KJ_TEST("HttpClient parallel pipeline") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
kj::Promise<void> readRequestsPromise = kj::READY_NOW; kj::Promise<void> readRequestsPromise = kj::READY_NOW;
kj::Promise<void> writeResponsesPromise = kj::READY_NOW; kj::Promise<void> writeResponsesPromise = kj::READY_NOW;
...@@ -1311,10 +1322,9 @@ KJ_TEST("HttpClient parallel pipeline") { ...@@ -1311,10 +1322,9 @@ KJ_TEST("HttpClient parallel pipeline") {
KJ_TEST("HttpServer pipeline") { KJ_TEST("HttpServer pipeline") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
TestHttpService service(PIPELINE_TESTS, table); TestHttpService service(PIPELINE_TESTS, table);
...@@ -1340,10 +1350,9 @@ KJ_TEST("HttpServer pipeline") { ...@@ -1340,10 +1350,9 @@ KJ_TEST("HttpServer pipeline") {
KJ_TEST("HttpServer parallel pipeline") { KJ_TEST("HttpServer parallel pipeline") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto allRequestText = auto allRequestText =
kj::strArray(KJ_MAP(testCase, PIPELINE_TESTS) { return testCase.request.raw; }, ""); kj::strArray(KJ_MAP(testCase, PIPELINE_TESTS) { return testCase.request.raw; }, "");
...@@ -1370,10 +1379,9 @@ KJ_TEST("HttpServer parallel pipeline") { ...@@ -1370,10 +1379,9 @@ KJ_TEST("HttpServer parallel pipeline") {
KJ_TEST("HttpClient <-> HttpServer") { KJ_TEST("HttpClient <-> HttpServer") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
TestHttpService service(PIPELINE_TESTS, table); TestHttpService service(PIPELINE_TESTS, table);
...@@ -1395,8 +1403,7 @@ KJ_TEST("HttpClient <-> HttpServer") { ...@@ -1395,8 +1403,7 @@ KJ_TEST("HttpClient <-> HttpServer") {
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
KJ_TEST("HttpInputStream requests") { KJ_TEST("HttpInputStream requests") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::HttpHeaderTable table; kj::HttpHeaderTable table;
...@@ -1434,8 +1441,7 @@ KJ_TEST("HttpInputStream requests") { ...@@ -1434,8 +1441,7 @@ KJ_TEST("HttpInputStream requests") {
} }
KJ_TEST("HttpInputStream responses") { KJ_TEST("HttpInputStream responses") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::HttpHeaderTable table; kj::HttpHeaderTable table;
...@@ -1475,8 +1481,7 @@ KJ_TEST("HttpInputStream responses") { ...@@ -1475,8 +1481,7 @@ KJ_TEST("HttpInputStream responses") {
} }
KJ_TEST("HttpInputStream bare messages") { KJ_TEST("HttpInputStream bare messages") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::HttpHeaderTable table; kj::HttpHeaderTable table;
...@@ -1530,9 +1535,8 @@ KJ_TEST("HttpInputStream bare messages") { ...@@ -1530,9 +1535,8 @@ KJ_TEST("HttpInputStream bare messages") {
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
KJ_TEST("WebSocket core protocol") { KJ_TEST("WebSocket core protocol") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
auto client = newWebSocket(kj::mv(pipe.ends[0]), nullptr); auto client = newWebSocket(kj::mv(pipe.ends[0]), nullptr);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
...@@ -1591,9 +1595,8 @@ KJ_TEST("WebSocket core protocol") { ...@@ -1591,9 +1595,8 @@ KJ_TEST("WebSocket core protocol") {
} }
KJ_TEST("WebSocket fragmented") { KJ_TEST("WebSocket fragmented") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]); auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
...@@ -1629,9 +1632,8 @@ public: ...@@ -1629,9 +1632,8 @@ public:
}; };
KJ_TEST("WebSocket masked") { KJ_TEST("WebSocket masked") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
FakeEntropySource maskGenerator; FakeEntropySource maskGenerator;
auto client = kj::mv(pipe.ends[0]); auto client = kj::mv(pipe.ends[0]);
...@@ -1657,9 +1659,8 @@ KJ_TEST("WebSocket masked") { ...@@ -1657,9 +1659,8 @@ KJ_TEST("WebSocket masked") {
} }
KJ_TEST("WebSocket unsolicited pong") { KJ_TEST("WebSocket unsolicited pong") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]); auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
...@@ -1684,9 +1685,8 @@ KJ_TEST("WebSocket unsolicited pong") { ...@@ -1684,9 +1685,8 @@ KJ_TEST("WebSocket unsolicited pong") {
} }
KJ_TEST("WebSocket ping") { KJ_TEST("WebSocket ping") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]); auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
...@@ -1722,9 +1722,8 @@ KJ_TEST("WebSocket ping") { ...@@ -1722,9 +1722,8 @@ KJ_TEST("WebSocket ping") {
} }
KJ_TEST("WebSocket ping mid-send") { KJ_TEST("WebSocket ping mid-send") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]); auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
...@@ -1791,6 +1790,10 @@ public: ...@@ -1791,6 +1790,10 @@ public:
return out->tryPumpFrom(input, amount); return out->tryPumpFrom(input, amount);
} }
Promise<void> whenWriteDisconnected() override {
return out->whenWriteDisconnected();
}
void shutdownWrite() override { void shutdownWrite() override {
out = nullptr; out = nullptr;
} }
...@@ -1801,8 +1804,7 @@ private: ...@@ -1801,8 +1804,7 @@ private:
}; };
KJ_TEST("WebSocket double-ping mid-send") { KJ_TEST("WebSocket double-ping mid-send") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
auto upPipe = newOneWayPipe(); auto upPipe = newOneWayPipe();
auto downPipe = newOneWayPipe(); auto downPipe = newOneWayPipe();
...@@ -1839,9 +1841,8 @@ KJ_TEST("WebSocket double-ping mid-send") { ...@@ -1839,9 +1841,8 @@ KJ_TEST("WebSocket double-ping mid-send") {
} }
KJ_TEST("WebSocket ping received during pong send") { KJ_TEST("WebSocket ping received during pong send") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]); auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr); auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
...@@ -1875,10 +1876,9 @@ KJ_TEST("WebSocket ping received during pong send") { ...@@ -1875,10 +1876,9 @@ KJ_TEST("WebSocket ping received during pong send") {
} }
KJ_TEST("WebSocket pump disconnect on send") { KJ_TEST("WebSocket pump disconnect on send") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe1 = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe1 = kj::newTwoWayPipe(); auto pipe2 = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe2 = kj::newTwoWayPipe();
auto client1 = newWebSocket(kj::mv(pipe1.ends[0]), nullptr); auto client1 = newWebSocket(kj::mv(pipe1.ends[0]), nullptr);
auto server1 = newWebSocket(kj::mv(pipe1.ends[1]), nullptr); auto server1 = newWebSocket(kj::mv(pipe1.ends[1]), nullptr);
...@@ -1900,10 +1900,9 @@ KJ_TEST("WebSocket pump disconnect on send") { ...@@ -1900,10 +1900,9 @@ KJ_TEST("WebSocket pump disconnect on send") {
} }
KJ_TEST("WebSocket pump disconnect on receive") { KJ_TEST("WebSocket pump disconnect on receive") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe1 = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe1 = kj::newTwoWayPipe(); auto pipe2 = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe2 = kj::newTwoWayPipe();
auto server1 = newWebSocket(kj::mv(pipe1.ends[1]), nullptr); auto server1 = newWebSocket(kj::mv(pipe1.ends[1]), nullptr);
auto client2 = newWebSocket(kj::mv(pipe2.ends[0]), nullptr); auto client2 = newWebSocket(kj::mv(pipe2.ends[0]), nullptr);
...@@ -2053,9 +2052,8 @@ void testWebSocketClient(kj::WaitScope& waitScope, HttpHeaderTable& headerTable, ...@@ -2053,9 +2052,8 @@ void testWebSocketClient(kj::WaitScope& waitScope, HttpHeaderTable& headerTable,
} }
KJ_TEST("HttpClient WebSocket handshake") { KJ_TEST("HttpClient WebSocket handshake") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
...@@ -2084,9 +2082,8 @@ KJ_TEST("HttpClient WebSocket handshake") { ...@@ -2084,9 +2082,8 @@ KJ_TEST("HttpClient WebSocket handshake") {
} }
KJ_TEST("HttpClient WebSocket error") { KJ_TEST("HttpClient WebSocket error") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto pipe = kj::newTwoWayPipe();
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
...@@ -2131,10 +2128,9 @@ KJ_TEST("HttpClient WebSocket error") { ...@@ -2131,10 +2128,9 @@ KJ_TEST("HttpClient WebSocket error") {
} }
KJ_TEST("HttpServer WebSocket handshake") { KJ_TEST("HttpServer WebSocket handshake") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable::Builder tableBuilder; HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
...@@ -2158,10 +2154,9 @@ KJ_TEST("HttpServer WebSocket handshake") { ...@@ -2158,10 +2154,9 @@ KJ_TEST("HttpServer WebSocket handshake") {
} }
KJ_TEST("HttpServer WebSocket handshake error") { KJ_TEST("HttpServer WebSocket handshake error") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable::Builder tableBuilder; HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
...@@ -2189,10 +2184,9 @@ KJ_TEST("HttpServer WebSocket handshake error") { ...@@ -2189,10 +2184,9 @@ KJ_TEST("HttpServer WebSocket handshake error") {
KJ_TEST("HttpServer request timeout") { KJ_TEST("HttpServer request timeout") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
TestHttpService service(PIPELINE_TESTS, table); TestHttpService service(PIPELINE_TESTS, table);
...@@ -2215,10 +2209,9 @@ KJ_TEST("HttpServer request timeout") { ...@@ -2215,10 +2209,9 @@ KJ_TEST("HttpServer request timeout") {
KJ_TEST("HttpServer pipeline timeout") { KJ_TEST("HttpServer pipeline timeout") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
TestHttpService service(PIPELINE_TESTS, table); TestHttpService service(PIPELINE_TESTS, table);
...@@ -2269,10 +2262,9 @@ private: ...@@ -2269,10 +2262,9 @@ private:
KJ_TEST("HttpServer no response") { KJ_TEST("HttpServer no response") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
BrokenHttpService service; BrokenHttpService service;
...@@ -2297,10 +2289,9 @@ KJ_TEST("HttpServer no response") { ...@@ -2297,10 +2289,9 @@ KJ_TEST("HttpServer no response") {
KJ_TEST("HttpServer disconnected") { KJ_TEST("HttpServer disconnected") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
BrokenHttpService service(KJ_EXCEPTION(DISCONNECTED, "disconnected")); BrokenHttpService service(KJ_EXCEPTION(DISCONNECTED, "disconnected"));
...@@ -2319,10 +2310,9 @@ KJ_TEST("HttpServer disconnected") { ...@@ -2319,10 +2310,9 @@ KJ_TEST("HttpServer disconnected") {
KJ_TEST("HttpServer overloaded") { KJ_TEST("HttpServer overloaded") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
BrokenHttpService service(KJ_EXCEPTION(OVERLOADED, "overloaded")); BrokenHttpService service(KJ_EXCEPTION(OVERLOADED, "overloaded"));
...@@ -2341,10 +2331,9 @@ KJ_TEST("HttpServer overloaded") { ...@@ -2341,10 +2331,9 @@ KJ_TEST("HttpServer overloaded") {
KJ_TEST("HttpServer unimplemented") { KJ_TEST("HttpServer unimplemented") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
BrokenHttpService service(KJ_EXCEPTION(UNIMPLEMENTED, "unimplemented")); BrokenHttpService service(KJ_EXCEPTION(UNIMPLEMENTED, "unimplemented"));
...@@ -2363,10 +2352,9 @@ KJ_TEST("HttpServer unimplemented") { ...@@ -2363,10 +2352,9 @@ KJ_TEST("HttpServer unimplemented") {
KJ_TEST("HttpServer threw exception") { KJ_TEST("HttpServer threw exception") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
BrokenHttpService service(KJ_EXCEPTION(FAILED, "failed")); BrokenHttpService service(KJ_EXCEPTION(FAILED, "failed"));
...@@ -2407,10 +2395,9 @@ private: ...@@ -2407,10 +2395,9 @@ private:
KJ_TEST("HttpServer threw exception after starting response") { KJ_TEST("HttpServer threw exception after starting response") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
PartialResponseService service; PartialResponseService service;
...@@ -2455,10 +2442,9 @@ private: ...@@ -2455,10 +2442,9 @@ private:
KJ_TEST("HttpServer failed to write complete response but didn't throw") { KJ_TEST("HttpServer failed to write complete response but didn't throw") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
PartialResponseNoThrowService service; PartialResponseNoThrowService service;
...@@ -2526,10 +2512,9 @@ private: ...@@ -2526,10 +2512,9 @@ private:
KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") { KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table; HttpHeaderTable table;
PumpResponseService service; PumpResponseService service;
...@@ -2550,16 +2535,72 @@ KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") { ...@@ -2550,16 +2535,72 @@ KJ_TEST("HttpFixedLengthEntityWriter correctly implements tryPumpFrom") {
"Hello, World!", text); "Hello, World!", text);
} }
class HangingHttpService final: public HttpService {
// HttpService that hangs forever.
public:
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& responseSender) override {
kj::Promise<void> result = kj::NEVER_DONE;
++inFlight;
return result.attach(kj::defer([this]() {
if (--inFlight == 0) {
KJ_IF_MAYBE(f, onCancelFulfiller) {
f->get()->fulfill();
}
}
}));
}
kj::Promise<void> onCancel() {
auto paf = kj::newPromiseAndFulfiller<void>();
onCancelFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
uint inFlight = 0;
private:
kj::Maybe<kj::Exception> exception;
kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> onCancelFulfiller;
};
KJ_TEST("HttpServer cancels request when client disconnects") {
KJ_HTTP_TEST_SETUP_IO;
kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable table;
HangingHttpService service;
HttpServer server(timer, table, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
KJ_EXPECT(service.inFlight == 0);
static constexpr auto request = "GET / HTTP/1.1\r\n\r\n"_kj;
pipe.ends[1]->write(request.begin(), request.size()).wait(waitScope);
auto cancelPromise = service.onCancel();
KJ_EXPECT(!cancelPromise.poll(waitScope));
KJ_EXPECT(service.inFlight == 1);
// Disconnect client and verify server cancels.
pipe.ends[1] = nullptr;
KJ_ASSERT(cancelPromise.poll(waitScope));
KJ_EXPECT(service.inFlight == 0);
cancelPromise.wait(waitScope);
}
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
KJ_TEST("newHttpService from HttpClient") { KJ_TEST("newHttpService from HttpClient") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto frontPipe = kj::newTwoWayPipe(); auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto backPipe = kj::newTwoWayPipe(); auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE;
kj::Promise<void> writeResponsesPromise = kj::READY_NOW; kj::Promise<void> writeResponsesPromise = kj::READY_NOW;
for (auto& testCase: PIPELINE_TESTS) { for (auto& testCase: PIPELINE_TESTS) {
...@@ -2596,11 +2637,10 @@ KJ_TEST("newHttpService from HttpClient") { ...@@ -2596,11 +2637,10 @@ KJ_TEST("newHttpService from HttpClient") {
} }
KJ_TEST("newHttpService from HttpClient WebSockets") { KJ_TEST("newHttpService from HttpClient WebSockets") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto frontPipe = kj::newTwoWayPipe(); auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto backPipe = kj::newTwoWayPipe(); auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
auto writeResponsesPromise = expectRead(*backPipe.ends[1], request) auto writeResponsesPromise = expectRead(*backPipe.ends[1], request)
...@@ -2639,11 +2679,10 @@ KJ_TEST("newHttpService from HttpClient WebSockets") { ...@@ -2639,11 +2679,10 @@ KJ_TEST("newHttpService from HttpClient WebSockets") {
} }
KJ_TEST("newHttpService from HttpClient WebSockets disconnect") { KJ_TEST("newHttpService from HttpClient WebSockets disconnect") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto frontPipe = kj::newTwoWayPipe(); auto frontPipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto backPipe = kj::newTwoWayPipe(); auto backPipe = KJ_HTTP_TEST_CREATE_2PIPE;
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE); auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
auto writeResponsesPromise = expectRead(*backPipe.ends[1], request) auto writeResponsesPromise = expectRead(*backPipe.ends[1], request)
...@@ -2683,8 +2722,7 @@ KJ_TEST("newHttpService from HttpClient WebSockets disconnect") { ...@@ -2683,8 +2722,7 @@ KJ_TEST("newHttpService from HttpClient WebSockets disconnect") {
KJ_TEST("newHttpClient from HttpService") { KJ_TEST("newHttpClient from HttpService") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
HttpHeaderTable table; HttpHeaderTable table;
...@@ -2697,10 +2735,9 @@ KJ_TEST("newHttpClient from HttpService") { ...@@ -2697,10 +2735,9 @@ KJ_TEST("newHttpClient from HttpService") {
} }
KJ_TEST("newHttpClient from HttpService WebSockets") { KJ_TEST("newHttpClient from HttpService WebSockets") {
kj::EventLoop eventLoop; KJ_HTTP_TEST_SETUP_IO;
kj::WaitScope waitScope(eventLoop);
kj::TimerImpl timer(kj::origin<kj::TimePoint>()); kj::TimerImpl timer(kj::origin<kj::TimePoint>());
auto pipe = kj::newTwoWayPipe(); auto pipe = KJ_HTTP_TEST_CREATE_2PIPE;
HttpHeaderTable::Builder tableBuilder; HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header"); HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
...@@ -2746,6 +2783,9 @@ public: ...@@ -2746,6 +2783,9 @@ public:
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
return inner->tryPumpFrom(input, amount); return inner->tryPumpFrom(input, amount);
} }
Promise<void> whenWriteDisconnected() override {
return inner->whenWriteDisconnected();
}
void shutdownWrite() override { void shutdownWrite() override {
return inner->shutdownWrite(); return inner->shutdownWrite();
} }
......
...@@ -1760,6 +1760,10 @@ public: ...@@ -1760,6 +1760,10 @@ public:
return fork.addBranch(); return fork.addBranch();
} }
Promise<void> whenWriteDisconnected() {
return inner.whenWriteDisconnected();
}
private: private:
AsyncOutputStream& inner; AsyncOutputStream& inner;
kj::Promise<void> writeQueue = kj::READY_NOW; kj::Promise<void> writeQueue = kj::READY_NOW;
...@@ -1787,6 +1791,9 @@ public: ...@@ -1787,6 +1791,9 @@ public:
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()"); return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()");
} }
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
}; };
class HttpDiscardingEntityWriter final: public kj::AsyncOutputStream { class HttpDiscardingEntityWriter final: public kj::AsyncOutputStream {
...@@ -1797,6 +1804,9 @@ public: ...@@ -1797,6 +1804,9 @@ public:
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return kj::READY_NOW; return kj::READY_NOW;
} }
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
}; };
class HttpFixedLengthEntityWriter final: public kj::AsyncOutputStream { class HttpFixedLengthEntityWriter final: public kj::AsyncOutputStream {
...@@ -1877,6 +1887,10 @@ public: ...@@ -1877,6 +1887,10 @@ public:
return kj::mv(promise); return kj::mv(promise);
} }
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
private: private:
HttpOutputStream& inner; HttpOutputStream& inner;
uint64_t length; uint64_t length;
...@@ -1960,6 +1974,10 @@ public: ...@@ -1960,6 +1974,10 @@ public:
} }
} }
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
private: private:
HttpOutputStream& inner; HttpOutputStream& inner;
}; };
...@@ -2022,6 +2040,18 @@ public: ...@@ -2022,6 +2040,18 @@ public:
return kj::READY_NOW; return kj::READY_NOW;
} }
void abort() override {
queuedPong = nullptr;
sendingPong = nullptr;
disconnected = true;
stream->abortRead();
stream->shutdownWrite();
}
kj::Promise<void> whenAborted() override {
return stream->whenWriteDisconnected();
}
kj::Promise<Message> receive() override { kj::Promise<Message> receive() override {
size_t headerSize = Header::headerSize(recvData.begin(), recvData.size()); size_t headerSize = Header::headerSize(recvData.begin(), recvData.size());
...@@ -2506,7 +2536,12 @@ kj::Promise<void> WebSocket::pumpTo(WebSocket& other) { ...@@ -2506,7 +2536,12 @@ kj::Promise<void> WebSocket::pumpTo(WebSocket& other) {
} else { } else {
// Fall back to default implementation. // Fall back to default implementation.
return kj::evalNow([&]() { return kj::evalNow([&]() {
return pumpWebSocketLoop(*this, other); auto cancelPromise = other.whenAborted().then([this]() -> kj::Promise<void> {
this->abort();
return KJ_EXCEPTION(DISCONNECTED,
"destination of WebSocket pump disconnected prematurely");
});
return pumpWebSocketLoop(*this, other).exclusiveJoin(kj::mv(cancelPromise));
}); });
} }
} }
...@@ -2517,12 +2552,7 @@ kj::Maybe<kj::Promise<void>> WebSocket::tryPumpFrom(WebSocket& other) { ...@@ -2517,12 +2552,7 @@ kj::Maybe<kj::Promise<void>> WebSocket::tryPumpFrom(WebSocket& other) {
namespace { namespace {
class AbortableWebSocket: public WebSocket { class WebSocketPipeImpl final: public WebSocket, public kj::Refcounted {
public:
virtual void abort() = 0;
};
class WebSocketPipeImpl final: public AbortableWebSocket, public kj::Refcounted {
// Represents one direction of a WebSocket pipe. // Represents one direction of a WebSocket pipe.
// //
// This class behaves as a "loopback" WebSocket: a message sent using send() is received using // This class behaves as a "loopback" WebSocket: a message sent using send() is received using
...@@ -2548,6 +2578,12 @@ public: ...@@ -2548,6 +2578,12 @@ public:
} else { } else {
ownState = heap<Aborted>(); ownState = heap<Aborted>();
state = *ownState; state = *ownState;
aborted = true;
KJ_IF_MAYBE(f, abortedFulfiller) {
f->get()->fulfill();
abortedFulfiller = nullptr;
}
} }
} }
...@@ -2581,6 +2617,20 @@ public: ...@@ -2581,6 +2617,20 @@ public:
return kj::READY_NOW; return kj::READY_NOW;
} }
} }
kj::Promise<void> whenAborted() override {
if (aborted) {
return kj::READY_NOW;
} else KJ_IF_MAYBE(p, abortedPromise) {
return p->addBranch();
} else {
auto paf = newPromiseAndFulfiller<void>();
abortedFulfiller = kj::mv(paf.fulfiller);
auto fork = paf.promise.fork();
auto result = fork.addBranch();
abortedPromise = kj::mv(fork);
return result;
}
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
KJ_IF_MAYBE(s, state) { KJ_IF_MAYBE(s, state) {
return s->tryPumpFrom(other); return s->tryPumpFrom(other);
...@@ -2605,12 +2655,16 @@ public: ...@@ -2605,12 +2655,16 @@ public:
} }
private: private:
kj::Maybe<AbortableWebSocket&> state; kj::Maybe<WebSocket&> state;
// Object-oriented state! If any method call is blocked waiting on activity from the other end, // Object-oriented state! If any method call is blocked waiting on activity from the other end,
// then `state` is non-null and method calls should be forwarded to it. If no calls are // then `state` is non-null and method calls should be forwarded to it. If no calls are
// outstanding, `state` is null. // outstanding, `state` is null.
kj::Own<AbortableWebSocket> ownState; kj::Own<WebSocket> ownState;
bool aborted = false;
Maybe<Own<PromiseFulfiller<void>>> abortedFulfiller = nullptr;
Maybe<ForkedPromise<void>> abortedPromise = nullptr;
void endState(WebSocket& obj) { void endState(WebSocket& obj) {
KJ_IF_MAYBE(s, state) { KJ_IF_MAYBE(s, state) {
...@@ -2626,7 +2680,7 @@ private: ...@@ -2626,7 +2680,7 @@ private:
}; };
typedef kj::OneOf<kj::ArrayPtr<const char>, kj::ArrayPtr<const byte>, ClosePtr> MessagePtr; typedef kj::OneOf<kj::ArrayPtr<const char>, kj::ArrayPtr<const byte>, ClosePtr> MessagePtr;
class BlockedSend final: public AbortableWebSocket { class BlockedSend final: public WebSocket {
public: public:
BlockedSend(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, MessagePtr message) BlockedSend(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, MessagePtr message)
: fulfiller(fulfiller), pipe(pipe), message(kj::mv(message)) { : fulfiller(fulfiller), pipe(pipe), message(kj::mv(message)) {
...@@ -2643,6 +2697,9 @@ private: ...@@ -2643,6 +2697,9 @@ private:
pipe.endState(*this); pipe.endState(*this);
pipe.abort(); pipe.abort();
} }
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_FAIL_ASSERT("another message send is already in progress"); KJ_FAIL_ASSERT("another message send is already in progress");
...@@ -2713,7 +2770,7 @@ private: ...@@ -2713,7 +2770,7 @@ private:
Canceler canceler; Canceler canceler;
}; };
class BlockedPumpFrom final: public AbortableWebSocket { class BlockedPumpFrom final: public WebSocket {
public: public:
BlockedPumpFrom(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, BlockedPumpFrom(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe,
WebSocket& input) WebSocket& input)
...@@ -2731,6 +2788,9 @@ private: ...@@ -2731,6 +2788,9 @@ private:
pipe.endState(*this); pipe.endState(*this);
pipe.abort(); pipe.abort();
} }
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_FAIL_ASSERT("another message send is already in progress"); KJ_FAIL_ASSERT("another message send is already in progress");
...@@ -2788,7 +2848,7 @@ private: ...@@ -2788,7 +2848,7 @@ private:
Canceler canceler; Canceler canceler;
}; };
class BlockedReceive final: public AbortableWebSocket { class BlockedReceive final: public WebSocket {
public: public:
BlockedReceive(kj::PromiseFulfiller<Message>& fulfiller, WebSocketPipeImpl& pipe) BlockedReceive(kj::PromiseFulfiller<Message>& fulfiller, WebSocketPipeImpl& pipe)
: fulfiller(fulfiller), pipe(pipe) { : fulfiller(fulfiller), pipe(pipe) {
...@@ -2805,6 +2865,9 @@ private: ...@@ -2805,6 +2865,9 @@ private:
pipe.endState(*this); pipe.endState(*this);
pipe.abort(); pipe.abort();
} }
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_REQUIRE(canceler.isEmpty(), "already pumping"); KJ_REQUIRE(canceler.isEmpty(), "already pumping");
...@@ -2860,7 +2923,7 @@ private: ...@@ -2860,7 +2923,7 @@ private:
Canceler canceler; Canceler canceler;
}; };
class BlockedPumpTo final: public AbortableWebSocket { class BlockedPumpTo final: public WebSocket {
public: public:
BlockedPumpTo(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, WebSocket& output) BlockedPumpTo(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, WebSocket& output)
: fulfiller(fulfiller), pipe(pipe), output(output) { : fulfiller(fulfiller), pipe(pipe), output(output) {
...@@ -2881,6 +2944,9 @@ private: ...@@ -2881,6 +2944,9 @@ private:
pipe.endState(*this); pipe.endState(*this);
pipe.abort(); pipe.abort();
} }
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress");
...@@ -2926,11 +2992,14 @@ private: ...@@ -2926,11 +2992,14 @@ private:
Canceler canceler; Canceler canceler;
}; };
class Disconnected final: public AbortableWebSocket { class Disconnected final: public WebSocket {
public: public:
void abort() override { void abort() override {
// can ignore // can ignore
} }
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_FAIL_REQUIRE("can't send() after disconnect()"); KJ_FAIL_REQUIRE("can't send() after disconnect()");
...@@ -2956,11 +3025,14 @@ private: ...@@ -2956,11 +3025,14 @@ private:
} }
}; };
class Aborted final: public AbortableWebSocket { class Aborted final: public WebSocket {
public: public:
void abort() override { void abort() override {
// can ignore // can ignore
} }
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed");
...@@ -3009,6 +3081,13 @@ public: ...@@ -3009,6 +3081,13 @@ public:
kj::Promise<void> disconnect() override { kj::Promise<void> disconnect() override {
return out->disconnect(); return out->disconnect();
} }
void abort() override {
in->abort();
out->abort();
}
kj::Promise<void> whenAborted() override {
return out->whenAborted();
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
return out->tryPumpFrom(other); return out->tryPumpFrom(other);
} }
...@@ -3382,6 +3461,22 @@ public: ...@@ -3382,6 +3461,22 @@ public:
} }
} }
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
void shutdownWrite() override { void shutdownWrite() override {
KJ_IF_MAYBE(s, stream) { KJ_IF_MAYBE(s, stream) {
return s->get()->shutdownWrite(); return s->get()->shutdownWrite();
...@@ -3456,6 +3551,22 @@ public: ...@@ -3456,6 +3551,22 @@ public:
} }
} }
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(s, stream) {
return s->get()->whenWriteDisconnected();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected();
}, [](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return kj::READY_NOW;
} else {
return kj::mv(e);
}
});
}
}
public: public:
kj::ForkedPromise<void> promise; kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<AsyncOutputStream>> stream; kj::Maybe<kj::Own<AsyncOutputStream>> stream;
...@@ -3865,6 +3976,9 @@ public: ...@@ -3865,6 +3976,9 @@ public:
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return kj::READY_NOW; return kj::READY_NOW;
} }
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
// We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method.
}; };
...@@ -4079,6 +4193,13 @@ private: ...@@ -4079,6 +4193,13 @@ private:
kj::Promise<void> disconnect() override { kj::Promise<void> disconnect() override {
return inner->disconnect(); return inner->disconnect();
} }
void abort() override {
// Don't need to worry about completion task in this case -- cancelling it is reasonable.
inner->abort();
}
kj::Promise<void> whenAborted() override {
return inner->whenAborted();
}
kj::Promise<Message> receive() override { kj::Promise<Message> receive() override {
return inner->receive().then([this](Message&& message) -> kj::Promise<Message> { return inner->receive().then([this](Message&& message) -> kj::Promise<Message> {
if (message.is<WebSocket::Close>()) { if (message.is<WebSocket::Close>()) {
...@@ -4712,6 +4833,12 @@ private: ...@@ -4712,6 +4833,12 @@ private:
kj::Promise<void> disconnect() override { kj::Promise<void> disconnect() override {
return kj::cp(exception); return kj::cp(exception);
} }
void abort() override {
kj::throwRecoverableException(kj::cp(exception));
}
kj::Promise<void> whenAborted() override {
return kj::cp(exception);
}
kj::Promise<Message> receive() override { kj::Promise<Message> receive() override {
return kj::cp(exception); return kj::cp(exception);
} }
...@@ -4796,7 +4923,10 @@ kj::Promise<bool> HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection ...@@ -4796,7 +4923,10 @@ kj::Promise<bool> HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection
} }
} }
auto promise = obj->loop(true); // Start reading requests and responding to them, but immediately cancel processing if the client
// disconnects.
auto promise = obj->loop(true)
.exclusiveJoin(connection.whenWriteDisconnected().then([]() {return false;}));
// Eagerly evaluate so that we drop the connection when the promise resolves, even if the caller // Eagerly evaluate so that we drop the connection when the promise resolves, even if the caller
// doesn't eagerly evaluate. // doesn't eagerly evaluate.
......
...@@ -489,6 +489,17 @@ public: ...@@ -489,6 +489,17 @@ public:
// shutdown, but is sometimes useful when you want the other end to trigger whatever behavior // shutdown, but is sometimes useful when you want the other end to trigger whatever behavior
// it normally triggers when a connection is dropped. // it normally triggers when a connection is dropped.
virtual void abort() = 0;
// Forcefully close this WebSocket, such that the remote end should get a DISCONNECTED error if
// it continues to write. This differs from disconnect(), which only closes the sending
// direction, but still allows receives.
virtual kj::Promise<void> whenAborted() = 0;
// Resolves when the remote side aborts the connection such that send() would throw DISCONNECTED,
// if this can be detected without actually writing a message. (If not, this promise never
// resolves, but send() or receive() will throw DISCONNECTED when appropriate. See also
// kj::AsyncOutputStream::whenWriteDisconnected().)
struct Close { struct Close {
uint16_t code; uint16_t code;
kj::String reason; kj::String reason;
...@@ -629,6 +640,9 @@ public: ...@@ -629,6 +640,9 @@ public:
// //
// `url` and `headers` are invalidated on the first read from `requestBody` or when the returned // `url` and `headers` are invalidated on the first read from `requestBody` or when the returned
// promise resolves, whichever comes first. // promise resolves, whichever comes first.
//
// Request processing can be canceled by dropping the returned promise. HttpServer may do so if
// the client disconnects prematurely.
virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host); virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host);
// Handles CONNECT requests. Only relevant for proxy services. Default implementation throws // Handles CONNECT requests. Only relevant for proxy services. Default implementation throws
......
...@@ -180,6 +180,10 @@ public: ...@@ -180,6 +180,10 @@ public:
return writeInternal(pieces[0], pieces.slice(1, pieces.size())); return writeInternal(pieces[0], pieces.slice(1, pieces.size()));
} }
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
void shutdownWrite() override { void shutdownWrite() override {
KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()"); KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment