rpc-twoparty-test.c++ 14.9 KB
Newer Older
Kenton Varda's avatar
Kenton Varda committed
1 2
// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
3
//
Kenton Varda's avatar
Kenton Varda committed
4 5 6 7 8 9
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
10
//
Kenton Varda's avatar
Kenton Varda committed
11 12
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
13
//
Kenton Varda's avatar
Kenton Varda committed
14 15 16 17 18 19 20
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
21

22 23
#define CAPNP_TESTING_CAPNP 1

24 25
#include "rpc-twoparty.h"
#include "test-util.h"
Kenton Varda's avatar
Kenton Varda committed
26
#include <capnp/rpc.capnp.h>
27 28
#include <kj/debug.h>
#include <kj/thread.h>
29 30 31 32 33 34 35 36 37 38
#include <kj/compat/gtest.h>

// TODO(cleanup): Auto-generate stringification functions for union discriminants.
namespace capnp {
namespace rpc {
inline kj::String KJ_STRINGIFY(Message::Which which) {
  return kj::str(static_cast<uint16_t>(which));
}
}  // namespace rpc
}  // namespace capnp
39 40 41 42 43 44 45

namespace capnp {
namespace _ {
namespace {

class TestRestorer final: public SturdyRefRestorer<test::TestSturdyRefObjectId> {
public:
46 47
  TestRestorer(int& callCount, int& handleCount)
      : callCount(callCount), handleCount(handleCount) {}
48 49 50 51 52 53 54 55 56

  Capability::Client restore(test::TestSturdyRefObjectId::Reader objectId) override {
    switch (objectId.getTag()) {
      case test::TestSturdyRefObjectId::Tag::TEST_INTERFACE:
        return kj::heap<TestInterfaceImpl>(callCount);
      case test::TestSturdyRefObjectId::Tag::TEST_EXTENDS:
        return Capability::Client(newBrokenCap("No TestExtends implemented."));
      case test::TestSturdyRefObjectId::Tag::TEST_PIPELINE:
        return kj::heap<TestPipelineImpl>(callCount);
Kenton Varda's avatar
Kenton Varda committed
57 58 59 60
      case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLEE:
        return kj::heap<TestTailCalleeImpl>(callCount);
      case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER:
        return kj::heap<TestTailCallerImpl>(callCount);
61
      case test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF:
62
        return kj::heap<TestMoreStuffImpl>(callCount, handleCount);
63 64 65 66 67 68
    }
    KJ_UNREACHABLE;
  }

private:
  int& callCount;
69
  int& handleCount;
70 71
};

72 73
kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider,
                                          int& callCount, int& handleCount) {
74
  return ioProvider.newPipeThread(
75 76
      [&callCount, &handleCount](
       kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream, kj::WaitScope& waitScope) {
77
    TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
78
    TestRestorer restorer(callCount, handleCount);
79
    auto server = makeRpcServer(network, restorer);
80
    network.onDisconnect().wait(waitScope);
81
  });
82 83
}

84
Capability::Client getPersistentCap(RpcSystem<rpc::twoparty::VatId>& client,
85 86
                                    rpc::twoparty::Side side,
                                    test::TestSturdyRefObjectId::Tag tag) {
87
  // Create the VatId.
88
  MallocMessageBuilder hostIdMessage(8);
89
  auto hostId = hostIdMessage.initRoot<rpc::twoparty::VatId>();
90
  hostId.setSide(side);
91 92 93

  // Create the SturdyRefObjectId.
  MallocMessageBuilder objectIdMessage(8);
94
  objectIdMessage.initRoot<test::TestSturdyRefObjectId>().setTag(tag);
95 96

  // Connect to the remote capability.
97
  return client.restore(hostId, objectIdMessage.getRoot<AnyPointer>());
98 99 100
}

TEST(TwoPartyNetwork, Basic) {
101
  auto ioContext = kj::setupAsyncIo();
102
  int callCount = 0;
103
  int handleCount = 0;
104

105
  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
106
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
107
  auto rpcClient = makeRpcClient(network);
108

109 110
  // Request the particular capability from the server.
  auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
111 112
      test::TestSturdyRefObjectId::Tag::TEST_INTERFACE).castAs<test::TestInterface>();

113
  // Use the capability.
114 115 116 117 118 119 120 121 122 123 124
  auto request1 = client.fooRequest();
  request1.setI(123);
  request1.setJ(true);
  auto promise1 = request1.send();

  auto request2 = client.bazRequest();
  initTestMessage(request2.initS());
  auto promise2 = request2.send();

  bool barFailed = false;
  auto request3 = client.barRequest();
125
  auto promise3 = request3.send().then(
126 127 128 129 130 131 132 133
      [](Response<test::TestInterface::BarResults>&& response) {
        ADD_FAILURE() << "Expected bar() call to fail.";
      }, [&](kj::Exception&& e) {
        barFailed = true;
      });

  EXPECT_EQ(0, callCount);

134
  auto response1 = promise1.wait(ioContext.waitScope);
135 136 137

  EXPECT_EQ("foo", response1.getX());

138
  auto response2 = promise2.wait(ioContext.waitScope);
139

140
  promise3.wait(ioContext.waitScope);
141 142 143 144 145

  EXPECT_EQ(2, callCount);
  EXPECT_TRUE(barFailed);
}

146
TEST(TwoPartyNetwork, Pipelining) {
147
  auto ioContext = kj::setupAsyncIo();
148
  int callCount = 0;
149
  int handleCount = 0;
150 151
  int reverseCallCount = 0;  // Calls back from server to client.

152
  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
153
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
154
  auto rpcClient = makeRpcClient(network);
155

156
  bool disconnected = false;
157
  kj::Promise<void> disconnectPromise = network.onDisconnect().then([&]() { disconnected = true; });
158

159 160 161 162
  {
    // Request the particular capability from the server.
    auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
        test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>();
163

164 165 166 167
    {
      // Use the capability.
      auto request = client.getCapRequest();
      request.setN(234);
168
      request.setInCap(kj::heap<TestInterfaceImpl>(reverseCallCount));
169

170
      auto promise = request.send();
171

172 173 174
      auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
      pipelineRequest.setI(321);
      auto pipelinePromise = pipelineRequest.send();
175

176 177 178
      auto pipelineRequest2 = promise.getOutBox().getCap()
          .castAs<test::TestExtends>().graultRequest();
      auto pipelinePromise2 = pipelineRequest2.send();
179

180 181 182 183 184
      promise = nullptr;  // Just to be annoying, drop the original promise.

      EXPECT_EQ(0, callCount);
      EXPECT_EQ(0, reverseCallCount);

185
      auto response = pipelinePromise.wait(ioContext.waitScope);
186 187
      EXPECT_EQ("bar", response.getX());

188
      auto response2 = pipelinePromise2.wait(ioContext.waitScope);
189 190 191 192 193
      checkTestMessage(response2);

      EXPECT_EQ(3, callCount);
      EXPECT_EQ(1, reverseCallCount);
    }
194

195
    EXPECT_FALSE(disconnected);
196

197
    // What if we disconnect?
198
    serverThread.pipe->shutdownWrite();
199

200
    // The other side should also disconnect.
201
    disconnectPromise.wait(ioContext.waitScope);
202 203 204 205 206

    {
      // Use the now-broken capability.
      auto request = client.getCapRequest();
      request.setN(234);
207
      request.setInCap(kj::heap<TestInterfaceImpl>(reverseCallCount));
208 209 210 211 212 213 214 215 216 217 218

      auto promise = request.send();

      auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
      pipelineRequest.setI(321);
      auto pipelinePromise = pipelineRequest.send();

      auto pipelineRequest2 = promise.getOutBox().getCap()
          .castAs<test::TestExtends>().graultRequest();
      auto pipelinePromise2 = pipelineRequest2.send();

219 220
      EXPECT_ANY_THROW(pipelinePromise.wait(ioContext.waitScope));
      EXPECT_ANY_THROW(pipelinePromise2.wait(ioContext.waitScope));
221 222 223 224 225

      EXPECT_EQ(3, callCount);
      EXPECT_EQ(1, reverseCallCount);
    }
  }
226 227
}

228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
TEST(TwoPartyNetwork, Release) {
  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
  auto rpcClient = makeRpcClient(network);

  // Request the particular capability from the server.
  auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
      test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF).castAs<test::TestMoreStuff>();

  auto handle1 = client.getHandleRequest().send().wait(ioContext.waitScope).getHandle();
  auto promise = client.getHandleRequest().send();
  auto handle2 = promise.wait(ioContext.waitScope).getHandle();

  EXPECT_EQ(2, handleCount);

  handle1 = nullptr;

  // There once was a bug where the last outgoing message (and any capabilities attached) would
  // not get cleaned up (until a new message was sent). This appeared to be a bug in Release,
  // becaues if a client received a message and then released a capability from it but then did
  // not make any further calls, then the capability would not be released because the message
  // introducing it remained the last server -> client message (because a "Release" message has
  // no reply). Here we are explicitly trying to catch this bug. This proves tricky, because when
  // we drop a reference on the client side, there's no particular way to wait for the release
  // message to reach the server except to make a subsequent call and wait for the return -- but
257
  // that would mask the bug. So, we wait spin waiting for handleCount to change.
258

259 260 261 262 263 264
  uint maxSpins = 1000;

  while (handleCount > 1) {
    ioContext.provider->getTimer().afterDelay(10 * kj::MILLISECONDS).wait(ioContext.waitScope);
    KJ_ASSERT(--maxSpins > 0);
  }
265 266 267 268 269 270 271 272 273
  EXPECT_EQ(1, handleCount);

  handle2 = nullptr;

  ioContext.provider->getTimer().afterDelay(10 * kj::MILLISECONDS).wait(ioContext.waitScope);
  EXPECT_EQ(1, handleCount);

  promise = nullptr;

274 275 276 277
  while (handleCount > 0) {
    ioContext.provider->getTimer().afterDelay(10 * kj::MILLISECONDS).wait(ioContext.waitScope);
    KJ_ASSERT(--maxSpins > 0);
  }
278 279 280
  EXPECT_EQ(0, handleCount);
}

Kenton Varda's avatar
Kenton Varda committed
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
TEST(TwoPartyNetwork, Abort) {
  // Verify that aborts are received.

  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);

  MallocMessageBuilder refMessage(128);
  auto hostId = refMessage.initRoot<rpc::twoparty::VatId>();
  hostId.setSide(rpc::twoparty::Side::SERVER);

  auto conn = KJ_ASSERT_NONNULL(network.connect(hostId));

  {
    // Send an invalid message (Return to non-existent question).
    auto msg = conn->newOutgoingMessage(128);
    auto body = msg->getBody().initAs<rpc::Message>().initReturn();
    body.setAnswerId(1234);
    body.setCanceled();
    msg->send();
  }

  auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(ioContext.waitScope));
  EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs<rpc::Message>().which());

  EXPECT_TRUE(conn->receiveIncomingMessage().wait(ioContext.waitScope) == nullptr);
}

312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
TEST(TwoPartyNetwork, ConvenienceClasses) {
  auto ioContext = kj::setupAsyncIo();

  int callCount = 0;
  TwoPartyServer server(kj::heap<TestInterfaceImpl>(callCount));

  auto address = ioContext.provider->getNetwork()
      .parseAddress("127.0.0.1").wait(ioContext.waitScope);

  auto listener = address->listen();
  auto listenPromise = server.listen(*listener);

  address = ioContext.provider->getNetwork()
      .parseAddress("127.0.0.1", listener->getPort()).wait(ioContext.waitScope);

  auto connection = address->connect().wait(ioContext.waitScope);
  TwoPartyClient client(*connection);
  auto cap = client.bootstrap().castAs<test::TestInterface>();

  auto request = cap.fooRequest();
  request.setI(123);
  request.setJ(true);
  EXPECT_EQ(0, callCount);
  auto response = request.send().wait(ioContext.waitScope);
  EXPECT_EQ("foo", response.getX());
  EXPECT_EQ(1, callCount);
}

340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
TEST(TwoPartyNetwork, HugeMessage) {
  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
  auto rpcClient = makeRpcClient(network);

  auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
      test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF).castAs<test::TestMoreStuff>();

  // Oversized request fails.
  {
    auto req = client.methodWithDefaultsRequest();
    req.initA(100000000);  // 100 MB
356

357
    KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("larger than our single-message size limit",
358
        req.send().ignoreResult().wait(ioContext.waitScope));
359 360 361
  }

  // Oversized response fails.
362
  KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("larger than our single-message size limit",
363
      client.getEnormousStringRequest().send().ignoreResult().wait(ioContext.waitScope));
364 365 366 367 368 369 370 371 372

  // Connection is still up.
  {
    auto req = client.getCallSequenceRequest();
    req.setExpected(0);
    KJ_EXPECT(req.send().wait(ioContext.waitScope).getN() == 0);
  }
}

Kenton Varda's avatar
Kenton Varda committed
373
class TestAuthenticatedBootstrapImpl final
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421
    : public test::TestAuthenticatedBootstrap<rpc::twoparty::VatId>::Server {
public:
  TestAuthenticatedBootstrapImpl(rpc::twoparty::VatId::Reader clientId) {
    this->clientId.setRoot(clientId);
  }

protected:
  kj::Promise<void> getCallerId(GetCallerIdContext context) override {
    context.getResults().setCaller(clientId.getRoot<rpc::twoparty::VatId>());
    return kj::READY_NOW;
  }

private:
  MallocMessageBuilder clientId;
};

class TestBootstrapFactory: public BootstrapFactory<rpc::twoparty::VatId> {
public:
  Capability::Client createFor(rpc::twoparty::VatId::Reader clientId) {
    called = true;
    EXPECT_EQ(rpc::twoparty::Side::CLIENT, clientId.getSide());
    return kj::heap<TestAuthenticatedBootstrapImpl>(clientId);
  }

  bool called = false;
};

kj::AsyncIoProvider::PipeThread runAuthenticatingServer(
    kj::AsyncIoProvider& ioProvider, BootstrapFactory<rpc::twoparty::VatId>& bootstrapFactory) {
  return ioProvider.newPipeThread([&bootstrapFactory](
      kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream, kj::WaitScope& waitScope) {
    TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
    auto server = makeRpcServer(network, bootstrapFactory);
    network.onDisconnect().wait(waitScope);
  });
}

TEST(TwoPartyNetwork, BootstrapFactory) {
  auto ioContext = kj::setupAsyncIo();
  TestBootstrapFactory bootstrapFactory;
  auto serverThread = runAuthenticatingServer(*ioContext.provider, bootstrapFactory);
  TwoPartyClient client(*serverThread.pipe);
  auto resp = client.bootstrap().castAs<test::TestAuthenticatedBootstrap<rpc::twoparty::VatId>>()
      .getCallerIdRequest().send().wait(ioContext.waitScope);
  EXPECT_EQ(rpc::twoparty::Side::CLIENT, resp.getCaller().getSide());
  EXPECT_TRUE(bootstrapFactory.called);
}

422 423 424
}  // namespace
}  // namespace _
}  // namespace capnp