Commit 3b08c76b authored by Kenton Varda's avatar Kenton Varda

Refactor handling of connection-level headers.

Although applications in theory shouldn't care to see connection-level headers (e.g. `Transfer-Encoding`), higher-level specs like the JavaScript Fetch API often specify that these headers should be visible, and they can be useful for debugging. So, this change makes it so that the application can see the connection-level headers on incoming requests.

For outgoing requests, the application can provide an HttpHeaders object that specifies these headers (important especially for the pass-through case), but the HTTP implementation will ignore them.

Additionally, we can now allow the application to set WebSocket connection-level headers on non-WebSocket requests. This is useful for frameworks that emulate WebSocket over HTTP and assume the ability to set WebSocket headers (especially `Sec-WebSocket-Extension`) on regular non-WebSocket HTTP requests.
parent cb49fa92
......@@ -119,20 +119,21 @@ KJ_TEST("HttpHeaders::parseRequest") {
KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(fooBar)) == "Baz");
KJ_EXPECT(headers.get(bazQux) == nullptr);
KJ_EXPECT(headers.get(HttpHeaderId::CONTENT_TYPE) == nullptr);
KJ_EXPECT(result.connectionHeaders.contentLength == "123");
KJ_EXPECT(result.connectionHeaders.transferEncoding == nullptr);
KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(HttpHeaderId::CONTENT_LENGTH)) == "123");
KJ_EXPECT(headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr);
std::map<kj::StringPtr, kj::StringPtr> unpackedHeaders;
headers.forEach([&](kj::StringPtr name, kj::StringPtr value) {
KJ_EXPECT(unpackedHeaders.insert(std::make_pair(name, value)).second);
});
KJ_EXPECT(unpackedHeaders.size() == 4);
KJ_EXPECT(unpackedHeaders.size() == 5);
KJ_EXPECT(unpackedHeaders["Content-Length"] == "123");
KJ_EXPECT(unpackedHeaders["Host"] == "example.com");
KJ_EXPECT(unpackedHeaders["Date"] == "early");
KJ_EXPECT(unpackedHeaders["Foo-Bar"] == "Baz");
KJ_EXPECT(unpackedHeaders["other-Header"] == "yep");
KJ_EXPECT(headers.serializeRequest(result.method, result.url, result.connectionHeaders) ==
KJ_EXPECT(headers.serializeRequest(result.method, result.url) ==
"POST /some/path HTTP/1.1\r\n"
"Content-Length: 123\r\n"
"Host: example.com\r\n"
......@@ -168,21 +169,22 @@ KJ_TEST("HttpHeaders::parseResponse") {
KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(fooBar)) == "Baz");
KJ_EXPECT(headers.get(bazQux) == nullptr);
KJ_EXPECT(headers.get(HttpHeaderId::CONTENT_TYPE) == nullptr);
KJ_EXPECT(result.connectionHeaders.contentLength == "123");
KJ_EXPECT(result.connectionHeaders.transferEncoding == nullptr);
KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(HttpHeaderId::CONTENT_LENGTH)) == "123");
KJ_EXPECT(headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr);
std::map<kj::StringPtr, kj::StringPtr> unpackedHeaders;
headers.forEach([&](kj::StringPtr name, kj::StringPtr value) {
KJ_EXPECT(unpackedHeaders.insert(std::make_pair(name, value)).second);
});
KJ_EXPECT(unpackedHeaders.size() == 4);
KJ_EXPECT(unpackedHeaders.size() == 5);
KJ_EXPECT(unpackedHeaders["Content-Length"] == "123");
KJ_EXPECT(unpackedHeaders["Host"] == "example.com");
KJ_EXPECT(unpackedHeaders["Date"] == "early");
KJ_EXPECT(unpackedHeaders["Foo-Bar"] == "Baz");
KJ_EXPECT(unpackedHeaders["other-Header"] == "yep");
KJ_EXPECT(headers.serializeResponse(
result.statusCode, result.statusText, result.connectionHeaders) ==
result.statusCode, result.statusText) ==
"HTTP/1.1 418 I'm a teapot\r\n"
"Content-Length: 123\r\n"
"Host: example.com\r\n"
......
......@@ -476,31 +476,25 @@ static const char* BUILTIN_HEADER_NAMES[] = {
#undef HEADER_NAME
};
enum class BuiltinHeaderIndices {
enum class BuiltinHeaderIndicesEnum {
#define HEADER_ID(id, name) id,
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID)
#undef HEADER_ID
};
static constexpr size_t CONNECTION_HEADER_COUNT KJ_UNUSED = 0
#define COUNT_HEADER(id, name) + 1
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(COUNT_HEADER)
#undef COUNT_HEADER
;
enum class ConnectionHeaderIndices {
#define HEADER_ID(id, name) id,
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HEADER_ID)
namespace BuiltinHeaderIndices {
#define HEADER_ID(id, name) constexpr uint id = static_cast<uint>(BuiltinHeaderIndicesEnum::id);
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID)
#undef HEADER_ID
};
static constexpr uint CONNECTION_HEADER_XOR = kj::maxValue;
static constexpr uint CONNECTION_HEADER_THRESHOLD = CONNECTION_HEADER_XOR >> 1;
constexpr uint CONNECTION_HEADERS_COUNT = BuiltinHeaderIndices::SEC_WEBSOCKET_KEY;
constexpr uint WEBSOCKET_CONNECTION_HEADERS_COUNT = BuiltinHeaderIndices::HOST;
} // namespace
#define DEFINE_HEADER(id, name) \
const HttpHeaderId HttpHeaderId::id(nullptr, static_cast<uint>(BuiltinHeaderIndices::id));
const HttpHeaderId HttpHeaderId::id(nullptr, BuiltinHeaderIndices::id);
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(DEFINE_HEADER)
#undef DEFINE_HEADER
......@@ -562,15 +556,9 @@ HttpHeaderId HttpHeaderTable::Builder::add(kj::StringPtr name) {
HttpHeaderTable::HttpHeaderTable()
: idsByName(kj::heap<IdsByNameMap>()) {
#define ADD_HEADER(id, name) \
idsByName->map.insert(std::make_pair(name, \
static_cast<uint>(ConnectionHeaderIndices::id) ^ CONNECTION_HEADER_XOR));
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(ADD_HEADER);
#undef ADD_HEADER
#define ADD_HEADER(id, name) \
namesById.add(name); \
idsByName->map.insert(std::make_pair(name, static_cast<uint>(BuiltinHeaderIndices::id)));
idsByName->map.insert(std::make_pair(name, BuiltinHeaderIndices::id));
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(ADD_HEADER);
#undef ADD_HEADER
}
......@@ -657,8 +645,7 @@ void HttpHeaders::add(kj::StringPtr name, kj::StringPtr value) {
requireValidHeaderName(name);
requireValidHeaderValue(value);
KJ_REQUIRE(addNoCheck(name, value) == nullptr,
"can't set connection-level headers on HttpHeaders", name, value) { break; }
addNoCheck(name, value);
}
void HttpHeaders::add(kj::StringPtr name, kj::String&& value) {
......@@ -672,12 +659,8 @@ void HttpHeaders::add(kj::String&& name, kj::String&& value) {
takeOwnership(kj::mv(value));
}
kj::Maybe<uint> HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) {
void HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) {
KJ_IF_MAYBE(id, table->stringToId(name)) {
if (id->id > CONNECTION_HEADER_THRESHOLD) {
return id->id ^ CONNECTION_HEADER_XOR;
}
if (indexedHeaders[id->id] == nullptr) {
indexedHeaders[id->id] = value;
} else {
......@@ -689,8 +672,6 @@ kj::Maybe<uint> HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value)
} else {
unindexedHeaders.add(Header {name, value});
}
return nullptr;
}
void HttpHeaders::takeOwnership(kj::String&& string) {
......@@ -887,7 +868,7 @@ kj::Maybe<HttpHeaders::Request> HttpHeaders::tryParseRequest(kj::ArrayPtr<char>
// Ignore rest of line. Don't care about "HTTP/1.1" or whatever.
consumeLine(ptr);
if (!parseHeaders(ptr, end, request.connectionHeaders)) return nullptr;
if (!parseHeaders(ptr, end)) return nullptr;
return request;
}
......@@ -914,28 +895,16 @@ kj::Maybe<HttpHeaders::Response> HttpHeaders::tryParseResponse(kj::ArrayPtr<char
response.statusText = consumeLine(ptr);
if (!parseHeaders(ptr, end, response.connectionHeaders)) return nullptr;
if (!parseHeaders(ptr, end)) return nullptr;
return response;
}
bool HttpHeaders::parseHeaders(char* ptr, char* end, ConnectionHeaders& connectionHeaders) {
bool HttpHeaders::parseHeaders(char* ptr, char* end) {
while (*ptr != '\0') {
KJ_IF_MAYBE(name, consumeHeaderName(ptr)) {
kj::StringPtr line = consumeLine(ptr);
KJ_IF_MAYBE(connectionHeaderId, addNoCheck(*name, line)) {
// Parsed a connection header.
switch (*connectionHeaderId) {
#define HANDLE_HEADER(id, name) \
case static_cast<uint>(ConnectionHeaderIndices::id): \
connectionHeaders.id = line; \
break;
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER)
#undef HANDLE_HEADER
default:
KJ_UNREACHABLE;
}
}
addNoCheck(*name, line);
} else {
return false;
}
......@@ -946,13 +915,15 @@ bool HttpHeaders::parseHeaders(char* ptr, char* end, ConnectionHeaders& connecti
// -----------------------------------------------------------------------------
kj::String HttpHeaders::serializeRequest(HttpMethod method, kj::StringPtr url,
const ConnectionHeaders& connectionHeaders) const {
kj::String HttpHeaders::serializeRequest(
HttpMethod method, kj::StringPtr url,
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
return serialize(kj::toCharSequence(method), url, kj::StringPtr("HTTP/1.1"), connectionHeaders);
}
kj::String HttpHeaders::serializeResponse(uint statusCode, kj::StringPtr statusText,
const ConnectionHeaders& connectionHeaders) const {
kj::String HttpHeaders::serializeResponse(
uint statusCode, kj::StringPtr statusText,
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
auto statusCodeStr = kj::toCharSequence(statusCode);
return serialize(kj::StringPtr("HTTP/1.1"), statusCodeStr, statusText, connectionHeaders);
......@@ -961,7 +932,7 @@ kj::String HttpHeaders::serializeResponse(uint statusCode, kj::StringPtr statusT
kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
kj::ArrayPtr<const char> word2,
kj::ArrayPtr<const char> word3,
const ConnectionHeaders& connectionHeaders) const {
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
const kj::StringPtr space = " ";
const kj::StringPtr newline = "\r\n";
const kj::StringPtr colon = ": ";
......@@ -970,15 +941,11 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
if (word1 != nullptr) {
size += word1.size() + word2.size() + word3.size() + 4;
}
#define HANDLE_HEADER(id, name) \
if (connectionHeaders.id != nullptr) { \
size += connectionHeaders.id.size() + (sizeof(name) + 3); \
}
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER)
#undef HANDLE_HEADER
KJ_ASSERT(connectionHeaders.size() <= indexedHeaders.size());
for (auto i: kj::indices(indexedHeaders)) {
if (indexedHeaders[i] != nullptr) {
size += table->idToString(HttpHeaderId(table, i)).size() + indexedHeaders[i].size() + 4;
kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i];
if (value != nullptr) {
size += table->idToString(HttpHeaderId(table, i)).size() + value.size() + 4;
}
}
for (auto& header: unindexedHeaders) {
......@@ -991,16 +958,10 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
if (word1 != nullptr) {
ptr = kj::_::fill(ptr, word1, space, word2, space, word3, newline);
}
#define HANDLE_HEADER(id, name) \
if (connectionHeaders.id != nullptr) { \
ptr = kj::_::fill(ptr, kj::StringPtr(name), colon, connectionHeaders.id, newline); \
}
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER)
#undef HANDLE_HEADER
for (auto i: kj::indices(indexedHeaders)) {
if (indexedHeaders[i] != nullptr) {
ptr = kj::_::fill(ptr, table->idToString(HttpHeaderId(table, i)), colon,
indexedHeaders[i], newline);
kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i];
if (value != nullptr) {
ptr = kj::_::fill(ptr, table->idToString(HttpHeaderId(table, i)), colon, value, newline);
}
}
for (auto& header: unindexedHeaders) {
......@@ -1013,7 +974,7 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
}
kj::String HttpHeaders::toString() const {
return serialize(nullptr, nullptr, nullptr, ConnectionHeaders());
return serialize(nullptr, nullptr, nullptr, nullptr);
}
// =======================================================================================
......@@ -1196,7 +1157,7 @@ public:
kj::Own<kj::AsyncInputStream> getEntityBody(
RequestOrResponse type, HttpMethod method, uint statusCode,
HttpHeaders::ConnectionHeaders& connectionHeaders);
const kj::HttpHeaders& headers);
struct ReleasedBuffer {
kj::Array<byte> buffer;
......@@ -1570,13 +1531,13 @@ static_assert(!fastCaseCmp<'f','O','o','B','1','a'>("FooB1"), "");
kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody(
RequestOrResponse type, HttpMethod method, uint statusCode,
HttpHeaders::ConnectionHeaders& connectionHeaders) {
const kj::HttpHeaders& headers) {
if (type == RESPONSE) {
if (method == HttpMethod::HEAD) {
// Body elided.
kj::Maybe<uint64_t> length;
if (connectionHeaders.contentLength != nullptr) {
length = strtoull(connectionHeaders.contentLength.cStr(), nullptr, 10);
KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) {
length = strtoull(cl->cStr(), nullptr, 10);
}
return kj::heap<HttpNullEntityReader>(*this, length);
} else if (statusCode == 204 || statusCode == 205 || statusCode == 304) {
......@@ -1585,19 +1546,18 @@ kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody(
}
}
if (connectionHeaders.transferEncoding != nullptr) {
KJ_IF_MAYBE(te, headers.get(HttpHeaderId::TRANSFER_ENCODING)) {
// TODO(someday): Support plugable transfer encodings? Or at least gzip?
// TODO(soon): Support stacked transfer encodings, e.g. "gzip, chunked".
if (fastCaseCmp<'c','h','u','n','k','e','d'>(connectionHeaders.transferEncoding.cStr())) {
if (fastCaseCmp<'c','h','u','n','k','e','d'>(te->cStr())) {
return kj::heap<HttpChunkedEntityReader>(*this);
} else {
KJ_FAIL_REQUIRE("unknown transfer encoding") { break; }
}
}
if (connectionHeaders.contentLength != nullptr) {
return kj::heap<HttpFixedLengthEntityReader>(*this,
strtoull(connectionHeaders.contentLength.cStr(), nullptr, 10));
KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) {
return kj::heap<HttpFixedLengthEntityReader>(*this, strtoull(cl->cStr(), nullptr, 10));
}
if (type == REQUEST) {
......@@ -1605,10 +1565,10 @@ kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody(
return kj::heap<HttpNullEntityReader>(*this, uint64_t(0));
}
if (connectionHeaders.connection != nullptr) {
KJ_IF_MAYBE(c, headers.get(HttpHeaderId::CONNECTION)) {
// TODO(soon): Connection header can actually have multiple tokens... but no one ever uses
// that feature?
if (fastCaseCmp<'c','l','o','s','e'>(connectionHeaders.connection.cStr())) {
if (fastCaseCmp<'c','l','o','s','e'>(c->cStr())) {
return kj::heap<HttpConnectionCloseEntityReader>(*this);
}
}
......@@ -2463,16 +2423,16 @@ public:
"this HttpClient's connection has been closed by the server or due to an error");
closeWatcherTask = nullptr;
HttpHeaders::ConnectionHeaders connectionHeaders;
kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT];
kj::String lengthStr;
if (method == HttpMethod::GET || method == HttpMethod::HEAD) {
// No entity-body.
} else KJ_IF_MAYBE(s, expectedBodySize) {
lengthStr = kj::str(*s);
connectionHeaders.contentLength = lengthStr;
connectionHeaders[BuiltinHeaderIndices::CONTENT_LENGTH] = lengthStr;
} else {
connectionHeaders.transferEncoding = "chunked";
connectionHeaders[BuiltinHeaderIndices::TRANSFER_ENCODING] = "chunked";
}
httpOutput.writeHeaders(headers.serializeRequest(method, url, connectionHeaders));
......@@ -2491,14 +2451,16 @@ public:
auto responsePromise = httpInput.readResponseHeaders()
.then([this,method](kj::Maybe<HttpHeaders::Response>&& response) -> HttpClient::Response {
KJ_IF_MAYBE(r, response) {
auto& headers = httpInput.getHeaders();
HttpClient::Response result {
r->statusCode,
r->statusText,
&httpInput.getHeaders(),
httpInput.getEntityBody(HttpInputStream::RESPONSE, method, r->statusCode,
r->connectionHeaders)
&headers,
httpInput.getEntityBody(HttpInputStream::RESPONSE, method, r->statusCode, headers)
};
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(r->connectionHeaders.connection.cStr())) {
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(
headers.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) {
closed = true;
} else {
watchForClose();
......@@ -2533,11 +2495,11 @@ public:
"HttpClient").generate(keyBytes);
auto keyBase64 = kj::encodeBase64(keyBytes);
HttpHeaders::ConnectionHeaders connectionHeaders;
connectionHeaders.connection = "Upgrade";
connectionHeaders.upgrade = "websocket";
connectionHeaders.websocketVersion = "13";
connectionHeaders.websocketKey = keyBase64;
kj::StringPtr connectionHeaders[WEBSOCKET_CONNECTION_HEADERS_COUNT];
connectionHeaders[BuiltinHeaderIndices::CONNECTION] = "Upgrade";
connectionHeaders[BuiltinHeaderIndices::UPGRADE] = "websocket";
connectionHeaders[BuiltinHeaderIndices::SEC_WEBSOCKET_VERSION] = "13";
connectionHeaders[BuiltinHeaderIndices::SEC_WEBSOCKET_KEY] = keyBase64;
httpOutput.writeHeaders(headers.serializeRequest(HttpMethod::GET, url, connectionHeaders));
......@@ -2549,18 +2511,23 @@ public:
[this](kj::StringPtr keyBase64, kj::Maybe<HttpHeaders::Response>&& response)
-> HttpClient::WebSocketResponse {
KJ_IF_MAYBE(r, response) {
auto& headers = httpInput.getHeaders();
if (r->statusCode == 101) {
if (!fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
r->connectionHeaders.upgrade.cStr())) {
headers.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
KJ_FAIL_REQUIRE("server returned incorrect Upgrade header; should be 'websocket'",
r->connectionHeaders.upgrade) { break; }
headers.get(HttpHeaderId::UPGRADE).orDefault("(null)")) {
break;
}
return HttpClient::WebSocketResponse();
}
auto expectedAccept = generateWebSocketAccept(keyBase64);
if (r->connectionHeaders.websocketAccept != expectedAccept) {
if (headers.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT).orDefault(nullptr)
!= expectedAccept) {
KJ_FAIL_REQUIRE("server returned incorrect Sec-WebSocket-Accept header",
r->connectionHeaders.websocketAccept, expectedAccept) { break; }
headers.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT).orDefault("(null)"),
expectedAccept) { break; }
return HttpClient::WebSocketResponse();
}
......@@ -2575,11 +2542,12 @@ public:
HttpClient::WebSocketResponse result {
r->statusCode,
r->statusText,
&httpInput.getHeaders(),
&headers,
httpInput.getEntityBody(HttpInputStream::RESPONSE, HttpMethod::GET, r->statusCode,
r->connectionHeaders)
headers)
};
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(r->connectionHeaders.connection.cStr())) {
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(
headers.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) {
closed = true;
} else {
watchForClose();
......@@ -3367,31 +3335,32 @@ public:
KJ_IF_MAYBE(req, request) {
kj::Promise<void> promise = nullptr;
auto& headers = httpInput.getHeaders();
if (fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
req->connectionHeaders.upgrade.cStr())) {
headers.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
if (req->method != HttpMethod::GET) {
return sendError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
if (req->connectionHeaders.websocketVersion != "13") {
if (headers.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
if (req->connectionHeaders.websocketKey == nullptr) {
KJ_IF_MAYBE(key, headers.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
websocketKey = kj::str(*key);
promise = server.service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
return sendError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
currentMethod = HttpMethod::GET;
websocketKey = kj::str(req->connectionHeaders.websocketKey);
promise = server.service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
currentMethod = req->method;
websocketKey = nullptr;
auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, req->connectionHeaders);
HttpInputStream::REQUEST, req->method, 0, headers);
// TODO(perf): If the client disconnects, should we cancel the response? Probably, to
// prevent permanent deadlock. It's slightly weird in that arguably the client should
......@@ -3399,7 +3368,7 @@ public:
// other HTTP servers do similar things.
promise = server.service.request(
req->method, req->url, httpInput.getHeaders(), *body, *this);
req->method, req->url, headers, *body, *this);
promise = promise.attach(kj::mv(body));
}
......@@ -3494,16 +3463,16 @@ private:
httpInput.finishRead();
}
HttpHeaders::ConnectionHeaders connectionHeaders;
kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT];
kj::String lengthStr;
if (statusCode == 204 || statusCode == 205 || statusCode == 304) {
// No entity-body.
} else KJ_IF_MAYBE(s, expectedBodySize) {
lengthStr = kj::str(*s);
connectionHeaders.contentLength = lengthStr;
connectionHeaders[BuiltinHeaderIndices::CONTENT_LENGTH] = lengthStr;
} else {
connectionHeaders.transferEncoding = "chunked";
connectionHeaders[BuiltinHeaderIndices::TRANSFER_ENCODING] = "chunked";
}
httpOutput.writeHeaders(headers.serializeResponse(statusCode, statusText, connectionHeaders));
......@@ -3532,10 +3501,10 @@ private:
auto websocketAccept = generateWebSocketAccept(key);
HttpHeaders::ConnectionHeaders connectionHeaders;
connectionHeaders.websocketAccept = websocketAccept;
connectionHeaders.upgrade = "websocket";
connectionHeaders.connection = "Upgrade";
kj::StringPtr connectionHeaders[WEBSOCKET_CONNECTION_HEADERS_COUNT];
connectionHeaders[BuiltinHeaderIndices::SEC_WEBSOCKET_ACCEPT] = websocketAccept;
connectionHeaders[BuiltinHeaderIndices::UPGRADE] = "websocket";
connectionHeaders[BuiltinHeaderIndices::CONNECTION] = "Upgrade";
httpOutput.writeHeaders(headers.serializeResponse(
101, "Switching Protocols", connectionHeaders));
......@@ -3544,16 +3513,13 @@ private:
}
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::String body) {
auto bodySize = kj::str(body.size());
HttpHeaders failed(server.requestHeaderTable);
HttpHeaders::ConnectionHeaders connHeaders;
connHeaders.connection = "close";
connHeaders.contentLength = bodySize;
failed.set(HttpHeaderId::CONNECTION, "close");
failed.set(HttpHeaderId::CONTENT_LENGTH, kj::str(body.size()));
failed.set(HttpHeaderId::CONTENT_TYPE, "text/plain");
httpOutput.writeHeaders(failed.serializeResponse(statusCode, statusText, connHeaders));
httpOutput.writeHeaders(failed.serializeResponse(statusCode, statusText));
httpOutput.writeBodyData(kj::mv(body));
httpOutput.finishBody();
return httpOutput.flush(); // loop ends after flush
......
......@@ -78,19 +78,6 @@ namespace kj {
MACRO(UNSUBSCRIBE)
/* UPnP */
#define KJ_HTTP_FOR_EACH_CONNECTION_HEADER(MACRO) \
MACRO(connection, "Connection") \
MACRO(contentLength, "Content-Length") \
MACRO(keepAlive, "Keep-Alive") \
MACRO(te, "TE") \
MACRO(trailer, "Trailer") \
MACRO(transferEncoding, "Transfer-Encoding") \
MACRO(upgrade, "Upgrade") \
MACRO(websocketKey, "Sec-WebSocket-Key") \
MACRO(websocketVersion, "Sec-WebSocket-Version") \
MACRO(websocketAccept, "Sec-WebSocket-Accept") \
MACRO(websocketExtensions, "Sec-WebSocket-Extensions")
enum class HttpMethod {
// Enum of known HTTP methods.
//
......@@ -138,12 +125,27 @@ public:
// In opt mode, no-op.
#define KJ_HTTP_FOR_EACH_BUILTIN_HEADER(MACRO) \
/* Headers that are always read-only. */ \
MACRO(CONNECTION, "Connection") \
MACRO(CONTENT_LENGTH, "Content-Length") \
MACRO(KEEP_ALIVE, "Keep-Alive") \
MACRO(TE, "TE") \
MACRO(TRAILER, "Trailer") \
MACRO(TRANSFER_ENCODING, "Transfer-Encoding") \
MACRO(UPGRADE, "Upgrade") \
\
/* Headers that are read-only for WebSocket handshakes. */ \
MACRO(SEC_WEBSOCKET_KEY, "Sec-WebSocket-Key") \
MACRO(SEC_WEBSOCKET_VERSION, "Sec-WebSocket-Version") \
MACRO(SEC_WEBSOCKET_ACCEPT, "Sec-WebSocket-Accept") \
MACRO(SEC_WEBSOCKET_EXTENSIONS, "Sec-WebSocket-Extensions") \
\
/* Headers that you can write. */ \
MACRO(HOST, "Host") \
MACRO(DATE, "Date") \
MACRO(LOCATION, "Location") \
MACRO(CONTENT_TYPE, "Content-Type")
// For convenience, these very-common headers are valid for all HttpHeaderTables. You can refer
// to them like:
// For convenience, these headers are valid for all HttpHeaderTables. You can refer to them like:
//
// HttpHeaderId::HOST
//
......@@ -300,26 +302,13 @@ public:
// Takes overship of a string so that it lives until the HttpHeaders object is destroyed. Useful
// when you've passed a dynamic value to set() or add() or parse*().
struct ConnectionHeaders {
// These headers govern details of the specific HTTP connection or framing of the content.
// Hence, they are managed internally within the HTTP library, and never appear in an
// HttpHeaders structure.
#define DECLARE_HEADER(id, name) \
kj::StringPtr id;
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(DECLARE_HEADER)
#undef DECLARE_HEADER
};
struct Request {
HttpMethod method;
kj::StringPtr url;
ConnectionHeaders connectionHeaders;
};
struct Response {
uint statusCode;
kj::StringPtr statusText;
ConnectionHeaders connectionHeaders;
};
kj::Maybe<Request> tryParseRequest(kj::ArrayPtr<char> content);
......@@ -334,11 +323,15 @@ public:
// `HttpHeaders` is destroyed, or pass it to `takeOwnership()`.
kj::String serializeRequest(HttpMethod method, kj::StringPtr url,
const ConnectionHeaders& connectionHeaders) const;
kj::ArrayPtr<const kj::StringPtr> connectionHeaders = nullptr) const;
kj::String serializeResponse(uint statusCode, kj::StringPtr statusText,
const ConnectionHeaders& connectionHeaders) const;
kj::ArrayPtr<const kj::StringPtr> connectionHeaders = nullptr) const;
// Serialize the headers as a complete request or response blob. The blob uses '\r\n' newlines
// and includes the double-newline to indicate the end of the headers.
//
// `connectionHeaders`, if provided, contains connection-level headers supplied by the HTTP
// implementation, in the order specified by the KJ_HTTP_FOR_EACH_BUILTIN_HEADER macro. These
// headers values override any corresponding header value in the HttpHeaders object.
kj::String toString() const;
......@@ -356,16 +349,16 @@ private:
kj::Vector<kj::Array<char>> ownedStrings;
kj::Maybe<uint> addNoCheck(kj::StringPtr name, kj::StringPtr value);
void addNoCheck(kj::StringPtr name, kj::StringPtr value);
kj::StringPtr cloneToOwn(kj::StringPtr str);
kj::String serialize(kj::ArrayPtr<const char> word1,
kj::ArrayPtr<const char> word2,
kj::ArrayPtr<const char> word3,
const ConnectionHeaders& connectionHeaders) const;
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const;
bool parseHeaders(char* ptr, char* end, ConnectionHeaders& connectionHeaders);
bool parseHeaders(char* ptr, char* end);
// TODO(perf): Arguably we should store a map, but header sets are never very long
// TODO(perf): We could optimize for common headers by storing them directly as fields. We could
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment