Commit 5183fcaa authored by Kenton Varda's avatar Kenton Varda

Make TLS work with latest BoringSSL and fix readiness I/O stalling bug.

parent 422b6a27
......@@ -748,7 +748,6 @@ ChainPromiseNode::~ChainPromiseNode() noexcept(false) {}
void ChainPromiseNode::onReady(Event* event) noexcept {
switch (state) {
case STEP1:
KJ_REQUIRE(onReadyEvent == nullptr, "onReady() can only be called once.");
onReadyEvent = event;
return;
case STEP2:
......
......@@ -49,7 +49,8 @@ kj::Maybe<size_t> ReadyInputStreamWrapper::read(kj::ArrayPtr<byte> dst) {
} else {
content = kj::arrayPtr(buffer, n);
}
}).attach(kj::defer([this]() {isPumping = false;}));
isPumping = false;
});
}).fork();
}
......@@ -98,7 +99,7 @@ kj::Maybe<size_t> ReadyOutputStreamWrapper::write(kj::ArrayPtr<const byte> data)
if (!isPumping) {
isPumping = true;
pumpTask = kj::evalNow([&]() {
return pump().attach(kj::defer([this]() {isPumping = false;}));
return pump();
}).fork();
}
......@@ -130,6 +131,7 @@ kj::Promise<void> ReadyOutputStreamWrapper::pump() {
if (filled > 0) {
return pump();
} else {
isPumping = false;
return kj::READY_NOW;
}
});
......
......@@ -25,6 +25,7 @@
#include <kj/test.h>
#include <kj/async-io.h>
#include <stdlib.h>
#include <openssl/opensslv.h>
namespace kj {
namespace {
......@@ -344,6 +345,92 @@ KJ_TEST("TLS basics") {
KJ_ASSERT(kj::StringPtr(buf) == "foo");
}
KJ_TEST("TLS multiple messages") {
TlsTest test;
ErrorNexus e;
auto pipe = test.io.provider->newTwoWayPipe();
auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"));
auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1])));
auto client = clientPromise.wait(test.io.waitScope);
auto server = serverPromise.wait(test.io.waitScope);
auto writePromise = client->write("foo", 3)
.then([&]() { return client->write("bar", 3); });
char buf[4];
buf[3] = '\0';
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "foo");
writePromise = writePromise
.then([&]() { return client->write("baz", 3); });
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "bar");
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "baz");
auto readPromise = server->read(&buf, 3);
KJ_EXPECT(!readPromise.poll(test.io.waitScope));
writePromise = writePromise
.then([&]() { return client->write("qux", 3); });
readPromise.wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "qux");
}
kj::Promise<void> writeN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t count) {
if (count == 0) return kj::READY_NOW;
--count;
return stream.write(text.begin(), text.size())
.then([&stream, text, count]() {
return writeN(stream, text, count);
});
}
kj::Promise<void> readN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t count) {
if (count == 0) return kj::READY_NOW;
--count;
auto buf = kj::heapString(text.size());
auto promise = stream.read(buf.begin(), buf.size());
return promise.then(kj::mvCapture(buf, [&stream, text, count](kj::String buf) {
KJ_ASSERT(buf == text, buf, text, count);
return readN(stream, text, count);
}));
}
KJ_TEST("TLS full duplex") {
TlsTest test;
ErrorNexus e;
auto pipe = test.io.provider->newTwoWayPipe();
auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"));
auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1])));
auto client = clientPromise.wait(test.io.waitScope);
auto server = serverPromise.wait(test.io.waitScope);
auto writeUp = writeN(*client, "foo", 10000);
auto readDown = readN(*client, "bar", 10000);
KJ_EXPECT(!writeUp.poll(test.io.waitScope));
KJ_EXPECT(!readDown.poll(test.io.waitScope));
auto writeDown = writeN(*server, "bar", 10000);
auto readUp = readN(*server, "foo", 10000);
readUp.wait(test.io.waitScope);
readDown.wait(test.io.waitScope);
writeUp.wait(test.io.waitScope);
writeDown.wait(test.io.waitScope);
}
class TestSniCallback: public TlsSniCallback {
public:
kj::Maybe<TlsKeypair> getKey(kj::StringPtr hostname) override {
......@@ -411,6 +498,13 @@ KJ_TEST("TLS certificate validation") {
"self signed certificate");
}
// BoringSSL seems to print error messages differently.
#ifdef OPENSSL_IS_BORINGSSL
#define SSL_MESSAGE(interesting, boring) boring
#else
#define SSL_MESSAGE(interesting, boring) interesting
#endif
KJ_TEST("TLS client certificate verification") {
TlsContext::Options serverOptions = TlsTest::defaultServer();
TlsContext::Options clientOptions = TlsTest::defaultClient();
......@@ -427,9 +521,13 @@ KJ_TEST("TLS client certificate verification") {
auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com");
auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1]));
KJ_EXPECT_THROW_MESSAGE("peer did not return a certificate",
KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("peer did not return a certificate",
"PEER_DID_NOT_RETURN_A_CERTIFICATE"),
serverPromise.wait(test.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("alert handshake failure",
KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("alert handshake failure",
"SSLV3_ALERT_HANDSHAKE_FAILURE"),
clientPromise.wait(test.io.waitScope));
}
......@@ -445,9 +543,13 @@ KJ_TEST("TLS client certificate verification") {
auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com");
auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1]));
KJ_EXPECT_THROW_MESSAGE("certificate verify failed",
KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("certificate verify failed",
"CERTIFICATE_VERIFY_FAILED"),
serverPromise.wait(test.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("alert unknown ca",
KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("alert unknown ca",
"TLSV1_ALERT_UNKNOWN_CA"),
clientPromise.wait(test.io.waitScope));
}
......
......@@ -30,6 +30,7 @@
#include <openssl/x509v3.h>
#include <openssl/evp.h>
#include <openssl/conf.h>
#include <openssl/ssl.h>
#include <openssl/tls1.h>
#include <kj/debug.h>
#include <kj/vector.h>
......@@ -489,6 +490,13 @@ TlsContext::Options::Options()
// Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll
// never bother.
struct TlsContext::SniCallback {
// struct SniCallback exists only so that callback() can be declared in the .c++ file, since it
// references OpenSSL types.
static int callback(SSL* ssl, int* ad, void* arg);
};
TlsContext::TlsContext(Options options) {
ensureOpenSslInitialized();
......@@ -573,21 +581,17 @@ TlsContext::TlsContext(Options options) {
// honor options.sniCallback
KJ_IF_MAYBE(sni, options.sniCallback) {
SSL_CTX_set_tlsext_servername_callback(ctx, &sniCallback);
SSL_CTX_set_tlsext_servername_callback(ctx, &SniCallback::callback);
SSL_CTX_set_tlsext_servername_arg(ctx, sni);
}
this->ctx = ctx;
}
int TlsContext::sniCallback(void* sslp, int* ad, void* arg) {
// The first parameter is actually type SSL*, but we didn't want to include the OpenSSL headers
// from our header.
//
int TlsContext::SniCallback::callback(SSL* ssl, int* ad, void* arg) {
// The third parameter is actually type TlsSniCallback*.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
SSL* ssl = reinterpret_cast<SSL*>(sslp);
TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg);
const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
......@@ -671,14 +675,14 @@ TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) {
}
}
TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem) {
TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password) {
ensureOpenSslInitialized();
// const_cast apparently needed for older versions of OpenSSL.
BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
KJ_DEFER(BIO_free(bio));
pkey = PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr);
pkey = PEM_read_bio_PrivateKey(bio, nullptr, &passwordCallback, &password);
if (pkey == nullptr) {
throwOpensslError();
}
......@@ -702,6 +706,18 @@ TlsPrivateKey::~TlsPrivateKey() noexcept(false) {
EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
}
int TlsPrivateKey::passwordCallback(char* buf, int size, int rwflag, void* u) {
auto& password = *reinterpret_cast<kj::Maybe<kj::StringPtr>*>(u);
KJ_IF_MAYBE(p, password) {
int result = kj::min(p->size(), size);
memcpy(buf, p->begin(), result);
return result;
} else {
return 0;
}
}
// =======================================================================================
// class TlsCertificate
......
......@@ -125,7 +125,7 @@ public:
private:
void* ctx; // actually type SSL_CTX, but we don't want to #include the OpenSSL headers here
static int sniCallback(void* ssl, int* ad, void* arg);
struct SniCallback;
};
class TlsPrivateKey {
......@@ -137,10 +137,10 @@ public:
// RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to
// decrypt.
TlsPrivateKey(kj::StringPtr pem);
TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password = nullptr);
// Parse a single PEM-encoded private key. Supports PKCS8 keys as well as "traditional format"
// RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to
// decrypt.
// RSA and DSA keys. A password may optionally be provided and will be used if the key is
// encrypted.
~TlsPrivateKey() noexcept(false);
......@@ -158,6 +158,8 @@ private:
void* pkey; // actually type EVP_PKEY*
friend class TlsContext;
static int passwordCallback(char* buf, int size, int rwflag, void* u);
};
class TlsCertificate {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment