Unverified Commit b5425465 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #637 from capnproto/rewrite-websocket

Redesign server-side WebSocket handling.
parents 9032d060 632888d6
...@@ -1691,11 +1691,8 @@ public: ...@@ -1691,11 +1691,8 @@ public:
kj::Promise<void> request( kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override { kj::AsyncInputStream& requestBody, Response& response) override {
KJ_FAIL_ASSERT("can't get here"); KJ_ASSERT(headers.isWebSocket());
}
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
HttpHeaders responseHeaders(headerTable); HttpHeaders responseHeaders(headerTable);
KJ_IF_MAYBE(h, headers.get(hMyHeader)) { KJ_IF_MAYBE(h, headers.get(hMyHeader)) {
responseHeaders.set(hMyHeader, kj::str("respond-", *h)); responseHeaders.set(hMyHeader, kj::str("respond-", *h));
...@@ -2491,6 +2488,7 @@ public: ...@@ -2491,6 +2488,7 @@ public:
kj::Promise<void> request( kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override { kj::AsyncInputStream& requestBody, Response& response) override {
if (!headers.isWebSocket()) {
KJ_ASSERT(url != "/throw"); KJ_ASSERT(url != "/throw");
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url); auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
...@@ -2499,10 +2497,7 @@ public: ...@@ -2499,10 +2497,7 @@ public:
promises.add(stream->write(body.begin(), body.size())); promises.add(stream->write(body.begin(), body.size()));
promises.add(requestBody.readAllBytes().ignoreResult()); promises.add(requestBody.readAllBytes().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body)); return kj::joinPromises(promises.finish()).attach(kj::mv(stream), kj::mv(body));
} } else {
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
auto ws = response.acceptWebSocket(HttpHeaders(headerTable)); auto ws = response.acceptWebSocket(HttpHeaders(headerTable));
auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url); auto body = kj::str(headers.get(HttpHeaderId::HOST).orDefault("null"), ":", url);
auto sendPromise = ws->send(body); auto sendPromise = ws->send(body);
...@@ -2512,6 +2507,7 @@ public: ...@@ -2512,6 +2507,7 @@ public:
promises.add(ws->receive().ignoreResult()); promises.add(ws->receive().ignoreResult());
return kj::joinPromises(promises.finish()).attach(kj::mv(ws)); return kj::joinPromises(promises.finish()).attach(kj::mv(ws));
} }
}
private: private:
HttpHeaderTable& headerTable; HttpHeaderTable& headerTable;
......
...@@ -630,6 +630,19 @@ kj::StringPtr HttpHeaders::cloneToOwn(kj::StringPtr str) { ...@@ -630,6 +630,19 @@ kj::StringPtr HttpHeaders::cloneToOwn(kj::StringPtr str) {
return result; return result;
} }
namespace {
template <char... chars>
constexpr bool fastCaseCmp(const char* actual);
} // namespace
bool HttpHeaders::isWebSocket() const {
return fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr());
}
void HttpHeaders::set(HttpHeaderId id, kj::StringPtr value) { void HttpHeaders::set(HttpHeaderId id, kj::StringPtr value) {
id.requireFrom(*table); id.requireFrom(*table);
requireValidHeaderValue(value); requireValidHeaderValue(value);
...@@ -3165,6 +3178,7 @@ public: ...@@ -3165,6 +3178,7 @@ public:
kj::Promise<void> request( kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override { kj::AsyncInputStream& requestBody, Response& response) override {
if (!headers.isWebSocket()) {
auto innerReq = client.request(method, url, headers, requestBody.tryGetLength()); auto innerReq = client.request(method, url, headers, requestBody.tryGetLength());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2); auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
...@@ -3181,10 +3195,7 @@ public: ...@@ -3181,10 +3195,7 @@ public:
})); }));
return kj::joinPromises(promises.finish()); return kj::joinPromises(promises.finish());
} } else {
kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) override {
return client.openWebSocket(url, headers) return client.openWebSocket(url, headers)
.then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> { .then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> {
KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) { KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) {
...@@ -3206,6 +3217,7 @@ public: ...@@ -3206,6 +3217,7 @@ public:
KJ_UNREACHABLE; KJ_UNREACHABLE;
}); });
} }
}
kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override { kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override {
return client.connect(kj::mv(host)); return client.connect(kj::mv(host));
...@@ -3237,25 +3249,11 @@ kj::Promise<void> HttpService::Response::sendError( ...@@ -3237,25 +3249,11 @@ kj::Promise<void> HttpService::Response::sendError(
return sendError(statusCode, statusText, HttpHeaders(headerTable)); return sendError(statusCode, statusText, HttpHeaders(headerTable));
} }
kj::Promise<void> HttpService::openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response) {
class EmptyStream final: public kj::AsyncInputStream {
public:
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return size_t(0);
}
};
auto requestBody = heap<EmptyStream>();
auto promise = request(HttpMethod::GET, url, headers, *requestBody, response);
return promise.attach(kj::mv(requestBody));
}
kj::Promise<kj::Own<kj::AsyncIoStream>> HttpService::connect(kj::StringPtr host) { kj::Promise<kj::Own<kj::AsyncIoStream>> HttpService::connect(kj::StringPtr host) {
KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService"); KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService");
} }
class HttpServer::Connection final: private HttpService::WebSocketResponse { class HttpServer::Connection final: private HttpService::Response {
public: public:
Connection(HttpServer& server, kj::Own<kj::AsyncIoStream>&& stream, Connection(HttpServer& server, kj::Own<kj::AsyncIoStream>&& stream,
HttpService& service) HttpService& service)
...@@ -3348,31 +3346,9 @@ public: ...@@ -3348,31 +3346,9 @@ public:
} }
KJ_IF_MAYBE(req, request) { KJ_IF_MAYBE(req, request) {
kj::Promise<void> promise = nullptr;
auto& headers = httpInput.getHeaders(); auto& headers = httpInput.getHeaders();
if (fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
headers.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
if (req->method != HttpMethod::GET) {
return sendError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
if (headers.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
KJ_IF_MAYBE(key, headers.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
websocketKey = kj::str(*key);
promise = service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
return sendError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
} else {
currentMethod = req->method; currentMethod = req->method;
websocketKey = nullptr;
auto body = httpInput.getEntityBody( auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, headers); HttpInputStream::REQUEST, req->method, 0, headers);
...@@ -3381,15 +3357,19 @@ public: ...@@ -3381,15 +3357,19 @@ public:
// be able to shutdown the upstream but still wait on the downstream, but I believe many // be able to shutdown the upstream but still wait on the downstream, but I believe many
// other HTTP servers do similar things. // other HTTP servers do similar things.
promise = service.request( auto promise = service.request(
req->method, req->url, headers, *body, *this); req->method, req->url, headers, *body, *this);
promise = promise.attach(kj::mv(body)); return promise.attach(kj::mv(body))
}
return promise
.then([this]() -> kj::Promise<void> { .then([this]() -> kj::Promise<void> {
// Response done. Await next request. // Response done. Await next request.
KJ_IF_MAYBE(p, webSocketError) {
// sendWebSocketError() was called. Finish sending and close the connection.
auto promise = kj::mv(*p);
webSocketError = nullptr;
return kj::mv(promise);
}
if (upgraded) { if (upgraded) {
// We've upgraded to WebSocket so we can exit this listen loop. In fact, we no longer // We've upgraded to WebSocket so we can exit this listen loop. In fact, we no longer
// own the stream. // own the stream.
...@@ -3424,7 +3404,15 @@ public: ...@@ -3424,7 +3404,15 @@ public:
if (currentMethod == nullptr) { if (currentMethod == nullptr) {
// Dang, already sent a partial response. Can't do anything else. // Dang, already sent a partial response. Can't do anything else.
//
KJ_IF_MAYBE(p, webSocketError) {
// sendWebSocketError() was called. Finish sending and close the connection. Don't log
// the exception because it's probably a side-effect of this.
auto promise = kj::mv(*p);
webSocketError = nullptr;
return kj::mv(promise);
}
// If it's a DISCONNECTED exception, it's probably that the client disconnected, which is // If it's a DISCONNECTED exception, it's probably that the client disconnected, which is
// not really worth logging. // not really worth logging.
if (e.getType() != kj::Exception::Type::DISCONNECTED) { if (e.getType() != kj::Exception::Type::DISCONNECTED) {
...@@ -3461,23 +3449,17 @@ private: ...@@ -3461,23 +3449,17 @@ private:
HttpOutputStream httpOutput; HttpOutputStream httpOutput;
kj::Own<kj::AsyncIoStream> ownStream; kj::Own<kj::AsyncIoStream> ownStream;
kj::Maybe<HttpMethod> currentMethod; kj::Maybe<HttpMethod> currentMethod;
kj::Maybe<kj::String> websocketKey;
bool timedOut = false; bool timedOut = false;
bool closed = false; bool closed = false;
bool upgraded = false; bool upgraded = false;
kj::Maybe<kj::Promise<void>> webSocketError;
kj::Own<kj::AsyncOutputStream> send( kj::Own<kj::AsyncOutputStream> send(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers, uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize) override { kj::Maybe<uint64_t> expectedBodySize) override {
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called startResponse()"); auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()");
currentMethod = nullptr; currentMethod = nullptr;
if (websocketKey != nullptr) {
// This was a WebSocket request but the upgrade wasn't accepted.
websocketKey = nullptr;
httpInput.finishRead();
}
kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT]; kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT];
kj::String lengthStr; kj::String lengthStr;
...@@ -3509,9 +3491,31 @@ private: ...@@ -3509,9 +3491,31 @@ private:
} }
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override { kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
auto key = KJ_REQUIRE_NONNULL(kj::mv(websocketKey), "not a WebSocket request"); auto& requestHeaders = httpInput.getHeaders();
KJ_REQUIRE(requestHeaders.isWebSocket(),
"can't call acceptWebSocket() if the request headers didn't have Upgrade: WebSocket");
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()");
currentMethod = nullptr; currentMethod = nullptr;
websocketKey = nullptr;
if (method != HttpMethod::GET) {
return sendWebSocketError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
if (requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendWebSocketError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
kj::String key;
KJ_IF_MAYBE(k, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
key = kj::str(*k);
} else {
return sendWebSocketError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
upgraded = true; upgraded = true;
auto websocketAccept = generateWebSocketAccept(key); auto websocketAccept = generateWebSocketAccept(key);
...@@ -3539,6 +3543,42 @@ private: ...@@ -3539,6 +3543,42 @@ private:
httpOutput.finishBody(); httpOutput.finishBody();
return httpOutput.flush(); // loop ends after flush return httpOutput.flush(); // loop ends after flush
} }
kj::Own<WebSocket> sendWebSocketError(
uint statusCode, kj::StringPtr statusText, kj::String errorMessage) {
kj::Exception exception = KJ_EXCEPTION(FAILED,
"received bad WebSocket handshake", errorMessage);
webSocketError = sendError(statusCode, statusText, kj::mv(errorMessage));
kj::throwRecoverableException(kj::mv(exception));
// Fallback path when exceptions are disabled.
class BrokenWebSocket final: public WebSocket {
public:
BrokenWebSocket(kj::Exception exception): exception(kj::mv(exception)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return kj::cp(exception);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return kj::cp(exception);
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
return kj::cp(exception);
}
kj::Promise<void> disconnect() override {
return kj::cp(exception);
}
kj::Promise<Message> receive() override {
return kj::cp(exception);
}
private:
kj::Exception exception;
};
return kj::heap<BrokenWebSocket>(KJ_EXCEPTION(FAILED,
"received bad WebSocket handshake", errorMessage));
}
}; };
HttpServer::HttpServer(kj::Timer& timer, HttpHeaderTable& requestHeaderTable, HttpService& service, HttpServer::HttpServer(kj::Timer& timer, HttpHeaderTable& requestHeaderTable, HttpService& service,
......
...@@ -260,6 +260,13 @@ public: ...@@ -260,6 +260,13 @@ public:
// Creates a shallow clone of the HttpHeaders. The returned object references the same strings // Creates a shallow clone of the HttpHeaders. The returned object references the same strings
// as the original, owning none of them. // as the original, owning none of them.
bool isWebSocket() const;
// Convenience method that checks for the presence of the header `Upgrade: websocket`.
//
// Note that this does not actually validate that the request is a complete WebSocket handshake
// with the correct version number -- such validation will occur if and when you call
// acceptWebSocket().
kj::Maybe<kj::StringPtr> get(HttpHeaderId id) const; kj::Maybe<kj::StringPtr> get(HttpHeaderId id) const;
// Read a header. // Read a header.
...@@ -413,7 +420,7 @@ public: ...@@ -413,7 +420,7 @@ public:
virtual kj::Promise<Message> receive() = 0; virtual kj::Promise<Message> receive() = 0;
// Read one message from the WebSocket and return it. Can only call once at a time. Do not call // Read one message from the WebSocket and return it. Can only call once at a time. Do not call
// again after EndOfStream is received. // again after Close is received.
kj::Promise<void> pumpTo(WebSocket& other); kj::Promise<void> pumpTo(WebSocket& other);
// Continuously receives messages from this WebSocket and send them to `other`. // Continuously receives messages from this WebSocket and send them to `other`.
...@@ -513,6 +520,9 @@ public: ...@@ -513,6 +520,9 @@ public:
// `statusText` and `headers` need only remain valid until send() returns (they can be // `statusText` and `headers` need only remain valid until send() returns (they can be
// stack-allocated). // stack-allocated).
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// If headers.isWebSocket() is true then you can call acceptWebSocket() instead of send().
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText,
const HttpHeaders& headers); const HttpHeaders& headers);
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText,
...@@ -536,21 +546,6 @@ public: ...@@ -536,21 +546,6 @@ public:
// `url` and `headers` are invalidated on the first read from `requestBody` or when the returned // `url` and `headers` are invalidated on the first read from `requestBody` or when the returned
// promise resolves, whichever comes first. // promise resolves, whichever comes first.
class WebSocketResponse: public Response {
public:
virtual kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) = 0;
// Accept and open the WebSocket.
//
// `headers` need only remain valid until acceptWebSocket() returns (it can be stack-allocated).
};
virtual kj::Promise<void> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers, WebSocketResponse& response);
// Tries to open a WebSocket. Default implementation calls request() and never returns a
// WebSocket.
//
// `url` and `headers` are invalidated when the returned promise resolves.
virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host); virtual kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host);
// Handles CONNECT requests. Only relevant for proxy services. Default implementation throws // Handles CONNECT requests. Only relevant for proxy services. Default implementation throws
// UNIMPLEMENTED. // UNIMPLEMENTED.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment