Commit e46ec312 authored by Martin Hurton's avatar Martin Hurton

Implement socket_base_t::get_credential member function

The get_credential () member function returns
credential for the last peer we received message for.
The idea is that this function is used to implement user-level API.
parent 5c4f3cc6
...@@ -613,6 +613,9 @@ int zmq::curve_server_t::receive_and_process_zap_reply () ...@@ -613,6 +613,9 @@ int zmq::curve_server_t::receive_and_process_zap_reply ()
goto error; goto error;
} }
// Save user id
set_user_id (msg [5].data (), msg [5].size ());
// Process metadata frame // Process metadata frame
rc = parse_metadata (static_cast <const unsigned char*> (msg [6].data ()), rc = parse_metadata (static_cast <const unsigned char*> (msg [6].data ()),
msg [6].size ()); msg [6].size ());
......
...@@ -98,6 +98,12 @@ bool zmq::dealer_t::xhas_out () ...@@ -98,6 +98,12 @@ bool zmq::dealer_t::xhas_out ()
return lb.has_out (); return lb.has_out ();
} }
zmq::blob_t zmq::dealer_t::get_credential () const
{
return fq.get_credential ();
}
void zmq::dealer_t::xread_activated (pipe_t *pipe_) void zmq::dealer_t::xread_activated (pipe_t *pipe_)
{ {
fq.activated (pipe_); fq.activated (pipe_);
......
...@@ -51,6 +51,7 @@ namespace zmq ...@@ -51,6 +51,7 @@ namespace zmq
int xrecv (zmq::msg_t *msg_); int xrecv (zmq::msg_t *msg_);
bool xhas_in (); bool xhas_in ();
bool xhas_out (); bool xhas_out ();
blob_t get_credential () const;
void xread_activated (zmq::pipe_t *pipe_); void xread_activated (zmq::pipe_t *pipe_);
void xwrite_activated (zmq::pipe_t *pipe_); void xwrite_activated (zmq::pipe_t *pipe_);
void xpipe_terminated (zmq::pipe_t *pipe_); void xpipe_terminated (zmq::pipe_t *pipe_);
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
zmq::fq_t::fq_t () : zmq::fq_t::fq_t () :
active (0), active (0),
last_in (NULL),
current (0), current (0),
more (false) more (false)
{ {
...@@ -54,6 +55,11 @@ void zmq::fq_t::pipe_terminated (pipe_t *pipe_) ...@@ -54,6 +55,11 @@ void zmq::fq_t::pipe_terminated (pipe_t *pipe_)
current = 0; current = 0;
} }
pipes.erase (pipe_); pipes.erase (pipe_);
if (last_in == pipe_) {
saved_credential = last_in->get_credential ();
last_in = NULL;
}
} }
void zmq::fq_t::activated (pipe_t *pipe_) void zmq::fq_t::activated (pipe_t *pipe_)
...@@ -88,8 +94,10 @@ int zmq::fq_t::recvpipe (msg_t *msg_, pipe_t **pipe_) ...@@ -88,8 +94,10 @@ int zmq::fq_t::recvpipe (msg_t *msg_, pipe_t **pipe_)
if (pipe_) if (pipe_)
*pipe_ = pipes [current]; *pipe_ = pipes [current];
more = msg_->flags () & msg_t::more? true: false; more = msg_->flags () & msg_t::more? true: false;
if (!more) if (!more) {
last_in = pipes [current];
current = (current + 1) % active; current = (current + 1) % active;
}
return 0; return 0;
} }
...@@ -136,3 +144,9 @@ bool zmq::fq_t::has_in () ...@@ -136,3 +144,9 @@ bool zmq::fq_t::has_in ()
return false; return false;
} }
zmq::blob_t zmq::fq_t::get_credential () const
{
return last_in?
last_in->get_credential (): saved_credential;
}
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#define __ZMQ_FQ_HPP_INCLUDED__ #define __ZMQ_FQ_HPP_INCLUDED__
#include "array.hpp" #include "array.hpp"
#include "blob.hpp"
#include "pipe.hpp" #include "pipe.hpp"
#include "msg.hpp" #include "msg.hpp"
...@@ -45,6 +46,7 @@ namespace zmq ...@@ -45,6 +46,7 @@ namespace zmq
int recv (msg_t *msg_); int recv (msg_t *msg_);
int recvpipe (msg_t *msg_, pipe_t **pipe_); int recvpipe (msg_t *msg_, pipe_t **pipe_);
bool has_in (); bool has_in ();
blob_t get_credential () const;
private: private:
...@@ -56,6 +58,11 @@ namespace zmq ...@@ -56,6 +58,11 @@ namespace zmq
// beginning of the pipes array. // beginning of the pipes array.
pipes_t::size_type active; pipes_t::size_type active;
// Pointer to the last pipe we received message from.
// NULL when no message has been received or the pipe
// has terminated.
pipe_t *last_in;
// Index of the next bound pipe to read a message from. // Index of the next bound pipe to read a message from.
pipes_t::size_type current; pipes_t::size_type current;
...@@ -63,6 +70,9 @@ namespace zmq ...@@ -63,6 +70,9 @@ namespace zmq
// there are following parts still waiting in the current pipe. // there are following parts still waiting in the current pipe.
bool more; bool more;
// Holds credential after the last_acive_pipe has terminated.
blob_t saved_credential;
fq_t (const fq_t&); fq_t (const fq_t&);
const fq_t &operator = (const fq_t&); const fq_t &operator = (const fq_t&);
}; };
......
...@@ -47,6 +47,16 @@ void zmq::mechanism_t::peer_identity (msg_t *msg_) ...@@ -47,6 +47,16 @@ void zmq::mechanism_t::peer_identity (msg_t *msg_)
msg_->set_flags (msg_t::identity); msg_->set_flags (msg_t::identity);
} }
void zmq::mechanism_t::set_user_id (const void *data_, size_t size_)
{
user_id = blob_t (static_cast <const unsigned char*> (data_), size_);
}
zmq::blob_t zmq::mechanism_t::get_user_id () const
{
return user_id;
}
const char *zmq::mechanism_t::socket_type_string (int socket_type) const const char *zmq::mechanism_t::socket_type_string (int socket_type) const
{ {
static const char *names [] = {"PAIR", "PUB", "SUB", "REQ", "REP", static const char *names [] = {"PAIR", "PUB", "SUB", "REQ", "REP",
......
...@@ -60,6 +60,10 @@ namespace zmq ...@@ -60,6 +60,10 @@ namespace zmq
void peer_identity (msg_t *msg_); void peer_identity (msg_t *msg_);
void set_user_id (const void *user_id, size_t size);
blob_t get_user_id () const;
protected: protected:
// Only used to identify the socket for the Socket-Type // Only used to identify the socket for the Socket-Type
...@@ -91,6 +95,8 @@ namespace zmq ...@@ -91,6 +95,8 @@ namespace zmq
blob_t identity; blob_t identity;
blob_t user_id;
// Returns true iff socket associated with the mechanism // Returns true iff socket associated with the mechanism
// is compatible with a given socket type 'type_'. // is compatible with a given socket type 'type_'.
bool check_socket_type (const std::string type_) const; bool check_socket_type (const std::string type_) const;
......
...@@ -267,6 +267,11 @@ bool zmq::msg_t::is_identity () const ...@@ -267,6 +267,11 @@ bool zmq::msg_t::is_identity () const
return (u.base.flags & identity) == identity; return (u.base.flags & identity) == identity;
} }
bool zmq::msg_t::is_credential () const
{
return (u.base.flags & credential) == credential;
}
bool zmq::msg_t::is_delimiter () const bool zmq::msg_t::is_delimiter () const
{ {
return u.base.type == type_delimiter; return u.base.type == type_delimiter;
......
...@@ -49,6 +49,7 @@ namespace zmq ...@@ -49,6 +49,7 @@ namespace zmq
{ {
more = 1, // Followed by more parts more = 1, // Followed by more parts
command = 2, // Command frame (see ZMTP spec) command = 2, // Command frame (see ZMTP spec)
credential = 32,
identity = 64, identity = 64,
shared = 128 shared = 128
}; };
...@@ -70,6 +71,7 @@ namespace zmq ...@@ -70,6 +71,7 @@ namespace zmq
int64_t fd (); int64_t fd ();
void set_fd (int64_t fd_); void set_fd (int64_t fd_);
bool is_identity () const; bool is_identity () const;
bool is_credential () const;
bool is_delimiter () const; bool is_delimiter () const;
bool is_vsm (); bool is_vsm ();
bool is_cmsg (); bool is_cmsg ();
......
...@@ -268,6 +268,9 @@ int zmq::null_mechanism_t::receive_and_process_zap_reply () ...@@ -268,6 +268,9 @@ int zmq::null_mechanism_t::receive_and_process_zap_reply ()
goto error; goto error;
} }
// Save user id
set_user_id (msg [5].data (), msg [5].size ());
// Process metadata frame // Process metadata frame
rc = parse_metadata (static_cast <const unsigned char*> (msg [6].data ()), rc = parse_metadata (static_cast <const unsigned char*> (msg [6].data ()),
msg [6].size ()); msg [6].size ());
......
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
zmq::pair_t::pair_t (class ctx_t *parent_, uint32_t tid_, int sid_) : zmq::pair_t::pair_t (class ctx_t *parent_, uint32_t tid_, int sid_) :
socket_base_t (parent_, tid_, sid_), socket_base_t (parent_, tid_, sid_),
pipe (NULL) pipe (NULL),
last_in (NULL)
{ {
options.type = ZMQ_PAIR; options.type = ZMQ_PAIR;
} }
...@@ -51,8 +52,13 @@ void zmq::pair_t::xattach_pipe (pipe_t *pipe_, bool subscribe_to_all_) ...@@ -51,8 +52,13 @@ void zmq::pair_t::xattach_pipe (pipe_t *pipe_, bool subscribe_to_all_)
void zmq::pair_t::xpipe_terminated (pipe_t *pipe_) void zmq::pair_t::xpipe_terminated (pipe_t *pipe_)
{ {
if (pipe_ == pipe) if (pipe_ == pipe) {
if (last_in == pipe) {
saved_credential = last_in->get_credential ();
last_in = NULL;
}
pipe = NULL; pipe = NULL;
}
} }
void zmq::pair_t::xread_activated (pipe_t *) void zmq::pair_t::xread_activated (pipe_t *)
...@@ -99,6 +105,7 @@ int zmq::pair_t::xrecv (msg_t *msg_) ...@@ -99,6 +105,7 @@ int zmq::pair_t::xrecv (msg_t *msg_)
errno = EAGAIN; errno = EAGAIN;
return -1; return -1;
} }
last_in = pipe;
return 0; return 0;
} }
...@@ -117,3 +124,8 @@ bool zmq::pair_t::xhas_out () ...@@ -117,3 +124,8 @@ bool zmq::pair_t::xhas_out ()
return pipe->check_write (); return pipe->check_write ();
} }
zmq::blob_t zmq::pair_t::get_credential () const
{
return last_in? last_in->get_credential (): saved_credential;
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#ifndef __ZMQ_PAIR_HPP_INCLUDED__ #ifndef __ZMQ_PAIR_HPP_INCLUDED__
#define __ZMQ_PAIR_HPP_INCLUDED__ #define __ZMQ_PAIR_HPP_INCLUDED__
#include "blob.hpp"
#include "socket_base.hpp" #include "socket_base.hpp"
#include "session_base.hpp" #include "session_base.hpp"
...@@ -45,6 +46,7 @@ namespace zmq ...@@ -45,6 +46,7 @@ namespace zmq
int xrecv (zmq::msg_t *msg_); int xrecv (zmq::msg_t *msg_);
bool xhas_in (); bool xhas_in ();
bool xhas_out (); bool xhas_out ();
blob_t get_credential () const;
void xread_activated (zmq::pipe_t *pipe_); void xread_activated (zmq::pipe_t *pipe_);
void xwrite_activated (zmq::pipe_t *pipe_); void xwrite_activated (zmq::pipe_t *pipe_);
void xpipe_terminated (zmq::pipe_t *pipe_); void xpipe_terminated (zmq::pipe_t *pipe_);
...@@ -53,6 +55,10 @@ namespace zmq ...@@ -53,6 +55,10 @@ namespace zmq
zmq::pipe_t *pipe; zmq::pipe_t *pipe;
zmq::pipe_t *last_in;
blob_t saved_credential;
pair_t (const pair_t&); pair_t (const pair_t&);
const pair_t &operator = (const pair_t&); const pair_t &operator = (const pair_t&);
}; };
......
...@@ -110,6 +110,11 @@ zmq::blob_t zmq::pipe_t::get_identity () ...@@ -110,6 +110,11 @@ zmq::blob_t zmq::pipe_t::get_identity ()
return identity; return identity;
} }
zmq::blob_t zmq::pipe_t::get_credential () const
{
return credential;
}
bool zmq::pipe_t::check_read () bool zmq::pipe_t::check_read ()
{ {
if (unlikely (!in_active)) if (unlikely (!in_active))
...@@ -143,11 +148,21 @@ bool zmq::pipe_t::read (msg_t *msg_) ...@@ -143,11 +148,21 @@ bool zmq::pipe_t::read (msg_t *msg_)
if (unlikely (state != active && state != waiting_for_delimiter)) if (unlikely (state != active && state != waiting_for_delimiter))
return false; return false;
read_message:
if (!inpipe->read (msg_)) { if (!inpipe->read (msg_)) {
in_active = false; in_active = false;
return false; return false;
} }
// If this is a credential, save a copy and receive next message.
if (unlikely (msg_->is_credential ())) {
const unsigned char *data = static_cast <const unsigned char *> (msg_->data ());
credential = blob_t (data, msg_->size ());
const int rc = msg_->close ();
zmq_assert (rc == 0);
goto read_message;
}
// If delimiter was read, start termination process of the pipe. // If delimiter was read, start termination process of the pipe.
if (msg_->is_delimiter ()) { if (msg_->is_delimiter ()) {
process_delimiter (); process_delimiter ();
......
...@@ -78,6 +78,8 @@ namespace zmq ...@@ -78,6 +78,8 @@ namespace zmq
void set_identity (const blob_t &identity_); void set_identity (const blob_t &identity_);
blob_t get_identity (); blob_t get_identity ();
blob_t get_credential () const;
// Returns true if there is at least one message to read in the pipe. // Returns true if there is at least one message to read in the pipe.
bool check_read (); bool check_read ();
...@@ -198,6 +200,9 @@ namespace zmq ...@@ -198,6 +200,9 @@ namespace zmq
// Identity of the writer. Used uniquely by the reader side. // Identity of the writer. Used uniquely by the reader side.
blob_t identity; blob_t identity;
// Pipe's credential.
blob_t credential;
// Returns true if the message is delimiter; false otherwise. // Returns true if the message is delimiter; false otherwise.
static bool is_delimiter (const msg_t &msg_); static bool is_delimiter (const msg_t &msg_);
......
...@@ -468,6 +468,9 @@ int zmq::plain_mechanism_t::receive_and_process_zap_reply () ...@@ -468,6 +468,9 @@ int zmq::plain_mechanism_t::receive_and_process_zap_reply ()
goto error; goto error;
} }
// Save user id
set_user_id (msg [5].data (), msg [5].size ());
// Process metadata frame // Process metadata frame
rc = parse_metadata (static_cast <const unsigned char*> (msg [6].data ()), rc = parse_metadata (static_cast <const unsigned char*> (msg [6].data ()),
msg [6].size ()); msg [6].size ());
......
...@@ -60,3 +60,8 @@ bool zmq::pull_t::xhas_in () ...@@ -60,3 +60,8 @@ bool zmq::pull_t::xhas_in ()
{ {
return fq.has_in (); return fq.has_in ();
} }
zmq::blob_t zmq::pull_t::get_credential () const
{
return fq.get_credential ();
}
...@@ -46,6 +46,7 @@ namespace zmq ...@@ -46,6 +46,7 @@ namespace zmq
void xattach_pipe (zmq::pipe_t *pipe_, bool subscribe_to_all_); void xattach_pipe (zmq::pipe_t *pipe_, bool subscribe_to_all_);
int xrecv (zmq::msg_t *msg_); int xrecv (zmq::msg_t *msg_);
bool xhas_in (); bool xhas_in ();
blob_t get_credential () const;
void xread_activated (zmq::pipe_t *pipe_); void xread_activated (zmq::pipe_t *pipe_);
void xpipe_terminated (zmq::pipe_t *pipe_); void xpipe_terminated (zmq::pipe_t *pipe_);
......
...@@ -371,6 +371,11 @@ bool zmq::router_t::xhas_out () ...@@ -371,6 +371,11 @@ bool zmq::router_t::xhas_out ()
return true; return true;
} }
zmq::blob_t zmq::router_t::get_credential () const
{
return fq.get_credential ();
}
bool zmq::router_t::identify_peer (pipe_t *pipe_) bool zmq::router_t::identify_peer (pipe_t *pipe_)
{ {
msg_t msg; msg_t msg;
......
...@@ -59,6 +59,7 @@ namespace zmq ...@@ -59,6 +59,7 @@ namespace zmq
// Rollback any message parts that were sent but not yet flushed. // Rollback any message parts that were sent but not yet flushed.
int rollback (); int rollback ();
blob_t get_credential () const;
private: private:
......
...@@ -1043,6 +1043,11 @@ int zmq::socket_base_t::xrecv (msg_t *) ...@@ -1043,6 +1043,11 @@ int zmq::socket_base_t::xrecv (msg_t *)
return -1; return -1;
} }
zmq::blob_t zmq::socket_base_t::get_credential () const
{
return blob_t ();
}
void zmq::socket_base_t::xread_activated (pipe_t *) void zmq::socket_base_t::xread_activated (pipe_t *)
{ {
zmq_assert (false); zmq_assert (false);
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "own.hpp" #include "own.hpp"
#include "array.hpp" #include "array.hpp"
#include "blob.hpp"
#include "stdint.hpp" #include "stdint.hpp"
#include "poller.hpp" #include "poller.hpp"
#include "atomic_counter.hpp" #include "atomic_counter.hpp"
...@@ -144,6 +145,11 @@ namespace zmq ...@@ -144,6 +145,11 @@ namespace zmq
virtual bool xhas_in (); virtual bool xhas_in ();
virtual int xrecv (zmq::msg_t *msg_); virtual int xrecv (zmq::msg_t *msg_);
// Returns the credential for the peer from which we have received
// the last message. If no message has been received yet,
// the function returns empty credential.
virtual blob_t get_credential () const;
// i_pipe_events will be forwarded to these functions. // i_pipe_events will be forwarded to these functions.
virtual void xread_activated (pipe_t *pipe_); virtual void xread_activated (pipe_t *pipe_);
virtual void xwrite_activated (pipe_t *pipe_); virtual void xwrite_activated (pipe_t *pipe_);
......
...@@ -692,7 +692,7 @@ void zmq::stream_engine_t::mechanism_ready () ...@@ -692,7 +692,7 @@ void zmq::stream_engine_t::mechanism_ready ()
} }
read_msg = &stream_engine_t::pull_and_encode; read_msg = &stream_engine_t::pull_and_encode;
write_msg = &stream_engine_t::decode_and_push; write_msg = &stream_engine_t::write_credential;
} }
int zmq::stream_engine_t::pull_msg_from_session (msg_t *msg_) int zmq::stream_engine_t::pull_msg_from_session (msg_t *msg_)
...@@ -705,6 +705,29 @@ int zmq::stream_engine_t::push_msg_to_session (msg_t *msg_) ...@@ -705,6 +705,29 @@ int zmq::stream_engine_t::push_msg_to_session (msg_t *msg_)
return session->push_msg (msg_); return session->push_msg (msg_);
} }
int zmq::stream_engine_t::write_credential (msg_t *msg_)
{
zmq_assert (mechanism != NULL);
zmq_assert (session != NULL);
const blob_t credential = mechanism->get_user_id ();
if (credential.size () > 0) {
msg_t msg;
int rc = msg.init_size (credential.size ());
zmq_assert (rc == 0);
memcpy (msg.data (), credential.data (), credential.size ());
msg.set_flags (msg_t::credential);
rc = session->push_msg (&msg);
if (rc == -1) {
rc = msg.close ();
errno_assert (rc == 0);
return -1;
}
}
write_msg = &stream_engine_t::decode_and_push;
return decode_and_push (msg_);
}
int zmq::stream_engine_t::pull_and_encode (msg_t *msg_) int zmq::stream_engine_t::pull_and_encode (msg_t *msg_)
{ {
zmq_assert (mechanism != NULL); zmq_assert (mechanism != NULL);
......
...@@ -101,6 +101,7 @@ namespace zmq ...@@ -101,6 +101,7 @@ namespace zmq
int pull_msg_from_session (msg_t *msg_); int pull_msg_from_session (msg_t *msg_);
int push_msg_to_session (msg_t *msg); int push_msg_to_session (msg_t *msg);
int write_credential (msg_t *msg_);
int pull_and_encode (msg_t *msg_); int pull_and_encode (msg_t *msg_);
int decode_and_push (msg_t *msg_); int decode_and_push (msg_t *msg_);
int push_one_then_decode_and_push (msg_t *msg_); int push_one_then_decode_and_push (msg_t *msg_);
......
...@@ -199,6 +199,11 @@ bool zmq::xsub_t::xhas_in () ...@@ -199,6 +199,11 @@ bool zmq::xsub_t::xhas_in ()
} }
} }
zmq::blob_t zmq::xsub_t::get_credential () const
{
return fq.get_credential ();
}
bool zmq::xsub_t::match (msg_t *msg_) bool zmq::xsub_t::match (msg_t *msg_)
{ {
return subscriptions.check ((unsigned char*) msg_->data (), msg_->size ()); return subscriptions.check ((unsigned char*) msg_->data (), msg_->size ());
......
...@@ -49,6 +49,7 @@ namespace zmq ...@@ -49,6 +49,7 @@ namespace zmq
bool xhas_out (); bool xhas_out ();
int xrecv (zmq::msg_t *msg_); int xrecv (zmq::msg_t *msg_);
bool xhas_in (); bool xhas_in ();
blob_t get_credential () const;
void xread_activated (zmq::pipe_t *pipe_); void xread_activated (zmq::pipe_t *pipe_);
void xwrite_activated (zmq::pipe_t *pipe_); void xwrite_activated (zmq::pipe_t *pipe_);
void xhiccuped (pipe_t *pipe_); void xhiccuped (pipe_t *pipe_);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment