async-io.c++ 32.5 KB
Newer Older
Kenton Varda's avatar
Kenton Varda committed
1 2
// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
3
//
Kenton Varda's avatar
Kenton Varda committed
4 5 6 7 8 9
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
10
//
Kenton Varda's avatar
Kenton Varda committed
11 12
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
13
//
Kenton Varda's avatar
Kenton Varda committed
14 15 16 17 18 19 20
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
21 22 23 24

#include "async-io.h"
#include "async-unix.h"
#include "debug.h"
25
#include "thread.h"
Kenton Varda's avatar
Kenton Varda committed
26
#include "io.h"
27 28 29 30 31 32 33
#include <unistd.h>
#include <sys/uio.h>
#include <errno.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
34
#include <netinet/in.h>
35
#include <netinet/tcp.h>
36 37 38
#include <stddef.h>
#include <stdlib.h>
#include <arpa/inet.h>
Kenton Varda's avatar
Kenton Varda committed
39 40
#include <netdb.h>
#include <set>
41 42 43 44 45 46 47 48 49 50

#ifndef POLLRDHUP
// Linux-only optimization.  If not available, define to 0, as this will make it a no-op.
#define POLLRDHUP 0
#endif

namespace kj {

namespace {

51 52 53 54 55 56 57 58
void setNonblocking(int fd) {
  int flags;
  KJ_SYSCALL(flags = fcntl(fd, F_GETFL));
  if ((flags & O_NONBLOCK) == 0) {
    KJ_SYSCALL(fcntl(fd, F_SETFL, flags | O_NONBLOCK));
  }
}

59 60 61 62 63 64 65 66
void setCloseOnExec(int fd) {
  int flags;
  KJ_SYSCALL(flags = fcntl(fd, F_GETFD));
  if ((flags & FD_CLOEXEC) == 0) {
    KJ_SYSCALL(fcntl(fd, F_SETFD, flags | FD_CLOEXEC));
  }
}

Kenton Varda's avatar
Kenton Varda committed
67 68
static constexpr uint NEW_FD_FLAGS =
#if __linux__
69
    LowLevelAsyncIoProvider::ALREADY_CLOEXEC | LowLevelAsyncIoProvider::ALREADY_NONBLOCK |
Kenton Varda's avatar
Kenton Varda committed
70 71 72 73 74
#endif
    LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
// We always try to open FDs with CLOEXEC and NONBLOCK already set on Linux, but on other platforms
// this is not possible.

75 76
class OwnedFileDescriptor {
public:
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
  OwnedFileDescriptor(int fd, uint flags): fd(fd), flags(flags) {
    if (flags & LowLevelAsyncIoProvider::ALREADY_NONBLOCK) {
      KJ_DREQUIRE(fcntl(fd, F_GETFL) & O_NONBLOCK, "You claimed you set NONBLOCK, but you didn't.");
    } else {
      setNonblocking(fd);
    }

    if (flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) {
      if (flags & LowLevelAsyncIoProvider::ALREADY_CLOEXEC) {
        KJ_DREQUIRE(fcntl(fd, F_GETFD) & FD_CLOEXEC,
                    "You claimed you set CLOEXEC, but you didn't.");
      } else {
        setCloseOnExec(fd);
      }
    }
92 93 94 95
  }

  ~OwnedFileDescriptor() noexcept(false) {
    // Don't use SYSCALL() here because close() should not be repeated on EINTR.
96
    if ((flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) && close(fd) < 0) {
97 98 99 100 101 102 103 104 105
      KJ_FAIL_SYSCALL("close", errno, fd) {
        // Recoverable exceptions are safe in destructors.
        break;
      }
    }
  }

protected:
  const int fd;
106 107 108

private:
  uint flags;
109 110 111 112
};

// =======================================================================================

113
class AsyncStreamFd: public OwnedFileDescriptor, public AsyncIoStream {
114
public:
115 116
  AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags)
      : OwnedFileDescriptor(fd, flags), eventPort(eventPort) {}
117
  virtual ~AsyncStreamFd() noexcept(false) {}
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135

  Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
    return tryReadInternal(buffer, minBytes, maxBytes, 0).then([=](size_t result) {
      KJ_REQUIRE(result >= minBytes, "Premature EOF") {
        // Pretend we read zeros from the input.
        memset(reinterpret_cast<byte*>(buffer) + result, 0, minBytes - result);
        return minBytes;
      }
      return result;
    });
  }

  Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
    return tryReadInternal(buffer, minBytes, maxBytes, 0);
  }

  Promise<void> write(const void* buffer, size_t size) override {
    ssize_t writeResult;
136
    KJ_NONBLOCKING_SYSCALL(writeResult = ::write(fd, buffer, size)) {
137 138 139 140 141 142 143 144 145 146 147 148 149
      return READY_NOW;
    }

    // A negative result means EAGAIN, which we can treat the same as having written zero bytes.
    size_t n = writeResult < 0 ? 0 : writeResult;

    if (n == size) {
      return READY_NOW;
    } else {
      buffer = reinterpret_cast<const byte*>(buffer) + n;
      size -= n;
    }

150
    return eventPort.onFdEvent(fd, POLLOUT).then([=](short) {
151 152 153 154 155 156 157 158 159 160 161 162
      return write(buffer, size);
    });
  }

  Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
    if (pieces.size() == 0) {
      return writeInternal(nullptr, nullptr);
    } else {
      return writeInternal(pieces[0], pieces.slice(1, pieces.size()));
    }
  }

163 164 165
  void shutdownWrite() override {
    // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
    // UnixAsyncIoProvider interface.
166
    KJ_SYSCALL(shutdown(fd, SHUT_WR));
167 168
  }

169
private:
170
  UnixEventPort& eventPort;
171 172 173 174 175 176 177 178 179
  bool gotHup = false;

  Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
                                  size_t alreadyRead) {
    // `alreadyRead` is the number of bytes we have already received via previous reads -- minBytes,
    // maxBytes, and buffer have already been adjusted to account for them, but this count must
    // be included in the final return value.

    ssize_t n;
180
    KJ_NONBLOCKING_SYSCALL(n = ::read(fd, buffer, maxBytes)) {
181 182 183 184 185
      return alreadyRead;
    }

    if (n < 0) {
      // Read would block.
186
      return eventPort.onFdEvent(fd, POLLIN | POLLRDHUP).then([=](short events) {
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
        gotHup = events & (POLLHUP | POLLRDHUP);
        return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead);
      });
    } else if (n == 0) {
      // EOF -OR- maxBytes == 0.
      return alreadyRead;
    } else if (implicitCast<size_t>(n) < minBytes) {
      // The kernel returned fewer bytes than we asked for (and fewer than we need).
      if (gotHup) {
        // We've already received an indication that the next read() will return EOF, so there's
        // nothing to wait for.
        return alreadyRead + n;
      } else {
        // We know that calling read() again will simply fail with EAGAIN (unless a new packet just
        // arrived, which is unlikely), so let's not bother to call read() again but instead just
        // go strait to polling.
        //
        // Note:  Actually, if we haven't done any polls yet, then we haven't had a chance to
        //   receive POLLRDHUP yet, so it's possible we're at EOF.  But that seems like a
        //   sufficiently unusual case that we're better off skipping straight to polling here.
        buffer = reinterpret_cast<byte*>(buffer) + n;
        minBytes -= n;
        maxBytes -= n;
        alreadyRead += n;
211
        return eventPort.onFdEvent(fd, POLLIN | POLLRDHUP).then([=](short events) {
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
          gotHup = events & (POLLHUP | POLLRDHUP);
          return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead);
        });
      }
    } else {
      // We read enough to stop here.
      return alreadyRead + n;
    }
  }

  Promise<void> writeInternal(ArrayPtr<const byte> firstPiece,
                              ArrayPtr<const ArrayPtr<const byte>> morePieces) {
    KJ_STACK_ARRAY(struct iovec, iov, 1 + morePieces.size(), 16, 128);

    // writev() interface is not const-correct.  :(
    iov[0].iov_base = const_cast<byte*>(firstPiece.begin());
    iov[0].iov_len = firstPiece.size();
    for (uint i = 0; i < morePieces.size(); i++) {
      iov[i + 1].iov_base = const_cast<byte*>(morePieces[i].begin());
      iov[i + 1].iov_len = morePieces[i].size();
    }

    ssize_t writeResult;
235
    KJ_NONBLOCKING_SYSCALL(writeResult = ::writev(fd, iov.begin(), iov.size())) {
Kenton Varda's avatar
Kenton Varda committed
236 237 238 239 240 241 242 243 244 245 246
      // Error.

      // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
      // a bug that exists in both Clang and GCC:
      //   http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
      //   http://llvm.org/bugs/show_bug.cgi?id=12286
      goto error;
    }
    if (false) {
    error:
      return kj::READY_NOW;
247 248 249 250 251 252 253 254 255 256
    }

    // A negative result means EAGAIN, which we can treat the same as having written zero bytes.
    size_t n = writeResult < 0 ? 0 : writeResult;

    // Discard all data that was written, then issue a new write for what's left (if any).
    for (;;) {
      if (n < firstPiece.size()) {
        // Only part of the first piece was consumed.  Wait for POLLOUT and then write again.
        firstPiece = firstPiece.slice(n, firstPiece.size());
257
        return eventPort.onFdEvent(fd, POLLOUT).then([=](short) {
258 259 260 261
          return writeInternal(firstPiece, morePieces);
        });
      } else if (morePieces.size() == 0) {
        // First piece was fully-consumed and there are no more pieces, so we're done.
262
        KJ_DASSERT(n == firstPiece.size(), n);
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
        return READY_NOW;
      } else {
        // First piece was fully consumed, so move on to the next piece.
        n -= firstPiece.size();
        firstPiece = morePieces[0];
        morePieces = morePieces.slice(1, morePieces.size());
      }
    }
  }
};

// =======================================================================================

class SocketAddress {
public:
  SocketAddress(const void* sockaddr, uint len): addrlen(len) {
    KJ_REQUIRE(len <= sizeof(addr), "Sorry, your sockaddr is too big for me.");
    memcpy(&addr.generic, sockaddr, len);
  }

Kenton Varda's avatar
Kenton Varda committed
283 284 285 286 287 288 289 290 291 292 293 294
  bool operator<(const SocketAddress& other) const {
    // So we can use std::set<SocketAddress>...  see DNS lookup code.

    if (wildcard < other.wildcard) return true;
    if (wildcard > other.wildcard) return false;

    if (addrlen < other.addrlen) return true;
    if (addrlen > other.addrlen) return false;

    return memcmp(&addr.generic, &other.addr.generic, addrlen) < 0;
  }

295
  int socket(int type) const {
296 297
    bool isStream = type == SOCK_STREAM;

298 299 300 301 302
    int result;
#if __linux__
    type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
#endif
    KJ_SYSCALL(result = ::socket(addr.generic.sa_family, type, 0));
303 304 305 306 307 308 309 310 311 312 313 314 315

    if (isStream && (addr.generic.sa_family == AF_INET ||
                     addr.generic.sa_family == AF_INET6)) {
      // TODO(0.5):  As a hack for the 0.4 release we are always setting
      //   TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
      //   RPC protocol.  Later, we should extend the interface to provide more
      //   control over this.  Perhaps write() should have a flag which
      //   specifies whether to pass MSG_MORE.
      int one = 1;
      KJ_SYSCALL(setsockopt(
          result, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)));
    }

316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
    return result;
  }

  void bind(int sockfd) const {
    if (wildcard) {
      // Disable IPV6_V6ONLY because we want to handle both ipv4 and ipv6 on this socket.  (The
      // default value of this option varies across platforms.)
      int value = 0;
      KJ_SYSCALL(setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY, &value, sizeof(value)));
    }

    KJ_SYSCALL(::bind(sockfd, &addr.generic, addrlen), toString());
  }

  void connect(int sockfd) const {
    // Unfortunately connect() doesn't fit the mold of KJ_NONBLOCKING_SYSCALL, since it indicates
    // non-blocking using EINPROGRESS.
    for (;;) {
      if (::connect(sockfd, &addr.generic, addrlen) < 0) {
        int error = errno;
        if (error == EINPROGRESS) {
          return;
        } else if (error != EINTR) {
          KJ_FAIL_SYSCALL("connect()", error, toString()) {
            // Recover by returning, since reads/writes will simply fail.
            return;
          }
        }
      } else {
        // no error
        return;
      }
    }
  }

  uint getPort() const {
    switch (addr.generic.sa_family) {
      case AF_INET: return ntohs(addr.inet4.sin_port);
      case AF_INET6: return ntohs(addr.inet6.sin6_port);
      default: return 0;
    }
  }

  String toString() const {
    if (wildcard) {
      return str("*:", getPort());
    }

    switch (addr.generic.sa_family) {
      case AF_INET: {
        char buffer[INET6_ADDRSTRLEN];
        if (inet_ntop(addr.inet4.sin_family, &addr.inet4.sin_addr,
                      buffer, sizeof(buffer)) == nullptr) {
          KJ_FAIL_SYSCALL("inet_ntop", errno) { return heapString("(inet_ntop error)"); }
        }
        return str(buffer, ':', ntohs(addr.inet4.sin_port));
      }
      case AF_INET6: {
        char buffer[INET6_ADDRSTRLEN];
        if (inet_ntop(addr.inet6.sin6_family, &addr.inet6.sin6_addr,
                      buffer, sizeof(buffer)) == nullptr) {
          KJ_FAIL_SYSCALL("inet_ntop", errno) { return heapString("(inet_ntop error)"); }
        }
        return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
      }
      case AF_UNIX: {
        return str("unix:", addr.unixDomain.sun_path);
      }
      default:
        return str("(unknown address family ", addr.generic.sa_family, ")");
    }
  }

Kenton Varda's avatar
Kenton Varda committed
389 390 391 392 393 394 395 396
  static Promise<Array<SocketAddress>> lookupHost(
      LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint);
  // Perform a DNS lookup.

  static Promise<Array<SocketAddress>> parse(
      LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) {
    // TODO(someday):  Allow commas in `str`.

397 398 399 400 401 402 403 404 405
    SocketAddress result;

    if (str.startsWith("unix:")) {
      StringPtr path = str.slice(strlen("unix:"));
      KJ_REQUIRE(path.size() < sizeof(addr.unixDomain.sun_path),
                 "Unix domain socket address is too long.", str);
      result.addr.unixDomain.sun_family = AF_UNIX;
      strcpy(result.addr.unixDomain.sun_path, path.cStr());
      result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
Kenton Varda's avatar
Kenton Varda committed
406 407 408
      auto array = kj::heapArrayBuilder<SocketAddress>(1);
      array.add(result);
      return array.finish();
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
    }

    // Try to separate the address and port.
    ArrayPtr<const char> addrPart;
    Maybe<StringPtr> portPart;

    int af;

    if (str.startsWith("[")) {
      // Address starts with a bracket, which is a common way to write an ip6 address with a port,
      // since without brackets around the address part, the port looks like another segment of
      // the address.
      af = AF_INET6;
      size_t closeBracket = KJ_ASSERT_NONNULL(str.findLast(']'),
          "Unclosed '[' in address string.", str);

      addrPart = str.slice(1, closeBracket);
      if (str.size() > closeBracket + 1) {
        KJ_REQUIRE(str.slice(closeBracket + 1).startsWith(":"),
                   "Expected port suffix after ']'.", str);
        portPart = str.slice(closeBracket + 2);
      }
    } else {
      KJ_IF_MAYBE(colon, str.findFirst(':')) {
        if (str.slice(*colon + 1).findFirst(':') == nullptr) {
          // There is exactly one colon and no brackets, so it must be an ip4 address with port.
          af = AF_INET;
          addrPart = str.slice(0, *colon);
          portPart = str.slice(*colon + 1);
        } else {
          // There are two or more colons and no brackets, so the whole thing must be an ip6
          // address with no port.
          af = AF_INET6;
          addrPart = str;
        }
      } else {
        // No colons, so it must be an ip4 address without port.
        af = AF_INET;
        addrPart = str;
      }
    }

    // Parse the port.
    unsigned long port;
    KJ_IF_MAYBE(portText, portPart) {
      char* endptr;
      port = strtoul(portText->cStr(), &endptr, 0);
      if (portText->size() == 0 || *endptr != '\0') {
Kenton Varda's avatar
Kenton Varda committed
457 458
        // Not a number.  Maybe it's a service name.  Fall back to DNS.
        return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint);
459 460 461 462 463 464
      }
      KJ_REQUIRE(port < 65536, "Port number too large.");
    } else {
      port = portHint;
    }

Kenton Varda's avatar
Kenton Varda committed
465 466 467 468 469 470 471 472 473 474 475
    // Check for wildcard.
    if (addrPart.size() == 1 && addrPart[0] == '*') {
      result.wildcard = true;
      result.addrlen = sizeof(addr.inet6);
      result.addr.inet6.sin6_family = AF_INET6;
      result.addr.inet6.sin6_port = htons(port);
      auto array = kj::heapArrayBuilder<SocketAddress>(1);
      array.add(result);
      return array.finish();
    }

476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
    void* addrTarget;
    if (af == AF_INET6) {
      result.addrlen = sizeof(addr.inet6);
      result.addr.inet6.sin6_family = AF_INET6;
      result.addr.inet6.sin6_port = htons(port);
      addrTarget = &result.addr.inet6.sin6_addr;
    } else {
      result.addrlen = sizeof(addr.inet4);
      result.addr.inet4.sin_family = AF_INET;
      result.addr.inet4.sin_port = htons(port);
      addrTarget = &result.addr.inet4.sin_addr;
    }

    // addrPart is not necessarily NUL-terminated so we have to make a copy.  :(
    KJ_REQUIRE(addrPart.size() < INET6_ADDRSTRLEN - 1, "IP address too long.", addrPart);
    char buffer[INET6_ADDRSTRLEN];
    memcpy(buffer, addrPart.begin(), addrPart.size());
    buffer[addrPart.size()] = '\0';

    // OK, parse it!
    switch (inet_pton(af, buffer, addrTarget)) {
Kenton Varda's avatar
Kenton Varda committed
497
      case 1: {
498
        // success.
Kenton Varda's avatar
Kenton Varda committed
499 500 501 502
        auto array = kj::heapArrayBuilder<SocketAddress>(1);
        array.add(result);
        return array.finish();
      }
503
      case 0:
Kenton Varda's avatar
Kenton Varda committed
504 505
        // It's apparently not a simple address...  fall back to DNS.
        return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port);
506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
      default:
        KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart);
    }
  }

  static SocketAddress getLocalAddress(int sockfd) {
    SocketAddress result;
    result.addrlen = sizeof(addr);
    KJ_SYSCALL(getsockname(sockfd, &result.addr.generic, &result.addrlen));
    return result;
  }

private:
  SocketAddress(): addrlen(0) {
    memset(&addr, 0, sizeof(addr));
  }

  socklen_t addrlen;
  bool wildcard = false;
  union {
    struct sockaddr generic;
    struct sockaddr_in inet4;
    struct sockaddr_in6 inet6;
    struct sockaddr_un unixDomain;
    struct sockaddr_storage storage;
  } addr;
Kenton Varda's avatar
Kenton Varda committed
532 533 534

  struct LookupParams;
  class LookupReader;
535 536
};

Kenton Varda's avatar
Kenton Varda committed
537 538 539
class SocketAddress::LookupReader {
  // Reads SocketAddresses off of a pipe coming from another thread that is performing
  // getaddrinfo.
540

Kenton Varda's avatar
Kenton Varda committed
541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
public:
  LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input)
      : thread(kj::mv(thread)), input(kj::mv(input)) {}

  ~LookupReader() {
    if (thread) thread->detach();
  }

  Promise<Array<SocketAddress>> read() {
    return input->tryRead(&current, sizeof(current), sizeof(current)).then(
        [this](size_t n) -> Promise<Array<SocketAddress>> {
      if (n < sizeof(current)) {
        thread = nullptr;
        // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
        // anyway.
        KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; }
        return addresses.releaseAsArray();
      } else {
        // getaddrinfo() can return multiple copies of the same address for several reasons.
        // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so
        // it may return two copies of the same address, one for each type, unless it explicitly
        // knows that the service name given is specific to one type.  But we can't tell it a type,
        // because we don't actually know which one the user wants, and if we specify SOCK_STREAM
        // while the user specified a UDP service name then they'll get a resolution error which
        // is lame.  (At least, I think that's how it works.)
        //
        // So we instead resort to de-duping results.
        if (alreadySeen.insert(current).second) {
          addresses.add(current);
        }
        return read();
      }
    });
  }

private:
  kj::Own<Thread> thread;
  kj::Own<AsyncInputStream> input;
  SocketAddress current;
  kj::Vector<SocketAddress> addresses;
  std::set<SocketAddress> alreadySeen;
};

struct SocketAddress::LookupParams {
  kj::String host;
  kj::String service;
};

Promise<Array<SocketAddress>> SocketAddress::lookupHost(
    LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) {
  // This shitty function spawns a thread to run getaddrinfo().  Unfortunately, getaddrinfo() is
  // the only cross-platform DNS API and it is blocking.
  //
  // TODO(perf):  Use a thread pool?  Maybe kj::Thread should use a thread pool automatically?
  //   Maybe use the various platform-specific asynchronous DNS libraries?  Please do not implement
  //   a custom DNS resolver...

  int fds[2];
599
#if __linux__
Kenton Varda's avatar
Kenton Varda committed
600 601 602
  KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC));
#else
  KJ_SYSCALL(pipe(fds));
603
#endif
Kenton Varda's avatar
Kenton Varda committed
604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637

  auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS);

  int outFd = fds[1];

  LookupParams params = { kj::mv(host), kj::mv(service) };

  auto thread = heap<Thread>(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) {
    FdOutputStream output((AutoCloseFd(outFd)));

    struct addrinfo* list;
    int status = getaddrinfo(
        params.host == "*" ? nullptr : params.host.cStr(),
        params.service == nullptr ? nullptr : params.service.cStr(),
        nullptr, &list);
    if (status == 0) {
      KJ_DEFER(freeaddrinfo(list));

      struct addrinfo* cur = list;
      while (cur != nullptr) {
        if (params.service == nullptr) {
          switch (cur->ai_addr->sa_family) {
            case AF_INET:
              ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint);
              break;
            case AF_INET6:
              ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint);
              break;
            default:
              break;
          }
        }

        SocketAddress addr;
Kenton Varda's avatar
Kenton Varda committed
638
        memset(&addr, 0, sizeof(addr));  // mollify valgrind
Kenton Varda's avatar
Kenton Varda committed
639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
        if (params.host == "*") {
          // Set up a wildcard SocketAddress.  Only use the port number returned by getaddrinfo().
          addr.wildcard = true;
          addr.addrlen = sizeof(addr.addr.inet6);
          addr.addr.inet6.sin6_family = AF_INET6;
          switch (cur->ai_addr->sa_family) {
            case AF_INET:
              addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port;
              break;
            case AF_INET6:
              addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port;
              break;
            default:
              addr.addr.inet6.sin6_port = portHint;
              break;
          }
        } else {
          addr.addrlen = cur->ai_addrlen;
          memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen);
        }
659
        static_assert(canMemcpy<SocketAddress>(), "Can't write() SocketAddress...");
Kenton Varda's avatar
Kenton Varda committed
660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675
        output.write(&addr, sizeof(addr));
        cur = cur->ai_next;
      }
    } else if (status == EAI_SYSTEM) {
      KJ_FAIL_SYSCALL("getaddrinfo", errno, params.host, params.service) {
        return;
      }
    } else {
      KJ_FAIL_REQUIRE("DNS lookup failed.",
                      params.host, params.service, gai_strerror(status)) {
        return;
      }
    }
  }));

  auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input));
676
  return reader->read().attach(kj::mv(reader));
Kenton Varda's avatar
Kenton Varda committed
677 678 679
}

// =======================================================================================
680

681 682
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor {
public:
683 684
  FdConnectionReceiver(UnixEventPort& eventPort, int fd, uint flags)
      : OwnedFileDescriptor(fd, flags), eventPort(eventPort) {}
685 686 687 688

  Promise<Own<AsyncIoStream>> accept() override {
    int newFd;

689
  retry:
690
#if __linux__
691
    newFd = ::accept4(fd, nullptr, nullptr, SOCK_NONBLOCK | SOCK_CLOEXEC);
692
#else
693
    newFd = ::accept(fd, nullptr, nullptr);
694 695
#endif

696
    if (newFd >= 0) {
697
      return Own<AsyncIoStream>(heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS));
698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
    } else {
      int error = errno;

      switch (error) {
        case EAGAIN:
#if EAGAIN != EWOULDBLOCK
        case EWOULDBLOCK:
#endif
          // Not ready yet.
          return eventPort.onFdEvent(fd, POLLIN).then([this](short) {
            return accept();
          });

        case EINTR:
        case ENETDOWN:
        case EPROTO:
        case EHOSTDOWN:
        case EHOSTUNREACH:
        case ENETUNREACH:
        case ECONNABORTED:
Kenton Varda's avatar
Kenton Varda committed
718 719 720 721 722
        case ETIMEDOUT:
          // According to the Linux man page, accept() may report an error if the accepted
          // connection is already broken.  In this case, we really ought to just ignore it and
          // keep waiting.  But it's hard to say exactly what errors are such network errors and
          // which ones are permanent errors.  We've made a guess here.
723 724 725 726 727 728
          goto retry;

        default:
          KJ_FAIL_SYSCALL("accept", error);
      }

729 730 731 732 733 734
    }
  }

  uint getPort() override {
    return SocketAddress::getLocalAddress(fd).getPort();
  }
735 736 737

public:
  UnixEventPort& eventPort;
738 739
};

740 741 742 743
class TimerImpl final: public Timer {
public:
  TimerImpl(UnixEventPort& eventPort): eventPort(eventPort) {}

744
  TimePoint now() override { return eventPort.steadyTime(); }
745

746
  Promise<void> atTime(TimePoint time) override {
747 748 749
    return eventPort.atSteadyTime(time);
  }

750
  Promise<void> afterDelay(Duration delay) override {
751 752 753 754 755 756 757
    return eventPort.atSteadyTime(eventPort.steadyTime() + delay);
  }

private:
  UnixEventPort& eventPort;
};

758 759
class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
public:
760 761
  LowLevelAsyncIoProviderImpl()
      : eventLoop(eventPort), timer(eventPort), waitScope(eventLoop) {}
762 763

  inline WaitScope& getWaitScope() { return waitScope; }
764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790

  Own<AsyncInputStream> wrapInputFd(int fd, uint flags = 0) override {
    return heap<AsyncStreamFd>(eventPort, fd, flags);
  }
  Own<AsyncOutputStream> wrapOutputFd(int fd, uint flags = 0) override {
    return heap<AsyncStreamFd>(eventPort, fd, flags);
  }
  Own<AsyncIoStream> wrapSocketFd(int fd, uint flags = 0) override {
    return heap<AsyncStreamFd>(eventPort, fd, flags);
  }
  Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(int fd, uint flags = 0) override {
    auto result = heap<AsyncStreamFd>(eventPort, fd, flags);
    return eventPort.onFdEvent(fd, POLLOUT).then(kj::mvCapture(result,
        [fd](Own<AsyncIoStream>&& stream, short events) {
          int err;
          socklen_t errlen = sizeof(err);
          KJ_SYSCALL(getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &errlen));
          if (err != 0) {
            KJ_FAIL_SYSCALL("connect()", err) { break; }
          }
          return kj::mv(stream);
        }));
  }
  Own<ConnectionReceiver> wrapListenSocketFd(int fd, uint flags = 0) override {
    return heap<FdConnectionReceiver>(eventPort, fd, flags);
  }

791 792
  Timer& getTimer() override { return timer; }

793 794
  UnixEventPort& getEventPort() { return eventPort; }

795 796 797
private:
  UnixEventPort eventPort;
  EventLoop eventLoop;
798
  TimerImpl timer;
799
  WaitScope waitScope;
800 801
};

802 803
// =======================================================================================

Kenton Varda's avatar
Kenton Varda committed
804
class NetworkAddressImpl final: public NetworkAddress {
805
public:
Kenton Varda's avatar
Kenton Varda committed
806 807 808 809 810 811
  NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array<SocketAddress> addrs)
      : lowLevel(lowLevel), addrs(kj::mv(addrs)) {}

  Promise<Own<AsyncIoStream>> connect() override {
    return connectImpl(0);
  }
812 813

  Own<ConnectionReceiver> listen() override {
Kenton Varda's avatar
Kenton Varda committed
814 815 816 817 818
    if (addrs.size() > 1) {
      KJ_LOG(WARNING, "Bind address resolved to multiple addresses.  Only the first address will "
          "be used.  If this is incorrect, specify the address numerically.  This may be fixed "
          "in the future.", addrs[0].toString());
    }
Kenton Varda's avatar
Kenton Varda committed
819 820

    int fd = addrs[0].socket(SOCK_STREAM);
821

822 823 824 825 826 827 828
    {
      KJ_ON_SCOPE_FAILURE(close(fd));

      // We always enable SO_REUSEADDR because having to take your server down for five minutes
      // before it can restart really sucks.
      int optval = 1;
      KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)));
829

Kenton Varda's avatar
Kenton Varda committed
830
      addrs[0].bind(fd);
831

832 833 834
      // TODO(someday):  Let queue size be specified explicitly in string addresses.
      KJ_SYSCALL(::listen(fd, SOMAXCONN));
    }
835

836
    return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS);
837 838 839
  }

  String toString() override {
Kenton Varda's avatar
Kenton Varda committed
840
    return strArray(KJ_MAP(addr, addrs) { return addr.toString(); }, ",");
841 842 843
  }

private:
844
  LowLevelAsyncIoProvider& lowLevel;
Kenton Varda's avatar
Kenton Varda committed
845
  Array<SocketAddress> addrs;
846

Kenton Varda's avatar
Kenton Varda committed
847 848
  Promise<Own<AsyncIoStream>> connectImpl(uint index) {
    KJ_ASSERT(index < addrs.size());
849

Kenton Varda's avatar
Kenton Varda committed
850
    int fd = addrs[index].socket(SOCK_STREAM);
851

Kenton Varda's avatar
Kenton Varda committed
852 853 854 855 856 857 858 859 860 861 862 863
    KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
      addrs[index].connect(fd);
    })) {
      // Connect failed.
      close(fd);
      if (index + 1 < addrs.size()) {
        // Try the next address instead.
        return connectImpl(index + 1);
      } else {
        // No more addresses to try, so propagate the exception.
        return kj::mv(*exception);
      }
864 865
    }

Kenton Varda's avatar
Kenton Varda committed
866 867 868 869 870 871 872 873 874 875 876 877 878 879
    return lowLevel.wrapConnectingSocketFd(fd, NEW_FD_FLAGS).then(
        [](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
      // Success, pass along.
      return kj::mv(stream);
    }, [this,index](Exception&& exception) -> Promise<Own<AsyncIoStream>> {
      // Connect failed.
      if (index + 1 < addrs.size()) {
        // Try the next address instead.
        return connectImpl(index + 1);
      } else {
        // No more addresses to try, so propagate the exception.
        return kj::mv(exception);
      }
    });
880 881 882 883 884
  }
};

class SocketNetwork final: public Network {
public:
885
  explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
886

Kenton Varda's avatar
Kenton Varda committed
887
  Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
888
    auto& lowLevelCopy = lowLevel;
889
    return evalLater(mvCapture(heapString(addr),
Kenton Varda's avatar
Kenton Varda committed
890 891 892 893 894
        [&lowLevelCopy,portHint](String&& addr) {
      return SocketAddress::parse(lowLevelCopy, addr, portHint);
    })).then([&lowLevelCopy](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
      return heap<NetworkAddressImpl>(lowLevelCopy, kj::mv(addresses));
    });
895 896
  }

Kenton Varda's avatar
Kenton Varda committed
897 898 899 900
  Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
    auto array = kj::heapArrayBuilder<SocketAddress>(1);
    array.add(SocketAddress(sockaddr, len));
    return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, array.finish()));
901 902
  }

903
private:
904
  LowLevelAsyncIoProvider& lowLevel;
905
};
906

907
// =======================================================================================
908

909
class AsyncIoProviderImpl final: public AsyncIoProvider {
910
public:
911 912
  AsyncIoProviderImpl(LowLevelAsyncIoProvider& lowLevel)
      : lowLevel(lowLevel), network(lowLevel) {}
Kenton Varda's avatar
Kenton Varda committed
913

914 915
  OneWayPipe newOneWayPipe() override {
    int fds[2];
916
#if __linux__
917
    KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC));
918
#else
919
    KJ_SYSCALL(pipe(fds));
920
#endif
921 922 923 924
    return OneWayPipe {
      lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS),
      lowLevel.wrapOutputFd(fds[1], NEW_FD_FLAGS)
    };
925
  }
926

927 928 929
  TwoWayPipe newTwoWayPipe() override {
    int fds[2];
    int type = SOCK_STREAM;
930
#if __linux__
931
    type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
932
#endif
933
    KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
934 935 936 937
    return TwoWayPipe { {
      lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS),
      lowLevel.wrapSocketFd(fds[1], NEW_FD_FLAGS)
    } };
938 939 940 941 942 943
  }

  Network& getNetwork() override {
    return network;
  }

944
  PipeThread newPipeThread(
945
      Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
946 947 948 949 950 951 952 953
    int fds[2];
    int type = SOCK_STREAM;
#if __linux__
    type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
#endif
    KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));

    int threadFd = fds[1];
954
    KJ_ON_SCOPE_FAILURE(close(threadFd));
955

956 957 958
    auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);

    auto thread = heap<Thread>(kj::mvCapture(startFunc,
959
        [threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
960 961 962
      LowLevelAsyncIoProviderImpl lowLevel;
      auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
      AsyncIoProviderImpl ioProvider(lowLevel);
963
      startFunc(ioProvider, *stream, lowLevel.getWaitScope());
964
    }));
965

966
    return { kj::mv(thread), kj::mv(pipe) };
967
  }
968

969 970
  Timer& getTimer() override { return lowLevel.getTimer(); }

971
private:
972
  LowLevelAsyncIoProvider& lowLevel;
973 974 975 976 977 978 979 980
  SocketNetwork network;
};

}  // namespace

Promise<void> AsyncInputStream::read(void* buffer, size_t bytes) {
  return read(buffer, bytes, bytes).then([](size_t) {});
}
981

982 983 984 985 986 987 988
Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
  return kj::heap<AsyncIoProviderImpl>(lowLevel);
}

AsyncIoContext setupAsyncIo() {
  auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
  auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
989
  auto& waitScope = lowLevel->getWaitScope();
990 991
  auto& eventPort = lowLevel->getEventPort();
  return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort };
992 993
}

994
}  // namespace kj