diff --git a/c++/src/kj/compat/gzip-test.c++ b/c++/src/kj/compat/gzip-test.c++ index db83d8f0b6e2bf6de99c897863eb346d15dbd8dc..2c44aee068c36156fa28c2f81559115364376a19 100644 --- a/c++/src/kj/compat/gzip-test.c++ +++ b/c++/src/kj/compat/gzip-test.c++ @@ -36,11 +36,36 @@ static const byte FOOBAR_GZIP[] = { 0x00, 0x00, }; -class MockInputStream: public AsyncInputStream { +class MockInputStream: public InputStream { public: MockInputStream(kj::ArrayPtr<const byte> bytes, size_t blockSize) : bytes(bytes), blockSize(blockSize) {} + size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // Clamp max read to blockSize. + size_t n = kj::min(blockSize, maxBytes); + + // Unless that's less than minBytes -- in which case, use minBytes. + n = kj::max(n, minBytes); + + // But also don't read more data than we have. + n = kj::min(n, bytes.size()); + + memcpy(buffer, bytes.begin(), n); + bytes = bytes.slice(n, bytes.size()); + return n; + } + +private: + kj::ArrayPtr<const byte> bytes; + size_t blockSize; +}; + +class MockAsyncInputStream: public AsyncInputStream { +public: + MockAsyncInputStream(kj::ArrayPtr<const byte> bytes, size_t blockSize) + : bytes(bytes), blockSize(blockSize) {} + Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { // Clamp max read to blockSize. size_t n = kj::min(blockSize, maxBytes); @@ -62,25 +87,66 @@ private: }; KJ_TEST("gzip decompression") { + // Normal read. + { + MockInputStream rawInput(FOOBAR_GZIP, kj::maxValue); + GzipInputStream gzip(rawInput); + KJ_EXPECT(gzip.readAllText() == "foobar"); + } + + // Force read one byte at a time. + { + MockInputStream rawInput(FOOBAR_GZIP, 1); + GzipInputStream gzip(rawInput); + KJ_EXPECT(gzip.readAllText() == "foobar"); + } + + // Read truncated input. + { + MockInputStream rawInput(kj::arrayPtr(FOOBAR_GZIP, sizeof(FOOBAR_GZIP) / 2), kj::maxValue); + GzipInputStream gzip(rawInput); + + char text[16]; + size_t n = gzip.tryRead(text, 1, sizeof(text)); + text[n] = '\0'; + KJ_EXPECT(StringPtr(text, n) == "fo"); + + KJ_EXPECT_THROW_MESSAGE("gzip compressed stream ended prematurely", + gzip.tryRead(text, 1, sizeof(text))); + } + + // Read concatenated input. + { + Vector<byte> bytes; + bytes.addAll(ArrayPtr<const byte>(FOOBAR_GZIP)); + bytes.addAll(ArrayPtr<const byte>(FOOBAR_GZIP)); + MockInputStream rawInput(bytes, kj::maxValue); + GzipInputStream gzip(rawInput); + + KJ_EXPECT(gzip.readAllText() == "foobarfoobar"); + } +} + +KJ_TEST("async gzip decompression") { auto io = setupAsyncIo(); // Normal read. { - MockInputStream rawInput(FOOBAR_GZIP, kj::maxValue); + MockAsyncInputStream rawInput(FOOBAR_GZIP, kj::maxValue); GzipAsyncInputStream gzip(rawInput); KJ_EXPECT(gzip.readAllText().wait(io.waitScope) == "foobar"); } // Force read one byte at a time. { - MockInputStream rawInput(FOOBAR_GZIP, 1); + MockAsyncInputStream rawInput(FOOBAR_GZIP, 1); GzipAsyncInputStream gzip(rawInput); KJ_EXPECT(gzip.readAllText().wait(io.waitScope) == "foobar"); } // Read truncated input. { - MockInputStream rawInput(kj::arrayPtr(FOOBAR_GZIP, sizeof(FOOBAR_GZIP) / 2), kj::maxValue); + MockAsyncInputStream rawInput(kj::arrayPtr(FOOBAR_GZIP, sizeof(FOOBAR_GZIP) / 2), kj::maxValue); GzipAsyncInputStream gzip(rawInput); char text[16]; @@ -97,19 +163,39 @@ KJ_TEST("gzip decompression") { Vector<byte> bytes; bytes.addAll(ArrayPtr<const byte>(FOOBAR_GZIP)); bytes.addAll(ArrayPtr<const byte>(FOOBAR_GZIP)); - MockInputStream rawInput(bytes, kj::maxValue); + MockAsyncInputStream rawInput(bytes, kj::maxValue); GzipAsyncInputStream gzip(rawInput); KJ_EXPECT(gzip.readAllText().wait(io.waitScope) == "foobarfoobar"); } } -class MockOutputStream: public AsyncOutputStream { +class MockOutputStream: public OutputStream { public: kj::Vector<byte> bytes; - kj::String decompress(WaitScope& ws) { + kj::String decompress() { MockInputStream rawInput(bytes, kj::maxValue); + GzipInputStream gzip(rawInput); + return gzip.readAllText(); + } + + void write(const void* buffer, size_t size) override { + bytes.addAll(arrayPtr(reinterpret_cast<const byte*>(buffer), size)); + } + void write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { + for (auto& piece: pieces) { + bytes.addAll(piece); + } + } +}; + +class MockAsyncOutputStream: public AsyncOutputStream { +public: + kj::Vector<byte> bytes; + + kj::String decompress(WaitScope& ws) { + MockAsyncInputStream rawInput(bytes, kj::maxValue); GzipAsyncInputStream gzip(rawInput); return gzip.readAllText().wait(ws); } @@ -127,11 +213,73 @@ public: }; KJ_TEST("gzip compression") { + // Normal write. + { + MockOutputStream rawOutput; + { + GzipOutputStream gzip(rawOutput); + gzip.write("foobar", 6); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } + + // Multi-part write. + { + MockOutputStream rawOutput; + { + GzipOutputStream gzip(rawOutput); + gzip.write("foo", 3); + gzip.write("bar", 3); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } + + // Array-of-arrays write. + { + MockOutputStream rawOutput; + + { + GzipOutputStream gzip(rawOutput); + + ArrayPtr<const byte> pieces[] = { + kj::StringPtr("foo").asBytes(), + kj::StringPtr("bar").asBytes(), + }; + gzip.write(pieces); + } + + KJ_EXPECT(rawOutput.decompress() == "foobar"); + } +} + +KJ_TEST("gzip huge round trip") { + auto bytes = heapArray<byte>(65536); + for (auto& b: bytes) { + b = rand(); + } + + MockOutputStream rawOutput; + { + GzipOutputStream gzipOut(rawOutput); + gzipOut.write(bytes.begin(), bytes.size()); + } + + MockInputStream rawInput(rawOutput.bytes, kj::maxValue); + GzipInputStream gzipIn(rawInput); + auto decompressed = gzipIn.readAllBytes(); + + KJ_ASSERT(decompressed.size() == bytes.size()); + KJ_ASSERT(memcmp(bytes.begin(), decompressed.begin(), bytes.size()) == 0); +} + +KJ_TEST("async gzip compression") { auto io = setupAsyncIo(); // Normal write. { - MockOutputStream rawOutput; + MockAsyncOutputStream rawOutput; GzipAsyncOutputStream gzip(rawOutput); gzip.write("foobar", 6).wait(io.waitScope); gzip.end().wait(io.waitScope); @@ -141,7 +289,7 @@ KJ_TEST("gzip compression") { // Multi-part write. { - MockOutputStream rawOutput; + MockAsyncOutputStream rawOutput; GzipAsyncOutputStream gzip(rawOutput); gzip.write("foo", 3).wait(io.waitScope); gzip.write("bar", 3).wait(io.waitScope); @@ -152,7 +300,7 @@ KJ_TEST("gzip compression") { // Array-of-arrays write. { - MockOutputStream rawOutput; + MockAsyncOutputStream rawOutput; GzipAsyncOutputStream gzip(rawOutput); ArrayPtr<const byte> pieces[] = { @@ -166,7 +314,7 @@ KJ_TEST("gzip compression") { } } -KJ_TEST("gzip huge round trip") { +KJ_TEST("async gzip huge round trip") { auto io = setupAsyncIo(); auto bytes = heapArray<byte>(65536); @@ -174,12 +322,12 @@ KJ_TEST("gzip huge round trip") { b = rand(); } - MockOutputStream rawOutput; + MockAsyncOutputStream rawOutput; GzipAsyncOutputStream gzipOut(rawOutput); gzipOut.write(bytes.begin(), bytes.size()).wait(io.waitScope); gzipOut.end().wait(io.waitScope); - MockInputStream rawInput(rawOutput.bytes, kj::maxValue); + MockAsyncInputStream rawInput(rawOutput.bytes, kj::maxValue); GzipAsyncInputStream gzipIn(rawInput); auto decompressed = gzipIn.readAllBytes().wait(io.waitScope); diff --git a/c++/src/kj/compat/gzip.c++ b/c++/src/kj/compat/gzip.c++ index 6fbbf48ade501358dfa809d3b0415e14b92e60c5..a8e8f475ad005c9db6f07b1a576d5b68357c8b94 100644 --- a/c++/src/kj/compat/gzip.c++ +++ b/c++/src/kj/compat/gzip.c++ @@ -26,6 +26,138 @@ namespace kj { +GzipInputStream::GzipInputStream(InputStream& inner) + : inner(inner) { + memset(&ctx, 0, sizeof(ctx)); + ctx.next_in = nullptr; + ctx.avail_in = 0; + ctx.next_out = nullptr; + ctx.avail_out = 0; + + // windowBits = 15 (maximum) + magic value 16 to ask for gzip. + KJ_ASSERT(inflateInit2(&ctx, 15 + 16) == Z_OK); +} + +GzipInputStream::~GzipInputStream() noexcept(false) { + inflateEnd(&ctx); +} + +size_t GzipInputStream::tryRead(void* out, size_t minBytes, size_t maxBytes) { + if (maxBytes == 0) return size_t(0); + + return readImpl(reinterpret_cast<byte*>(out), minBytes, maxBytes, 0); +} + +size_t GzipInputStream::readImpl( + byte* out, size_t minBytes, size_t maxBytes, size_t alreadyRead) { + if (ctx.avail_in == 0) { + size_t amount = inner.tryRead(buffer, 1, sizeof(buffer)); + if (amount == 0) { + if (!atValidEndpoint) { + KJ_FAIL_REQUIRE("gzip compressed stream ended prematurely"); + } + return alreadyRead; + } else { + ctx.next_in = buffer; + ctx.avail_in = amount; + } + } + + ctx.next_out = reinterpret_cast<byte*>(out); + ctx.avail_out = maxBytes; + + auto inflateResult = inflate(&ctx, Z_NO_FLUSH); + atValidEndpoint = inflateResult == Z_STREAM_END; + if (inflateResult == Z_OK || inflateResult == Z_STREAM_END) { + if (atValidEndpoint && ctx.avail_in > 0) { + // There's more data available. Assume start of new content. + KJ_ASSERT(inflateReset(&ctx) == Z_OK); + } + + size_t n = maxBytes - ctx.avail_out; + if (n >= minBytes) { + return n + alreadyRead; + } else { + return readImpl(out + n, minBytes - n, maxBytes - n, alreadyRead + n); + } + } else { + if (ctx.msg == nullptr) { + KJ_FAIL_REQUIRE("gzip decompression failed", inflateResult); + } else { + KJ_FAIL_REQUIRE("gzip decompression failed", ctx.msg); + } + } +} + +// ======================================================================================= + +GzipOutputStream::GzipOutputStream(OutputStream& inner, int compressionLevel) + : inner(inner) { + memset(&ctx, 0, sizeof(ctx)); + ctx.next_in = nullptr; + ctx.avail_in = 0; + ctx.next_out = nullptr; + ctx.avail_out = 0; + + int initResult = + deflateInit2(&ctx, compressionLevel, Z_DEFLATED, + 15 + 16, // windowBits = 15 (maximum) + magic value 16 to ask for gzip. + 8, // memLevel = 8 (the default) + Z_DEFAULT_STRATEGY); + KJ_ASSERT(initResult == Z_OK, initResult); +} + +GzipOutputStream::~GzipOutputStream() noexcept(false) { + KJ_DEFER(deflateEnd(&ctx)); + + for (;;) { + ctx.next_out = buffer; + ctx.avail_out = sizeof(buffer); + + auto deflateResult = deflate(&ctx, Z_FINISH); + if (deflateResult != Z_OK && deflateResult != Z_STREAM_END) { + if (ctx.msg == nullptr) { + KJ_FAIL_REQUIRE("gzip compression failed", deflateResult); + } else { + KJ_FAIL_REQUIRE("gzip compression failed", ctx.msg); + } + } + + size_t n = sizeof(buffer) - ctx.avail_out; + inner.write(buffer, n); + if (deflateResult == Z_STREAM_END) { + break; + } + } +} + +void GzipOutputStream::write(const void* in, size_t size) { + ctx.next_in = const_cast<byte*>(reinterpret_cast<const byte*>(in)); + ctx.avail_in = size; + pump(); +} + +void GzipOutputStream::pump() { + while (ctx.avail_in > 0) { + ctx.next_out = buffer; + ctx.avail_out = sizeof(buffer); + + auto deflateResult = deflate(&ctx, Z_NO_FLUSH); + if (deflateResult != Z_OK) { + if (ctx.msg == nullptr) { + KJ_FAIL_REQUIRE("gzip compression failed", deflateResult); + } else { + KJ_FAIL_REQUIRE("gzip compression failed", ctx.msg); + } + } + + size_t n = sizeof(buffer) - ctx.avail_out; + inner.write(buffer, n); + } +} + +// ======================================================================================= + GzipAsyncInputStream::GzipAsyncInputStream(AsyncInputStream& inner) : inner(inner) { memset(&ctx, 0, sizeof(ctx)); diff --git a/c++/src/kj/compat/gzip.h b/c++/src/kj/compat/gzip.h index 9a08d045e390b40f339d0fc1ad6ed637a436fae3..962a595eff77ad9bcc84b008e9f4b4f91d865dfb 100644 --- a/c++/src/kj/compat/gzip.h +++ b/c++/src/kj/compat/gzip.h @@ -21,11 +21,48 @@ #pragma once +#include <kj/io.h> #include <kj/async-io.h> #include <zlib.h> namespace kj { +class GzipInputStream final: public InputStream { +public: + GzipInputStream(InputStream& inner); + ~GzipInputStream() noexcept(false); + KJ_DISALLOW_COPY(GzipInputStream); + + size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + +private: + InputStream& inner; + z_stream ctx; + bool atValidEndpoint = false; + + byte buffer[4096]; + + size_t readImpl(byte* buffer, size_t minBytes, size_t maxBytes, size_t alreadyRead); +}; + +class GzipOutputStream final: public OutputStream { +public: + GzipOutputStream(OutputStream& inner, int compressionLevel = Z_DEFAULT_COMPRESSION); + ~GzipOutputStream() noexcept(false); + KJ_DISALLOW_COPY(GzipOutputStream); + + void write(const void* buffer, size_t size) override; + using OutputStream::write; + +private: + OutputStream& inner; + z_stream ctx; + + byte buffer[4096]; + + void pump(); +}; + class GzipAsyncInputStream final: public AsyncInputStream { public: GzipAsyncInputStream(AsyncInputStream& inner);