Commit f909b9c7 authored by Pieter Hintjens's avatar Pieter Hintjens

plain_mechanism now uses options.as_server

- we need to switch to PLAIN according to options.mechanism
- we need to catch case when both peers are as-server (or neither is)
- and to use username/password from options, for client
parent da1e9a17
...@@ -113,7 +113,7 @@ void zmq::ipc_connecter_t::out_event () ...@@ -113,7 +113,7 @@ void zmq::ipc_connecter_t::out_event ()
} }
// Create the engine object for this connection. // Create the engine object for this connection.
stream_engine_t *engine = new (std::nothrow) stream_engine_t *engine = new (std::nothrow)
stream_engine_t (fd, options, false, endpoint); stream_engine_t (fd, options, endpoint);
alloc_assert (engine); alloc_assert (engine);
// Attach the engine to the corresponding session object. // Attach the engine to the corresponding session object.
......
...@@ -81,7 +81,7 @@ void zmq::ipc_listener_t::in_event () ...@@ -81,7 +81,7 @@ void zmq::ipc_listener_t::in_event ()
// Create the engine object for this connection. // Create the engine object for this connection.
stream_engine_t *engine = new (std::nothrow) stream_engine_t *engine = new (std::nothrow)
stream_engine_t (fd, options, true, endpoint); stream_engine_t (fd, options, endpoint);
alloc_assert (engine); alloc_assert (engine);
// Choose I/O thread to run connecter in. Given that we are already // Choose I/O thread to run connecter in. Given that we are already
......
...@@ -52,7 +52,7 @@ zmq::options_t::options_t () : ...@@ -52,7 +52,7 @@ zmq::options_t::options_t () :
tcp_keepalive_idle (-1), tcp_keepalive_idle (-1),
tcp_keepalive_intvl (-1), tcp_keepalive_intvl (-1),
mechanism (ZMQ_NULL), mechanism (ZMQ_NULL),
plain_server (0), as_server (0),
socket_id (0) socket_id (0)
{ {
} }
...@@ -251,7 +251,7 @@ int zmq::options_t::setsockopt (int option_, const void *optval_, ...@@ -251,7 +251,7 @@ int zmq::options_t::setsockopt (int option_, const void *optval_,
case ZMQ_PLAIN_SERVER: case ZMQ_PLAIN_SERVER:
if (is_int && (value == 0 || value == 1)) { if (is_int && (value == 0 || value == 1)) {
plain_server = value; as_server = value;
mechanism = value? ZMQ_PLAIN: ZMQ_NULL; mechanism = value? ZMQ_PLAIN: ZMQ_NULL;
return 0; return 0;
} }
...@@ -265,7 +265,7 @@ int zmq::options_t::setsockopt (int option_, const void *optval_, ...@@ -265,7 +265,7 @@ int zmq::options_t::setsockopt (int option_, const void *optval_,
else else
if (optvallen_ >= 0 && optvallen_ < 256 && optval_ != NULL) { if (optvallen_ >= 0 && optvallen_ < 256 && optval_ != NULL) {
plain_username.assign ((const char *) optval_, optvallen_); plain_username.assign ((const char *) optval_, optvallen_);
plain_server = false; as_server = 0;
mechanism = ZMQ_PLAIN; mechanism = ZMQ_PLAIN;
return 0; return 0;
} }
...@@ -279,7 +279,7 @@ int zmq::options_t::setsockopt (int option_, const void *optval_, ...@@ -279,7 +279,7 @@ int zmq::options_t::setsockopt (int option_, const void *optval_,
else else
if (optvallen_ >= 0 && optvallen_ < 256 && optval_ != NULL) { if (optvallen_ >= 0 && optvallen_ < 256 && optval_ != NULL) {
plain_password.assign ((const char *) optval_, optvallen_); plain_password.assign ((const char *) optval_, optvallen_);
plain_server = false; as_server = 0;
mechanism = ZMQ_PLAIN; mechanism = ZMQ_PLAIN;
return 0; return 0;
} }
...@@ -485,7 +485,7 @@ int zmq::options_t::getsockopt (int option_, void *optval_, size_t *optvallen_) ...@@ -485,7 +485,7 @@ int zmq::options_t::getsockopt (int option_, void *optval_, size_t *optvallen_)
case ZMQ_PLAIN_SERVER: case ZMQ_PLAIN_SERVER:
if (is_int) { if (is_int) {
*value = plain_server; *value = as_server && mechanism == ZMQ_PLAIN;
return 0; return 0;
} }
break; break;
......
...@@ -125,11 +125,12 @@ namespace zmq ...@@ -125,11 +125,12 @@ namespace zmq
// Security mechanism for all connections on this socket // Security mechanism for all connections on this socket
int mechanism; int mechanism;
// If peer is acting as server for PLAIN or CURVE mechanisms
int as_server;
// Security credentials for PLAIN mechanism // Security credentials for PLAIN mechanism
std::string plain_username; std::string plain_username;
std::string plain_password; std::string plain_password;
int plain_server;
// ID of the socket. // ID of the socket.
int socket_id; int socket_id;
......
...@@ -30,10 +30,9 @@ ...@@ -30,10 +30,9 @@
#include "plain_mechanism.hpp" #include "plain_mechanism.hpp"
#include "wire.hpp" #include "wire.hpp"
zmq::plain_mechanism_t::plain_mechanism_t (const options_t &options_, zmq::plain_mechanism_t::plain_mechanism_t (const options_t &options_) :
bool as_server_) :
mechanism_t (options_), mechanism_t (options_),
state (as_server_? waiting_for_hello: sending_hello) state (options.as_server? waiting_for_hello: sending_hello)
{ {
} }
...@@ -46,31 +45,30 @@ int zmq::plain_mechanism_t::next_handshake_message (msg_t *msg_) ...@@ -46,31 +45,30 @@ int zmq::plain_mechanism_t::next_handshake_message (msg_t *msg_)
int rc = 0; int rc = 0;
switch (state) { switch (state) {
case sending_hello: case sending_hello:
rc = hello_command (msg_); rc = hello_command (msg_);
if (rc == 0) if (rc == 0)
state = waiting_for_welcome; state = waiting_for_welcome;
break; break;
case sending_welcome: case sending_welcome:
rc = welcome_command (msg_); rc = welcome_command (msg_);
if (rc == 0) if (rc == 0)
state = waiting_for_initiate; state = waiting_for_initiate;
break; break;
case sending_initiate: case sending_initiate:
rc = initiate_command (msg_); rc = initiate_command (msg_);
if (rc == 0) if (rc == 0)
state = waiting_for_ready; state = waiting_for_ready;
break; break;
case sending_ready: case sending_ready:
rc = ready_command (msg_); rc = ready_command (msg_);
if (rc == 0) if (rc == 0)
state = ready; state = ready;
break; break;
default: default:
errno = EAGAIN; errno = EAGAIN;
rc = -1; rc = -1;
} }
return rc; return rc;
} }
...@@ -79,46 +77,46 @@ int zmq::plain_mechanism_t::process_handshake_message (msg_t *msg_) ...@@ -79,46 +77,46 @@ int zmq::plain_mechanism_t::process_handshake_message (msg_t *msg_)
int rc = 0; int rc = 0;
switch (state) { switch (state) {
case waiting_for_hello: case waiting_for_hello:
rc = process_hello_command (msg_); rc = process_hello_command (msg_);
if (rc == 0) if (rc == 0)
state = sending_welcome; state = sending_welcome;
break; break;
case waiting_for_welcome: case waiting_for_welcome:
rc = process_welcome_command (msg_); rc = process_welcome_command (msg_);
if (rc == 0) if (rc == 0)
state = sending_initiate; state = sending_initiate;
break; break;
case waiting_for_initiate: case waiting_for_initiate:
rc = process_initiate_command (msg_); rc = process_initiate_command (msg_);
if (rc == 0) if (rc == 0)
state = sending_ready; state = sending_ready;
break; break;
case waiting_for_ready: case waiting_for_ready:
rc = process_ready_command (msg_); rc = process_ready_command (msg_);
if (rc == 0) if (rc == 0)
state = ready; state = ready;
break; break;
default: default:
errno = EAGAIN; errno = EAGAIN;
rc = -1; rc = -1;
} }
if (rc == 0) { if (rc == 0) {
rc = msg_->close (); rc = msg_->close ();
errno_assert (rc == 0); errno_assert (rc == 0);
rc = msg_->init (); rc = msg_->init ();
errno_assert (rc == 0); errno_assert (rc == 0);
} }
return 0; return 0;
} }
bool zmq::plain_mechanism_t::is_handshake_complete () const bool zmq::plain_mechanism_t::is_handshake_complete () const
{ {
return state == ready; return state == ready;
} }
int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const
{ {
const std::string username = options.plain_username; const std::string username = options.plain_username;
...@@ -147,6 +145,7 @@ int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const ...@@ -147,6 +145,7 @@ int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const
return 0; return 0;
} }
int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_) int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
{ {
const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ()); const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
...@@ -164,7 +163,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_) ...@@ -164,7 +163,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
size_t username_length = static_cast <size_t> (*ptr++); size_t username_length = static_cast <size_t> (*ptr++);
bytes_left -= 1; bytes_left -= 1;
...@@ -172,7 +170,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_) ...@@ -172,7 +170,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
const std::string username = std::string ((char *) ptr, username_length); const std::string username = std::string ((char *) ptr, username_length);
ptr += username_length; ptr += username_length;
bytes_left -= username_length; bytes_left -= username_length;
...@@ -181,7 +178,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_) ...@@ -181,7 +178,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
size_t password_length = static_cast <size_t> (*ptr++); size_t password_length = static_cast <size_t> (*ptr++);
bytes_left -= 1; bytes_left -= 1;
...@@ -189,7 +185,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_) ...@@ -189,7 +185,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
const std::string password = std::string ((char *) ptr, password_length); const std::string password = std::string ((char *) ptr, password_length);
ptr += password_length; ptr += password_length;
bytes_left -= password_length; bytes_left -= password_length;
...@@ -198,9 +193,8 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_) ...@@ -198,9 +193,8 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
// TODO: Add user authentication // TODO: Add user authentication
return 0; return 0;
} }
...@@ -221,7 +215,6 @@ int zmq::plain_mechanism_t::process_welcome_command (msg_t *msg_) ...@@ -221,7 +215,6 @@ int zmq::plain_mechanism_t::process_welcome_command (msg_t *msg_)
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
return 0; return 0;
} }
...@@ -266,7 +259,6 @@ int zmq::plain_mechanism_t::process_initiate_command (msg_t *msg_) ...@@ -266,7 +259,6 @@ int zmq::plain_mechanism_t::process_initiate_command (msg_t *msg_)
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
return parse_property_list (ptr + 8, bytes_left - 8); return parse_property_list (ptr + 8, bytes_left - 8);
} }
...@@ -311,7 +303,6 @@ int zmq::plain_mechanism_t::process_ready_command (msg_t *msg_) ...@@ -311,7 +303,6 @@ int zmq::plain_mechanism_t::process_ready_command (msg_t *msg_)
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
return parse_property_list (ptr + 8, bytes_left - 8); return parse_property_list (ptr + 8, bytes_left - 8);
} }
...@@ -322,21 +313,21 @@ int zmq::plain_mechanism_t::parse_property_list (const unsigned char *ptr, ...@@ -322,21 +313,21 @@ int zmq::plain_mechanism_t::parse_property_list (const unsigned char *ptr,
const size_t name_length = static_cast <size_t> (*ptr); const size_t name_length = static_cast <size_t> (*ptr);
ptr += 1; ptr += 1;
bytes_left -= 1; bytes_left -= 1;
if (bytes_left < name_length) if (bytes_left < name_length)
break; break;
const std::string name = std::string((const char *) ptr, name_length);
const std::string name = std::string ((const char *) ptr, name_length);
ptr += name_length; ptr += name_length;
bytes_left -= name_length; bytes_left -= name_length;
if (bytes_left < 4) if (bytes_left < 4)
break; break;
const size_t value_length = static_cast <size_t> (get_uint32 (ptr)); const size_t value_length = static_cast <size_t> (get_uint32 (ptr));
ptr += 4; ptr += 4;
bytes_left -= 4; bytes_left -= 4;
if (bytes_left < value_length) if (bytes_left < value_length)
break; break;
const unsigned char * const value = ptr; const unsigned char * const value = ptr;
ptr += value_length; ptr += value_length;
bytes_left -= value_length; bytes_left -= value_length;
...@@ -348,11 +339,9 @@ int zmq::plain_mechanism_t::parse_property_list (const unsigned char *ptr, ...@@ -348,11 +339,9 @@ int zmq::plain_mechanism_t::parse_property_list (const unsigned char *ptr,
if (name == "Identity" && options.recv_identity) if (name == "Identity" && options.recv_identity)
set_peer_identity (value, value_length); set_peer_identity (value, value_length);
} }
if (bytes_left > 0) { if (bytes_left > 0) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
return 0; return 0;
} }
...@@ -32,7 +32,7 @@ namespace zmq ...@@ -32,7 +32,7 @@ namespace zmq
{ {
public: public:
plain_mechanism_t (const options_t &options_, bool as_server_); plain_mechanism_t (const options_t &options_);
virtual ~plain_mechanism_t (); virtual ~plain_mechanism_t ();
// mechanism implementation // mechanism implementation
......
...@@ -50,10 +50,9 @@ ...@@ -50,10 +50,9 @@
#include "likely.hpp" #include "likely.hpp"
#include "wire.hpp" #include "wire.hpp"
zmq::stream_engine_t::stream_engine_t (fd_t fd_, const options_t &options_, zmq::stream_engine_t::stream_engine_t (fd_t fd_, const options_t &options_,
bool as_server_, const std::string &endpoint_) : const std::string &endpoint_) :
s (fd_), s (fd_),
as_server (as_server_),
inpos (NULL), inpos (NULL),
insize (0), insize (0),
decoder (NULL), decoder (NULL),
...@@ -80,7 +79,7 @@ zmq::stream_engine_t::stream_engine_t (fd_t fd_, const options_t &options_, ...@@ -80,7 +79,7 @@ zmq::stream_engine_t::stream_engine_t (fd_t fd_, const options_t &options_,
{ {
int rc = tx_msg.init (); int rc = tx_msg.init ();
errno_assert (rc == 0); errno_assert (rc == 0);
// Put the socket into non-blocking mode. // Put the socket into non-blocking mode.
unblock_socket (s); unblock_socket (s);
// Set the socket buffer limits for the underlying socket. // Set the socket buffer limits for the underlying socket.
...@@ -532,8 +531,7 @@ bool zmq::stream_engine_t::handshake () ...@@ -532,8 +531,7 @@ bool zmq::stream_engine_t::handshake ()
} }
else else
if (memcmp (greeting_recv + 12, "PLAIN\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) == 0) { if (memcmp (greeting_recv + 12, "PLAIN\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) == 0) {
mechanism = new (std::nothrow) mechanism = new (std::nothrow) plain_mechanism_t (options);
plain_mechanism_t (options, options.plain_server);
alloc_assert (mechanism); alloc_assert (mechanism);
} }
else { else {
......
...@@ -52,8 +52,8 @@ namespace zmq ...@@ -52,8 +52,8 @@ namespace zmq
{ {
public: public:
stream_engine_t (fd_t fd_, const options_t &options_, stream_engine_t (fd_t fd_, const options_t &options_,
bool as_server_, const std::string &endpoint); const std::string &endpoint);
~stream_engine_t (); ~stream_engine_t ();
// i_engine interface implementation. // i_engine interface implementation.
......
...@@ -126,7 +126,7 @@ void zmq::tcp_connecter_t::out_event () ...@@ -126,7 +126,7 @@ void zmq::tcp_connecter_t::out_event ()
// Create the engine object for this connection. // Create the engine object for this connection.
stream_engine_t *engine = new (std::nothrow) stream_engine_t *engine = new (std::nothrow)
stream_engine_t (fd, options, false, endpoint); stream_engine_t (fd, options, endpoint);
alloc_assert (engine); alloc_assert (engine);
// Attach the engine to the corresponding session object. // Attach the engine to the corresponding session object.
......
...@@ -92,7 +92,7 @@ void zmq::tcp_listener_t::in_event () ...@@ -92,7 +92,7 @@ void zmq::tcp_listener_t::in_event ()
// Create the engine object for this connection. // Create the engine object for this connection.
stream_engine_t *engine = new (std::nothrow) stream_engine_t *engine = new (std::nothrow)
stream_engine_t (fd, options, true, endpoint); stream_engine_t (fd, options, endpoint);
alloc_assert (engine); alloc_assert (engine);
// Choose I/O thread to run connecter in. Given that we are already // Choose I/O thread to run connecter in. Given that we are already
......
...@@ -17,11 +17,7 @@ ...@@ -17,11 +17,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
#include "../include/zmq.h" #include "testutil.hpp"
#include <string.h>
#include <stdbool.h>
#undef NDEBUG
#include <assert.h>
int main (void) int main (void)
{ {
...@@ -40,18 +36,43 @@ int main (void) ...@@ -40,18 +36,43 @@ int main (void)
int rc; int rc;
size_t optsize; size_t optsize;
int mechanism; int mechanism;
int as_server;
optsize = sizeof (int); optsize = sizeof (int);
rc = zmq_getsockopt (client, ZMQ_MECHANISM, &mechanism, &optsize); rc = zmq_getsockopt (client, ZMQ_MECHANISM, &mechanism, &optsize);
assert (rc == 0); assert (rc == 0);
assert (mechanism == ZMQ_NULL); assert (mechanism == ZMQ_NULL);
optsize = sizeof (int);
rc = zmq_getsockopt (server, ZMQ_MECHANISM, &mechanism, &optsize); rc = zmq_getsockopt (server, ZMQ_MECHANISM, &mechanism, &optsize);
assert (rc == 0); assert (rc == 0);
assert (mechanism == ZMQ_NULL); assert (mechanism == ZMQ_NULL);
rc = zmq_getsockopt (client, ZMQ_PLAIN_SERVER, &as_server, &optsize);
assert (rc == 0);
assert (as_server == 0);
rc = zmq_getsockopt (server, ZMQ_PLAIN_SERVER, &as_server, &optsize);
assert (rc == 0);
assert (as_server == 0);
rc = zmq_bind (server, "tcp://*:9999");
assert (rc == 0);
rc = zmq_connect (client, "tcp://localhost:9999");
assert (rc == 0);
bounce (server, client);
rc = zmq_close (client);
assert (rc == 0);
rc = zmq_close (server);
assert (rc == 0);
// Check PLAIN security // Check PLAIN security
server = zmq_socket (ctx, ZMQ_DEALER);
assert (server);
client = zmq_socket (ctx, ZMQ_DEALER);
assert (client);
char username [256]; char username [256];
optsize = 256; optsize = 256;
rc = zmq_getsockopt (client, ZMQ_PLAIN_USERNAME, username, &optsize); rc = zmq_getsockopt (client, ZMQ_PLAIN_USERNAME, username, &optsize);
...@@ -80,36 +101,63 @@ int main (void) ...@@ -80,36 +101,63 @@ int main (void)
assert (rc == 0); assert (rc == 0);
assert (optsize == 8 + 1); assert (optsize == 8 + 1);
as_server = 1;
rc = zmq_setsockopt (server, ZMQ_PLAIN_SERVER, &as_server, sizeof (int));
assert (rc == 0);
optsize = sizeof (int); optsize = sizeof (int);
rc = zmq_getsockopt (client, ZMQ_MECHANISM, &mechanism, &optsize); rc = zmq_getsockopt (client, ZMQ_MECHANISM, &mechanism, &optsize);
assert (rc == 0); assert (rc == 0);
assert (mechanism == ZMQ_PLAIN); assert (mechanism == ZMQ_PLAIN);
int as_server = 1;
rc = zmq_setsockopt (server, ZMQ_PLAIN_SERVER, &as_server, sizeof (int));
assert (rc == 0);
optsize = sizeof (int);
rc = zmq_getsockopt (server, ZMQ_MECHANISM, &mechanism, &optsize); rc = zmq_getsockopt (server, ZMQ_MECHANISM, &mechanism, &optsize);
assert (rc == 0); assert (rc == 0);
assert (mechanism == ZMQ_PLAIN); assert (mechanism == ZMQ_PLAIN);
// Check we can switch back to NULL security rc = zmq_getsockopt (client, ZMQ_PLAIN_SERVER, &as_server, &optsize);
rc = zmq_setsockopt (client, ZMQ_PLAIN_USERNAME, NULL, 0);
assert (rc == 0); assert (rc == 0);
rc = zmq_setsockopt (client, ZMQ_PLAIN_PASSWORD, NULL, 0); assert (as_server == 0);
rc = zmq_getsockopt (server, ZMQ_PLAIN_SERVER, &as_server, &optsize);
assert (rc == 0); assert (rc == 0);
optsize = sizeof (int); assert (as_server == 1);
rc = zmq_getsockopt (client, ZMQ_MECHANISM, &mechanism, &optsize);
rc = zmq_bind (server, "tcp://*:9998");
assert (rc == 0); assert (rc == 0);
assert (mechanism == ZMQ_NULL); rc = zmq_connect (client, "tcp://localhost:9998");
assert (rc == 0);
bounce (server, client);
rc = zmq_close (client); rc = zmq_close (client);
assert (rc == 0); assert (rc == 0);
rc = zmq_close (server);
assert (rc == 0);
// Check PLAIN security -- two servers trying to talk to each other
server = zmq_socket (ctx, ZMQ_DEALER);
assert (server);
client = zmq_socket (ctx, ZMQ_DEALER);
assert (client);
rc = zmq_setsockopt (server, ZMQ_PLAIN_SERVER, &as_server, sizeof (int));
assert (rc == 0);
rc = zmq_setsockopt (client, ZMQ_PLAIN_SERVER, &as_server, sizeof (int));
assert (rc == 0);
rc = zmq_bind (server, "tcp://*:9997");
assert (rc == 0);
rc = zmq_connect (client, "tcp://localhost:9997");
assert (rc == 0);
bounce (server, client);
rc = zmq_close (client);
assert (rc == 0);
rc = zmq_close (server); rc = zmq_close (server);
assert (rc == 0); assert (rc == 0);
// Shutdown
rc = zmq_ctx_term (ctx); rc = zmq_ctx_term (ctx);
assert (rc == 0); assert (rc == 0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment