Commit 1f2c549c authored by Joe Lee's avatar Joe Lee

Add HttpClient wrapper for limiting concurrent connections

parent 48547eb6
......@@ -3089,20 +3089,166 @@ KJ_TEST("HttpClient disable connection reuse") {
});
};
// We can do several requests in a row and only have one connection.
// Each serial request gets its own connection.
doRequest().wait(io.waitScope);
doRequest().wait(io.waitScope);
doRequest().wait(io.waitScope);
KJ_EXPECT(count == 0);
KJ_EXPECT(cumulative == 3);
// But if we do two in parallel, we'll end up with two connections.
// Each parallel request gets its own connection.
auto req1 = doRequest();
auto req2 = doRequest();
req1.wait(io.waitScope);
req2.wait(io.waitScope);
KJ_EXPECT(count == 0);
KJ_EXPECT(cumulative == 5);
}
KJ_TEST("HttpClient concurrency limiting") {
auto io = kj::setupAsyncIo();
kj::TimerImpl serverTimer(kj::origin<kj::TimePoint>());
kj::TimerImpl clientTimer(kj::origin<kj::TimePoint>());
HttpHeaderTable headerTable;
auto listener = io.provider->getNetwork().parseAddress("localhost", 0)
.wait(io.waitScope)->listen();
DummyService service(headerTable);
HttpServerSettings serverSettings;
HttpServer server(serverTimer, headerTable, service, serverSettings);
auto listenTask = server.listenHttp(*listener);
auto addr = io.provider->getNetwork().parseAddress("localhost", listener->getPort())
.wait(io.waitScope);
uint count = 0;
uint cumulative = 0;
CountingNetworkAddress countingAddr(*addr, count, cumulative);
FakeEntropySource entropySource;
HttpClientSettings clientSettings;
clientSettings.entropySource = entropySource;
clientSettings.idleTimout = 0 * kj::SECONDS;
auto innerClient = newHttpClient(clientTimer, headerTable, countingAddr, clientSettings);
struct CallbackEvent {
uint runningCount;
uint pendingCount;
bool operator==(const CallbackEvent& other) const {
return runningCount == other.runningCount && pendingCount == other.pendingCount;
}
bool operator!=(const CallbackEvent& other) const { return !(*this == other); }
// TODO(someday): Can use default spaceship operator in C++20:
//auto operator<=>(const CallbackEvent&) const = default;
};
kj::Vector<CallbackEvent> callbackEvents;
auto callback = [&](uint runningCount, uint pendingCount) {
callbackEvents.add(CallbackEvent{runningCount, pendingCount});
};
auto client = newConcurrencyLimitingHttpClient(*innerClient, 1, kj::mv(callback));
KJ_EXPECT(count == 0);
KJ_EXPECT(cumulative == 0);
uint i = 0;
auto doRequest = [&]() {
uint n = i++;
return client->request(HttpMethod::GET, kj::str("/", n), HttpHeaders(headerTable)).response
.then([](HttpClient::Response&& response) {
auto promise = response.body->readAllText();
return promise.attach(kj::mv(response.body));
}).then([n](kj::String body) {
KJ_EXPECT(body == kj::str("null:/", n));
});
};
// Second connection blocked by first.
auto req1 = doRequest();
KJ_EXPECT(callbackEvents == kj::ArrayPtr<const CallbackEvent>({ {1, 0} }));
callbackEvents.clear();
auto req2 = doRequest();
// TODO(someday): Figure out why this poll() is necessary on Windows and macOS.
io.waitScope.poll();
KJ_EXPECT(req1.poll(io.waitScope));
KJ_EXPECT(!req2.poll(io.waitScope));
KJ_EXPECT(count == 1);
KJ_EXPECT(cumulative == 1);
KJ_EXPECT(callbackEvents == kj::ArrayPtr<const CallbackEvent>({ {1, 1} }));
callbackEvents.clear();
// Releasing first connection allows second to start.
req1.wait(io.waitScope);
KJ_EXPECT(req2.poll(io.waitScope));
KJ_EXPECT(count == 1);
KJ_EXPECT(cumulative == 2);
KJ_EXPECT(callbackEvents == kj::ArrayPtr<const CallbackEvent>({ {1, 0} }));
callbackEvents.clear();
req2.wait(io.waitScope);
KJ_EXPECT(count == 0);
KJ_EXPECT(cumulative == 2);
KJ_EXPECT(callbackEvents == kj::ArrayPtr<const CallbackEvent>({ {0, 0} }));
callbackEvents.clear();
// Using body stream after releasing blocked response promise throws no exception
auto req3 = doRequest();
{
kj::Own<kj::AsyncOutputStream> req4Body;
{
auto req4 = client->request(HttpMethod::GET, kj::str("/", ++i), HttpHeaders(headerTable));
io.waitScope.poll();
req4Body = kj::mv(req4.body);
}
auto writePromise = req4Body->write("a", 1);
KJ_EXPECT(!writePromise.poll(io.waitScope));
}
req3.wait(io.waitScope);
KJ_EXPECT(count == 0);
KJ_EXPECT(cumulative == 3);
// Similar connection limiting for web sockets
#if __linux__
// TODO(someday): Figure out why the sequencing of websockets events does
// not work correctly on Windows (and maybe macOS?). The solution is not as
// simple as inserting poll()s as above, since doing so puts the websocket in
// a state that trips a "previous HTTP message body incomplete" assertion,
// while trying to write 500 network response.
callbackEvents.clear();
auto ws1 = kj::heap(client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable)));
KJ_EXPECT(callbackEvents == kj::ArrayPtr<const CallbackEvent>({ {1, 0} }));
callbackEvents.clear();
auto ws2 = kj::heap(client->openWebSocket(kj::str("/websocket"), HttpHeaders(headerTable)));
KJ_EXPECT(ws1->poll(io.waitScope));
KJ_EXPECT(!ws2->poll(io.waitScope));
KJ_EXPECT(count == 1);
KJ_EXPECT(cumulative == 4);
KJ_EXPECT(callbackEvents == kj::ArrayPtr<const CallbackEvent>({ {1, 1} }));
callbackEvents.clear();
{
auto response1 = ws1->wait(io.waitScope);
KJ_EXPECT(!ws2->poll(io.waitScope));
KJ_EXPECT(callbackEvents == kj::ArrayPtr<const CallbackEvent>({}));
}
KJ_EXPECT(ws2->poll(io.waitScope));
KJ_EXPECT(count == 1);
KJ_EXPECT(cumulative == 5);
KJ_EXPECT(callbackEvents == kj::ArrayPtr<const CallbackEvent>({ {1, 0} }));
callbackEvents.clear();
{
auto response2 = ws2->wait(io.waitScope);
KJ_EXPECT(callbackEvents == kj::ArrayPtr<const CallbackEvent>({}));
}
KJ_EXPECT(count == 0);
KJ_EXPECT(cumulative == 5);
KJ_EXPECT(callbackEvents == kj::ArrayPtr<const CallbackEvent>({ {0, 0} }));
#endif
}
KJ_TEST("HttpClient multi host") {
......
......@@ -27,6 +27,7 @@
#include <stdlib.h>
#include <kj/encoding.h>
#include <deque>
#include <queue>
#include <map>
namespace kj {
......@@ -3951,6 +3952,171 @@ kj::Own<HttpClient> newHttpClient(kj::Timer& timer, HttpHeaderTable& responseHea
namespace {
class ConcurrencyLimitingHttpClient final: public HttpClient {
public:
ConcurrencyLimitingHttpClient(
kj::HttpClient& inner, uint maxConcurrentRequests,
kj::Function<void(uint runningCount, uint pendingCount)> countChangedCallback)
: inner(inner),
maxConcurrentRequests(maxConcurrentRequests),
countChangedCallback(kj::mv(countChangedCallback)) {}
Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
if (concurrentRequests < maxConcurrentRequests) {
auto counter = ConnectionCounter(*this);
auto request = inner.request(method, url, headers, expectedBodySize);
fireCountChanged();
auto promise = attachCounter(kj::mv(request.response), kj::mv(counter));
return { kj::mv(request.body), kj::mv(promise) };
}
auto paf = kj::newPromiseAndFulfiller<ConnectionCounter>();
auto urlCopy = kj::str(url);
auto headersCopy = headers.clone();
auto combined = paf.promise
.then([this,
method,
urlCopy = kj::mv(urlCopy),
headersCopy = kj::mv(headersCopy),
expectedBodySize](ConnectionCounter&& counter) mutable {
auto req = inner.request(method, urlCopy, headersCopy, expectedBodySize);
return kj::tuple(kj::mv(req.body), attachCounter(kj::mv(req.response), kj::mv(counter)));
});
auto split = combined.split();
pendingRequests.push(kj::mv(paf.fulfiller));
fireCountChanged();
return { kj::heap<PromiseOutputStream>(kj::mv(kj::get<0>(split))), kj::mv(kj::get<1>(split)) };
}
kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const kj::HttpHeaders& headers) override {
if (concurrentRequests < maxConcurrentRequests) {
auto counter = ConnectionCounter(*this);
auto response = inner.openWebSocket(url, headers);
fireCountChanged();
return attachCounter(kj::mv(response), kj::mv(counter));
}
auto paf = kj::newPromiseAndFulfiller<ConnectionCounter>();
auto urlCopy = kj::str(url);
auto headersCopy = headers.clone();
auto promise = paf.promise
.then([this,
urlCopy = kj::mv(urlCopy),
headersCopy = kj::mv(headersCopy)](ConnectionCounter&& counter) mutable {
return attachCounter(inner.openWebSocket(urlCopy, headersCopy), kj::mv(counter));
});
pendingRequests.push(kj::mv(paf.fulfiller));
fireCountChanged();
return kj::mv(promise);
}
private:
struct ConnectionCounter;
kj::HttpClient& inner;
uint maxConcurrentRequests;
uint concurrentRequests = 0;
kj::Function<void(uint runningCount, uint pendingCount)> countChangedCallback;
std::queue<kj::Own<kj::PromiseFulfiller<ConnectionCounter>>> pendingRequests;
// TODO(someday): want maximum cap on queue size?
struct ConnectionCounter final {
ConnectionCounter(ConcurrencyLimitingHttpClient& client) : parent(&client) {
++parent->concurrentRequests;
}
KJ_DISALLOW_COPY(ConnectionCounter);
~ConnectionCounter() noexcept(false) {
if (parent != nullptr) {
--parent->concurrentRequests;
parent->serviceQueue();
parent->fireCountChanged();
}
}
ConnectionCounter(ConnectionCounter&& other) : parent(other.parent) {
other.parent = nullptr;
}
ConnectionCounter& operator=(ConnectionCounter&& other) {
if (this != &other) {
this->parent = other.parent;
other.parent = nullptr;
}
return *this;
}
ConcurrencyLimitingHttpClient* parent;
};
void serviceQueue() {
if (concurrentRequests >= maxConcurrentRequests) { return; }
if (pendingRequests.empty()) { return; }
auto fulfiller = kj::mv(pendingRequests.front());
pendingRequests.pop();
fulfiller->fulfill(ConnectionCounter(*this));
}
void fireCountChanged() {
countChangedCallback(concurrentRequests, pendingRequests.size());
}
using WebSocketOrBody = kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>>;
static WebSocketOrBody attachCounter(WebSocketOrBody&& webSocketOrBody,
ConnectionCounter&& counter) {
KJ_SWITCH_ONEOF(webSocketOrBody) {
KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) {
return ws.attach(kj::mv(counter));
}
KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) {
return body.attach(kj::mv(counter));
}
}
KJ_UNREACHABLE;
}
static kj::Promise<WebSocketResponse> attachCounter(kj::Promise<WebSocketResponse>&& promise,
ConnectionCounter&& counter) {
return promise.then([counter = kj::mv(counter)](WebSocketResponse&& response) mutable {
return WebSocketResponse {
response.statusCode,
response.statusText,
response.headers,
attachCounter(kj::mv(response.webSocketOrBody), kj::mv(counter))
};
});
}
static kj::Promise<Response> attachCounter(kj::Promise<Response>&& promise,
ConnectionCounter&& counter) {
return promise.then([counter = kj::mv(counter)](Response&& response) mutable {
return Response {
response.statusCode,
response.statusText,
response.headers,
response.body.attach(kj::mv(counter))
};
});
}
};
}
kj::Own<HttpClient> newConcurrencyLimitingHttpClient(
HttpClient& inner, uint maxConcurrentRequests,
kj::Function<void(uint runningCount, uint pendingCount)> countChangedCallback) {
return kj::heap<ConcurrencyLimitingHttpClient>(inner, maxConcurrentRequests,
kj::mv(countChangedCallback));
}
// =======================================================================================
namespace {
class NullInputStream final: public kj::AsyncInputStream {
public:
NullInputStream(kj::Maybe<size_t> expectedLength = size_t(0))
......
......@@ -706,6 +706,14 @@ kj::Own<HttpClient> newHttpClient(HttpHeaderTable& responseHeaderTable, kj::Asyn
// subsequent requests will fail. If a response takes a long time, it blocks subsequent responses.
// If a WebSocket is opened successfully, all subsequent requests fail.
kj::Own<HttpClient> newConcurrencyLimitingHttpClient(
HttpClient& inner, uint maxConcurrentRequests,
kj::Function<void(uint runningCount, uint pendingCount)> countChangedCallback);
// Creates an HttpClient that is limited to a maximum number of concurrent requests. Additional
// requests are queued, to be opened only after an open request completes. `countChangedCallback`
// is called when a new connection is opened or enqueued and when an open connection is closed,
// passing the number of open and pending connections.
kj::Own<HttpClient> newHttpClient(HttpService& service);
kj::Own<HttpService> newHttpService(HttpClient& client);
// Adapts an HttpClient to an HttpService and vice versa.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment