rpc-twoparty-test.c++ 14.8 KB
Newer Older
Kenton Varda's avatar
Kenton Varda committed
1 2
// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
3
//
Kenton Varda's avatar
Kenton Varda committed
4 5 6 7 8 9
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
10
//
Kenton Varda's avatar
Kenton Varda committed
11 12
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
13
//
Kenton Varda's avatar
Kenton Varda committed
14 15 16 17 18 19 20
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
21 22 23

#include "rpc-twoparty.h"
#include "test-util.h"
Kenton Varda's avatar
Kenton Varda committed
24
#include <capnp/rpc.capnp.h>
25 26
#include <kj/debug.h>
#include <kj/thread.h>
27 28 29 30 31 32 33 34 35 36
#include <kj/compat/gtest.h>

// TODO(cleanup): Auto-generate stringification functions for union discriminants.
namespace capnp {
namespace rpc {
inline kj::String KJ_STRINGIFY(Message::Which which) {
  return kj::str(static_cast<uint16_t>(which));
}
}  // namespace rpc
}  // namespace capnp
37 38 39 40 41 42 43

namespace capnp {
namespace _ {
namespace {

class TestRestorer final: public SturdyRefRestorer<test::TestSturdyRefObjectId> {
public:
44 45
  TestRestorer(int& callCount, int& handleCount)
      : callCount(callCount), handleCount(handleCount) {}
46 47 48 49 50 51 52 53 54

  Capability::Client restore(test::TestSturdyRefObjectId::Reader objectId) override {
    switch (objectId.getTag()) {
      case test::TestSturdyRefObjectId::Tag::TEST_INTERFACE:
        return kj::heap<TestInterfaceImpl>(callCount);
      case test::TestSturdyRefObjectId::Tag::TEST_EXTENDS:
        return Capability::Client(newBrokenCap("No TestExtends implemented."));
      case test::TestSturdyRefObjectId::Tag::TEST_PIPELINE:
        return kj::heap<TestPipelineImpl>(callCount);
Kenton Varda's avatar
Kenton Varda committed
55 56 57 58
      case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLEE:
        return kj::heap<TestTailCalleeImpl>(callCount);
      case test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER:
        return kj::heap<TestTailCallerImpl>(callCount);
59
      case test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF:
60
        return kj::heap<TestMoreStuffImpl>(callCount, handleCount);
61 62 63 64 65 66
    }
    KJ_UNREACHABLE;
  }

private:
  int& callCount;
67
  int& handleCount;
68 69
};

70 71
kj::AsyncIoProvider::PipeThread runServer(kj::AsyncIoProvider& ioProvider,
                                          int& callCount, int& handleCount) {
72
  return ioProvider.newPipeThread(
73 74
      [&callCount, &handleCount](
       kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream, kj::WaitScope& waitScope) {
75
    TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
76
    TestRestorer restorer(callCount, handleCount);
77
    auto server = makeRpcServer(network, restorer);
78
    network.onDisconnect().wait(waitScope);
79
  });
80 81
}

82
Capability::Client getPersistentCap(RpcSystem<rpc::twoparty::VatId>& client,
83 84
                                    rpc::twoparty::Side side,
                                    test::TestSturdyRefObjectId::Tag tag) {
85
  // Create the VatId.
86
  MallocMessageBuilder hostIdMessage(8);
87
  auto hostId = hostIdMessage.initRoot<rpc::twoparty::VatId>();
88
  hostId.setSide(side);
89 90 91

  // Create the SturdyRefObjectId.
  MallocMessageBuilder objectIdMessage(8);
92
  objectIdMessage.initRoot<test::TestSturdyRefObjectId>().setTag(tag);
93 94

  // Connect to the remote capability.
95
  return client.restore(hostId, objectIdMessage.getRoot<AnyPointer>());
96 97 98
}

TEST(TwoPartyNetwork, Basic) {
99
  auto ioContext = kj::setupAsyncIo();
100
  int callCount = 0;
101
  int handleCount = 0;
102

103
  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
104
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
105
  auto rpcClient = makeRpcClient(network);
106

107 108
  // Request the particular capability from the server.
  auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
109 110
      test::TestSturdyRefObjectId::Tag::TEST_INTERFACE).castAs<test::TestInterface>();

111
  // Use the capability.
112 113 114 115 116 117 118 119 120 121 122
  auto request1 = client.fooRequest();
  request1.setI(123);
  request1.setJ(true);
  auto promise1 = request1.send();

  auto request2 = client.bazRequest();
  initTestMessage(request2.initS());
  auto promise2 = request2.send();

  bool barFailed = false;
  auto request3 = client.barRequest();
123
  auto promise3 = request3.send().then(
124 125 126 127 128 129 130 131
      [](Response<test::TestInterface::BarResults>&& response) {
        ADD_FAILURE() << "Expected bar() call to fail.";
      }, [&](kj::Exception&& e) {
        barFailed = true;
      });

  EXPECT_EQ(0, callCount);

132
  auto response1 = promise1.wait(ioContext.waitScope);
133 134 135

  EXPECT_EQ("foo", response1.getX());

136
  auto response2 = promise2.wait(ioContext.waitScope);
137

138
  promise3.wait(ioContext.waitScope);
139 140 141 142 143

  EXPECT_EQ(2, callCount);
  EXPECT_TRUE(barFailed);
}

144
TEST(TwoPartyNetwork, Pipelining) {
145
  auto ioContext = kj::setupAsyncIo();
146
  int callCount = 0;
147
  int handleCount = 0;
148 149
  int reverseCallCount = 0;  // Calls back from server to client.

150
  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
151
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
152
  auto rpcClient = makeRpcClient(network);
153

154
  bool disconnected = false;
155
  kj::Promise<void> disconnectPromise = network.onDisconnect().then([&]() { disconnected = true; });
156

157 158 159 160
  {
    // Request the particular capability from the server.
    auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
        test::TestSturdyRefObjectId::Tag::TEST_PIPELINE).castAs<test::TestPipeline>();
161

162 163 164 165
    {
      // Use the capability.
      auto request = client.getCapRequest();
      request.setN(234);
166
      request.setInCap(kj::heap<TestInterfaceImpl>(reverseCallCount));
167

168
      auto promise = request.send();
169

170 171 172
      auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
      pipelineRequest.setI(321);
      auto pipelinePromise = pipelineRequest.send();
173

174 175 176
      auto pipelineRequest2 = promise.getOutBox().getCap()
          .castAs<test::TestExtends>().graultRequest();
      auto pipelinePromise2 = pipelineRequest2.send();
177

178 179 180 181 182
      promise = nullptr;  // Just to be annoying, drop the original promise.

      EXPECT_EQ(0, callCount);
      EXPECT_EQ(0, reverseCallCount);

183
      auto response = pipelinePromise.wait(ioContext.waitScope);
184 185
      EXPECT_EQ("bar", response.getX());

186
      auto response2 = pipelinePromise2.wait(ioContext.waitScope);
187 188 189 190 191
      checkTestMessage(response2);

      EXPECT_EQ(3, callCount);
      EXPECT_EQ(1, reverseCallCount);
    }
192

193
    EXPECT_FALSE(disconnected);
194

195
    // What if we disconnect?
196
    serverThread.pipe->shutdownWrite();
197

198
    // The other side should also disconnect.
199
    disconnectPromise.wait(ioContext.waitScope);
200 201 202 203 204

    {
      // Use the now-broken capability.
      auto request = client.getCapRequest();
      request.setN(234);
205
      request.setInCap(kj::heap<TestInterfaceImpl>(reverseCallCount));
206 207 208 209 210 211 212 213 214 215 216

      auto promise = request.send();

      auto pipelineRequest = promise.getOutBox().getCap().fooRequest();
      pipelineRequest.setI(321);
      auto pipelinePromise = pipelineRequest.send();

      auto pipelineRequest2 = promise.getOutBox().getCap()
          .castAs<test::TestExtends>().graultRequest();
      auto pipelinePromise2 = pipelineRequest2.send();

217 218
      EXPECT_ANY_THROW(pipelinePromise.wait(ioContext.waitScope));
      EXPECT_ANY_THROW(pipelinePromise2.wait(ioContext.waitScope));
219 220 221 222 223

      EXPECT_EQ(3, callCount);
      EXPECT_EQ(1, reverseCallCount);
    }
  }
224 225
}

226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
TEST(TwoPartyNetwork, Release) {
  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
  auto rpcClient = makeRpcClient(network);

  // Request the particular capability from the server.
  auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
      test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF).castAs<test::TestMoreStuff>();

  auto handle1 = client.getHandleRequest().send().wait(ioContext.waitScope).getHandle();
  auto promise = client.getHandleRequest().send();
  auto handle2 = promise.wait(ioContext.waitScope).getHandle();

  EXPECT_EQ(2, handleCount);

  handle1 = nullptr;

  // There once was a bug where the last outgoing message (and any capabilities attached) would
  // not get cleaned up (until a new message was sent). This appeared to be a bug in Release,
  // becaues if a client received a message and then released a capability from it but then did
  // not make any further calls, then the capability would not be released because the message
  // introducing it remained the last server -> client message (because a "Release" message has
  // no reply). Here we are explicitly trying to catch this bug. This proves tricky, because when
  // we drop a reference on the client side, there's no particular way to wait for the release
  // message to reach the server except to make a subsequent call and wait for the return -- but
255
  // that would mask the bug. So, we wait spin waiting for handleCount to change.
256

257 258 259 260 261 262
  uint maxSpins = 1000;

  while (handleCount > 1) {
    ioContext.provider->getTimer().afterDelay(10 * kj::MILLISECONDS).wait(ioContext.waitScope);
    KJ_ASSERT(--maxSpins > 0);
  }
263 264 265 266 267 268 269 270 271
  EXPECT_EQ(1, handleCount);

  handle2 = nullptr;

  ioContext.provider->getTimer().afterDelay(10 * kj::MILLISECONDS).wait(ioContext.waitScope);
  EXPECT_EQ(1, handleCount);

  promise = nullptr;

272 273 274 275
  while (handleCount > 0) {
    ioContext.provider->getTimer().afterDelay(10 * kj::MILLISECONDS).wait(ioContext.waitScope);
    KJ_ASSERT(--maxSpins > 0);
  }
276 277 278
  EXPECT_EQ(0, handleCount);
}

Kenton Varda's avatar
Kenton Varda committed
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
TEST(TwoPartyNetwork, Abort) {
  // Verify that aborts are received.

  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);

  MallocMessageBuilder refMessage(128);
  auto hostId = refMessage.initRoot<rpc::twoparty::VatId>();
  hostId.setSide(rpc::twoparty::Side::SERVER);

  auto conn = KJ_ASSERT_NONNULL(network.connect(hostId));

  {
    // Send an invalid message (Return to non-existent question).
    auto msg = conn->newOutgoingMessage(128);
    auto body = msg->getBody().initAs<rpc::Message>().initReturn();
    body.setAnswerId(1234);
    body.setCanceled();
    msg->send();
  }

  auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(ioContext.waitScope));
  EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs<rpc::Message>().which());

  EXPECT_TRUE(conn->receiveIncomingMessage().wait(ioContext.waitScope) == nullptr);
}

310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
TEST(TwoPartyNetwork, ConvenienceClasses) {
  auto ioContext = kj::setupAsyncIo();

  int callCount = 0;
  TwoPartyServer server(kj::heap<TestInterfaceImpl>(callCount));

  auto address = ioContext.provider->getNetwork()
      .parseAddress("127.0.0.1").wait(ioContext.waitScope);

  auto listener = address->listen();
  auto listenPromise = server.listen(*listener);

  address = ioContext.provider->getNetwork()
      .parseAddress("127.0.0.1", listener->getPort()).wait(ioContext.waitScope);

  auto connection = address->connect().wait(ioContext.waitScope);
  TwoPartyClient client(*connection);
  auto cap = client.bootstrap().castAs<test::TestInterface>();

  auto request = cap.fooRequest();
  request.setI(123);
  request.setJ(true);
  EXPECT_EQ(0, callCount);
  auto response = request.send().wait(ioContext.waitScope);
  EXPECT_EQ("foo", response.getX());
  EXPECT_EQ(1, callCount);
}

338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
TEST(TwoPartyNetwork, HugeMessage) {
  auto ioContext = kj::setupAsyncIo();
  int callCount = 0;
  int handleCount = 0;

  auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
  TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
  auto rpcClient = makeRpcClient(network);

  auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
      test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF).castAs<test::TestMoreStuff>();

  // Oversized request fails.
  {
    auto req = client.methodWithDefaultsRequest();
    req.initA(100000000);  // 100 MB
354 355 356

    KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("larger than the single-message size limit",
        req.send().ignoreResult().wait(ioContext.waitScope));
357 358 359
  }

  // Oversized response fails.
360 361
  KJ_EXPECT_THROW_RECOVERABLE_MESSAGE("larger than the single-message size limit",
      client.getEnormousStringRequest().send().ignoreResult().wait(ioContext.waitScope));
362 363 364 365 366 367 368 369 370

  // Connection is still up.
  {
    auto req = client.getCallSequenceRequest();
    req.setExpected(0);
    KJ_EXPECT(req.send().wait(ioContext.waitScope).getN() == 0);
  }
}

Kenton Varda's avatar
Kenton Varda committed
371
class TestAuthenticatedBootstrapImpl final
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
    : public test::TestAuthenticatedBootstrap<rpc::twoparty::VatId>::Server {
public:
  TestAuthenticatedBootstrapImpl(rpc::twoparty::VatId::Reader clientId) {
    this->clientId.setRoot(clientId);
  }

protected:
  kj::Promise<void> getCallerId(GetCallerIdContext context) override {
    context.getResults().setCaller(clientId.getRoot<rpc::twoparty::VatId>());
    return kj::READY_NOW;
  }

private:
  MallocMessageBuilder clientId;
};

class TestBootstrapFactory: public BootstrapFactory<rpc::twoparty::VatId> {
public:
  Capability::Client createFor(rpc::twoparty::VatId::Reader clientId) {
    called = true;
    EXPECT_EQ(rpc::twoparty::Side::CLIENT, clientId.getSide());
    return kj::heap<TestAuthenticatedBootstrapImpl>(clientId);
  }

  bool called = false;
};

kj::AsyncIoProvider::PipeThread runAuthenticatingServer(
    kj::AsyncIoProvider& ioProvider, BootstrapFactory<rpc::twoparty::VatId>& bootstrapFactory) {
  return ioProvider.newPipeThread([&bootstrapFactory](
      kj::AsyncIoProvider& ioProvider, kj::AsyncIoStream& stream, kj::WaitScope& waitScope) {
    TwoPartyVatNetwork network(stream, rpc::twoparty::Side::SERVER);
    auto server = makeRpcServer(network, bootstrapFactory);
    network.onDisconnect().wait(waitScope);
  });
}

TEST(TwoPartyNetwork, BootstrapFactory) {
  auto ioContext = kj::setupAsyncIo();
  TestBootstrapFactory bootstrapFactory;
  auto serverThread = runAuthenticatingServer(*ioContext.provider, bootstrapFactory);
  TwoPartyClient client(*serverThread.pipe);
  auto resp = client.bootstrap().castAs<test::TestAuthenticatedBootstrap<rpc::twoparty::VatId>>()
      .getCallerIdRequest().send().wait(ioContext.waitScope);
  EXPECT_EQ(rpc::twoparty::Side::CLIENT, resp.getCaller().getSide());
  EXPECT_TRUE(bootstrapFactory.called);
}

420 421 422
}  // namespace
}  // namespace _
}  // namespace capnp