ez-rpc.c++ 13 KB
Newer Older
Kenton Varda's avatar
Kenton Varda committed
1 2
// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
3
//
Kenton Varda's avatar
Kenton Varda committed
4 5 6 7 8 9
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
10
//
Kenton Varda's avatar
Kenton Varda committed
11 12
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
13
//
Kenton Varda's avatar
Kenton Varda committed
14 15 16 17 18 19 20
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
21 22 23 24 25 26

#include "ez-rpc.h"
#include "rpc-twoparty.h"
#include <capnp/rpc.capnp.h>
#include <kj/async-io.h>
#include <kj/debug.h>
27
#include <kj/threadlocal.h>
28 29 30 31
#include <map>

namespace capnp {

32
KJ_THREADLOCAL_PTR(EzRpcContext) threadEzContext = nullptr;
33 34 35

class EzRpcContext: public kj::Refcounted {
public:
36
  EzRpcContext(): ioContext(kj::setupAsyncIo()) {
37 38 39 40 41 42 43 44 45 46 47
    threadEzContext = this;
  }

  ~EzRpcContext() noexcept(false) {
    KJ_REQUIRE(threadEzContext == this,
               "EzRpcContext destroyed from different thread than it was created.") {
      return;
    }
    threadEzContext = nullptr;
  }

48 49 50 51
  kj::WaitScope& getWaitScope() {
    return ioContext.waitScope;
  }

52
  kj::AsyncIoProvider& getIoProvider() {
53 54 55 56 57
    return *ioContext.provider;
  }

  kj::LowLevelAsyncIoProvider& getLowLevelIoProvider() {
    return *ioContext.lowLevelProvider;
58 59 60 61 62 63 64 65 66 67 68 69
  }

  static kj::Own<EzRpcContext> getThreadLocal() {
    EzRpcContext* existing = threadEzContext;
    if (existing != nullptr) {
      return kj::addRef(*existing);
    } else {
      return kj::refcounted<EzRpcContext>();
    }
  }

private:
70
  kj::AsyncIoContext ioContext;
71 72 73 74
};

// =======================================================================================

75 76 77 78
kj::Promise<kj::Own<kj::AsyncIoStream>> connectAttach(kj::Own<kj::NetworkAddress>&& addr) {
  return addr->connect().attach(kj::mv(addr));
}

79 80 81 82 83 84
struct EzRpcClient::Impl {
  kj::Own<EzRpcContext> context;

  struct ClientContext {
    kj::Own<kj::AsyncIoStream> stream;
    TwoPartyVatNetwork network;
85
    RpcSystem<rpc::twoparty::VatId> rpcSystem;
86

87
    ClientContext(kj::Own<kj::AsyncIoStream>&& stream, ReaderOptions readerOpts)
88
        : stream(kj::mv(stream)),
89
          network(*this->stream, rpc::twoparty::Side::CLIENT, readerOpts),
90 91
          rpcSystem(makeRpcClient(network)) {}

92 93 94 95 96 97 98 99 100
    Capability::Client getMain() {
      word scratch[4];
      memset(scratch, 0, sizeof(scratch));
      MallocMessageBuilder message(scratch);
      auto hostId = message.getRoot<rpc::twoparty::VatId>();
      hostId.setSide(rpc::twoparty::Side::SERVER);
      return rpcSystem.bootstrap(hostId);
    }

101 102 103 104
    Capability::Client restore(kj::StringPtr name) {
      word scratch[64];
      memset(scratch, 0, sizeof(scratch));
      MallocMessageBuilder message(scratch);
105 106 107

      auto hostIdOrphan = message.getOrphanage().newOrphan<rpc::twoparty::VatId>();
      auto hostId = hostIdOrphan.get();
108
      hostId.setSide(rpc::twoparty::Side::SERVER);
109 110 111 112 113 114 115

      auto objectId = message.getRoot<AnyPointer>();
      objectId.setAs<Text>(name);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
      return rpcSystem.restore(hostId, objectId);
#pragma GCC diagnostic pop
116 117 118 119 120 121 122 123
    }
  };

  kj::ForkedPromise<void> setupPromise;

  kj::Maybe<kj::Own<ClientContext>> clientContext;
  // Filled in before `setupPromise` resolves.

124 125
  Impl(kj::StringPtr serverAddress, uint defaultPort,
       ReaderOptions readerOpts)
126 127
      : context(EzRpcContext::getThreadLocal()),
        setupPromise(context->getIoProvider().getNetwork()
Kenton Varda's avatar
Kenton Varda committed
128
            .parseAddress(serverAddress, defaultPort)
129
            .then([](kj::Own<kj::NetworkAddress>&& addr) {
130
              return connectAttach(kj::mv(addr));
131 132 133
            }).then([this, readerOpts](kj::Own<kj::AsyncIoStream>&& stream) {
              clientContext = kj::heap<ClientContext>(kj::mv(stream),
                                                      readerOpts);
134 135
            }).fork()) {}

136 137
  Impl(const struct sockaddr* serverAddress, uint addrSize,
       ReaderOptions readerOpts)
138
      : context(EzRpcContext::getThreadLocal()),
139 140 141
        setupPromise(
            connectAttach(context->getIoProvider().getNetwork()
                .getSockaddr(serverAddress, addrSize))
142 143 144
            .then([this, readerOpts](kj::Own<kj::AsyncIoStream>&& stream) {
              clientContext = kj::heap<ClientContext>(kj::mv(stream),
                                                      readerOpts);
145 146
            }).fork()) {}

147
  Impl(int socketFd, ReaderOptions readerOpts)
148 149
      : context(EzRpcContext::getThreadLocal()),
        setupPromise(kj::Promise<void>(kj::READY_NOW).fork()),
150
        clientContext(kj::heap<ClientContext>(
151 152
            context->getLowLevelIoProvider().wrapSocketFd(socketFd),
            readerOpts)) {}
153 154
};

155 156
EzRpcClient::EzRpcClient(kj::StringPtr serverAddress, uint defaultPort, ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(serverAddress, defaultPort, readerOpts)) {}
157

158 159
EzRpcClient::EzRpcClient(const struct sockaddr* serverAddress, uint addrSize, ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(serverAddress, addrSize, readerOpts)) {}
160

161 162
EzRpcClient::EzRpcClient(int socketFd, ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(socketFd, readerOpts)) {}
163 164 165

EzRpcClient::~EzRpcClient() noexcept(false) {}

166 167 168 169 170 171 172 173 174 175
Capability::Client EzRpcClient::getMain() {
  KJ_IF_MAYBE(client, impl->clientContext) {
    return client->get()->getMain();
  } else {
    return impl->setupPromise.addBranch().then([this]() {
      return KJ_ASSERT_NONNULL(impl->clientContext)->getMain();
    });
  }
}

176 177 178 179 180 181 182 183 184 185 186
Capability::Client EzRpcClient::importCap(kj::StringPtr name) {
  KJ_IF_MAYBE(client, impl->clientContext) {
    return client->get()->restore(name);
  } else {
    return impl->setupPromise.addBranch().then(kj::mvCapture(kj::heapString(name),
        [this](kj::String&& name) {
      return KJ_ASSERT_NONNULL(impl->clientContext)->restore(name);
    }));
  }
}

187 188 189 190
kj::WaitScope& EzRpcClient::getWaitScope() {
  return impl->context->getWaitScope();
}

191 192 193 194
kj::AsyncIoProvider& EzRpcClient::getIoProvider() {
  return impl->context->getIoProvider();
}

195 196 197 198
kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
  return impl->context->getLowLevelIoProvider();
}

199 200
// =======================================================================================

201 202 203 204 205 206 207 208 209 210 211 212 213
namespace {

class DummyFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter {
public:
  bool shouldAllow(const struct sockaddr* addr, uint addrlen) override {
    return true;
  }
};

static DummyFilter DUMMY_FILTER;

}  // namespace

214 215 216
struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
                                public kj::TaskSet::ErrorHandler {
  Capability::Client mainInterface;
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
  kj::Own<EzRpcContext> context;

  struct ExportedCap {
    kj::String name;
    Capability::Client cap = nullptr;

    ExportedCap(kj::StringPtr name, Capability::Client cap)
        : name(kj::heapString(name)), cap(cap) {}

    ExportedCap() = default;
    ExportedCap(const ExportedCap&) = delete;
    ExportedCap(ExportedCap&&) = default;
    ExportedCap& operator=(const ExportedCap&) = delete;
    ExportedCap& operator=(ExportedCap&&) = default;
    // Make std::map happy...
  };

  std::map<kj::StringPtr, ExportedCap> exportMap;

  kj::ForkedPromise<uint> portPromise;

  kj::TaskSet tasks;

  struct ServerContext {
    kj::Own<kj::AsyncIoStream> stream;
    TwoPartyVatNetwork network;
243
    RpcSystem<rpc::twoparty::VatId> rpcSystem;
244

245 246 247
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
    ServerContext(kj::Own<kj::AsyncIoStream>&& stream, SturdyRefRestorer<AnyPointer>& restorer,
248
                  ReaderOptions readerOpts)
249
        : stream(kj::mv(stream)),
250
          network(*this->stream, rpc::twoparty::Side::SERVER, readerOpts),
251
          rpcSystem(makeRpcServer(network, restorer)) {}
252
#pragma GCC diagnostic pop
253 254
  };

255 256 257 258
  Impl(Capability::Client mainInterface, kj::StringPtr bindAddress, uint defaultPort,
       ReaderOptions readerOpts)
      : mainInterface(kj::mv(mainInterface)),
        context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) {
259 260 261
    auto paf = kj::newPromiseAndFulfiller<uint>();
    portPromise = paf.promise.fork();

Kenton Varda's avatar
Kenton Varda committed
262
    tasks.add(context->getIoProvider().getNetwork().parseAddress(bindAddress, defaultPort)
263
        .then(kj::mvCapture(paf.fulfiller,
264 265
          [this, readerOpts](kj::Own<kj::PromiseFulfiller<uint>>&& portFulfiller,
                             kj::Own<kj::NetworkAddress>&& addr) {
266 267
      auto listener = addr->listen();
      portFulfiller->fulfill(listener->getPort());
268
      acceptLoop(kj::mv(listener), readerOpts);
269 270 271
    })));
  }

272 273 274 275
  Impl(Capability::Client mainInterface, struct sockaddr* bindAddress, uint addrSize,
       ReaderOptions readerOpts)
      : mainInterface(kj::mv(mainInterface)),
        context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) {
276
    auto listener = context->getIoProvider().getNetwork()
Kenton Varda's avatar
Kenton Varda committed
277
        .getSockaddr(bindAddress, addrSize)->listen();
278
    portPromise = kj::Promise<uint>(listener->getPort()).fork();
279
    acceptLoop(kj::mv(listener), readerOpts);
280 281
  }

282 283 284
  Impl(Capability::Client mainInterface, int socketFd, uint port, ReaderOptions readerOpts)
      : mainInterface(kj::mv(mainInterface)),
        context(EzRpcContext::getThreadLocal()),
285 286
        portPromise(kj::Promise<uint>(port).fork()),
        tasks(*this) {
287 288
    acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd, DUMMY_FILTER),
               readerOpts);
289 290
  }

291
  void acceptLoop(kj::Own<kj::ConnectionReceiver>&& listener, ReaderOptions readerOpts) {
292 293
    auto ptr = listener.get();
    tasks.add(ptr->accept().then(kj::mvCapture(kj::mv(listener),
294 295 296
        [this, readerOpts](kj::Own<kj::ConnectionReceiver>&& listener,
                           kj::Own<kj::AsyncIoStream>&& connection) {
      acceptLoop(kj::mv(listener), readerOpts);
297

298
      auto server = kj::heap<ServerContext>(kj::mv(connection), *this, readerOpts);
299 300 301

      // Arrange to destroy the server context when all references are gone, or when the
      // EzRpcServer is destroyed (which will destroy the TaskSet).
302
      tasks.add(server->network.onDisconnect().attach(kj::mv(server)));
303 304 305
    })));
  }

306 307 308
  Capability::Client restore(AnyPointer::Reader objectId) override {
    if (objectId.isNull()) {
      return mainInterface;
309
    } else {
310 311 312 313 314 315 316 317
      auto name = objectId.getAs<Text>();
      auto iter = exportMap.find(name);
      if (iter == exportMap.end()) {
        KJ_FAIL_REQUIRE("Server exports no such capability.", name) { break; }
        return nullptr;
      } else {
        return iter->second.cap;
      }
318 319 320 321 322 323 324 325
    }
  }

  void taskFailed(kj::Exception&& exception) override {
    kj::throwFatalException(kj::mv(exception));
  }
};

326 327 328 329 330 331 332 333 334 335 336 337
EzRpcServer::EzRpcServer(Capability::Client mainInterface, kj::StringPtr bindAddress,
                         uint defaultPort, ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(kj::mv(mainInterface), bindAddress, defaultPort, readerOpts)) {}

EzRpcServer::EzRpcServer(Capability::Client mainInterface, struct sockaddr* bindAddress,
                         uint addrSize, ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(kj::mv(mainInterface), bindAddress, addrSize, readerOpts)) {}

EzRpcServer::EzRpcServer(Capability::Client mainInterface, int socketFd, uint port,
                         ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(kj::mv(mainInterface), socketFd, port, readerOpts)) {}

338 339
EzRpcServer::EzRpcServer(kj::StringPtr bindAddress, uint defaultPort,
                         ReaderOptions readerOpts)
340
    : EzRpcServer(nullptr, bindAddress, defaultPort, readerOpts) {}
341

342 343
EzRpcServer::EzRpcServer(struct sockaddr* bindAddress, uint addrSize,
                         ReaderOptions readerOpts)
344
    : EzRpcServer(nullptr, bindAddress, addrSize, readerOpts) {}
345

346
EzRpcServer::EzRpcServer(int socketFd, uint port, ReaderOptions readerOpts)
347
    : EzRpcServer(nullptr, socketFd, port, readerOpts) {}
348 349 350 351 352 353 354 355 356 357 358 359

EzRpcServer::~EzRpcServer() noexcept(false) {}

void EzRpcServer::exportCap(kj::StringPtr name, Capability::Client cap) {
  Impl::ExportedCap entry(kj::heapString(name), cap);
  impl->exportMap[entry.name] = kj::mv(entry);
}

kj::Promise<uint> EzRpcServer::getPort() {
  return impl->portPromise.addBranch();
}

360 361 362 363
kj::WaitScope& EzRpcServer::getWaitScope() {
  return impl->context->getWaitScope();
}

364 365 366 367
kj::AsyncIoProvider& EzRpcServer::getIoProvider() {
  return impl->context->getIoProvider();
}

368 369 370 371
kj::LowLevelAsyncIoProvider& EzRpcServer::getLowLevelIoProvider() {
  return impl->context->getLowLevelIoProvider();
}

372
}  // namespace capnp