Unverified Commit 5c8e496e authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #684 from capnproto/synchronous-gzip

Add sychronous gzip stream implementations.
parents cdc5c91c 113fa5a6
......@@ -432,12 +432,14 @@ CapnpParser::CapnpParser(Orphanage orphanageParam, ErrorReporter& errorReporterP
initLocation(location, builder);
return result;
}),
p::transform(stringLiteral,
[this](Located<Text::Reader>&& value) -> Orphan<Expression> {
p::transform(p::oneOrMore(stringLiteral),
[this](kj::Array<Located<Text::Reader>>&& value) -> Orphan<Expression> {
auto result = orphanage.newOrphan<Expression>();
auto builder = result.get();
builder.setString(value.value);
value.copyLocationTo(builder);
builder.setString(kj::strArray(
KJ_MAP(part, value) { return part.value; }, ""));
builder.setStartByte(value.front().startByte);
builder.setEndByte(value.back().endByte);
return result;
}),
p::transform(binaryLiteral,
......
......@@ -127,7 +127,8 @@ struct TestDefaults {
textList = ["quux", "corge", "grault"],
dataList = ["garply", "waldo", "fred"],
structList = [
(textField = "x structlist 1"),
(textField = "x " "structlist"
" 1"),
(textField = "x structlist 2"),
(textField = "x structlist 3")],
enumList = [qux, bar, grault]
......@@ -704,7 +705,8 @@ struct TestConstants {
textList = ["quux", "corge", "grault"],
dataList = ["garply", "waldo", "fred"],
structList = [
(textField = "x structlist 1"),
(textField = "x " "structlist"
" 1"),
(textField = "x structlist 2"),
(textField = "x structlist 3")],
enumList = [qux, bar, grault]
......
......@@ -36,11 +36,36 @@ static const byte FOOBAR_GZIP[] = {
0x00, 0x00,
};
class MockInputStream: public AsyncInputStream {
class MockInputStream: public InputStream {
public:
MockInputStream(kj::ArrayPtr<const byte> bytes, size_t blockSize)
: bytes(bytes), blockSize(blockSize) {}
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
// Clamp max read to blockSize.
size_t n = kj::min(blockSize, maxBytes);
// Unless that's less than minBytes -- in which case, use minBytes.
n = kj::max(n, minBytes);
// But also don't read more data than we have.
n = kj::min(n, bytes.size());
memcpy(buffer, bytes.begin(), n);
bytes = bytes.slice(n, bytes.size());
return n;
}
private:
kj::ArrayPtr<const byte> bytes;
size_t blockSize;
};
class MockAsyncInputStream: public AsyncInputStream {
public:
MockAsyncInputStream(kj::ArrayPtr<const byte> bytes, size_t blockSize)
: bytes(bytes), blockSize(blockSize) {}
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
// Clamp max read to blockSize.
size_t n = kj::min(blockSize, maxBytes);
......@@ -62,25 +87,66 @@ private:
};
KJ_TEST("gzip decompression") {
// Normal read.
{
MockInputStream rawInput(FOOBAR_GZIP, kj::maxValue);
GzipInputStream gzip(rawInput);
KJ_EXPECT(gzip.readAllText() == "foobar");
}
// Force read one byte at a time.
{
MockInputStream rawInput(FOOBAR_GZIP, 1);
GzipInputStream gzip(rawInput);
KJ_EXPECT(gzip.readAllText() == "foobar");
}
// Read truncated input.
{
MockInputStream rawInput(kj::arrayPtr(FOOBAR_GZIP, sizeof(FOOBAR_GZIP) / 2), kj::maxValue);
GzipInputStream gzip(rawInput);
char text[16];
size_t n = gzip.tryRead(text, 1, sizeof(text));
text[n] = '\0';
KJ_EXPECT(StringPtr(text, n) == "fo");
KJ_EXPECT_THROW_MESSAGE("gzip compressed stream ended prematurely",
gzip.tryRead(text, 1, sizeof(text)));
}
// Read concatenated input.
{
Vector<byte> bytes;
bytes.addAll(ArrayPtr<const byte>(FOOBAR_GZIP));
bytes.addAll(ArrayPtr<const byte>(FOOBAR_GZIP));
MockInputStream rawInput(bytes, kj::maxValue);
GzipInputStream gzip(rawInput);
KJ_EXPECT(gzip.readAllText() == "foobarfoobar");
}
}
KJ_TEST("async gzip decompression") {
auto io = setupAsyncIo();
// Normal read.
{
MockInputStream rawInput(FOOBAR_GZIP, kj::maxValue);
MockAsyncInputStream rawInput(FOOBAR_GZIP, kj::maxValue);
GzipAsyncInputStream gzip(rawInput);
KJ_EXPECT(gzip.readAllText().wait(io.waitScope) == "foobar");
}
// Force read one byte at a time.
{
MockInputStream rawInput(FOOBAR_GZIP, 1);
MockAsyncInputStream rawInput(FOOBAR_GZIP, 1);
GzipAsyncInputStream gzip(rawInput);
KJ_EXPECT(gzip.readAllText().wait(io.waitScope) == "foobar");
}
// Read truncated input.
{
MockInputStream rawInput(kj::arrayPtr(FOOBAR_GZIP, sizeof(FOOBAR_GZIP) / 2), kj::maxValue);
MockAsyncInputStream rawInput(kj::arrayPtr(FOOBAR_GZIP, sizeof(FOOBAR_GZIP) / 2), kj::maxValue);
GzipAsyncInputStream gzip(rawInput);
char text[16];
......@@ -97,19 +163,39 @@ KJ_TEST("gzip decompression") {
Vector<byte> bytes;
bytes.addAll(ArrayPtr<const byte>(FOOBAR_GZIP));
bytes.addAll(ArrayPtr<const byte>(FOOBAR_GZIP));
MockInputStream rawInput(bytes, kj::maxValue);
MockAsyncInputStream rawInput(bytes, kj::maxValue);
GzipAsyncInputStream gzip(rawInput);
KJ_EXPECT(gzip.readAllText().wait(io.waitScope) == "foobarfoobar");
}
}
class MockOutputStream: public AsyncOutputStream {
class MockOutputStream: public OutputStream {
public:
kj::Vector<byte> bytes;
kj::String decompress(WaitScope& ws) {
kj::String decompress() {
MockInputStream rawInput(bytes, kj::maxValue);
GzipInputStream gzip(rawInput);
return gzip.readAllText();
}
void write(const void* buffer, size_t size) override {
bytes.addAll(arrayPtr(reinterpret_cast<const byte*>(buffer), size));
}
void write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
for (auto& piece: pieces) {
bytes.addAll(piece);
}
}
};
class MockAsyncOutputStream: public AsyncOutputStream {
public:
kj::Vector<byte> bytes;
kj::String decompress(WaitScope& ws) {
MockAsyncInputStream rawInput(bytes, kj::maxValue);
GzipAsyncInputStream gzip(rawInput);
return gzip.readAllText().wait(ws);
}
......@@ -127,11 +213,73 @@ public:
};
KJ_TEST("gzip compression") {
// Normal write.
{
MockOutputStream rawOutput;
{
GzipOutputStream gzip(rawOutput);
gzip.write("foobar", 6);
}
KJ_EXPECT(rawOutput.decompress() == "foobar");
}
// Multi-part write.
{
MockOutputStream rawOutput;
{
GzipOutputStream gzip(rawOutput);
gzip.write("foo", 3);
gzip.write("bar", 3);
}
KJ_EXPECT(rawOutput.decompress() == "foobar");
}
// Array-of-arrays write.
{
MockOutputStream rawOutput;
{
GzipOutputStream gzip(rawOutput);
ArrayPtr<const byte> pieces[] = {
kj::StringPtr("foo").asBytes(),
kj::StringPtr("bar").asBytes(),
};
gzip.write(pieces);
}
KJ_EXPECT(rawOutput.decompress() == "foobar");
}
}
KJ_TEST("gzip huge round trip") {
auto bytes = heapArray<byte>(65536);
for (auto& b: bytes) {
b = rand();
}
MockOutputStream rawOutput;
{
GzipOutputStream gzipOut(rawOutput);
gzipOut.write(bytes.begin(), bytes.size());
}
MockInputStream rawInput(rawOutput.bytes, kj::maxValue);
GzipInputStream gzipIn(rawInput);
auto decompressed = gzipIn.readAllBytes();
KJ_ASSERT(decompressed.size() == bytes.size());
KJ_ASSERT(memcmp(bytes.begin(), decompressed.begin(), bytes.size()) == 0);
}
KJ_TEST("async gzip compression") {
auto io = setupAsyncIo();
// Normal write.
{
MockOutputStream rawOutput;
MockAsyncOutputStream rawOutput;
GzipAsyncOutputStream gzip(rawOutput);
gzip.write("foobar", 6).wait(io.waitScope);
gzip.end().wait(io.waitScope);
......@@ -141,7 +289,7 @@ KJ_TEST("gzip compression") {
// Multi-part write.
{
MockOutputStream rawOutput;
MockAsyncOutputStream rawOutput;
GzipAsyncOutputStream gzip(rawOutput);
gzip.write("foo", 3).wait(io.waitScope);
gzip.write("bar", 3).wait(io.waitScope);
......@@ -152,7 +300,7 @@ KJ_TEST("gzip compression") {
// Array-of-arrays write.
{
MockOutputStream rawOutput;
MockAsyncOutputStream rawOutput;
GzipAsyncOutputStream gzip(rawOutput);
ArrayPtr<const byte> pieces[] = {
......@@ -166,7 +314,7 @@ KJ_TEST("gzip compression") {
}
}
KJ_TEST("gzip huge round trip") {
KJ_TEST("async gzip huge round trip") {
auto io = setupAsyncIo();
auto bytes = heapArray<byte>(65536);
......@@ -174,12 +322,12 @@ KJ_TEST("gzip huge round trip") {
b = rand();
}
MockOutputStream rawOutput;
MockAsyncOutputStream rawOutput;
GzipAsyncOutputStream gzipOut(rawOutput);
gzipOut.write(bytes.begin(), bytes.size()).wait(io.waitScope);
gzipOut.end().wait(io.waitScope);
MockInputStream rawInput(rawOutput.bytes, kj::maxValue);
MockAsyncInputStream rawInput(rawOutput.bytes, kj::maxValue);
GzipAsyncInputStream gzipIn(rawInput);
auto decompressed = gzipIn.readAllBytes().wait(io.waitScope);
......
......@@ -26,6 +26,138 @@
namespace kj {
GzipInputStream::GzipInputStream(InputStream& inner)
: inner(inner) {
memset(&ctx, 0, sizeof(ctx));
ctx.next_in = nullptr;
ctx.avail_in = 0;
ctx.next_out = nullptr;
ctx.avail_out = 0;
// windowBits = 15 (maximum) + magic value 16 to ask for gzip.
KJ_ASSERT(inflateInit2(&ctx, 15 + 16) == Z_OK);
}
GzipInputStream::~GzipInputStream() noexcept(false) {
inflateEnd(&ctx);
}
size_t GzipInputStream::tryRead(void* out, size_t minBytes, size_t maxBytes) {
if (maxBytes == 0) return size_t(0);
return readImpl(reinterpret_cast<byte*>(out), minBytes, maxBytes, 0);
}
size_t GzipInputStream::readImpl(
byte* out, size_t minBytes, size_t maxBytes, size_t alreadyRead) {
if (ctx.avail_in == 0) {
size_t amount = inner.tryRead(buffer, 1, sizeof(buffer));
if (amount == 0) {
if (!atValidEndpoint) {
KJ_FAIL_REQUIRE("gzip compressed stream ended prematurely");
}
return alreadyRead;
} else {
ctx.next_in = buffer;
ctx.avail_in = amount;
}
}
ctx.next_out = reinterpret_cast<byte*>(out);
ctx.avail_out = maxBytes;
auto inflateResult = inflate(&ctx, Z_NO_FLUSH);
atValidEndpoint = inflateResult == Z_STREAM_END;
if (inflateResult == Z_OK || inflateResult == Z_STREAM_END) {
if (atValidEndpoint && ctx.avail_in > 0) {
// There's more data available. Assume start of new content.
KJ_ASSERT(inflateReset(&ctx) == Z_OK);
}
size_t n = maxBytes - ctx.avail_out;
if (n >= minBytes) {
return n + alreadyRead;
} else {
return readImpl(out + n, minBytes - n, maxBytes - n, alreadyRead + n);
}
} else {
if (ctx.msg == nullptr) {
KJ_FAIL_REQUIRE("gzip decompression failed", inflateResult);
} else {
KJ_FAIL_REQUIRE("gzip decompression failed", ctx.msg);
}
}
}
// =======================================================================================
GzipOutputStream::GzipOutputStream(OutputStream& inner, int compressionLevel)
: inner(inner) {
memset(&ctx, 0, sizeof(ctx));
ctx.next_in = nullptr;
ctx.avail_in = 0;
ctx.next_out = nullptr;
ctx.avail_out = 0;
int initResult =
deflateInit2(&ctx, compressionLevel, Z_DEFLATED,
15 + 16, // windowBits = 15 (maximum) + magic value 16 to ask for gzip.
8, // memLevel = 8 (the default)
Z_DEFAULT_STRATEGY);
KJ_ASSERT(initResult == Z_OK, initResult);
}
GzipOutputStream::~GzipOutputStream() noexcept(false) {
KJ_DEFER(deflateEnd(&ctx));
for (;;) {
ctx.next_out = buffer;
ctx.avail_out = sizeof(buffer);
auto deflateResult = deflate(&ctx, Z_FINISH);
if (deflateResult != Z_OK && deflateResult != Z_STREAM_END) {
if (ctx.msg == nullptr) {
KJ_FAIL_REQUIRE("gzip compression failed", deflateResult);
} else {
KJ_FAIL_REQUIRE("gzip compression failed", ctx.msg);
}
}
size_t n = sizeof(buffer) - ctx.avail_out;
inner.write(buffer, n);
if (deflateResult == Z_STREAM_END) {
break;
}
}
}
void GzipOutputStream::write(const void* in, size_t size) {
ctx.next_in = const_cast<byte*>(reinterpret_cast<const byte*>(in));
ctx.avail_in = size;
pump();
}
void GzipOutputStream::pump() {
while (ctx.avail_in > 0) {
ctx.next_out = buffer;
ctx.avail_out = sizeof(buffer);
auto deflateResult = deflate(&ctx, Z_NO_FLUSH);
if (deflateResult != Z_OK) {
if (ctx.msg == nullptr) {
KJ_FAIL_REQUIRE("gzip compression failed", deflateResult);
} else {
KJ_FAIL_REQUIRE("gzip compression failed", ctx.msg);
}
}
size_t n = sizeof(buffer) - ctx.avail_out;
inner.write(buffer, n);
}
}
// =======================================================================================
GzipAsyncInputStream::GzipAsyncInputStream(AsyncInputStream& inner)
: inner(inner) {
memset(&ctx, 0, sizeof(ctx));
......
......@@ -21,11 +21,48 @@
#pragma once
#include <kj/io.h>
#include <kj/async-io.h>
#include <zlib.h>
namespace kj {
class GzipInputStream final: public InputStream {
public:
GzipInputStream(InputStream& inner);
~GzipInputStream() noexcept(false);
KJ_DISALLOW_COPY(GzipInputStream);
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override;
private:
InputStream& inner;
z_stream ctx;
bool atValidEndpoint = false;
byte buffer[4096];
size_t readImpl(byte* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead);
};
class GzipOutputStream final: public OutputStream {
public:
GzipOutputStream(OutputStream& inner, int compressionLevel = Z_DEFAULT_COMPRESSION);
~GzipOutputStream() noexcept(false);
KJ_DISALLOW_COPY(GzipOutputStream);
void write(const void* buffer, size_t size) override;
using OutputStream::write;
private:
OutputStream& inner;
z_stream ctx;
byte buffer[4096];
void pump();
};
class GzipAsyncInputStream final: public AsyncInputStream {
public:
GzipAsyncInputStream(AsyncInputStream& inner);
......
......@@ -112,5 +112,41 @@ KJ_TEST("VectorOutputStream") {
KJ_ASSERT(output.getWriteBuffer().begin() == output.getArray().begin() + 40);
}
class MockInputStream: public InputStream {
public:
MockInputStream(kj::ArrayPtr<const byte> bytes, size_t blockSize)
: bytes(bytes), blockSize(blockSize) {}
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
// Clamp max read to blockSize.
size_t n = kj::min(blockSize, maxBytes);
// Unless that's less than minBytes -- in which case, use minBytes.
n = kj::max(n, minBytes);
// But also don't read more data than we have.
n = kj::min(n, bytes.size());
memcpy(buffer, bytes.begin(), n);
bytes = bytes.slice(n, bytes.size());
return n;
}
private:
kj::ArrayPtr<const byte> bytes;
size_t blockSize;
};
KJ_TEST("InputStream::readAllText()") {
auto bigText = strArray(kj::repeat("foo bar baz"_kj, 12345), ",");
size_t blockSizes[] = { 1, 4, 256, bigText.size() };
for (size_t blockSize: blockSizes) {
KJ_CONTEXT(blockSize);
MockInputStream input(bigText.asBytes(), blockSize);
KJ_EXPECT(input.readAllText() == bigText);
}
}
} // namespace
} // namespace kj
......@@ -28,6 +28,7 @@
#include "miniposix.h"
#include <algorithm>
#include <errno.h>
#include "vector.h"
#if _WIN32
#ifndef NOMINMAX
......@@ -66,6 +67,43 @@ void InputStream::skip(size_t bytes) {
}
}
namespace {
Array<byte> readAll(InputStream& input, bool nulTerminate) {
Vector<Array<byte>> parts;
constexpr size_t BLOCK_SIZE = 4096;
for (;;) {
auto part = heapArray<byte>(BLOCK_SIZE);
size_t n = input.tryRead(part.begin(), part.size(), part.size());
if (n < part.size()) {
auto result = heapArray<byte>(parts.size() * BLOCK_SIZE + n + nulTerminate);
byte* pos = result.begin();
for (auto& p: parts) {
memcpy(pos, p.begin(), BLOCK_SIZE);
pos += BLOCK_SIZE;
}
memcpy(pos, part.begin(), n);
pos += n;
if (nulTerminate) *pos++ = '\0';
KJ_ASSERT(pos == result.end());
return result;
} else {
parts.add(kj::mv(part));
}
}
}
} // namespace
String InputStream::readAllText() {
return String(readAll(*this, true).releaseAsChars());
}
Array<byte> InputStream::readAllBytes() {
return readAll(*this, false);
}
void OutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
for (auto piece: pieces) {
write(piece.begin(), piece.size());
......
......@@ -65,6 +65,10 @@ public:
virtual void skip(size_t bytes);
// Skips past the given number of bytes, discarding them. The default implementation read()s
// into a scratch buffer.
String readAllText();
Array<byte> readAllBytes();
// Read until EOF and return as one big byte array or string.
};
class OutputStream {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment