Commit db844fd9 authored by gejun's avatar gejun

Misc improvements on SSL-related code

parent c813572d
......@@ -184,9 +184,9 @@ int main(int argc, char* argv[]) {
// Start the server.
brpc::ServerOptions options;
options.idle_timeout_sec = FLAGS_idle_timeout_s;
options.ssl_options.default_cert.certificate = FLAGS_certificate;
options.ssl_options.default_cert.private_key = FLAGS_private_key;
options.ssl_options.ciphers = FLAGS_ciphers;
options.mutable_ssl_options()->default_cert.certificate = FLAGS_certificate;
options.mutable_ssl_options()->default_cert.private_key = FLAGS_private_key;
options.mutable_ssl_options()->ciphers = FLAGS_ciphers;
if (server.Start(FLAGS_port, &options) != 0) {
LOG(ERROR) << "Fail to start HttpServer";
return -1;
......
......@@ -95,7 +95,9 @@ int main(int argc, char* argv[]) {
// Initialize the channel, NULL means using default options.
brpc::ChannelOptions options;
options.ssl_options = std::make_shared<brpc::ChannelSSLOptions>();
if (FLAGS_enable_ssl) {
options.mutable_ssl_options();
}
options.protocol = FLAGS_protocol;
options.connection_type = FLAGS_connection_type;
options.connect_timeout_ms = std::min(FLAGS_timeout_ms / 2, 100);
......
......@@ -82,9 +82,8 @@ int main(int argc, char* argv[]) {
// Start the server.
brpc::ServerOptions options;
options.ssl_options = std::make_shared<brpc::ServerSSLOptions>();
options.ssl_options->default_cert.certificate = "cert.pem";
options.ssl_options->default_cert.private_key = "key.pem";
options.mutable_ssl_options()->default_cert.certificate = "cert.pem";
options.mutable_ssl_options()->default_cert.private_key = "key.pem";
options.idle_timeout_sec = FLAGS_idle_timeout_s;
options.max_concurrency = FLAGS_max_concurrency;
options.internal_port = FLAGS_internal_port;
......
......@@ -22,6 +22,7 @@
#include "butil/time.h" // milliseconds_from_now
#include "butil/logging.h"
#include "butil/third_party/murmurhash3/murmurhash3.h"
#include "butil/strings/string_util.h"
#include "bthread/unstable.h" // bthread_timer_add
#include "brpc/socket_map.h" // SocketMapInsert
#include "brpc/compress.h"
......@@ -50,13 +51,19 @@ ChannelOptions::ChannelOptions()
, auth(NULL)
, retry_policy(NULL)
, ns_filter(NULL)
, connection_group(0)
{}
ChannelSSLOptions* ChannelOptions::mutable_ssl_options() {
if (!_ssl_options) {
_ssl_options.reset(new ChannelSSLOptions);
}
return _ssl_options.get();
}
static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) {
if (opt.auth == NULL &&
opt.ssl_options == NULL &&
opt.connection_group == 0) {
!opt.has_ssl_options() &&
opt.connection_group.empty()) {
// Returning zeroized result by default is more intuitive for users.
return ChannelSignature();
}
......@@ -68,23 +75,23 @@ static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) {
buf.clear();
butil::MurmurHash3_x64_128_Init(&mm_ctx, seed);
if (opt.connection_group) {
if (!opt.connection_group.empty()) {
buf.append("|conng=");
buf.append((char*)&opt.connection_group, sizeof(opt.connection_group));
buf.append(opt.connection_group);
}
if (opt.auth) {
buf.append("|auth=");
buf.append((char*)&opt.auth, sizeof(opt.auth));
}
const ChannelSSLOptions* ssl = opt.ssl_options.get();
if (ssl) {
if (opt.has_ssl_options()) {
const ChannelSSLOptions& ssl = opt.ssl_options();
buf.push_back('|');
buf.append(ssl->ciphers);
buf.append(ssl.ciphers);
buf.push_back('|');
buf.append(ssl->protocols);
buf.append(ssl.protocols);
buf.push_back('|');
buf.append(ssl->sni_name);
const VerifyOptions& verify = ssl->verify;
buf.append(ssl.sni_name);
const VerifyOptions& verify = ssl.verify;
buf.push_back('|');
buf.append((char*)&verify.verify_depth, sizeof(verify.verify_depth));
buf.push_back('|');
......@@ -95,8 +102,8 @@ static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) {
butil::MurmurHash3_x64_128_Update(&mm_ctx, buf.data(), buf.size());
buf.clear();
if (ssl) {
const CertInfo& cert = ssl->client_cert;
if (opt.has_ssl_options()) {
const CertInfo& cert = opt.ssl_options().client_cert;
if (!cert.certificate.empty()) {
// Certificate may be too long (PEM string) to fit into `buf'
butil::MurmurHash3_x64_128_Update(
......@@ -187,19 +194,13 @@ int Channel::InitChannelOptions(const ChannelOptions* options) {
if (_options.auth == NULL) {
_options.auth = policy::global_esp_authenticator();
}
} else if (_options.protocol == brpc::PROTOCOL_HTTP) {
if (_raw_server_address.compare(0, 5, "https") == 0) {
if (_options.ssl_options == NULL) {
_options.ssl_options = std::make_shared<ChannelSSLOptions>();
}
if (_options.ssl_options->sni_name.empty()) {
int port;
ParseHostAndPortFromURL(_raw_server_address.c_str(),
&_options.ssl_options->sni_name, &port);
}
}
}
// Normalize connection_group
std::string& cg = _options.connection_group;
if (!cg.empty() && (::isspace(cg.front()) || ::isspace(cg.back()))) {
butil::TrimWhitespace(cg, butil::TRIM_ALL, &cg);
}
return 0;
}
......@@ -229,8 +230,7 @@ int Channel::Init(const char* server_addr_and_port,
return -1;
}
}
_raw_server_address.assign(server_addr_and_port);
return Init(point, options);
return InitSingle(point, server_addr_and_port, options);
}
int Channel::Init(const char* server_addr, int port,
......@@ -252,25 +252,21 @@ int Channel::Init(const char* server_addr, int port,
return -1;
}
}
_raw_server_address.assign(server_addr);
return Init(point, options);
return InitSingle(point, server_addr, options);
}
static int CreateSocketSSLContext(const ChannelOptions& options,
ChannelSignature* sig,
std::shared_ptr<SocketSSLContext>* ssl_ctx) {
if (options.ssl_options != NULL) {
*sig = ComputeChannelSignature(options);
SSL_CTX* raw_ctx = CreateClientSSLContext(*options.ssl_options);
if (options.has_ssl_options()) {
SSL_CTX* raw_ctx = CreateClientSSLContext(options.ssl_options());
if (!raw_ctx) {
LOG(ERROR) << "Fail to CreateClientSSLContext";
return -1;
}
*ssl_ctx = std::make_shared<SocketSSLContext>();
(*ssl_ctx)->raw_ctx = raw_ctx;
(*ssl_ctx)->sni_name = options.ssl_options->sni_name;
(*ssl_ctx)->sni_name = options.ssl_options().sni_name;
} else {
sig->Reset();
(*ssl_ctx) = NULL;
}
return 0;
......@@ -278,19 +274,32 @@ static int CreateSocketSSLContext(const ChannelOptions& options,
int Channel::Init(butil::EndPoint server_addr_and_port,
const ChannelOptions* options) {
return InitSingle(server_addr_and_port, "", options);
}
int Channel::InitSingle(const butil::EndPoint& server_addr_and_port,
const char* raw_server_address,
const ChannelOptions* options) {
GlobalInitializeOrDie();
if (InitChannelOptions(options) != 0) {
return -1;
}
if (_options.protocol == brpc::PROTOCOL_HTTP &&
::strncmp(raw_server_address, "https://", 8) == 0) {
if (_options.mutable_ssl_options()->sni_name.empty()) {
ParseURL(raw_server_address,
NULL, &_options.mutable_ssl_options()->sni_name, NULL);
}
}
const int port = server_addr_and_port.port;
if (port < 0 || port > 65535) {
LOG(ERROR) << "Invalid port=" << port;
return -1;
}
_server_address = server_addr_and_port;
ChannelSignature sig;
const ChannelSignature sig = ComputeChannelSignature(_options);
std::shared_ptr<SocketSSLContext> ssl_ctx;
if (CreateSocketSSLContext(_options, &sig, &ssl_ctx) != 0) {
if (CreateSocketSSLContext(_options, &ssl_ctx) != 0) {
return -1;
}
if (SocketMapInsert(SocketMapKey(server_addr_and_port, sig),
......@@ -312,6 +321,13 @@ int Channel::Init(const char* ns_url,
if (InitChannelOptions(options) != 0) {
return -1;
}
if (_options.protocol == brpc::PROTOCOL_HTTP &&
::strncmp(ns_url, "https://", 8) == 0) {
if (_options.mutable_ssl_options()->sni_name.empty()) {
ParseURL(ns_url,
NULL, &_options.mutable_ssl_options()->sni_name, NULL);
}
}
LoadBalancerWithNaming* lb = new (std::nothrow) LoadBalancerWithNaming;
if (NULL == lb) {
LOG(FATAL) << "Fail to new LoadBalancerWithNaming";
......@@ -320,7 +336,8 @@ int Channel::Init(const char* ns_url,
GetNamingServiceThreadOptions ns_opt;
ns_opt.succeed_without_server = _options.succeed_without_server;
ns_opt.log_succeed_without_server = _options.log_succeed_without_server;
if (CreateSocketSSLContext(_options, &ns_opt.channel_signature, &ns_opt.ssl_ctx) != 0) {
ns_opt.channel_signature = ComputeChannelSignature(_options);
if (CreateSocketSSLContext(_options, &ns_opt.ssl_ctx) != 0) {
return -1;
}
if (lb->Init(ns_url, lb_name, _options.ns_filter, &ns_opt) != 0) {
......
......@@ -21,10 +21,11 @@
// To brpc developers: This is a header included by user, don't depend
// on internal structures, use opaque pointers instead.
#include <ostream> // std::ostream
#include "bthread/errno.h" // Redefine errno
#include "butil/intrusive_ptr.hpp" // butil::intrusive_ptr
#include "brpc/ssl_option.h" // ChannelSSLOptions
#include <ostream> // std::ostream
#include "bthread/errno.h" // Redefine errno
#include "butil/intrusive_ptr.hpp" // butil::intrusive_ptr
#include "butil/ptr_container.h"
#include "brpc/ssl_options.h" // ChannelSSLOptions
#include "brpc/channel_base.h" // ChannelBase
#include "brpc/adaptive_protocol_type.h" // AdaptiveProtocolType
#include "brpc/adaptive_connection_type.h" // AdaptiveConnectionType
......@@ -90,8 +91,10 @@ struct ChannelOptions {
bool log_succeed_without_server;
// SSL related options. Refer to `ChannelSSLOptions' for details
std::shared_ptr<ChannelSSLOptions> ssl_options;
bool has_ssl_options() const { return _ssl_options != NULL; }
const ChannelSSLOptions& ssl_options() const { return *_ssl_options.get(); }
ChannelSSLOptions* mutable_ssl_options();
// Turn on authentication for this channel if `auth' is not NULL.
// Note `auth' will not be deleted by channel and must remain valid when
// the channel is being used.
......@@ -113,10 +116,16 @@ struct ChannelOptions {
// Default: NULL
const NamingServiceFilter* ns_filter;
// Channels with same connection_group share connections. In an another
// word, set to a different value to not share connections.
// Default: 0
int connection_group;
// Channels with same connection_group share connections.
// In other words, set to a different value to stop sharing connections.
// Case-sensitive, leading and trailing spaces are ignored.
// Default: ""
std::string connection_group;
private:
// SSLOptions is large and not often used, allocate it on heap to
// prevent ChannelOptions from being bloated in most cases.
butil::PtrContainer<ChannelSSLOptions> _ssl_options;
};
// A Channel represents a communication line to one server or multiple servers
......@@ -195,8 +204,10 @@ protected:
static void CallMethodImpl(Controller* controller, SharedLoadBalancer* lb);
int InitChannelOptions(const ChannelOptions* options);
int InitSingle(const butil::EndPoint& server_addr_and_port,
const char* raw_server_address,
const ChannelOptions* options);
std::string _raw_server_address;
butil::EndPoint _server_address;
SocketId _server_id;
Protocol::SerializeRequest _serialize_request;
......
......@@ -30,16 +30,26 @@ namespace brpc {
struct NSKey {
std::string protocol;
std::string service_name;
ChannelSignature channel_signature;
NSKey(const std::string& prot_in,
const std::string& service_in,
const ChannelSignature& sig)
: protocol(prot_in), service_name(service_in), channel_signature(sig) {
}
};
struct NSKeyHasher {
size_t operator()(const NSKey& nskey) const {
return butil::DefaultHasher<std::string>()(nskey.service_name)
* 101 + butil::DefaultHasher<std::string>()(nskey.protocol);
size_t h = butil::DefaultHasher<std::string>()(nskey.protocol);
h = h * 101 + butil::DefaultHasher<std::string>()(nskey.service_name);
h = h * 101 + nskey.channel_signature.data[1];
return h;
}
};
inline bool operator==(const NSKey& k1, const NSKey& k2) {
return k1.protocol == k2.protocol &&
k1.service_name == k2.service_name;
k1.service_name == k2.service_name &&
k1.channel_signature == k2.channel_signature;
}
typedef butil::FlatMap<NSKey, NamingServiceThread*, NSKeyHasher> NamingServiceMap;
......@@ -59,7 +69,7 @@ NamingServiceThread::Actions::~Actions() {
// Remove all sockets from SocketMap
for (std::vector<ServerNode>::const_iterator it = _last_servers.begin();
it != _last_servers.end(); ++it) {
const SocketMapKey key(it->addr, _owner->_options.channel_signature);
const SocketMapKey key(*it, _owner->_options.channel_signature);
SocketMapRemove(key);
}
EndWait(0);
......@@ -112,7 +122,7 @@ void NamingServiceThread::Actions::ResetServers(
// TODO: For each unique SocketMapKey (i.e. SSL settings), insert a new
// Socket. SocketMapKey may be passed through AddWatcher. Make sure
// to pick those Sockets with the right settings during OnAddedServers
const SocketMapKey key(_added[i].addr, _owner->_options.channel_signature);
const SocketMapKey key(_added[i], _owner->_options.channel_signature);
CHECK_EQ(0, SocketMapInsert(key, &tagged_id.id, _owner->_options.ssl_ctx));
_added_sockets.push_back(tagged_id);
}
......@@ -121,7 +131,7 @@ void NamingServiceThread::Actions::ResetServers(
for (size_t i = 0; i < _removed.size(); ++i) {
ServerNodeWithId tagged_id;
tagged_id.node = _removed[i];
const SocketMapKey key(_removed[i].addr, _owner->_options.channel_signature);
const SocketMapKey key(_removed[i], _owner->_options.channel_signature);
CHECK_EQ(0, SocketMapFind(key, &tagged_id.id));
_removed_sockets.push_back(tagged_id);
}
......@@ -173,7 +183,7 @@ void NamingServiceThread::Actions::ResetServers(
for (size_t i = 0; i < _removed.size(); ++i) {
// TODO: Remove all Sockets that have the same address in SocketMapKey.peer
// We may need another data structure to avoid linear cost
const SocketMapKey key(_removed[i].addr, _owner->_options.channel_signature);
const SocketMapKey key(_removed[i], _owner->_options.channel_signature);
SocketMapRemove(key);
}
......@@ -220,7 +230,7 @@ NamingServiceThread::~NamingServiceThread() {
RPC_VLOG << "~NamingServiceThread(" << *this << ')';
// Remove from g_nsthread_map first
if (!_protocol.empty()) {
const NSKey key = { _protocol, _service_name };
const NSKey key(_protocol, _service_name, _options.channel_signature);
std::unique_lock<pthread_mutex_t> mu(g_nsthread_map_mutex);
if (g_nsthread_map != NULL) {
NamingServiceThread** ptr = g_nsthread_map->seek(key);
......@@ -410,9 +420,8 @@ int GetNamingServiceThread(
LOG(ERROR) << "Unknown protocol=" << protocol;
return -1;
}
NSKey key;
key.protocol = protocol;
key.service_name = service_name;
const NSKey key(protocol, service_name,
(options ? options->channel_signature : ChannelSignature()));
bool new_thread = false;
butil::intrusive_ptr<NamingServiceThread> nsthread;
{
......
......@@ -22,7 +22,7 @@
// For some versions of openssl, SSL_* are defined inside this header
#include <openssl/ossl_typ.h>
#include "brpc/socket_id.h" // SocketId
#include "brpc/ssl_option.h" // ServerSSLOptions
#include "brpc/ssl_options.h" // ServerSSLOptions
namespace brpc {
......
......@@ -333,8 +333,9 @@ static void GlobalInitializeOrDieImpl() {
#endif
NamingServiceExtension()->RegisterOrDie("file", &g_ext->fns);
NamingServiceExtension()->RegisterOrDie("list", &g_ext->lns);
NamingServiceExtension()->RegisterOrDie("http", &g_ext->dns);
NamingServiceExtension()->RegisterOrDie("redis", &g_ext->dns);
NamingServiceExtension()->RegisterOrDie("http", &g_ext->dns);
NamingServiceExtension()->RegisterOrDie("https", &g_ext->dns);
NamingServiceExtension()->RegisterOrDie("redis", &g_ext->dns);
NamingServiceExtension()->RegisterOrDie("remotefile", &g_ext->rfns);
NamingServiceExtension()->RegisterOrDie("consul", &g_ext->cns);
......
......@@ -20,29 +20,15 @@
#include <vector> // std::vector
#include <string> // std::string
#include <ostream> // std::ostream
#include "butil/endpoint.h" // butil::EndPoint
#include "butil/macros.h" // BAIDU_CONCAT
#include "butil/endpoint.h" // butil::EndPoint
#include "butil/macros.h" // BAIDU_CONCAT
#include "brpc/describable.h"
#include "brpc/destroyable.h"
#include "brpc/extension.h" // Extension<T>
#include "brpc/extension.h" // Extension<T>
#include "brpc/server_node.h" // ServerNode
namespace brpc {
// Representing a server inside a NamingService.
struct ServerNode {
ServerNode() {}
ServerNode(butil::ip_t ip, int port, const std::string& tag2)
: addr(ip, port), tag(tag2) {}
ServerNode(const butil::EndPoint& pt, const std::string& tag2)
: addr(pt), tag(tag2) {}
ServerNode(butil::ip_t ip, int port) : addr(ip, port) {}
explicit ServerNode(const butil::EndPoint& pt) : addr(pt) {}
butil::EndPoint addr;
std::string tag;
};
// Continuing actions to added/removed servers.
// NOTE: You don't have to implement this class.
class NamingServiceActions {
......@@ -84,21 +70,6 @@ inline Extension<const NamingService>* NamingServiceExtension() {
return Extension<const NamingService>::instance();
}
inline bool operator<(const ServerNode& n1, const ServerNode& n2)
{ return n1.addr != n2.addr ? (n1.addr < n2.addr) : (n1.tag < n2.tag); }
inline bool operator==(const ServerNode& n1, const ServerNode& n2)
{ return n1.addr == n2.addr && n1.tag == n2.tag; }
inline bool operator!=(const ServerNode& n1, const ServerNode& n2)
{ return !(n1 == n2); }
inline std::ostream& operator<<(std::ostream& os, const ServerNode& n) {
os << n.addr;
if (!n.tag.empty()) {
os << "(tag=" << n.tag << ')';
}
return os;
}
} // namespace brpc
#endif // BRPC_NAMING_SERVICE_H
......@@ -1297,14 +1297,28 @@ void ProcessHttpRequest(InputMessageBase *msg) {
}
bool ParseHttpServerAddress(butil::EndPoint* point, const char* server_addr_and_port) {
std::string schema;
std::string host;
int port = -1;
if (ParseHostAndPortFromURL(server_addr_and_port, &host, &port) != 0) {
if (ParseURL(server_addr_and_port, &schema, &host, &port) != 0) {
LOG(ERROR) << "Invalid address=`" << server_addr_and_port << '\'';
return false;
}
if (schema.empty() || schema == "http") {
if (port < 0) {
port = 80;
}
} else if (schema == "https") {
if (port < 0) {
port = 443;
}
} else {
LOG(ERROR) << "Invalid schema=`" << schema << '\'';
return false;
}
if (str2endpoint(host.c_str(), port, point) != 0 &&
hostname2endpoint(host.c_str(), port, point) != 0) {
LOG(ERROR) << "Invalid address=`" << host << '\'';
LOG(ERROR) << "Invalid host=" << host << " port=" << port;
return false;
}
return true;
......
......@@ -145,6 +145,13 @@ ServerOptions::ServerOptions()
}
}
ServerSSLOptions* ServerOptions::mutable_ssl_options() {
if (!_ssl_options) {
_ssl_options.reset(new ServerSSLOptions);
}
return _ssl_options.get();
}
Server::MethodProperty::OpaqueParams::OpaqueParams()
: is_tabbed(false)
, allow_http_body_to_pb(true)
......@@ -840,8 +847,8 @@ int Server::StartInternal(const butil::ip_t& ip,
// Free last SSL contexts
FreeSSLContexts();
if (_options.ssl_options) {
CertInfo& default_cert = _options.ssl_options->default_cert;
if (_options.has_ssl_options()) {
CertInfo& default_cert = _options.mutable_ssl_options()->default_cert;
if (default_cert.certificate.empty()) {
LOG(ERROR) << "default_cert is empty";
return -1;
......@@ -851,7 +858,7 @@ int Server::StartInternal(const butil::ip_t& ip,
}
_default_ssl_ctx = _ssl_ctx_map.begin()->second.ctx;
const std::vector<CertInfo>& certs = _options.ssl_options->certs;
const std::vector<CertInfo>& certs = _options.mutable_ssl_options()->certs;
for (size_t i = 0; i < certs.size(); ++i) {
if (AddCertificate(certs[i]) != 0) {
return -1;
......@@ -1795,7 +1802,7 @@ Server::FindServicePropertyByName(const butil::StringPiece& name) const {
}
int Server::AddCertificate(const CertInfo& cert) {
if (_options.ssl_options == NULL) {
if (!_options.has_ssl_options()) {
LOG(ERROR) << "ServerOptions.ssl_options is not configured yet";
return -1;
}
......@@ -1810,7 +1817,7 @@ int Server::AddCertificate(const CertInfo& cert) {
ssl_ctx.filters = cert.sni_filters;
ssl_ctx.ctx = std::make_shared<SocketSSLContext>();
SSL_CTX* raw_ctx = CreateServerSSLContext(cert.certificate, cert.private_key,
*_options.ssl_options, &ssl_ctx.filters);
_options.ssl_options(), &ssl_ctx.filters);
if (raw_ctx == NULL) {
return -1;
}
......@@ -1860,7 +1867,7 @@ bool Server::AddCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) {
}
int Server::RemoveCertificate(const CertInfo& cert) {
if (_options.ssl_options == NULL) {
if (!_options.has_ssl_options()) {
LOG(ERROR) << "ServerOptions.ssl_options is not configured yet";
return -1;
}
......@@ -1905,7 +1912,7 @@ bool Server::RemoveCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) {
}
int Server::ResetCertificates(const std::vector<CertInfo>& certs) {
if (_options.ssl_options == NULL) {
if (!_options.has_ssl_options()) {
LOG(ERROR) << "ServerOptions.ssl_options is not configured yet";
return -1;
}
......@@ -1918,8 +1925,8 @@ int Server::ResetCertificates(const std::vector<CertInfo>& certs) {
// Add default certficiate into tmp_map first since it can't be reloaded
std::string default_cert_key =
_options.ssl_options->default_cert.certificate
+ _options.ssl_options->default_cert.private_key;
_options.ssl_options().default_cert.certificate
+ _options.ssl_options().default_cert.private_key;
tmp_map[default_cert_key] = _ssl_ctx_map[default_cert_key];
for (size_t i = 0; i < certs.size(); ++i) {
......@@ -1935,7 +1942,7 @@ int Server::ResetCertificates(const std::vector<CertInfo>& certs) {
ssl_ctx.ctx = std::make_shared<SocketSSLContext>();
ssl_ctx.ctx->raw_ctx = CreateServerSSLContext(
certs[i].certificate, certs[i].private_key,
*_options.ssl_options, &ssl_ctx.filters);
_options.ssl_options(), &ssl_ctx.filters);
if (ssl_ctx.ctx->raw_ctx == NULL) {
return -1;
}
......@@ -2086,7 +2093,7 @@ int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl,
int* al, Server* server) {
(void)al;
const char* hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
bool strict_sni = server->_options.ssl_options->strict_sni;
bool strict_sni = server->_options.ssl_options().strict_sni;
if (hostname == NULL) {
return strict_sni ? SSL_TLSEXT_ERR_ALERT_FATAL : SSL_TLSEXT_ERR_NOACK;
}
......
......@@ -24,13 +24,14 @@
#include "bthread/errno.h" // Redefine errno
#include "bthread/bthread.h" // Server may need some bthread functions,
// e.g. bthread_usleep
#include <google/protobuf/service.h> // google::protobuf::Service
#include <google/protobuf/service.h> // google::protobuf::Service
#include "butil/macros.h" // DISALLOW_COPY_AND_ASSIGN
#include "butil/containers/doubly_buffered_data.h" // DoublyBufferedData
#include "bvar/bvar.h"
#include "butil/containers/case_ignored_flat_map.h" // [CaseIgnored]FlatMap
#include "butil/ptr_container.h"
#include "brpc/controller.h" // brpc::Controller
#include "brpc/ssl_option.h" // ServerSSLOptions
#include "brpc/ssl_options.h" // ServerSSLOptions
#include "brpc/describable.h" // User often needs this
#include "brpc/data_factory.h" // DataFactory
#include "brpc/builtin/tabbed.h"
......@@ -199,7 +200,9 @@ struct ServerOptions {
bool security_mode() const { return internal_port >= 0 || !has_builtin_services; }
// SSL related options. Refer to `ServerSSLOptions' for details
std::shared_ptr<ServerSSLOptions> ssl_options;
bool has_ssl_options() const { return _ssl_options != NULL; }
const ServerSSLOptions& ssl_options() const { return *_ssl_options.get(); }
ServerSSLOptions* mutable_ssl_options();
// [CAUTION] This option is for implementing specialized http proxies,
// most users don't need it. Don't change this option unless you fully
......@@ -225,6 +228,11 @@ struct ServerOptions {
// All names inside must be valid, check protocols name in global.cpp
// Default: empty (all protocols)
std::string enabled_protocols;
private:
// SSLOptions is large and not often used, allocate it on heap to
// prevent ServerOptions from being bloated in most cases.
butil::PtrContainer<ServerSSLOptions> _ssl_options;
};
// This struct is originally designed to contain basic statistics of the
......
......@@ -229,7 +229,7 @@ int SocketMap::Insert(const SocketMapKey& key, SocketId* id,
}
SocketId tmp_id;
SocketOptions opt;
opt.remote_side = key.peer;
opt.remote_side = key.peer.addr;
opt.initial_ssl_ctx = ssl_ctx;
if (_options.socket_creator->CreateSocket(opt, &tmp_id) != 0) {
PLOG(FATAL) << "Fail to create socket to " << key.peer;
......
......@@ -17,18 +17,17 @@
#ifndef BRPC_SOCKET_MAP_H
#define BRPC_SOCKET_MAP_H
#include <vector> // std::vector
#include "bvar/bvar.h" // bvar::PassiveStatus
#include "butil/containers/flat_map.h" // FlatMap
#include <vector> // std::vector
#include "bvar/bvar.h" // bvar::PassiveStatus
#include "butil/containers/flat_map.h" // FlatMap
#include "brpc/socket_id.h" // SockdetId
#include "brpc/options.pb.h" // ProtocolType
#include "brpc/input_messenger.h" // InputMessageHandler
#include "brpc/server_node.h" // ServerNode
namespace brpc {
// Global mapping from remote-side to out-going sockets created by Channels.
// Different signature means that the Channel needs separate sockets.
struct ChannelSignature {
uint64_t data[2];
......@@ -52,8 +51,11 @@ struct SocketMapKey {
SocketMapKey(const butil::EndPoint& pt, const ChannelSignature& cs)
: peer(pt), channel_signature(cs)
{}
butil::EndPoint peer;
SocketMapKey(const ServerNode& sn, const ChannelSignature& cs)
: peer(sn), channel_signature(cs)
{}
ServerNode peer;
ChannelSignature channel_signature;
};
......@@ -62,9 +64,11 @@ inline bool operator==(const SocketMapKey& k1, const SocketMapKey& k2) {
};
struct SocketMapKeyHasher {
std::size_t operator()(const SocketMapKey& key) const {
return butil::DefaultHasher<butil::EndPoint>()(key.peer) ^
key.channel_signature.data[1];
size_t operator()(const SocketMapKey& key) const {
size_t h = butil::DefaultHasher<butil::EndPoint>()(key.peer.addr);
h = h * 101 + butil::DefaultHasher<std::string>()(key.peer.tag);
h = h * 101 + key.channel_signature.data[1];
return h;
}
};
......
......@@ -224,8 +224,8 @@ int URI::SetHttpURL(const char* url) {
return 0;
}
int ParseHostAndPortFromURL(const char* url, std::string* host_out,
int* port_out) {
int ParseURL(const char* url,
std::string* schema_out, std::string* host_out, int* port_out) {
const char* p = url;
// skip heading blanks
if (*p == ' ') {
......@@ -235,7 +235,6 @@ int ParseHostAndPortFromURL(const char* url, std::string* host_out,
// Find end of host, locate schema and user_info during the searching
bool need_schema = true;
bool need_user_info = true;
butil::StringPiece schema;
for (; true; ++p) {
const char action = g_url_parsing_fast_action_map[(int)*p];
if (action == URI_PARSE_CONTINUE) {
......@@ -247,7 +246,9 @@ int ParseHostAndPortFromURL(const char* url, std::string* host_out,
if (*p == ':') {
if (p[1] == '/' && p[2] == '/' && need_schema) {
need_schema = false;
schema.set(start, p - start);
if (schema_out) {
schema_out->assign(start, p - start);
}
p += 2;
start = p + 1;
}
......@@ -266,15 +267,12 @@ int ParseHostAndPortFromURL(const char* url, std::string* host_out,
}
int port = -1;
const char* host_end = SplitHostAndPort(start, p, &port);
if (port < 0) {
if (schema.empty() || schema == "http") {
port = 80;
} else if (schema == "https") {
port = 443;
}
if (host_out) {
host_out->assign(start, host_end - start);
}
if (port_out) {
*port_out = port;
}
host_out->assign(start, host_end - start);
*port_out = port;
return 0;
}
......
......@@ -155,10 +155,9 @@ friend class HttpMessage;
mutable QueryMap _query_map;
};
// Parse host and port from `url'.
// When port is absent, it's set to 80 for http and 443 for https.
// Parse host/port/schema from `url' if the corresponding parameter is not NULL.
// Returns 0 on success, -1 otherwise.
int ParseHostAndPortFromURL(const char* url, std::string* host, int* port);
int ParseURL(const char* url, std::string* schema, std::string* host, int* port);
inline void URI::SetQuery(const std::string& key, const std::string& value) {
get_query_map()[key] = value;
......
......@@ -141,6 +141,7 @@ class MyEchoService : public ::test::EchoService {
if (req->code() != 0) {
res->add_code_list(req->code());
}
res->set_receiving_socket_id(cntl->_current_call.sending_sock->id());
}
};
......@@ -258,14 +259,17 @@ protected:
}
void SetUpChannel(brpc::Channel* channel,
bool single_server, bool short_connection,
const brpc::Authenticator* auth = NULL) {
bool single_server,
bool short_connection,
const brpc::Authenticator* auth = NULL,
std::string connection_group = std::string()) {
brpc::ChannelOptions opt;
if (short_connection) {
opt.connection_type = brpc::CONNECTION_TYPE_SHORT;
}
opt.auth = auth;
opt.max_retry = 0;
opt.connection_group = connection_group;
if (single_server) {
EXPECT_EQ(0, channel->Init(_ep, &opt));
} else {
......@@ -405,6 +409,7 @@ protected:
EXPECT_EQ(0, cntl.ErrorCode())
<< single_server << ", " << async << ", " << short_connection;
const uint64_t receiving_socket_id = res.receiving_socket_id();
EXPECT_EQ(0, cntl.sub_count());
EXPECT_TRUE(NULL == cntl.sub(-1));
EXPECT_TRUE(NULL == cntl.sub(0));
......@@ -419,7 +424,48 @@ protected:
}
} else {
EXPECT_GE(1ul, _messenger.ConnectionCount());
}
}
if (single_server && !short_connection) {
// Reuse the connection
brpc::Channel channel2;
SetUpChannel(&channel2, single_server, short_connection);
cntl.Reset();
req.Clear();
res.Clear();
req.set_message(__FUNCTION__);
CallMethod(&channel2, &cntl, &req, &res, async);
EXPECT_EQ(0, cntl.ErrorCode())
<< single_server << ", " << async << ", " << short_connection;
EXPECT_EQ(receiving_socket_id, res.receiving_socket_id());
// A different connection_group does not reuse the connection
brpc::Channel channel3;
SetUpChannel(&channel3, single_server, short_connection,
NULL, "another_group");
cntl.Reset();
req.Clear();
res.Clear();
req.set_message(__FUNCTION__);
CallMethod(&channel3, &cntl, &req, &res, async);
EXPECT_EQ(0, cntl.ErrorCode())
<< single_server << ", " << async << ", " << short_connection;
const uint64_t receiving_socket_id2 = res.receiving_socket_id();
EXPECT_NE(receiving_socket_id, receiving_socket_id2);
// Channel in the same connection_group reuses the connection
// note that the leading/trailing spaces should be trimed.
brpc::Channel channel4;
SetUpChannel(&channel4, single_server, short_connection,
NULL, " another_group ");
cntl.Reset();
req.Clear();
res.Clear();
req.set_message(__FUNCTION__);
CallMethod(&channel4, &cntl, &req, &res, async);
EXPECT_EQ(0, cntl.ErrorCode())
<< single_server << ", " << async << ", " << short_connection;
EXPECT_EQ(receiving_socket_id2, res.receiving_socket_id());
}
StopAndJoin();
}
......@@ -1547,6 +1593,10 @@ protected:
void TestAuthentication(bool single_server,
bool async, bool short_connection) {
std::cout << " *** single=" << single_server
<< " async=" << async
<< " short=" << short_connection << std::endl;
ASSERT_EQ(0, StartAccept(_ep));
MyAuthenticator auth;
brpc::Channel channel;
......@@ -1809,7 +1859,7 @@ TEST_F(ChannelTest, init_as_single_server) {
ASSERT_EQ(ep, channel._server_address);
brpc::SocketId id;
ASSERT_EQ(0, brpc::SocketMapFind(ep, &id));
ASSERT_EQ(0, brpc::SocketMapFind(brpc::SocketMapKey(ep), &id));
ASSERT_EQ(id, channel._server_id);
const int NUM = 10;
......
......@@ -127,7 +127,7 @@ TEST_F(SocketMapTest, max_pool_size) {
} //namespace
int main(int argc, char* argv[]) {
butil::str2endpoint("127.0.0.1:12345", &g_key.peer);
butil::str2endpoint("127.0.0.1:12345", &g_key.peer.addr);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -95,7 +95,7 @@ TEST_F(SSLTest, sanity) {
brpc::CertInfo cert;
cert.certificate = "cert1.crt";
cert.private_key = "cert1.key";
options.ssl_options.default_cert = cert;
options.mutable_ssl_options()->default_cert = cert;
EchoServiceImpl echo_svc;
ASSERT_EQ(0, server.AddService(
......@@ -108,7 +108,7 @@ TEST_F(SSLTest, sanity) {
{
brpc::Channel channel;
brpc::ChannelOptions coptions;
coptions.ssl_options.enable = true;
coptions.mutable_ssl_options();
ASSERT_EQ(0, channel.Init("localhost", port, &coptions));
brpc::Controller cntl;
......@@ -124,7 +124,7 @@ TEST_F(SSLTest, sanity) {
{
brpc::Channel channel;
brpc::ChannelOptions coptions;
coptions.ssl_options.enable = true;
coptions.mutable_ssl_options();
ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions));
for (int i = 0; i < NUM; ++i) {
google::protobuf::Closure* thrd_func =
......@@ -140,7 +140,7 @@ TEST_F(SSLTest, sanity) {
brpc::Channel channel;
brpc::ChannelOptions coptions;
coptions.protocol = "http";
coptions.ssl_options.enable = true;
coptions.mutable_ssl_options();
ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions));
for (int i = 0; i < NUM; ++i) {
google::protobuf::Closure* thrd_func =
......@@ -160,8 +160,7 @@ void CheckCert(const char* cname, const char* cert) {
const int port = 8613;
brpc::Channel channel;
brpc::ChannelOptions coptions;
coptions.ssl_options.enable = true;
coptions.ssl_options.sni_name = cname;
coptions.mutable_ssl_options()->sni_name = cname;
ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions));
SendMultipleRPC(&channel, 1);
......@@ -199,14 +198,14 @@ TEST_F(SSLTest, ssl_sni) {
cert.certificate = "cert1.crt";
cert.private_key = "cert1.key";
cert.sni_filters.push_back("cert1.com");
options.ssl_options.default_cert = cert;
options.mutable_ssl_options()->default_cert = cert;
}
{
brpc::CertInfo cert;
cert.certificate = GetRawPemString("cert2.crt");
cert.private_key = GetRawPemString("cert2.key");
cert.sni_filters.push_back("*.cert2.com");
options.ssl_options.certs.push_back(cert);
options.mutable_ssl_options()->certs.push_back(cert);
}
EchoServiceImpl echo_svc;
ASSERT_EQ(0, server.AddService(
......@@ -230,7 +229,7 @@ TEST_F(SSLTest, ssl_reload) {
cert.certificate = "cert1.crt";
cert.private_key = "cert1.key";
cert.sni_filters.push_back("cert1.com");
options.ssl_options.default_cert = cert;
options.mutable_ssl_options()->default_cert = cert;
}
EchoServiceImpl echo_svc;
ASSERT_EQ(0, server.AddService(
......@@ -318,7 +317,6 @@ TEST_F(SSLTest, ssl_perf) {
ASSERT_GT(servfd, 0);
brpc::ChannelSSLOptions opt;
opt.enable = true;
SSL_CTX* cli_ctx = brpc::CreateClientSSLContext(opt);
SSL_CTX* serv_ctx =
brpc::CreateServerSSLContext("cert1.crt", "cert1.key",
......
......@@ -20,9 +20,11 @@ TEST(URITest, everything) {
ASSERT_EQ(*uri.GetQuery("wd"), "uri");
ASSERT_FALSE(uri.GetQuery("nonkey"));
std::string schema;
std::string host_out;
int port_out = -1;
brpc::ParseHostAndPortFromURL(uri_str.c_str(), &host_out, &port_out);
brpc::ParseURL(uri_str.c_str(), &schema, &host_out, &port_out);
ASSERT_EQ("foobar", schema);
ASSERT_EQ("www.baidu.com", host_out);
ASSERT_EQ(80, port_out);
}
......
......@@ -15,6 +15,7 @@ message EchoRequest {
message EchoResponse {
required string message = 1;
repeated int32 code_list = 2;
optional uint64 receiving_socket_id = 3;
};
message ComboRequest {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment