Commit e46ec312 authored by Martin Hurton's avatar Martin Hurton

Implement socket_base_t::get_credential member function

The get_credential () member function returns
credential for the last peer we received message for.
The idea is that this function is used to implement user-level API.
parent 5c4f3cc6
......@@ -613,6 +613,9 @@ int zmq::curve_server_t::receive_and_process_zap_reply ()
goto error;
}
// Save user id
set_user_id (msg [5].data (), msg [5].size ());
// Process metadata frame
rc = parse_metadata (static_cast <const unsigned char*> (msg [6].data ()),
msg [6].size ());
......
......@@ -98,6 +98,12 @@ bool zmq::dealer_t::xhas_out ()
return lb.has_out ();
}
zmq::blob_t zmq::dealer_t::get_credential () const
{
return fq.get_credential ();
}
void zmq::dealer_t::xread_activated (pipe_t *pipe_)
{
fq.activated (pipe_);
......
......@@ -51,6 +51,7 @@ namespace zmq
int xrecv (zmq::msg_t *msg_);
bool xhas_in ();
bool xhas_out ();
blob_t get_credential () const;
void xread_activated (zmq::pipe_t *pipe_);
void xwrite_activated (zmq::pipe_t *pipe_);
void xpipe_terminated (zmq::pipe_t *pipe_);
......
......@@ -24,6 +24,7 @@
zmq::fq_t::fq_t () :
active (0),
last_in (NULL),
current (0),
more (false)
{
......@@ -54,6 +55,11 @@ void zmq::fq_t::pipe_terminated (pipe_t *pipe_)
current = 0;
}
pipes.erase (pipe_);
if (last_in == pipe_) {
saved_credential = last_in->get_credential ();
last_in = NULL;
}
}
void zmq::fq_t::activated (pipe_t *pipe_)
......@@ -88,8 +94,10 @@ int zmq::fq_t::recvpipe (msg_t *msg_, pipe_t **pipe_)
if (pipe_)
*pipe_ = pipes [current];
more = msg_->flags () & msg_t::more? true: false;
if (!more)
if (!more) {
last_in = pipes [current];
current = (current + 1) % active;
}
return 0;
}
......@@ -136,3 +144,9 @@ bool zmq::fq_t::has_in ()
return false;
}
zmq::blob_t zmq::fq_t::get_credential () const
{
return last_in?
last_in->get_credential (): saved_credential;
}
......@@ -21,6 +21,7 @@
#define __ZMQ_FQ_HPP_INCLUDED__
#include "array.hpp"
#include "blob.hpp"
#include "pipe.hpp"
#include "msg.hpp"
......@@ -45,6 +46,7 @@ namespace zmq
int recv (msg_t *msg_);
int recvpipe (msg_t *msg_, pipe_t **pipe_);
bool has_in ();
blob_t get_credential () const;
private:
......@@ -56,6 +58,11 @@ namespace zmq
// beginning of the pipes array.
pipes_t::size_type active;
// Pointer to the last pipe we received message from.
// NULL when no message has been received or the pipe
// has terminated.
pipe_t *last_in;
// Index of the next bound pipe to read a message from.
pipes_t::size_type current;
......@@ -63,6 +70,9 @@ namespace zmq
// there are following parts still waiting in the current pipe.
bool more;
// Holds credential after the last_acive_pipe has terminated.
blob_t saved_credential;
fq_t (const fq_t&);
const fq_t &operator = (const fq_t&);
};
......
......@@ -47,6 +47,16 @@ void zmq::mechanism_t::peer_identity (msg_t *msg_)
msg_->set_flags (msg_t::identity);
}
void zmq::mechanism_t::set_user_id (const void *data_, size_t size_)
{
user_id = blob_t (static_cast <const unsigned char*> (data_), size_);
}
zmq::blob_t zmq::mechanism_t::get_user_id () const
{
return user_id;
}
const char *zmq::mechanism_t::socket_type_string (int socket_type) const
{
static const char *names [] = {"PAIR", "PUB", "SUB", "REQ", "REP",
......
......@@ -60,6 +60,10 @@ namespace zmq
void peer_identity (msg_t *msg_);
void set_user_id (const void *user_id, size_t size);
blob_t get_user_id () const;
protected:
// Only used to identify the socket for the Socket-Type
......@@ -91,6 +95,8 @@ namespace zmq
blob_t identity;
blob_t user_id;
// Returns true iff socket associated with the mechanism
// is compatible with a given socket type 'type_'.
bool check_socket_type (const std::string type_) const;
......
......@@ -267,6 +267,11 @@ bool zmq::msg_t::is_identity () const
return (u.base.flags & identity) == identity;
}
bool zmq::msg_t::is_credential () const
{
return (u.base.flags & credential) == credential;
}
bool zmq::msg_t::is_delimiter () const
{
return u.base.type == type_delimiter;
......
......@@ -49,6 +49,7 @@ namespace zmq
{
more = 1, // Followed by more parts
command = 2, // Command frame (see ZMTP spec)
credential = 32,
identity = 64,
shared = 128
};
......@@ -70,6 +71,7 @@ namespace zmq
int64_t fd ();
void set_fd (int64_t fd_);
bool is_identity () const;
bool is_credential () const;
bool is_delimiter () const;
bool is_vsm ();
bool is_cmsg ();
......
......@@ -268,6 +268,9 @@ int zmq::null_mechanism_t::receive_and_process_zap_reply ()
goto error;
}
// Save user id
set_user_id (msg [5].data (), msg [5].size ());
// Process metadata frame
rc = parse_metadata (static_cast <const unsigned char*> (msg [6].data ()),
msg [6].size ());
......
......@@ -24,7 +24,8 @@
zmq::pair_t::pair_t (class ctx_t *parent_, uint32_t tid_, int sid_) :
socket_base_t (parent_, tid_, sid_),
pipe (NULL)
pipe (NULL),
last_in (NULL)
{
options.type = ZMQ_PAIR;
}
......@@ -51,8 +52,13 @@ void zmq::pair_t::xattach_pipe (pipe_t *pipe_, bool subscribe_to_all_)
void zmq::pair_t::xpipe_terminated (pipe_t *pipe_)
{
if (pipe_ == pipe)
if (pipe_ == pipe) {
if (last_in == pipe) {
saved_credential = last_in->get_credential ();
last_in = NULL;
}
pipe = NULL;
}
}
void zmq::pair_t::xread_activated (pipe_t *)
......@@ -99,6 +105,7 @@ int zmq::pair_t::xrecv (msg_t *msg_)
errno = EAGAIN;
return -1;
}
last_in = pipe;
return 0;
}
......@@ -117,3 +124,8 @@ bool zmq::pair_t::xhas_out ()
return pipe->check_write ();
}
zmq::blob_t zmq::pair_t::get_credential () const
{
return last_in? last_in->get_credential (): saved_credential;
}
......@@ -20,6 +20,7 @@
#ifndef __ZMQ_PAIR_HPP_INCLUDED__
#define __ZMQ_PAIR_HPP_INCLUDED__
#include "blob.hpp"
#include "socket_base.hpp"
#include "session_base.hpp"
......@@ -45,6 +46,7 @@ namespace zmq
int xrecv (zmq::msg_t *msg_);
bool xhas_in ();
bool xhas_out ();
blob_t get_credential () const;
void xread_activated (zmq::pipe_t *pipe_);
void xwrite_activated (zmq::pipe_t *pipe_);
void xpipe_terminated (zmq::pipe_t *pipe_);
......@@ -53,6 +55,10 @@ namespace zmq
zmq::pipe_t *pipe;
zmq::pipe_t *last_in;
blob_t saved_credential;
pair_t (const pair_t&);
const pair_t &operator = (const pair_t&);
};
......
......@@ -110,6 +110,11 @@ zmq::blob_t zmq::pipe_t::get_identity ()
return identity;
}
zmq::blob_t zmq::pipe_t::get_credential () const
{
return credential;
}
bool zmq::pipe_t::check_read ()
{
if (unlikely (!in_active))
......@@ -143,11 +148,21 @@ bool zmq::pipe_t::read (msg_t *msg_)
if (unlikely (state != active && state != waiting_for_delimiter))
return false;
read_message:
if (!inpipe->read (msg_)) {
in_active = false;
return false;
}
// If this is a credential, save a copy and receive next message.
if (unlikely (msg_->is_credential ())) {
const unsigned char *data = static_cast <const unsigned char *> (msg_->data ());
credential = blob_t (data, msg_->size ());
const int rc = msg_->close ();
zmq_assert (rc == 0);
goto read_message;
}
// If delimiter was read, start termination process of the pipe.
if (msg_->is_delimiter ()) {
process_delimiter ();
......
......@@ -78,6 +78,8 @@ namespace zmq
void set_identity (const blob_t &identity_);
blob_t get_identity ();
blob_t get_credential () const;
// Returns true if there is at least one message to read in the pipe.
bool check_read ();
......@@ -198,6 +200,9 @@ namespace zmq
// Identity of the writer. Used uniquely by the reader side.
blob_t identity;
// Pipe's credential.
blob_t credential;
// Returns true if the message is delimiter; false otherwise.
static bool is_delimiter (const msg_t &msg_);
......
......@@ -468,6 +468,9 @@ int zmq::plain_mechanism_t::receive_and_process_zap_reply ()
goto error;
}
// Save user id
set_user_id (msg [5].data (), msg [5].size ());
// Process metadata frame
rc = parse_metadata (static_cast <const unsigned char*> (msg [6].data ()),
msg [6].size ());
......
......@@ -60,3 +60,8 @@ bool zmq::pull_t::xhas_in ()
{
return fq.has_in ();
}
zmq::blob_t zmq::pull_t::get_credential () const
{
return fq.get_credential ();
}
......@@ -46,6 +46,7 @@ namespace zmq
void xattach_pipe (zmq::pipe_t *pipe_, bool subscribe_to_all_);
int xrecv (zmq::msg_t *msg_);
bool xhas_in ();
blob_t get_credential () const;
void xread_activated (zmq::pipe_t *pipe_);
void xpipe_terminated (zmq::pipe_t *pipe_);
......
......@@ -371,6 +371,11 @@ bool zmq::router_t::xhas_out ()
return true;
}
zmq::blob_t zmq::router_t::get_credential () const
{
return fq.get_credential ();
}
bool zmq::router_t::identify_peer (pipe_t *pipe_)
{
msg_t msg;
......
......@@ -59,6 +59,7 @@ namespace zmq
// Rollback any message parts that were sent but not yet flushed.
int rollback ();
blob_t get_credential () const;
private:
......
......@@ -1043,6 +1043,11 @@ int zmq::socket_base_t::xrecv (msg_t *)
return -1;
}
zmq::blob_t zmq::socket_base_t::get_credential () const
{
return blob_t ();
}
void zmq::socket_base_t::xread_activated (pipe_t *)
{
zmq_assert (false);
......
......@@ -26,6 +26,7 @@
#include "own.hpp"
#include "array.hpp"
#include "blob.hpp"
#include "stdint.hpp"
#include "poller.hpp"
#include "atomic_counter.hpp"
......@@ -144,6 +145,11 @@ namespace zmq
virtual bool xhas_in ();
virtual int xrecv (zmq::msg_t *msg_);
// Returns the credential for the peer from which we have received
// the last message. If no message has been received yet,
// the function returns empty credential.
virtual blob_t get_credential () const;
// i_pipe_events will be forwarded to these functions.
virtual void xread_activated (pipe_t *pipe_);
virtual void xwrite_activated (pipe_t *pipe_);
......
......@@ -692,7 +692,7 @@ void zmq::stream_engine_t::mechanism_ready ()
}
read_msg = &stream_engine_t::pull_and_encode;
write_msg = &stream_engine_t::decode_and_push;
write_msg = &stream_engine_t::write_credential;
}
int zmq::stream_engine_t::pull_msg_from_session (msg_t *msg_)
......@@ -705,6 +705,29 @@ int zmq::stream_engine_t::push_msg_to_session (msg_t *msg_)
return session->push_msg (msg_);
}
int zmq::stream_engine_t::write_credential (msg_t *msg_)
{
zmq_assert (mechanism != NULL);
zmq_assert (session != NULL);
const blob_t credential = mechanism->get_user_id ();
if (credential.size () > 0) {
msg_t msg;
int rc = msg.init_size (credential.size ());
zmq_assert (rc == 0);
memcpy (msg.data (), credential.data (), credential.size ());
msg.set_flags (msg_t::credential);
rc = session->push_msg (&msg);
if (rc == -1) {
rc = msg.close ();
errno_assert (rc == 0);
return -1;
}
}
write_msg = &stream_engine_t::decode_and_push;
return decode_and_push (msg_);
}
int zmq::stream_engine_t::pull_and_encode (msg_t *msg_)
{
zmq_assert (mechanism != NULL);
......
......@@ -101,6 +101,7 @@ namespace zmq
int pull_msg_from_session (msg_t *msg_);
int push_msg_to_session (msg_t *msg);
int write_credential (msg_t *msg_);
int pull_and_encode (msg_t *msg_);
int decode_and_push (msg_t *msg_);
int push_one_then_decode_and_push (msg_t *msg_);
......
......@@ -199,6 +199,11 @@ bool zmq::xsub_t::xhas_in ()
}
}
zmq::blob_t zmq::xsub_t::get_credential () const
{
return fq.get_credential ();
}
bool zmq::xsub_t::match (msg_t *msg_)
{
return subscriptions.check ((unsigned char*) msg_->data (), msg_->size ());
......
......@@ -49,6 +49,7 @@ namespace zmq
bool xhas_out ();
int xrecv (zmq::msg_t *msg_);
bool xhas_in ();
blob_t get_credential () const;
void xread_activated (zmq::pipe_t *pipe_);
void xwrite_activated (zmq::pipe_t *pipe_);
void xhiccuped (pipe_t *pipe_);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment