Commit 8a099294 authored by Kenton Varda's avatar Kenton Varda

Implement WebSocket core protocol.

Still need to add handshake separately.
parent 2d72fe55
...@@ -381,6 +381,26 @@ kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { ...@@ -381,6 +381,26 @@ kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) {
})); }));
} }
kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::ArrayPtr<const byte> expected) {
if (expected.size() == 0) return kj::READY_NOW;
auto buffer = kj::heapArray<byte>(expected.size());
auto promise = in.tryRead(buffer.begin(), 1, buffer.size());
return promise.then(kj::mvCapture(buffer, [&in,expected](kj::Array<byte> buffer, size_t amount) {
if (amount == 0) {
KJ_FAIL_ASSERT("expected data never sent", expected);
}
auto actual = buffer.slice(0, amount);
if (memcmp(actual.begin(), expected.begin(), actual.size()) != 0) {
KJ_FAIL_ASSERT("data from stream doesn't match expected", expected, actual);
}
return expectRead(in, expected.slice(amount, expected.size()));
}));
}
void testHttpClientRequest(kj::AsyncIoContext& io, const HttpRequestTestCase& testCase) { void testHttpClientRequest(kj::AsyncIoContext& io, const HttpRequestTestCase& testCase) {
auto pipe = io.provider->newTwoWayPipe(); auto pipe = io.provider->newTwoWayPipe();
...@@ -1150,6 +1170,298 @@ KJ_TEST("HttpClient <-> HttpServer") { ...@@ -1150,6 +1170,298 @@ KJ_TEST("HttpClient <-> HttpServer") {
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
KJ_TEST("WebSocket core protocol") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = newWebSocket(kj::mv(pipe.ends[0]));
auto server = newWebSocket(kj::mv(pipe.ends[1]));
auto mediumString = kj::strArray(kj::repeat(kj::StringPtr("123456789"), 30), "");
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("123456789"), 10000), "");
auto clientTask = client->send(kj::StringPtr("hello"))
.then([&]() { return client->send(mediumString); })
.then([&]() { return client->send(bigString); })
.then([&]() { return client->send(kj::StringPtr("world").asBytes()); })
.then([&]() { return client->close(1234, "bored"); });
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "hello");
}
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == mediumString);
}
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == bigString);
}
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::Array<byte>>());
KJ_EXPECT(kj::str(message.get<kj::Array<byte>>().asChars()) == "world");
}
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<WebSocket::Close>());
KJ_EXPECT(message.get<WebSocket::Close>().code == 1234);
KJ_EXPECT(message.get<WebSocket::Close>().reason == "bored");
}
auto serverTask = server->close(4321, "whatever");
{
auto message = client->receive().wait(io.waitScope);
KJ_ASSERT(message.is<WebSocket::Close>());
KJ_EXPECT(message.get<WebSocket::Close>().code == 4321);
KJ_EXPECT(message.get<WebSocket::Close>().reason == "whatever");
}
clientTask.wait(io.waitScope);
serverTask.wait(io.waitScope);
}
KJ_TEST("WebSocket fragmented") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
byte DATA[] = {
0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ',
0x00, 0x03, 'w', 'o', 'r',
0x80, 0x02, 'l', 'd',
};
auto clientTask = client->write(DATA, sizeof(DATA));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "hello world");
}
clientTask.wait(io.waitScope);
}
class ConstantMaskGenerator final: public WebSocket::MaskKeyGenerator {
public:
void next(byte (&bytes)[4]) override {
bytes[0] = 12;
bytes[1] = 34;
bytes[2] = 56;
bytes[3] = 78;
}
};
KJ_TEST("WebSocket masked") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
ConstantMaskGenerator maskGenerator;
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]), maskGenerator);
byte DATA[] = {
0x81, 0x86, 12, 34, 56, 78, 'h' ^ 12, 'e' ^ 34, 'l' ^ 56, 'l' ^ 78, 'o' ^ 12, ' ' ^ 34,
};
auto clientTask = client->write(DATA, sizeof(DATA));
auto serverTask = server->send(kj::StringPtr("hello "));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "hello ");
}
expectRead(*client, DATA).wait(io.waitScope);
clientTask.wait(io.waitScope);
serverTask.wait(io.waitScope);
}
KJ_TEST("WebSocket unsolicited pong") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
byte DATA[] = {
0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ',
0x8A, 0x03, 'f', 'o', 'o',
0x80, 0x05, 'w', 'o', 'r', 'l', 'd',
};
auto clientTask = client->write(DATA, sizeof(DATA));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "hello world");
}
clientTask.wait(io.waitScope);
}
KJ_TEST("WebSocket ping") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
// Be extra-annoying by having the ping arrive between fragments.
byte DATA[] = {
0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ',
0x89, 0x03, 'f', 'o', 'o',
0x80, 0x05, 'w', 'o', 'r', 'l', 'd',
};
auto clientTask = client->write(DATA, sizeof(DATA));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "hello world");
}
auto serverTask = server->send(kj::StringPtr("bar"));
byte EXPECTED[] = {
0x8A, 0x03, 'f', 'o', 'o', // pong
0x81, 0x03, 'b', 'a', 'r', // message
};
expectRead(*client, EXPECTED).wait(io.waitScope);
clientTask.wait(io.waitScope);
serverTask.wait(io.waitScope);
}
KJ_TEST("WebSocket ping mid-send") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), "");
auto serverTask = server->send(bigString).eagerlyEvaluate(nullptr);
byte DATA[] = {
0x89, 0x03, 'f', 'o', 'o', // ping
0x81, 0x03, 'b', 'a', 'r', // some other message
};
auto clientTask = client->write(DATA, sizeof(DATA));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "bar");
}
byte EXPECTED1[] = { 0x81, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 };
expectRead(*client, EXPECTED1).wait(io.waitScope);
expectRead(*client, bigString).wait(io.waitScope);
byte EXPECTED2[] = { 0x8A, 0x03, 'f', 'o', 'o' };
expectRead(*client, EXPECTED2).wait(io.waitScope);
clientTask.wait(io.waitScope);
serverTask.wait(io.waitScope);
}
KJ_TEST("WebSocket double-ping mid-send") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), "");
auto serverTask = server->send(bigString).eagerlyEvaluate(nullptr);
byte DATA[] = {
0x89, 0x03, 'f', 'o', 'o', // ping
0x89, 0x03, 'q', 'u', 'x', // ping2
0x81, 0x03, 'b', 'a', 'r', // some other message
};
auto clientTask = client->write(DATA, sizeof(DATA));
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "bar");
}
byte EXPECTED1[] = { 0x81, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 };
expectRead(*client, EXPECTED1).wait(io.waitScope);
expectRead(*client, bigString).wait(io.waitScope);
byte EXPECTED2[] = { 0x8A, 0x03, 'q', 'u', 'x' };
expectRead(*client, EXPECTED2).wait(io.waitScope);
clientTask.wait(io.waitScope);
serverTask.wait(io.waitScope);
}
KJ_TEST("WebSocket ping received during pong send") {
auto io = kj::setupAsyncIo();
auto pipe = io.provider->newTwoWayPipe();
auto client = kj::mv(pipe.ends[0]);
auto server = newWebSocket(kj::mv(pipe.ends[1]));
// Send a very large ping so that sending the pong takes a while. Then send a second ping
// immediately after.
byte PREFIX[] = { 0x89, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 };
auto bigString = kj::strArray(kj::repeat(kj::StringPtr("12345678"), 65536), "");
byte POSTFIX[] = {
0x89, 0x03, 'f', 'o', 'o',
0x81, 0x03, 'b', 'a', 'r',
};
kj::ArrayPtr<const byte> parts[] = {PREFIX, bigString.asBytes(), POSTFIX};
auto clientTask = client->write(parts);
{
auto message = server->receive().wait(io.waitScope);
KJ_ASSERT(message.is<kj::String>());
KJ_EXPECT(message.get<kj::String>() == "bar");
}
byte EXPECTED1[] = { 0x8A, 0x7f, 0, 0, 0, 0, 0, 8, 0, 0 };
expectRead(*client, EXPECTED1).wait(io.waitScope);
expectRead(*client, bigString).wait(io.waitScope);
byte EXPECTED2[] = { 0x8A, 0x03, 'f', 'o', 'o' };
expectRead(*client, EXPECTED2).wait(io.waitScope);
clientTask.wait(io.waitScope);
}
// -----------------------------------------------------------------------------
KJ_TEST("HttpServer request timeout") { KJ_TEST("HttpServer request timeout") {
auto PIPELINE_TESTS = pipelineTestCases(); auto PIPELINE_TESTS = pipelineTestCases();
......
...@@ -1547,6 +1547,466 @@ private: ...@@ -1547,6 +1547,466 @@ private:
// ======================================================================================= // =======================================================================================
class WebSocketImpl final: public WebSocket {
public:
WebSocketImpl(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator,
kj::Array<byte> buffer = kj::heapArray<byte>(4096),
size_t bytesAlreadyAvailable = 0)
: stream(kj::mv(stream)), maskKeyGenerator(maskKeyGenerator),
recvAvail(bytesAlreadyAvailable), recvBuffer(kj::mv(buffer)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return sendImpl(OPCODE_BINARY, message);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return sendImpl(OPCODE_TEXT, message.asBytes());
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
kj::Array<byte> payload;
if (code == 1005) {
KJ_REQUIRE(reason.size() == 0, "WebSocket close code 1005 cannot have a reason");
// code 1005 -- leave payload empty
} else {
payload = heapArray<byte>(reason.size() + 2);
payload[0] = code >> 8;
payload[1] = code;
memcpy(payload.begin() + 2, reason.begin(), reason.size());
}
auto promise = sendImpl(OPCODE_CLOSE, payload);
return promise.attach(kj::mv(payload));
}
kj::Promise<Message> receive() override {
auto& recvHeader = *reinterpret_cast<Header*>(recvBuffer.begin());
size_t headerSize = recvHeader.headerSize(recvAvail);
if (headerSize > recvAvail) {
return stream->tryRead(recvBuffer.begin() + recvAvail, 1, recvBuffer.size() - recvAvail)
.then([this](size_t actual) -> kj::Promise<Message> {
if (actual == 0) {
if (recvAvail) {
return KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in frame header");
} else {
// It's incorrect for the WebSocket to disconnect without sending `Close`.
return KJ_EXCEPTION(DISCONNECTED,
"WebSocket disconnected between frames without sending `Close`.");
}
}
recvAvail += actual;
return receive();
});
}
size_t payloadLen = recvHeader.getPayloadLen();
auto opcode = recvHeader.getOpcode();
bool isData = opcode < OPCODE_FIRST_CONTROL;
if (opcode == OPCODE_CONTINUATION) {
KJ_REQUIRE(!fragments.empty(), "unexpected continuation frame in WebSocket");
opcode = fragmentOpcode;
} else if (isData) {
KJ_REQUIRE(fragments.empty(), "expected continuation frame in WebSocket");
}
bool isFin = recvHeader.isFin();
kj::Array<byte> message; // space to allocate
byte* payloadTarget; // location into which to read payload (size is payloadLen)
if (isFin) {
// Add space for NUL terminator when allocating text message.
size_t amountToAllocate = payloadLen + (opcode == OPCODE_TEXT && isFin);
if (isData && !fragments.empty()) {
// Final frame of a fragmented message. Gather the fragments.
size_t offset = 0;
for (auto& fragment: fragments) offset += fragment.size();
message = kj::heapArray<byte>(offset + amountToAllocate);
offset = 0;
for (auto& fragment: fragments) {
memcpy(message.begin() + offset, fragment.begin(), fragment.size());
offset += fragment.size();
}
payloadTarget = message.begin() + offset;
fragments.clear();
fragmentOpcode = 0;
} else {
// Single-frame message.
message = kj::heapArray<byte>(amountToAllocate);
payloadTarget = message.begin();
}
} else {
// Fragmented message, and this isn't the final fragment.
KJ_REQUIRE(isData, "WebSocket control frame cannot be fragmented");
message = kj::heapArray<byte>(payloadLen);
payloadTarget = message.begin();
if (fragments.empty()) {
// This is the first fragment, so set the opcode.
fragmentOpcode = opcode;
}
}
Mask mask = recvHeader.getMask();
auto handleMessage = kj::mvCapture(message,
[this,opcode,payloadTarget,payloadLen,mask,isFin]
(kj::Array<byte>&& message) -> kj::Promise<Message> {
if (!mask.isZero()) {
mask.apply(kj::arrayPtr(payloadTarget, payloadLen));
}
if (!isFin) {
// Add fragment to the list and loop.
fragments.add(kj::mv(message));
return receive();
}
switch (opcode) {
case OPCODE_CONTINUATION:
// Shouldn't get here; handled above.
KJ_UNREACHABLE;
case OPCODE_TEXT:
message.back() = '\0';
return Message(kj::String(message.releaseAsChars()));
case OPCODE_BINARY:
return Message(message.releaseAsBytes());
case OPCODE_CLOSE:
if (message.size() < 2) {
return Message(Close { 1005, nullptr });
} else {
uint16_t status = (static_cast<uint16_t>(message[0]) << 8)
| (static_cast<uint16_t>(message[1]) );
return Message(Close {
status, kj::heapString(message.slice(2, message.size()).asChars())
});
}
case OPCODE_PING:
// Send back a pong.
queuePong(kj::mv(message));
return receive();
case OPCODE_PONG:
// Unsolicited pong. Ignore.
return receive();
default:
KJ_FAIL_REQUIRE("unknown WebSocket opcode", opcode);
}
});
if (headerSize + payloadLen <= recvAvail) {
// All data already received.
memcpy(payloadTarget, recvBuffer.begin() + headerSize, payloadLen);
size_t consumed = headerSize + payloadLen;
size_t remaining = recvAvail - consumed;
memmove(recvBuffer.begin(), recvBuffer.begin() + consumed, remaining);
recvAvail = remaining;
return handleMessage();
} else {
// Need to read more data.
size_t consumed = recvAvail - headerSize;
memcpy(payloadTarget, recvBuffer.begin() + headerSize, consumed);
recvAvail = 0;
size_t remaining = payloadLen - consumed;
auto promise = stream->tryRead(payloadTarget + consumed, remaining, remaining)
.then([remaining](size_t amount) {
if (amount < remaining) {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in message"));
}
});
return promise.then(kj::mv(handleMessage));
}
}
private:
class Mask {
public:
Mask(): maskBytes { 0, 0, 0, 0 } {}
Mask(const byte* ptr) { memcpy(maskBytes, ptr, 4); }
Mask(kj::Maybe<WebSocket::MaskKeyGenerator&> generator) {
KJ_IF_MAYBE(g, generator) {
g->next(maskBytes);
} else {
memset(maskBytes, 0, 4);
}
}
void apply(kj::ArrayPtr<byte> bytes) const {
apply(bytes.begin(), bytes.size());
}
void copyTo(byte* output) const {
memcpy(output, maskBytes, 4);
}
bool isZero() const {
return (maskBytes[0] | maskBytes[1] | maskBytes[2] | maskBytes[3]) == 0;
}
private:
byte maskBytes[4];
void apply(byte* __restrict__ bytes, size_t size) const {
for (size_t i = 0; i < size; i++) {
bytes[i] ^= maskBytes[i % 4];
}
}
};
class Header {
public:
kj::ArrayPtr<const byte> compose(bool fin, byte opcode, uint64_t payloadLen, Mask mask) {
bytes[0] = (fin ? FIN_MASK : 0) | opcode;
bool hasMask = !mask.isZero();
size_t fill;
if (payloadLen < 126) {
bytes[1] = (hasMask ? USE_MASK_MASK : 0) | payloadLen;
if (hasMask) {
mask.copyTo(bytes + 2);
fill = 6;
} else {
fill = 2;
}
} else if (payloadLen < 65536) {
bytes[1] = (hasMask ? USE_MASK_MASK : 0) | 126;
bytes[2] = static_cast<byte>(payloadLen >> 8);
bytes[3] = static_cast<byte>(payloadLen );
if (hasMask) {
mask.copyTo(bytes + 4);
fill = 8;
} else {
fill = 4;
}
} else {
bytes[1] = (hasMask ? USE_MASK_MASK : 0) | 127;
bytes[2] = static_cast<byte>(payloadLen >> 56);
bytes[3] = static_cast<byte>(payloadLen >> 48);
bytes[4] = static_cast<byte>(payloadLen >> 40);
bytes[5] = static_cast<byte>(payloadLen >> 42);
bytes[6] = static_cast<byte>(payloadLen >> 24);
bytes[7] = static_cast<byte>(payloadLen >> 16);
bytes[8] = static_cast<byte>(payloadLen >> 8);
bytes[9] = static_cast<byte>(payloadLen );
if (hasMask) {
mask.copyTo(bytes + 10);
fill = 14;
} else {
fill = 10;
}
}
return arrayPtr(bytes, fill);
}
bool isFin() const {
return bytes[0] & FIN_MASK;
}
bool hasRsv() const {
return bytes[0] & RSV_MASK;
}
byte getOpcode() const {
return bytes[0] & OPCODE_MASK;
}
uint64_t getPayloadLen() const {
byte payloadLen = bytes[1] & PAYLOAD_MASK;
if (payloadLen == 127) {
return (static_cast<uint64_t>(bytes[2]) << 56)
| (static_cast<uint64_t>(bytes[3]) << 48)
| (static_cast<uint64_t>(bytes[4]) << 40)
| (static_cast<uint64_t>(bytes[5]) << 32)
| (static_cast<uint64_t>(bytes[6]) << 24)
| (static_cast<uint64_t>(bytes[7]) << 16)
| (static_cast<uint64_t>(bytes[8]) << 8)
| (static_cast<uint64_t>(bytes[9]) );
} else if (payloadLen == 126) {
return (static_cast<uint64_t>(bytes[2]) << 8)
| (static_cast<uint64_t>(bytes[3]) );
} else {
return payloadLen;
}
}
Mask getMask() const {
if (bytes[1] & USE_MASK_MASK) {
byte payloadLen = bytes[1] & PAYLOAD_MASK;
if (payloadLen == 128) {
return Mask(bytes + 10);
} else if (payloadLen == 127) {
return Mask(bytes + 4);
} else {
return Mask(bytes + 2);
}
} else {
return Mask();
}
}
size_t headerSize(size_t sizeSoFar) {
if (sizeSoFar < 2) return 2;
size_t required = 2;
if (bytes[1] & USE_MASK_MASK) {
required += 4;
}
byte payloadLen = bytes[1] & PAYLOAD_MASK;
if (payloadLen == 127) {
required += 8;
} else if (payloadLen == 126) {
required += 2;
}
return required;
}
private:
byte bytes[14];
static constexpr byte FIN_MASK = 0x80;
static constexpr byte RSV_MASK = 0x70;
static constexpr byte OPCODE_MASK = 0x0f;
static constexpr byte USE_MASK_MASK = 0x80;
static constexpr byte PAYLOAD_MASK = 0x7f;
};
static constexpr byte OPCODE_CONTINUATION = 0;
static constexpr byte OPCODE_TEXT = 1;
static constexpr byte OPCODE_BINARY = 2;
static constexpr byte OPCODE_CLOSE = 8;
static constexpr byte OPCODE_PING = 9;
static constexpr byte OPCODE_PONG = 10;
static constexpr byte OPCODE_FIRST_CONTROL = 8;
// ---------------------------------------------------------------------------
kj::Own<kj::AsyncIoStream> stream;
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator;
bool sendClosed = false;
bool currentlySending = false;
Header sendHeader;
kj::ArrayPtr<const byte> sendParts[2];
kj::Maybe<kj::Array<byte>> queuedPong;
// If a Ping is received while currentlySending is true, then queuedPong is set to the body of
// a pong message that should be sent once the current send is complete.
kj::Maybe<kj::Promise<void>> sendingPong;
// If a Pong is being sent asynchronously in response to a Ping, this is a promise for the
// completion of that send.
uint fragmentOpcode = 0;
kj::Vector<kj::Array<byte>> fragments;
// If `fragments` is non-empty, we've already received some fragments of a message.
// `fragmentOpcode` is the original opcode.
uint recvAvail = 0;
kj::Array<byte> recvBuffer;
kj::Promise<void> sendImpl(byte opcode, kj::ArrayPtr<const byte> message) {
KJ_REQUIRE(!sendClosed, "WebSocket already closed");
KJ_REQUIRE(!currentlySending, "another message send is already in progress");
currentlySending = true;
KJ_IF_MAYBE(p, sendingPong) {
// We recently sent a pong, make sure it's finished before proceeding.
auto promise = p->then([this, opcode, message]() {
currentlySending = false;
return sendImpl(opcode, message);
});
sendingPong = nullptr;
return promise;
}
sendClosed = opcode == OPCODE_CLOSE;
Mask mask(maskKeyGenerator);
kj::Array<byte> ownMessage;
if (!mask.isZero()) {
// Sadness, we have to make a copy to apply the mask.
ownMessage = kj::heapArray(message);
mask.apply(ownMessage);
message = ownMessage;
}
sendParts[0] = sendHeader.compose(true, opcode, message.size(), mask);
sendParts[1] = message;
auto promise = stream->write(sendParts);
if (!mask.isZero()) {
promise = promise.attach(kj::mv(ownMessage));
}
return promise.then([this]() {
currentlySending = false;
// Send queued pong if needed.
KJ_IF_MAYBE(q, queuedPong) {
kj::Array<byte> payload = kj::mv(*q);
queuedPong = nullptr;
queuePong(kj::mv(payload));
}
});
}
void queuePong(kj::Array<byte> payload) {
if (currentlySending) {
// There is a message-send in progress, so we cannot write to the stream now.
//
// Note: According to spec, if the server receives a second ping before responding to the
// previous one, it can opt to respond only to the last ping. So we don't have to check if
// queuedPong is already non-null.
queuedPong = kj::mv(payload);
} else KJ_IF_MAYBE(promise, sendingPong) {
// We're still sending a previous pong. Wait for it to finish before sending ours.
sendingPong = promise->then(kj::mvCapture(payload, [this](kj::Array<byte> payload) mutable {
return sendPong(kj::mv(payload));
}));
} else {
// We're not sending any pong currently.
sendingPong = sendPong(kj::mv(payload));
}
}
kj::Promise<void> sendPong(kj::Array<byte> payload) {
if (sendClosed) {
return kj::READY_NOW;
}
sendParts[0] = sendHeader.compose(true, OPCODE_PONG, payload.size(), Mask(maskKeyGenerator));
sendParts[1] = payload;
return stream->write(sendParts);
}
};
} // namespace
kj::Own<WebSocket> newWebSocket(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator) {
return kj::heap<WebSocketImpl>(kj::mv(stream), maskKeyGenerator);
}
// =======================================================================================
namespace {
class HttpClientImpl final: public HttpClient { class HttpClientImpl final: public HttpClient {
public: public:
HttpClientImpl(HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& rawStream) HttpClientImpl(HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& rawStream)
...@@ -1607,7 +2067,7 @@ private: ...@@ -1607,7 +2067,7 @@ private:
} // namespace } // namespace
kj::Promise<HttpClient::WebSocketResponse> HttpClient::openWebSocket( kj::Promise<HttpClient::WebSocketResponse> HttpClient::openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, kj::Own<WebSocket> downstream) { kj::StringPtr url, const HttpHeaders& headers) {
return request(HttpMethod::GET, url, headers, nullptr) return request(HttpMethod::GET, url, headers, nullptr)
.response.then([](HttpClient::Response&& response) -> WebSocketResponse { .response.then([](HttpClient::Response&& response) -> WebSocketResponse {
kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>> body; kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>> body;
......
...@@ -369,12 +369,56 @@ private: ...@@ -369,12 +369,56 @@ private:
}; };
class WebSocket { class WebSocket {
// Interface representincg an open WebSocket session.
//
// Each side can send and receive data and "close" messages.
//
// Ping/Pong and message fragmentation are not exposed through this interface. These features of
// the underlying WebSocket protocol are not exposed by the browser-level Javascript API either,
// and thus applications typically need to implement these features at the applicaiton protocol
// level instead. The implementation is, however, expected to reply to Ping messages it receives.
public: public:
WebSocket(kj::Own<kj::AsyncIoStream> stream); virtual kj::Promise<void> send(kj::ArrayPtr<const byte> message) = 0;
// Create a WebSocket wrapping the given I/O stream. virtual kj::Promise<void> send(kj::ArrayPtr<const char> message) = 0;
// Send a message (binary or text). The underlying buffer must remain valid, and you must not
// call send() again, until the returned promise resolves.
virtual kj::Promise<void> close(uint16_t code, kj::StringPtr reason) = 0;
// Send a Close message.
//
// Note that the returned Promise resolves once the message has been sent -- it does NOT wait
// for the other end to send a Close reply. The application should await a reply before dropping
// the WebSocket object.
struct Close {
uint16_t code;
kj::String reason;
};
typedef kj::OneOf<kj::String, kj::Array<byte>, Close> Message;
virtual kj::Promise<Message> receive() = 0;
// Read one message from the WebSocket and return it. Can only call once at a time. Do not call
// again after EndOfStream is received.
class MaskKeyGenerator {
// Class for generating WebSocket packet masks keys. See RFC6455 to understand how masking is
// used in WebSockets.
//
// The RFC insists that mask keys must be crypto-random, but it is not crypto -- it's just a
// value to be XOR'd with each four bytes of the data, and the mask itself is transmitted in
// plaintext ahead of the message. Apparently the WebSocket designers imagined that a random
// mask would make mass surveillance via string matching more difficult, but in practice this
// seems like no more than a minor speedbump. The other purpose of the mask is to prevent dumb
// proxies and captive portals from getting confused, but even a global constant mask could
// accomplish that.
//
// KJ leaves it up to the application to decide how to generate masks.
kj::Promise<void> send(kj::ArrayPtr<const byte> message); public:
kj::Promise<void> send(kj::ArrayPtr<const char> message); virtual void next(byte (&bytes)[4]) = 0;
};
}; };
class HttpClient { class HttpClient {
...@@ -428,10 +472,11 @@ public: ...@@ -428,10 +472,11 @@ public:
// `statusText` and `headers` remain valid until `upstreamOrBody` is dropped. // `statusText` and `headers` remain valid until `upstreamOrBody` is dropped.
}; };
virtual kj::Promise<WebSocketResponse> openWebSocket( virtual kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, kj::Own<WebSocket> downstream); kj::StringPtr url, const HttpHeaders& headers);
// Tries to open a WebSocket. Default implementation calls send() and never returns a WebSocket. // Tries to open a WebSocket. Default implementation calls send() and never returns a WebSocket.
// //
// `url` and `headers` are invalidated when the returned promise resolves. // `url` and `headers` need only remain valid until `openWebSocket()` returns (they can be
// stack-allocated).
virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host); virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host);
// Handles CONNECT requests. Only relevant for proxy clients. Default implementation throws // Handles CONNECT requests. Only relevant for proxy clients. Default implementation throws
...@@ -478,12 +523,11 @@ public: ...@@ -478,12 +523,11 @@ public:
class WebSocketResponse: public Response { class WebSocketResponse: public Response {
public: public:
kj::Own<WebSocket> startWebSocket( kj::Own<WebSocket> acceptWebSocket(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers, uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers);
WebSocket& upstream); // Accept and open the WebSocket.
// Begin the response.
// //
// `statusText` and `headers` need only remain valid until startWebSocket() returns (they can // `statusText` and `headers` need only remain valid until acceptWebSocket() returns (they can
// be stack-allocated). // be stack-allocated).
}; };
...@@ -523,6 +567,15 @@ kj::Own<HttpClient> newHttpClient(HttpService& service); ...@@ -523,6 +567,15 @@ kj::Own<HttpClient> newHttpClient(HttpService& service);
kj::Own<HttpService> newHttpService(HttpClient& client); kj::Own<HttpService> newHttpService(HttpClient& client);
// Adapts an HttpClient to an HttpService and vice versa. // Adapts an HttpClient to an HttpService and vice versa.
kj::Own<WebSocket> newWebSocket(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<WebSocket::MaskKeyGenerator&> maskKeyGenerator = nullptr);
// Create a new WebSocket on top of the given stream. It is assumed that the HTTP -> WebSocket
// upgrade handshake has already occurred (or is not needed), and messages can immediately be
// sent and received on the stream. Normally applications would not call this directly.
//
// `maskKeyGenerator` is optional, but if omitted, the WebSocket frames will not be masked. Refer
// to RFC6455 to understand when masking is required.
struct HttpServerSettings { struct HttpServerSettings {
kj::Duration headerTimeout = 15 * kj::SECONDS; kj::Duration headerTimeout = 15 * kj::SECONDS;
// After initial connection open, or after receiving the first byte of a pipelined request, // After initial connection open, or after receiving the first byte of a pipelined request,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment