Commit 7c3496a7 authored by Richard Newton's avatar Richard Newton

Fix race condition and support multiple socket connects before bind.

parent 6fefa416
...@@ -64,6 +64,7 @@ tests/test_req_request_ids ...@@ -64,6 +64,7 @@ tests/test_req_request_ids
tests/test_req_strict tests/test_req_strict
tests/test_fork tests/test_fork
tests/test_conflate tests/test_conflate
tests/test_inproc_connect_before_bind
tests/test_linger tests/test_linger
tests/test_security_null tests/test_security_null
tests/test_security_null.opp tests/test_security_null.opp
......
...@@ -55,6 +55,7 @@ namespace zmq ...@@ -55,6 +55,7 @@ namespace zmq
term_ack, term_ack,
reap, reap,
reaped, reaped,
inproc_connected,
done done
} type; } type;
......
...@@ -396,28 +396,49 @@ void zmq::ctx_t::pend_connection (const char *addr_, const pending_connection_t ...@@ -396,28 +396,49 @@ void zmq::ctx_t::pend_connection (const char *addr_, const pending_connection_t
{ {
endpoints_sync.lock (); endpoints_sync.lock ();
// Todo, use multimap to support multiple pending connections endpoints_t::iterator it = endpoints.find (addr_);
pending_connections[addr_] = pending_connection_; if (it == endpoints.end ())
{
// Still no bind.
pending_connection_.socket->inc_seqnum ();
pending_connections.insert (pending_connections_t::value_type (std::string (addr_), pending_connection_));
}
else
{
// Bind has happened in the mean time, connect directly
pending_connection_t copy = pending_connection_;
it->second.socket->inc_seqnum();
copy.pipe->set_tid(it->second.socket->get_tid());
command_t cmd;
cmd.type = command_t::bind;
cmd.args.bind.pipe = copy.pipe;
it->second.socket->process_command(cmd);
}
endpoints_sync.unlock (); endpoints_sync.unlock ();
} }
zmq::pending_connection_t zmq::ctx_t::next_pending_connection(const char *addr_) void zmq::ctx_t::connect_pending (const char *addr_, zmq::socket_base_t *bind_socket_)
{ {
endpoints_sync.lock (); endpoints_sync.lock ();
pending_connections_t::iterator it = pending_connections.find (addr_); std::pair<pending_connections_t::iterator, pending_connections_t::iterator> pending = pending_connections.equal_range(addr_);
if (it == pending_connections.end ()) {
endpoints_sync.unlock (); for (pending_connections_t::iterator p = pending.first; p != pending.second; ++p)
pending_connection_t empty = {NULL, NULL}; {
return empty; bind_socket_->inc_seqnum();
p->second.pipe->set_tid(bind_socket_->get_tid());
command_t cmd;
cmd.type = command_t::bind;
cmd.args.bind.pipe = p->second.pipe;
bind_socket_->process_command(cmd);
bind_socket_->send_inproc_connected(p->second.socket);
} }
pending_connection_t pending_connection = it->second;
pending_connections.erase(it); pending_connections.erase(pending.first, pending.second);
endpoints_sync.unlock (); endpoints_sync.unlock ();
return pending_connection;
} }
// The last used socket ID, or 0 if no socket was used so far. Note that this // The last used socket ID, or 0 if no socket was used so far. Note that this
......
...@@ -109,7 +109,7 @@ namespace zmq ...@@ -109,7 +109,7 @@ namespace zmq
void unregister_endpoints (zmq::socket_base_t *socket_); void unregister_endpoints (zmq::socket_base_t *socket_);
endpoint_t find_endpoint (const char *addr_); endpoint_t find_endpoint (const char *addr_);
void pend_connection (const char *addr_, const pending_connection_t &pending_connection_); void pend_connection (const char *addr_, const pending_connection_t &pending_connection_);
pending_connection_t next_pending_connection (const char *addr_); void connect_pending (const char *addr_, zmq::socket_base_t *bind_socket_);
enum { enum {
term_tid = 0, term_tid = 0,
...@@ -166,7 +166,7 @@ namespace zmq ...@@ -166,7 +166,7 @@ namespace zmq
endpoints_t endpoints; endpoints_t endpoints;
// List of inproc connection endpoints pending a bind // List of inproc connection endpoints pending a bind
typedef std::map <std::string, pending_connection_t> pending_connections_t; typedef std::multimap <std::string, pending_connection_t> pending_connections_t;
pending_connections_t pending_connections; pending_connections_t pending_connections;
// Synchronisation of access to the list of inproc endpoints. // Synchronisation of access to the list of inproc endpoints.
......
...@@ -127,6 +127,10 @@ void zmq::object_t::process_command (command_t &cmd_) ...@@ -127,6 +127,10 @@ void zmq::object_t::process_command (command_t &cmd_)
process_reaped (); process_reaped ();
break; break;
case command_t::inproc_connected:
process_seqnum ();
break;
case command_t::done: case command_t::done:
default: default:
zmq_assert (false); zmq_assert (false);
...@@ -153,9 +157,9 @@ void zmq::object_t::pend_connection (const char *addr_, const pending_connection ...@@ -153,9 +157,9 @@ void zmq::object_t::pend_connection (const char *addr_, const pending_connection
ctx->pend_connection (addr_, pending_connection_); ctx->pend_connection (addr_, pending_connection_);
} }
zmq::pending_connection_t zmq::object_t::next_pending_connection (const char *addr_) void zmq::object_t::connect_pending (const char *addr_, zmq::socket_base_t *bind_socket_)
{ {
return ctx->next_pending_connection(addr_); return ctx->connect_pending(addr_, bind_socket_);
} }
void zmq::object_t::destroy_socket (socket_base_t *socket_) void zmq::object_t::destroy_socket (socket_base_t *socket_)
...@@ -312,6 +316,14 @@ void zmq::object_t::send_reaped () ...@@ -312,6 +316,14 @@ void zmq::object_t::send_reaped ()
send_command (cmd); send_command (cmd);
} }
void zmq::object_t::send_inproc_connected (zmq::socket_base_t *socket_)
{
command_t cmd;
cmd.destination = socket_;
cmd.type = command_t::inproc_connected;
send_command (cmd);
}
void zmq::object_t::send_done () void zmq::object_t::send_done ()
{ {
command_t cmd; command_t cmd;
......
...@@ -51,6 +51,7 @@ namespace zmq ...@@ -51,6 +51,7 @@ namespace zmq
void set_tid(uint32_t id); void set_tid(uint32_t id);
ctx_t *get_ctx (); ctx_t *get_ctx ();
void process_command (zmq::command_t &cmd_); void process_command (zmq::command_t &cmd_);
void send_inproc_connected (zmq::socket_base_t *socket_);
protected: protected:
...@@ -60,7 +61,7 @@ namespace zmq ...@@ -60,7 +61,7 @@ namespace zmq
void unregister_endpoints (zmq::socket_base_t *socket_); void unregister_endpoints (zmq::socket_base_t *socket_);
zmq::endpoint_t find_endpoint (const char *addr_); zmq::endpoint_t find_endpoint (const char *addr_);
void pend_connection (const char *addr_, const pending_connection_t &pending_connection_); void pend_connection (const char *addr_, const pending_connection_t &pending_connection_);
zmq::pending_connection_t next_pending_connection (const char *addr_); void connect_pending (const char *addr_, zmq::socket_base_t *bind_socket_);
void destroy_socket (zmq::socket_base_t *socket_); void destroy_socket (zmq::socket_base_t *socket_);
......
...@@ -342,52 +342,8 @@ int zmq::socket_base_t::bind (const char *addr_) ...@@ -342,52 +342,8 @@ int zmq::socket_base_t::bind (const char *addr_)
endpoint_t endpoint = {this, options}; endpoint_t endpoint = {this, options};
int rc = register_endpoint (addr_, endpoint); int rc = register_endpoint (addr_, endpoint);
if (rc == 0) { if (rc == 0) {
// Save last endpoint URI connect_pending(addr_, this);
last_endpoint.assign (addr_); last_endpoint.assign (addr_);
pending_connection_t pending_connection = next_pending_connection(addr_);
while (pending_connection.pipe != NULL)
{
inc_seqnum();
//// If required, send the identity of the local socket to the peer.
//if (peer.options.recv_identity) {
// msg_t id;
// rc = id.init_size (options.identity_size);
// errno_assert (rc == 0);
// memcpy (id.data (), options.identity, options.identity_size);
// id.set_flags (msg_t::identity);
// bool written = new_pipes [0]->write (&id);
// zmq_assert (written);
// new_pipes [0]->flush ();
//}
//// If required, send the identity of the peer to the local socket.
//if (options.recv_identity) {
// msg_t id;
// rc = id.init_size (peer.options.identity_size);
// errno_assert (rc == 0);
// memcpy (id.data (), peer.options.identity, peer.options.identity_size);
// id.set_flags (msg_t::identity);
// bool written = new_pipes [1]->write (&id);
// zmq_assert (written);
// new_pipes [1]->flush ();
//}
//// Attach remote end of the pipe to the peer socket. Note that peer's
//// seqnum was incremented in find_endpoint function. We don't need it
//// increased here.
//send_bind (peer.socket, new_pipes [1], false);
pending_connection.pipe->set_tid(get_tid());
command_t cmd;
cmd.type = command_t::bind;
cmd.args.bind.pipe = pending_connection.pipe;
process_command(cmd);
pending_connection = next_pending_connection(addr_);
}
} }
return rc; return rc;
} }
......
...@@ -36,7 +36,8 @@ noinst_PROGRAMS = test_system \ ...@@ -36,7 +36,8 @@ noinst_PROGRAMS = test_system \
test_spec_pushpull \ test_spec_pushpull \
test_req_request_ids \ test_req_request_ids \
test_req_strict \ test_req_strict \
test_conflate test_conflate \
test_inproc_connect_before_bind
if !ON_MINGW if !ON_MINGW
noinst_PROGRAMS += test_shutdown_stress \ noinst_PROGRAMS += test_shutdown_stress \
...@@ -80,6 +81,7 @@ test_spec_pushpull_SOURCES = test_spec_pushpull.cpp ...@@ -80,6 +81,7 @@ test_spec_pushpull_SOURCES = test_spec_pushpull.cpp
test_req_request_ids_SOURCES = test_req_request_ids.cpp test_req_request_ids_SOURCES = test_req_request_ids.cpp
test_req_strict_SOURCES = test_req_strict.cpp test_req_strict_SOURCES = test_req_strict.cpp
test_conflate_SOURCES = test_conflate.cpp test_conflate_SOURCES = test_conflate.cpp
test_inproc_connect_before_bind_SOURCES = test_inproc_connect_before_bind.cpp
if !ON_MINGW if !ON_MINGW
test_shutdown_stress_SOURCES = test_shutdown_stress.cpp test_shutdown_stress_SOURCES = test_shutdown_stress.cpp
test_pair_ipc_SOURCES = test_pair_ipc.cpp testutil.hpp test_pair_ipc_SOURCES = test_pair_ipc.cpp testutil.hpp
......
...@@ -17,9 +17,27 @@ ...@@ -17,9 +17,27 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
#include "../include/zmq_utils.h"
#include <stdio.h> #include <stdio.h>
#include "testutil.hpp" #include "testutil.hpp"
static void pusher (void *ctx)
{
// Connect first
void *connectSocket = zmq_socket (ctx, ZMQ_PAIR);
assert (connectSocket);
int rc = zmq_connect (connectSocket, "inproc://a");
assert (rc == 0);
// Queue up some data
rc = zmq_send_const (connectSocket, "foobar", 6, 0);
assert (rc == 6);
// Cleanup
rc = zmq_close (connectSocket);
assert (rc == 0);
}
void test_bind_before_connect() void test_bind_before_connect()
{ {
void *ctx = zmq_ctx_new (); void *ctx = zmq_ctx_new ();
...@@ -45,7 +63,7 @@ void test_bind_before_connect() ...@@ -45,7 +63,7 @@ void test_bind_before_connect()
zmq_msg_t msg; zmq_msg_t msg;
rc = zmq_msg_init (&msg); rc = zmq_msg_init (&msg);
assert (rc == 0); assert (rc == 0);
rc = zmq_msg_recv (&msg, bindSocket, ZMQ_NOBLOCK); rc = zmq_msg_recv (&msg, bindSocket, 0);
assert (rc == 6); assert (rc == 6);
void *data = zmq_msg_data (&msg); void *data = zmq_msg_data (&msg);
assert (memcmp ("foobar", data, 6) == 0); assert (memcmp ("foobar", data, 6) == 0);
...@@ -72,7 +90,6 @@ void test_connect_before_bind() ...@@ -72,7 +90,6 @@ void test_connect_before_bind()
int rc = zmq_connect (connectSocket, "inproc://a"); int rc = zmq_connect (connectSocket, "inproc://a");
assert (rc == 0); assert (rc == 0);
// Queue up some data // Queue up some data
rc = zmq_send_const (connectSocket, "foobar", 6, 0); rc = zmq_send_const (connectSocket, "foobar", 6, 0);
assert (rc == 6); assert (rc == 6);
...@@ -87,7 +104,7 @@ void test_connect_before_bind() ...@@ -87,7 +104,7 @@ void test_connect_before_bind()
zmq_msg_t msg; zmq_msg_t msg;
rc = zmq_msg_init (&msg); rc = zmq_msg_init (&msg);
assert (rc == 0); assert (rc == 0);
rc = zmq_msg_recv (&msg, bindSocket, ZMQ_NOBLOCK); rc = zmq_msg_recv (&msg, bindSocket, 0);
assert (rc == 6); assert (rc == 6);
void *data = zmq_msg_data (&msg); void *data = zmq_msg_data (&msg);
assert (memcmp ("foobar", data, 6) == 0); assert (memcmp ("foobar", data, 6) == 0);
...@@ -103,12 +120,167 @@ void test_connect_before_bind() ...@@ -103,12 +120,167 @@ void test_connect_before_bind()
assert (rc == 0); assert (rc == 0);
} }
void test_connect_before_bind_pub_sub()
{
void *ctx = zmq_ctx_new ();
assert (ctx);
// Connect first
void *connectSocket = zmq_socket (ctx, ZMQ_PUB);
assert (connectSocket);
int rc = zmq_connect (connectSocket, "inproc://a");
assert (rc == 0);
// Queue up some data, this will be dropped
rc = zmq_send_const (connectSocket, "before", 6, 0);
assert (rc == 6);
// Now bind
void *bindSocket = zmq_socket (ctx, ZMQ_SUB);
assert (bindSocket);
rc = zmq_setsockopt (bindSocket, ZMQ_SUBSCRIBE, "", 0);
assert (rc == 0);
rc = zmq_bind (bindSocket, "inproc://a");
assert (rc == 0);
// Wait for pub-sub connection to happen
zmq_sleep (1);
// Queue up some data, this not will be dropped
rc = zmq_send_const (connectSocket, "after", 6, 0);
assert (rc == 6);
// Read pending message
zmq_msg_t msg;
rc = zmq_msg_init (&msg);
assert (rc == 0);
rc = zmq_msg_recv (&msg, bindSocket, 0);
assert (rc == 6);
void *data = zmq_msg_data (&msg);
assert (memcmp ("after", data, 5) == 0);
// Cleanup
rc = zmq_close (connectSocket);
assert (rc == 0);
rc = zmq_close (bindSocket);
assert (rc == 0);
rc = zmq_ctx_term (ctx);
assert (rc == 0);
}
void test_multiple_connects()
{
const unsigned int no_of_connects = 10;
void *ctx = zmq_ctx_new ();
assert (ctx);
int rc;
void *connectSocket[no_of_connects];
// Connect first
for (unsigned int i = 0; i < no_of_connects; ++i)
{
connectSocket [i] = zmq_socket (ctx, ZMQ_PUSH);
assert (connectSocket [i]);
rc = zmq_connect (connectSocket [i], "inproc://a");
assert (rc == 0);
// Queue up some data
rc = zmq_send_const (connectSocket [i], "foobar", 6, 0);
assert (rc == 6);
}
// Now bind
void *bindSocket = zmq_socket (ctx, ZMQ_PULL);
assert (bindSocket);
rc = zmq_bind (bindSocket, "inproc://a");
assert (rc == 0);
for (unsigned int i = 0; i < no_of_connects; ++i)
{
// Read pending message
zmq_msg_t msg;
rc = zmq_msg_init (&msg);
assert (rc == 0);
rc = zmq_msg_recv (&msg, bindSocket, 0);
assert (rc == 6);
void *data = zmq_msg_data (&msg);
assert (memcmp ("foobar", data, 6) == 0);
}
// Cleanup
for (unsigned int i = 0; i < no_of_connects; ++i)
{
rc = zmq_close (connectSocket [i]);
assert (rc == 0);
}
rc = zmq_close (bindSocket);
assert (rc == 0);
rc = zmq_ctx_term (ctx);
assert (rc == 0);
}
void test_multiple_threads()
{
const unsigned int no_of_threads = 10;
void *ctx = zmq_ctx_new ();
assert (ctx);
int rc;
void *threads [no_of_threads];
// Connect first
for (unsigned int i = 0; i < no_of_threads; ++i)
{
threads [i] = zmq_threadstart (&pusher, ctx);
}
//zmq_sleep(1);
// Now bind
void *bindSocket = zmq_socket (ctx, ZMQ_PULL);
assert (bindSocket);
rc = zmq_bind (bindSocket, "inproc://a");
assert (rc == 0);
for (unsigned int i = 0; i < no_of_threads; ++i)
{
// Read pending message
zmq_msg_t msg;
rc = zmq_msg_init (&msg);
assert (rc == 0);
rc = zmq_msg_recv (&msg, bindSocket, 0);
assert (rc == 6);
void *data = zmq_msg_data (&msg);
assert (memcmp ("foobar", data, 6) == 0);
}
// Cleanup
for (unsigned int i = 0; i < no_of_threads; ++i)
{
zmq_threadclose (threads [i]);
}
rc = zmq_close (bindSocket);
assert (rc == 0);
rc = zmq_ctx_term (ctx);
assert (rc == 0);
}
int main (void) int main (void)
{ {
setup_test_environment(); setup_test_environment();
test_bind_before_connect(); test_bind_before_connect();
test_connect_before_bind(); test_connect_before_bind();
test_connect_before_bind_pub_sub();
test_multiple_connects();
test_multiple_threads();
return 0 ; return 0;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment