Commit f909b9c7 authored by Pieter Hintjens's avatar Pieter Hintjens

plain_mechanism now uses options.as_server

- we need to switch to PLAIN according to options.mechanism
- we need to catch case when both peers are as-server (or neither is)
- and to use username/password from options, for client
parent da1e9a17
......@@ -113,7 +113,7 @@ void zmq::ipc_connecter_t::out_event ()
}
// Create the engine object for this connection.
stream_engine_t *engine = new (std::nothrow)
stream_engine_t (fd, options, false, endpoint);
stream_engine_t (fd, options, endpoint);
alloc_assert (engine);
// Attach the engine to the corresponding session object.
......
......@@ -81,7 +81,7 @@ void zmq::ipc_listener_t::in_event ()
// Create the engine object for this connection.
stream_engine_t *engine = new (std::nothrow)
stream_engine_t (fd, options, true, endpoint);
stream_engine_t (fd, options, endpoint);
alloc_assert (engine);
// Choose I/O thread to run connecter in. Given that we are already
......
......@@ -52,7 +52,7 @@ zmq::options_t::options_t () :
tcp_keepalive_idle (-1),
tcp_keepalive_intvl (-1),
mechanism (ZMQ_NULL),
plain_server (0),
as_server (0),
socket_id (0)
{
}
......@@ -251,7 +251,7 @@ int zmq::options_t::setsockopt (int option_, const void *optval_,
case ZMQ_PLAIN_SERVER:
if (is_int && (value == 0 || value == 1)) {
plain_server = value;
as_server = value;
mechanism = value? ZMQ_PLAIN: ZMQ_NULL;
return 0;
}
......@@ -265,7 +265,7 @@ int zmq::options_t::setsockopt (int option_, const void *optval_,
else
if (optvallen_ >= 0 && optvallen_ < 256 && optval_ != NULL) {
plain_username.assign ((const char *) optval_, optvallen_);
plain_server = false;
as_server = 0;
mechanism = ZMQ_PLAIN;
return 0;
}
......@@ -279,7 +279,7 @@ int zmq::options_t::setsockopt (int option_, const void *optval_,
else
if (optvallen_ >= 0 && optvallen_ < 256 && optval_ != NULL) {
plain_password.assign ((const char *) optval_, optvallen_);
plain_server = false;
as_server = 0;
mechanism = ZMQ_PLAIN;
return 0;
}
......@@ -485,7 +485,7 @@ int zmq::options_t::getsockopt (int option_, void *optval_, size_t *optvallen_)
case ZMQ_PLAIN_SERVER:
if (is_int) {
*value = plain_server;
*value = as_server && mechanism == ZMQ_PLAIN;
return 0;
}
break;
......
......@@ -125,11 +125,12 @@ namespace zmq
// Security mechanism for all connections on this socket
int mechanism;
// If peer is acting as server for PLAIN or CURVE mechanisms
int as_server;
// Security credentials for PLAIN mechanism
std::string plain_username;
std::string plain_password;
int plain_server;
// ID of the socket.
int socket_id;
......
......@@ -30,10 +30,9 @@
#include "plain_mechanism.hpp"
#include "wire.hpp"
zmq::plain_mechanism_t::plain_mechanism_t (const options_t &options_,
bool as_server_) :
zmq::plain_mechanism_t::plain_mechanism_t (const options_t &options_) :
mechanism_t (options_),
state (as_server_? waiting_for_hello: sending_hello)
state (options.as_server? waiting_for_hello: sending_hello)
{
}
......@@ -46,31 +45,30 @@ int zmq::plain_mechanism_t::next_handshake_message (msg_t *msg_)
int rc = 0;
switch (state) {
case sending_hello:
rc = hello_command (msg_);
if (rc == 0)
state = waiting_for_welcome;
break;
case sending_welcome:
rc = welcome_command (msg_);
if (rc == 0)
state = waiting_for_initiate;
break;
case sending_initiate:
rc = initiate_command (msg_);
if (rc == 0)
state = waiting_for_ready;
break;
case sending_ready:
rc = ready_command (msg_);
if (rc == 0)
state = ready;
break;
default:
errno = EAGAIN;
rc = -1;
case sending_hello:
rc = hello_command (msg_);
if (rc == 0)
state = waiting_for_welcome;
break;
case sending_welcome:
rc = welcome_command (msg_);
if (rc == 0)
state = waiting_for_initiate;
break;
case sending_initiate:
rc = initiate_command (msg_);
if (rc == 0)
state = waiting_for_ready;
break;
case sending_ready:
rc = ready_command (msg_);
if (rc == 0)
state = ready;
break;
default:
errno = EAGAIN;
rc = -1;
}
return rc;
}
......@@ -79,46 +77,46 @@ int zmq::plain_mechanism_t::process_handshake_message (msg_t *msg_)
int rc = 0;
switch (state) {
case waiting_for_hello:
rc = process_hello_command (msg_);
if (rc == 0)
state = sending_welcome;
break;
case waiting_for_welcome:
rc = process_welcome_command (msg_);
if (rc == 0)
state = sending_initiate;
break;
case waiting_for_initiate:
rc = process_initiate_command (msg_);
if (rc == 0)
state = sending_ready;
break;
case waiting_for_ready:
rc = process_ready_command (msg_);
if (rc == 0)
state = ready;
break;
default:
errno = EAGAIN;
rc = -1;
case waiting_for_hello:
rc = process_hello_command (msg_);
if (rc == 0)
state = sending_welcome;
break;
case waiting_for_welcome:
rc = process_welcome_command (msg_);
if (rc == 0)
state = sending_initiate;
break;
case waiting_for_initiate:
rc = process_initiate_command (msg_);
if (rc == 0)
state = sending_ready;
break;
case waiting_for_ready:
rc = process_ready_command (msg_);
if (rc == 0)
state = ready;
break;
default:
errno = EAGAIN;
rc = -1;
}
if (rc == 0) {
rc = msg_->close ();
errno_assert (rc == 0);
rc = msg_->init ();
errno_assert (rc == 0);
}
return 0;
}
bool zmq::plain_mechanism_t::is_handshake_complete () const
{
return state == ready;
}
int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const
{
const std::string username = options.plain_username;
......@@ -147,6 +145,7 @@ int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const
return 0;
}
int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
{
const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
......@@ -164,7 +163,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
errno = EPROTO;
return -1;
}
size_t username_length = static_cast <size_t> (*ptr++);
bytes_left -= 1;
......@@ -172,7 +170,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
errno = EPROTO;
return -1;
}
const std::string username = std::string ((char *) ptr, username_length);
ptr += username_length;
bytes_left -= username_length;
......@@ -181,7 +178,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
errno = EPROTO;
return -1;
}
size_t password_length = static_cast <size_t> (*ptr++);
bytes_left -= 1;
......@@ -189,7 +185,6 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
errno = EPROTO;
return -1;
}
const std::string password = std::string ((char *) ptr, password_length);
ptr += password_length;
bytes_left -= password_length;
......@@ -198,9 +193,8 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
errno = EPROTO;
return -1;
}
// TODO: Add user authentication
return 0;
}
......@@ -221,7 +215,6 @@ int zmq::plain_mechanism_t::process_welcome_command (msg_t *msg_)
errno = EPROTO;
return -1;
}
return 0;
}
......@@ -266,7 +259,6 @@ int zmq::plain_mechanism_t::process_initiate_command (msg_t *msg_)
errno = EPROTO;
return -1;
}
return parse_property_list (ptr + 8, bytes_left - 8);
}
......@@ -311,7 +303,6 @@ int zmq::plain_mechanism_t::process_ready_command (msg_t *msg_)
errno = EPROTO;
return -1;
}
return parse_property_list (ptr + 8, bytes_left - 8);
}
......@@ -322,21 +313,21 @@ int zmq::plain_mechanism_t::parse_property_list (const unsigned char *ptr,
const size_t name_length = static_cast <size_t> (*ptr);
ptr += 1;
bytes_left -= 1;
if (bytes_left < name_length)
break;
const std::string name = std::string((const char *) ptr, name_length);
const std::string name = std::string ((const char *) ptr, name_length);
ptr += name_length;
bytes_left -= name_length;
if (bytes_left < 4)
break;
const size_t value_length = static_cast <size_t> (get_uint32 (ptr));
ptr += 4;
bytes_left -= 4;
if (bytes_left < value_length)
break;
const unsigned char * const value = ptr;
ptr += value_length;
bytes_left -= value_length;
......@@ -348,11 +339,9 @@ int zmq::plain_mechanism_t::parse_property_list (const unsigned char *ptr,
if (name == "Identity" && options.recv_identity)
set_peer_identity (value, value_length);
}
if (bytes_left > 0) {
errno = EPROTO;
return -1;
}
return 0;
}
......@@ -32,7 +32,7 @@ namespace zmq
{
public:
plain_mechanism_t (const options_t &options_, bool as_server_);
plain_mechanism_t (const options_t &options_);
virtual ~plain_mechanism_t ();
// mechanism implementation
......
......@@ -50,10 +50,9 @@
#include "likely.hpp"
#include "wire.hpp"
zmq::stream_engine_t::stream_engine_t (fd_t fd_, const options_t &options_,
bool as_server_, const std::string &endpoint_) :
zmq::stream_engine_t::stream_engine_t (fd_t fd_, const options_t &options_,
const std::string &endpoint_) :
s (fd_),
as_server (as_server_),
inpos (NULL),
insize (0),
decoder (NULL),
......@@ -80,7 +79,7 @@ zmq::stream_engine_t::stream_engine_t (fd_t fd_, const options_t &options_,
{
int rc = tx_msg.init ();
errno_assert (rc == 0);
// Put the socket into non-blocking mode.
unblock_socket (s);
// Set the socket buffer limits for the underlying socket.
......@@ -532,8 +531,7 @@ bool zmq::stream_engine_t::handshake ()
}
else
if (memcmp (greeting_recv + 12, "PLAIN\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) == 0) {
mechanism = new (std::nothrow)
plain_mechanism_t (options, options.plain_server);
mechanism = new (std::nothrow) plain_mechanism_t (options);
alloc_assert (mechanism);
}
else {
......
......@@ -52,8 +52,8 @@ namespace zmq
{
public:
stream_engine_t (fd_t fd_, const options_t &options_,
bool as_server_, const std::string &endpoint);
stream_engine_t (fd_t fd_, const options_t &options_,
const std::string &endpoint);
~stream_engine_t ();
// i_engine interface implementation.
......
......@@ -126,7 +126,7 @@ void zmq::tcp_connecter_t::out_event ()
// Create the engine object for this connection.
stream_engine_t *engine = new (std::nothrow)
stream_engine_t (fd, options, false, endpoint);
stream_engine_t (fd, options, endpoint);
alloc_assert (engine);
// Attach the engine to the corresponding session object.
......
......@@ -92,7 +92,7 @@ void zmq::tcp_listener_t::in_event ()
// Create the engine object for this connection.
stream_engine_t *engine = new (std::nothrow)
stream_engine_t (fd, options, true, endpoint);
stream_engine_t (fd, options, endpoint);
alloc_assert (engine);
// Choose I/O thread to run connecter in. Given that we are already
......
......@@ -17,11 +17,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "../include/zmq.h"
#include <string.h>
#include <stdbool.h>
#undef NDEBUG
#include <assert.h>
#include "testutil.hpp"
int main (void)
{
......@@ -40,18 +36,43 @@ int main (void)
int rc;
size_t optsize;
int mechanism;
int as_server;
optsize = sizeof (int);
rc = zmq_getsockopt (client, ZMQ_MECHANISM, &mechanism, &optsize);
assert (rc == 0);
assert (mechanism == ZMQ_NULL);
optsize = sizeof (int);
rc = zmq_getsockopt (server, ZMQ_MECHANISM, &mechanism, &optsize);
assert (rc == 0);
assert (mechanism == ZMQ_NULL);
assert (mechanism == ZMQ_NULL);
rc = zmq_getsockopt (client, ZMQ_PLAIN_SERVER, &as_server, &optsize);
assert (rc == 0);
assert (as_server == 0);
rc = zmq_getsockopt (server, ZMQ_PLAIN_SERVER, &as_server, &optsize);
assert (rc == 0);
assert (as_server == 0);
rc = zmq_bind (server, "tcp://*:9999");
assert (rc == 0);
rc = zmq_connect (client, "tcp://localhost:9999");
assert (rc == 0);
bounce (server, client);
rc = zmq_close (client);
assert (rc == 0);
rc = zmq_close (server);
assert (rc == 0);
// Check PLAIN security
server = zmq_socket (ctx, ZMQ_DEALER);
assert (server);
client = zmq_socket (ctx, ZMQ_DEALER);
assert (client);
char username [256];
optsize = 256;
rc = zmq_getsockopt (client, ZMQ_PLAIN_USERNAME, username, &optsize);
......@@ -80,36 +101,63 @@ int main (void)
assert (rc == 0);
assert (optsize == 8 + 1);
as_server = 1;
rc = zmq_setsockopt (server, ZMQ_PLAIN_SERVER, &as_server, sizeof (int));
assert (rc == 0);
optsize = sizeof (int);
rc = zmq_getsockopt (client, ZMQ_MECHANISM, &mechanism, &optsize);
assert (rc == 0);
assert (mechanism == ZMQ_PLAIN);
int as_server = 1;
rc = zmq_setsockopt (server, ZMQ_PLAIN_SERVER, &as_server, sizeof (int));
assert (rc == 0);
optsize = sizeof (int);
rc = zmq_getsockopt (server, ZMQ_MECHANISM, &mechanism, &optsize);
assert (rc == 0);
assert (mechanism == ZMQ_PLAIN);
// Check we can switch back to NULL security
rc = zmq_setsockopt (client, ZMQ_PLAIN_USERNAME, NULL, 0);
rc = zmq_getsockopt (client, ZMQ_PLAIN_SERVER, &as_server, &optsize);
assert (rc == 0);
rc = zmq_setsockopt (client, ZMQ_PLAIN_PASSWORD, NULL, 0);
assert (as_server == 0);
rc = zmq_getsockopt (server, ZMQ_PLAIN_SERVER, &as_server, &optsize);
assert (rc == 0);
optsize = sizeof (int);
rc = zmq_getsockopt (client, ZMQ_MECHANISM, &mechanism, &optsize);
assert (as_server == 1);
rc = zmq_bind (server, "tcp://*:9998");
assert (rc == 0);
assert (mechanism == ZMQ_NULL);
rc = zmq_connect (client, "tcp://localhost:9998");
assert (rc == 0);
bounce (server, client);
rc = zmq_close (client);
assert (rc == 0);
rc = zmq_close (server);
assert (rc == 0);
// Check PLAIN security -- two servers trying to talk to each other
server = zmq_socket (ctx, ZMQ_DEALER);
assert (server);
client = zmq_socket (ctx, ZMQ_DEALER);
assert (client);
rc = zmq_setsockopt (server, ZMQ_PLAIN_SERVER, &as_server, sizeof (int));
assert (rc == 0);
rc = zmq_setsockopt (client, ZMQ_PLAIN_SERVER, &as_server, sizeof (int));
assert (rc == 0);
rc = zmq_bind (server, "tcp://*:9997");
assert (rc == 0);
rc = zmq_connect (client, "tcp://localhost:9997");
assert (rc == 0);
bounce (server, client);
rc = zmq_close (client);
assert (rc == 0);
rc = zmq_close (server);
assert (rc == 0);
// Shutdown
rc = zmq_ctx_term (ctx);
assert (rc == 0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment