plain_mechanism.cpp 12.9 KB
Newer Older
1
/*
2
    Copyright (c) 2007-2014 Contributors as noted in the AUTHORS file
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28

    This file is part of 0MQ.

    0MQ is free software; you can redistribute it and/or modify it under
    the terms of the GNU Lesser General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    0MQ is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "platform.hpp"
#ifdef ZMQ_HAVE_WINDOWS
#include "windows.hpp"
#endif

#include <string.h>
#include <string>

#include "msg.hpp"
29
#include "session_base.hpp"
30 31 32 33
#include "err.hpp"
#include "plain_mechanism.hpp"
#include "wire.hpp"

34
zmq::plain_mechanism_t::plain_mechanism_t (session_base_t *session_,
35
                                           const std::string &peer_address_,
36
                                           const options_t &options_) :
37
    mechanism_t (options_),
38
    session (session_),
39
    peer_address (peer_address_),
40
    expecting_zap_reply (false),
41
    state (options.as_server? waiting_for_hello: sending_hello)
42 43 44 45 46 47 48
{
}

zmq::plain_mechanism_t::~plain_mechanism_t ()
{
}

49
int zmq::plain_mechanism_t::next_handshake_command (msg_t *msg_)
50 51 52 53
{
    int rc = 0;

    switch (state) {
54
        case sending_hello:
55
            rc = produce_hello (msg_);
56 57 58 59
            if (rc == 0)
                state = waiting_for_welcome;
            break;
        case sending_welcome:
60
            rc = produce_welcome (msg_);
61 62 63 64
            if (rc == 0)
                state = waiting_for_initiate;
            break;
        case sending_initiate:
65
            rc = produce_initiate (msg_);
66 67 68 69
            if (rc == 0)
                state = waiting_for_ready;
            break;
        case sending_ready:
70
            rc = produce_ready (msg_);
71 72 73 74 75 76
            if (rc == 0)
                state = ready;
            break;
        default:
            errno = EAGAIN;
            rc = -1;
77 78 79 80
    }
    return rc;
}

81
int zmq::plain_mechanism_t::process_handshake_command (msg_t *msg_)
82 83 84 85
{
    int rc = 0;

    switch (state) {
86
        case waiting_for_hello:
87
            rc = process_hello (msg_);
88 89
            if (rc == 0)
                state = expecting_zap_reply? waiting_for_zap_reply: sending_welcome;
90 91
            break;
        case waiting_for_welcome:
92
            rc = process_welcome (msg_);
93 94 95 96
            if (rc == 0)
                state = sending_initiate;
            break;
        case waiting_for_initiate:
97
            rc = process_initiate (msg_);
98 99 100 101
            if (rc == 0)
                state = sending_ready;
            break;
        case waiting_for_ready:
102
            rc = process_ready (msg_);
103 104 105 106
            if (rc == 0)
                state = ready;
            break;
        default:
107
            errno = EPROTO;
108
            rc = -1;
109
            break;
110 111 112 113 114 115 116
    }
    if (rc == 0) {
        rc = msg_->close ();
        errno_assert (rc == 0);
        rc = msg_->init ();
        errno_assert (rc == 0);
    }
117
    return rc;
118 119 120 121 122 123 124
}

bool zmq::plain_mechanism_t::is_handshake_complete () const
{
    return state == ready;
}

125 126 127 128 129 130 131 132 133 134 135
int zmq::plain_mechanism_t::zap_msg_available ()
{
    if (state != waiting_for_zap_reply) {
        errno = EFSM;
        return -1;
    }
    const int rc = receive_and_process_zap_reply ();
    if (rc == 0)
        state = sending_welcome;
    return rc;
}
136

137
int zmq::plain_mechanism_t::produce_hello (msg_t *msg_) const
138
{
139
    const std::string username = options.plain_username;
140
    zmq_assert (username.length () < 256);
141 142

    const std::string password = options.plain_password;
143 144
    zmq_assert (password.length () < 256);

145
    const size_t command_size = 6 + 1 + username.length ()
Pieter Hintjens's avatar
Pieter Hintjens committed
146
                                  + 1 + password.length ();
147 148 149 150

    const int rc = msg_->init_size (command_size);
    errno_assert (rc == 0);

Pieter Hintjens's avatar
Pieter Hintjens committed
151
    unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
152 153
    memcpy (ptr, "\x05HELLO", 6);
    ptr += 6;
Martin Hurton's avatar
Martin Hurton committed
154

155 156 157
    *ptr++ = static_cast <unsigned char> (username.length ());
    memcpy (ptr, username.c_str (), username.length ());
    ptr += username.length ();
Martin Hurton's avatar
Martin Hurton committed
158

159 160 161 162 163 164 165
    *ptr++ = static_cast <unsigned char> (password.length ());
    memcpy (ptr, password.c_str (), password.length ());
    ptr += password.length ();

    return 0;
}

166

167
int zmq::plain_mechanism_t::process_hello (msg_t *msg_)
168 169 170 171
{
    const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
    size_t bytes_left = msg_->size ();

172
    if (bytes_left < 6 || memcmp (ptr, "\x05HELLO", 6)) {
173 174 175
        errno = EPROTO;
        return -1;
    }
176 177
    ptr += 6;
    bytes_left -= 6;
178 179 180 181 182

    if (bytes_left < 1) {
        errno = EPROTO;
        return -1;
    }
Martin Hurton's avatar
Martin Hurton committed
183
    const size_t username_length = static_cast <size_t> (*ptr++);
184 185 186 187 188 189 190 191 192 193 194 195 196 197
    bytes_left -= 1;

    if (bytes_left < username_length) {
        errno = EPROTO;
        return -1;
    }
    const std::string username = std::string ((char *) ptr, username_length);
    ptr += username_length;
    bytes_left -= username_length;

    if (bytes_left < 1) {
        errno = EPROTO;
        return -1;
    }
Martin Hurton's avatar
Martin Hurton committed
198
    const size_t password_length = static_cast <size_t> (*ptr++);
199 200 201 202 203 204 205 206 207 208 209 210 211 212
    bytes_left -= 1;

    if (bytes_left < password_length) {
        errno = EPROTO;
        return -1;
    }
    const std::string password = std::string ((char *) ptr, password_length);
    ptr += password_length;
    bytes_left -= password_length;

    if (bytes_left > 0) {
        errno = EPROTO;
        return -1;
    }
213

214
    //  Use ZAP protocol (RFC 27) to authenticate the user.
215
    int rc = session->zap_connect ();
216 217 218 219 220 221 222 223
    if (rc == 0) {
        send_zap_request (username, password);
        rc = receive_and_process_zap_reply ();
        if (rc != 0) {
            if (errno != EAGAIN)
                return -1;
            expecting_zap_reply = true;
        }
224 225
    }

226 227 228
    return 0;
}

229
int zmq::plain_mechanism_t::produce_welcome (msg_t *msg_) const
230 231 232
{
    const int rc = msg_->init_size (8);
    errno_assert (rc == 0);
233
    memcpy (msg_->data (), "\x07WELCOME", 8);
234 235 236
    return 0;
}

237
int zmq::plain_mechanism_t::process_welcome (msg_t *msg_)
238 239 240 241
{
    const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
    size_t bytes_left = msg_->size ();

242
    if (bytes_left != 8 ||  memcmp (ptr, "\x07WELCOME", 8)) {
243 244 245 246 247 248
        errno = EPROTO;
        return -1;
    }
    return 0;
}

249
int zmq::plain_mechanism_t::produce_initiate (msg_t *msg_) const
250 251 252 253 254 255 256
{
    unsigned char * const command_buffer = (unsigned char *) malloc (512);
    alloc_assert (command_buffer);

    unsigned char *ptr = command_buffer;

    //  Add mechanism string
257 258
    memcpy (ptr, "\x08INITIATE", 9);
    ptr += 9;
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280

    //  Add socket type property
    const char *socket_type = socket_type_string (options.type);
    ptr += add_property (ptr, "Socket-Type", socket_type, strlen (socket_type));

    //  Add identity property
    if (options.type == ZMQ_REQ
    ||  options.type == ZMQ_DEALER
    ||  options.type == ZMQ_ROUTER) {
        ptr += add_property (ptr, "Identity",
            options.identity, options.identity_size);
    }

    const size_t command_size = ptr - command_buffer;
    const int rc = msg_->init_size (command_size);
    errno_assert (rc == 0);
    memcpy (msg_->data (), command_buffer, command_size);
    free (command_buffer);

    return 0;
}

281
int zmq::plain_mechanism_t::process_initiate (msg_t *msg_)
282 283 284 285
{
    const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
    size_t bytes_left = msg_->size ();

286
    if (bytes_left < 9 || memcmp (ptr, "\x08INITIATE", 9)) {
287 288 289
        errno = EPROTO;
        return -1;
    }
290 291 292
    ptr += 9;
    bytes_left -= 9;
    return parse_metadata (ptr, bytes_left);
293 294
}

295
int zmq::plain_mechanism_t::produce_ready (msg_t *msg_) const
296 297 298 299 300 301
{
    unsigned char * const command_buffer = (unsigned char *) malloc (512);
    alloc_assert (command_buffer);

    unsigned char *ptr = command_buffer;

302
    //  Add command name
303
    memcpy (ptr, "\x05READY", 6);
304
    ptr += 6;
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326

    //  Add socket type property
    const char *socket_type = socket_type_string (options.type);
    ptr += add_property (ptr, "Socket-Type", socket_type, strlen (socket_type));

    //  Add identity property
    if (options.type == ZMQ_REQ
    ||  options.type == ZMQ_DEALER
    ||  options.type == ZMQ_ROUTER) {
        ptr += add_property (ptr, "Identity",
            options.identity, options.identity_size);
    }

    const size_t command_size = ptr - command_buffer;
    const int rc = msg_->init_size (command_size);
    errno_assert (rc == 0);
    memcpy (msg_->data (), command_buffer, command_size);
    free (command_buffer);

    return 0;
}

327
int zmq::plain_mechanism_t::process_ready (msg_t *msg_)
328 329 330 331
{
    const unsigned char *ptr = static_cast <unsigned char *> (msg_->data ());
    size_t bytes_left = msg_->size ();

332
    if (bytes_left < 6 || memcmp (ptr, "\x05READY", 6)) {
333 334 335
        errno = EPROTO;
        return -1;
    }
336 337 338
    ptr += 6;
    bytes_left -= 6;
    return parse_metadata (ptr, bytes_left);
339 340
}

341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
void zmq::plain_mechanism_t::send_zap_request (const std::string &username,
                                               const std::string &password)
{
    int rc;
    msg_t msg;

    //  Address delimiter frame
    rc = msg.init ();
    errno_assert (rc == 0);
    msg.set_flags (msg_t::more);
    rc = session->write_zap_msg (&msg);
    errno_assert (rc == 0);

    //  Version frame
    rc = msg.init_size (3);
    errno_assert (rc == 0);
    memcpy (msg.data (), "1.0", 3);
    msg.set_flags (msg_t::more);
    rc = session->write_zap_msg (&msg);
    errno_assert (rc == 0);

362
    //  Request id frame
363 364 365 366 367 368 369 370
    rc = msg.init_size (1);
    errno_assert (rc == 0);
    memcpy (msg.data (), "1", 1);
    msg.set_flags (msg_t::more);
    rc = session->write_zap_msg (&msg);
    errno_assert (rc == 0);

    //  Domain frame
371
    rc = msg.init_size (options.zap_domain.length ());
372
    errno_assert (rc == 0);
373
    memcpy (msg.data (), options.zap_domain.c_str (), options.zap_domain.length ());
374 375 376 377
    msg.set_flags (msg_t::more);
    rc = session->write_zap_msg (&msg);
    errno_assert (rc == 0);

378 379 380 381 382 383 384 385
    //  Address frame
    rc = msg.init_size (peer_address.length ());
    errno_assert (rc == 0);
    memcpy (msg.data (), peer_address.c_str (), peer_address.length ());
    msg.set_flags (msg_t::more);
    rc = session->write_zap_msg (&msg);
    errno_assert (rc == 0);

386
    //  Identity frame
387
    rc = msg.init_size (options.identity_size);
388
    errno_assert (rc == 0);
389 390 391 392 393
    memcpy (msg.data (), options.identity, options.identity_size);
    msg.set_flags (msg_t::more);
    rc = session->write_zap_msg (&msg);
    errno_assert (rc == 0);

394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
    //  Mechanism frame
    rc = msg.init_size (5);
    errno_assert (rc == 0);
    memcpy (msg.data (), "PLAIN", 5);
    msg.set_flags (msg_t::more);
    rc = session->write_zap_msg (&msg);
    errno_assert (rc == 0);

    //  Username frame
    rc = msg.init_size (username.length ());
    errno_assert (rc == 0);
    memcpy (msg.data (), username.c_str (), username.length ());
    msg.set_flags (msg_t::more);
    rc = session->write_zap_msg (&msg);
    errno_assert (rc == 0);

    //  Password frame
    rc = msg.init_size (password.length ());
    errno_assert (rc == 0);
    memcpy (msg.data (), password.c_str (), password.length ());
    rc = session->write_zap_msg (&msg);
    errno_assert (rc == 0);
}

418 419 420
int zmq::plain_mechanism_t::receive_and_process_zap_reply ()
{
    int rc = 0;
421
    msg_t msg [7];  //  ZAP reply consists of 7 frames
422

423 424
    //  Initialize all reply frames
    for (int i = 0; i < 7; i++) {
425 426 427 428
        rc = msg [i].init ();
        errno_assert (rc == 0);
    }

429
    for (int i = 0; i < 7; i++) {
430 431 432
        rc = session->read_zap_msg (&msg [i]);
        if (rc == -1)
            break;
433
        if ((msg [i].flags () & msg_t::more) == (i < 6? 0: msg_t::more)) {
434 435 436 437 438 439 440 441 442 443 444
            errno = EPROTO;
            rc = -1;
            break;
        }
    }

    if (rc != 0)
        goto error;

    //  Address delimiter frame
    if (msg [0].size () > 0) {
445
        rc = -1;
446 447 448 449 450 451
        errno = EPROTO;
        goto error;
    }

    //  Version frame
    if (msg [1].size () != 3 || memcmp (msg [1].data (), "1.0", 3)) {
452
        rc = -1;
453 454 455 456
        errno = EPROTO;
        goto error;
    }

457
    //  Request id frame
458
    if (msg [2].size () != 1 || memcmp (msg [2].data (), "1", 1)) {
459
        rc = -1;
460 461 462 463 464 465
        errno = EPROTO;
        goto error;
    }

    //  Status code frame
    if (msg [3].size () != 3 || memcmp (msg [3].data (), "200", 3)) {
466
        rc = -1;
467 468 469 470
        errno = EACCES;
        goto error;
    }

471 472 473
    //  Save user id
    set_user_id (msg [5].data (), msg [5].size ());

474 475 476 477
    //  Process metadata frame
    rc = parse_metadata (static_cast <const unsigned char*> (msg [6].data ()),
                         msg [6].size ());

478
error:
479
    for (int i = 0; i < 7; i++) {
480 481 482 483 484 485
        const int rc2 = msg [i].close ();
        errno_assert (rc2 == 0);
    }

    return rc;
}