Commit db844fd9 authored by gejun's avatar gejun

Misc improvements on SSL-related code

parent c813572d
...@@ -184,9 +184,9 @@ int main(int argc, char* argv[]) { ...@@ -184,9 +184,9 @@ int main(int argc, char* argv[]) {
// Start the server. // Start the server.
brpc::ServerOptions options; brpc::ServerOptions options;
options.idle_timeout_sec = FLAGS_idle_timeout_s; options.idle_timeout_sec = FLAGS_idle_timeout_s;
options.ssl_options.default_cert.certificate = FLAGS_certificate; options.mutable_ssl_options()->default_cert.certificate = FLAGS_certificate;
options.ssl_options.default_cert.private_key = FLAGS_private_key; options.mutable_ssl_options()->default_cert.private_key = FLAGS_private_key;
options.ssl_options.ciphers = FLAGS_ciphers; options.mutable_ssl_options()->ciphers = FLAGS_ciphers;
if (server.Start(FLAGS_port, &options) != 0) { if (server.Start(FLAGS_port, &options) != 0) {
LOG(ERROR) << "Fail to start HttpServer"; LOG(ERROR) << "Fail to start HttpServer";
return -1; return -1;
......
...@@ -95,7 +95,9 @@ int main(int argc, char* argv[]) { ...@@ -95,7 +95,9 @@ int main(int argc, char* argv[]) {
// Initialize the channel, NULL means using default options. // Initialize the channel, NULL means using default options.
brpc::ChannelOptions options; brpc::ChannelOptions options;
options.ssl_options = std::make_shared<brpc::ChannelSSLOptions>(); if (FLAGS_enable_ssl) {
options.mutable_ssl_options();
}
options.protocol = FLAGS_protocol; options.protocol = FLAGS_protocol;
options.connection_type = FLAGS_connection_type; options.connection_type = FLAGS_connection_type;
options.connect_timeout_ms = std::min(FLAGS_timeout_ms / 2, 100); options.connect_timeout_ms = std::min(FLAGS_timeout_ms / 2, 100);
......
...@@ -82,9 +82,8 @@ int main(int argc, char* argv[]) { ...@@ -82,9 +82,8 @@ int main(int argc, char* argv[]) {
// Start the server. // Start the server.
brpc::ServerOptions options; brpc::ServerOptions options;
options.ssl_options = std::make_shared<brpc::ServerSSLOptions>(); options.mutable_ssl_options()->default_cert.certificate = "cert.pem";
options.ssl_options->default_cert.certificate = "cert.pem"; options.mutable_ssl_options()->default_cert.private_key = "key.pem";
options.ssl_options->default_cert.private_key = "key.pem";
options.idle_timeout_sec = FLAGS_idle_timeout_s; options.idle_timeout_sec = FLAGS_idle_timeout_s;
options.max_concurrency = FLAGS_max_concurrency; options.max_concurrency = FLAGS_max_concurrency;
options.internal_port = FLAGS_internal_port; options.internal_port = FLAGS_internal_port;
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "butil/time.h" // milliseconds_from_now #include "butil/time.h" // milliseconds_from_now
#include "butil/logging.h" #include "butil/logging.h"
#include "butil/third_party/murmurhash3/murmurhash3.h" #include "butil/third_party/murmurhash3/murmurhash3.h"
#include "butil/strings/string_util.h"
#include "bthread/unstable.h" // bthread_timer_add #include "bthread/unstable.h" // bthread_timer_add
#include "brpc/socket_map.h" // SocketMapInsert #include "brpc/socket_map.h" // SocketMapInsert
#include "brpc/compress.h" #include "brpc/compress.h"
...@@ -50,13 +51,19 @@ ChannelOptions::ChannelOptions() ...@@ -50,13 +51,19 @@ ChannelOptions::ChannelOptions()
, auth(NULL) , auth(NULL)
, retry_policy(NULL) , retry_policy(NULL)
, ns_filter(NULL) , ns_filter(NULL)
, connection_group(0)
{} {}
ChannelSSLOptions* ChannelOptions::mutable_ssl_options() {
if (!_ssl_options) {
_ssl_options.reset(new ChannelSSLOptions);
}
return _ssl_options.get();
}
static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) { static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) {
if (opt.auth == NULL && if (opt.auth == NULL &&
opt.ssl_options == NULL && !opt.has_ssl_options() &&
opt.connection_group == 0) { opt.connection_group.empty()) {
// Returning zeroized result by default is more intuitive for users. // Returning zeroized result by default is more intuitive for users.
return ChannelSignature(); return ChannelSignature();
} }
...@@ -68,23 +75,23 @@ static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) { ...@@ -68,23 +75,23 @@ static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) {
buf.clear(); buf.clear();
butil::MurmurHash3_x64_128_Init(&mm_ctx, seed); butil::MurmurHash3_x64_128_Init(&mm_ctx, seed);
if (opt.connection_group) { if (!opt.connection_group.empty()) {
buf.append("|conng="); buf.append("|conng=");
buf.append((char*)&opt.connection_group, sizeof(opt.connection_group)); buf.append(opt.connection_group);
} }
if (opt.auth) { if (opt.auth) {
buf.append("|auth="); buf.append("|auth=");
buf.append((char*)&opt.auth, sizeof(opt.auth)); buf.append((char*)&opt.auth, sizeof(opt.auth));
} }
const ChannelSSLOptions* ssl = opt.ssl_options.get(); if (opt.has_ssl_options()) {
if (ssl) { const ChannelSSLOptions& ssl = opt.ssl_options();
buf.push_back('|'); buf.push_back('|');
buf.append(ssl->ciphers); buf.append(ssl.ciphers);
buf.push_back('|'); buf.push_back('|');
buf.append(ssl->protocols); buf.append(ssl.protocols);
buf.push_back('|'); buf.push_back('|');
buf.append(ssl->sni_name); buf.append(ssl.sni_name);
const VerifyOptions& verify = ssl->verify; const VerifyOptions& verify = ssl.verify;
buf.push_back('|'); buf.push_back('|');
buf.append((char*)&verify.verify_depth, sizeof(verify.verify_depth)); buf.append((char*)&verify.verify_depth, sizeof(verify.verify_depth));
buf.push_back('|'); buf.push_back('|');
...@@ -95,8 +102,8 @@ static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) { ...@@ -95,8 +102,8 @@ static ChannelSignature ComputeChannelSignature(const ChannelOptions& opt) {
butil::MurmurHash3_x64_128_Update(&mm_ctx, buf.data(), buf.size()); butil::MurmurHash3_x64_128_Update(&mm_ctx, buf.data(), buf.size());
buf.clear(); buf.clear();
if (ssl) { if (opt.has_ssl_options()) {
const CertInfo& cert = ssl->client_cert; const CertInfo& cert = opt.ssl_options().client_cert;
if (!cert.certificate.empty()) { if (!cert.certificate.empty()) {
// Certificate may be too long (PEM string) to fit into `buf' // Certificate may be too long (PEM string) to fit into `buf'
butil::MurmurHash3_x64_128_Update( butil::MurmurHash3_x64_128_Update(
...@@ -187,19 +194,13 @@ int Channel::InitChannelOptions(const ChannelOptions* options) { ...@@ -187,19 +194,13 @@ int Channel::InitChannelOptions(const ChannelOptions* options) {
if (_options.auth == NULL) { if (_options.auth == NULL) {
_options.auth = policy::global_esp_authenticator(); _options.auth = policy::global_esp_authenticator();
} }
} else if (_options.protocol == brpc::PROTOCOL_HTTP) {
if (_raw_server_address.compare(0, 5, "https") == 0) {
if (_options.ssl_options == NULL) {
_options.ssl_options = std::make_shared<ChannelSSLOptions>();
}
if (_options.ssl_options->sni_name.empty()) {
int port;
ParseHostAndPortFromURL(_raw_server_address.c_str(),
&_options.ssl_options->sni_name, &port);
}
}
} }
// Normalize connection_group
std::string& cg = _options.connection_group;
if (!cg.empty() && (::isspace(cg.front()) || ::isspace(cg.back()))) {
butil::TrimWhitespace(cg, butil::TRIM_ALL, &cg);
}
return 0; return 0;
} }
...@@ -229,8 +230,7 @@ int Channel::Init(const char* server_addr_and_port, ...@@ -229,8 +230,7 @@ int Channel::Init(const char* server_addr_and_port,
return -1; return -1;
} }
} }
_raw_server_address.assign(server_addr_and_port); return InitSingle(point, server_addr_and_port, options);
return Init(point, options);
} }
int Channel::Init(const char* server_addr, int port, int Channel::Init(const char* server_addr, int port,
...@@ -252,25 +252,21 @@ int Channel::Init(const char* server_addr, int port, ...@@ -252,25 +252,21 @@ int Channel::Init(const char* server_addr, int port,
return -1; return -1;
} }
} }
_raw_server_address.assign(server_addr); return InitSingle(point, server_addr, options);
return Init(point, options);
} }
static int CreateSocketSSLContext(const ChannelOptions& options, static int CreateSocketSSLContext(const ChannelOptions& options,
ChannelSignature* sig,
std::shared_ptr<SocketSSLContext>* ssl_ctx) { std::shared_ptr<SocketSSLContext>* ssl_ctx) {
if (options.ssl_options != NULL) { if (options.has_ssl_options()) {
*sig = ComputeChannelSignature(options); SSL_CTX* raw_ctx = CreateClientSSLContext(options.ssl_options());
SSL_CTX* raw_ctx = CreateClientSSLContext(*options.ssl_options);
if (!raw_ctx) { if (!raw_ctx) {
LOG(ERROR) << "Fail to CreateClientSSLContext"; LOG(ERROR) << "Fail to CreateClientSSLContext";
return -1; return -1;
} }
*ssl_ctx = std::make_shared<SocketSSLContext>(); *ssl_ctx = std::make_shared<SocketSSLContext>();
(*ssl_ctx)->raw_ctx = raw_ctx; (*ssl_ctx)->raw_ctx = raw_ctx;
(*ssl_ctx)->sni_name = options.ssl_options->sni_name; (*ssl_ctx)->sni_name = options.ssl_options().sni_name;
} else { } else {
sig->Reset();
(*ssl_ctx) = NULL; (*ssl_ctx) = NULL;
} }
return 0; return 0;
...@@ -278,19 +274,32 @@ static int CreateSocketSSLContext(const ChannelOptions& options, ...@@ -278,19 +274,32 @@ static int CreateSocketSSLContext(const ChannelOptions& options,
int Channel::Init(butil::EndPoint server_addr_and_port, int Channel::Init(butil::EndPoint server_addr_and_port,
const ChannelOptions* options) { const ChannelOptions* options) {
return InitSingle(server_addr_and_port, "", options);
}
int Channel::InitSingle(const butil::EndPoint& server_addr_and_port,
const char* raw_server_address,
const ChannelOptions* options) {
GlobalInitializeOrDie(); GlobalInitializeOrDie();
if (InitChannelOptions(options) != 0) { if (InitChannelOptions(options) != 0) {
return -1; return -1;
} }
if (_options.protocol == brpc::PROTOCOL_HTTP &&
::strncmp(raw_server_address, "https://", 8) == 0) {
if (_options.mutable_ssl_options()->sni_name.empty()) {
ParseURL(raw_server_address,
NULL, &_options.mutable_ssl_options()->sni_name, NULL);
}
}
const int port = server_addr_and_port.port; const int port = server_addr_and_port.port;
if (port < 0 || port > 65535) { if (port < 0 || port > 65535) {
LOG(ERROR) << "Invalid port=" << port; LOG(ERROR) << "Invalid port=" << port;
return -1; return -1;
} }
_server_address = server_addr_and_port; _server_address = server_addr_and_port;
ChannelSignature sig; const ChannelSignature sig = ComputeChannelSignature(_options);
std::shared_ptr<SocketSSLContext> ssl_ctx; std::shared_ptr<SocketSSLContext> ssl_ctx;
if (CreateSocketSSLContext(_options, &sig, &ssl_ctx) != 0) { if (CreateSocketSSLContext(_options, &ssl_ctx) != 0) {
return -1; return -1;
} }
if (SocketMapInsert(SocketMapKey(server_addr_and_port, sig), if (SocketMapInsert(SocketMapKey(server_addr_and_port, sig),
...@@ -312,6 +321,13 @@ int Channel::Init(const char* ns_url, ...@@ -312,6 +321,13 @@ int Channel::Init(const char* ns_url,
if (InitChannelOptions(options) != 0) { if (InitChannelOptions(options) != 0) {
return -1; return -1;
} }
if (_options.protocol == brpc::PROTOCOL_HTTP &&
::strncmp(ns_url, "https://", 8) == 0) {
if (_options.mutable_ssl_options()->sni_name.empty()) {
ParseURL(ns_url,
NULL, &_options.mutable_ssl_options()->sni_name, NULL);
}
}
LoadBalancerWithNaming* lb = new (std::nothrow) LoadBalancerWithNaming; LoadBalancerWithNaming* lb = new (std::nothrow) LoadBalancerWithNaming;
if (NULL == lb) { if (NULL == lb) {
LOG(FATAL) << "Fail to new LoadBalancerWithNaming"; LOG(FATAL) << "Fail to new LoadBalancerWithNaming";
...@@ -320,7 +336,8 @@ int Channel::Init(const char* ns_url, ...@@ -320,7 +336,8 @@ int Channel::Init(const char* ns_url,
GetNamingServiceThreadOptions ns_opt; GetNamingServiceThreadOptions ns_opt;
ns_opt.succeed_without_server = _options.succeed_without_server; ns_opt.succeed_without_server = _options.succeed_without_server;
ns_opt.log_succeed_without_server = _options.log_succeed_without_server; ns_opt.log_succeed_without_server = _options.log_succeed_without_server;
if (CreateSocketSSLContext(_options, &ns_opt.channel_signature, &ns_opt.ssl_ctx) != 0) { ns_opt.channel_signature = ComputeChannelSignature(_options);
if (CreateSocketSSLContext(_options, &ns_opt.ssl_ctx) != 0) {
return -1; return -1;
} }
if (lb->Init(ns_url, lb_name, _options.ns_filter, &ns_opt) != 0) { if (lb->Init(ns_url, lb_name, _options.ns_filter, &ns_opt) != 0) {
......
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
#include <ostream> // std::ostream #include <ostream> // std::ostream
#include "bthread/errno.h" // Redefine errno #include "bthread/errno.h" // Redefine errno
#include "butil/intrusive_ptr.hpp" // butil::intrusive_ptr #include "butil/intrusive_ptr.hpp" // butil::intrusive_ptr
#include "brpc/ssl_option.h" // ChannelSSLOptions #include "butil/ptr_container.h"
#include "brpc/ssl_options.h" // ChannelSSLOptions
#include "brpc/channel_base.h" // ChannelBase #include "brpc/channel_base.h" // ChannelBase
#include "brpc/adaptive_protocol_type.h" // AdaptiveProtocolType #include "brpc/adaptive_protocol_type.h" // AdaptiveProtocolType
#include "brpc/adaptive_connection_type.h" // AdaptiveConnectionType #include "brpc/adaptive_connection_type.h" // AdaptiveConnectionType
...@@ -90,7 +91,9 @@ struct ChannelOptions { ...@@ -90,7 +91,9 @@ struct ChannelOptions {
bool log_succeed_without_server; bool log_succeed_without_server;
// SSL related options. Refer to `ChannelSSLOptions' for details // SSL related options. Refer to `ChannelSSLOptions' for details
std::shared_ptr<ChannelSSLOptions> ssl_options; bool has_ssl_options() const { return _ssl_options != NULL; }
const ChannelSSLOptions& ssl_options() const { return *_ssl_options.get(); }
ChannelSSLOptions* mutable_ssl_options();
// Turn on authentication for this channel if `auth' is not NULL. // Turn on authentication for this channel if `auth' is not NULL.
// Note `auth' will not be deleted by channel and must remain valid when // Note `auth' will not be deleted by channel and must remain valid when
...@@ -113,10 +116,16 @@ struct ChannelOptions { ...@@ -113,10 +116,16 @@ struct ChannelOptions {
// Default: NULL // Default: NULL
const NamingServiceFilter* ns_filter; const NamingServiceFilter* ns_filter;
// Channels with same connection_group share connections. In an another // Channels with same connection_group share connections.
// word, set to a different value to not share connections. // In other words, set to a different value to stop sharing connections.
// Default: 0 // Case-sensitive, leading and trailing spaces are ignored.
int connection_group; // Default: ""
std::string connection_group;
private:
// SSLOptions is large and not often used, allocate it on heap to
// prevent ChannelOptions from being bloated in most cases.
butil::PtrContainer<ChannelSSLOptions> _ssl_options;
}; };
// A Channel represents a communication line to one server or multiple servers // A Channel represents a communication line to one server or multiple servers
...@@ -195,8 +204,10 @@ protected: ...@@ -195,8 +204,10 @@ protected:
static void CallMethodImpl(Controller* controller, SharedLoadBalancer* lb); static void CallMethodImpl(Controller* controller, SharedLoadBalancer* lb);
int InitChannelOptions(const ChannelOptions* options); int InitChannelOptions(const ChannelOptions* options);
int InitSingle(const butil::EndPoint& server_addr_and_port,
const char* raw_server_address,
const ChannelOptions* options);
std::string _raw_server_address;
butil::EndPoint _server_address; butil::EndPoint _server_address;
SocketId _server_id; SocketId _server_id;
Protocol::SerializeRequest _serialize_request; Protocol::SerializeRequest _serialize_request;
......
...@@ -30,16 +30,26 @@ namespace brpc { ...@@ -30,16 +30,26 @@ namespace brpc {
struct NSKey { struct NSKey {
std::string protocol; std::string protocol;
std::string service_name; std::string service_name;
ChannelSignature channel_signature;
NSKey(const std::string& prot_in,
const std::string& service_in,
const ChannelSignature& sig)
: protocol(prot_in), service_name(service_in), channel_signature(sig) {
}
}; };
struct NSKeyHasher { struct NSKeyHasher {
size_t operator()(const NSKey& nskey) const { size_t operator()(const NSKey& nskey) const {
return butil::DefaultHasher<std::string>()(nskey.service_name) size_t h = butil::DefaultHasher<std::string>()(nskey.protocol);
* 101 + butil::DefaultHasher<std::string>()(nskey.protocol); h = h * 101 + butil::DefaultHasher<std::string>()(nskey.service_name);
h = h * 101 + nskey.channel_signature.data[1];
return h;
} }
}; };
inline bool operator==(const NSKey& k1, const NSKey& k2) { inline bool operator==(const NSKey& k1, const NSKey& k2) {
return k1.protocol == k2.protocol && return k1.protocol == k2.protocol &&
k1.service_name == k2.service_name; k1.service_name == k2.service_name &&
k1.channel_signature == k2.channel_signature;
} }
typedef butil::FlatMap<NSKey, NamingServiceThread*, NSKeyHasher> NamingServiceMap; typedef butil::FlatMap<NSKey, NamingServiceThread*, NSKeyHasher> NamingServiceMap;
...@@ -59,7 +69,7 @@ NamingServiceThread::Actions::~Actions() { ...@@ -59,7 +69,7 @@ NamingServiceThread::Actions::~Actions() {
// Remove all sockets from SocketMap // Remove all sockets from SocketMap
for (std::vector<ServerNode>::const_iterator it = _last_servers.begin(); for (std::vector<ServerNode>::const_iterator it = _last_servers.begin();
it != _last_servers.end(); ++it) { it != _last_servers.end(); ++it) {
const SocketMapKey key(it->addr, _owner->_options.channel_signature); const SocketMapKey key(*it, _owner->_options.channel_signature);
SocketMapRemove(key); SocketMapRemove(key);
} }
EndWait(0); EndWait(0);
...@@ -112,7 +122,7 @@ void NamingServiceThread::Actions::ResetServers( ...@@ -112,7 +122,7 @@ void NamingServiceThread::Actions::ResetServers(
// TODO: For each unique SocketMapKey (i.e. SSL settings), insert a new // TODO: For each unique SocketMapKey (i.e. SSL settings), insert a new
// Socket. SocketMapKey may be passed through AddWatcher. Make sure // Socket. SocketMapKey may be passed through AddWatcher. Make sure
// to pick those Sockets with the right settings during OnAddedServers // to pick those Sockets with the right settings during OnAddedServers
const SocketMapKey key(_added[i].addr, _owner->_options.channel_signature); const SocketMapKey key(_added[i], _owner->_options.channel_signature);
CHECK_EQ(0, SocketMapInsert(key, &tagged_id.id, _owner->_options.ssl_ctx)); CHECK_EQ(0, SocketMapInsert(key, &tagged_id.id, _owner->_options.ssl_ctx));
_added_sockets.push_back(tagged_id); _added_sockets.push_back(tagged_id);
} }
...@@ -121,7 +131,7 @@ void NamingServiceThread::Actions::ResetServers( ...@@ -121,7 +131,7 @@ void NamingServiceThread::Actions::ResetServers(
for (size_t i = 0; i < _removed.size(); ++i) { for (size_t i = 0; i < _removed.size(); ++i) {
ServerNodeWithId tagged_id; ServerNodeWithId tagged_id;
tagged_id.node = _removed[i]; tagged_id.node = _removed[i];
const SocketMapKey key(_removed[i].addr, _owner->_options.channel_signature); const SocketMapKey key(_removed[i], _owner->_options.channel_signature);
CHECK_EQ(0, SocketMapFind(key, &tagged_id.id)); CHECK_EQ(0, SocketMapFind(key, &tagged_id.id));
_removed_sockets.push_back(tagged_id); _removed_sockets.push_back(tagged_id);
} }
...@@ -173,7 +183,7 @@ void NamingServiceThread::Actions::ResetServers( ...@@ -173,7 +183,7 @@ void NamingServiceThread::Actions::ResetServers(
for (size_t i = 0; i < _removed.size(); ++i) { for (size_t i = 0; i < _removed.size(); ++i) {
// TODO: Remove all Sockets that have the same address in SocketMapKey.peer // TODO: Remove all Sockets that have the same address in SocketMapKey.peer
// We may need another data structure to avoid linear cost // We may need another data structure to avoid linear cost
const SocketMapKey key(_removed[i].addr, _owner->_options.channel_signature); const SocketMapKey key(_removed[i], _owner->_options.channel_signature);
SocketMapRemove(key); SocketMapRemove(key);
} }
...@@ -220,7 +230,7 @@ NamingServiceThread::~NamingServiceThread() { ...@@ -220,7 +230,7 @@ NamingServiceThread::~NamingServiceThread() {
RPC_VLOG << "~NamingServiceThread(" << *this << ')'; RPC_VLOG << "~NamingServiceThread(" << *this << ')';
// Remove from g_nsthread_map first // Remove from g_nsthread_map first
if (!_protocol.empty()) { if (!_protocol.empty()) {
const NSKey key = { _protocol, _service_name }; const NSKey key(_protocol, _service_name, _options.channel_signature);
std::unique_lock<pthread_mutex_t> mu(g_nsthread_map_mutex); std::unique_lock<pthread_mutex_t> mu(g_nsthread_map_mutex);
if (g_nsthread_map != NULL) { if (g_nsthread_map != NULL) {
NamingServiceThread** ptr = g_nsthread_map->seek(key); NamingServiceThread** ptr = g_nsthread_map->seek(key);
...@@ -410,9 +420,8 @@ int GetNamingServiceThread( ...@@ -410,9 +420,8 @@ int GetNamingServiceThread(
LOG(ERROR) << "Unknown protocol=" << protocol; LOG(ERROR) << "Unknown protocol=" << protocol;
return -1; return -1;
} }
NSKey key; const NSKey key(protocol, service_name,
key.protocol = protocol; (options ? options->channel_signature : ChannelSignature()));
key.service_name = service_name;
bool new_thread = false; bool new_thread = false;
butil::intrusive_ptr<NamingServiceThread> nsthread; butil::intrusive_ptr<NamingServiceThread> nsthread;
{ {
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
// For some versions of openssl, SSL_* are defined inside this header // For some versions of openssl, SSL_* are defined inside this header
#include <openssl/ossl_typ.h> #include <openssl/ossl_typ.h>
#include "brpc/socket_id.h" // SocketId #include "brpc/socket_id.h" // SocketId
#include "brpc/ssl_option.h" // ServerSSLOptions #include "brpc/ssl_options.h" // ServerSSLOptions
namespace brpc { namespace brpc {
......
...@@ -334,6 +334,7 @@ static void GlobalInitializeOrDieImpl() { ...@@ -334,6 +334,7 @@ static void GlobalInitializeOrDieImpl() {
NamingServiceExtension()->RegisterOrDie("file", &g_ext->fns); NamingServiceExtension()->RegisterOrDie("file", &g_ext->fns);
NamingServiceExtension()->RegisterOrDie("list", &g_ext->lns); NamingServiceExtension()->RegisterOrDie("list", &g_ext->lns);
NamingServiceExtension()->RegisterOrDie("http", &g_ext->dns); NamingServiceExtension()->RegisterOrDie("http", &g_ext->dns);
NamingServiceExtension()->RegisterOrDie("https", &g_ext->dns);
NamingServiceExtension()->RegisterOrDie("redis", &g_ext->dns); NamingServiceExtension()->RegisterOrDie("redis", &g_ext->dns);
NamingServiceExtension()->RegisterOrDie("remotefile", &g_ext->rfns); NamingServiceExtension()->RegisterOrDie("remotefile", &g_ext->rfns);
NamingServiceExtension()->RegisterOrDie("consul", &g_ext->cns); NamingServiceExtension()->RegisterOrDie("consul", &g_ext->cns);
......
...@@ -25,24 +25,10 @@ ...@@ -25,24 +25,10 @@
#include "brpc/describable.h" #include "brpc/describable.h"
#include "brpc/destroyable.h" #include "brpc/destroyable.h"
#include "brpc/extension.h" // Extension<T> #include "brpc/extension.h" // Extension<T>
#include "brpc/server_node.h" // ServerNode
namespace brpc { namespace brpc {
// Representing a server inside a NamingService.
struct ServerNode {
ServerNode() {}
ServerNode(butil::ip_t ip, int port, const std::string& tag2)
: addr(ip, port), tag(tag2) {}
ServerNode(const butil::EndPoint& pt, const std::string& tag2)
: addr(pt), tag(tag2) {}
ServerNode(butil::ip_t ip, int port) : addr(ip, port) {}
explicit ServerNode(const butil::EndPoint& pt) : addr(pt) {}
butil::EndPoint addr;
std::string tag;
};
// Continuing actions to added/removed servers. // Continuing actions to added/removed servers.
// NOTE: You don't have to implement this class. // NOTE: You don't have to implement this class.
class NamingServiceActions { class NamingServiceActions {
...@@ -84,21 +70,6 @@ inline Extension<const NamingService>* NamingServiceExtension() { ...@@ -84,21 +70,6 @@ inline Extension<const NamingService>* NamingServiceExtension() {
return Extension<const NamingService>::instance(); return Extension<const NamingService>::instance();
} }
inline bool operator<(const ServerNode& n1, const ServerNode& n2)
{ return n1.addr != n2.addr ? (n1.addr < n2.addr) : (n1.tag < n2.tag); }
inline bool operator==(const ServerNode& n1, const ServerNode& n2)
{ return n1.addr == n2.addr && n1.tag == n2.tag; }
inline bool operator!=(const ServerNode& n1, const ServerNode& n2)
{ return !(n1 == n2); }
inline std::ostream& operator<<(std::ostream& os, const ServerNode& n) {
os << n.addr;
if (!n.tag.empty()) {
os << "(tag=" << n.tag << ')';
}
return os;
}
} // namespace brpc } // namespace brpc
#endif // BRPC_NAMING_SERVICE_H #endif // BRPC_NAMING_SERVICE_H
...@@ -1297,14 +1297,28 @@ void ProcessHttpRequest(InputMessageBase *msg) { ...@@ -1297,14 +1297,28 @@ void ProcessHttpRequest(InputMessageBase *msg) {
} }
bool ParseHttpServerAddress(butil::EndPoint* point, const char* server_addr_and_port) { bool ParseHttpServerAddress(butil::EndPoint* point, const char* server_addr_and_port) {
std::string schema;
std::string host; std::string host;
int port = -1; int port = -1;
if (ParseHostAndPortFromURL(server_addr_and_port, &host, &port) != 0) { if (ParseURL(server_addr_and_port, &schema, &host, &port) != 0) {
LOG(ERROR) << "Invalid address=`" << server_addr_and_port << '\'';
return false;
}
if (schema.empty() || schema == "http") {
if (port < 0) {
port = 80;
}
} else if (schema == "https") {
if (port < 0) {
port = 443;
}
} else {
LOG(ERROR) << "Invalid schema=`" << schema << '\'';
return false; return false;
} }
if (str2endpoint(host.c_str(), port, point) != 0 && if (str2endpoint(host.c_str(), port, point) != 0 &&
hostname2endpoint(host.c_str(), port, point) != 0) { hostname2endpoint(host.c_str(), port, point) != 0) {
LOG(ERROR) << "Invalid address=`" << host << '\''; LOG(ERROR) << "Invalid host=" << host << " port=" << port;
return false; return false;
} }
return true; return true;
......
...@@ -145,6 +145,13 @@ ServerOptions::ServerOptions() ...@@ -145,6 +145,13 @@ ServerOptions::ServerOptions()
} }
} }
ServerSSLOptions* ServerOptions::mutable_ssl_options() {
if (!_ssl_options) {
_ssl_options.reset(new ServerSSLOptions);
}
return _ssl_options.get();
}
Server::MethodProperty::OpaqueParams::OpaqueParams() Server::MethodProperty::OpaqueParams::OpaqueParams()
: is_tabbed(false) : is_tabbed(false)
, allow_http_body_to_pb(true) , allow_http_body_to_pb(true)
...@@ -840,8 +847,8 @@ int Server::StartInternal(const butil::ip_t& ip, ...@@ -840,8 +847,8 @@ int Server::StartInternal(const butil::ip_t& ip,
// Free last SSL contexts // Free last SSL contexts
FreeSSLContexts(); FreeSSLContexts();
if (_options.ssl_options) { if (_options.has_ssl_options()) {
CertInfo& default_cert = _options.ssl_options->default_cert; CertInfo& default_cert = _options.mutable_ssl_options()->default_cert;
if (default_cert.certificate.empty()) { if (default_cert.certificate.empty()) {
LOG(ERROR) << "default_cert is empty"; LOG(ERROR) << "default_cert is empty";
return -1; return -1;
...@@ -851,7 +858,7 @@ int Server::StartInternal(const butil::ip_t& ip, ...@@ -851,7 +858,7 @@ int Server::StartInternal(const butil::ip_t& ip,
} }
_default_ssl_ctx = _ssl_ctx_map.begin()->second.ctx; _default_ssl_ctx = _ssl_ctx_map.begin()->second.ctx;
const std::vector<CertInfo>& certs = _options.ssl_options->certs; const std::vector<CertInfo>& certs = _options.mutable_ssl_options()->certs;
for (size_t i = 0; i < certs.size(); ++i) { for (size_t i = 0; i < certs.size(); ++i) {
if (AddCertificate(certs[i]) != 0) { if (AddCertificate(certs[i]) != 0) {
return -1; return -1;
...@@ -1795,7 +1802,7 @@ Server::FindServicePropertyByName(const butil::StringPiece& name) const { ...@@ -1795,7 +1802,7 @@ Server::FindServicePropertyByName(const butil::StringPiece& name) const {
} }
int Server::AddCertificate(const CertInfo& cert) { int Server::AddCertificate(const CertInfo& cert) {
if (_options.ssl_options == NULL) { if (!_options.has_ssl_options()) {
LOG(ERROR) << "ServerOptions.ssl_options is not configured yet"; LOG(ERROR) << "ServerOptions.ssl_options is not configured yet";
return -1; return -1;
} }
...@@ -1810,7 +1817,7 @@ int Server::AddCertificate(const CertInfo& cert) { ...@@ -1810,7 +1817,7 @@ int Server::AddCertificate(const CertInfo& cert) {
ssl_ctx.filters = cert.sni_filters; ssl_ctx.filters = cert.sni_filters;
ssl_ctx.ctx = std::make_shared<SocketSSLContext>(); ssl_ctx.ctx = std::make_shared<SocketSSLContext>();
SSL_CTX* raw_ctx = CreateServerSSLContext(cert.certificate, cert.private_key, SSL_CTX* raw_ctx = CreateServerSSLContext(cert.certificate, cert.private_key,
*_options.ssl_options, &ssl_ctx.filters); _options.ssl_options(), &ssl_ctx.filters);
if (raw_ctx == NULL) { if (raw_ctx == NULL) {
return -1; return -1;
} }
...@@ -1860,7 +1867,7 @@ bool Server::AddCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) { ...@@ -1860,7 +1867,7 @@ bool Server::AddCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) {
} }
int Server::RemoveCertificate(const CertInfo& cert) { int Server::RemoveCertificate(const CertInfo& cert) {
if (_options.ssl_options == NULL) { if (!_options.has_ssl_options()) {
LOG(ERROR) << "ServerOptions.ssl_options is not configured yet"; LOG(ERROR) << "ServerOptions.ssl_options is not configured yet";
return -1; return -1;
} }
...@@ -1905,7 +1912,7 @@ bool Server::RemoveCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) { ...@@ -1905,7 +1912,7 @@ bool Server::RemoveCertMapping(CertMaps& bg, const SSLContext& ssl_ctx) {
} }
int Server::ResetCertificates(const std::vector<CertInfo>& certs) { int Server::ResetCertificates(const std::vector<CertInfo>& certs) {
if (_options.ssl_options == NULL) { if (!_options.has_ssl_options()) {
LOG(ERROR) << "ServerOptions.ssl_options is not configured yet"; LOG(ERROR) << "ServerOptions.ssl_options is not configured yet";
return -1; return -1;
} }
...@@ -1918,8 +1925,8 @@ int Server::ResetCertificates(const std::vector<CertInfo>& certs) { ...@@ -1918,8 +1925,8 @@ int Server::ResetCertificates(const std::vector<CertInfo>& certs) {
// Add default certficiate into tmp_map first since it can't be reloaded // Add default certficiate into tmp_map first since it can't be reloaded
std::string default_cert_key = std::string default_cert_key =
_options.ssl_options->default_cert.certificate _options.ssl_options().default_cert.certificate
+ _options.ssl_options->default_cert.private_key; + _options.ssl_options().default_cert.private_key;
tmp_map[default_cert_key] = _ssl_ctx_map[default_cert_key]; tmp_map[default_cert_key] = _ssl_ctx_map[default_cert_key];
for (size_t i = 0; i < certs.size(); ++i) { for (size_t i = 0; i < certs.size(); ++i) {
...@@ -1935,7 +1942,7 @@ int Server::ResetCertificates(const std::vector<CertInfo>& certs) { ...@@ -1935,7 +1942,7 @@ int Server::ResetCertificates(const std::vector<CertInfo>& certs) {
ssl_ctx.ctx = std::make_shared<SocketSSLContext>(); ssl_ctx.ctx = std::make_shared<SocketSSLContext>();
ssl_ctx.ctx->raw_ctx = CreateServerSSLContext( ssl_ctx.ctx->raw_ctx = CreateServerSSLContext(
certs[i].certificate, certs[i].private_key, certs[i].certificate, certs[i].private_key,
*_options.ssl_options, &ssl_ctx.filters); _options.ssl_options(), &ssl_ctx.filters);
if (ssl_ctx.ctx->raw_ctx == NULL) { if (ssl_ctx.ctx->raw_ctx == NULL) {
return -1; return -1;
} }
...@@ -2086,7 +2093,7 @@ int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl, ...@@ -2086,7 +2093,7 @@ int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl,
int* al, Server* server) { int* al, Server* server) {
(void)al; (void)al;
const char* hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); const char* hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
bool strict_sni = server->_options.ssl_options->strict_sni; bool strict_sni = server->_options.ssl_options().strict_sni;
if (hostname == NULL) { if (hostname == NULL) {
return strict_sni ? SSL_TLSEXT_ERR_ALERT_FATAL : SSL_TLSEXT_ERR_NOACK; return strict_sni ? SSL_TLSEXT_ERR_ALERT_FATAL : SSL_TLSEXT_ERR_NOACK;
} }
......
...@@ -29,8 +29,9 @@ ...@@ -29,8 +29,9 @@
#include "butil/containers/doubly_buffered_data.h" // DoublyBufferedData #include "butil/containers/doubly_buffered_data.h" // DoublyBufferedData
#include "bvar/bvar.h" #include "bvar/bvar.h"
#include "butil/containers/case_ignored_flat_map.h" // [CaseIgnored]FlatMap #include "butil/containers/case_ignored_flat_map.h" // [CaseIgnored]FlatMap
#include "butil/ptr_container.h"
#include "brpc/controller.h" // brpc::Controller #include "brpc/controller.h" // brpc::Controller
#include "brpc/ssl_option.h" // ServerSSLOptions #include "brpc/ssl_options.h" // ServerSSLOptions
#include "brpc/describable.h" // User often needs this #include "brpc/describable.h" // User often needs this
#include "brpc/data_factory.h" // DataFactory #include "brpc/data_factory.h" // DataFactory
#include "brpc/builtin/tabbed.h" #include "brpc/builtin/tabbed.h"
...@@ -199,7 +200,9 @@ struct ServerOptions { ...@@ -199,7 +200,9 @@ struct ServerOptions {
bool security_mode() const { return internal_port >= 0 || !has_builtin_services; } bool security_mode() const { return internal_port >= 0 || !has_builtin_services; }
// SSL related options. Refer to `ServerSSLOptions' for details // SSL related options. Refer to `ServerSSLOptions' for details
std::shared_ptr<ServerSSLOptions> ssl_options; bool has_ssl_options() const { return _ssl_options != NULL; }
const ServerSSLOptions& ssl_options() const { return *_ssl_options.get(); }
ServerSSLOptions* mutable_ssl_options();
// [CAUTION] This option is for implementing specialized http proxies, // [CAUTION] This option is for implementing specialized http proxies,
// most users don't need it. Don't change this option unless you fully // most users don't need it. Don't change this option unless you fully
...@@ -225,6 +228,11 @@ struct ServerOptions { ...@@ -225,6 +228,11 @@ struct ServerOptions {
// All names inside must be valid, check protocols name in global.cpp // All names inside must be valid, check protocols name in global.cpp
// Default: empty (all protocols) // Default: empty (all protocols)
std::string enabled_protocols; std::string enabled_protocols;
private:
// SSLOptions is large and not often used, allocate it on heap to
// prevent ServerOptions from being bloated in most cases.
butil::PtrContainer<ServerSSLOptions> _ssl_options;
}; };
// This struct is originally designed to contain basic statistics of the // This struct is originally designed to contain basic statistics of the
......
...@@ -229,7 +229,7 @@ int SocketMap::Insert(const SocketMapKey& key, SocketId* id, ...@@ -229,7 +229,7 @@ int SocketMap::Insert(const SocketMapKey& key, SocketId* id,
} }
SocketId tmp_id; SocketId tmp_id;
SocketOptions opt; SocketOptions opt;
opt.remote_side = key.peer; opt.remote_side = key.peer.addr;
opt.initial_ssl_ctx = ssl_ctx; opt.initial_ssl_ctx = ssl_ctx;
if (_options.socket_creator->CreateSocket(opt, &tmp_id) != 0) { if (_options.socket_creator->CreateSocket(opt, &tmp_id) != 0) {
PLOG(FATAL) << "Fail to create socket to " << key.peer; PLOG(FATAL) << "Fail to create socket to " << key.peer;
......
...@@ -23,12 +23,11 @@ ...@@ -23,12 +23,11 @@
#include "brpc/socket_id.h" // SockdetId #include "brpc/socket_id.h" // SockdetId
#include "brpc/options.pb.h" // ProtocolType #include "brpc/options.pb.h" // ProtocolType
#include "brpc/input_messenger.h" // InputMessageHandler #include "brpc/input_messenger.h" // InputMessageHandler
#include "brpc/server_node.h" // ServerNode
namespace brpc { namespace brpc {
// Global mapping from remote-side to out-going sockets created by Channels. // Different signature means that the Channel needs separate sockets.
struct ChannelSignature { struct ChannelSignature {
uint64_t data[2]; uint64_t data[2];
...@@ -52,8 +51,11 @@ struct SocketMapKey { ...@@ -52,8 +51,11 @@ struct SocketMapKey {
SocketMapKey(const butil::EndPoint& pt, const ChannelSignature& cs) SocketMapKey(const butil::EndPoint& pt, const ChannelSignature& cs)
: peer(pt), channel_signature(cs) : peer(pt), channel_signature(cs)
{} {}
SocketMapKey(const ServerNode& sn, const ChannelSignature& cs)
: peer(sn), channel_signature(cs)
{}
butil::EndPoint peer; ServerNode peer;
ChannelSignature channel_signature; ChannelSignature channel_signature;
}; };
...@@ -62,9 +64,11 @@ inline bool operator==(const SocketMapKey& k1, const SocketMapKey& k2) { ...@@ -62,9 +64,11 @@ inline bool operator==(const SocketMapKey& k1, const SocketMapKey& k2) {
}; };
struct SocketMapKeyHasher { struct SocketMapKeyHasher {
std::size_t operator()(const SocketMapKey& key) const { size_t operator()(const SocketMapKey& key) const {
return butil::DefaultHasher<butil::EndPoint>()(key.peer) ^ size_t h = butil::DefaultHasher<butil::EndPoint>()(key.peer.addr);
key.channel_signature.data[1]; h = h * 101 + butil::DefaultHasher<std::string>()(key.peer.tag);
h = h * 101 + key.channel_signature.data[1];
return h;
} }
}; };
......
...@@ -224,8 +224,8 @@ int URI::SetHttpURL(const char* url) { ...@@ -224,8 +224,8 @@ int URI::SetHttpURL(const char* url) {
return 0; return 0;
} }
int ParseHostAndPortFromURL(const char* url, std::string* host_out, int ParseURL(const char* url,
int* port_out) { std::string* schema_out, std::string* host_out, int* port_out) {
const char* p = url; const char* p = url;
// skip heading blanks // skip heading blanks
if (*p == ' ') { if (*p == ' ') {
...@@ -235,7 +235,6 @@ int ParseHostAndPortFromURL(const char* url, std::string* host_out, ...@@ -235,7 +235,6 @@ int ParseHostAndPortFromURL(const char* url, std::string* host_out,
// Find end of host, locate schema and user_info during the searching // Find end of host, locate schema and user_info during the searching
bool need_schema = true; bool need_schema = true;
bool need_user_info = true; bool need_user_info = true;
butil::StringPiece schema;
for (; true; ++p) { for (; true; ++p) {
const char action = g_url_parsing_fast_action_map[(int)*p]; const char action = g_url_parsing_fast_action_map[(int)*p];
if (action == URI_PARSE_CONTINUE) { if (action == URI_PARSE_CONTINUE) {
...@@ -247,7 +246,9 @@ int ParseHostAndPortFromURL(const char* url, std::string* host_out, ...@@ -247,7 +246,9 @@ int ParseHostAndPortFromURL(const char* url, std::string* host_out,
if (*p == ':') { if (*p == ':') {
if (p[1] == '/' && p[2] == '/' && need_schema) { if (p[1] == '/' && p[2] == '/' && need_schema) {
need_schema = false; need_schema = false;
schema.set(start, p - start); if (schema_out) {
schema_out->assign(start, p - start);
}
p += 2; p += 2;
start = p + 1; start = p + 1;
} }
...@@ -266,15 +267,12 @@ int ParseHostAndPortFromURL(const char* url, std::string* host_out, ...@@ -266,15 +267,12 @@ int ParseHostAndPortFromURL(const char* url, std::string* host_out,
} }
int port = -1; int port = -1;
const char* host_end = SplitHostAndPort(start, p, &port); const char* host_end = SplitHostAndPort(start, p, &port);
if (port < 0) { if (host_out) {
if (schema.empty() || schema == "http") {
port = 80;
} else if (schema == "https") {
port = 443;
}
}
host_out->assign(start, host_end - start); host_out->assign(start, host_end - start);
}
if (port_out) {
*port_out = port; *port_out = port;
}
return 0; return 0;
} }
......
...@@ -155,10 +155,9 @@ friend class HttpMessage; ...@@ -155,10 +155,9 @@ friend class HttpMessage;
mutable QueryMap _query_map; mutable QueryMap _query_map;
}; };
// Parse host and port from `url'. // Parse host/port/schema from `url' if the corresponding parameter is not NULL.
// When port is absent, it's set to 80 for http and 443 for https.
// Returns 0 on success, -1 otherwise. // Returns 0 on success, -1 otherwise.
int ParseHostAndPortFromURL(const char* url, std::string* host, int* port); int ParseURL(const char* url, std::string* schema, std::string* host, int* port);
inline void URI::SetQuery(const std::string& key, const std::string& value) { inline void URI::SetQuery(const std::string& key, const std::string& value) {
get_query_map()[key] = value; get_query_map()[key] = value;
......
...@@ -141,6 +141,7 @@ class MyEchoService : public ::test::EchoService { ...@@ -141,6 +141,7 @@ class MyEchoService : public ::test::EchoService {
if (req->code() != 0) { if (req->code() != 0) {
res->add_code_list(req->code()); res->add_code_list(req->code());
} }
res->set_receiving_socket_id(cntl->_current_call.sending_sock->id());
} }
}; };
...@@ -258,14 +259,17 @@ protected: ...@@ -258,14 +259,17 @@ protected:
} }
void SetUpChannel(brpc::Channel* channel, void SetUpChannel(brpc::Channel* channel,
bool single_server, bool short_connection, bool single_server,
const brpc::Authenticator* auth = NULL) { bool short_connection,
const brpc::Authenticator* auth = NULL,
std::string connection_group = std::string()) {
brpc::ChannelOptions opt; brpc::ChannelOptions opt;
if (short_connection) { if (short_connection) {
opt.connection_type = brpc::CONNECTION_TYPE_SHORT; opt.connection_type = brpc::CONNECTION_TYPE_SHORT;
} }
opt.auth = auth; opt.auth = auth;
opt.max_retry = 0; opt.max_retry = 0;
opt.connection_group = connection_group;
if (single_server) { if (single_server) {
EXPECT_EQ(0, channel->Init(_ep, &opt)); EXPECT_EQ(0, channel->Init(_ep, &opt));
} else { } else {
...@@ -405,6 +409,7 @@ protected: ...@@ -405,6 +409,7 @@ protected:
EXPECT_EQ(0, cntl.ErrorCode()) EXPECT_EQ(0, cntl.ErrorCode())
<< single_server << ", " << async << ", " << short_connection; << single_server << ", " << async << ", " << short_connection;
const uint64_t receiving_socket_id = res.receiving_socket_id();
EXPECT_EQ(0, cntl.sub_count()); EXPECT_EQ(0, cntl.sub_count());
EXPECT_TRUE(NULL == cntl.sub(-1)); EXPECT_TRUE(NULL == cntl.sub(-1));
EXPECT_TRUE(NULL == cntl.sub(0)); EXPECT_TRUE(NULL == cntl.sub(0));
...@@ -420,6 +425,47 @@ protected: ...@@ -420,6 +425,47 @@ protected:
} else { } else {
EXPECT_GE(1ul, _messenger.ConnectionCount()); EXPECT_GE(1ul, _messenger.ConnectionCount());
} }
if (single_server && !short_connection) {
// Reuse the connection
brpc::Channel channel2;
SetUpChannel(&channel2, single_server, short_connection);
cntl.Reset();
req.Clear();
res.Clear();
req.set_message(__FUNCTION__);
CallMethod(&channel2, &cntl, &req, &res, async);
EXPECT_EQ(0, cntl.ErrorCode())
<< single_server << ", " << async << ", " << short_connection;
EXPECT_EQ(receiving_socket_id, res.receiving_socket_id());
// A different connection_group does not reuse the connection
brpc::Channel channel3;
SetUpChannel(&channel3, single_server, short_connection,
NULL, "another_group");
cntl.Reset();
req.Clear();
res.Clear();
req.set_message(__FUNCTION__);
CallMethod(&channel3, &cntl, &req, &res, async);
EXPECT_EQ(0, cntl.ErrorCode())
<< single_server << ", " << async << ", " << short_connection;
const uint64_t receiving_socket_id2 = res.receiving_socket_id();
EXPECT_NE(receiving_socket_id, receiving_socket_id2);
// Channel in the same connection_group reuses the connection
// note that the leading/trailing spaces should be trimed.
brpc::Channel channel4;
SetUpChannel(&channel4, single_server, short_connection,
NULL, " another_group ");
cntl.Reset();
req.Clear();
res.Clear();
req.set_message(__FUNCTION__);
CallMethod(&channel4, &cntl, &req, &res, async);
EXPECT_EQ(0, cntl.ErrorCode())
<< single_server << ", " << async << ", " << short_connection;
EXPECT_EQ(receiving_socket_id2, res.receiving_socket_id());
}
StopAndJoin(); StopAndJoin();
} }
...@@ -1547,6 +1593,10 @@ protected: ...@@ -1547,6 +1593,10 @@ protected:
void TestAuthentication(bool single_server, void TestAuthentication(bool single_server,
bool async, bool short_connection) { bool async, bool short_connection) {
std::cout << " *** single=" << single_server
<< " async=" << async
<< " short=" << short_connection << std::endl;
ASSERT_EQ(0, StartAccept(_ep)); ASSERT_EQ(0, StartAccept(_ep));
MyAuthenticator auth; MyAuthenticator auth;
brpc::Channel channel; brpc::Channel channel;
...@@ -1809,7 +1859,7 @@ TEST_F(ChannelTest, init_as_single_server) { ...@@ -1809,7 +1859,7 @@ TEST_F(ChannelTest, init_as_single_server) {
ASSERT_EQ(ep, channel._server_address); ASSERT_EQ(ep, channel._server_address);
brpc::SocketId id; brpc::SocketId id;
ASSERT_EQ(0, brpc::SocketMapFind(ep, &id)); ASSERT_EQ(0, brpc::SocketMapFind(brpc::SocketMapKey(ep), &id));
ASSERT_EQ(id, channel._server_id); ASSERT_EQ(id, channel._server_id);
const int NUM = 10; const int NUM = 10;
......
...@@ -127,7 +127,7 @@ TEST_F(SocketMapTest, max_pool_size) { ...@@ -127,7 +127,7 @@ TEST_F(SocketMapTest, max_pool_size) {
} //namespace } //namespace
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
butil::str2endpoint("127.0.0.1:12345", &g_key.peer); butil::str2endpoint("127.0.0.1:12345", &g_key.peer.addr);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }
...@@ -95,7 +95,7 @@ TEST_F(SSLTest, sanity) { ...@@ -95,7 +95,7 @@ TEST_F(SSLTest, sanity) {
brpc::CertInfo cert; brpc::CertInfo cert;
cert.certificate = "cert1.crt"; cert.certificate = "cert1.crt";
cert.private_key = "cert1.key"; cert.private_key = "cert1.key";
options.ssl_options.default_cert = cert; options.mutable_ssl_options()->default_cert = cert;
EchoServiceImpl echo_svc; EchoServiceImpl echo_svc;
ASSERT_EQ(0, server.AddService( ASSERT_EQ(0, server.AddService(
...@@ -108,7 +108,7 @@ TEST_F(SSLTest, sanity) { ...@@ -108,7 +108,7 @@ TEST_F(SSLTest, sanity) {
{ {
brpc::Channel channel; brpc::Channel channel;
brpc::ChannelOptions coptions; brpc::ChannelOptions coptions;
coptions.ssl_options.enable = true; coptions.mutable_ssl_options();
ASSERT_EQ(0, channel.Init("localhost", port, &coptions)); ASSERT_EQ(0, channel.Init("localhost", port, &coptions));
brpc::Controller cntl; brpc::Controller cntl;
...@@ -124,7 +124,7 @@ TEST_F(SSLTest, sanity) { ...@@ -124,7 +124,7 @@ TEST_F(SSLTest, sanity) {
{ {
brpc::Channel channel; brpc::Channel channel;
brpc::ChannelOptions coptions; brpc::ChannelOptions coptions;
coptions.ssl_options.enable = true; coptions.mutable_ssl_options();
ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions)); ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions));
for (int i = 0; i < NUM; ++i) { for (int i = 0; i < NUM; ++i) {
google::protobuf::Closure* thrd_func = google::protobuf::Closure* thrd_func =
...@@ -140,7 +140,7 @@ TEST_F(SSLTest, sanity) { ...@@ -140,7 +140,7 @@ TEST_F(SSLTest, sanity) {
brpc::Channel channel; brpc::Channel channel;
brpc::ChannelOptions coptions; brpc::ChannelOptions coptions;
coptions.protocol = "http"; coptions.protocol = "http";
coptions.ssl_options.enable = true; coptions.mutable_ssl_options();
ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions)); ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions));
for (int i = 0; i < NUM; ++i) { for (int i = 0; i < NUM; ++i) {
google::protobuf::Closure* thrd_func = google::protobuf::Closure* thrd_func =
...@@ -160,8 +160,7 @@ void CheckCert(const char* cname, const char* cert) { ...@@ -160,8 +160,7 @@ void CheckCert(const char* cname, const char* cert) {
const int port = 8613; const int port = 8613;
brpc::Channel channel; brpc::Channel channel;
brpc::ChannelOptions coptions; brpc::ChannelOptions coptions;
coptions.ssl_options.enable = true; coptions.mutable_ssl_options()->sni_name = cname;
coptions.ssl_options.sni_name = cname;
ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions)); ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions));
SendMultipleRPC(&channel, 1); SendMultipleRPC(&channel, 1);
...@@ -199,14 +198,14 @@ TEST_F(SSLTest, ssl_sni) { ...@@ -199,14 +198,14 @@ TEST_F(SSLTest, ssl_sni) {
cert.certificate = "cert1.crt"; cert.certificate = "cert1.crt";
cert.private_key = "cert1.key"; cert.private_key = "cert1.key";
cert.sni_filters.push_back("cert1.com"); cert.sni_filters.push_back("cert1.com");
options.ssl_options.default_cert = cert; options.mutable_ssl_options()->default_cert = cert;
} }
{ {
brpc::CertInfo cert; brpc::CertInfo cert;
cert.certificate = GetRawPemString("cert2.crt"); cert.certificate = GetRawPemString("cert2.crt");
cert.private_key = GetRawPemString("cert2.key"); cert.private_key = GetRawPemString("cert2.key");
cert.sni_filters.push_back("*.cert2.com"); cert.sni_filters.push_back("*.cert2.com");
options.ssl_options.certs.push_back(cert); options.mutable_ssl_options()->certs.push_back(cert);
} }
EchoServiceImpl echo_svc; EchoServiceImpl echo_svc;
ASSERT_EQ(0, server.AddService( ASSERT_EQ(0, server.AddService(
...@@ -230,7 +229,7 @@ TEST_F(SSLTest, ssl_reload) { ...@@ -230,7 +229,7 @@ TEST_F(SSLTest, ssl_reload) {
cert.certificate = "cert1.crt"; cert.certificate = "cert1.crt";
cert.private_key = "cert1.key"; cert.private_key = "cert1.key";
cert.sni_filters.push_back("cert1.com"); cert.sni_filters.push_back("cert1.com");
options.ssl_options.default_cert = cert; options.mutable_ssl_options()->default_cert = cert;
} }
EchoServiceImpl echo_svc; EchoServiceImpl echo_svc;
ASSERT_EQ(0, server.AddService( ASSERT_EQ(0, server.AddService(
...@@ -318,7 +317,6 @@ TEST_F(SSLTest, ssl_perf) { ...@@ -318,7 +317,6 @@ TEST_F(SSLTest, ssl_perf) {
ASSERT_GT(servfd, 0); ASSERT_GT(servfd, 0);
brpc::ChannelSSLOptions opt; brpc::ChannelSSLOptions opt;
opt.enable = true;
SSL_CTX* cli_ctx = brpc::CreateClientSSLContext(opt); SSL_CTX* cli_ctx = brpc::CreateClientSSLContext(opt);
SSL_CTX* serv_ctx = SSL_CTX* serv_ctx =
brpc::CreateServerSSLContext("cert1.crt", "cert1.key", brpc::CreateServerSSLContext("cert1.crt", "cert1.key",
......
...@@ -20,9 +20,11 @@ TEST(URITest, everything) { ...@@ -20,9 +20,11 @@ TEST(URITest, everything) {
ASSERT_EQ(*uri.GetQuery("wd"), "uri"); ASSERT_EQ(*uri.GetQuery("wd"), "uri");
ASSERT_FALSE(uri.GetQuery("nonkey")); ASSERT_FALSE(uri.GetQuery("nonkey"));
std::string schema;
std::string host_out; std::string host_out;
int port_out = -1; int port_out = -1;
brpc::ParseHostAndPortFromURL(uri_str.c_str(), &host_out, &port_out); brpc::ParseURL(uri_str.c_str(), &schema, &host_out, &port_out);
ASSERT_EQ("foobar", schema);
ASSERT_EQ("www.baidu.com", host_out); ASSERT_EQ("www.baidu.com", host_out);
ASSERT_EQ(80, port_out); ASSERT_EQ(80, port_out);
} }
......
...@@ -15,6 +15,7 @@ message EchoRequest { ...@@ -15,6 +15,7 @@ message EchoRequest {
message EchoResponse { message EchoResponse {
required string message = 1; required string message = 1;
repeated int32 code_list = 2; repeated int32 code_list = 2;
optional uint64 receiving_socket_id = 3;
}; };
message ComboRequest { message ComboRequest {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment