Commit d7a37783 authored by sigiesec's avatar sigiesec

Problem: plain_server_t duplicates zap_client_t::send_zap_request

Solution: Use zap_client_t::send_zap_request
parent 014b201d
...@@ -194,7 +194,7 @@ zmq::mechanism_t::status_t zmq::null_mechanism_t::status () const ...@@ -194,7 +194,7 @@ zmq::mechanism_t::status_t zmq::null_mechanism_t::status () const
int zmq::null_mechanism_t::send_zap_request () int zmq::null_mechanism_t::send_zap_request ()
{ {
return zap_client.send_zap_request ("NULL", 4, NULL, 0); return zap_client.send_zap_request ("NULL", 4, NULL, NULL, 0);
} }
int zmq::null_mechanism_t::receive_and_process_zap_reply () int zmq::null_mechanism_t::receive_and_process_zap_reply ()
......
...@@ -43,6 +43,7 @@ zmq::plain_server_t::plain_server_t (session_base_t *session_, ...@@ -43,6 +43,7 @@ zmq::plain_server_t::plain_server_t (session_base_t *session_,
mechanism_t (options_), mechanism_t (options_),
session (session_), session (session_),
peer_address (peer_address_), peer_address (peer_address_),
zap_client (session, peer_address, options),
state (waiting_for_hello) state (waiting_for_hello)
{ {
} }
...@@ -259,89 +260,12 @@ int zmq::plain_server_t::produce_error (msg_t *msg_) const ...@@ -259,89 +260,12 @@ int zmq::plain_server_t::produce_error (msg_t *msg_) const
int zmq::plain_server_t::send_zap_request (const std::string &username, int zmq::plain_server_t::send_zap_request (const std::string &username,
const std::string &password) const std::string &password)
{ {
int rc; const uint8_t *credentials[] = {
msg_t msg; reinterpret_cast<const uint8_t *> (username.c_str ()),
reinterpret_cast<const uint8_t *> (password.c_str ())};
// Address delimiter frame size_t credentials_sizes[] = {username.size (), password.size ()};
rc = msg.init (); return zap_client.send_zap_request ("PLAIN", 5, credentials,
errno_assert (rc == 0); credentials_sizes, 2);
msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg);
if (rc != 0)
return close_and_return (&msg, -1);
// Version frame
rc = msg.init_size (3);
errno_assert (rc == 0);
memcpy (msg.data (), "1.0", 3);
msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg);
if (rc != 0)
return close_and_return (&msg, -1);
// Request id frame
rc = msg.init_size (1);
errno_assert (rc == 0);
memcpy (msg.data (), "1", 1);
msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg);
if (rc != 0)
return close_and_return (&msg, -1);
// Domain frame
rc = msg.init_size (options.zap_domain.length ());
errno_assert (rc == 0);
memcpy (msg.data (), options.zap_domain.c_str (), options.zap_domain.length ());
msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg);
if (rc != 0)
return close_and_return (&msg, -1);
// Address frame
rc = msg.init_size (peer_address.length ());
errno_assert (rc == 0);
memcpy (msg.data (), peer_address.c_str (), peer_address.length ());
msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg);
if (rc != 0)
return close_and_return (&msg, -1);
// Identity frame
rc = msg.init_size (options.identity_size);
errno_assert (rc == 0);
memcpy (msg.data (), options.identity, options.identity_size);
msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg);
if (rc != 0)
return close_and_return (&msg, -1);
// Mechanism frame
rc = msg.init_size (5);
errno_assert (rc == 0);
memcpy (msg.data (), "PLAIN", 5);
msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg);
if (rc != 0)
return close_and_return (&msg, -1);
// Username frame
rc = msg.init_size (username.length ());
errno_assert (rc == 0);
memcpy (msg.data (), username.c_str (), username.length ());
msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg);
if (rc != 0)
return close_and_return (&msg, -1);
// Password frame
rc = msg.init_size (password.length ());
errno_assert (rc == 0);
memcpy (msg.data (), password.c_str (), password.length ());
rc = session->write_zap_msg (&msg);
if (rc != 0)
return close_and_return (&msg, -1);
return 0;
} }
int zmq::plain_server_t::receive_and_process_zap_reply () int zmq::plain_server_t::receive_and_process_zap_reply ()
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "mechanism.hpp" #include "mechanism.hpp"
#include "options.hpp" #include "options.hpp"
#include "zap_client.hpp"
namespace zmq namespace zmq
{ {
...@@ -71,6 +72,8 @@ namespace zmq ...@@ -71,6 +72,8 @@ namespace zmq
const std::string peer_address; const std::string peer_address;
zap_client_t zap_client;
// Status code as received from ZAP handler // Status code as received from ZAP handler
std::string status_code; std::string status_code;
......
...@@ -34,10 +34,29 @@ ...@@ -34,10 +34,29 @@
namespace zmq namespace zmq
{ {
zap_client_t::zap_client_t (session_base_t *const session_,
const std::string &peer_address_,
const options_t &options_) :
session (session_),
peer_address (peer_address_),
options (options_)
{
}
int zap_client_t::send_zap_request (const char *mechanism, int zap_client_t::send_zap_request (const char *mechanism,
size_t mechanism_length, size_t mechanism_length,
const uint8_t *credentials, const uint8_t *credentials,
size_t credentials_size) size_t credentials_size)
{
return send_zap_request (mechanism, mechanism_length, &credentials,
&credentials_size, 1);
}
int zap_client_t::send_zap_request (const char *mechanism,
size_t mechanism_length,
const uint8_t **credentials,
size_t *credentials_sizes,
size_t credentials_count)
{ {
// TODO I don't think the rc can be -1 anywhere below. // TODO I don't think the rc can be -1 anywhere below.
// It might only be -1 if the HWM was exceeded, but on the ZAP socket, // It might only be -1 if the HWM was exceeded, but on the ZAP socket,
...@@ -105,18 +124,19 @@ int zap_client_t::send_zap_request (const char *mechanism, ...@@ -105,18 +124,19 @@ int zap_client_t::send_zap_request (const char *mechanism,
rc = msg.init_size (mechanism_length); rc = msg.init_size (mechanism_length);
errno_assert (rc == 0); errno_assert (rc == 0);
memcpy (msg.data (), mechanism, mechanism_length); memcpy (msg.data (), mechanism, mechanism_length);
if (credentials) if (credentials_count)
msg.set_flags (msg_t::more); msg.set_flags (msg_t::more);
rc = session->write_zap_msg (&msg); rc = session->write_zap_msg (&msg);
if (rc != 0) if (rc != 0)
return close_and_return (&msg, -1); return close_and_return (&msg, -1);
// Credentials frame // Credentials frames
// Skip if credential is NULL for (size_t i = 0; i < credentials_count; ++i) {
if (credentials) { rc = msg.init_size (credentials_sizes[i]);
rc = msg.init_size (credentials_size);
errno_assert (rc == 0); errno_assert (rc == 0);
memcpy (msg.data (), credentials, credentials_size); if (i < credentials_count - 1)
msg.set_flags (msg_t::more);
memcpy (msg.data (), credentials[i], credentials_sizes[i]);
rc = session->write_zap_msg (&msg); rc = session->write_zap_msg (&msg);
if (rc != 0) if (rc != 0)
return close_and_return (&msg, -1); return close_and_return (&msg, -1);
......
...@@ -39,18 +39,19 @@ class zap_client_t ...@@ -39,18 +39,19 @@ class zap_client_t
public: public:
zap_client_t (session_base_t *const session_, zap_client_t (session_base_t *const session_,
const std::string &peer_address_, const std::string &peer_address_,
const options_t &options_) : const options_t &options_);
session (session_),
peer_address (peer_address_),
options (options_)
{
}
int send_zap_request (const char *mechanism, int send_zap_request (const char *mechanism,
size_t mechanism_length, size_t mechanism_length,
const uint8_t *credentials, const uint8_t *credentials,
size_t credentials_size); size_t credentials_size);
int send_zap_request (const char *mechanism,
size_t mechanism_length,
const uint8_t **credentials,
size_t *credentials_sizes,
size_t credentials_count);
private: private:
session_base_t *const session; session_base_t *const session;
const std::string &peer_address; const std::string &peer_address;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment