Unverified Commit cba1b179 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #691 from capnproto/harris/read-all-bytes-safely

Add limit to readAllBytes()/readAllText()
parents 90349b4f 49b44f96
......@@ -681,6 +681,72 @@ kj::Promise<void> expectRead(kj::AsyncInputStream& in, kj::StringPtr expected) {
}));
}
class MockAsyncInputStream: public AsyncInputStream {
public:
MockAsyncInputStream(kj::ArrayPtr<const byte> bytes, size_t blockSize)
: bytes(bytes), blockSize(blockSize) {}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
// Clamp max read to blockSize.
size_t n = kj::min(blockSize, maxBytes);
// Unless that's less than minBytes -- in which case, use minBytes.
n = kj::max(n, minBytes);
// But also don't read more data than we have.
n = kj::min(n, bytes.size());
memcpy(buffer, bytes.begin(), n);
bytes = bytes.slice(n, bytes.size());
return n;
}
private:
kj::ArrayPtr<const byte> bytes;
size_t blockSize;
};
KJ_TEST("AsyncInputStream::readAllText() / readAllBytes()") {
kj::EventLoop loop;
WaitScope ws(loop);
auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ",");
size_t inputSizes[] = { 0, 1, 256, 4096, 8191, 8192, 8193, 10000, bigText.size() };
size_t blockSizes[] = { 1, 4, 256, 4096, 8192, bigText.size() };
uint64_t limits[] = {
0, 1, 256,
bigText.size() / 2,
bigText.size() - 1,
bigText.size(),
bigText.size() + 1,
kj::maxValue
};
for (size_t inputSize: inputSizes) {
for (size_t blockSize: blockSizes) {
for (uint64_t limit: limits) {
KJ_CONTEXT(inputSize, blockSize, limit);
auto textSlice = bigText.asBytes().slice(0, inputSize);
auto readAllText = [&]() {
MockAsyncInputStream input(textSlice, blockSize);
return input.readAllText(limit).wait(ws);
};
auto readAllBytes = [&]() {
MockAsyncInputStream input(textSlice, blockSize);
return input.readAllBytes(limit).wait(ws);
};
if (limit > inputSize) {
KJ_EXPECT(readAllText().asBytes() == textSlice);
KJ_EXPECT(readAllBytes() == textSlice);
} else {
KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllText());
KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllBytes());
}
}
}
}
}
KJ_TEST("Userland pipe") {
kj::EventLoop loop;
WaitScope ws(loop);
......
......@@ -121,17 +121,17 @@ class AllReader {
public:
AllReader(AsyncInputStream& input): input(input) {}
Promise<Array<byte>> readAllBytes() {
return loop().then([this](uint64_t size) {
auto out = heapArray<byte>(size);
Promise<Array<byte>> readAllBytes(uint64_t limit) {
return loop(limit).then([this, limit](uint64_t headroom) {
auto out = heapArray<byte>(limit - headroom);
copyInto(out);
return out;
});
}
Promise<String> readAllText() {
return loop().then([this](uint64_t size) {
auto out = heapArray<char>(size + 1);
Promise<String> readAllText(uint64_t limit) {
return loop(limit).then([this, limit](uint64_t headroom) {
auto out = heapArray<char>(limit - headroom + 1);
copyInto(out.slice(0, out.size() - 1).asBytes());
out.back() = '\0';
return String(kj::mv(out));
......@@ -142,17 +142,19 @@ private:
AsyncInputStream& input;
Vector<Array<byte>> parts;
Promise<uint64_t> loop(uint64_t total = 0) {
auto part = heapArray<byte>(4096);
Promise<uint64_t> loop(uint64_t limit) {
KJ_REQUIRE(limit > 0, "Reached limit before EOF.");
auto part = heapArray<byte>(kj::min(4096, limit));
auto partPtr = part.asPtr();
parts.add(kj::mv(part));
return input.tryRead(partPtr.begin(), partPtr.size(), partPtr.size())
.then([this,KJ_CPCAP(partPtr),total](size_t amount) -> Promise<uint64_t> {
uint64_t newTotal = total + amount;
.then([this,KJ_CPCAP(partPtr),limit](size_t amount) mutable -> Promise<uint64_t> {
limit -= amount;
if (amount < partPtr.size()) {
return newTotal;
return limit;
} else {
return loop(newTotal);
return loop(limit);
}
});
}
......@@ -169,15 +171,15 @@ private:
} // namespace
Promise<Array<byte>> AsyncInputStream::readAllBytes() {
Promise<Array<byte>> AsyncInputStream::readAllBytes(uint64_t limit) {
auto reader = kj::heap<AllReader>(*this);
auto promise = reader->readAllBytes();
auto promise = reader->readAllBytes(limit);
return promise.attach(kj::mv(reader));
}
Promise<String> AsyncInputStream::readAllText() {
Promise<String> AsyncInputStream::readAllText(uint64_t limit) {
auto reader = kj::heap<AllReader>(*this);
auto promise = reader->readAllText();
auto promise = reader->readAllText(limit);
return promise.attach(kj::mv(reader));
}
......
......@@ -79,9 +79,13 @@ public:
// The default implementation first tries calling output.tryPumpFrom(), but if that fails, it
// performs a naive pump by allocating a buffer and reading to it / writing from it in a loop.
Promise<Array<byte>> readAllBytes();
Promise<String> readAllText();
// Read until EOF and return as one big byte array or string.
Promise<Array<byte>> readAllBytes(uint64_t limit = kj::maxValue);
Promise<String> readAllText(uint64_t limit = kj::maxValue);
// Read until EOF and return as one big byte array or string. Throw an exception if EOF is not
// seen before reading `limit` bytes.
//
// To prevent runaway memory allocation, consider using a more conservative value for `limit` than
// the default, particularly on untrusted data streams which may never see EOF.
};
class AsyncOutputStream {
......
......@@ -137,14 +137,41 @@ private:
size_t blockSize;
};
KJ_TEST("InputStream::readAllText()") {
KJ_TEST("InputStream::readAllText() / readAllBytes()") {
auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ",");
size_t blockSizes[] = { 1, 4, 256, bigText.size() };
size_t inputSizes[] = { 0, 1, 256, 4096, 8191, 8192, 8193, 10000, bigText.size() };
size_t blockSizes[] = { 1, 4, 256, 4096, 8192, bigText.size() };
uint64_t limits[] = {
0, 1, 256,
bigText.size() / 2,
bigText.size() - 1,
bigText.size(),
bigText.size() + 1,
kj::maxValue
};
for (size_t blockSize: blockSizes) {
KJ_CONTEXT(blockSize);
MockInputStream input(bigText.asBytes(), blockSize);
KJ_EXPECT(input.readAllText() == bigText);
for (size_t inputSize: inputSizes) {
for (size_t blockSize: blockSizes) {
for (uint64_t limit: limits) {
KJ_CONTEXT(inputSize, blockSize, limit);
auto textSlice = bigText.asBytes().slice(0, inputSize);
auto readAllText = [&]() {
MockInputStream input(textSlice, blockSize);
return input.readAllText(limit);
};
auto readAllBytes = [&]() {
MockInputStream input(textSlice, blockSize);
return input.readAllBytes(limit);
};
if (limit > inputSize) {
KJ_EXPECT(readAllText().asBytes() == textSlice);
KJ_EXPECT(readAllBytes() == textSlice);
} else {
KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllText());
KJ_EXPECT_THROW_MESSAGE("Reached limit before EOF.", readAllBytes());
}
}
}
}
}
......
......@@ -70,13 +70,15 @@ void InputStream::skip(size_t bytes) {
namespace {
Array<byte> readAll(InputStream& input, bool nulTerminate) {
Array<byte> readAll(InputStream& input, uint64_t limit, bool nulTerminate) {
Vector<Array<byte>> parts;
constexpr size_t BLOCK_SIZE = 4096;
for (;;) {
auto part = heapArray<byte>(BLOCK_SIZE);
KJ_REQUIRE(limit > 0, "Reached limit before EOF.");
auto part = heapArray<byte>(kj::min(BLOCK_SIZE, limit));
size_t n = input.tryRead(part.begin(), part.size(), part.size());
limit -= n;
if (n < part.size()) {
auto result = heapArray<byte>(parts.size() * BLOCK_SIZE + n + nulTerminate);
byte* pos = result.begin();
......@@ -97,11 +99,11 @@ Array<byte> readAll(InputStream& input, bool nulTerminate) {
} // namespace
String InputStream::readAllText() {
return String(readAll(*this, true).releaseAsChars());
String InputStream::readAllText(uint64_t limit) {
return String(readAll(*this, limit, true).releaseAsChars());
}
Array<byte> InputStream::readAllBytes() {
return readAll(*this, false);
Array<byte> InputStream::readAllBytes(uint64_t limit) {
return readAll(*this, limit, false);
}
void OutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
......
......@@ -29,6 +29,7 @@
#include "common.h"
#include "array.h"
#include "exception.h"
#include <stdint.h>
namespace kj {
......@@ -66,9 +67,13 @@ public:
// Skips past the given number of bytes, discarding them. The default implementation read()s
// into a scratch buffer.
String readAllText();
Array<byte> readAllBytes();
// Read until EOF and return as one big byte array or string.
String readAllText(uint64_t limit = kj::maxValue);
Array<byte> readAllBytes(uint64_t limit = kj::maxValue);
// Read until EOF and return as one big byte array or string. Throw an exception if EOF is not
// seen before reading `limit` bytes.
//
// To prevent runaway memory allocation, consider using a more conservative value for `limit` than
// the default, particularly on untrusted data streams which may never see EOF.
};
class OutputStream {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment