Commit 745f8a5c authored by Kenton Varda's avatar Kenton Varda

Implement WebSocket client-side handshake.

parent c8d8575a
......@@ -1174,8 +1174,8 @@ KJ_TEST("WebSocket core protocol") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = newWebSocket(kj::mv(pipe.ends[0]));
auto server = newWebSocket(kj::mv(pipe.ends[1]));
auto client = newWebSocket(kj::mv(pipe.ends[0]), nullptr);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
auto mediumString = kj::strArray(kj::repeat(kj::StringPtr("123456789"), 30), "");
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("123456789"), 10000), "");
......@@ -1235,7 +1235,7 @@ KJ_TEST("WebSocket fragmented") {
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
byte DATA[] = {
0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ',
......@@ -1256,20 +1256,21 @@ KJ_TEST("WebSocket fragmented") {
clientTask.wait(io.waitScope);
}
class ConstantMaskGenerator final: public WebSocket::MaskKeyGenerator {
class FakeEntropySource final: public EntropySource {
public:
void next(byte (&bytes)[4]) override {
bytes[0] = 12;
bytes[1] = 34;
bytes[2] = 56;
bytes[3] = 78;
void generate(kj::ArrayPtr<byte> buffer) override {
static constexpr byte DUMMY[4] = { 12, 34, 56, 78 };
for (auto i: kj::indices(buffer)) {
buffer[i] = DUMMY[i % sizeof(DUMMY)];
}
}
};
KJ_TEST("WebSocket masked") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
ConstantMaskGenerator maskGenerator;
FakeEntropySource maskGenerator;
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), maskGenerator);
......@@ -1298,7 +1299,7 @@ KJ_TEST("WebSocket unsolicited pong") {
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
byte DATA[] = {
0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ',
......@@ -1324,7 +1325,7 @@ KJ_TEST("WebSocket ping") {
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
// Be extra-annoying by having the ping arrive between fragments.
byte DATA[] = {
......@@ -1361,7 +1362,7 @@ KJ_TEST("WebSocket ping mid-send") {
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), "");
auto serverTask = server->send(bigString).eagerlyEvaluate(nullptr);
......@@ -1395,7 +1396,7 @@ KJ_TEST("WebSocket double-ping mid-send") {
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), "");
auto serverTask = server->send(bigString).eagerlyEvaluate(nullptr);
......@@ -1430,7 +1431,7 @@ KJ_TEST("WebSocket ping received during pong send") {
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
// Send a very large ping so that sending the pong takes a while. Then send a second ping
// immediately after.
......@@ -1529,7 +1530,7 @@ const char WEBSOCKET_REQUEST_HANDSHAKE[] =
" HTTP/1.1\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
"Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n"
"Sec-WebSocket-Version: 13\r\n"
"My-Header: foo\r\n"
"\r\n";
......@@ -1537,7 +1538,7 @@ const char WEBSOCKET_RESPONSE_HANDSHAKE[] =
"HTTP/1.1 101 Switching Protocols\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"
"Sec-WebSocket-Accept: pShtIFKT0s8RYZvnWY/CrjQD8CM=\r\n"
"My-Header: respond-foo\r\n"
"\r\n";
const char WEBSOCKET_RESPONSE_HANDSHAKE_ERROR[] =
......@@ -1545,18 +1546,123 @@ const char WEBSOCKET_RESPONSE_HANDSHAKE_ERROR[] =
"Content-Length: 0\r\n"
"My-Header: respond-foo\r\n"
"\r\n";
const byte WEBSOCKET_FIRST_MESSAGE_INLINE[] = "\x81\x0c" "start-inline";
const byte WEBSOCKET_FIRST_MESSAGE_DETACHED[] = "\x81\x0e" "start-detached";
const byte WEBSOCKET_SEND_MESSAGE[] = "\x81\x03" "bar";
const byte WEBSOCKET_REPLY_MESSAGE[] = "\x81\x09" "reply:bar";
const byte WEBSOCKET_SEND_CLOSE[] = "\x88\x05\x12\x34" "qux";
const byte WEBSOCKET_REPLY_CLOSE[] = "\x88\x11\x12\x35" "close-reply:qux";
const byte WEBSOCKET_FIRST_MESSAGE_INLINE[] =
{ 0x81, 0x0c, 's','t','a','r','t','-','i','n','l','i','n','e' };
const byte WEBSOCKET_FIRST_MESSAGE_DETACHED[] =
{ 0x81, 0x0e, 's','t','a','r','t','-','d','e','t','a','c','h','e','d' };
const byte WEBSOCKET_SEND_MESSAGE[] =
{ 0x81, 0x83, 12, 34, 56, 78, 'b'^12, 'a'^34, 'r'^56 };
const byte WEBSOCKET_REPLY_MESSAGE[] =
{ 0x81, 0x09, 'r','e','p','l','y',':','b','a','r' };
const byte WEBSOCKET_SEND_CLOSE[] =
{ 0x88, 0x85, 12, 34, 56, 78, 0x12^12, 0x34^34, 'q'^56, 'u'^78, 'x'^12 };
const byte WEBSOCKET_REPLY_CLOSE[] =
{ 0x88, 0x11, 0x12, 0x35, 'c','l','o','s','e','-','r','e','p','l','y',':','q','u','x' };
template <size_t s>
kj::ArrayPtr<const byte> nulterm(const byte (&bytes)[s]) {
// Ugh, the byte arrays defined above all end up having NUL terminators because I specified them
// as string literals. This function will remove the NUL.
return kj::ArrayPtr<const byte>(bytes, s - 1);
kj::ArrayPtr<const byte> asBytes(const char (&chars)[s]) {
return kj::ArrayPtr<const char>(chars, s - 1).asBytes();
}
KJ_TEST("HttpClient WebSocket handshake") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
auto serverTask = expectRead(*pipe.ends[1], request)
.then([&]() { return pipe.ends[1]->write({asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)}); })
.then([&]() { return pipe.ends[1]->write({WEBSOCKET_FIRST_MESSAGE_INLINE}); })
.then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_MESSAGE); })
.then([&]() { return pipe.ends[1]->write({WEBSOCKET_REPLY_MESSAGE}); })
.then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); })
.then([&]() { return pipe.ends[1]->write({WEBSOCKET_REPLY_CLOSE}); })
.eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
FakeEntropySource entropySource;
auto client = newHttpClient(*headerTable, *pipe.ends[0], entropySource);
kj::HttpHeaders headers(*headerTable);
headers.set(hMyHeader, "foo");
auto response = client->openWebSocket("/websocket", headers).wait(io.waitScope);
KJ_EXPECT(response.statusCode == 101);
KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText);
KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo");
KJ_ASSERT(response.webSocketOrBody.is<kj::Own<WebSocket>>());
auto ws = kj::mv(response.webSocketOrBody.get<kj::Own<WebSocket>>());
{
auto message = ws->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "start-inline");
}
ws->send(kj::StringPtr("bar")).wait(io.waitScope);
{
auto message = ws->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "reply:bar");
}
ws->close(0x1234, "qux").wait(io.waitScope);
{
auto message = ws->receive().wait(io.waitScope);
KJ_ASSERT(message.is<WebSocket::Close>());
KJ_EXPECT(message.get<WebSocket::Close>().code == 0x1235);
KJ_EXPECT(message.get<WebSocket::Close>().reason == "close-reply:qux");
}
serverTask.wait(io.waitScope);
}
KJ_TEST("HttpClient WebSocket error") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
auto serverTask = expectRead(*pipe.ends[1], request)
.then([&]() { return pipe.ends[1]->write({asBytes(WEBSOCKET_RESPONSE_HANDSHAKE_ERROR)}); })
.then([&]() { return expectRead(*pipe.ends[1], request); })
.then([&]() { return pipe.ends[1]->write({asBytes(WEBSOCKET_RESPONSE_HANDSHAKE_ERROR)}); })
.eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
FakeEntropySource entropySource;
auto client = newHttpClient(*headerTable, *pipe.ends[0], entropySource);
kj::HttpHeaders headers(*headerTable);
headers.set(hMyHeader, "foo");
{
auto response = client->openWebSocket("/websocket", headers).wait(io.waitScope);
KJ_EXPECT(response.statusCode == 404);
KJ_EXPECT(response.statusText == "Not Found", response.statusText);
KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo");
KJ_ASSERT(response.webSocketOrBody.is<kj::Own<AsyncInputStream>>());
}
{
auto response = client->openWebSocket("/websocket", headers).wait(io.waitScope);
KJ_EXPECT(response.statusCode == 404);
KJ_EXPECT(response.statusText == "Not Found", response.statusText);
KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo");
KJ_ASSERT(response.webSocketOrBody.is<kj::Own<AsyncInputStream>>());
}
serverTask.wait(io.waitScope);
}
KJ_TEST("HttpServer WebSocket handshake") {
......@@ -1575,11 +1681,11 @@ KJ_TEST("HttpServer WebSocket handshake") {
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_FIRST_MESSAGE_INLINE)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_MESSAGE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_MESSAGE)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_CLOSE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_CLOSE)).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(io.waitScope);
pipe.ends[1]->write({WEBSOCKET_SEND_MESSAGE}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_REPLY_MESSAGE).wait(io.waitScope);
pipe.ends[1]->write({WEBSOCKET_SEND_CLOSE}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE).wait(io.waitScope);
listenTask.wait(io.waitScope);
}
......@@ -1602,11 +1708,11 @@ KJ_TEST("HttpServer WebSocket handshake detached") {
listenTask.wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_FIRST_MESSAGE_DETACHED)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_MESSAGE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_MESSAGE)).wait(io.waitScope);
pipe.ends[1]->write({nulterm(WEBSOCKET_SEND_CLOSE)}).wait(io.waitScope);
expectRead(*pipe.ends[1], nulterm(WEBSOCKET_REPLY_CLOSE)).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_FIRST_MESSAGE_DETACHED).wait(io.waitScope);
pipe.ends[1]->write({WEBSOCKET_SEND_MESSAGE}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_REPLY_MESSAGE).wait(io.waitScope);
pipe.ends[1]->write({WEBSOCKET_SEND_CLOSE}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE).wait(io.waitScope);
}
KJ_TEST("HttpServer WebSocket handshake error") {
......
......@@ -426,6 +426,17 @@ namespace {
constexpr char WEBSOCKET_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
// From RFC6455.
static kj::String generateWebSocketAccept(kj::StringPtr key) {
// WebSocket demands we do a SHA-1 here. ARRGHH WHY SHA-1 WHYYYYYY?
SHA1_CTX ctx;
byte digest[20];
SHA1Init(&ctx);
SHA1Update(&ctx, key.asBytes().begin(), key.size());
SHA1Update(&ctx, reinterpret_cast<const byte*>(WEBSOCKET_GUID), strlen(WEBSOCKET_GUID));
SHA1Final(digest, &ctx);
return kj::encodeBase64(digest);
}
constexpr auto HTTP_SEPARATOR_CHARS = kj::parse::anyOfChars("()<>@,;:\\\"/[]?={} \t");
// RFC2616 section 2.2: https://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
......@@ -1165,14 +1176,11 @@ public:
struct ReleasedBuffer {
kj::Array<byte> buffer;
size_t filled;
kj::ArrayPtr<byte> leftover;
};
ReleasedBuffer releaseBuffer() {
if (leftover.size() > 0) {
memmove(headerBuffer.begin(), leftover.begin(), leftover.size());
}
return { headerBuffer.releaseAsBytes(), leftover.size() };
return { headerBuffer.releaseAsBytes(), leftover.asBytes() };
}
private:
......@@ -1859,13 +1867,13 @@ private:
class WebSocketImpl final: public WebSocket {
public:
WebSocketImpl(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator,
kj::Maybe<EntropySource&> maskKeyGenerator,
kj::Array<byte> buffer = kj::heapArray<byte>(4096),
size_t bytesAlreadyAvailable = 0,
kj::ArrayPtr<byte> leftover = nullptr,
kj::Maybe<kj::Promise<void>> waitBeforeSend = nullptr)
: stream(kj::mv(stream)), maskKeyGenerator(maskKeyGenerator),
sendingPong(kj::mv(waitBeforeSend)),
recvAvail(bytesAlreadyAvailable), recvBuffer(kj::mv(buffer)) {}
recvData(leftover), recvBuffer(kj::mv(buffer)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return sendImpl(OPCODE_BINARY, message);
......@@ -1893,14 +1901,22 @@ public:
}
kj::Promise<Message> receive() override {
auto& recvHeader = *reinterpret_cast<Header*>(recvBuffer.begin());
size_t headerSize = recvHeader.headerSize(recvAvail);
auto& recvHeader = *reinterpret_cast<Header*>(recvData.begin());
size_t headerSize = recvHeader.headerSize(recvData.size());
if (headerSize > recvData.size()) {
if (recvData.begin() != recvBuffer.begin()) {
// Move existing data to front of buffer.
if (recvData.size() > 0) {
memmove(recvBuffer.begin(), recvData.begin(), recvData.size());
}
recvData = recvBuffer.slice(0, recvData.size());
}
if (headerSize > recvAvail) {
return stream->tryRead(recvBuffer.begin() + recvAvail, 1, recvBuffer.size() - recvAvail)
return stream->tryRead(recvData.end(), 1, recvBuffer.end() - recvData.end())
.then([this](size_t actual) -> kj::Promise<Message> {
if (actual == 0) {
if (recvAvail) {
if (recvData.size() > 0) {
return KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in frame header");
} else {
// It's incorrect for the WebSocket to disconnect without sending `Close`.
......@@ -1909,11 +1925,13 @@ public:
}
}
recvAvail += actual;
recvData = recvBuffer.slice(0, recvData.size() + actual);
return receive();
});
}
recvData = recvData.slice(headerSize, recvData.size());
size_t payloadLen = recvHeader.getPayloadLen();
auto opcode = recvHeader.getOpcode();
......@@ -2012,26 +2030,22 @@ public:
}
});
if (headerSize + payloadLen <= recvAvail) {
if (payloadLen <= recvData.size()) {
// All data already received.
memcpy(payloadTarget, recvBuffer.begin() + headerSize, payloadLen);
size_t consumed = headerSize + payloadLen;
size_t remaining = recvAvail - consumed;
memmove(recvBuffer.begin(), recvBuffer.begin() + consumed, remaining);
recvAvail = remaining;
memcpy(payloadTarget, recvData.begin(), payloadLen);
recvData = recvData.slice(payloadLen, recvData.size());
return handleMessage();
} else {
// Need to read more data.
size_t consumed = recvAvail - headerSize;
memcpy(payloadTarget, recvBuffer.begin() + headerSize, consumed);
recvAvail = 0;
size_t remaining = payloadLen - consumed;
auto promise = stream->tryRead(payloadTarget + consumed, remaining, remaining)
memcpy(payloadTarget, recvData.begin(), recvData.size());
size_t remaining = payloadLen - recvData.size();
auto promise = stream->tryRead(payloadTarget + recvData.size(), remaining, remaining)
.then([remaining](size_t amount) {
if (amount < remaining) {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in message"));
}
});
recvData = nullptr;
return promise.then(kj::mv(handleMessage));
}
}
......@@ -2042,9 +2056,9 @@ private:
Mask(): maskBytes { 0, 0, 0, 0 } {}
Mask(const byte* ptr) { memcpy(maskBytes, ptr, 4); }
Mask(kj::Maybe<WebSocket::MaskKeyGenerator&> generator) {
Mask(kj::Maybe<EntropySource&> generator) {
KJ_IF_MAYBE(g, generator) {
g->next(maskBytes);
g->generate(maskBytes);
} else {
memset(maskBytes, 0, 4);
}
......@@ -2207,7 +2221,7 @@ private:
// ---------------------------------------------------------------------------
kj::Own<kj::AsyncIoStream> stream;
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator;
kj::Maybe<EntropySource&> maskKeyGenerator;
bool sendClosed = false;
bool currentlySending = false;
......@@ -2232,8 +2246,8 @@ private:
// If `fragments` is non-empty, we've already received some fragments of a message.
// `fragmentOpcode` is the original opcode.
uint recvAvail = 0;
kj::Array<byte> recvBuffer;
kj::ArrayPtr<byte> recvData;
kj::Promise<void> sendImpl(byte opcode, kj::ArrayPtr<const byte> message) {
KJ_REQUIRE(!sendClosed, "WebSocket already closed");
......@@ -2314,18 +2328,18 @@ private:
kj::Own<WebSocket> upgradeToWebSocket(
kj::Own<kj::AsyncIoStream> stream, HttpInputStream& httpInput, HttpOutputStream& httpOutput,
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator = nullptr) {
kj::Maybe<EntropySource&> maskKeyGenerator) {
// Create a WebSocket upgraded from an HTTP stream.
auto releasedBuffer = httpInput.releaseBuffer();
return kj::heap<WebSocketImpl>(kj::mv(stream), maskKeyGenerator,
kj::mv(releasedBuffer.buffer), releasedBuffer.filled,
kj::mv(releasedBuffer.buffer), releasedBuffer.leftover,
httpOutput.flush());
}
} // namespace
kj::Own<WebSocket> newWebSocket(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator) {
kj::Maybe<EntropySource&> maskKeyGenerator) {
return kj::heap<WebSocketImpl>(kj::mv(stream), maskKeyGenerator);
}
......@@ -2335,12 +2349,19 @@ namespace {
class HttpClientImpl final: public HttpClient {
public:
HttpClientImpl(HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& rawStream)
: httpInput(rawStream, responseHeaderTable),
httpOutput(rawStream) {}
HttpClientImpl(HttpHeaderTable& responseHeaderTable, kj::Own<kj::AsyncIoStream> rawStream,
kj::Maybe<EntropySource&> entropySource)
: httpInput(*rawStream, responseHeaderTable),
httpOutput(*rawStream),
ownStream(kj::mv(rawStream)),
entropySource(entropySource) {}
Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
KJ_REQUIRE(!upgraded,
"can't make further requests on this HttpClient because it has been or is in the process "
"of being upgraded");
HttpHeaders::ConnectionHeaders connectionHeaders;
kj::String lengthStr;
......@@ -2385,9 +2406,82 @@ public:
return { kj::mv(bodyStream), kj::mv(responsePromise) };
}
kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers) override {
KJ_REQUIRE(!upgraded,
"can't make further requests on this HttpClient because it has been or is in the process "
"of being upgraded");
// Mark upgraded for now, even though the upgrade could fail, because we can't allow pipelined
// requests in the meantime.
upgraded = true;
byte keyBytes[16];
KJ_ASSERT_NONNULL(this->entropySource,
"can't use openWebSocket() because no EntropySource was provided when creating the "
"HttpClient").generate(keyBytes);
auto keyBase64 = kj::encodeBase64(keyBytes);
HttpHeaders::ConnectionHeaders connectionHeaders;
connectionHeaders.connection = "Upgrade";
connectionHeaders.upgrade = "websocket";
connectionHeaders.websocketVersion = "13";
connectionHeaders.websocketKey = keyBase64;
httpOutput.writeHeaders(headers.serializeRequest(HttpMethod::GET, url, connectionHeaders));
// No entity-body.
httpOutput.finishBody();
return httpInput.readResponseHeaders()
.then(kj::mvCapture(keyBase64,
[this](kj::StringPtr keyBase64, kj::Maybe<HttpHeaders::Response>&& response)
-> HttpClient::WebSocketResponse {
KJ_IF_MAYBE(r, response) {
if (r->statusCode == 101) {
if (!fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
r->connectionHeaders.upgrade.cStr())) {
KJ_FAIL_REQUIRE("server returned incorrect Upgrade header; should be 'websocket'",
r->connectionHeaders.upgrade) { break; }
return HttpClient::WebSocketResponse();
}
auto expectedAccept = generateWebSocketAccept(keyBase64);
if (r->connectionHeaders.websocketAccept != expectedAccept) {
KJ_FAIL_REQUIRE("server returned incorrect Sec-WebSocket-Accept header",
r->connectionHeaders.websocketAccept, expectedAccept) { break; }
return HttpClient::WebSocketResponse();
}
return {
r->statusCode,
r->statusText,
&httpInput.getHeaders(),
upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, entropySource),
};
} else {
upgraded = false;
return {
r->statusCode,
r->statusText,
&httpInput.getHeaders(),
httpInput.getEntityBody(HttpInputStream::RESPONSE, HttpMethod::GET, r->statusCode,
r->connectionHeaders)
};
}
} else {
KJ_FAIL_REQUIRE("received invalid HTTP response") { break; }
return HttpClient::WebSocketResponse();
}
}));
}
private:
HttpInputStream httpInput;
HttpOutputStream httpOutput;
kj::Own<AsyncIoStream> ownStream;
kj::Maybe<EntropySource&> entropySource;
bool upgraded = false;
};
} // namespace
......@@ -2413,8 +2507,11 @@ kj::Promise<kj::Own<kj::AsyncIoStream>> HttpClient::connect(kj::StringPtr host)
}
kj::Own<HttpClient> newHttpClient(
HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream) {
return kj::heap<HttpClientImpl>(responseHeaderTable, stream);
HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream,
kj::Maybe<EntropySource&> entropySource) {
return kj::heap<HttpClientImpl>(responseHeaderTable,
kj::Own<kj::AsyncIoStream>(&stream, kj::NullDisposer::instance),
entropySource);
}
// =======================================================================================
......@@ -2569,7 +2666,8 @@ public:
KJ_IF_MAYBE(req, request) {
kj::Promise<void> promise = nullptr;
if (req->connectionHeaders.upgrade == "websocket") {
if (fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
req->connectionHeaders.upgrade.cStr())) {
if (req->method != HttpMethod::GET) {
return sendError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
......@@ -2730,14 +2828,7 @@ private:
websocketKey = nullptr;
upgraded = true;
// WebSocket demands we do a SHA-1 here. ARRGHH WHY SHA-1 WHYYYYYY?
SHA1_CTX ctx;
byte digest[20];
SHA1Init(&ctx);
SHA1Update(&ctx, key.asBytes().begin(), key.size());
SHA1Update(&ctx, reinterpret_cast<const byte*>(WEBSOCKET_GUID), strlen(WEBSOCKET_GUID));
SHA1Final(digest, &ctx);
auto websocketAccept = kj::encodeBase64(digest);
auto websocketAccept = generateWebSocketAccept(key);
HttpHeaders::ConnectionHeaders connectionHeaders;
connectionHeaders.websocketAccept = websocketAccept;
......@@ -2747,7 +2838,7 @@ private:
httpOutput.writeHeaders(headers.serializeResponse(
101, "Switching Protocols", connectionHeaders));
return upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput);
return upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, nullptr);
}
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::String body) {
......
......@@ -371,6 +371,16 @@ private:
// also add direct accessors for those headers.
};
class EntropySource {
// Interface for an object that generates entropy. Typically, cryptographically-random entropy
// is expected.
//
// TODO(cleanup): Put this somewhere more general.
public:
virtual void generate(kj::ArrayPtr<byte> buffer) = 0;
};
class WebSocket {
// Interface representincg an open WebSocket session.
//
......@@ -404,24 +414,6 @@ public:
virtual kj::Promise<Message> receive() = 0;
// Read one message from the WebSocket and return it. Can only call once at a time. Do not call
// again after EndOfStream is received.
class MaskKeyGenerator {
// Class for generating WebSocket packet masks keys. See RFC6455 to understand how masking is
// used in WebSockets.
//
// The RFC insists that mask keys must be crypto-random, but it is not crypto -- it's just a
// value to be XOR'd with each four bytes of the data, and the mask itself is transmitted in
// plaintext ahead of the message. Apparently the WebSocket designers imagined that a random
// mask would make mass surveillance via string matching more difficult, but in practice this
// seems like no more than a minor speedbump. The other purpose of the mask is to prevent dumb
// proxies and captive portals from getting confused, but even a global constant mask could
// accomplish that.
//
// KJ leaves it up to the application to decide how to generate masks.
public:
virtual void next(byte (&bytes)[4]) = 0;
};
};
class HttpClient {
......@@ -471,7 +463,7 @@ public:
uint statusCode;
kj::StringPtr statusText;
const HttpHeaders* headers;
kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>> upstreamOrBody;
kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>> webSocketOrBody;
// `statusText` and `headers` remain valid until `upstreamOrBody` is dropped.
};
virtual kj::Promise<WebSocketResponse> openWebSocket(
......@@ -545,15 +537,25 @@ public:
};
kj::Own<HttpClient> newHttpClient(HttpHeaderTable& responseHeaderTable, kj::Network& network,
kj::Maybe<kj::Network&> tlsNetwork = nullptr);
kj::Maybe<kj::Network&> tlsNetwork = nullptr,
kj::Maybe<EntropySource&> entropySource = nullptr);
// Creates a proxy HttpClient that connects to hosts over the given network.
//
// `responseHeaderTable` is used when parsing HTTP responses. Requests can use any header table.
//
// `tlsNetwork` is required to support HTTPS destination URLs. Otherwise, only HTTP URLs can be
// fetched.
kj::Own<HttpClient> newHttpClient(HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream);
//
// `entropySource` must be provided in order to use `openWebSocket`. If you don't need WebSockets,
// `entropySource` can be omitted. The WebSocket protocol uses random values to avoid triggering
// flaws (including security flaws) in certain HTTP proxy software. Specifically, entropy is used
// to generate the `Sec-WebSocket-Key` header and to generate frame masks. If you know that there
// are no broken or vulnerable proxies between you and the server, you can provide an dummy entropy
// source that doesn't generate real entropy (e.g. returning the same value every time). Otherwise,
// you must provide a cryptographically-random entropy source.
kj::Own<HttpClient> newHttpClient(HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream,
kj::Maybe<EntropySource&> entropySource = nullptr);
// Creates an HttpClient that speaks over the given pre-established connection. The client may
// be used as a proxy client or a host client depending on whether the peer is operating as
// a proxy.
......@@ -563,19 +565,32 @@ kj::Own<HttpClient> newHttpClient(HttpHeaderTable& responseHeaderTable, kj::Asyn
// fail as well. If the destination server chooses to close the connection after a response,
// subsequent requests will fail. If a response takes a long time, it blocks subsequent responses.
// If a WebSocket is opened successfully, all subsequent requests fail.
//
// `entropySource` must be provided in order to use `openWebSocket`. If you don't need WebSockets,
// `entropySource` can be omitted. The WebSocket protocol uses random values to avoid triggering
// flaws (including security flaws) in certain HTTP proxy software. Specifically, entropy is used
// to generate the `Sec-WebSocket-Key` header and to generate frame masks. If you know that there
// are no broken or vulnerable proxies between you and the server, you can provide an dummy entropy
// source that doesn't generate real entropy (e.g. returning the same value every time). Otherwise,
// you must provide a cryptographically-random entropy source.
kj::Own<HttpClient> newHttpClient(HttpService& service);
kj::Own<HttpService> newHttpService(HttpClient& client);
// Adapts an HttpClient to an HttpService and vice versa.
kj::Own<WebSocket> newWebSocket(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator = nullptr);
kj::Maybe<EntropySource&> maskEntropySource);
// Create a new WebSocket on top of the given stream. It is assumed that the HTTP -> WebSocket
// upgrade handshake has already occurred (or is not needed), and messages can immediately be
// sent and received on the stream. Normally applications would not call this directly.
//
// `maskKeyGenerator` is optional, but if omitted, the WebSocket frames will not be masked. Refer
// to RFC6455 to understand when masking is required.
// `maskEntropySource` is used to generate cryptographically-random frame masks. If null, outgoing
// frames will not be masked. Servers are not required to mask their outgoing frames, but clients
// ARE required to do so. So, on the client side, you MUST specify an entropy source. The mask
// must be crytographically random if the data being sent on the WebSocket may be malicious. The
// purpose of the mask is to prevent badly-written HTTP proxies from interpreting "things that look
// like HTTP requests" in a message as being actual HTTP requests, which could result in cache
// poisoning. See RFC6455 section 10.3.
struct HttpServerSettings {
kj::Duration headerTimeout = 15 * kj::SECONDS;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment