Unverified Commit 08dd2e79 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #612 from capnproto/membrane-fixes

Various tweaks needed by Sandstorm HTTP changes
parents 02be7235 743e7cec
......@@ -321,27 +321,41 @@ public:
static kj::Own<ClientHook> wrap(ClientHook& cap, MembranePolicy& policy, bool reverse) {
if (cap.getBrand() == MEMBRANE_BRAND) {
auto& otherMembrane = kj::downcast<MembraneHook>(cap);
if (otherMembrane.policy.get() == &policy && otherMembrane.reverse == !reverse) {
auto& rootPolicy = policy.rootPolicy();
if (&otherMembrane.policy->rootPolicy() == &rootPolicy &&
otherMembrane.reverse == !reverse) {
// Capability that passed across the membrane one way is now passing back the other way.
// Unwrap it rather than double-wrap it.
return otherMembrane.inner->addRef();
Capability::Client unwrapped(otherMembrane.inner->addRef());
return ClientHook::from(
reverse ? rootPolicy.importInternal(kj::mv(unwrapped), *otherMembrane.policy, policy)
: rootPolicy.exportExternal(kj::mv(unwrapped), *otherMembrane.policy, policy));
}
}
return kj::refcounted<MembraneHook>(cap.addRef(), policy.addRef(), reverse);
return ClientHook::from(
reverse ? policy.importExternal(Capability::Client(cap.addRef()))
: policy.exportInternal(Capability::Client(cap.addRef())));
}
static kj::Own<ClientHook> wrap(kj::Own<ClientHook> cap, MembranePolicy& policy, bool reverse) {
if (cap->getBrand() == MEMBRANE_BRAND) {
auto& otherMembrane = kj::downcast<MembraneHook>(*cap);
if (otherMembrane.policy.get() == &policy && otherMembrane.reverse == !reverse) {
auto& rootPolicy = policy.rootPolicy();
if (&otherMembrane.policy->rootPolicy() == &rootPolicy &&
otherMembrane.reverse == !reverse) {
// Capability that passed across the membrane one way is now passing back the other way.
// Unwrap it rather than double-wrap it.
return otherMembrane.inner->addRef();
Capability::Client unwrapped(otherMembrane.inner->addRef());
return ClientHook::from(
reverse ? rootPolicy.importInternal(kj::mv(unwrapped), *otherMembrane.policy, policy)
: rootPolicy.exportExternal(kj::mv(unwrapped), *otherMembrane.policy, policy));
}
}
return kj::refcounted<MembraneHook>(kj::mv(cap), policy.addRef(), reverse);
return ClientHook::from(
reverse ? policy.importExternal(Capability::Client(kj::mv(cap)))
: policy.exportInternal(Capability::Client(kj::mv(cap))));
}
Request<AnyPointer, AnyPointer> newCall(
......@@ -359,7 +373,8 @@ public:
// something outside the membrane later. We have to wait before we actually redirect,
// otherwise behavior will differ depending on whether the promise is resolved.
KJ_IF_MAYBE(p, whenMoreResolved()) {
return newLocalPromiseClient(kj::mv(*p))->newCall(interfaceId, methodId, sizeHint);
return newLocalPromiseClient(p->attach(addRef()))
->newCall(interfaceId, methodId, sizeHint);
}
return ClientHook::from(kj::mv(*r))->newCall(interfaceId, methodId, sizeHint);
......@@ -386,7 +401,8 @@ public:
// something outside the membrane later. We have to wait before we actually redirect,
// otherwise behavior will differ depending on whether the promise is resolved.
KJ_IF_MAYBE(p, whenMoreResolved()) {
return newLocalPromiseClient(kj::mv(*p))->call(interfaceId, methodId, kj::mv(context));
return newLocalPromiseClient(p->attach(addRef()))
->call(interfaceId, methodId, kj::mv(context));
}
return ClientHook::from(kj::mv(*r))->call(interfaceId, methodId, kj::mv(context));
......@@ -467,6 +483,26 @@ kj::Own<ClientHook> membrane(kj::Own<ClientHook> inner, MembranePolicy& policy,
} // namespace
Capability::Client MembranePolicy::importExternal(Capability::Client external) {
return Capability::Client(kj::refcounted<MembraneHook>(
ClientHook::from(kj::mv(external)), addRef(), true));
}
Capability::Client MembranePolicy::exportInternal(Capability::Client internal) {
return Capability::Client(kj::refcounted<MembraneHook>(
ClientHook::from(kj::mv(internal)), addRef(), false));
}
Capability::Client MembranePolicy::importInternal(
Capability::Client internal, MembranePolicy& exportPolicy, MembranePolicy& importPolicy) {
return kj::mv(internal);
}
Capability::Client MembranePolicy::exportExternal(
Capability::Client external, MembranePolicy& importPolicy, MembranePolicy& exportPolicy) {
return kj::mv(external);
}
Capability::Client membrane(Capability::Client inner, kj::Own<MembranePolicy> policy) {
return Capability::Client(membrane(
ClientHook::from(kj::mv(inner)), *policy, false));
......
......@@ -114,6 +114,56 @@ public:
// After the revocation promise has rejected, inboundCall() and outboundCall() will still be
// invoked for new calls, but the `target` passed to them will be a capability that always
// rethrows the revocation exception.
// ---------------------------------------------------------------------------
// Control over importing and exporting.
//
// Most membranes should not override these methods. The default behavior is that a capability
// that crosses the membrane is wrapped in it, and if the wrapped version crosses back the other
// way, it is unwrapped.
virtual Capability::Client importExternal(Capability::Client external);
// An external capability is crossing into the membrane. Returns the capability that should
// substitute for it when called from the inside.
//
// The default implementation creates a capability that invokes this MembranePolicy. E.g. all
// calls will invoke outboundCall().
//
// Note that reverseMembrane(cap, policy) normally calls policy->importExternal(cap), unless
// `cap` itself was originally returned by the default implementation of exportInternal(), in
// which case importInternal() is called instead.
virtual Capability::Client exportInternal(Capability::Client internal);
// An internal capability is crossing out of the membrane. Returns the capability that should
// substitute for it when called from the outside.
//
// The default implementation creates a capability that invokes this MembranePolicy. E.g. all
// calls will invoke inboundCall().
//
// Note that membrane(cap, policy) normally calls policy->exportInternal(cap), unless `cap`
// itself was originally returned by the default implementation of exportInternal(), in which
// case importInternal() is called instead.
virtual MembranePolicy& rootPolicy() { return *this; }
// If two policies return the same value for rootPolicy(), then a capability imported through
// one can be exported through the other, and vice versa. `importInternal()` and
// `exportExternal()` will always be called on the root policy, passing the two child policies
// as parameters. If you don't override rootPolicy(), then the policy references passed to
// importInternal() and exportExternal() will always be references to *this.
virtual Capability::Client importInternal(
Capability::Client internal, MembranePolicy& exportPolicy, MembranePolicy& importPolicy);
// An internal capability which was previously exported is now being re-imported, i.e. a
// capability passed out of the membrane and then back in.
//
// The default implementation simply returns `internal`.
virtual Capability::Client exportExternal(
Capability::Client external, MembranePolicy& importPolicy, MembranePolicy& exportPolicy);
// An external capability which was previously imported is now being re-exported, i.e. a
// capability passed into the membrane and then back out.
//
// The default implementation simply returns `external`.
};
Capability::Client membrane(Capability::Client inner, kj::Own<MembranePolicy> policy);
......
......@@ -2394,7 +2394,7 @@ private:
kj::Own<ClientHook> resolvedCap) {
auto vpap = startCall(interfaceId, methodId, kj::mv(resolvedCap), kj::mv(context));
return kj::tuple(kj::mv(vpap.promise), kj::mv(vpap.pipeline));
})).attach(addRef(*this)).split();
})).attach(addRef(*this), kj::mv(capability)).split();
return {
kj::mv(kj::get<0>(promises)),
......
......@@ -167,7 +167,18 @@ template <typename VatId, typename ProvisionId, typename RecipientId,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
Capability::Client bootstrapInterface, RealmGatewayClient gateway);
Capability::Client bootstrapInterface, RealmGatewayClient gateway)
CAPNP_DEPRECATED("Please transition to using MembranePolicy instead of RealmGateway.");
// ** DEPRECATED **
//
// This uses a RealmGateway to create a membrane between the external network and internal
// capabilites to translate save() requests. However, MembranePolicy (membrane.h) allows for the
// creation of much more powerful membranes and doesn't need to be tied to an RpcSystem.
// Applications should transition to using membranes instead of RealmGateway. RealmGateway will be
// removed in a future version of Cap'n Proto.
//
// Original description:
//
// Make an RPC server for a VatNetwork that resides in a different realm from the application.
// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format
// and the network's ("external") format.
......@@ -186,7 +197,18 @@ template <typename VatId, typename ProvisionId, typename RecipientId,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcServer(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
BootstrapFactory<VatId>& bootstrapFactory, RealmGatewayClient gateway);
BootstrapFactory<VatId>& bootstrapFactory, RealmGatewayClient gateway)
CAPNP_DEPRECATED("Please transition to using MembranePolicy instead of RealmGateway.");
// ** DEPRECATED **
//
// This uses a RealmGateway to create a membrane between the external network and internal
// capabilites to translate save() requests. However, MembranePolicy (membrane.h) allows for the
// creation of much more powerful membranes and doesn't need to be tied to an RpcSystem.
// Applications should transition to using membranes instead of RealmGateway. RealmGateway will be
// removed in a future version of Cap'n Proto.
//
// Original description:
//
// Make an RPC server that can serve different bootstrap interfaces to different clients via a
// BootstrapInterface and communicates with a different realm than the application is in via a
// RealmGateway.
......@@ -232,7 +254,18 @@ template <typename VatId, typename ProvisionId, typename RecipientId,
typename ExternalRef = _::ExternalRefFromRealmGatewayClient<RealmGatewayClient>>
RpcSystem<VatId> makeRpcClient(
VatNetwork<VatId, ProvisionId, RecipientId, ThirdPartyCapId, JoinResult>& network,
RealmGatewayClient gateway);
RealmGatewayClient gateway)
CAPNP_DEPRECATED("Please transition to using MembranePolicy instead of RealmGateway.");
// ** DEPRECATED **
//
// This uses a RealmGateway to create a membrane between the external network and internal
// capabilites to translate save() requests. However, MembranePolicy (membrane.h) allows for the
// creation of much more powerful membranes and doesn't need to be tied to an RpcSystem.
// Applications should transition to using membranes instead of RealmGateway. RealmGateway will be
// removed in a future version of Cap'n Proto.
//
// Original description:
//
// Make an RPC client for a VatNetwork that resides in a different realm from the application.
// The given RealmGateway is used to translate SturdyRefs between the app's ("internal") format
// and the network's ("external") format.
......
......@@ -748,7 +748,6 @@ ChainPromiseNode::~ChainPromiseNode() noexcept(false) {}
void ChainPromiseNode::onReady(Event* event) noexcept {
switch (state) {
case STEP1:
KJ_REQUIRE(onReadyEvent == nullptr, "onReady() can only be called once.");
onReadyEvent = event;
return;
case STEP2:
......
......@@ -49,7 +49,8 @@ kj::Maybe<size_t> ReadyInputStreamWrapper::read(kj::ArrayPtr<byte> dst) {
} else {
content = kj::arrayPtr(buffer, n);
}
}).attach(kj::defer([this]() {isPumping = false;}));
isPumping = false;
});
}).fork();
}
......@@ -98,7 +99,7 @@ kj::Maybe<size_t> ReadyOutputStreamWrapper::write(kj::ArrayPtr<const byte> data)
if (!isPumping) {
isPumping = true;
pumpTask = kj::evalNow([&]() {
return pump().attach(kj::defer([this]() {isPumping = false;}));
return pump();
}).fork();
}
......@@ -130,6 +131,7 @@ kj::Promise<void> ReadyOutputStreamWrapper::pump() {
if (filled > 0) {
return pump();
} else {
isPumping = false;
return kj::READY_NOW;
}
});
......
......@@ -25,6 +25,7 @@
#include <kj/test.h>
#include <kj/async-io.h>
#include <stdlib.h>
#include <openssl/opensslv.h>
namespace kj {
namespace {
......@@ -344,6 +345,92 @@ KJ_TEST("TLS basics") {
KJ_ASSERT(kj::StringPtr(buf) == "foo");
}
KJ_TEST("TLS multiple messages") {
TlsTest test;
ErrorNexus e;
auto pipe = test.io.provider->newTwoWayPipe();
auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"));
auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1])));
auto client = clientPromise.wait(test.io.waitScope);
auto server = serverPromise.wait(test.io.waitScope);
auto writePromise = client->write("foo", 3)
.then([&]() { return client->write("bar", 3); });
char buf[4];
buf[3] = '\0';
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "foo");
writePromise = writePromise
.then([&]() { return client->write("baz", 3); });
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "bar");
server->read(&buf, 3).wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "baz");
auto readPromise = server->read(&buf, 3);
KJ_EXPECT(!readPromise.poll(test.io.waitScope));
writePromise = writePromise
.then([&]() { return client->write("qux", 3); });
readPromise.wait(test.io.waitScope);
KJ_ASSERT(kj::StringPtr(buf) == "qux");
}
kj::Promise<void> writeN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t count) {
if (count == 0) return kj::READY_NOW;
--count;
return stream.write(text.begin(), text.size())
.then([&stream, text, count]() {
return writeN(stream, text, count);
});
}
kj::Promise<void> readN(kj::AsyncIoStream& stream, kj::StringPtr text, size_t count) {
if (count == 0) return kj::READY_NOW;
--count;
auto buf = kj::heapString(text.size());
auto promise = stream.read(buf.begin(), buf.size());
return promise.then(kj::mvCapture(buf, [&stream, text, count](kj::String buf) {
KJ_ASSERT(buf == text, buf, text, count);
return readN(stream, text, count);
}));
}
KJ_TEST("TLS full duplex") {
TlsTest test;
ErrorNexus e;
auto pipe = test.io.provider->newTwoWayPipe();
auto clientPromise = e.wrap(test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com"));
auto serverPromise = e.wrap(test.tlsServer.wrapServer(kj::mv(pipe.ends[1])));
auto client = clientPromise.wait(test.io.waitScope);
auto server = serverPromise.wait(test.io.waitScope);
auto writeUp = writeN(*client, "foo", 10000);
auto readDown = readN(*client, "bar", 10000);
KJ_EXPECT(!writeUp.poll(test.io.waitScope));
KJ_EXPECT(!readDown.poll(test.io.waitScope));
auto writeDown = writeN(*server, "bar", 10000);
auto readUp = readN(*server, "foo", 10000);
readUp.wait(test.io.waitScope);
readDown.wait(test.io.waitScope);
writeUp.wait(test.io.waitScope);
writeDown.wait(test.io.waitScope);
}
class TestSniCallback: public TlsSniCallback {
public:
kj::Maybe<TlsKeypair> getKey(kj::StringPtr hostname) override {
......@@ -411,6 +498,13 @@ KJ_TEST("TLS certificate validation") {
"self signed certificate");
}
// BoringSSL seems to print error messages differently.
#ifdef OPENSSL_IS_BORINGSSL
#define SSL_MESSAGE(interesting, boring) boring
#else
#define SSL_MESSAGE(interesting, boring) interesting
#endif
KJ_TEST("TLS client certificate verification") {
TlsContext::Options serverOptions = TlsTest::defaultServer();
TlsContext::Options clientOptions = TlsTest::defaultClient();
......@@ -427,9 +521,13 @@ KJ_TEST("TLS client certificate verification") {
auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com");
auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1]));
KJ_EXPECT_THROW_MESSAGE("peer did not return a certificate",
KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("peer did not return a certificate",
"PEER_DID_NOT_RETURN_A_CERTIFICATE"),
serverPromise.wait(test.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("alert handshake failure",
KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("alert handshake failure",
"SSLV3_ALERT_HANDSHAKE_FAILURE"),
clientPromise.wait(test.io.waitScope));
}
......@@ -445,9 +543,13 @@ KJ_TEST("TLS client certificate verification") {
auto clientPromise = test.tlsClient.wrapClient(kj::mv(pipe.ends[0]), "example.com");
auto serverPromise = test.tlsServer.wrapServer(kj::mv(pipe.ends[1]));
KJ_EXPECT_THROW_MESSAGE("certificate verify failed",
KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("certificate verify failed",
"CERTIFICATE_VERIFY_FAILED"),
serverPromise.wait(test.io.waitScope));
KJ_EXPECT_THROW_MESSAGE("alert unknown ca",
KJ_EXPECT_THROW_MESSAGE(
SSL_MESSAGE("alert unknown ca",
"TLSV1_ALERT_UNKNOWN_CA"),
clientPromise.wait(test.io.waitScope));
}
......
......@@ -30,6 +30,7 @@
#include <openssl/x509v3.h>
#include <openssl/evp.h>
#include <openssl/conf.h>
#include <openssl/ssl.h>
#include <openssl/tls1.h>
#include <kj/debug.h>
#include <kj/vector.h>
......@@ -489,6 +490,13 @@ TlsContext::Options::Options()
// Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll
// never bother.
struct TlsContext::SniCallback {
// struct SniCallback exists only so that callback() can be declared in the .c++ file, since it
// references OpenSSL types.
static int callback(SSL* ssl, int* ad, void* arg);
};
TlsContext::TlsContext(Options options) {
ensureOpenSslInitialized();
......@@ -573,21 +581,17 @@ TlsContext::TlsContext(Options options) {
// honor options.sniCallback
KJ_IF_MAYBE(sni, options.sniCallback) {
SSL_CTX_set_tlsext_servername_callback(ctx, &sniCallback);
SSL_CTX_set_tlsext_servername_callback(ctx, &SniCallback::callback);
SSL_CTX_set_tlsext_servername_arg(ctx, sni);
}
this->ctx = ctx;
}
int TlsContext::sniCallback(void* sslp, int* ad, void* arg) {
// The first parameter is actually type SSL*, but we didn't want to include the OpenSSL headers
// from our header.
//
int TlsContext::SniCallback::callback(SSL* ssl, int* ad, void* arg) {
// The third parameter is actually type TlsSniCallback*.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
SSL* ssl = reinterpret_cast<SSL*>(sslp);
TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg);
const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
......@@ -671,14 +675,14 @@ TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) {
}
}
TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem) {
TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password) {
ensureOpenSslInitialized();
// const_cast apparently needed for older versions of OpenSSL.
BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
KJ_DEFER(BIO_free(bio));
pkey = PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr);
pkey = PEM_read_bio_PrivateKey(bio, nullptr, &passwordCallback, &password);
if (pkey == nullptr) {
throwOpensslError();
}
......@@ -702,6 +706,18 @@ TlsPrivateKey::~TlsPrivateKey() noexcept(false) {
EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
}
int TlsPrivateKey::passwordCallback(char* buf, int size, int rwflag, void* u) {
auto& password = *reinterpret_cast<kj::Maybe<kj::StringPtr>*>(u);
KJ_IF_MAYBE(p, password) {
int result = kj::min(p->size(), size);
memcpy(buf, p->begin(), result);
return result;
} else {
return 0;
}
}
// =======================================================================================
// class TlsCertificate
......
......@@ -125,7 +125,7 @@ public:
private:
void* ctx; // actually type SSL_CTX, but we don't want to #include the OpenSSL headers here
static int sniCallback(void* ssl, int* ad, void* arg);
struct SniCallback;
};
class TlsPrivateKey {
......@@ -137,10 +137,10 @@ public:
// RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to
// decrypt.
TlsPrivateKey(kj::StringPtr pem);
TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password = nullptr);
// Parse a single PEM-encoded private key. Supports PKCS8 keys as well as "traditional format"
// RSA and DSA keys. Does not accept encrypted keys; it is the caller's responsibility to
// decrypt.
// RSA and DSA keys. A password may optionally be provided and will be used if the key is
// encrypted.
~TlsPrivateKey() noexcept(false);
......@@ -158,6 +158,8 @@ private:
void* pkey; // actually type EVP_PKEY*
friend class TlsContext;
static int passwordCallback(char* buf, int size, int rwflag, void* u);
};
class TlsCertificate {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment