Commit f044b0a6 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #386 from sandstorm-io/async-win32

Support RPC on Windows (MinGW)
parents 907b508e eff2da3d
......@@ -139,6 +139,7 @@ includekj_HEADERS = \
src/kj/async-inl.h \
src/kj/time.h \
src/kj/async-unix.h \
src/kj/async-win32.h \
src/kj/async-io.h \
src/kj/main.h \
src/kj/test.h \
......@@ -225,12 +226,14 @@ libkj_test_la_LDFLAGS = -release $(VERSION) -no-undefined
libkj_test_la_SOURCES = src/kj/test.c++
if !LITE_MODE
libkj_async_la_LIBADD = libkj.la $(PTHREAD_LIBS)
libkj_async_la_LIBADD = libkj.la $(ASYNC_LIBS) $(PTHREAD_LIBS)
libkj_async_la_LDFLAGS = -release $(SO_VERSION) -no-undefined
libkj_async_la_SOURCES= \
src/kj/async.c++ \
src/kj/async-unix.c++ \
src/kj/async-win32.c++ \
src/kj/async-io.c++ \
src/kj/async-io-win32.c++ \
src/kj/time.c++
endif !LITE_MODE
......@@ -260,7 +263,7 @@ libcapnp_la_SOURCES= \
if !LITE_MODE
libcapnp_rpc_la_LIBADD = libcapnp.la libkj-async.la libkj.la $(PTHREAD_LIBS)
libcapnp_rpc_la_LIBADD = libcapnp.la libkj-async.la libkj.la $(ASYNC_LIBS) $(PTHREAD_LIBS)
libcapnp_rpc_la_LDFLAGS = -release $(SO_VERSION) -no-undefined
libcapnp_rpc_la_SOURCES= \
src/capnp/serialize-async.c++ \
......@@ -386,6 +389,7 @@ check_PROGRAMS = capnp-test capnp-evolution-test
heavy_tests = \
src/kj/async-test.c++ \
src/kj/async-unix-test.c++ \
src/kj/async-win32-test.c++ \
src/kj/async-io-test.c++ \
src/kj/parse/common-test.c++ \
src/kj/parse/char-test.c++ \
......
......@@ -57,11 +57,15 @@ AS_CASE("${host_os}", *mingw*, [
PTHREAD_CFLAGS="-mthreads"
PTHREAD_LIBS=""
PTHREAD_CC=""
ASYNC_LIBS="-lws2_32"
AC_SUBST(PTHREAD_LIBS)
AC_SUBST(PTHREAD_CFLAGS)
AC_SUBST(PTHREAD_CC)
AC_SUBST(ASYNC_LIBS)
], *, [
ACX_PTHREAD
ASYNC_LIBS=""
AC_SUBST(ASYNC_LIBS)
])
LT_INIT
......
......@@ -28,6 +28,12 @@
#include <kj/miniposix.h>
#include "test-util.h"
#if !_WIN32
// This test is super-slow on Windows seemingly due to generating exception stack traces being
// expensive.
//
// TODO(perf): Maybe create an API to disable stack traces, and use it here.
namespace capnp {
namespace _ { // private
namespace {
......@@ -257,3 +263,5 @@ KJ_TEST("fuzz-test double-far pointer") {
} // namespace
} // namespace _ (private)
} // namespace capnp
#endif
......@@ -22,7 +22,6 @@
#include "rpc-twoparty.h"
#include "test-util.h"
#include <capnp/rpc.capnp.h>
#include <kj/async-unix.h>
#include <kj/debug.h>
#include <kj/thread.h>
#include <kj/compat/gtest.h>
......
......@@ -27,10 +27,22 @@
#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/socket.h>
#include "test-util.h"
#include <kj/compat/gtest.h>
#if _WIN32
#define WIN32_LEAN_AND_MEAN
#include <winsock2.h>
#include <kj/windows-sanity.h>
namespace kj {
namespace _ {
int win32Socketpair(SOCKET socks[2]);
}
}
#else
#include <sys/socket.h>
#endif
namespace capnp {
namespace _ { // private
namespace {
......@@ -86,11 +98,21 @@ private:
class PipeWithSmallBuffer {
public:
#ifdef _WIN32
#define KJ_SOCKCALL KJ_WINSOCK
#ifndef SHUT_WR
#define SHUT_WR SD_SEND
#endif
#define socketpair(family, type, flags, fds) kj::_::win32Socketpair(fds)
#else
#define KJ_SOCKCALL KJ_SYSCALL
#endif
PipeWithSmallBuffer() {
// Use a socketpair rather than a pipe so that we can set the buffer size extremely small.
KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
KJ_SOCKCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
KJ_SYSCALL(shutdown(fds[0], SHUT_WR));
KJ_SOCKCALL(shutdown(fds[0], SHUT_WR));
// Note: OSX reports ENOTCONN if we also try to shutdown(fds[1], SHUT_RD).
// Request that the buffer size be as small as possible, to force the event loop to kick in.
......@@ -106,25 +128,82 @@ public:
// Anyway, we now use 127 to avoid these issues (but also to screw around with non-word-boundary
// writes).
uint small = 127;
KJ_SYSCALL(setsockopt(fds[0], SOL_SOCKET, SO_RCVBUF, &small, sizeof(small)));
KJ_SYSCALL(setsockopt(fds[1], SOL_SOCKET, SO_SNDBUF, &small, sizeof(small)));
KJ_SOCKCALL(setsockopt(fds[0], SOL_SOCKET, SO_RCVBUF, (const char*)&small, sizeof(small)));
KJ_SOCKCALL(setsockopt(fds[1], SOL_SOCKET, SO_SNDBUF, (const char*)&small, sizeof(small)));
}
~PipeWithSmallBuffer() {
#if _WIN32
closesocket(fds[0]);
closesocket(fds[1]);
#else
close(fds[0]);
close(fds[1]);
#endif
}
inline int operator[](uint index) { return fds[index]; }
private:
#ifdef _WIN32
SOCKET fds[2];
#else
int fds[2];
#endif
};
#if _WIN32
// Sockets on win32 are not file descriptors. Ugh.
//
// TODO(cleanup): Maybe put these somewhere reusable? kj/io.h is inappropriate since we don't
// really want to link against winsock.
class SocketOutputStream: public kj::OutputStream {
public:
explicit SocketOutputStream(SOCKET fd): fd(fd) {}
void write(const void* buffer, size_t size) override {
const char* ptr = reinterpret_cast<const char*>(buffer);
while (size > 0) {
ssize_t n;
KJ_SOCKCALL(n = send(fd, ptr, size, 0));
size -= n;
ptr += n;
}
}
private:
SOCKET fd;
};
class SocketInputStream: public kj::InputStream {
public:
explicit SocketInputStream(SOCKET fd): fd(fd) {}
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
char* ptr = reinterpret_cast<char*>(buffer);
size_t total = 0;
while (total < minBytes) {
ssize_t n;
KJ_SOCKCALL(n = recv(fd, ptr, maxBytes, 0));
total += n;
maxBytes -= n;
ptr += n;
}
}
private:
SOCKET fd;
};
#else // _WIN32
typedef kj::FdOutputStream SocketOutputStream;
typedef kj::FdInputStream SocketInputStream;
#endif // _WIN32, else
TEST(SerializeAsyncTest, ParseAsync) {
PipeWithSmallBuffer fds;
auto ioContext = kj::setupAsyncIo();
auto input = ioContext.lowLevelProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]);
SocketOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput);
TestMessageBuilder message(1);
......@@ -143,7 +222,7 @@ TEST(SerializeAsyncTest, ParseAsyncOddSegmentCount) {
PipeWithSmallBuffer fds;
auto ioContext = kj::setupAsyncIo();
auto input = ioContext.lowLevelProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]);
SocketOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput);
TestMessageBuilder message(7);
......@@ -162,7 +241,7 @@ TEST(SerializeAsyncTest, ParseAsyncEvenSegmentCount) {
PipeWithSmallBuffer fds;
auto ioContext = kj::setupAsyncIo();
auto input = ioContext.lowLevelProvider->wrapInputFd(fds[0]);
kj::FdOutputStream rawOutput(fds[1]);
SocketOutputStream rawOutput(fds[1]);
FragmentingOutputStream output(rawOutput);
TestMessageBuilder message(10);
......@@ -190,7 +269,8 @@ TEST(SerializeAsyncTest, WriteAsync) {
}
kj::Thread thread([&]() {
StreamFdMessageReader reader(fds[0]);
SocketInputStream input(fds[0]);
InputStreamMessageReader reader(input);
auto listReader = reader.getRoot<TestAllTypes>().getStructList();
EXPECT_EQ(list.size(), listReader.size());
for (auto element: listReader) {
......@@ -214,7 +294,8 @@ TEST(SerializeAsyncTest, WriteAsyncOddSegmentCount) {
}
kj::Thread thread([&]() {
StreamFdMessageReader reader(fds[0]);
SocketInputStream input(fds[0]);
InputStreamMessageReader reader(input);
auto listReader = reader.getRoot<TestAllTypes>().getStructList();
EXPECT_EQ(list.size(), listReader.size());
for (auto element: listReader) {
......@@ -238,7 +319,8 @@ TEST(SerializeAsyncTest, WriteAsyncEvenSegmentCount) {
}
kj::Thread thread([&]() {
StreamFdMessageReader reader(fds[0]);
SocketInputStream input(fds[0]);
InputStreamMessageReader reader(input);
auto listReader = reader.getRoot<TestAllTypes>().getStructList();
EXPECT_EQ(list.size(), listReader.size());
for (auto element: listReader) {
......
......@@ -20,12 +20,15 @@
// THE SOFTWARE.
#include "async-io.h"
#include "async-unix.h"
#include "debug.h"
#include <kj/compat/gtest.h>
#include <sys/types.h>
#include <sys/socket.h>
#if _WIN32
#include <ws2tcpip.h>
#include "windows-sanity.h"
#else
#include <netdb.h>
#endif
namespace kj {
namespace {
......@@ -95,7 +98,9 @@ TEST(AsyncIo, AddressParsing) {
EXPECT_EQ("0.0.0.0:0", tryParse(w, network, "0.0.0.0"));
EXPECT_EQ("1.2.3.4:5678", tryParse(w, network, "1.2.3.4", 5678));
#if !_WIN32
EXPECT_EQ("unix:foo/bar/baz", tryParse(w, network, "unix:foo/bar/baz"));
#endif
// We can parse services by name...
#if !__ANDROID__ // Service names not supported on Android for some reason?
......@@ -220,6 +225,8 @@ TEST(AsyncIo, Timeouts) {
EXPECT_EQ(123, promise2.wait(ioContext.waitScope));
}
#if !_WIN32 // datagrams not implemented on win32 yet
TEST(AsyncIo, Udp) {
auto ioContext = setupAsyncIo();
......@@ -366,5 +373,7 @@ TEST(AsyncIo, Udp) {
}
}
#endif // !_WIN32
} // namespace
} // namespace kj
// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if _WIN32
// For Unix implementation, see async-io.c++.
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#include "async-io.h"
#include "async-win32.h"
#include "debug.h"
#include "thread.h"
#include "io.h"
#include "vector.h"
#include <set>
#include <winsock2.h>
#include <ws2ipdef.h>
#include <ws2tcpip.h>
#include <mswsock.h>
#ifndef IPV6_V6ONLY
// MinGW's headers are missing this.
#define IPV6_V6ONLY 27
#endif
namespace kj {
namespace _ { // private
int win32Socketpair(SOCKET socks[2]) {
// This function from: https://github.com/ncm/selectable-socketpair/blob/master/socketpair.c
//
// Copyright notice:
//
// Copyright 2007, 2010 by Nathan C. Myers <ncm@cantrip.org>
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// The name of the author must not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Note: This function is called from some Cap'n Proto unit tests, despite not having a public
// header declaration.
// TODO(cleanup): Consider putting this somewhere public? Note that since it depends on Winsock,
// it needs to be in the kj-async library.
union {
struct sockaddr_in inaddr;
struct sockaddr addr;
} a;
SOCKET listener;
int e;
socklen_t addrlen = sizeof(a.inaddr);
int reuse = 1;
if (socks == 0) {
WSASetLastError(WSAEINVAL);
return SOCKET_ERROR;
}
socks[0] = socks[1] = -1;
listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (listener == -1)
return SOCKET_ERROR;
memset(&a, 0, sizeof(a));
a.inaddr.sin_family = AF_INET;
a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
a.inaddr.sin_port = 0;
for (;;) {
if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR,
(char*) &reuse, (socklen_t) sizeof(reuse)) == -1)
break;
if (bind(listener, &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
break;
memset(&a, 0, sizeof(a));
if (getsockname(listener, &a.addr, &addrlen) == SOCKET_ERROR)
break;
// win32 getsockname may only set the port number, p=0.0005.
// ( http://msdn.microsoft.com/library/ms738543.aspx ):
a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
a.inaddr.sin_family = AF_INET;
if (listen(listener, 1) == SOCKET_ERROR)
break;
socks[0] = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
if (socks[0] == -1)
break;
if (connect(socks[0], &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
break;
socks[1] = accept(listener, NULL, NULL);
if (socks[1] == -1)
break;
closesocket(listener);
return 0;
}
e = WSAGetLastError();
closesocket(listener);
closesocket(socks[0]);
closesocket(socks[1]);
WSASetLastError(e);
socks[0] = socks[1] = -1;
return SOCKET_ERROR;
}
} // namespace _
namespace {
bool detectWine() {
HMODULE hntdll = GetModuleHandle("ntdll.dll");
if(hntdll == NULL) return false;
return GetProcAddress(hntdll, "wine_get_version") != nullptr;
}
bool isWine() {
static bool result = detectWine();
return result;
}
// =======================================================================================
static constexpr uint NEW_FD_FLAGS = LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
class OwnedFd {
public:
OwnedFd(SOCKET fd, uint flags): fd(fd), flags(flags) {
// TODO(perf): Maybe use SetFileCompletionNotificationModes() to tell Windows not to bother
// delivering an event when the operation completes inline. Not currently implemented on
// Wine, though.
}
~OwnedFd() noexcept(false) {
if (flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) {
KJ_WINSOCK(closesocket(fd)) { break; }
}
static_assert(sizeof(SOCKET) == 8, "nope");
}
protected:
SOCKET fd;
private:
uint flags;
};
// =======================================================================================
class AsyncStreamFd: public OwnedFd, public AsyncIoStream {
public:
AsyncStreamFd(Win32EventPort& eventPort, SOCKET fd, uint flags)
: OwnedFd(fd, flags),
observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))) {}
virtual ~AsyncStreamFd() noexcept(false) {}
Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
return tryRead(buffer, minBytes, maxBytes).then([=](size_t result) {
KJ_REQUIRE(result >= minBytes, "Premature EOF") {
// Pretend we read zeros from the input.
memset(reinterpret_cast<byte*>(buffer) + result, 0, minBytes - result);
return minBytes;
}
return result;
});
}
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
auto bufs = heapArray<WSABUF>(1);
bufs[0].buf = reinterpret_cast<char*>(buffer);
bufs[0].len = maxBytes;
ArrayPtr<WSABUF> ref = bufs;
return tryReadInternal(ref, minBytes, 0).attach(kj::mv(bufs));
}
Promise<void> write(const void* buffer, size_t size) override {
auto bufs = heapArray<WSABUF>(1);
bufs[0].buf = const_cast<char*>(reinterpret_cast<const char*>(buffer));
bufs[0].len = size;
ArrayPtr<WSABUF> ref = bufs;
return writeInternal(ref).attach(kj::mv(bufs));
}
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
auto bufs = heapArray<WSABUF>(pieces.size());
for (auto i: kj::indices(pieces)) {
bufs[i].buf = const_cast<char*>(pieces[i].asChars().begin());
bufs[i].len = pieces[i].size();
}
ArrayPtr<WSABUF> ref = bufs;
return writeInternal(ref).attach(kj::mv(bufs));
}
kj::Promise<void> connect(const struct sockaddr* addr, uint addrlen) {
// In order to connect asynchronously, we need the ConnectEx() function. Apparently, we have
// to query the socket for it dynamically, I guess because of the insanity in which winsock
// can be implemented in userspace and old implementations may not support it.
GUID guid = WSAID_CONNECTEX;
LPFN_CONNECTEX connectEx = nullptr;
DWORD n = 0;
KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
&connectEx, sizeof(connectEx), &n, NULL, NULL)) {
goto fail; // avoid memory leak due to compiler bugs
}
if (false) {
fail:
return kj::READY_NOW;
}
// OK, phew, we now have our ConnectEx function pointer. Call it.
auto op = observer->newOperation(0);
if (!connectEx(fd, addr, addrlen, NULL, 0, NULL, op->getOverlapped())) {
DWORD error = WSAGetLastError();
if (error != ERROR_IO_PENDING) {
KJ_FAIL_WIN32("ConnectEx()", error) { break; }
return kj::READY_NOW;
}
}
return op->onComplete().then([this](Win32EventPort::IoResult result) {
if (result.errorCode != ERROR_SUCCESS) {
KJ_FAIL_WIN32("ConnectEx()", result.errorCode) { return; }
}
// Enable shutdown() to work.
setsockopt(SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0);
});
}
void shutdownWrite() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// Win32AsyncIoProvider interface.
KJ_WINSOCK(shutdown(fd, SD_SEND));
}
void abortRead() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// Win32AsyncIoProvider interface.
KJ_WINSOCK(shutdown(fd, SD_RECEIVE));
}
void getsockopt(int level, int option, void* value, uint* length) override {
socklen_t socklen = *length;
KJ_WINSOCK(::getsockopt(fd, level, option,
reinterpret_cast<char*>(value), &socklen));
*length = socklen;
}
void setsockopt(int level, int option, const void* value, uint length) override {
KJ_WINSOCK(::setsockopt(fd, level, option,
reinterpret_cast<const char*>(value), length));
}
void getsockname(struct sockaddr* addr, uint* length) override {
socklen_t socklen = *length;
KJ_WINSOCK(::getsockname(fd, addr, &socklen));
*length = socklen;
}
void getpeername(struct sockaddr* addr, uint* length) override {
socklen_t socklen = *length;
KJ_WINSOCK(::getpeername(fd, addr, &socklen));
*length = socklen;
}
private:
Own<Win32EventPort::IoObserver> observer;
Promise<size_t> tryReadInternal(ArrayPtr<WSABUF> bufs, size_t minBytes, size_t alreadyRead) {
// `bufs` will remain valid until the promise completes and may be freely modified.
//
// `alreadyRead` is the number of bytes we have already received via previous reads -- minBytes
// and buffer have already been adjusted to account for them, but this count must be included
// in the final return value.
auto op = observer->newOperation(0);
DWORD flags = 0;
if (WSARecv(fd, bufs.begin(), bufs.size(), NULL, &flags,
op->getOverlapped(), NULL) == SOCKET_ERROR) {
DWORD error = WSAGetLastError();
if (error != WSA_IO_PENDING) {
KJ_FAIL_WIN32("WSARecv()", error) { break; }
return alreadyRead;
}
}
return op->onComplete()
.then([this,bufs,minBytes,alreadyRead](Win32IocpEventPort::IoResult result) mutable
-> Promise<size_t> {
if (result.errorCode != ERROR_SUCCESS) {
if (alreadyRead > 0) {
// Report what we already read.
return alreadyRead;
} else {
KJ_FAIL_WIN32("WSARecv()", result.errorCode) { break; }
return size_t(0);
}
}
if (result.bytesTransferred == 0) {
return alreadyRead;
}
alreadyRead += result.bytesTransferred;
if (result.bytesTransferred >= minBytes) {
// We can stop here.
return alreadyRead;
}
minBytes -= result.bytesTransferred;
while (result.bytesTransferred >= bufs[0].len) {
result.bytesTransferred -= bufs[0].len;
bufs = bufs.slice(1, bufs.size());
}
if (result.bytesTransferred > 0) {
bufs[0].buf += result.bytesTransferred;
bufs[0].len -= result.bytesTransferred;
}
return tryReadInternal(bufs, minBytes, alreadyRead);
}).attach(kj::mv(bufs));
}
Promise<void> writeInternal(ArrayPtr<WSABUF> bufs) {
// `bufs` will remain valid until the promise completes and may be freely modified.
auto op = observer->newOperation(0);
if (WSASend(fd, bufs.begin(), bufs.size(), NULL, 0,
op->getOverlapped(), NULL) == SOCKET_ERROR) {
DWORD error = WSAGetLastError();
if (error != WSA_IO_PENDING) {
KJ_FAIL_WIN32("WSASend()", error) { break; }
return kj::READY_NOW;
}
}
return op->onComplete()
.then([this,bufs](Win32IocpEventPort::IoResult result) mutable -> Promise<void> {
if (result.errorCode != ERROR_SUCCESS) {
KJ_FAIL_WIN32("WSASend()", result.errorCode) { break; }
return kj::READY_NOW;
}
while (bufs.size() > 0 && result.bytesTransferred >= bufs[0].len) {
result.bytesTransferred -= bufs[0].len;
bufs = bufs.slice(1, bufs.size());
}
if (result.bytesTransferred > 0) {
bufs[0].buf += result.bytesTransferred;
bufs[0].len -= result.bytesTransferred;
}
if (bufs.size() > 0) {
return writeInternal(bufs);
} else {
return kj::READY_NOW;
}
}).attach(kj::mv(bufs));
}
};
// =======================================================================================
class SocketAddress {
public:
SocketAddress(const void* sockaddr, uint len): addrlen(len) {
KJ_REQUIRE(len <= sizeof(addr), "Sorry, your sockaddr is too big for me.");
memcpy(&addr.generic, sockaddr, len);
}
bool operator<(const SocketAddress& other) const {
// So we can use std::set<SocketAddress>... see DNS lookup code.
if (wildcard < other.wildcard) return true;
if (wildcard > other.wildcard) return false;
if (addrlen < other.addrlen) return true;
if (addrlen > other.addrlen) return false;
return memcmp(&addr.generic, &other.addr.generic, addrlen) < 0;
}
const struct sockaddr* getRaw() const { return &addr.generic; }
int getRawSize() const { return addrlen; }
SOCKET socket(int type) const {
bool isStream = type == SOCK_STREAM;
SOCKET result = ::socket(addr.generic.sa_family, type, 0);
if (result == INVALID_SOCKET) {
KJ_FAIL_WIN32("WSASocket()", WSAGetLastError()) { return INVALID_SOCKET; }
}
if (isStream && (addr.generic.sa_family == AF_INET ||
addr.generic.sa_family == AF_INET6)) {
// TODO(perf): As a hack for the 0.4 release we are always setting
// TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
// RPC protocol. Later, we should extend the interface to provide more
// control over this. Perhaps write() should have a flag which
// specifies whether to pass MSG_MORE.
BOOL one = TRUE;
KJ_WINSOCK(setsockopt(result, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)));
}
return result;
}
void bind(SOCKET sockfd) const {
if (wildcard) {
// Disable IPV6_V6ONLY because we want to handle both ipv4 and ipv6 on this socket. (The
// default value of this option varies across platforms.)
DWORD value = 0;
KJ_WINSOCK(setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY,
reinterpret_cast<char*>(&value), sizeof(value)));
}
KJ_WINSOCK(::bind(sockfd, &addr.generic, addrlen), toString());
}
uint getPort() const {
switch (addr.generic.sa_family) {
case AF_INET: return ntohs(addr.inet4.sin_port);
case AF_INET6: return ntohs(addr.inet6.sin6_port);
default: return 0;
}
}
String toString() const {
if (wildcard) {
return str("*:", getPort());
}
switch (addr.generic.sa_family) {
case AF_INET: {
char buffer[16];
if (InetNtopA(addr.inet4.sin_family, const_cast<struct in_addr*>(&addr.inet4.sin_addr),
buffer, sizeof(buffer)) == nullptr) {
KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
return heapString("(inet_ntop error)");
}
return str(buffer, ':', ntohs(addr.inet4.sin_port));
}
case AF_INET6: {
char buffer[46];
if (InetNtopA(addr.inet6.sin6_family, const_cast<struct in6_addr*>(&addr.inet6.sin6_addr),
buffer, sizeof(buffer)) == nullptr) {
KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
return heapString("(inet_ntop error)");
}
return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
}
default:
return str("(unknown address family ", addr.generic.sa_family, ")");
}
}
static Promise<Array<SocketAddress>> lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint);
// Perform a DNS lookup.
static Promise<Array<SocketAddress>> parse(
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) {
// TODO(someday): Allow commas in `str`.
SocketAddress result;
// Try to separate the address and port.
ArrayPtr<const char> addrPart;
Maybe<StringPtr> portPart;
int af;
if (str.startsWith("[")) {
// Address starts with a bracket, which is a common way to write an ip6 address with a port,
// since without brackets around the address part, the port looks like another segment of
// the address.
af = AF_INET6;
size_t closeBracket = KJ_ASSERT_NONNULL(str.findLast(']'),
"Unclosed '[' in address string.", str);
addrPart = str.slice(1, closeBracket);
if (str.size() > closeBracket + 1) {
KJ_REQUIRE(str.slice(closeBracket + 1).startsWith(":"),
"Expected port suffix after ']'.", str);
portPart = str.slice(closeBracket + 2);
}
} else {
KJ_IF_MAYBE(colon, str.findFirst(':')) {
if (str.slice(*colon + 1).findFirst(':') == nullptr) {
// There is exactly one colon and no brackets, so it must be an ip4 address with port.
af = AF_INET;
addrPart = str.slice(0, *colon);
portPart = str.slice(*colon + 1);
} else {
// There are two or more colons and no brackets, so the whole thing must be an ip6
// address with no port.
af = AF_INET6;
addrPart = str;
}
} else {
// No colons, so it must be an ip4 address without port.
af = AF_INET;
addrPart = str;
}
}
// Parse the port.
unsigned long port;
KJ_IF_MAYBE(portText, portPart) {
char* endptr;
port = strtoul(portText->cStr(), &endptr, 0);
if (portText->size() == 0 || *endptr != '\0') {
// Not a number. Maybe it's a service name. Fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint);
}
KJ_REQUIRE(port < 65536, "Port number too large.");
} else {
port = portHint;
}
// Check for wildcard.
if (addrPart.size() == 1 && addrPart[0] == '*') {
result.wildcard = true;
// Create an ip6 socket and set IPV6_V6ONLY to 0 later.
result.addrlen = sizeof(addr.inet6);
result.addr.inet6.sin6_family = AF_INET6;
result.addr.inet6.sin6_port = htons(port);
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
}
void* addrTarget;
if (af == AF_INET6) {
result.addrlen = sizeof(addr.inet6);
result.addr.inet6.sin6_family = AF_INET6;
result.addr.inet6.sin6_port = htons(port);
addrTarget = &result.addr.inet6.sin6_addr;
} else {
result.addrlen = sizeof(addr.inet4);
result.addr.inet4.sin_family = AF_INET;
result.addr.inet4.sin_port = htons(port);
addrTarget = &result.addr.inet4.sin_addr;
}
// addrPart is not necessarily NUL-terminated so we have to make a copy. :(
char buffer[64];
KJ_REQUIRE(addrPart.size() < sizeof(buffer) - 1, "IP address too long.", addrPart);
memcpy(buffer, addrPart.begin(), addrPart.size());
buffer[addrPart.size()] = '\0';
// OK, parse it!
switch (InetPtonA(af, buffer, addrTarget)) {
case 1: {
// success.
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
}
case 0:
// It's apparently not a simple address... fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port);
default:
KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart);
}
}
static SocketAddress getLocalAddress(int sockfd) {
SocketAddress result;
result.addrlen = sizeof(addr);
KJ_WINSOCK(getsockname(sockfd, &result.addr.generic, &result.addrlen));
return result;
}
static SocketAddress getWildcardForFamily(int family) {
SocketAddress result;
switch (family) {
case AF_INET:
result.addrlen = sizeof(addr.inet4);
result.addr.inet4.sin_family = AF_INET;
return result;
case AF_INET6:
result.addrlen = sizeof(addr.inet6);
result.addr.inet6.sin6_family = AF_INET6;
return result;
default:
KJ_FAIL_REQUIRE("unknown address family", family);
}
}
private:
SocketAddress(): addrlen(0) {
memset(&addr, 0, sizeof(addr));
}
socklen_t addrlen;
bool wildcard = false;
union {
struct sockaddr generic;
struct sockaddr_in inet4;
struct sockaddr_in6 inet6;
struct sockaddr_storage storage;
} addr;
struct LookupParams;
class LookupReader;
};
class SocketAddress::LookupReader {
// Reads SocketAddresses off of a pipe coming from another thread that is performing
// getaddrinfo.
public:
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input)
: thread(kj::mv(thread)), input(kj::mv(input)) {}
~LookupReader() {
if (thread) thread->detach();
}
Promise<Array<SocketAddress>> read() {
return input->tryRead(&current, sizeof(current), sizeof(current)).then(
[this](size_t n) -> Promise<Array<SocketAddress>> {
if (n < sizeof(current)) {
thread = nullptr;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; }
return addresses.releaseAsArray();
} else {
// getaddrinfo() can return multiple copies of the same address for several reasons.
// A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so
// it may return two copies of the same address, one for each type, unless it explicitly
// knows that the service name given is specific to one type. But we can't tell it a type,
// because we don't actually know which one the user wants, and if we specify SOCK_STREAM
// while the user specified a UDP service name then they'll get a resolution error which
// is lame. (At least, I think that's how it works.)
//
// So we instead resort to de-duping results.
if (alreadySeen.insert(current).second) {
addresses.add(current);
}
return read();
}
});
}
private:
kj::Own<Thread> thread;
kj::Own<AsyncInputStream> input;
SocketAddress current;
kj::Vector<SocketAddress> addresses;
std::set<SocketAddress> alreadySeen;
};
struct SocketAddress::LookupParams {
kj::String host;
kj::String service;
};
Promise<Array<SocketAddress>> SocketAddress::lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) {
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
//
// TODO(perf): Use GetAddrInfoEx(). But there are problems:
// - Not implemented in Wine.
// - Doesn't seem compatible with I/O completion ports, in particular because it's not associated
// with a handle. Could signal completion as an APC instead, but that requires the IOCP code
// to use GetQueuedCompletionStatusEx() which it doesn't right now becaues it's not available
// in Wine.
// - Requires Unicode, for some reason. Only GetAddrInfoExW() supports async, according to the
// docs. Never mind that DNS itself is ASCII...
SOCKET fds[2];
KJ_WINSOCK(_::win32Socketpair(fds));
auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS);
int outFd = fds[1];
LookupParams params = { kj::mv(host), kj::mv(service) };
auto thread = heap<Thread>(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) {
KJ_DEFER(closesocket(outFd));
struct addrinfo* list;
int status = getaddrinfo(
params.host == "*" ? nullptr : params.host.cStr(),
params.service == nullptr ? nullptr : params.service.cStr(),
nullptr, &list);
if (status == 0) {
KJ_DEFER(freeaddrinfo(list));
struct addrinfo* cur = list;
while (cur != nullptr) {
if (params.service == nullptr) {
switch (cur->ai_addr->sa_family) {
case AF_INET:
((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint);
break;
case AF_INET6:
((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint);
break;
default:
break;
}
}
SocketAddress addr;
memset(&addr, 0, sizeof(addr)); // mollify valgrind
if (params.host == "*") {
// Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo().
addr.wildcard = true;
addr.addrlen = sizeof(addr.addr.inet6);
addr.addr.inet6.sin6_family = AF_INET6;
switch (cur->ai_addr->sa_family) {
case AF_INET:
addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port;
break;
case AF_INET6:
addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port;
break;
default:
addr.addr.inet6.sin6_port = portHint;
break;
}
} else {
addr.addrlen = cur->ai_addrlen;
memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen);
}
static_assert(canMemcpy<SocketAddress>(), "Can't write() SocketAddress...");
const char* data = reinterpret_cast<const char*>(&addr);
size_t size = sizeof(addr);
while (size > 0) {
int n;
KJ_WINSOCK(n = send(outFd, data, size, 0));
data += n;
size -= n;
}
cur = cur->ai_next;
}
} else {
KJ_FAIL_WIN32("getaddrinfo()", status, params.host, params.service) {
return;
}
}
}));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input));
return reader->read().attach(kj::mv(reader));
}
// =======================================================================================
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd {
public:
FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd, uint flags)
: OwnedFd(fd, flags), eventPort(eventPort),
observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))),
address(SocketAddress::getLocalAddress(fd)) {
// In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
// to query the socket for it dynamically, I guess because of the insanity in which winsock
// can be implemented in userspace and old implementations may not support it.
GUID guid = WSAID_ACCEPTEX;
DWORD n = 0;
KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
&acceptEx, sizeof(acceptEx), &n, NULL, NULL)) {
acceptEx = nullptr;
return;
}
}
Promise<Own<AsyncIoStream>> accept() override {
SOCKET newFd = address.socket(SOCK_STREAM);
KJ_ASSERT(newFd != INVALID_SOCKET);
auto result = heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS);
auto scratch = heapArray<byte>(256);
DWORD dummy;
auto op = observer->newOperation(0);
if (!acceptEx(fd, newFd, scratch.begin(), 0, 128, 128, &dummy, op->getOverlapped())) {
DWORD error = WSAGetLastError();
if (error != ERROR_IO_PENDING) {
KJ_FAIL_WIN32("AcceptEx()", error) { break; }
return Own<AsyncIoStream>(kj::mv(result)); // dummy, won't be used
}
}
return op->onComplete().attach(kj::mv(scratch)).then(mvCapture(result,
[this](Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult) {
if (ioResult.errorCode != ERROR_SUCCESS) {
KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; }
} else {
SOCKET me = fd;
stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
reinterpret_cast<char*>(&me), sizeof(me));
}
return kj::mv(stream);
}));
}
uint getPort() override {
return address.getPort();
}
void getsockopt(int level, int option, void* value, uint* length) override {
socklen_t socklen = *length;
KJ_WINSOCK(::getsockopt(fd, level, option,
reinterpret_cast<char*>(value), &socklen));
*length = socklen;
}
void setsockopt(int level, int option, const void* value, uint length) override {
KJ_WINSOCK(::setsockopt(fd, level, option,
reinterpret_cast<const char*>(value), length));
}
public:
Win32EventPort& eventPort;
Own<Win32EventPort::IoObserver> observer;
LPFN_ACCEPTEX acceptEx = nullptr;
SocketAddress address;
};
// TODO(someday): DatagramPortImpl
class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
public:
LowLevelAsyncIoProviderImpl()
: eventLoop(eventPort), waitScope(eventLoop) {}
inline WaitScope& getWaitScope() { return waitScope; }
Own<AsyncInputStream> wrapInputFd(SOCKET fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Own<AsyncOutputStream> wrapOutputFd(SOCKET fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Own<AsyncIoStream> wrapSocketFd(SOCKET fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
SOCKET fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override {
auto result = heap<AsyncStreamFd>(eventPort, fd, flags);
// ConnectEx requires that the socket be bound, for some reason. Bind to an arbitrary port.
SocketAddress::getWildcardForFamily(addr->sa_family).bind(fd);
auto connected = result->connect(addr, addrlen);
return connected.then(kj::mvCapture(result, [](Own<AsyncIoStream>&& result) {
return kj::mv(result);
}));
}
Own<ConnectionReceiver> wrapListenSocketFd(SOCKET fd, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, flags);
}
Timer& getTimer() override { return eventPort.getTimer(); }
Win32EventPort& getEventPort() { return eventPort; }
private:
Win32IocpEventPort eventPort;
EventLoop eventLoop;
WaitScope waitScope;
};
// =======================================================================================
class NetworkAddressImpl final: public NetworkAddress {
public:
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array<SocketAddress> addrs)
: lowLevel(lowLevel), addrs(kj::mv(addrs)) {}
Promise<Own<AsyncIoStream>> connect() override {
auto addrsCopy = heapArray(addrs.asPtr());
auto promise = connectImpl(lowLevel, addrsCopy);
return promise.attach(kj::mv(addrsCopy));
}
Own<ConnectionReceiver> listen() override {
if (addrs.size() > 1) {
KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will "
"be used. If this is incorrect, specify the address numerically. This may be fixed "
"in the future.", addrs[0].toString());
}
int fd = addrs[0].socket(SOCK_STREAM);
{
KJ_ON_SCOPE_FAILURE(closesocket(fd));
// We always enable SO_REUSEADDR because having to take your server down for five minutes
// before it can restart really sucks.
int optval = 1;
KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
reinterpret_cast<char*>(&optval), sizeof(optval)));
addrs[0].bind(fd);
// TODO(someday): Let queue size be specified explicitly in string addresses.
KJ_WINSOCK(::listen(fd, SOMAXCONN));
}
return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS);
}
Own<DatagramPort> bindDatagramPort() override {
if (addrs.size() > 1) {
KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will "
"be used. If this is incorrect, specify the address numerically. This may be fixed "
"in the future.", addrs[0].toString());
}
int fd = addrs[0].socket(SOCK_DGRAM);
{
KJ_ON_SCOPE_FAILURE(closesocket(fd));
// We always enable SO_REUSEADDR because having to take your server down for five minutes
// before it can restart really sucks.
int optval = 1;
KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
reinterpret_cast<char*>(&optval), sizeof(optval)));
addrs[0].bind(fd);
}
return lowLevel.wrapDatagramSocketFd(fd, NEW_FD_FLAGS);
}
Own<NetworkAddress> clone() override {
return kj::heap<NetworkAddressImpl>(lowLevel, kj::heapArray(addrs.asPtr()));
}
String toString() override {
return strArray(KJ_MAP(addr, addrs) { return addr.toString(); }, ",");
}
const SocketAddress& chooseOneAddress() {
KJ_REQUIRE(addrs.size() > 0, "No addresses available.");
return addrs[counter++ % addrs.size()];
}
private:
LowLevelAsyncIoProvider& lowLevel;
Array<SocketAddress> addrs;
uint counter = 0;
static Promise<Own<AsyncIoStream>> connectImpl(
LowLevelAsyncIoProvider& lowLevel, ArrayPtr<SocketAddress> addrs) {
KJ_ASSERT(addrs.size() > 0);
int fd = addrs[0].socket(SOCK_STREAM);
return kj::evalNow([&]() {
return lowLevel.wrapConnectingSocketFd(
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
}).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
// Success, pass along.
return kj::mv(stream);
}, [&lowLevel,addrs](Exception&& exception) mutable -> Promise<Own<AsyncIoStream>> {
// Connect failed.
if (addrs.size() > 1) {
// Try the next address instead.
return connectImpl(lowLevel, addrs.slice(1, addrs.size()));
} else {
// No more addresses to try, so propagate the exception.
return kj::mv(exception);
}
});
}
};
class SocketNetwork final: public Network {
public:
explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
auto& lowLevelCopy = lowLevel;
return evalLater(mvCapture(heapString(addr),
[&lowLevelCopy,portHint](String&& addr) {
return SocketAddress::parse(lowLevelCopy, addr, portHint);
})).then([&lowLevelCopy](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevelCopy, kj::mv(addresses));
});
}
Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(SocketAddress(sockaddr, len));
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, array.finish()));
}
private:
LowLevelAsyncIoProvider& lowLevel;
};
// =======================================================================================
class AsyncIoProviderImpl final: public AsyncIoProvider {
public:
AsyncIoProviderImpl(LowLevelAsyncIoProvider& lowLevel)
: lowLevel(lowLevel), network(lowLevel) {}
OneWayPipe newOneWayPipe() override {
SOCKET fds[2];
KJ_WINSOCK(_::win32Socketpair(fds));
auto in = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
auto out = lowLevel.wrapOutputFd(fds[1], NEW_FD_FLAGS);
in->shutdownWrite();
return { kj::mv(in), kj::mv(out) };
}
TwoWayPipe newTwoWayPipe() override {
SOCKET fds[2];
KJ_WINSOCK(_::win32Socketpair(fds));
return TwoWayPipe { {
lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS),
lowLevel.wrapSocketFd(fds[1], NEW_FD_FLAGS)
} };
}
Network& getNetwork() override {
return network;
}
PipeThread newPipeThread(
Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
SOCKET fds[2];
KJ_WINSOCK(_::win32Socketpair(fds));
int threadFd = fds[1];
KJ_ON_SCOPE_FAILURE(closesocket(threadFd));
auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
auto thread = heap<Thread>(kj::mvCapture(startFunc,
[threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
LowLevelAsyncIoProviderImpl lowLevel;
auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
AsyncIoProviderImpl ioProvider(lowLevel);
startFunc(ioProvider, *stream, lowLevel.getWaitScope());
}));
return { kj::mv(thread), kj::mv(pipe) };
}
Timer& getTimer() override { return lowLevel.getTimer(); }
private:
LowLevelAsyncIoProvider& lowLevel;
SocketNetwork network;
};
} // namespace
Promise<void> AsyncInputStream::read(void* buffer, size_t bytes) {
return read(buffer, bytes, bytes).then([](size_t) {});
}
void AsyncIoStream::getsockopt(int level, int option, void* value, uint* length) {
KJ_UNIMPLEMENTED("Not a socket.");
}
void AsyncIoStream::setsockopt(int level, int option, const void* value, uint length) {
KJ_UNIMPLEMENTED("Not a socket.");
}
void AsyncIoStream::getsockname(struct sockaddr* addr, uint* length) {
KJ_UNIMPLEMENTED("Not a socket.");
}
void AsyncIoStream::getpeername(struct sockaddr* addr, uint* length) {
KJ_UNIMPLEMENTED("Not a socket.");
}
void ConnectionReceiver::getsockopt(int level, int option, void* value, uint* length) {
KJ_UNIMPLEMENTED("Not a socket.");
}
void ConnectionReceiver::setsockopt(int level, int option, const void* value, uint length) {
KJ_UNIMPLEMENTED("Not a socket.");
}
void DatagramPort::getsockopt(int level, int option, void* value, uint* length) {
KJ_UNIMPLEMENTED("Not a socket.");
}
void DatagramPort::setsockopt(int level, int option, const void* value, uint length) {
KJ_UNIMPLEMENTED("Not a socket.");
}
Own<DatagramPort> NetworkAddress::bindDatagramPort() {
KJ_UNIMPLEMENTED("Datagram sockets not implemented.");
}
Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(SOCKET fd, uint flags) {
KJ_UNIMPLEMENTED("Datagram sockets not implemented.");
}
Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
return kj::heap<AsyncIoProviderImpl>(lowLevel);
}
AsyncIoContext setupAsyncIo() {
WSADATA dontcare;
int result = WSAStartup(MAKEWORD(2, 2), &dontcare);
if (result != 0) {
KJ_FAIL_WIN32("WSAStartup()", result);
}
auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
auto& waitScope = lowLevel->getWaitScope();
auto& eventPort = lowLevel->getEventPort();
return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort };
}
} // namespace kj
#endif // _WIN32
......@@ -19,6 +19,9 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if !_WIN32
// For Win32 implementation, see async-io-win32.c++.
#include "async-io.h"
#include "async-unix.h"
#include "debug.h"
......@@ -452,7 +455,8 @@ public:
char buffer[INET6_ADDRSTRLEN];
if (inet_ntop(addr.inet4.sin_family, &addr.inet4.sin_addr,
buffer, sizeof(buffer)) == nullptr) {
KJ_FAIL_SYSCALL("inet_ntop", errno) { return heapString("(inet_ntop error)"); }
KJ_FAIL_SYSCALL("inet_ntop", errno) { break; }
return heapString("(inet_ntop error)");
}
return str(buffer, ':', ntohs(addr.inet4.sin_port));
}
......@@ -460,7 +464,8 @@ public:
char buffer[INET6_ADDRSTRLEN];
if (inet_ntop(addr.inet6.sin6_family, &addr.inet6.sin6_addr,
buffer, sizeof(buffer)) == nullptr) {
KJ_FAIL_SYSCALL("inet_ntop", errno) { return heapString("(inet_ntop error)"); }
KJ_FAIL_SYSCALL("inet_ntop", errno) { break; }
return heapString("(inet_ntop error)");
}
return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
}
......@@ -880,28 +885,10 @@ public:
UnixEventPort::FdObserver observer;
};
class TimerImpl final: public Timer {
public:
TimerImpl(UnixEventPort& eventPort): eventPort(eventPort) {}
TimePoint now() override { return eventPort.steadyTime(); }
Promise<void> atTime(TimePoint time) override {
return eventPort.atSteadyTime(time);
}
Promise<void> afterDelay(Duration delay) override {
return eventPort.atSteadyTime(eventPort.steadyTime() + delay);
}
private:
UnixEventPort& eventPort;
};
class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
public:
LowLevelAsyncIoProviderImpl()
: eventLoop(eventPort), timer(eventPort), waitScope(eventLoop) {}
: eventLoop(eventPort), waitScope(eventLoop) {}
inline WaitScope& getWaitScope() { return waitScope; }
......@@ -935,14 +922,13 @@ public:
return heap<DatagramPortImpl>(*this, eventPort, fd, flags);
}
Timer& getTimer() override { return timer; }
Timer& getTimer() override { return eventPort.getTimer(); }
UnixEventPort& getEventPort() { return eventPort; }
private:
UnixEventPort eventPort;
EventLoop eventLoop;
TimerImpl timer;
WaitScope waitScope;
};
......@@ -1397,3 +1383,5 @@ AsyncIoContext setupAsyncIo() {
}
} // namespace kj
#endif // !_WIN32
......@@ -35,7 +35,12 @@ struct sockaddr;
namespace kj {
#if _WIN32
class Win32EventPort;
#else
class UnixEventPort;
#endif
class NetworkAddress;
// =======================================================================================
......@@ -377,6 +382,7 @@ public:
// If this flag is not used, then the file descriptor is not automatically closed and the
// close-on-exec flag is not modified.
#if !_WIN32
ALREADY_CLOEXEC = 1 << 1,
// Indicates that the close-on-exec flag is known already to be set, so need not be set again.
// Only relevant when combined with TAKE_OWNERSHIP.
......@@ -391,37 +397,56 @@ public:
//
// On Linux, all system calls which yield new file descriptors have flags or variants which
// enable non-blocking mode immediately. Unfortunately, other OS's do not.
#endif
};
virtual Own<AsyncInputStream> wrapInputFd(int fd, uint flags = 0) = 0;
#if _WIN32
typedef uintptr_t Fd;
// On Windows, the `fd` parameter to each of these methods must be a SOCKET, and must have the
// flag WSA_FLAG_OVERLAPPED (which socket() uses by default, but WSASocket() wants you to specify
// explicitly).
#else
typedef int Fd;
// On Unix, any arbitrary file descriptor is supported.
#endif
virtual Own<AsyncInputStream> wrapInputFd(Fd fd, uint flags = 0) = 0;
// Create an AsyncInputStream wrapping a file descriptor.
//
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<AsyncOutputStream> wrapOutputFd(int fd, uint flags = 0) = 0;
virtual Own<AsyncOutputStream> wrapOutputFd(Fd fd, uint flags = 0) = 0;
// Create an AsyncOutputStream wrapping a file descriptor.
//
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<AsyncIoStream> wrapSocketFd(int fd, uint flags = 0) = 0;
virtual Own<AsyncIoStream> wrapSocketFd(Fd fd, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a socket file descriptor.
//
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(int fd, uint flags = 0) = 0;
#if _WIN32
virtual Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
Fd fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) = 0;
#else
virtual Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(Fd fd, uint flags = 0) = 0;
#endif
// Create an AsyncIoStream wrapping a socket that is in the process of connecting. The returned
// promise should not resolve until connection has completed -- traditionally indicated by the
// descriptor becoming writable.
//
// `flags` is a bitwise-OR of the values of the `Flags` enum.
//
// On Windows, the callee initiates connect rather than the caller.
// TODO(now): Maybe on all systems?
virtual Own<ConnectionReceiver> wrapListenSocketFd(int fd, uint flags = 0) = 0;
virtual Own<ConnectionReceiver> wrapListenSocketFd(Fd fd, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a listen socket file descriptor. This socket should already
// have had `bind()` and `listen()` called on it, so it's ready for `accept()`.
//
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<DatagramPort> wrapDatagramSocketFd(int fd, uint flags = 0);
virtual Own<DatagramPort> wrapDatagramSocketFd(Fd fd, uint flags = 0);
virtual Timer& getTimer() = 0;
// Returns a `Timer` based on real time. Time does not pass while event handlers are running --
......@@ -440,9 +465,13 @@ struct AsyncIoContext {
Own<AsyncIoProvider> provider;
WaitScope& waitScope;
#if _WIN32
Win32EventPort& win32EventPort;
#else
UnixEventPort& unixEventPort;
// TEMPORARY: Direct access to underlying UnixEventPort, mainly for waiting on signals. This
// field will go away at some point when we have a chance to improve these interfaces.
#endif
};
AsyncIoContext setupAsyncIo();
......
......@@ -19,6 +19,8 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if !_WIN32
#include "async-unix.h"
#include "thread.h"
#include "debug.h"
......@@ -34,6 +36,7 @@
#include <algorithm>
namespace kj {
namespace {
inline void delay() { usleep(10000); }
......@@ -529,14 +532,16 @@ TEST(AsyncUnixTest, SteadyTimers) {
EventLoop loop(port);
WaitScope waitScope(loop);
auto start = port.steadyTime();
auto& timer = port.getTimer();
auto start = timer.now();
kj::Vector<TimePoint> expected;
kj::Vector<TimePoint> actual;
auto addTimer = [&](Duration delay) {
expected.add(max(start + delay, start));
port.atSteadyTime(start + delay).then([&]() {
actual.add(port.steadyTime());
timer.atTime(start + delay).then([&]() {
actual.add(timer.now());
}).detach([](Exception&& e) { ADD_FAILURE() << str(e).cStr(); });
};
......@@ -547,7 +552,7 @@ TEST(AsyncUnixTest, SteadyTimers) {
addTimer(-10 * MILLISECONDS);
std::sort(expected.begin(), expected.end());
port.atSteadyTime(expected.back() + MILLISECONDS).wait(waitScope);
timer.atTime(expected.back() + MILLISECONDS).wait(waitScope);
ASSERT_EQ(expected.size(), actual.size());
for (int i = 0; i < expected.size(); ++i) {
......@@ -571,7 +576,7 @@ TEST(AsyncUnixTest, Wake) {
EXPECT_TRUE(port.wait());
{
auto promise = port.atSteadyTime(port.steadyTime());
auto promise = port.getTimer().atTime(port.getTimer().now());
EXPECT_FALSE(port.wait());
}
......@@ -585,4 +590,7 @@ TEST(AsyncUnixTest, Wake) {
EXPECT_TRUE(port.wait());
}
} // namespace
} // namespace kj
#endif // !_WIN32
......@@ -19,6 +19,8 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if !_WIN32
#include "async-unix.h"
#include "debug.h"
#include "threadlocal.h"
......@@ -26,7 +28,6 @@
#include <errno.h>
#include <inttypes.h>
#include <limits>
#include <set>
#include <chrono>
#include <pthread.h>
......@@ -44,64 +45,11 @@ namespace kj {
// =======================================================================================
// Timer code common to multiple implementations
struct UnixEventPort::TimerSet {
struct TimerBefore {
bool operator()(TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs);
};
using Timers = std::multiset<TimerPromiseAdapter*, TimerBefore>;
Timers timers;
};
class UnixEventPort::TimerPromiseAdapter {
public:
TimerPromiseAdapter(PromiseFulfiller<void>& fulfiller, UnixEventPort& port, TimePoint time)
: time(time), fulfiller(fulfiller), port(port) {
pos = port.timers->timers.insert(this);
}
~TimerPromiseAdapter() {
if (pos != port.timers->timers.end()) {
port.timers->timers.erase(pos);
}
}
void fulfill() {
fulfiller.fulfill();
port.timers->timers.erase(pos);
pos = port.timers->timers.end();
}
const TimePoint time;
PromiseFulfiller<void>& fulfiller;
UnixEventPort& port;
TimerSet::Timers::const_iterator pos;
};
bool UnixEventPort::TimerSet::TimerBefore::operator()(
TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) {
return lhs->time < rhs->time;
}
Promise<void> UnixEventPort::atSteadyTime(TimePoint time) {
return newAdaptedPromise<void, TimerPromiseAdapter>(*this, time);
}
TimePoint UnixEventPort::currentSteadyTime() {
TimePoint UnixEventPort::readClock() {
return origin<TimePoint>() + std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch()).count() * NANOSECONDS;
}
void UnixEventPort::processTimers() {
frozenSteadyTime = currentSteadyTime();
for (;;) {
auto front = timers->timers.begin();
if (front == timers->timers.end() || (*front)->time > frozenSteadyTime) {
break;
}
(*front)->fulfill();
}
}
// =======================================================================================
// Signal code common to multiple implementations
......@@ -247,8 +195,7 @@ void UnixEventPort::gotSignal(const siginfo_t& siginfo) {
// epoll FdObserver implementation
UnixEventPort::UnixEventPort()
: timers(kj::heap<TimerSet>()),
frozenSteadyTime(currentSteadyTime()),
: timerImpl(readClock()),
epollFd(-1),
signalFd(-1),
eventFd(-1) {
......@@ -358,27 +305,10 @@ Promise<void> UnixEventPort::FdObserver::whenUrgentDataAvailable() {
}
bool UnixEventPort::wait() {
// epoll_wait()'s timeout is an `int` count of milliseconds, so truncate to that.
// Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that
// will break the math below.
constexpr Duration MAX_TIMEOUT =
min(int(maxValue) * MILLISECONDS, Duration(maxValue) - MILLISECONDS);
int epollTimeout = -1;
auto timer = timers->timers.begin();
if (timer != timers->timers.end()) {
Duration timeout = (*timer)->time - currentSteadyTime();
if (timeout < 0 * SECONDS) {
epollTimeout = 0;
} else if (timeout < MAX_TIMEOUT) {
// Round up to the next millisecond
epollTimeout = (timeout + 1 * MILLISECONDS - unit<Duration>()) / MILLISECONDS;
} else {
epollTimeout = MAX_TIMEOUT / MILLISECONDS;
}
}
return doEpollWait(epollTimeout);
return doEpollWait(
timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue))
.map([](uint64_t t) -> int { return t; })
.orDefault(-1));
}
bool UnixEventPort::poll() {
......@@ -552,7 +482,7 @@ bool UnixEventPort::doEpollWait(int timeout) {
}
}
processTimers();
timerImpl.advanceTo(readClock());
return woken;
}
......@@ -566,8 +496,7 @@ bool UnixEventPort::doEpollWait(int timeout) {
#endif
UnixEventPort::UnixEventPort()
: timers(kj::heap<TimerSet>()),
frozenSteadyTime(currentSteadyTime()) {
: timerImpl(readClock()) {
static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port.");
*reinterpret_cast<pthread_t*>(&threadId) = pthread_self();
......@@ -769,33 +698,17 @@ bool UnixEventPort::wait() {
threadCapture = &capture;
sigprocmask(SIG_UNBLOCK, &newMask, &origMask);
// poll()'s timeout is an `int` count of milliseconds, so truncate to that.
// Also, make sure that we aren't within a millisecond of overflowing a `Duration` since that
// will break the math below.
constexpr Duration MAX_TIMEOUT =
min(int(maxValue) * MILLISECONDS, Duration(maxValue) - MILLISECONDS);
int pollTimeout = -1;
auto timer = timers->timers.begin();
if (timer != timers->timers.end()) {
Duration timeout = (*timer)->time - currentSteadyTime();
if (timeout < 0 * SECONDS) {
pollTimeout = 0;
} else if (timeout < MAX_TIMEOUT) {
// Round up to the next millisecond
pollTimeout = (timeout + 1 * MILLISECONDS - unit<Duration>()) / MILLISECONDS;
} else {
pollTimeout = MAX_TIMEOUT / MILLISECONDS;
}
}
pollContext.run(pollTimeout);
pollContext.run(
timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, int(maxValue))
.map([](uint64_t t) -> int { return t; })
.orDefault(-1));
sigprocmask(SIG_SETMASK, &origMask, nullptr);
threadCapture = nullptr;
// Queue events.
pollContext.processResults();
processTimers();
timerImpl.advanceTo(readClock());
return false;
}
......@@ -857,7 +770,7 @@ bool UnixEventPort::poll() {
pollContext.run(0);
pollContext.processResults();
}
processTimers();
timerImpl.advanceTo(readClock());
return woken;
}
......@@ -872,3 +785,5 @@ void UnixEventPort::wake() const {
#endif // KJ_USE_EPOLL, else
} // namespace kj
#endif // !_WIN32
......@@ -22,6 +22,10 @@
#ifndef KJ_ASYNC_UNIX_H_
#define KJ_ASYNC_UNIX_H_
#if _WIN32
#error "This file is Unix-specific. On Windows, include async-win32.h instead."
#endif
#if defined(__GNUC__) && !KJ_HEADER_WARNINGS
#pragma GCC system_header
#endif
......@@ -93,8 +97,7 @@ public:
// needs to use SIGUSR1, call this at startup (before any calls to `captureSignal()` and before
// constructing an `UnixEventPort`) to offer a different signal.
TimePoint steadyTime() { return frozenSteadyTime; }
Promise<void> atSteadyTime(TimePoint time);
Timer& getTimer() { return timerImpl; }
// implements EventPort ------------------------------------------------------
bool wait() override;
......@@ -106,14 +109,12 @@ private:
class TimerPromiseAdapter;
class SignalPromiseAdapter;
Own<TimerSet> timers;
TimePoint frozenSteadyTime;
TimerImpl timerImpl;
SignalPromiseAdapter* signalHead = nullptr;
SignalPromiseAdapter** signalTail = &signalHead;
TimePoint currentSteadyTime();
void processTimers();
TimePoint readClock();
void gotSignal(const siginfo_t& siginfo);
friend class TimerPromiseAdapter;
......
// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if _WIN32
#include "async-win32.h"
#include "thread.h"
#include "test.h"
namespace kj {
namespace {
KJ_TEST("Win32IocpEventPort I/O operations") {
Win32IocpEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
auto pipeName = kj::str("\\\\.\\Pipe\\kj-async-win32-test.", GetCurrentProcessId());
HANDLE readEnd_, writeEnd_;
KJ_WIN32(readEnd_ = CreateNamedPipeA(pipeName.cStr(),
PIPE_ACCESS_INBOUND | FILE_FLAG_OVERLAPPED,
PIPE_TYPE_BYTE | PIPE_WAIT,
1, 0, 0, 0, NULL));
AutoCloseHandle readEnd(readEnd_);
KJ_WIN32(writeEnd_ = CreateFileA(pipeName.cStr(), GENERIC_WRITE, 0, NULL, OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL, NULL));
AutoCloseHandle writeEnd(writeEnd_);
auto observer = port.observeIo(readEnd);
auto op = observer->newOperation(0);
byte buffer[256];
KJ_ASSERT(!ReadFile(readEnd, buffer, sizeof(buffer), NULL, op->getOverlapped()));
DWORD error = GetLastError();
if (error != ERROR_IO_PENDING) {
KJ_FAIL_WIN32("ReadFile()", error);
}
bool done = false;
auto promise = op->onComplete().then([&](Win32EventPort::IoResult result) {
done = true;
return result;
}).eagerlyEvaluate(nullptr);
KJ_EXPECT(!done);
evalLater([]() {}).wait(waitScope);
evalLater([]() {}).wait(waitScope);
evalLater([]() {}).wait(waitScope);
evalLater([]() {}).wait(waitScope);
evalLater([]() {}).wait(waitScope);
KJ_EXPECT(!done);
DWORD bytesWritten;
KJ_WIN32(WriteFile(writeEnd, "foo", 3, &bytesWritten, NULL));
KJ_EXPECT(bytesWritten == 3);
auto result = promise.wait(waitScope);
KJ_EXPECT(result.errorCode == ERROR_SUCCESS);
KJ_EXPECT(result.bytesTransferred == 3);
KJ_EXPECT(kj::str(kj::arrayPtr(buffer, 3).asChars()) == "foo");
}
KJ_TEST("Win32IocpEventPort::wake()") {
Win32IocpEventPort port;
Thread thread([&]() {
Sleep(10);
port.wake();
});
KJ_EXPECT(port.wait());
}
KJ_TEST("Win32IocpEventPort::wake() on poll()") {
Win32IocpEventPort port;
volatile bool woken = false;
Thread thread([&]() {
Sleep(10);
port.wake();
woken = true;
});
KJ_EXPECT(!port.poll());
while (!woken) Sleep(10);
KJ_EXPECT(port.poll());
}
KJ_TEST("Win32IocpEventPort timer") {
Win32IocpEventPort port;
EventLoop loop(port);
WaitScope waitScope(loop);
auto start = port.getTimer().now();
bool done = false;
auto promise = port.getTimer().afterDelay(10 * MILLISECONDS).then([&]() {
done = true;
}).eagerlyEvaluate(nullptr);
KJ_EXPECT(!done);
evalLater([]() {}).wait(waitScope);
evalLater([]() {}).wait(waitScope);
evalLater([]() {}).wait(waitScope);
evalLater([]() {}).wait(waitScope);
evalLater([]() {}).wait(waitScope);
KJ_EXPECT(!done);
promise.wait(waitScope);
KJ_EXPECT(done);
KJ_EXPECT(port.getTimer().now() - start >= 10 * MILLISECONDS);
}
} // namespace
} // namespace kj
#endif // _WIN32
// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#include "async-win32.h"
#include "debug.h"
#include <chrono>
#include "refcount.h"
#undef ERROR // dammit windows.h
namespace kj {
Win32IocpEventPort::Win32IocpEventPort()
: iocp(newIocpHandle()), thread(openCurrentThread()), timerImpl(readClock()) {}
Win32IocpEventPort::~Win32IocpEventPort() noexcept(false) {}
class Win32IocpEventPort::IoPromiseAdapter final: public OVERLAPPED {
public:
IoPromiseAdapter(PromiseFulfiller<IoResult>& fulfiller, Win32IocpEventPort& port,
uint64_t offset, IoPromiseAdapter** selfPtr)
: fulfiller(fulfiller), port(port) {
*selfPtr = this;
memset(implicitCast<OVERLAPPED*>(this), 0, sizeof(OVERLAPPED));
this->Offset = offset & 0x00000000FFFFFFFFull;
this->OffsetHigh = offset >> 32;
}
~IoPromiseAdapter() {
if (handle != INVALID_HANDLE_VALUE) {
// Need to cancel the I/O.
//
// Note: Even if HasOverlappedIoCompleted(this) is true, CancelIoEx() still seems needed to
// force the completion event.
if (!CancelIoEx(handle, this)) {
DWORD error = GetLastError();
// ERROR_NOT_FOUND probably means the operation already completed and is enqueued on the
// IOCP.
//
// ERROR_INVALID_HANDLE probably means that, amid a mass of destructors, the HANDLE was
// closed before all of the I/O promises were destroyed. We tolerate this so long as the
// I/O promises are also destroyed before returning to the event loop, hence the I/O
// tasks won't actually continue on a dead handle.
//
// TODO(cleanup): ERROR_INVALID_HANDLE really shouldn't be allowed. Unfortunately, the
// refcounted nature of capabilities and the RPC system seems to mean that objects
// are unwound in the wrong order in several of Cap'n Proto's tests. So we live with this
// for now. Note that even if a new handle is opened with the same numeric value, it
// should be hardless to call CancelIoEx() on it because it couldn't possibly be using
// the same OVERLAPPED structure.
if (error != ERROR_NOT_FOUND && error != ERROR_INVALID_HANDLE) {
KJ_FAIL_WIN32("CancelIoEx()", error, handle);
}
}
// We have to wait for the IOCP to poop out the event, so that we can safely destroy the
// OVERLAPPED.
while (handle != INVALID_HANDLE_VALUE) {
port.waitIocp(INFINITE);
}
}
}
void start(HANDLE handle) {
KJ_ASSERT(this->handle == INVALID_HANDLE_VALUE);
this->handle = handle;
}
void done(IoResult result) {
KJ_ASSERT(handle != INVALID_HANDLE_VALUE);
handle = INVALID_HANDLE_VALUE;
fulfiller.fulfill(kj::mv(result));
}
private:
PromiseFulfiller<IoResult>& fulfiller;
Win32IocpEventPort& port;
HANDLE handle = INVALID_HANDLE_VALUE;
// If an I/O operation is currently enqueued, the handle on which it is enqueued.
};
class Win32IocpEventPort::IoOperationImpl final: public Win32EventPort::IoOperation {
public:
explicit IoOperationImpl(Win32IocpEventPort& port, HANDLE handle, uint64_t offset)
: handle(handle),
promise(newAdaptedPromise<IoResult, IoPromiseAdapter>(port, offset, &promiseAdapter)) {}
LPOVERLAPPED getOverlapped() override {
KJ_REQUIRE(promiseAdapter != nullptr, "already called onComplete()");
return promiseAdapter;
}
Promise<IoResult> onComplete() override {
KJ_REQUIRE(promiseAdapter != nullptr, "can only call onComplete() once");
promiseAdapter->start(handle);
promiseAdapter = nullptr;
return kj::mv(promise);
}
private:
HANDLE handle;
IoPromiseAdapter* promiseAdapter;
Promise<IoResult> promise;
};
class Win32IocpEventPort::IoObserverImpl final: public Win32EventPort::IoObserver {
public:
IoObserverImpl(Win32IocpEventPort& port, HANDLE handle)
: port(port), handle(handle) {
KJ_WIN32(CreateIoCompletionPort(handle, port.iocp, 0, 1), handle, port.iocp.get());
}
Own<IoOperation> newOperation(uint64_t offset) {
return heap<IoOperationImpl>(port, handle, offset);
}
private:
Win32IocpEventPort& port;
HANDLE handle;
};
Own<Win32EventPort::IoObserver> Win32IocpEventPort::observeIo(HANDLE handle) {
return heap<IoObserverImpl>(*this, handle);
}
Own<Win32EventPort::SignalObserver> Win32IocpEventPort::observeSignalState(HANDLE handle) {
return waitThreads.observeSignalState(handle);
}
TimePoint Win32IocpEventPort::readClock() {
return origin<TimePoint>() + std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch()).count() * NANOSECONDS;
}
bool Win32IocpEventPort::wait() {
waitIocp(timerImpl.timeoutToNextEvent(readClock(), MILLISECONDS, INFINITE - 1)
.map([](uint64_t t) -> DWORD { return t; })
.orDefault(INFINITE));
timerImpl.advanceTo(readClock());
return receivedWake();
}
bool Win32IocpEventPort::poll() {
waitIocp(0);
return receivedWake();
}
void Win32IocpEventPort::wake() const {
if (!__atomic_load_n(&sentWake, __ATOMIC_ACQUIRE)) {
__atomic_store_n(&sentWake, true, __ATOMIC_RELEASE);
KJ_WIN32(PostQueuedCompletionStatus(iocp, 0, 0, nullptr));
}
}
void Win32IocpEventPort::waitIocp(DWORD timeoutMs) {
DWORD bytesTransferred;
ULONG_PTR completionKey;
LPOVERLAPPED overlapped = nullptr;
// TODO(someday): Should we use GetQueuedCompletionStatusEx()? It would allow us to read multiple
// events in one call and would let us wait in an alertable state, which would allow users to
// use APCs. However, it currently isn't implemented on Wine (as of 1.9.22).
BOOL success = GetQueuedCompletionStatus(
iocp, &bytesTransferred, &completionKey, &overlapped, timeoutMs);
if (overlapped == nullptr) {
if (success) {
// wake() called in another thread.
} else {
DWORD error = GetLastError();
if (error == WAIT_TIMEOUT) {
// Great, nothing to do. (Why this is WAIT_TIMEOUT and not ERROR_TIMEOUT I'm not sure.)
} else {
KJ_FAIL_WIN32("GetQueuedCompletionStatus()", error, error, overlapped);
}
}
} else {
DWORD error = success ? ERROR_SUCCESS : GetLastError();
static_cast<IoPromiseAdapter*>(overlapped)->done(IoResult { error, bytesTransferred });
}
}
bool Win32IocpEventPort::receivedWake() {
if (__atomic_load_n(&sentWake, __ATOMIC_ACQUIRE)) {
__atomic_store_n(&sentWake, false, __ATOMIC_RELEASE);
return true;
} else {
return false;
}
}
AutoCloseHandle Win32IocpEventPort::newIocpHandle() {
HANDLE h;
KJ_WIN32(h = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 1));
return AutoCloseHandle(h);
}
AutoCloseHandle Win32IocpEventPort::openCurrentThread() {
HANDLE process = GetCurrentProcess();
HANDLE result;
KJ_WIN32(DuplicateHandle(process, GetCurrentThread(), process, &result,
0, FALSE, DUPLICATE_SAME_ACCESS));
return AutoCloseHandle(result);
}
// =======================================================================================
Win32WaitObjectThreadPool::Win32WaitObjectThreadPool(uint mainThreadCount) {}
Own<Win32EventPort::SignalObserver> Win32WaitObjectThreadPool::observeSignalState(HANDLE handle) {
KJ_UNIMPLEMENTED("wait for win32 handles");
}
uint Win32WaitObjectThreadPool::prepareMainThreadWait(HANDLE* handles[]) {
KJ_UNIMPLEMENTED("wait for win32 handles");
}
bool Win32WaitObjectThreadPool::finishedMainThreadWait(DWORD returnCode) {
KJ_UNIMPLEMENTED("wait for win32 handles");
}
} // namespace kj
#endif // _WIN32
// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef KJ_ASYNC_WIN32_H_
#define KJ_ASYNC_WIN32_H_
#if !_WIN32
#error "This file is Windows-specific. On Unix, include async-unix.h instead."
#endif
#include "async.h"
#include "time.h"
#include "io.h"
#include <inttypes.h>
// Include windows.h as lean as possible. (If you need more of the Windows API for your app,
// #include windows.h yourself before including this header.)
#define WIN32_LEAN_AND_MEAN 1
#define NOSERVICE 1
#define NOMCX 1
#define NOIME 1
#include <windows.h>
#include "windows-sanity.h"
namespace kj {
class Win32EventPort: public EventPort {
// Abstract base interface for EventPorts that can listen on Win32 event types. Due to the
// absurd complexity of the Win32 API, it's not possible to standardize on a single
// implementation of EventPort. In particular, there is no way for a single thread to use I/O
// completion ports (the most efficient way of handling I/O) while at the same time waiting for
// signalable handles or UI messages.
//
// Note that UI messages are not supported at all by this interface because the message queue
// is implemented by user32.dll and we want libkj to depend only on kernel32.dll. A separate
// compat library could provide a Win32EventPort implementation that works with the UI message
// queue.
public:
// ---------------------------------------------------------------------------
// overlapped I/O
struct IoResult {
DWORD errorCode;
DWORD bytesTransferred;
};
class IoOperation {
public:
virtual LPOVERLAPPED getOverlapped() = 0;
// Gets the OVERLAPPED structure to pass to the Win32 I/O call. Do NOT modify it; just pass it
// on.
virtual Promise<IoResult> onComplete() = 0;
// After making the Win32 call, if the return value indicates that the operation was
// successfully queued (i.e. the completion event will definitely occur), call this to wait
// for completion.
//
// You MUST call this if the operation was successfully queued, and you MUST NOT call this
// otherwise. If the Win32 call failed (without queuing any operation or event) then you should
// simply drop the IoOperation object.
//
// Dropping the returned Promise cancels the operation via Win32's CancelIoEx(). The destructor
// will wait for the cancellation to complete, such that after dropping the proimse it is safe
// to free the buffer that the operation was reading from / writing to.
//
// You may safely drop the `IoOperation` while still waiting for this promise. You may not,
// however, drop the `IoObserver`.
};
class IoObserver {
public:
virtual Own<IoOperation> newOperation(uint64_t offset) = 0;
// Begin an I/O operation. For file operations, `offset` is the offset within the file at
// which the operation will start. For stream operations, `offset` is ignored.
};
virtual Own<IoObserver> observeIo(HANDLE handle) = 0;
// Given a handle which supports overlapped I/O, arrange to receive I/O completion events via
// this EventPort.
//
// Different Win32EventPort implementations may handle this in different ways, such as by using
// completion routines (APCs) or by using I/O completion ports. The caller should not assume
// any particular technique.
//
// WARNING: It is only safe to call observeIo() on a particular handle once during its lifetime.
// You cannot observe the same handle from multiple Win32EventPorts, even if not at the same
// time. This is because the Win32 API provides no way to disassociate a handle from an I/O
// completion port once it is associated.
// ---------------------------------------------------------------------------
// signalable handles
//
// Warning: Due to limitations in the Win32 API, implementations of EventPort may be forced to
// spawn additional threads to wait for signaled objects. This is necessary if the EventPort
// implementation is based on I/O completion ports, or if you need to wait on more than 64
// handles at once.
class SignalObserver {
public:
virtual Promise<void> onSignaled() = 0;
// Returns a promise that completes the next time the handle enters the signaled state.
//
// Depending on the type of handle, the handle may automatically be reset to a non-signaled
// state before the promise resolves. The underlying implementaiton uses WaitForSingleObject()
// or an equivalent wait call, so check the documentation for that to understand the semantics.
//
// If the handle is a mutex and it is abandoned without being unlocked, the promise breaks with
// an exception.
virtual Promise<bool> onSignaledOrAbandoned() = 0;
// Like onSingaled(), but instead of throwing when a mutex is abandoned, resolves to `true`.
// Resolves to `false` for non-abandoned signals.
};
virtual Own<SignalObserver> observeSignalState(HANDLE handle) = 0;
// Given a handle that supports waiting for it to become "signaled" via WaitForSingleObject(),
// return an object that can wait for this state using the EventPort.
// ---------------------------------------------------------------------------
// time
virtual Timer& getTimer() = 0;
};
class Win32WaitObjectThreadPool {
// Helper class that implements Win32EventPort::observeSignalState() by spawning additional
// threads as needed to perform the actual waiting.
//
// This class is intended to be used to assist in building Win32EventPort implementations.
public:
Win32WaitObjectThreadPool(uint mainThreadCount = 0);
// `mainThreadCount` indicates the number of objects the main thread is able to listen on
// directly. Typically this would be zero (e.g. if the main thread watches an I/O completion
// port) or MAXIMUM_WAIT_OBJECTS (e.g. if the main thread is a UI thread but can use
// MsgWaitForMultipleObjectsEx() to wait on some handles at the same time as messages).
Own<Win32EventPort::SignalObserver> observeSignalState(HANDLE handle);
// Implemetns Win32EventPort::observeSignalState().
uint prepareMainThreadWait(HANDLE* handles[]);
// Call immediately before invoking WaitForMultipleObjects() or similar in the main thread.
// Fills in `handles` with the handle pointers to wait on, and returns the number of handles
// in this array. (The array should be allocated to be at least the size passed to the
// constructor).
//
// There's no need to call this if `mainThreadCount` as passed to the constructor was zero.
bool finishedMainThreadWait(DWORD returnCode);
// Call immediately after invoking WaitForMultipleObjects() or similar in the main thread,
// passing the value returend by that call. Returns true if the event indicated by `returnCode`
// has been handled (i.e. it was WAIT_OBJECT_n or WAIT_ABANDONED_n where n is in-range for the
// last call to prepareMainThreadWait()).
};
class Win32IocpEventPort final: public Win32EventPort {
// An EventPort implementation which uses Windows I/O completion ports to listen for events.
//
// With this implementation, observeSignalState() requires spawning a separate thread.
public:
Win32IocpEventPort();
~Win32IocpEventPort() noexcept(false);
// implements EventPort ------------------------------------------------------
bool wait() override;
bool poll() override;
void wake() const override;
// implements Win32IocpEventPort ---------------------------------------------
Own<IoObserver> observeIo(HANDLE handle) override;
Own<SignalObserver> observeSignalState(HANDLE handle) override;
Timer& getTimer() override { return timerImpl; }
private:
class IoPromiseAdapter;
class IoOperationImpl;
class IoObserverImpl;
AutoCloseHandle iocp;
AutoCloseHandle thread;
Win32WaitObjectThreadPool waitThreads;
TimerImpl timerImpl;
mutable bool sentWake = false;
static TimePoint readClock();
void waitIocp(DWORD timeoutMs);
// Wait on the I/O completion port for up to timeoutMs and pump events. Does not advance the
// timer; caller must do that.
bool receivedWake();
static AutoCloseHandle newIocpHandle();
static AutoCloseHandle openCurrentThread();
};
} // namespace kj
#endif // KJ_ASYNC_WIN32_H_
......@@ -27,6 +27,13 @@
#if _WIN32
#define strerror_r(errno,buf,len) strerror_s(buf,len,errno)
#define NOMINMAX 1
#define WIN32_LEAN_AND_MEAN 1
#define NOSERVICE 1
#define NOMCX 1
#define NOIME 1
#include <windows.h>
#include "windows-sanity.h"
#endif
namespace kj {
......@@ -125,6 +132,38 @@ Exception::Type typeOfErrno(int error) {
}
}
#if _WIN32
Exception::Type typeOfWin32Error(DWORD error) {
switch (error) {
// TODO(now): This needs more work.
case WSAETIMEDOUT:
return Exception::Type::OVERLOADED;
case WSAENOTCONN:
case WSAECONNABORTED:
case WSAECONNREFUSED:
case WSAECONNRESET:
case WSAEHOSTDOWN:
case WSAEHOSTUNREACH:
case WSAENETDOWN:
case WSAENETRESET:
case WSAENETUNREACH:
return Exception::Type::DISCONNECTED;
case WSAEOPNOTSUPP:
case WSAENOPROTOOPT:
case WSAENOTSOCK: // This is really saying "syscall not implemented for non-sockets".
return Exception::Type::UNIMPLEMENTED;
default:
return Exception::Type::FAILED;
}
}
#endif // _WIN32
enum DescriptionStyle {
LOG,
ASSERTION,
......@@ -132,7 +171,8 @@ enum DescriptionStyle {
};
static String makeDescriptionImpl(DescriptionStyle style, const char* code, int errorNumber,
const char* macroArgs, ArrayPtr<String> argValues) {
const char* sysErrorString, const char* macroArgs,
ArrayPtr<String> argValues) {
KJ_STACK_ARRAY(ArrayPtr<const char>, argNames, argValues.size(), 8, 64);
if (argValues.size() > 0) {
......@@ -205,13 +245,21 @@ static String makeDescriptionImpl(DescriptionStyle style, const char* code, int
#if __USE_GNU
char buffer[256];
if (style == SYSCALL) {
sysErrorArray = strerror_r(errorNumber, buffer, sizeof(buffer));
if (sysErrorString == nullptr) {
sysErrorArray = strerror_r(errorNumber, buffer, sizeof(buffer));
} else {
sysErrorArray = sysErrorString;
}
}
#else
char buffer[256];
if (style == SYSCALL) {
strerror_r(errorNumber, buffer, sizeof(buffer));
sysErrorArray = buffer;
if (sysErrorString == nullptr) {
strerror_r(errorNumber, buffer, sizeof(buffer));
sysErrorArray = buffer;
} else {
sysErrorArray = sysErrorString;
}
}
#endif
......@@ -250,6 +298,7 @@ static String makeDescriptionImpl(DescriptionStyle style, const char* code, int
pos = _::fill(pos, codeArray, colon, sysErrorArray);
break;
}
for (size_t i = 0; i < argValues.size(); i++) {
if (i > 0 || style != LOG) {
pos = _::fill(pos, delim);
......@@ -269,7 +318,7 @@ static String makeDescriptionImpl(DescriptionStyle style, const char* code, int
void Debug::logInternal(const char* file, int line, LogSeverity severity, const char* macroArgs,
ArrayPtr<String> argValues) {
getExceptionCallback().logMessage(severity, trimSourceFilename(file).cStr(), line, 0,
makeDescriptionImpl(LOG, nullptr, 0, macroArgs, argValues));
makeDescriptionImpl(LOG, nullptr, 0, nullptr, macroArgs, argValues));
}
Debug::Fault::~Fault() noexcept(false) {
......@@ -292,18 +341,46 @@ void Debug::Fault::init(
const char* file, int line, Exception::Type type,
const char* condition, const char* macroArgs, ArrayPtr<String> argValues) {
exception = new Exception(type, file, line,
makeDescriptionImpl(ASSERTION, condition, 0, macroArgs, argValues));
makeDescriptionImpl(ASSERTION, condition, 0, nullptr, macroArgs, argValues));
}
void Debug::Fault::init(
const char* file, int line, int osErrorNumber,
const char* condition, const char* macroArgs, ArrayPtr<String> argValues) {
exception = new Exception(typeOfErrno(osErrorNumber), file, line,
makeDescriptionImpl(SYSCALL, condition, osErrorNumber, macroArgs, argValues));
makeDescriptionImpl(SYSCALL, condition, osErrorNumber, nullptr, macroArgs, argValues));
}
#if _WIN32
void Debug::Fault::init(
const char* file, int line, Win32Error osErrorNumber,
const char* condition, const char* macroArgs, ArrayPtr<String> argValues) {
LPVOID ptr;
// TODO(now): Use FormatMessageW() instead.
// TODO(now): Why doesn't this work for winsock errors?
DWORD result = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER |
FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, osErrorNumber.number,
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
(LPTSTR) &ptr, 0, NULL);
if (result > 0) {
KJ_DEFER(LocalFree(ptr));
exception = new Exception(typeOfWin32Error(osErrorNumber.number), file, line,
makeDescriptionImpl(SYSCALL, condition, 0, reinterpret_cast<char*>(ptr),
macroArgs, argValues));
} else {
auto message = kj::str("win32 error code: ", osErrorNumber.number);
exception = new Exception(typeOfWin32Error(osErrorNumber.number), file, line,
makeDescriptionImpl(SYSCALL, condition, 0, message.cStr(),
macroArgs, argValues));
}
}
#endif
String Debug::makeDescriptionInternal(const char* macroArgs, ArrayPtr<String> argValues) {
return makeDescriptionImpl(LOG, nullptr, 0, macroArgs, argValues);
return makeDescriptionImpl(LOG, nullptr, 0, nullptr, macroArgs, argValues);
}
int Debug::getOsErrorNumber(bool nonblocking) {
......@@ -316,6 +393,12 @@ int Debug::getOsErrorNumber(bool nonblocking) {
: result;
}
#if _WIN32
Debug::Win32Error Debug::getWin32Error() {
return Win32Error(::GetLastError());
}
#endif
Debug::Context::Context(): logged(false) {}
Debug::Context::~Context() noexcept(false) {}
......
......@@ -165,6 +165,24 @@ namespace kj {
for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \
errorNumber, code, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal())
#if _WIN32
#define KJ_WIN32(call, ...) \
if (::kj::_::Debug::isWin32Success(call)) {} else \
for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \
::kj::_::Debug::getWin32Error(), #call, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal())
#define KJ_WINSOCK(call, ...) \
if ((call) != SOCKET_ERROR) {} else \
for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \
::kj::_::Debug::getWin32Error(), #call, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal())
#define KJ_FAIL_WIN32(code, errorNumber, ...) \
for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \
::kj::_::Debug::Win32Error(errorNumber), code, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal())
#endif
#define KJ_UNIMPLEMENTED(...) \
for (::kj::_::Debug::Fault f(__FILE__, __LINE__, ::kj::Exception::Type::UNIMPLEMENTED, \
nullptr, "" #__VA_ARGS__, __VA_ARGS__);; f.fatal())
......@@ -223,6 +241,24 @@ namespace kj {
for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \
errorNumber, code, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal())
#if _WIN32
#define KJ_WIN32(call, ...) \
if (::kj::_::Debug::isWin32Success(call)) {} else \
for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \
::kj::_::Debug::getWin32Error(), #call, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal())
#define KJ_WINSOCK(call, ...) \
if ((call) != SOCKET_ERROR) {} else \
for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \
::kj::_::Debug::getWin32Error(), #call, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal())
#define KJ_FAIL_WIN32(code, errorNumber, ...) \
for (::kj::_::Debug::Fault f(__FILE__, __LINE__, \
::kj::_::Debug::Win32Error(errorNumber), code, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal())
#endif
#define KJ_UNIMPLEMENTED(...) \
for (::kj::_::Debug::Fault f(__FILE__, __LINE__, ::kj::Exception::Type::UNIMPLEMENTED, \
nullptr, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal())
......@@ -275,6 +311,14 @@ public:
typedef LogSeverity Severity; // backwards-compatibility
#if _WIN32
struct Win32Error {
// Hack for overloading purposes.
uint number;
inline explicit Win32Error(uint number): number(number) {}
};
#endif
static inline bool shouldLog(LogSeverity severity) { return severity >= minSeverity; }
// Returns whether messages of the given severity should be logged.
......@@ -289,16 +333,17 @@ public:
class Fault {
public:
template <typename... Params>
Fault(const char* file, int line, Exception::Type type,
const char* condition, const char* macroArgs, Params&&... params);
template <typename... Params>
Fault(const char* file, int line, int osErrorNumber,
template <typename Code, typename... Params>
Fault(const char* file, int line, Code code,
const char* condition, const char* macroArgs, Params&&... params);
Fault(const char* file, int line, Exception::Type type,
const char* condition, const char* macroArgs);
Fault(const char* file, int line, int osErrorNumber,
const char* condition, const char* macroArgs);
#if _WIN32
Fault(const char* file, int line, Win32Error osErrorNumber,
const char* condition, const char* macroArgs);
#endif
~Fault() noexcept(false);
KJ_NOINLINE KJ_NORETURN(void fatal());
......@@ -309,6 +354,10 @@ public:
const char* condition, const char* macroArgs, ArrayPtr<String> argValues);
void init(const char* file, int line, int osErrorNumber,
const char* condition, const char* macroArgs, ArrayPtr<String> argValues);
#if _WIN32
void init(const char* file, int line, Win32Error osErrorNumber,
const char* condition, const char* macroArgs, ArrayPtr<String> argValues);
#endif
Exception* exception;
};
......@@ -326,6 +375,12 @@ public:
template <typename Call>
static SyscallResult syscall(Call&& call, bool nonblocking);
#if _WIN32
static bool isWin32Success(int boolean);
static bool isWin32Success(void* handle);
static Win32Error getWin32Error();
#endif
class Context: public ExceptionCallback {
public:
Context();
......@@ -394,21 +449,12 @@ inline void Debug::log<>(const char* file, int line, LogSeverity severity, const
logInternal(file, line, severity, macroArgs, nullptr);
}
template <typename... Params>
Debug::Fault::Fault(const char* file, int line, Exception::Type type,
template <typename Code, typename... Params>
Debug::Fault::Fault(const char* file, int line, Code code,
const char* condition, const char* macroArgs, Params&&... params)
: exception(nullptr) {
String argValues[sizeof...(Params)] = {str(params)...};
init(file, line, type, condition, macroArgs,
arrayPtr(argValues, sizeof...(Params)));
}
template <typename... Params>
Debug::Fault::Fault(const char* file, int line, int osErrorNumber,
const char* condition, const char* macroArgs, Params&&... params)
: exception(nullptr) {
String argValues[sizeof...(Params)] = {str(params)...};
init(file, line, osErrorNumber, condition, macroArgs,
init(file, line, code, condition, macroArgs,
arrayPtr(argValues, sizeof...(Params)));
}
......@@ -424,6 +470,22 @@ inline Debug::Fault::Fault(const char* file, int line, kj::Exception::Type type,
init(file, line, type, condition, macroArgs, nullptr);
}
#if _WIN32
inline Debug::Fault::Fault(const char* file, int line, Win32Error osErrorNumber,
const char* condition, const char* macroArgs)
: exception(nullptr) {
init(file, line, osErrorNumber, condition, macroArgs, nullptr);
}
inline bool Debug::isWin32Success(int boolean) {
return boolean;
}
inline bool Debug::isWin32Success(void* handle) {
// Assume null and INVALID_HANDLE_VALUE mean failure.
return handle != nullptr && handle != (void*)-1;
}
#endif
template <typename Call>
Debug::SyscallResult Debug::syscall(Call&& call, bool nonblocking) {
while (call() < 0) {
......
......@@ -38,6 +38,13 @@
#include <execinfo.h>
#endif
#if _WIN32
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#include "windows-sanity.h"
#include <dbghelp.h>
#endif
#if (__linux__ || __APPLE__) && defined(KJ_DEBUG)
#include <stdio.h>
#include <pthread.h>
......@@ -57,19 +64,142 @@ StringPtr KJ_STRINGIFY(LogSeverity severity) {
return SEVERITY_STRINGS[static_cast<uint>(severity)];
}
#if _WIN32 && _M_X64
// Currently the Win32 stack-trace code only supports x86_64. We could easily extend it to support
// i386 as well but it requires some code changes around how we read the context to start the
// trace.
namespace {
struct Dbghelp {
// Load dbghelp.dll dynamically since we don't really need it, it's just for debugging.
HINSTANCE lib;
WINBOOL WINAPI (*symInitialize)(HANDLE hProcess,PCSTR UserSearchPath,WINBOOL fInvadeProcess);
WINBOOL WINAPI (*stackWalk64)(
DWORD MachineType,HANDLE hProcess,HANDLE hThread,
LPSTACKFRAME64 StackFrame,PVOID ContextRecord,
PREAD_PROCESS_MEMORY_ROUTINE64 ReadMemoryRoutine,
PFUNCTION_TABLE_ACCESS_ROUTINE64 FunctionTableAccessRoutine,
PGET_MODULE_BASE_ROUTINE64 GetModuleBaseRoutine,
PTRANSLATE_ADDRESS_ROUTINE64 TranslateAddress);
PVOID WINAPI (*symFunctionTableAccess64)(HANDLE hProcess,DWORD64 AddrBase);
DWORD64 WINAPI (*symGetModuleBase64)(HANDLE hProcess,DWORD64 qwAddr);
WINBOOL WINAPI (*symGetLineFromAddr64)(
HANDLE hProcess,DWORD64 qwAddr,PDWORD pdwDisplacement,PIMAGEHLP_LINE64 Line64);
Dbghelp()
: lib(LoadLibraryA("dbghelp.dll")),
symInitialize(lib == nullptr ? nullptr :
reinterpret_cast<decltype(symInitialize)>(
GetProcAddress(lib, "SymInitialize"))),
stackWalk64(symInitialize == nullptr ? nullptr :
reinterpret_cast<decltype(stackWalk64)>(
GetProcAddress(lib, "StackWalk64"))),
symFunctionTableAccess64(symInitialize == nullptr ? nullptr :
reinterpret_cast<decltype(symFunctionTableAccess64)>(
GetProcAddress(lib, "SymFunctionTableAccess64"))),
symGetModuleBase64(symInitialize == nullptr ? nullptr :
reinterpret_cast<decltype(symGetModuleBase64)>(
GetProcAddress(lib, "SymGetModuleBase64"))),
symGetLineFromAddr64(symInitialize == nullptr ? nullptr :
reinterpret_cast<decltype(symGetLineFromAddr64)>(
GetProcAddress(lib, "SymGetLineFromAddr64"))) {
if (symInitialize != nullptr) {
symInitialize(GetCurrentProcess(), NULL, TRUE);
}
}
};
const Dbghelp& getDbghelp() {
static Dbghelp dbghelp;
return dbghelp;
}
ArrayPtr<void* const> getStackTrace(ArrayPtr<void*> space, uint ignoreCount,
HANDLE thread, CONTEXT& context) {
const Dbghelp& dbghelp = getDbghelp();
if (dbghelp.stackWalk64 == nullptr ||
dbghelp.symFunctionTableAccess64 == nullptr ||
dbghelp.symGetModuleBase64 == nullptr) {
return nullptr;
}
STACKFRAME64 frame;
memset(&frame, 0, sizeof(frame));
frame.AddrPC.Offset = context.Rip;
frame.AddrPC.Mode = AddrModeFlat;
frame.AddrStack.Offset = context.Rsp;
frame.AddrStack.Mode = AddrModeFlat;
frame.AddrFrame.Offset = context.Rbp;
frame.AddrFrame.Mode = AddrModeFlat;
HANDLE process = GetCurrentProcess();
uint count = 0;
for (; count < space.size(); count++) {
if (!dbghelp.stackWalk64(IMAGE_FILE_MACHINE_AMD64, process, thread,
&frame, &context, NULL, dbghelp.symFunctionTableAccess64,
dbghelp.symGetModuleBase64, NULL)){
break;
}
space[count] = reinterpret_cast<void*>(frame.AddrPC.Offset);
}
return space.slice(kj::min(ignoreCount, count), count);
}
} // namespace
#endif
ArrayPtr<void* const> getStackTrace(ArrayPtr<void*> space, uint ignoreCount) {
#ifndef KJ_HAS_BACKTRACE
return nullptr;
#else
#if _WIN32 && _M_X64
CONTEXT context;
RtlCaptureContext(&context);
return getStackTrace(space, ignoreCount, GetCurrentThread(), context);
#elif KJ_HAS_BACKTRACE
size_t size = backtrace(space.begin(), space.size());
return space.slice(kj::min(ignoreCount + 1, size), size);
#else
return nullptr;
#endif
}
String stringifyStackTrace(ArrayPtr<void* const> trace) {
if (trace.size() == 0) return nullptr;
#if (__linux__ || __APPLE__) && !__ANDROID__ && defined(KJ_DEBUG)
#ifndef KJ_DEBUG
return nullptr;
#elif _WIN32 && _M_X64 && _MSC_VER
// Try to get file/line using SymGetLineFromAddr64(). We don't bother if we aren't on MSVC since
// this requires MSVC debug info.
//
// TODO(someday): We could perhaps shell out to addr2line on MinGW.
const Dbghelp& dbghelp = getDbghelp();
if (dbghelp.symGetLineFromAddr64 == nullptr) return nullptr;
HANDLE process = GetCurrentProcess();
KJ_STACK_ARRAY(String, lines, trace.size(), 32, 32);
for (auto i: kj::indices(trace)) {
IMAGEHLP_LINE64 lineInfo;
memset(&lineInfo, 0, sizeof(lineInfo));
lineInfo.SizeOfStruct = sizeof(lineInfo);
if (dbghelp.symGetLineFromAddr64(process, reinterpret_cast<DWORD64>(trace[i]), NULL, &lineInfo)) {
lines[i] = kj::str('\n', lineInfo.FileName, ':', lineInfo.LineNumber);
}
}
return strArray(lines, "");
#elif (__linux__ || __APPLE__) && !__ANDROID__
// We want to generate a human-readable stack trace.
// TODO(someday): It would be really great if we could avoid farming out to another process
......@@ -144,12 +274,63 @@ String stringifyStackTrace(ArrayPtr<void* const> trace) {
pclose(p);
return strArray(arrayPtr(lines, i), "");
#else
return nullptr;
#endif
}
#if KJ_HAS_BACKTRACE
String getStackTrace() {
void* space[32];
auto trace = getStackTrace(space, 2);
return kj::str(kj::strArray(trace, " "), stringifyStackTrace(trace));
}
#if _WIN32 && _M_X64
namespace {
DWORD mainThreadId = 0;
BOOL WINAPI breakHandler(DWORD type) {
switch (type) {
case CTRL_C_EVENT:
case CTRL_BREAK_EVENT: {
HANDLE thread = OpenThread(THREAD_ALL_ACCESS, FALSE, mainThreadId);
if (thread != NULL) {
if (SuspendThread(thread) != (DWORD)-1) {
CONTEXT context;
memset(&context, 0, sizeof(context));
context.ContextFlags = CONTEXT_FULL;
if (GetThreadContext(thread, &context)) {
void* traceSpace[32];
auto trace = getStackTrace(traceSpace, 2, thread, context);
ResumeThread(thread);
auto message = kj::str("*** Received CTRL+C. stack: ", strArray(trace, " "),
stringifyStackTrace(trace), '\n');
FdOutputStream(STDERR_FILENO).write(message.begin(), message.size());
} else {
ResumeThread(thread);
}
}
CloseHandle(thread);
}
break;
}
default:
break;
}
return FALSE; // still crash
}
} // namespace
void printStackTraceOnCrash() {
mainThreadId = GetCurrentThreadId();
KJ_WIN32(SetConsoleCtrlHandler(breakHandler, TRUE));
}
#elif KJ_HAS_BACKTRACE
namespace {
void crashHandler(int signo, siginfo_t* info, void* context) {
......
......@@ -323,6 +323,9 @@ String stringifyStackTrace(ArrayPtr<void* const>);
// Convert the stack trace to a string with file names and line numbers. This may involve executing
// suprocesses.
String getStackTrace();
// Get a stack trace right now and stringify it. Useful for debugging.
void printStackTraceOnCrash();
// Registers signal handlers on common "crash" signals like SIGSEGV that will (attempt to) print
// a stack trace. You should call this as early as possible on program startup. Programs using
......
......@@ -25,7 +25,14 @@
#include <algorithm>
#include <errno.h>
#if !_WIN32
#if _WIN32
#ifndef NOMINMAX
#define NOMINMAX 1
#endif
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#include "windows-sanity.h"
#else
#include <sys/uio.h>
#endif
......@@ -371,4 +378,49 @@ void FdOutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
#endif
}
// =======================================================================================
#if _WIN32
AutoCloseHandle::~AutoCloseHandle() noexcept(false) {
if (handle != (void*)-1) {
KJ_WIN32(CloseHandle(handle));
}
}
HandleInputStream::~HandleInputStream() noexcept(false) {}
size_t HandleInputStream::tryRead(void* buffer, size_t minBytes, size_t maxBytes) {
byte* pos = reinterpret_cast<byte*>(buffer);
byte* min = pos + minBytes;
byte* max = pos + maxBytes;
while (pos < min) {
DWORD n;
KJ_WIN32(ReadFile(handle, pos, kj::min(max - pos, DWORD(kj::maxValue)), &n, nullptr));
if (n == 0) {
break;
}
pos += n;
}
return pos - reinterpret_cast<byte*>(buffer);
}
HandleOutputStream::~HandleOutputStream() noexcept(false) {}
void HandleOutputStream::write(const void* buffer, size_t size) {
const char* pos = reinterpret_cast<const char*>(buffer);
while (size > 0) {
DWORD n;
KJ_WIN32(WriteFile(handle, pos, kj::min(size, DWORD(kj::maxValue)), &n, nullptr));
KJ_ASSERT(n > 0, "write() returned zero.");
pos += n;
size -= n;
}
}
#endif // _WIN32
} // namespace kj
......@@ -326,6 +326,90 @@ private:
AutoCloseFd autoclose;
};
// =======================================================================================
// Win32 Handle I/O
#ifdef _WIN32
class AutoCloseHandle {
// A wrapper around a Win32 HANDLE which automatically closes the handle when destroyed.
// The wrapper supports move construction for transferring ownership of the handle. If
// CloseHandle() returns an error, the destructor throws an exception, UNLESS the destructor is
// being called during unwind from another exception, in which case the close error is ignored.
//
// If your code is not exception-safe, you should not use AutoCloseHandle. In this case you will
// have to call close() yourself and handle errors appropriately.
public:
inline AutoCloseHandle(): handle((void*)-1) {}
inline AutoCloseHandle(decltype(nullptr)): handle((void*)-1) {}
inline explicit AutoCloseHandle(void* handle): handle(handle) {}
inline AutoCloseHandle(AutoCloseHandle&& other) noexcept: handle(other.handle) {
other.handle = (void*)-1;
}
KJ_DISALLOW_COPY(AutoCloseHandle);
~AutoCloseHandle() noexcept(false);
inline AutoCloseHandle& operator=(AutoCloseHandle&& other) {
AutoCloseHandle old(kj::mv(*this));
handle = other.handle;
other.handle = (void*)-1;
return *this;
}
inline AutoCloseHandle& operator=(decltype(nullptr)) {
AutoCloseHandle old(kj::mv(*this));
return *this;
}
inline operator void*() const { return handle; }
inline void* get() const { return handle; }
operator bool() const = delete;
// Deleting this operator prevents accidental use in boolean contexts, which
// the void* conversion operator above would otherwise allow.
inline bool operator==(decltype(nullptr)) { return handle != (void*)-1; }
inline bool operator!=(decltype(nullptr)) { return handle == (void*)-1; }
private:
void* handle; // -1 (aka INVALID_HANDLE_VALUE) if not valid.
};
class HandleInputStream: public InputStream {
// An InputStream wrapping a Win32 HANDLE.
public:
explicit HandleInputStream(void* handle): handle(handle) {}
explicit HandleInputStream(AutoCloseHandle handle): handle(handle), autoclose(mv(handle)) {}
KJ_DISALLOW_COPY(HandleInputStream);
~HandleInputStream() noexcept(false);
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override;
private:
void* handle;
AutoCloseHandle autoclose;
};
class HandleOutputStream: public OutputStream {
// An OutputStream wrapping a Win32 HANDLE.
public:
explicit HandleOutputStream(void* handle): handle(handle) {}
explicit HandleOutputStream(AutoCloseHandle handle): handle(handle), autoclose(mv(handle)) {}
KJ_DISALLOW_COPY(HandleOutputStream);
~HandleOutputStream() noexcept(false);
void write(const void* buffer, size_t size) override;
private:
void* handle;
AutoCloseHandle autoclose;
};
#endif // _WIN32
} // namespace kj
#endif // KJ_IO_H_
......@@ -30,9 +30,12 @@
#include <limits.h>
#if _WIN32
#define NOMINMAX
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX 1
#endif
#include <windows.h>
#undef NOMINMAX
#include "windows-sanity.h"
#else
#include <sys/uio.h>
#endif
......
......@@ -293,12 +293,24 @@ void RemoveE0(char* buffer) {
// Remove redundant leading 0's after an e, e.g. 1e012. Seems to appear on
// Windows.
for (;;) {
buffer = strstr(buffer, "e0");
if (buffer == NULL || buffer[2] < '0' || buffer[2] > '9') {
return;
}
memmove(buffer + 1, buffer + 2, strlen(buffer + 2) + 1);
// Find and skip 'e'.
char* ptr = strchr(buffer, 'e');
if (ptr == nullptr) return;
++ptr;
// Skip '-'.
if (*ptr == '-') ++ptr;
// Skip '0's.
char* ptr2 = ptr;
while (*ptr2 == '0') ++ptr2;
// If we went past the last digit, back up one.
if (*ptr2 < '0' || *ptr2 > '9') --ptr2;
// Move bytes backwards.
if (ptr2 > ptr) {
memmove(ptr, ptr2, strlen(ptr2) + 1);
}
}
#endif
......@@ -398,6 +410,9 @@ char* FloatToBuffer(float value, char* buffer) {
DelocalizeRadix(buffer);
RemovePlus(buffer);
#if _WIN32
RemoveE0(buffer);
#endif // _WIN32
return buffer;
}
......
......@@ -22,6 +22,7 @@
#include "time.h"
#include "debug.h"
#include <set>
namespace kj {
......@@ -29,4 +30,96 @@ kj::Exception Timer::makeTimeoutException() {
return KJ_EXCEPTION(OVERLOADED, "operation timed out");
}
struct TimerImpl::Impl {
struct TimerBefore {
bool operator()(TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs);
};
using Timers = std::multiset<TimerPromiseAdapter*, TimerBefore>;
Timers timers;
};
class TimerImpl::TimerPromiseAdapter {
public:
TimerPromiseAdapter(PromiseFulfiller<void>& fulfiller, TimerImpl::Impl& impl, TimePoint time)
: time(time), fulfiller(fulfiller), impl(impl) {
pos = impl.timers.insert(this);
}
~TimerPromiseAdapter() {
if (pos != impl.timers.end()) {
impl.timers.erase(pos);
}
}
void fulfill() {
fulfiller.fulfill();
impl.timers.erase(pos);
pos = impl.timers.end();
}
const TimePoint time;
private:
PromiseFulfiller<void>& fulfiller;
TimerImpl::Impl& impl;
Impl::Timers::const_iterator pos;
};
inline bool TimerImpl::Impl::TimerBefore::operator()(
TimerPromiseAdapter* lhs, TimerPromiseAdapter* rhs) {
return lhs->time < rhs->time;
}
Promise<void> TimerImpl::atTime(TimePoint time) {
return newAdaptedPromise<void, TimerPromiseAdapter>(*impl, time);
}
Promise<void> TimerImpl::afterDelay(Duration delay) {
return newAdaptedPromise<void, TimerPromiseAdapter>(*impl, time + delay);
}
TimerImpl::TimerImpl(TimePoint startTime)
: time(startTime), impl(heap<Impl>()) {}
TimerImpl::~TimerImpl() noexcept(false) {}
Maybe<TimePoint> TimerImpl::nextEvent() {
auto iter = impl->timers.begin();
if (iter == impl->timers.end()) {
return nullptr;
} else {
return (*iter)->time;
}
}
Maybe<uint64_t> TimerImpl::timeoutToNextEvent(TimePoint start, Duration unit, uint64_t max) {
return nextEvent().map([&](TimePoint nextTime) -> uint64_t {
if (nextTime <= start) return 0;
Duration timeout = nextTime - start;
uint64_t result = timeout / unit;
bool roundUp = timeout % unit > 0 * SECONDS;
if (result >= max) {
return max;
} else {
return result + roundUp;
}
});
}
void TimerImpl::advanceTo(TimePoint newTime) {
KJ_REQUIRE(newTime >= time, "can't advance backwards in time") { return; }
time = newTime;
for (;;) {
auto front = impl->timers.begin();
if (front == impl->timers.end() || (*front)->time > time) {
break;
}
(*front)->fulfill();
}
}
} // namespace kj
......@@ -97,6 +97,49 @@ private:
static kj::Exception makeTimeoutException();
};
class TimerImpl final: public Timer {
// Implementation of Timer that expects an external caller -- usually, the EventPort
// implementation -- to tell it when time has advanced.
public:
TimerImpl(TimePoint startTime);
~TimerImpl() noexcept(false);
Maybe<TimePoint> nextEvent();
// Returns the time at which the next scheduled timer event will occur, or null if no timer
// events are scheduled.
Maybe<uint64_t> timeoutToNextEvent(TimePoint start, Duration unit, uint64_t max);
// Convenience method which computes a timeout value to pass to an event-waiting system call to
// cause it to time out when the next timer event occurs.
//
// `start` is the time at which the timeout starts counting. This is typically not the same as
// now() since some time may have passed since the last time advanceTo() was called.
//
// `unit` is the time unit in which the timeout is measured. This is often MILLISECONDS. Note
// that this method will fractional values *up*, to guarantee that the returned timeout waits
// until just *after* the time the event is scheduled.
//
// The timeout will be clamped to `max`. Use this to avoid an overflow if e.g. the OS wants a
// 32-bit value or a signed value.
//
// Returns nullptr if there are no future events.
void advanceTo(TimePoint newTime);
// Set the time to `time` and fire any at() events that have been passed.
// implements Timer ----------------------------------------------------------
TimePoint now() override;
Promise<void> atTime(TimePoint time) override;
Promise<void> afterDelay(Duration delay) override;
private:
struct Impl;
class TimerPromiseAdapter;
TimePoint time;
Own<Impl> impl;
};
// =======================================================================================
// inline implementation details
......@@ -114,6 +157,8 @@ Promise<T> Timer::timeoutAfter(Duration delay, Promise<T>&& promise) {
}));
}
inline TimePoint TimerImpl::now() { return time; }
} // namespace kj
#endif // KJ_TIME_H_
......@@ -264,9 +264,9 @@ public:
return value / other.value;
}
template <typename OtherNumber>
inline constexpr decltype(Number(1) % OtherNumber(1))
inline constexpr Quantity<decltype(Number(1) % OtherNumber(1)), Unit>
operator%(const Quantity<OtherNumber, Unit>& other) const {
return value % other.value;
return Quantity<decltype(Number(1) % OtherNumber(1)), Unit>(value % other.value);
}
template <typename OtherNumber, typename OtherUnit>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment