Commit 62d0abd0 authored by gejun's avatar gejun

SSL supports NS & connection_group support

parent 0f621581
...@@ -95,7 +95,7 @@ int main(int argc, char* argv[]) { ...@@ -95,7 +95,7 @@ int main(int argc, char* argv[]) {
// Initialize the channel, NULL means using default options. // Initialize the channel, NULL means using default options.
brpc::ChannelOptions options; brpc::ChannelOptions options;
options.ssl_options.enable = FLAGS_enable_ssl; options.ssl_options = std::make_shared<brpc::ChannelSSLOptions>();
options.protocol = FLAGS_protocol; options.protocol = FLAGS_protocol;
options.connection_type = FLAGS_connection_type; options.connection_type = FLAGS_connection_type;
options.connect_timeout_ms = std::min(FLAGS_timeout_ms / 2, 100); options.connect_timeout_ms = std::min(FLAGS_timeout_ms / 2, 100);
......
...@@ -82,8 +82,9 @@ int main(int argc, char* argv[]) { ...@@ -82,8 +82,9 @@ int main(int argc, char* argv[]) {
// Start the server. // Start the server.
brpc::ServerOptions options; brpc::ServerOptions options;
options.ssl_options.default_cert.certificate = "cert.pem"; options.ssl_options = std::make_shared<brpc::ServerSSLOptions>();
options.ssl_options.default_cert.private_key = "key.pem"; options.ssl_options->default_cert.certificate = "cert.pem";
options.ssl_options->default_cert.private_key = "key.pem";
options.idle_timeout_sec = FLAGS_idle_timeout_s; options.idle_timeout_sec = FLAGS_idle_timeout_s;
options.max_concurrency = FLAGS_max_concurrency; options.max_concurrency = FLAGS_max_concurrency;
options.internal_port = FLAGS_internal_port; options.internal_port = FLAGS_internal_port;
......
...@@ -44,8 +44,8 @@ Acceptor::~Acceptor() { ...@@ -44,8 +44,8 @@ Acceptor::~Acceptor() {
Join(); Join();
} }
int Acceptor::StartAccept( int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec,
int listened_fd, int idle_timeout_sec, SSL_CTX* ssl_ctx) { const std::shared_ptr<SocketSSLContext>& ssl_ctx) {
if (listened_fd < 0) { if (listened_fd < 0) {
LOG(FATAL) << "Invalid listened_fd=" << listened_fd; LOG(FATAL) << "Invalid listened_fd=" << listened_fd;
return -1; return -1;
...@@ -271,7 +271,7 @@ void Acceptor::OnNewConnectionsUntilEAGAIN(Socket* acception) { ...@@ -271,7 +271,7 @@ void Acceptor::OnNewConnectionsUntilEAGAIN(Socket* acception) {
options.remote_side = butil::EndPoint(*(sockaddr_in*)&in_addr); options.remote_side = butil::EndPoint(*(sockaddr_in*)&in_addr);
options.user = acception->user(); options.user = acception->user();
options.on_edge_triggered_events = InputMessenger::OnNewMessages; options.on_edge_triggered_events = InputMessenger::OnNewMessages;
options.ssl_ctx = am->_ssl_ctx; options.initial_ssl_ctx = am->_ssl_ctx;
if (Socket::Create(options, &socket_id) != 0) { if (Socket::Create(options, &socket_id) != 0) {
LOG(ERROR) << "Fail to create Socket"; LOG(ERROR) << "Fail to create Socket";
continue; continue;
......
...@@ -53,7 +53,8 @@ public: ...@@ -53,7 +53,8 @@ public:
// transmission for `idle_timeout_sec' will be closed automatically iff // transmission for `idle_timeout_sec' will be closed automatically iff
// `idle_timeout_sec' > 0 // `idle_timeout_sec' > 0
// Return 0 on success, -1 otherwise. // Return 0 on success, -1 otherwise.
int StartAccept(int listened_fd, int idle_timeout_sec, SSL_CTX* ssl_ctx); int StartAccept(int listened_fd, int idle_timeout_sec,
const std::shared_ptr<SocketSSLContext>& ssl_ctx);
// [thread-safe] Stop accepting connections. // [thread-safe] Stop accepting connections.
// `closewait_ms' is not used anymore. // `closewait_ms' is not used anymore.
...@@ -104,8 +105,7 @@ private: ...@@ -104,8 +105,7 @@ private:
// The map containing all the accepted sockets // The map containing all the accepted sockets
SocketMap _socket_map; SocketMap _socket_map;
// Not owner std::shared_ptr<SocketSSLContext> _ssl_ctx;
SSL_CTX* _ssl_ctx;
}; };
} // namespace brpc } // namespace brpc
......
...@@ -192,10 +192,13 @@ void ConnectionsService::PrintConnections( ...@@ -192,10 +192,13 @@ void ConnectionsService::PrintConnections(
// slow (because we have many connections here). // slow (because we have many connections here).
int pref_index = ptr->preferred_index(); int pref_index = ptr->preferred_index();
SocketUniquePtr first_sub; SocketUniquePtr first_sub;
int pooled_count = -1;
if (ptr->fd() < 0) {
int numfree = 0; int numfree = 0;
int numinflight = 0; int numinflight = 0;
if (ptr->fd() < 0) { if (ptr->GetPooledSocketStats(&numfree, &numinflight)) {
ptr->GetPooledSocketStats(&numfree, &numinflight); pooled_count = numfree + numinflight;
}
// Check preferred_index of any pooled sockets. // Check preferred_index of any pooled sockets.
ptr->ListPooledSockets(&first_id, 1); ptr->ListPooledSockets(&first_id, 1);
if (!first_id.empty()) { if (!first_id.empty()) {
...@@ -263,11 +266,11 @@ void ConnectionsService::PrintConnections( ...@@ -263,11 +266,11 @@ void ConnectionsService::PrintConnections(
} }
os << SSLStateToYesNo(ptr->ssl_state(), use_html) << bar; os << SSLStateToYesNo(ptr->ssl_state(), use_html) << bar;
char protname[32]; char protname[32];
if (!ptr->CreatedByConnect()) { if (pooled_count < 0) {
snprintf(protname, sizeof(protname), "%s", pref_prot); snprintf(protname, sizeof(protname), "%s", pref_prot);
} else { } else {
snprintf(protname, sizeof(protname), "%s*%d", pref_prot, snprintf(protname, sizeof(protname), "%s*%d", pref_prot,
numfree + numinflight); pooled_count);
} }
os << min_width(protname, 12) << bar; os << min_width(protname, 12) << bar;
if (ptr->fd() >= 0) { if (ptr->fd() >= 0) {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include "butil/time.h" // milliseconds_from_now #include "butil/time.h" // milliseconds_from_now
#include "butil/logging.h" #include "butil/logging.h"
#include "butil/third_party/murmurhash3/murmurhash3.h"
#include "bthread/unstable.h" // bthread_timer_add #include "bthread/unstable.h" // bthread_timer_add
#include "brpc/socket_map.h" // SocketMapInsert #include "brpc/socket_map.h" // SocketMapInsert
#include "brpc/compress.h" #include "brpc/compress.h"
...@@ -32,7 +33,6 @@ ...@@ -32,7 +33,6 @@
#include "brpc/details/usercode_backup_pool.h" // TooManyUserCode #include "brpc/details/usercode_backup_pool.h" // TooManyUserCode
#include "brpc/policy/esp_authenticator.h" #include "brpc/policy/esp_authenticator.h"
namespace brpc { namespace brpc {
DECLARE_bool(enable_rpcz); DECLARE_bool(enable_rpcz);
...@@ -50,8 +50,73 @@ ChannelOptions::ChannelOptions() ...@@ -50,8 +50,73 @@ ChannelOptions::ChannelOptions()
, auth(NULL) , auth(NULL)
, retry_policy(NULL) , retry_policy(NULL)
, ns_filter(NULL) , ns_filter(NULL)
, connection_group(0)
{} {}
static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) {
if (opt.auth == NULL &&
opt.ssl_options == NULL &&
opt.connection_group == 0) {
// Returning zeroized result by default is more intuitive for users.
return ChannelSignature();
}
uint32_t seed = 0;
std::string buf;
buf.reserve(1024);
butil::MurmurHash3_x64_128_Context mm_ctx;
do {
buf.clear();
butil::MurmurHash3_x64_128_Init(&mm_ctx, seed);
if (opt.connection_group) {
buf.append("|conng=");
buf.append((char*)&opt.connection_group, sizeof(opt.connection_group));
}
if (opt.auth) {
buf.append("|auth=");
buf.append((char*)&opt.auth, sizeof(opt.auth));
}
const ChannelSSLOptions* ssl = opt.ssl_options.get();
if (ssl) {
buf.push_back('|');
buf.append(ssl->ciphers);
buf.push_back('|');
buf.append(ssl->protocols);
buf.push_back('|');
buf.append(ssl->sni_name);
const VerifyOptions& verify = ssl->verify;
buf.push_back('|');
buf.append((char*)&verify.verify_depth, sizeof(verify.verify_depth));
buf.push_back('|');
buf.append(verify.ca_file_path);
} else {
// All disabled ChannelSSLOptions are the same
}
butil::MurmurHash3_x64_128_Update(&mm_ctx, buf.data(), buf.size());
buf.clear();
if (ssl) {
const CertInfo& cert = ssl->client_cert;
if (!cert.certificate.empty()) {
// Certificate may be too long (PEM string) to fit into `buf'
butil::MurmurHash3_x64_128_Update(
&mm_ctx, cert.certificate.data(), cert.certificate.size());
butil::MurmurHash3_x64_128_Update(
&mm_ctx, cert.private_key.data(), cert.private_key.size());
}
}
// sni_filters has no effect in ChannelSSLOptions
ChannelSignature result;
butil::MurmurHash3_x64_128_Final(result.data, &mm_ctx);
if (result != ChannelSignature()) {
// the empty result is reserved for default case and cannot
// be used, increment the seed and retry.
return result;
}
++seed;
} while (true);
}
Channel::Channel(ProfilerLinker) Channel::Channel(ProfilerLinker)
: _server_id((SocketId)-1) : _server_id((SocketId)-1)
, _serialize_request(NULL) , _serialize_request(NULL)
...@@ -62,9 +127,8 @@ Channel::Channel(ProfilerLinker) ...@@ -62,9 +127,8 @@ Channel::Channel(ProfilerLinker)
Channel::~Channel() { Channel::~Channel() {
if (_server_id != (SocketId)-1) { if (_server_id != (SocketId)-1) {
SocketMapRemove(SocketMapKey(_server_address, const ChannelSignature sig = ComputeChannelSignature(_options);
_options.ssl_options, SocketMapRemove(SocketMapKey(_server_address, sig));
_options.auth));
} }
} }
...@@ -125,11 +189,13 @@ int Channel::InitChannelOptions(const ChannelOptions* options) { ...@@ -125,11 +189,13 @@ int Channel::InitChannelOptions(const ChannelOptions* options) {
} }
} else if (_options.protocol == brpc::PROTOCOL_HTTP) { } else if (_options.protocol == brpc::PROTOCOL_HTTP) {
if (_raw_server_address.compare(0, 5, "https") == 0) { if (_raw_server_address.compare(0, 5, "https") == 0) {
_options.ssl_options.enable = true; if (_options.ssl_options == NULL) {
if (_options.ssl_options.sni_name.empty()) { _options.ssl_options = std::make_shared<ChannelSSLOptions>();
}
if (_options.ssl_options->sni_name.empty()) {
int port; int port;
ParseHostAndPortFromURL(_raw_server_address.c_str(), ParseHostAndPortFromURL(_raw_server_address.c_str(),
&_options.ssl_options.sni_name, &port); &_options.ssl_options->sni_name, &port);
} }
} }
} }
...@@ -190,6 +256,26 @@ int Channel::Init(const char* server_addr, int port, ...@@ -190,6 +256,26 @@ int Channel::Init(const char* server_addr, int port,
return Init(point, options); return Init(point, options);
} }
static int CreateSocketSSLContext(const ChannelOptions& options,
ChannelSignature* sig,
std::shared_ptr<SocketSSLContext>* ssl_ctx) {
if (options.ssl_options != NULL) {
*sig = ComputeChannelSignature(options);
SSL_CTX* raw_ctx = CreateClientSSLContext(*options.ssl_options);
if (!raw_ctx) {
LOG(ERROR) << "Fail to CreateClientSSLContext";
return -1;
}
*ssl_ctx = std::make_shared<SocketSSLContext>();
(*ssl_ctx)->raw_ctx = raw_ctx;
(*ssl_ctx)->sni_name = options.ssl_options->sni_name;
} else {
sig->Reset();
(*ssl_ctx) = NULL;
}
return 0;
}
int Channel::Init(butil::EndPoint server_addr_and_port, int Channel::Init(butil::EndPoint server_addr_and_port,
const ChannelOptions* options) { const ChannelOptions* options) {
GlobalInitializeOrDie(); GlobalInitializeOrDie();
...@@ -202,9 +288,13 @@ int Channel::Init(butil::EndPoint server_addr_and_port, ...@@ -202,9 +288,13 @@ int Channel::Init(butil::EndPoint server_addr_and_port,
return -1; return -1;
} }
_server_address = server_addr_and_port; _server_address = server_addr_and_port;
if (SocketMapInsert(SocketMapKey(server_addr_and_port, ChannelSignature sig;
_options.ssl_options, std::shared_ptr<SocketSSLContext> ssl_ctx;
_options.auth), &_server_id) != 0) { if (CreateSocketSSLContext(_options, &sig, &ssl_ctx) != 0) {
return -1;
}
if (SocketMapInsert(SocketMapKey(server_addr_and_port, sig),
&_server_id, ssl_ctx) != 0) {
LOG(ERROR) << "Fail to insert into SocketMap"; LOG(ERROR) << "Fail to insert into SocketMap";
return -1; return -1;
} }
...@@ -230,6 +320,9 @@ int Channel::Init(const char* ns_url, ...@@ -230,6 +320,9 @@ int Channel::Init(const char* ns_url,
GetNamingServiceThreadOptions ns_opt; GetNamingServiceThreadOptions ns_opt;
ns_opt.succeed_without_server = _options.succeed_without_server; ns_opt.succeed_without_server = _options.succeed_without_server;
ns_opt.log_succeed_without_server = _options.log_succeed_without_server; ns_opt.log_succeed_without_server = _options.log_succeed_without_server;
if (CreateSocketSSLContext(_options, &ns_opt.channel_signature, &ns_opt.ssl_ctx) != 0) {
return -1;
}
if (lb->Init(ns_url, lb_name, _options.ns_filter, &ns_opt) != 0) { if (lb->Init(ns_url, lb_name, _options.ns_filter, &ns_opt) != 0) {
LOG(ERROR) << "Fail to initialize LoadBalancerWithNaming"; LOG(ERROR) << "Fail to initialize LoadBalancerWithNaming";
delete lb; delete lb;
......
...@@ -90,7 +90,7 @@ struct ChannelOptions { ...@@ -90,7 +90,7 @@ struct ChannelOptions {
bool log_succeed_without_server; bool log_succeed_without_server;
// SSL related options. Refer to `ChannelSSLOptions' for details // SSL related options. Refer to `ChannelSSLOptions' for details
ChannelSSLOptions ssl_options; std::shared_ptr<ChannelSSLOptions> ssl_options;
// Turn on authentication for this channel if `auth' is not NULL. // Turn on authentication for this channel if `auth' is not NULL.
// Note `auth' will not be deleted by channel and must remain valid when // Note `auth' will not be deleted by channel and must remain valid when
...@@ -99,9 +99,10 @@ struct ChannelOptions { ...@@ -99,9 +99,10 @@ struct ChannelOptions {
const Authenticator* auth; const Authenticator* auth;
// Customize the error code that should be retried. The interface is // Customize the error code that should be retried. The interface is
// defined src/brpc/retry_policy.h // defined in src/brpc/retry_policy.h
// This object is NOT owned by channel and should remain valid when // This object is NOT owned by channel and should remain valid when
// channel is used. // channel is used.
// Default: NULL
const RetryPolicy* retry_policy; const RetryPolicy* retry_policy;
// Filter ServerNodes (i.e. based on `tag' field of `ServerNode') // Filter ServerNodes (i.e. based on `tag' field of `ServerNode')
...@@ -109,7 +110,13 @@ struct ChannelOptions { ...@@ -109,7 +110,13 @@ struct ChannelOptions {
// in src/brpc/naming_service_filter.h // in src/brpc/naming_service_filter.h
// This object is NOT owned by channel and should remain valid when // This object is NOT owned by channel and should remain valid when
// channel is used. // channel is used.
// Default: NULL
const NamingServiceFilter* ns_filter; const NamingServiceFilter* ns_filter;
// Channels with same connection_group share connections. In an another
// word, set to a different value to not share connections.
// Default: 0
int connection_group;
}; };
// A Channel represents a communication line to one server or multiple servers // A Channel represents a communication line to one server or multiple servers
......
...@@ -28,17 +28,18 @@ ...@@ -28,17 +28,18 @@
namespace brpc { namespace brpc {
struct NSKey { struct NSKey {
const NamingService* ns; std::string protocol;
std::string service_name; std::string service_name;
}; };
struct NSKeyHasher { struct NSKeyHasher {
size_t operator()(const NSKey& nskey) const { size_t operator()(const NSKey& nskey) const {
return butil::DefaultHasher<std::string>()(nskey.service_name) return butil::DefaultHasher<std::string>()(nskey.service_name)
* 101 + (uintptr_t)nskey.ns; * 101 + butil::DefaultHasher<std::string>()(nskey.protocol);
} }
}; };
inline bool operator==(const NSKey& k1, const NSKey& k2) { inline bool operator==(const NSKey& k1, const NSKey& k2) {
return (k1.ns == k2.ns && k1.service_name == k2.service_name); return k1.protocol == k2.protocol &&
k1.service_name == k2.service_name;
} }
typedef butil::FlatMap<NSKey, NamingServiceThread*, NSKeyHasher> NamingServiceMap; typedef butil::FlatMap<NSKey, NamingServiceThread*, NSKeyHasher> NamingServiceMap;
...@@ -58,7 +59,8 @@ NamingServiceThread::Actions::~Actions() { ...@@ -58,7 +59,8 @@ NamingServiceThread::Actions::~Actions() {
// Remove all sockets from SocketMap // Remove all sockets from SocketMap
for (std::vector<ServerNode>::const_iterator it = _last_servers.begin(); for (std::vector<ServerNode>::const_iterator it = _last_servers.begin();
it != _last_servers.end(); ++it) { it != _last_servers.end(); ++it) {
SocketMapRemove(SocketMapKey(it->addr)); const SocketMapKey key(it->addr, _owner->_options.channel_signature);
SocketMapRemove(key);
} }
EndWait(0); EndWait(0);
} }
...@@ -110,7 +112,8 @@ void NamingServiceThread::Actions::ResetServers( ...@@ -110,7 +112,8 @@ void NamingServiceThread::Actions::ResetServers(
// TODO: For each unique SocketMapKey (i.e. SSL settings), insert a new // TODO: For each unique SocketMapKey (i.e. SSL settings), insert a new
// Socket. SocketMapKey may be passed through AddWatcher. Make sure // Socket. SocketMapKey may be passed through AddWatcher. Make sure
// to pick those Sockets with the right settings during OnAddedServers // to pick those Sockets with the right settings during OnAddedServers
CHECK_EQ(SocketMapInsert(SocketMapKey(_added[i].addr), &tagged_id.id), 0); const SocketMapKey key(_added[i].addr, _owner->_options.channel_signature);
CHECK_EQ(0, SocketMapInsert(key, &tagged_id.id, _owner->_options.ssl_ctx));
_added_sockets.push_back(tagged_id); _added_sockets.push_back(tagged_id);
} }
...@@ -118,7 +121,8 @@ void NamingServiceThread::Actions::ResetServers( ...@@ -118,7 +121,8 @@ void NamingServiceThread::Actions::ResetServers(
for (size_t i = 0; i < _removed.size(); ++i) { for (size_t i = 0; i < _removed.size(); ++i) {
ServerNodeWithId tagged_id; ServerNodeWithId tagged_id;
tagged_id.node = _removed[i]; tagged_id.node = _removed[i];
CHECK_EQ(0, SocketMapFind(SocketMapKey(_removed[i].addr), &tagged_id.id)); const SocketMapKey key(_removed[i].addr, _owner->_options.channel_signature);
CHECK_EQ(0, SocketMapFind(key, &tagged_id.id));
_removed_sockets.push_back(tagged_id); _removed_sockets.push_back(tagged_id);
} }
...@@ -169,7 +173,8 @@ void NamingServiceThread::Actions::ResetServers( ...@@ -169,7 +173,8 @@ void NamingServiceThread::Actions::ResetServers(
for (size_t i = 0; i < _removed.size(); ++i) { for (size_t i = 0; i < _removed.size(); ++i) {
// TODO: Remove all Sockets that have the same address in SocketMapKey.peer // TODO: Remove all Sockets that have the same address in SocketMapKey.peer
// We may need another data structure to avoid linear cost // We may need another data structure to avoid linear cost
SocketMapRemove(SocketMapKey(_removed[i].addr)); const SocketMapKey key(_removed[i].addr, _owner->_options.channel_signature);
SocketMapRemove(key);
} }
if (!_removed.empty() || !_added.empty()) { if (!_removed.empty() || !_added.empty()) {
...@@ -207,7 +212,6 @@ int NamingServiceThread::Actions::WaitForFirstBatchOfServers() { ...@@ -207,7 +212,6 @@ int NamingServiceThread::Actions::WaitForFirstBatchOfServers() {
NamingServiceThread::NamingServiceThread() NamingServiceThread::NamingServiceThread()
: _tid(0) : _tid(0)
, _source_ns(NULL)
, _ns(NULL) , _ns(NULL)
, _actions(this) { , _actions(this) {
} }
...@@ -215,8 +219,8 @@ NamingServiceThread::NamingServiceThread() ...@@ -215,8 +219,8 @@ NamingServiceThread::NamingServiceThread()
NamingServiceThread::~NamingServiceThread() { NamingServiceThread::~NamingServiceThread() {
RPC_VLOG << "~NamingServiceThread(" << *this << ')'; RPC_VLOG << "~NamingServiceThread(" << *this << ')';
// Remove from g_nsthread_map first // Remove from g_nsthread_map first
if (_source_ns != NULL) { if (!_protocol.empty()) {
const NSKey key = { _source_ns, _service_name }; const NSKey key = { _protocol, _service_name };
std::unique_lock<pthread_mutex_t> mu(g_nsthread_map_mutex); std::unique_lock<pthread_mutex_t> mu(g_nsthread_map_mutex);
if (g_nsthread_map != NULL) { if (g_nsthread_map != NULL) {
NamingServiceThread** ptr = g_nsthread_map->seek(key); NamingServiceThread** ptr = g_nsthread_map->seek(key);
...@@ -255,15 +259,16 @@ void* NamingServiceThread::RunThis(void* arg) { ...@@ -255,15 +259,16 @@ void* NamingServiceThread::RunThis(void* arg) {
return NULL; return NULL;
} }
int NamingServiceThread::Start(const NamingService* naming_service, int NamingServiceThread::Start(NamingService* naming_service,
const std::string& protocol,
const std::string& service_name, const std::string& service_name,
const GetNamingServiceThreadOptions* opt_in) { const GetNamingServiceThreadOptions* opt_in) {
if (naming_service == NULL) { if (naming_service == NULL) {
LOG(ERROR) << "Param[naming_service] is NULL"; LOG(ERROR) << "Param[naming_service] is NULL";
return -1; return -1;
} }
_source_ns = naming_service; _ns = naming_service;
_ns = naming_service->New(); _protocol = protocol;
_service_name = service_name; _service_name = service_name;
if (opt_in) { if (opt_in) {
_options = *opt_in; _options = *opt_in;
...@@ -400,13 +405,13 @@ int GetNamingServiceThread( ...@@ -400,13 +405,13 @@ int GetNamingServiceThread(
LOG(ERROR) << "Invalid naming service url=" << url; LOG(ERROR) << "Invalid naming service url=" << url;
return -1; return -1;
} }
const NamingService* ns = NamingServiceExtension()->Find(protocol); const NamingService* source_ns = NamingServiceExtension()->Find(protocol);
if (ns == NULL) { if (source_ns == NULL) {
LOG(ERROR) << "Unknown protocol=" << protocol; LOG(ERROR) << "Unknown protocol=" << protocol;
return -1; return -1;
} }
NSKey key; NSKey key;
key.ns = ns; key.protocol = protocol;
key.service_name = service_name; key.service_name = service_name;
bool new_thread = false; bool new_thread = false;
butil::intrusive_ptr<NamingServiceThread> nsthread; butil::intrusive_ptr<NamingServiceThread> nsthread;
...@@ -452,7 +457,7 @@ int GetNamingServiceThread( ...@@ -452,7 +457,7 @@ int GetNamingServiceThread(
} }
} }
if (new_thread) { if (new_thread) {
if (nsthread->Start(ns, key.service_name, options) != 0) { if (nsthread->Start(source_ns->New(), key.protocol, key.service_name, options) != 0) {
LOG(ERROR) << "Fail to start NamingServiceThread"; LOG(ERROR) << "Fail to start NamingServiceThread";
std::unique_lock<pthread_mutex_t> mu(g_nsthread_map_mutex); std::unique_lock<pthread_mutex_t> mu(g_nsthread_map_mutex);
g_nsthread_map->erase(key); g_nsthread_map->erase(key);
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "brpc/shared_object.h" // SharedObject #include "brpc/shared_object.h" // SharedObject
#include "brpc/naming_service.h" // NamingService #include "brpc/naming_service.h" // NamingService
#include "brpc/naming_service_filter.h" // NamingServiceFilter #include "brpc/naming_service_filter.h" // NamingServiceFilter
#include "brpc/socket_map.h"
namespace brpc { namespace brpc {
...@@ -46,6 +46,8 @@ struct GetNamingServiceThreadOptions { ...@@ -46,6 +46,8 @@ struct GetNamingServiceThreadOptions {
bool succeed_without_server; bool succeed_without_server;
bool log_succeed_without_server; bool log_succeed_without_server;
ChannelSignature channel_signature;
std::shared_ptr<SocketSSLContext> ssl_ctx;
}; };
// A dedicated thread to map a name to ServerIds // A dedicated thread to map a name to ServerIds
...@@ -86,7 +88,9 @@ public: ...@@ -86,7 +88,9 @@ public:
NamingServiceThread(); NamingServiceThread();
~NamingServiceThread(); ~NamingServiceThread();
int Start(const NamingService* ns, const std::string& service_name, int Start(NamingService* ns,
const std::string& protocol,
const std::string& service_name,
const GetNamingServiceThreadOptions* options); const GetNamingServiceThreadOptions* options);
int WaitForFirstBatchOfServers(); int WaitForFirstBatchOfServers();
...@@ -106,9 +110,8 @@ private: ...@@ -106,9 +110,8 @@ private:
butil::Mutex _mutex; butil::Mutex _mutex;
bthread_t _tid; bthread_t _tid;
// TODO: better use a name.
const NamingService* _source_ns;
NamingService* _ns; NamingService* _ns;
std::string _protocol;
std::string _service_name; std::string _service_name;
GetNamingServiceThreadOptions _options; GetNamingServiceThreadOptions _options;
std::vector<ServerNodeWithId> _last_sockets; std::vector<ServerNodeWithId> _last_sockets;
......
...@@ -437,10 +437,6 @@ static int SetSSLOptions(SSL_CTX* ctx, const std::string& ciphers, ...@@ -437,10 +437,6 @@ static int SetSSLOptions(SSL_CTX* ctx, const std::string& ciphers,
} }
SSL_CTX* CreateClientSSLContext(const ChannelSSLOptions& options) { SSL_CTX* CreateClientSSLContext(const ChannelSSLOptions& options) {
if (!options.enable) {
return NULL;
}
std::unique_ptr<SSL_CTX, FreeSSLCTX> ssl_ctx( std::unique_ptr<SSL_CTX, FreeSSLCTX> ssl_ctx(
SSL_CTX_new(SSLv23_client_method())); SSL_CTX_new(SSLv23_client_method()));
if (!ssl_ctx) { if (!ssl_ctx) {
...@@ -770,51 +766,66 @@ int SSLDHInit() { ...@@ -770,51 +766,66 @@ int SSLDHInit() {
return 0; return 0;
} }
} // namespace brpc static std::string GetNextLevelSeparator(const char* sep) {
if (sep[0] != '\n') {
std::ostream& operator<<(std::ostream& os, SSL* ssl) { return sep;
os << "[SSL HANDSHAKE]" }
<< "\n* cipher: " << SSL_get_cipher(ssl) const size_t left_len = strlen(sep + 1);
<< "\n* protocol: " << SSL_get_version(ssl) if (left_len == 0) {
<< "\n* verify: " << (SSL_get_verify_mode(ssl) & SSL_VERIFY_PEER return "\n ";
? "success" : "none") }
<< "\n"; std::string new_sep;
new_sep.reserve(left_len * 2 + 1);
new_sep.append(sep, left_len + 1);
new_sep.append(sep + 1, left_len);
return new_sep;
}
void Print(std::ostream& os, SSL* ssl, const char* sep) {
os << "cipher=" << SSL_get_cipher(ssl) << sep
<< "protocol=" << SSL_get_version(ssl) << sep
<< "verify=" << (SSL_get_verify_mode(ssl) & SSL_VERIFY_PEER
? "success" : "none");
X509* cert = SSL_get_peer_certificate(ssl); X509* cert = SSL_get_peer_certificate(ssl);
if (cert) { if (cert) {
os << "\n" << cert; os << sep << "peer_certificate={";
const std::string new_sep = GetNextLevelSeparator(sep);
if (sep[0] == '\n') {
os << new_sep;
}
Print(os, cert, new_sep.c_str());
if (sep[0] == '\n') {
os << sep;
}
os << '}';
} }
return os;
} }
std::ostream& operator<<(std::ostream& os, X509* cert) { void Print(std::ostream& os, X509* cert, const char* sep) {
BIO* buf = BIO_new(BIO_s_mem()); BIO* buf = BIO_new(BIO_s_mem());
if (buf == NULL) { if (buf == NULL) {
return os; return;
} }
BIO_printf(buf, "[CERTIFICATE]"); BIO_printf(buf, "subject=");
BIO_printf(buf, "\n* subject: ");
X509_NAME_print(buf, X509_get_subject_name(cert), 0); X509_NAME_print(buf, X509_get_subject_name(cert), 0);
BIO_printf(buf, "\n* start date: "); BIO_printf(buf, "%sstart_date=", sep);
ASN1_TIME_print(buf, X509_get_notBefore(cert)); ASN1_TIME_print(buf, X509_get_notBefore(cert));
BIO_printf(buf, "\n* expire date: "); BIO_printf(buf, "%sexpire_date=", sep);
ASN1_TIME_print(buf, X509_get_notAfter(cert)); ASN1_TIME_print(buf, X509_get_notAfter(cert));
BIO_printf(buf, "\n* common name: "); BIO_printf(buf, "%scommon_name=", sep);
std::vector<std::string> hostnames; std::vector<std::string> hostnames;
brpc::ExtractHostnames(cert, &hostnames); brpc::ExtractHostnames(cert, &hostnames);
for (size_t i = 0; i < hostnames.size(); ++i) { for (size_t i = 0; i < hostnames.size(); ++i) {
BIO_printf(buf, "%s; ", hostnames[i].c_str()); BIO_printf(buf, "%s;", hostnames[i].c_str());
} }
BIO_printf(buf, "\n* issuer: "); BIO_printf(buf, "%sissuer=", sep);
X509_NAME_print(buf, X509_get_issuer_name(cert), 0); X509_NAME_print(buf, X509_get_issuer_name(cert), 0);
BIO_printf(buf, "\n");
char* bufp = NULL; char* bufp = NULL;
int len = BIO_get_mem_data(buf, &bufp); int len = BIO_get_mem_data(buf, &bufp);
os << butil::StringPiece(bufp, len); os << butil::StringPiece(bufp, len);
return os;
} }
} // namespace brpc
...@@ -22,8 +22,7 @@ ...@@ -22,8 +22,7 @@
// For some versions of openssl, SSL_* are defined inside this header // For some versions of openssl, SSL_* are defined inside this header
#include <openssl/ossl_typ.h> #include <openssl/ossl_typ.h>
#include "brpc/socket_id.h" // SocketId #include "brpc/socket_id.h" // SocketId
#include "brpc/ssl_option.h" // SSLOptions #include "brpc/ssl_option.h" // ServerSSLOptions
namespace brpc { namespace brpc {
...@@ -76,7 +75,7 @@ SSL_CTX* CreateClientSSLContext(const ChannelSSLOptions& options); ...@@ -76,7 +75,7 @@ SSL_CTX* CreateClientSSLContext(const ChannelSSLOptions& options);
// fields into `hostnames' // fields into `hostnames'
SSL_CTX* CreateServerSSLContext(const std::string& certificate_file, SSL_CTX* CreateServerSSLContext(const std::string& certificate_file,
const std::string& private_key_file, const std::string& private_key_file,
const SSLOptions& options, const ServerSSLOptions& options,
std::vector<std::string>* hostnames); std::vector<std::string>* hostnames);
// Create a new SSL (per connection object) using configurations in `ctx'. // Create a new SSL (per connection object) using configurations in `ctx'.
...@@ -92,9 +91,9 @@ void AddBIOBuffer(SSL* ssl, int fd, int bufsize); ...@@ -92,9 +91,9 @@ void AddBIOBuffer(SSL* ssl, int fd, int bufsize);
// set to indicate the reason (0 for EOF) // set to indicate the reason (0 for EOF)
SSLState DetectSSLState(int fd, int* error_code); SSLState DetectSSLState(int fd, int* error_code);
} // namespace brpc void Print(std::ostream& os, SSL* ssl, const char* sep);
void Print(std::ostream& os, X509* cert, const char* sep);
std::ostream& operator<<(std::ostream& os, SSL* ssl); } // namespace brpc
std::ostream& operator<<(std::ostream& os, X509* cert);
#endif // BRPC_SSL_HELPER_H #endif // BRPC_SSL_HELPER_H
...@@ -1683,7 +1683,7 @@ void RtmpClientStream::CleanupSocketForStream( ...@@ -1683,7 +1683,7 @@ void RtmpClientStream::CleanupSocketForStream(
Socket* prev_sock, Controller*, int /*error_code*/) { Socket* prev_sock, Controller*, int /*error_code*/) {
if (prev_sock) { if (prev_sock) {
if (_from_socketmap) { if (_from_socketmap) {
_client_impl->socket_map().Remove(prev_sock->remote_side(), _client_impl->socket_map().Remove(SocketMapKey(prev_sock->remote_side()),
prev_sock->id()); prev_sock->id());
} else { } else {
prev_sock->SetFailed(); // not necessary, already failed. prev_sock->SetFailed(); // not necessary, already failed.
...@@ -1888,7 +1888,7 @@ void RtmpClientStream::OnStopInternal() { ...@@ -1888,7 +1888,7 @@ void RtmpClientStream::OnStopInternal() {
LOG(FATAL) << "RtmpContext of " << *_rtmpsock << " is NULL"; LOG(FATAL) << "RtmpContext of " << *_rtmpsock << " is NULL";
} }
if (_from_socketmap) { if (_from_socketmap) {
_client_impl->socket_map().Remove(_rtmpsock->remote_side(), _client_impl->socket_map().Remove(SocketMapKey(_rtmpsock->remote_side()),
_rtmpsock->id()); _rtmpsock->id());
} else { } else {
_rtmpsock->ReleaseAdditionalReference(); _rtmpsock->ReleaseAdditionalReference();
......
...@@ -840,14 +840,18 @@ int Server::StartInternal(const butil::ip_t& ip, ...@@ -840,14 +840,18 @@ int Server::StartInternal(const butil::ip_t& ip,
// Free last SSL contexts // Free last SSL contexts
FreeSSLContexts(); FreeSSLContexts();
CertInfo& default_cert = _options.ssl_options.default_cert; if (_options.ssl_options) {
if (!default_cert.certificate.empty()) { CertInfo& default_cert = _options.ssl_options->default_cert;
if (default_cert.certificate.empty()) {
LOG(ERROR) << "default_cert is empty";
return -1;
}
if (AddCertificate(default_cert) != 0) { if (AddCertificate(default_cert) != 0) {
return -1; return -1;
} }
_default_ssl_ctx = _ssl_ctx_map.begin()->second.ctx; _default_ssl_ctx = _ssl_ctx_map.begin()->second.ctx;
const std::vector<CertInfo>& certs = _options.ssl_options.certs; const std::vector<CertInfo>& certs = _options.ssl_options->certs;
for (size_t i = 0; i < certs.size(); ++i) { for (size_t i = 0; i < certs.size(); ++i) {
if (AddCertificate(certs[i]) != 0) { if (AddCertificate(certs[i]) != 0) {
return -1; return -1;
...@@ -1791,6 +1795,10 @@ Server::FindServicePropertyByName(const butil::StringPiece& name) const { ...@@ -1791,6 +1795,10 @@ Server::FindServicePropertyByName(const butil::StringPiece& name) const {
} }
int Server::AddCertificate(const CertInfo& cert) { int Server::AddCertificate(const CertInfo& cert) {
if (_options.ssl_options == NULL) {
LOG(ERROR) << "ServerOptions.ssl_options is not configured yet";
return -1;
}
std::string cert_key(cert.certificate); std::string cert_key(cert.certificate);
cert_key.append(cert.private_key); cert_key.append(cert.private_key);
if (_ssl_ctx_map.seek(cert_key) != NULL) { if (_ssl_ctx_map.seek(cert_key) != NULL) {
...@@ -1800,15 +1808,17 @@ int Server::AddCertificate(const CertInfo& cert) { ...@@ -1800,15 +1808,17 @@ int Server::AddCertificate(const CertInfo& cert) {
SSLContext ssl_ctx; SSLContext ssl_ctx;
ssl_ctx.filters = cert.sni_filters; ssl_ctx.filters = cert.sni_filters;
ssl_ctx.ctx = CreateServerSSLContext(cert.certificate, cert.private_key, ssl_ctx.ctx = std::make_shared<SocketSSLContext>();
_options.ssl_options, &ssl_ctx.filters); SSL_CTX* raw_ctx = CreateServerSSLContext(cert.certificate, cert.private_key,
if (ssl_ctx.ctx == NULL) { *_options.ssl_options, &ssl_ctx.filters);
if (raw_ctx == NULL) {
return -1; return -1;
} }
ssl_ctx.ctx->raw_ctx = raw_ctx;
#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
SSL_CTX_set_tlsext_servername_callback(ssl_ctx.ctx, SSLSwitchCTXByHostname); SSL_CTX_set_tlsext_servername_callback(ssl_ctx.ctx->raw_ctx, SSLSwitchCTXByHostname);
SSL_CTX_set_tlsext_servername_arg(ssl_ctx.ctx, this); SSL_CTX_set_tlsext_servername_arg(ssl_ctx.ctx->raw_ctx, this);
#endif #endif
if (!_reload_cert_maps.Modify(AddCertMapping, ssl_ctx)) { if (!_reload_cert_maps.Modify(AddCertMapping, ssl_ctx)) {
...@@ -1850,6 +1860,10 @@ bool Server::AddCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) { ...@@ -1850,6 +1860,10 @@ bool Server::AddCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) {
} }
int Server::RemoveCertificate(const CertInfo& cert) { int Server::RemoveCertificate(const CertInfo& cert) {
if (_options.ssl_options == NULL) {
LOG(ERROR) << "ServerOptions.ssl_options is not configured yet";
return -1;
}
std::string cert_key(cert.certificate); std::string cert_key(cert.certificate);
cert_key.append(cert.private_key); cert_key.append(cert.private_key);
SSLContext* ssl_ctx = _ssl_ctx_map.seek(cert_key); SSLContext* ssl_ctx = _ssl_ctx_map.seek(cert_key);
...@@ -1868,8 +1882,6 @@ int Server::RemoveCertificate(const CertInfo& cert) { ...@@ -1868,8 +1882,6 @@ int Server::RemoveCertificate(const CertInfo& cert) {
return -1; return -1;
} }
// After a successful Modify, now it's safe to erase SSLContext
SSL_CTX_free(ssl_ctx->ctx);
_ssl_ctx_map.erase(cert_key); _ssl_ctx_map.erase(cert_key);
return 0; return 0;
} }
...@@ -1884,7 +1896,7 @@ bool Server::RemoveCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) { ...@@ -1884,7 +1896,7 @@ bool Server::RemoveCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) {
} else { } else {
cmap = &(bg.cert_map); cmap = &(bg.cert_map);
} }
SSL_CTX** ctx = cmap->seek(hostname); std::shared_ptr<SocketSSLContext>* ctx = cmap->seek(hostname);
if (ctx != NULL && *ctx == ssl_ctx.ctx) { if (ctx != NULL && *ctx == ssl_ctx.ctx) {
cmap->erase(hostname); cmap->erase(hostname);
} }
...@@ -1893,6 +1905,11 @@ bool Server::RemoveCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) { ...@@ -1893,6 +1905,11 @@ bool Server::RemoveCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) {
} }
int Server::ResetCertificates(const std::vector<CertInfo>& certs) { int Server::ResetCertificates(const std::vector<CertInfo>& certs) {
if (_options.ssl_options == NULL) {
LOG(ERROR) << "ServerOptions.ssl_options is not configured yet";
return -1;
}
SSLContextMap tmp_map; SSLContextMap tmp_map;
if (tmp_map.init(INITIAL_CERT_MAP) != 0) { if (tmp_map.init(INITIAL_CERT_MAP) != 0) {
LOG(ERROR) << "Fail to initialize tmp_map"; LOG(ERROR) << "Fail to initialize tmp_map";
...@@ -1901,8 +1918,8 @@ int Server::ResetCertificates(const std::vector<CertInfo>& certs) { ...@@ -1901,8 +1918,8 @@ int Server::ResetCertificates(const std::vector<CertInfo>& certs) {
// Add default certficiate into tmp_map first since it can't be reloaded // Add default certficiate into tmp_map first since it can't be reloaded
std::string default_cert_key = std::string default_cert_key =
_options.ssl_options.default_cert.certificate _options.ssl_options->default_cert.certificate
+ _options.ssl_options.default_cert.private_key; + _options.ssl_options->default_cert.private_key;
tmp_map[default_cert_key] = _ssl_ctx_map[default_cert_key]; tmp_map[default_cert_key] = _ssl_ctx_map[default_cert_key];
for (size_t i = 0; i < certs.size(); ++i) { for (size_t i = 0; i < certs.size(); ++i) {
...@@ -1915,28 +1932,26 @@ int Server::ResetCertificates(const std::vector<CertInfo>& certs) { ...@@ -1915,28 +1932,26 @@ int Server::ResetCertificates(const std::vector<CertInfo>& certs) {
SSLContext ssl_ctx; SSLContext ssl_ctx;
ssl_ctx.filters = certs[i].sni_filters; ssl_ctx.filters = certs[i].sni_filters;
ssl_ctx.ctx = CreateServerSSLContext( ssl_ctx.ctx = std::make_shared<SocketSSLContext>();
ssl_ctx.ctx->raw_ctx = CreateServerSSLContext(
certs[i].certificate, certs[i].private_key, certs[i].certificate, certs[i].private_key,
_options.ssl_options, &ssl_ctx.filters); *_options.ssl_options, &ssl_ctx.filters);
if (ssl_ctx.ctx == NULL) { if (ssl_ctx.ctx->raw_ctx == NULL) {
FreeSSLContextMap(tmp_map, true);
return -1; return -1;
} }
#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
SSL_CTX_set_tlsext_servername_callback(ssl_ctx.ctx, SSLSwitchCTXByHostname); SSL_CTX_set_tlsext_servername_callback(ssl_ctx.ctx->raw_ctx, SSLSwitchCTXByHostname);
SSL_CTX_set_tlsext_servername_arg(ssl_ctx.ctx, this); SSL_CTX_set_tlsext_servername_arg(ssl_ctx.ctx->raw_ctx, this);
#endif #endif
tmp_map[cert_key] = ssl_ctx; tmp_map[cert_key] = ssl_ctx;
} }
if (!_reload_cert_maps.Modify(ResetCertMappings, tmp_map)) { if (!_reload_cert_maps.Modify(ResetCertMappings, tmp_map)) {
FreeSSLContextMap(tmp_map, true);
return -1; return -1;
} }
_ssl_ctx_map.swap(tmp_map); _ssl_ctx_map.swap(tmp_map);
FreeSSLContextMap(tmp_map, true);
return 0; return 0;
} }
...@@ -1976,19 +1991,8 @@ bool Server::ResetCertMappings(CertMaps& bg, const SSLContextMap& ctx_map) { ...@@ -1976,19 +1991,8 @@ bool Server::ResetCertMappings(CertMaps& bg, const SSLContextMap& ctx_map) {
return true; return true;
} }
void Server::FreeSSLContextMap(SSLContextMap& ctx_map, bool keep_default) {
for (SSLContextMap::iterator it =
ctx_map.begin(); it != ctx_map.end(); ++it) {
if (keep_default && it->second.ctx == _default_ssl_ctx) {
continue;
}
SSL_CTX_free(it->second.ctx);
}
ctx_map.clear();
}
void Server::FreeSSLContexts() { void Server::FreeSSLContexts() {
FreeSSLContextMap(_ssl_ctx_map, false); _ssl_ctx_map.clear();
_reload_cert_maps.Modify(ClearCertMapping); _reload_cert_maps.Modify(ClearCertMapping);
_default_ssl_ctx = NULL; _default_ssl_ctx = NULL;
} }
...@@ -2082,7 +2086,7 @@ int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl, ...@@ -2082,7 +2086,7 @@ int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl,
int* al, Server* server) { int* al, Server* server) {
(void)al; (void)al;
const char* hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); const char* hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
bool strict_sni = server->_options.ssl_options.strict_sni; bool strict_sni = server->_options.ssl_options->strict_sni;
if (hostname == NULL) { if (hostname == NULL) {
return strict_sni ? SSL_TLSEXT_ERR_ALERT_FATAL : SSL_TLSEXT_ERR_NOACK; return strict_sni ? SSL_TLSEXT_ERR_ALERT_FATAL : SSL_TLSEXT_ERR_NOACK;
} }
...@@ -2092,7 +2096,7 @@ int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl, ...@@ -2092,7 +2096,7 @@ int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl,
return SSL_TLSEXT_ERR_ALERT_FATAL; return SSL_TLSEXT_ERR_ALERT_FATAL;
} }
SSL_CTX** pctx = s->cert_map.seek(hostname); std::shared_ptr<SocketSSLContext>* pctx = s->cert_map.seek(hostname);
if (pctx == NULL) { if (pctx == NULL) {
const char* dot = hostname; const char* dot = hostname;
for (; *dot != '\0'; ++dot) { for (; *dot != '\0'; ++dot) {
...@@ -2114,7 +2118,7 @@ int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl, ...@@ -2114,7 +2118,7 @@ int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl,
} }
// Switch SSL_CTX to the one with correct hostname // Switch SSL_CTX to the one with correct hostname
SSL_set_SSL_CTX(ssl, *pctx); SSL_set_SSL_CTX(ssl, (*pctx)->raw_ctx);
return SSL_TLSEXT_ERR_OK; return SSL_TLSEXT_ERR_OK;
} }
#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME #endif // SSL_CTRL_SET_TLSEXT_HOSTNAME
......
...@@ -38,10 +38,6 @@ ...@@ -38,10 +38,6 @@
#include "brpc/health_reporter.h" #include "brpc/health_reporter.h"
#include "brpc/adaptive_max_concurrency.h" #include "brpc/adaptive_max_concurrency.h"
extern "C" {
struct ssl_ctx_st;
}
namespace brpc { namespace brpc {
class Acceptor; class Acceptor;
...@@ -52,6 +48,7 @@ class SimpleDataPool; ...@@ -52,6 +48,7 @@ class SimpleDataPool;
class MongoServiceAdaptor; class MongoServiceAdaptor;
class RestfulMap; class RestfulMap;
class RtmpService; class RtmpService;
class SocketSSLContext;
struct ServerOptions { struct ServerOptions {
// Constructed with default options. // Constructed with default options.
...@@ -202,7 +199,7 @@ struct ServerOptions { ...@@ -202,7 +199,7 @@ struct ServerOptions {
bool security_mode() const { return internal_port >= 0 || !has_builtin_services; } bool security_mode() const { return internal_port >= 0 || !has_builtin_services; }
// SSL related options. Refer to `ServerSSLOptions' for details // SSL related options. Refer to `ServerSSLOptions' for details
ServerSSLOptions ssl_options; std::shared_ptr<ServerSSLOptions> ssl_options;
// [CAUTION] This option is for implementing specialized http proxies, // [CAUTION] This option is for implementing specialized http proxies,
// most users don't need it. Don't change this option unless you fully // most users don't need it. Don't change this option unless you fully
...@@ -573,21 +570,20 @@ friend class Controller; ...@@ -573,21 +570,20 @@ friend class Controller;
std::string ServerPrefix() const; std::string ServerPrefix() const;
// Mapping from hostname to corresponding SSL_CTX // Mapping from hostname to corresponding SSL_CTX
typedef butil::CaseIgnoredFlatMap<struct ssl_ctx_st*> CertMap; typedef butil::CaseIgnoredFlatMap<std::shared_ptr<SocketSSLContext> > CertMap;
struct CertMaps { struct CertMaps {
CertMap cert_map; CertMap cert_map;
CertMap wildcard_cert_map; CertMap wildcard_cert_map;
}; };
struct SSLContext { struct SSLContext {
struct ssl_ctx_st* ctx; std::shared_ptr<SocketSSLContext> ctx;
std::vector<std::string> filters; std::vector<std::string> filters;
}; };
// Mapping from [certficate + private-key] to SSLContext // Mapping from [certficate + private-key] to SSLContext
typedef butil::FlatMap<std::string, SSLContext> SSLContextMap; typedef butil::FlatMap<std::string, SSLContext> SSLContextMap;
void FreeSSLContexts(); void FreeSSLContexts();
void FreeSSLContextMap(SSLContextMap& ctx_map, bool keep_default);
static int SSLSwitchCTXByHostname(struct ssl_st* ssl, static int SSLSwitchCTXByHostname(struct ssl_st* ssl,
int* al, Server* server); int* al, Server* server);
...@@ -636,7 +632,7 @@ friend class Controller; ...@@ -636,7 +632,7 @@ friend class Controller;
RestfulMap* _global_restful_map; RestfulMap* _global_restful_map;
// Default certficate which can't be reloaded // Default certficate which can't be reloaded
struct ssl_ctx_st* _default_ssl_ctx; std::shared_ptr<SocketSSLContext> _default_ssl_ctx;
// Reloadable SSL mappings // Reloadable SSL mappings
butil::DoublyBufferedData<CertMaps> _reload_cert_maps; butil::DoublyBufferedData<CertMaps> _reload_cert_maps;
......
...@@ -640,7 +640,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) { ...@@ -640,7 +640,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) {
return -1; return -1;
} }
// Disable SSL check if there is no SSL context // Disable SSL check if there is no SSL context
m->_ssl_state = (options.ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN); m->_ssl_state = (options.initial_ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN);
m->_ssl_session = NULL; m->_ssl_session = NULL;
m->_connection_type_for_progressive_read = CONNECTION_TYPE_UNKNOWN; m->_connection_type_for_progressive_read = CONNECTION_TYPE_UNKNOWN;
m->_controller_released_socket.store(false, butil::memory_order_relaxed); m->_controller_released_socket.store(false, butil::memory_order_relaxed);
...@@ -1048,9 +1048,7 @@ void Socket::OnRecycle() { ...@@ -1048,9 +1048,7 @@ void Socket::OnRecycle() {
_ssl_session = NULL; _ssl_session = NULL;
} }
if (_options.owns_ssl_ctx && _options.ssl_ctx) { _options.initial_ssl_ctx = NULL;
SSL_CTX_free(_options.ssl_ctx);
}
delete _pipeline_q; delete _pipeline_q;
_pipeline_q = NULL; _pipeline_q = NULL;
...@@ -1162,7 +1160,7 @@ int Socket::WaitEpollOut(int fd, bool pollin, const timespec* abstime) { ...@@ -1162,7 +1160,7 @@ int Socket::WaitEpollOut(int fd, bool pollin, const timespec* abstime) {
int Socket::Connect(const timespec* abstime, int Socket::Connect(const timespec* abstime,
int (*on_connect)(int, int, void*), void* data) { int (*on_connect)(int, int, void*), void* data) {
if (_options.ssl_ctx) { if (_options.initial_ssl_ctx) {
_ssl_state = SSL_CONNECTING; _ssl_state = SSL_CONNECTING;
} else { } else {
_ssl_state = SSL_OFF; _ssl_state = SSL_OFF;
...@@ -1780,7 +1778,7 @@ ssize_t Socket::DoWrite(WriteRequest* req) { ...@@ -1780,7 +1778,7 @@ ssize_t Socket::DoWrite(WriteRequest* req) {
} }
int Socket::SSLHandshake(int fd, bool server_mode) { int Socket::SSLHandshake(int fd, bool server_mode) {
if (_options.ssl_ctx == NULL) { if (_options.initial_ssl_ctx == NULL) {
if (server_mode) { if (server_mode) {
LOG(ERROR) << "Lack SSL configuration to handle SSL request"; LOG(ERROR) << "Lack SSL configuration to handle SSL request";
return -1; return -1;
...@@ -1793,15 +1791,17 @@ int Socket::SSLHandshake(int fd, bool server_mode) { ...@@ -1793,15 +1791,17 @@ int Socket::SSLHandshake(int fd, bool server_mode) {
// Free the last session, which may be deprecated when socket failed // Free the last session, which may be deprecated when socket failed
SSL_free(_ssl_session); SSL_free(_ssl_session);
} }
_ssl_session = CreateSSLSession(_options.ssl_ctx, id(), fd, server_mode); _ssl_session = CreateSSLSession(_options.initial_ssl_ctx->raw_ctx, id(), fd, server_mode);
if (_ssl_session == NULL) { if (_ssl_session == NULL) {
LOG(ERROR) << "Fail to CreateSSLSession";
return -1; return -1;
} }
#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
if (!_options.sni_name.empty()) { if (!_options.initial_ssl_ctx->sni_name.empty()) {
SSL_set_tlsext_host_name(_ssl_session, _options.sni_name.c_str()); SSL_set_tlsext_host_name(_ssl_session, _options.initial_ssl_ctx->sni_name.c_str());
} }
#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME #endif
_ssl_state = SSL_CONNECTING; _ssl_state = SSL_CONNECTING;
// Loop until SSL handshake has completed. For SSL_ERROR_WANT_READ/WRITE, // Loop until SSL handshake has completed. For SSL_ERROR_WANT_READ/WRITE,
...@@ -2159,74 +2159,84 @@ void Socket::DebugSocket(std::ostream& os, SocketId id) { ...@@ -2159,74 +2159,84 @@ void Socket::DebugSocket(std::ostream& os, SocketId id) {
} else { } else {
os << "\nparsing_context=" << ShowObject(parsing_context); os << "\nparsing_context=" << ShowObject(parsing_context);
} }
const SSLState ssl_state = ptr->ssl_state();
os << "\npipeline_q=" << npipelined os << "\npipeline_q=" << npipelined
<< "\nhc_interval_s=" << ptr->_health_check_interval_s << "\nhc_interval_s=" << ptr->_health_check_interval_s
<< "\nninprocess=" << ptr->_ninprocess.load(butil::memory_order_relaxed) << "\nninprocess=" << ptr->_ninprocess.load(butil::memory_order_relaxed)
<< "\nauth_flag_error=" << ptr->_auth_flag_error.load(butil::memory_order_relaxed) << "\nauth_flag_error=" << ptr->_auth_flag_error.load(butil::memory_order_relaxed)
<< "\nauth_id=" << ptr->_auth_id.value << "\nauth_id=" << ptr->_auth_id.value
<< "\nauth_context=" << ptr->_auth_context << "\nauth_context=" << ptr->_auth_context
<< "\nssl_state=" << SSLStateToString(ptr->_ssl_state)
<< "\nssl_ctx=" << (void*)ptr->_options.ssl_ctx
<< "\nssl_session=" << (void*)ptr->_ssl_session
<< "\nlogoff_flag=" << ptr->_logoff_flag.load(butil::memory_order_relaxed) << "\nlogoff_flag=" << ptr->_logoff_flag.load(butil::memory_order_relaxed)
<< "\nrecycle_flag=" << ptr->_recycle_flag.load(butil::memory_order_relaxed) << "\nrecycle_flag=" << ptr->_recycle_flag.load(butil::memory_order_relaxed)
<< "\ncid=" << ptr->_correlation_id << "\ncid=" << ptr->_correlation_id
<< "\nwrite_head=" << ptr->_write_head.load(butil::memory_order_relaxed); << "\nwrite_head=" << ptr->_write_head.load(butil::memory_order_relaxed)
if (ptr->ssl_state() == SSL_CONNECTED) { << "\nssl_state=" << SSLStateToString(ssl_state);
os << "\n\n" << ptr->_ssl_session; const SocketSSLContext* ssl_ctx = ptr->_options.initial_ssl_ctx.get();
if (ssl_ctx) {
os << "\ninitial_ssl_ctx=" << ssl_ctx->raw_ctx;
if (!ssl_ctx->sni_name.empty()) {
os << "\nsni_name=" << ssl_ctx->sni_name;
}
}
if (ssl_state == SSL_CONNECTED) {
os << "\nssl_session={\n ";
Print(os, ptr->_ssl_session, "\n ");
os << "\n}";
} }
#if defined(OS_MACOSX) #if defined(OS_MACOSX)
struct tcp_connection_info ti; struct tcp_connection_info ti;
socklen_t len = sizeof(ti); socklen_t len = sizeof(ti);
if (fd >= 0 && getsockopt(fd, IPPROTO_TCP, TCP_CONNECTION_INFO, &ti, &len) == 0) { if (fd >= 0 && getsockopt(fd, IPPROTO_TCP, TCP_CONNECTION_INFO, &ti, &len) == 0) {
os << "\ntcpi_state=" << (uint32_t)ti.tcpi_state os << "\ntcpi={\n state=" << (uint32_t)ti.tcpi_state
<< "\ntcpi_snd_wscale=" << (uint32_t)ti.tcpi_snd_wscale << "\n snd_wscale=" << (uint32_t)ti.tcpi_snd_wscale
<< "\ntcpi_rcv_wscale=" << (uint32_t)ti.tcpi_rcv_wscale << "\n rcv_wscale=" << (uint32_t)ti.tcpi_rcv_wscale
<< "\ntcpi_options=" << (uint32_t)ti.tcpi_options << "\n options=" << (uint32_t)ti.tcpi_options
<< "\ntcpi_flags=" << (uint32_t)ti.tcpi_flags << "\n flags=" << (uint32_t)ti.tcpi_flags
<< "\ntcpi_rto=" << ti.tcpi_rto << "\n rto=" << ti.tcpi_rto
<< "\ntcpi_maxseg=" << ti.tcpi_maxseg << "\n maxseg=" << ti.tcpi_maxseg
<< "\ntcpi_snd_ssthresh=" << ti.tcpi_snd_ssthresh << "\n snd_ssthresh=" << ti.tcpi_snd_ssthresh
<< "\ntcpi_snd_cwnd=" << ti.tcpi_snd_cwnd << "\n snd_cwnd=" << ti.tcpi_snd_cwnd
<< "\ntcpi_snd_wnd=" << ti.tcpi_snd_wnd << "\n snd_wnd=" << ti.tcpi_snd_wnd
<< "\ntcpi_snd_sbbytes=" << ti.tcpi_snd_sbbytes << "\n snd_sbbytes=" << ti.tcpi_snd_sbbytes
<< "\ntcpi_rcv_wnd=" << ti.tcpi_rcv_wnd << "\n rcv_wnd=" << ti.tcpi_rcv_wnd
<< "\ntcpi_srtt=" << ti.tcpi_srtt << "\n srtt=" << ti.tcpi_srtt
<< "\ntcpi_rttvar=" << ti.tcpi_rttvar; << "\n rttvar=" << ti.tcpi_rttvar
<< "\n}";
} }
#elif defined(OS_LINUX) #elif defined(OS_LINUX)
struct tcp_info ti; struct tcp_info ti;
socklen_t len = sizeof(ti); socklen_t len = sizeof(ti);
if (fd >= 0 && getsockopt(fd, SOL_TCP, TCP_INFO, &ti, &len) == 0) { if (fd >= 0 && getsockopt(fd, SOL_TCP, TCP_INFO, &ti, &len) == 0) {
os << "\ntcpi_state=" << (uint32_t)ti.tcpi_state os << "\ntcpi={\n state=" << (uint32_t)ti.tcpi_state
<< "\ntcpi_ca_state=" << (uint32_t)ti.tcpi_ca_state << "\n ca_state=" << (uint32_t)ti.tcpi_ca_state
<< "\ntcpi_retransmits=" << (uint32_t)ti.tcpi_retransmits << "\n retransmits=" << (uint32_t)ti.tcpi_retransmits
<< "\ntcpi_probes=" << (uint32_t)ti.tcpi_probes << "\n probes=" << (uint32_t)ti.tcpi_probes
<< "\ntcpi_backoff=" << (uint32_t)ti.tcpi_backoff << "\n backoff=" << (uint32_t)ti.tcpi_backoff
<< "\ntcpi_options=" << (uint32_t)ti.tcpi_options << "\n options=" << (uint32_t)ti.tcpi_options
<< "\ntcpi_snd_wscale=" << (uint32_t)ti.tcpi_snd_wscale << "\n snd_wscale=" << (uint32_t)ti.tcpi_snd_wscale
<< "\ntcpi_rcv_wscale=" << (uint32_t)ti.tcpi_rcv_wscale << "\n rcv_wscale=" << (uint32_t)ti.tcpi_rcv_wscale
<< "\ntcpi_rto=" << ti.tcpi_rto << "\n rto=" << ti.tcpi_rto
<< "\ntcpi_ato=" << ti.tcpi_ato << "\n ato=" << ti.tcpi_ato
<< "\ntcpi_snd_mss=" << ti.tcpi_snd_mss << "\n snd_mss=" << ti.tcpi_snd_mss
<< "\ntcpi_rcv_mss=" << ti.tcpi_rcv_mss << "\n rcv_mss=" << ti.tcpi_rcv_mss
<< "\ntcpi_unacked=" << ti.tcpi_unacked << "\n unacked=" << ti.tcpi_unacked
<< "\ntcpi_sacked=" << ti.tcpi_sacked << "\n sacked=" << ti.tcpi_sacked
<< "\ntcpi_lost=" << ti.tcpi_lost << "\n lost=" << ti.tcpi_lost
<< "\ntcpi_retrans=" << ti.tcpi_retrans << "\n retrans=" << ti.tcpi_retrans
<< "\ntcpi_fackets=" << ti.tcpi_fackets << "\n fackets=" << ti.tcpi_fackets
<< "\ntcpi_last_data_sent=" << ti.tcpi_last_data_sent << "\n last_data_sent=" << ti.tcpi_last_data_sent
<< "\ntcpi_last_ack_sent=" << ti.tcpi_last_ack_sent << "\n last_ack_sent=" << ti.tcpi_last_ack_sent
<< "\ntcpi_last_data_recv=" << ti.tcpi_last_data_recv << "\n last_data_recv=" << ti.tcpi_last_data_recv
<< "\ntcpi_last_ack_recv=" << ti.tcpi_last_ack_recv << "\n last_ack_recv=" << ti.tcpi_last_ack_recv
<< "\ntcpi_pmtu=" << ti.tcpi_pmtu << "\n pmtu=" << ti.tcpi_pmtu
<< "\ntcpi_rcv_ssthresh=" << ti.tcpi_rcv_ssthresh << "\n rcv_ssthresh=" << ti.tcpi_rcv_ssthresh
<< "\ntcpi_rtt=" << ti.tcpi_rtt // smoothed << "\n rtt=" << ti.tcpi_rtt // smoothed
<< "\ntcpi_rttvar=" << ti.tcpi_rttvar << "\n rttvar=" << ti.tcpi_rttvar
<< "\ntcpi_snd_ssthresh=" << ti.tcpi_snd_ssthresh << "\n snd_ssthresh=" << ti.tcpi_snd_ssthresh
<< "\ntcpi_snd_cwnd=" << ti.tcpi_snd_cwnd << "\n snd_cwnd=" << ti.tcpi_snd_cwnd
<< "\ntcpi_advmss=" << ti.tcpi_advmss << "\n advmss=" << ti.tcpi_advmss
<< "\ntcpi_reordering=" << ti.tcpi_reordering; << "\n reordering=" << ti.tcpi_reordering
<< "\n}";
} }
#endif #endif
} }
...@@ -2356,8 +2366,6 @@ inline int SocketPool::GetSocket(SocketUniquePtr* ptr) { ...@@ -2356,8 +2366,6 @@ inline int SocketPool::GetSocket(SocketUniquePtr* ptr) {
} }
// Not found in pool // Not found in pool
SocketOptions opt = _options; SocketOptions opt = _options;
// Only main socket can be the owner of ssl_ctx
opt.owns_ssl_ctx = false;
opt.health_check_interval_s = -1; opt.health_check_interval_s = -1;
if (get_client_side_messenger()->Create(opt, &sid) == 0 && if (get_client_side_messenger()->Create(opt, &sid) == 0 &&
Socket::Address(sid, ptr) == 0) { Socket::Address(sid, ptr) == 0) {
...@@ -2533,8 +2541,6 @@ int Socket::GetShortSocket(Socket* main_socket, ...@@ -2533,8 +2541,6 @@ int Socket::GetShortSocket(Socket* main_socket,
} }
SocketId id; SocketId id;
SocketOptions opt = main_socket->_options; SocketOptions opt = main_socket->_options;
// Only main socket can be the owner of ssl_ctx
opt.owns_ssl_ctx = false;
opt.health_check_interval_s = -1; opt.health_check_interval_s = -1;
if (get_client_side_messenger()->Create(opt, &id) != 0) { if (get_client_side_messenger()->Create(opt, &id) != 0) {
return -1; return -1;
...@@ -2619,6 +2625,16 @@ std::string Socket::description() const { ...@@ -2619,6 +2625,16 @@ std::string Socket::description() const {
return result; return result;
} }
SocketSSLContext::SocketSSLContext()
: raw_ctx(NULL)
{}
SocketSSLContext::~SocketSSLContext() {
if (raw_ctx) {
SSL_CTX_free(raw_ctx);
}
}
} // namespace brpc } // namespace brpc
......
...@@ -38,7 +38,6 @@ ...@@ -38,7 +38,6 @@
#include "brpc/socket_id.h" // SocketId #include "brpc/socket_id.h" // SocketId
#include "brpc/socket_message.h" // SocketMessagePtr #include "brpc/socket_message.h" // SocketMessagePtr
namespace brpc { namespace brpc {
namespace policy { namespace policy {
class ConsistentHashingLoadBalancer; class ConsistentHashingLoadBalancer;
...@@ -137,6 +136,14 @@ struct PipelinedInfo { ...@@ -137,6 +136,14 @@ struct PipelinedInfo {
bthread_id_t id_wait; bthread_id_t id_wait;
}; };
struct SocketSSLContext {
SocketSSLContext();
~SocketSSLContext();
SSL_CTX* raw_ctx; // owned
std::string sni_name; // useful for clients
};
// TODO: Comment fields // TODO: Comment fields
struct SocketOptions { struct SocketOptions {
SocketOptions(); SocketOptions();
...@@ -155,9 +162,7 @@ struct SocketOptions { ...@@ -155,9 +162,7 @@ struct SocketOptions {
// one thread at any time. // one thread at any time.
void (*on_edge_triggered_events)(Socket*); void (*on_edge_triggered_events)(Socket*);
int health_check_interval_s; int health_check_interval_s;
bool owns_ssl_ctx; std::shared_ptr<SocketSSLContext> initial_ssl_ctx;
SSL_CTX* ssl_ctx;
std::string sni_name;
bthread_keytable_pool_t* keytable_pool; bthread_keytable_pool_t* keytable_pool;
SocketConnection* conn; SocketConnection* conn;
AppConnect* app_connect; AppConnect* app_connect;
......
...@@ -54,8 +54,6 @@ inline SocketOptions::SocketOptions() ...@@ -54,8 +54,6 @@ inline SocketOptions::SocketOptions()
, user(NULL) , user(NULL)
, on_edge_triggered_events(NULL) , on_edge_triggered_events(NULL)
, health_check_interval_s(-1) , health_check_interval_s(-1)
, owns_ssl_ctx(false)
, ssl_ctx(NULL)
, keytable_pool(NULL) , keytable_pool(NULL)
, conn(NULL) , conn(NULL)
, app_connect(NULL) , app_connect(NULL)
......
...@@ -21,14 +21,11 @@ ...@@ -21,14 +21,11 @@
#include "butil/time.h" #include "butil/time.h"
#include "butil/scoped_lock.h" #include "butil/scoped_lock.h"
#include "butil/logging.h" #include "butil/logging.h"
#include "butil/third_party/murmurhash3/murmurhash3.h"
#include "brpc/log.h" #include "brpc/log.h"
#include "brpc/protocol.h" #include "brpc/protocol.h"
#include "brpc/input_messenger.h" #include "brpc/input_messenger.h"
#include "brpc/reloadable_flags.h" #include "brpc/reloadable_flags.h"
#include "brpc/socket_map.h" #include "brpc/socket_map.h"
#include "brpc/details/ssl_helper.h" // CreateClientSSLContext
namespace brpc { namespace brpc {
...@@ -88,62 +85,9 @@ SocketMap* get_or_new_client_side_socket_map() { ...@@ -88,62 +85,9 @@ SocketMap* get_or_new_client_side_socket_map() {
return g_socket_map.load(butil::memory_order_consume); return g_socket_map.load(butil::memory_order_consume);
} }
void ComputeSocketMapKeyChecksum(const SocketMapKey& key, int SocketMapInsert(const SocketMapKey& key, SocketId* id,
unsigned char* checksum) { const std::shared_ptr<SocketSSLContext>& ssl_ctx) {
butil::MurmurHash3_x64_128_Context mm_ctx; return get_or_new_client_side_socket_map()->Insert(key, id, ssl_ctx);
butil::MurmurHash3_x64_128_Init(&mm_ctx, 0);
const int BUFSIZE = 1024; // Should be enough
char buf[BUFSIZE];
int cur_len = 0;
#define SAFE_MEMCOPY(dst, cur_len, src, size) \
do { \
int copy_len = std::min((int)size, BUFSIZE - cur_len); \
if (copy_len > 0) { \
memcpy(dst + cur_len, src, copy_len); \
cur_len += copy_len; \
} \
} while (0);
std::size_t ephash = butil::DefaultHasher<butil::EndPoint>()(key.peer);
SAFE_MEMCOPY(buf, cur_len, &ephash, sizeof(ephash));
SAFE_MEMCOPY(buf, cur_len, &key.auth, sizeof(key.auth));
const ChannelSSLOptions& ssl = key.ssl_options;
SAFE_MEMCOPY(buf, cur_len, &ssl.enable, sizeof(ssl.enable));
if (ssl.enable) {
SAFE_MEMCOPY(buf, cur_len, ssl.ciphers.data(), ssl.ciphers.size());
SAFE_MEMCOPY(buf, cur_len, ssl.protocols.data(), ssl.protocols.size());
SAFE_MEMCOPY(buf, cur_len, ssl.sni_name.data(), ssl.sni_name.size());
const VerifyOptions& verify = ssl.verify;
SAFE_MEMCOPY(buf, cur_len, &verify.verify_depth,
sizeof(verify.verify_depth));
if (verify.verify_depth > 0) {
SAFE_MEMCOPY(buf, cur_len, verify.ca_file_path.data(),
verify.ca_file_path.size());
}
} else {
// All disabled ChannelSSLOptions are the same
}
#undef SAFE_MEMCOPY
butil::MurmurHash3_x64_128_Update(&mm_ctx, buf, cur_len);
const CertInfo& cert = ssl.client_cert;
if (ssl.enable && !cert.certificate.empty()) {
// Certificate may be too long (PEM string) to fit into `buf'
butil::MurmurHash3_x64_128_Update(
&mm_ctx, cert.certificate.data(), cert.certificate.size());
butil::MurmurHash3_x64_128_Update(
&mm_ctx, cert.private_key.data(), cert.private_key.size());
// sni_filters has no effect in ChannelSSLOptions
}
butil::MurmurHash3_x64_128_Final(checksum, &mm_ctx);
}
int SocketMapInsert(const SocketMapKey& key, SocketId* id) {
return get_or_new_client_side_socket_map()->Insert(key, id);
} }
int SocketMapFind(const SocketMapKey& key, SocketId* id) { int SocketMapFind(const SocketMapKey& key, SocketId* id) {
...@@ -264,10 +208,10 @@ void SocketMap::PrintSocketMap(std::ostream& os, void* arg) { ...@@ -264,10 +208,10 @@ void SocketMap::PrintSocketMap(std::ostream& os, void* arg) {
static_cast<SocketMap*>(arg)->Print(os); static_cast<SocketMap*>(arg)->Print(os);
} }
int SocketMap::Insert(const SocketMapKey& key, SocketId* id) { int SocketMap::Insert(const SocketMapKey& key, SocketId* id,
SocketMapKeyChecksum ck(key); const std::shared_ptr<SocketSSLContext>& ssl_ctx) {
std::unique_lock<butil::Mutex> mu(_mutex); std::unique_lock<butil::Mutex> mu(_mutex);
SingleConnection* sc = _map.seek(ck); SingleConnection* sc = _map.seek(key);
if (sc) { if (sc) {
if (!sc->socket->Failed() || if (!sc->socket->Failed() ||
sc->socket->health_check_interval() > 0/*HC enabled*/) { sc->socket->health_check_interval() > 0/*HC enabled*/) {
...@@ -277,30 +221,20 @@ int SocketMap::Insert(const SocketMapKey& key, SocketId* id) { ...@@ -277,30 +221,20 @@ int SocketMap::Insert(const SocketMapKey& key, SocketId* id) {
} }
// A socket w/o HC is failed (permanently), replace it. // A socket w/o HC is failed (permanently), replace it.
SocketUniquePtr ptr(sc->socket); // Remove the ref added at insertion. SocketUniquePtr ptr(sc->socket); // Remove the ref added at insertion.
_map.erase(ck); // in principle, we can override the entry in map w/o _map.erase(key); // in principle, we can override the entry in map w/o
// removing and inserting it again. But this would make error branches // removing and inserting it again. But this would make error branches
// below have to remove the entry before returning, which is // below have to remove the entry before returning, which is
// error-prone. We prefer code maintainability here. // error-prone. We prefer code maintainability here.
sc = NULL; sc = NULL;
} }
std::unique_ptr<SSL_CTX, FreeSSLCTX> ssl_ctx(
CreateClientSSLContext(key.ssl_options));
if (key.ssl_options.enable && !ssl_ctx) {
return -1;
}
SocketId tmp_id; SocketId tmp_id;
SocketOptions opt; SocketOptions opt;
opt.remote_side = key.peer; opt.remote_side = key.peer;
// Can't save SSL_CTX in SocketMap since SingleConnection's desctruction opt.initial_ssl_ctx = ssl_ctx;
// may happen before Socket's destruction (remove Channel before RPC complete)
opt.owns_ssl_ctx = true;
opt.ssl_ctx = ssl_ctx.get();
opt.sni_name = key.ssl_options.sni_name;
if (_options.socket_creator->CreateSocket(opt, &tmp_id) != 0) { if (_options.socket_creator->CreateSocket(opt, &tmp_id) != 0) {
PLOG(FATAL) << "Fail to create socket to " << key.peer; PLOG(FATAL) << "Fail to create socket to " << key.peer;
return -1; return -1;
} }
ssl_ctx.release();
// Add a reference to make sure that sc->socket is always accessible. Not // Add a reference to make sure that sc->socket is always accessible. Not
// use SocketUniquePtr which cannot put into containers before c++11. // use SocketUniquePtr which cannot put into containers before c++11.
// The ref will be removed at entry's removal. // The ref will be removed at entry's removal.
...@@ -310,7 +244,7 @@ int SocketMap::Insert(const SocketMapKey& key, SocketId* id) { ...@@ -310,7 +244,7 @@ int SocketMap::Insert(const SocketMapKey& key, SocketId* id) {
return -1; return -1;
} }
SingleConnection new_sc = { 1, ptr.release(), 0 }; SingleConnection new_sc = { 1, ptr.release(), 0 };
_map[ck] = new_sc; _map[key] = new_sc;
*id = tmp_id; *id = tmp_id;
bool need_to_create_bvar = false; bool need_to_create_bvar = false;
if (FLAGS_show_socketmap_in_vars && !_exposed_in_bvar) { if (FLAGS_show_socketmap_in_vars && !_exposed_in_bvar) {
...@@ -334,9 +268,8 @@ void SocketMap::Remove(const SocketMapKey& key, SocketId expected_id) { ...@@ -334,9 +268,8 @@ void SocketMap::Remove(const SocketMapKey& key, SocketId expected_id) {
void SocketMap::RemoveInternal(const SocketMapKey& key, void SocketMap::RemoveInternal(const SocketMapKey& key,
SocketId expected_id, SocketId expected_id,
bool remove_orphan) { bool remove_orphan) {
SocketMapKeyChecksum ck(key);
std::unique_lock<butil::Mutex> mu(_mutex); std::unique_lock<butil::Mutex> mu(_mutex);
SingleConnection* sc = _map.seek(ck); SingleConnection* sc = _map.seek(key);
if (!sc) { if (!sc) {
return; return;
} }
...@@ -354,7 +287,7 @@ void SocketMap::RemoveInternal(const SocketMapKey& key, ...@@ -354,7 +287,7 @@ void SocketMap::RemoveInternal(const SocketMapKey& key,
sc->no_ref_us = butil::cpuwide_time_us(); sc->no_ref_us = butil::cpuwide_time_us();
} else { } else {
Socket* const s = sc->socket; Socket* const s = sc->socket;
_map.erase(ck); _map.erase(key);
bool need_to_create_bvar = false; bool need_to_create_bvar = false;
if (FLAGS_show_socketmap_in_vars && !_exposed_in_bvar) { if (FLAGS_show_socketmap_in_vars && !_exposed_in_bvar) {
_exposed_in_bvar = true; _exposed_in_bvar = true;
...@@ -374,9 +307,8 @@ void SocketMap::RemoveInternal(const SocketMapKey& key, ...@@ -374,9 +307,8 @@ void SocketMap::RemoveInternal(const SocketMapKey& key,
} }
int SocketMap::Find(const SocketMapKey& key, SocketId* id) { int SocketMap::Find(const SocketMapKey& key, SocketId* id) {
SocketMapKeyChecksum ck(key);
BAIDU_SCOPED_LOCK(_mutex); BAIDU_SCOPED_LOCK(_mutex);
SingleConnection* sc = _map.seek(ck); SingleConnection* sc = _map.seek(key);
if (sc) { if (sc) {
*id = sc->socket->id(); *id = sc->socket->id();
return 0; return 0;
...@@ -400,14 +332,14 @@ void SocketMap::List(std::vector<butil::EndPoint>* pts) { ...@@ -400,14 +332,14 @@ void SocketMap::List(std::vector<butil::EndPoint>* pts) {
} }
} }
void SocketMap::ListOrphans(int64_t defer_us, std::vector<butil::EndPoint>* out) { void SocketMap::ListOrphans(int64_t defer_us, std::vector<SocketMapKey>* out) {
out->clear(); out->clear();
const int64_t now = butil::cpuwide_time_us(); const int64_t now = butil::cpuwide_time_us();
BAIDU_SCOPED_LOCK(_mutex); BAIDU_SCOPED_LOCK(_mutex);
for (Map::iterator it = _map.begin(); it != _map.end(); ++it) { for (Map::iterator it = _map.begin(); it != _map.end(); ++it) {
SingleConnection& sc = it->second; SingleConnection& sc = it->second;
if (sc.ref_count == 0 && now - sc.no_ref_us >= defer_us) { if (sc.ref_count == 0 && now - sc.no_ref_us >= defer_us) {
out->push_back(it->first.peer); out->push_back(it->first);
} }
} }
} }
...@@ -420,7 +352,7 @@ void* SocketMap::RunWatchConnections(void* arg) { ...@@ -420,7 +352,7 @@ void* SocketMap::RunWatchConnections(void* arg) {
void SocketMap::WatchConnections() { void SocketMap::WatchConnections() {
std::vector<SocketId> main_sockets; std::vector<SocketId> main_sockets;
std::vector<SocketId> pooled_sockets; std::vector<SocketId> pooled_sockets;
std::vector<butil::EndPoint> orphan_sockets; std::vector<SocketMapKey> orphan_sockets;
const uint64_t CHECK_INTERVAL_US = 1000000UL; const uint64_t CHECK_INTERVAL_US = 1000000UL;
while (bthread_usleep(CHECK_INTERVAL_US) == 0) { while (bthread_usleep(CHECK_INTERVAL_US) == 0) {
// NOTE: save the gflag which may be reloaded at any time. // NOTE: save the gflag which may be reloaded at any time.
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "butil/containers/flat_map.h" // FlatMap #include "butil/containers/flat_map.h" // FlatMap
#include "brpc/socket_id.h" // SockdetId #include "brpc/socket_id.h" // SockdetId
#include "brpc/options.pb.h" // ProtocolType #include "brpc/options.pb.h" // ProtocolType
#include "brpc/ssl_option.h" // ChannelSSLOptions
#include "brpc/input_messenger.h" // InputMessageHandler #include "brpc/input_messenger.h" // InputMessageHandler
...@@ -30,29 +29,57 @@ namespace brpc { ...@@ -30,29 +29,57 @@ namespace brpc {
// Global mapping from remote-side to out-going sockets created by Channels. // Global mapping from remote-side to out-going sockets created by Channels.
struct ChannelSignature {
uint64_t data[2];
ChannelSignature() { Reset(); }
void Reset() { data[0] = data[1] = 0; }
};
inline bool operator==(const ChannelSignature& s1, const ChannelSignature& s2) {
return s1.data[0] == s2.data[0] && s1.data[1] == s2.data[1];
}
inline bool operator!=(const ChannelSignature& s1, const ChannelSignature& s2) {
return !(s1 == s2);
}
// The following fields uniquely define a Socket. In other word, // The following fields uniquely define a Socket. In other word,
// Socket can't be shared between 2 different SocketMapKeys // Socket can't be shared between 2 different SocketMapKeys
struct SocketMapKey { struct SocketMapKey {
SocketMapKey(const butil::EndPoint& pt, explicit SocketMapKey(const butil::EndPoint& pt)
ChannelSSLOptions ssl = ChannelSSLOptions(), : peer(pt)
const Authenticator* auth2 = NULL) {}
: peer(pt), ssl_options(ssl), auth(auth2) SocketMapKey(const butil::EndPoint& pt, const ChannelSignature& cs)
: peer(pt), channel_signature(cs)
{} {}
butil::EndPoint peer; butil::EndPoint peer;
ChannelSSLOptions ssl_options; ChannelSignature channel_signature;
const Authenticator* auth; };
inline bool operator==(const SocketMapKey& k1, const SocketMapKey& k2) {
return k1.peer == k2.peer && k1.channel_signature == k2.channel_signature;
};
struct SocketMapKeyHasher {
std::size_t operator()(const SocketMapKey& key) const {
return butil::DefaultHasher<butil::EndPoint>()(key.peer) ^
key.channel_signature.data[1];
}
}; };
// Calculate an 128-bit hashcode for SocketMapKey
void ComputeSocketMapKeyChecksum(const SocketMapKey& key,
unsigned char* checksum);
// Try to share the Socket to `key'. If the Socket does not exist, create one. // Try to share the Socket to `key'. If the Socket does not exist, create one.
// The corresponding SocketId is written to `*id'. If this function returns // The corresponding SocketId is written to `*id'. If this function returns
// successfully, SocketMapRemove() MUST be called when the Socket is not needed. // successfully, SocketMapRemove() MUST be called when the Socket is not needed.
// Return 0 on success, -1 otherwise. // Return 0 on success, -1 otherwise.
int SocketMapInsert(const SocketMapKey& key, SocketId* id); int SocketMapInsert(const SocketMapKey& key, SocketId* id,
const std::shared_ptr<SocketSSLContext>& ssl_ctx);
inline int SocketMapInsert(const SocketMapKey& key, SocketId* id) {
std::shared_ptr<SocketSSLContext> empty_ptr;
return SocketMapInsert(key, id, empty_ptr);
}
// Find the SocketId associated with `key'. // Find the SocketId associated with `key'.
// Return 0 on found, -1 otherwise. // Return 0 on found, -1 otherwise.
...@@ -110,7 +137,13 @@ public: ...@@ -110,7 +137,13 @@ public:
SocketMap(); SocketMap();
~SocketMap(); ~SocketMap();
int Init(const SocketMapOptions&); int Init(const SocketMapOptions&);
int Insert(const SocketMapKey& key, SocketId* id); int Insert(const SocketMapKey& key, SocketId* id,
const std::shared_ptr<SocketSSLContext>& ssl_ctx);
int Insert(const SocketMapKey& key, SocketId* id) {
std::shared_ptr<SocketSSLContext> empty_ptr;
return Insert(key, id, empty_ptr);
}
void Remove(const SocketMapKey& key, SocketId expected_id); void Remove(const SocketMapKey& key, SocketId expected_id);
int Find(const SocketMapKey& key, SocketId* id); int Find(const SocketMapKey& key, SocketId* id);
void List(std::vector<SocketId>* ids); void List(std::vector<SocketId>* ids);
...@@ -120,7 +153,7 @@ public: ...@@ -120,7 +153,7 @@ public:
private: private:
void RemoveInternal(const SocketMapKey& key, SocketId id, void RemoveInternal(const SocketMapKey& key, SocketId id,
bool remove_orphan); bool remove_orphan);
void ListOrphans(int64_t defer_us, std::vector<butil::EndPoint>* out); void ListOrphans(int64_t defer_us, std::vector<SocketMapKey>* out);
void WatchConnections(); void WatchConnections();
static void* RunWatchConnections(void*); static void* RunWatchConnections(void*);
void Print(std::ostream& os); void Print(std::ostream& os);
...@@ -133,39 +166,10 @@ private: ...@@ -133,39 +166,10 @@ private:
int64_t no_ref_us; int64_t no_ref_us;
}; };
// Store checksum of SocketMapKey instead of itself in order to:
// 1. Save precious space of key field in FlatMap
// 2. Simplify equivalence logic between SocketMapKeys
// (regard the hash collision to be zero)
struct SocketMapKeyChecksum {
explicit SocketMapKeyChecksum(const SocketMapKey& key)
: peer(key.peer) {
ComputeSocketMapKeyChecksum(key, checksum);
}
butil::EndPoint peer;
unsigned char checksum[16];
inline bool operator==(const SocketMapKeyChecksum& rhs) const {
return this->peer == rhs.peer
&& memcmp(this->checksum, rhs.checksum, sizeof(checksum)) == 0;
}
};
struct Checksum2Hash {
std::size_t operator()(const SocketMapKeyChecksum& key) const {
// Slice a subset of checksum over an evenly distributed hash
// won't affect the overall balance
std::size_t hash;
memcpy(&hash, key.checksum, sizeof(hash));
return hash;
}
};
// TODO: When RpcChannels connecting to one EndPoint are frequently created // TODO: When RpcChannels connecting to one EndPoint are frequently created
// and destroyed, a single map+mutex may become hot-spots. // and destroyed, a single map+mutex may become hot-spots.
typedef butil::FlatMap<SocketMapKeyChecksum, typedef butil::FlatMap<SocketMapKey, SingleConnection,
SingleConnection, Checksum2Hash> Map; SocketMapKeyHasher> Map;
SocketMapOptions _options; SocketMapOptions _options;
butil::Mutex _mutex; butil::Mutex _mutex;
Map _map; Map _map;
......
...@@ -21,8 +21,7 @@ namespace brpc { ...@@ -21,8 +21,7 @@ namespace brpc {
VerifyOptions::VerifyOptions() : verify_depth(0) {} VerifyOptions::VerifyOptions() : verify_depth(0) {}
ChannelSSLOptions::ChannelSSLOptions() ChannelSSLOptions::ChannelSSLOptions()
: enable(false) : ciphers("DEFAULT")
, ciphers("DEFAULT")
, protocols("TLSv1, TLSv1.1, TLSv1.2") , protocols("TLSv1, TLSv1.1, TLSv1.2")
{} {}
......
...@@ -59,10 +59,6 @@ struct ChannelSSLOptions { ...@@ -59,10 +59,6 @@ struct ChannelSSLOptions {
// Constructed with default options // Constructed with default options
ChannelSSLOptions(); ChannelSSLOptions();
// Whether to enable SSL on the channel.
// Default: false
bool enable;
// Cipher suites used for SSL handshake. // Cipher suites used for SSL handshake.
// The format of this string should follow that in `man 1 cipers'. // The format of this string should follow that in `man 1 cipers'.
// Default: "DEFAULT" // Default: "DEFAULT"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment