Commit 49b44f96 authored by Harris Hancock's avatar Harris Hancock

Add limit to readAllBytes()/readAllText()

Reading an unbounded amount of data from a stream is a potential DoS vector. To manage this risk, readAllText() and readAllBytes() now accept a `limit` parameter. For source backwards-compatibility, this limit defaults to kj::maxValue.
parent 90349b4f
...@@ -681,6 +681,72 @@ kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) { ...@@ -681,6 +681,72 @@ kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) {
})); }));
} }
class MockAsyncInputStream: public AsyncInputStream {
public:
MockAsyncInputStream(kj::ArrayPtr<const byte> bytes, size_t blockSize)
: bytes(bytes), blockSize(blockSize) {}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
// Clamp max read to blockSize.
size_t n = kj::min(blockSize, maxBytes);
// Unless that's less than minBytes -- in which case, use minBytes.
n = kj::max(n, minBytes);
// But also don't read more data than we have.
n = kj::min(n, bytes.size());
memcpy(buffer, bytes.begin(), n);
bytes = bytes.slice(n, bytes.size());
return n;
}
private:
kj::ArrayPtr<const byte> bytes;
size_t blockSize;
};
KJ_TEST("AsyncInputStream::readAllText() / readAllBytes()") {
kj::EventLoop loop;
WaitScope ws(loop);
auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ",");
size_t inputSizes[] = { 0, 1, 256, 4096, 8191, 8192, 8193, 10000, bigText.size() };
size_t blockSizes[] = { 1, 4, 256, 4096, 8192, bigText.size() };
uint64_t limits[] = {
0, 1, 256,
bigText.size() / 2,
bigText.size() - 1,
bigText.size(),
bigText.size() + 1,
kj::maxValue
};
for (size_t inputSize: inputSizes) {
for (size_t blockSize: blockSizes) {
for (uint64_t limit: limits) {
KJ_CONTEXT(inputSize, blockSize, limit);
auto textSlice = bigText.asBytes().slice(0, inputSize);
auto readAllText = [&]() {
MockAsyncInputStream input(textSlice, blockSize);
return input.readAllText(limit).wait(ws);
};
auto readAllBytes = [&]() {
MockAsyncInputStream input(textSlice, blockSize);
return input.readAllBytes(limit).wait(ws);
};
if (limit > inputSize) {
KJ_EXPECT(readAllText().asBytes() == textSlice);
KJ_EXPECT(readAllBytes() == textSlice);
} else {
KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllText());
KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllBytes());
}
}
}
}
}
KJ_TEST("Userland pipe") { KJ_TEST("Userland pipe") {
kj::EventLoop loop; kj::EventLoop loop;
WaitScope ws(loop); WaitScope ws(loop);
......
...@@ -121,17 +121,17 @@ class AllReader { ...@@ -121,17 +121,17 @@ class AllReader {
public: public:
AllReader(AsyncInputStream& input): input(input) {} AllReader(AsyncInputStream& input): input(input) {}
Promise<Array<byte>> readAllBytes() { Promise<Array<byte>> readAllBytes(uint64_t limit) {
return loop().then([this](uint64_t size) { return loop(limit).then([this, limit](uint64_t headroom) {
auto out = heapArray<byte>(size); auto out = heapArray<byte>(limit - headroom);
copyInto(out); copyInto(out);
return out; return out;
}); });
} }
Promise<String> readAllText() { Promise<String> readAllText(uint64_t limit) {
return loop().then([this](uint64_t size) { return loop(limit).then([this, limit](uint64_t headroom) {
auto out = heapArray<char>(size + 1); auto out = heapArray<char>(limit - headroom + 1);
copyInto(out.slice(0, out.size() - 1).asBytes()); copyInto(out.slice(0, out.size() - 1).asBytes());
out.back() = '\0'; out.back() = '\0';
return String(kj::mv(out)); return String(kj::mv(out));
...@@ -142,17 +142,19 @@ private: ...@@ -142,17 +142,19 @@ private:
AsyncInputStream& input; AsyncInputStream& input;
Vector<Array<byte>> parts; Vector<Array<byte>> parts;
Promise<uint64_t> loop(uint64_t total = 0) { Promise<uint64_t> loop(uint64_t limit) {
auto part = heapArray<byte>(4096); KJ_REQUIRE(limit > 0, "Reached limit before EOF.");
auto part = heapArray<byte>(kj::min(4096, limit));
auto partPtr = part.asPtr(); auto partPtr = part.asPtr();
parts.add(kj::mv(part)); parts.add(kj::mv(part));
return input.tryRead(partPtr.begin(), partPtr.size(), partPtr.size()) return input.tryRead(partPtr.begin(), partPtr.size(), partPtr.size())
.then([this,KJ_CPCAP(partPtr),total](size_t amount) -> Promise<uint64_t> { .then([this,KJ_CPCAP(partPtr),limit](size_t amount) mutable -> Promise<uint64_t> {
uint64_t newTotal = total + amount; limit -= amount;
if (amount < partPtr.size()) { if (amount < partPtr.size()) {
return newTotal; return limit;
} else { } else {
return loop(newTotal); return loop(limit);
} }
}); });
} }
...@@ -169,15 +171,15 @@ private: ...@@ -169,15 +171,15 @@ private:
} // namespace } // namespace
Promise<Array<byte>> AsyncInputStream::readAllBytes() { Promise<Array<byte>> AsyncInputStream::readAllBytes(uint64_t limit) {
auto reader = kj::heap<AllReader>(*this); auto reader = kj::heap<AllReader>(*this);
auto promise = reader->readAllBytes(); auto promise = reader->readAllBytes(limit);
return promise.attach(kj::mv(reader)); return promise.attach(kj::mv(reader));
} }
Promise<String> AsyncInputStream::readAllText() { Promise<String> AsyncInputStream::readAllText(uint64_t limit) {
auto reader = kj::heap<AllReader>(*this); auto reader = kj::heap<AllReader>(*this);
auto promise = reader->readAllText(); auto promise = reader->readAllText(limit);
return promise.attach(kj::mv(reader)); return promise.attach(kj::mv(reader));
} }
......
...@@ -79,9 +79,13 @@ public: ...@@ -79,9 +79,13 @@ public:
// The default implementation first tries calling output.tryPumpFrom(), but if that fails, it // The default implementation first tries calling output.tryPumpFrom(), but if that fails, it
// performs a naive pump by allocating a buffer and reading to it / writing from it in a loop. // performs a naive pump by allocating a buffer and reading to it / writing from it in a loop.
Promise<Array<byte>> readAllBytes(); Promise<Array<byte>> readAllBytes(uint64_t limit = kj::maxValue);
Promise<String> readAllText(); Promise<String> readAllText(uint64_t limit = kj::maxValue);
// Read until EOF and return as one big byte array or string. // Read until EOF and return as one big byte array or string. Throw an exception if EOF is not
// seen before reading `limit` bytes.
//
// To prevent runaway memory allocation, consider using a more conservative value for `limit` than
// the default, particularly on untrusted data streams which may never see EOF.
}; };
class AsyncOutputStream { class AsyncOutputStream {
......
...@@ -137,14 +137,41 @@ private: ...@@ -137,14 +137,41 @@ private:
size_t blockSize; size_t blockSize;
}; };
KJ_TEST("InputStream::readAllText()") { KJ_TEST("InputStream::readAllText() / readAllBytes()") {
auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ","); auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ",");
size_t blockSizes[] = { 1, 4, 256, bigText.size() }; size_t inputSizes[] = { 0, 1, 256, 4096, 8191, 8192, 8193, 10000, bigText.size() };
size_t blockSizes[] = { 1, 4, 256, 4096, 8192, bigText.size() };
uint64_t limits[] = {
0, 1, 256,
bigText.size() / 2,
bigText.size() - 1,
bigText.size(),
bigText.size() + 1,
kj::maxValue
};
for (size_t inputSize: inputSizes) {
for (size_t blockSize: blockSizes) { for (size_t blockSize: blockSizes) {
KJ_CONTEXT(blockSize); for (uint64_t limit: limits) {
MockInputStream input(bigText.asBytes(), blockSize); KJ_CONTEXT(inputSize, blockSize, limit);
KJ_EXPECT(input.readAllText() == bigText); auto textSlice = bigText.asBytes().slice(0, inputSize);
auto readAllText = [&]() {
MockInputStream input(textSlice, blockSize);
return input.readAllText(limit);
};
auto readAllBytes = [&]() {
MockInputStream input(textSlice, blockSize);
return input.readAllBytes(limit);
};
if (limit > inputSize) {
KJ_EXPECT(readAllText().asBytes() == textSlice);
KJ_EXPECT(readAllBytes() == textSlice);
} else {
KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllText());
KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllBytes());
}
}
}
} }
} }
......
...@@ -70,13 +70,15 @@ void InputStream::skip(size_t bytes) { ...@@ -70,13 +70,15 @@ void InputStream::skip(size_t bytes) {
namespace { namespace {
Array<byte> readAll(InputStream& input, bool nulTerminate) { Array<byte> readAll(InputStream& input, uint64_t limit, bool nulTerminate) {
Vector<Array<byte>> parts; Vector<Array<byte>> parts;
constexpr size_t BLOCK_SIZE = 4096; constexpr size_t BLOCK_SIZE = 4096;
for (;;) { for (;;) {
auto part = heapArray<byte>(BLOCK_SIZE); KJ_REQUIRE(limit > 0, "Reached limit before EOF.");
auto part = heapArray<byte>(kj::min(BLOCK_SIZE, limit));
size_t n = input.tryRead(part.begin(), part.size(), part.size()); size_t n = input.tryRead(part.begin(), part.size(), part.size());
limit -= n;
if (n < part.size()) { if (n < part.size()) {
auto result = heapArray<byte>(parts.size() * BLOCK_SIZE + n + nulTerminate); auto result = heapArray<byte>(parts.size() * BLOCK_SIZE + n + nulTerminate);
byte* pos = result.begin(); byte* pos = result.begin();
...@@ -97,11 +99,11 @@ Array<byte> readAll(InputStream& input, bool nulTerminate) { ...@@ -97,11 +99,11 @@ Array<byte> readAll(InputStream& input, bool nulTerminate) {
} // namespace } // namespace
String InputStream::readAllText() { String InputStream::readAllText(uint64_t limit) {
return String(readAll(*this, true).releaseAsChars()); return String(readAll(*this, limit, true).releaseAsChars());
} }
Array<byte> InputStream::readAllBytes() { Array<byte> InputStream::readAllBytes(uint64_t limit) {
return readAll(*this, false); return readAll(*this, limit, false);
} }
void OutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) { void OutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "common.h" #include "common.h"
#include "array.h" #include "array.h"
#include "exception.h" #include "exception.h"
#include <stdint.h>
namespace kj { namespace kj {
...@@ -66,9 +67,13 @@ public: ...@@ -66,9 +67,13 @@ public:
// Skips past the given number of bytes, discarding them. The default implementation read()s // Skips past the given number of bytes, discarding them. The default implementation read()s
// into a scratch buffer. // into a scratch buffer.
String readAllText(); String readAllText(uint64_t limit = kj::maxValue);
Array<byte> readAllBytes(); Array<byte> readAllBytes(uint64_t limit = kj::maxValue);
// Read until EOF and return as one big byte array or string. // Read until EOF and return as one big byte array or string. Throw an exception if EOF is not
// seen before reading `limit` bytes.
//
// To prevent runaway memory allocation, consider using a more conservative value for `limit` than
// the default, particularly on untrusted data streams which may never see EOF.
}; };
class OutputStream { class OutputStream {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment