// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#define CAPNP_TESTING_CAPNP 1

#include "rpc-twoparty.h"
#include "test-util.h"
#include <capnp/rpc.capnp.h>
#include <kj/debug.h>
#include <kj/thread.h>
#include <kj/compat/gtest.h>

// TODO(cleanup): Auto-generate stringification functions for union discriminants.
namespace capnp {
namespace rpc {
inline kj::String KJ_STRINGIFY(Message::Which which) {
  return kj::str(static_cast<uint16_t>(which));
}
}  // namespace rpc
}  // namespace capnp

namespace capnp {
namespace _ {
namespace {

class TestRestorer final: public SturdyRefRestorer<test::TestSturdyRefObjectId> {
public:
  TestRestorer(int& callCount, int& handleCount)
      : callCount(callCount), handleCount(handleCount) {}

  Capability::Client restore(test::TestSturdyRefObjectId::Reader objectId) override {
    switch (objectId.getTag()) {
      case test::TestSturdyRefObjectId::Tag::TEST_INTERFACE:
        return kj::heap<TestInterfaceImpl>(callCount);
      case test::TestSturdyRefObjectId::Tag::TEST_EXTENDS:
        return Capability::Client(newBrokenCap("No TestExtends implemented."));
      case test::TestSturdyRefObjectId::Tag::TEST_PIPELINE:
        return kj::heap<TestPipelineImpl>(callCount);
      case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLEE:
        return kj::heap<TestTailCalleeImpl>(callCount);
      case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER:
        return kj::heap<TestTailCallerImpl>(callCount);
      case test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF:
        return kj::heap<TestMoreStuffImpl>(callCount, handleCount);
    }
    KJ_UNREACHABLE;
  }

private:
  int& callCount;
  int& handleCount;
};

kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider,
                                          int& callCount, int& handleCount) {
  return ioProvider.newPipeThread(
      [&callCount, &handleCount](
       kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream, kj::WaitScope& waitScope) {
    TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
    TestRestorer restorer(callCount, handleCount);
    auto server = makeRpcServer(network, restorer);
    network.onDisconnect().wait(waitScope);
  });
}

Capability::Client getPersistentCap(RpcSystem<rpc::twoparty::VatId>& client,
                                    rpc::twoparty::Side side,
                                    test::TestSturdyRefObjectId::Tag tag) {
  // Create the VatId.
  MallocMessageBuilder hostIdMessage(8);
  auto hostId = hostIdMessage.initRoot<rpc::twoparty::VatId>();
  hostId.setSide(side);

  // Create the SturdyRefObjectId.
  MallocMessageBuilder objectIdMessage(8);
  objectIdMessage.initRoot<test::TestSturdyRefObjectId>().setTag(tag);

  // Connect to the remote capability.
  return client.restore(hostId, objectIdMessage.getRoot<AnyPointer>());
}

TEST(TwoPartyNetwork, Basic) {
  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
  auto rpcClient = makeRpcClient(network);

  // Request the particular capability from the server.
  auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
      test::TestSturdyRefObjectId::Tag::TEST_INTERFACE).castAs<test::TestInterface>();

  // Use the capability.
  auto request1 = client.fooRequest();
  request1.setI(123);
  request1.setJ(true);
  auto promise1 = request1.send();

  auto request2 = client.bazRequest();
  initTestMessage(request2.initS());
  auto promise2 = request2.send();

  bool barFailed = false;
  auto request3 = client.barRequest();
  auto promise3 = request3.send().then(
      [](Response<test::TestInterface::BarResults>&& response) {
        ADD_FAILURE() << "Expected bar() call to fail.";
      }, [&](kj::Exception&& e) {
        barFailed = true;
      });

  EXPECT_EQ(0, callCount);

  auto response1 = promise1.wait(ioContext.waitScope);

  EXPECT_EQ("foo", response1.getX());

  auto response2 = promise2.wait(ioContext.waitScope);

  promise3.wait(ioContext.waitScope);

  EXPECT_EQ(2, callCount);
  EXPECT_TRUE(barFailed);
}

TEST(TwoPartyNetwork, Pipelining) {
  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;
  int reverseCallCount = 0;  // Calls back from server to client.

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
  auto rpcClient = makeRpcClient(network);

  bool disconnected = false;
  kj::Promise<void> disconnectPromise = network.onDisconnect().then([&]() { disconnected = true; });

  {
    // Request the particular capability from the server.
    auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
        test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>();

    {
      // Use the capability.
      auto request = client.getCapRequest();
      request.setN(234);
      request.setInCap(kj::heap<TestInterfaceImpl>(reverseCallCount));

      auto promise = request.send();

      auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
      pipelineRequest.setI(321);
      auto pipelinePromise = pipelineRequest.send();

      auto pipelineRequest2 = promise.getOutBox().getCap()
          .castAs<test::TestExtends>().graultRequest();
      auto pipelinePromise2 = pipelineRequest2.send();

      promise = nullptr;  // Just to be annoying, drop the original promise.

      EXPECT_EQ(0, callCount);
      EXPECT_EQ(0, reverseCallCount);

      auto response = pipelinePromise.wait(ioContext.waitScope);
      EXPECT_EQ("bar", response.getX());

      auto response2 = pipelinePromise2.wait(ioContext.waitScope);
      checkTestMessage(response2);

      EXPECT_EQ(3, callCount);
      EXPECT_EQ(1, reverseCallCount);
    }

    EXPECT_FALSE(disconnected);

    // What if we disconnect?
    serverThread.pipe->shutdownWrite();

    // The other side should also disconnect.
    disconnectPromise.wait(ioContext.waitScope);

    {
      // Use the now-broken capability.
      auto request = client.getCapRequest();
      request.setN(234);
      request.setInCap(kj::heap<TestInterfaceImpl>(reverseCallCount));

      auto promise = request.send();

      auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
      pipelineRequest.setI(321);
      auto pipelinePromise = pipelineRequest.send();

      auto pipelineRequest2 = promise.getOutBox().getCap()
          .castAs<test::TestExtends>().graultRequest();
      auto pipelinePromise2 = pipelineRequest2.send();

      EXPECT_ANY_THROW(pipelinePromise.wait(ioContext.waitScope));
      EXPECT_ANY_THROW(pipelinePromise2.wait(ioContext.waitScope));

      EXPECT_EQ(3, callCount);
      EXPECT_EQ(1, reverseCallCount);
    }
  }
}

TEST(TwoPartyNetwork, Release) {
  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
  auto rpcClient = makeRpcClient(network);

  // Request the particular capability from the server.
  auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
      test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF).castAs<test::TestMoreStuff>();

  auto handle1 = client.getHandleRequest().send().wait(ioContext.waitScope).getHandle();
  auto promise = client.getHandleRequest().send();
  auto handle2 = promise.wait(ioContext.waitScope).getHandle();

  EXPECT_EQ(2, handleCount);

  handle1 = nullptr;

  // There once was a bug where the last outgoing message (and any capabilities attached) would
  // not get cleaned up (until a new message was sent). This appeared to be a bug in Release,
  // becaues if a client received a message and then released a capability from it but then did
  // not make any further calls, then the capability would not be released because the message
  // introducing it remained the last server -> client message (because a "Release" message has
  // no reply). Here we are explicitly trying to catch this bug. This proves tricky, because when
  // we drop a reference on the client side, there's no particular way to wait for the release
  // message to reach the server except to make a subsequent call and wait for the return -- but
  // that would mask the bug. So, we wait spin waiting for handleCount to change.

  uint maxSpins = 1000;

  while (handleCount > 1) {
    ioContext.provider->getTimer().afterDelay(10 * kj::MILLISECONDS).wait(ioContext.waitScope);
    KJ_ASSERT(--maxSpins > 0);
  }
  EXPECT_EQ(1, handleCount);

  handle2 = nullptr;

  ioContext.provider->getTimer().afterDelay(10 * kj::MILLISECONDS).wait(ioContext.waitScope);
  EXPECT_EQ(1, handleCount);

  promise = nullptr;

  while (handleCount > 0) {
    ioContext.provider->getTimer().afterDelay(10 * kj::MILLISECONDS).wait(ioContext.waitScope);
    KJ_ASSERT(--maxSpins > 0);
  }
  EXPECT_EQ(0, handleCount);
}

TEST(TwoPartyNetwork, Abort) {
  // Verify that aborts are received.

  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);

  MallocMessageBuilder refMessage(128);
  auto hostId = refMessage.initRoot<rpc::twoparty::VatId>();
  hostId.setSide(rpc::twoparty::Side::SERVER);

  auto conn = KJ_ASSERT_NONNULL(network.connect(hostId));

  {
    // Send an invalid message (Return to non-existent question).
    auto msg = conn->newOutgoingMessage(128);
    auto body = msg->getBody().initAs<rpc::Message>().initReturn();
    body.setAnswerId(1234);
    body.setCanceled();
    msg->send();
  }

  auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(ioContext.waitScope));
  EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs<rpc::Message>().which());

  EXPECT_TRUE(conn->receiveIncomingMessage().wait(ioContext.waitScope) == nullptr);
}

TEST(TwoPartyNetwork, ConvenienceClasses) {
  auto ioContext = kj::setupAsyncIo();

  int callCount = 0;
  TwoPartyServer server(kj::heap<TestInterfaceImpl>(callCount));

  auto address = ioContext.provider->getNetwork()
      .parseAddress("127.0.0.1").wait(ioContext.waitScope);

  auto listener = address->listen();
  auto listenPromise = server.listen(*listener);

  address = ioContext.provider->getNetwork()
      .parseAddress("127.0.0.1", listener->getPort()).wait(ioContext.waitScope);

  auto connection = address->connect().wait(ioContext.waitScope);
  TwoPartyClient client(*connection);
  auto cap = client.bootstrap().castAs<test::TestInterface>();

  auto request = cap.fooRequest();
  request.setI(123);
  request.setJ(true);
  EXPECT_EQ(0, callCount);
  auto response = request.send().wait(ioContext.waitScope);
  EXPECT_EQ("foo", response.getX());
  EXPECT_EQ(1, callCount);
}

TEST(TwoPartyNetwork, HugeMessage) {
  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
  auto rpcClient = makeRpcClient(network);

  auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
      test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF).castAs<test::TestMoreStuff>();

  // Oversized request fails.
  {
    auto req = client.methodWithDefaultsRequest();
    req.initA(100000000);  // 100 MB

    KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("larger than our single-message size limit",
        req.send().ignoreResult().wait(ioContext.waitScope));
  }

  // Oversized response fails.
  KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("larger than our single-message size limit",
      client.getEnormousStringRequest().send().ignoreResult().wait(ioContext.waitScope));

  // Connection is still up.
  {
    auto req = client.getCallSequenceRequest();
    req.setExpected(0);
    KJ_EXPECT(req.send().wait(ioContext.waitScope).getN() == 0);
  }
}

class TestAuthenticatedBootstrapImpl final
    : public test::TestAuthenticatedBootstrap<rpc::twoparty::VatId>::Server {
public:
  TestAuthenticatedBootstrapImpl(rpc::twoparty::VatId::Reader clientId) {
    this->clientId.setRoot(clientId);
  }

protected:
  kj::Promise<void> getCallerId(GetCallerIdContext context) override {
    context.getResults().setCaller(clientId.getRoot<rpc::twoparty::VatId>());
    return kj::READY_NOW;
  }

private:
  MallocMessageBuilder clientId;
};

class TestBootstrapFactory: public BootstrapFactory<rpc::twoparty::VatId> {
public:
  Capability::Client createFor(rpc::twoparty::VatId::Reader clientId) {
    called = true;
    EXPECT_EQ(rpc::twoparty::Side::CLIENT, clientId.getSide());
    return kj::heap<TestAuthenticatedBootstrapImpl>(clientId);
  }

  bool called = false;
};

kj::AsyncIoProvider::PipeThread runAuthenticatingServer(
    kj::AsyncIoProvider& ioProvider, BootstrapFactory<rpc::twoparty::VatId>& bootstrapFactory) {
  return ioProvider.newPipeThread([&bootstrapFactory](
      kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream, kj::WaitScope& waitScope) {
    TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
    auto server = makeRpcServer(network, bootstrapFactory);
    network.onDisconnect().wait(waitScope);
  });
}

TEST(TwoPartyNetwork, BootstrapFactory) {
  auto ioContext = kj::setupAsyncIo();
  TestBootstrapFactory bootstrapFactory;
  auto serverThread = runAuthenticatingServer(*ioContext.provider, bootstrapFactory);
  TwoPartyClient client(*serverThread.pipe);
  auto resp = client.bootstrap().castAs<test::TestAuthenticatedBootstrap<rpc::twoparty::VatId>>()
      .getCallerIdRequest().send().wait(ioContext.waitScope);
  EXPECT_EQ(rpc::twoparty::Side::CLIENT, resp.getCaller().getSide());
  EXPECT_TRUE(bootstrapFactory.called);
}

}  // namespace
}  // namespace _
}  // namespace capnp