Commit 5183fcaa authored by Kenton Varda's avatar Kenton Varda

Make TLS work with latest BoringSSL and fix readiness I/O stalling bug.

parent 422b6a27
...@@ -748,7 +748,6 @@ ChainPromiseNode::~ChainPromiseNode() noexcept(false) {} ...@@ -748,7 +748,6 @@ ChainPromiseNode::~ChainPromiseNode() noexcept(false) {}
void ChainPromiseNode::onReady(Event* event) noexcept { void ChainPromiseNode::onReady(Event* event) noexcept {
switch (state) { switch (state) {
case STEP1: case STEP1:
KJ_REQUIRE(onReadyEvent == nullptr, "onReady() can only be called once.");
onReadyEvent = event; onReadyEvent = event;
return; return;
case STEP2: case STEP2:
......
...@@ -49,7 +49,8 @@ kj::Maybe<size_t> ReadyInputStreamWrapper::read(kj::ArrayPtr<byte> dst) { ...@@ -49,7 +49,8 @@ kj::Maybe<size_t> ReadyInputStreamWrapper::read(kj::ArrayPtr<byte> dst) {
} else { } else {
content = kj::arrayPtr(buffer, n); content = kj::arrayPtr(buffer, n);
} }
}).attach(kj::defer([this]() {isPumping = false;})); isPumping = false;
});
}).fork(); }).fork();
} }
...@@ -98,7 +99,7 @@ kj::Maybe<size_t> ReadyOutputStreamWrapper::write(kj::ArrayPtr<const byte> data) ...@@ -98,7 +99,7 @@ kj::Maybe<size_t> ReadyOutputStreamWrapper::write(kj::ArrayPtr<const byte> data)
if (!isPumping) { if (!isPumping) {
isPumping = true; isPumping = true;
pumpTask = kj::evalNow([&]() { pumpTask = kj::evalNow([&]() {
return pump().attach(kj::defer([this]() {isPumping = false;})); return pump();
}).fork(); }).fork();
} }
...@@ -130,6 +131,7 @@ kj::Promise<void> ReadyOutputStreamWrapper::pump() { ...@@ -130,6 +131,7 @@ kj::Promise<void> ReadyOutputStreamWrapper::pump() {
if (filled > 0) { if (filled > 0) {
return pump(); return pump();
} else { } else {
isPumping = false;
return kj::READY_NOW; return kj::READY_NOW;
} }
}); });
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <kj/test.h> #include <kj/test.h>
#include <kj/async-io.h> #include <kj/async-io.h>
#include <stdlib.h> #include <stdlib.h>
#include <openssl/opensslv.h>
namespace kj { namespace kj {
namespace { namespace {
...@@ -344,6 +345,92 @@ KJ_TEST("TLS basics") { ...@@ -344,6 +345,92 @@ KJ_TEST("TLS basics") {
KJ_ASSERT(kj::StringPtr(buf) == "foo"); KJ_ASSERT(kj::StringPtr(buf) == "foo");
} }
KJ_TEST("TLS multiple messages") {
TlsTest test;
ErrorNexus e;
auto pipe = test.io.provider->newTwoWayPipe();
auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"));
auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1])));
auto client = clientPromise.wait(test.io.waitScope);
auto server = serverPromise.wait(test.io.waitScope);
auto writePromise = client->write("foo", 3)
.then([&]() { return client->write("bar", 3); });
char buf[4];
buf[3] = '\0';
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "foo");
writePromise = writePromise
.then([&]() { return client->write("baz", 3); });
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "bar");
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "baz");
auto readPromise = server->read(&buf, 3);
KJ_EXPECT(!readPromise.poll(test.io.waitScope));
writePromise = writePromise
.then([&]() { return client->write("qux", 3); });
readPromise.wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "qux");
}
kj::Promise<void> writeN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t count) {
if (count == 0) return kj::READY_NOW;
--count;
return stream.write(text.begin(), text.size())
.then([&stream, text, count]() {
return writeN(stream, text, count);
});
}
kj::Promise<void> readN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t count) {
if (count == 0) return kj::READY_NOW;
--count;
auto buf = kj::heapString(text.size());
auto promise = stream.read(buf.begin(), buf.size());
return promise.then(kj::mvCapture(buf, [&stream, text, count](kj::String buf) {
KJ_ASSERT(buf == text, buf, text, count);
return readN(stream, text, count);
}));
}
KJ_TEST("TLS full duplex") {
TlsTest test;
ErrorNexus e;
auto pipe = test.io.provider->newTwoWayPipe();
auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"));
auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1])));
auto client = clientPromise.wait(test.io.waitScope);
auto server = serverPromise.wait(test.io.waitScope);
auto writeUp = writeN(*client, "foo", 10000);
auto readDown = readN(*client, "bar", 10000);
KJ_EXPECT(!writeUp.poll(test.io.waitScope));
KJ_EXPECT(!readDown.poll(test.io.waitScope));
auto writeDown = writeN(*server, "bar", 10000);
auto readUp = readN(*server, "foo", 10000);
readUp.wait(test.io.waitScope);
readDown.wait(test.io.waitScope);
writeUp.wait(test.io.waitScope);
writeDown.wait(test.io.waitScope);
}
class TestSniCallback: public TlsSniCallback { class TestSniCallback: public TlsSniCallback {
public: public:
kj::Maybe<TlsKeypair> getKey(kj::StringPtr hostname) override { kj::Maybe<TlsKeypair> getKey(kj::StringPtr hostname) override {
...@@ -411,6 +498,13 @@ KJ_TEST("TLS certificate validation") { ...@@ -411,6 +498,13 @@ KJ_TEST("TLS certificate validation") {
"self signed certificate"); "self signed certificate");
} }
// BoringSSL seems to print error messages differently.
#ifdef OPENSSL_IS_BORINGSSL
#define SSL_MESSAGE(interesting, boring) boring
#else
#define SSL_MESSAGE(interesting, boring) interesting
#endif
KJ_TEST("TLS client certificate verification") { KJ_TEST("TLS client certificate verification") {
TlsContext::Options serverOptions = TlsTest::defaultServer(); TlsContext::Options serverOptions = TlsTest::defaultServer();
TlsContext::Options clientOptions = TlsTest::defaultClient(); TlsContext::Options clientOptions = TlsTest::defaultClient();
...@@ -427,9 +521,13 @@ KJ_TEST("TLS client certificate verification") { ...@@ -427,9 +521,13 @@ KJ_TEST("TLS client certificate verification") {
auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"); auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com");
auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1])); auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1]));
KJ_EXPECT_THROW_MESSAGE("peer did not return a certificate", KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("peer did not return a certificate",
"PEER_DID_NOT_RETURN_A_CERTIFICATE"),
serverPromise.wait(test.io.waitScope)); serverPromise.wait(test.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("alert handshake failure", KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("alert handshake failure",
"SSLV3_ALERT_HANDSHAKE_FAILURE"),
clientPromise.wait(test.io.waitScope)); clientPromise.wait(test.io.waitScope));
} }
...@@ -445,9 +543,13 @@ KJ_TEST("TLS client certificate verification") { ...@@ -445,9 +543,13 @@ KJ_TEST("TLS client certificate verification") {
auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"); auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com");
auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1])); auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1]));
KJ_EXPECT_THROW_MESSAGE("certificate verify failed", KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("certificate verify failed",
"CERTIFICATE_VERIFY_FAILED"),
serverPromise.wait(test.io.waitScope)); serverPromise.wait(test.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("alert unknown ca", KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("alert unknown ca",
"TLSV1_ALERT_UNKNOWN_CA"),
clientPromise.wait(test.io.waitScope)); clientPromise.wait(test.io.waitScope));
} }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <openssl/x509v3.h> #include <openssl/x509v3.h>
#include <openssl/evp.h> #include <openssl/evp.h>
#include <openssl/conf.h> #include <openssl/conf.h>
#include <openssl/ssl.h>
#include <openssl/tls1.h> #include <openssl/tls1.h>
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/vector.h> #include <kj/vector.h>
...@@ -489,6 +490,13 @@ TlsContext::Options::Options() ...@@ -489,6 +490,13 @@ TlsContext::Options::Options()
// Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll // Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll
// never bother. // never bother.
struct TlsContext::SniCallback {
// struct SniCallback exists only so that callback() can be declared in the .c++ file, since it
// references OpenSSL types.
static int callback(SSL* ssl, int* ad, void* arg);
};
TlsContext::TlsContext(Options options) { TlsContext::TlsContext(Options options) {
ensureOpenSslInitialized(); ensureOpenSslInitialized();
...@@ -573,21 +581,17 @@ TlsContext::TlsContext(Options options) { ...@@ -573,21 +581,17 @@ TlsContext::TlsContext(Options options) {
// honor options.sniCallback // honor options.sniCallback
KJ_IF_MAYBE(sni, options.sniCallback) { KJ_IF_MAYBE(sni, options.sniCallback) {
SSL_CTX_set_tlsext_servername_callback(ctx, &sniCallback); SSL_CTX_set_tlsext_servername_callback(ctx, &SniCallback::callback);
SSL_CTX_set_tlsext_servername_arg(ctx, sni); SSL_CTX_set_tlsext_servername_arg(ctx, sni);
} }
this->ctx = ctx; this->ctx = ctx;
} }
int TlsContext::sniCallback(void* sslp, int* ad, void* arg) { int TlsContext::SniCallback::callback(SSL* ssl, int* ad, void* arg) {
// The first parameter is actually type SSL*, but we didn't want to include the OpenSSL headers
// from our header.
//
// The third parameter is actually type TlsSniCallback*. // The third parameter is actually type TlsSniCallback*.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
SSL* ssl = reinterpret_cast<SSL*>(sslp);
TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg); TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg);
const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
...@@ -671,14 +675,14 @@ TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) { ...@@ -671,14 +675,14 @@ TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) {
} }
} }
TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem) { TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password) {
ensureOpenSslInitialized(); ensureOpenSslInitialized();
// const_cast apparently needed for older versions of OpenSSL. // const_cast apparently needed for older versions of OpenSSL.
BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size()); BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
KJ_DEFER(BIO_free(bio)); KJ_DEFER(BIO_free(bio));
pkey = PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr); pkey = PEM_read_bio_PrivateKey(bio, nullptr, &passwordCallback, &password);
if (pkey == nullptr) { if (pkey == nullptr) {
throwOpensslError(); throwOpensslError();
} }
...@@ -702,6 +706,18 @@ TlsPrivateKey::~TlsPrivateKey() noexcept(false) { ...@@ -702,6 +706,18 @@ TlsPrivateKey::~TlsPrivateKey() noexcept(false) {
EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey)); EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
} }
int TlsPrivateKey::passwordCallback(char* buf, int size, int rwflag, void* u) {
auto& password = *reinterpret_cast<kj::Maybe<kj::StringPtr>*>(u);
KJ_IF_MAYBE(p, password) {
int result = kj::min(p->size(), size);
memcpy(buf, p->begin(), result);
return result;
} else {
return 0;
}
}
// ======================================================================================= // =======================================================================================
// class TlsCertificate // class TlsCertificate
......
...@@ -125,7 +125,7 @@ public: ...@@ -125,7 +125,7 @@ public:
private: private:
void* ctx; // actually type SSL_CTX, but we don't want to #include the OpenSSL headers here void* ctx; // actually type SSL_CTX, but we don't want to #include the OpenSSL headers here
static int sniCallback(void* ssl, int* ad, void* arg); struct SniCallback;
}; };
class TlsPrivateKey { class TlsPrivateKey {
...@@ -137,10 +137,10 @@ public: ...@@ -137,10 +137,10 @@ public:
// RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to // RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to
// decrypt. // decrypt.
TlsPrivateKey(kj::StringPtr pem); TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password = nullptr);
// Parse a single PEM-encoded private key. Supports PKCS8 keys as well as "traditional format" // Parse a single PEM-encoded private key. Supports PKCS8 keys as well as "traditional format"
// RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to // RSA and DSA keys. A password may optionally be provided and will be used if the key is
// decrypt. // encrypted.
~TlsPrivateKey() noexcept(false); ~TlsPrivateKey() noexcept(false);
...@@ -158,6 +158,8 @@ private: ...@@ -158,6 +158,8 @@ private:
void* pkey; // actually type EVP_PKEY* void* pkey; // actually type EVP_PKEY*
friend class TlsContext; friend class TlsContext;
static int passwordCallback(char* buf, int size, int rwflag, void* u);
}; };
class TlsCertificate { class TlsCertificate {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment