curve_client.cpp 9.02 KB
Newer Older
1
/*
2
    Copyright (c) 2007-2016 Contributors as noted in the AUTHORS file
3

4
    This file is part of libzmq, the ZeroMQ core engine in C++.
5

6 7 8
    libzmq is free software; you can redistribute it and/or modify it under
    the terms of the GNU Lesser General Public License (LGPL) as published
    by the Free Software Foundation; either version 3 of the License, or
9 10
    (at your option) any later version.

11 12 13 14 15 16 17 18 19 20 21 22 23 24
    As a special exception, the Contributors give you permission to link
    this library with independent modules to produce an executable,
    regardless of the license terms of these independent modules, and to
    copy and distribute the resulting executable under terms of your choice,
    provided that you also meet, for each linked independent module, the
    terms and conditions of the license of that module. An independent
    module is a module which is not derived from or based on this library.
    If you modify this library, you must extend this exception to your
    version of the library.

    libzmq is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
    License for more details.
25 26 27 28 29

    You should have received a copy of the GNU Lesser General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

30 31
#include "precompiled.hpp"
#include "macros.hpp"
32

33
#ifdef ZMQ_HAVE_CURVE
34 35 36 37 38 39

#include "msg.hpp"
#include "session_base.hpp"
#include "err.hpp"
#include "curve_client.hpp"
#include "wire.hpp"
40
#include "curve_client_tools.hpp"
41

42 43
zmq::curve_client_t::curve_client_t (session_base_t *session_,
                                     const options_t &options_) :
44
    mechanism_base_t (session_, options_),
45 46
    curve_mechanism_base_t (
      session_, options_, "CurveZMQMESSAGEC", "CurveZMQMESSAGES"),
47 48 49 50
    _state (send_hello),
    _tools (options_.curve_public_key,
            options_.curve_secret_key,
            options_.curve_server_key)
51 52 53 54 55 56 57
{
}

zmq::curve_client_t::~curve_client_t ()
{
}

58
int zmq::curve_client_t::next_handshake_command (msg_t *msg_)
59 60 61
{
    int rc = 0;

62
    switch (_state) {
63
        case send_hello:
64
            rc = produce_hello (msg_);
65
            if (rc == 0)
66
                _state = expect_welcome;
67 68
            break;
        case send_initiate:
69
            rc = produce_initiate (msg_);
70
            if (rc == 0)
71
                _state = expect_ready;
72 73 74 75 76 77 78 79
            break;
        default:
            errno = EAGAIN;
            rc = -1;
    }
    return rc;
}

80
int zmq::curve_client_t::process_handshake_command (msg_t *msg_)
81
{
82
    const unsigned char *msg_data =
83
      static_cast<unsigned char *> (msg_->data ());
84
    const size_t msg_size = msg_->size ();
85

86
    int rc = 0;
87
    if (curve_client_tools_t::is_handshake_command_welcome (msg_data, msg_size))
88
        rc = process_welcome (msg_data, msg_size);
89 90
    else if (curve_client_tools_t::is_handshake_command_ready (msg_data,
                                                               msg_size))
91
        rc = process_ready (msg_data, msg_size);
92 93
    else if (curve_client_tools_t::is_handshake_command_error (msg_data,
                                                               msg_size))
94 95
        rc = process_error (msg_data, msg_size);
    else {
96
        session->get_socket ()->event_handshake_failed_protocol (
97
          session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND);
98 99
        errno = EPROTO;
        rc = -1;
100
    }
101

102 103 104 105 106 107
    if (rc == 0) {
        rc = msg_->close ();
        errno_assert (rc == 0);
        rc = msg_->init ();
        errno_assert (rc == 0);
    }
108

109 110 111
    return rc;
}

112 113
int zmq::curve_client_t::encode (msg_t *msg_)
{
114
    zmq_assert (_state == connected);
115
    return curve_mechanism_base_t::encode (msg_);
116 117 118 119
}

int zmq::curve_client_t::decode (msg_t *msg_)
{
120
    zmq_assert (_state == connected);
121
    return curve_mechanism_base_t::decode (msg_);
122 123
}

124
zmq::mechanism_t::status_t zmq::curve_client_t::status () const
125
{
126
    if (_state == connected)
127
        return mechanism_t::ready;
128
    if (_state == error_received)
129 130 131
        return mechanism_t::error;
    else
        return mechanism_t::handshaking;
132 133
}

134
int zmq::curve_client_t::produce_hello (msg_t *msg_)
135
{
136 137
    int rc = msg_->init_size (200);
    errno_assert (rc == 0);
138

139
    rc = _tools.produce_hello (msg_->data (), cn_nonce);
140
    if (rc == -1) {
141
        session->get_socket ()->event_handshake_failed_protocol (
142 143 144 145
          session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);

        // TODO this is somewhat inconsistent: we call init_size, but we may
        // not close msg_; i.e. we assume that msg_ is initialized but empty
146
        // (if it were non-empty, calling init_size might cause a leak!)
147

148
        // msg_->close ();
149
        return -1;
150
    }
151

152
    cn_nonce++;
153 154 155 156

    return 0;
}

157 158
int zmq::curve_client_t::process_welcome (const uint8_t *msg_data_,
                                          size_t msg_size_)
159
{
160
    int rc = _tools.process_welcome (msg_data_, msg_size_, cn_precom);
161

162
    if (rc == -1) {
163
        session->get_socket ()->event_handshake_failed_protocol (
164
          session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
165

166 167 168 169
        errno = EPROTO;
        return -1;
    }

170
    _state = send_initiate;
171

172 173 174
    return 0;
}

175
int zmq::curve_client_t::produce_initiate (msg_t *msg_)
176
{
177 178
    const size_t metadata_length = basic_properties_len ();
    unsigned char *metadata_plaintext =
179
      static_cast<unsigned char *> (malloc (metadata_length));
180
    alloc_assert (metadata_plaintext);
181

182
    add_basic_properties (metadata_plaintext, metadata_length);
183

184 185
    size_t msg_size = 113 + 128 + crypto_box_BOXZEROBYTES + metadata_length;
    int rc = msg_->init_size (msg_size);
186 187
    errno_assert (rc == 0);

188 189
    rc = _tools.produce_initiate (msg_->data (), msg_size, cn_nonce,
                                  metadata_plaintext, metadata_length);
190 191 192 193

    free (metadata_plaintext);

    if (-1 == rc) {
194
        session->get_socket ()->event_handshake_failed_protocol (
195
          session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
196

197 198 199
        // TODO see comment in produce_hello
        return -1;
    }
200

201
    cn_nonce++;
202 203 204 205

    return 0;
}

206 207
int zmq::curve_client_t::process_ready (const uint8_t *msg_data_,
                                        size_t msg_size_)
208
{
209
    if (msg_size_ < 30) {
210 211 212
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY);
213 214 215 216
        errno = EPROTO;
        return -1;
    }

217
    const size_t clen = (msg_size_ - 14) + crypto_box_BOXZEROBYTES;
218

219
    uint8_t ready_nonce[crypto_box_NONCEBYTES];
220 221
    uint8_t *ready_plaintext =
      static_cast<uint8_t *> (malloc (crypto_box_ZEROBYTES + clen));
222 223
    alloc_assert (ready_plaintext);
    uint8_t *ready_box =
224
      static_cast<uint8_t *> (malloc (crypto_box_BOXZEROBYTES + 16 + clen));
225
    alloc_assert (ready_box);
226 227

    memset (ready_box, 0, crypto_box_BOXZEROBYTES);
228
    memcpy (ready_box + crypto_box_BOXZEROBYTES, msg_data_ + 14,
229
            clen - crypto_box_BOXZEROBYTES);
230 231

    memcpy (ready_nonce, "CurveZMQREADY---", 16);
232 233
    memcpy (ready_nonce + 16, msg_data_ + 6, 8);
    cn_peer_nonce = get_uint64 (msg_data_ + 6);
234

235 236
    int rc = crypto_box_open_afternm (ready_plaintext, ready_box, clen,
                                      ready_nonce, cn_precom);
237
    free (ready_box);
238 239

    if (rc != 0) {
240
        session->get_socket ()->event_handshake_failed_protocol (
241
          session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
242 243 244 245
        errno = EPROTO;
        return -1;
    }

246 247
    rc = parse_metadata (ready_plaintext + crypto_box_ZEROBYTES,
                         clen - crypto_box_ZEROBYTES);
248 249
    free (ready_plaintext);

250
    if (rc == 0)
251
        _state = connected;
252
    else {
253 254 255 256
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_METADATA);
        errno = EPROTO;
    }
257

258 259 260
    return rc;
}

261 262
int zmq::curve_client_t::process_error (const uint8_t *msg_data_,
                                        size_t msg_size_)
263
{
264
    if (_state != expect_welcome && _state != expect_ready) {
265 266
        session->get_socket ()->event_handshake_failed_protocol (
          session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND);
267 268 269
        errno = EPROTO;
        return -1;
    }
270
    if (msg_size_ < 7) {
271
        session->get_socket ()->event_handshake_failed_protocol (
272 273
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
274 275 276
        errno = EPROTO;
        return -1;
    }
277 278
    const size_t error_reason_len = static_cast<size_t> (msg_data_[6]);
    if (error_reason_len > msg_size_ - 7) {
279
        session->get_socket ()->event_handshake_failed_protocol (
280 281
          session->get_endpoint (),
          ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
282 283 284
        errno = EPROTO;
        return -1;
    }
285
    const char *error_reason = reinterpret_cast<const char *> (msg_data_) + 7;
286
    handle_error_reason (error_reason, error_reason_len);
287
    _state = error_received;
288 289 290
    return 0;
}

291
#endif