Unverified Commit 06bdebfe authored by Luca Boccassi's avatar Luca Boccassi Committed by GitHub

Merge pull request #3805 from sigiesec/curve-zerocopy

CURVE: Reduce number of memory allocations and message copies
parents f1513f96 4177bf74
......@@ -174,6 +174,7 @@ endif()
# To disable curve, use --disable-curve
option(WITH_LIBSODIUM "Use libsodium instead of built-in tweetnacl" OFF)
option(WITH_LIBSODIUM_STATIC "Use static libsodium library" OFF)
option(ENABLE_CURVE "Enable CURVE security" ON)
if(NOT ENABLE_CURVE)
......@@ -183,6 +184,9 @@ elseif(WITH_LIBSODIUM)
if(SODIUM_FOUND)
message(STATUS "Using libsodium for CURVE security")
include_directories(${SODIUM_INCLUDE_DIRS})
if(WITH_LIBSODIUM_STATIC)
add_compile_definitions(SODIUM_STATIC)
endif()
set(ZMQ_USE_LIBSODIUM 1)
set(ZMQ_HAVE_CURVE 1)
else()
......
......@@ -1085,7 +1085,8 @@ test_apps += \
unittests/unittest_mtrie \
unittests/unittest_ip_resolver \
unittests/unittest_udp_address \
unittests/unittest_radix_tree
unittests/unittest_radix_tree \
unittests/unittest_curve_encoding
unittests_unittest_poller_SOURCES = unittests/unittest_poller.cpp
unittests_unittest_poller_CPPFLAGS = -I$(top_srcdir)/src ${TESTUTIL_CPPFLAGS} $(CODE_COVERAGE_CPPFLAGS)
......@@ -1140,6 +1141,15 @@ unittests_unittest_radix_tree_LDADD = \
$(top_builddir)/src/.libs/libzmq.a \
${src_libzmq_la_LIBADD} \
$(CODE_COVERAGE_LDFLAGS)
unittests_unittest_curve_encoding_SOURCES = unittests/unittest_curve_encoding.cpp
unittests_unittest_curve_encoding_CPPFLAGS = -I$(top_srcdir)/src ${TESTUTIL_CPPFLAGS} $(CODE_COVERAGE_CPPFLAGS)
unittests_unittest_curve_encoding_CXXFLAGS = $(CODE_COVERAGE_CXXFLAGS)
unittests_unittest_curve_encoding_LDADD = \
${TESTUTIL_LIBS} \
$(top_builddir)/src/.libs/libzmq.a \
${src_libzmq_la_LIBADD} \
$(CODE_COVERAGE_LDFLAGS)
endif
check_PROGRAMS = ${test_apps}
......
......@@ -137,7 +137,7 @@ int zmq::curve_client_t::produce_hello (msg_t *msg_)
int rc = msg_->init_size (200);
errno_assert (rc == 0);
rc = _tools.produce_hello (msg_->data (), cn_nonce);
rc = _tools.produce_hello (msg_->data (), get_and_inc_nonce ());
if (rc == -1) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
......@@ -150,15 +150,14 @@ int zmq::curve_client_t::produce_hello (msg_t *msg_)
return -1;
}
cn_nonce++;
return 0;
}
int zmq::curve_client_t::process_welcome (const uint8_t *msg_data_,
size_t msg_size_)
{
const int rc = _tools.process_welcome (msg_data_, msg_size_, cn_precom);
const int rc = _tools.process_welcome (msg_data_, msg_size_,
get_writable_precom_buffer ());
if (rc == -1) {
session->get_socket ()->event_handshake_failed_protocol (
......@@ -186,7 +185,7 @@ int zmq::curve_client_t::produce_initiate (msg_t *msg_)
int rc = msg_->init_size (msg_size);
errno_assert (rc == 0);
rc = _tools.produce_initiate (msg_->data (), msg_size, cn_nonce,
rc = _tools.produce_initiate (msg_->data (), msg_size, get_and_inc_nonce (),
&metadata_plaintext[0], metadata_length);
if (-1 == rc) {
......@@ -197,8 +196,6 @@ int zmq::curve_client_t::produce_initiate (msg_t *msg_)
return -1;
}
cn_nonce++;
return 0;
}
......@@ -227,10 +224,10 @@ int zmq::curve_client_t::process_ready (const uint8_t *msg_data_,
memcpy (ready_nonce, "CurveZMQREADY---", 16);
memcpy (ready_nonce + 16, msg_data_ + 6, 8);
cn_peer_nonce = get_uint64 (msg_data_ + 6);
set_peer_nonce (get_uint64 (msg_data_ + 6));
int rc = crypto_box_open_afternm (&ready_plaintext[0], &ready_box[0], clen,
ready_nonce, cn_precom);
ready_nonce, get_precom_buffer ());
if (rc != 0) {
session->get_socket ()->event_handshake_failed_protocol (
......
This diff is collapsed.
......@@ -52,7 +52,40 @@
namespace zmq
{
class curve_mechanism_base_t : public virtual mechanism_base_t
class curve_encoding_t
{
public:
curve_encoding_t (const char *encode_nonce_prefix_,
const char *decode_nonce_prefix_);
int encode (msg_t *msg_);
int decode (msg_t *msg_, int *error_event_code_);
uint8_t *get_writable_precom_buffer () { return _cn_precom; }
const uint8_t *get_precom_buffer () const { return _cn_precom; }
typedef uint64_t nonce_t;
nonce_t get_and_inc_nonce () { return _cn_nonce++; }
void set_peer_nonce (nonce_t peer_nonce_) { _cn_peer_nonce = peer_nonce_; };
private:
int check_validity (msg_t *msg_, int *error_event_code_);
const char *_encode_nonce_prefix;
const char *_decode_nonce_prefix;
nonce_t _cn_nonce;
nonce_t _cn_peer_nonce;
// Intermediary buffer used to speed up boxing and unboxing.
uint8_t _cn_precom[crypto_box_BEFORENMBYTES];
ZMQ_NON_COPYABLE_NOR_MOVABLE (curve_encoding_t)
};
class curve_mechanism_base_t : public virtual mechanism_base_t,
public curve_encoding_t
{
public:
curve_mechanism_base_t (session_base_t *session_,
......@@ -63,16 +96,6 @@ class curve_mechanism_base_t : public virtual mechanism_base_t
// mechanism implementation
int encode (msg_t *msg_) ZMQ_OVERRIDE;
int decode (msg_t *msg_) ZMQ_OVERRIDE;
protected:
const char *encode_nonce_prefix;
const char *decode_nonce_prefix;
uint64_t cn_nonce;
uint64_t cn_peer_nonce;
// Intermediary buffer used to speed up boxing and unboxing.
uint8_t cn_precom[crypto_box_BEFORENMBYTES];
};
}
......
......@@ -181,7 +181,7 @@ int zmq::curve_server_t::process_hello (msg_t *msg_)
memcpy (hello_nonce, "CurveZMQHELLO---", 16);
memcpy (hello_nonce + 16, hello + 112, 8);
cn_peer_nonce = get_uint64 (hello + 112);
set_peer_nonce (get_uint64 (hello + 112));
memset (hello_box, 0, crypto_box_BOXZEROBYTES);
memcpy (hello_box + crypto_box_BOXZEROBYTES, hello + 120, 80);
......@@ -345,7 +345,7 @@ int zmq::curve_server_t::process_initiate (msg_t *msg_)
memcpy (initiate_nonce, "CurveZMQINITIATE", 16);
memcpy (initiate_nonce + 16, initiate + 105, 8);
cn_peer_nonce = get_uint64 (initiate + 105);
set_peer_nonce (get_uint64 (initiate + 105));
const uint8_t *client_key = &initiate_plaintext[crypto_box_ZEROBYTES];
......@@ -396,7 +396,8 @@ int zmq::curve_server_t::process_initiate (msg_t *msg_)
}
// Precompute connection secret from client key
rc = crypto_box_beforenm (cn_precom, _cn_client, _cn_secret);
rc = crypto_box_beforenm (get_writable_precom_buffer (), _cn_client,
_cn_secret);
zmq_assert (rc == 0);
// Given this is a backward-incompatible change, it's behind a socket
......@@ -449,13 +450,13 @@ int zmq::curve_server_t::produce_ready (msg_t *msg_)
const size_t mlen = ptr - &ready_plaintext[0];
memcpy (ready_nonce, "CurveZMQREADY---", 16);
put_uint64 (ready_nonce + 16, cn_nonce);
put_uint64 (ready_nonce + 16, get_and_inc_nonce ());
std::vector<uint8_t> ready_box (crypto_box_BOXZEROBYTES + 16
+ metadata_length);
int rc = crypto_box_afternm (&ready_box[0], &ready_plaintext[0], mlen,
ready_nonce, cn_precom);
ready_nonce, get_precom_buffer ());
zmq_assert (rc == 0);
rc = msg_->init_size (14 + mlen - crypto_box_BOXZEROBYTES);
......@@ -470,8 +471,6 @@ int zmq::curve_server_t::produce_ready (msg_t *msg_)
memcpy (ready + 14, &ready_box[crypto_box_BOXZEROBYTES],
mlen - crypto_box_BOXZEROBYTES);
cn_nonce++;
return 0;
}
......
......@@ -361,6 +361,30 @@ size_t zmq::msg_t::size () const
}
}
void zmq::msg_t::shrink (size_t new_size_)
{
// Check the validity of the message.
zmq_assert (check ());
zmq_assert (new_size_ <= size ());
switch (_u.base.type) {
case type_vsm:
_u.vsm.size = static_cast<unsigned char> (new_size_);
break;
case type_lmsg:
_u.lmsg.content->size = new_size_;
break;
case type_zclmsg:
_u.zclmsg.content->size = new_size_;
break;
case type_cmsg:
_u.cmsg.size = new_size_;
break;
default:
zmq_assert (false);
}
}
unsigned char zmq::msg_t::flags () const
{
return _u.base.flags;
......
......@@ -161,6 +161,8 @@ class msg_t
// references drops to 0, the message is closed and false is returned.
bool rm_refs (int refs_);
void shrink (size_t new_size_);
// Size in bytes of the largest message that is still copied around
// rather than being reference-counted.
enum
......
......@@ -61,6 +61,10 @@
// duplicated from fd.hpp
#ifdef ZMQ_HAVE_WINDOWS
#ifndef NOMINMAX
#define NOMINMAX // Macros min(a,b) and max(a,b)
#endif
#include <winsock2.h>
#include <ws2tcpip.h>
#include <stdexcept>
......
......@@ -8,6 +8,7 @@ set(unittests
unittest_ip_resolver
unittest_udp_address
unittest_radix_tree
unittest_curve_encoding
)
#if(ENABLE_DRAFTS)
......
/*
Copyright (c) 2018 Contributors as noted in the AUTHORS file
This file is part of 0MQ.
0MQ is free software; you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 3 of the License, or
(at your option) any later version.
0MQ is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "../tests/testutil_unity.hpp"
// TODO: remove this ugly hack
#ifdef close
#undef close
#endif
#include <curve_mechanism_base.hpp>
#include <msg.hpp>
#include <random.hpp>
#include <unity.h>
#include <vector>
void setUp ()
{
}
void tearDown ()
{
}
void test_roundtrip (zmq::msg_t *msg_)
{
#ifdef ZMQ_HAVE_CURVE
const std::vector<uint8_t> original (static_cast<uint8_t *> (msg_->data ()),
static_cast<uint8_t *> (msg_->data ())
+ msg_->size ());
zmq::curve_encoding_t encoding_client ("CurveZMQMESSAGEC",
"CurveZMQMESSAGES");
zmq::curve_encoding_t encoding_server ("CurveZMQMESSAGES",
"CurveZMQMESSAGEC");
uint8_t client_public[32];
uint8_t client_secret[32];
TEST_ASSERT_SUCCESS_ERRNO (
crypto_box_keypair (client_public, client_secret));
uint8_t server_public[32];
uint8_t server_secret[32];
TEST_ASSERT_SUCCESS_ERRNO (
crypto_box_keypair (server_public, server_secret));
TEST_ASSERT_SUCCESS_ERRNO (
crypto_box_beforenm (encoding_client.get_writable_precom_buffer (),
server_public, client_secret));
TEST_ASSERT_SUCCESS_ERRNO (
crypto_box_beforenm (encoding_server.get_writable_precom_buffer (),
client_public, server_secret));
TEST_ASSERT_SUCCESS_ERRNO (encoding_client.encode (msg_));
// TODO: This is hacky...
encoding_server.set_peer_nonce (0);
int error_event_code;
TEST_ASSERT_SUCCESS_ERRNO (
encoding_server.decode (msg_, &error_event_code));
TEST_ASSERT_EQUAL_INT (original.size (), msg_->size ());
if (!original.empty ()) {
TEST_ASSERT_EQUAL_UINT8_ARRAY (&original[0], msg_->data (),
original.size ());
}
#else
TEST_IGNORE_MESSAGE ("CURVE support is disabled");
#endif
}
void test_roundtrip_empty ()
{
zmq::msg_t msg;
msg.init ();
test_roundtrip (&msg);
msg.close ();
}
void test_roundtrip_small ()
{
zmq::msg_t msg;
msg.init_size (32);
memcpy (msg.data (), "0123456789ABCDEF0123456789ABCDEF", 32);
test_roundtrip (&msg);
msg.close ();
}
void test_roundtrip_large ()
{
zmq::msg_t msg;
msg.init_size (2048);
for (size_t pos = 0; pos < 2048; pos += 32) {
memcpy (static_cast<char *> (msg.data ()) + pos,
"0123456789ABCDEF0123456789ABCDEF", 32);
}
test_roundtrip (&msg);
msg.close ();
}
void test_roundtrip_empty_more ()
{
zmq::msg_t msg;
msg.init ();
msg.set_flags (zmq::msg_t::more);
test_roundtrip (&msg);
TEST_ASSERT_TRUE (msg.flags () & zmq::msg_t::more);
msg.close ();
}
int main ()
{
setup_test_environment ();
zmq::random_open ();
UNITY_BEGIN ();
RUN_TEST (test_roundtrip_empty);
RUN_TEST (test_roundtrip_small);
RUN_TEST (test_roundtrip_large);
RUN_TEST (test_roundtrip_empty_more);
zmq::random_close ();
return UNITY_END ();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment