Unverified Commit b5425465 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #637 from capnproto/rewrite-websocket

Redesign server-side WebSocket handling.
parents 9032d060 632888d6
......@@ -1691,11 +1691,8 @@ public:
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override {
KJ_FAIL_ASSERT("can't get here");
}
KJ_ASSERT(headers.isWebSocket());
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
HttpHeaders responseHeaders(headerTable);
KJ_IF_MAYBE(h, headers.get(hMyHeader)) {
responseHeaders.set(hMyHeader, kj::str("respond-", *h));
......@@ -2491,26 +2488,25 @@ public:
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override {
KJ_ASSERT(url != "/throw");
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto stream = response.send(200, "OK", HttpHeaders(headerTable), body.size());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(stream->write(body.begin(), body.size()));
promises.add(requestBody.readAllBytes().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body));
}
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
auto ws = response.acceptWebSocket(HttpHeaders(headerTable));
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto sendPromise = ws->send(body);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(sendPromise.attach(kj::mv(body)));
promises.add(ws->receive().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(ws));
if (!headers.isWebSocket()) {
KJ_ASSERT(url != "/throw");
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto stream = response.send(200, "OK", HttpHeaders(headerTable), body.size());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(stream->write(body.begin(), body.size()));
promises.add(requestBody.readAllBytes().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body));
} else {
auto ws = response.acceptWebSocket(HttpHeaders(headerTable));
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto sendPromise = ws->send(body);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(sendPromise.attach(kj::mv(body)));
promises.add(ws->receive().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(ws));
}
}
private:
......
......@@ -630,6 +630,19 @@ kj::StringPtr HttpHeaders::cloneToOwn(kj::StringPtr str) {
return result;
}
namespace {
template <char... chars>
constexpr bool fastCaseCmp(const char* actual);
} // namespace
bool HttpHeaders::isWebSocket() const {
return fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr());
}
void HttpHeaders::set(HttpHeaderId id, kj::StringPtr value) {
id.requireFrom(*table);
requireValidHeaderValue(value);
......@@ -3165,46 +3178,45 @@ public:
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override {
auto innerReq = client.request(method, url, headers, requestBody.tryGetLength());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(requestBody.pumpTo(*innerReq.body).ignoreResult()
.attach(kj::mv(innerReq.body)).eagerlyEvaluate(nullptr));
promises.add(innerReq.response
.then([&response](HttpClient::Response&& innerResponse) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
innerResponse.body->tryGetLength());
auto promise = innerResponse.body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(innerResponse.body));
}));
return kj::joinPromises(promises.finish());
}
if (!headers.isWebSocket()) {
auto innerReq = client.request(method, url, headers, requestBody.tryGetLength());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(requestBody.pumpTo(*innerReq.body).ignoreResult()
.attach(kj::mv(innerReq.body)).eagerlyEvaluate(nullptr));
promises.add(innerReq.response
.then([&response](HttpClient::Response&& innerResponse) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
innerResponse.body->tryGetLength());
auto promise = innerResponse.body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(innerResponse.body));
}));
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
return client.openWebSocket(url, headers)
.then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> {
KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) {
KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) {
auto ws2 = response.acceptWebSocket(*innerResponse.headers);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(ws->pumpTo(*ws2));
promises.add(ws2->pumpTo(*ws));
return kj::joinPromises(promises.finish()).attach(kj::mv(ws), kj::mv(ws2));
}
KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
body->tryGetLength());
auto promise = body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(body));
return kj::joinPromises(promises.finish());
} else {
return client.openWebSocket(url, headers)
.then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> {
KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) {
KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) {
auto ws2 = response.acceptWebSocket(*innerResponse.headers);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(ws->pumpTo(*ws2));
promises.add(ws2->pumpTo(*ws));
return kj::joinPromises(promises.finish()).attach(kj::mv(ws), kj::mv(ws2));
}
KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
body->tryGetLength());
auto promise = body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(body));
}
}
}
KJ_UNREACHABLE;
});
KJ_UNREACHABLE;
});
}
}
kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override {
......@@ -3237,25 +3249,11 @@ kj::Promise<void> HttpService::Response::sendError(
return sendError(statusCode, statusText, HttpHeaders(headerTable));
}
kj::Promise<void> HttpService::openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) {
class EmptyStream final: public kj::AsyncInputStream {
public:
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return size_t(0);
}
};
auto requestBody = heap<EmptyStream>();
auto promise = request(HttpMethod::GET, url, headers, *requestBody, response);
return promise.attach(kj::mv(requestBody));
}
kj::Promise<kj::Own<kj::AsyncIoStream>> HttpService::connect(kj::StringPtr host) {
KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService");
}
class HttpServer::Connection final: private HttpService::WebSocketResponse {
class HttpServer::Connection final: private HttpService::Response {
public:
Connection(HttpServer& server, kj::Own<kj::AsyncIoStream>&& stream,
HttpService& service)
......@@ -3348,48 +3346,30 @@ public:
}
KJ_IF_MAYBE(req, request) {
kj::Promise<void> promise = nullptr;
auto& headers = httpInput.getHeaders();
if (fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
headers.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
if (req->method != HttpMethod::GET) {
return sendError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
currentMethod = req->method;
auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, headers);
if (headers.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
// TODO(perf): If the client disconnects, should we cancel the response? Probably, to
// prevent permanent deadlock. It's slightly weird in that arguably the client should
// be able to shutdown the upstream but still wait on the downstream, but I believe many
// other HTTP servers do similar things.
KJ_IF_MAYBE(key, headers.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
websocketKey = kj::str(*key);
promise = service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
return sendError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
} else {
currentMethod = req->method;
websocketKey = nullptr;
auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, headers);
// TODO(perf): If the client disconnects, should we cancel the response? Probably, to
// prevent permanent deadlock. It's slightly weird in that arguably the client should
// be able to shutdown the upstream but still wait on the downstream, but I believe many
// other HTTP servers do similar things.
promise = service.request(
req->method, req->url, headers, *body, *this);
promise = promise.attach(kj::mv(body));
}
return promise
auto promise = service.request(
req->method, req->url, headers, *body, *this);
return promise.attach(kj::mv(body))
.then([this]() -> kj::Promise<void> {
// Response done. Await next request.
KJ_IF_MAYBE(p, webSocketError) {
// sendWebSocketError() was called. Finish sending and close the connection.
auto promise = kj::mv(*p);
webSocketError = nullptr;
return kj::mv(promise);
}
if (upgraded) {
// We've upgraded to WebSocket so we can exit this listen loop. In fact, we no longer
// own the stream.
......@@ -3424,7 +3404,15 @@ public:
if (currentMethod == nullptr) {
// Dang, already sent a partial response. Can't do anything else.
//
KJ_IF_MAYBE(p, webSocketError) {
// sendWebSocketError() was called. Finish sending and close the connection. Don't log
// the exception because it's probably a side-effect of this.
auto promise = kj::mv(*p);
webSocketError = nullptr;
return kj::mv(promise);
}
// If it's a DISCONNECTED exception, it's probably that the client disconnected, which is
// not really worth logging.
if (e.getType() != kj::Exception::Type::DISCONNECTED) {
......@@ -3461,23 +3449,17 @@ private:
HttpOutputStream httpOutput;
kj::Own<kj::AsyncIoStream> ownStream;
kj::Maybe<HttpMethod> currentMethod;
kj::Maybe<kj::String> websocketKey;
bool timedOut = false;
bool closed = false;
bool upgraded = false;
kj::Maybe<kj::Promise<void>> webSocketError;
kj::Own<kj::AsyncOutputStream> send(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize) override {
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called startResponse()");
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()");
currentMethod = nullptr;
if (websocketKey != nullptr) {
// This was a WebSocket request but the upgrade wasn't accepted.
websocketKey = nullptr;
httpInput.finishRead();
}
kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT];
kj::String lengthStr;
......@@ -3509,9 +3491,31 @@ private:
}
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
auto key = KJ_REQUIRE_NONNULL(kj::mv(websocketKey), "not a WebSocket request");
auto& requestHeaders = httpInput.getHeaders();
KJ_REQUIRE(requestHeaders.isWebSocket(),
"can't call acceptWebSocket() if the request headers didn't have Upgrade: WebSocket");
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()");
currentMethod = nullptr;
websocketKey = nullptr;
if (method != HttpMethod::GET) {
return sendWebSocketError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
if (requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendWebSocketError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
kj::String key;
KJ_IF_MAYBE(k, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
key = kj::str(*k);
} else {
return sendWebSocketError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
upgraded = true;
auto websocketAccept = generateWebSocketAccept(key);
......@@ -3539,6 +3543,42 @@ private:
httpOutput.finishBody();
return httpOutput.flush(); // loop ends after flush
}
kj::Own<WebSocket> sendWebSocketError(
uint statusCode, kj::StringPtr statusText, kj::String errorMessage) {
kj::Exception exception = KJ_EXCEPTION(FAILED,
"received bad WebSocket handshake", errorMessage);
webSocketError = sendError(statusCode, statusText, kj::mv(errorMessage));
kj::throwRecoverableException(kj::mv(exception));
// Fallback path when exceptions are disabled.
class BrokenWebSocket final: public WebSocket {
public:
BrokenWebSocket(kj::Exception exception): exception(kj::mv(exception)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return kj::cp(exception);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return kj::cp(exception);
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
return kj::cp(exception);
}
kj::Promise<void> disconnect() override {
return kj::cp(exception);
}
kj::Promise<Message> receive() override {
return kj::cp(exception);
}
private:
kj::Exception exception;
};
return kj::heap<BrokenWebSocket>(KJ_EXCEPTION(FAILED,
"received bad WebSocket handshake", errorMessage));
}
};
HttpServer::HttpServer(kj::Timer& timer, HttpHeaderTable& requestHeaderTable, HttpService& service,
......
......@@ -260,6 +260,13 @@ public:
// Creates a shallow clone of the HttpHeaders. The returned object references the same strings
// as the original, owning none of them.
bool isWebSocket() const;
// Convenience method that checks for the presence of the header `Upgrade: websocket`.
//
// Note that this does not actually validate that the request is a complete WebSocket handshake
// with the correct version number -- such validation will occur if and when you call
// acceptWebSocket().
kj::Maybe<kj::StringPtr> get(HttpHeaderId id) const;
// Read a header.
......@@ -413,7 +420,7 @@ public:
virtual kj::Promise<Message> receive() = 0;
// Read one message from the WebSocket and return it. Can only call once at a time. Do not call
// again after EndOfStream is received.
// again after Close is received.
kj::Promise<void> pumpTo(WebSocket& other);
// Continuously receives messages from this WebSocket and send them to `other`.
......@@ -513,6 +520,9 @@ public:
// `statusText` and `headers` need only remain valid until send() returns (they can be
// stack-allocated).
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// If headers.isWebSocket() is true then you can call acceptWebSocket() instead of send().
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText,
const HttpHeaders& headers);
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText,
......@@ -536,21 +546,6 @@ public:
// `url` and `headers` are invalidated on the first read from `requestBody` or when the returned
// promise resolves, whichever comes first.
class WebSocketResponse: public Response {
public:
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// Accept and open the WebSocket.
//
// `headers` need only remain valid until acceptWebSocket() returns (it can be stack-allocated).
};
virtual kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response);
// Tries to open a WebSocket. Default implementation calls request() and never returns a
// WebSocket.
//
// `url` and `headers` are invalidated when the returned promise resolves.
virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host);
// Handles CONNECT requests. Only relevant for proxy services. Default implementation throws
// UNIMPLEMENTED.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment