Commit 89e53131 authored by Simon Giesecke's avatar Simon Giesecke

Refactored zmq::stream_engine_t::handshake, extracted several sub-methods

parent c3739ff6
...@@ -490,18 +490,50 @@ bool zmq::stream_engine_t::handshake () ...@@ -490,18 +490,50 @@ bool zmq::stream_engine_t::handshake ()
zmq_assert (_handshaking); zmq_assert (_handshaking);
zmq_assert (_greeting_bytes_read < _greeting_size); zmq_assert (_greeting_bytes_read < _greeting_size);
// Receive the greeting. // Receive the greeting.
const int rc = receive_greeting ();
if (rc == -1)
return false;
const bool unversioned = rc != 0;
// Position of the revision field in the greeting.
const size_t revision_pos = 10;
if (!(this
->*select_handshake_fun (unversioned,
_greeting_recv[revision_pos])) ())
return false;
// Start polling for output if necessary.
if (_outsize == 0)
set_pollout (_handle);
// Handshaking was successful.
// Switch into the normal message flow.
_handshaking = false;
if (_has_handshake_timer) {
cancel_timer (handshake_timer_id);
_has_handshake_timer = false;
}
return true;
}
int zmq::stream_engine_t::receive_greeting ()
{
bool unversioned = false;
while (_greeting_bytes_read < _greeting_size) { while (_greeting_bytes_read < _greeting_size) {
const int n = tcp_read (_s, _greeting_recv + _greeting_bytes_read, const int n = tcp_read (_s, _greeting_recv + _greeting_bytes_read,
_greeting_size - _greeting_bytes_read); _greeting_size - _greeting_bytes_read);
if (n == 0) { if (n == 0) {
errno = EPIPE; errno = EPIPE;
error (connection_error); error (connection_error);
return false; return -1;
} }
if (n == -1) { if (n == -1) {
if (errno != EAGAIN) if (errno != EAGAIN)
error (connection_error); error (connection_error);
return false; return -1;
} }
_greeting_bytes_read += n; _greeting_bytes_read += n;
...@@ -509,8 +541,10 @@ bool zmq::stream_engine_t::handshake () ...@@ -509,8 +541,10 @@ bool zmq::stream_engine_t::handshake ()
// We have received at least one byte from the peer. // We have received at least one byte from the peer.
// If the first byte is not 0xff, we know that the // If the first byte is not 0xff, we know that the
// peer is using unversioned protocol. // peer is using unversioned protocol.
if (_greeting_recv[0] != 0xff) if (_greeting_recv[0] != 0xff) {
unversioned = true;
break; break;
}
if (_greeting_bytes_read < signature_size) if (_greeting_bytes_read < signature_size)
continue; continue;
...@@ -519,208 +553,231 @@ bool zmq::stream_engine_t::handshake () ...@@ -519,208 +553,231 @@ bool zmq::stream_engine_t::handshake ()
// with the 'flags' field if a regular message was sent). // with the 'flags' field if a regular message was sent).
// Zero indicates this is a header of a routing id message // Zero indicates this is a header of a routing id message
// (i.e. the peer is using the unversioned protocol). // (i.e. the peer is using the unversioned protocol).
if (!(_greeting_recv[9] & 0x01)) if (!(_greeting_recv[9] & 0x01)) {
unversioned = true;
break; break;
}
// The peer is using versioned protocol. // The peer is using versioned protocol.
// Send the major version number. receive_greeting_versioned ();
if (_outpos + _outsize == _greeting_send + signature_size) { }
return unversioned ? 1 : 0;
}
void zmq::stream_engine_t::receive_greeting_versioned ()
{
// Send the major version number.
if (_outpos + _outsize == _greeting_send + signature_size) {
if (_outsize == 0)
set_pollout (_handle);
_outpos[_outsize++] = 3; // Major version number
}
if (_greeting_bytes_read > signature_size) {
if (_outpos + _outsize == _greeting_send + signature_size + 1) {
if (_outsize == 0) if (_outsize == 0)
set_pollout (_handle); set_pollout (_handle);
_outpos[_outsize++] = 3; // Major version number
}
if (_greeting_bytes_read > signature_size) { // Use ZMTP/2.0 to talk to older peers.
if (_outpos + _outsize == _greeting_send + signature_size + 1) { if (_greeting_recv[10] == ZMTP_1_0
if (_outsize == 0) || _greeting_recv[10] == ZMTP_2_0)
set_pollout (_handle); _outpos[_outsize++] = _options.type;
else {
// Use ZMTP/2.0 to talk to older peers. _outpos[_outsize++] = 0; // Minor version number
if (_greeting_recv[10] == ZMTP_1_0 memset (_outpos + _outsize, 0, 20);
|| _greeting_recv[10] == ZMTP_2_0)
_outpos[_outsize++] = _options.type; zmq_assert (_options.mechanism == ZMQ_NULL
else { || _options.mechanism == ZMQ_PLAIN
_outpos[_outsize++] = 0; // Minor version number || _options.mechanism == ZMQ_CURVE
memset (_outpos + _outsize, 0, 20); || _options.mechanism == ZMQ_GSSAPI);
zmq_assert (_options.mechanism == ZMQ_NULL if (_options.mechanism == ZMQ_NULL)
|| _options.mechanism == ZMQ_PLAIN memcpy (_outpos + _outsize, "NULL", 4);
|| _options.mechanism == ZMQ_CURVE else if (_options.mechanism == ZMQ_PLAIN)
|| _options.mechanism == ZMQ_GSSAPI); memcpy (_outpos + _outsize, "PLAIN", 5);
else if (_options.mechanism == ZMQ_GSSAPI)
if (_options.mechanism == ZMQ_NULL) memcpy (_outpos + _outsize, "GSSAPI", 6);
memcpy (_outpos + _outsize, "NULL", 4); else if (_options.mechanism == ZMQ_CURVE)
else if (_options.mechanism == ZMQ_PLAIN) memcpy (_outpos + _outsize, "CURVE", 5);
memcpy (_outpos + _outsize, "PLAIN", 5); _outsize += 20;
else if (_options.mechanism == ZMQ_GSSAPI) memset (_outpos + _outsize, 0, 32);
memcpy (_outpos + _outsize, "GSSAPI", 6); _outsize += 32;
else if (_options.mechanism == ZMQ_CURVE) _greeting_size = v3_greeting_size;
memcpy (_outpos + _outsize, "CURVE", 5);
_outsize += 20;
memset (_outpos + _outsize, 0, 32);
_outsize += 32;
_greeting_size = v3_greeting_size;
}
} }
} }
} }
}
// Position of the revision field in the greeting. zmq::stream_engine_t::handshake_fun_t
const size_t revision_pos = 10; zmq::stream_engine_t::select_handshake_fun (bool unversioned,
unsigned char revision)
{
// Is the peer using ZMTP/1.0 with no revision number? // Is the peer using ZMTP/1.0 with no revision number?
// If so, we send and receive rest of routing id message if (unversioned) {
if (_greeting_recv[0] != 0xff || !(_greeting_recv[9] & 0x01)) { return &stream_engine_t::handshake_v1_0_unversioned;
if (_session->zap_enabled ()) { }
// reject ZMTP 1.0 connections if ZAP is enabled switch (revision) {
error (protocol_error); case ZMTP_1_0:
return false; return &stream_engine_t::handshake_v1_0;
} case ZMTP_2_0:
return &stream_engine_t::handshake_v2_0;
default:
return &stream_engine_t::handshake_v3_0;
}
}
_encoder = new (std::nothrow) v1_encoder_t (out_batch_size); bool zmq::stream_engine_t::handshake_v1_0_unversioned ()
alloc_assert (_encoder); {
// We send and receive rest of routing id message
if (_session->zap_enabled ()) {
// reject ZMTP 1.0 connections if ZAP is enabled
error (protocol_error);
return false;
}
_decoder = _encoder = new (std::nothrow) v1_encoder_t (out_batch_size);
new (std::nothrow) v1_decoder_t (in_batch_size, _options.maxmsgsize); alloc_assert (_encoder);
alloc_assert (_decoder);
_decoder =
new (std::nothrow) v1_decoder_t (in_batch_size, _options.maxmsgsize);
alloc_assert (_decoder);
// We have already sent the message header.
// Since there is no way to tell the encoder to
// skip the message header, we simply throw that
// header data away.
const size_t header_size =
_options.routing_id_size + 1 >= UCHAR_MAX ? 10 : 2;
unsigned char tmp[10], *bufferp = tmp;
// Prepare the routing id message and load it into encoder.
// Then consume bytes we have already sent to the peer.
const int rc = _tx_msg.init_size (_options.routing_id_size);
zmq_assert (rc == 0);
memcpy (_tx_msg.data (), _options.routing_id, _options.routing_id_size);
_encoder->load_msg (&_tx_msg);
const size_t buffer_size = _encoder->encode (&bufferp, header_size);
zmq_assert (buffer_size == header_size);
// Make sure the decoder sees the data we have already received.
_inpos = _greeting_recv;
_insize = _greeting_bytes_read;
// To allow for interoperability with peers that do not forward
// their subscriptions, we inject a phantom subscription message
// message into the incoming message stream.
if (_options.type == ZMQ_PUB || _options.type == ZMQ_XPUB)
_subscription_required = true;
// We are sending our routing id now and the next message
// will come from the socket.
_next_msg = &stream_engine_t::pull_msg_from_session;
// We have already sent the message header. // We are expecting routing id message.
// Since there is no way to tell the encoder to _process_msg = &stream_engine_t::process_routing_id_msg;
// skip the message header, we simply throw that
// header data away.
const size_t header_size =
_options.routing_id_size + 1 >= UCHAR_MAX ? 10 : 2;
unsigned char tmp[10], *bufferp = tmp;
// Prepare the routing id message and load it into encoder.
// Then consume bytes we have already sent to the peer.
const int rc = _tx_msg.init_size (_options.routing_id_size);
zmq_assert (rc == 0);
memcpy (_tx_msg.data (), _options.routing_id, _options.routing_id_size);
_encoder->load_msg (&_tx_msg);
size_t buffer_size = _encoder->encode (&bufferp, header_size);
zmq_assert (buffer_size == header_size);
// Make sure the decoder sees the data we have already received.
_inpos = _greeting_recv;
_insize = _greeting_bytes_read;
// To allow for interoperability with peers that do not forward
// their subscriptions, we inject a phantom subscription message
// message into the incoming message stream.
if (_options.type == ZMQ_PUB || _options.type == ZMQ_XPUB)
_subscription_required = true;
// We are sending our routing id now and the next message
// will come from the socket.
_next_msg = &stream_engine_t::pull_msg_from_session;
// We are expecting routing id message. return true;
_process_msg = &stream_engine_t::process_routing_id_msg; }
} else if (_greeting_recv[revision_pos] == ZMTP_1_0) {
if (_session->zap_enabled ()) {
// reject ZMTP 1.0 connections if ZAP is enabled
error (protocol_error);
return false;
}
_encoder = new (std::nothrow) v1_encoder_t (out_batch_size); bool zmq::stream_engine_t::handshake_v1_0 ()
alloc_assert (_encoder); {
if (_session->zap_enabled ()) {
// reject ZMTP 1.0 connections if ZAP is enabled
error (protocol_error);
return false;
}
_decoder = _encoder = new (std::nothrow) v1_encoder_t (out_batch_size);
new (std::nothrow) v1_decoder_t (in_batch_size, _options.maxmsgsize); alloc_assert (_encoder);
alloc_assert (_decoder);
} else if (_greeting_recv[revision_pos] == ZMTP_2_0) {
if (_session->zap_enabled ()) {
// reject ZMTP 2.0 connections if ZAP is enabled
error (protocol_error);
return false;
}
_encoder = new (std::nothrow) v2_encoder_t (out_batch_size); _decoder =
alloc_assert (_encoder); new (std::nothrow) v1_decoder_t (in_batch_size, _options.maxmsgsize);
alloc_assert (_decoder);
_decoder = new (std::nothrow) return true;
v2_decoder_t (in_batch_size, _options.maxmsgsize, _options.zero_copy); }
alloc_assert (_decoder);
} else {
_encoder = new (std::nothrow) v2_encoder_t (out_batch_size);
alloc_assert (_encoder);
_decoder = new (std::nothrow) bool zmq::stream_engine_t::handshake_v2_0 ()
v2_decoder_t (in_batch_size, _options.maxmsgsize, _options.zero_copy); {
alloc_assert (_decoder); if (_session->zap_enabled ()) {
// reject ZMTP 2.0 connections if ZAP is enabled
error (protocol_error);
return false;
}
if (_options.mechanism == ZMQ_NULL _encoder = new (std::nothrow) v2_encoder_t (out_batch_size);
&& memcmp (_greeting_recv + 12, alloc_assert (_encoder);
"NULL\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20)
== 0) { _decoder = new (std::nothrow)
v2_decoder_t (in_batch_size, _options.maxmsgsize, _options.zero_copy);
alloc_assert (_decoder);
return true;
}
bool zmq::stream_engine_t::handshake_v3_0 ()
{
_encoder = new (std::nothrow) v2_encoder_t (out_batch_size);
alloc_assert (_encoder);
_decoder = new (std::nothrow)
v2_decoder_t (in_batch_size, _options.maxmsgsize, _options.zero_copy);
alloc_assert (_decoder);
if (_options.mechanism == ZMQ_NULL
&& memcmp (_greeting_recv + 12, "NULL\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0",
20)
== 0) {
_mechanism = new (std::nothrow)
null_mechanism_t (_session, _peer_address, _options);
alloc_assert (_mechanism);
} else if (_options.mechanism == ZMQ_PLAIN
&& memcmp (_greeting_recv + 12,
"PLAIN\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20)
== 0) {
if (_options.as_server)
_mechanism = new (std::nothrow) _mechanism = new (std::nothrow)
null_mechanism_t (_session, _peer_address, _options); plain_server_t (_session, _peer_address, _options);
alloc_assert (_mechanism); else
} else if (_options.mechanism == ZMQ_PLAIN _mechanism = new (std::nothrow) plain_client_t (_session, _options);
&& memcmp (_greeting_recv + 12, alloc_assert (_mechanism);
"PLAIN\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) }
== 0) {
if (_options.as_server)
_mechanism = new (std::nothrow)
plain_server_t (_session, _peer_address, _options);
else
_mechanism =
new (std::nothrow) plain_client_t (_session, _options);
alloc_assert (_mechanism);
}
#ifdef ZMQ_HAVE_CURVE #ifdef ZMQ_HAVE_CURVE
else if (_options.mechanism == ZMQ_CURVE else if (_options.mechanism == ZMQ_CURVE
&& memcmp (_greeting_recv + 12, && memcmp (_greeting_recv + 12,
"CURVE\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) "CURVE\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20)
== 0) { == 0) {
if (_options.as_server) if (_options.as_server)
_mechanism = new (std::nothrow) _mechanism = new (std::nothrow)
curve_server_t (_session, _peer_address, _options); curve_server_t (_session, _peer_address, _options);
else else
_mechanism = _mechanism = new (std::nothrow) curve_client_t (_session, _options);
new (std::nothrow) curve_client_t (_session, _options); alloc_assert (_mechanism);
alloc_assert (_mechanism); }
}
#endif #endif
#ifdef HAVE_LIBGSSAPI_KRB5 #ifdef HAVE_LIBGSSAPI_KRB5
else if (_options.mechanism == ZMQ_GSSAPI else if (_options.mechanism == ZMQ_GSSAPI
&& memcmp (_greeting_recv + 12, && memcmp (_greeting_recv + 12,
"GSSAPI\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) "GSSAPI\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20)
== 0) { == 0) {
if (_options.as_server) if (_options.as_server)
_mechanism = new (std::nothrow) _mechanism = new (std::nothrow)
gssapi_server_t (_session, _peer_address, _options); gssapi_server_t (_session, _peer_address, _options);
else else
_mechanism = _mechanism =
new (std::nothrow) gssapi_client_t (_session, _options); new (std::nothrow) gssapi_client_t (_session, _options);
alloc_assert (_mechanism); alloc_assert (_mechanism);
}
#endif
else {
_session->get_socket ()->event_handshake_failed_protocol (
_session->get_endpoint (),
ZMQ_PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH);
error (protocol_error);
return false;
}
_next_msg = &stream_engine_t::next_handshake_command;
_process_msg = &stream_engine_t::process_handshake_command;
} }
#endif
// Start polling for output if necessary. else {
if (_outsize == 0) _session->get_socket ()->event_handshake_failed_protocol (
set_pollout (_handle); _session->get_endpoint (),
ZMQ_PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH);
// Handshaking was successful. error (protocol_error);
// Switch into the normal message flow. return false;
_handshaking = false;
if (_has_handshake_timer) {
cancel_timer (handshake_timer_id);
_has_handshake_timer = false;
} }
_next_msg = &stream_engine_t::next_handshake_command;
_process_msg = &stream_engine_t::process_handshake_command;
return true; return true;
} }
......
...@@ -93,12 +93,22 @@ class stream_engine_t : public io_object_t, public i_engine ...@@ -93,12 +93,22 @@ class stream_engine_t : public io_object_t, public i_engine
// Function to handle network disconnections. // Function to handle network disconnections.
void error (error_reason_t reason_); void error (error_reason_t reason_);
// Receives the greeting message from the peer.
int receive_greeting ();
// Detects the protocol used by the peer. // Detects the protocol used by the peer.
bool handshake (); bool handshake ();
// Receive the greeting from the peer.
int receive_greeting ();
void receive_greeting_versioned ();
typedef bool (stream_engine_t::*handshake_fun_t) ();
static handshake_fun_t select_handshake_fun (bool unversioned,
unsigned char revision);
bool handshake_v1_0_unversioned ();
bool handshake_v1_0 ();
bool handshake_v2_0 ();
bool handshake_v3_0 ();
int routing_id_msg (msg_t *msg_); int routing_id_msg (msg_t *msg_);
int process_routing_id_msg (msg_t *msg_); int process_routing_id_msg (msg_t *msg_);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment