Commit c8d8575a authored by Kenton Varda's avatar Kenton Varda

Implement WebSocket server-side handshake.

parent 8a099294
......@@ -1460,6 +1460,180 @@ KJ_TEST("WebSocket ping received during pong send") {
clientTask.wait(io.waitScope);
}
class TestWebSocketService final: public HttpService, private kj::TaskSet::ErrorHandler {
public:
TestWebSocketService(HttpHeaderTable& headerTable, HttpHeaderId hMyHeader)
: headerTable(headerTable), hMyHeader(hMyHeader), tasks(*this) {}
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override {
KJ_FAIL_ASSERT("can't get here");
}
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
HttpHeaders responseHeaders(headerTable);
KJ_IF_MAYBE(h, headers.get(hMyHeader)) {
responseHeaders.set(hMyHeader, kj::str("respond-", *h));
}
if (url == "/return-error") {
response.send(404, "Not Found", responseHeaders, uint64_t(0));
return kj::READY_NOW;
} else if (url == "/ws-inline") {
auto ws = response.acceptWebSocket(responseHeaders);
return doWebSocket(*ws, "start-inline").attach(kj::mv(ws));
} else if (url == "/ws-detached") {
auto ws = response.acceptWebSocket(responseHeaders);
tasks.add(doWebSocket(*ws, "start-detached").attach(kj::mv(ws)));
return kj::READY_NOW;
} else {
KJ_FAIL_ASSERT("unexpected path", url);
}
}
private:
HttpHeaderTable& headerTable;
HttpHeaderId hMyHeader;
kj::TaskSet tasks;
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, exception);
}
static kj::Promise<void> doWebSocket(WebSocket& ws, kj::StringPtr message) {
auto copy = kj::str(message);
return ws.send(copy).attach(kj::mv(copy))
.then([&ws]() {
return ws.receive();
}).then([&ws](WebSocket::Message&& message) {
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(str, kj::String) {
return doWebSocket(ws, kj::str("reply:", str));
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
return doWebSocket(ws, kj::str("reply:", data));
}
KJ_CASE_ONEOF(close, WebSocket::Close) {
auto reason = kj::str("close-reply:", close.reason);
return ws.close(close.code + 1, reason).attach(kj::mv(reason));
}
}
KJ_UNREACHABLE;
});
}
};
const char WEBSOCKET_REQUEST_HANDSHAKE[] =
" HTTP/1.1\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
"Sec-WebSocket-Version: 13\r\n"
"My-Header: foo\r\n"
"\r\n";
const char WEBSOCKET_RESPONSE_HANDSHAKE[] =
"HTTP/1.1 101 Switching Protocols\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"
"My-Header: respond-foo\r\n"
"\r\n";
const char WEBSOCKET_RESPONSE_HANDSHAKE_ERROR[] =
"HTTP/1.1 404 Not Found\r\n"
"Content-Length: 0\r\n"
"My-Header: respond-foo\r\n"
"\r\n";
const byte WEBSOCKET_FIRST_MESSAGE_INLINE[] = "\x81\x0c" "start-inline";
const byte WEBSOCKET_FIRST_MESSAGE_DETACHED[] = "\x81\x0e" "start-detached";
const byte WEBSOCKET_SEND_MESSAGE[] = "\x81\x03" "bar";
const byte WEBSOCKET_REPLY_MESSAGE[] = "\x81\x09" "reply:bar";
const byte WEBSOCKET_SEND_CLOSE[] = "\x88\x05\x12\x34" "qux";
const byte WEBSOCKET_REPLY_CLOSE[] = "\x88\x11\x12\x35" "close-reply:qux";
template <size_t s>
kj::ArrayPtr<const byte> nulterm(const byte (&bytes)[s]) {
// Ugh, the byte arrays defined above all end up having NUL terminators because I specified them
// as string literals. This function will remove the NUL.
return kj::ArrayPtr<const byte>(bytes, s - 1);
}
KJ_TEST("HttpServer WebSocket handshake") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
TestWebSocketService service(*headerTable, hMyHeader);
HttpServer server(io.provider->getTimer(), *headerTable, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
auto request = kj::str("GET /ws-inline", WEBSOCKET_REQUEST_HANDSHAKE);
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_FIRST_MESSAGE_INLINE)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_MESSAGE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_MESSAGE)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_CLOSE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_CLOSE)).wait(io.waitScope);
listenTask.wait(io.waitScope);
}
KJ_TEST("HttpServer WebSocket handshake detached") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
TestWebSocketService service(*headerTable, hMyHeader);
HttpServer server(io.provider->getTimer(), *headerTable, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
auto request = kj::str("GET /ws-detached", WEBSOCKET_REQUEST_HANDSHAKE);
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE).wait(io.waitScope);
listenTask.wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_FIRST_MESSAGE_DETACHED)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_MESSAGE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_MESSAGE)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_CLOSE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_CLOSE)).wait(io.waitScope);
}
KJ_TEST("HttpServer WebSocket handshake error") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
TestWebSocketService service(*headerTable, hMyHeader);
HttpServer server(io.provider->getTimer(), *headerTable, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
auto request = kj::str("GET /return-error", WEBSOCKET_REQUEST_HANDSHAKE);
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(io.waitScope);
// Can send more requests!
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(io.waitScope);
pipe.ends[1]->shutdownWrite();
listenTask.wait(io.waitScope);
}
// -----------------------------------------------------------------------------
KJ_TEST("HttpServer request timeout") {
......
......@@ -24,9 +24,303 @@
#include <kj/parse/char.h>
#include <unordered_map>
#include <stdlib.h>
#include <kj/encoding.h>
namespace kj {
// =======================================================================================
// SHA-1 implementation from https://github.com/clibs/sha1
//
// The WebSocket standard depends on SHA-1. ARRRGGGHHHHH.
//
// Any old checksum would have served the purpose, or hell, even just returning the header
// verbatim. But NO, they decided to throw a whole complicated hash algorithm in there, AND
// THEY CHOSE A BROKEN ONE THAT WE OTHERWISE WOULDN'T NEED ANYMORE.
//
// TODO(cleanup): Move this to a shared hashing library. Maybe. Or maybe don't, becaues no one
// should be using SHA-1 anymore.
//
// THIS USAGE IS NOT SECURITY SENSITIVE. IF YOU REPORT A SECURITY ISSUE BECAUSE YOU SAW SHA1 IN THE
// SOURCE CODE I WILL MAKE FUN OF YOU.
/*
SHA-1 in C
By Steve Reid <steve@edmweb.com>
100% Public Domain
Test Vectors (from FIPS PUB 180-1)
"abc"
A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D
"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1
A million repetitions of "a"
34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F
*/
/* #define LITTLE_ENDIAN * This should be #define'd already, if true. */
/* #define SHA1HANDSOFF * Copies data before messing with it. */
#define SHA1HANDSOFF
typedef struct
{
uint32_t state[5];
uint32_t count[2];
unsigned char buffer[64];
} SHA1_CTX;
#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits))))
/* blk0() and blk() perform the initial expand. */
/* I got the idea of expanding during the round function from SSLeay */
#if BYTE_ORDER == LITTLE_ENDIAN
#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \
|(rol(block->l[i],8)&0x00FF00FF))
#elif BYTE_ORDER == BIG_ENDIAN
#define blk0(i) block->l[i]
#else
#error "Endianness not defined!"
#endif
#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \
^block->l[(i+2)&15]^block->l[i&15],1))
/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */
#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30);
#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30);
#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30);
/* Hash a single 512-bit block. This is the core of the algorithm. */
void SHA1Transform(
uint32_t state[5],
const unsigned char buffer[64]
)
{
uint32_t a, b, c, d, e;
typedef union
{
unsigned char c[64];
uint32_t l[16];
} CHAR64LONG16;
#ifdef SHA1HANDSOFF
CHAR64LONG16 block[1]; /* use array to appear as a pointer */
memcpy(block, buffer, 64);
#else
/* The following had better never be used because it causes the
* pointer-to-const buffer to be cast into a pointer to non-const.
* And the result is written through. I threw a "const" in, hoping
* this will cause a diagnostic.
*/
CHAR64LONG16 *block = (const CHAR64LONG16 *) buffer;
#endif
/* Copy context->state[] to working vars */
a = state[0];
b = state[1];
c = state[2];
d = state[3];
e = state[4];
/* 4 rounds of 20 operations each. Loop unrolled. */
R0(a, b, c, d, e, 0);
R0(e, a, b, c, d, 1);
R0(d, e, a, b, c, 2);
R0(c, d, e, a, b, 3);
R0(b, c, d, e, a, 4);
R0(a, b, c, d, e, 5);
R0(e, a, b, c, d, 6);
R0(d, e, a, b, c, 7);
R0(c, d, e, a, b, 8);
R0(b, c, d, e, a, 9);
R0(a, b, c, d, e, 10);
R0(e, a, b, c, d, 11);
R0(d, e, a, b, c, 12);
R0(c, d, e, a, b, 13);
R0(b, c, d, e, a, 14);
R0(a, b, c, d, e, 15);
R1(e, a, b, c, d, 16);
R1(d, e, a, b, c, 17);
R1(c, d, e, a, b, 18);
R1(b, c, d, e, a, 19);
R2(a, b, c, d, e, 20);
R2(e, a, b, c, d, 21);
R2(d, e, a, b, c, 22);
R2(c, d, e, a, b, 23);
R2(b, c, d, e, a, 24);
R2(a, b, c, d, e, 25);
R2(e, a, b, c, d, 26);
R2(d, e, a, b, c, 27);
R2(c, d, e, a, b, 28);
R2(b, c, d, e, a, 29);
R2(a, b, c, d, e, 30);
R2(e, a, b, c, d, 31);
R2(d, e, a, b, c, 32);
R2(c, d, e, a, b, 33);
R2(b, c, d, e, a, 34);
R2(a, b, c, d, e, 35);
R2(e, a, b, c, d, 36);
R2(d, e, a, b, c, 37);
R2(c, d, e, a, b, 38);
R2(b, c, d, e, a, 39);
R3(a, b, c, d, e, 40);
R3(e, a, b, c, d, 41);
R3(d, e, a, b, c, 42);
R3(c, d, e, a, b, 43);
R3(b, c, d, e, a, 44);
R3(a, b, c, d, e, 45);
R3(e, a, b, c, d, 46);
R3(d, e, a, b, c, 47);
R3(c, d, e, a, b, 48);
R3(b, c, d, e, a, 49);
R3(a, b, c, d, e, 50);
R3(e, a, b, c, d, 51);
R3(d, e, a, b, c, 52);
R3(c, d, e, a, b, 53);
R3(b, c, d, e, a, 54);
R3(a, b, c, d, e, 55);
R3(e, a, b, c, d, 56);
R3(d, e, a, b, c, 57);
R3(c, d, e, a, b, 58);
R3(b, c, d, e, a, 59);
R4(a, b, c, d, e, 60);
R4(e, a, b, c, d, 61);
R4(d, e, a, b, c, 62);
R4(c, d, e, a, b, 63);
R4(b, c, d, e, a, 64);
R4(a, b, c, d, e, 65);
R4(e, a, b, c, d, 66);
R4(d, e, a, b, c, 67);
R4(c, d, e, a, b, 68);
R4(b, c, d, e, a, 69);
R4(a, b, c, d, e, 70);
R4(e, a, b, c, d, 71);
R4(d, e, a, b, c, 72);
R4(c, d, e, a, b, 73);
R4(b, c, d, e, a, 74);
R4(a, b, c, d, e, 75);
R4(e, a, b, c, d, 76);
R4(d, e, a, b, c, 77);
R4(c, d, e, a, b, 78);
R4(b, c, d, e, a, 79);
/* Add the working vars back into context.state[] */
state[0] += a;
state[1] += b;
state[2] += c;
state[3] += d;
state[4] += e;
/* Wipe variables */
a = b = c = d = e = 0;
#ifdef SHA1HANDSOFF
memset(block, '\0', sizeof(block));
#endif
}
/* SHA1Init - Initialize new context */
void SHA1Init(
SHA1_CTX * context
)
{
/* SHA1 initialization constants */
context->state[0] = 0x67452301;
context->state[1] = 0xEFCDAB89;
context->state[2] = 0x98BADCFE;
context->state[3] = 0x10325476;
context->state[4] = 0xC3D2E1F0;
context->count[0] = context->count[1] = 0;
}
/* Run your data through this. */
void SHA1Update(
SHA1_CTX * context,
const unsigned char *data,
uint32_t len
)
{
uint32_t i;
uint32_t j;
j = context->count[0];
if ((context->count[0] += len << 3) < j)
context->count[1]++;
context->count[1] += (len >> 29);
j = (j >> 3) & 63;
if ((j + len) > 63)
{
memcpy(&context->buffer[j], data, (i = 64 - j));
SHA1Transform(context->state, context->buffer);
for (; i + 63 < len; i += 64)
{
SHA1Transform(context->state, &data[i]);
}
j = 0;
}
else
i = 0;
memcpy(&context->buffer[j], &data[i], len - i);
}
/* Add padding and return the message digest. */
void SHA1Final(
unsigned char digest[20],
SHA1_CTX * context
)
{
unsigned i;
unsigned char finalcount[8];
unsigned char c;
#if 0 /* untested "improvement" by DHR */
/* Convert context->count to a sequence of bytes
* in finalcount. Second element first, but
* big-endian order within element.
* But we do it all backwards.
*/
unsigned char *fcp = &finalcount[8];
for (i = 0; i < 2; i++)
{
uint32_t t = context->count[i];
int j;
for (j = 0; j < 4; t >>= 8, j++)
*--fcp = (unsigned char) t}
#else
for (i = 0; i < 8; i++)
{
finalcount[i] = (unsigned char) ((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */
}
#endif
c = 0200;
SHA1Update(context, &c, 1);
while ((context->count[0] & 504) != 448)
{
c = 0000;
SHA1Update(context, &c, 1);
}
SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */
for (i = 0; i < 20; i++)
{
digest[i] = (unsigned char)
((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255);
}
/* Wipe variables */
memset(context, '\0', sizeof(*context));
memset(&finalcount, '\0', sizeof(finalcount));
}
// End SHA-1 implementation.
// =======================================================================================
static const char* METHOD_NAMES[] = {
#define METHOD_NAME(id) #id,
KJ_HTTP_FOR_EACH_METHOD(METHOD_NAME)
......@@ -129,6 +423,9 @@ kj::Maybe<HttpMethod> tryParseHttpMethod(kj::StringPtr name) {
namespace {
constexpr char WEBSOCKET_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
// From RFC6455.
constexpr auto HTTP_SEPARATOR_CHARS = kj::parse::anyOfChars("()<>@,;:\\\"/[]?={} \t");
// RFC2616 section 2.2: https://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
......@@ -866,6 +1163,18 @@ public:
RequestOrResponse type, HttpMethod method, uint statusCode,
HttpHeaders::ConnectionHeaders& connectionHeaders);
struct ReleasedBuffer {
kj::Array<byte> buffer;
size_t filled;
};
ReleasedBuffer releaseBuffer() {
if (leftover.size() > 0) {
memmove(headerBuffer.begin(), leftover.begin(), leftover.size());
}
return { headerBuffer.releaseAsBytes(), leftover.size() };
}
private:
AsyncIoStream& inner;
kj::Array<char> headerBuffer;
......@@ -1552,8 +1861,10 @@ public:
WebSocketImpl(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator,
kj::Array<byte> buffer = kj::heapArray<byte>(4096),
size_t bytesAlreadyAvailable = 0)
size_t bytesAlreadyAvailable = 0,
kj::Maybe<kj::Promise<void>> waitBeforeSend = nullptr)
: stream(kj::mv(stream)), maskKeyGenerator(maskKeyGenerator),
sendingPong(kj::mv(waitBeforeSend)),
recvAvail(bytesAlreadyAvailable), recvBuffer(kj::mv(buffer)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
......@@ -1910,6 +2221,11 @@ private:
kj::Maybe<kj::Promise<void>> sendingPong;
// If a Pong is being sent asynchronously in response to a Ping, this is a promise for the
// completion of that send.
//
// Additionally, this member is used if we need to block our first send on WebSocket startup,
// e.g. because we need to wait for HTTP handshake writes to flush before we can start sending
// WebSocket data. `sendingPong` was overloaded for this use case because the logic is the same.
// Perhaps it should be renamed to `blockSend` or `writeQueue`.
uint fragmentOpcode = 0;
kj::Vector<kj::Array<byte>> fragments;
......@@ -1996,6 +2312,16 @@ private:
}
};
kj::Own<WebSocket> upgradeToWebSocket(
kj::Own<kj::AsyncIoStream> stream, HttpInputStream& httpInput, HttpOutputStream& httpOutput,
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator = nullptr) {
// Create a WebSocket upgraded from an HTTP stream.
auto releasedBuffer = httpInput.releaseBuffer();
return kj::heap<WebSocketImpl>(kj::mv(stream), maskKeyGenerator,
kj::mv(releasedBuffer.buffer), releasedBuffer.filled,
httpOutput.flush());
}
} // namespace
kj::Own<WebSocket> newWebSocket(kj::Own<kj::AsyncIoStream> stream,
......@@ -2156,14 +2482,8 @@ kj::Promise<kj::Own<kj::AsyncIoStream>> HttpService::connect(kj::StringPtr host)
KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService");
}
class HttpServer::Connection final: private HttpService::Response {
class HttpServer::Connection final: private HttpService::WebSocketResponse {
public:
Connection(HttpServer& server, kj::AsyncIoStream& stream)
: server(server),
httpInput(stream, server.requestHeaderTable),
httpOutput(stream) {
++server.connectionCount;
}
Connection(HttpServer& server, kj::Own<kj::AsyncIoStream>&& stream)
: server(server),
httpInput(*stream, server.requestHeaderTable),
......@@ -2247,7 +2567,29 @@ public:
}
KJ_IF_MAYBE(req, request) {
kj::Promise<void> promise = nullptr;
if (req->connectionHeaders.upgrade == "websocket") {
if (req->method != HttpMethod::GET) {
return sendError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
if (req->connectionHeaders.websocketVersion != "13") {
return sendError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
if (req->connectionHeaders.websocketKey == nullptr) {
return sendError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
currentMethod = HttpMethod::GET;
websocketKey = kj::str(req->connectionHeaders.websocketKey);
promise = server.service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
currentMethod = req->method;
websocketKey = nullptr;
auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, req->connectionHeaders);
......@@ -2256,13 +2598,26 @@ public:
// be able to shutdown the upstream but still wait on the downstream, but I believe many
// other HTTP servers do similar things.
auto promise = server.service.request(
promise = server.service.request(
req->method, req->url, httpInput.getHeaders(), *body, *this);
return promise.attach(kj::mv(body))
.then([this]() { return httpOutput.flush(); })
promise = promise.attach(kj::mv(body));
}
return promise
.then([this]() -> kj::Promise<void> {
// Response done. Await next request.
if (upgraded) {
// We've upgraded to WebSocket so we can exit this listen loop. In fact, we no longer
// own the stream.
//
// Note that the WebSocket itself also flush()es the httpOutput before writing any
// WebSocket content, but we should also make sure that we don't let the listen loop
// exit until that flush is done, since we can't destroy the HttpOutputStream in the
// meantime.
return httpOutput.flush();
}
if (currentMethod != nullptr) {
return sendError(500, "Internal Server Error", kj::str(
"ERROR: The HttpService did not generate a response."));
......@@ -2273,7 +2628,7 @@ public:
return httpOutput.flush();
}
return loop(false);
return httpOutput.flush().then([this]() { return loop(false); });
});
} else {
// Bad request.
......@@ -2322,8 +2677,10 @@ private:
HttpOutputStream httpOutput;
kj::Own<kj::AsyncIoStream> ownStream;
kj::Maybe<HttpMethod> currentMethod;
kj::Maybe<kj::String> websocketKey;
bool timedOut = false;
bool closed = false;
bool upgraded = false;
kj::Own<kj::AsyncOutputStream> send(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers,
......@@ -2331,6 +2688,12 @@ private:
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called startResponse()");
currentMethod = nullptr;
if (websocketKey != nullptr) {
// This was a WebSocket request but the upgrade wasn't accepted.
websocketKey = nullptr;
httpInput.finishRead();
}
HttpHeaders::ConnectionHeaders connectionHeaders;
kj::String lengthStr;
......@@ -2361,6 +2724,32 @@ private:
}
}
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
auto key = KJ_REQUIRE_NONNULL(kj::mv(websocketKey), "not a WebSocket request");
currentMethod = nullptr;
websocketKey = nullptr;
upgraded = true;
// WebSocket demands we do a SHA-1 here. ARRGHH WHY SHA-1 WHYYYYYY?
SHA1_CTX ctx;
byte digest[20];
SHA1Init(&ctx);
SHA1Update(&ctx, key.asBytes().begin(), key.size());
SHA1Update(&ctx, reinterpret_cast<const byte*>(WEBSOCKET_GUID), strlen(WEBSOCKET_GUID));
SHA1Final(digest, &ctx);
auto websocketAccept = kj::encodeBase64(digest);
HttpHeaders::ConnectionHeaders connectionHeaders;
connectionHeaders.websocketAccept = websocketAccept;
connectionHeaders.upgrade = "websocket";
connectionHeaders.connection = "Upgrade";
httpOutput.writeHeaders(headers.serializeResponse(
101, "Switching Protocols", connectionHeaders));
return upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput);
}
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::String body) {
auto bodySize = kj::str(body.size());
......
......@@ -86,7 +86,10 @@ namespace kj {
MACRO(te, "TE") \
MACRO(trailer, "Trailer") \
MACRO(transferEncoding, "Transfer-Encoding") \
MACRO(upgrade, "Upgrade")
MACRO(upgrade, "Upgrade") \
MACRO(websocketKey, "Sec-WebSocket-Key") \
MACRO(websocketVersion, "Sec-WebSocket-Version") \
MACRO(websocketAccept, "Sec-WebSocket-Accept")
enum class HttpMethod {
// Enum of known HTTP methods.
......@@ -523,12 +526,10 @@ public:
class WebSocketResponse: public Response {
public:
kj::Own<WebSocket> acceptWebSocket(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers);
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// Accept and open the WebSocket.
//
// `statusText` and `headers` need only remain valid until acceptWebSocket() returns (they can
// be stack-allocated).
// `headers` need only remain valid until acceptWebSocket() returns (it can be stack-allocated).
};
virtual kj::Promise<void> openWebSocket(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment