Commit a9e95487 authored by gejun's avatar gejun

Remove Socket._options and make app_connect shared

parent 79dcf456
...@@ -1019,8 +1019,8 @@ private: ...@@ -1019,8 +1019,8 @@ private:
class RtmpConnect : public AppConnect { class RtmpConnect : public AppConnect {
public: public:
// @AppConnect // @AppConnect
void StartConnect(const Socket* s, void (*done)(int, void*), void* data); void StartConnect(const Socket* s, void (*done)(int, void*), void* data) override;
void StopConnect(Socket* s); void StopConnect(Socket* s) override;
}; };
void RtmpConnect::StartConnect( void RtmpConnect::StartConnect(
...@@ -1067,7 +1067,6 @@ void RtmpConnect::StopConnect(Socket* s) { ...@@ -1067,7 +1067,6 @@ void RtmpConnect::StopConnect(Socket* s) {
} else { } else {
ctx->OnConnected(EFAILEDSOCKET); ctx->OnConnected(EFAILEDSOCKET);
} }
delete this;
} }
class RtmpSocketCreator : public SocketCreator { class RtmpSocketCreator : public SocketCreator {
...@@ -1078,7 +1077,7 @@ public: ...@@ -1078,7 +1077,7 @@ public:
int CreateSocket(const SocketOptions& opt, SocketId* id) { int CreateSocket(const SocketOptions& opt, SocketId* id) {
SocketOptions sock_opt = opt; SocketOptions sock_opt = opt;
sock_opt.app_connect = new RtmpConnect; sock_opt.app_connect = std::make_shared<RtmpConnect>();
sock_opt.initial_parsing_context = new policy::RtmpContext(&_connect_options, NULL); sock_opt.initial_parsing_context = new policy::RtmpContext(&_connect_options, NULL);
return get_client_side_messenger()->Create(sock_opt, id); return get_client_side_messenger()->Create(sock_opt, id);
} }
...@@ -1090,7 +1089,7 @@ private: ...@@ -1090,7 +1089,7 @@ private:
int RtmpClientImpl::CreateSocket(const butil::EndPoint& pt, SocketId* id) { int RtmpClientImpl::CreateSocket(const butil::EndPoint& pt, SocketId* id) {
SocketOptions sock_opt; SocketOptions sock_opt;
sock_opt.remote_side = pt; sock_opt.remote_side = pt;
sock_opt.app_connect = new RtmpConnect; sock_opt.app_connect = std::make_shared<RtmpConnect>();
sock_opt.initial_parsing_context = new policy::RtmpContext(&_connect_options, NULL); sock_opt.initial_parsing_context = new policy::RtmpContext(&_connect_options, NULL);
return get_client_side_messenger()->Create(sock_opt, id); return get_client_side_messenger()->Create(sock_opt, id);
} }
......
...@@ -439,7 +439,6 @@ Socket::Socket(Forbidden) ...@@ -439,7 +439,6 @@ Socket::Socket(Forbidden)
, _on_edge_triggered_events(NULL) , _on_edge_triggered_events(NULL)
, _user(NULL) , _user(NULL)
, _conn(NULL) , _conn(NULL)
, _app_connect(NULL)
, _this_id(0) , _this_id(0)
, _preferred_index(-1) , _preferred_index(-1)
, _hc_count(0) , _hc_count(0)
...@@ -639,6 +638,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) { ...@@ -639,6 +638,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) {
// Disable SSL check if there is no SSL context // Disable SSL check if there is no SSL context
m->_ssl_state = (options.initial_ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN); m->_ssl_state = (options.initial_ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN);
m->_ssl_session = NULL; m->_ssl_session = NULL;
m->_ssl_ctx = options.initial_ssl_ctx;
m->_connection_type_for_progressive_read = CONNECTION_TYPE_UNKNOWN; m->_connection_type_for_progressive_read = CONNECTION_TYPE_UNKNOWN;
m->_controller_released_socket.store(false, butil::memory_order_relaxed); m->_controller_released_socket.store(false, butil::memory_order_relaxed);
m->_overcrowded = false; m->_overcrowded = false;
...@@ -657,7 +657,6 @@ int Socket::Create(const SocketOptions& options, SocketId* id) { ...@@ -657,7 +657,6 @@ int Socket::Create(const SocketOptions& options, SocketId* id) {
} }
m->_last_writetime_us.store(cpuwide_now, butil::memory_order_relaxed); m->_last_writetime_us.store(cpuwide_now, butil::memory_order_relaxed);
m->_unwritten_bytes.store(0, butil::memory_order_relaxed); m->_unwritten_bytes.store(0, butil::memory_order_relaxed);
m->_options = options;
CHECK(NULL == m->_write_head.load(butil::memory_order_relaxed)); CHECK(NULL == m->_write_head.load(butil::memory_order_relaxed));
// Must be last one! Internal fields of this Socket may be access // Must be last one! Internal fields of this Socket may be access
// just after calling ResetFileDescriptor. // just after calling ResetFileDescriptor.
...@@ -1010,9 +1009,9 @@ bool HealthCheckTask::OnTriggeringTask(timespec* next_abstime) { ...@@ -1010,9 +1009,9 @@ bool HealthCheckTask::OnTriggeringTask(timespec* next_abstime) {
void Socket::OnRecycle() { void Socket::OnRecycle() {
const bool create_by_connect = CreatedByConnect(); const bool create_by_connect = CreatedByConnect();
if (_app_connect) { if (_app_connect) {
AppConnect* const saved_app_connect = _app_connect; std::shared_ptr<AppConnect> tmp;
_app_connect = NULL; _app_connect.swap(tmp);
saved_app_connect->StopConnect(this); tmp->StopConnect(this);
} }
if (_conn) { if (_conn) {
SocketConnection* const saved_conn = _conn; SocketConnection* const saved_conn = _conn;
...@@ -1051,7 +1050,7 @@ void Socket::OnRecycle() { ...@@ -1051,7 +1050,7 @@ void Socket::OnRecycle() {
_ssl_session = NULL; _ssl_session = NULL;
} }
_options.initial_ssl_ctx = NULL; _ssl_ctx = NULL;
delete _pipeline_q; delete _pipeline_q;
_pipeline_q = NULL; _pipeline_q = NULL;
...@@ -1163,7 +1162,7 @@ int Socket::WaitEpollOut(int fd, bool pollin, const timespec* abstime) { ...@@ -1163,7 +1162,7 @@ int Socket::WaitEpollOut(int fd, bool pollin, const timespec* abstime) {
int Socket::Connect(const timespec* abstime, int Socket::Connect(const timespec* abstime,
int (*on_connect)(int, int, void*), void* data) { int (*on_connect)(int, int, void*), void* data) {
if (_options.initial_ssl_ctx) { if (_ssl_ctx) {
_ssl_state = SSL_CONNECTING; _ssl_state = SSL_CONNECTING;
} else { } else {
_ssl_state = SSL_OFF; _ssl_state = SSL_OFF;
...@@ -1781,7 +1780,7 @@ ssize_t Socket::DoWrite(WriteRequest* req) { ...@@ -1781,7 +1780,7 @@ ssize_t Socket::DoWrite(WriteRequest* req) {
} }
int Socket::SSLHandshake(int fd, bool server_mode) { int Socket::SSLHandshake(int fd, bool server_mode) {
if (_options.initial_ssl_ctx == NULL) { if (_ssl_ctx == NULL) {
if (server_mode) { if (server_mode) {
LOG(ERROR) << "Lack SSL configuration to handle SSL request"; LOG(ERROR) << "Lack SSL configuration to handle SSL request";
return -1; return -1;
...@@ -1794,14 +1793,14 @@ int Socket::SSLHandshake(int fd, bool server_mode) { ...@@ -1794,14 +1793,14 @@ int Socket::SSLHandshake(int fd, bool server_mode) {
// Free the last session, which may be deprecated when socket failed // Free the last session, which may be deprecated when socket failed
SSL_free(_ssl_session); SSL_free(_ssl_session);
} }
_ssl_session = CreateSSLSession(_options.initial_ssl_ctx->raw_ctx, id(), fd, server_mode); _ssl_session = CreateSSLSession(_ssl_ctx->raw_ctx, id(), fd, server_mode);
if (_ssl_session == NULL) { if (_ssl_session == NULL) {
LOG(ERROR) << "Fail to CreateSSLSession"; LOG(ERROR) << "Fail to CreateSSLSession";
return -1; return -1;
} }
#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
if (!_options.initial_ssl_ctx->sni_name.empty()) { if (!_ssl_ctx->sni_name.empty()) {
SSL_set_tlsext_host_name(_ssl_session, _options.initial_ssl_ctx->sni_name.c_str()); SSL_set_tlsext_host_name(_ssl_session, _ssl_ctx->sni_name.c_str());
} }
#endif #endif
...@@ -2174,7 +2173,7 @@ void Socket::DebugSocket(std::ostream& os, SocketId id) { ...@@ -2174,7 +2173,7 @@ void Socket::DebugSocket(std::ostream& os, SocketId id) {
<< "\ncid=" << ptr->_correlation_id << "\ncid=" << ptr->_correlation_id
<< "\nwrite_head=" << ptr->_write_head.load(butil::memory_order_relaxed) << "\nwrite_head=" << ptr->_write_head.load(butil::memory_order_relaxed)
<< "\nssl_state=" << SSLStateToString(ssl_state); << "\nssl_state=" << SSLStateToString(ssl_state);
const SocketSSLContext* ssl_ctx = ptr->_options.initial_ssl_ctx.get(); const SocketSSLContext* ssl_ctx = ptr->_ssl_ctx.get();
if (ssl_ctx) { if (ssl_ctx) {
os << "\ninitial_ssl_ctx=" << ssl_ctx->raw_ctx; os << "\ninitial_ssl_ctx=" << ssl_ctx->raw_ctx;
if (!ssl_ctx->sni_name.empty()) { if (!ssl_ctx->sni_name.empty()) {
...@@ -2461,7 +2460,14 @@ int Socket::GetPooledSocket(Socket* main_socket, ...@@ -2461,7 +2460,14 @@ int Socket::GetPooledSocket(Socket* main_socket,
// Create socket_pool optimistically. // Create socket_pool optimistically.
SocketPool* socket_pool = main_sp->socket_pool.load(butil::memory_order_consume); SocketPool* socket_pool = main_sp->socket_pool.load(butil::memory_order_consume);
if (socket_pool == NULL) { if (socket_pool == NULL) {
socket_pool = new SocketPool(main_socket->_options); SocketOptions opt;
opt.remote_side = main_socket->remote_side();
opt.user = main_socket->user();
opt.on_edge_triggered_events = main_socket->_on_edge_triggered_events;
opt.initial_ssl_ctx = main_socket->_ssl_ctx;
opt.keytable_pool = main_socket->_keytable_pool;
opt.app_connect = main_socket->_app_connect;
socket_pool = new SocketPool(opt);
SocketPool* expected = NULL; SocketPool* expected = NULL;
if (!main_sp->socket_pool.compare_exchange_strong( if (!main_sp->socket_pool.compare_exchange_strong(
expected, socket_pool, butil::memory_order_acq_rel)) { expected, socket_pool, butil::memory_order_acq_rel)) {
...@@ -2543,8 +2549,13 @@ int Socket::GetShortSocket(Socket* main_socket, ...@@ -2543,8 +2549,13 @@ int Socket::GetShortSocket(Socket* main_socket,
return -1; return -1;
} }
SocketId id; SocketId id;
SocketOptions opt = main_socket->_options; SocketOptions opt;
opt.health_check_interval_s = -1; opt.remote_side = main_socket->remote_side();
opt.user = main_socket->user();
opt.on_edge_triggered_events = main_socket->_on_edge_triggered_events;
opt.initial_ssl_ctx = main_socket->_ssl_ctx;
opt.keytable_pool = main_socket->_keytable_pool;
opt.app_connect = main_socket->_app_connect;
if (get_client_side_messenger()->Create(opt, &id) != 0) { if (get_client_side_messenger()->Create(opt, &id) != 0) {
return -1; return -1;
} }
......
...@@ -92,6 +92,8 @@ public: ...@@ -92,6 +92,8 @@ public:
// Application-level connect. After TCP connected, the client sends some // Application-level connect. After TCP connected, the client sends some
// sort of "connect" message to the server to establish application-level // sort of "connect" message to the server to establish application-level
// connection. // connection.
// Instances of AppConnect may be shared by multiple sockets and often
// created by std::make_shared<T>() where T implements AppConnect
class AppConnect { class AppConnect {
public: public:
virtual ~AppConnect() {} virtual ~AppConnect() {}
...@@ -108,7 +110,6 @@ public: ...@@ -108,7 +110,6 @@ public:
// Called when the host socket is setfailed or about to be recycled. // Called when the host socket is setfailed or about to be recycled.
// If the AppConnect is still in-progress, it should be canceled properly. // If the AppConnect is still in-progress, it should be canceled properly.
// This callback can delete self.
virtual void StopConnect(Socket*) = 0; virtual void StopConnect(Socket*) = 0;
}; };
...@@ -165,7 +166,7 @@ struct SocketOptions { ...@@ -165,7 +166,7 @@ struct SocketOptions {
std::shared_ptr<SocketSSLContext> initial_ssl_ctx; std::shared_ptr<SocketSSLContext> initial_ssl_ctx;
bthread_keytable_pool_t* keytable_pool; bthread_keytable_pool_t* keytable_pool;
SocketConnection* conn; SocketConnection* conn;
AppConnect* app_connect; std::shared_ptr<AppConnect> app_connect;
// The created socket will set parsing_context with this value. // The created socket will set parsing_context with this value.
Destroyable* initial_parsing_context; Destroyable* initial_parsing_context;
}; };
...@@ -267,7 +268,6 @@ public: ...@@ -267,7 +268,6 @@ public:
// `conn' parameter passed to Create() // `conn' parameter passed to Create()
void set_conn(SocketConnection* conn) { _conn = conn; } void set_conn(SocketConnection* conn) { _conn = conn; }
SocketConnection* conn() const { return _conn; } SocketConnection* conn() const { return _conn; }
AppConnect* app_connect() const { return _app_connect; }
// Saved contexts for parsing. Reset before trying new protocols and // Saved contexts for parsing. Reset before trying new protocols and
// recycling of the socket. // recycling of the socket.
...@@ -648,9 +648,6 @@ private: ...@@ -648,9 +648,6 @@ private:
// carefully before implementing the callback. // carefully before implementing the callback.
void (*_on_edge_triggered_events)(Socket*); void (*_on_edge_triggered_events)(Socket*);
// Original options used to create this Socket
SocketOptions _options;
// A set of callbacks to monitor important events of this socket. // A set of callbacks to monitor important events of this socket.
// Initialized by SocketOptions.user // Initialized by SocketOptions.user
SocketUser* _user; SocketUser* _user;
...@@ -660,7 +657,7 @@ private: ...@@ -660,7 +657,7 @@ private:
// User-level connection after TCP-connected. // User-level connection after TCP-connected.
// Initialized by SocketOptions.app_connect. // Initialized by SocketOptions.app_connect.
AppConnect* _app_connect; std::shared_ptr<AppConnect> _app_connect;
// Identifier of this Socket in ResourcePool // Identifier of this Socket in ResourcePool
SocketId _this_id; SocketId _this_id;
...@@ -718,6 +715,7 @@ private: ...@@ -718,6 +715,7 @@ private:
SSLState _ssl_state; SSLState _ssl_state;
SSL* _ssl_session; // owner SSL* _ssl_session; // owner
std::shared_ptr<SocketSSLContext> _ssl_ctx;
// Pass from controller, for progressive reading. // Pass from controller, for progressive reading.
ConnectionType _connection_type_for_progressive_read; ConnectionType _connection_type_for_progressive_read;
......
...@@ -283,7 +283,6 @@ public: ...@@ -283,7 +283,6 @@ public:
} }
void StopConnect(brpc::Socket*) { void StopConnect(brpc::Socket*) {
LOG(INFO) << "Stop application-level connect"; LOG(INFO) << "Stop application-level connect";
delete this;
} }
void MakeConnectDone() { void MakeConnectDone() {
_done(0, _data); _done(0, _data);
...@@ -313,7 +312,7 @@ TEST_F(SocketTest, single_threaded_connect_and_write) { ...@@ -313,7 +312,7 @@ TEST_F(SocketTest, single_threaded_connect_and_write) {
brpc::SocketId id = 8888; brpc::SocketId id = 8888;
brpc::SocketOptions options; brpc::SocketOptions options;
options.remote_side = point; options.remote_side = point;
MyConnect* my_connect = new MyConnect; std::shared_ptr<MyConnect> my_connect = std::make_shared<MyConnect>();
options.app_connect = my_connect; options.app_connect = my_connect;
options.user = new CheckRecycle; options.user = new CheckRecycle;
ASSERT_EQ(0, brpc::Socket::Create(options, &id)); ASSERT_EQ(0, brpc::Socket::Create(options, &id));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment