Unverified Commit 2dfdcaff authored by Luca Boccassi's avatar Luca Boccassi Committed by GitHub

Merge pull request #3141 from sigiesec/analyze

More code style improvements
parents bbc90388 47dcd84f
......@@ -740,6 +740,7 @@ set (cxx-sources
pgm_socket.hpp
pipe.hpp
plain_client.hpp
plain_common.hpp
plain_server.hpp
poll.hpp
poller.hpp
......
......@@ -134,6 +134,7 @@ src_libzmq_la_SOURCES = \
src/pipe.hpp \
src/plain_client.cpp \
src/plain_client.hpp \
src/plain_common.hpp \
src/plain_server.cpp \
src/plain_server.hpp \
src/platform.hpp \
......
......@@ -70,8 +70,6 @@ zmq::ctx_t::ctx_t () :
_starting (true),
_terminating (false),
_reaper (NULL),
_slot_count (0),
_slots (NULL),
_max_sockets (clipped_maxsocket (ZMQ_MAX_SOCKETS_DFLT)),
_max_msgsz (INT_MAX),
_io_thread_count (ZMQ_IO_THREADS_DFLT),
......@@ -115,10 +113,8 @@ zmq::ctx_t::~ctx_t ()
// Deallocate the reaper thread object.
LIBZMQ_DELETE (_reaper);
// Deallocate the array of mailboxes. No special work is
// needed as mailboxes themselves were deallocated with their
// The mailboxes in _slots themselves were deallocated with their
// corresponding io_thread/socket objects.
free (_slots);
// De-initialise crypto library, if needed.
zmq::random_close ();
......@@ -283,19 +279,23 @@ int zmq::ctx_t::get (int option_)
bool zmq::ctx_t::start ()
{
// Initialise the array of mailboxes. Additional three slots are for
// Initialise the array of mailboxes. Additional two slots are for
// zmq_ctx_term thread and reaper thread.
_opt_sync.lock ();
int mazmq = _max_sockets;
int ios = _io_thread_count;
const int term_and_reaper_threads_count = 2;
const int mazmq = _max_sockets;
const int ios = _io_thread_count;
_opt_sync.unlock ();
_slot_count = mazmq + ios + 2;
_slots =
static_cast<i_mailbox **> (malloc (sizeof (i_mailbox *) * _slot_count));
if (!_slots) {
int slot_count = mazmq + ios + term_and_reaper_threads_count;
try {
_slots.reserve (slot_count);
_empty_slots.reserve (slot_count - term_and_reaper_threads_count);
}
catch (const std::bad_alloc &) {
errno = ENOMEM;
goto fail;
return false;
}
_slots.resize (term_and_reaper_threads_count);
// Initialise the infrastructure for zmq_ctx_term thread.
_slots[term_tid] = &_term_mailbox;
......@@ -312,12 +312,10 @@ bool zmq::ctx_t::start ()
_reaper->start ();
// Create I/O thread objects and launch them.
for (int32_t i = static_cast<int32_t> (_slot_count) - 1;
i >= static_cast<int32_t> (2); i--) {
_slots[i] = NULL;
}
_slots.resize (slot_count, NULL);
for (int i = 2; i != ios + 2; i++) {
for (int i = term_and_reaper_threads_count;
i != ios + term_and_reaper_threads_count; i++) {
io_thread_t *io_thread = new (std::nothrow) io_thread_t (this, i);
if (!io_thread) {
errno = ENOMEM;
......@@ -333,8 +331,8 @@ bool zmq::ctx_t::start ()
}
// In the unused part of the slot array, create a list of empty slots.
for (int32_t i = static_cast<int32_t> (_slot_count) - 1;
i >= static_cast<int32_t> (ios) + 2; i--) {
for (int32_t i = static_cast<int32_t> (_slots.size ()) - 1;
i >= static_cast<int32_t> (ios) + term_and_reaper_threads_count; i--) {
_empty_slots.push_back (i);
}
......@@ -347,10 +345,7 @@ fail_cleanup_reaper:
_reaper = NULL;
fail_cleanup_slots:
free (_slots);
_slots = NULL;
fail:
_slots.clear ();
return false;
}
......
......@@ -198,8 +198,7 @@ class ctx_t : public thread_ctx_t
io_threads_t _io_threads;
// Array of pointers to mailboxes for both application and I/O threads.
uint32_t _slot_count;
i_mailbox **_slots;
std::vector<i_mailbox *> _slots;
// Mailbox for zmq_ctx_term thread.
mailbox_t _term_mailbox;
......
......@@ -218,10 +218,13 @@ int zmq::ip_resolver_t::resolve (ip_addr_t *ip_addr_, const char *name_)
// Trim any square brackets surrounding the address. Used for
// IPv6 addresses to remove the confusion with the port
// delimiter. Should we validate that the brackets are present if
// delimiter.
// TODO Should we validate that the brackets are present if
// 'addr' contains ':' ?
if (addr.size () >= 2 && addr[0] == '[' && addr[addr.size () - 1] == ']') {
addr = addr.substr (1, addr.size () - 2);
const size_t brackets_length = 2;
if (addr.size () >= brackets_length && addr[0] == '['
&& addr[addr.size () - 1] == ']') {
addr = addr.substr (1, addr.size () - brackets_length);
}
// Look for an interface name / zone_id in the address
......
......@@ -29,6 +29,7 @@
#include "precompiled.hpp"
#include <string.h>
#include <limits.h>
#include "mechanism.hpp"
#include "options.hpp"
......@@ -64,7 +65,7 @@ void zmq::mechanism_t::set_user_id (const void *data_, size_t size_)
_user_id.set (static_cast<const unsigned char *> (data_), size_);
zap_properties.ZMQ_MAP_INSERT_OR_EMPLACE (
std::string (ZMQ_MSG_PROPERTY_USER_ID),
std::string ((char *) data_, size_));
std::string (reinterpret_cast<const char *> (data_), size_));
}
const zmq::blob_t &zmq::mechanism_t::get_user_id () const
......@@ -113,15 +114,18 @@ const char *zmq::mechanism_t::socket_type_string (int socket_type_) const
return names[socket_type_];
}
const size_t name_len_size = sizeof (unsigned char);
const size_t value_len_size = sizeof (uint32_t);
static size_t property_len (size_t name_len_, size_t value_len_)
{
return 1 + name_len_ + 4 + value_len_;
return name_len_size + name_len_ + value_len_size + value_len_;
}
static size_t name_len (const char *name_)
{
const size_t name_len = strlen (name_);
zmq_assert (name_len <= 255);
zmq_assert (name_len <= UCHAR_MAX);
return name_len;
}
......@@ -135,12 +139,13 @@ size_t zmq::mechanism_t::add_property (unsigned char *ptr_,
const size_t total_len = ::property_len (name_len, value_len_);
zmq_assert (total_len <= ptr_capacity_);
*ptr_++ = static_cast<unsigned char> (name_len);
*ptr_ = static_cast<unsigned char> (name_len);
ptr_ += name_len_size;
memcpy (ptr_, name_, name_len);
ptr_ += name_len;
zmq_assert (value_len_ <= 0x7FFFFFFF);
put_uint32 (ptr_, static_cast<uint32_t> (value_len_));
ptr_ += 4;
ptr_ += value_len_size;
memcpy (ptr_, value_, value_len_);
return total_len;
......@@ -228,20 +233,21 @@ int zmq::mechanism_t::parse_metadata (const unsigned char *ptr_,
while (bytes_left > 1) {
const size_t name_length = static_cast<size_t> (*ptr_);
ptr_ += 1;
bytes_left -= 1;
ptr_ += name_len_size;
bytes_left -= name_len_size;
if (bytes_left < name_length)
break;
const std::string name = std::string ((char *) ptr_, name_length);
const std::string name =
std::string (reinterpret_cast<const char *> (ptr_), name_length);
ptr_ += name_length;
bytes_left -= name_length;
if (bytes_left < 4)
if (bytes_left < value_len_size)
break;
const size_t value_length = static_cast<size_t> (get_uint32 (ptr_));
ptr_ += 4;
bytes_left -= 4;
ptr_ += value_len_size;
bytes_left -= value_len_size;
if (bytes_left < value_length)
break;
......@@ -264,7 +270,8 @@ int zmq::mechanism_t::parse_metadata (const unsigned char *ptr_,
}
(zap_flag_ ? zap_properties : zmtp_properties)
.ZMQ_MAP_INSERT_OR_EMPLACE (
name, std::string ((char *) value, value_length));
name,
std::string (reinterpret_cast<const char *> (value), value_length));
}
if (bytes_left > 0) {
errno = EPROTO;
......
......@@ -55,12 +55,24 @@ int zmq::mechanism_base_t::check_basic_command_structure (msg_t *msg_)
void zmq::mechanism_base_t::handle_error_reason (const char *error_reason_,
size_t error_reason_len_)
{
if (error_reason_len_ == 3 && error_reason_[1] == '0'
&& error_reason_[2] == '0' && error_reason_[0] >= '3'
&& error_reason_[0] <= '5') {
// it is a ZAP status code, so emit an authentication failure event
const size_t status_code_len = 3;
const char zero_digit = '0';
const size_t significant_digit_index = 0;
const size_t first_zero_digit_index = 1;
const size_t second_zero_digit_index = 2;
const int factor = 100;
if (error_reason_len_ == status_code_len
&& error_reason_[first_zero_digit_index] == zero_digit
&& error_reason_[second_zero_digit_index] == zero_digit
&& error_reason_[significant_digit_index] >= '3'
&& error_reason_[significant_digit_index] <= '5') {
// it is a ZAP error status code (300, 400 or 500), so emit an authentication failure event
session->get_socket ()->event_handshake_failed_auth (
session->get_endpoint (), (error_reason_[0] - '0') * 100);
session->get_endpoint (),
(error_reason_[significant_digit_index] - zero_digit) * factor);
} else {
// this is a violation of the ZAP protocol
// TODO zmq_assert in this case?
}
}
......
......@@ -299,27 +299,21 @@ int zmq::msg_t::copy (msg_t &src_)
if (unlikely (rc < 0))
return rc;
if (src_._u.base.type == type_lmsg) {
// One reference is added to shared messages. Non-shared messages
// are turned into shared messages and reference count is set to 2.
if (src_._u.lmsg.flags & msg_t::shared)
src_._u.lmsg.content->refcnt.add (1);
else {
src_._u.lmsg.flags |= msg_t::shared;
src_._u.lmsg.content->refcnt.set (2);
}
}
// The initial reference count, when a non-shared message is initially
// shared (between the original and the copy we create here).
const atomic_counter_t::integer_t initial_shared_refcnt = 2;
if (src_.is_zcmsg ()) {
if (src_.is_lmsg () || src_.is_zcmsg ()) {
// One reference is added to shared messages. Non-shared messages
// are turned into shared messages and reference count is set to 2.
if (src_._u.zclmsg.flags & msg_t::shared)
// are turned into shared messages.
if (src_.flags () & msg_t::shared)
src_.refcnt ()->add (1);
else {
src_._u.zclmsg.flags |= msg_t::shared;
src_.refcnt ()->set (2);
src_.set_flags (msg_t::shared);
src_.refcnt ()->set (initial_shared_refcnt);
}
}
if (src_._u.base.metadata != NULL)
src_._u.base.metadata->add_ref ();
......@@ -431,6 +425,11 @@ bool zmq::msg_t::is_cmsg () const
return _u.base.type == type_cmsg;
}
bool zmq::msg_t::is_lmsg () const
{
return _u.base.type == type_lmsg;
}
bool zmq::msg_t::is_zcmsg () const
{
return _u.base.type == type_zclmsg;
......
......@@ -117,6 +117,7 @@ class msg_t
bool is_leave () const;
bool is_vsm () const;
bool is_cmsg () const;
bool is_lmsg () const;
bool is_zcmsg () const;
uint32_t get_routing_id ();
int set_routing_id (uint32_t routing_id_);
......
......@@ -38,6 +38,13 @@
#include "session_base.hpp"
#include "null_mechanism.hpp"
const char error_command_name[] = "\5ERROR";
const size_t error_command_name_len = sizeof (error_command_name) - 1;
const size_t error_reason_len_size = 1;
const char ready_command_name[] = "\5READY";
const size_t ready_command_name_len = sizeof (ready_command_name) - 1;
zmq::null_mechanism_t::null_mechanism_t (session_base_t *session_,
const std::string &peer_address_,
const options_t &options_) :
......@@ -96,20 +103,24 @@ int zmq::null_mechanism_t::next_handshake_command (msg_t *msg_)
_error_command_sent = true;
if (status_code != "300") {
const size_t status_code_len = 3;
const int rc = msg_->init_size (6 + 1 + status_code_len);
const int rc = msg_->init_size (
error_command_name_len + error_reason_len_size + status_code_len);
zmq_assert (rc == 0);
unsigned char *msg_data =
static_cast<unsigned char *> (msg_->data ());
memcpy (msg_data, "\5ERROR", 6);
msg_data[6] = status_code_len;
memcpy (msg_data + 7, status_code.c_str (), status_code_len);
memcpy (msg_data, error_command_name, error_command_name_len);
msg_data += error_command_name_len;
*msg_data = status_code_len;
msg_data += error_reason_len_size;
memcpy (msg_data, status_code.c_str (), status_code_len);
return 0;
}
errno = EAGAIN;
return -1;
}
make_command_with_basic_properties (msg_, "\5READY", 6);
make_command_with_basic_properties (msg_, ready_command_name,
ready_command_name_len);
_ready_command_sent = true;
......@@ -130,9 +141,11 @@ int zmq::null_mechanism_t::process_handshake_command (msg_t *msg_)
const size_t data_size = msg_->size ();
int rc = 0;
if (data_size >= 6 && !memcmp (cmd_data, "\5READY", 6))
if (data_size >= ready_command_name_len
&& !memcmp (cmd_data, ready_command_name, ready_command_name_len))
rc = process_ready_command (cmd_data, data_size);
else if (data_size >= 6 && !memcmp (cmd_data, "\5ERROR", 6))
else if (data_size >= error_command_name_len
&& !memcmp (cmd_data, error_command_name, error_command_name_len))
rc = process_error_command (cmd_data, data_size);
else {
session->get_socket ()->event_handshake_failed_protocol (
......@@ -154,13 +167,16 @@ int zmq::null_mechanism_t::process_ready_command (
const unsigned char *cmd_data_, size_t data_size_)
{
_ready_command_received = true;
return parse_metadata (cmd_data_ + 6, data_size_ - 6);
return parse_metadata (cmd_data_ + ready_command_name_len,
data_size_ - ready_command_name_len);
}
int zmq::null_mechanism_t::process_error_command (
const unsigned char *cmd_data_, size_t data_size_)
{
if (data_size_ < 7) {
const size_t fixed_prefix_size =
error_command_name_len + error_reason_len_size;
if (data_size_ < fixed_prefix_size) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (),
ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
......@@ -168,8 +184,9 @@ int zmq::null_mechanism_t::process_error_command (
errno = EPROTO;
return -1;
}
const size_t error_reason_len = static_cast<size_t> (cmd_data_[6]);
if (error_reason_len > data_size_ - 7) {
const size_t error_reason_len =
static_cast<size_t> (cmd_data_[error_command_name_len]);
if (error_reason_len > data_size_ - fixed_prefix_size) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (),
ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
......@@ -177,7 +194,8 @@ int zmq::null_mechanism_t::process_error_command (
errno = EPROTO;
return -1;
}
const char *error_reason = reinterpret_cast<const char *> (cmd_data_) + 7;
const char *error_reason =
reinterpret_cast<const char *> (cmd_data_) + fixed_prefix_size;
handle_error_reason (error_reason, error_reason_len);
_error_command_received = true;
return 0;
......
......@@ -29,6 +29,7 @@
#include "precompiled.hpp"
#include <string.h>
#include <limits.h>
#include <set>
#include "options.hpp"
......@@ -262,7 +263,8 @@ int zmq::options_t::set_curve_key (uint8_t *destination_,
return 0;
case CURVE_KEYSIZE_Z85 + 1:
if (zmq_z85_decode (destination_, (char *) optval_)) {
if (zmq_z85_decode (destination_,
reinterpret_cast<const char *> (optval_))) {
mechanism = ZMQ_CURVE;
return 0;
}
......@@ -270,7 +272,8 @@ int zmq::options_t::set_curve_key (uint8_t *destination_,
case CURVE_KEYSIZE_Z85:
char z85_key[CURVE_KEYSIZE_Z85 + 1];
memcpy (z85_key, (char *) optval_, optvallen_);
memcpy (z85_key, reinterpret_cast<const char *> (optval_),
optvallen_);
z85_key[CURVE_KEYSIZE_Z85] = 0;
if (zmq_z85_decode (destination_, z85_key)) {
mechanism = ZMQ_CURVE;
......@@ -491,7 +494,7 @@ int zmq::options_t::setsockopt (int option_,
case ZMQ_TCP_ACCEPT_FILTER: {
std::string filter_str;
int rc = do_setsockopt_string_allow_empty_strict (
optval_, optvallen_, &filter_str, 255);
optval_, optvallen_, &filter_str, UCHAR_MAX);
if (rc == 0) {
if (filter_str.empty ()) {
tcp_accept_filters.clear ();
......@@ -559,7 +562,7 @@ int zmq::options_t::setsockopt (int option_,
case ZMQ_ZAP_DOMAIN:
return do_setsockopt_string_allow_empty_relaxed (
optval_, optvallen_, &zap_domain, 255);
optval_, optvallen_, &zap_domain, UCHAR_MAX);
break;
// If curve encryption isn't built, these options provoke EINVAL
......@@ -718,15 +721,14 @@ int zmq::options_t::setsockopt (int option_,
case ZMQ_METADATA:
if (optvallen_ > 0 && !is_int) {
std::string s ((char *) optval_);
size_t pos = 0;
std::string key, val, delimiter = ":";
pos = s.find (delimiter);
const std::string s (reinterpret_cast<const char *> (optval_));
const size_t pos = s.find (":");
if (pos != std::string::npos && pos != 0
&& pos != s.length () - 1) {
key = s.substr (0, pos);
if (key.compare (0, 2, "X-") == 0 && key.length () < 256) {
val = s.substr (pos + 1, s.length ());
std::string key = s.substr (0, pos);
if (key.compare (0, 2, "X-") == 0
&& key.length () <= UCHAR_MAX) {
std::string val = s.substr (pos + 1, s.length ());
app_metadata.insert (
std::pair<std::string, std::string> (key, val));
return 0;
......@@ -735,7 +737,6 @@ int zmq::options_t::setsockopt (int option_,
}
errno = EINVAL;
return -1;
break;
case ZMQ_MULTICAST_LOOP:
return do_setsockopt_int_as_bool_relaxed (optval_, optvallen_,
......
......@@ -31,11 +31,13 @@
#include "macros.hpp"
#include <string>
#include <limits.h>
#include "msg.hpp"
#include "err.hpp"
#include "plain_client.hpp"
#include "session_base.hpp"
#include "plain_common.hpp"
zmq::plain_client_t::plain_client_t (session_base_t *const session_,
const options_t &options_) :
......@@ -54,13 +56,11 @@ int zmq::plain_client_t::next_handshake_command (msg_t *msg_)
switch (_state) {
case sending_hello:
rc = produce_hello (msg_);
if (rc == 0)
produce_hello (msg_);
_state = waiting_for_welcome;
break;
case sending_initiate:
rc = produce_initiate (msg_);
if (rc == 0)
produce_initiate (msg_);
_state = waiting_for_ready;
break;
default:
......@@ -77,11 +77,14 @@ int zmq::plain_client_t::process_handshake_command (msg_t *msg_)
const size_t data_size = msg_->size ();
int rc = 0;
if (data_size >= 8 && !memcmp (cmd_data, "\7WELCOME", 8))
if (data_size >= welcome_prefix_len
&& !memcmp (cmd_data, welcome_prefix, welcome_prefix_len))
rc = process_welcome (cmd_data, data_size);
else if (data_size >= 6 && !memcmp (cmd_data, "\5READY", 6))
else if (data_size >= ready_prefix_len
&& !memcmp (cmd_data, ready_prefix, ready_prefix_len))
rc = process_ready (cmd_data, data_size);
else if (data_size >= 6 && !memcmp (cmd_data, "\5ERROR", 6))
else if (data_size >= error_prefix_len
&& !memcmp (cmd_data, error_prefix, error_prefix_len))
rc = process_error (cmd_data, data_size);
else {
session->get_socket ()->event_handshake_failed_protocol (
......@@ -110,23 +113,24 @@ zmq::mechanism_t::status_t zmq::plain_client_t::status () const
return mechanism_t::handshaking;
}
int zmq::plain_client_t::produce_hello (msg_t *msg_) const
void zmq::plain_client_t::produce_hello (msg_t *msg_) const
{
const std::string username = options.plain_username;
zmq_assert (username.length () < 256);
zmq_assert (username.length () <= UCHAR_MAX);
const std::string password = options.plain_password;
zmq_assert (password.length () < 256);
zmq_assert (password.length () <= UCHAR_MAX);
const size_t command_size =
6 + 1 + username.length () + 1 + password.length ();
const size_t command_size = hello_prefix_len + brief_len_size
+ username.length () + brief_len_size
+ password.length ();
const int rc = msg_->init_size (command_size);
errno_assert (rc == 0);
unsigned char *ptr = static_cast<unsigned char *> (msg_->data ());
memcpy (ptr, "\x05HELLO", 6);
ptr += 6;
memcpy (ptr, hello_prefix, hello_prefix_len);
ptr += hello_prefix_len;
*ptr++ = static_cast<unsigned char> (username.length ());
memcpy (ptr, username.c_str (), username.length ());
......@@ -134,8 +138,6 @@ int zmq::plain_client_t::produce_hello (msg_t *msg_) const
*ptr++ = static_cast<unsigned char> (password.length ());
memcpy (ptr, password.c_str (), password.length ());
return 0;
}
int zmq::plain_client_t::process_welcome (const unsigned char *cmd_data_,
......@@ -149,7 +151,7 @@ int zmq::plain_client_t::process_welcome (const unsigned char *cmd_data_,
errno = EPROTO;
return -1;
}
if (data_size_ != 8) {
if (data_size_ != welcome_prefix_len) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (),
ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME);
......@@ -160,11 +162,10 @@ int zmq::plain_client_t::process_welcome (const unsigned char *cmd_data_,
return 0;
}
int zmq::plain_client_t::produce_initiate (msg_t *msg_) const
void zmq::plain_client_t::produce_initiate (msg_t *msg_) const
{
make_command_with_basic_properties (msg_, "\x08INITIATE", 9);
return 0;
make_command_with_basic_properties (msg_, initiate_prefix,
initiate_prefix_len);
}
int zmq::plain_client_t::process_ready (const unsigned char *cmd_data_,
......@@ -176,7 +177,8 @@ int zmq::plain_client_t::process_ready (const unsigned char *cmd_data_,
errno = EPROTO;
return -1;
}
const int rc = parse_metadata (cmd_data_ + 6, data_size_ - 6);
const int rc = parse_metadata (cmd_data_ + ready_prefix_len,
data_size_ - ready_prefix_len);
if (rc == 0)
_state = ready;
else
......@@ -195,22 +197,25 @@ int zmq::plain_client_t::process_error (const unsigned char *cmd_data_,
errno = EPROTO;
return -1;
}
if (data_size_ < 7) {
const size_t start_of_error_reason = error_prefix_len + brief_len_size;
if (data_size_ < start_of_error_reason) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (),
ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
errno = EPROTO;
return -1;
}
const size_t error_reason_len = static_cast<size_t> (cmd_data_[6]);
if (error_reason_len > data_size_ - 7) {
const size_t error_reason_len =
static_cast<size_t> (cmd_data_[error_prefix_len]);
if (error_reason_len > data_size_ - start_of_error_reason) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (),
ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
errno = EPROTO;
return -1;
}
const char *error_reason = reinterpret_cast<const char *> (cmd_data_) + 7;
const char *error_reason =
reinterpret_cast<const char *> (cmd_data_) + start_of_error_reason;
handle_error_reason (error_reason, error_reason_len);
_state = error_command_received;
return 0;
......
......@@ -61,8 +61,8 @@ class plain_client_t : public mechanism_base_t
state_t _state;
int produce_hello (msg_t *msg_) const;
int produce_initiate (msg_t *msg_) const;
void produce_hello (msg_t *msg_) const;
void produce_initiate (msg_t *msg_) const;
int process_welcome (const unsigned char *cmd_data_, size_t data_size_);
int process_ready (const unsigned char *cmd_data_, size_t data_size_);
......
/*
Copyright (c) 2007-2016 Contributors as noted in the AUTHORS file
This file is part of libzmq, the ZeroMQ core engine in C++.
libzmq is free software; you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version.
As a special exception, the Contributors give you permission to link
this library with independent modules to produce an executable,
regardless of the license terms of these independent modules, and to
copy and distribute the resulting executable under terms of your choice,
provided that you also meet, for each linked independent module, the
terms and conditions of the license of that module. An independent
module is a module which is not derived from or based on this library.
If you modify this library, you must extend this exception to your
version of the library.
libzmq is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
License for more details.
You should have received a copy of the GNU Lesser General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef __ZMQ_PLAIN_COMMON_HPP_INCLUDED__
#define __ZMQ_PLAIN_COMMON_HPP_INCLUDED__
namespace zmq
{
const char hello_prefix[] = "\x05WELCOME";
const size_t hello_prefix_len = sizeof (hello_prefix) - 1;
const char welcome_prefix[] = "\x07WELCOME";
const size_t welcome_prefix_len = sizeof (welcome_prefix) - 1;
const char initiate_prefix[] = "\x08INITIATE";
const size_t initiate_prefix_len = sizeof (initiate_prefix) - 1;
const char ready_prefix[] = "\x05READY";
const size_t ready_prefix_len = sizeof (ready_prefix) - 1;
const char error_prefix[] = "\x05ERROR";
const size_t error_prefix_len = sizeof (error_prefix) - 1;
const size_t brief_len_size = sizeof (char);
}
#endif
......@@ -36,6 +36,7 @@
#include "err.hpp"
#include "plain_server.hpp"
#include "wire.hpp"
#include "plain_common.hpp"
zmq::plain_server_t::plain_server_t (session_base_t *session_,
const std::string &peer_address_,
......@@ -63,18 +64,15 @@ int zmq::plain_server_t::next_handshake_command (msg_t *msg_)
switch (state) {
case sending_welcome:
rc = produce_welcome (msg_);
if (rc == 0)
produce_welcome (msg_);
state = waiting_for_initiate;
break;
case sending_ready:
rc = produce_ready (msg_);
if (rc == 0)
produce_ready (msg_);
state = ready;
break;
case sending_error:
rc = produce_error (msg_);
if (rc == 0)
produce_error (msg_);
state = error_sent;
break;
default:
......@@ -118,17 +116,18 @@ int zmq::plain_server_t::process_hello (msg_t *msg_)
if (rc == -1)
return -1;
const unsigned char *ptr = static_cast<unsigned char *> (msg_->data ());
const char *ptr = static_cast<char *> (msg_->data ());
size_t bytes_left = msg_->size ();
if (bytes_left < 6 || memcmp (ptr, "\x05HELLO", 6)) {
if (bytes_left < hello_prefix_len
|| memcmp (ptr, hello_prefix, hello_prefix_len)) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND);
errno = EPROTO;
return -1;
}
ptr += 6;
bytes_left -= 6;
ptr += hello_prefix_len;
bytes_left -= hello_prefix_len;
if (bytes_left < 1) {
// PLAIN I: invalid PLAIN client, did not send username
......@@ -149,7 +148,7 @@ int zmq::plain_server_t::process_hello (msg_t *msg_)
errno = EPROTO;
return -1;
}
const std::string username = std::string ((char *) ptr, username_length);
const std::string username = std::string (ptr, username_length);
ptr += username_length;
bytes_left -= username_length;
if (bytes_left < 1) {
......@@ -172,7 +171,7 @@ int zmq::plain_server_t::process_hello (msg_t *msg_)
return -1;
}
const std::string password = std::string ((char *) ptr, password_length);
const std::string password = std::string (ptr, password_length);
ptr += password_length;
bytes_left -= password_length;
if (bytes_left > 0) {
......@@ -202,12 +201,11 @@ int zmq::plain_server_t::process_hello (msg_t *msg_)
return receive_and_process_zap_reply () == -1 ? -1 : 0;
}
int zmq::plain_server_t::produce_welcome (msg_t *msg_) const
void zmq::plain_server_t::produce_welcome (msg_t *msg_) const
{
const int rc = msg_->init_size (8);
const int rc = msg_->init_size (welcome_prefix_len);
errno_assert (rc == 0);
memcpy (msg_->data (), "\x07WELCOME", 8);
return 0;
memcpy (msg_->data (), welcome_prefix, welcome_prefix_len);
}
int zmq::plain_server_t::process_initiate (msg_t *msg_)
......@@ -215,35 +213,39 @@ int zmq::plain_server_t::process_initiate (msg_t *msg_)
const unsigned char *ptr = static_cast<unsigned char *> (msg_->data ());
const size_t bytes_left = msg_->size ();
if (bytes_left < 9 || memcmp (ptr, "\x08INITIATE", 9)) {
if (bytes_left < initiate_prefix_len
|| memcmp (ptr, initiate_prefix, initiate_prefix_len)) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND);
errno = EPROTO;
return -1;
}
const int rc = parse_metadata (ptr + 9, bytes_left - 9);
const int rc = parse_metadata (ptr + initiate_prefix_len,
bytes_left - initiate_prefix_len);
if (rc == 0)
state = sending_ready;
return rc;
}
int zmq::plain_server_t::produce_ready (msg_t *msg_) const
void zmq::plain_server_t::produce_ready (msg_t *msg_) const
{
make_command_with_basic_properties (msg_, "\5READY", 6);
return 0;
make_command_with_basic_properties (msg_, ready_prefix, ready_prefix_len);
}
int zmq::plain_server_t::produce_error (msg_t *msg_) const
void zmq::plain_server_t::produce_error (msg_t *msg_) const
{
zmq_assert (status_code.length () == 3);
const int rc = msg_->init_size (6 + 1 + status_code.length ());
const char expected_status_code_len = 3;
zmq_assert (status_code.length ()
== static_cast<size_t> (expected_status_code_len));
const size_t status_code_len_size = sizeof (expected_status_code_len);
const int rc = msg_->init_size (error_prefix_len + status_code_len_size
+ expected_status_code_len);
zmq_assert (rc == 0);
char *msg_data = static_cast<char *> (msg_->data ());
memcpy (msg_data, "\5ERROR", 6);
msg_data[6] = static_cast<char> (status_code.length ());
memcpy (msg_data + 7, status_code.c_str (), status_code.length ());
return 0;
memcpy (msg_data, error_prefix, error_prefix_len);
msg_data[error_prefix_len] = expected_status_code_len;
memcpy (msg_data + error_prefix_len + status_code_len_size,
status_code.c_str (), status_code.length ());
}
void zmq::plain_server_t::send_zap_request (const std::string &username_,
......@@ -253,6 +255,8 @@ void zmq::plain_server_t::send_zap_request (const std::string &username_,
reinterpret_cast<const uint8_t *> (username_.c_str ()),
reinterpret_cast<const uint8_t *> (password_.c_str ())};
size_t credentials_sizes[] = {username_.size (), password_.size ()};
zap_client_t::send_zap_request ("PLAIN", 5, credentials, credentials_sizes,
2);
const char plain_mechanism_name[] = "PLAIN";
zap_client_t::send_zap_request (
plain_mechanism_name, sizeof (plain_mechanism_name) - 1, credentials,
credentials_sizes, sizeof (credentials) / sizeof (credentials[0]));
}
......@@ -51,9 +51,9 @@ class plain_server_t : public zap_client_common_handshake_t
virtual int process_handshake_command (msg_t *msg_);
private:
int produce_welcome (msg_t *msg_) const;
int produce_ready (msg_t *msg_) const;
int produce_error (msg_t *msg_) const;
void produce_welcome (msg_t *msg_) const;
void produce_ready (msg_t *msg_) const;
void produce_error (msg_t *msg_) const;
int process_hello (msg_t *msg_);
int process_initiate (msg_t *msg_);
......
......@@ -99,11 +99,10 @@ static int sleep_ms (unsigned int ms_)
static int close_wait_ms (int fd_, unsigned int max_ms_ = 2000)
{
unsigned int ms_so_far = 0;
unsigned int step_ms = max_ms_ / 10;
if (step_ms < 1)
step_ms = 1;
if (step_ms > 100)
step_ms = 100;
const unsigned int min_step_ms = 1;
const unsigned int max_step_ms = 100;
const unsigned int step_ms =
std::min (std::max (min_step_ms, max_ms_ / 10), max_step_ms);
int rc = 0; // do not sleep on first attempt
do {
......
......@@ -1731,14 +1731,14 @@ void zmq::socket_base_t::monitor_event (int event_,
if (_monitor_socket) {
// Send event in first frame
const uint16_t event = static_cast<uint16_t> (event_);
const uint32_t value = static_cast<uint32_t> (value_);
zmq_msg_t msg;
zmq_msg_init_size (&msg, 6);
zmq_msg_init_size (&msg, sizeof (event) + sizeof (value));
uint8_t *data = static_cast<uint8_t *> (zmq_msg_data (&msg));
// Avoid dereferencing uint32_t on unaligned address
uint16_t event = static_cast<uint16_t> (event_);
uint32_t value = static_cast<uint32_t> (value_);
memcpy (data + 0, &event, sizeof (event));
memcpy (data + 2, &value, sizeof (value));
memcpy (data + sizeof (event), &value, sizeof (value));
zmq_sendmsg (_monitor_socket, &msg, ZMQ_SNDMORE);
// Send address in second frame
......
......@@ -30,6 +30,7 @@
#include "precompiled.hpp"
#include "macros.hpp"
#include <limits.h>
#include <string.h>
#ifndef ZMQ_HAVE_WINDOWS
......@@ -232,7 +233,7 @@ void zmq::stream_engine_t::plug (io_thread_t *io_thread_,
// Send the 'length' and 'flags' fields of the routing id message.
// The 'length' field is encoded in the long format.
_outpos = _greeting_send;
_outpos[_outsize++] = 0xff;
_outpos[_outsize++] = UCHAR_MAX;
put_uint64 (&_outpos[_outsize], _options.routing_id_size + 1);
_outsize += 8;
_outpos[_outsize++] = 0x7f;
......@@ -587,7 +588,8 @@ bool zmq::stream_engine_t::handshake ()
// Since there is no way to tell the encoder to
// skip the message header, we simply throw that
// header data away.
const size_t header_size = _options.routing_id_size + 1 >= 255 ? 10 : 2;
const size_t header_size =
_options.routing_id_size + 1 >= UCHAR_MAX ? 10 : 2;
unsigned char tmp[10], *bufferp = tmp;
// Prepare the routing id message and load it into encoder.
......
......@@ -48,6 +48,8 @@
#include <stdlib.h>
#endif
#include <limits.h>
zmq::tcp_address_t::tcp_address_t () : _has_src_addr (false)
{
memset (&_address, 0, sizeof (_address));
......@@ -227,17 +229,18 @@ int zmq::tcp_address_mask_t::resolve (const char *name_, bool ipv6_)
return rc;
// Parse the cidr mask number.
const int full_mask_ipv4 = sizeof (_address.ipv4.sin_addr) * CHAR_BIT;
const int full_mask_ipv6 = sizeof (_address.ipv6.sin6_addr) * CHAR_BIT;
if (mask_str.empty ()) {
if (_address.family () == AF_INET6)
_address_mask = 128;
else
_address_mask = 32;
_address_mask =
_address.family () == AF_INET6 ? full_mask_ipv6 : full_mask_ipv4;
} else if (mask_str == "0")
_address_mask = 0;
else {
const int mask = atoi (mask_str.c_str ());
if ((mask < 1) || (_address.family () == AF_INET6 && mask > 128)
|| (_address.family () != AF_INET6 && mask > 32)) {
if ((mask < 1)
|| (_address.family () == AF_INET6 && mask > full_mask_ipv6)
|| (_address.family () != AF_INET6 && mask > full_mask_ipv4)) {
errno = EINVAL;
return -1;
}
......
......@@ -146,7 +146,7 @@ void zmq::thread_t::
{
int priority =
(_thread_priority >= 0 ? _thread_priority : DEFAULT_PRIORITY);
priority = (priority < 255 ? priority : DEFAULT_PRIORITY);
priority = (priority < UCHAR_MAX ? priority : DEFAULT_PRIORITY);
if (_descriptor != NULL || _descriptor > 0) {
taskPrioritySet (_descriptor, priority);
}
......
......@@ -362,8 +362,8 @@ void zmq::udp_engine_t::out_event ()
msg_t body_msg;
rc = _session->pull_msg (&body_msg);
size_t group_size = group_msg.size ();
size_t body_size = body_msg.size ();
const size_t group_size = group_msg.size ();
const size_t body_size = body_msg.size ();
size_t size;
if (_options.raw_socket) {
......@@ -400,12 +400,11 @@ void zmq::udp_engine_t::out_event ()
errno_assert (rc == 0);
#ifdef ZMQ_HAVE_WINDOWS
rc = sendto (_fd, reinterpret_cast<const char *> (_out_buffer),
static_cast<int> (size), 0, _out_address,
rc = sendto (_fd, _out_buffer, static_cast<int> (size), 0, _out_address,
static_cast<int> (_out_address_len));
wsa_assert (rc != SOCKET_ERROR);
#elif defined ZMQ_HAVE_VXWORKS
rc = sendto (_fd, (caddr_t) _out_buffer, size, 0,
rc = sendto (_fd, reinterpret_cast<caddr_t> (_out_buffer), size, 0,
(sockaddr *) _out_address, (int) _out_address_len);
errno_assert (rc != -1);
#else
......@@ -440,7 +439,7 @@ void zmq::udp_engine_t::in_event ()
socklen_t in_addrlen = sizeof (sockaddr_storage);
#ifdef ZMQ_HAVE_WINDOWS
int nbytes =
recvfrom (_fd, reinterpret_cast<char *> (_in_buffer), MAX_UDP_MSG, 0,
recvfrom (_fd, _in_buffer, MAX_UDP_MSG, 0,
reinterpret_cast<sockaddr *> (&in_address), &in_addrlen);
const int last_error = WSAGetLastError ();
if (nbytes == SOCKET_ERROR) {
......@@ -449,7 +448,7 @@ void zmq::udp_engine_t::in_event ()
return;
}
#elif defined ZMQ_HAVE_VXWORKS
int nbytes = recvfrom (_fd, (char *) _in_buffer, MAX_UDP_MSG, 0,
int nbytes = recvfrom (_fd, _in_buffer, MAX_UDP_MSG, 0,
(sockaddr *) &in_address, (int *) &in_addrlen);
if (nbytes == -1) {
errno_assert (errno != EBADF && errno != EFAULT && errno != ENOMEM
......@@ -483,9 +482,10 @@ void zmq::udp_engine_t::in_event ()
body_size = nbytes;
body_offset = 0;
} else {
const char *group_buffer =
reinterpret_cast<const char *> (_in_buffer) + 1;
int group_size = _in_buffer[0];
// TODO in out_event, the group size is an *unsigned* char. what is
// the maximum value?
const char *group_buffer = _in_buffer + 1;
const int group_size = _in_buffer[0];
rc = msg.init_size (group_size);
errno_assert (rc == 0);
......
......@@ -62,8 +62,8 @@ class udp_engine_t : public io_object_t, public i_engine
const struct sockaddr *_out_address;
socklen_t _out_address_len;
unsigned char _out_buffer[MAX_UDP_MSG];
unsigned char _in_buffer[MAX_UDP_MSG];
char _out_buffer[MAX_UDP_MSG];
char _in_buffer[MAX_UDP_MSG];
bool _send_enabled;
bool _recv_enabled;
};
......
......@@ -31,6 +31,7 @@
#include <stdlib.h>
#include <string.h>
#include <limits>
#include <limits.h>
#include "decoder.hpp"
#include "v1_decoder.hpp"
......@@ -57,10 +58,10 @@ zmq::v1_decoder_t::~v1_decoder_t ()
int zmq::v1_decoder_t::one_byte_size_ready (unsigned char const *)
{
// First byte of size is read. If it is 0xff read 8-byte size.
// First byte of size is read. If it is UCHAR_MAX (0xff) read 8-byte size.
// Otherwise allocate the buffer for message data and read the
// message data into it.
if (*_tmpbuf == 0xff)
if (*_tmpbuf == UCHAR_MAX)
next_step (_tmpbuf, 8, &v1_decoder_t::eight_byte_size_ready);
else {
// There has to be at least one byte (the flags) in the message).
......
......@@ -33,6 +33,8 @@
#include "msg.hpp"
#include "wire.hpp"
#include <limits.h>
zmq::v1_encoder_t::v1_encoder_t (size_t bufsize_) :
encoder_base_t<v1_encoder_t> (bufsize_)
{
......@@ -62,12 +64,12 @@ void zmq::v1_encoder_t::message_ready ()
// For messages less than 255 bytes long, write one byte of message size.
// For longer messages write 0xff escape character followed by 8-byte
// message size. In both cases 'flags' field follows.
if (size < 255) {
if (size < UCHAR_MAX) {
_tmpbuf[0] = static_cast<unsigned char> (size);
_tmpbuf[1] = (in_progress->flags () & msg_t::more);
next_step (_tmpbuf, 2, &v1_encoder_t::size_ready, false);
} else {
_tmpbuf[0] = 0xff;
_tmpbuf[0] = UCHAR_MAX;
put_uint64 (_tmpbuf + 1, size);
_tmpbuf[9] = (in_progress->flags () & msg_t::more);
next_step (_tmpbuf, 10, &v1_encoder_t::size_ready, false);
......
......@@ -34,6 +34,8 @@
#include "likely.hpp"
#include "wire.hpp"
#include <limits.h>
zmq::v2_encoder_t::v2_encoder_t (size_t bufsize_) :
encoder_base_t<v2_encoder_t> (bufsize_)
{
......@@ -52,7 +54,7 @@ void zmq::v2_encoder_t::message_ready ()
protocol_flags = 0;
if (in_progress->flags () & msg_t::more)
protocol_flags |= v2_protocol_t::more_flag;
if (in_progress->size () > 255)
if (in_progress->size () > UCHAR_MAX)
protocol_flags |= v2_protocol_t::large_flag;
if (in_progress->flags () & msg_t::command)
protocol_flags |= v2_protocol_t::command_flag;
......@@ -61,7 +63,7 @@ void zmq::v2_encoder_t::message_ready ()
// the length is encoded as 8-bit unsigned integer. For larger
// messages, 64-bit unsigned integer in network byte order is used.
const size_t size = in_progress->size ();
if (unlikely (size > 255)) {
if (unlikely (size > UCHAR_MAX)) {
put_uint64 (_tmp_buf + 1, size);
next_step (_tmp_buf, 9, &v2_encoder_t::size_ready, false);
} else {
......
......@@ -35,6 +35,12 @@
namespace zmq
{
const char zap_version[] = "1.0";
const size_t zap_version_len = sizeof (zap_version) - 1;
const char id[] = "1";
const size_t id_len = sizeof (id) - 1;
zap_client_t::zap_client_t (session_base_t *const session_,
const std::string &peer_address_,
const options_t &options_) :
......@@ -72,17 +78,17 @@ void zap_client_t::send_zap_request (const char *mechanism_,
errno_assert (rc == 0);
// Version frame
rc = msg.init_size (3);
rc = msg.init_size (zap_version_len);
errno_assert (rc == 0);
memcpy (msg.data (), "1.0", 3);
memcpy (msg.data (), zap_version, zap_version_len);
msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg);
errno_assert (rc == 0);
// Request ID frame
rc = msg.init_size (1);
rc = msg.init_size (id_len);
errno_assert (rc == 0);
memcpy (msg.data (), "1", 1);
memcpy (msg.data (), id, id_len);
msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg);
errno_assert (rc == 0);
......@@ -136,15 +142,16 @@ void zap_client_t::send_zap_request (const char *mechanism_,
int zap_client_t::receive_and_process_zap_reply ()
{
int rc = 0;
msg_t msg[7]; // ZAP reply consists of 7 frames
const size_t zap_reply_frame_count = 7;
msg_t msg[zap_reply_frame_count];
// Initialize all reply frames
for (int i = 0; i < 7; i++) {
for (size_t i = 0; i < zap_reply_frame_count; i++) {
rc = msg[i].init ();
errno_assert (rc == 0);
}
for (int i = 0; i < 7; i++) {
for (size_t i = 0; i < zap_reply_frame_count; i++) {
rc = session->read_zap_msg (&msg[i]);
if (rc == -1) {
if (errno == EAGAIN) {
......@@ -152,7 +159,8 @@ int zap_client_t::receive_and_process_zap_reply ()
}
return close_and_return (msg, -1);
}
if ((msg[i].flags () & msg_t::more) == (i < 6 ? 0 : msg_t::more)) {
if ((msg[i].flags () & msg_t::more)
== (i < zap_reply_frame_count - 1 ? 0 : msg_t::more)) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZAP_MALFORMED_REPLY);
errno = EPROTO;
......@@ -170,7 +178,8 @@ int zap_client_t::receive_and_process_zap_reply ()
}
// Version frame
if (msg[1].size () != 3 || memcmp (msg[1].data (), "1.0", 3)) {
if (msg[1].size () != zap_version_len
|| memcmp (msg[1].data (), zap_version, zap_version_len)) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZAP_BAD_VERSION);
errno = EPROTO;
......@@ -178,7 +187,7 @@ int zap_client_t::receive_and_process_zap_reply ()
}
// Request id frame
if (msg[2].size () != 1 || memcmp (msg[2].data (), "1", 1)) {
if (msg[2].size () != id_len || memcmp (msg[2].data (), id, id_len)) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID);
errno = EPROTO;
......@@ -214,7 +223,7 @@ int zap_client_t::receive_and_process_zap_reply ()
}
// Close all reply frames
for (int i = 0; i < 7; i++) {
for (size_t i = 0; i < zap_reply_frame_count; i++) {
const int rc2 = msg[i].close ();
errno_assert (rc2 == 0);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment