// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#if KJ_HAS_OPENSSL

#include "tls.h"
#include "readiness-io.h"
#include <openssl/bio.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>
#include <openssl/evp.h>
#include <openssl/conf.h>
#include <openssl/ssl.h>
#include <openssl/tls1.h>
#include <kj/debug.h>
#include <kj/vector.h>

#if OPENSSL_VERSION_NUMBER < 0x10100000L
#define BIO_set_init(x,v)          (x->init=v)
#define BIO_get_data(x)            (x->ptr)
#define BIO_set_data(x,v)          (x->ptr=v)
#endif

namespace kj {
namespace {

// =======================================================================================
// misc helpers

KJ_NORETURN(void throwOpensslError());
void throwOpensslError() {
  // Call when an OpenSSL function returns an error code to convert that into an exception and
  // throw it.

  kj::Vector<kj::String> lines;
  while (unsigned long long error = ERR_get_error()) {
    char message[1024];
    ERR_error_string_n(error, message, sizeof(message));
    lines.add(kj::heapString(message));
  }
  kj::String message = kj::strArray(lines, "\n");
  KJ_FAIL_ASSERT("OpenSSL error", message);
}

#if OPENSSL_VERSION_NUMBER < 0x10100000L && !defined(OPENSSL_IS_BORINGSSL)
// Older versions of OpenSSL don't define _up_ref() functions.

void EVP_PKEY_up_ref(EVP_PKEY* pkey) {
  CRYPTO_add(&pkey->references, 1, CRYPTO_LOCK_EVP_PKEY);
}

void X509_up_ref(X509* x509) {
  CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509);
}

#endif

#if OPENSSL_VERSION_NUMBER < 0x10100000L
class OpenSslInit {
  // Initializes the OpenSSL library.
public:
  OpenSslInit() {
    SSL_library_init();
    SSL_load_error_strings();
    OPENSSL_config(nullptr);
  }
};

void ensureOpenSslInitialized() {
  // Initializes the OpenSSL library the first time it is called.
  static OpenSslInit init;
}
#else
inline void ensureOpenSslInitialized() {
  // As of 1.1.0, no initialization is needed.
}
#endif

// =======================================================================================
// Implementation of kj::AsyncIoStream that applies TLS on top of some other AsyncIoStream.
//
// TODO(perf): OpenSSL's I/O abstraction layer, "BIO", is readiness-based, but AsyncIoStream is
//   completion-based. This forces us to use an intermediate buffer which wastes memory and incurs
//   redundant copies. We could improve the situation by creating a way to detect if the underlying
//   AsyncIoStream is simply wrapping a file descriptor (or other readiness-based stream?) and use
//   that directly if so.

class TlsConnection final: public kj::AsyncIoStream {
public:
  TlsConnection(kj::Own<kj::AsyncIoStream> stream, SSL_CTX* ctx)
      : TlsConnection(*stream, ctx) {
    ownInner = kj::mv(stream);
  }

  TlsConnection(kj::AsyncIoStream& stream, SSL_CTX* ctx)
      : inner(stream), readBuffer(stream), writeBuffer(stream) {
    ssl = SSL_new(ctx);
    if (ssl == nullptr) {
      throwOpensslError();
    }

    BIO* bio = BIO_new(const_cast<BIO_METHOD*>(getBioVtable()));
    if (bio == nullptr) {
      SSL_free(ssl);
      throwOpensslError();
    }

    BIO_set_data(bio, this);
    BIO_set_init(bio, 1);
    SSL_set_bio(ssl, bio, bio);
  }

  kj::Promise<void> connect(kj::StringPtr expectedServerHostname) {
    if (!SSL_set_tlsext_host_name(ssl, expectedServerHostname.cStr())) {
      throwOpensslError();
    }

    X509_VERIFY_PARAM* verify = SSL_get0_param(ssl);
    if (verify == nullptr) {
      throwOpensslError();
    }

    if (X509_VERIFY_PARAM_set1_host(
        verify, expectedServerHostname.cStr(), expectedServerHostname.size()) <= 0) {
      throwOpensslError();
    }

    return sslCall([this]() { return SSL_connect(ssl); }).then([this](size_t) {
      X509* cert = SSL_get_peer_certificate(ssl);
      KJ_REQUIRE(cert != nullptr, "TLS peer provided no certificate");
      X509_free(cert);

      auto result = SSL_get_verify_result(ssl);
      if (result != X509_V_OK) {
        const char* reason = X509_verify_cert_error_string(result);
        KJ_FAIL_REQUIRE("TLS peer's certificate is not trusted", reason);
      }
    });
  }
  kj::Promise<void> accept() {
    // We are the server. Set SSL options to prefer server's cipher choice.
    SSL_set_options(ssl, SSL_OP_CIPHER_SERVER_PREFERENCE);

    return sslCall([this]() { return SSL_accept(ssl); }).ignoreResult();
  }

  ~TlsConnection() noexcept(false) {
    SSL_free(ssl);
  }

  kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
    return tryReadInternal(buffer, minBytes, maxBytes, 0);
  }

  Promise<void> write(const void* buffer, size_t size) override {
    return writeInternal(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size), nullptr);
  }

  Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
    return writeInternal(pieces[0], pieces.slice(1, pieces.size()));
  }

  void shutdownWrite() override {
    KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");

    // TODO(soon): shutdownWrite() is problematic because it doesn't return a promise. It was
    //   designed to assume that it would only be called after all writes are finished and that
    //   there was no reason to block at that point, but SSL sessions don't fit this since they
    //   actually have to send a shutdown message.
    shutdownTask = sslCall([this]() {
      // The first SSL_shutdown() call is expected to return 0 and may flag a misleading error.
      int result = SSL_shutdown(ssl);
      return result == 0 ? 1 : result;
    }).ignoreResult().eagerlyEvaluate([](kj::Exception&& e) {
      KJ_LOG(ERROR, e);
    });
  }

  void abortRead() override {
    inner.abortRead();
  }

  void getsockopt(int level, int option, void* value, uint* length) override {
    inner.getsockopt(level, option, value, length);
  }
  void setsockopt(int level, int option, const void* value, uint length) override {
    inner.setsockopt(level, option, value, length);
  }

  void getsockname(struct sockaddr* addr, uint* length) override {
    inner.getsockname(addr, length);
  }
  void getpeername(struct sockaddr* addr, uint* length) override {
    inner.getpeername(addr, length);
  }

private:
  SSL* ssl;
  kj::AsyncIoStream& inner;
  kj::Own<kj::AsyncIoStream> ownInner;

  bool disconnected = false;
  kj::Maybe<kj::Promise<void>> shutdownTask;

  ReadyInputStreamWrapper readBuffer;
  ReadyOutputStreamWrapper writeBuffer;

  kj::Promise<size_t> tryReadInternal(
      void* buffer, size_t minBytes, size_t maxBytes, size_t alreadyDone) {
    if (disconnected) return alreadyDone;

    return sslCall([this,buffer,maxBytes]() { return SSL_read(ssl, buffer, maxBytes); })
        .then([this,buffer,minBytes,maxBytes,alreadyDone](size_t n) -> kj::Promise<size_t> {
      if (n >= minBytes || n == 0) {
        return alreadyDone + n;
      } else {
        return tryReadInternal(reinterpret_cast<byte*>(buffer) + n,
                               minBytes - n, maxBytes - n, alreadyDone + n);
      }
    });
  }

  Promise<void> writeInternal(kj::ArrayPtr<const byte> first,
                              kj::ArrayPtr<const kj::ArrayPtr<const byte>> rest) {
    KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");

    return sslCall([this,first]() { return SSL_write(ssl, first.begin(), first.size()); })
        .then([this,first,rest](size_t n) -> kj::Promise<void> {
      if (n == 0) {
        return KJ_EXCEPTION(DISCONNECTED, "ssl connection ended during write");
      } else if (n < first.size()) {
        return writeInternal(first.slice(n, first.size()), rest);
      } else if (rest.size() > 0) {
        return writeInternal(rest[0], rest.slice(1, rest.size()));
      } else {
        return kj::READY_NOW;
      }
    });
  }

  template <typename Func>
  kj::Promise<size_t> sslCall(Func&& func) {
    if (disconnected) return size_t(0);

    ssize_t result = func();

    if (result > 0) {
      return result;
    } else {
      int error = SSL_get_error(ssl, result);
      switch (error) {
        case SSL_ERROR_ZERO_RETURN:
          disconnected = true;
          return size_t(0);
        case SSL_ERROR_WANT_READ:
          return readBuffer.whenReady().then(kj::mvCapture(func,
              [this](Func&& func) mutable { return sslCall(kj::fwd<Func>(func)); }));
        case SSL_ERROR_WANT_WRITE:
          return writeBuffer.whenReady().then(kj::mvCapture(func,
              [this](Func&& func) mutable { return sslCall(kj::fwd<Func>(func)); }));
        case SSL_ERROR_SSL:
          throwOpensslError();
        case SSL_ERROR_SYSCALL:
          if (result == 0) {
            disconnected = true;
            return size_t(0);
          } else {
            // According to documentation we shouldn't get here, because our BIO never returns an
            // "error". But in practice we do get here sometimes when the peer disconnects
            // prematurely.
            KJ_FAIL_ASSERT("TLS protocol error");
          }
        default:
          KJ_FAIL_ASSERT("unexpected SSL error code", error);
      }
    }
  }

  static int bioRead(BIO* b, char* out, int outl) {
    BIO_clear_retry_flags(b);
    KJ_IF_MAYBE(n, reinterpret_cast<TlsConnection*>(BIO_get_data(b))->readBuffer
        .read(kj::arrayPtr(out, outl).asBytes())) {
      return *n;
    } else {
      BIO_set_retry_read(b);
      return -1;
    }
  }

  static int bioWrite(BIO* b, const char* in, int inl) {
    BIO_clear_retry_flags(b);
    KJ_IF_MAYBE(n, reinterpret_cast<TlsConnection*>(BIO_get_data(b))->writeBuffer
        .write(kj::arrayPtr(in, inl).asBytes())) {
      return *n;
    } else {
      BIO_set_retry_write(b);
      return -1;
    }
  }

  static long bioCtrl(BIO* b, int cmd, long num, void* ptr) {
    switch (cmd) {
      case BIO_CTRL_FLUSH:
        return 1;
      case BIO_CTRL_PUSH:
      case BIO_CTRL_POP:
        // Informational?
        return 0;
      default:
        KJ_LOG(WARNING, "unimplemented bio_ctrl", cmd);
        return 0;
    }
  }

  static int bioCreate(BIO* b) {
    BIO_set_data(b, nullptr);
    return 1;
  }

  static int bioDestroy(BIO* b) {
    // The BIO does NOT own the TlsConnection.
    return 1;
  }

#if OPENSSL_VERSION_NUMBER < 0x10100000L
  static const BIO_METHOD* getBioVtable() {
    static const BIO_METHOD VTABLE {
      BIO_TYPE_SOURCE_SINK,
      "KJ stream",
      TlsConnection::bioWrite,
      TlsConnection::bioRead,
      nullptr,  // puts
      nullptr,  // gets
      TlsConnection::bioCtrl,
      TlsConnection::bioCreate,
      TlsConnection::bioDestroy,
      nullptr
    };
    return &VTABLE;
  }
#else
  static const BIO_METHOD* getBioVtable() {
    static const BIO_METHOD* const vtable = makeBioVtable();
    return vtable;
  }
  static const BIO_METHOD* makeBioVtable() {
    BIO_METHOD* vtable = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "KJ stream");
    BIO_meth_set_write(vtable, TlsConnection::bioWrite);
    BIO_meth_set_read(vtable, TlsConnection::bioRead);
    BIO_meth_set_ctrl(vtable, TlsConnection::bioCtrl);
    BIO_meth_set_create(vtable, TlsConnection::bioCreate);
    BIO_meth_set_destroy(vtable, TlsConnection::bioDestroy);
    return vtable;
  }
#endif
};

// =======================================================================================
// Implementations of ConnectionReceiver, NetworkAddress, and Network as wrappers adding TLS.

class TlsConnectionReceiver final: public kj::ConnectionReceiver {
public:
  TlsConnectionReceiver(TlsContext& tls, kj::Own<kj::ConnectionReceiver> inner)
      : tls(tls), inner(kj::mv(inner)) {}

  Promise<Own<AsyncIoStream>> accept() override {
    return inner->accept().then([this](kj::Own<AsyncIoStream> stream) {
      return tls.wrapServer(kj::mv(stream));
    });
  }

  uint getPort() override {
    return inner->getPort();
  }

  void getsockopt(int level, int option, void* value, uint* length) override {
    return inner->getsockopt(level, option, value, length);
  }

  void setsockopt(int level, int option, const void* value, uint length) override {
    return inner->setsockopt(level, option, value, length);
  }

private:
  TlsContext& tls;
  kj::Own<kj::ConnectionReceiver> inner;
};

class TlsNetworkAddress final: public kj::NetworkAddress {
public:
  TlsNetworkAddress(TlsContext& tls, kj::String hostname, kj::Own<kj::NetworkAddress>&& inner)
      : tls(tls), hostname(kj::mv(hostname)), inner(kj::mv(inner)) {}

  Promise<Own<AsyncIoStream>> connect() override {
    // Note: It's unfortunately pretty common for people to assume they can drop the NetworkAddress
    //   as soon as connect() returns, and this works with the native network implementation.
    //   So, we make some copies here.
    auto& tlsRef = tls;
    auto hostnameCopy = kj::str(hostname);
    return inner->connect().then(kj::mvCapture(hostnameCopy,
        [&tlsRef](kj::String&& hostname, Own<AsyncIoStream>&& stream) {
      return tlsRef.wrapClient(kj::mv(stream), hostname);
    }));
  }

  Own<ConnectionReceiver> listen() override {
    return tls.wrapPort(inner->listen());
  }

  Own<NetworkAddress> clone() override {
    return kj::heap<TlsNetworkAddress>(tls, kj::str(hostname), inner->clone());
  }

  String toString() override {
    return kj::str("tls:", inner->toString());
  }

private:
  TlsContext& tls;
  kj::String hostname;
  kj::Own<kj::NetworkAddress> inner;
};

class TlsNetwork final: public kj::Network {
public:
  TlsNetwork(TlsContext& tls, kj::Network& inner): tls(tls), inner(inner) {}
  TlsNetwork(TlsContext& tls, kj::Own<kj::Network> inner)
      : tls(tls), inner(*inner), ownInner(kj::mv(inner)) {}

  Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint) override {
    kj::String hostname;
    KJ_IF_MAYBE(pos, addr.findFirst(':')) {
      hostname = kj::heapString(addr.slice(0, *pos));
    } else {
      hostname = kj::heapString(addr);
    }

    return inner.parseAddress(addr, portHint)
        .then(kj::mvCapture(hostname, [this](kj::String&& hostname, kj::Own<NetworkAddress>&& addr)
            -> kj::Own<kj::NetworkAddress> {
      return kj::heap<TlsNetworkAddress>(tls, kj::mv(hostname), kj::mv(addr));
    }));
  }

  Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
    KJ_UNIMPLEMENTED("TLS does not implement getSockaddr() because it needs to know hostnames");
  }

  Own<Network> restrictPeers(
      kj::ArrayPtr<const kj::StringPtr> allow,
      kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
    // TODO(someday): Maybe we could implement the ability to specify CA or hostname restrictions?
    //   Or is it better to let people do that via the TlsContext? A neat thing about
    //   restrictPeers() is that it's easy to make user-configurable.
    return kj::heap<TlsNetwork>(tls, inner.restrictPeers(allow, deny));
  }

private:
  TlsContext& tls;
  kj::Network& inner;
  kj::Own<kj::Network> ownInner;
};

}  // namespace

// =======================================================================================
// class TlsContext

TlsContext::Options::Options()
    : useSystemTrustStore(true),
      verifyClients(false),
      minVersion(TlsVersion::TLS_1_0),
      cipherList("ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:ECDHE-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:DES-CBC3-SHA:!DSS") {}
// Cipher list is Mozilla's "intermediate" list, except with classic DH removed since we don't
// currently support setting dhparams. See:
//     https://mozilla.github.io/server-side-tls/ssl-config-generator/
//
// Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll
// never bother.

struct TlsContext::SniCallback {
  // struct SniCallback exists only so that callback() can be declared in the .c++ file, since it
  // references OpenSSL types.

  static int callback(SSL* ssl, int* ad, void* arg);
};

TlsContext::TlsContext(Options options) {
  ensureOpenSslInitialized();

#if OPENSSL_VERSION_NUMBER >= 0x10100000L || defined(OPENSSL_IS_BORINGSSL)
  SSL_CTX* ctx = SSL_CTX_new(TLS_method());
#else
  SSL_CTX* ctx = SSL_CTX_new(SSLv23_method());
#endif

  if (ctx == nullptr) {
    throwOpensslError();
  }
  KJ_ON_SCOPE_FAILURE(SSL_CTX_free(ctx));

  // honor options.useSystemTrustStore
  if (options.useSystemTrustStore) {
    if (!SSL_CTX_set_default_verify_paths(ctx)) {
      throwOpensslError();
    }
  }

  // honor options.trustedCertificates
  if (options.trustedCertificates.size() > 0) {
    X509_STORE* store = SSL_CTX_get_cert_store(ctx);
    if (store == nullptr) {
      throwOpensslError();
    }
    for (auto& cert: options.trustedCertificates) {
      if (!X509_STORE_add_cert(store, reinterpret_cast<X509*>(cert.chain[0]))) {
        throwOpensslError();
      }
    }
  }

  if (options.verifyClients) {
    SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL);
  }

  // honor options.minVersion
  long optionFlags = 0;
  if (options.minVersion > TlsVersion::SSL_3) {
    optionFlags |= SSL_OP_NO_SSLv3;
  }
  if (options.minVersion > TlsVersion::TLS_1_0) {
    optionFlags |= SSL_OP_NO_TLSv1;
  }
  if (options.minVersion > TlsVersion::TLS_1_1) {
    optionFlags |= SSL_OP_NO_TLSv1_1;
  }
  if (options.minVersion > TlsVersion::TLS_1_2) {
    optionFlags |= SSL_OP_NO_TLSv1_2;
  }
  SSL_CTX_set_options(ctx, optionFlags);  // note: never fails; returns new options bitmask

  // honor options.cipherList
  if (!SSL_CTX_set_cipher_list(ctx, options.cipherList.cStr())) {
    throwOpensslError();
  }

  // honor options.defaultKeypair
  KJ_IF_MAYBE(kp, options.defaultKeypair) {
    if (!SSL_CTX_use_PrivateKey(ctx, reinterpret_cast<EVP_PKEY*>(kp->privateKey.pkey))) {
      throwOpensslError();
    }

    if (!SSL_CTX_use_certificate(ctx, reinterpret_cast<X509*>(kp->certificate.chain[0]))) {
      throwOpensslError();
    }

    for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) {
      X509* x509 = reinterpret_cast<X509*>(kp->certificate.chain[i]);
      if (x509 == nullptr) break;  // end of chain

      if (!SSL_CTX_add_extra_chain_cert(ctx, x509)) {
        throwOpensslError();
      }

      // SSL_CTX_add_extra_chain_cert() does NOT up the refcount itself.
      X509_up_ref(x509);
    }
  }

  // honor options.sniCallback
  KJ_IF_MAYBE(sni, options.sniCallback) {
    SSL_CTX_set_tlsext_servername_callback(ctx, &SniCallback::callback);
    SSL_CTX_set_tlsext_servername_arg(ctx, sni);
  }

  this->ctx = ctx;
}

int TlsContext::SniCallback::callback(SSL* ssl, int* ad, void* arg) {
  // The third parameter is actually type TlsSniCallback*.

  KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
    TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg);

    const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
    if (name != nullptr) {
      KJ_IF_MAYBE(kp, sni.getKey(name)) {
        if (!SSL_use_PrivateKey(ssl, reinterpret_cast<EVP_PKEY*>(kp->privateKey.pkey))) {
          throwOpensslError();
        }

        if (!SSL_use_certificate(ssl, reinterpret_cast<X509*>(kp->certificate.chain[0]))) {
          throwOpensslError();
        }

        if (!SSL_clear_chain_certs(ssl)) {
          throwOpensslError();
        }

        for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) {
          X509* x509 = reinterpret_cast<X509*>(kp->certificate.chain[i]);
          if (x509 == nullptr) break;  // end of chain

          if (!SSL_add0_chain_cert(ssl, x509)) {
            throwOpensslError();
          }

          // SSL_add0_chain_cert() does NOT up the refcount itself.
          X509_up_ref(x509);
        }
      }
    }
  })) {
    KJ_LOG(ERROR, "exception when invoking SNI callback", *exception);
    *ad = SSL_AD_INTERNAL_ERROR;
    return SSL_TLSEXT_ERR_ALERT_FATAL;
  }

  return SSL_TLSEXT_ERR_OK;
}

TlsContext::~TlsContext() noexcept(false) {
  SSL_CTX_free(reinterpret_cast<SSL_CTX*>(ctx));
}

kj::Promise<kj::Own<kj::AsyncIoStream>> TlsContext::wrapClient(
    kj::Own<kj::AsyncIoStream> stream, kj::StringPtr expectedServerHostname) {
  auto conn = kj::heap<TlsConnection>(kj::mv(stream), reinterpret_cast<SSL_CTX*>(ctx));
  auto promise = conn->connect(expectedServerHostname);
  return promise.then(kj::mvCapture(conn, [](kj::Own<TlsConnection> conn)
      -> kj::Own<kj::AsyncIoStream> {
    return kj::mv(conn);
  }));
}

kj::Promise<kj::Own<kj::AsyncIoStream>> TlsContext::wrapServer(kj::Own<kj::AsyncIoStream> stream) {
  auto conn = kj::heap<TlsConnection>(kj::mv(stream), reinterpret_cast<SSL_CTX*>(ctx));
  auto promise = conn->accept();
  return promise.then(kj::mvCapture(conn, [](kj::Own<TlsConnection> conn)
      -> kj::Own<kj::AsyncIoStream> {
    return kj::mv(conn);
  }));
}

kj::Own<kj::ConnectionReceiver> TlsContext::wrapPort(kj::Own<kj::ConnectionReceiver> port) {
  return kj::heap<TlsConnectionReceiver>(*this, kj::mv(port));
}

kj::Own<kj::Network> TlsContext::wrapNetwork(kj::Network& network) {
  return kj::heap<TlsNetwork>(*this, network);
}

// =======================================================================================
// class TlsPrivateKey

TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) {
  ensureOpenSslInitialized();

  const byte* ptr = asn1.begin();
  pkey = d2i_AutoPrivateKey(nullptr, &ptr, asn1.size());
  if (pkey == nullptr) {
    throwOpensslError();
  }
}

TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password) {
  ensureOpenSslInitialized();

  // const_cast apparently needed for older versions of OpenSSL.
  BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
  KJ_DEFER(BIO_free(bio));

  pkey = PEM_read_bio_PrivateKey(bio, nullptr, &passwordCallback, &password);
  if (pkey == nullptr) {
    throwOpensslError();
  }
}

TlsPrivateKey::TlsPrivateKey(const TlsPrivateKey& other)
    : pkey(other.pkey) {
  if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast<EVP_PKEY*>(pkey));
}

TlsPrivateKey& TlsPrivateKey::operator=(const TlsPrivateKey& other) {
  if (pkey != other.pkey) {
    EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
    pkey = other.pkey;
    if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast<EVP_PKEY*>(pkey));
  }
  return *this;
}

TlsPrivateKey::~TlsPrivateKey() noexcept(false) {
  EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
}

int TlsPrivateKey::passwordCallback(char* buf, int size, int rwflag, void* u) {
  auto& password = *reinterpret_cast<kj::Maybe<kj::StringPtr>*>(u);

  KJ_IF_MAYBE(p, password) {
    int result = kj::min(p->size(), size);
    memcpy(buf, p->begin(), result);
    return result;
  } else {
    return 0;
  }
}

// =======================================================================================
// class TlsCertificate

TlsCertificate::TlsCertificate(kj::ArrayPtr<const kj::ArrayPtr<const byte>> asn1) {
  ensureOpenSslInitialized();

  KJ_REQUIRE(asn1.size() > 0, "must provide at least one certificate in chain");
  KJ_REQUIRE(asn1.size() <= kj::size(chain),
      "exceeded maximum certificate chain length of 10");

  memset(chain, 0, sizeof(chain));

  for (auto i: kj::indices(asn1)) {
    auto p = asn1[i].begin();

    // "_AUX" apparently refers to some auxilliary information that can be appended to the
    // certificate, but should only be trusted for your own certificate, not the whole chain??
    // I don't really know, I'm just cargo-culting.
    chain[i] = i == 0 ? d2i_X509_AUX(nullptr, &p, asn1[i].size())
                      : d2i_X509(nullptr, &p, asn1[i].size());

    if (chain[i] == nullptr) {
      for (size_t j = 0; j < i; j++) {
        X509_free(reinterpret_cast<X509*>(chain[j]));
      }
      throwOpensslError();
    }
  }
}

TlsCertificate::TlsCertificate(kj::ArrayPtr<const byte> asn1)
    : TlsCertificate(kj::arrayPtr(&asn1, 1)) {}

TlsCertificate::TlsCertificate(kj::StringPtr pem) {
  ensureOpenSslInitialized();

  memset(chain, 0, sizeof(chain));

  // const_cast apparently needed for older versions of OpenSSL.
  BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
  KJ_DEFER(BIO_free(bio));

  for (auto i: kj::indices(chain)) {
    // "_AUX" apparently refers to some auxilliary information that can be appended to the
    // certificate, but should only be trusted for your own certificate, not the whole chain??
    // I don't really know, I'm just cargo-culting.
    chain[i] = i == 0 ? PEM_read_bio_X509_AUX(bio, nullptr, nullptr, nullptr)
                      : PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);

    if (chain[i] == nullptr) {
      auto error = ERR_peek_last_error();
      if (i > 0 && ERR_GET_LIB(error) == ERR_LIB_PEM &&
          ERR_GET_REASON(error) == PEM_R_NO_START_LINE) {
        // EOF; we're done.
        ERR_clear_error();
        return;
      } else {
        for (size_t j = 0; j < i; j++) {
          X509_free(reinterpret_cast<X509*>(chain[j]));
        }
        throwOpensslError();
      }
    }
  }

  // We reached the chain length limit. Try to read one more to verify that the chain ends here.
  X509* dummy = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
  if (dummy != nullptr) {
    X509_free(dummy);
    for (auto i: kj::indices(chain)) {
      X509_free(reinterpret_cast<X509*>(chain[i]));
    }
    KJ_FAIL_REQUIRE("exceeded maximum certificate chain length of 10");
  }
}

TlsCertificate::TlsCertificate(const TlsCertificate& other) {
  memcpy(chain, other.chain, sizeof(chain));
  for (void* p: chain) {
    if (p == nullptr) break;  // end of chain; quit early
    X509_up_ref(reinterpret_cast<X509*>(p));
  }
}

TlsCertificate& TlsCertificate::operator=(const TlsCertificate& other) {
  for (auto i: kj::indices(chain)) {
    if (chain[i] != other.chain[i]) {
      EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(chain[i]));
      chain[i] = other.chain[i];
      if (chain[i] != nullptr) X509_up_ref(reinterpret_cast<X509*>(chain[i]));
    } else if (chain[i] == nullptr) {
      // end of both chains; quit early
      break;
    }
  }
  return *this;
}

TlsCertificate::~TlsCertificate() noexcept(false) {
  for (void* p: chain) {
    if (p == nullptr) break;  // end of chain; quit early
    X509_free(reinterpret_cast<X509*>(p));
  }
}

}  // namespace kj

#endif  // KJ_HAS_OPENSSL