Commit abb19c3c authored by Kenton Varda's avatar Kenton Varda

Improve HTTP body cancellation handling.

parent fac09cf7
...@@ -869,6 +869,40 @@ KJ_TEST("HttpClient responses") { ...@@ -869,6 +869,40 @@ KJ_TEST("HttpClient responses") {
} }
} }
KJ_TEST("HttpClient canceled write") {
kj::EventLoop eventLoop;
kj::WaitScope waitScope(eventLoop);
auto pipe = kj::newTwoWayPipe();
auto serverPromise = pipe.ends[1]->readAllText();
{
HttpHeaderTable table;
auto client = newHttpClient(table, *pipe.ends[0]);
auto body = kj::heapArray<byte>(4096);
memset(body.begin(), 0xcf, body.size());
auto req = client->request(HttpMethod::POST, "/", HttpHeaders(table), uint64_t(4096));
// Start a write and immediately cancel it.
(void)req.body->write(body.begin(), body.size());
KJ_EXPECT_THROW_MESSAGE("overwrote", req.body->write("foo", 3).wait(waitScope));
req.body = nullptr;
KJ_EXPECT(!serverPromise.poll(waitScope));
KJ_EXPECT_THROW_MESSAGE("can't start new request until previous request body",
client->request(HttpMethod::GET, "/", HttpHeaders(table)).response.wait(waitScope));
}
pipe.ends[0]->shutdownWrite();
auto text = serverPromise.wait(waitScope);
KJ_EXPECT(text == "POST / HTTP/1.1\r\nContent-Length: 4096\r\n\r\n", text);
}
KJ_TEST("HttpServer requests") { KJ_TEST("HttpServer requests") {
HttpResponseTestCase RESPONSE = { HttpResponseTestCase RESPONSE = {
"HTTP/1.1 200 OK\r\n" "HTTP/1.1 200 OK\r\n"
......
...@@ -1607,12 +1607,13 @@ public: ...@@ -1607,12 +1607,13 @@ public:
HttpOutputStream(AsyncOutputStream& inner): inner(inner) {} HttpOutputStream(AsyncOutputStream& inner): inner(inner) {}
bool canReuse() { bool canReuse() {
return !inBody && !broken; return !inBody && !broken && !writeInProgress;
} }
void writeHeaders(String content) { void writeHeaders(String content) {
// Writes some header content and begins a new entity body. // Writes some header content and begins a new entity body.
KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return; }
KJ_REQUIRE(!inBody, "previous HTTP message body incomplete; can't write more messages"); KJ_REQUIRE(!inBody, "previous HTTP message body incomplete; can't write more messages");
inBody = true; inBody = true;
...@@ -1620,42 +1621,56 @@ public: ...@@ -1620,42 +1621,56 @@ public:
} }
void writeBodyData(kj::String content) { void writeBodyData(kj::String content) {
KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return; }
KJ_REQUIRE(inBody) { return; } KJ_REQUIRE(inBody) { return; }
queueWrite(kj::mv(content)); queueWrite(kj::mv(content));
} }
kj::Promise<void> writeBodyData(const void* buffer, size_t size) { kj::Promise<void> writeBodyData(const void* buffer, size_t size) {
KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return kj::READY_NOW; }
KJ_REQUIRE(inBody) { return kj::READY_NOW; } KJ_REQUIRE(inBody) { return kj::READY_NOW; }
auto fork = writeQueue.then([this,buffer,size]() { writeInProgress = true;
return inner.write(buffer, size); auto fork = writeQueue.fork();
}).fork();
writeQueue = fork.addBranch(); writeQueue = fork.addBranch();
return fork.addBranch();
return fork.addBranch().then([this,buffer,size]() {
return inner.write(buffer, size);
}).then([this]() {
writeInProgress = false;
});
} }
kj::Promise<void> writeBodyData(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) { kj::Promise<void> writeBodyData(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return kj::READY_NOW; }
KJ_REQUIRE(inBody) { return kj::READY_NOW; } KJ_REQUIRE(inBody) { return kj::READY_NOW; }
auto fork = writeQueue.then([this,pieces]() { writeInProgress = true;
return inner.write(pieces); auto fork = writeQueue.fork();
}).fork();
writeQueue = fork.addBranch(); writeQueue = fork.addBranch();
return fork.addBranch();
return fork.addBranch().then([this,pieces]() {
return inner.write(pieces);
}).then([this]() {
writeInProgress = false;
});
} }
Promise<uint64_t> pumpBodyFrom(AsyncInputStream& input, uint64_t amount) { Promise<uint64_t> pumpBodyFrom(AsyncInputStream& input, uint64_t amount) {
KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return uint64_t(0); }
KJ_REQUIRE(inBody) { return uint64_t(0); } KJ_REQUIRE(inBody) { return uint64_t(0); }
auto fork = writeQueue.then([this,&input,amount]() { writeInProgress = true;
return input.pumpTo(inner, amount); auto fork = writeQueue.fork();
}).fork(); writeQueue = fork.addBranch();
writeQueue = fork.addBranch().ignoreResult(); return fork.addBranch().then([this,&input,amount]() {
return fork.addBranch(); return input.pumpTo(inner, amount);
}).then([this](uint64_t actual) {
writeInProgress = false;
return actual;
});
} }
void finishBody() { void finishBody() {
...@@ -1689,6 +1704,11 @@ private: ...@@ -1689,6 +1704,11 @@ private:
bool inBody = false; bool inBody = false;
bool broken = false; bool broken = false;
bool writeInProgress = false;
// True if a write method has been called and has not completed successfully. In the case that
// a write throws an exception or is canceled, this remains true forever. In these cases, the
// underlying steram is in an inconsitent state and cannot be reused.
void queueWrite(kj::String content) { void queueWrite(kj::String content) {
writeQueue = writeQueue.then(kj::mvCapture(content, [this](kj::String&& content) { writeQueue = writeQueue.then(kj::mvCapture(content, [this](kj::String&& content) {
auto promise = inner.write(content.begin(), content.size()); auto promise = inner.write(content.begin(), content.size());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment