tls.c++ 27.1 KB
Newer Older
Kenton Varda's avatar
Kenton Varda committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

22 23
#if KJ_HAS_OPENSSL

Kenton Varda's avatar
Kenton Varda committed
24 25 26 27 28 29 30 31 32
#include "tls.h"
#include "readiness-io.h"
#include <openssl/bio.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>
#include <openssl/evp.h>
#include <openssl/conf.h>
33
#include <openssl/ssl.h>
Kenton Varda's avatar
Kenton Varda committed
34 35 36 37
#include <openssl/tls1.h>
#include <kj/debug.h>
#include <kj/vector.h>

38 39 40 41 42 43
#if OPENSSL_VERSION_NUMBER < 0x10100000L
#define BIO_set_init(x,v)          (x->init=v)
#define BIO_get_data(x)            (x->ptr)
#define BIO_set_data(x,v)          (x->ptr=v)
#endif

Kenton Varda's avatar
Kenton Varda committed
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
namespace kj {
namespace {

// =======================================================================================
// misc helpers

KJ_NORETURN(void throwOpensslError());
void throwOpensslError() {
  // Call when an OpenSSL function returns an error code to convert that into an exception and
  // throw it.

  kj::Vector<kj::String> lines;
  while (unsigned long long error = ERR_get_error()) {
    char message[1024];
    ERR_error_string_n(error, message, sizeof(message));
    lines.add(kj::heapString(message));
  }
  kj::String message = kj::strArray(lines, "\n");
  KJ_FAIL_ASSERT("OpenSSL error", message);
}

#if OPENSSL_VERSION_NUMBER < 0x10100000L && !defined(OPENSSL_IS_BORINGSSL)
// Older versions of OpenSSL don't define _up_ref() functions.

void EVP_PKEY_up_ref(EVP_PKEY* pkey) {
  CRYPTO_add(&pkey->references, 1, CRYPTO_LOCK_EVP_PKEY);
}

void X509_up_ref(X509* x509) {
  CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509);
}

#endif

78
#if OPENSSL_VERSION_NUMBER < 0x10100000L
Kenton Varda's avatar
Kenton Varda committed
79 80 81 82 83 84 85 86 87 88 89 90 91 92
class OpenSslInit {
  // Initializes the OpenSSL library.
public:
  OpenSslInit() {
    SSL_library_init();
    SSL_load_error_strings();
    OPENSSL_config(nullptr);
  }
};

void ensureOpenSslInitialized() {
  // Initializes the OpenSSL library the first time it is called.
  static OpenSslInit init;
}
93 94 95 96 97
#else
inline void ensureOpenSslInitialized() {
  // As of 1.1.0, no initialization is needed.
}
#endif
Kenton Varda's avatar
Kenton Varda committed
98 99 100 101 102 103 104 105 106 107

// =======================================================================================
// Implementation of kj::AsyncIoStream that applies TLS on top of some other AsyncIoStream.
//
// TODO(perf): OpenSSL's I/O abstraction layer, "BIO", is readiness-based, but AsyncIoStream is
//   completion-based. This forces us to use an intermediate buffer which wastes memory and incurs
//   redundant copies. We could improve the situation by creating a way to detect if the underlying
//   AsyncIoStream is simply wrapping a file descriptor (or other readiness-based stream?) and use
//   that directly if so.

108
class TlsConnection final: public kj::AsyncIoStream {
Kenton Varda's avatar
Kenton Varda committed
109 110 111 112 113 114 115 116 117 118 119 120 121
public:
  TlsConnection(kj::Own<kj::AsyncIoStream> stream, SSL_CTX* ctx)
      : TlsConnection(*stream, ctx) {
    ownInner = kj::mv(stream);
  }

  TlsConnection(kj::AsyncIoStream& stream, SSL_CTX* ctx)
      : inner(stream), readBuffer(stream), writeBuffer(stream) {
    ssl = SSL_new(ctx);
    if (ssl == nullptr) {
      throwOpensslError();
    }

122
    BIO* bio = BIO_new(const_cast<BIO_METHOD*>(getBioVtable()));
Kenton Varda's avatar
Kenton Varda committed
123 124 125 126 127
    if (bio == nullptr) {
      SSL_free(ssl);
      throwOpensslError();
    }

128 129
    BIO_set_data(bio, this);
    BIO_set_init(bio, 1);
Kenton Varda's avatar
Kenton Varda committed
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    SSL_set_bio(ssl, bio, bio);
  }

  kj::Promise<void> connect(kj::StringPtr expectedServerHostname) {
    if (!SSL_set_tlsext_host_name(ssl, expectedServerHostname.cStr())) {
      throwOpensslError();
    }

    X509_VERIFY_PARAM* verify = SSL_get0_param(ssl);
    if (verify == nullptr) {
      throwOpensslError();
    }

    if (X509_VERIFY_PARAM_set1_host(
        verify, expectedServerHostname.cStr(), expectedServerHostname.size()) <= 0) {
      throwOpensslError();
    }

148
    return sslCall([this]() { return SSL_connect(ssl); }).then([this](size_t) {
Kenton Varda's avatar
Kenton Varda committed
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
      X509* cert = SSL_get_peer_certificate(ssl);
      KJ_REQUIRE(cert != nullptr, "TLS peer provided no certificate");
      X509_free(cert);

      auto result = SSL_get_verify_result(ssl);
      if (result != X509_V_OK) {
        const char* reason = X509_verify_cert_error_string(result);
        KJ_FAIL_REQUIRE("TLS peer's certificate is not trusted", reason);
      }
    });
  }
  kj::Promise<void> accept() {
    // We are the server. Set SSL options to prefer server's cipher choice.
    SSL_set_options(ssl, SSL_OP_CIPHER_SERVER_PREFERENCE);

    return sslCall([this]() { return SSL_accept(ssl); }).ignoreResult();
  }

  ~TlsConnection() noexcept(false) {
    SSL_free(ssl);
  }

  kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
172
    return tryReadInternal(buffer, minBytes, maxBytes, 0);
Kenton Varda's avatar
Kenton Varda committed
173 174 175
  }

  Promise<void> write(const void* buffer, size_t size) override {
176
    return writeInternal(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size), nullptr);
Kenton Varda's avatar
Kenton Varda committed
177 178 179
  }

  Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
180
    return writeInternal(pieces[0], pieces.slice(1, pieces.size()));
Kenton Varda's avatar
Kenton Varda committed
181 182 183
  }

  void shutdownWrite() override {
184 185
    KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");

186
    // TODO(0.8): shutdownWrite() is problematic because it doesn't return a promise. It was
Kenton Varda's avatar
Kenton Varda committed
187 188 189
    //   designed to assume that it would only be called after all writes are finished and that
    //   there was no reason to block at that point, but SSL sessions don't fit this since they
    //   actually have to send a shutdown message.
190 191 192 193 194 195 196
    shutdownTask = sslCall([this]() {
      // The first SSL_shutdown() call is expected to return 0 and may flag a misleading error.
      int result = SSL_shutdown(ssl);
      return result == 0 ? 1 : result;
    }).ignoreResult().eagerlyEvaluate([](kj::Exception&& e) {
      KJ_LOG(ERROR, e);
    });
Kenton Varda's avatar
Kenton Varda committed
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
  }

  void abortRead() override {
    inner.abortRead();
  }

  void getsockopt(int level, int option, void* value, uint* length) override {
    inner.getsockopt(level, option, value, length);
  }
  void setsockopt(int level, int option, const void* value, uint length) override {
    inner.setsockopt(level, option, value, length);
  }

  void getsockname(struct sockaddr* addr, uint* length) override {
    inner.getsockname(addr, length);
  }
  void getpeername(struct sockaddr* addr, uint* length) override {
    inner.getpeername(addr, length);
  }

private:
  SSL* ssl;
  kj::AsyncIoStream& inner;
  kj::Own<kj::AsyncIoStream> ownInner;

  bool disconnected = false;
223
  kj::Maybe<kj::Promise<void>> shutdownTask;
Kenton Varda's avatar
Kenton Varda committed
224 225 226 227 228 229 230 231 232 233

  ReadyInputStreamWrapper readBuffer;
  ReadyOutputStreamWrapper writeBuffer;

  kj::Promise<size_t> tryReadInternal(
      void* buffer, size_t minBytes, size_t maxBytes, size_t alreadyDone) {
    if (disconnected) return alreadyDone;

    return sslCall([this,buffer,maxBytes]() { return SSL_read(ssl, buffer, maxBytes); })
        .then([this,buffer,minBytes,maxBytes,alreadyDone](size_t n) -> kj::Promise<size_t> {
Kenton Varda's avatar
Kenton Varda committed
234
      if (n >= minBytes || n == 0) {
Kenton Varda's avatar
Kenton Varda committed
235 236 237 238 239 240 241 242 243 244
        return alreadyDone + n;
      } else {
        return tryReadInternal(reinterpret_cast<byte*>(buffer) + n,
                               minBytes - n, maxBytes - n, alreadyDone + n);
      }
    });
  }

  Promise<void> writeInternal(kj::ArrayPtr<const byte> first,
                              kj::ArrayPtr<const kj::ArrayPtr<const byte>> rest) {
245 246
    KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");

Kenton Varda's avatar
Kenton Varda committed
247 248
    return sslCall([this,first]() { return SSL_write(ssl, first.begin(), first.size()); })
        .then([this,first,rest](size_t n) -> kj::Promise<void> {
Kenton Varda's avatar
Kenton Varda committed
249 250 251
      if (n == 0) {
        return KJ_EXCEPTION(DISCONNECTED, "ssl connection ended during write");
      } else if (n < first.size()) {
Kenton Varda's avatar
Kenton Varda committed
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
        return writeInternal(first.slice(n, first.size()), rest);
      } else if (rest.size() > 0) {
        return writeInternal(rest[0], rest.slice(1, rest.size()));
      } else {
        return kj::READY_NOW;
      }
    });
  }

  template <typename Func>
  kj::Promise<size_t> sslCall(Func&& func) {
    if (disconnected) return size_t(0);

    ssize_t result = func();

    if (result > 0) {
      return result;
    } else {
      int error = SSL_get_error(ssl, result);
      switch (error) {
        case SSL_ERROR_ZERO_RETURN:
          disconnected = true;
          return size_t(0);
        case SSL_ERROR_WANT_READ:
          return readBuffer.whenReady().then(kj::mvCapture(func,
              [this](Func&& func) mutable { return sslCall(kj::fwd<Func>(func)); }));
        case SSL_ERROR_WANT_WRITE:
          return writeBuffer.whenReady().then(kj::mvCapture(func,
              [this](Func&& func) mutable { return sslCall(kj::fwd<Func>(func)); }));
        case SSL_ERROR_SSL:
          throwOpensslError();
        case SSL_ERROR_SYSCALL:
          if (result == 0) {
            disconnected = true;
            return size_t(0);
          } else {
288 289 290 291
            // According to documentation we shouldn't get here, because our BIO never returns an
            // "error". But in practice we do get here sometimes when the peer disconnects
            // prematurely.
            KJ_FAIL_ASSERT("TLS protocol error");
Kenton Varda's avatar
Kenton Varda committed
292 293 294 295 296 297 298 299 300
          }
        default:
          KJ_FAIL_ASSERT("unexpected SSL error code", error);
      }
    }
  }

  static int bioRead(BIO* b, char* out, int outl) {
    BIO_clear_retry_flags(b);
301
    KJ_IF_MAYBE(n, reinterpret_cast<TlsConnection*>(BIO_get_data(b))->readBuffer
Kenton Varda's avatar
Kenton Varda committed
302 303 304 305 306 307 308 309 310 311
        .read(kj::arrayPtr(out, outl).asBytes())) {
      return *n;
    } else {
      BIO_set_retry_read(b);
      return -1;
    }
  }

  static int bioWrite(BIO* b, const char* in, int inl) {
    BIO_clear_retry_flags(b);
312
    KJ_IF_MAYBE(n, reinterpret_cast<TlsConnection*>(BIO_get_data(b))->writeBuffer
Kenton Varda's avatar
Kenton Varda committed
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
        .write(kj::arrayPtr(in, inl).asBytes())) {
      return *n;
    } else {
      BIO_set_retry_write(b);
      return -1;
    }
  }

  static long bioCtrl(BIO* b, int cmd, long num, void* ptr) {
    switch (cmd) {
      case BIO_CTRL_FLUSH:
        return 1;
      case BIO_CTRL_PUSH:
      case BIO_CTRL_POP:
        // Informational?
        return 0;
      default:
        KJ_LOG(WARNING, "unimplemented bio_ctrl", cmd);
        return 0;
    }
  }

  static int bioCreate(BIO* b) {
336
    BIO_set_data(b, nullptr);
Kenton Varda's avatar
Kenton Varda committed
337 338 339 340 341 342 343 344
    return 1;
  }

  static int bioDestroy(BIO* b) {
    // The BIO does NOT own the TlsConnection.
    return 1;
  }

345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
#if OPENSSL_VERSION_NUMBER < 0x10100000L
  static const BIO_METHOD* getBioVtable() {
    static const BIO_METHOD VTABLE {
      BIO_TYPE_SOURCE_SINK,
      "KJ stream",
      TlsConnection::bioWrite,
      TlsConnection::bioRead,
      nullptr,  // puts
      nullptr,  // gets
      TlsConnection::bioCtrl,
      TlsConnection::bioCreate,
      TlsConnection::bioDestroy,
      nullptr
    };
    return &VTABLE;
  }
#else
  static const BIO_METHOD* getBioVtable() {
    static const BIO_METHOD* const vtable = makeBioVtable();
364 365
    return vtable;
  }
366
  static const BIO_METHOD* makeBioVtable() {
367 368 369 370 371 372 373 374
    BIO_METHOD* vtable = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "KJ stream");
    BIO_meth_set_write(vtable, TlsConnection::bioWrite);
    BIO_meth_set_read(vtable, TlsConnection::bioRead);
    BIO_meth_set_ctrl(vtable, TlsConnection::bioCtrl);
    BIO_meth_set_create(vtable, TlsConnection::bioCreate);
    BIO_meth_set_destroy(vtable, TlsConnection::bioDestroy);
    return vtable;
  }
375
#endif
Kenton Varda's avatar
Kenton Varda committed
376 377 378 379 380
};

// =======================================================================================
// Implementations of ConnectionReceiver, NetworkAddress, and Network as wrappers adding TLS.

381
class TlsConnectionReceiver final: public kj::ConnectionReceiver {
Kenton Varda's avatar
Kenton Varda committed
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
public:
  TlsConnectionReceiver(TlsContext& tls, kj::Own<kj::ConnectionReceiver> inner)
      : tls(tls), inner(kj::mv(inner)) {}

  Promise<Own<AsyncIoStream>> accept() override {
    return inner->accept().then([this](kj::Own<AsyncIoStream> stream) {
      return tls.wrapServer(kj::mv(stream));
    });
  }

  uint getPort() override {
    return inner->getPort();
  }

  void getsockopt(int level, int option, void* value, uint* length) override {
    return inner->getsockopt(level, option, value, length);
  }

  void setsockopt(int level, int option, const void* value, uint length) override {
    return inner->setsockopt(level, option, value, length);
  }

private:
  TlsContext& tls;
  kj::Own<kj::ConnectionReceiver> inner;
};

409
class TlsNetworkAddress final: public kj::NetworkAddress {
Kenton Varda's avatar
Kenton Varda committed
410 411 412 413 414
public:
  TlsNetworkAddress(TlsContext& tls, kj::String hostname, kj::Own<kj::NetworkAddress>&& inner)
      : tls(tls), hostname(kj::mv(hostname)), inner(kj::mv(inner)) {}

  Promise<Own<AsyncIoStream>> connect() override {
415 416 417 418 419 420 421 422 423
    // Note: It's unfortunately pretty common for people to assume they can drop the NetworkAddress
    //   as soon as connect() returns, and this works with the native network implementation.
    //   So, we make some copies here.
    auto& tlsRef = tls;
    auto hostnameCopy = kj::str(hostname);
    return inner->connect().then(kj::mvCapture(hostnameCopy,
        [&tlsRef](kj::String&& hostname, Own<AsyncIoStream>&& stream) {
      return tlsRef.wrapClient(kj::mv(stream), hostname);
    }));
Kenton Varda's avatar
Kenton Varda committed
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
  }

  Own<ConnectionReceiver> listen() override {
    return tls.wrapPort(inner->listen());
  }

  Own<NetworkAddress> clone() override {
    return kj::heap<TlsNetworkAddress>(tls, kj::str(hostname), inner->clone());
  }

  String toString() override {
    return kj::str("tls:", inner->toString());
  }

private:
  TlsContext& tls;
  kj::String hostname;
  kj::Own<kj::NetworkAddress> inner;
};

444
class TlsNetwork final: public kj::Network {
Kenton Varda's avatar
Kenton Varda committed
445 446
public:
  TlsNetwork(TlsContext& tls, kj::Network& inner): tls(tls), inner(inner) {}
447 448
  TlsNetwork(TlsContext& tls, kj::Own<kj::Network> inner)
      : tls(tls), inner(*inner), ownInner(kj::mv(inner)) {}
Kenton Varda's avatar
Kenton Varda committed
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468

  Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint) override {
    kj::String hostname;
    KJ_IF_MAYBE(pos, addr.findFirst(':')) {
      hostname = kj::heapString(addr.slice(0, *pos));
    } else {
      hostname = kj::heapString(addr);
    }

    return inner.parseAddress(addr, portHint)
        .then(kj::mvCapture(hostname, [this](kj::String&& hostname, kj::Own<NetworkAddress>&& addr)
            -> kj::Own<kj::NetworkAddress> {
      return kj::heap<TlsNetworkAddress>(tls, kj::mv(hostname), kj::mv(addr));
    }));
  }

  Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
    KJ_UNIMPLEMENTED("TLS does not implement getSockaddr() because it needs to know hostnames");
  }

469 470 471 472 473 474 475 476 477
  Own<Network> restrictPeers(
      kj::ArrayPtr<const kj::StringPtr> allow,
      kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
    // TODO(someday): Maybe we could implement the ability to specify CA or hostname restrictions?
    //   Or is it better to let people do that via the TlsContext? A neat thing about
    //   restrictPeers() is that it's easy to make user-configurable.
    return kj::heap<TlsNetwork>(tls, inner.restrictPeers(allow, deny));
  }

Kenton Varda's avatar
Kenton Varda committed
478 479 480
private:
  TlsContext& tls;
  kj::Network& inner;
481
  kj::Own<kj::Network> ownInner;
Kenton Varda's avatar
Kenton Varda committed
482 483 484 485 486 487 488 489 490
};

}  // namespace

// =======================================================================================
// class TlsContext

TlsContext::Options::Options()
    : useSystemTrustStore(true),
491
      verifyClients(false),
Kenton Varda's avatar
Kenton Varda committed
492 493 494 495 496 497 498 499 500
      minVersion(TlsVersion::TLS_1_0),
      cipherList("ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:ECDHE-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:DES-CBC3-SHA:!DSS") {}
// Cipher list is Mozilla's "intermediate" list, except with classic DH removed since we don't
// currently support setting dhparams. See:
//     https://mozilla.github.io/server-side-tls/ssl-config-generator/
//
// Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll
// never bother.

501 502 503 504 505 506 507
struct TlsContext::SniCallback {
  // struct SniCallback exists only so that callback() can be declared in the .c++ file, since it
  // references OpenSSL types.

  static int callback(SSL* ssl, int* ad, void* arg);
};

Kenton Varda's avatar
Kenton Varda committed
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
TlsContext::TlsContext(Options options) {
  ensureOpenSslInitialized();

#if OPENSSL_VERSION_NUMBER >= 0x10100000L || defined(OPENSSL_IS_BORINGSSL)
  SSL_CTX* ctx = SSL_CTX_new(TLS_method());
#else
  SSL_CTX* ctx = SSL_CTX_new(SSLv23_method());
#endif

  if (ctx == nullptr) {
    throwOpensslError();
  }
  KJ_ON_SCOPE_FAILURE(SSL_CTX_free(ctx));

  // honor options.useSystemTrustStore
  if (options.useSystemTrustStore) {
    if (!SSL_CTX_set_default_verify_paths(ctx)) {
      throwOpensslError();
    }
  }

  // honor options.trustedCertificates
  if (options.trustedCertificates.size() > 0) {
    X509_STORE* store = SSL_CTX_get_cert_store(ctx);
    if (store == nullptr) {
      throwOpensslError();
    }
    for (auto& cert: options.trustedCertificates) {
      if (!X509_STORE_add_cert(store, reinterpret_cast<X509*>(cert.chain[0]))) {
        throwOpensslError();
      }
    }
  }

542 543 544 545
  if (options.verifyClients) {
    SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL);
  }

Kenton Varda's avatar
Kenton Varda committed
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591
  // honor options.minVersion
  long optionFlags = 0;
  if (options.minVersion > TlsVersion::SSL_3) {
    optionFlags |= SSL_OP_NO_SSLv3;
  }
  if (options.minVersion > TlsVersion::TLS_1_0) {
    optionFlags |= SSL_OP_NO_TLSv1;
  }
  if (options.minVersion > TlsVersion::TLS_1_1) {
    optionFlags |= SSL_OP_NO_TLSv1_1;
  }
  if (options.minVersion > TlsVersion::TLS_1_2) {
    optionFlags |= SSL_OP_NO_TLSv1_2;
  }
  SSL_CTX_set_options(ctx, optionFlags);  // note: never fails; returns new options bitmask

  // honor options.cipherList
  if (!SSL_CTX_set_cipher_list(ctx, options.cipherList.cStr())) {
    throwOpensslError();
  }

  // honor options.defaultKeypair
  KJ_IF_MAYBE(kp, options.defaultKeypair) {
    if (!SSL_CTX_use_PrivateKey(ctx, reinterpret_cast<EVP_PKEY*>(kp->privateKey.pkey))) {
      throwOpensslError();
    }

    if (!SSL_CTX_use_certificate(ctx, reinterpret_cast<X509*>(kp->certificate.chain[0]))) {
      throwOpensslError();
    }

    for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) {
      X509* x509 = reinterpret_cast<X509*>(kp->certificate.chain[i]);
      if (x509 == nullptr) break;  // end of chain

      if (!SSL_CTX_add_extra_chain_cert(ctx, x509)) {
        throwOpensslError();
      }

      // SSL_CTX_add_extra_chain_cert() does NOT up the refcount itself.
      X509_up_ref(x509);
    }
  }

  // honor options.sniCallback
  KJ_IF_MAYBE(sni, options.sniCallback) {
592
    SSL_CTX_set_tlsext_servername_callback(ctx, &SniCallback::callback);
Kenton Varda's avatar
Kenton Varda committed
593 594 595 596 597 598
    SSL_CTX_set_tlsext_servername_arg(ctx, sni);
  }

  this->ctx = ctx;
}

599
int TlsContext::SniCallback::callback(SSL* ssl, int* ad, void* arg) {
Kenton Varda's avatar
Kenton Varda committed
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
  // The third parameter is actually type TlsSniCallback*.

  KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
    TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg);

    const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
    if (name != nullptr) {
      KJ_IF_MAYBE(kp, sni.getKey(name)) {
        if (!SSL_use_PrivateKey(ssl, reinterpret_cast<EVP_PKEY*>(kp->privateKey.pkey))) {
          throwOpensslError();
        }

        if (!SSL_use_certificate(ssl, reinterpret_cast<X509*>(kp->certificate.chain[0]))) {
          throwOpensslError();
        }

        if (!SSL_clear_chain_certs(ssl)) {
          throwOpensslError();
        }

        for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) {
          X509* x509 = reinterpret_cast<X509*>(kp->certificate.chain[i]);
          if (x509 == nullptr) break;  // end of chain

          if (!SSL_add0_chain_cert(ssl, x509)) {
            throwOpensslError();
          }

          // SSL_add0_chain_cert() does NOT up the refcount itself.
          X509_up_ref(x509);
        }
      }
    }
  })) {
    KJ_LOG(ERROR, "exception when invoking SNI callback", *exception);
    *ad = SSL_AD_INTERNAL_ERROR;
    return SSL_TLSEXT_ERR_ALERT_FATAL;
  }

  return SSL_TLSEXT_ERR_OK;
}

TlsContext::~TlsContext() noexcept(false) {
  SSL_CTX_free(reinterpret_cast<SSL_CTX*>(ctx));
}

kj::Promise<kj::Own<kj::AsyncIoStream>> TlsContext::wrapClient(
    kj::Own<kj::AsyncIoStream> stream, kj::StringPtr expectedServerHostname) {
  auto conn = kj::heap<TlsConnection>(kj::mv(stream), reinterpret_cast<SSL_CTX*>(ctx));
  auto promise = conn->connect(expectedServerHostname);
  return promise.then(kj::mvCapture(conn, [](kj::Own<TlsConnection> conn)
      -> kj::Own<kj::AsyncIoStream> {
    return kj::mv(conn);
  }));
}

kj::Promise<kj::Own<kj::AsyncIoStream>> TlsContext::wrapServer(kj::Own<kj::AsyncIoStream> stream) {
  auto conn = kj::heap<TlsConnection>(kj::mv(stream), reinterpret_cast<SSL_CTX*>(ctx));
  auto promise = conn->accept();
  return promise.then(kj::mvCapture(conn, [](kj::Own<TlsConnection> conn)
      -> kj::Own<kj::AsyncIoStream> {
    return kj::mv(conn);
  }));
}

kj::Own<kj::ConnectionReceiver> TlsContext::wrapPort(kj::Own<kj::ConnectionReceiver> port) {
  return kj::heap<TlsConnectionReceiver>(*this, kj::mv(port));
}

kj::Own<kj::Network> TlsContext::wrapNetwork(kj::Network& network) {
  return kj::heap<TlsNetwork>(*this, network);
}

// =======================================================================================
// class TlsPrivateKey

TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) {
  ensureOpenSslInitialized();

  const byte* ptr = asn1.begin();
  pkey = d2i_AutoPrivateKey(nullptr, &ptr, asn1.size());
  if (pkey == nullptr) {
    throwOpensslError();
  }
}

686
TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password) {
Kenton Varda's avatar
Kenton Varda committed
687 688 689 690 691 692
  ensureOpenSslInitialized();

  // const_cast apparently needed for older versions of OpenSSL.
  BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
  KJ_DEFER(BIO_free(bio));

693
  pkey = PEM_read_bio_PrivateKey(bio, nullptr, &passwordCallback, &password);
Kenton Varda's avatar
Kenton Varda committed
694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
  if (pkey == nullptr) {
    throwOpensslError();
  }
}

TlsPrivateKey::TlsPrivateKey(const TlsPrivateKey& other)
    : pkey(other.pkey) {
  if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast<EVP_PKEY*>(pkey));
}

TlsPrivateKey& TlsPrivateKey::operator=(const TlsPrivateKey& other) {
  if (pkey != other.pkey) {
    EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
    pkey = other.pkey;
    if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast<EVP_PKEY*>(pkey));
  }
  return *this;
}

TlsPrivateKey::~TlsPrivateKey() noexcept(false) {
  EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
}

717 718 719 720 721 722 723 724 725 726 727 728
int TlsPrivateKey::passwordCallback(char* buf, int size, int rwflag, void* u) {
  auto& password = *reinterpret_cast<kj::Maybe<kj::StringPtr>*>(u);

  KJ_IF_MAYBE(p, password) {
    int result = kj::min(p->size(), size);
    memcpy(buf, p->begin(), result);
    return result;
  } else {
    return 0;
  }
}

Kenton Varda's avatar
Kenton Varda committed
729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
// =======================================================================================
// class TlsCertificate

TlsCertificate::TlsCertificate(kj::ArrayPtr<const kj::ArrayPtr<const byte>> asn1) {
  ensureOpenSslInitialized();

  KJ_REQUIRE(asn1.size() > 0, "must provide at least one certificate in chain");
  KJ_REQUIRE(asn1.size() <= kj::size(chain),
      "exceeded maximum certificate chain length of 10");

  memset(chain, 0, sizeof(chain));

  for (auto i: kj::indices(asn1)) {
    auto p = asn1[i].begin();

    // "_AUX" apparently refers to some auxilliary information that can be appended to the
    // certificate, but should only be trusted for your own certificate, not the whole chain??
    // I don't really know, I'm just cargo-culting.
    chain[i] = i == 0 ? d2i_X509_AUX(nullptr, &p, asn1[i].size())
                      : d2i_X509(nullptr, &p, asn1[i].size());

    if (chain[i] == nullptr) {
      for (size_t j = 0; j < i; j++) {
        X509_free(reinterpret_cast<X509*>(chain[j]));
      }
      throwOpensslError();
    }
  }
}

TlsCertificate::TlsCertificate(kj::ArrayPtr<const byte> asn1)
    : TlsCertificate(kj::arrayPtr(&asn1, 1)) {}

TlsCertificate::TlsCertificate(kj::StringPtr pem) {
  ensureOpenSslInitialized();

  memset(chain, 0, sizeof(chain));

  // const_cast apparently needed for older versions of OpenSSL.
  BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
  KJ_DEFER(BIO_free(bio));

  for (auto i: kj::indices(chain)) {
    // "_AUX" apparently refers to some auxilliary information that can be appended to the
    // certificate, but should only be trusted for your own certificate, not the whole chain??
    // I don't really know, I'm just cargo-culting.
    chain[i] = i == 0 ? PEM_read_bio_X509_AUX(bio, nullptr, nullptr, nullptr)
                      : PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);

    if (chain[i] == nullptr) {
      auto error = ERR_peek_last_error();
      if (i > 0 && ERR_GET_LIB(error) == ERR_LIB_PEM &&
          ERR_GET_REASON(error) == PEM_R_NO_START_LINE) {
        // EOF; we're done.
        ERR_clear_error();
        return;
      } else {
        for (size_t j = 0; j < i; j++) {
          X509_free(reinterpret_cast<X509*>(chain[j]));
        }
        throwOpensslError();
      }
    }
  }

  // We reached the chain length limit. Try to read one more to verify that the chain ends here.
  X509* dummy = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
  if (dummy != nullptr) {
    X509_free(dummy);
    for (auto i: kj::indices(chain)) {
      X509_free(reinterpret_cast<X509*>(chain[i]));
    }
    KJ_FAIL_REQUIRE("exceeded maximum certificate chain length of 10");
  }
}

TlsCertificate::TlsCertificate(const TlsCertificate& other) {
  memcpy(chain, other.chain, sizeof(chain));
  for (void* p: chain) {
    if (p == nullptr) break;  // end of chain; quit early
    X509_up_ref(reinterpret_cast<X509*>(p));
  }
}

TlsCertificate& TlsCertificate::operator=(const TlsCertificate& other) {
  for (auto i: kj::indices(chain)) {
    if (chain[i] != other.chain[i]) {
      EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(chain[i]));
      chain[i] = other.chain[i];
      if (chain[i] != nullptr) X509_up_ref(reinterpret_cast<X509*>(chain[i]));
    } else if (chain[i] == nullptr) {
      // end of both chains; quit early
      break;
    }
  }
  return *this;
}

TlsCertificate::~TlsCertificate() noexcept(false) {
  for (void* p: chain) {
    if (p == nullptr) break;  // end of chain; quit early
    X509_free(reinterpret_cast<X509*>(p));
  }
}

}  // namespace kj

836
#endif  // KJ_HAS_OPENSSL