Commit 988efbc7 authored by skaller's avatar skaller

Thread Safe Sockets.

1. Reorganise C API socket functions to eliminate bad practice
of public functions calling other public functions. This should
be done for msg's too but hasn't been in this patch.

2. Reorganise code in C API socket functions so that the
socket is cast on one line, the C++ function called on
the next with the result retained, then the result is returned.

This makes the code much simpler to read and also allows
pre- and post- call hooks to be inserted easily.

3. Insert pre- and post- call hooks which set and release
a mutex iff the thread_safe flag is on.

4. Add the thread_safe_flag to base_socket_t initialised to
false to preserve existing semantics. Add an accessor for
the flag, add a mutex, and add lock and unlock functions.

Note: as yet no code to actually set the flag.
parent 4dd6ce06
......@@ -121,7 +121,8 @@ zmq::socket_base_t::socket_base_t (ctx_t *parent_, uint32_t tid_) :
destroyed (false),
last_tsc (0),
ticks (0),
rcvmore (false)
rcvmore (false),
thread_safe_flag (false)
{
}
......@@ -873,3 +874,13 @@ void zmq::socket_base_t::extract_flags (msg_t *msg_)
rcvmore = msg_->flags () & msg_t::more ? true : false;
}
void zmq::socket_base_t::lock()
{
sync.lock();
}
void zmq::socket_base_t::unlock()
{
sync.unlock();
}
......@@ -95,7 +95,9 @@ namespace zmq
void write_activated (pipe_t *pipe_);
void hiccuped (pipe_t *pipe_);
void terminated (pipe_t *pipe_);
bool thread_safe() const { return thread_safe_flag; }
void lock();
void unlock();
protected:
socket_base_t (zmq::ctx_t *parent_, uint32_t tid_);
......@@ -195,6 +197,8 @@ namespace zmq
socket_base_t (const socket_base_t&);
const socket_base_t &operator = (const socket_base_t&);
bool thread_safe_flag;
mutex_t sync;
};
}
......
......@@ -194,8 +194,11 @@ int zmq_setsockopt (void *s_, int option_, const void *optval_,
errno = ENOTSOCK;
return -1;
}
return (((zmq::socket_base_t*) s_)->setsockopt (option_, optval_,
optvallen_));
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = s->setsockopt (option_, optval_, optvallen_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_getsockopt (void *s_, int option_, void *optval_, size_t *optvallen_)
......@@ -204,8 +207,11 @@ int zmq_getsockopt (void *s_, int option_, void *optval_, size_t *optvallen_)
errno = ENOTSOCK;
return -1;
}
return (((zmq::socket_base_t*) s_)->getsockopt (option_, optval_,
optvallen_));
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = s->getsockopt (option_, optval_, optvallen_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_bind (void *s_, const char *addr_)
......@@ -214,7 +220,11 @@ int zmq_bind (void *s_, const char *addr_)
errno = ENOTSOCK;
return -1;
}
return (((zmq::socket_base_t*) s_)->bind (addr_));
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = s->bind (addr_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_connect (void *s_, const char *addr_)
......@@ -223,7 +233,34 @@ int zmq_connect (void *s_, const char *addr_)
errno = ENOTSOCK;
return -1;
}
return (((zmq::socket_base_t*) s_)->connect (addr_));
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = s->connect (addr_);
if(s->thread_safe()) s->unlock();
return result;
}
// sending functions
static int inner_sendmsg (zmq::socket_base_t *s_, zmq_msg_t *msg_, int flags_)
{
int sz = (int) zmq_msg_size (msg_);
int rc = s_->send ((zmq::msg_t*) msg_, flags_);
if (unlikely (rc < 0))
return -1;
return sz;
}
int zmq_sendmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = inner_sendmsg (s, msg_, flags_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_send (void *s_, const void *buf_, size_t len_, int flags_)
......@@ -234,7 +271,10 @@ int zmq_send (void *s_, const void *buf_, size_t len_, int flags_)
return -1;
memcpy (zmq_msg_data (&msg), buf_, len_);
rc = zmq_sendmsg (s_, &msg, flags_);
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
rc = inner_sendmsg (s, &msg, flags_);
if(s->thread_safe()) s->unlock();
if (unlikely (rc < 0)) {
int err = errno;
int rc2 = zmq_msg_close (&msg);
......@@ -248,13 +288,43 @@ int zmq_send (void *s_, const void *buf_, size_t len_, int flags_)
return rc;
}
// receiving functions
static int inner_recvmsg (zmq::socket_base_t *s_, zmq_msg_t *msg_, int flags_)
{
int rc = s_->recv ((zmq::msg_t*) msg_, flags_);
if (unlikely (rc < 0))
return -1;
return (int) zmq_msg_size (msg_);
}
int zmq_recvmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = inner_recvmsg(s, msg_, flags_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_recv (void *s_, void *buf_, size_t len_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
zmq_msg_t msg;
int rc = zmq_msg_init (&msg);
errno_assert (rc == 0);
int nbytes = zmq_recvmsg (s_, &msg, flags_);
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int nbytes = inner_recvmsg (s, &msg, flags_);
if(s->thread_safe()) s->unlock();
if (unlikely (nbytes < 0)) {
int err = errno;
rc = zmq_msg_close (&msg);
......@@ -274,31 +344,7 @@ int zmq_recv (void *s_, void *buf_, size_t len_, int flags_)
return nbytes;
}
int zmq_sendmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
int sz = (int) zmq_msg_size (msg_);
int rc = (((zmq::socket_base_t*) s_)->send ((zmq::msg_t*) msg_, flags_));
if (unlikely (rc < 0))
return -1;
return sz;
}
int zmq_recvmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
int rc = (((zmq::socket_base_t*) s_)->recv ((zmq::msg_t*) msg_, flags_));
if (unlikely (rc < 0))
return -1;
return (int) zmq_msg_size (msg_);
}
// message manipulators
int zmq_msg_init (zmq_msg_t *msg_)
{
return ((zmq::msg_t*) msg_)->init ();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment