Unverified Commit 6052cb94 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #618 from capnproto/connection-headers-handling

 Refactor handling of connection-level headers.
parents cb49fa92 d42d3b40
...@@ -119,20 +119,21 @@ KJ_TEST("HttpHeaders::parseRequest") { ...@@ -119,20 +119,21 @@ KJ_TEST("HttpHeaders::parseRequest") {
KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(fooBar)) == "Baz"); KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(fooBar)) == "Baz");
KJ_EXPECT(headers.get(bazQux) == nullptr); KJ_EXPECT(headers.get(bazQux) == nullptr);
KJ_EXPECT(headers.get(HttpHeaderId::CONTENT_TYPE) == nullptr); KJ_EXPECT(headers.get(HttpHeaderId::CONTENT_TYPE) == nullptr);
KJ_EXPECT(result.connectionHeaders.contentLength == "123"); KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(HttpHeaderId::CONTENT_LENGTH)) == "123");
KJ_EXPECT(result.connectionHeaders.transferEncoding == nullptr); KJ_EXPECT(headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr);
std::map<kj::StringPtr, kj::StringPtr> unpackedHeaders; std::map<kj::StringPtr, kj::StringPtr> unpackedHeaders;
headers.forEach([&](kj::StringPtr name, kj::StringPtr value) { headers.forEach([&](kj::StringPtr name, kj::StringPtr value) {
KJ_EXPECT(unpackedHeaders.insert(std::make_pair(name, value)).second); KJ_EXPECT(unpackedHeaders.insert(std::make_pair(name, value)).second);
}); });
KJ_EXPECT(unpackedHeaders.size() == 4); KJ_EXPECT(unpackedHeaders.size() == 5);
KJ_EXPECT(unpackedHeaders["Content-Length"] == "123");
KJ_EXPECT(unpackedHeaders["Host"] == "example.com"); KJ_EXPECT(unpackedHeaders["Host"] == "example.com");
KJ_EXPECT(unpackedHeaders["Date"] == "early"); KJ_EXPECT(unpackedHeaders["Date"] == "early");
KJ_EXPECT(unpackedHeaders["Foo-Bar"] == "Baz"); KJ_EXPECT(unpackedHeaders["Foo-Bar"] == "Baz");
KJ_EXPECT(unpackedHeaders["other-Header"] == "yep"); KJ_EXPECT(unpackedHeaders["other-Header"] == "yep");
KJ_EXPECT(headers.serializeRequest(result.method, result.url, result.connectionHeaders) == KJ_EXPECT(headers.serializeRequest(result.method, result.url) ==
"POST /some/path HTTP/1.1\r\n" "POST /some/path HTTP/1.1\r\n"
"Content-Length: 123\r\n" "Content-Length: 123\r\n"
"Host: example.com\r\n" "Host: example.com\r\n"
...@@ -168,21 +169,22 @@ KJ_TEST("HttpHeaders::parseResponse") { ...@@ -168,21 +169,22 @@ KJ_TEST("HttpHeaders::parseResponse") {
KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(fooBar)) == "Baz"); KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(fooBar)) == "Baz");
KJ_EXPECT(headers.get(bazQux) == nullptr); KJ_EXPECT(headers.get(bazQux) == nullptr);
KJ_EXPECT(headers.get(HttpHeaderId::CONTENT_TYPE) == nullptr); KJ_EXPECT(headers.get(HttpHeaderId::CONTENT_TYPE) == nullptr);
KJ_EXPECT(result.connectionHeaders.contentLength == "123"); KJ_EXPECT(KJ_ASSERT_NONNULL(headers.get(HttpHeaderId::CONTENT_LENGTH)) == "123");
KJ_EXPECT(result.connectionHeaders.transferEncoding == nullptr); KJ_EXPECT(headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr);
std::map<kj::StringPtr, kj::StringPtr> unpackedHeaders; std::map<kj::StringPtr, kj::StringPtr> unpackedHeaders;
headers.forEach([&](kj::StringPtr name, kj::StringPtr value) { headers.forEach([&](kj::StringPtr name, kj::StringPtr value) {
KJ_EXPECT(unpackedHeaders.insert(std::make_pair(name, value)).second); KJ_EXPECT(unpackedHeaders.insert(std::make_pair(name, value)).second);
}); });
KJ_EXPECT(unpackedHeaders.size() == 4); KJ_EXPECT(unpackedHeaders.size() == 5);
KJ_EXPECT(unpackedHeaders["Content-Length"] == "123");
KJ_EXPECT(unpackedHeaders["Host"] == "example.com"); KJ_EXPECT(unpackedHeaders["Host"] == "example.com");
KJ_EXPECT(unpackedHeaders["Date"] == "early"); KJ_EXPECT(unpackedHeaders["Date"] == "early");
KJ_EXPECT(unpackedHeaders["Foo-Bar"] == "Baz"); KJ_EXPECT(unpackedHeaders["Foo-Bar"] == "Baz");
KJ_EXPECT(unpackedHeaders["other-Header"] == "yep"); KJ_EXPECT(unpackedHeaders["other-Header"] == "yep");
KJ_EXPECT(headers.serializeResponse( KJ_EXPECT(headers.serializeResponse(
result.statusCode, result.statusText, result.connectionHeaders) == result.statusCode, result.statusText) ==
"HTTP/1.1 418 I'm a teapot\r\n" "HTTP/1.1 418 I'm a teapot\r\n"
"Content-Length: 123\r\n" "Content-Length: 123\r\n"
"Host: example.com\r\n" "Host: example.com\r\n"
......
...@@ -476,31 +476,25 @@ static const char* BUILTIN_HEADER_NAMES[] = { ...@@ -476,31 +476,25 @@ static const char* BUILTIN_HEADER_NAMES[] = {
#undef HEADER_NAME #undef HEADER_NAME
}; };
enum class BuiltinHeaderIndices { enum class BuiltinHeaderIndicesEnum {
#define HEADER_ID(id, name) id, #define HEADER_ID(id, name) id,
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID) KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID)
#undef HEADER_ID #undef HEADER_ID
}; };
static constexpr size_t CONNECTION_HEADER_COUNT KJ_UNUSED = 0 namespace BuiltinHeaderIndices {
#define COUNT_HEADER(id, name) + 1 #define HEADER_ID(id, name) constexpr uint id = static_cast<uint>(BuiltinHeaderIndicesEnum::id);
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(COUNT_HEADER) KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID)
#undef COUNT_HEADER
;
enum class ConnectionHeaderIndices {
#define HEADER_ID(id, name) id,
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HEADER_ID)
#undef HEADER_ID #undef HEADER_ID
}; };
static constexpr uint CONNECTION_HEADER_XOR = kj::maxValue; constexpr uint CONNECTION_HEADERS_COUNT = BuiltinHeaderIndices::SEC_WEBSOCKET_KEY;
static constexpr uint CONNECTION_HEADER_THRESHOLD = CONNECTION_HEADER_XOR >> 1; constexpr uint WEBSOCKET_CONNECTION_HEADERS_COUNT = BuiltinHeaderIndices::HOST;
} // namespace } // namespace
#define DEFINE_HEADER(id, name) \ #define DEFINE_HEADER(id, name) \
const HttpHeaderId HttpHeaderId::id(nullptr, static_cast<uint>(BuiltinHeaderIndices::id)); const HttpHeaderId HttpHeaderId::id(nullptr, BuiltinHeaderIndices::id);
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(DEFINE_HEADER) KJ_HTTP_FOR_EACH_BUILTIN_HEADER(DEFINE_HEADER)
#undef DEFINE_HEADER #undef DEFINE_HEADER
...@@ -562,15 +556,9 @@ HttpHeaderId HttpHeaderTable::Builder::add(kj::StringPtr name) { ...@@ -562,15 +556,9 @@ HttpHeaderId HttpHeaderTable::Builder::add(kj::StringPtr name) {
HttpHeaderTable::HttpHeaderTable() HttpHeaderTable::HttpHeaderTable()
: idsByName(kj::heap<IdsByNameMap>()) { : idsByName(kj::heap<IdsByNameMap>()) {
#define ADD_HEADER(id, name) \
idsByName->map.insert(std::make_pair(name, \
static_cast<uint>(ConnectionHeaderIndices::id) ^ CONNECTION_HEADER_XOR));
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(ADD_HEADER);
#undef ADD_HEADER
#define ADD_HEADER(id, name) \ #define ADD_HEADER(id, name) \
namesById.add(name); \ namesById.add(name); \
idsByName->map.insert(std::make_pair(name, static_cast<uint>(BuiltinHeaderIndices::id))); idsByName->map.insert(std::make_pair(name, BuiltinHeaderIndices::id));
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(ADD_HEADER); KJ_HTTP_FOR_EACH_BUILTIN_HEADER(ADD_HEADER);
#undef ADD_HEADER #undef ADD_HEADER
} }
...@@ -657,8 +645,7 @@ void HttpHeaders::add(kj::StringPtr name, kj::StringPtr value) { ...@@ -657,8 +645,7 @@ void HttpHeaders::add(kj::StringPtr name, kj::StringPtr value) {
requireValidHeaderName(name); requireValidHeaderName(name);
requireValidHeaderValue(value); requireValidHeaderValue(value);
KJ_REQUIRE(addNoCheck(name, value) == nullptr, addNoCheck(name, value);
"can't set connection-level headers on HttpHeaders", name, value) { break; }
} }
void HttpHeaders::add(kj::StringPtr name, kj::String&& value) { void HttpHeaders::add(kj::StringPtr name, kj::String&& value) {
...@@ -672,12 +659,8 @@ void HttpHeaders::add(kj::String&& name, kj::String&& value) { ...@@ -672,12 +659,8 @@ void HttpHeaders::add(kj::String&& name, kj::String&& value) {
takeOwnership(kj::mv(value)); takeOwnership(kj::mv(value));
} }
kj::Maybe<uint> HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) { void HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) {
KJ_IF_MAYBE(id, table->stringToId(name)) { KJ_IF_MAYBE(id, table->stringToId(name)) {
if (id->id > CONNECTION_HEADER_THRESHOLD) {
return id->id ^ CONNECTION_HEADER_XOR;
}
if (indexedHeaders[id->id] == nullptr) { if (indexedHeaders[id->id] == nullptr) {
indexedHeaders[id->id] = value; indexedHeaders[id->id] = value;
} else { } else {
...@@ -689,8 +672,6 @@ kj::Maybe<uint> HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) ...@@ -689,8 +672,6 @@ kj::Maybe<uint> HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value)
} else { } else {
unindexedHeaders.add(Header {name, value}); unindexedHeaders.add(Header {name, value});
} }
return nullptr;
} }
void HttpHeaders::takeOwnership(kj::String&& string) { void HttpHeaders::takeOwnership(kj::String&& string) {
...@@ -887,7 +868,7 @@ kj::Maybe<HttpHeaders::Request> HttpHeaders::tryParseRequest(kj::ArrayPtr<char> ...@@ -887,7 +868,7 @@ kj::Maybe<HttpHeaders::Request> HttpHeaders::tryParseRequest(kj::ArrayPtr<char>
// Ignore rest of line. Don't care about "HTTP/1.1" or whatever. // Ignore rest of line. Don't care about "HTTP/1.1" or whatever.
consumeLine(ptr); consumeLine(ptr);
if (!parseHeaders(ptr, end, request.connectionHeaders)) return nullptr; if (!parseHeaders(ptr, end)) return nullptr;
return request; return request;
} }
...@@ -914,28 +895,16 @@ kj::Maybe<HttpHeaders::Response> HttpHeaders::tryParseResponse(kj::ArrayPtr<char ...@@ -914,28 +895,16 @@ kj::Maybe<HttpHeaders::Response> HttpHeaders::tryParseResponse(kj::ArrayPtr<char
response.statusText = consumeLine(ptr); response.statusText = consumeLine(ptr);
if (!parseHeaders(ptr, end, response.connectionHeaders)) return nullptr; if (!parseHeaders(ptr, end)) return nullptr;
return response; return response;
} }
bool HttpHeaders::parseHeaders(char* ptr, char* end, ConnectionHeaders& connectionHeaders) { bool HttpHeaders::parseHeaders(char* ptr, char* end) {
while (*ptr != '\0') { while (*ptr != '\0') {
KJ_IF_MAYBE(name, consumeHeaderName(ptr)) { KJ_IF_MAYBE(name, consumeHeaderName(ptr)) {
kj::StringPtr line = consumeLine(ptr); kj::StringPtr line = consumeLine(ptr);
KJ_IF_MAYBE(connectionHeaderId, addNoCheck(*name, line)) { addNoCheck(*name, line);
// Parsed a connection header.
switch (*connectionHeaderId) {
#define HANDLE_HEADER(id, name) \
case static_cast<uint>(ConnectionHeaderIndices::id): \
connectionHeaders.id = line; \
break;
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER)
#undef HANDLE_HEADER
default:
KJ_UNREACHABLE;
}
}
} else { } else {
return false; return false;
} }
...@@ -946,13 +915,15 @@ bool HttpHeaders::parseHeaders(char* ptr, char* end, ConnectionHeaders& connecti ...@@ -946,13 +915,15 @@ bool HttpHeaders::parseHeaders(char* ptr, char* end, ConnectionHeaders& connecti
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
kj::String HttpHeaders::serializeRequest(HttpMethod method, kj::StringPtr url, kj::String HttpHeaders::serializeRequest(
const ConnectionHeaders& connectionHeaders) const { HttpMethod method, kj::StringPtr url,
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
return serialize(kj::toCharSequence(method), url, kj::StringPtr("HTTP/1.1"), connectionHeaders); return serialize(kj::toCharSequence(method), url, kj::StringPtr("HTTP/1.1"), connectionHeaders);
} }
kj::String HttpHeaders::serializeResponse(uint statusCode, kj::StringPtr statusText, kj::String HttpHeaders::serializeResponse(
const ConnectionHeaders& connectionHeaders) const { uint statusCode, kj::StringPtr statusText,
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
auto statusCodeStr = kj::toCharSequence(statusCode); auto statusCodeStr = kj::toCharSequence(statusCode);
return serialize(kj::StringPtr("HTTP/1.1"), statusCodeStr, statusText, connectionHeaders); return serialize(kj::StringPtr("HTTP/1.1"), statusCodeStr, statusText, connectionHeaders);
...@@ -961,7 +932,7 @@ kj::String HttpHeaders::serializeResponse(uint statusCode, kj::StringPtr statusT ...@@ -961,7 +932,7 @@ kj::String HttpHeaders::serializeResponse(uint statusCode, kj::StringPtr statusT
kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1, kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
kj::ArrayPtr<const char> word2, kj::ArrayPtr<const char> word2,
kj::ArrayPtr<const char> word3, kj::ArrayPtr<const char> word3,
const ConnectionHeaders& connectionHeaders) const { kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
const kj::StringPtr space = " "; const kj::StringPtr space = " ";
const kj::StringPtr newline = "\r\n"; const kj::StringPtr newline = "\r\n";
const kj::StringPtr colon = ": "; const kj::StringPtr colon = ": ";
...@@ -970,15 +941,11 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1, ...@@ -970,15 +941,11 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
if (word1 != nullptr) { if (word1 != nullptr) {
size += word1.size() + word2.size() + word3.size() + 4; size += word1.size() + word2.size() + word3.size() + 4;
} }
#define HANDLE_HEADER(id, name) \ KJ_ASSERT(connectionHeaders.size() <= indexedHeaders.size());
if (connectionHeaders.id != nullptr) { \
size += connectionHeaders.id.size() + (sizeof(name) + 3); \
}
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER)
#undef HANDLE_HEADER
for (auto i: kj::indices(indexedHeaders)) { for (auto i: kj::indices(indexedHeaders)) {
if (indexedHeaders[i] != nullptr) { kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i];
size += table->idToString(HttpHeaderId(table, i)).size() + indexedHeaders[i].size() + 4; if (value != nullptr) {
size += table->idToString(HttpHeaderId(table, i)).size() + value.size() + 4;
} }
} }
for (auto& header: unindexedHeaders) { for (auto& header: unindexedHeaders) {
...@@ -991,16 +958,10 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1, ...@@ -991,16 +958,10 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
if (word1 != nullptr) { if (word1 != nullptr) {
ptr = kj::_::fill(ptr, word1, space, word2, space, word3, newline); ptr = kj::_::fill(ptr, word1, space, word2, space, word3, newline);
} }
#define HANDLE_HEADER(id, name) \
if (connectionHeaders.id != nullptr) { \
ptr = kj::_::fill(ptr, kj::StringPtr(name), colon, connectionHeaders.id, newline); \
}
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(HANDLE_HEADER)
#undef HANDLE_HEADER
for (auto i: kj::indices(indexedHeaders)) { for (auto i: kj::indices(indexedHeaders)) {
if (indexedHeaders[i] != nullptr) { kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i];
ptr = kj::_::fill(ptr, table->idToString(HttpHeaderId(table, i)), colon, if (value != nullptr) {
indexedHeaders[i], newline); ptr = kj::_::fill(ptr, table->idToString(HttpHeaderId(table, i)), colon, value, newline);
} }
} }
for (auto& header: unindexedHeaders) { for (auto& header: unindexedHeaders) {
...@@ -1013,7 +974,7 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1, ...@@ -1013,7 +974,7 @@ kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
} }
kj::String HttpHeaders::toString() const { kj::String HttpHeaders::toString() const {
return serialize(nullptr, nullptr, nullptr, ConnectionHeaders()); return serialize(nullptr, nullptr, nullptr, nullptr);
} }
// ======================================================================================= // =======================================================================================
...@@ -1144,15 +1105,18 @@ public: ...@@ -1144,15 +1105,18 @@ public:
} }
inline kj::Promise<kj::Maybe<HttpHeaders::Request>> readRequestHeaders() { inline kj::Promise<kj::Maybe<HttpHeaders::Request>> readRequestHeaders() {
headers.clear();
return readMessageHeaders().then([this](kj::ArrayPtr<char> text) { return readMessageHeaders().then([this](kj::ArrayPtr<char> text) {
headers.clear();
return headers.tryParseRequest(text); return headers.tryParseRequest(text);
}); });
} }
inline kj::Promise<kj::Maybe<HttpHeaders::Response>> readResponseHeaders() { inline kj::Promise<kj::Maybe<HttpHeaders::Response>> readResponseHeaders() {
headers.clear(); // Note: readResponseHeaders() could be called multiple times concurrently when pipelining
// requests. readMessageHeaders() will serialize these, but it's important not to mess with
// state (like calling headers.clear()) before said serialization has taken place.
return readMessageHeaders().then([this](kj::ArrayPtr<char> text) { return readMessageHeaders().then([this](kj::ArrayPtr<char> text) {
headers.clear();
return headers.tryParseResponse(text); return headers.tryParseResponse(text);
}); });
} }
...@@ -1196,7 +1160,7 @@ public: ...@@ -1196,7 +1160,7 @@ public:
kj::Own<kj::AsyncInputStream> getEntityBody( kj::Own<kj::AsyncInputStream> getEntityBody(
RequestOrResponse type, HttpMethod method, uint statusCode, RequestOrResponse type, HttpMethod method, uint statusCode,
HttpHeaders::ConnectionHeaders& connectionHeaders); const kj::HttpHeaders& headers);
struct ReleasedBuffer { struct ReleasedBuffer {
kj::Array<byte> buffer; kj::Array<byte> buffer;
...@@ -1570,13 +1534,13 @@ static_assert(!fastCaseCmp<'f','O','o','B','1','a'>("FooB1"), ""); ...@@ -1570,13 +1534,13 @@ static_assert(!fastCaseCmp<'f','O','o','B','1','a'>("FooB1"), "");
kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody( kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody(
RequestOrResponse type, HttpMethod method, uint statusCode, RequestOrResponse type, HttpMethod method, uint statusCode,
HttpHeaders::ConnectionHeaders& connectionHeaders) { const kj::HttpHeaders& headers) {
if (type == RESPONSE) { if (type == RESPONSE) {
if (method == HttpMethod::HEAD) { if (method == HttpMethod::HEAD) {
// Body elided. // Body elided.
kj::Maybe<uint64_t> length; kj::Maybe<uint64_t> length;
if (connectionHeaders.contentLength != nullptr) { KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) {
length = strtoull(connectionHeaders.contentLength.cStr(), nullptr, 10); length = strtoull(cl->cStr(), nullptr, 10);
} }
return kj::heap<HttpNullEntityReader>(*this, length); return kj::heap<HttpNullEntityReader>(*this, length);
} else if (statusCode == 204 || statusCode == 205 || statusCode == 304) { } else if (statusCode == 204 || statusCode == 205 || statusCode == 304) {
...@@ -1585,19 +1549,18 @@ kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody( ...@@ -1585,19 +1549,18 @@ kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody(
} }
} }
if (connectionHeaders.transferEncoding != nullptr) { KJ_IF_MAYBE(te, headers.get(HttpHeaderId::TRANSFER_ENCODING)) {
// TODO(someday): Support plugable transfer encodings? Or at least gzip? // TODO(someday): Support plugable transfer encodings? Or at least gzip?
// TODO(soon): Support stacked transfer encodings, e.g. "gzip, chunked". // TODO(soon): Support stacked transfer encodings, e.g. "gzip, chunked".
if (fastCaseCmp<'c','h','u','n','k','e','d'>(connectionHeaders.transferEncoding.cStr())) { if (fastCaseCmp<'c','h','u','n','k','e','d'>(te->cStr())) {
return kj::heap<HttpChunkedEntityReader>(*this); return kj::heap<HttpChunkedEntityReader>(*this);
} else { } else {
KJ_FAIL_REQUIRE("unknown transfer encoding") { break; } KJ_FAIL_REQUIRE("unknown transfer encoding") { break; }
} }
} }
if (connectionHeaders.contentLength != nullptr) { KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) {
return kj::heap<HttpFixedLengthEntityReader>(*this, return kj::heap<HttpFixedLengthEntityReader>(*this, strtoull(cl->cStr(), nullptr, 10));
strtoull(connectionHeaders.contentLength.cStr(), nullptr, 10));
} }
if (type == REQUEST) { if (type == REQUEST) {
...@@ -1605,10 +1568,10 @@ kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody( ...@@ -1605,10 +1568,10 @@ kj::Own<kj::AsyncInputStream> HttpInputStream::getEntityBody(
return kj::heap<HttpNullEntityReader>(*this, uint64_t(0)); return kj::heap<HttpNullEntityReader>(*this, uint64_t(0));
} }
if (connectionHeaders.connection != nullptr) { KJ_IF_MAYBE(c, headers.get(HttpHeaderId::CONNECTION)) {
// TODO(soon): Connection header can actually have multiple tokens... but no one ever uses // TODO(soon): Connection header can actually have multiple tokens... but no one ever uses
// that feature? // that feature?
if (fastCaseCmp<'c','l','o','s','e'>(connectionHeaders.connection.cStr())) { if (fastCaseCmp<'c','l','o','s','e'>(c->cStr())) {
return kj::heap<HttpConnectionCloseEntityReader>(*this); return kj::heap<HttpConnectionCloseEntityReader>(*this);
} }
} }
...@@ -2461,18 +2424,20 @@ public: ...@@ -2461,18 +2424,20 @@ public:
"of being upgraded"); "of being upgraded");
KJ_REQUIRE(!closed, KJ_REQUIRE(!closed,
"this HttpClient's connection has been closed by the server or due to an error"); "this HttpClient's connection has been closed by the server or due to an error");
KJ_REQUIRE(httpOutput.canReuse(),
"can't start new request until previous request body has been fully written");
closeWatcherTask = nullptr; closeWatcherTask = nullptr;
HttpHeaders::ConnectionHeaders connectionHeaders; kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT];
kj::String lengthStr; kj::String lengthStr;
if (method == HttpMethod::GET || method == HttpMethod::HEAD) { if (method == HttpMethod::GET || method == HttpMethod::HEAD) {
// No entity-body. // No entity-body.
} else KJ_IF_MAYBE(s, expectedBodySize) { } else KJ_IF_MAYBE(s, expectedBodySize) {
lengthStr = kj::str(*s); lengthStr = kj::str(*s);
connectionHeaders.contentLength = lengthStr; connectionHeaders[BuiltinHeaderIndices::CONTENT_LENGTH] = lengthStr;
} else { } else {
connectionHeaders.transferEncoding = "chunked"; connectionHeaders[BuiltinHeaderIndices::TRANSFER_ENCODING] = "chunked";
} }
httpOutput.writeHeaders(headers.serializeRequest(method, url, connectionHeaders)); httpOutput.writeHeaders(headers.serializeRequest(method, url, connectionHeaders));
...@@ -2491,14 +2456,16 @@ public: ...@@ -2491,14 +2456,16 @@ public:
auto responsePromise = httpInput.readResponseHeaders() auto responsePromise = httpInput.readResponseHeaders()
.then([this,method](kj::Maybe<HttpHeaders::Response>&& response) -> HttpClient::Response { .then([this,method](kj::Maybe<HttpHeaders::Response>&& response) -> HttpClient::Response {
KJ_IF_MAYBE(r, response) { KJ_IF_MAYBE(r, response) {
auto& headers = httpInput.getHeaders();
HttpClient::Response result { HttpClient::Response result {
r->statusCode, r->statusCode,
r->statusText, r->statusText,
&httpInput.getHeaders(), &headers,
httpInput.getEntityBody(HttpInputStream::RESPONSE, method, r->statusCode, httpInput.getEntityBody(HttpInputStream::RESPONSE, method, r->statusCode, headers)
r->connectionHeaders)
}; };
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(r->connectionHeaders.connection.cStr())) {
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(
headers.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) {
closed = true; closed = true;
} else { } else {
watchForClose(); watchForClose();
...@@ -2533,11 +2500,11 @@ public: ...@@ -2533,11 +2500,11 @@ public:
"HttpClient").generate(keyBytes); "HttpClient").generate(keyBytes);
auto keyBase64 = kj::encodeBase64(keyBytes); auto keyBase64 = kj::encodeBase64(keyBytes);
HttpHeaders::ConnectionHeaders connectionHeaders; kj::StringPtr connectionHeaders[WEBSOCKET_CONNECTION_HEADERS_COUNT];
connectionHeaders.connection = "Upgrade"; connectionHeaders[BuiltinHeaderIndices::CONNECTION] = "Upgrade";
connectionHeaders.upgrade = "websocket"; connectionHeaders[BuiltinHeaderIndices::UPGRADE] = "websocket";
connectionHeaders.websocketVersion = "13"; connectionHeaders[BuiltinHeaderIndices::SEC_WEBSOCKET_VERSION] = "13";
connectionHeaders.websocketKey = keyBase64; connectionHeaders[BuiltinHeaderIndices::SEC_WEBSOCKET_KEY] = keyBase64;
httpOutput.writeHeaders(headers.serializeRequest(HttpMethod::GET, url, connectionHeaders)); httpOutput.writeHeaders(headers.serializeRequest(HttpMethod::GET, url, connectionHeaders));
...@@ -2549,18 +2516,23 @@ public: ...@@ -2549,18 +2516,23 @@ public:
[this](kj::StringPtr keyBase64, kj::Maybe<HttpHeaders::Response>&& response) [this](kj::StringPtr keyBase64, kj::Maybe<HttpHeaders::Response>&& response)
-> HttpClient::WebSocketResponse { -> HttpClient::WebSocketResponse {
KJ_IF_MAYBE(r, response) { KJ_IF_MAYBE(r, response) {
auto& headers = httpInput.getHeaders();
if (r->statusCode == 101) { if (r->statusCode == 101) {
if (!fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>( if (!fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
r->connectionHeaders.upgrade.cStr())) { headers.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
KJ_FAIL_REQUIRE("server returned incorrect Upgrade header; should be 'websocket'", KJ_FAIL_REQUIRE("server returned incorrect Upgrade header; should be 'websocket'",
r->connectionHeaders.upgrade) { break; } headers.get(HttpHeaderId::UPGRADE).orDefault("(null)")) {
break;
}
return HttpClient::WebSocketResponse(); return HttpClient::WebSocketResponse();
} }
auto expectedAccept = generateWebSocketAccept(keyBase64); auto expectedAccept = generateWebSocketAccept(keyBase64);
if (r->connectionHeaders.websocketAccept != expectedAccept) { if (headers.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT).orDefault(nullptr)
!= expectedAccept) {
KJ_FAIL_REQUIRE("server returned incorrect Sec-WebSocket-Accept header", KJ_FAIL_REQUIRE("server returned incorrect Sec-WebSocket-Accept header",
r->connectionHeaders.websocketAccept, expectedAccept) { break; } headers.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT).orDefault("(null)"),
expectedAccept) { break; }
return HttpClient::WebSocketResponse(); return HttpClient::WebSocketResponse();
} }
...@@ -2575,11 +2547,12 @@ public: ...@@ -2575,11 +2547,12 @@ public:
HttpClient::WebSocketResponse result { HttpClient::WebSocketResponse result {
r->statusCode, r->statusCode,
r->statusText, r->statusText,
&httpInput.getHeaders(), &headers,
httpInput.getEntityBody(HttpInputStream::RESPONSE, HttpMethod::GET, r->statusCode, httpInput.getEntityBody(HttpInputStream::RESPONSE, HttpMethod::GET, r->statusCode,
r->connectionHeaders) headers)
}; };
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(r->connectionHeaders.connection.cStr())) { if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(
headers.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) {
closed = true; closed = true;
} else { } else {
watchForClose(); watchForClose();
...@@ -3367,31 +3340,32 @@ public: ...@@ -3367,31 +3340,32 @@ public:
KJ_IF_MAYBE(req, request) { KJ_IF_MAYBE(req, request) {
kj::Promise<void> promise = nullptr; kj::Promise<void> promise = nullptr;
auto& headers = httpInput.getHeaders();
if (fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>( if (fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
req->connectionHeaders.upgrade.cStr())) { headers.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
if (req->method != HttpMethod::GET) { if (req->method != HttpMethod::GET) {
return sendError(400, "Bad Request", kj::str( return sendError(400, "Bad Request", kj::str(
"ERROR: WebSocket must be initiated with a GET request.")); "ERROR: WebSocket must be initiated with a GET request."));
} }
if (req->connectionHeaders.websocketVersion != "13") { if (headers.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendError(400, "Bad Request", kj::str( return sendError(400, "Bad Request", kj::str(
"ERROR: The requested WebSocket version is not supported.")); "ERROR: The requested WebSocket version is not supported."));
} }
if (req->connectionHeaders.websocketKey == nullptr) { KJ_IF_MAYBE(key, headers.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
currentMethod = HttpMethod::GET;
websocketKey = kj::str(*key);
promise = server.service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else {
return sendError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key")); return sendError(400, "Bad Request", kj::str("ERROR: Missing Sec-WebSocket-Key"));
} }
currentMethod = HttpMethod::GET;
websocketKey = kj::str(req->connectionHeaders.websocketKey);
promise = server.service.openWebSocket(req->url, httpInput.getHeaders(), *this);
} else { } else {
currentMethod = req->method; currentMethod = req->method;
websocketKey = nullptr; websocketKey = nullptr;
auto body = httpInput.getEntityBody( auto body = httpInput.getEntityBody(
HttpInputStream::REQUEST, req->method, 0, req->connectionHeaders); HttpInputStream::REQUEST, req->method, 0, headers);
// TODO(perf): If the client disconnects, should we cancel the response? Probably, to // TODO(perf): If the client disconnects, should we cancel the response? Probably, to
// prevent permanent deadlock. It's slightly weird in that arguably the client should // prevent permanent deadlock. It's slightly weird in that arguably the client should
...@@ -3399,7 +3373,7 @@ public: ...@@ -3399,7 +3373,7 @@ public:
// other HTTP servers do similar things. // other HTTP servers do similar things.
promise = server.service.request( promise = server.service.request(
req->method, req->url, httpInput.getHeaders(), *body, *this); req->method, req->url, headers, *body, *this);
promise = promise.attach(kj::mv(body)); promise = promise.attach(kj::mv(body));
} }
...@@ -3494,16 +3468,16 @@ private: ...@@ -3494,16 +3468,16 @@ private:
httpInput.finishRead(); httpInput.finishRead();
} }
HttpHeaders::ConnectionHeaders connectionHeaders; kj::StringPtr connectionHeaders[CONNECTION_HEADERS_COUNT];
kj::String lengthStr; kj::String lengthStr;
if (statusCode == 204 || statusCode == 205 || statusCode == 304) { if (statusCode == 204 || statusCode == 205 || statusCode == 304) {
// No entity-body. // No entity-body.
} else KJ_IF_MAYBE(s, expectedBodySize) { } else KJ_IF_MAYBE(s, expectedBodySize) {
lengthStr = kj::str(*s); lengthStr = kj::str(*s);
connectionHeaders.contentLength = lengthStr; connectionHeaders[BuiltinHeaderIndices::CONTENT_LENGTH] = lengthStr;
} else { } else {
connectionHeaders.transferEncoding = "chunked"; connectionHeaders[BuiltinHeaderIndices::TRANSFER_ENCODING] = "chunked";
} }
httpOutput.writeHeaders(headers.serializeResponse(statusCode, statusText, connectionHeaders)); httpOutput.writeHeaders(headers.serializeResponse(statusCode, statusText, connectionHeaders));
...@@ -3532,10 +3506,10 @@ private: ...@@ -3532,10 +3506,10 @@ private:
auto websocketAccept = generateWebSocketAccept(key); auto websocketAccept = generateWebSocketAccept(key);
HttpHeaders::ConnectionHeaders connectionHeaders; kj::StringPtr connectionHeaders[WEBSOCKET_CONNECTION_HEADERS_COUNT];
connectionHeaders.websocketAccept = websocketAccept; connectionHeaders[BuiltinHeaderIndices::SEC_WEBSOCKET_ACCEPT] = websocketAccept;
connectionHeaders.upgrade = "websocket"; connectionHeaders[BuiltinHeaderIndices::UPGRADE] = "websocket";
connectionHeaders.connection = "Upgrade"; connectionHeaders[BuiltinHeaderIndices::CONNECTION] = "Upgrade";
httpOutput.writeHeaders(headers.serializeResponse( httpOutput.writeHeaders(headers.serializeResponse(
101, "Switching Protocols", connectionHeaders)); 101, "Switching Protocols", connectionHeaders));
...@@ -3544,16 +3518,13 @@ private: ...@@ -3544,16 +3518,13 @@ private:
} }
kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::String body) { kj::Promise<void> sendError(uint statusCode, kj::StringPtr statusText, kj::String body) {
auto bodySize = kj::str(body.size());
HttpHeaders failed(server.requestHeaderTable); HttpHeaders failed(server.requestHeaderTable);
HttpHeaders::ConnectionHeaders connHeaders; failed.set(HttpHeaderId::CONNECTION, "close");
connHeaders.connection = "close"; failed.set(HttpHeaderId::CONTENT_LENGTH, kj::str(body.size()));
connHeaders.contentLength = bodySize;
failed.set(HttpHeaderId::CONTENT_TYPE, "text/plain"); failed.set(HttpHeaderId::CONTENT_TYPE, "text/plain");
httpOutput.writeHeaders(failed.serializeResponse(statusCode, statusText, connHeaders)); httpOutput.writeHeaders(failed.serializeResponse(statusCode, statusText));
httpOutput.writeBodyData(kj::mv(body)); httpOutput.writeBodyData(kj::mv(body));
httpOutput.finishBody(); httpOutput.finishBody();
return httpOutput.flush(); // loop ends after flush return httpOutput.flush(); // loop ends after flush
......
...@@ -78,19 +78,6 @@ namespace kj { ...@@ -78,19 +78,6 @@ namespace kj {
MACRO(UNSUBSCRIBE) MACRO(UNSUBSCRIBE)
/* UPnP */ /* UPnP */
#define KJ_HTTP_FOR_EACH_CONNECTION_HEADER(MACRO) \
MACRO(connection, "Connection") \
MACRO(contentLength, "Content-Length") \
MACRO(keepAlive, "Keep-Alive") \
MACRO(te, "TE") \
MACRO(trailer, "Trailer") \
MACRO(transferEncoding, "Transfer-Encoding") \
MACRO(upgrade, "Upgrade") \
MACRO(websocketKey, "Sec-WebSocket-Key") \
MACRO(websocketVersion, "Sec-WebSocket-Version") \
MACRO(websocketAccept, "Sec-WebSocket-Accept") \
MACRO(websocketExtensions, "Sec-WebSocket-Extensions")
enum class HttpMethod { enum class HttpMethod {
// Enum of known HTTP methods. // Enum of known HTTP methods.
// //
...@@ -138,12 +125,27 @@ public: ...@@ -138,12 +125,27 @@ public:
// In opt mode, no-op. // In opt mode, no-op.
#define KJ_HTTP_FOR_EACH_BUILTIN_HEADER(MACRO) \ #define KJ_HTTP_FOR_EACH_BUILTIN_HEADER(MACRO) \
/* Headers that are always read-only. */ \
MACRO(CONNECTION, "Connection") \
MACRO(CONTENT_LENGTH, "Content-Length") \
MACRO(KEEP_ALIVE, "Keep-Alive") \
MACRO(TE, "TE") \
MACRO(TRAILER, "Trailer") \
MACRO(TRANSFER_ENCODING, "Transfer-Encoding") \
MACRO(UPGRADE, "Upgrade") \
\
/* Headers that are read-only for WebSocket handshakes. */ \
MACRO(SEC_WEBSOCKET_KEY, "Sec-WebSocket-Key") \
MACRO(SEC_WEBSOCKET_VERSION, "Sec-WebSocket-Version") \
MACRO(SEC_WEBSOCKET_ACCEPT, "Sec-WebSocket-Accept") \
MACRO(SEC_WEBSOCKET_EXTENSIONS, "Sec-WebSocket-Extensions") \
\
/* Headers that you can write. */ \
MACRO(HOST, "Host") \ MACRO(HOST, "Host") \
MACRO(DATE, "Date") \ MACRO(DATE, "Date") \
MACRO(LOCATION, "Location") \ MACRO(LOCATION, "Location") \
MACRO(CONTENT_TYPE, "Content-Type") MACRO(CONTENT_TYPE, "Content-Type")
// For convenience, these very-common headers are valid for all HttpHeaderTables. You can refer // For convenience, these headers are valid for all HttpHeaderTables. You can refer to them like:
// to them like:
// //
// HttpHeaderId::HOST // HttpHeaderId::HOST
// //
...@@ -300,26 +302,13 @@ public: ...@@ -300,26 +302,13 @@ public:
// Takes overship of a string so that it lives until the HttpHeaders object is destroyed. Useful // Takes overship of a string so that it lives until the HttpHeaders object is destroyed. Useful
// when you've passed a dynamic value to set() or add() or parse*(). // when you've passed a dynamic value to set() or add() or parse*().
struct ConnectionHeaders {
// These headers govern details of the specific HTTP connection or framing of the content.
// Hence, they are managed internally within the HTTP library, and never appear in an
// HttpHeaders structure.
#define DECLARE_HEADER(id, name) \
kj::StringPtr id;
KJ_HTTP_FOR_EACH_CONNECTION_HEADER(DECLARE_HEADER)
#undef DECLARE_HEADER
};
struct Request { struct Request {
HttpMethod method; HttpMethod method;
kj::StringPtr url; kj::StringPtr url;
ConnectionHeaders connectionHeaders;
}; };
struct Response { struct Response {
uint statusCode; uint statusCode;
kj::StringPtr statusText; kj::StringPtr statusText;
ConnectionHeaders connectionHeaders;
}; };
kj::Maybe<Request> tryParseRequest(kj::ArrayPtr<char> content); kj::Maybe<Request> tryParseRequest(kj::ArrayPtr<char> content);
...@@ -334,11 +323,15 @@ public: ...@@ -334,11 +323,15 @@ public:
// `HttpHeaders` is destroyed, or pass it to `takeOwnership()`. // `HttpHeaders` is destroyed, or pass it to `takeOwnership()`.
kj::String serializeRequest(HttpMethod method, kj::StringPtr url, kj::String serializeRequest(HttpMethod method, kj::StringPtr url,
const ConnectionHeaders& connectionHeaders) const; kj::ArrayPtr<const kj::StringPtr> connectionHeaders = nullptr) const;
kj::String serializeResponse(uint statusCode, kj::StringPtr statusText, kj::String serializeResponse(uint statusCode, kj::StringPtr statusText,
const ConnectionHeaders& connectionHeaders) const; kj::ArrayPtr<const kj::StringPtr> connectionHeaders = nullptr) const;
// Serialize the headers as a complete request or response blob. The blob uses '\r\n' newlines // Serialize the headers as a complete request or response blob. The blob uses '\r\n' newlines
// and includes the double-newline to indicate the end of the headers. // and includes the double-newline to indicate the end of the headers.
//
// `connectionHeaders`, if provided, contains connection-level headers supplied by the HTTP
// implementation, in the order specified by the KJ_HTTP_FOR_EACH_BUILTIN_HEADER macro. These
// headers values override any corresponding header value in the HttpHeaders object.
kj::String toString() const; kj::String toString() const;
...@@ -356,16 +349,16 @@ private: ...@@ -356,16 +349,16 @@ private:
kj::Vector<kj::Array<char>> ownedStrings; kj::Vector<kj::Array<char>> ownedStrings;
kj::Maybe<uint> addNoCheck(kj::StringPtr name, kj::StringPtr value); void addNoCheck(kj::StringPtr name, kj::StringPtr value);
kj::StringPtr cloneToOwn(kj::StringPtr str); kj::StringPtr cloneToOwn(kj::StringPtr str);
kj::String serialize(kj::ArrayPtr<const char> word1, kj::String serialize(kj::ArrayPtr<const char> word1,
kj::ArrayPtr<const char> word2, kj::ArrayPtr<const char> word2,
kj::ArrayPtr<const char> word3, kj::ArrayPtr<const char> word3,
const ConnectionHeaders& connectionHeaders) const; kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const;
bool parseHeaders(char* ptr, char* end, ConnectionHeaders& connectionHeaders); bool parseHeaders(char* ptr, char* end);
// TODO(perf): Arguably we should store a map, but header sets are never very long // TODO(perf): Arguably we should store a map, but header sets are never very long
// TODO(perf): We could optimize for common headers by storing them directly as fields. We could // TODO(perf): We could optimize for common headers by storing them directly as fields. We could
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment