Commit 28b0a5fa authored by Pieter Hintjens's avatar Pieter Hintjens

Updated libzmq to match RFC 23, 24, 25, 26

* Command names changed from null terminated to length-specified
* Command frames use the correct flag (bit 2)
* test_stream acts as test case for command frames
* Some code cleanups
parent 1844a27c
...@@ -50,18 +50,18 @@ zmq::curve_client_t::~curve_client_t () ...@@ -50,18 +50,18 @@ zmq::curve_client_t::~curve_client_t ()
{ {
} }
int zmq::curve_client_t::next_handshake_message (msg_t *msg_) int zmq::curve_client_t::next_handshake_command (msg_t *msg_)
{ {
int rc = 0; int rc = 0;
switch (state) { switch (state) {
case send_hello: case send_hello:
rc = hello_msg (msg_); rc = produce_hello (msg_);
if (rc == 0) if (rc == 0)
state = expect_welcome; state = expect_welcome;
break; break;
case send_initiate: case send_initiate:
rc = initiate_msg (msg_); rc = produce_initiate (msg_);
if (rc == 0) if (rc == 0)
state = expect_ready; state = expect_ready;
break; break;
...@@ -72,7 +72,7 @@ int zmq::curve_client_t::next_handshake_message (msg_t *msg_) ...@@ -72,7 +72,7 @@ int zmq::curve_client_t::next_handshake_message (msg_t *msg_)
return rc; return rc;
} }
int zmq::curve_client_t::process_handshake_message (msg_t *msg_) int zmq::curve_client_t::process_handshake_command (msg_t *msg_)
{ {
int rc = 0; int rc = 0;
...@@ -138,7 +138,7 @@ int zmq::curve_client_t::encode (msg_t *msg_) ...@@ -138,7 +138,7 @@ int zmq::curve_client_t::encode (msg_t *msg_)
uint8_t *message = static_cast <uint8_t *> (msg_->data ()); uint8_t *message = static_cast <uint8_t *> (msg_->data ());
memcpy (message, "MESSAGE\0", 8); memcpy (message, "\x07MESSAGE", 8);
memcpy (message + 8, &cn_nonce, 8); memcpy (message + 8, &cn_nonce, 8);
memcpy (message + 16, message_box + crypto_box_BOXZEROBYTES, memcpy (message + 16, message_box + crypto_box_BOXZEROBYTES,
mlen - crypto_box_BOXZEROBYTES); mlen - crypto_box_BOXZEROBYTES);
...@@ -161,7 +161,7 @@ int zmq::curve_client_t::decode (msg_t *msg_) ...@@ -161,7 +161,7 @@ int zmq::curve_client_t::decode (msg_t *msg_)
} }
const uint8_t *message = static_cast <uint8_t *> (msg_->data ()); const uint8_t *message = static_cast <uint8_t *> (msg_->data ());
if (memcmp (message, "MESSAGE\0", 8)) { if (memcmp (message, "\x07MESSAGE", 8)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
...@@ -213,7 +213,7 @@ bool zmq::curve_client_t::is_handshake_complete () const ...@@ -213,7 +213,7 @@ bool zmq::curve_client_t::is_handshake_complete () const
return state == connected; return state == connected;
} }
int zmq::curve_client_t::hello_msg (msg_t *msg_) int zmq::curve_client_t::produce_hello (msg_t *msg_)
{ {
uint8_t hello_nonce [crypto_box_NONCEBYTES]; uint8_t hello_nonce [crypto_box_NONCEBYTES];
uint8_t hello_plaintext [crypto_box_ZEROBYTES + 64]; uint8_t hello_plaintext [crypto_box_ZEROBYTES + 64];
...@@ -235,7 +235,7 @@ int zmq::curve_client_t::hello_msg (msg_t *msg_) ...@@ -235,7 +235,7 @@ int zmq::curve_client_t::hello_msg (msg_t *msg_)
errno_assert (rc == 0); errno_assert (rc == 0);
uint8_t *hello = static_cast <uint8_t *> (msg_->data ()); uint8_t *hello = static_cast <uint8_t *> (msg_->data ());
memcpy (hello, "HELLO\0", 6); memcpy (hello, "\x05HELLO", 6);
// CurveZMQ major and minor version numbers // CurveZMQ major and minor version numbers
memcpy (hello + 6, "\1\0", 2); memcpy (hello + 6, "\1\0", 2);
// Anti-amplification padding // Anti-amplification padding
...@@ -260,7 +260,7 @@ int zmq::curve_client_t::process_welcome (msg_t *msg_) ...@@ -260,7 +260,7 @@ int zmq::curve_client_t::process_welcome (msg_t *msg_)
} }
const uint8_t * welcome = static_cast <uint8_t *> (msg_->data ()); const uint8_t * welcome = static_cast <uint8_t *> (msg_->data ());
if (memcmp (welcome, "WELCOME\0", 8)) { if (memcmp (welcome, "\x07WELCOME", 8)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
...@@ -294,7 +294,7 @@ int zmq::curve_client_t::process_welcome (msg_t *msg_) ...@@ -294,7 +294,7 @@ int zmq::curve_client_t::process_welcome (msg_t *msg_)
return 0; return 0;
} }
int zmq::curve_client_t::initiate_msg (msg_t *msg_) int zmq::curve_client_t::produce_initiate (msg_t *msg_)
{ {
uint8_t vouch_nonce [crypto_box_NONCEBYTES]; uint8_t vouch_nonce [crypto_box_NONCEBYTES];
uint8_t vouch_plaintext [crypto_box_ZEROBYTES + 32]; uint8_t vouch_plaintext [crypto_box_ZEROBYTES + 32];
...@@ -351,7 +351,7 @@ int zmq::curve_client_t::initiate_msg (msg_t *msg_) ...@@ -351,7 +351,7 @@ int zmq::curve_client_t::initiate_msg (msg_t *msg_)
uint8_t *initiate = static_cast <uint8_t *> (msg_->data ()); uint8_t *initiate = static_cast <uint8_t *> (msg_->data ());
memcpy (initiate, "INITIATE\0", 9); memcpy (initiate, "\x08INITIATE", 9);
// Cookie provided by the server in the WELCOME command // Cookie provided by the server in the WELCOME command
memcpy (initiate + 9, cn_cookie, 96); memcpy (initiate + 9, cn_cookie, 96);
// Short nonce, prefixed by "CurveZMQINITIATE" // Short nonce, prefixed by "CurveZMQINITIATE"
...@@ -373,7 +373,7 @@ int zmq::curve_client_t::process_ready (msg_t *msg_) ...@@ -373,7 +373,7 @@ int zmq::curve_client_t::process_ready (msg_t *msg_)
} }
const uint8_t *ready = static_cast <uint8_t *> (msg_->data ()); const uint8_t *ready = static_cast <uint8_t *> (msg_->data ());
if (memcmp (ready, "READY\0", 6)) { if (memcmp (ready, "\x05READY", 6)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
......
...@@ -50,8 +50,8 @@ namespace zmq ...@@ -50,8 +50,8 @@ namespace zmq
virtual ~curve_client_t (); virtual ~curve_client_t ();
// mechanism implementation // mechanism implementation
virtual int next_handshake_message (msg_t *msg_); virtual int next_handshake_command (msg_t *msg_);
virtual int process_handshake_message (msg_t *msg_); virtual int process_handshake_command (msg_t *msg_);
virtual int encode (msg_t *msg_); virtual int encode (msg_t *msg_);
virtual int decode (msg_t *msg_); virtual int decode (msg_t *msg_);
virtual bool is_handshake_complete () const; virtual bool is_handshake_complete () const;
...@@ -96,9 +96,9 @@ namespace zmq ...@@ -96,9 +96,9 @@ namespace zmq
// Nonce // Nonce
uint64_t cn_nonce; uint64_t cn_nonce;
int hello_msg (msg_t *msg_); int produce_hello (msg_t *msg_);
int process_welcome (msg_t *msg_); int process_welcome (msg_t *msg_);
int initiate_msg (msg_t *msg_); int produce_initiate (msg_t *msg_);
int process_ready (msg_t *msg_); int process_ready (msg_t *msg_);
}; };
......
...@@ -54,18 +54,18 @@ zmq::curve_server_t::~curve_server_t () ...@@ -54,18 +54,18 @@ zmq::curve_server_t::~curve_server_t ()
{ {
} }
int zmq::curve_server_t::next_handshake_message (msg_t *msg_) int zmq::curve_server_t::next_handshake_command (msg_t *msg_)
{ {
int rc = 0; int rc = 0;
switch (state) { switch (state) {
case send_welcome: case send_welcome:
rc = welcome_msg (msg_); rc = produce_welcome (msg_);
if (rc == 0) if (rc == 0)
state = expect_initiate; state = expect_initiate;
break; break;
case send_ready: case send_ready:
rc = ready_msg (msg_); rc = produce_ready (msg_);
if (rc == 0) if (rc == 0)
state = connected; state = connected;
break; break;
...@@ -77,7 +77,7 @@ int zmq::curve_server_t::next_handshake_message (msg_t *msg_) ...@@ -77,7 +77,7 @@ int zmq::curve_server_t::next_handshake_message (msg_t *msg_)
return rc; return rc;
} }
int zmq::curve_server_t::process_handshake_message (msg_t *msg_) int zmq::curve_server_t::process_handshake_command (msg_t *msg_)
{ {
int rc = 0; int rc = 0;
...@@ -143,7 +143,7 @@ int zmq::curve_server_t::encode (msg_t *msg_) ...@@ -143,7 +143,7 @@ int zmq::curve_server_t::encode (msg_t *msg_)
uint8_t *message = static_cast <uint8_t *> (msg_->data ()); uint8_t *message = static_cast <uint8_t *> (msg_->data ());
memcpy (message, "MESSAGE\0", 8); memcpy (message, "\x07MESSAGE", 8);
memcpy (message + 8, &cn_nonce, 8); memcpy (message + 8, &cn_nonce, 8);
memcpy (message + 16, message_box + crypto_box_BOXZEROBYTES, memcpy (message + 16, message_box + crypto_box_BOXZEROBYTES,
mlen - crypto_box_BOXZEROBYTES); mlen - crypto_box_BOXZEROBYTES);
...@@ -166,7 +166,7 @@ int zmq::curve_server_t::decode (msg_t *msg_) ...@@ -166,7 +166,7 @@ int zmq::curve_server_t::decode (msg_t *msg_)
} }
const uint8_t *message = static_cast <uint8_t *> (msg_->data ()); const uint8_t *message = static_cast <uint8_t *> (msg_->data ());
if (memcmp (message, "MESSAGE\0", 8)) { if (memcmp (message, "\x07MESSAGE", 8)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
...@@ -238,7 +238,7 @@ int zmq::curve_server_t::process_hello (msg_t *msg_) ...@@ -238,7 +238,7 @@ int zmq::curve_server_t::process_hello (msg_t *msg_)
} }
const uint8_t * const hello = static_cast <uint8_t *> (msg_->data ()); const uint8_t * const hello = static_cast <uint8_t *> (msg_->data ());
if (memcmp (hello, "HELLO\0", 6)) { if (memcmp (hello, "\x05HELLO", 6)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
...@@ -276,7 +276,7 @@ int zmq::curve_server_t::process_hello (msg_t *msg_) ...@@ -276,7 +276,7 @@ int zmq::curve_server_t::process_hello (msg_t *msg_)
return rc; return rc;
} }
int zmq::curve_server_t::welcome_msg (msg_t *msg_) int zmq::curve_server_t::produce_welcome (msg_t *msg_)
{ {
uint8_t cookie_nonce [crypto_secretbox_NONCEBYTES]; uint8_t cookie_nonce [crypto_secretbox_NONCEBYTES];
uint8_t cookie_plaintext [crypto_secretbox_ZEROBYTES + 64]; uint8_t cookie_plaintext [crypto_secretbox_ZEROBYTES + 64];
...@@ -329,7 +329,7 @@ int zmq::curve_server_t::welcome_msg (msg_t *msg_) ...@@ -329,7 +329,7 @@ int zmq::curve_server_t::welcome_msg (msg_t *msg_)
errno_assert (rc == 0); errno_assert (rc == 0);
uint8_t * const welcome = static_cast <uint8_t *> (msg_->data ()); uint8_t * const welcome = static_cast <uint8_t *> (msg_->data ());
memcpy (welcome, "WELCOME\0", 8); memcpy (welcome, "\x07WELCOME", 8);
memcpy (welcome + 8, welcome_nonce + 8, 16); memcpy (welcome + 8, welcome_nonce + 8, 16);
memcpy (welcome + 24, welcome_ciphertext + crypto_box_BOXZEROBYTES, 144); memcpy (welcome + 24, welcome_ciphertext + crypto_box_BOXZEROBYTES, 144);
...@@ -344,7 +344,7 @@ int zmq::curve_server_t::process_initiate (msg_t *msg_) ...@@ -344,7 +344,7 @@ int zmq::curve_server_t::process_initiate (msg_t *msg_)
} }
const uint8_t *initiate = static_cast <uint8_t *> (msg_->data ()); const uint8_t *initiate = static_cast <uint8_t *> (msg_->data ());
if (memcmp (initiate, "INITIATE\0", 9)) { if (memcmp (initiate, "\x08INITIATE", 9)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
...@@ -447,7 +447,7 @@ int zmq::curve_server_t::process_initiate (msg_t *msg_) ...@@ -447,7 +447,7 @@ int zmq::curve_server_t::process_initiate (msg_t *msg_)
clen - crypto_box_ZEROBYTES - 96); clen - crypto_box_ZEROBYTES - 96);
} }
int zmq::curve_server_t::ready_msg (msg_t *msg_) int zmq::curve_server_t::produce_ready (msg_t *msg_)
{ {
uint8_t ready_nonce [crypto_box_NONCEBYTES]; uint8_t ready_nonce [crypto_box_NONCEBYTES];
uint8_t ready_plaintext [crypto_box_ZEROBYTES + 256]; uint8_t ready_plaintext [crypto_box_ZEROBYTES + 256];
...@@ -482,7 +482,7 @@ int zmq::curve_server_t::ready_msg (msg_t *msg_) ...@@ -482,7 +482,7 @@ int zmq::curve_server_t::ready_msg (msg_t *msg_)
uint8_t *ready = static_cast <uint8_t *> (msg_->data ()); uint8_t *ready = static_cast <uint8_t *> (msg_->data ());
memcpy (ready, "READY\0", 6); memcpy (ready, "\x05READY", 6);
// Short nonce, prefixed by "CurveZMQREADY---" // Short nonce, prefixed by "CurveZMQREADY---"
memcpy (ready + 6, &cn_nonce, 8); memcpy (ready + 6, &cn_nonce, 8);
// Box [metadata](S'->C') // Box [metadata](S'->C')
......
...@@ -55,8 +55,8 @@ namespace zmq ...@@ -55,8 +55,8 @@ namespace zmq
virtual ~curve_server_t (); virtual ~curve_server_t ();
// mechanism implementation // mechanism implementation
virtual int next_handshake_message (msg_t *msg_); virtual int next_handshake_command (msg_t *msg_);
virtual int process_handshake_message (msg_t *msg_); virtual int process_handshake_command (msg_t *msg_);
virtual int encode (msg_t *msg_); virtual int encode (msg_t *msg_);
virtual int decode (msg_t *msg_); virtual int decode (msg_t *msg_);
virtual int zap_msg_available (); virtual int zap_msg_available ();
...@@ -104,9 +104,9 @@ namespace zmq ...@@ -104,9 +104,9 @@ namespace zmq
uint8_t cn_precom [crypto_box_BEFORENMBYTES]; uint8_t cn_precom [crypto_box_BEFORENMBYTES];
int process_hello (msg_t *msg_); int process_hello (msg_t *msg_);
int welcome_msg (msg_t *msg_); int produce_welcome (msg_t *msg_);
int process_initiate (msg_t *msg_); int process_initiate (msg_t *msg_);
int ready_msg (msg_t *msg_); int produce_ready (msg_t *msg_);
void send_zap_request (const uint8_t *key); void send_zap_request (const uint8_t *key);
int receive_and_process_zap_reply (); int receive_and_process_zap_reply ();
......
...@@ -40,11 +40,11 @@ namespace zmq ...@@ -40,11 +40,11 @@ namespace zmq
virtual ~mechanism_t (); virtual ~mechanism_t ();
// Prepare next handshake message that is to be sent to the peer. // Prepare next handshake command that is to be sent to the peer.
virtual int next_handshake_message (msg_t *msg_) = 0; virtual int next_handshake_command (msg_t *msg_) = 0;
// Process the handshake message received from the peer. // Process the handshake command received from the peer.
virtual int process_handshake_message (msg_t *msg_) = 0; virtual int process_handshake_command (msg_t *msg_) = 0;
virtual int encode (msg_t *msg_) { return 0; } virtual int encode (msg_t *msg_) { return 0; }
......
...@@ -47,7 +47,8 @@ namespace zmq ...@@ -47,7 +47,8 @@ namespace zmq
// Message flags. // Message flags.
enum enum
{ {
more = 1, more = 1, // Followed by more parts
command = 2, // Command frame (see ZMTP spec)
identity = 64, identity = 64,
shared = 128 shared = 128
}; };
......
...@@ -53,7 +53,7 @@ zmq::null_mechanism_t::~null_mechanism_t () ...@@ -53,7 +53,7 @@ zmq::null_mechanism_t::~null_mechanism_t ()
{ {
} }
int zmq::null_mechanism_t::next_handshake_message (msg_t *msg_) int zmq::null_mechanism_t::next_handshake_command (msg_t *msg_)
{ {
if (ready_command_sent) { if (ready_command_sent) {
errno = EAGAIN; errno = EAGAIN;
...@@ -78,7 +78,7 @@ int zmq::null_mechanism_t::next_handshake_message (msg_t *msg_) ...@@ -78,7 +78,7 @@ int zmq::null_mechanism_t::next_handshake_message (msg_t *msg_)
unsigned char *ptr = command_buffer; unsigned char *ptr = command_buffer;
// Add mechanism string // Add mechanism string
memcpy (ptr, "READY\0", 6); memcpy (ptr, "\5READY", 6);
ptr += 6; ptr += 6;
// Add socket type property // Add socket type property
...@@ -104,7 +104,7 @@ int zmq::null_mechanism_t::next_handshake_message (msg_t *msg_) ...@@ -104,7 +104,7 @@ int zmq::null_mechanism_t::next_handshake_message (msg_t *msg_)
return 0; return 0;
} }
int zmq::null_mechanism_t::process_handshake_message (msg_t *msg_) int zmq::null_mechanism_t::process_handshake_command (msg_t *msg_)
{ {
if (ready_command_received) { if (ready_command_received) {
errno = EPROTO; errno = EPROTO;
...@@ -115,7 +115,7 @@ int zmq::null_mechanism_t::process_handshake_message (msg_t *msg_) ...@@ -115,7 +115,7 @@ int zmq::null_mechanism_t::process_handshake_message (msg_t *msg_)
static_cast <unsigned char *> (msg_->data ()); static_cast <unsigned char *> (msg_->data ());
size_t bytes_left = msg_->size (); size_t bytes_left = msg_->size ();
if (bytes_left < 6 || memcmp (ptr, "READY\0", 6)) { if (bytes_left < 6 || memcmp (ptr, "\5READY", 6)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
......
...@@ -39,8 +39,8 @@ namespace zmq ...@@ -39,8 +39,8 @@ namespace zmq
virtual ~null_mechanism_t (); virtual ~null_mechanism_t ();
// mechanism implementation // mechanism implementation
virtual int next_handshake_message (msg_t *msg_); virtual int next_handshake_command (msg_t *msg_);
virtual int process_handshake_message (msg_t *msg_); virtual int process_handshake_command (msg_t *msg_);
virtual int zap_msg_available (); virtual int zap_msg_available ();
virtual bool is_handshake_complete () const; virtual bool is_handshake_complete () const;
......
...@@ -46,28 +46,28 @@ zmq::plain_mechanism_t::~plain_mechanism_t () ...@@ -46,28 +46,28 @@ zmq::plain_mechanism_t::~plain_mechanism_t ()
{ {
} }
int zmq::plain_mechanism_t::next_handshake_message (msg_t *msg_) int zmq::plain_mechanism_t::next_handshake_command (msg_t *msg_)
{ {
int rc = 0; int rc = 0;
switch (state) { switch (state) {
case sending_hello: case sending_hello:
rc = hello_command (msg_); rc = produce_hello (msg_);
if (rc == 0) if (rc == 0)
state = waiting_for_welcome; state = waiting_for_welcome;
break; break;
case sending_welcome: case sending_welcome:
rc = welcome_command (msg_); rc = produce_welcome (msg_);
if (rc == 0) if (rc == 0)
state = waiting_for_initiate; state = waiting_for_initiate;
break; break;
case sending_initiate: case sending_initiate:
rc = initiate_command (msg_); rc = produce_initiate (msg_);
if (rc == 0) if (rc == 0)
state = waiting_for_ready; state = waiting_for_ready;
break; break;
case sending_ready: case sending_ready:
rc = ready_command (msg_); rc = produce_ready (msg_);
if (rc == 0) if (rc == 0)
state = ready; state = ready;
break; break;
...@@ -78,28 +78,28 @@ int zmq::plain_mechanism_t::next_handshake_message (msg_t *msg_) ...@@ -78,28 +78,28 @@ int zmq::plain_mechanism_t::next_handshake_message (msg_t *msg_)
return rc; return rc;
} }
int zmq::plain_mechanism_t::process_handshake_message (msg_t *msg_) int zmq::plain_mechanism_t::process_handshake_command (msg_t *msg_)
{ {
int rc = 0; int rc = 0;
switch (state) { switch (state) {
case waiting_for_hello: case waiting_for_hello:
rc = process_hello_command (msg_); rc = process_hello (msg_);
if (rc == 0) if (rc == 0)
state = expecting_zap_reply? waiting_for_zap_reply: sending_welcome; state = expecting_zap_reply? waiting_for_zap_reply: sending_welcome;
break; break;
case waiting_for_welcome: case waiting_for_welcome:
rc = process_welcome_command (msg_); rc = process_welcome (msg_);
if (rc == 0) if (rc == 0)
state = sending_initiate; state = sending_initiate;
break; break;
case waiting_for_initiate: case waiting_for_initiate:
rc = process_initiate_command (msg_); rc = process_initiate (msg_);
if (rc == 0) if (rc == 0)
state = sending_ready; state = sending_ready;
break; break;
case waiting_for_ready: case waiting_for_ready:
rc = process_ready_command (msg_); rc = process_ready (msg_);
if (rc == 0) if (rc == 0)
state = ready; state = ready;
break; break;
...@@ -134,7 +134,7 @@ int zmq::plain_mechanism_t::zap_msg_available () ...@@ -134,7 +134,7 @@ int zmq::plain_mechanism_t::zap_msg_available ()
return rc; return rc;
} }
int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const int zmq::plain_mechanism_t::produce_hello (msg_t *msg_) const
{ {
const std::string username = options.plain_username; const std::string username = options.plain_username;
zmq_assert (username.length () < 256); zmq_assert (username.length () < 256);
...@@ -142,15 +142,15 @@ int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const ...@@ -142,15 +142,15 @@ int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const
const std::string password = options.plain_password; const std::string password = options.plain_password;
zmq_assert (password.length () < 256); zmq_assert (password.length () < 256);
const size_t command_size = 8 + 1 + username.length () const size_t command_size = 6 + 1 + username.length ()
+ 1 + password.length (); + 1 + password.length ();
const int rc = msg_->init_size (command_size); const int rc = msg_->init_size (command_size);
errno_assert (rc == 0); errno_assert (rc == 0);
unsigned char *ptr = static_cast <unsigned char *> (msg_->data ()); unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
memcpy (ptr, "HELLO ", 8); memcpy (ptr, "\x05HELLO", 6);
ptr += 8; ptr += 6;
*ptr++ = static_cast <unsigned char> (username.length ()); *ptr++ = static_cast <unsigned char> (username.length ());
memcpy (ptr, username.c_str (), username.length ()); memcpy (ptr, username.c_str (), username.length ());
...@@ -164,17 +164,17 @@ int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const ...@@ -164,17 +164,17 @@ int zmq::plain_mechanism_t::hello_command (msg_t *msg_) const
} }
int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_) int zmq::plain_mechanism_t::process_hello (msg_t *msg_)
{ {
const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ()); const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
size_t bytes_left = msg_->size (); size_t bytes_left = msg_->size ();
if (bytes_left < 8 || memcmp (ptr, "HELLO ", 8)) { if (bytes_left < 6 || memcmp (ptr, "\x05HELLO", 6)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
ptr += 8; ptr += 6;
bytes_left -= 8; bytes_left -= 6;
if (bytes_left < 1) { if (bytes_left < 1) {
errno = EPROTO; errno = EPROTO;
...@@ -226,27 +226,27 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_) ...@@ -226,27 +226,27 @@ int zmq::plain_mechanism_t::process_hello_command (msg_t *msg_)
return 0; return 0;
} }
int zmq::plain_mechanism_t::welcome_command (msg_t *msg_) const int zmq::plain_mechanism_t::produce_welcome (msg_t *msg_) const
{ {
const int rc = msg_->init_size (8); const int rc = msg_->init_size (8);
errno_assert (rc == 0); errno_assert (rc == 0);
memcpy (msg_->data (), "WELCOME ", 8); memcpy (msg_->data (), "\x07WELCOME", 8);
return 0; return 0;
} }
int zmq::plain_mechanism_t::process_welcome_command (msg_t *msg_) int zmq::plain_mechanism_t::process_welcome (msg_t *msg_)
{ {
const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ()); const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
size_t bytes_left = msg_->size (); size_t bytes_left = msg_->size ();
if (bytes_left != 8 || memcmp (ptr, "WELCOME ", 8)) { if (bytes_left != 8 || memcmp (ptr, "\x07WELCOME", 8)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
return 0; return 0;
} }
int zmq::plain_mechanism_t::initiate_command (msg_t *msg_) const int zmq::plain_mechanism_t::produce_initiate (msg_t *msg_) const
{ {
unsigned char * const command_buffer = (unsigned char *) malloc (512); unsigned char * const command_buffer = (unsigned char *) malloc (512);
alloc_assert (command_buffer); alloc_assert (command_buffer);
...@@ -254,8 +254,8 @@ int zmq::plain_mechanism_t::initiate_command (msg_t *msg_) const ...@@ -254,8 +254,8 @@ int zmq::plain_mechanism_t::initiate_command (msg_t *msg_) const
unsigned char *ptr = command_buffer; unsigned char *ptr = command_buffer;
// Add mechanism string // Add mechanism string
memcpy (ptr, "INITIATE", 8); memcpy (ptr, "\x08INITIATE", 9);
ptr += 8; ptr += 9;
// Add socket type property // Add socket type property
const char *socket_type = socket_type_string (options.type); const char *socket_type = socket_type_string (options.type);
...@@ -278,19 +278,21 @@ int zmq::plain_mechanism_t::initiate_command (msg_t *msg_) const ...@@ -278,19 +278,21 @@ int zmq::plain_mechanism_t::initiate_command (msg_t *msg_) const
return 0; return 0;
} }
int zmq::plain_mechanism_t::process_initiate_command (msg_t *msg_) int zmq::plain_mechanism_t::process_initiate (msg_t *msg_)
{ {
const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ()); const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
size_t bytes_left = msg_->size (); size_t bytes_left = msg_->size ();
if (bytes_left < 8 || memcmp (ptr, "INITIATE", 8)) { if (bytes_left < 9 || memcmp (ptr, "\x08INITIATE", 9)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
return parse_metadata (ptr + 8, bytes_left - 8); ptr += 9;
bytes_left -= 9;
return parse_metadata (ptr, bytes_left);
} }
int zmq::plain_mechanism_t::ready_command (msg_t *msg_) const int zmq::plain_mechanism_t::produce_ready (msg_t *msg_) const
{ {
unsigned char * const command_buffer = (unsigned char *) malloc (512); unsigned char * const command_buffer = (unsigned char *) malloc (512);
alloc_assert (command_buffer); alloc_assert (command_buffer);
...@@ -298,7 +300,7 @@ int zmq::plain_mechanism_t::ready_command (msg_t *msg_) const ...@@ -298,7 +300,7 @@ int zmq::plain_mechanism_t::ready_command (msg_t *msg_) const
unsigned char *ptr = command_buffer; unsigned char *ptr = command_buffer;
// Add command name // Add command name
memcpy (ptr, "READY\0", 6); memcpy (ptr, "\x05READY", 6);
ptr += 6; ptr += 6;
// Add socket type property // Add socket type property
...@@ -322,12 +324,12 @@ int zmq::plain_mechanism_t::ready_command (msg_t *msg_) const ...@@ -322,12 +324,12 @@ int zmq::plain_mechanism_t::ready_command (msg_t *msg_) const
return 0; return 0;
} }
int zmq::plain_mechanism_t::process_ready_command (msg_t *msg_) int zmq::plain_mechanism_t::process_ready (msg_t *msg_)
{ {
const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ()); const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
size_t bytes_left = msg_->size (); size_t bytes_left = msg_->size ();
if (bytes_left < 6 || memcmp (ptr, "READY\0", 6)) { if (bytes_left < 6 || memcmp (ptr, "\x05READY", 6)) {
errno = EPROTO; errno = EPROTO;
return -1; return -1;
} }
......
...@@ -39,8 +39,8 @@ namespace zmq ...@@ -39,8 +39,8 @@ namespace zmq
virtual ~plain_mechanism_t (); virtual ~plain_mechanism_t ();
// mechanism implementation // mechanism implementation
virtual int next_handshake_message (msg_t *msg_); virtual int next_handshake_command (msg_t *msg_);
virtual int process_handshake_message (msg_t *msg_); virtual int process_handshake_command (msg_t *msg_);
virtual int zap_msg_available (); virtual int zap_msg_available ();
virtual bool is_handshake_complete () const; virtual bool is_handshake_complete () const;
...@@ -68,15 +68,15 @@ namespace zmq ...@@ -68,15 +68,15 @@ namespace zmq
state_t state; state_t state;
int hello_command (msg_t *msg_) const; int produce_hello (msg_t *msg_) const;
int welcome_command (msg_t *msg_) const; int produce_welcome (msg_t *msg_) const;
int initiate_command (msg_t *msg_) const; int produce_initiate (msg_t *msg_) const;
int ready_command (msg_t *msg_) const; int produce_ready (msg_t *msg_) const;
int process_hello_command (msg_t *msg_); int process_hello (msg_t *msg_);
int process_welcome_command (msg_t *msg); int process_welcome (msg_t *msg);
int process_ready_command (msg_t *msg_); int process_ready (msg_t *msg_);
int process_initiate_command (msg_t *msg_); int process_initiate (msg_t *msg_);
void send_zap_request (const std::string &username, void send_zap_request (const std::string &username,
const std::string &password); const std::string &password);
......
...@@ -395,7 +395,6 @@ bool zmq::stream_engine_t::handshake () ...@@ -395,7 +395,6 @@ bool zmq::stream_engine_t::handshake ()
{ {
zmq_assert (handshaking); zmq_assert (handshaking);
zmq_assert (greeting_bytes_read < greeting_size); zmq_assert (greeting_bytes_read < greeting_size);
// Receive the greeting. // Receive the greeting.
while (greeting_bytes_read < greeting_size) { while (greeting_bytes_read < greeting_size) {
const int n = read (greeting_recv + greeting_bytes_read, const int n = read (greeting_recv + greeting_bytes_read,
...@@ -492,8 +491,8 @@ bool zmq::stream_engine_t::handshake () ...@@ -492,8 +491,8 @@ bool zmq::stream_engine_t::handshake ()
insize = greeting_bytes_read; insize = greeting_bytes_read;
// To allow for interoperability with peers that do not forward // To allow for interoperability with peers that do not forward
// their subscriptions, we inject a phony subscription // their subscriptions, we inject a phantom subscription message
// message into the incomming message stream. // message into the incoming message stream.
if (options.type == ZMQ_PUB || options.type == ZMQ_XPUB) if (options.type == ZMQ_PUB || options.type == ZMQ_XPUB)
subscription_required = true; subscription_required = true;
} }
...@@ -550,9 +549,8 @@ bool zmq::stream_engine_t::handshake () ...@@ -550,9 +549,8 @@ bool zmq::stream_engine_t::handshake ()
error (); error ();
return false; return false;
} }
read_msg = &stream_engine_t::next_handshake_command;
read_msg = &stream_engine_t::next_handshake_message; write_msg = &stream_engine_t::process_handshake_command;
write_msg = &stream_engine_t::process_handshake_message;
} }
// Start polling for output if necessary. // Start polling for output if necessary.
...@@ -598,12 +596,13 @@ int zmq::stream_engine_t::write_identity (msg_t *msg_) ...@@ -598,12 +596,13 @@ int zmq::stream_engine_t::write_identity (msg_t *msg_)
return 0; return 0;
} }
int zmq::stream_engine_t::next_handshake_message (msg_t *msg_) int zmq::stream_engine_t::next_handshake_command (msg_t *msg_)
{ {
zmq_assert (mechanism != NULL); zmq_assert (mechanism != NULL);
const int rc = mechanism->next_handshake_message (msg_); const int rc = mechanism->next_handshake_command (msg_);
if (rc == 0) { if (rc == 0) {
msg_->set_flags (msg_t::command);
if (mechanism->is_handshake_complete ()) if (mechanism->is_handshake_complete ())
mechanism_ready (); mechanism_ready ();
} }
...@@ -611,11 +610,10 @@ int zmq::stream_engine_t::next_handshake_message (msg_t *msg_) ...@@ -611,11 +610,10 @@ int zmq::stream_engine_t::next_handshake_message (msg_t *msg_)
return rc; return rc;
} }
int zmq::stream_engine_t::process_handshake_message (msg_t *msg_) int zmq::stream_engine_t::process_handshake_command (msg_t *msg_)
{ {
zmq_assert (mechanism != NULL); zmq_assert (mechanism != NULL);
const int rc = mechanism->process_handshake_command (msg_);
const int rc = mechanism->process_handshake_message (msg_);
if (rc == 0) { if (rc == 0) {
if (mechanism->is_handshake_complete ()) if (mechanism->is_handshake_complete ())
mechanism_ready (); mechanism_ready ();
......
...@@ -96,8 +96,8 @@ namespace zmq ...@@ -96,8 +96,8 @@ namespace zmq
int read_identity (msg_t *msg_); int read_identity (msg_t *msg_);
int write_identity (msg_t *msg_); int write_identity (msg_t *msg_);
int next_handshake_message (msg_t *msg); int next_handshake_command (msg_t *msg);
int process_handshake_message (msg_t *msg); int process_handshake_command (msg_t *msg);
int pull_msg_from_session (msg_t *msg_); int pull_msg_from_session (msg_t *msg_);
int push_msg_to_session (msg_t *msg); int push_msg_to_session (msg_t *msg);
...@@ -171,8 +171,8 @@ namespace zmq ...@@ -171,8 +171,8 @@ namespace zmq
bool io_error; bool io_error;
// Indicates whether the engine is to inject a phony // Indicates whether the engine is to inject a phantom
// subscription message into the incomming stream. // subscription message into the incoming stream.
// Needed to support old peers. // Needed to support old peers.
bool subscription_required; bool subscription_required;
......
...@@ -54,6 +54,8 @@ int zmq::v2_decoder_t::flags_ready () ...@@ -54,6 +54,8 @@ int zmq::v2_decoder_t::flags_ready ()
msg_flags = 0; msg_flags = 0;
if (tmpbuf [0] & v2_protocol_t::more_flag) if (tmpbuf [0] & v2_protocol_t::more_flag)
msg_flags |= msg_t::more; msg_flags |= msg_t::more;
if (tmpbuf [0] & v2_protocol_t::command_flag)
msg_flags |= msg_t::command;
// The payload length is either one or eight bytes, // The payload length is either one or eight bytes,
// depending on whether the 'large' bit is set. // depending on whether the 'large' bit is set.
......
...@@ -42,6 +42,8 @@ void zmq::v2_encoder_t::message_ready () ...@@ -42,6 +42,8 @@ void zmq::v2_encoder_t::message_ready ()
protocol_flags |= v2_protocol_t::more_flag; protocol_flags |= v2_protocol_t::more_flag;
if (in_progress->size () > 255) if (in_progress->size () > 255)
protocol_flags |= v2_protocol_t::large_flag; protocol_flags |= v2_protocol_t::large_flag;
if (in_progress->flags () & msg_t::command)
protocol_flags |= v2_protocol_t::command_flag;
// Encode the message length. For messages less then 256 bytes, // Encode the message length. For messages less then 256 bytes,
// the length is encoded as 8-bit unsigned integer. For larger // the length is encoded as 8-bit unsigned integer. For larger
......
...@@ -30,7 +30,8 @@ namespace zmq ...@@ -30,7 +30,8 @@ namespace zmq
enum enum
{ {
more_flag = 1, more_flag = 1,
large_flag = 2 large_flag = 2,
command_flag = 4
}; };
}; };
} }
......
...@@ -85,4 +85,5 @@ test_fork_SOURCES = test_fork.cpp ...@@ -85,4 +85,5 @@ test_fork_SOURCES = test_fork.cpp
endif endif
# Run the test cases # Run the test cases
TESTS = $(noinst_PROGRAMS) # TESTS = $(noinst_PROGRAMS)
TESTS = test_security_plain
\ No newline at end of file
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
#include "platform.hpp"
#include <string.h> #include <string.h>
#include <stdlib.h> #include <stdlib.h>
#include "testutil.hpp" #include "testutil.hpp"
#include "../include/zmq_utils.h" #include "../include/zmq_utils.h"
#include "../src/z85_codec.hpp" #include "../src/z85_codec.hpp"
#include "../src/platform.hpp"
// Test keys from the zmq_curve man page // Test keys from the zmq_curve man page
static char client_public [] = "Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID"; static char client_public [] = "Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID";
......
...@@ -38,7 +38,7 @@ typedef struct { ...@@ -38,7 +38,7 @@ typedef struct {
// 8-byte size is set to 1 for backwards compatibility // 8-byte size is set to 1 for backwards compatibility
static zmtp_greeting_t greeting static zmtp_greeting_t greeting
= { { 0xFF, 0, 0, 0, 0, 0, 0, 0, 1, 0x7F }, {3, 0}, { 'N', 'U', 'L', 'L'} }; = { { 0xFF, 0, 0, 0, 0, 0, 0, 0, 1, 0x7F }, { 3, 0 }, { 'N', 'U', 'L', 'L'} };
static void static void
test_stream_to_dealer (void) test_stream_to_dealer (void)
...@@ -106,13 +106,13 @@ test_stream_to_dealer (void) ...@@ -106,13 +106,13 @@ test_stream_to_dealer (void)
assert (buffer [1] == 0); assert (buffer [1] == 0);
// Mechanism is "NULL" // Mechanism is "NULL"
assert (memcmp (buffer + 2, "NULL\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 22) == 0); assert (memcmp (buffer + 2, "NULL\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) == 0);
assert (memcmp (buffer + 54, "\0\51READY\0", 8) == 0); assert (memcmp (buffer + 54, "\4\51\5READY", 8) == 0);
assert (memcmp (buffer + 62, "\13Socket-Type\0\0\0\6DEALER", 22) == 0); assert (memcmp (buffer + 62, "\13Socket-Type\0\0\0\6DEALER", 22) == 0);
assert (memcmp (buffer + 84, "\10Identity\0\0\0\0", 13) == 0); assert (memcmp (buffer + 84, "\10Identity\0\0\0\0", 13) == 0);
// Announce we are ready // Announce we are ready
memcpy (buffer, "\0\51READY\0", 8); memcpy (buffer, "\4\51\5READY", 8);
memcpy (buffer + 8, "\13Socket-Type\0\0\0\6ROUTER", 22); memcpy (buffer + 8, "\13Socket-Type\0\0\0\6ROUTER", 22);
memcpy (buffer + 30, "\10Identity\0\0\0\0", 13); memcpy (buffer + 30, "\10Identity\0\0\0\0", 13);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment