curve_client.cpp 13.1 KB
Newer Older
1
/*
2
    Copyright (c) 2007-2014 Contributors as noted in the AUTHORS file
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

    This file is part of 0MQ.

    0MQ is free software; you can redistribute it and/or modify it under
    the terms of the GNU Lesser General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    0MQ is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "platform.hpp"

#ifdef HAVE_LIBSODIUM

#ifdef ZMQ_HAVE_WINDOWS
#include "windows.hpp"
#endif

#include "msg.hpp"
#include "session_base.hpp"
#include "err.hpp"
#include "curve_client.hpp"
#include "wire.hpp"

zmq::curve_client_t::curve_client_t (const options_t &options_) :
    mechanism_t (options_),
36
    state (send_hello),
37 38
    cn_nonce(1),
    cn_peer_nonce(1),
39
    sync()
40
{
41 42 43
    memcpy (public_key, options_.curve_public_key, crypto_box_PUBLICKEYBYTES);
    memcpy (secret_key, options_.curve_secret_key, crypto_box_SECRETKEYBYTES);
    memcpy (server_key, options_.curve_server_key, crypto_box_PUBLICKEYBYTES);
44 45 46 47 48 49
    scoped_lock_t lock (sync);
#if defined(HAVE_TWEETNACL)
    // allow opening of /dev/urandom
    unsigned char tmpbytes[4];
    randombytes(tmpbytes, 4);
#else
Frank's avatar
Frank committed
50
    // todo check return code
51
    sodium_init();
52
#endif
53

54 55
    //  Generate short-term key pair
    const int rc = crypto_box_keypair (cn_public, cn_secret);
56 57 58 59 60 61 62
    zmq_assert (rc == 0);
}

zmq::curve_client_t::~curve_client_t ()
{
}

63
int zmq::curve_client_t::next_handshake_command (msg_t *msg_)
64 65 66 67 68
{
    int rc = 0;

    switch (state) {
        case send_hello:
69
            rc = produce_hello (msg_);
70 71 72 73
            if (rc == 0)
                state = expect_welcome;
            break;
        case send_initiate:
74
            rc = produce_initiate (msg_);
75 76 77 78 79 80 81 82 83 84
            if (rc == 0)
                state = expect_ready;
            break;
        default:
            errno = EAGAIN;
            rc = -1;
    }
    return rc;
}

85
int zmq::curve_client_t::process_handshake_command (msg_t *msg_)
86
{
87 88 89
    const unsigned char *msg_data =
        static_cast <unsigned char *> (msg_->data ());
    const size_t msg_size = msg_->size ();
90

91 92 93 94 95 96 97 98 99 100 101 102
    int rc = 0;
    if (msg_size >= 8 && !memcmp (msg_data, "\7WELCOME", 8))
        rc = process_welcome (msg_data, msg_size);
    else
    if (msg_size >= 6 && !memcmp (msg_data, "\5READY", 6))
        rc = process_ready (msg_data, msg_size);
    else
    if (msg_size >= 6 && !memcmp (msg_data, "\5ERROR", 6))
        rc = process_error (msg_data, msg_size);
    else {
        errno = EPROTO;
        rc = -1;
103
    }
104

105 106 107 108 109 110
    if (rc == 0) {
        rc = msg_->close ();
        errno_assert (rc == 0);
        rc = msg_->init ();
        errno_assert (rc == 0);
    }
111

112 113 114
    return rc;
}

115 116 117 118 119 120 121 122 123 124
int zmq::curve_client_t::encode (msg_t *msg_)
{
    zmq_assert (state == connected);

    uint8_t flags = 0;
    if (msg_->flags () & msg_t::more)
        flags |= 0x01;

    uint8_t message_nonce [crypto_box_NONCEBYTES];
    memcpy (message_nonce, "CurveZMQMESSAGEC", 16);
125
    put_uint64 (message_nonce + 16, cn_nonce);
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

    const size_t mlen = crypto_box_ZEROBYTES + 1 + msg_->size ();

    uint8_t *message_plaintext = static_cast <uint8_t *> (malloc (mlen));
    alloc_assert (message_plaintext);

    memset (message_plaintext, 0, crypto_box_ZEROBYTES);
    message_plaintext [crypto_box_ZEROBYTES] = flags;
    memcpy (message_plaintext + crypto_box_ZEROBYTES + 1,
            msg_->data (), msg_->size ());

    uint8_t *message_box = static_cast <uint8_t *> (malloc (mlen));
    alloc_assert (message_box);

    int rc = crypto_box_afternm (message_box, message_plaintext,
141
                                 mlen, message_nonce, cn_precom);
142 143 144 145 146 147 148 149 150 151
    zmq_assert (rc == 0);

    rc = msg_->close ();
    zmq_assert (rc == 0);

    rc = msg_->init_size (16 + mlen - crypto_box_BOXZEROBYTES);
    zmq_assert (rc == 0);

    uint8_t *message = static_cast <uint8_t *> (msg_->data ());

152
    memcpy (message, "\x07MESSAGE", 8);
153
    memcpy (message + 8, message_nonce + 16, 8);
154 155 156 157 158 159
    memcpy (message + 16, message_box + crypto_box_BOXZEROBYTES,
            mlen - crypto_box_BOXZEROBYTES);

    free (message_plaintext);
    free (message_box);

160
    cn_nonce++;
161 162 163 164 165 166 167 168 169 170 171 172 173 174

    return 0;
}

int zmq::curve_client_t::decode (msg_t *msg_)
{
    zmq_assert (state == connected);

    if (msg_->size () < 33) {
        errno = EPROTO;
        return -1;
    }

    const uint8_t *message = static_cast <uint8_t *> (msg_->data ());
175
    if (memcmp (message, "\x07MESSAGE", 8)) {
176 177 178 179 180 181 182
        errno = EPROTO;
        return -1;
    }

    uint8_t message_nonce [crypto_box_NONCEBYTES];
    memcpy (message_nonce, "CurveZMQMESSAGES", 16);
    memcpy (message_nonce + 16, message + 8, 8);
183 184 185 186 187 188 189
    uint64_t nonce = get_uint64(message + 8);
    if (nonce <= cn_peer_nonce) {
        errno = EPROTO;
        return -1;
    }
    cn_peer_nonce = nonce;

190 191 192 193 194 195 196 197 198 199 200 201 202 203

    const size_t clen = crypto_box_BOXZEROBYTES + (msg_->size () - 16);

    uint8_t *message_plaintext = static_cast <uint8_t *> (malloc (clen));
    alloc_assert (message_plaintext);

    uint8_t *message_box = static_cast <uint8_t *> (malloc (clen));
    alloc_assert (message_box);

    memset (message_box, 0, crypto_box_BOXZEROBYTES);
    memcpy (message_box + crypto_box_BOXZEROBYTES,
            message + 16, msg_->size () - 16);

    int rc = crypto_box_open_afternm (message_plaintext, message_box,
204
                                      clen, message_nonce, cn_precom);
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
    if (rc == 0) {
        rc = msg_->close ();
        zmq_assert (rc == 0);

        rc = msg_->init_size (clen - 1 - crypto_box_ZEROBYTES);
        zmq_assert (rc == 0);

        const uint8_t flags = message_plaintext [crypto_box_ZEROBYTES];
        if (flags & 0x01)
            msg_->set_flags (msg_t::more);

        memcpy (msg_->data (),
                message_plaintext + crypto_box_ZEROBYTES + 1,
                msg_->size ());
    }
    else
        errno = EPROTO;

    free (message_plaintext);
    free (message_box);

    return rc;
}

229
zmq::mechanism_t::status_t zmq::curve_client_t::status () const
230
{
231 232 233 234 235 236 237
    if (state == connected)
        return mechanism_t::ready;
    else
    if (state == error_received)
        return mechanism_t::error;
    else
        return mechanism_t::handshaking;
238 239
}

240
int zmq::curve_client_t::produce_hello (msg_t *msg_)
241 242 243 244 245 246 247
{
    uint8_t hello_nonce [crypto_box_NONCEBYTES];
    uint8_t hello_plaintext [crypto_box_ZEROBYTES + 64];
    uint8_t hello_box [crypto_box_BOXZEROBYTES + 80];

    //  Prepare the full nonce
    memcpy (hello_nonce, "CurveZMQHELLO---", 16);
248
    put_uint64 (hello_nonce + 16, cn_nonce);
249 250 251 252 253 254

    //  Create Box [64 * %x0](C'->S)
    memset (hello_plaintext, 0, sizeof hello_plaintext);

    int rc = crypto_box (hello_box, hello_plaintext,
                         sizeof hello_plaintext,
255
                         hello_nonce, server_key, cn_secret);
256 257 258 259 260 261
    zmq_assert (rc == 0);

    rc = msg_->init_size (200);
    errno_assert (rc == 0);
    uint8_t *hello = static_cast <uint8_t *> (msg_->data ());

262
    memcpy (hello, "\x05HELLO", 6);
263
    //  CurveZMQ major and minor version numbers
264
    memcpy (hello + 6, "\1\0", 2);
265
    //  Anti-amplification padding
266
    memset (hello + 8, 0, 72);
267
    //  Client public connection key
268
    memcpy (hello + 80, cn_public, crypto_box_PUBLICKEYBYTES);
269 270 271 272 273
    //  Short nonce, prefixed by "CurveZMQHELLO---"
    memcpy (hello + 112, hello_nonce + 16, 8);
    //  Signature, Box [64 * %x0](C'->S)
    memcpy (hello + 120, hello_box + crypto_box_BOXZEROBYTES, 80);

274
    cn_nonce++;
275 276 277 278

    return 0;
}

279 280
int zmq::curve_client_t::process_welcome (
        const uint8_t *msg_data, size_t msg_size)
281
{
282
    if (msg_size != 168) {
283 284 285 286 287 288 289 290 291 292
        errno = EPROTO;
        return -1;
    }

    uint8_t welcome_nonce [crypto_box_NONCEBYTES];
    uint8_t welcome_plaintext [crypto_box_ZEROBYTES + 128];
    uint8_t welcome_box [crypto_box_BOXZEROBYTES + 144];

    //  Open Box [S' + cookie](C'->S)
    memset (welcome_box, 0, crypto_box_BOXZEROBYTES);
293
    memcpy (welcome_box + crypto_box_BOXZEROBYTES, msg_data + 24, 144);
294 295

    memcpy (welcome_nonce, "WELCOME-", 8);
296
    memcpy (welcome_nonce + 8, msg_data + 8, 16);
297 298 299

    int rc = crypto_box_open (welcome_plaintext, welcome_box,
                              sizeof welcome_box,
300
                              welcome_nonce, server_key, cn_secret);
301 302 303 304 305
    if (rc != 0) {
        errno = EPROTO;
        return -1;
    }

306 307
    memcpy (cn_server, welcome_plaintext + crypto_box_ZEROBYTES, 32);
    memcpy (cn_cookie, welcome_plaintext + crypto_box_ZEROBYTES + 32, 16 + 80);
308 309

    //  Message independent precomputation
310
    rc = crypto_box_beforenm (cn_precom, cn_server, cn_secret);
311 312
    zmq_assert (rc == 0);

313 314
    state = send_initiate;

315 316 317
    return 0;
}

318
int zmq::curve_client_t::produce_initiate (msg_t *msg_)
319 320
{
    uint8_t vouch_nonce [crypto_box_NONCEBYTES];
321 322
    uint8_t vouch_plaintext [crypto_box_ZEROBYTES + 64];
    uint8_t vouch_box [crypto_box_BOXZEROBYTES + 80];
323

324
    //  Create vouch = Box [C',S](C->S')
325
    memset (vouch_plaintext, 0, crypto_box_ZEROBYTES);
326
    memcpy (vouch_plaintext + crypto_box_ZEROBYTES, cn_public, 32);
327
    memcpy (vouch_plaintext + crypto_box_ZEROBYTES + 32, server_key, 32);
328 329 330 331 332 333

    memcpy (vouch_nonce, "VOUCH---", 8);
    randombytes (vouch_nonce + 8, 16);

    int rc = crypto_box (vouch_box, vouch_plaintext,
                         sizeof vouch_plaintext,
334
                         vouch_nonce, cn_server, secret_key);
335 336
    zmq_assert (rc == 0);

337
    //  Assume here that metadata is limited to 256 bytes
338
    uint8_t initiate_nonce [crypto_box_NONCEBYTES];
339 340
    uint8_t initiate_plaintext [crypto_box_ZEROBYTES + 128 + 256];
    uint8_t initiate_box [crypto_box_BOXZEROBYTES + 144 + 256];
341 342 343

    //  Create Box [C + vouch + metadata](C'->S')
    memset (initiate_plaintext, 0, crypto_box_ZEROBYTES);
344
    memcpy (initiate_plaintext + crypto_box_ZEROBYTES,
345
            public_key, 32);
346 347 348
    memcpy (initiate_plaintext + crypto_box_ZEROBYTES + 32,
            vouch_nonce + 8, 16);
    memcpy (initiate_plaintext + crypto_box_ZEROBYTES + 48,
349
            vouch_box + crypto_box_BOXZEROBYTES, 80);
350

351 352
    //  Metadata starts after vouch
    uint8_t *ptr = initiate_plaintext + crypto_box_ZEROBYTES + 128;
353 354 355 356 357 358 359 360 361

    //  Add socket type property
    const char *socket_type = socket_type_string (options.type);
    ptr += add_property (ptr, "Socket-Type", socket_type, strlen (socket_type));

    //  Add identity property
    if (options.type == ZMQ_REQ
    ||  options.type == ZMQ_DEALER
    ||  options.type == ZMQ_ROUTER)
Pieter Hintjens's avatar
Pieter Hintjens committed
362
        ptr += add_property (ptr, "Identity", options.identity, options.identity_size);
363 364 365 366

    const size_t mlen = ptr - initiate_plaintext;

    memcpy (initiate_nonce, "CurveZMQINITIATE", 16);
367
    put_uint64 (initiate_nonce + 16, cn_nonce);
368 369

    rc = crypto_box (initiate_box, initiate_plaintext,
370
                     mlen, initiate_nonce, cn_server, cn_secret);
371 372
    zmq_assert (rc == 0);

373
    rc = msg_->init_size (113 + mlen - crypto_box_BOXZEROBYTES);
374 375 376 377
    errno_assert (rc == 0);

    uint8_t *initiate = static_cast <uint8_t *> (msg_->data ());

378
    memcpy (initiate, "\x08INITIATE", 9);
379
    //  Cookie provided by the server in the WELCOME command
380
    memcpy (initiate + 9, cn_cookie, 96);
381
    //  Short nonce, prefixed by "CurveZMQINITIATE"
382
    memcpy (initiate + 105, initiate_nonce + 16, 8);
383
    //  Box [C + vouch + metadata](C'->S')
384
    memcpy (initiate + 113, initiate_box + crypto_box_BOXZEROBYTES,
385
            mlen - crypto_box_BOXZEROBYTES);
386
    cn_nonce++;
387 388 389 390

    return 0;
}

391 392
int zmq::curve_client_t::process_ready (
        const uint8_t *msg_data, size_t msg_size)
393
{
394
    if (msg_size < 30) {
395 396 397 398
        errno = EPROTO;
        return -1;
    }

399
    const size_t clen = (msg_size - 14) + crypto_box_BOXZEROBYTES;
400 401 402 403 404 405 406

    uint8_t ready_nonce [crypto_box_NONCEBYTES];
    uint8_t ready_plaintext [crypto_box_ZEROBYTES + 256];
    uint8_t ready_box [crypto_box_BOXZEROBYTES + 16 + 256];

    memset (ready_box, 0, crypto_box_BOXZEROBYTES);
    memcpy (ready_box + crypto_box_BOXZEROBYTES,
407
            msg_data + 14, clen - crypto_box_BOXZEROBYTES);
408 409

    memcpy (ready_nonce, "CurveZMQREADY---", 16);
410
    memcpy (ready_nonce + 16, msg_data + 6, 8);
411
    cn_peer_nonce = get_uint64(msg_data + 6);
412 413

    int rc = crypto_box_open_afternm (ready_plaintext, ready_box,
414
                                      clen, ready_nonce, cn_precom);
415 416 417 418 419 420

    if (rc != 0) {
        errno = EPROTO;
        return -1;
    }

421 422
    rc = parse_metadata (ready_plaintext + crypto_box_ZEROBYTES,
                         clen - crypto_box_ZEROBYTES);
423 424 425
    if (rc == 0)
        state = connected;

426 427 428
    return rc;
}

429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
int zmq::curve_client_t::process_error (
        const uint8_t *msg_data, size_t msg_size)
{
    if (state != expect_welcome && state != expect_ready) {
        errno = EPROTO;
        return -1;
    }
    if (msg_size < 7) {
        errno = EPROTO;
        return -1;
    }
    const size_t error_reason_len = static_cast <size_t> (msg_data [6]);
    if (error_reason_len > msg_size - 7) {
        errno = EPROTO;
        return -1;
    }
    state = error_received;
    return 0;
}

449
#endif