Commit 6b6fe39c authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #531 from capnproto/websocket

Add WebSocket support to HTTP library
parents 2444c9eb 92bab53c
......@@ -381,6 +381,26 @@ kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) {
}));
}
kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::ArrayPtr<const byte> expected) {
if (expected.size() == 0) return kj::READY_NOW;
auto buffer = kj::heapArray<byte>(expected.size());
auto promise = in.tryRead(buffer.begin(), 1, buffer.size());
return promise.then(kj::mvCapture(buffer, [&in,expected](kj::Array<byte> buffer, size_t amount) {
if (amount == 0) {
KJ_FAIL_ASSERT("expected data never sent", expected);
}
auto actual = buffer.slice(0, amount);
if (memcmp(actual.begin(), expected.begin(), actual.size()) != 0) {
KJ_FAIL_ASSERT("data from stream doesn't match expected", expected, actual);
}
return expectRead(in, expected.slice(amount, expected.size()));
}));
}
void testHttpClientRequest(kj::AsyncIoContext& io, const HttpRequestTestCase& testCase) {
auto pipe = io.provider->newTwoWayPipe();
......@@ -1150,6 +1170,773 @@ KJ_TEST("HttpClient <-> HttpServer") {
// -----------------------------------------------------------------------------
KJ_TEST("WebSocket core protocol") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = newWebSocket(kj::mv(pipe.ends[0]), nullptr);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
auto mediumString = kj::strArray(kj::repeat(kj::StringPtr("123456789"), 30), "");
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("123456789"), 10000), "");
auto clientTask = client->send(kj::StringPtr("hello"))
.then([&]() { return client->send(mediumString); })
.then([&]() { return client->send(bigString); })
.then([&]() { return client->send(kj::StringPtr("world").asBytes()); })
.then([&]() { return client->close(1234, "bored"); });
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "hello");
}
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == mediumString);
}
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == bigString);
}
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::Array<byte>>());
KJ_EXPECT(kj::str(message.get<kj::Array<byte>>().asChars()) == "world");
}
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<WebSocket::Close>());
KJ_EXPECT(message.get<WebSocket::Close>().code == 1234);
KJ_EXPECT(message.get<WebSocket::Close>().reason == "bored");
}
auto serverTask = server->close(4321, "whatever");
{
auto message = client->receive().wait(io.waitScope);
KJ_ASSERT(message.is<WebSocket::Close>());
KJ_EXPECT(message.get<WebSocket::Close>().code == 4321);
KJ_EXPECT(message.get<WebSocket::Close>().reason == "whatever");
}
clientTask.wait(io.waitScope);
serverTask.wait(io.waitScope);
}
KJ_TEST("WebSocket fragmented") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
byte DATA[] = {
0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ',
0x00, 0x03, 'w', 'o', 'r',
0x80, 0x02, 'l', 'd',
};
auto clientTask = client->write(DATA, sizeof(DATA));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "hello world");
}
clientTask.wait(io.waitScope);
}
class FakeEntropySource final: public EntropySource {
public:
void generate(kj::ArrayPtr<byte> buffer) override {
static constexpr byte DUMMY[4] = { 12, 34, 56, 78 };
for (auto i: kj::indices(buffer)) {
buffer[i] = DUMMY[i % sizeof(DUMMY)];
}
}
};
KJ_TEST("WebSocket masked") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
FakeEntropySource maskGenerator;
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), maskGenerator);
byte DATA[] = {
0x81, 0x86, 12, 34, 56, 78, 'h' ^ 12, 'e' ^ 34, 'l' ^ 56, 'l' ^ 78, 'o' ^ 12, ' ' ^ 34,
};
auto clientTask = client->write(DATA, sizeof(DATA));
auto serverTask = server->send(kj::StringPtr("hello "));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "hello ");
}
expectRead(*client, DATA).wait(io.waitScope);
clientTask.wait(io.waitScope);
serverTask.wait(io.waitScope);
}
KJ_TEST("WebSocket unsolicited pong") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
byte DATA[] = {
0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ',
0x8A, 0x03, 'f', 'o', 'o',
0x80, 0x05, 'w', 'o', 'r', 'l', 'd',
};
auto clientTask = client->write(DATA, sizeof(DATA));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "hello world");
}
clientTask.wait(io.waitScope);
}
KJ_TEST("WebSocket ping") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
// Be extra-annoying by having the ping arrive between fragments.
byte DATA[] = {
0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ',
0x89, 0x03, 'f', 'o', 'o',
0x80, 0x05, 'w', 'o', 'r', 'l', 'd',
};
auto clientTask = client->write(DATA, sizeof(DATA));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "hello world");
}
auto serverTask = server->send(kj::StringPtr("bar"));
byte EXPECTED[] = {
0x8A, 0x03, 'f', 'o', 'o', // pong
0x81, 0x03, 'b', 'a', 'r', // message
};
expectRead(*client, EXPECTED).wait(io.waitScope);
clientTask.wait(io.waitScope);
serverTask.wait(io.waitScope);
}
KJ_TEST("WebSocket ping mid-send") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), "");
auto serverTask = server->send(bigString).eagerlyEvaluate(nullptr);
byte DATA[] = {
0x89, 0x03, 'f', 'o', 'o', // ping
0x81, 0x03, 'b', 'a', 'r', // some other message
};
auto clientTask = client->write(DATA, sizeof(DATA));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "bar");
}
byte EXPECTED1[] = { 0x81, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 };
expectRead(*client, EXPECTED1).wait(io.waitScope);
expectRead(*client, bigString).wait(io.waitScope);
byte EXPECTED2[] = { 0x8A, 0x03, 'f', 'o', 'o' };
expectRead(*client, EXPECTED2).wait(io.waitScope);
clientTask.wait(io.waitScope);
serverTask.wait(io.waitScope);
}
class UnbufferedPipe final: public AsyncIoStream {
// An in-memory one-way pipe with no internal buffer. read() blocks waiting for write()s and
// write() blocks waiting for read()s.
//
// TODO(cleanup): This is probably broadly useful. Put it in a utility library somewhere.
// NOTE: Must implement handling of cancellation first!
public:
kj::Promise<void> write(const void* buffer, size_t size) override {
KJ_SWITCH_ONEOF(current) {
KJ_CASE_ONEOF(w, CurrentWrite) {
KJ_FAIL_REQUIRE("can only call write() once at a time");
}
KJ_CASE_ONEOF(r, CurrentRead) {
if (size < r.minBytes) {
// Write does not complete the current read.
memcpy(r.buffer.begin(), buffer, size);
r.minBytes -= size;
r.alreadyRead += size;
r.buffer = r.buffer.slice(size, r.buffer.size());
return kj::READY_NOW;
} else if (size <= r.buffer.size()) {
// Write satisfies the current read, and read satisfies the write.
memcpy(r.buffer.begin(), buffer, size);
r.fulfiller->fulfill(r.alreadyRead + size);
current = None();
return kj::READY_NOW;
} else {
// Write satisfies the read and still has more data leftover to write.
size_t amount = r.buffer.size();
memcpy(r.buffer.begin(), buffer, amount);
r.fulfiller->fulfill(amount + r.alreadyRead);
auto paf = kj::newPromiseAndFulfiller<void>();
current = CurrentWrite {
kj::arrayPtr(reinterpret_cast<const byte*>(buffer) + amount, size - amount),
kj::mv(paf.fulfiller)
};
return kj::mv(paf.promise);
}
}
KJ_CASE_ONEOF(e, Eof) {
KJ_FAIL_REQUIRE("write after EOF");
}
KJ_CASE_ONEOF(n, None) {
auto paf = kj::newPromiseAndFulfiller<void>();
current = CurrentWrite {
kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size),
kj::mv(paf.fulfiller)
};
return kj::mv(paf.promise);
}
}
KJ_UNREACHABLE;
}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_SWITCH_ONEOF(current) {
KJ_CASE_ONEOF(w, CurrentWrite) {
if (maxBytes < w.buffer.size()) {
// Entire read satisfied by write, write is still pending.
memcpy(buffer, w.buffer.begin(), maxBytes);
w.buffer = w.buffer.slice(maxBytes, w.buffer.size());
return maxBytes;
} else if (minBytes <= w.buffer.size()) {
// Read is satisfied by write and consumes entire write.
size_t result = w.buffer.size();
memcpy(buffer, w.buffer.begin(), result);
w.fulfiller->fulfill();
current = None();
return result;
} else {
// Read consumes entire write and is not satisfied.
size_t alreadyRead = w.buffer.size();
memcpy(buffer, w.buffer.begin(), alreadyRead);
w.fulfiller->fulfill();
auto paf = kj::newPromiseAndFulfiller<size_t>();
current = CurrentRead {
kj::arrayPtr(reinterpret_cast<byte*>(buffer) + alreadyRead, maxBytes - alreadyRead),
minBytes - alreadyRead,
alreadyRead,
kj::mv(paf.fulfiller)
};
return kj::mv(paf.promise);
}
}
KJ_CASE_ONEOF(r, CurrentRead) {
KJ_FAIL_REQUIRE("can only call read() once at a time");
}
KJ_CASE_ONEOF(e, Eof) {
return size_t(0);
}
KJ_CASE_ONEOF(n, None) {
auto paf = kj::newPromiseAndFulfiller<size_t>();
current = CurrentRead {
kj::arrayPtr(reinterpret_cast<byte*>(buffer), maxBytes),
minBytes,
0,
kj::mv(paf.fulfiller)
};
return kj::mv(paf.promise);
}
}
KJ_UNREACHABLE;
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
// TODO(cleanup): Should this be the defalut implementation of this method?
if (pieces.size() == 0) return kj::READY_NOW;
return write(pieces[0].begin(), pieces[0].size())
.then([this, pieces]() {
return write(pieces.slice(1, pieces.size()));
});
}
void shutdownWrite() override {
KJ_SWITCH_ONEOF(current) {
KJ_CASE_ONEOF(w, CurrentWrite) {
KJ_FAIL_REQUIRE("can't call shutdownWrite() during a write()");
}
KJ_CASE_ONEOF(r, CurrentRead) {
r.fulfiller->fulfill(kj::mv(r.alreadyRead));
}
KJ_CASE_ONEOF(e, Eof) {
// ignore
}
KJ_CASE_ONEOF(n, None) {
// ignore
}
}
current = Eof();
}
private:
struct CurrentWrite {
kj::ArrayPtr<const byte> buffer;
kj::Own<kj::PromiseFulfiller<void>> fulfiller;
};
struct CurrentRead {
kj::ArrayPtr<byte> buffer;
size_t minBytes;
size_t alreadyRead;
kj::Own<kj::PromiseFulfiller<size_t>> fulfiller;
};
struct Eof {};
struct None {};
kj::OneOf<CurrentWrite, CurrentRead, Eof, None> current = None();
};
class InputOutputPair final: public kj::AsyncIoStream {
// Creates an AsyncIoStream out of an AsyncInputStream and an AsyncOutputStream.
public:
InputOutputPair(kj::AsyncInputStream& in, kj::AsyncIoStream& out)
: in(in), out(out) {}
kj::Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
return in.read(buffer, minBytes, maxBytes);
}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return in.tryRead(buffer, minBytes, maxBytes);
}
Maybe<uint64_t> tryGetLength() override {
return in.tryGetLength();
}
Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount = kj::maxValue) override {
return in.pumpTo(output, amount);
}
kj::Promise<void> write(const void* buffer, size_t size) override {
return out.write(buffer, size);
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
return out.write(pieces);
}
kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
return out.tryPumpFrom(input, amount);
}
void shutdownWrite() override {
return out.shutdownWrite();
}
private:
kj::AsyncInputStream& in;
kj::AsyncIoStream& out;
};
KJ_TEST("WebSocket double-ping mid-send") {
auto io = kj::setupAsyncIo();
UnbufferedPipe upPipe;
UnbufferedPipe downPipe;
InputOutputPair client(downPipe, upPipe);
auto server = newWebSocket(kj::heap<InputOutputPair>(upPipe, downPipe), nullptr);
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), "");
auto serverTask = server->send(bigString).eagerlyEvaluate(nullptr);
byte DATA[] = {
0x89, 0x03, 'f', 'o', 'o', // ping
0x89, 0x03, 'q', 'u', 'x', // ping2
0x81, 0x03, 'b', 'a', 'r', // some other message
};
auto clientTask = client.write(DATA, sizeof(DATA));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "bar");
}
byte EXPECTED1[] = { 0x81, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 };
expectRead(client, EXPECTED1).wait(io.waitScope);
expectRead(client, bigString).wait(io.waitScope);
byte EXPECTED2[] = { 0x8A, 0x03, 'q', 'u', 'x' };
expectRead(client, EXPECTED2).wait(io.waitScope);
clientTask.wait(io.waitScope);
serverTask.wait(io.waitScope);
}
KJ_TEST("WebSocket ping received during pong send") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), nullptr);
// Send a very large ping so that sending the pong takes a while. Then send a second ping
// immediately after.
byte PREFIX[] = { 0x89, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 };
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), "");
byte POSTFIX[] = {
0x89, 0x03, 'f', 'o', 'o',
0x81, 0x03, 'b', 'a', 'r',
};
kj::ArrayPtr<const byte> parts[] = {PREFIX, bigString.asBytes(), POSTFIX};
auto clientTask = client->write(parts);
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "bar");
}
byte EXPECTED1[] = { 0x8A, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 };
expectRead(*client, EXPECTED1).wait(io.waitScope);
expectRead(*client, bigString).wait(io.waitScope);
byte EXPECTED2[] = { 0x8A, 0x03, 'f', 'o', 'o' };
expectRead(*client, EXPECTED2).wait(io.waitScope);
clientTask.wait(io.waitScope);
}
class TestWebSocketService final: public HttpService, private kj::TaskSet::ErrorHandler {
public:
TestWebSocketService(HttpHeaderTable& headerTable, HttpHeaderId hMyHeader)
: headerTable(headerTable), hMyHeader(hMyHeader), tasks(*this) {}
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override {
KJ_FAIL_ASSERT("can't get here");
}
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
HttpHeaders responseHeaders(headerTable);
KJ_IF_MAYBE(h, headers.get(hMyHeader)) {
responseHeaders.set(hMyHeader, kj::str("respond-", *h));
}
if (url == "/return-error") {
response.send(404, "Not Found", responseHeaders, uint64_t(0));
return kj::READY_NOW;
} else if (url == "/ws-inline") {
auto ws = response.acceptWebSocket(responseHeaders);
return doWebSocket(*ws, "start-inline").attach(kj::mv(ws));
} else if (url == "/ws-detached") {
auto ws = response.acceptWebSocket(responseHeaders);
tasks.add(doWebSocket(*ws, "start-detached").attach(kj::mv(ws)));
return kj::READY_NOW;
} else {
KJ_FAIL_ASSERT("unexpected path", url);
}
}
private:
HttpHeaderTable& headerTable;
HttpHeaderId hMyHeader;
kj::TaskSet tasks;
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, exception);
}
static kj::Promise<void> doWebSocket(WebSocket& ws, kj::StringPtr message) {
auto copy = kj::str(message);
return ws.send(copy).attach(kj::mv(copy))
.then([&ws]() {
return ws.receive();
}).then([&ws](WebSocket::Message&& message) {
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(str, kj::String) {
return doWebSocket(ws, kj::str("reply:", str));
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
return doWebSocket(ws, kj::str("reply:", data));
}
KJ_CASE_ONEOF(close, WebSocket::Close) {
auto reason = kj::str("close-reply:", close.reason);
return ws.close(close.code + 1, reason).attach(kj::mv(reason));
}
}
KJ_UNREACHABLE;
});
}
};
const char WEBSOCKET_REQUEST_HANDSHAKE[] =
" HTTP/1.1\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n"
"Sec-WebSocket-Version: 13\r\n"
"My-Header: foo\r\n"
"\r\n";
const char WEBSOCKET_RESPONSE_HANDSHAKE[] =
"HTTP/1.1 101 Switching Protocols\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Accept: pShtIFKT0s8RYZvnWY/CrjQD8CM=\r\n"
"My-Header: respond-foo\r\n"
"\r\n";
const char WEBSOCKET_RESPONSE_HANDSHAKE_ERROR[] =
"HTTP/1.1 404 Not Found\r\n"
"Content-Length: 0\r\n"
"My-Header: respond-foo\r\n"
"\r\n";
const byte WEBSOCKET_FIRST_MESSAGE_INLINE[] =
{ 0x81, 0x0c, 's','t','a','r','t','-','i','n','l','i','n','e' };
const byte WEBSOCKET_FIRST_MESSAGE_DETACHED[] =
{ 0x81, 0x0e, 's','t','a','r','t','-','d','e','t','a','c','h','e','d' };
const byte WEBSOCKET_SEND_MESSAGE[] =
{ 0x81, 0x83, 12, 34, 56, 78, 'b'^12, 'a'^34, 'r'^56 };
const byte WEBSOCKET_REPLY_MESSAGE[] =
{ 0x81, 0x09, 'r','e','p','l','y',':','b','a','r' };
const byte WEBSOCKET_SEND_CLOSE[] =
{ 0x88, 0x85, 12, 34, 56, 78, 0x12^12, 0x34^34, 'q'^56, 'u'^78, 'x'^12 };
const byte WEBSOCKET_REPLY_CLOSE[] =
{ 0x88, 0x11, 0x12, 0x35, 'c','l','o','s','e','-','r','e','p','l','y',':','q','u','x' };
template <size_t s>
kj::ArrayPtr<const byte> asBytes(const char (&chars)[s]) {
return kj::ArrayPtr<const char>(chars, s - 1).asBytes();
}
KJ_TEST("HttpClient WebSocket handshake") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
auto serverTask = expectRead(*pipe.ends[1], request)
.then([&]() { return pipe.ends[1]->write({asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)}); })
.then([&]() { return pipe.ends[1]->write({WEBSOCKET_FIRST_MESSAGE_INLINE}); })
.then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_MESSAGE); })
.then([&]() { return pipe.ends[1]->write({WEBSOCKET_REPLY_MESSAGE}); })
.then([&]() { return expectRead(*pipe.ends[1], WEBSOCKET_SEND_CLOSE); })
.then([&]() { return pipe.ends[1]->write({WEBSOCKET_REPLY_CLOSE}); })
.eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
FakeEntropySource entropySource;
auto client = newHttpClient(*headerTable, *pipe.ends[0], entropySource);
kj::HttpHeaders headers(*headerTable);
headers.set(hMyHeader, "foo");
auto response = client->openWebSocket("/websocket", headers).wait(io.waitScope);
KJ_EXPECT(response.statusCode == 101);
KJ_EXPECT(response.statusText == "Switching Protocols", response.statusText);
KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo");
KJ_ASSERT(response.webSocketOrBody.is<kj::Own<WebSocket>>());
auto ws = kj::mv(response.webSocketOrBody.get<kj::Own<WebSocket>>());
{
auto message = ws->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "start-inline");
}
ws->send(kj::StringPtr("bar")).wait(io.waitScope);
{
auto message = ws->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "reply:bar");
}
ws->close(0x1234, "qux").wait(io.waitScope);
{
auto message = ws->receive().wait(io.waitScope);
KJ_ASSERT(message.is<WebSocket::Close>());
KJ_EXPECT(message.get<WebSocket::Close>().code == 0x1235);
KJ_EXPECT(message.get<WebSocket::Close>().reason == "close-reply:qux");
}
serverTask.wait(io.waitScope);
}
KJ_TEST("HttpClient WebSocket error") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
auto serverTask = expectRead(*pipe.ends[1], request)
.then([&]() { return pipe.ends[1]->write({asBytes(WEBSOCKET_RESPONSE_HANDSHAKE_ERROR)}); })
.then([&]() { return expectRead(*pipe.ends[1], request); })
.then([&]() { return pipe.ends[1]->write({asBytes(WEBSOCKET_RESPONSE_HANDSHAKE_ERROR)}); })
.eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
FakeEntropySource entropySource;
auto client = newHttpClient(*headerTable, *pipe.ends[0], entropySource);
kj::HttpHeaders headers(*headerTable);
headers.set(hMyHeader, "foo");
{
auto response = client->openWebSocket("/websocket", headers).wait(io.waitScope);
KJ_EXPECT(response.statusCode == 404);
KJ_EXPECT(response.statusText == "Not Found", response.statusText);
KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo");
KJ_ASSERT(response.webSocketOrBody.is<kj::Own<AsyncInputStream>>());
}
{
auto response = client->openWebSocket("/websocket", headers).wait(io.waitScope);
KJ_EXPECT(response.statusCode == 404);
KJ_EXPECT(response.statusText == "Not Found", response.statusText);
KJ_EXPECT(KJ_ASSERT_NONNULL(response.headers->get(hMyHeader)) == "respond-foo");
KJ_ASSERT(response.webSocketOrBody.is<kj::Own<AsyncInputStream>>());
}
serverTask.wait(io.waitScope);
}
KJ_TEST("HttpServer WebSocket handshake") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
TestWebSocketService service(*headerTable, hMyHeader);
HttpServer server(io.provider->getTimer(), *headerTable, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
auto request = kj::str("GET /ws-inline", WEBSOCKET_REQUEST_HANDSHAKE);
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(io.waitScope);
pipe.ends[1]->write({WEBSOCKET_SEND_MESSAGE}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_REPLY_MESSAGE).wait(io.waitScope);
pipe.ends[1]->write({WEBSOCKET_SEND_CLOSE}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE).wait(io.waitScope);
listenTask.wait(io.waitScope);
}
KJ_TEST("HttpServer WebSocket handshake detached") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
TestWebSocketService service(*headerTable, hMyHeader);
HttpServer server(io.provider->getTimer(), *headerTable, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
auto request = kj::str("GET /ws-detached", WEBSOCKET_REQUEST_HANDSHAKE);
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE).wait(io.waitScope);
listenTask.wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_FIRST_MESSAGE_DETACHED).wait(io.waitScope);
pipe.ends[1]->write({WEBSOCKET_SEND_MESSAGE}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_REPLY_MESSAGE).wait(io.waitScope);
pipe.ends[1]->write({WEBSOCKET_SEND_CLOSE}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_REPLY_CLOSE).wait(io.waitScope);
}
KJ_TEST("HttpServer WebSocket handshake error") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
HttpHeaderTable::Builder tableBuilder;
HttpHeaderId hMyHeader = tableBuilder.add("My-Header");
auto headerTable = tableBuilder.build();
TestWebSocketService service(*headerTable, hMyHeader);
HttpServer server(io.provider->getTimer(), *headerTable, service);
auto listenTask = server.listenHttp(kj::mv(pipe.ends[0]));
auto request = kj::str("GET /return-error", WEBSOCKET_REQUEST_HANDSHAKE);
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(io.waitScope);
// Can send more requests!
pipe.ends[1]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*pipe.ends[1], WEBSOCKET_RESPONSE_HANDSHAKE_ERROR).wait(io.waitScope);
pipe.ends[1]->shutdownWrite();
listenTask.wait(io.waitScope);
}
// -----------------------------------------------------------------------------
KJ_TEST("HttpServer request timeout") {
auto PIPELINE_TESTS = pipelineTestCases();
......@@ -1484,6 +2271,88 @@ KJ_TEST("newHttpService from HttpClient") {
writeResponsesPromise.wait(io.waitScope);
}
KJ_TEST("newHttpService from HttpClient WebSockets") {
auto io = kj::setupAsyncIo();
auto frontPipe = io.provider->newTwoWayPipe();
auto backPipe = io.provider->newTwoWayPipe();
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
auto writeResponsesPromise = expectRead(*backPipe.ends[1], request)
.then([&]() { return backPipe.ends[1]->write({asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)}); })
.then([&]() { return backPipe.ends[1]->write({WEBSOCKET_FIRST_MESSAGE_INLINE}); })
.then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_MESSAGE); })
.then([&]() { return backPipe.ends[1]->write({WEBSOCKET_REPLY_MESSAGE}); })
.then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_CLOSE); })
.then([&]() { return backPipe.ends[1]->write({WEBSOCKET_REPLY_CLOSE}); })
// expect EOF
.then([&]() { return backPipe.ends[1]->readAllBytes(); })
.then([&](kj::ArrayPtr<byte> content) {
KJ_EXPECT(content.size() == 0);
// Send EOF.
backPipe.ends[1]->shutdownWrite();
})
.eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
{
HttpHeaderTable table;
FakeEntropySource entropySource;
auto backClient = newHttpClient(table, *backPipe.ends[0], entropySource);
auto frontService = newHttpService(*backClient);
HttpServer frontServer(io.provider->getTimer(), table, *frontService);
auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1]));
frontPipe.ends[0]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*frontPipe.ends[0], WEBSOCKET_RESPONSE_HANDSHAKE).wait(io.waitScope);
expectRead(*frontPipe.ends[0], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(io.waitScope);
frontPipe.ends[0]->write({WEBSOCKET_SEND_MESSAGE}).wait(io.waitScope);
expectRead(*frontPipe.ends[0], WEBSOCKET_REPLY_MESSAGE).wait(io.waitScope);
frontPipe.ends[0]->write({WEBSOCKET_SEND_CLOSE}).wait(io.waitScope);
expectRead(*frontPipe.ends[0], WEBSOCKET_REPLY_CLOSE).wait(io.waitScope);
frontPipe.ends[0]->shutdownWrite();
listenTask.wait(io.waitScope);
}
writeResponsesPromise.wait(io.waitScope);
}
KJ_TEST("newHttpService from HttpClient WebSockets disconnect") {
auto io = kj::setupAsyncIo();
auto frontPipe = io.provider->newTwoWayPipe();
auto backPipe = io.provider->newTwoWayPipe();
auto request = kj::str("GET /websocket", WEBSOCKET_REQUEST_HANDSHAKE);
auto writeResponsesPromise = expectRead(*backPipe.ends[1], request)
.then([&]() { return backPipe.ends[1]->write({asBytes(WEBSOCKET_RESPONSE_HANDSHAKE)}); })
.then([&]() { return backPipe.ends[1]->write({WEBSOCKET_FIRST_MESSAGE_INLINE}); })
.then([&]() { return expectRead(*backPipe.ends[1], WEBSOCKET_SEND_MESSAGE); })
.then([&]() { backPipe.ends[1]->shutdownWrite(); })
.eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
{
HttpHeaderTable table;
FakeEntropySource entropySource;
auto backClient = newHttpClient(table, *backPipe.ends[0], entropySource);
auto frontService = newHttpService(*backClient);
HttpServer frontServer(io.provider->getTimer(), table, *frontService);
auto listenTask = frontServer.listenHttp(kj::mv(frontPipe.ends[1]));
frontPipe.ends[0]->write({request.asBytes()}).wait(io.waitScope);
expectRead(*frontPipe.ends[0], WEBSOCKET_RESPONSE_HANDSHAKE).wait(io.waitScope);
expectRead(*frontPipe.ends[0], WEBSOCKET_FIRST_MESSAGE_INLINE).wait(io.waitScope);
frontPipe.ends[0]->write({WEBSOCKET_SEND_MESSAGE}).wait(io.waitScope);
KJ_EXPECT(frontPipe.ends[0]->readAllText().wait(io.waitScope) == "");
frontPipe.ends[0]->shutdownWrite();
listenTask.wait(io.waitScope);
}
writeResponsesPromise.wait(io.waitScope);
}
// -----------------------------------------------------------------------------
KJ_TEST("HttpClient to capnproto.org") {
......
......@@ -24,9 +24,303 @@
#include <kj/parse/char.h>
#include <unordered_map>
#include <stdlib.h>
#include <kj/encoding.h>
namespace kj {
// =======================================================================================
// SHA-1 implementation from https://github.com/clibs/sha1
//
// The WebSocket standard depends on SHA-1. ARRRGGGHHHHH.
//
// Any old checksum would have served the purpose, or hell, even just returning the header
// verbatim. But NO, they decided to throw a whole complicated hash algorithm in there, AND
// THEY CHOSE A BROKEN ONE THAT WE OTHERWISE WOULDN'T NEED ANYMORE.
//
// TODO(cleanup): Move this to a shared hashing library. Maybe. Or maybe don't, becaues no one
// should be using SHA-1 anymore.
//
// THIS USAGE IS NOT SECURITY SENSITIVE. IF YOU REPORT A SECURITY ISSUE BECAUSE YOU SAW SHA1 IN THE
// SOURCE CODE I WILL MAKE FUN OF YOU.
/*
SHA-1 in C
By Steve Reid <steve@edmweb.com>
100% Public Domain
Test Vectors (from FIPS PUB 180-1)
"abc"
A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D
"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1
A million repetitions of "a"
34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F
*/
/* #define LITTLE_ENDIAN * This should be #define'd already, if true. */
/* #define SHA1HANDSOFF * Copies data before messing with it. */
#define SHA1HANDSOFF
typedef struct
{
uint32_t state[5];
uint32_t count[2];
unsigned char buffer[64];
} SHA1_CTX;
#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits))))
/* blk0() and blk() perform the initial expand. */
/* I got the idea of expanding during the round function from SSLeay */
#if BYTE_ORDER == LITTLE_ENDIAN
#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \
|(rol(block->l[i],8)&0x00FF00FF))
#elif BYTE_ORDER == BIG_ENDIAN
#define blk0(i) block->l[i]
#else
#error "Endianness not defined!"
#endif
#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \
^block->l[(i+2)&15]^block->l[i&15],1))
/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */
#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30);
#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30);
#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30);
/* Hash a single 512-bit block. This is the core of the algorithm. */
void SHA1Transform(
uint32_t state[5],
const unsigned char buffer[64]
)
{
uint32_t a, b, c, d, e;
typedef union
{
unsigned char c[64];
uint32_t l[16];
} CHAR64LONG16;
#ifdef SHA1HANDSOFF
CHAR64LONG16 block[1]; /* use array to appear as a pointer */
memcpy(block, buffer, 64);
#else
/* The following had better never be used because it causes the
* pointer-to-const buffer to be cast into a pointer to non-const.
* And the result is written through. I threw a "const" in, hoping
* this will cause a diagnostic.
*/
CHAR64LONG16 *block = (const CHAR64LONG16 *) buffer;
#endif
/* Copy context->state[] to working vars */
a = state[0];
b = state[1];
c = state[2];
d = state[3];
e = state[4];
/* 4 rounds of 20 operations each. Loop unrolled. */
R0(a, b, c, d, e, 0);
R0(e, a, b, c, d, 1);
R0(d, e, a, b, c, 2);
R0(c, d, e, a, b, 3);
R0(b, c, d, e, a, 4);
R0(a, b, c, d, e, 5);
R0(e, a, b, c, d, 6);
R0(d, e, a, b, c, 7);
R0(c, d, e, a, b, 8);
R0(b, c, d, e, a, 9);
R0(a, b, c, d, e, 10);
R0(e, a, b, c, d, 11);
R0(d, e, a, b, c, 12);
R0(c, d, e, a, b, 13);
R0(b, c, d, e, a, 14);
R0(a, b, c, d, e, 15);
R1(e, a, b, c, d, 16);
R1(d, e, a, b, c, 17);
R1(c, d, e, a, b, 18);
R1(b, c, d, e, a, 19);
R2(a, b, c, d, e, 20);
R2(e, a, b, c, d, 21);
R2(d, e, a, b, c, 22);
R2(c, d, e, a, b, 23);
R2(b, c, d, e, a, 24);
R2(a, b, c, d, e, 25);
R2(e, a, b, c, d, 26);
R2(d, e, a, b, c, 27);
R2(c, d, e, a, b, 28);
R2(b, c, d, e, a, 29);
R2(a, b, c, d, e, 30);
R2(e, a, b, c, d, 31);
R2(d, e, a, b, c, 32);
R2(c, d, e, a, b, 33);
R2(b, c, d, e, a, 34);
R2(a, b, c, d, e, 35);
R2(e, a, b, c, d, 36);
R2(d, e, a, b, c, 37);
R2(c, d, e, a, b, 38);
R2(b, c, d, e, a, 39);
R3(a, b, c, d, e, 40);
R3(e, a, b, c, d, 41);
R3(d, e, a, b, c, 42);
R3(c, d, e, a, b, 43);
R3(b, c, d, e, a, 44);
R3(a, b, c, d, e, 45);
R3(e, a, b, c, d, 46);
R3(d, e, a, b, c, 47);
R3(c, d, e, a, b, 48);
R3(b, c, d, e, a, 49);
R3(a, b, c, d, e, 50);
R3(e, a, b, c, d, 51);
R3(d, e, a, b, c, 52);
R3(c, d, e, a, b, 53);
R3(b, c, d, e, a, 54);
R3(a, b, c, d, e, 55);
R3(e, a, b, c, d, 56);
R3(d, e, a, b, c, 57);
R3(c, d, e, a, b, 58);
R3(b, c, d, e, a, 59);
R4(a, b, c, d, e, 60);
R4(e, a, b, c, d, 61);
R4(d, e, a, b, c, 62);
R4(c, d, e, a, b, 63);
R4(b, c, d, e, a, 64);
R4(a, b, c, d, e, 65);
R4(e, a, b, c, d, 66);
R4(d, e, a, b, c, 67);
R4(c, d, e, a, b, 68);
R4(b, c, d, e, a, 69);
R4(a, b, c, d, e, 70);
R4(e, a, b, c, d, 71);
R4(d, e, a, b, c, 72);
R4(c, d, e, a, b, 73);
R4(b, c, d, e, a, 74);
R4(a, b, c, d, e, 75);
R4(e, a, b, c, d, 76);
R4(d, e, a, b, c, 77);
R4(c, d, e, a, b, 78);
R4(b, c, d, e, a, 79);
/* Add the working vars back into context.state[] */
state[0] += a;
state[1] += b;
state[2] += c;
state[3] += d;
state[4] += e;
/* Wipe variables */
a = b = c = d = e = 0;
#ifdef SHA1HANDSOFF
memset(block, '\0', sizeof(block));
#endif
}
/* SHA1Init - Initialize new context */
void SHA1Init(
SHA1_CTX * context
)
{
/* SHA1 initialization constants */
context->state[0] = 0x67452301;
context->state[1] = 0xEFCDAB89;
context->state[2] = 0x98BADCFE;
context->state[3] = 0x10325476;
context->state[4] = 0xC3D2E1F0;
context->count[0] = context->count[1] = 0;
}
/* Run your data through this. */
void SHA1Update(
SHA1_CTX * context,
const unsigned char *data,
uint32_t len
)
{
uint32_t i;
uint32_t j;
j = context->count[0];
if ((context->count[0] += len << 3) < j)
context->count[1]++;
context->count[1] += (len >> 29);
j = (j >> 3) & 63;
if ((j + len) > 63)
{
memcpy(&context->buffer[j], data, (i = 64 - j));
SHA1Transform(context->state, context->buffer);
for (; i + 63 < len; i += 64)
{
SHA1Transform(context->state, &data[i]);
}
j = 0;
}
else
i = 0;
memcpy(&context->buffer[j], &data[i], len - i);
}
/* Add padding and return the message digest. */
void SHA1Final(
unsigned char digest[20],
SHA1_CTX * context
)
{
unsigned i;
unsigned char finalcount[8];
unsigned char c;
#if 0 /* untested "improvement" by DHR */
/* Convert context->count to a sequence of bytes
* in finalcount. Second element first, but
* big-endian order within element.
* But we do it all backwards.
*/
unsigned char *fcp = &finalcount[8];
for (i = 0; i < 2; i++)
{
uint32_t t = context->count[i];
int j;
for (j = 0; j < 4; t >>= 8, j++)
*--fcp = (unsigned char) t}
#else
for (i = 0; i < 8; i++)
{
finalcount[i] = (unsigned char) ((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */
}
#endif
c = 0200;
SHA1Update(context, &c, 1);
while ((context->count[0] & 504) != 448)
{
c = 0000;
SHA1Update(context, &c, 1);
}
SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */
for (i = 0; i < 20; i++)
{
digest[i] = (unsigned char)
((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255);
}
/* Wipe variables */
memset(context, '\0', sizeof(*context));
memset(&finalcount, '\0', sizeof(finalcount));
}
// End SHA-1 implementation.
// =======================================================================================
static const char* METHOD_NAMES[] = {
#define METHOD_NAME(id) #id,
KJ_HTTP_FOR_EACH_METHOD(METHOD_NAME)
......@@ -129,6 +423,20 @@ kj::Maybe<HttpMethod> tryParseHttpMethod(kj::StringPtr name) {
namespace {
constexpr char WEBSOCKET_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
// From RFC6455.
static kj::String generateWebSocketAccept(kj::StringPtr key) {
// WebSocket demands we do a SHA-1 here. ARRGHH WHY SHA-1 WHYYYYYY?
SHA1_CTX ctx;
byte digest[20];
SHA1Init(&ctx);
SHA1Update(&ctx, key.asBytes().begin(), key.size());
SHA1Update(&ctx, reinterpret_cast<const byte*>(WEBSOCKET_GUID), strlen(WEBSOCKET_GUID));
SHA1Final(digest, &ctx);
return kj::encodeBase64(digest);
}
constexpr auto HTTP_SEPARATOR_CHARS = kj::parse::anyOfChars("()<>@,;:\\\"/[]?={} \t");
// RFC2616 section 2.2: https://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
......@@ -866,6 +1174,15 @@ public:
RequestOrResponse type, HttpMethod method, uint statusCode,
HttpHeaders::ConnectionHeaders& connectionHeaders);
struct ReleasedBuffer {
kj::Array<byte> buffer;
kj::ArrayPtr<byte> leftover;
};
ReleasedBuffer releaseBuffer() {
return { headerBuffer.releaseAsBytes(), leftover.asBytes() };
}
private:
AsyncIoStream& inner;
kj::Array<char> headerBuffer;
......@@ -1547,14 +1864,526 @@ private:
// =======================================================================================
class WebSocketImpl final: public WebSocket {
public:
WebSocketImpl(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<EntropySource&> maskKeyGenerator,
kj::Array<byte> buffer = kj::heapArray<byte>(4096),
kj::ArrayPtr<byte> leftover = nullptr,
kj::Maybe<kj::Promise<void>> waitBeforeSend = nullptr)
: stream(kj::mv(stream)), maskKeyGenerator(maskKeyGenerator),
sendingPong(kj::mv(waitBeforeSend)),
recvBuffer(kj::mv(buffer)), recvData(leftover) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return sendImpl(OPCODE_BINARY, message);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return sendImpl(OPCODE_TEXT, message.asBytes());
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
kj::Array<byte> payload;
if (code == 1005) {
KJ_REQUIRE(reason.size() == 0, "WebSocket close code 1005 cannot have a reason");
// code 1005 -- leave payload empty
} else {
payload = heapArray<byte>(reason.size() + 2);
payload[0] = code >> 8;
payload[1] = code;
memcpy(payload.begin() + 2, reason.begin(), reason.size());
}
auto promise = sendImpl(OPCODE_CLOSE, payload);
return promise.attach(kj::mv(payload));
}
kj::Promise<void> disconnect() override {
if (!sendClosed) {
KJ_REQUIRE(!currentlySending, "another message send is already in progress");
KJ_IF_MAYBE(p, sendingPong) {
// We recently sent a pong, make sure it's finished before proceeding.
currentlySending = true;
auto promise = p->then([this]() {
currentlySending = false;
return disconnect();
});
sendingPong = nullptr;
return promise;
}
sendClosed = true;
}
stream->shutdownWrite();
return kj::READY_NOW;
}
kj::Promise<Message> receive() override {
auto& recvHeader = *reinterpret_cast<Header*>(recvData.begin());
size_t headerSize = recvHeader.headerSize(recvData.size());
if (headerSize > recvData.size()) {
if (recvData.begin() != recvBuffer.begin()) {
// Move existing data to front of buffer.
if (recvData.size() > 0) {
memmove(recvBuffer.begin(), recvData.begin(), recvData.size());
}
recvData = recvBuffer.slice(0, recvData.size());
}
return stream->tryRead(recvData.end(), 1, recvBuffer.end() - recvData.end())
.then([this](size_t actual) -> kj::Promise<Message> {
if (actual == 0) {
if (recvData.size() > 0) {
return KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in frame header");
} else {
// It's incorrect for the WebSocket to disconnect without sending `Close`.
return KJ_EXCEPTION(DISCONNECTED,
"WebSocket disconnected between frames without sending `Close`.");
}
}
recvData = recvBuffer.slice(0, recvData.size() + actual);
return receive();
});
}
recvData = recvData.slice(headerSize, recvData.size());
size_t payloadLen = recvHeader.getPayloadLen();
auto opcode = recvHeader.getOpcode();
bool isData = opcode < OPCODE_FIRST_CONTROL;
if (opcode == OPCODE_CONTINUATION) {
KJ_REQUIRE(!fragments.empty(), "unexpected continuation frame in WebSocket");
opcode = fragmentOpcode;
} else if (isData) {
KJ_REQUIRE(fragments.empty(), "expected continuation frame in WebSocket");
}
bool isFin = recvHeader.isFin();
kj::Array<byte> message; // space to allocate
byte* payloadTarget; // location into which to read payload (size is payloadLen)
if (isFin) {
// Add space for NUL terminator when allocating text message.
size_t amountToAllocate = payloadLen + (opcode == OPCODE_TEXT && isFin);
if (isData && !fragments.empty()) {
// Final frame of a fragmented message. Gather the fragments.
size_t offset = 0;
for (auto& fragment: fragments) offset += fragment.size();
message = kj::heapArray<byte>(offset + amountToAllocate);
offset = 0;
for (auto& fragment: fragments) {
memcpy(message.begin() + offset, fragment.begin(), fragment.size());
offset += fragment.size();
}
payloadTarget = message.begin() + offset;
fragments.clear();
fragmentOpcode = 0;
} else {
// Single-frame message.
message = kj::heapArray<byte>(amountToAllocate);
payloadTarget = message.begin();
}
} else {
// Fragmented message, and this isn't the final fragment.
KJ_REQUIRE(isData, "WebSocket control frame cannot be fragmented");
message = kj::heapArray<byte>(payloadLen);
payloadTarget = message.begin();
if (fragments.empty()) {
// This is the first fragment, so set the opcode.
fragmentOpcode = opcode;
}
}
Mask mask = recvHeader.getMask();
auto handleMessage = kj::mvCapture(message,
[this,opcode,payloadTarget,payloadLen,mask,isFin]
(kj::Array<byte>&& message) -> kj::Promise<Message> {
if (!mask.isZero()) {
mask.apply(kj::arrayPtr(payloadTarget, payloadLen));
}
if (!isFin) {
// Add fragment to the list and loop.
fragments.add(kj::mv(message));
return receive();
}
switch (opcode) {
case OPCODE_CONTINUATION:
// Shouldn't get here; handled above.
KJ_UNREACHABLE;
case OPCODE_TEXT:
message.back() = '\0';
return Message(kj::String(message.releaseAsChars()));
case OPCODE_BINARY:
return Message(message.releaseAsBytes());
case OPCODE_CLOSE:
if (message.size() < 2) {
return Message(Close { 1005, nullptr });
} else {
uint16_t status = (static_cast<uint16_t>(message[0]) << 8)
| (static_cast<uint16_t>(message[1]) );
return Message(Close {
status, kj::heapString(message.slice(2, message.size()).asChars())
});
}
case OPCODE_PING:
// Send back a pong.
queuePong(kj::mv(message));
return receive();
case OPCODE_PONG:
// Unsolicited pong. Ignore.
return receive();
default:
KJ_FAIL_REQUIRE("unknown WebSocket opcode", opcode);
}
});
if (payloadLen <= recvData.size()) {
// All data already received.
memcpy(payloadTarget, recvData.begin(), payloadLen);
recvData = recvData.slice(payloadLen, recvData.size());
return handleMessage();
} else {
// Need to read more data.
memcpy(payloadTarget, recvData.begin(), recvData.size());
size_t remaining = payloadLen - recvData.size();
auto promise = stream->tryRead(payloadTarget + recvData.size(), remaining, remaining)
.then([remaining](size_t amount) {
if (amount < remaining) {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in message"));
}
});
recvData = nullptr;
return promise.then(kj::mv(handleMessage));
}
}
private:
class Mask {
public:
Mask(): maskBytes { 0, 0, 0, 0 } {}
Mask(const byte* ptr) { memcpy(maskBytes, ptr, 4); }
Mask(kj::Maybe<EntropySource&> generator) {
KJ_IF_MAYBE(g, generator) {
g->generate(maskBytes);
} else {
memset(maskBytes, 0, 4);
}
}
void apply(kj::ArrayPtr<byte> bytes) const {
apply(bytes.begin(), bytes.size());
}
void copyTo(byte* output) const {
memcpy(output, maskBytes, 4);
}
bool isZero() const {
return (maskBytes[0] | maskBytes[1] | maskBytes[2] | maskBytes[3]) == 0;
}
private:
byte maskBytes[4];
void apply(byte* __restrict__ bytes, size_t size) const {
for (size_t i = 0; i < size; i++) {
bytes[i] ^= maskBytes[i % 4];
}
}
};
class Header {
public:
kj::ArrayPtr<const byte> compose(bool fin, byte opcode, uint64_t payloadLen, Mask mask) {
bytes[0] = (fin ? FIN_MASK : 0) | opcode;
bool hasMask = !mask.isZero();
size_t fill;
if (payloadLen < 126) {
bytes[1] = (hasMask ? USE_MASK_MASK : 0) | payloadLen;
if (hasMask) {
mask.copyTo(bytes + 2);
fill = 6;
} else {
fill = 2;
}
} else if (payloadLen < 65536) {
bytes[1] = (hasMask ? USE_MASK_MASK : 0) | 126;
bytes[2] = static_cast<byte>(payloadLen >> 8);
bytes[3] = static_cast<byte>(payloadLen );
if (hasMask) {
mask.copyTo(bytes + 4);
fill = 8;
} else {
fill = 4;
}
} else {
bytes[1] = (hasMask ? USE_MASK_MASK : 0) | 127;
bytes[2] = static_cast<byte>(payloadLen >> 56);
bytes[3] = static_cast<byte>(payloadLen >> 48);
bytes[4] = static_cast<byte>(payloadLen >> 40);
bytes[5] = static_cast<byte>(payloadLen >> 42);
bytes[6] = static_cast<byte>(payloadLen >> 24);
bytes[7] = static_cast<byte>(payloadLen >> 16);
bytes[8] = static_cast<byte>(payloadLen >> 8);
bytes[9] = static_cast<byte>(payloadLen );
if (hasMask) {
mask.copyTo(bytes + 10);
fill = 14;
} else {
fill = 10;
}
}
return arrayPtr(bytes, fill);
}
bool isFin() const {
return bytes[0] & FIN_MASK;
}
bool hasRsv() const {
return bytes[0] & RSV_MASK;
}
byte getOpcode() const {
return bytes[0] & OPCODE_MASK;
}
uint64_t getPayloadLen() const {
byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK;
if (payloadLen == 127) {
return (static_cast<uint64_t>(bytes[2]) << 56)
| (static_cast<uint64_t>(bytes[3]) << 48)
| (static_cast<uint64_t>(bytes[4]) << 40)
| (static_cast<uint64_t>(bytes[5]) << 32)
| (static_cast<uint64_t>(bytes[6]) << 24)
| (static_cast<uint64_t>(bytes[7]) << 16)
| (static_cast<uint64_t>(bytes[8]) << 8)
| (static_cast<uint64_t>(bytes[9]) );
} else if (payloadLen == 126) {
return (static_cast<uint64_t>(bytes[2]) << 8)
| (static_cast<uint64_t>(bytes[3]) );
} else {
return payloadLen;
}
}
Mask getMask() const {
if (bytes[1] & USE_MASK_MASK) {
byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK;
if (payloadLen == 127) {
return Mask(bytes + 10);
} else if (payloadLen == 126) {
return Mask(bytes + 4);
} else {
return Mask(bytes + 2);
}
} else {
return Mask();
}
}
size_t headerSize(size_t sizeSoFar) {
if (sizeSoFar < 2) return 2;
size_t required = 2;
if (bytes[1] & USE_MASK_MASK) {
required += 4;
}
byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK;
if (payloadLen == 127) {
required += 8;
} else if (payloadLen == 126) {
required += 2;
}
return required;
}
private:
byte bytes[14];
static constexpr byte FIN_MASK = 0x80;
static constexpr byte RSV_MASK = 0x70;
static constexpr byte OPCODE_MASK = 0x0f;
static constexpr byte USE_MASK_MASK = 0x80;
static constexpr byte PAYLOAD_LEN_MASK = 0x7f;
};
static constexpr byte OPCODE_CONTINUATION = 0;
static constexpr byte OPCODE_TEXT = 1;
static constexpr byte OPCODE_BINARY = 2;
static constexpr byte OPCODE_CLOSE = 8;
static constexpr byte OPCODE_PING = 9;
static constexpr byte OPCODE_PONG = 10;
static constexpr byte OPCODE_FIRST_CONTROL = 8;
// ---------------------------------------------------------------------------
kj::Own<kj::AsyncIoStream> stream;
kj::Maybe<EntropySource&> maskKeyGenerator;
bool sendClosed = false;
bool currentlySending = false;
Header sendHeader;
kj::ArrayPtr<const byte> sendParts[2];
kj::Maybe<kj::Array<byte>> queuedPong;
// If a Ping is received while currentlySending is true, then queuedPong is set to the body of
// a pong message that should be sent once the current send is complete.
kj::Maybe<kj::Promise<void>> sendingPong;
// If a Pong is being sent asynchronously in response to a Ping, this is a promise for the
// completion of that send.
//
// Additionally, this member is used if we need to block our first send on WebSocket startup,
// e.g. because we need to wait for HTTP handshake writes to flush before we can start sending
// WebSocket data. `sendingPong` was overloaded for this use case because the logic is the same.
// Perhaps it should be renamed to `blockSend` or `writeQueue`.
uint fragmentOpcode = 0;
kj::Vector<kj::Array<byte>> fragments;
// If `fragments` is non-empty, we've already received some fragments of a message.
// `fragmentOpcode` is the original opcode.
kj::Array<byte> recvBuffer;
kj::ArrayPtr<byte> recvData;
kj::Promise<void> sendImpl(byte opcode, kj::ArrayPtr<const byte> message) {
KJ_REQUIRE(!sendClosed, "WebSocket already closed");
KJ_REQUIRE(!currentlySending, "another message send is already in progress");
currentlySending = true;
KJ_IF_MAYBE(p, sendingPong) {
// We recently sent a pong, make sure it's finished before proceeding.
auto promise = p->then([this, opcode, message]() {
currentlySending = false;
return sendImpl(opcode, message);
});
sendingPong = nullptr;
return promise;
}
sendClosed = opcode == OPCODE_CLOSE;
Mask mask(maskKeyGenerator);
kj::Array<byte> ownMessage;
if (!mask.isZero()) {
// Sadness, we have to make a copy to apply the mask.
ownMessage = kj::heapArray(message);
mask.apply(ownMessage);
message = ownMessage;
}
sendParts[0] = sendHeader.compose(true, opcode, message.size(), mask);
sendParts[1] = message;
auto promise = stream->write(sendParts);
if (!mask.isZero()) {
promise = promise.attach(kj::mv(ownMessage));
}
return promise.then([this]() {
currentlySending = false;
// Send queued pong if needed.
KJ_IF_MAYBE(q, queuedPong) {
kj::Array<byte> payload = kj::mv(*q);
queuedPong = nullptr;
queuePong(kj::mv(payload));
}
});
}
void queuePong(kj::Array<byte> payload) {
if (currentlySending) {
// There is a message-send in progress, so we cannot write to the stream now.
//
// Note: According to spec, if the server receives a second ping before responding to the
// previous one, it can opt to respond only to the last ping. So we don't have to check if
// queuedPong is already non-null.
queuedPong = kj::mv(payload);
} else KJ_IF_MAYBE(promise, sendingPong) {
// We're still sending a previous pong. Wait for it to finish before sending ours.
sendingPong = promise->then(kj::mvCapture(payload, [this](kj::Array<byte> payload) mutable {
return sendPong(kj::mv(payload));
}));
} else {
// We're not sending any pong currently.
sendingPong = sendPong(kj::mv(payload));
}
}
kj::Promise<void> sendPong(kj::Array<byte> payload) {
if (sendClosed) {
return kj::READY_NOW;
}
sendParts[0] = sendHeader.compose(true, OPCODE_PONG, payload.size(), Mask(maskKeyGenerator));
sendParts[1] = payload;
return stream->write(sendParts).attach(kj::mv(payload));
}
};
kj::Own<WebSocket> upgradeToWebSocket(
kj::Own<kj::AsyncIoStream> stream, HttpInputStream& httpInput, HttpOutputStream& httpOutput,
kj::Maybe<EntropySource&> maskKeyGenerator) {
// Create a WebSocket upgraded from an HTTP stream.
auto releasedBuffer = httpInput.releaseBuffer();
return kj::heap<WebSocketImpl>(kj::mv(stream), maskKeyGenerator,
kj::mv(releasedBuffer.buffer), releasedBuffer.leftover,
httpOutput.flush());
}
} // namespace
kj::Own<WebSocket> newWebSocket(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<EntropySource&> maskKeyGenerator) {
return kj::heap<WebSocketImpl>(kj::mv(stream), maskKeyGenerator);
}
// =======================================================================================
namespace {
class HttpClientImpl final: public HttpClient {
public:
HttpClientImpl(HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& rawStream)
: httpInput(rawStream, responseHeaderTable),
httpOutput(rawStream) {}
HttpClientImpl(HttpHeaderTable& responseHeaderTable, kj::Own<kj::AsyncIoStream> rawStream,
kj::Maybe<EntropySource&> entropySource)
: httpInput(*rawStream, responseHeaderTable),
httpOutput(*rawStream),
ownStream(kj::mv(rawStream)),
entropySource(entropySource) {}
Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
KJ_REQUIRE(!upgraded,
"can't make further requests on this HttpClient because it has been or is in the process "
"of being upgraded");
HttpHeaders::ConnectionHeaders connectionHeaders;
kj::String lengthStr;
......@@ -1599,15 +2428,88 @@ public:
return { kj::mv(bodyStream), kj::mv(responsePromise) };
}
kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers) override {
KJ_REQUIRE(!upgraded,
"can't make further requests on this HttpClient because it has been or is in the process "
"of being upgraded");
// Mark upgraded for now, even though the upgrade could fail, because we can't allow pipelined
// requests in the meantime.
upgraded = true;
byte keyBytes[16];
KJ_ASSERT_NONNULL(this->entropySource,
"can't use openWebSocket() because no EntropySource was provided when creating the "
"HttpClient").generate(keyBytes);
auto keyBase64 = kj::encodeBase64(keyBytes);
HttpHeaders::ConnectionHeaders connectionHeaders;
connectionHeaders.connection = "Upgrade";
connectionHeaders.upgrade = "websocket";
connectionHeaders.websocketVersion = "13";
connectionHeaders.websocketKey = keyBase64;
httpOutput.writeHeaders(headers.serializeRequest(HttpMethod::GET, url, connectionHeaders));
// No entity-body.
httpOutput.finishBody();
return httpInput.readResponseHeaders()
.then(kj::mvCapture(keyBase64,
[this](kj::StringPtr keyBase64, kj::Maybe<HttpHeaders::Response>&& response)
-> HttpClient::WebSocketResponse {
KJ_IF_MAYBE(r, response) {
if (r->statusCode == 101) {
if (!fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
r->connectionHeaders.upgrade.cStr())) {
KJ_FAIL_REQUIRE("server returned incorrect Upgrade header; should be 'websocket'",
r->connectionHeaders.upgrade) { break; }
return HttpClient::WebSocketResponse();
}
auto expectedAccept = generateWebSocketAccept(keyBase64);
if (r->connectionHeaders.websocketAccept != expectedAccept) {
KJ_FAIL_REQUIRE("server returned incorrect Sec-WebSocket-Accept header",
r->connectionHeaders.websocketAccept, expectedAccept) { break; }
return HttpClient::WebSocketResponse();
}
return {
r->statusCode,
r->statusText,
&httpInput.getHeaders(),
upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, entropySource),
};
} else {
upgraded = false;
return {
r->statusCode,
r->statusText,
&httpInput.getHeaders(),
httpInput.getEntityBody(HttpInputStream::RESPONSE, HttpMethod::GET, r->statusCode,
r->connectionHeaders)
};
}
} else {
KJ_FAIL_REQUIRE("received invalid HTTP response") { break; }
return HttpClient::WebSocketResponse();
}
}));
}
private:
HttpInputStream httpInput;
HttpOutputStream httpOutput;
kj::Own<AsyncIoStream> ownStream;
kj::Maybe<EntropySource&> entropySource;
bool upgraded = false;
};
} // namespace
kj::Promise<HttpClient::WebSocketResponse> HttpClient::openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, kj::Own<WebSocket> downstream) {
kj::StringPtr url, const HttpHeaders& headers) {
return request(HttpMethod::GET, url, headers, nullptr)
.response.then([](HttpClient::Response&& response) -> WebSocketResponse {
kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>> body;
......@@ -1627,8 +2529,11 @@ kj::Promise<kj::Own<kj::AsyncIoStream>> HttpClient::connect(kj::StringPtr host)
}
kj::Own<HttpClient> newHttpClient(
HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream) {
return kj::heap<HttpClientImpl>(responseHeaderTable, stream);
HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream,
kj::Maybe<EntropySource&> entropySource) {
return kj::heap<HttpClientImpl>(responseHeaderTable,
kj::Own<kj::AsyncIoStream>(&stream, kj::NullDisposer::instance),
entropySource);
}
// =======================================================================================
......@@ -1660,6 +2565,30 @@ public:
return kj::joinPromises(promises.finish());
}
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
return client.openWebSocket(url, headers)
.then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> {
KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) {
KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) {
auto ws2 = response.acceptWebSocket(*innerResponse.headers);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(pumpWebSocket(*ws, *ws2));
promises.add(pumpWebSocket(*ws2, *ws));
return kj::joinPromises(promises.finish()).attach(kj::mv(ws), kj::mv(ws2));
}
KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
body->tryGetLength());
auto promise = body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(body));
}
}
KJ_UNREACHABLE;
});
}
kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override {
return client.connect(kj::mv(host));
}
......@@ -1668,6 +2597,41 @@ public:
private:
HttpClient& client;
static kj::Promise<void> pumpWebSocket(WebSocket& from, WebSocket& to) {
return kj::evalNow([&]() {
return pumpWebSocketLoop(from, to);
}).catch_([&to](kj::Exception&& e) -> kj::Promise<void> {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return to.disconnect();
} else {
return to.close(1002, e.getDescription());
}
});
}
static kj::Promise<void> pumpWebSocketLoop(WebSocket& from, WebSocket& to) {
return from.receive().then([&from,&to](WebSocket::Message&& message) {
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(text, kj::String) {
return to.send(text)
.attach(kj::mv(text))
.then([&from,&to]() { return pumpWebSocketLoop(from, to); });
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
return to.send(data)
.attach(kj::mv(data))
.then([&from,&to]() { return pumpWebSocketLoop(from, to); });
}
KJ_CASE_ONEOF(close, WebSocket::Close) {
return to.close(close.code, close.reason)
.attach(kj::mv(close))
.then([&from,&to]() { return pumpWebSocketLoop(from, to); });
}
}
KJ_UNREACHABLE;
});
}
};
} // namespace
......@@ -1696,14 +2660,8 @@ kj::Promise<kj::Own<kj::AsyncIoStream>> HttpService::connect(kj::StringPtr host)
KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService");
}
class HttpServer::Connection final: private HttpService::Response {
class HttpServer::Connection final: private HttpService::WebSocketResponse {
public:
Connection(HttpServer& server, kj::AsyncIoStream& stream)
: server(server),
httpInput(stream, server.requestHeaderTable),
httpOutput(stream) {
++server.connectionCount;
}
Connection(HttpServer& server, kj::Own<kj::AsyncIoStream>&& stream)
: server(server),
httpInput(*stream, server.requestHeaderTable),
......@@ -1767,7 +2725,7 @@ public:
}
return receivedHeaders
.then([this,firstRequest](kj::Maybe<HttpHeaders::Request>&& request) -> kj::Promise<void> {
.then([this](kj::Maybe<HttpHeaders::Request>&& request) -> kj::Promise<void> {
if (closed) {
// Client closed connection. Close our end too.
return httpOutput.flush();
......@@ -1787,7 +2745,30 @@ public:
}
KJ_IF_MAYBE(req, request) {
kj::Promise<void> promise = nullptr;
if (fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
req->connectionHeaders.upgrade.cStr())) {
if (req->method != HttpMethod::GET) {
return sendError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
if (req->connectionHeaders.websocketVersion != "13") {
return sendError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
if (req->connectionHeaders.websocketKey == nullptr) {
return sendError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
currentMethod = HttpMethod::GET;
websocketKey = kj::str(req->connectionHeaders.websocketKey);
promise = server.service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
currentMethod = req->method;
websocketKey = nullptr;
auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, req->connectionHeaders);
......@@ -1796,13 +2777,26 @@ public:
// be able to shutdown the upstream but still wait on the downstream, but I believe many
// other HTTP servers do similar things.
auto promise = server.service.request(
promise = server.service.request(
req->method, req->url, httpInput.getHeaders(), *body, *this);
return promise.attach(kj::mv(body))
.then([this]() { return httpOutput.flush(); })
promise = promise.attach(kj::mv(body));
}
return promise
.then([this]() -> kj::Promise<void> {
// Response done. Await next request.
if (upgraded) {
// We've upgraded to WebSocket so we can exit this listen loop. In fact, we no longer
// own the stream.
//
// Note that the WebSocket itself also flush()es the httpOutput before writing any
// WebSocket content, but we should also make sure that we don't let the listen loop
// exit until that flush is done, since we can't destroy the HttpOutputStream in the
// meantime.
return httpOutput.flush();
}
if (currentMethod != nullptr) {
return sendError(500, "Internal Server Error", kj::str(
"ERROR: The HttpService did not generate a response."));
......@@ -1813,7 +2807,7 @@ public:
return httpOutput.flush();
}
return loop(false);
return httpOutput.flush().then([this]() { return loop(false); });
});
} else {
// Bad request.
......@@ -1862,8 +2856,10 @@ private:
HttpOutputStream httpOutput;
kj::Own<kj::AsyncIoStream> ownStream;
kj::Maybe<HttpMethod> currentMethod;
kj::Maybe<kj::String> websocketKey;
bool timedOut = false;
bool closed = false;
bool upgraded = false;
kj::Own<kj::AsyncOutputStream> send(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers,
......@@ -1871,6 +2867,12 @@ private:
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called startResponse()");
currentMethod = nullptr;
if (websocketKey != nullptr) {
// This was a WebSocket request but the upgrade wasn't accepted.
websocketKey = nullptr;
httpInput.finishRead();
}
HttpHeaders::ConnectionHeaders connectionHeaders;
kj::String lengthStr;
......@@ -1901,6 +2903,25 @@ private:
}
}
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
auto key = KJ_REQUIRE_NONNULL(kj::mv(websocketKey), "not a WebSocket request");
currentMethod = nullptr;
websocketKey = nullptr;
upgraded = true;
auto websocketAccept = generateWebSocketAccept(key);
HttpHeaders::ConnectionHeaders connectionHeaders;
connectionHeaders.websocketAccept = websocketAccept;
connectionHeaders.upgrade = "websocket";
connectionHeaders.connection = "Upgrade";
httpOutput.writeHeaders(headers.serializeResponse(
101, "Switching Protocols", connectionHeaders));
return upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, nullptr);
}
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::String body) {
auto bodySize = kj::str(body.size());
......
......@@ -86,7 +86,10 @@ namespace kj {
MACRO(te, "TE") \
MACRO(trailer, "Trailer") \
MACRO(transferEncoding, "Transfer-Encoding") \
MACRO(upgrade, "Upgrade")
MACRO(upgrade, "Upgrade") \
MACRO(websocketKey, "Sec-WebSocket-Key") \
MACRO(websocketVersion, "Sec-WebSocket-Version") \
MACRO(websocketAccept, "Sec-WebSocket-Accept")
enum class HttpMethod {
// Enum of known HTTP methods.
......@@ -368,13 +371,54 @@ private:
// also add direct accessors for those headers.
};
class EntropySource {
// Interface for an object that generates entropy. Typically, cryptographically-random entropy
// is expected.
//
// TODO(cleanup): Put this somewhere more general.
public:
virtual void generate(kj::ArrayPtr<byte> buffer) = 0;
};
class WebSocket {
// Interface representincg an open WebSocket session.
//
// Each side can send and receive data and "close" messages.
//
// Ping/Pong and message fragmentation are not exposed through this interface. These features of
// the underlying WebSocket protocol are not exposed by the browser-level Javascript API either,
// and thus applications typically need to implement these features at the application protocol
// level instead. The implementation is, however, expected to reply to Ping messages it receives.
public:
WebSocket(kj::Own<kj::AsyncIoStream> stream);
// Create a WebSocket wrapping the given I/O stream.
virtual kj::Promise<void> send(kj::ArrayPtr<const byte> message) = 0;
virtual kj::Promise<void> send(kj::ArrayPtr<const char> message) = 0;
// Send a message (binary or text). The underlying buffer must remain valid, and you must not
// call send() again, until the returned promise resolves.
virtual kj::Promise<void> close(uint16_t code, kj::StringPtr reason) = 0;
// Send a Close message.
//
// Note that the returned Promise resolves once the message has been sent -- it does NOT wait
// for the other end to send a Close reply. The application should await a reply before dropping
// the WebSocket object.
virtual kj::Promise<void> disconnect() = 0;
// Sends EOF on the underlying connection without sending a "close" message. This is NOT a clean
// shutdown, but is sometimes useful when you want the other end to trigger whatever behavior
// it normally triggers when a connection is dropped.
struct Close {
uint16_t code;
kj::String reason;
};
typedef kj::OneOf<kj::String, kj::Array<byte>, Close> Message;
kj::Promise<void> send(kj::ArrayPtr<const byte> message);
kj::Promise<void> send(kj::ArrayPtr<const char> message);
virtual kj::Promise<Message> receive() = 0;
// Read one message from the WebSocket and return it. Can only call once at a time. Do not call
// again after EndOfStream is received.
};
class HttpClient {
......@@ -392,7 +436,7 @@ public:
kj::StringPtr statusText;
const HttpHeaders* headers;
kj::Own<kj::AsyncInputStream> body;
// `statusText` and `headers` remain valid until `body` is dropped.
// `statusText` and `headers` remain valid until `body` is dropped or read from.
};
struct Request {
......@@ -424,14 +468,15 @@ public:
uint statusCode;
kj::StringPtr statusText;
const HttpHeaders* headers;
kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>> upstreamOrBody;
// `statusText` and `headers` remain valid until `upstreamOrBody` is dropped.
kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>> webSocketOrBody;
// `statusText` and `headers` remain valid until `webSocketOrBody` is dropped or read from.
};
virtual kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, kj::Own<WebSocket> downstream);
kj::StringPtr url, const HttpHeaders& headers);
// Tries to open a WebSocket. Default implementation calls send() and never returns a WebSocket.
//
// `url` and `headers` are invalidated when the returned promise resolves.
// `url` and `headers` need only remain valid until `openWebSocket()` returns (they can be
// stack-allocated).
virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host);
// Handles CONNECT requests. Only relevant for proxy clients. Default implementation throws
......@@ -478,13 +523,10 @@ public:
class WebSocketResponse: public Response {
public:
kj::Own<WebSocket> startWebSocket(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers,
WebSocket& upstream);
// Begin the response.
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// Accept and open the WebSocket.
//
// `statusText` and `headers` need only remain valid until startWebSocket() returns (they can
// be stack-allocated).
// `headers` need only remain valid until acceptWebSocket() returns (it can be stack-allocated).
};
virtual kj::Promise<void> openWebSocket(
......@@ -500,15 +542,25 @@ public:
};
kj::Own<HttpClient> newHttpClient(HttpHeaderTable& responseHeaderTable, kj::Network& network,
kj::Maybe<kj::Network&> tlsNetwork = nullptr);
kj::Maybe<kj::Network&> tlsNetwork = nullptr,
kj::Maybe<EntropySource&> entropySource = nullptr);
// Creates a proxy HttpClient that connects to hosts over the given network.
//
// `responseHeaderTable` is used when parsing HTTP responses. Requests can use any header table.
//
// `tlsNetwork` is required to support HTTPS destination URLs. Otherwise, only HTTP URLs can be
// fetched.
kj::Own<HttpClient> newHttpClient(HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream);
//
// `entropySource` must be provided in order to use `openWebSocket`. If you don't need WebSockets,
// `entropySource` can be omitted. The WebSocket protocol uses random values to avoid triggering
// flaws (including security flaws) in certain HTTP proxy software. Specifically, entropy is used
// to generate the `Sec-WebSocket-Key` header and to generate frame masks. If you know that there
// are no broken or vulnerable proxies between you and the server, you can provide an dummy entropy
// source that doesn't generate real entropy (e.g. returning the same value every time). Otherwise,
// you must provide a cryptographically-random entropy source.
kj::Own<HttpClient> newHttpClient(HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream,
kj::Maybe<EntropySource&> entropySource = nullptr);
// Creates an HttpClient that speaks over the given pre-established connection. The client may
// be used as a proxy client or a host client depending on whether the peer is operating as
// a proxy.
......@@ -518,11 +570,33 @@ kj::Own<HttpClient> newHttpClient(HttpHeaderTable& responseHeaderTable, kj::Asyn
// fail as well. If the destination server chooses to close the connection after a response,
// subsequent requests will fail. If a response takes a long time, it blocks subsequent responses.
// If a WebSocket is opened successfully, all subsequent requests fail.
//
// `entropySource` must be provided in order to use `openWebSocket`. If you don't need WebSockets,
// `entropySource` can be omitted. The WebSocket protocol uses random values to avoid triggering
// flaws (including security flaws) in certain HTTP proxy software. Specifically, entropy is used
// to generate the `Sec-WebSocket-Key` header and to generate frame masks. If you know that there
// are no broken or vulnerable proxies between you and the server, you can provide an dummy entropy
// source that doesn't generate real entropy (e.g. returning the same value every time). Otherwise,
// you must provide a cryptographically-random entropy source.
kj::Own<HttpClient> newHttpClient(HttpService& service);
kj::Own<HttpService> newHttpService(HttpClient& client);
// Adapts an HttpClient to an HttpService and vice versa.
kj::Own<WebSocket> newWebSocket(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<EntropySource&> maskEntropySource);
// Create a new WebSocket on top of the given stream. It is assumed that the HTTP -> WebSocket
// upgrade handshake has already occurred (or is not needed), and messages can immediately be
// sent and received on the stream. Normally applications would not call this directly.
//
// `maskEntropySource` is used to generate cryptographically-random frame masks. If null, outgoing
// frames will not be masked. Servers are required NOT to mask their outgoing frames, but clients
// ARE required to do so. So, on the client side, you MUST specify an entropy source. The mask
// must be crytographically random if the data being sent on the WebSocket may be malicious. The
// purpose of the mask is to prevent badly-written HTTP proxies from interpreting "things that look
// like HTTP requests" in a message as being actual HTTP requests, which could result in cache
// poisoning. See RFC6455 section 10.3.
struct HttpServerSettings {
kj::Duration headerTimeout = 15 * kj::SECONDS;
// After initial connection open, or after receiving the first byte of a pipelined request,
......
......@@ -98,4 +98,41 @@ TEST(OneOf, Copy) {
EXPECT_STREQ("foo", var2.get<const char*>());
}
TEST(OneOf, Switch) {
OneOf<int, float, const char*> var;
var = "foo";
uint count = 0;
{
KJ_SWITCH_ONEOF(var) {
KJ_CASE_ONEOF(i, int) {
KJ_FAIL_ASSERT("expected char*, got int", i);
}
KJ_CASE_ONEOF(s, const char*) {
KJ_EXPECT(kj::StringPtr(s) == "foo");
++count;
}
KJ_CASE_ONEOF(n, float) {
KJ_FAIL_ASSERT("expected char*, got float", n);
}
}
}
KJ_EXPECT(count == 1);
{
KJ_SWITCH_ONEOF(kj::cp(var)) {
KJ_CASE_ONEOF(i, int) {
KJ_FAIL_ASSERT("expected char*, got int", i);
}
KJ_CASE_ONEOF(s, const char*) {
KJ_EXPECT(kj::StringPtr(s) == "foo");
}
KJ_CASE_ONEOF(n, float) {
KJ_FAIL_ASSERT("expected char*, got float", n);
}
}
}
}
} // namespace kj
......@@ -37,6 +37,31 @@ struct TypeIndex_ { static constexpr uint value = TypeIndex_<i + 1, Key, Rest...
template <uint i, typename Key, typename... Rest>
struct TypeIndex_<i, Key, Key, Rest...> { static constexpr uint value = i; };
enum class Variants0 {};
enum class Variants1 { _variant0 };
enum class Variants2 { _variant0, _variant1 };
enum class Variants3 { _variant0, _variant1, _variant2 };
enum class Variants4 { _variant0, _variant1, _variant2, _variant3 };
enum class Variants5 { _variant0, _variant1, _variant2, _variant3, _variant4 };
enum class Variants6 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5 };
enum class Variants7 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6 };
enum class Variants8 { _variant0, _variant1, _variant2, _variant3, _variant4, _variant5, _variant6,
_variant7 };
template <uint i> struct Variants_;
template <> struct Variants_<0> { typedef Variants0 Type; };
template <> struct Variants_<1> { typedef Variants1 Type; };
template <> struct Variants_<2> { typedef Variants2 Type; };
template <> struct Variants_<3> { typedef Variants3 Type; };
template <> struct Variants_<4> { typedef Variants4 Type; };
template <> struct Variants_<5> { typedef Variants5 Type; };
template <> struct Variants_<6> { typedef Variants6 Type; };
template <> struct Variants_<7> { typedef Variants7 Type; };
template <> struct Variants_<8> { typedef Variants8 Type; };
template <uint i>
using Variants = typename Variants_<i>::Type;
} // namespace _ (private)
template <typename... Variants>
......@@ -48,7 +73,12 @@ class OneOf {
public:
inline OneOf(): tag(0) {}
OneOf(const OneOf& other) { copyFrom(other); }
OneOf(OneOf& other) { copyFrom(other); }
OneOf(OneOf&& other) { moveFrom(other); }
template <typename T>
OneOf(T&& other): tag(typeIndex<Decay<T>>()) {
ctor(*reinterpret_cast<Decay<T>*>(space), kj::fwd<T>(other));
}
~OneOf() { destroy(); }
OneOf& operator=(const OneOf& other) { if (tag != 0) destroy(); copyFrom(other); return *this; }
......@@ -96,6 +126,22 @@ public:
// block call allHandled<n>() where n is the number of variants. This will fail to compile
// if new variants are added in the future.
typedef _::Variants<sizeof...(Variants)> Tag;
Tag which() {
KJ_IREQUIRE(tag != 0, "Can't KJ_SWITCH_ONEOF() on uninitialized value.");
return static_cast<Tag>(tag - 1);
}
template <typename T>
static constexpr Tag tagFor() {
return static_cast<Tag>(typeIndex<T>() - 1);
}
OneOf* _switchSubject() & { return this; }
const OneOf* _switchSubject() const& { return this; }
_::NullableValue<OneOf> _switchSubject() && { return kj::mv(*this); }
private:
uint tag;
......@@ -150,6 +196,20 @@ private:
doAll(copyVariantFrom<Variants>(other)...);
}
template <typename T>
inline bool copyVariantFrom(OneOf& other) {
if (other.is<T>()) {
ctor(*reinterpret_cast<T*>(space), other.get<T>());
}
return false;
}
void copyFrom(OneOf& other) {
// Initialize as a copy of `other`. Expects that `this` starts out uninitialized, so the tag
// is invalid.
tag = other.tag;
doAll(copyVariantFrom<Variants>(other)...);
}
template <typename T>
inline bool moveVariantFrom(OneOf& other) {
if (other.is<T>()) {
......@@ -176,6 +236,53 @@ void OneOf<Variants...>::allHandled() {
KJ_UNREACHABLE;
}
#if __cplusplus > 201402L
#define KJ_SWITCH_ONEOF(value) \
switch (auto _kj_switch_subject = value._switchSubject(); _kj_switch_subject->which())
#else
#define KJ_SWITCH_ONEOF(value) \
/* Without C++17, we can only support one switch per containing block. Deal with it. */ \
auto _kj_switch_subject = value._switchSubject(); \
switch (_kj_switch_subject->which())
#endif
#define KJ_CASE_ONEOF(name, ...) \
break; \
case ::kj::Decay<decltype(*_kj_switch_subject)>::tagFor<__VA_ARGS__>(): \
for (auto& name = _kj_switch_subject->get<__VA_ARGS__>(), *_kj_switch_done = &name; \
_kj_switch_done; _kj_switch_done = nullptr)
#define KJ_CASE_ONEOF_DEFAULT break; default:
// Allows switching over a OneOf.
//
// Example:
//
// kj::OneOf<int, float, const char*> variant;
// KJ_SWITCH_ONEOF(variant) {
// KJ_CASE_ONEOF(i, int) {
// doSomethingWithInt(i);
// }
// KJ_CASE_ONEOF(s, const char*) {
// doSomethingWithString(s);
// }
// KJ_CASE_ONEOF_DEFAULT {
// doSomethingElse();
// }
// }
//
// Notes:
// - If you don't handle all possible types and don't include a default branch, you'll get a
// compiler warning, just like a regular switch() over an enum where one of the enum values is
// missing.
// - There's no need for a `break` statement in a KJ_CASE_ONEOF; it is implied.
// - Under C++11 and C++14, only one KJ_SWITCH_ONEOF() can appear in a block. Wrap the switch in
// a pair of braces if you need a second switch in the same block. If C++17 is enabled, this is
// not an issue.
//
// Implementation notes:
// - The use of __VA_ARGS__ is to account for template types that have commas separating type
// parameters, since macros don't recognize <> as grouping.
// - _kj_switch_done is really used as a boolean flag to prevent the for() loop from actually
// looping, but it's defined as a pointer since that's all we can define in this context.
} // namespace kj
#endif // KJ_ONE_OF_H_
......@@ -11,6 +11,9 @@ QUICK=
PARALLEL=$(nproc 2>/dev/null || echo 1)
# Have automake dump test failure to stdout. Important for CI.
export VERBOSE=true
while [ $# -gt 0 ]; do
case "$1" in
-j* )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment