Commit 632888d6 authored by Kenton Varda's avatar Kenton Varda

Redesign server-side WebSocket handling.

Previously HttpService had two virtual methods: request() and openWebSocket(). Since it's legitimate to respond to a WebSocket request with a normal HTTP response, openWebSocket() actually had a default implementation that fell back to request().

In the new design, there is only request(). The HttpService detects a WebSocket request by checking the headers. A convenience method, HttpHeaders::isWebSocket(), is provided for this purpose.

The new approach makes life much easier for services composed of many layers. For example, you might write an HttpService implementation which performs some URL or header rewrite and then calls on to another HttpService. Previously, every such wrapper would have to separately handle regular requests and WebSockets, usually with near-identical code. Of course, you could factor out the common code, but in practice this often turned out pretty clunky. Worse, developers would often just omit the openWebSocket() implementation since implementing only request() seems to work fine -- until you need a WebSocket, and everything is broken. With the new approach, you have to go somewhat out of your way to write a wrapper layer that breaks WebSockets.

I did not apply the same logic to HttpClient because:

1. It's not as easy: HttpClient's methods return results rather than calling a callback on completion, so unifying the methods would have forced request()'s signature to change. Lots of code would need to be updated, and would likely become uglier, as request() would now have to return a `webSocketOrBody` variant type even when the caller isn't asking for a WebSocket.
2. People don't implement custom HttpClients nearly as often as they implement custom HttpServices.
parent 9032d060
...@@ -1691,11 +1691,8 @@ public: ...@@ -1691,11 +1691,8 @@ public:
kj::Promise<void> request( kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override { kj::AsyncInputStream& requestBody, Response& response) override {
KJ_FAIL_ASSERT("can't get here"); KJ_ASSERT(headers.isWebSocket());
}
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
HttpHeaders responseHeaders(headerTable); HttpHeaders responseHeaders(headerTable);
KJ_IF_MAYBE(h, headers.get(hMyHeader)) { KJ_IF_MAYBE(h, headers.get(hMyHeader)) {
responseHeaders.set(hMyHeader, kj::str("respond-", *h)); responseHeaders.set(hMyHeader, kj::str("respond-", *h));
...@@ -2491,26 +2488,25 @@ public: ...@@ -2491,26 +2488,25 @@ public:
kj::Promise<void> request( kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override { kj::AsyncInputStream& requestBody, Response& response) override {
KJ_ASSERT(url != "/throw"); if (!headers.isWebSocket()) {
KJ_ASSERT(url != "/throw");
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto stream = response.send(200, "OK", HttpHeaders(headerTable), body.size()); auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2); auto stream = response.send(200, "OK", HttpHeaders(headerTable), body.size());
promises.add(stream->write(body.begin(), body.size())); auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(requestBody.readAllBytes().ignoreResult()); promises.add(stream->write(body.begin(), body.size()));
return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body)); promises.add(requestBody.readAllBytes().ignoreResult());
} return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body));
} else {
kj::Promise<void> openWebSocket( auto ws = response.acceptWebSocket(HttpHeaders(headerTable));
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override { auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto ws = response.acceptWebSocket(HttpHeaders(headerTable)); auto sendPromise = ws->send(body);
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto sendPromise = ws->send(body); auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(sendPromise.attach(kj::mv(body)));
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2); promises.add(ws->receive().ignoreResult());
promises.add(sendPromise.attach(kj::mv(body))); return kj::joinPromises(promises.finish()).attach(kj::mv(ws));
promises.add(ws->receive().ignoreResult()); }
return kj::joinPromises(promises.finish()).attach(kj::mv(ws));
} }
private: private:
......
This diff is collapsed.
...@@ -260,6 +260,13 @@ public: ...@@ -260,6 +260,13 @@ public:
// Creates a shallow clone of the HttpHeaders. The returned object references the same strings // Creates a shallow clone of the HttpHeaders. The returned object references the same strings
// as the original, owning none of them. // as the original, owning none of them.
bool isWebSocket() const;
// Convenience method that checks for the presence of the header `Upgrade: websocket`.
//
// Note that this does not actually validate that the request is a complete WebSocket handshake
// with the correct version number -- such validation will occur if and when you call
// acceptWebSocket().
kj::Maybe<kj::StringPtr> get(HttpHeaderId id) const; kj::Maybe<kj::StringPtr> get(HttpHeaderId id) const;
// Read a header. // Read a header.
...@@ -413,7 +420,7 @@ public: ...@@ -413,7 +420,7 @@ public:
virtual kj::Promise<Message> receive() = 0; virtual kj::Promise<Message> receive() = 0;
// Read one message from the WebSocket and return it. Can only call once at a time. Do not call // Read one message from the WebSocket and return it. Can only call once at a time. Do not call
// again after EndOfStream is received. // again after Close is received.
kj::Promise<void> pumpTo(WebSocket& other); kj::Promise<void> pumpTo(WebSocket& other);
// Continuously receives messages from this WebSocket and send them to `other`. // Continuously receives messages from this WebSocket and send them to `other`.
...@@ -513,6 +520,9 @@ public: ...@@ -513,6 +520,9 @@ public:
// `statusText` and `headers` need only remain valid until send() returns (they can be // `statusText` and `headers` need only remain valid until send() returns (they can be
// stack-allocated). // stack-allocated).
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// If headers.isWebSocket() is true then you can call acceptWebSocket() instead of send().
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText,
const HttpHeaders& headers); const HttpHeaders& headers);
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText,
...@@ -536,21 +546,6 @@ public: ...@@ -536,21 +546,6 @@ public:
// `url` and `headers` are invalidated on the first read from `requestBody` or when the returned // `url` and `headers` are invalidated on the first read from `requestBody` or when the returned
// promise resolves, whichever comes first. // promise resolves, whichever comes first.
class WebSocketResponse: public Response {
public:
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// Accept and open the WebSocket.
//
// `headers` need only remain valid until acceptWebSocket() returns (it can be stack-allocated).
};
virtual kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response);
// Tries to open a WebSocket. Default implementation calls request() and never returns a
// WebSocket.
//
// `url` and `headers` are invalidated when the returned promise resolves.
virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host); virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host);
// Handles CONNECT requests. Only relevant for proxy services. Default implementation throws // Handles CONNECT requests. Only relevant for proxy services. Default implementation throws
// UNIMPLEMENTED. // UNIMPLEMENTED.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment