curve_client.cpp 8.95 KB
Newer Older
1
/*
2
    Copyright (c) 2007-2016 Contributors as noted in the AUTHORS file
3

4
    This file is part of libzmq, the ZeroMQ core engine in C++.
5

6 7 8
    libzmq is free software; you can redistribute it and/or modify it under
    the terms of the GNU Lesser General Public License (LGPL) as published
    by the Free Software Foundation; either version 3 of the License, or
9 10
    (at your option) any later version.

11 12 13 14 15 16 17 18 19 20 21 22 23 24
    As a special exception, the Contributors give you permission to link
    this library with independent modules to produce an executable,
    regardless of the license terms of these independent modules, and to
    copy and distribute the resulting executable under terms of your choice,
    provided that you also meet, for each linked independent module, the
    terms and conditions of the license of that module. An independent
    module is a module which is not derived from or based on this library.
    If you modify this library, you must extend this exception to your
    version of the library.

    libzmq is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
    License for more details.
25 26 27 28 29

    You should have received a copy of the GNU Lesser General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

30 31
#include "precompiled.hpp"
#include "macros.hpp"
32

33
#ifdef ZMQ_HAVE_CURVE
34 35 36 37 38 39

#include "msg.hpp"
#include "session_base.hpp"
#include "err.hpp"
#include "curve_client.hpp"
#include "wire.hpp"
40
#include "curve_client_tools.hpp"
41

42 43
zmq::curve_client_t::curve_client_t (session_base_t *session_,
                                     const options_t &options_) :
44
    mechanism_base_t (session_, options_),
45 46
    curve_mechanism_base_t (
      session_, options_, "CurveZMQMESSAGEC", "CurveZMQMESSAGES"),
47
    state (send_hello),
48 49
    tools (options_.curve_public_key,
           options_.curve_secret_key,
50
           options_.curve_server_key)
51 52 53 54 55 56 57
{
}

zmq::curve_client_t::~curve_client_t ()
{
}

58
int zmq::curve_client_t::next_handshake_command (msg_t *msg_)
59 60 61 62 63
{
    int rc = 0;

    switch (state) {
        case send_hello:
64
            rc = produce_hello (msg_);
65 66 67 68
            if (rc == 0)
                state = expect_welcome;
            break;
        case send_initiate:
69
            rc = produce_initiate (msg_);
70 71 72 73 74 75 76 77 78 79
            if (rc == 0)
                state = expect_ready;
            break;
        default:
            errno = EAGAIN;
            rc = -1;
    }
    return rc;
}

80
int zmq::curve_client_t::process_handshake_command (msg_t *msg_)
81
{
82 83 84
    const unsigned char *msg_data =
        static_cast <unsigned char *> (msg_->data ());
    const size_t msg_size = msg_->size ();
85

86
    int rc = 0;
87
    if (curve_client_tools_t::is_handshake_command_welcome (msg_data, msg_size))
88
        rc = process_welcome (msg_data, msg_size);
89 90
    else if (curve_client_tools_t::is_handshake_command_ready (msg_data,
                                                               msg_size))
91
        rc = process_ready (msg_data, msg_size);
92 93
    else if (curve_client_tools_t::is_handshake_command_error (msg_data,
                                                               msg_size))
94 95
        rc = process_error (msg_data, msg_size);
    else {
96 97 98
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND);
99 100
        errno = EPROTO;
        rc = -1;
101
    }
102

103 104 105 106 107 108
    if (rc == 0) {
        rc = msg_->close ();
        errno_assert (rc == 0);
        rc = msg_->init ();
        errno_assert (rc == 0);
    }
109

110 111 112
    return rc;
}

113 114 115
int zmq::curve_client_t::encode (msg_t *msg_)
{
    zmq_assert (state == connected);
116
    return curve_mechanism_base_t::encode (msg_);
117 118 119 120 121
}

int zmq::curve_client_t::decode (msg_t *msg_)
{
    zmq_assert (state == connected);
122
    return curve_mechanism_base_t::decode (msg_);
123 124
}

125
zmq::mechanism_t::status_t zmq::curve_client_t::status () const
126
{
127 128 129 130 131 132 133
    if (state == connected)
        return mechanism_t::ready;
    else
    if (state == error_received)
        return mechanism_t::error;
    else
        return mechanism_t::handshaking;
134 135
}

136
int zmq::curve_client_t::produce_hello (msg_t *msg_)
137
{
138 139
    int rc = msg_->init_size (200);
    errno_assert (rc == 0);
140

141 142
    rc = tools.produce_hello (msg_->data (), cn_nonce);
    if (rc == -1) {
143 144 145 146
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
      
147 148 149
        // TODO this is somewhat inconsistent: we call init_size, but we may 
        // not close msg_; i.e. we assume that msg_ is initialized but empty 
        // (if it were non-empty, calling init_size might cause a leak!)
150

151
        // msg_->close ();
152
        return -1;
153
    }
154

155
    cn_nonce++;
156 157 158 159

    return 0;
}

160 161
int zmq::curve_client_t::process_welcome (const uint8_t *msg_data,
                                          size_t msg_size)
162
{
163
    int rc = tools.process_welcome (msg_data, msg_size, cn_precom);
164

165
    if (rc == -1) {
166 167 168 169
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);

170 171 172 173
        errno = EPROTO;
        return -1;
    }

174 175
    state = send_initiate;

176 177 178
    return 0;
}

179
int zmq::curve_client_t::produce_initiate (msg_t *msg_)
180
{
181 182 183 184
    const size_t metadata_length = basic_properties_len ();
    unsigned char *metadata_plaintext =
      (unsigned char *) malloc (metadata_length);
    alloc_assert (metadata_plaintext);
185

186
    add_basic_properties (metadata_plaintext, metadata_length);
187

188 189
    size_t msg_size = 113 + 128 + crypto_box_BOXZEROBYTES + metadata_length;
    int rc = msg_->init_size (msg_size);
190 191
    errno_assert (rc == 0);

192 193 194 195 196 197
    rc = tools.produce_initiate (msg_->data (), msg_size, cn_nonce,
                                 metadata_plaintext, metadata_length);

    free (metadata_plaintext);

    if (-1 == rc) {
198 199 200 201
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);

202 203 204
        // TODO see comment in produce_hello
        return -1;
    }
205

206
    cn_nonce++;
207 208 209 210

    return 0;
}

211 212
int zmq::curve_client_t::process_ready (
        const uint8_t *msg_data, size_t msg_size)
213
{
214
    if (msg_size < 30) {
215 216 217
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY);
218 219 220 221
        errno = EPROTO;
        return -1;
    }

222
    const size_t clen = (msg_size - 14) + crypto_box_BOXZEROBYTES;
223 224

    uint8_t ready_nonce [crypto_box_NONCEBYTES];
225 226 227 228 229
    uint8_t *ready_plaintext = (uint8_t *) malloc (crypto_box_ZEROBYTES + clen);
    alloc_assert (ready_plaintext);
    uint8_t *ready_box =
      (uint8_t *) malloc (crypto_box_BOXZEROBYTES + 16 + clen);
    alloc_assert (ready_box);
230 231 232

    memset (ready_box, 0, crypto_box_BOXZEROBYTES);
    memcpy (ready_box + crypto_box_BOXZEROBYTES,
233
            msg_data + 14, clen - crypto_box_BOXZEROBYTES);
234 235

    memcpy (ready_nonce, "CurveZMQREADY---", 16);
236
    memcpy (ready_nonce + 16, msg_data + 6, 8);
237
    cn_peer_nonce = get_uint64(msg_data + 6);
238 239

    int rc = crypto_box_open_afternm (ready_plaintext, ready_box,
240
                                      clen, ready_nonce, cn_precom);
241
    free (ready_box);
242 243

    if (rc != 0) {
244 245 246
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
247 248 249 250
        errno = EPROTO;
        return -1;
    }

251 252
    rc = parse_metadata (ready_plaintext + crypto_box_ZEROBYTES,
                         clen - crypto_box_ZEROBYTES);
253 254
    free (ready_plaintext);

255 256
    if (rc == 0)
        state = connected;
257 258 259 260 261 262
    else
    {
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_METADATA);
        errno = EPROTO;
    }
263

264 265 266
    return rc;
}

267 268 269 270
int zmq::curve_client_t::process_error (
        const uint8_t *msg_data, size_t msg_size)
{
    if (state != expect_welcome && state != expect_ready) {
271 272
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND);
273 274 275 276
        errno = EPROTO;
        return -1;
    }
    if (msg_size < 7) {
277
        session->get_socket ()->event_handshake_failed_protocol (
278 279
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
280 281 282 283 284
        errno = EPROTO;
        return -1;
    }
    const size_t error_reason_len = static_cast <size_t> (msg_data [6]);
    if (error_reason_len > msg_size - 7) {
285
        session->get_socket ()->event_handshake_failed_protocol (
286 287
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
288 289 290
        errno = EPROTO;
        return -1;
    }
291 292
    const char *error_reason = reinterpret_cast<const char *> (msg_data) + 7;
    handle_error_reason (error_reason, error_reason_len);
293 294 295 296
    state = error_received;
    return 0;
}

297
#endif