Commit c8d8575a authored by Kenton Varda's avatar Kenton Varda

Implement WebSocket server-side handshake.

parent 8a099294
......@@ -1460,6 +1460,180 @@ KJ_TEST("WebSocket ping received during pong send") {
clientTask.wait(io.waitScope);
}
class TestWebSocketService final: public HttpService, private kj::TaskSet::ErrorHandler {
public:
TestWebSocketService(HttpHeaderTable& headerTable, HttpHeaderId hMyHeader)
: headerTable(headerTable), hMyHeader(hMyHeader), tasks(*this) {}
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override {
KJ_FAIL_ASSERT("can't get here");
}
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
HttpHeaders responseHeaders(headerTable);
KJ_IF_MAYBE(h, headers.get(hMyHeader)) {
responseHeaders.set(hMyHeader, kj::str("respond-", *h));
}
if (url == "/return-error") {
response.send(404, "Not Found", responseHeaders, uint64_t(0));
return kj::READY_NOW;
} else if (url == "/ws-inline") {
auto ws = response.acceptWebSocket(responseHeaders);
return doWebSocket(*ws, "start-inline").attach(kj::mv(ws));
} else if (url == "/ws-detached") {
auto ws = response.acceptWebSocket(responseHeaders);
tasks.add(doWebSocket(*ws, "start-detached").attach(kj::mv(ws)));
return kj::READY_NOW;
} else {
KJ_FAIL_ASSERT("unexpected path", url);
}
}
private:
HttpHeaderTable& headerTable;
HttpHeaderId hMyHeader;
kj::TaskSet tasks;
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, exception);
}
static kj::Promise<void> doWebSocket(WebSocket& ws, kj::StringPtr message) {
auto copy = kj::str(message);
return ws.send(copy).attach(kj::mv(copy))
.then([&ws]() {
return ws.receive();
}).then([&ws](WebSocket::Message&& message) {
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(str, kj::String) {
return doWebSocket(ws, kj::str("reply:", str));
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
return doWebSocket(ws, kj::str("reply:", data));
}
KJ_CASE_ONEOF(close, WebSocket::Close) {
auto reason = kj::str("close-reply:", close.reason);
return ws.close(close.code + 1, reason).attach(kj::mv(reason));
}
}
KJ_UNREACHABLE;
});
}
};
const char WEBSOCKET_REQUEST_HANDSHAKE[] =
" HTTP/1.1\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
"Sec-WebSocket-Version: 13\r\n"
"My-Header: foo\r\n"
"\r\n";
const char WEBSOCKET_RESPONSE_HANDSHAKE[] =
"HTTP/1.1 101 Switching Protocols\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"
"My-Header: respond-foo\r\n"
"\r\n";
const char WEBSOCKET_RESPONSE_HANDSHAKE_ERROR[] =
"HTTP/1.1 404 Not Found\r\n"
"Content-Length: 0\r\n"
"My-Header: respond-foo\r\n"
"\r\n";
const byte WEBSOCKET_FIRST_MESSAGE_INLINE[] = "\x81\x0c" "start-inline";
const byte WEBSOCKET_FIRST_MESSAGE_DETACHED[] = "\x81\x0e" "start-detached";
const byte WEBSOCKET_SEND_MESSAGE[] = "\x81\x03" "bar";
const byte WEBSOCKET_REPLY_MESSAGE[] = "\x81\x09" "reply:bar";
const byte WEBSOCKET_SEND_CLOSE[] = "\x88\x05\x12\x34" "qux";
const byte WEBSOCKET_REPLY_CLOSE[] = "\x88\x11\x12\x35" "close-reply:qux";
template <size_t s>
kj::ArrayPtr<const byte> nulterm(const byte (&bytes)[s]) {
// Ugh, the byte arrays defined above all end up having NUL terminators because I specified them
// as string literals. This function will remove the NUL.
return kj::ArrayPtr<const byte>(bytes, s - 1);
}
KJ_TEST("HttpServer WebSocket handshake") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
TestWebSocketService service(*headerTable, hMyHeader);
HttpServer server(io.provider->getTimer(), *headerTable, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
auto request = kj::str("GET /ws-inline", WEBSOCKET_REQUEST_HANDSHAKE);
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_FIRST_MESSAGE_INLINE)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_MESSAGE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_MESSAGE)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_CLOSE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_CLOSE)).wait(io.waitScope);
listenTask.wait(io.waitScope);
}
KJ_TEST("HttpServer WebSocket handshake detached") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
TestWebSocketService service(*headerTable, hMyHeader);
HttpServer server(io.provider->getTimer(), *headerTable, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
auto request = kj::str("GET /ws-detached", WEBSOCKET_REQUEST_HANDSHAKE);
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE).wait(io.waitScope);
listenTask.wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_FIRST_MESSAGE_DETACHED)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_MESSAGE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_MESSAGE)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_CLOSE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_CLOSE)).wait(io.waitScope);
}
KJ_TEST("HttpServer WebSocket handshake error") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
TestWebSocketService service(*headerTable, hMyHeader);
HttpServer server(io.provider->getTimer(), *headerTable, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
auto request = kj::str("GET /return-error", WEBSOCKET_REQUEST_HANDSHAKE);
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(io.waitScope);
// Can send more requests!
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(io.waitScope);
pipe.ends[1]->shutdownWrite();
listenTask.wait(io.waitScope);
}
// -----------------------------------------------------------------------------
KJ_TEST("HttpServer request timeout") {
......
This diff is collapsed.
......@@ -86,7 +86,10 @@ namespace kj {
MACRO(te, "TE") \
MACRO(trailer, "Trailer") \
MACRO(transferEncoding, "Transfer-Encoding") \
MACRO(upgrade, "Upgrade")
MACRO(upgrade, "Upgrade") \
MACRO(websocketKey, "Sec-WebSocket-Key") \
MACRO(websocketVersion, "Sec-WebSocket-Version") \
MACRO(websocketAccept, "Sec-WebSocket-Accept")
enum class HttpMethod {
// Enum of known HTTP methods.
......@@ -523,12 +526,10 @@ public:
class WebSocketResponse: public Response {
public:
kj::Own<WebSocket> acceptWebSocket(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers);
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// Accept and open the WebSocket.
//
// `statusText` and `headers` need only remain valid until acceptWebSocket() returns (they can
// be stack-allocated).
// `headers` need only remain valid until acceptWebSocket() returns (it can be stack-allocated).
};
virtual kj::Promise<void> openWebSocket(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment