// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#include "serialize.h"
#include "layout.h"
#include <kj/debug.h>
#include <exception>

namespace capnp {

UnalignedFlatArrayMessageReader::UnalignedFlatArrayMessageReader(
    kj::ArrayPtr<const word> array, ReaderOptions options)
    : MessageReader(options), end(array.end()) {
  if (array.size() < 1) {
    // Assume empty message.
    return;
  }

  const _::WireValue<uint32_t>* table =
      reinterpret_cast<const _::WireValue<uint32_t>*>(array.begin());

  uint segmentCount = table[0].get() + 1;
  size_t offset = segmentCount / 2u + 1u;

  KJ_REQUIRE(array.size() >= offset, "Message ends prematurely in segment table.") {
    return;
  }

  {
    uint segmentSize = table[1].get();

    KJ_REQUIRE(array.size() >= offset + segmentSize,
               "Message ends prematurely in first segment.") {
      return;
    }

    segment0 = array.slice(offset, offset + segmentSize);
    offset += segmentSize;
  }

  if (segmentCount > 1) {
    moreSegments = kj::heapArray<kj::ArrayPtr<const word>>(segmentCount - 1);

    for (uint i = 1; i < segmentCount; i++) {
      uint segmentSize = table[i + 1].get();

      KJ_REQUIRE(array.size() >= offset + segmentSize, "Message ends prematurely.") {
        moreSegments = nullptr;
        return;
      }

      moreSegments[i - 1] = array.slice(offset, offset + segmentSize);
      offset += segmentSize;
    }
  }

  end = array.begin() + offset;
}

size_t expectedSizeInWordsFromPrefix(kj::ArrayPtr<const word> array) {
  if (array.size() < 1) {
    // All messages are at least one word.
    return 1;
  }

  const _::WireValue<uint32_t>* table =
      reinterpret_cast<const _::WireValue<uint32_t>*>(array.begin());

  uint segmentCount = table[0].get() + 1;
  size_t offset = segmentCount / 2u + 1u;

  // If the array is too small to contain the full segment table, truncate segmentCount to just
  // what is available.
  segmentCount = kj::min(segmentCount, array.size() * 2 - 1u);

  size_t totalSize = offset;
  for (uint i = 0; i < segmentCount; i++) {
    totalSize += table[i + 1].get();
  }
  return totalSize;
}

kj::ArrayPtr<const word> UnalignedFlatArrayMessageReader::getSegment(uint id) {
  if (id == 0) {
    return segment0;
  } else if (id <= moreSegments.size()) {
    return moreSegments[id - 1];
  } else {
    return nullptr;
  }
}

kj::ArrayPtr<const word> FlatArrayMessageReader::checkAlignment(kj::ArrayPtr<const word> array) {
  KJ_REQUIRE((uintptr_t)array.begin() % sizeof(void*) == 0,
      "Input to FlatArrayMessageReader is not aligned. If your architecture supports unaligned "
      "access (e.g. x86/x64/modern ARM), you may use UnalignedFlatArrayMessageReader instead, "
      "though this may harm performance.");

  return array;
}

kj::ArrayPtr<const word> initMessageBuilderFromFlatArrayCopy(
    kj::ArrayPtr<const word> array, MessageBuilder& target, ReaderOptions options) {
  FlatArrayMessageReader reader(array, options);
  target.setRoot(reader.getRoot<AnyPointer>());
  return kj::arrayPtr(reader.getEnd(), array.end());
}

kj::Array<word> messageToFlatArray(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
  kj::Array<word> result = kj::heapArray<word>(computeSerializedSizeInWords(segments));

  _::WireValue<uint32_t>* table =
      reinterpret_cast<_::WireValue<uint32_t>*>(result.begin());

  // We write the segment count - 1 because this makes the first word zero for single-segment
  // messages, improving compression.  We don't bother doing this with segment sizes because
  // one-word segments are rare anyway.
  table[0].set(segments.size() - 1);

  for (uint i = 0; i < segments.size(); i++) {
    table[i + 1].set(segments[i].size());
  }

  if (segments.size() % 2 == 0) {
    // Set padding byte.
    table[segments.size() + 1].set(0);
  }

  word* dst = result.begin() + segments.size() / 2 + 1;

  for (auto& segment: segments) {
    memcpy(dst, segment.begin(), segment.size() * sizeof(word));
    dst += segment.size();
  }

  KJ_DASSERT(dst == result.end(), "Buffer overrun/underrun bug in code above.");

  return kj::mv(result);
}

size_t computeSerializedSizeInWords(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
  KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message.");

  size_t totalSize = segments.size() / 2 + 1;

  for (auto& segment: segments) {
    totalSize += segment.size();
  }

  return totalSize;
}

// =======================================================================================

InputStreamMessageReader::InputStreamMessageReader(
    kj::InputStream& inputStream, ReaderOptions options, kj::ArrayPtr<word> scratchSpace)
    : MessageReader(options), inputStream(inputStream), readPos(nullptr) {
  _::WireValue<uint32_t> firstWord[2];

  inputStream.read(firstWord, sizeof(firstWord));

  uint segmentCount = firstWord[0].get() + 1;
  uint segment0Size = segmentCount == 0 ? 0 : firstWord[1].get();

  size_t totalWords = segment0Size;

  // Reject messages with too many segments for security reasons.
  KJ_REQUIRE(segmentCount < 512, "Message has too many segments.") {
    segmentCount = 1;
    segment0Size = 1;
    break;
  }

  // Read sizes for all segments except the first.  Include padding if necessary.
  KJ_STACK_ARRAY(_::WireValue<uint32_t>, moreSizes, segmentCount & ~1, 16, 64);
  if (segmentCount > 1) {
    inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0]));
    for (uint i = 0; i < segmentCount - 1; i++) {
      totalWords += moreSizes[i].get();
    }
  }

  // Don't accept a message which the receiver couldn't possibly traverse without hitting the
  // traversal limit.  Without this check, a malicious client could transmit a very large segment
  // size to make the receiver allocate excessive space and possibly crash.
  KJ_REQUIRE(totalWords <= options.traversalLimitInWords,
             "Message is too large.  To increase the limit on the receiving end, see "
             "capnp::ReaderOptions.") {
    segmentCount = 1;
    segment0Size = kj::min(segment0Size, options.traversalLimitInWords);
    totalWords = segment0Size;
    break;
  }

  if (scratchSpace.size() < totalWords) {
    // TODO(perf):  Consider allocating each segment as a separate chunk to reduce memory
    //   fragmentation.
    ownedSpace = kj::heapArray<word>(totalWords);
    scratchSpace = ownedSpace;
  }

  segment0 = scratchSpace.slice(0, segment0Size);

  if (segmentCount > 1) {
    moreSegments = kj::heapArray<kj::ArrayPtr<const word>>(segmentCount - 1);
    size_t offset = segment0Size;

    for (uint i = 0; i < segmentCount - 1; i++) {
      uint segmentSize = moreSizes[i].get();
      moreSegments[i] = scratchSpace.slice(offset, offset + segmentSize);
      offset += segmentSize;
    }
  }

  if (segmentCount == 1) {
    inputStream.read(scratchSpace.begin(), totalWords * sizeof(word));
  } else if (segmentCount > 1) {
    readPos = scratchSpace.asBytes().begin();
    readPos += inputStream.read(readPos, segment0Size * sizeof(word), totalWords * sizeof(word));
  }
}

InputStreamMessageReader::~InputStreamMessageReader() noexcept(false) {
  if (readPos != nullptr) {
    unwindDetector.catchExceptionsIfUnwinding([&]() {
      // Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is
      // valid.
      const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end());
      inputStream.skip(allEnd - readPos);
    });
  }
}

kj::ArrayPtr<const word> InputStreamMessageReader::getSegment(uint id) {
  if (id > moreSegments.size()) {
    return nullptr;
  }

  kj::ArrayPtr<const word> segment = id == 0 ? segment0 : moreSegments[id - 1];

  if (readPos != nullptr) {
    // May need to lazily read more data.
    const byte* segmentEnd = reinterpret_cast<const byte*>(segment.end());
    if (readPos < segmentEnd) {
      // Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is
      // valid.
      const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end());
      readPos += inputStream.read(readPos, segmentEnd - readPos, allEnd - readPos);
    }
  }

  return segment;
}

void readMessageCopy(kj::InputStream& input, MessageBuilder& target,
                     ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
  InputStreamMessageReader message(input, options, scratchSpace);
  target.setRoot(message.getRoot<AnyPointer>());
}

// -------------------------------------------------------------------

void writeMessage(kj::OutputStream& output, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
  KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message.");

  KJ_STACK_ARRAY(_::WireValue<uint32_t>, table, (segments.size() + 2) & ~size_t(1), 16, 64);

  // We write the segment count - 1 because this makes the first word zero for single-segment
  // messages, improving compression.  We don't bother doing this with segment sizes because
  // one-word segments are rare anyway.
  table[0].set(segments.size() - 1);
  for (uint i = 0; i < segments.size(); i++) {
    table[i + 1].set(segments[i].size());
  }
  if (segments.size() % 2 == 0) {
    // Set padding byte.
    table[segments.size() + 1].set(0);
  }

  KJ_STACK_ARRAY(kj::ArrayPtr<const byte>, pieces, segments.size() + 1, 4, 32);
  pieces[0] = table.asBytes();

  for (uint i = 0; i < segments.size(); i++) {
    pieces[i + 1] = segments[i].asBytes();
  }

  output.write(pieces);
}

// =======================================================================================

StreamFdMessageReader::~StreamFdMessageReader() noexcept(false) {}

void writeMessageToFd(int fd, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
  kj::FdOutputStream stream(fd);
  writeMessage(stream, segments);
}

void readMessageCopyFromFd(int fd, MessageBuilder& target,
                           ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
  kj::FdInputStream stream(fd);
  readMessageCopy(stream, target, options, scratchSpace);
}

}  // namespace capnp