ez-rpc.c++ 12.8 KB
Newer Older
Kenton Varda's avatar
Kenton Varda committed
1 2
// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
3
//
Kenton Varda's avatar
Kenton Varda committed
4 5 6 7 8 9
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
10
//
Kenton Varda's avatar
Kenton Varda committed
11 12
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
13
//
Kenton Varda's avatar
Kenton Varda committed
14 15 16 17 18 19 20
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
21 22 23 24 25 26

#include "ez-rpc.h"
#include "rpc-twoparty.h"
#include <capnp/rpc.capnp.h>
#include <kj/async-io.h>
#include <kj/debug.h>
27
#include <kj/threadlocal.h>
28 29 30 31
#include <map>

namespace capnp {

32
KJ_THREADLOCAL_PTR(EzRpcContext) threadEzContext = nullptr;
33 34 35

class EzRpcContext: public kj::Refcounted {
public:
36
  EzRpcContext(): ioContext(kj::setupAsyncIo()) {
37 38 39 40 41 42 43 44 45 46 47
    threadEzContext = this;
  }

  ~EzRpcContext() noexcept(false) {
    KJ_REQUIRE(threadEzContext == this,
               "EzRpcContext destroyed from different thread than it was created.") {
      return;
    }
    threadEzContext = nullptr;
  }

48 49 50 51
  kj::WaitScope& getWaitScope() {
    return ioContext.waitScope;
  }

52
  kj::AsyncIoProvider& getIoProvider() {
53 54 55 56 57
    return *ioContext.provider;
  }

  kj::LowLevelAsyncIoProvider& getLowLevelIoProvider() {
    return *ioContext.lowLevelProvider;
58 59 60 61 62 63 64 65 66 67 68 69
  }

  static kj::Own<EzRpcContext> getThreadLocal() {
    EzRpcContext* existing = threadEzContext;
    if (existing != nullptr) {
      return kj::addRef(*existing);
    } else {
      return kj::refcounted<EzRpcContext>();
    }
  }

private:
70
  kj::AsyncIoContext ioContext;
71 72 73 74
};

// =======================================================================================

75 76 77 78
kj::Promise<kj::Own<kj::AsyncIoStream>> connectAttach(kj::Own<kj::NetworkAddress>&& addr) {
  return addr->connect().attach(kj::mv(addr));
}

79 80 81 82 83 84
struct EzRpcClient::Impl {
  kj::Own<EzRpcContext> context;

  struct ClientContext {
    kj::Own<kj::AsyncIoStream> stream;
    TwoPartyVatNetwork network;
85
    RpcSystem<rpc::twoparty::VatId> rpcSystem;
86

87
    ClientContext(kj::Own<kj::AsyncIoStream>&& stream, ReaderOptions readerOpts)
88
        : stream(kj::mv(stream)),
89
          network(*this->stream, rpc::twoparty::Side::CLIENT, readerOpts),
90 91
          rpcSystem(makeRpcClient(network)) {}

92 93 94 95 96 97 98 99 100
    Capability::Client getMain() {
      word scratch[4];
      memset(scratch, 0, sizeof(scratch));
      MallocMessageBuilder message(scratch);
      auto hostId = message.getRoot<rpc::twoparty::VatId>();
      hostId.setSide(rpc::twoparty::Side::SERVER);
      return rpcSystem.bootstrap(hostId);
    }

101 102 103 104
    Capability::Client restore(kj::StringPtr name) {
      word scratch[64];
      memset(scratch, 0, sizeof(scratch));
      MallocMessageBuilder message(scratch);
105 106 107

      auto hostIdOrphan = message.getOrphanage().newOrphan<rpc::twoparty::VatId>();
      auto hostId = hostIdOrphan.get();
108
      hostId.setSide(rpc::twoparty::Side::SERVER);
109 110 111 112 113 114 115

      auto objectId = message.getRoot<AnyPointer>();
      objectId.setAs<Text>(name);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
      return rpcSystem.restore(hostId, objectId);
#pragma GCC diagnostic pop
116 117 118 119 120 121 122 123
    }
  };

  kj::ForkedPromise<void> setupPromise;

  kj::Maybe<kj::Own<ClientContext>> clientContext;
  // Filled in before `setupPromise` resolves.

124 125
  Impl(kj::StringPtr serverAddress, uint defaultPort,
       ReaderOptions readerOpts)
126 127
      : context(EzRpcContext::getThreadLocal()),
        setupPromise(context->getIoProvider().getNetwork()
Kenton Varda's avatar
Kenton Varda committed
128
            .parseAddress(serverAddress, defaultPort)
129
            .then([readerOpts](kj::Own<kj::NetworkAddress>&& addr) {
130
              return connectAttach(kj::mv(addr));
131 132 133
            }).then([this, readerOpts](kj::Own<kj::AsyncIoStream>&& stream) {
              clientContext = kj::heap<ClientContext>(kj::mv(stream),
                                                      readerOpts);
134 135
            }).fork()) {}

136 137
  Impl(const struct sockaddr* serverAddress, uint addrSize,
       ReaderOptions readerOpts)
138
      : context(EzRpcContext::getThreadLocal()),
139 140 141
        setupPromise(
            connectAttach(context->getIoProvider().getNetwork()
                .getSockaddr(serverAddress, addrSize))
142 143 144
            .then([this, readerOpts](kj::Own<kj::AsyncIoStream>&& stream) {
              clientContext = kj::heap<ClientContext>(kj::mv(stream),
                                                      readerOpts);
145 146
            }).fork()) {}

147
  Impl(int socketFd, ReaderOptions readerOpts)
148 149
      : context(EzRpcContext::getThreadLocal()),
        setupPromise(kj::Promise<void>(kj::READY_NOW).fork()),
150
        clientContext(kj::heap<ClientContext>(
151 152
            context->getLowLevelIoProvider().wrapSocketFd(socketFd),
            readerOpts)) {}
153 154
};

155 156
EzRpcClient::EzRpcClient(kj::StringPtr serverAddress, uint defaultPort, ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(serverAddress, defaultPort, readerOpts)) {}
157

158 159
EzRpcClient::EzRpcClient(const struct sockaddr* serverAddress, uint addrSize, ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(serverAddress, addrSize, readerOpts)) {}
160

161 162
EzRpcClient::EzRpcClient(int socketFd, ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(socketFd, readerOpts)) {}
163 164 165

EzRpcClient::~EzRpcClient() noexcept(false) {}

166 167 168 169 170 171 172 173 174 175
Capability::Client EzRpcClient::getMain() {
  KJ_IF_MAYBE(client, impl->clientContext) {
    return client->get()->getMain();
  } else {
    return impl->setupPromise.addBranch().then([this]() {
      return KJ_ASSERT_NONNULL(impl->clientContext)->getMain();
    });
  }
}

176 177 178 179 180 181 182 183 184 185 186
Capability::Client EzRpcClient::importCap(kj::StringPtr name) {
  KJ_IF_MAYBE(client, impl->clientContext) {
    return client->get()->restore(name);
  } else {
    return impl->setupPromise.addBranch().then(kj::mvCapture(kj::heapString(name),
        [this](kj::String&& name) {
      return KJ_ASSERT_NONNULL(impl->clientContext)->restore(name);
    }));
  }
}

187 188 189 190
kj::WaitScope& EzRpcClient::getWaitScope() {
  return impl->context->getWaitScope();
}

191 192 193 194
kj::AsyncIoProvider& EzRpcClient::getIoProvider() {
  return impl->context->getIoProvider();
}

195 196 197 198
kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
  return impl->context->getLowLevelIoProvider();
}

199 200
// =======================================================================================

201 202 203
struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
                                public kj::TaskSet::ErrorHandler {
  Capability::Client mainInterface;
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
  kj::Own<EzRpcContext> context;

  struct ExportedCap {
    kj::String name;
    Capability::Client cap = nullptr;

    ExportedCap(kj::StringPtr name, Capability::Client cap)
        : name(kj::heapString(name)), cap(cap) {}

    ExportedCap() = default;
    ExportedCap(const ExportedCap&) = delete;
    ExportedCap(ExportedCap&&) = default;
    ExportedCap& operator=(const ExportedCap&) = delete;
    ExportedCap& operator=(ExportedCap&&) = default;
    // Make std::map happy...
  };

  std::map<kj::StringPtr, ExportedCap> exportMap;

  kj::ForkedPromise<uint> portPromise;

  kj::TaskSet tasks;

  struct ServerContext {
    kj::Own<kj::AsyncIoStream> stream;
    TwoPartyVatNetwork network;
230
    RpcSystem<rpc::twoparty::VatId> rpcSystem;
231

232 233 234
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
    ServerContext(kj::Own<kj::AsyncIoStream>&& stream, SturdyRefRestorer<AnyPointer>& restorer,
235
                  ReaderOptions readerOpts)
236
        : stream(kj::mv(stream)),
237
          network(*this->stream, rpc::twoparty::Side::SERVER, readerOpts),
238
          rpcSystem(makeRpcServer(network, restorer)) {}
239
#pragma GCC diagnostic pop
240 241
  };

242 243 244 245
  Impl(Capability::Client mainInterface, kj::StringPtr bindAddress, uint defaultPort,
       ReaderOptions readerOpts)
      : mainInterface(kj::mv(mainInterface)),
        context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) {
246 247 248
    auto paf = kj::newPromiseAndFulfiller<uint>();
    portPromise = paf.promise.fork();

Kenton Varda's avatar
Kenton Varda committed
249
    tasks.add(context->getIoProvider().getNetwork().parseAddress(bindAddress, defaultPort)
250
        .then(kj::mvCapture(paf.fulfiller,
251 252
          [this, readerOpts](kj::Own<kj::PromiseFulfiller<uint>>&& portFulfiller,
                             kj::Own<kj::NetworkAddress>&& addr) {
253 254
      auto listener = addr->listen();
      portFulfiller->fulfill(listener->getPort());
255
      acceptLoop(kj::mv(listener), readerOpts);
256 257 258
    })));
  }

259 260 261 262
  Impl(Capability::Client mainInterface, struct sockaddr* bindAddress, uint addrSize,
       ReaderOptions readerOpts)
      : mainInterface(kj::mv(mainInterface)),
        context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) {
263
    auto listener = context->getIoProvider().getNetwork()
Kenton Varda's avatar
Kenton Varda committed
264
        .getSockaddr(bindAddress, addrSize)->listen();
265
    portPromise = kj::Promise<uint>(listener->getPort()).fork();
266
    acceptLoop(kj::mv(listener), readerOpts);
267 268
  }

269 270 271
  Impl(Capability::Client mainInterface, int socketFd, uint port, ReaderOptions readerOpts)
      : mainInterface(kj::mv(mainInterface)),
        context(EzRpcContext::getThreadLocal()),
272 273
        portPromise(kj::Promise<uint>(port).fork()),
        tasks(*this) {
274
    acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd), readerOpts);
275 276
  }

277
  void acceptLoop(kj::Own<kj::ConnectionReceiver>&& listener, ReaderOptions readerOpts) {
278 279
    auto ptr = listener.get();
    tasks.add(ptr->accept().then(kj::mvCapture(kj::mv(listener),
280 281 282
        [this, readerOpts](kj::Own<kj::ConnectionReceiver>&& listener,
                           kj::Own<kj::AsyncIoStream>&& connection) {
      acceptLoop(kj::mv(listener), readerOpts);
283

284
      auto server = kj::heap<ServerContext>(kj::mv(connection), *this, readerOpts);
285 286 287

      // Arrange to destroy the server context when all references are gone, or when the
      // EzRpcServer is destroyed (which will destroy the TaskSet).
288
      tasks.add(server->network.onDisconnect().attach(kj::mv(server)));
289 290 291
    })));
  }

292 293 294
  Capability::Client restore(AnyPointer::Reader objectId) override {
    if (objectId.isNull()) {
      return mainInterface;
295
    } else {
296 297 298 299 300 301 302 303
      auto name = objectId.getAs<Text>();
      auto iter = exportMap.find(name);
      if (iter == exportMap.end()) {
        KJ_FAIL_REQUIRE("Server exports no such capability.", name) { break; }
        return nullptr;
      } else {
        return iter->second.cap;
      }
304 305 306 307 308 309 310 311
    }
  }

  void taskFailed(kj::Exception&& exception) override {
    kj::throwFatalException(kj::mv(exception));
  }
};

312 313 314 315 316 317 318 319 320 321 322 323
EzRpcServer::EzRpcServer(Capability::Client mainInterface, kj::StringPtr bindAddress,
                         uint defaultPort, ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(kj::mv(mainInterface), bindAddress, defaultPort, readerOpts)) {}

EzRpcServer::EzRpcServer(Capability::Client mainInterface, struct sockaddr* bindAddress,
                         uint addrSize, ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(kj::mv(mainInterface), bindAddress, addrSize, readerOpts)) {}

EzRpcServer::EzRpcServer(Capability::Client mainInterface, int socketFd, uint port,
                         ReaderOptions readerOpts)
    : impl(kj::heap<Impl>(kj::mv(mainInterface), socketFd, port, readerOpts)) {}

324 325
EzRpcServer::EzRpcServer(kj::StringPtr bindAddress, uint defaultPort,
                         ReaderOptions readerOpts)
326
    : EzRpcServer(nullptr, bindAddress, defaultPort, readerOpts) {}
327

328 329
EzRpcServer::EzRpcServer(struct sockaddr* bindAddress, uint addrSize,
                         ReaderOptions readerOpts)
330
    : EzRpcServer(nullptr, bindAddress, addrSize, readerOpts) {}
331

332
EzRpcServer::EzRpcServer(int socketFd, uint port, ReaderOptions readerOpts)
333
    : EzRpcServer(nullptr, socketFd, port, readerOpts) {}
334 335 336 337 338 339 340 341 342 343 344 345

EzRpcServer::~EzRpcServer() noexcept(false) {}

void EzRpcServer::exportCap(kj::StringPtr name, Capability::Client cap) {
  Impl::ExportedCap entry(kj::heapString(name), cap);
  impl->exportMap[entry.name] = kj::mv(entry);
}

kj::Promise<uint> EzRpcServer::getPort() {
  return impl->portPromise.addBranch();
}

346 347 348 349
kj::WaitScope& EzRpcServer::getWaitScope() {
  return impl->context->getWaitScope();
}

350 351 352 353
kj::AsyncIoProvider& EzRpcServer::getIoProvider() {
  return impl->context->getIoProvider();
}

354 355 356 357
kj::LowLevelAsyncIoProvider& EzRpcServer::getLowLevelIoProvider() {
  return impl->context->getLowLevelIoProvider();
}

358
}  // namespace capnp