Commit d32b388c authored by Kenton Varda's avatar Kenton Varda

Add AsyncOutputStream::whenWriteDisconnected().

This will allow the HTTP server to detect when a client has disconnected and cancel processing the response.
parent 3fcd0f46
......@@ -2103,5 +2103,94 @@ KJ_TEST("Userland tee buffer size limit") {
}
}
KJ_TEST("Userspace OneWayPipe whenWriteDisconnected()") {
kj::EventLoop loop;
WaitScope ws(loop);
auto pipe = newOneWayPipe();
auto abortedPromise = pipe.out->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(ws));
pipe.in = nullptr;
KJ_ASSERT(abortedPromise.poll(ws));
abortedPromise.wait(ws);
}
KJ_TEST("Userspace TwoWayPipe whenWriteDisconnected()") {
kj::EventLoop loop;
WaitScope ws(loop);
auto pipe = newTwoWayPipe();
auto abortedPromise = pipe.ends[0]->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(ws));
pipe.ends[1] = nullptr;
KJ_ASSERT(abortedPromise.poll(ws));
abortedPromise.wait(ws);
}
#if !_WIN32 // IOCP doesn't support detecting disconnect AFAICT.
KJ_TEST("OS OneWayPipe whenWriteDisconnected()") {
auto io = setupAsyncIo();
auto pipe = io.provider->newOneWayPipe();
pipe.out->write("foo", 3).wait(io.waitScope);
auto abortedPromise = pipe.out->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(io.waitScope));
pipe.in = nullptr;
KJ_ASSERT(abortedPromise.poll(io.waitScope));
abortedPromise.wait(io.waitScope);
}
KJ_TEST("OS TwoWayPipe whenWriteDisconnected()") {
auto io = setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
pipe.ends[0]->write("foo", 3).wait(io.waitScope);
pipe.ends[1]->write("bar", 3).wait(io.waitScope);
auto abortedPromise = pipe.ends[0]->whenWriteDisconnected();
KJ_ASSERT(!abortedPromise.poll(io.waitScope));
pipe.ends[1] = nullptr;
KJ_ASSERT(abortedPromise.poll(io.waitScope));
abortedPromise.wait(io.waitScope);
char buffer[4];
KJ_ASSERT(pipe.ends[0]->tryRead(&buffer, sizeof(buffer), sizeof(buffer)).wait(io.waitScope) == 3);
buffer[3] = '\0';
KJ_EXPECT(buffer == "bar"_kj);
}
KJ_TEST("import socket FD that's already broken") {
auto io = setupAsyncIo();
int fds[2];
KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
KJ_SYSCALL(write(fds[1], "foo", 3));
KJ_SYSCALL(close(fds[1]));
auto stream = io.lowLevelProvider->wrapSocketFd(fds[0], LowLevelAsyncIoProvider::TAKE_OWNERSHIP);
auto abortedPromise = stream->whenWriteDisconnected();
KJ_ASSERT(abortedPromise.poll(io.waitScope));
abortedPromise.wait(io.waitScope);
char buffer[4];
KJ_ASSERT(stream->tryRead(&buffer, sizeof(buffer), sizeof(buffer)).wait(io.waitScope) == 3);
buffer[3] = '\0';
KJ_EXPECT(buffer == "foo"_kj);
}
#endif
} // namespace
} // namespace kj
......@@ -179,6 +179,17 @@ public:
}
}
Promise<void> whenWriteDisconnected() override {
KJ_IF_MAYBE(p, writeDisconnectedPromise) {
return p->addBranch();
} else {
auto fork = observer.whenWriteDisconnected().fork();
auto result = fork.addBranch();
writeDisconnectedPromise = kj::mv(fork);
return kj::mv(result);
}
}
void shutdownWrite() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// UnixAsyncIoProvider interface.
......@@ -290,6 +301,7 @@ public:
private:
UnixEventPort& eventPort;
UnixEventPort::FdObserver observer;
Maybe<ForkedPromise<void>> writeDisconnectedPromise;
Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
size_t alreadyRead) {
......
......@@ -311,6 +311,23 @@ public:
});
}
Promise<void> whenWriteDisconnected() override {
// Windows IOCP does not provide a direct, documented way to detect when the socket disconnects
// without actually doing a read or write. However, there is an undocoumented-but-stable
// ioctl called IOCTL_AFD_POLL which can be used for this purpose. In fact, select() is
// implemented in terms of this ioctl -- performed synchronously -- but it's entirely possible
// to put only one socket into the list and perform the ioctl asynchronously. Here's the
// source code for select() in Windows 2000 (not sure how this became public...):
//
// https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655
//
// And here's an interesting discussion: https://github.com/python-trio/trio/issues/52
//
// TODO(soon): Implement this with IOCTL_AFD_POLL. For now I'm leaving it unimplemented because
// I added this method for a Linux-only use case.
return NEVER_DONE;
}
void shutdownWrite() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// Win32AsyncIoProvider interface.
......
......@@ -228,6 +228,12 @@ public:
} else {
ownState = kj::heap<AbortedRead>();
state = *ownState;
readAborted = true;
KJ_IF_MAYBE(f, readAbortFulfiller) {
f->get()->fulfill();
readAbortFulfiller = nullptr;
}
}
}
......@@ -268,6 +274,21 @@ public:
}
}
Promise<void> whenWriteDisconnected() override {
if (readAborted) {
return kj::READY_NOW;
} else KJ_IF_MAYBE(p, readAbortPromise) {
return p->addBranch();
} else {
auto paf = newPromiseAndFulfiller<void>();
readAbortFulfiller = kj::mv(paf.fulfiller);
auto fork = paf.promise.fork();
auto result = fork.addBranch();
readAbortPromise = kj::mv(fork);
return result;
}
}
void shutdownWrite() override {
KJ_IF_MAYBE(s, state) {
s->shutdownWrite();
......@@ -285,6 +306,10 @@ private:
kj::Own<AsyncIoStream> ownState;
bool readAborted = false;
Maybe<Own<PromiseFulfiller<void>>> readAbortFulfiller = nullptr;
Maybe<ForkedPromise<void>> readAbortPromise = nullptr;
void endState(AsyncIoStream& obj) {
KJ_IF_MAYBE(s, state) {
if (s == &obj) {
......@@ -443,6 +468,10 @@ private:
KJ_FAIL_REQUIRE("can't shutdownWrite() until previous write() completes");
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private:
PromiseFulfiller<void>& fulfiller;
AsyncPipe& pipe;
......@@ -562,6 +591,10 @@ private:
KJ_FAIL_REQUIRE("can't shutdownWrite() until previous tryPumpFrom() completes");
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private:
PromiseFulfiller<uint64_t>& fulfiller;
AsyncPipe& pipe;
......@@ -733,6 +766,10 @@ private:
pipe.shutdownWrite();
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private:
PromiseFulfiller<size_t>& fulfiller;
AsyncPipe& pipe;
......@@ -901,6 +938,10 @@ private:
pipe.shutdownWrite();
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
private:
PromiseFulfiller<uint64_t>& fulfiller;
AsyncPipe& pipe;
......@@ -937,6 +978,9 @@ private:
// ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped,
// which is not an error even if reads have been aborted.
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
};
class ShutdownedWrite final: public AsyncIoStream {
......@@ -966,6 +1010,9 @@ private:
// ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped,
// so it will only be called once anyhow.
}
Promise<void> whenWriteDisconnected() override {
KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe");
}
};
};
......@@ -1013,6 +1060,10 @@ public:
return pipe->tryPumpFrom(input, amount);
}
Promise<void> whenWriteDisconnected() override {
return pipe->whenWriteDisconnected();
}
private:
Own<AsyncPipe> pipe;
UnwindDetector unwind;
......@@ -1049,6 +1100,9 @@ public:
AsyncInputStream& input, uint64_t amount) override {
return out->tryPumpFrom(input, amount);
}
Promise<void> whenWriteDisconnected() override {
return out->whenWriteDisconnected();
}
void shutdownWrite() override {
out->shutdownWrite();
}
......
......@@ -106,6 +106,17 @@ public:
// output stream. If it finds one, it performs the pump. Otherwise, it returns null.
//
// The default implementation always returns null.
virtual Promise<void> whenWriteDisconnected() = 0;
// Returns a promise that resolves when the stream has become disconnected such that new write()s
// will fail with a DISCONNECTED exception. This is particularly useful, for example, to cancel
// work early when it is detected that no one will receive the result.
//
// Note that not all streams are able to detect this condition without actually performing a
// write(); such stream implementations may return a promise that never resolves.
//
// Unlike most other asynchronous stream methods, it is safe to call whenWriteDisconnected()
// multiple times without canceling the previous promises.
};
class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream {
......
......@@ -365,6 +365,13 @@ void UnixEventPort::FdObserver::fire(short events) {
}
}
if (events & (EPOLLHUP | EPOLLERR)) {
KJ_IF_MAYBE(f, hupFulfiller) {
f->get()->fulfill();
hupFulfiller = nullptr;
}
}
if (events & EPOLLPRI) {
KJ_IF_MAYBE(f, urgentFulfiller) {
f->get()->fulfill();
......@@ -398,6 +405,12 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() {
return kj::mv(paf.promise);
}
Promise<void> UnixEventPort::FdObserver::whenWriteDisconnected() {
auto paf = newPromiseAndFulfiller<void>();
hupFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
bool UnixEventPort::wait() {
return doEpollWait(
timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue))
......@@ -652,6 +665,13 @@ void UnixEventPort::FdObserver::fire(short events) {
}
}
if (events & (POLLHUP | POLLERR | POLLNVAL)) {
KJ_IF_MAYBE(f, hupFulfiller) {
f->get()->fulfill();
hupFulfiller = nullptr;
}
}
if (events & POLLPRI) {
KJ_IF_MAYBE(f, urgentFulfiller) {
f->get()->fulfill();
......@@ -724,6 +744,19 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() {
return kj::mv(paf.promise);
}
Promise<void> UnixEventPort::FdObserver::whenWriteDisconnected() {
if (prev == nullptr) {
KJ_DASSERT(next == nullptr);
prev = eventPort.observersTail;
*prev = this;
eventPort.observersTail = &next;
}
auto paf = newPromiseAndFulfiller<void>();
hupFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
class UnixEventPort::PollContext {
public:
PollContext(FdObserver* ptr) {
......
......@@ -278,6 +278,9 @@ public:
// WARNING: This has some known weird behavior on macOS. See
// https://github.com/sandstorm-io/capnproto/issues/374.
Promise<void> whenWriteDisconnected();
// Resolves when poll() on the file descriptor reports POLLHUP or POLLERR.
private:
UnixEventPort& eventPort;
int fd;
......@@ -286,6 +289,7 @@ private:
kj::Maybe<Own<PromiseFulfiller<void>>> readFulfiller;
kj::Maybe<Own<PromiseFulfiller<void>>> writeFulfiller;
kj::Maybe<Own<PromiseFulfiller<void>>> urgentFulfiller;
kj::Maybe<Own<PromiseFulfiller<void>>> hupFulfiller;
// Replaced each time `whenBecomesReadable()` or `whenBecomesWritable()` is called. Reverted to
// null every time an event is fired.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment