Commit 632888d6 authored by Kenton Varda's avatar Kenton Varda

Redesign server-side WebSocket handling.

Previously HttpService had two virtual methods: request() and openWebSocket(). Since it's legitimate to respond to a WebSocket request with a normal HTTP response, openWebSocket() actually had a default implementation that fell back to request().

In the new design, there is only request(). The HttpService detects a WebSocket request by checking the headers. A convenience method, HttpHeaders::isWebSocket(), is provided for this purpose.

The new approach makes life much easier for services composed of many layers. For example, you might write an HttpService implementation which performs some URL or header rewrite and then calls on to another HttpService. Previously, every such wrapper would have to separately handle regular requests and WebSockets, usually with near-identical code. Of course, you could factor out the common code, but in practice this often turned out pretty clunky. Worse, developers would often just omit the openWebSocket() implementation since implementing only request() seems to work fine -- until you need a WebSocket, and everything is broken. With the new approach, you have to go somewhat out of your way to write a wrapper layer that breaks WebSockets.

I did not apply the same logic to HttpClient because:

1. It's not as easy: HttpClient's methods return results rather than calling a callback on completion, so unifying the methods would have forced request()'s signature to change. Lots of code would need to be updated, and would likely become uglier, as request() would now have to return a `webSocketOrBody` variant type even when the caller isn't asking for a WebSocket.
2. People don't implement custom HttpClients nearly as often as they implement custom HttpServices.
parent 9032d060
......@@ -1691,11 +1691,8 @@ public:
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override {
KJ_FAIL_ASSERT("can't get here");
}
KJ_ASSERT(headers.isWebSocket());
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
HttpHeaders responseHeaders(headerTable);
KJ_IF_MAYBE(h, headers.get(hMyHeader)) {
responseHeaders.set(hMyHeader, kj::str("respond-", *h));
......@@ -2491,26 +2488,25 @@ public:
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override {
KJ_ASSERT(url != "/throw");
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto stream = response.send(200, "OK", HttpHeaders(headerTable), body.size());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(stream->write(body.begin(), body.size()));
promises.add(requestBody.readAllBytes().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body));
}
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
auto ws = response.acceptWebSocket(HttpHeaders(headerTable));
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto sendPromise = ws->send(body);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(sendPromise.attach(kj::mv(body)));
promises.add(ws->receive().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(ws));
if (!headers.isWebSocket()) {
KJ_ASSERT(url != "/throw");
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto stream = response.send(200, "OK", HttpHeaders(headerTable), body.size());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(stream->write(body.begin(), body.size()));
promises.add(requestBody.readAllBytes().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body));
} else {
auto ws = response.acceptWebSocket(HttpHeaders(headerTable));
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto sendPromise = ws->send(body);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(sendPromise.attach(kj::mv(body)));
promises.add(ws->receive().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(ws));
}
}
private:
......
......@@ -630,6 +630,19 @@ kj::StringPtr HttpHeaders::cloneToOwn(kj::StringPtr str) {
return result;
}
namespace {
template <char... chars>
constexpr bool fastCaseCmp(const char* actual);
} // namespace
bool HttpHeaders::isWebSocket() const {
return fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr());
}
void HttpHeaders::set(HttpHeaderId id, kj::StringPtr value) {
id.requireFrom(*table);
requireValidHeaderValue(value);
......@@ -3165,46 +3178,45 @@ public:
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override {
auto innerReq = client.request(method, url, headers, requestBody.tryGetLength());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(requestBody.pumpTo(*innerReq.body).ignoreResult()
.attach(kj::mv(innerReq.body)).eagerlyEvaluate(nullptr));
promises.add(innerReq.response
.then([&response](HttpClient::Response&& innerResponse) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
innerResponse.body->tryGetLength());
auto promise = innerResponse.body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(innerResponse.body));
}));
return kj::joinPromises(promises.finish());
}
if (!headers.isWebSocket()) {
auto innerReq = client.request(method, url, headers, requestBody.tryGetLength());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(requestBody.pumpTo(*innerReq.body).ignoreResult()
.attach(kj::mv(innerReq.body)).eagerlyEvaluate(nullptr));
promises.add(innerReq.response
.then([&response](HttpClient::Response&& innerResponse) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
innerResponse.body->tryGetLength());
auto promise = innerResponse.body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(innerResponse.body));
}));
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
return client.openWebSocket(url, headers)
.then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> {
KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) {
KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) {
auto ws2 = response.acceptWebSocket(*innerResponse.headers);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(ws->pumpTo(*ws2));
promises.add(ws2->pumpTo(*ws));
return kj::joinPromises(promises.finish()).attach(kj::mv(ws), kj::mv(ws2));
}
KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
body->tryGetLength());
auto promise = body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(body));
return kj::joinPromises(promises.finish());
} else {
return client.openWebSocket(url, headers)
.then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> {
KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) {
KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) {
auto ws2 = response.acceptWebSocket(*innerResponse.headers);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(ws->pumpTo(*ws2));
promises.add(ws2->pumpTo(*ws));
return kj::joinPromises(promises.finish()).attach(kj::mv(ws), kj::mv(ws2));
}
KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
body->tryGetLength());
auto promise = body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(body));
}
}
}
KJ_UNREACHABLE;
});
KJ_UNREACHABLE;
});
}
}
kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override {
......@@ -3237,25 +3249,11 @@ kj::Promise<void> HttpService::Response::sendError(
return sendError(statusCode, statusText, HttpHeaders(headerTable));
}
kj::Promise<void> HttpService::openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) {
class EmptyStream final: public kj::AsyncInputStream {
public:
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return size_t(0);
}
};
auto requestBody = heap<EmptyStream>();
auto promise = request(HttpMethod::GET, url, headers, *requestBody, response);
return promise.attach(kj::mv(requestBody));
}
kj::Promise<kj::Own<kj::AsyncIoStream>> HttpService::connect(kj::StringPtr host) {
KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService");
}
class HttpServer::Connection final: private HttpService::WebSocketResponse {
class HttpServer::Connection final: private HttpService::Response {
public:
Connection(HttpServer& server, kj::Own<kj::AsyncIoStream>&& stream,
HttpService& service)
......@@ -3348,48 +3346,30 @@ public:
}
KJ_IF_MAYBE(req, request) {
kj::Promise<void> promise = nullptr;
auto& headers = httpInput.getHeaders();
if (fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
headers.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
if (req->method != HttpMethod::GET) {
return sendError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
currentMethod = req->method;
auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, headers);
if (headers.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
// TODO(perf): If the client disconnects, should we cancel the response? Probably, to
// prevent permanent deadlock. It's slightly weird in that arguably the client should
// be able to shutdown the upstream but still wait on the downstream, but I believe many
// other HTTP servers do similar things.
KJ_IF_MAYBE(key, headers.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
websocketKey = kj::str(*key);
promise = service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
return sendError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
} else {
currentMethod = req->method;
websocketKey = nullptr;
auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, headers);
// TODO(perf): If the client disconnects, should we cancel the response? Probably, to
// prevent permanent deadlock. It's slightly weird in that arguably the client should
// be able to shutdown the upstream but still wait on the downstream, but I believe many
// other HTTP servers do similar things.
promise = service.request(
req->method, req->url, headers, *body, *this);
promise = promise.attach(kj::mv(body));
}
return promise
auto promise = service.request(
req->method, req->url, headers, *body, *this);
return promise.attach(kj::mv(body))
.then([this]() -> kj::Promise<void> {
// Response done. Await next request.
KJ_IF_MAYBE(p, webSocketError) {
// sendWebSocketError() was called. Finish sending and close the connection.
auto promise = kj::mv(*p);
webSocketError = nullptr;
return kj::mv(promise);
}
if (upgraded) {
// We've upgraded to WebSocket so we can exit this listen loop. In fact, we no longer
// own the stream.
......@@ -3424,7 +3404,15 @@ public:
if (currentMethod == nullptr) {
// Dang, already sent a partial response. Can't do anything else.
//
KJ_IF_MAYBE(p, webSocketError) {
// sendWebSocketError() was called. Finish sending and close the connection. Don't log
// the exception because it's probably a side-effect of this.
auto promise = kj::mv(*p);
webSocketError = nullptr;
return kj::mv(promise);
}
// If it's a DISCONNECTED exception, it's probably that the client disconnected, which is
// not really worth logging.
if (e.getType() != kj::Exception::Type::DISCONNECTED) {
......@@ -3461,23 +3449,17 @@ private:
HttpOutputStream httpOutput;
kj::Own<kj::AsyncIoStream> ownStream;
kj::Maybe<HttpMethod> currentMethod;
kj::Maybe<kj::String> websocketKey;
bool timedOut = false;
bool closed = false;
bool upgraded = false;
kj::Maybe<kj::Promise<void>> webSocketError;
kj::Own<kj::AsyncOutputStream> send(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize) override {
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called startResponse()");
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()");
currentMethod = nullptr;
if (websocketKey != nullptr) {
// This was a WebSocket request but the upgrade wasn't accepted.
websocketKey = nullptr;
httpInput.finishRead();
}
kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT];
kj::String lengthStr;
......@@ -3509,9 +3491,31 @@ private:
}
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
auto key = KJ_REQUIRE_NONNULL(kj::mv(websocketKey), "not a WebSocket request");
auto& requestHeaders = httpInput.getHeaders();
KJ_REQUIRE(requestHeaders.isWebSocket(),
"can't call acceptWebSocket() if the request headers didn't have Upgrade: WebSocket");
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()");
currentMethod = nullptr;
websocketKey = nullptr;
if (method != HttpMethod::GET) {
return sendWebSocketError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
if (requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendWebSocketError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
kj::String key;
KJ_IF_MAYBE(k, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
key = kj::str(*k);
} else {
return sendWebSocketError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
upgraded = true;
auto websocketAccept = generateWebSocketAccept(key);
......@@ -3539,6 +3543,42 @@ private:
httpOutput.finishBody();
return httpOutput.flush(); // loop ends after flush
}
kj::Own<WebSocket> sendWebSocketError(
uint statusCode, kj::StringPtr statusText, kj::String errorMessage) {
kj::Exception exception = KJ_EXCEPTION(FAILED,
"received bad WebSocket handshake", errorMessage);
webSocketError = sendError(statusCode, statusText, kj::mv(errorMessage));
kj::throwRecoverableException(kj::mv(exception));
// Fallback path when exceptions are disabled.
class BrokenWebSocket final: public WebSocket {
public:
BrokenWebSocket(kj::Exception exception): exception(kj::mv(exception)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return kj::cp(exception);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return kj::cp(exception);
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
return kj::cp(exception);
}
kj::Promise<void> disconnect() override {
return kj::cp(exception);
}
kj::Promise<Message> receive() override {
return kj::cp(exception);
}
private:
kj::Exception exception;
};
return kj::heap<BrokenWebSocket>(KJ_EXCEPTION(FAILED,
"received bad WebSocket handshake", errorMessage));
}
};
HttpServer::HttpServer(kj::Timer& timer, HttpHeaderTable& requestHeaderTable, HttpService& service,
......
......@@ -260,6 +260,13 @@ public:
// Creates a shallow clone of the HttpHeaders. The returned object references the same strings
// as the original, owning none of them.
bool isWebSocket() const;
// Convenience method that checks for the presence of the header `Upgrade: websocket`.
//
// Note that this does not actually validate that the request is a complete WebSocket handshake
// with the correct version number -- such validation will occur if and when you call
// acceptWebSocket().
kj::Maybe<kj::StringPtr> get(HttpHeaderId id) const;
// Read a header.
......@@ -413,7 +420,7 @@ public:
virtual kj::Promise<Message> receive() = 0;
// Read one message from the WebSocket and return it. Can only call once at a time. Do not call
// again after EndOfStream is received.
// again after Close is received.
kj::Promise<void> pumpTo(WebSocket& other);
// Continuously receives messages from this WebSocket and send them to `other`.
......@@ -513,6 +520,9 @@ public:
// `statusText` and `headers` need only remain valid until send() returns (they can be
// stack-allocated).
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// If headers.isWebSocket() is true then you can call acceptWebSocket() instead of send().
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText,
const HttpHeaders& headers);
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText,
......@@ -536,21 +546,6 @@ public:
// `url` and `headers` are invalidated on the first read from `requestBody` or when the returned
// promise resolves, whichever comes first.
class WebSocketResponse: public Response {
public:
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// Accept and open the WebSocket.
//
// `headers` need only remain valid until acceptWebSocket() returns (it can be stack-allocated).
};
virtual kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response);
// Tries to open a WebSocket. Default implementation calls request() and never returns a
// WebSocket.
//
// `url` and `headers` are invalidated when the returned promise resolves.
virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host);
// Handles CONNECT requests. Only relevant for proxy services. Default implementation throws
// UNIMPLEMENTED.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment