Unverified Commit 08dd2e79 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #612 from capnproto/membrane-fixes

Various tweaks needed by Sandstorm HTTP changes
parents 02be7235 743e7cec
...@@ -321,27 +321,41 @@ public: ...@@ -321,27 +321,41 @@ public:
static kj::Own<ClientHook> wrap(ClientHook& cap, MembranePolicy& policy, bool reverse) { static kj::Own<ClientHook> wrap(ClientHook& cap, MembranePolicy& policy, bool reverse) {
if (cap.getBrand() == MEMBRANE_BRAND) { if (cap.getBrand() == MEMBRANE_BRAND) {
auto& otherMembrane = kj::downcast<MembraneHook>(cap); auto& otherMembrane = kj::downcast<MembraneHook>(cap);
if (otherMembrane.policy.get() == &policy && otherMembrane.reverse == !reverse) { auto& rootPolicy = policy.rootPolicy();
if (&otherMembrane.policy->rootPolicy() == &rootPolicy &&
otherMembrane.reverse == !reverse) {
// Capability that passed across the membrane one way is now passing back the other way. // Capability that passed across the membrane one way is now passing back the other way.
// Unwrap it rather than double-wrap it. // Unwrap it rather than double-wrap it.
return otherMembrane.inner->addRef(); Capability::Client unwrapped(otherMembrane.inner->addRef());
return ClientHook::from(
reverse ? rootPolicy.importInternal(kj::mv(unwrapped), *otherMembrane.policy, policy)
: rootPolicy.exportExternal(kj::mv(unwrapped), *otherMembrane.policy, policy));
} }
} }
return kj::refcounted<MembraneHook>(cap.addRef(), policy.addRef(), reverse); return ClientHook::from(
reverse ? policy.importExternal(Capability::Client(cap.addRef()))
: policy.exportInternal(Capability::Client(cap.addRef())));
} }
static kj::Own<ClientHook> wrap(kj::Own<ClientHook> cap, MembranePolicy& policy, bool reverse) { static kj::Own<ClientHook> wrap(kj::Own<ClientHook> cap, MembranePolicy& policy, bool reverse) {
if (cap->getBrand() == MEMBRANE_BRAND) { if (cap->getBrand() == MEMBRANE_BRAND) {
auto& otherMembrane = kj::downcast<MembraneHook>(*cap); auto& otherMembrane = kj::downcast<MembraneHook>(*cap);
if (otherMembrane.policy.get() == &policy && otherMembrane.reverse == !reverse) { auto& rootPolicy = policy.rootPolicy();
if (&otherMembrane.policy->rootPolicy() == &rootPolicy &&
otherMembrane.reverse == !reverse) {
// Capability that passed across the membrane one way is now passing back the other way. // Capability that passed across the membrane one way is now passing back the other way.
// Unwrap it rather than double-wrap it. // Unwrap it rather than double-wrap it.
return otherMembrane.inner->addRef(); Capability::Client unwrapped(otherMembrane.inner->addRef());
return ClientHook::from(
reverse ? rootPolicy.importInternal(kj::mv(unwrapped), *otherMembrane.policy, policy)
: rootPolicy.exportExternal(kj::mv(unwrapped), *otherMembrane.policy, policy));
} }
} }
return kj::refcounted<MembraneHook>(kj::mv(cap), policy.addRef(), reverse); return ClientHook::from(
reverse ? policy.importExternal(Capability::Client(kj::mv(cap)))
: policy.exportInternal(Capability::Client(kj::mv(cap))));
} }
Request<AnyPointer, AnyPointer> newCall( Request<AnyPointer, AnyPointer> newCall(
...@@ -359,7 +373,8 @@ public: ...@@ -359,7 +373,8 @@ public:
// something outside the membrane later. We have to wait before we actually redirect, // something outside the membrane later. We have to wait before we actually redirect,
// otherwise behavior will differ depending on whether the promise is resolved. // otherwise behavior will differ depending on whether the promise is resolved.
KJ_IF_MAYBE(p, whenMoreResolved()) { KJ_IF_MAYBE(p, whenMoreResolved()) {
return newLocalPromiseClient(kj::mv(*p))->newCall(interfaceId, methodId, sizeHint); return newLocalPromiseClient(p->attach(addRef()))
->newCall(interfaceId, methodId, sizeHint);
} }
return ClientHook::from(kj::mv(*r))->newCall(interfaceId, methodId, sizeHint); return ClientHook::from(kj::mv(*r))->newCall(interfaceId, methodId, sizeHint);
...@@ -386,7 +401,8 @@ public: ...@@ -386,7 +401,8 @@ public:
// something outside the membrane later. We have to wait before we actually redirect, // something outside the membrane later. We have to wait before we actually redirect,
// otherwise behavior will differ depending on whether the promise is resolved. // otherwise behavior will differ depending on whether the promise is resolved.
KJ_IF_MAYBE(p, whenMoreResolved()) { KJ_IF_MAYBE(p, whenMoreResolved()) {
return newLocalPromiseClient(kj::mv(*p))->call(interfaceId, methodId, kj::mv(context)); return newLocalPromiseClient(p->attach(addRef()))
->call(interfaceId, methodId, kj::mv(context));
} }
return ClientHook::from(kj::mv(*r))->call(interfaceId, methodId, kj::mv(context)); return ClientHook::from(kj::mv(*r))->call(interfaceId, methodId, kj::mv(context));
...@@ -467,6 +483,26 @@ kj::Own<ClientHook> membrane(kj::Own<ClientHook> inner, MembranePolicy& policy, ...@@ -467,6 +483,26 @@ kj::Own<ClientHook> membrane(kj::Own<ClientHook> inner, MembranePolicy& policy,
} // namespace } // namespace
Capability::Client MembranePolicy::importExternal(Capability::Client external) {
return Capability::Client(kj::refcounted<MembraneHook>(
ClientHook::from(kj::mv(external)), addRef(), true));
}
Capability::Client MembranePolicy::exportInternal(Capability::Client internal) {
return Capability::Client(kj::refcounted<MembraneHook>(
ClientHook::from(kj::mv(internal)), addRef(), false));
}
Capability::Client MembranePolicy::importInternal(
Capability::Client internal, MembranePolicy& exportPolicy, MembranePolicy& importPolicy) {
return kj::mv(internal);
}
Capability::Client MembranePolicy::exportExternal(
Capability::Client external, MembranePolicy& importPolicy, MembranePolicy& exportPolicy) {
return kj::mv(external);
}
Capability::Client membrane(Capability::Client inner, kj::Own<MembranePolicy> policy) { Capability::Client membrane(Capability::Client inner, kj::Own<MembranePolicy> policy) {
return Capability::Client(membrane( return Capability::Client(membrane(
ClientHook::from(kj::mv(inner)), *policy, false)); ClientHook::from(kj::mv(inner)), *policy, false));
......
...@@ -114,6 +114,56 @@ public: ...@@ -114,6 +114,56 @@ public:
// After the revocation promise has rejected, inboundCall() and outboundCall() will still be // After the revocation promise has rejected, inboundCall() and outboundCall() will still be
// invoked for new calls, but the `target` passed to them will be a capability that always // invoked for new calls, but the `target` passed to them will be a capability that always
// rethrows the revocation exception. // rethrows the revocation exception.
// ---------------------------------------------------------------------------
// Control over importing and exporting.
//
// Most membranes should not override these methods. The default behavior is that a capability
// that crosses the membrane is wrapped in it, and if the wrapped version crosses back the other
// way, it is unwrapped.
virtual Capability::Client importExternal(Capability::Client external);
// An external capability is crossing into the membrane. Returns the capability that should
// substitute for it when called from the inside.
//
// The default implementation creates a capability that invokes this MembranePolicy. E.g. all
// calls will invoke outboundCall().
//
// Note that reverseMembrane(cap, policy) normally calls policy->importExternal(cap), unless
// `cap` itself was originally returned by the default implementation of exportInternal(), in
// which case importInternal() is called instead.
virtual Capability::Client exportInternal(Capability::Client internal);
// An internal capability is crossing out of the membrane. Returns the capability that should
// substitute for it when called from the outside.
//
// The default implementation creates a capability that invokes this MembranePolicy. E.g. all
// calls will invoke inboundCall().
//
// Note that membrane(cap, policy) normally calls policy->exportInternal(cap), unless `cap`
// itself was originally returned by the default implementation of exportInternal(), in which
// case importInternal() is called instead.
virtual MembranePolicy& rootPolicy() { return *this; }
// If two policies return the same value for rootPolicy(), then a capability imported through
// one can be exported through the other, and vice versa. `importInternal()` and
// `exportExternal()` will always be called on the root policy, passing the two child policies
// as parameters. If you don't override rootPolicy(), then the policy references passed to
// importInternal() and exportExternal() will always be references to *this.
virtual Capability::Client importInternal(
Capability::Client internal, MembranePolicy& exportPolicy, MembranePolicy& importPolicy);
// An internal capability which was previously exported is now being re-imported, i.e. a
// capability passed out of the membrane and then back in.
//
// The default implementation simply returns `internal`.
virtual Capability::Client exportExternal(
Capability::Client external, MembranePolicy& importPolicy, MembranePolicy& exportPolicy);
// An external capability which was previously imported is now being re-exported, i.e. a
// capability passed into the membrane and then back out.
//
// The default implementation simply returns `external`.
}; };
Capability::Client membrane(Capability::Client inner, kj::Own<MembranePolicy> policy); Capability::Client membrane(Capability::Client inner, kj::Own<MembranePolicy> policy);
......
...@@ -2394,7 +2394,7 @@ private: ...@@ -2394,7 +2394,7 @@ private:
kj::Own<ClientHook> resolvedCap) { kj::Own<ClientHook> resolvedCap) {
auto vpap = startCall(interfaceId, methodId, kj::mv(resolvedCap), kj::mv(context)); auto vpap = startCall(interfaceId, methodId, kj::mv(resolvedCap), kj::mv(context));
return kj::tuple(kj::mv(vpap.promise), kj::mv(vpap.pipeline)); return kj::tuple(kj::mv(vpap.promise), kj::mv(vpap.pipeline));
})).attach(addRef(*this)).split(); })).attach(addRef(*this), kj::mv(capability)).split();
return { return {
kj::mv(kj::get<0>(promises)), kj::mv(kj::get<0>(promises)),
......
...@@ -167,7 +167,18 @@ template <typename VatId, typename ProvisionId, typename RecipientId, ...@@ -167,7 +167,18 @@ template <typename VatId, typename ProvisionId, typename RecipientId,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>> typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcServer( RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network, VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
Capability::Client bootstrapInterface, RealmGatewayClient gateway); Capability::Client bootstrapInterface, RealmGatewayClient gateway)
CAPNP_DEPRECATED("Please transition to using MembranePolicy instead of RealmGateway.");
// ** DEPRECATED **
//
// This uses a RealmGateway to create a membrane between the external network and internal
// capabilites to translate save() requests. However, MembranePolicy (membrane.h) allows for the
// creation of much more powerful membranes and doesn't need to be tied to an RpcSystem.
// Applications should transition to using membranes instead of RealmGateway. RealmGateway will be
// removed in a future version of Cap'n Proto.
//
// Original description:
//
// Make an RPC server for a VatNetwork that resides in a different realm from the application. // Make an RPC server for a VatNetwork that resides in a different realm from the application.
// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format // The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format
// and the network's ("external") format. // and the network's ("external") format.
...@@ -186,7 +197,18 @@ template <typename VatId, typename ProvisionId, typename RecipientId, ...@@ -186,7 +197,18 @@ template <typename VatId, typename ProvisionId, typename RecipientId,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>> typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcServer( RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network, VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
BootstrapFactory<VatId>& bootstrapFactory, RealmGatewayClient gateway); BootstrapFactory<VatId>& bootstrapFactory, RealmGatewayClient gateway)
CAPNP_DEPRECATED("Please transition to using MembranePolicy instead of RealmGateway.");
// ** DEPRECATED **
//
// This uses a RealmGateway to create a membrane between the external network and internal
// capabilites to translate save() requests. However, MembranePolicy (membrane.h) allows for the
// creation of much more powerful membranes and doesn't need to be tied to an RpcSystem.
// Applications should transition to using membranes instead of RealmGateway. RealmGateway will be
// removed in a future version of Cap'n Proto.
//
// Original description:
//
// Make an RPC server that can serve different bootstrap interfaces to different clients via a // Make an RPC server that can serve different bootstrap interfaces to different clients via a
// BootstrapInterface and communicates with a different realm than the application is in via a // BootstrapInterface and communicates with a different realm than the application is in via a
// RealmGateway. // RealmGateway.
...@@ -232,7 +254,18 @@ template <typename VatId, typename ProvisionId, typename RecipientId, ...@@ -232,7 +254,18 @@ template <typename VatId, typename ProvisionId, typename RecipientId,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>> typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcClient( RpcSystem<VatId> makeRpcClient(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network, VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
RealmGatewayClient gateway); RealmGatewayClient gateway)
CAPNP_DEPRECATED("Please transition to using MembranePolicy instead of RealmGateway.");
// ** DEPRECATED **
//
// This uses a RealmGateway to create a membrane between the external network and internal
// capabilites to translate save() requests. However, MembranePolicy (membrane.h) allows for the
// creation of much more powerful membranes and doesn't need to be tied to an RpcSystem.
// Applications should transition to using membranes instead of RealmGateway. RealmGateway will be
// removed in a future version of Cap'n Proto.
//
// Original description:
//
// Make an RPC client for a VatNetwork that resides in a different realm from the application. // Make an RPC client for a VatNetwork that resides in a different realm from the application.
// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format // The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format
// and the network's ("external") format. // and the network's ("external") format.
......
...@@ -748,7 +748,6 @@ ChainPromiseNode::~ChainPromiseNode() noexcept(false) {} ...@@ -748,7 +748,6 @@ ChainPromiseNode::~ChainPromiseNode() noexcept(false) {}
void ChainPromiseNode::onReady(Event* event) noexcept { void ChainPromiseNode::onReady(Event* event) noexcept {
switch (state) { switch (state) {
case STEP1: case STEP1:
KJ_REQUIRE(onReadyEvent == nullptr, "onReady() can only be called once.");
onReadyEvent = event; onReadyEvent = event;
return; return;
case STEP2: case STEP2:
......
...@@ -49,7 +49,8 @@ kj::Maybe<size_t> ReadyInputStreamWrapper::read(kj::ArrayPtr<byte> dst) { ...@@ -49,7 +49,8 @@ kj::Maybe<size_t> ReadyInputStreamWrapper::read(kj::ArrayPtr<byte> dst) {
} else { } else {
content = kj::arrayPtr(buffer, n); content = kj::arrayPtr(buffer, n);
} }
}).attach(kj::defer([this]() {isPumping = false;})); isPumping = false;
});
}).fork(); }).fork();
} }
...@@ -98,7 +99,7 @@ kj::Maybe<size_t> ReadyOutputStreamWrapper::write(kj::ArrayPtr<const byte> data) ...@@ -98,7 +99,7 @@ kj::Maybe<size_t> ReadyOutputStreamWrapper::write(kj::ArrayPtr<const byte> data)
if (!isPumping) { if (!isPumping) {
isPumping = true; isPumping = true;
pumpTask = kj::evalNow([&]() { pumpTask = kj::evalNow([&]() {
return pump().attach(kj::defer([this]() {isPumping = false;})); return pump();
}).fork(); }).fork();
} }
...@@ -130,6 +131,7 @@ kj::Promise<void> ReadyOutputStreamWrapper::pump() { ...@@ -130,6 +131,7 @@ kj::Promise<void> ReadyOutputStreamWrapper::pump() {
if (filled > 0) { if (filled > 0) {
return pump(); return pump();
} else { } else {
isPumping = false;
return kj::READY_NOW; return kj::READY_NOW;
} }
}); });
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <kj/test.h> #include <kj/test.h>
#include <kj/async-io.h> #include <kj/async-io.h>
#include <stdlib.h> #include <stdlib.h>
#include <openssl/opensslv.h>
namespace kj { namespace kj {
namespace { namespace {
...@@ -344,6 +345,92 @@ KJ_TEST("TLS basics") { ...@@ -344,6 +345,92 @@ KJ_TEST("TLS basics") {
KJ_ASSERT(kj::StringPtr(buf) == "foo"); KJ_ASSERT(kj::StringPtr(buf) == "foo");
} }
KJ_TEST("TLS multiple messages") {
TlsTest test;
ErrorNexus e;
auto pipe = test.io.provider->newTwoWayPipe();
auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"));
auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1])));
auto client = clientPromise.wait(test.io.waitScope);
auto server = serverPromise.wait(test.io.waitScope);
auto writePromise = client->write("foo", 3)
.then([&]() { return client->write("bar", 3); });
char buf[4];
buf[3] = '\0';
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "foo");
writePromise = writePromise
.then([&]() { return client->write("baz", 3); });
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "bar");
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "baz");
auto readPromise = server->read(&buf, 3);
KJ_EXPECT(!readPromise.poll(test.io.waitScope));
writePromise = writePromise
.then([&]() { return client->write("qux", 3); });
readPromise.wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "qux");
}
kj::Promise<void> writeN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t count) {
if (count == 0) return kj::READY_NOW;
--count;
return stream.write(text.begin(), text.size())
.then([&stream, text, count]() {
return writeN(stream, text, count);
});
}
kj::Promise<void> readN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t count) {
if (count == 0) return kj::READY_NOW;
--count;
auto buf = kj::heapString(text.size());
auto promise = stream.read(buf.begin(), buf.size());
return promise.then(kj::mvCapture(buf, [&stream, text, count](kj::String buf) {
KJ_ASSERT(buf == text, buf, text, count);
return readN(stream, text, count);
}));
}
KJ_TEST("TLS full duplex") {
TlsTest test;
ErrorNexus e;
auto pipe = test.io.provider->newTwoWayPipe();
auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"));
auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1])));
auto client = clientPromise.wait(test.io.waitScope);
auto server = serverPromise.wait(test.io.waitScope);
auto writeUp = writeN(*client, "foo", 10000);
auto readDown = readN(*client, "bar", 10000);
KJ_EXPECT(!writeUp.poll(test.io.waitScope));
KJ_EXPECT(!readDown.poll(test.io.waitScope));
auto writeDown = writeN(*server, "bar", 10000);
auto readUp = readN(*server, "foo", 10000);
readUp.wait(test.io.waitScope);
readDown.wait(test.io.waitScope);
writeUp.wait(test.io.waitScope);
writeDown.wait(test.io.waitScope);
}
class TestSniCallback: public TlsSniCallback { class TestSniCallback: public TlsSniCallback {
public: public:
kj::Maybe<TlsKeypair> getKey(kj::StringPtr hostname) override { kj::Maybe<TlsKeypair> getKey(kj::StringPtr hostname) override {
...@@ -411,6 +498,13 @@ KJ_TEST("TLS certificate validation") { ...@@ -411,6 +498,13 @@ KJ_TEST("TLS certificate validation") {
"self signed certificate"); "self signed certificate");
} }
// BoringSSL seems to print error messages differently.
#ifdef OPENSSL_IS_BORINGSSL
#define SSL_MESSAGE(interesting, boring) boring
#else
#define SSL_MESSAGE(interesting, boring) interesting
#endif
KJ_TEST("TLS client certificate verification") { KJ_TEST("TLS client certificate verification") {
TlsContext::Options serverOptions = TlsTest::defaultServer(); TlsContext::Options serverOptions = TlsTest::defaultServer();
TlsContext::Options clientOptions = TlsTest::defaultClient(); TlsContext::Options clientOptions = TlsTest::defaultClient();
...@@ -427,9 +521,13 @@ KJ_TEST("TLS client certificate verification") { ...@@ -427,9 +521,13 @@ KJ_TEST("TLS client certificate verification") {
auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"); auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com");
auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1])); auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1]));
KJ_EXPECT_THROW_MESSAGE("peer did not return a certificate", KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("peer did not return a certificate",
"PEER_DID_NOT_RETURN_A_CERTIFICATE"),
serverPromise.wait(test.io.waitScope)); serverPromise.wait(test.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("alert handshake failure", KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("alert handshake failure",
"SSLV3_ALERT_HANDSHAKE_FAILURE"),
clientPromise.wait(test.io.waitScope)); clientPromise.wait(test.io.waitScope));
} }
...@@ -445,9 +543,13 @@ KJ_TEST("TLS client certificate verification") { ...@@ -445,9 +543,13 @@ KJ_TEST("TLS client certificate verification") {
auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"); auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com");
auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1])); auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1]));
KJ_EXPECT_THROW_MESSAGE("certificate verify failed", KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("certificate verify failed",
"CERTIFICATE_VERIFY_FAILED"),
serverPromise.wait(test.io.waitScope)); serverPromise.wait(test.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("alert unknown ca", KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("alert unknown ca",
"TLSV1_ALERT_UNKNOWN_CA"),
clientPromise.wait(test.io.waitScope)); clientPromise.wait(test.io.waitScope));
} }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <openssl/x509v3.h> #include <openssl/x509v3.h>
#include <openssl/evp.h> #include <openssl/evp.h>
#include <openssl/conf.h> #include <openssl/conf.h>
#include <openssl/ssl.h>
#include <openssl/tls1.h> #include <openssl/tls1.h>
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/vector.h> #include <kj/vector.h>
...@@ -489,6 +490,13 @@ TlsContext::Options::Options() ...@@ -489,6 +490,13 @@ TlsContext::Options::Options()
// Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll // Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll
// never bother. // never bother.
struct TlsContext::SniCallback {
// struct SniCallback exists only so that callback() can be declared in the .c++ file, since it
// references OpenSSL types.
static int callback(SSL* ssl, int* ad, void* arg);
};
TlsContext::TlsContext(Options options) { TlsContext::TlsContext(Options options) {
ensureOpenSslInitialized(); ensureOpenSslInitialized();
...@@ -573,21 +581,17 @@ TlsContext::TlsContext(Options options) { ...@@ -573,21 +581,17 @@ TlsContext::TlsContext(Options options) {
// honor options.sniCallback // honor options.sniCallback
KJ_IF_MAYBE(sni, options.sniCallback) { KJ_IF_MAYBE(sni, options.sniCallback) {
SSL_CTX_set_tlsext_servername_callback(ctx, &sniCallback); SSL_CTX_set_tlsext_servername_callback(ctx, &SniCallback::callback);
SSL_CTX_set_tlsext_servername_arg(ctx, sni); SSL_CTX_set_tlsext_servername_arg(ctx, sni);
} }
this->ctx = ctx; this->ctx = ctx;
} }
int TlsContext::sniCallback(void* sslp, int* ad, void* arg) { int TlsContext::SniCallback::callback(SSL* ssl, int* ad, void* arg) {
// The first parameter is actually type SSL*, but we didn't want to include the OpenSSL headers
// from our header.
//
// The third parameter is actually type TlsSniCallback*. // The third parameter is actually type TlsSniCallback*.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
SSL* ssl = reinterpret_cast<SSL*>(sslp);
TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg); TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg);
const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
...@@ -671,14 +675,14 @@ TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) { ...@@ -671,14 +675,14 @@ TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) {
} }
} }
TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem) { TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password) {
ensureOpenSslInitialized(); ensureOpenSslInitialized();
// const_cast apparently needed for older versions of OpenSSL. // const_cast apparently needed for older versions of OpenSSL.
BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size()); BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
KJ_DEFER(BIO_free(bio)); KJ_DEFER(BIO_free(bio));
pkey = PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr); pkey = PEM_read_bio_PrivateKey(bio, nullptr, &passwordCallback, &password);
if (pkey == nullptr) { if (pkey == nullptr) {
throwOpensslError(); throwOpensslError();
} }
...@@ -702,6 +706,18 @@ TlsPrivateKey::~TlsPrivateKey() noexcept(false) { ...@@ -702,6 +706,18 @@ TlsPrivateKey::~TlsPrivateKey() noexcept(false) {
EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey)); EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
} }
int TlsPrivateKey::passwordCallback(char* buf, int size, int rwflag, void* u) {
auto& password = *reinterpret_cast<kj::Maybe<kj::StringPtr>*>(u);
KJ_IF_MAYBE(p, password) {
int result = kj::min(p->size(), size);
memcpy(buf, p->begin(), result);
return result;
} else {
return 0;
}
}
// ======================================================================================= // =======================================================================================
// class TlsCertificate // class TlsCertificate
......
...@@ -125,7 +125,7 @@ public: ...@@ -125,7 +125,7 @@ public:
private: private:
void* ctx; // actually type SSL_CTX, but we don't want to #include the OpenSSL headers here void* ctx; // actually type SSL_CTX, but we don't want to #include the OpenSSL headers here
static int sniCallback(void* ssl, int* ad, void* arg); struct SniCallback;
}; };
class TlsPrivateKey { class TlsPrivateKey {
...@@ -137,10 +137,10 @@ public: ...@@ -137,10 +137,10 @@ public:
// RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to // RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to
// decrypt. // decrypt.
TlsPrivateKey(kj::StringPtr pem); TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password = nullptr);
// Parse a single PEM-encoded private key. Supports PKCS8 keys as well as "traditional format" // Parse a single PEM-encoded private key. Supports PKCS8 keys as well as "traditional format"
// RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to // RSA and DSA keys. A password may optionally be provided and will be used if the key is
// decrypt. // encrypted.
~TlsPrivateKey() noexcept(false); ~TlsPrivateKey() noexcept(false);
...@@ -158,6 +158,8 @@ private: ...@@ -158,6 +158,8 @@ private:
void* pkey; // actually type EVP_PKEY* void* pkey; // actually type EVP_PKEY*
friend class TlsContext; friend class TlsContext;
static int passwordCallback(char* buf, int size, int rwflag, void* u);
}; };
class TlsCertificate { class TlsCertificate {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment