// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors // Licensed under the MIT License: // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. #include "serialize-packed.h" #include <kj/debug.h> #include "layout.h" #include <vector> namespace capnp { namespace _ { // private PackedInputStream::PackedInputStream(kj::BufferedInputStream& inner): inner(inner) {} PackedInputStream::~PackedInputStream() noexcept(false) {} size_t PackedInputStream::tryRead(void* dst, size_t minBytes, size_t maxBytes) { if (maxBytes == 0) { return 0; } KJ_DREQUIRE(minBytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned."); KJ_DREQUIRE(maxBytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned."); uint8_t* __restrict__ out = reinterpret_cast<uint8_t*>(dst); uint8_t* const outEnd = reinterpret_cast<uint8_t*>(dst) + maxBytes; uint8_t* const outMin = reinterpret_cast<uint8_t*>(dst) + minBytes; kj::ArrayPtr<const byte> buffer = inner.tryGetReadBuffer(); if (buffer.size() == 0) { return 0; } const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin()); #define REFRESH_BUFFER() \ inner.skip(buffer.size()); \ buffer = inner.getReadBuffer(); \ KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { \ return out - reinterpret_cast<uint8_t*>(dst); \ } \ in = reinterpret_cast<const uint8_t*>(buffer.begin()) #define BUFFER_END (reinterpret_cast<const uint8_t*>(buffer.end())) #define BUFFER_REMAINING ((size_t)(BUFFER_END - in)) for (;;) { uint8_t tag; KJ_DASSERT((out - reinterpret_cast<uint8_t*>(dst)) % sizeof(word) == 0, "Output pointer should always be aligned here."); if (BUFFER_REMAINING < 10) { if (out >= outMin) { // We read at least the minimum amount, so go ahead and return. inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin())); return out - reinterpret_cast<uint8_t*>(dst); } if (BUFFER_REMAINING == 0) { REFRESH_BUFFER(); continue; } // We have at least 1, but not 10, bytes available. We need to read slowly, doing a bounds // check on each byte. tag = *in++; for (uint i = 0; i < 8; i++) { if (tag & (1u << i)) { if (BUFFER_REMAINING == 0) { REFRESH_BUFFER(); } *out++ = *in++; } else { *out++ = 0; } } if (BUFFER_REMAINING == 0 && (tag == 0 || tag == 0xffu)) { REFRESH_BUFFER(); } } else { tag = *in++; #define HANDLE_BYTE(n) \ { \ bool isNonzero = (tag & (1u << n)) != 0; \ *out++ = *in & (-(int8_t)isNonzero); \ in += isNonzero; \ } HANDLE_BYTE(0); HANDLE_BYTE(1); HANDLE_BYTE(2); HANDLE_BYTE(3); HANDLE_BYTE(4); HANDLE_BYTE(5); HANDLE_BYTE(6); HANDLE_BYTE(7); #undef HANDLE_BYTE } if (tag == 0) { KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here."); uint runLength = *in++ * sizeof(word); KJ_REQUIRE(runLength <= outEnd - out, "Packed input did not end cleanly on a segment boundary.") { return out - reinterpret_cast<uint8_t*>(dst); } memset(out, 0, runLength); out += runLength; } else if (tag == 0xffu) { KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here."); uint runLength = *in++ * sizeof(word); KJ_REQUIRE(runLength <= outEnd - out, "Packed input did not end cleanly on a segment boundary.") { return out - reinterpret_cast<uint8_t*>(dst); } uint inRemaining = BUFFER_REMAINING; if (inRemaining >= runLength) { // Fast path. memcpy(out, in, runLength); out += runLength; in += runLength; } else { // Copy over the first buffer, then do one big read for the rest. memcpy(out, in, inRemaining); out += inRemaining; runLength -= inRemaining; inner.skip(buffer.size()); inner.read(out, runLength); out += runLength; if (out == outEnd) { return maxBytes; } else { buffer = inner.getReadBuffer(); in = reinterpret_cast<const uint8_t*>(buffer.begin()); // Skip the bounds check below since we just did the same check above. continue; } } } if (out == outEnd) { inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin())); return maxBytes; } } KJ_FAIL_ASSERT("Can't get here."); return 0; // GCC knows KJ_FAIL_ASSERT doesn't return, but Eclipse CDT still warns... #undef REFRESH_BUFFER } void PackedInputStream::skip(size_t bytes) { // We can't just read into buffers because buffers must end on block boundaries. if (bytes == 0) { return; } KJ_DREQUIRE(bytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned."); kj::ArrayPtr<const byte> buffer = inner.getReadBuffer(); const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin()); #define REFRESH_BUFFER() \ inner.skip(buffer.size()); \ buffer = inner.getReadBuffer(); \ KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { return; } \ in = reinterpret_cast<const uint8_t*>(buffer.begin()) for (;;) { uint8_t tag; if (BUFFER_REMAINING < 10) { if (BUFFER_REMAINING == 0) { REFRESH_BUFFER(); continue; } // We have at least 1, but not 10, bytes available. We need to read slowly, doing a bounds // check on each byte. tag = *in++; for (uint i = 0; i < 8; i++) { if (tag & (1u << i)) { if (BUFFER_REMAINING == 0) { REFRESH_BUFFER(); } in++; } } bytes -= 8; if (BUFFER_REMAINING == 0 && (tag == 0 || tag == 0xffu)) { REFRESH_BUFFER(); } } else { tag = *in++; #define HANDLE_BYTE(n) \ in += (tag & (1u << n)) != 0 HANDLE_BYTE(0); HANDLE_BYTE(1); HANDLE_BYTE(2); HANDLE_BYTE(3); HANDLE_BYTE(4); HANDLE_BYTE(5); HANDLE_BYTE(6); HANDLE_BYTE(7); #undef HANDLE_BYTE bytes -= 8; } if (tag == 0) { KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here."); uint runLength = *in++ * sizeof(word); KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary.") { return; } bytes -= runLength; } else if (tag == 0xffu) { KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here."); uint runLength = *in++ * sizeof(word); KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary.") { return; } bytes -= runLength; uint inRemaining = BUFFER_REMAINING; if (inRemaining > runLength) { // Fast path. in += runLength; } else { // Forward skip to the underlying stream. runLength -= inRemaining; inner.skip(buffer.size() + runLength); if (bytes == 0) { return; } else { buffer = inner.getReadBuffer(); in = reinterpret_cast<const uint8_t*>(buffer.begin()); // Skip the bounds check below since we just did the same check above. continue; } } } if (bytes == 0) { inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin())); return; } } KJ_FAIL_ASSERT("Can't get here."); } // ------------------------------------------------------------------- PackedOutputStream::PackedOutputStream(kj::BufferedOutputStream& inner) : inner(inner) {} PackedOutputStream::~PackedOutputStream() noexcept(false) {} void PackedOutputStream::write(const void* src, size_t size) { kj::ArrayPtr<byte> buffer = inner.getWriteBuffer(); byte slowBuffer[20]; uint8_t* __restrict__ out = reinterpret_cast<uint8_t*>(buffer.begin()); const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(src); const uint8_t* const inEnd = reinterpret_cast<const uint8_t*>(src) + size; while (in < inEnd) { if (reinterpret_cast<uint8_t*>(buffer.end()) - out < 10) { // Oops, we're out of space. We need at least 10 bytes for the fast path, since we don't // bounds-check on every byte. // Write what we have so far. inner.write(buffer.begin(), out - reinterpret_cast<uint8_t*>(buffer.begin())); // Use a slow buffer into which we'll encode 10 to 20 bytes. This should get us past the // output stream's buffer boundary. buffer = kj::arrayPtr(slowBuffer, sizeof(slowBuffer)); out = reinterpret_cast<uint8_t*>(buffer.begin()); } uint8_t* tagPos = out++; #define HANDLE_BYTE(n) \ uint8_t bit##n = *in != 0; \ *out = *in; \ out += bit##n; /* out only advances if the byte was non-zero */ \ ++in HANDLE_BYTE(0); HANDLE_BYTE(1); HANDLE_BYTE(2); HANDLE_BYTE(3); HANDLE_BYTE(4); HANDLE_BYTE(5); HANDLE_BYTE(6); HANDLE_BYTE(7); #undef HANDLE_BYTE uint8_t tag = (bit0 << 0) | (bit1 << 1) | (bit2 << 2) | (bit3 << 3) | (bit4 << 4) | (bit5 << 5) | (bit6 << 6) | (bit7 << 7); *tagPos = tag; if (tag == 0) { // An all-zero word is followed by a count of consecutive zero words (not including the // first one). // We can check a whole word at a time. const uint64_t* inWord = reinterpret_cast<const uint64_t*>(in); // The count must fit it 1 byte, so limit to 255 words. const uint64_t* limit = reinterpret_cast<const uint64_t*>(inEnd); if (limit - inWord > 255) { limit = inWord + 255; } while (inWord < limit && *inWord == 0) { ++inWord; } // Write the count. *out++ = inWord - reinterpret_cast<const uint64_t*>(in); // Advance input. in = reinterpret_cast<const uint8_t*>(inWord); } else if (tag == 0xffu) { // An all-nonzero word is followed by a count of consecutive uncompressed words, followed // by the uncompressed words themselves. // Count the number of consecutive words in the input which have no more than a single // zero-byte. We look for at least two zeros because that's the point where our compression // scheme becomes a net win. // TODO(perf): Maybe look for three zeros? Compressing a two-zero word is a loss if the // following word has no zeros. const uint8_t* runStart = in; const uint8_t* limit = inEnd; if ((size_t)(limit - in) > 255 * sizeof(word)) { limit = in + 255 * sizeof(word); } while (in < limit) { // Check eight input bytes for zeros. uint c = *in++ == 0; c += *in++ == 0; c += *in++ == 0; c += *in++ == 0; c += *in++ == 0; c += *in++ == 0; c += *in++ == 0; c += *in++ == 0; if (c >= 2) { // Un-read the word with multiple zeros, since we'll want to compress that one. in -= 8; break; } } // Write the count. uint count = in - runStart; *out++ = count / sizeof(word); if (count <= reinterpret_cast<uint8_t*>(buffer.end()) - out) { // There's enough space to memcpy. memcpy(out, runStart, count); out += count; } else { // Input overruns the output buffer. We'll give it to the output stream in one chunk // and let it decide what to do. inner.write(buffer.begin(), reinterpret_cast<byte*>(out) - buffer.begin()); inner.write(runStart, in - runStart); buffer = inner.getWriteBuffer(); out = reinterpret_cast<uint8_t*>(buffer.begin()); } } } // Write whatever is left. inner.write(buffer.begin(), reinterpret_cast<byte*>(out) - buffer.begin()); } } // namespace _ (private) // ======================================================================================= PackedMessageReader::PackedMessageReader( kj::BufferedInputStream& inputStream, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) : PackedInputStream(inputStream), InputStreamMessageReader(static_cast<PackedInputStream&>(*this), options, scratchSpace) {} PackedMessageReader::~PackedMessageReader() noexcept(false) {} PackedFdMessageReader::PackedFdMessageReader( int fd, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) : FdInputStream(fd), BufferedInputStreamWrapper(static_cast<FdInputStream&>(*this)), PackedMessageReader(static_cast<BufferedInputStreamWrapper&>(*this), options, scratchSpace) {} PackedFdMessageReader::PackedFdMessageReader( kj::AutoCloseFd fd, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) : FdInputStream(kj::mv(fd)), BufferedInputStreamWrapper(static_cast<FdInputStream&>(*this)), PackedMessageReader(static_cast<BufferedInputStreamWrapper&>(*this), options, scratchSpace) {} PackedFdMessageReader::~PackedFdMessageReader() noexcept(false) {} void writePackedMessage(kj::BufferedOutputStream& output, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { _::PackedOutputStream packedOutput(output); writeMessage(packedOutput, segments); } void writePackedMessage(kj::OutputStream& output, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { KJ_IF_MAYBE(bufferedOutputPtr, kj::dynamicDowncastIfAvailable<kj::BufferedOutputStream>(output)) { writePackedMessage(*bufferedOutputPtr, segments); } else { byte buffer[8192]; kj::BufferedOutputStreamWrapper bufferedOutput(output, kj::arrayPtr(buffer, sizeof(buffer))); writePackedMessage(bufferedOutput, segments); } } void writePackedMessageToFd(int fd, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { kj::FdOutputStream output(fd); writePackedMessage(output, segments); } size_t computeUnpackedSizeInWords(kj::ArrayPtr<const byte> packedBytes) { const byte* ptr = packedBytes.begin(); const byte* end = packedBytes.end(); size_t total = 0; while (ptr < end) { uint tag = *ptr; size_t count = kj::popCount(tag); total += 1; KJ_REQUIRE(end - ptr >= count, "invalid packed data"); ptr += count + 1; if (tag == 0) { KJ_REQUIRE(ptr < end, "invalid packed data"); total += *ptr++; } else if (tag == 0xff) { KJ_REQUIRE(ptr < end, "invalid packed data"); size_t words = *ptr++; total += words; size_t bytes = words * sizeof(word); KJ_REQUIRE(end - ptr >= bytes, "invalid packed data"); ptr += bytes; } } return total; } } // namespace capnp