Commit 9dab56c1 authored by Pieter Hintjens's avatar Pieter Hintjens

Merge pull request #235 from skaller/master

Thread Safe Sockets.
parents cbf6126b 520ad3c2
...@@ -147,6 +147,7 @@ ZMQ_EXPORT int zmq_getmsgopt (zmq_msg_t *msg, int option, void *optval, ...@@ -147,6 +147,7 @@ ZMQ_EXPORT int zmq_getmsgopt (zmq_msg_t *msg, int option, void *optval,
/******************************************************************************/ /******************************************************************************/
ZMQ_EXPORT void *zmq_init (int io_threads); ZMQ_EXPORT void *zmq_init (int io_threads);
ZMQ_EXPORT void *zmq_init_thread_safe (int io_threads);
ZMQ_EXPORT int zmq_term (void *context); ZMQ_EXPORT int zmq_term (void *context);
/******************************************************************************/ /******************************************************************************/
......
...@@ -81,6 +81,16 @@ zmq::ctx_t::ctx_t (uint32_t io_threads_) : ...@@ -81,6 +81,16 @@ zmq::ctx_t::ctx_t (uint32_t io_threads_) :
zmq_assert (rc == 0); zmq_assert (rc == 0);
} }
void zmq::ctx_t::set_thread_safe()
{
thread_safe_flag = true;
}
bool zmq::ctx_t::get_thread_safe() const
{
return thread_safe_flag;
}
bool zmq::ctx_t::check_tag () bool zmq::ctx_t::check_tag ()
{ {
return tag == 0xbadcafe0; return tag == 0xbadcafe0;
......
...@@ -99,6 +99,10 @@ namespace zmq ...@@ -99,6 +99,10 @@ namespace zmq
reaper_tid = 1 reaper_tid = 1
}; };
// create thread safe sockets
void set_thread_safe();
bool get_thread_safe() const;
~ctx_t (); ~ctx_t ();
private: private:
...@@ -151,6 +155,8 @@ namespace zmq ...@@ -151,6 +155,8 @@ namespace zmq
zmq::socket_base_t *log_socket; zmq::socket_base_t *log_socket;
mutex_t log_sync; mutex_t log_sync;
bool thread_safe_flag;
ctx_t (const ctx_t&); ctx_t (const ctx_t&);
const ctx_t &operator = (const ctx_t&); const ctx_t &operator = (const ctx_t&);
}; };
......
...@@ -121,7 +121,8 @@ zmq::socket_base_t::socket_base_t (ctx_t *parent_, uint32_t tid_) : ...@@ -121,7 +121,8 @@ zmq::socket_base_t::socket_base_t (ctx_t *parent_, uint32_t tid_) :
destroyed (false), destroyed (false),
last_tsc (0), last_tsc (0),
ticks (0), ticks (0),
rcvmore (false) rcvmore (false),
thread_safe_flag (false)
{ {
} }
...@@ -873,3 +874,18 @@ void zmq::socket_base_t::extract_flags (msg_t *msg_) ...@@ -873,3 +874,18 @@ void zmq::socket_base_t::extract_flags (msg_t *msg_)
rcvmore = msg_->flags () & msg_t::more ? true : false; rcvmore = msg_->flags () & msg_t::more ? true : false;
} }
void zmq::socket_base_t::set_thread_safe()
{
thread_safe_flag = true;
}
void zmq::socket_base_t::lock()
{
sync.lock();
}
void zmq::socket_base_t::unlock()
{
sync.unlock();
}
...@@ -95,7 +95,10 @@ namespace zmq ...@@ -95,7 +95,10 @@ namespace zmq
void write_activated (pipe_t *pipe_); void write_activated (pipe_t *pipe_);
void hiccuped (pipe_t *pipe_); void hiccuped (pipe_t *pipe_);
void terminated (pipe_t *pipe_); void terminated (pipe_t *pipe_);
bool thread_safe() const { return thread_safe_flag; }
void set_thread_safe(); // should be in constructor, here for compat
void lock();
void unlock();
protected: protected:
socket_base_t (zmq::ctx_t *parent_, uint32_t tid_); socket_base_t (zmq::ctx_t *parent_, uint32_t tid_);
...@@ -195,6 +198,8 @@ namespace zmq ...@@ -195,6 +198,8 @@ namespace zmq
socket_base_t (const socket_base_t&); socket_base_t (const socket_base_t&);
const socket_base_t &operator = (const socket_base_t&); const socket_base_t &operator = (const socket_base_t&);
bool thread_safe_flag;
mutex_t sync;
}; };
} }
......
...@@ -90,7 +90,7 @@ const char *zmq_strerror (int errnum_) ...@@ -90,7 +90,7 @@ const char *zmq_strerror (int errnum_)
return zmq::errno_to_string (errnum_); return zmq::errno_to_string (errnum_);
} }
void *zmq_init (int io_threads_) static zmq::ctx_t *inner_init (int io_threads_)
{ {
if (io_threads_ < 0) { if (io_threads_ < 0) {
errno = EINVAL; errno = EINVAL;
...@@ -139,6 +139,18 @@ void *zmq_init (int io_threads_) ...@@ -139,6 +139,18 @@ void *zmq_init (int io_threads_)
// Create 0MQ context. // Create 0MQ context.
zmq::ctx_t *ctx = new (std::nothrow) zmq::ctx_t ((uint32_t) io_threads_); zmq::ctx_t *ctx = new (std::nothrow) zmq::ctx_t ((uint32_t) io_threads_);
alloc_assert (ctx); alloc_assert (ctx);
return ctx;
}
void *zmq_init (int io_threads_)
{
return (void*) inner_init (io_threads_);
}
void *zmq_init_thread_safe (int io_threads_)
{
zmq::ctx_t *ctx = inner_init (io_threads_);
ctx->set_thread_safe();
return (void*) ctx; return (void*) ctx;
} }
...@@ -174,7 +186,10 @@ void *zmq_socket (void *ctx_, int type_) ...@@ -174,7 +186,10 @@ void *zmq_socket (void *ctx_, int type_)
errno = EFAULT; errno = EFAULT;
return NULL; return NULL;
} }
return (void*) (((zmq::ctx_t*) ctx_)->create_socket (type_)); zmq::ctx_t *ctx = (zmq::ctx_t*) ctx_;
zmq::socket_base_t *s = ctx->create_socket (type_);
if (ctx->get_thread_safe ()) s->set_thread_safe ();
return (void*) s;
} }
int zmq_close (void *s_) int zmq_close (void *s_)
...@@ -194,8 +209,11 @@ int zmq_setsockopt (void *s_, int option_, const void *optval_, ...@@ -194,8 +209,11 @@ int zmq_setsockopt (void *s_, int option_, const void *optval_,
errno = ENOTSOCK; errno = ENOTSOCK;
return -1; return -1;
} }
return (((zmq::socket_base_t*) s_)->setsockopt (option_, optval_, zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
optvallen_)); if(s->thread_safe()) s->lock();
int result = s->setsockopt (option_, optval_, optvallen_);
if(s->thread_safe()) s->unlock();
return result;
} }
int zmq_getsockopt (void *s_, int option_, void *optval_, size_t *optvallen_) int zmq_getsockopt (void *s_, int option_, void *optval_, size_t *optvallen_)
...@@ -204,8 +222,11 @@ int zmq_getsockopt (void *s_, int option_, void *optval_, size_t *optvallen_) ...@@ -204,8 +222,11 @@ int zmq_getsockopt (void *s_, int option_, void *optval_, size_t *optvallen_)
errno = ENOTSOCK; errno = ENOTSOCK;
return -1; return -1;
} }
return (((zmq::socket_base_t*) s_)->getsockopt (option_, optval_, zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
optvallen_)); if(s->thread_safe()) s->lock();
int result = s->getsockopt (option_, optval_, optvallen_);
if(s->thread_safe()) s->unlock();
return result;
} }
int zmq_bind (void *s_, const char *addr_) int zmq_bind (void *s_, const char *addr_)
...@@ -214,7 +235,11 @@ int zmq_bind (void *s_, const char *addr_) ...@@ -214,7 +235,11 @@ int zmq_bind (void *s_, const char *addr_)
errno = ENOTSOCK; errno = ENOTSOCK;
return -1; return -1;
} }
return (((zmq::socket_base_t*) s_)->bind (addr_)); zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = s->bind (addr_);
if(s->thread_safe()) s->unlock();
return result;
} }
int zmq_connect (void *s_, const char *addr_) int zmq_connect (void *s_, const char *addr_)
...@@ -223,7 +248,34 @@ int zmq_connect (void *s_, const char *addr_) ...@@ -223,7 +248,34 @@ int zmq_connect (void *s_, const char *addr_)
errno = ENOTSOCK; errno = ENOTSOCK;
return -1; return -1;
} }
return (((zmq::socket_base_t*) s_)->connect (addr_)); zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = s->connect (addr_);
if(s->thread_safe()) s->unlock();
return result;
}
// sending functions
static int inner_sendmsg (zmq::socket_base_t *s_, zmq_msg_t *msg_, int flags_)
{
int sz = (int) zmq_msg_size (msg_);
int rc = s_->send ((zmq::msg_t*) msg_, flags_);
if (unlikely (rc < 0))
return -1;
return sz;
}
int zmq_sendmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = inner_sendmsg (s, msg_, flags_);
if(s->thread_safe()) s->unlock();
return result;
} }
int zmq_send (void *s_, const void *buf_, size_t len_, int flags_) int zmq_send (void *s_, const void *buf_, size_t len_, int flags_)
...@@ -234,7 +286,10 @@ int zmq_send (void *s_, const void *buf_, size_t len_, int flags_) ...@@ -234,7 +286,10 @@ int zmq_send (void *s_, const void *buf_, size_t len_, int flags_)
return -1; return -1;
memcpy (zmq_msg_data (&msg), buf_, len_); memcpy (zmq_msg_data (&msg), buf_, len_);
rc = zmq_sendmsg (s_, &msg, flags_); zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
rc = inner_sendmsg (s, &msg, flags_);
if(s->thread_safe()) s->unlock();
if (unlikely (rc < 0)) { if (unlikely (rc < 0)) {
int err = errno; int err = errno;
int rc2 = zmq_msg_close (&msg); int rc2 = zmq_msg_close (&msg);
...@@ -248,13 +303,43 @@ int zmq_send (void *s_, const void *buf_, size_t len_, int flags_) ...@@ -248,13 +303,43 @@ int zmq_send (void *s_, const void *buf_, size_t len_, int flags_)
return rc; return rc;
} }
// receiving functions
static int inner_recvmsg (zmq::socket_base_t *s_, zmq_msg_t *msg_, int flags_)
{
int rc = s_->recv ((zmq::msg_t*) msg_, flags_);
if (unlikely (rc < 0))
return -1;
return (int) zmq_msg_size (msg_);
}
int zmq_recvmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = inner_recvmsg(s, msg_, flags_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_recv (void *s_, void *buf_, size_t len_, int flags_) int zmq_recv (void *s_, void *buf_, size_t len_, int flags_)
{ {
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
zmq_msg_t msg; zmq_msg_t msg;
int rc = zmq_msg_init (&msg); int rc = zmq_msg_init (&msg);
errno_assert (rc == 0); errno_assert (rc == 0);
int nbytes = zmq_recvmsg (s_, &msg, flags_); zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int nbytes = inner_recvmsg (s, &msg, flags_);
if(s->thread_safe()) s->unlock();
if (unlikely (nbytes < 0)) { if (unlikely (nbytes < 0)) {
int err = errno; int err = errno;
rc = zmq_msg_close (&msg); rc = zmq_msg_close (&msg);
...@@ -274,31 +359,7 @@ int zmq_recv (void *s_, void *buf_, size_t len_, int flags_) ...@@ -274,31 +359,7 @@ int zmq_recv (void *s_, void *buf_, size_t len_, int flags_)
return nbytes; return nbytes;
} }
int zmq_sendmsg (void *s_, zmq_msg_t *msg_, int flags_) // message manipulators
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
int sz = (int) zmq_msg_size (msg_);
int rc = (((zmq::socket_base_t*) s_)->send ((zmq::msg_t*) msg_, flags_));
if (unlikely (rc < 0))
return -1;
return sz;
}
int zmq_recvmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
int rc = (((zmq::socket_base_t*) s_)->recv ((zmq::msg_t*) msg_, flags_));
if (unlikely (rc < 0))
return -1;
return (int) zmq_msg_size (msg_);
}
int zmq_msg_init (zmq_msg_t *msg_) int zmq_msg_init (zmq_msg_t *msg_)
{ {
return ((zmq::msg_t*) msg_)->init (); return ((zmq::msg_t*) msg_)->init ();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment