Commit 23b792a6 authored by Kenton Varda's avatar Kenton Varda

Support DNS lookup.

parent 06fd35c4
...@@ -30,12 +30,12 @@ namespace _ { ...@@ -30,12 +30,12 @@ namespace _ {
namespace { namespace {
TEST(EzRpc, Basic) { TEST(EzRpc, Basic) {
EzRpcServer server("127.0.0.1"); EzRpcServer server("localhost");
int callCount = 0; int callCount = 0;
server.exportCap("cap1", kj::heap<TestInterfaceImpl>(callCount)); server.exportCap("cap1", kj::heap<TestInterfaceImpl>(callCount));
server.exportCap("cap2", kj::heap<TestCallOrderImpl>()); server.exportCap("cap2", kj::heap<TestCallOrderImpl>());
EzRpcClient client("127.0.0.1", server.getPort().wait()); EzRpcClient client("localhost", server.getPort().wait());
auto cap = client.importCap<test::TestInterface>("cap1"); auto cap = client.importCap<test::TestInterface>("cap1");
auto request = cap.fooRequest(); auto request = cap.fooRequest();
......
...@@ -102,8 +102,8 @@ struct EzRpcClient::Impl { ...@@ -102,8 +102,8 @@ struct EzRpcClient::Impl {
Impl(kj::StringPtr serverAddress, uint defaultPort) Impl(kj::StringPtr serverAddress, uint defaultPort)
: context(EzRpcContext::getThreadLocal()), : context(EzRpcContext::getThreadLocal()),
setupPromise(context->getIoProvider().getNetwork() setupPromise(context->getIoProvider().getNetwork()
.parseRemoteAddress(serverAddress, defaultPort) .parseAddress(serverAddress, defaultPort)
.then([](kj::Own<kj::RemoteAddress>&& addr) { .then([](kj::Own<kj::NetworkAddress>&& addr) {
return addr->connect(); return addr->connect();
}).then([this](kj::Own<kj::AsyncIoStream>&& stream) { }).then([this](kj::Own<kj::AsyncIoStream>&& stream) {
clientContext = kj::heap<ClientContext>(kj::mv(stream)); clientContext = kj::heap<ClientContext>(kj::mv(stream));
...@@ -112,7 +112,7 @@ struct EzRpcClient::Impl { ...@@ -112,7 +112,7 @@ struct EzRpcClient::Impl {
Impl(struct sockaddr* serverAddress, uint addrSize) Impl(struct sockaddr* serverAddress, uint addrSize)
: context(EzRpcContext::getThreadLocal()), : context(EzRpcContext::getThreadLocal()),
setupPromise(context->getIoProvider().getNetwork() setupPromise(context->getIoProvider().getNetwork()
.getRemoteSockaddr(serverAddress, addrSize)->connect() .getSockaddr(serverAddress, addrSize)->connect()
.then([this](kj::Own<kj::AsyncIoStream>&& stream) { .then([this](kj::Own<kj::AsyncIoStream>&& stream) {
clientContext = kj::heap<ClientContext>(kj::mv(stream)); clientContext = kj::heap<ClientContext>(kj::mv(stream));
}).fork()) {} }).fork()) {}
...@@ -196,10 +196,10 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<Text>, public kj::TaskS ...@@ -196,10 +196,10 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<Text>, public kj::TaskS
auto paf = kj::newPromiseAndFulfiller<uint>(); auto paf = kj::newPromiseAndFulfiller<uint>();
portPromise = paf.promise.fork(); portPromise = paf.promise.fork();
tasks.add(context->getIoProvider().getNetwork().parseLocalAddress(bindAddress, defaultPort) tasks.add(context->getIoProvider().getNetwork().parseAddress(bindAddress, defaultPort)
.then(kj::mvCapture(paf.fulfiller, .then(kj::mvCapture(paf.fulfiller,
[this](kj::Own<kj::PromiseFulfiller<uint>>&& portFulfiller, [this](kj::Own<kj::PromiseFulfiller<uint>>&& portFulfiller,
kj::Own<kj::LocalAddress>&& addr) { kj::Own<kj::NetworkAddress>&& addr) {
auto listener = addr->listen(); auto listener = addr->listen();
portFulfiller->fulfill(listener->getPort()); portFulfiller->fulfill(listener->getPort());
acceptLoop(kj::mv(listener)); acceptLoop(kj::mv(listener));
...@@ -209,7 +209,7 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<Text>, public kj::TaskS ...@@ -209,7 +209,7 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<Text>, public kj::TaskS
Impl(struct sockaddr* bindAddress, uint addrSize) Impl(struct sockaddr* bindAddress, uint addrSize)
: context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) { : context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) {
auto listener = context->getIoProvider().getNetwork() auto listener = context->getIoProvider().getNetwork()
.getLocalSockaddr(bindAddress, addrSize)->listen(); .getSockaddr(bindAddress, addrSize)->listen();
portPromise = kj::Promise<uint>(listener->getPort()).fork(); portPromise = kj::Promise<uint>(listener->getPort()).fork();
acceptLoop(kj::mv(listener)); acceptLoop(kj::mv(listener));
} }
......
...@@ -42,8 +42,8 @@ TEST(AsyncIo, SimpleNetwork) { ...@@ -42,8 +42,8 @@ TEST(AsyncIo, SimpleNetwork) {
auto port = newPromiseAndFulfiller<uint>(); auto port = newPromiseAndFulfiller<uint>();
port.promise.then([&](uint portnum) { port.promise.then([&](uint portnum) {
return network.parseRemoteAddress("127.0.0.1", portnum); return network.parseAddress("localhost", portnum);
}).then([&](Own<RemoteAddress>&& result) { }).then([&](Own<NetworkAddress>&& result) {
return result->connect(); return result->connect();
}).then([&](Own<AsyncIoStream>&& result) { }).then([&](Own<AsyncIoStream>&& result) {
client = kj::mv(result); client = kj::mv(result);
...@@ -52,7 +52,7 @@ TEST(AsyncIo, SimpleNetwork) { ...@@ -52,7 +52,7 @@ TEST(AsyncIo, SimpleNetwork) {
ADD_FAILURE() << kj::str(exception).cStr(); ADD_FAILURE() << kj::str(exception).cStr();
}); });
kj::String result = network.parseLocalAddress("*").then([&](Own<LocalAddress>&& result) { kj::String result = network.parseAddress("*").then([&](Own<NetworkAddress>&& result) {
listener = result->listen(); listener = result->listen();
port.fulfiller->fulfill(listener->getPort()); port.fulfiller->fulfill(listener->getPort());
return listener->accept(); return listener->accept();
...@@ -67,28 +67,32 @@ TEST(AsyncIo, SimpleNetwork) { ...@@ -67,28 +67,32 @@ TEST(AsyncIo, SimpleNetwork) {
EXPECT_EQ("foo", result); EXPECT_EQ("foo", result);
} }
String tryParseLocal(Network& network, StringPtr text, uint portHint = 0) { String tryParse(Network& network, StringPtr text, uint portHint = 0) {
return network.parseLocalAddress(text, portHint).wait()->toString(); return network.parseAddress(text, portHint).wait()->toString();
}
String tryParseRemote(Network& network, StringPtr text, uint portHint = 0) {
return network.parseRemoteAddress(text, portHint).wait()->toString();
} }
TEST(AsyncIo, AddressParsing) { TEST(AsyncIo, AddressParsing) {
auto ioContext = setupAsyncIo(); auto ioContext = setupAsyncIo();
auto& network = ioContext.provider->getNetwork(); auto& network = ioContext.provider->getNetwork();
EXPECT_EQ("*:0", tryParseLocal(network, "*")); EXPECT_EQ("*:0", tryParse(network, "*"));
EXPECT_EQ("*:123", tryParseLocal(network, "123")); EXPECT_EQ("*:123", tryParse(network, "*:123"));
EXPECT_EQ("*:123", tryParseLocal(network, ":123")); EXPECT_EQ("[::]:123", tryParse(network, "0::0", 123));
EXPECT_EQ("[::]:123", tryParseLocal(network, "0::0", 123)); EXPECT_EQ("0.0.0.0:0", tryParse(network, "0.0.0.0"));
EXPECT_EQ("0.0.0.0:0", tryParseLocal(network, "0.0.0.0")); EXPECT_EQ("1.2.3.4:5678", tryParse(network, "1.2.3.4", 5678));
EXPECT_EQ("1.2.3.4:5678", tryParseRemote(network, "1.2.3.4", 5678)); EXPECT_EQ("[12ab:cd::34]:321", tryParse(network, "[12ab:cd:0::0:34]:321", 432));
EXPECT_EQ("[12ab:cd::34]:321", tryParseRemote(network, "[12ab:cd:0::0:34]:321", 432));
EXPECT_EQ("unix:foo/bar/baz", tryParse(network, "unix:foo/bar/baz"));
// We can parse services by name...
EXPECT_EQ("1.2.3.4:80", tryParse(network, "1.2.3.4:http", 5678));
EXPECT_EQ("[::]:80", tryParse(network, "[::]:http", 5678));
EXPECT_EQ("[12ab:cd::34]:80", tryParse(network, "[12ab:cd::34]:http", 5678));
EXPECT_EQ("*:80", tryParse(network, "*:http", 5678));
EXPECT_EQ("unix:foo/bar/baz", tryParseLocal(network, "unix:foo/bar/baz")); // It would be nice to test DNS lookup here but the test would not be very hermetic. Even
EXPECT_EQ("unix:foo/bar/baz", tryParseRemote(network, "unix:foo/bar/baz")); // localhost can map to different addresses depending on whether IPv6 is enabled. We do
// connect to "localhost" in a different test, though.
} }
TEST(AsyncIo, OneWayPipe) { TEST(AsyncIo, OneWayPipe) {
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "async-unix.h" #include "async-unix.h"
#include "debug.h" #include "debug.h"
#include "thread.h" #include "thread.h"
#include "io.h"
#include <unistd.h> #include <unistd.h>
#include <sys/uio.h> #include <sys/uio.h>
#include <errno.h> #include <errno.h>
...@@ -36,6 +37,8 @@ ...@@ -36,6 +37,8 @@
#include <stddef.h> #include <stddef.h>
#include <stdlib.h> #include <stdlib.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netdb.h>
#include <set>
#ifndef POLLRDHUP #ifndef POLLRDHUP
// Linux-only optimization. If not available, define to 0, as this will make it a no-op. // Linux-only optimization. If not available, define to 0, as this will make it a no-op.
...@@ -62,6 +65,14 @@ void setCloseOnExec(int fd) { ...@@ -62,6 +65,14 @@ void setCloseOnExec(int fd) {
} }
} }
static constexpr uint NEW_FD_FLAGS =
#if __linux__
LowLevelAsyncIoProvider::ALREADY_CLOEXEC || LowLevelAsyncIoProvider::ALREADY_NONBLOCK ||
#endif
LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
// We always try to open FDs with CLOEXEC and NONBLOCK already set on Linux, but on other platforms
// this is not possible.
class OwnedFileDescriptor { class OwnedFileDescriptor {
public: public:
OwnedFileDescriptor(int fd, uint flags): fd(fd), flags(flags) { OwnedFileDescriptor(int fd, uint flags): fd(fd), flags(flags) {
...@@ -261,6 +272,18 @@ public: ...@@ -261,6 +272,18 @@ public:
memcpy(&addr.generic, sockaddr, len); memcpy(&addr.generic, sockaddr, len);
} }
bool operator<(const SocketAddress& other) const {
// So we can use std::set<SocketAddress>... see DNS lookup code.
if (wildcard < other.wildcard) return true;
if (wildcard > other.wildcard) return false;
if (addrlen < other.addrlen) return true;
if (addrlen > other.addrlen) return false;
return memcmp(&addr.generic, &other.addr.generic, addrlen) < 0;
}
int socket(int type) const { int socket(int type) const {
int result; int result;
#if __linux__ #if __linux__
...@@ -340,7 +363,14 @@ public: ...@@ -340,7 +363,14 @@ public:
} }
} }
static SocketAddress parse(StringPtr str, uint portHint, bool requirePort = true) { static Promise<Array<SocketAddress>> lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint);
// Perform a DNS lookup.
static Promise<Array<SocketAddress>> parse(
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) {
// TODO(someday): Allow commas in `str`.
SocketAddress result; SocketAddress result;
if (str.startsWith("unix:")) { if (str.startsWith("unix:")) {
...@@ -350,7 +380,9 @@ public: ...@@ -350,7 +380,9 @@ public:
result.addr.unixDomain.sun_family = AF_UNIX; result.addr.unixDomain.sun_family = AF_UNIX;
strcpy(result.addr.unixDomain.sun_path, path.cStr()); strcpy(result.addr.unixDomain.sun_path, path.cStr());
result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1; result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
return result; auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
} }
// Try to separate the address and port. // Try to separate the address and port.
...@@ -399,16 +431,25 @@ public: ...@@ -399,16 +431,25 @@ public:
char* endptr; char* endptr;
port = strtoul(portText->cStr(), &endptr, 0); port = strtoul(portText->cStr(), &endptr, 0);
if (portText->size() == 0 || *endptr != '\0') { if (portText->size() == 0 || *endptr != '\0') {
KJ_FAIL_REQUIRE("Invalid IP port number.", *portText); // Not a number. Maybe it's a service name. Fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint);
} }
KJ_REQUIRE(port < 65536, "Port number too large."); KJ_REQUIRE(port < 65536, "Port number too large.");
} else { } else {
if (requirePort) {
KJ_REQUIRE(portHint != 0, "You must specify a port with this address.", str);
}
port = portHint; port = portHint;
} }
// Check for wildcard.
if (addrPart.size() == 1 && addrPart[0] == '*') {
result.wildcard = true;
result.addrlen = sizeof(addr.inet6);
result.addr.inet6.sin6_family = AF_INET6;
result.addr.inet6.sin6_port = htons(port);
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
}
void* addrTarget; void* addrTarget;
if (af == AF_INET6) { if (af == AF_INET6) {
result.addrlen = sizeof(addr.inet6); result.addrlen = sizeof(addr.inet6);
...@@ -430,50 +471,20 @@ public: ...@@ -430,50 +471,20 @@ public:
// OK, parse it! // OK, parse it!
switch (inet_pton(af, buffer, addrTarget)) { switch (inet_pton(af, buffer, addrTarget)) {
case 1: case 1: {
// success. // success.
return result; auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
}
case 0: case 0:
KJ_FAIL_REQUIRE("Invalid IP address.", addrPart); // It's apparently not a simple address... fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port);
default: default:
KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart); KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart);
} }
} }
static SocketAddress parseLocal(StringPtr str, uint portHint) {
// If the address contains no colons, or only a leading colon, and no periods, then it is a
// port only. If is empty, then it is a total wildcard. Otherwise, it is a full address
// specified the same as any remote address.
if (str == "*" ||
(str.findLast(':').orDefault(0) <= 1 &&
str.findFirst('.') == nullptr)) {
unsigned long port;
if (str == "*") {
port = portHint;
} else {
if (str[0] == ':') {
str = str.slice(1);
}
char* endptr;
port = strtoul(str.cStr(), &endptr, 0);
if (str.size() == 0 || *endptr != '\0') {
KJ_FAIL_REQUIRE("Invalid IP port number.", str);
}
KJ_REQUIRE(port < 65536, "Port number too large.");
}
// Prepare to bind to ALL IP interfaces. SocketAddress is zero'd by default.
SocketAddress result;
result.wildcard = true;
result.addrlen = sizeof(addr.inet6);
result.addr.inet6.sin6_family = AF_INET6;
result.addr.inet6.sin6_port = htons(port);
return result;
} else {
return parse(str, portHint, false);
}
}
static SocketAddress getLocalAddress(int sockfd) { static SocketAddress getLocalAddress(int sockfd) {
SocketAddress result; SocketAddress result;
result.addrlen = sizeof(addr); result.addrlen = sizeof(addr);
...@@ -495,15 +506,155 @@ private: ...@@ -495,15 +506,155 @@ private:
struct sockaddr_un unixDomain; struct sockaddr_un unixDomain;
struct sockaddr_storage storage; struct sockaddr_storage storage;
} addr; } addr;
struct LookupParams;
class LookupReader;
}; };
// ======================================================================================= class SocketAddress::LookupReader {
// Reads SocketAddresses off of a pipe coming from another thread that is performing
// getaddrinfo.
static constexpr uint NEW_FD_FLAGS = public:
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input)
: thread(kj::mv(thread)), input(kj::mv(input)) {}
~LookupReader() {
if (thread) thread->detach();
}
Promise<Array<SocketAddress>> read() {
return input->tryRead(&current, sizeof(current), sizeof(current)).then(
[this](size_t n) -> Promise<Array<SocketAddress>> {
if (n < sizeof(current)) {
thread = nullptr;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; }
return addresses.releaseAsArray();
} else {
// getaddrinfo() can return multiple copies of the same address for several reasons.
// A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so
// it may return two copies of the same address, one for each type, unless it explicitly
// knows that the service name given is specific to one type. But we can't tell it a type,
// because we don't actually know which one the user wants, and if we specify SOCK_STREAM
// while the user specified a UDP service name then they'll get a resolution error which
// is lame. (At least, I think that's how it works.)
//
// So we instead resort to de-duping results.
if (alreadySeen.insert(current).second) {
addresses.add(current);
}
return read();
}
});
}
private:
kj::Own<Thread> thread;
kj::Own<AsyncInputStream> input;
SocketAddress current;
kj::Vector<SocketAddress> addresses;
std::set<SocketAddress> alreadySeen;
};
struct SocketAddress::LookupParams {
kj::String host;
kj::String service;
};
Promise<Array<SocketAddress>> SocketAddress::lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) {
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
//
// TODO(perf): Use a thread pool? Maybe kj::Thread should use a thread pool automatically?
// Maybe use the various platform-specific asynchronous DNS libraries? Please do not implement
// a custom DNS resolver...
int fds[2];
#if __linux__ #if __linux__
LowLevelAsyncIoProvider::ALREADY_CLOEXEC || LowLevelAsyncIoProvider::ALREADY_NONBLOCK || KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC));
#else
KJ_SYSCALL(pipe(fds));
#endif #endif
LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS);
int outFd = fds[1];
LookupParams params = { kj::mv(host), kj::mv(service) };
auto thread = heap<Thread>(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) {
FdOutputStream output((AutoCloseFd(outFd)));
struct addrinfo* list;
int status = getaddrinfo(
params.host == "*" ? nullptr : params.host.cStr(),
params.service == nullptr ? nullptr : params.service.cStr(),
nullptr, &list);
if (status == 0) {
KJ_DEFER(freeaddrinfo(list));
struct addrinfo* cur = list;
while (cur != nullptr) {
if (params.service == nullptr) {
switch (cur->ai_addr->sa_family) {
case AF_INET:
((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint);
break;
case AF_INET6:
((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint);
break;
default:
break;
}
}
SocketAddress addr;
if (params.host == "*") {
// Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo().
addr.wildcard = true;
addr.addrlen = sizeof(addr.addr.inet6);
addr.addr.inet6.sin6_family = AF_INET6;
switch (cur->ai_addr->sa_family) {
case AF_INET:
addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port;
break;
case AF_INET6:
addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port;
break;
default:
addr.addr.inet6.sin6_port = portHint;
break;
}
} else {
addr.addrlen = cur->ai_addrlen;
memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen);
}
static_assert(__has_trivial_copy(SocketAddress), "Can't write() SocketAddress...");
output.write(&addr, sizeof(addr));
cur = cur->ai_next;
}
} else if (status == EAI_SYSTEM) {
KJ_FAIL_SYSCALL("getaddrinfo", errno, params.host, params.service) {
return;
}
} else {
KJ_FAIL_REQUIRE("DNS lookup failed.",
params.host, params.service, gai_strerror(status)) {
return;
}
}
}));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input));
auto result = reader->read();
result.attach(kj::mv(reader));
return kj::mv(result);
}
// =======================================================================================
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor { class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor {
public: public:
...@@ -600,13 +751,20 @@ private: ...@@ -600,13 +751,20 @@ private:
// ======================================================================================= // =======================================================================================
class LocalSocketAddress final: public LocalAddress { class NetworkAddressImpl final: public NetworkAddress {
public: public:
LocalSocketAddress(LowLevelAsyncIoProvider& lowLevel, SocketAddress addr) NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array<SocketAddress> addrs)
: lowLevel(lowLevel), addr(addr) {} : lowLevel(lowLevel), addrs(kj::mv(addrs)) {}
Promise<Own<AsyncIoStream>> connect() override {
return connectImpl(0);
}
Own<ConnectionReceiver> listen() override { Own<ConnectionReceiver> listen() override {
int fd = addr.socket(SOCK_STREAM); KJ_ASSERT(addrs.size() == 1,
"Sorry, unimplemented: Binding listen socket to multiple addresses.");
int fd = addrs[0].socket(SOCK_STREAM);
{ {
KJ_ON_SCOPE_FAILURE(close(fd)); KJ_ON_SCOPE_FAILURE(close(fd));
...@@ -616,7 +774,7 @@ public: ...@@ -616,7 +774,7 @@ public:
int optval = 1; int optval = 1;
KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval))); KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)));
addr.bind(fd); addrs[0].bind(fd);
// TODO(someday): Let queue size be specified explicitly in string addresses. // TODO(someday): Let queue size be specified explicitly in string addresses.
KJ_SYSCALL(::listen(fd, SOMAXCONN)); KJ_SYSCALL(::listen(fd, SOMAXCONN));
...@@ -626,63 +784,67 @@ public: ...@@ -626,63 +784,67 @@ public:
} }
String toString() override { String toString() override {
return addr.toString(); return strArray(KJ_MAP(addr, addrs) { return addr.toString(); }, ",");
} }
private: private:
LowLevelAsyncIoProvider& lowLevel; LowLevelAsyncIoProvider& lowLevel;
SocketAddress addr; Array<SocketAddress> addrs;
};
class RemoteSocketAddress final: public RemoteAddress { Promise<Own<AsyncIoStream>> connectImpl(uint index) {
public: KJ_ASSERT(index < addrs.size());
RemoteSocketAddress(LowLevelAsyncIoProvider& lowLevel, SocketAddress addr)
: lowLevel(lowLevel), addr(addr) {}
Promise<Own<AsyncIoStream>> connect() override { int fd = addrs[index].socket(SOCK_STREAM);
int fd = addr.socket(SOCK_STREAM);
{ KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
KJ_ON_SCOPE_FAILURE(close(fd)); addrs[index].connect(fd);
addr.connect(fd); })) {
// Connect failed.
close(fd);
if (index + 1 < addrs.size()) {
// Try the next address instead.
return connectImpl(index + 1);
} else {
// No more addresses to try, so propagate the exception.
return kj::mv(*exception);
}
} }
return lowLevel.wrapConnectingSocketFd(fd, NEW_FD_FLAGS); return lowLevel.wrapConnectingSocketFd(fd, NEW_FD_FLAGS).then(
} [](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
// Success, pass along.
String toString() override { return kj::mv(stream);
return addr.toString(); }, [this,index](Exception&& exception) -> Promise<Own<AsyncIoStream>> {
// Connect failed.
if (index + 1 < addrs.size()) {
// Try the next address instead.
return connectImpl(index + 1);
} else {
// No more addresses to try, so propagate the exception.
return kj::mv(exception);
}
});
} }
private:
LowLevelAsyncIoProvider& lowLevel;
SocketAddress addr;
}; };
class SocketNetwork final: public Network { class SocketNetwork final: public Network {
public: public:
explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {} explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
Promise<Own<LocalAddress>> parseLocalAddress(StringPtr addr, uint portHint = 0) override { Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
auto& lowLevelCopy = lowLevel;
return evalLater(mvCapture(heapString(addr),
[&lowLevelCopy,portHint](String&& addr) -> Own<LocalAddress> {
return heap<LocalSocketAddress>(lowLevelCopy, SocketAddress::parseLocal(addr, portHint));
}));
}
Promise<Own<RemoteAddress>> parseRemoteAddress(StringPtr addr, uint portHint = 0) override {
auto& lowLevelCopy = lowLevel; auto& lowLevelCopy = lowLevel;
return evalLater(mvCapture(heapString(addr), return evalLater(mvCapture(heapString(addr),
[&lowLevelCopy,portHint](String&& addr) -> Own<RemoteAddress> { [&lowLevelCopy,portHint](String&& addr) {
return heap<RemoteSocketAddress>(lowLevelCopy, SocketAddress::parse(addr, portHint)); return SocketAddress::parse(lowLevelCopy, addr, portHint);
})); })).then([&lowLevelCopy](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevelCopy, kj::mv(addresses));
});
} }
Own<LocalAddress> getLocalSockaddr(const void* sockaddr, uint len) override { Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
return Own<LocalAddress>(heap<LocalSocketAddress>(lowLevel, SocketAddress(sockaddr, len))); auto array = kj::heapArrayBuilder<SocketAddress>(1);
} array.add(SocketAddress(sockaddr, len));
Own<RemoteAddress> getRemoteSockaddr(const void* sockaddr, uint len) override { return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, array.finish()));
return Own<RemoteAddress>(heap<RemoteSocketAddress>(lowLevel, SocketAddress(sockaddr, len)));
} }
private: private:
......
...@@ -68,12 +68,19 @@ public: ...@@ -68,12 +68,19 @@ public:
// specify a port when constructing the LocalAddress -- one will have been assigned automatically. // specify a port when constructing the LocalAddress -- one will have been assigned automatically.
}; };
class RemoteAddress { class NetworkAddress {
// Represents a remote address to which the application can connect. // Represents a remote address to which the application can connect.
public: public:
virtual Promise<Own<AsyncIoStream>> connect() = 0; virtual Promise<Own<AsyncIoStream>> connect() = 0;
// Make a new connection to this address. // Make a new connection to this address.
//
// The address must not be a wildcard ("*"). If it is an IP address, it must have a port number.
virtual Own<ConnectionReceiver> listen() = 0;
// Listen for incoming connections on this address.
//
// The address must be local.
virtual String toString() = 0; virtual String toString() = 0;
// Produce a human-readable string which hopefully can be passed to Network::parseRemoteAddress() // Produce a human-readable string which hopefully can be passed to Network::parseRemoteAddress()
...@@ -86,9 +93,6 @@ class LocalAddress { ...@@ -86,9 +93,6 @@ class LocalAddress {
// Represents a local address on which the application can potentially accept connections. // Represents a local address on which the application can potentially accept connections.
public: public:
virtual Own<ConnectionReceiver> listen() = 0;
// Listen for incoming connections on this address.
virtual String toString() = 0; virtual String toString() = 0;
// Produce a human-readable string which hopefully can be passed to Network::parseRemoteAddress() // Produce a human-readable string which hopefully can be passed to Network::parseRemoteAddress()
// to reproduce this address, although whether or not that works of course depends on the Network // to reproduce this address, although whether or not that works of course depends on the Network
...@@ -105,30 +109,24 @@ class Network { ...@@ -105,30 +109,24 @@ class Network {
// LocalAddress and/or RemoteAddress instances directly and work from there, if at all possible. // LocalAddress and/or RemoteAddress instances directly and work from there, if at all possible.
public: public:
virtual Promise<Own<LocalAddress>> parseLocalAddress( virtual Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) = 0;
StringPtr addr, uint portHint = 0) = 0; // Construct a network address from a user-provided string. The format of the address
virtual Promise<Own<RemoteAddress>> parseRemoteAddress(
StringPtr addr, uint portHint = 0) = 0;
// Construct a local or remote address from a user-provided string. The format of the address
// strings is not specified at the API level, and application code should make no assumptions // strings is not specified at the API level, and application code should make no assumptions
// about them. These strings should always be provided by humans, and said humans will know // about them. These strings should always be provided by humans, and said humans will know
// what format to use in their particular context. // what format to use in their particular context.
// //
// `portHint`, if provided, specifies the "standard" IP port number for the application-level // `portHint`, if provided, specifies the "standard" IP port number for the application-level
// service in play. If the address turns out to be an IP address (v4 or v6), and it lacks a // service in play. If the address turns out to be an IP address (v4 or v6), and it lacks a
// port number, this port will be used. // port number, this port will be used. If `addr` lacks a port number *and* `portHint` is
// // omitted, then the returned address will only support listen() (not connect()), and a port
// In practice, a local address is usually just a port number (or even an empty string, if a // will be chosen when listen() is called.
// reasonable `portHint` is provided), whereas a remote address usually requires a hostname.
virtual Own<LocalAddress> getLocalSockaddr(const void* sockaddr, uint len) = 0; virtual Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) = 0;
virtual Own<RemoteAddress> getRemoteSockaddr(const void* sockaddr, uint len) = 0; // Construct a network address from a legacy struct sockaddr.
// Construct a local or remote address from a legacy struct sockaddr.
}; };
struct OneWayPipe { struct OneWayPipe {
// A data pipe with an input end and an output end. The two ends are safe to use in different // A data pipe with an input end and an output end. (Typically backed by pipe() system call.)
// threads. (Typically backed by pipe() system call.)
Own<AsyncInputStream> in; Own<AsyncInputStream> in;
Own<AsyncOutputStream> out; Own<AsyncOutputStream> out;
...@@ -136,8 +134,7 @@ struct OneWayPipe { ...@@ -136,8 +134,7 @@ struct OneWayPipe {
struct TwoWayPipe { struct TwoWayPipe {
// A data pipe that supports sending in both directions. Each end's output sends data to the // A data pipe that supports sending in both directions. Each end's output sends data to the
// other end's input. The ends can be used in separate threads. (Typically backed by // other end's input. (Typically backed by socketpair() system call.)
// socketpair() system call.)
Own<AsyncIoStream> ends[2]; Own<AsyncIoStream> ends[2];
}; };
...@@ -175,7 +172,8 @@ public: ...@@ -175,7 +172,8 @@ public:
// With that said, KJ currently supports the following string address formats: // With that said, KJ currently supports the following string address formats:
// - IPv4: "1.2.3.4", "1.2.3.4:80" // - IPv4: "1.2.3.4", "1.2.3.4:80"
// - IPv6: "1234:5678::abcd", "[1234:5678::abcd]:80" // - IPv6: "1234:5678::abcd", "[1234:5678::abcd]:80"
// - Local IP wildcard (local addresses only; covers both v4 and v6): "*", "*:80", ":80", "80" // - Local IP wildcard (covers both v4 and v6): "*", "*:80"
// - Symbolic names: "example.com", "example.com:80", "example.com:http", "1.2.3.4:http"
// - Unix domain: "unix:/path/to/socket" // - Unix domain: "unix:/path/to/socket"
struct PipeThread { struct PipeThread {
...@@ -284,7 +282,7 @@ struct AsyncIoContext { ...@@ -284,7 +282,7 @@ struct AsyncIoContext {
AsyncIoContext setupAsyncIo(); AsyncIoContext setupAsyncIo();
// Convenience method which sets up the current thread with everything it needs to do async I/O. // Convenience method which sets up the current thread with everything it needs to do async I/O.
// The returned object contains an `EventLoop` which is wrapping an appropriate `EventPort` for // The returned objects contain an `EventLoop` which is wrapping an appropriate `EventPort` for
// doing I/O on the host system, so everything is ready for the thread to start making async calls // doing I/O on the host system, so everything is ready for the thread to start making async calls
// and waiting on promises. // and waiting on promises.
// //
...@@ -292,10 +290,10 @@ AsyncIoContext setupAsyncIo(); ...@@ -292,10 +290,10 @@ AsyncIoContext setupAsyncIo();
// Example: // Example:
// //
// int main() { // int main() {
// auto ioSystem = kj::setupIoEventLoop(); // auto ioContext = kj::setupAsyncIo();
// //
// // Now we can call an async function. // // Now we can call an async function.
// Promise<String> textPromise = getHttp(ioSystem->getNetwork(), "http://example.com"); // Promise<String> textPromise = getHttp(*ioContext.provider, "http://example.com");
// //
// // And we can wait for the promise to complete. Note that you can only use `wait()` // // And we can wait for the promise to complete. Note that you can only use `wait()`
// // from the top level, not from inside a promise callback. // // from the top level, not from inside a promise callback.
......
...@@ -40,13 +40,15 @@ Thread::Thread(Function<void()> func): func(kj::mv(func)) { ...@@ -40,13 +40,15 @@ Thread::Thread(Function<void()> func): func(kj::mv(func)) {
} }
Thread::~Thread() noexcept(false) { Thread::~Thread() noexcept(false) {
int pthreadResult = pthread_join(*reinterpret_cast<pthread_t*>(&threadId), nullptr); if (!detached) {
if (pthreadResult != 0) { int pthreadResult = pthread_join(*reinterpret_cast<pthread_t*>(&threadId), nullptr);
KJ_FAIL_SYSCALL("pthread_join", pthreadResult) { break; } if (pthreadResult != 0) {
} KJ_FAIL_SYSCALL("pthread_join", pthreadResult) { break; }
}
KJ_IF_MAYBE(e, exception) { KJ_IF_MAYBE(e, exception) {
kj::throwRecoverableException(kj::mv(*e)); kj::throwRecoverableException(kj::mv(*e));
}
} }
} }
...@@ -57,6 +59,14 @@ void Thread::sendSignal(int signo) { ...@@ -57,6 +59,14 @@ void Thread::sendSignal(int signo) {
} }
} }
void Thread::detach() {
int pthreadResult = pthread_detach(*reinterpret_cast<pthread_t*>(&threadId));
if (pthreadResult != 0) {
KJ_FAIL_SYSCALL("pthread_detach", pthreadResult) { break; }
}
detached = true;
}
void* Thread::runThread(void* ptr) { void* Thread::runThread(void* ptr) {
Thread* thread = reinterpret_cast<Thread*>(ptr); Thread* thread = reinterpret_cast<Thread*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
......
...@@ -43,10 +43,14 @@ public: ...@@ -43,10 +43,14 @@ public:
void sendSignal(int signo); void sendSignal(int signo);
// Send a Unix signal to the given thread, using pthread_kill or an equivalent. // Send a Unix signal to the given thread, using pthread_kill or an equivalent.
void detach();
// Don't join the thread in ~Thread().
private: private:
Function<void()> func; Function<void()> func;
unsigned long long threadId; // actually pthread_t unsigned long long threadId; // actually pthread_t
kj::Maybe<kj::Exception> exception; kj::Maybe<kj::Exception> exception;
bool detached = false;
static void* runThread(void* ptr); static void* runThread(void* ptr);
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment