Unverified Commit 6052cb94 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #618 from capnproto/connection-headers-handling

 Refactor handling of connection-level headers.
parents cb49fa92 d42d3b40
......@@ -119,20 +119,21 @@ KJ_TEST("HttpHeaders::parseRequest") {
KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(fooBar)) == "Baz");
KJ_EXPECT(headers.get(bazQux) == nullptr);
KJ_EXPECT(headers.get(HttpHeaderId::CONTENT_TYPE) == nullptr);
KJ_EXPECT(result.connectionHeaders.contentLength == "123");
KJ_EXPECT(result.connectionHeaders.transferEncoding == nullptr);
KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(HttpHeaderId::CONTENT_LENGTH)) == "123");
KJ_EXPECT(headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr);
std::map<kj::StringPtr, kj::StringPtr> unpackedHeaders;
headers.forEach([&](kj::StringPtr name, kj::StringPtr value) {
KJ_EXPECT(unpackedHeaders.insert(std::make_pair(name, value)).second);
});
KJ_EXPECT(unpackedHeaders.size() == 4);
KJ_EXPECT(unpackedHeaders.size() == 5);
KJ_EXPECT(unpackedHeaders["Content-Length"] == "123");
KJ_EXPECT(unpackedHeaders["Host"] == "example.com");
KJ_EXPECT(unpackedHeaders["Date"] == "early");
KJ_EXPECT(unpackedHeaders["Foo-Bar"] == "Baz");
KJ_EXPECT(unpackedHeaders["other-Header"] == "yep");
KJ_EXPECT(headers.serializeRequest(result.method, result.url, result.connectionHeaders) ==
KJ_EXPECT(headers.serializeRequest(result.method, result.url) ==
"POST /some/path HTTP/1.1\r\n"
"Content-Length: 123\r\n"
"Host: example.com\r\n"
......@@ -168,21 +169,22 @@ KJ_TEST("HttpHeaders::parseResponse") {
KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(fooBar)) == "Baz");
KJ_EXPECT(headers.get(bazQux) == nullptr);
KJ_EXPECT(headers.get(HttpHeaderId::CONTENT_TYPE) == nullptr);
KJ_EXPECT(result.connectionHeaders.contentLength == "123");
KJ_EXPECT(result.connectionHeaders.transferEncoding == nullptr);
KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(HttpHeaderId::CONTENT_LENGTH)) == "123");
KJ_EXPECT(headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr);
std::map<kj::StringPtr, kj::StringPtr> unpackedHeaders;
headers.forEach([&](kj::StringPtr name, kj::StringPtr value) {
KJ_EXPECT(unpackedHeaders.insert(std::make_pair(name, value)).second);
});
KJ_EXPECT(unpackedHeaders.size() == 4);
KJ_EXPECT(unpackedHeaders.size() == 5);
KJ_EXPECT(unpackedHeaders["Content-Length"] == "123");
KJ_EXPECT(unpackedHeaders["Host"] == "example.com");
KJ_EXPECT(unpackedHeaders["Date"] == "early");
KJ_EXPECT(unpackedHeaders["Foo-Bar"] == "Baz");
KJ_EXPECT(unpackedHeaders["other-Header"] == "yep");
KJ_EXPECT(headers.serializeResponse(
result.statusCode, result.statusText, result.connectionHeaders) ==
result.statusCode, result.statusText) ==
"HTTP/1.1 418 I'm a teapot\r\n"
"Content-Length: 123\r\n"
"Host: example.com\r\n"
......
......@@ -476,31 +476,25 @@ static const char* BUILTIN_HEADER_NAMES[] = {
#undef HEADER_NAME
};
enum class BuiltinHeaderIndices {
enum class BuiltinHeaderIndicesEnum {
#define HEADER_ID(id, name) id,
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID)
#undef HEADER_ID
};
static constexpr size_t CONNECTION_HEADER_COUNT KJ_UNUSED = 0
#define COUNT_HEADER(id, name) + 1
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(COUNT_HEADER)
#undef COUNT_HEADER
;
enum class ConnectionHeaderIndices {
#define HEADER_ID(id, name) id,
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HEADER_ID)
namespace BuiltinHeaderIndices {
#define HEADER_ID(id, name) constexpr uint id = static_cast<uint>(BuiltinHeaderIndicesEnum::id);
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID)
#undef HEADER_ID
};
static constexpr uint CONNECTION_HEADER_XOR = kj::maxValue;
static constexpr uint CONNECTION_HEADER_THRESHOLD = CONNECTION_HEADER_XOR >> 1;
constexpr uint CONNECTION_HEADERS_COUNT = BuiltinHeaderIndices::SEC_WEBSOCKET_KEY;
constexpr uint WEBSOCKET_CONNECTION_HEADERS_COUNT = BuiltinHeaderIndices::HOST;
} // namespace
#define DEFINE_HEADER(id, name) \
const HttpHeaderId HttpHeaderId::id(nullptr, static_cast<uint>(BuiltinHeaderIndices::id));
const HttpHeaderId HttpHeaderId::id(nullptr, BuiltinHeaderIndices::id);
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(DEFINE_HEADER)
#undef DEFINE_HEADER
......@@ -562,15 +556,9 @@ HttpHeaderId HttpHeaderTable::Builder::add(kj::StringPtr name) {
HttpHeaderTable::HttpHeaderTable()
: idsByName(kj::heap<IdsByNameMap>()) {
#define ADD_HEADER(id, name) \
idsByName->map.insert(std::make_pair(name, \
static_cast<uint>(ConnectionHeaderIndices::id) ^ CONNECTION_HEADER_XOR));
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(ADD_HEADER);
#undef ADD_HEADER
#define ADD_HEADER(id, name) \
namesById.add(name); \
idsByName->map.insert(std::make_pair(name, static_cast<uint>(BuiltinHeaderIndices::id)));
idsByName->map.insert(std::make_pair(name, BuiltinHeaderIndices::id));
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(ADD_HEADER);
#undef ADD_HEADER
}
......@@ -657,8 +645,7 @@ void HttpHeaders::add(kj::StringPtr name, kj::StringPtr value) {
requireValidHeaderName(name);
requireValidHeaderValue(value);
KJ_REQUIRE(addNoCheck(name, value) == nullptr,
"can't set connection-level headers on HttpHeaders", name, value) { break; }
addNoCheck(name, value);
}
void HttpHeaders::add(kj::StringPtr name, kj::String&& value) {
......@@ -672,12 +659,8 @@ void HttpHeaders::add(kj::String&& name, kj::String&& value) {
takeOwnership(kj::mv(value));
}
kj::Maybe<uint> HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) {
void HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) {
KJ_IF_MAYBE(id, table->stringToId(name)) {
if (id->id > CONNECTION_HEADER_THRESHOLD) {
return id->id ^ CONNECTION_HEADER_XOR;
}
if (indexedHeaders[id->id] == nullptr) {
indexedHeaders[id->id] = value;
} else {
......@@ -689,8 +672,6 @@ kj::Maybe<uint> HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value)
} else {
unindexedHeaders.add(Header {name, value});
}
return nullptr;
}
void HttpHeaders::takeOwnership(kj::String&& string) {
......@@ -887,7 +868,7 @@ kj::Maybe<HttpHeaders::Request> HttpHeaders::tryParseRequest(kj::ArrayPtr<char>
// Ignore rest of line. Don't care about "HTTP/1.1" or whatever.
consumeLine(ptr);
if (!parseHeaders(ptr, end, request.connectionHeaders)) return nullptr;
if (!parseHeaders(ptr, end)) return nullptr;
return request;
}
......@@ -914,28 +895,16 @@ kj::Maybe<HttpHeaders::Response> HttpHeaders::tryParseResponse(kj::ArrayPtr<char
response.statusText = consumeLine(ptr);
if (!parseHeaders(ptr, end, response.connectionHeaders)) return nullptr;
if (!parseHeaders(ptr, end)) return nullptr;
return response;
}
bool HttpHeaders::parseHeaders(char* ptr, char* end, ConnectionHeaders& connectionHeaders) {
bool HttpHeaders::parseHeaders(char* ptr, char* end) {
while (*ptr != '\0') {
KJ_IF_MAYBE(name, consumeHeaderName(ptr)) {
kj::StringPtr line = consumeLine(ptr);
KJ_IF_MAYBE(connectionHeaderId, addNoCheck(*name, line)) {
// Parsed a connection header.
switch (*connectionHeaderId) {
#define HANDLE_HEADER(id, name) \
case static_cast<uint>(ConnectionHeaderIndices::id): \
connectionHeaders.id = line; \
break;
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER)
#undef HANDLE_HEADER
default:
KJ_UNREACHABLE;
}
}
addNoCheck(*name, line);
} else {
return false;
}
......@@ -946,13 +915,15 @@ bool HttpHeaders::parseHeaders(char* ptr, char* end, ConnectionHeaders& connecti
// -----------------------------------------------------------------------------
kj::String HttpHeaders::serializeRequest(HttpMethod method, kj::StringPtr url,
const ConnectionHeaders& connectionHeaders) const {
kj::String HttpHeaders::serializeRequest(
HttpMethod method, kj::StringPtr url,
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
return serialize(kj::toCharSequence(method), url, kj::StringPtr("HTTP/1.1"), connectionHeaders);
}
kj::String HttpHeaders::serializeResponse(uint statusCode, kj::StringPtr statusText,
const ConnectionHeaders& connectionHeaders) const {
kj::String HttpHeaders::serializeResponse(
uint statusCode, kj::StringPtr statusText,
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
auto statusCodeStr = kj::toCharSequence(statusCode);
return serialize(kj::StringPtr("HTTP/1.1"), statusCodeStr, statusText, connectionHeaders);
......@@ -961,7 +932,7 @@ kj::String HttpHeaders::serializeResponse(uint statusCode, kj::StringPtr statusT
kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
kj::ArrayPtr<const char> word2,
kj::ArrayPtr<const char> word3,
const ConnectionHeaders& connectionHeaders) const {
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
const kj::StringPtr space = " ";
const kj::StringPtr newline = "\r\n";
const kj::StringPtr colon = ": ";
......@@ -970,15 +941,11 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
if (word1 != nullptr) {
size += word1.size() + word2.size() + word3.size() + 4;
}
#define HANDLE_HEADER(id, name) \
if (connectionHeaders.id != nullptr) { \
size += connectionHeaders.id.size() + (sizeof(name) + 3); \
}
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER)
#undef HANDLE_HEADER
KJ_ASSERT(connectionHeaders.size() <= indexedHeaders.size());
for (auto i: kj::indices(indexedHeaders)) {
if (indexedHeaders[i] != nullptr) {
size += table->idToString(HttpHeaderId(table, i)).size() + indexedHeaders[i].size() + 4;
kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i];
if (value != nullptr) {
size += table->idToString(HttpHeaderId(table, i)).size() + value.size() + 4;
}
}
for (auto& header: unindexedHeaders) {
......@@ -991,16 +958,10 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
if (word1 != nullptr) {
ptr = kj::_::fill(ptr, word1, space, word2, space, word3, newline);
}
#define HANDLE_HEADER(id, name) \
if (connectionHeaders.id != nullptr) { \
ptr = kj::_::fill(ptr, kj::StringPtr(name), colon, connectionHeaders.id, newline); \
}
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER)
#undef HANDLE_HEADER
for (auto i: kj::indices(indexedHeaders)) {
if (indexedHeaders[i] != nullptr) {
ptr = kj::_::fill(ptr, table->idToString(HttpHeaderId(table, i)), colon,
indexedHeaders[i], newline);
kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i];
if (value != nullptr) {
ptr = kj::_::fill(ptr, table->idToString(HttpHeaderId(table, i)), colon, value, newline);
}
}
for (auto& header: unindexedHeaders) {
......@@ -1013,7 +974,7 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
}
kj::String HttpHeaders::toString() const {
return serialize(nullptr, nullptr, nullptr, ConnectionHeaders());
return serialize(nullptr, nullptr, nullptr, nullptr);
}
// =======================================================================================
......@@ -1144,15 +1105,18 @@ public:
}
inline kj::Promise<kj::Maybe<HttpHeaders::Request>> readRequestHeaders() {
headers.clear();
return readMessageHeaders().then([this](kj::ArrayPtr<char> text) {
headers.clear();
return headers.tryParseRequest(text);
});
}
inline kj::Promise<kj::Maybe<HttpHeaders::Response>> readResponseHeaders() {
headers.clear();
// Note: readResponseHeaders() could be called multiple times concurrently when pipelining
// requests. readMessageHeaders() will serialize these, but it's important not to mess with
// state (like calling headers.clear()) before said serialization has taken place.
return readMessageHeaders().then([this](kj::ArrayPtr<char> text) {
headers.clear();
return headers.tryParseResponse(text);
});
}
......@@ -1196,7 +1160,7 @@ public:
kj::Own<kj::AsyncInputStream> getEntityBody(
RequestOrResponse type, HttpMethod method, uint statusCode,
HttpHeaders::ConnectionHeaders& connectionHeaders);
const kj::HttpHeaders& headers);
struct ReleasedBuffer {
kj::Array<byte> buffer;
......@@ -1570,13 +1534,13 @@ static_assert(!fastCaseCmp<'f','O','o','B','1','a'>("FooB1"), "");
kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody(
RequestOrResponse type, HttpMethod method, uint statusCode,
HttpHeaders::ConnectionHeaders& connectionHeaders) {
const kj::HttpHeaders& headers) {
if (type == RESPONSE) {
if (method == HttpMethod::HEAD) {
// Body elided.
kj::Maybe<uint64_t> length;
if (connectionHeaders.contentLength != nullptr) {
length = strtoull(connectionHeaders.contentLength.cStr(), nullptr, 10);
KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) {
length = strtoull(cl->cStr(), nullptr, 10);
}
return kj::heap<HttpNullEntityReader>(*this, length);
} else if (statusCode == 204 || statusCode == 205 || statusCode == 304) {
......@@ -1585,19 +1549,18 @@ kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody(
}
}
if (connectionHeaders.transferEncoding != nullptr) {
KJ_IF_MAYBE(te, headers.get(HttpHeaderId::TRANSFER_ENCODING)) {
// TODO(someday): Support plugable transfer encodings? Or at least gzip?
// TODO(soon): Support stacked transfer encodings, e.g. "gzip, chunked".
if (fastCaseCmp<'c','h','u','n','k','e','d'>(connectionHeaders.transferEncoding.cStr())) {
if (fastCaseCmp<'c','h','u','n','k','e','d'>(te->cStr())) {
return kj::heap<HttpChunkedEntityReader>(*this);
} else {
KJ_FAIL_REQUIRE("unknown transfer encoding") { break; }
}
}
if (connectionHeaders.contentLength != nullptr) {
return kj::heap<HttpFixedLengthEntityReader>(*this,
strtoull(connectionHeaders.contentLength.cStr(), nullptr, 10));
KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) {
return kj::heap<HttpFixedLengthEntityReader>(*this, strtoull(cl->cStr(), nullptr, 10));
}
if (type == REQUEST) {
......@@ -1605,10 +1568,10 @@ kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody(
return kj::heap<HttpNullEntityReader>(*this, uint64_t(0));
}
if (connectionHeaders.connection != nullptr) {
KJ_IF_MAYBE(c, headers.get(HttpHeaderId::CONNECTION)) {
// TODO(soon): Connection header can actually have multiple tokens... but no one ever uses
// that feature?
if (fastCaseCmp<'c','l','o','s','e'>(connectionHeaders.connection.cStr())) {
if (fastCaseCmp<'c','l','o','s','e'>(c->cStr())) {
return kj::heap<HttpConnectionCloseEntityReader>(*this);
}
}
......@@ -2461,18 +2424,20 @@ public:
"of being upgraded");
KJ_REQUIRE(!closed,
"this HttpClient's connection has been closed by the server or due to an error");
KJ_REQUIRE(httpOutput.canReuse(),
"can't start new request until previous request body has been fully written");
closeWatcherTask = nullptr;
HttpHeaders::ConnectionHeaders connectionHeaders;
kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT];
kj::String lengthStr;
if (method == HttpMethod::GET || method == HttpMethod::HEAD) {
// No entity-body.
} else KJ_IF_MAYBE(s, expectedBodySize) {
lengthStr = kj::str(*s);
connectionHeaders.contentLength = lengthStr;
connectionHeaders[BuiltinHeaderIndices::CONTENT_LENGTH] = lengthStr;
} else {
connectionHeaders.transferEncoding = "chunked";
connectionHeaders[BuiltinHeaderIndices::TRANSFER_ENCODING] = "chunked";
}
httpOutput.writeHeaders(headers.serializeRequest(method, url, connectionHeaders));
......@@ -2491,14 +2456,16 @@ public:
auto responsePromise = httpInput.readResponseHeaders()
.then([this,method](kj::Maybe<HttpHeaders::Response>&& response) -> HttpClient::Response {
KJ_IF_MAYBE(r, response) {
auto& headers = httpInput.getHeaders();
HttpClient::Response result {
r->statusCode,
r->statusText,
&httpInput.getHeaders(),
httpInput.getEntityBody(HttpInputStream::RESPONSE, method, r->statusCode,
r->connectionHeaders)
&headers,
httpInput.getEntityBody(HttpInputStream::RESPONSE, method, r->statusCode, headers)
};
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(r->connectionHeaders.connection.cStr())) {
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(
headers.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) {
closed = true;
} else {
watchForClose();
......@@ -2533,11 +2500,11 @@ public:
"HttpClient").generate(keyBytes);
auto keyBase64 = kj::encodeBase64(keyBytes);
HttpHeaders::ConnectionHeaders connectionHeaders;
connectionHeaders.connection = "Upgrade";
connectionHeaders.upgrade = "websocket";
connectionHeaders.websocketVersion = "13";
connectionHeaders.websocketKey = keyBase64;
kj::StringPtr connectionHeaders[WEBSOCKET_CONNECTION_HEADERS_COUNT];
connectionHeaders[BuiltinHeaderIndices::CONNECTION] = "Upgrade";
connectionHeaders[BuiltinHeaderIndices::UPGRADE] = "websocket";
connectionHeaders[BuiltinHeaderIndices::SEC_WEBSOCKET_VERSION] = "13";
connectionHeaders[BuiltinHeaderIndices::SEC_WEBSOCKET_KEY] = keyBase64;
httpOutput.writeHeaders(headers.serializeRequest(HttpMethod::GET, url, connectionHeaders));
......@@ -2549,18 +2516,23 @@ public:
[this](kj::StringPtr keyBase64, kj::Maybe<HttpHeaders::Response>&& response)
-> HttpClient::WebSocketResponse {
KJ_IF_MAYBE(r, response) {
auto& headers = httpInput.getHeaders();
if (r->statusCode == 101) {
if (!fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
r->connectionHeaders.upgrade.cStr())) {
headers.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
KJ_FAIL_REQUIRE("server returned incorrect Upgrade header; should be 'websocket'",
r->connectionHeaders.upgrade) { break; }
headers.get(HttpHeaderId::UPGRADE).orDefault("(null)")) {
break;
}
return HttpClient::WebSocketResponse();
}
auto expectedAccept = generateWebSocketAccept(keyBase64);
if (r->connectionHeaders.websocketAccept != expectedAccept) {
if (headers.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT).orDefault(nullptr)
!= expectedAccept) {
KJ_FAIL_REQUIRE("server returned incorrect Sec-WebSocket-Accept header",
r->connectionHeaders.websocketAccept, expectedAccept) { break; }
headers.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT).orDefault("(null)"),
expectedAccept) { break; }
return HttpClient::WebSocketResponse();
}
......@@ -2575,11 +2547,12 @@ public:
HttpClient::WebSocketResponse result {
r->statusCode,
r->statusText,
&httpInput.getHeaders(),
&headers,
httpInput.getEntityBody(HttpInputStream::RESPONSE, HttpMethod::GET, r->statusCode,
r->connectionHeaders)
headers)
};
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(r->connectionHeaders.connection.cStr())) {
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(
headers.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) {
closed = true;
} else {
watchForClose();
......@@ -3367,31 +3340,32 @@ public:
KJ_IF_MAYBE(req, request) {
kj::Promise<void> promise = nullptr;
auto& headers = httpInput.getHeaders();
if (fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
req->connectionHeaders.upgrade.cStr())) {
headers.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
if (req->method != HttpMethod::GET) {
return sendError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request."));
}
if (req->connectionHeaders.websocketVersion != "13") {
if (headers.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported."));
}
if (req->connectionHeaders.websocketKey == nullptr) {
KJ_IF_MAYBE(key, headers.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
websocketKey = kj::str(*key);
promise = server.service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
return sendError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
}
currentMethod = HttpMethod::GET;
websocketKey = kj::str(req->connectionHeaders.websocketKey);
promise = server.service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
currentMethod = req->method;
websocketKey = nullptr;
auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, req->connectionHeaders);
HttpInputStream::REQUEST, req->method, 0, headers);
// TODO(perf): If the client disconnects, should we cancel the response? Probably, to
// prevent permanent deadlock. It's slightly weird in that arguably the client should
......@@ -3399,7 +3373,7 @@ public:
// other HTTP servers do similar things.
promise = server.service.request(
req->method, req->url, httpInput.getHeaders(), *body, *this);
req->method, req->url, headers, *body, *this);
promise = promise.attach(kj::mv(body));
}
......@@ -3494,16 +3468,16 @@ private:
httpInput.finishRead();
}
HttpHeaders::ConnectionHeaders connectionHeaders;
kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT];
kj::String lengthStr;
if (statusCode == 204 || statusCode == 205 || statusCode == 304) {
// No entity-body.
} else KJ_IF_MAYBE(s, expectedBodySize) {
lengthStr = kj::str(*s);
connectionHeaders.contentLength = lengthStr;
connectionHeaders[BuiltinHeaderIndices::CONTENT_LENGTH] = lengthStr;
} else {
connectionHeaders.transferEncoding = "chunked";
connectionHeaders[BuiltinHeaderIndices::TRANSFER_ENCODING] = "chunked";
}
httpOutput.writeHeaders(headers.serializeResponse(statusCode, statusText, connectionHeaders));
......@@ -3532,10 +3506,10 @@ private:
auto websocketAccept = generateWebSocketAccept(key);
HttpHeaders::ConnectionHeaders connectionHeaders;
connectionHeaders.websocketAccept = websocketAccept;
connectionHeaders.upgrade = "websocket";
connectionHeaders.connection = "Upgrade";
kj::StringPtr connectionHeaders[WEBSOCKET_CONNECTION_HEADERS_COUNT];
connectionHeaders[BuiltinHeaderIndices::SEC_WEBSOCKET_ACCEPT] = websocketAccept;
connectionHeaders[BuiltinHeaderIndices::UPGRADE] = "websocket";
connectionHeaders[BuiltinHeaderIndices::CONNECTION] = "Upgrade";
httpOutput.writeHeaders(headers.serializeResponse(
101, "Switching Protocols", connectionHeaders));
......@@ -3544,16 +3518,13 @@ private:
}
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::String body) {
auto bodySize = kj::str(body.size());
HttpHeaders failed(server.requestHeaderTable);
HttpHeaders::ConnectionHeaders connHeaders;
connHeaders.connection = "close";
connHeaders.contentLength = bodySize;
failed.set(HttpHeaderId::CONNECTION, "close");
failed.set(HttpHeaderId::CONTENT_LENGTH, kj::str(body.size()));
failed.set(HttpHeaderId::CONTENT_TYPE, "text/plain");
httpOutput.writeHeaders(failed.serializeResponse(statusCode, statusText, connHeaders));
httpOutput.writeHeaders(failed.serializeResponse(statusCode, statusText));
httpOutput.writeBodyData(kj::mv(body));
httpOutput.finishBody();
return httpOutput.flush(); // loop ends after flush
......
......@@ -78,19 +78,6 @@ namespace kj {
MACRO(UNSUBSCRIBE)
/* UPnP */
#define KJ_HTTP_FOR_EACH_CONNECTION_HEADER(MACRO) \
MACRO(connection, "Connection") \
MACRO(contentLength, "Content-Length") \
MACRO(keepAlive, "Keep-Alive") \
MACRO(te, "TE") \
MACRO(trailer, "Trailer") \
MACRO(transferEncoding, "Transfer-Encoding") \
MACRO(upgrade, "Upgrade") \
MACRO(websocketKey, "Sec-WebSocket-Key") \
MACRO(websocketVersion, "Sec-WebSocket-Version") \
MACRO(websocketAccept, "Sec-WebSocket-Accept") \
MACRO(websocketExtensions, "Sec-WebSocket-Extensions")
enum class HttpMethod {
// Enum of known HTTP methods.
//
......@@ -138,12 +125,27 @@ public:
// In opt mode, no-op.
#define KJ_HTTP_FOR_EACH_BUILTIN_HEADER(MACRO) \
/* Headers that are always read-only. */ \
MACRO(CONNECTION, "Connection") \
MACRO(CONTENT_LENGTH, "Content-Length") \
MACRO(KEEP_ALIVE, "Keep-Alive") \
MACRO(TE, "TE") \
MACRO(TRAILER, "Trailer") \
MACRO(TRANSFER_ENCODING, "Transfer-Encoding") \
MACRO(UPGRADE, "Upgrade") \
\
/* Headers that are read-only for WebSocket handshakes. */ \
MACRO(SEC_WEBSOCKET_KEY, "Sec-WebSocket-Key") \
MACRO(SEC_WEBSOCKET_VERSION, "Sec-WebSocket-Version") \
MACRO(SEC_WEBSOCKET_ACCEPT, "Sec-WebSocket-Accept") \
MACRO(SEC_WEBSOCKET_EXTENSIONS, "Sec-WebSocket-Extensions") \
\
/* Headers that you can write. */ \
MACRO(HOST, "Host") \
MACRO(DATE, "Date") \
MACRO(LOCATION, "Location") \
MACRO(CONTENT_TYPE, "Content-Type")
// For convenience, these very-common headers are valid for all HttpHeaderTables. You can refer
// to them like:
// For convenience, these headers are valid for all HttpHeaderTables. You can refer to them like:
//
// HttpHeaderId::HOST
//
......@@ -300,26 +302,13 @@ public:
// Takes overship of a string so that it lives until the HttpHeaders object is destroyed. Useful
// when you've passed a dynamic value to set() or add() or parse*().
struct ConnectionHeaders {
// These headers govern details of the specific HTTP connection or framing of the content.
// Hence, they are managed internally within the HTTP library, and never appear in an
// HttpHeaders structure.
#define DECLARE_HEADER(id, name) \
kj::StringPtr id;
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(DECLARE_HEADER)
#undef DECLARE_HEADER
};
struct Request {
HttpMethod method;
kj::StringPtr url;
ConnectionHeaders connectionHeaders;
};
struct Response {
uint statusCode;
kj::StringPtr statusText;
ConnectionHeaders connectionHeaders;
};
kj::Maybe<Request> tryParseRequest(kj::ArrayPtr<char> content);
......@@ -334,11 +323,15 @@ public:
// `HttpHeaders` is destroyed, or pass it to `takeOwnership()`.
kj::String serializeRequest(HttpMethod method, kj::StringPtr url,
const ConnectionHeaders& connectionHeaders) const;
kj::ArrayPtr<const kj::StringPtr> connectionHeaders = nullptr) const;
kj::String serializeResponse(uint statusCode, kj::StringPtr statusText,
const ConnectionHeaders& connectionHeaders) const;
kj::ArrayPtr<const kj::StringPtr> connectionHeaders = nullptr) const;
// Serialize the headers as a complete request or response blob. The blob uses '\r\n' newlines
// and includes the double-newline to indicate the end of the headers.
//
// `connectionHeaders`, if provided, contains connection-level headers supplied by the HTTP
// implementation, in the order specified by the KJ_HTTP_FOR_EACH_BUILTIN_HEADER macro. These
// headers values override any corresponding header value in the HttpHeaders object.
kj::String toString() const;
......@@ -356,16 +349,16 @@ private:
kj::Vector<kj::Array<char>> ownedStrings;
kj::Maybe<uint> addNoCheck(kj::StringPtr name, kj::StringPtr value);
void addNoCheck(kj::StringPtr name, kj::StringPtr value);
kj::StringPtr cloneToOwn(kj::StringPtr str);
kj::String serialize(kj::ArrayPtr<const char> word1,
kj::ArrayPtr<const char> word2,
kj::ArrayPtr<const char> word3,
const ConnectionHeaders& connectionHeaders) const;
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const;
bool parseHeaders(char* ptr, char* end, ConnectionHeaders& connectionHeaders);
bool parseHeaders(char* ptr, char* end);
// TODO(perf): Arguably we should store a map, but header sets are never very long
// TODO(perf): We could optimize for common headers by storing them directly as fields. We could
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment