Commit 41b9af2c authored by somdoron's avatar somdoron

problem: WS transport doesn't support mechanism

Solution: add support to mechanism
parent 7296fb5b
...@@ -60,6 +60,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. ...@@ -60,6 +60,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
#include "random.hpp" #include "random.hpp"
#include "ws_decoder.hpp" #include "ws_decoder.hpp"
#include "ws_encoder.hpp" #include "ws_encoder.hpp"
#include "null_mechanism.hpp"
#include "plain_server.hpp"
#include "plain_client.hpp"
#ifdef ZMQ_HAVE_CURVE
#include "curve_client.hpp"
#include "curve_server.hpp"
#endif
#ifdef ZMQ_HAVE_WINDOWS #ifdef ZMQ_HAVE_WINDOWS
#define strcasecmp _stricmp #define strcasecmp _stricmp
...@@ -99,10 +107,8 @@ zmq::ws_engine_t::ws_engine_t (fd_t fd_, ...@@ -99,10 +107,8 @@ zmq::ws_engine_t::ws_engine_t (fd_t fd_,
memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1); memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1);
memset (_websocket_protocol, 0, MAX_HEADER_VALUE_LENGTH + 1); memset (_websocket_protocol, 0, MAX_HEADER_VALUE_LENGTH + 1);
_next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> ( _next_msg = &ws_engine_t::next_handshake_command;
&ws_engine_t::routing_id_msg); _process_msg = &ws_engine_t::process_handshake_command;
_process_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
&ws_engine_t::process_routing_id_msg);
} }
zmq::ws_engine_t::~ws_engine_t () zmq::ws_engine_t::~ws_engine_t ()
...@@ -112,6 +118,18 @@ zmq::ws_engine_t::~ws_engine_t () ...@@ -112,6 +118,18 @@ zmq::ws_engine_t::~ws_engine_t ()
void zmq::ws_engine_t::start_ws_handshake () void zmq::ws_engine_t::start_ws_handshake ()
{ {
if (_client) { if (_client) {
char protocol[21];
if (_options.mechanism == ZMQ_NULL)
strcpy (protocol, "ZWS2.0/NULL,ZWS2.0");
else if (_options.mechanism == ZMQ_PLAIN)
strcpy (protocol, "ZWS2.0/PLAIN");
#ifdef ZMQ_HAVE_CURVE
else if (_options.mechanism == ZMQ_CURVE)
strcpy (protocol, "ZWS2.0/CURVE");
#endif
else
assert (false);
unsigned char nonce[16]; unsigned char nonce[16];
int *p = (int *) nonce; int *p = (int *) nonce;
...@@ -131,9 +149,10 @@ void zmq::ws_engine_t::start_ws_handshake () ...@@ -131,9 +149,10 @@ void zmq::ws_engine_t::start_ws_handshake ()
"Upgrade: websocket\r\n" "Upgrade: websocket\r\n"
"Connection: Upgrade\r\n" "Connection: Upgrade\r\n"
"Sec-WebSocket-Key: %s\r\n" "Sec-WebSocket-Key: %s\r\n"
"Sec-WebSocket-Protocol: ZWS2.0\r\n" "Sec-WebSocket-Protocol: %s\r\n"
"Sec-WebSocket-Version: 13\r\n\r\n", "Sec-WebSocket-Version: 13\r\n\r\n",
_address.path (), _address.host (), _websocket_key); _address.path (), _address.host (), _websocket_key,
protocol);
assert (size > 0 && size < WS_BUFFER_SIZE); assert (size > 0 && size < WS_BUFFER_SIZE);
_outpos = _write_buffer; _outpos = _write_buffer;
_outsize = size; _outsize = size;
...@@ -177,6 +196,48 @@ int zmq::ws_engine_t::process_routing_id_msg (msg_t *msg_) ...@@ -177,6 +196,48 @@ int zmq::ws_engine_t::process_routing_id_msg (msg_t *msg_)
return 0; return 0;
} }
bool zmq::ws_engine_t::select_protocol (char *protocol)
{
if (_options.mechanism == ZMQ_NULL && (strcmp ("ZWS2.0", protocol) == 0)) {
_next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
&ws_engine_t::routing_id_msg);
_process_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
&ws_engine_t::process_routing_id_msg);
return true;
} else if (_options.mechanism == ZMQ_NULL
&& strcmp ("ZWS2.0/NULL", protocol) == 0) {
_mechanism = new (std::nothrow)
null_mechanism_t (session (), _peer_address, _options);
alloc_assert (_mechanism);
return true;
} else if (_options.mechanism == ZMQ_PLAIN
&& strcmp ("ZWS2.0/PLAIN", protocol) == 0) {
if (_options.as_server)
_mechanism = new (std::nothrow)
plain_server_t (session (), _peer_address, _options);
else
_mechanism =
new (std::nothrow) plain_client_t (session (), _options);
alloc_assert (_mechanism);
return true;
}
#ifdef ZMQ_HAVE_CURVE
else if (_options.mechanism == ZMQ_CURVE
&& strcmp ("ZWS2.0/CURVE", protocol) == 0) {
if (_options.as_server)
_mechanism = new (std::nothrow)
curve_server_t (session (), _peer_address, _options);
else
_mechanism =
new (std::nothrow) curve_client_t (session (), _options);
alloc_assert (_mechanism);
return true;
}
#endif
return false;
}
bool zmq::ws_engine_t::handshake () bool zmq::ws_engine_t::handshake ()
{ {
bool complete; bool complete;
...@@ -390,7 +451,7 @@ bool zmq::ws_engine_t::server_handshake () ...@@ -390,7 +451,7 @@ bool zmq::ws_engine_t::server_handshake ()
if (*p == ' ') if (*p == ' ')
p++; p++;
if (strcmp ("ZWS2.0", p) == 0) { if (select_protocol (p)) {
strcpy (_websocket_protocol, p); strcpy (_websocket_protocol, p);
break; break;
} }
...@@ -760,7 +821,7 @@ bool zmq::ws_engine_t::client_handshake () ...@@ -760,7 +821,7 @@ bool zmq::ws_engine_t::client_handshake ()
strcpy (_websocket_accept, _header_value); strcpy (_websocket_accept, _header_value);
else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name) else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
== 0) { == 0) {
if (strcmp ("ZWS2.0", _header_value) == 0) if (select_protocol (_header_value))
strcpy (_websocket_protocol, _header_value); strcpy (_websocket_protocol, _header_value);
} }
_client_handshake_state = client_header_field_cr; _client_handshake_state = client_header_field_cr;
......
...@@ -143,6 +143,8 @@ class ws_engine_t : public stream_engine_base_t ...@@ -143,6 +143,8 @@ class ws_engine_t : public stream_engine_base_t
int routing_id_msg (msg_t *msg_); int routing_id_msg (msg_t *msg_);
int process_routing_id_msg (msg_t *msg_); int process_routing_id_msg (msg_t *msg_);
bool select_protocol (char *protocol);
bool client_handshake (); bool client_handshake ();
bool server_handshake (); bool server_handshake ();
......
...@@ -106,6 +106,43 @@ void test_large_message () ...@@ -106,6 +106,43 @@ void test_large_message ()
test_context_socket_close (sb); test_context_socket_close (sb);
} }
void test_curve ()
{
char client_public[41];
char client_secret[41];
char server_public[41];
char server_secret[41];
TEST_ASSERT_SUCCESS_ERRNO (
zmq_curve_keypair (server_public, server_secret));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_curve_keypair (client_public, client_secret));
void *server = test_context_socket (ZMQ_REP);
int as_server = 1;
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (server, ZMQ_CURVE_SERVER, &as_server, sizeof (int)));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (server, ZMQ_CURVE_SECRETKEY, server_secret, 41));
TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (server, "ws://*:5556/roundtrip"));
void *client = test_context_socket (ZMQ_REQ);
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (client, ZMQ_CURVE_SERVERKEY, server_public, 41));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (client, ZMQ_CURVE_PUBLICKEY, client_public, 41));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (client, ZMQ_CURVE_SECRETKEY, client_secret, 41));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_connect (client, "ws://127.0.0.1:5556/roundtrip"));
bounce (server, client);
test_context_socket_close (client);
test_context_socket_close (server);
}
int main () int main ()
{ {
setup_test_environment (); setup_test_environment ();
...@@ -114,5 +151,9 @@ int main () ...@@ -114,5 +151,9 @@ int main ()
RUN_TEST (test_roundtrip); RUN_TEST (test_roundtrip);
RUN_TEST (test_short_message); RUN_TEST (test_short_message);
RUN_TEST (test_large_message); RUN_TEST (test_large_message);
if (zmq_has ("curve"))
RUN_TEST (test_curve);
return UNITY_END (); return UNITY_END ();
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment