Commit 7bfd9512 authored by Doron Somech's avatar Doron Somech

problem: ws_encoder allocate a new msg for masking

solution: if msg is not shared or constant, mask the message in place
parent 40de4539
......@@ -55,19 +55,24 @@ void zmq::ws_encoder_t::message_ready ()
{
int offset = 0;
_is_binary = false;
if (in_progress ()->is_ping ())
_tmp_buf[offset++] = 0x80 | zmq::ws_protocol_t::opcode_ping;
else if (in_progress ()->is_pong ())
_tmp_buf[offset++] = 0x80 | zmq::ws_protocol_t::opcode_pong;
else if (in_progress ()->is_close_cmd ())
_tmp_buf[offset++] = 0x80 | zmq::ws_protocol_t::opcode_close;
else
else {
_tmp_buf[offset++] = 0x82; // Final | binary
_is_binary = true;
}
_tmp_buf[offset] = _must_mask ? 0x80 : 0x00;
size_t size = in_progress ()->size ();
size++; // TODO: check if binary
if (_is_binary)
size++;
if (size <= 125)
_tmp_buf[offset++] |= static_cast<unsigned char> (size & 127);
......@@ -88,17 +93,17 @@ void zmq::ws_encoder_t::message_ready ()
offset += 4;
}
// TODO: check if binary
// Encode flags.
unsigned char protocol_flags = 0;
if (in_progress ()->flags () & msg_t::more)
protocol_flags |= ws_protocol_t::more_flag;
if (in_progress ()->flags () & msg_t::command)
protocol_flags |= ws_protocol_t::command_flag;
if (_is_binary) {
// Encode flags.
unsigned char protocol_flags = 0;
if (in_progress ()->flags () & msg_t::more)
protocol_flags |= ws_protocol_t::more_flag;
if (in_progress ()->flags () & msg_t::command)
protocol_flags |= ws_protocol_t::command_flag;
_tmp_buf[offset++] =
_must_mask ? protocol_flags ^ _mask[0] : protocol_flags;
_tmp_buf[offset++] =
_must_mask ? protocol_flags ^ _mask[0] : protocol_flags;
}
next_step (_tmp_buf, offset, &ws_encoder_t::size_ready, false);
}
......@@ -109,20 +114,23 @@ void zmq::ws_encoder_t::size_ready ()
assert (in_progress () != &_masked_msg);
const size_t size = in_progress ()->size ();
_masked_msg.close ();
_masked_msg.init_size (size);
int mask_index = 1; // TODO: check if binary message
unsigned char *dest =
static_cast<unsigned char *> (_masked_msg.data ());
unsigned char *src =
static_cast<unsigned char *> (in_progress ()->data ());
for (size_t i = 0, size = in_progress ()->size (); i < size;
++i, mask_index++)
unsigned char *dest = src;
// If msg is shared or data is constant we cannot mask in-place, allocate a new msg for it
if (in_progress ()->flags () & msg_t::shared
|| in_progress ()->is_cmsg ()) {
_masked_msg.close ();
_masked_msg.init_size (size);
dest = static_cast<unsigned char *> (_masked_msg.data ());
}
int mask_index = _is_binary ? 1 : 0;
for (size_t i = 0; i < size; ++i, mask_index++)
dest[i] = src[i] ^ _mask[mask_index % 4];
next_step (_masked_msg.data (), _masked_msg.size (),
&ws_encoder_t::message_ready, true);
next_step (dest, size, &ws_encoder_t::message_ready, true);
} else {
next_step (in_progress ()->data (), in_progress ()->size (),
&ws_encoder_t::message_ready, true);
......
......@@ -50,6 +50,7 @@ class ws_encoder_t ZMQ_FINAL : public encoder_base_t<ws_encoder_t>
bool _must_mask;
unsigned char _mask[4];
msg_t _masked_msg;
bool _is_binary;
ZMQ_NON_COPYABLE_NOR_MOVABLE (ws_encoder_t)
};
......
......@@ -215,6 +215,61 @@ void test_curve ()
test_context_socket_close (server);
}
void test_mask_shared_msg ()
{
char connect_address[MAX_SOCKET_STRING + strlen ("/mask-shared")];
size_t addr_length = sizeof (connect_address);
void *sb = test_context_socket (ZMQ_DEALER);
TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (sb, "ws://*:*/mask-shared"));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_getsockopt (sb, ZMQ_LAST_ENDPOINT, connect_address, &addr_length));
strcat (connect_address, "/mask-shared");
void *sc = test_context_socket (ZMQ_DEALER);
TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (sc, connect_address));
zmq_msg_t msg;
zmq_msg_init_size (
&msg, 255); // Message have to be long enough so it won't fit inside msg
unsigned char *data = (unsigned char *) zmq_msg_data (&msg);
for (int i = 0; i < 255; i++)
data[i] = i;
// Taking a copy to make the msg shared
zmq_msg_t copy;
zmq_msg_init (&copy);
zmq_msg_copy (&copy, &msg);
// Sending the shared msg
int rc = zmq_msg_send (&msg, sc, 0);
TEST_ASSERT_EQUAL_INT (255, rc);
// Recv the msg and check that it was masked correctly
rc = zmq_msg_recv (&msg, sb, 0);
TEST_ASSERT_EQUAL_INT (255, rc);
data = (unsigned char *) zmq_msg_data (&msg);
for (int i = 0; i < 255; i++)
TEST_ASSERT_EQUAL_INT (i, data[i]);
// Testing that copy was not masked
data = (unsigned char *) zmq_msg_data (&copy);
for (int i = 0; i < 255; i++)
TEST_ASSERT_EQUAL_INT (i, data[i]);
// Constant msg cannot be masked as well, as it is constant
rc = zmq_send_const (sc, "HELLO", 5, 0);
TEST_ASSERT_EQUAL_INT (5, rc);
recv_string_expect_success (sb, "HELLO", 0);
zmq_msg_close (&copy);
zmq_msg_close (&msg);
test_context_socket_close (sc);
test_context_socket_close (sb);
}
int main ()
{
setup_test_environment ();
......@@ -225,6 +280,7 @@ int main ()
RUN_TEST (test_short_message);
RUN_TEST (test_large_message);
RUN_TEST (test_heartbeat);
RUN_TEST (test_mask_shared_msg);
if (zmq_has ("curve"))
RUN_TEST (test_curve);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment