Commit 632888d6 authored by Kenton Varda's avatar Kenton Varda

Redesign server-side WebSocket handling.

Previously HttpService had two virtual methods: request() and openWebSocket(). Since it's legitimate to respond to a WebSocket request with a normal HTTP response, openWebSocket() actually had a default implementation that fell back to request().

In the new design, there is only request(). The HttpService detects a WebSocket request by checking the headers. A convenience method, HttpHeaders::isWebSocket(), is provided for this purpose.

The new approach makes life much easier for services composed of many layers. For example, you might write an HttpService implementation which performs some URL or header rewrite and then calls on to another HttpService. Previously, every such wrapper would have to separately handle regular requests and WebSockets, usually with near-identical code. Of course, you could factor out the common code, but in practice this often turned out pretty clunky. Worse, developers would often just omit the openWebSocket() implementation since implementing only request() seems to work fine -- until you need a WebSocket, and everything is broken. With the new approach, you have to go somewhat out of your way to write a wrapper layer that breaks WebSockets.

I did not apply the same logic to HttpClient because:

1. It's not as easy: HttpClient's methods return results rather than calling a callback on completion, so unifying the methods would have forced request()'s signature to change. Lots of code would need to be updated, and would likely become uglier, as request() would now have to return a `webSocketOrBody` variant type even when the caller isn't asking for a WebSocket.
2. People don't implement custom HttpClients nearly as often as they implement custom HttpServices.
parent 9032d060
...@@ -1691,11 +1691,8 @@ public: ...@@ -1691,11 +1691,8 @@ public:
kj::Promise<void> request( kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override { kj::AsyncInputStream& requestBody, Response& response) override {
KJ_FAIL_ASSERT("can't get here"); KJ_ASSERT(headers.isWebSocket());
}
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
HttpHeaders responseHeaders(headerTable); HttpHeaders responseHeaders(headerTable);
KJ_IF_MAYBE(h, headers.get(hMyHeader)) { KJ_IF_MAYBE(h, headers.get(hMyHeader)) {
responseHeaders.set(hMyHeader, kj::str("respond-", *h)); responseHeaders.set(hMyHeader, kj::str("respond-", *h));
...@@ -2491,6 +2488,7 @@ public: ...@@ -2491,6 +2488,7 @@ public:
kj::Promise<void> request( kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override { kj::AsyncInputStream& requestBody, Response& response) override {
if (!headers.isWebSocket()) {
KJ_ASSERT(url != "/throw"); KJ_ASSERT(url != "/throw");
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url); auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
...@@ -2499,10 +2497,7 @@ public: ...@@ -2499,10 +2497,7 @@ public:
promises.add(stream->write(body.begin(), body.size())); promises.add(stream->write(body.begin(), body.size()));
promises.add(requestBody.readAllBytes().ignoreResult()); promises.add(requestBody.readAllBytes().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body)); return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body));
} } else {
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
auto ws = response.acceptWebSocket(HttpHeaders(headerTable)); auto ws = response.acceptWebSocket(HttpHeaders(headerTable));
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url); auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto sendPromise = ws->send(body); auto sendPromise = ws->send(body);
...@@ -2512,6 +2507,7 @@ public: ...@@ -2512,6 +2507,7 @@ public:
promises.add(ws->receive().ignoreResult()); promises.add(ws->receive().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(ws)); return kj::joinPromises(promises.finish()).attach(kj::mv(ws));
} }
}
private: private:
HttpHeaderTable& headerTable; HttpHeaderTable& headerTable;
......
...@@ -630,6 +630,19 @@ kj::StringPtr HttpHeaders::cloneToOwn(kj::StringPtr str) { ...@@ -630,6 +630,19 @@ kj::StringPtr HttpHeaders::cloneToOwn(kj::StringPtr str) {
return result; return result;
} }
namespace {
template <char... chars>
constexpr bool fastCaseCmp(const char* actual);
} // namespace
bool HttpHeaders::isWebSocket() const {
return fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr());
}
void HttpHeaders::set(HttpHeaderId id, kj::StringPtr value) { void HttpHeaders::set(HttpHeaderId id, kj::StringPtr value) {
id.requireFrom(*table); id.requireFrom(*table);
requireValidHeaderValue(value); requireValidHeaderValue(value);
...@@ -3165,6 +3178,7 @@ public: ...@@ -3165,6 +3178,7 @@ public:
kj::Promise<void> request( kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override { kj::AsyncInputStream& requestBody, Response& response) override {
if (!headers.isWebSocket()) {
auto innerReq = client.request(method, url, headers, requestBody.tryGetLength()); auto innerReq = client.request(method, url, headers, requestBody.tryGetLength());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2); auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
...@@ -3181,10 +3195,7 @@ public: ...@@ -3181,10 +3195,7 @@ public:
})); }));
return kj::joinPromises(promises.finish()); return kj::joinPromises(promises.finish());
} } else {
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
return client.openWebSocket(url, headers) return client.openWebSocket(url, headers)
.then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> { .then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> {
KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) { KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) {
...@@ -3206,6 +3217,7 @@ public: ...@@ -3206,6 +3217,7 @@ public:
KJ_UNREACHABLE; KJ_UNREACHABLE;
}); });
} }
}
kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override { kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override {
return client.connect(kj::mv(host)); return client.connect(kj::mv(host));
...@@ -3237,25 +3249,11 @@ kj::Promise<void> HttpService::Response::sendError( ...@@ -3237,25 +3249,11 @@ kj::Promise<void> HttpService::Response::sendError(
return sendError(statusCode, statusText, HttpHeaders(headerTable)); return sendError(statusCode, statusText, HttpHeaders(headerTable));
} }
kj::Promise<void> HttpService::openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) {
class EmptyStream final: public kj::AsyncInputStream {
public:
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return size_t(0);
}
};
auto requestBody = heap<EmptyStream>();
auto promise = request(HttpMethod::GET, url, headers, *requestBody, response);
return promise.attach(kj::mv(requestBody));
}
kj::Promise<kj::Own<kj::AsyncIoStream>> HttpService::connect(kj::StringPtr host) { kj::Promise<kj::Own<kj::AsyncIoStream>> HttpService::connect(kj::StringPtr host) {
KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService"); KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService");
} }
class HttpServer::Connection final: private HttpService::WebSocketResponse { class HttpServer::Connection final: private HttpService::Response {
public: public:
Connection(HttpServer& server, kj::Own<kj::AsyncIoStream>&& stream, Connection(HttpServer& server, kj::Own<kj::AsyncIoStream>&& stream,
HttpService& service) HttpService& service)
...@@ -3348,31 +3346,9 @@ public: ...@@ -3348,31 +3346,9 @@ public:
} }
KJ_IF_MAYBE(req, request) { KJ_IF_MAYBE(req, request) {
kj::Promise<void> promise = nullptr;
auto& headers = httpInput.getHeaders(); auto& headers = httpInput.getHeaders();
if (fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
headers.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
if (req->method != HttpMethod::GET) {
return sendError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
if (headers.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
KJ_IF_MAYBE(key, headers.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
websocketKey = kj::str(*key);
promise = service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
return sendError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
} else {
currentMethod = req->method; currentMethod = req->method;
websocketKey = nullptr;
auto body = httpInput.getEntityBody( auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, headers); HttpInputStream::REQUEST, req->method, 0, headers);
...@@ -3381,15 +3357,19 @@ public: ...@@ -3381,15 +3357,19 @@ public:
// be able to shutdown the upstream but still wait on the downstream, but I believe many // be able to shutdown the upstream but still wait on the downstream, but I believe many
// other HTTP servers do similar things. // other HTTP servers do similar things.
promise = service.request( auto promise = service.request(
req->method, req->url, headers, *body, *this); req->method, req->url, headers, *body, *this);
promise = promise.attach(kj::mv(body)); return promise.attach(kj::mv(body))
}
return promise
.then([this]() -> kj::Promise<void> { .then([this]() -> kj::Promise<void> {
// Response done. Await next request. // Response done. Await next request.
KJ_IF_MAYBE(p, webSocketError) {
// sendWebSocketError() was called. Finish sending and close the connection.
auto promise = kj::mv(*p);
webSocketError = nullptr;
return kj::mv(promise);
}
if (upgraded) { if (upgraded) {
// We've upgraded to WebSocket so we can exit this listen loop. In fact, we no longer // We've upgraded to WebSocket so we can exit this listen loop. In fact, we no longer
// own the stream. // own the stream.
...@@ -3424,7 +3404,15 @@ public: ...@@ -3424,7 +3404,15 @@ public:
if (currentMethod == nullptr) { if (currentMethod == nullptr) {
// Dang, already sent a partial response. Can't do anything else. // Dang, already sent a partial response. Can't do anything else.
//
KJ_IF_MAYBE(p, webSocketError) {
// sendWebSocketError() was called. Finish sending and close the connection. Don't log
// the exception because it's probably a side-effect of this.
auto promise = kj::mv(*p);
webSocketError = nullptr;
return kj::mv(promise);
}
// If it's a DISCONNECTED exception, it's probably that the client disconnected, which is // If it's a DISCONNECTED exception, it's probably that the client disconnected, which is
// not really worth logging. // not really worth logging.
if (e.getType() != kj::Exception::Type::DISCONNECTED) { if (e.getType() != kj::Exception::Type::DISCONNECTED) {
...@@ -3461,23 +3449,17 @@ private: ...@@ -3461,23 +3449,17 @@ private:
HttpOutputStream httpOutput; HttpOutputStream httpOutput;
kj::Own<kj::AsyncIoStream> ownStream; kj::Own<kj::AsyncIoStream> ownStream;
kj::Maybe<HttpMethod> currentMethod; kj::Maybe<HttpMethod> currentMethod;
kj::Maybe<kj::String> websocketKey;
bool timedOut = false; bool timedOut = false;
bool closed = false; bool closed = false;
bool upgraded = false; bool upgraded = false;
kj::Maybe<kj::Promise<void>> webSocketError;
kj::Own<kj::AsyncOutputStream> send( kj::Own<kj::AsyncOutputStream> send(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers, uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize) override { kj::Maybe<uint64_t> expectedBodySize) override {
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called startResponse()"); auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()");
currentMethod = nullptr; currentMethod = nullptr;
if (websocketKey != nullptr) {
// This was a WebSocket request but the upgrade wasn't accepted.
websocketKey = nullptr;
httpInput.finishRead();
}
kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT]; kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT];
kj::String lengthStr; kj::String lengthStr;
...@@ -3509,9 +3491,31 @@ private: ...@@ -3509,9 +3491,31 @@ private:
} }
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override { kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
auto key = KJ_REQUIRE_NONNULL(kj::mv(websocketKey), "not a WebSocket request"); auto& requestHeaders = httpInput.getHeaders();
KJ_REQUIRE(requestHeaders.isWebSocket(),
"can't call acceptWebSocket() if the request headers didn't have Upgrade: WebSocket");
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()");
currentMethod = nullptr; currentMethod = nullptr;
websocketKey = nullptr;
if (method != HttpMethod::GET) {
return sendWebSocketError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
if (requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendWebSocketError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
kj::String key;
KJ_IF_MAYBE(k, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
key = kj::str(*k);
} else {
return sendWebSocketError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
upgraded = true; upgraded = true;
auto websocketAccept = generateWebSocketAccept(key); auto websocketAccept = generateWebSocketAccept(key);
...@@ -3539,6 +3543,42 @@ private: ...@@ -3539,6 +3543,42 @@ private:
httpOutput.finishBody(); httpOutput.finishBody();
return httpOutput.flush(); // loop ends after flush return httpOutput.flush(); // loop ends after flush
} }
kj::Own<WebSocket> sendWebSocketError(
uint statusCode, kj::StringPtr statusText, kj::String errorMessage) {
kj::Exception exception = KJ_EXCEPTION(FAILED,
"received bad WebSocket handshake", errorMessage);
webSocketError = sendError(statusCode, statusText, kj::mv(errorMessage));
kj::throwRecoverableException(kj::mv(exception));
// Fallback path when exceptions are disabled.
class BrokenWebSocket final: public WebSocket {
public:
BrokenWebSocket(kj::Exception exception): exception(kj::mv(exception)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return kj::cp(exception);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return kj::cp(exception);
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
return kj::cp(exception);
}
kj::Promise<void> disconnect() override {
return kj::cp(exception);
}
kj::Promise<Message> receive() override {
return kj::cp(exception);
}
private:
kj::Exception exception;
};
return kj::heap<BrokenWebSocket>(KJ_EXCEPTION(FAILED,
"received bad WebSocket handshake", errorMessage));
}
}; };
HttpServer::HttpServer(kj::Timer& timer, HttpHeaderTable& requestHeaderTable, HttpService& service, HttpServer::HttpServer(kj::Timer& timer, HttpHeaderTable& requestHeaderTable, HttpService& service,
......
...@@ -260,6 +260,13 @@ public: ...@@ -260,6 +260,13 @@ public:
// Creates a shallow clone of the HttpHeaders. The returned object references the same strings // Creates a shallow clone of the HttpHeaders. The returned object references the same strings
// as the original, owning none of them. // as the original, owning none of them.
bool isWebSocket() const;
// Convenience method that checks for the presence of the header `Upgrade: websocket`.
//
// Note that this does not actually validate that the request is a complete WebSocket handshake
// with the correct version number -- such validation will occur if and when you call
// acceptWebSocket().
kj::Maybe<kj::StringPtr> get(HttpHeaderId id) const; kj::Maybe<kj::StringPtr> get(HttpHeaderId id) const;
// Read a header. // Read a header.
...@@ -413,7 +420,7 @@ public: ...@@ -413,7 +420,7 @@ public:
virtual kj::Promise<Message> receive() = 0; virtual kj::Promise<Message> receive() = 0;
// Read one message from the WebSocket and return it. Can only call once at a time. Do not call // Read one message from the WebSocket and return it. Can only call once at a time. Do not call
// again after EndOfStream is received. // again after Close is received.
kj::Promise<void> pumpTo(WebSocket& other); kj::Promise<void> pumpTo(WebSocket& other);
// Continuously receives messages from this WebSocket and send them to `other`. // Continuously receives messages from this WebSocket and send them to `other`.
...@@ -513,6 +520,9 @@ public: ...@@ -513,6 +520,9 @@ public:
// `statusText` and `headers` need only remain valid until send() returns (they can be // `statusText` and `headers` need only remain valid until send() returns (they can be
// stack-allocated). // stack-allocated).
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// If headers.isWebSocket() is true then you can call acceptWebSocket() instead of send().
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText,
const HttpHeaders& headers); const HttpHeaders& headers);
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText,
...@@ -536,21 +546,6 @@ public: ...@@ -536,21 +546,6 @@ public:
// `url` and `headers` are invalidated on the first read from `requestBody` or when the returned // `url` and `headers` are invalidated on the first read from `requestBody` or when the returned
// promise resolves, whichever comes first. // promise resolves, whichever comes first.
class WebSocketResponse: public Response {
public:
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// Accept and open the WebSocket.
//
// `headers` need only remain valid until acceptWebSocket() returns (it can be stack-allocated).
};
virtual kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response);
// Tries to open a WebSocket. Default implementation calls request() and never returns a
// WebSocket.
//
// `url` and `headers` are invalidated when the returned promise resolves.
virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host); virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host);
// Handles CONNECT requests. Only relevant for proxy services. Default implementation throws // Handles CONNECT requests. Only relevant for proxy services. Default implementation throws
// UNIMPLEMENTED. // UNIMPLEMENTED.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment