capability.c++ 31.6 KB
Newer Older
Kenton Varda's avatar
Kenton Varda committed
1 2
// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
3
//
Kenton Varda's avatar
Kenton Varda committed
4 5 6 7 8 9
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
10
//
Kenton Varda's avatar
Kenton Varda committed
11 12
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
13
//
Kenton Varda's avatar
Kenton Varda committed
14 15 16 17 18 19 20
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
21

22 23
#define CAPNP_PRIVATE

24
#include "capability.h"
Kenton Varda's avatar
Kenton Varda committed
25
#include "message.h"
26
#include "arena.h"
Kenton Varda's avatar
Kenton Varda committed
27
#include <kj/refcount.h>
28
#include <kj/debug.h>
Kenton Varda's avatar
Kenton Varda committed
29
#include <kj/vector.h>
30
#include <map>
31
#include "generated-header-support.h"
32 33 34

namespace capnp {

35 36 37 38 39 40 41 42 43
namespace _ {

void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory);
// Defined in layout.c++.

}  // namespace _

namespace {

44 45
static kj::Own<ClientHook> newNullCap();

46 47 48 49 50
class BrokenCapFactoryImpl: public _::BrokenCapFactory {
public:
  kj::Own<ClientHook> newBrokenCap(kj::StringPtr description) override {
    return capnp::newBrokenCap(description);
  }
51 52 53
  kj::Own<ClientHook> newNullCap() override {
    return capnp::newNullCap();
  }
54 55 56 57 58 59 60 61 62 63 64 65
};

static BrokenCapFactoryImpl brokenCapFactory;

}  // namespace

ClientHook::ClientHook() {
  setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
}

// =======================================================================================

66
Capability::Client::Client(decltype(nullptr))
67
    : hook(newNullCap()) {}
68

69 70 71
Capability::Client::Client(kj::Exception&& exception)
    : hook(newBrokenCap(kj::mv(exception))) {}

72 73 74 75 76 77 78 79 80 81 82 83 84
kj::Promise<kj::Maybe<int>> Capability::Client::getFd() {
  auto fd = hook->getFd();
  if (fd != nullptr) {
    return fd;
  } else KJ_IF_MAYBE(promise, hook->whenMoreResolved()) {
    return promise->attach(hook->addRef()).then([](kj::Own<ClientHook> newHook) {
      return Client(kj::mv(newHook)).getFd();
    });
  } else {
    return kj::Maybe<int>(nullptr);
  }
}

85
Capability::Server::DispatchCallResult Capability::Server::internalUnimplemented(
86
    const char* actualInterfaceName, uint64_t requestedTypeId) {
87 88 89 90 91
  return {
    KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.",
                 actualInterfaceName, requestedTypeId),
    false
  };
92 93
}

94
Capability::Server::DispatchCallResult Capability::Server::internalUnimplemented(
95
    const char* interfaceName, uint64_t typeId, uint16_t methodId) {
96 97 98 99
  return {
    KJ_EXCEPTION(UNIMPLEMENTED, "Method not implemented.", interfaceName, typeId, methodId),
    false
  };
100 101 102 103
}

kj::Promise<void> Capability::Server::internalUnimplemented(
    const char* interfaceName, const char* methodName, uint64_t typeId, uint16_t methodId) {
104 105
  return KJ_EXCEPTION(UNIMPLEMENTED, "Method not implemented.", interfaceName,
                      typeId, methodName, methodId);
106 107
}

Kenton Varda's avatar
Kenton Varda committed
108 109
ResponseHook::~ResponseHook() noexcept(false) {}

110
kj::Promise<void> ClientHook::whenResolved() {
111
  KJ_IF_MAYBE(promise, whenMoreResolved()) {
112
    return promise->then([](kj::Own<ClientHook>&& resolution) {
113 114 115 116 117 118
      return resolution->whenResolved();
    });
  } else {
    return kj::READY_NOW;
  }
}
Kenton Varda's avatar
Kenton Varda committed
119

120
// =======================================================================================
Kenton Varda's avatar
Kenton Varda committed
121

122 123 124 125 126 127 128 129
static inline uint firstSegmentSize(kj::Maybe<MessageSize> sizeHint) {
  KJ_IF_MAYBE(s, sizeHint) {
    return s->wordCount;
  } else {
    return SUGGESTED_FIRST_SEGMENT_WORDS;
  }
}

130
class LocalResponse final: public ResponseHook, public kj::Refcounted {
Kenton Varda's avatar
Kenton Varda committed
131
public:
132
  LocalResponse(kj::Maybe<MessageSize> sizeHint)
133
      : message(firstSegmentSize(sizeHint)) {}
Kenton Varda's avatar
Kenton Varda committed
134

135
  MallocMessageBuilder message;
Kenton Varda's avatar
Kenton Varda committed
136 137
};

138
class LocalCallContext final: public CallContextHook, public kj::Refcounted {
Kenton Varda's avatar
Kenton Varda committed
139
public:
140
  LocalCallContext(kj::Own<MallocMessageBuilder>&& request, kj::Own<ClientHook> clientRef,
141 142 143
                   kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller)
      : request(kj::mv(request)), clientRef(kj::mv(clientRef)),
        cancelAllowedFulfiller(kj::mv(cancelAllowedFulfiller)) {}
Kenton Varda's avatar
Kenton Varda committed
144

145
  AnyPointer::Reader getParams() override {
146
    KJ_IF_MAYBE(r, request) {
147
      return r->get()->getRoot<AnyPointer>();
148 149 150
    } else {
      KJ_FAIL_REQUIRE("Can't call getParams() after releaseParams().");
    }
Kenton Varda's avatar
Kenton Varda committed
151 152 153 154
  }
  void releaseParams() override {
    request = nullptr;
  }
155
  AnyPointer::Builder getResults(kj::Maybe<MessageSize> sizeHint) override {
156
    if (response == nullptr) {
157
      auto localResponse = kj::refcounted<LocalResponse>(sizeHint);
158
      responseBuilder = localResponse->message.getRoot<AnyPointer>();
159
      response = Response<AnyPointer>(responseBuilder.asReader(), kj::mv(localResponse));
Kenton Varda's avatar
Kenton Varda committed
160
    }
161 162
    return responseBuilder;
  }
163 164 165
  kj::Promise<void> tailCall(kj::Own<RequestHook>&& request) override {
    auto result = directTailCall(kj::mv(request));
    KJ_IF_MAYBE(f, tailCallPipelineFulfiller) {
166
      f->get()->fulfill(AnyPointer::Pipeline(kj::mv(result.pipeline)));
167 168 169 170
    }
    return kj::mv(result.promise);
  }
  ClientHook::VoidPromiseAndPipeline directTailCall(kj::Own<RequestHook>&& request) override {
171 172 173 174
    KJ_REQUIRE(response == nullptr, "Can't call tailCall() after initializing the results struct.");

    auto promise = request->send();

175
    auto voidPromise = promise.then([this](Response<AnyPointer>&& tailResponse) {
176 177
      response = kj::mv(tailResponse);
    });
178 179

    return { kj::mv(voidPromise), PipelineHook::from(kj::mv(promise)) };
180
  }
181 182
  kj::Promise<AnyPointer::Pipeline> onTailCall() override {
    auto paf = kj::newPromiseAndFulfiller<AnyPointer::Pipeline>();
183 184
    tailCallPipelineFulfiller = kj::mv(paf.fulfiller);
    return kj::mv(paf.promise);
Kenton Varda's avatar
Kenton Varda committed
185
  }
186
  void allowCancellation() override {
187
    cancelAllowedFulfiller->fulfill();
Kenton Varda's avatar
Kenton Varda committed
188
  }
189 190
  kj::Own<CallContextHook> addRef() override {
    return kj::addRef(*this);
191
  }
Kenton Varda's avatar
Kenton Varda committed
192

193
  kj::Maybe<kj::Own<MallocMessageBuilder>> request;
194 195
  kj::Maybe<Response<AnyPointer>> response;
  AnyPointer::Builder responseBuilder = nullptr;  // only valid if `response` is non-null
196
  kj::Own<ClientHook> clientRef;
197
  kj::Maybe<kj::Own<kj::PromiseFulfiller<AnyPointer::Pipeline>>> tailCallPipelineFulfiller;
198
  kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller;
Kenton Varda's avatar
Kenton Varda committed
199 200
};

201
class LocalRequest final: public RequestHook {
Kenton Varda's avatar
Kenton Varda committed
202
public:
203
  inline LocalRequest(uint64_t interfaceId, uint16_t methodId,
204
                      kj::Maybe<MessageSize> sizeHint, kj::Own<ClientHook> client)
205
      : message(kj::heap<MallocMessageBuilder>(firstSegmentSize(sizeHint))),
206 207
        interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {}

208
  RemotePromise<AnyPointer> send() override {
209 210
    KJ_REQUIRE(message.get() != nullptr, "Already called send() on this request.");

211 212 213 214
    auto cancelPaf = kj::newPromiseAndFulfiller<void>();

    auto context = kj::refcounted<LocalCallContext>(
        kj::mv(message), client->addRef(), kj::mv(cancelPaf.fulfiller));
215
    auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context));
216

217 218
    // We have to make sure the call is not canceled unless permitted.  We need to fork the promise
    // so that if the client drops their copy, the promise isn't necessarily canceled.
219
    auto forked = promiseAndPipeline.promise.fork();
220 221 222

    // We daemonize one branch, but only after joining it with the promise that fires if
    // cancellation is allowed.
223 224 225 226
    forked.addBranch()
        .attach(kj::addRef(*context))
        .exclusiveJoin(kj::mv(cancelPaf.promise))
        .detach([](kj::Exception&&) {});  // ignore exceptions
227 228

    // Now the other branch returns the response from the context.
229 230
    auto promise = forked.addBranch().then(kj::mvCapture(context,
        [](kj::Own<LocalCallContext>&& context) {
231
      context->getResults(MessageSize { 0, 0 });  // force response allocation
232 233
      return kj::mv(KJ_ASSERT_NONNULL(context->response));
    }));
234

235
    // We return the other branch.
236 237
    return RemotePromise<AnyPointer>(
        kj::mv(promise), AnyPointer::Pipeline(kj::mv(promiseAndPipeline.pipeline)));
238 239
  }

240 241 242 243 244 245
  kj::Promise<void> sendStreaming() override {
    // We don't do any special handling of streaming in RequestHook for local requests, because
    // there is no latency to compensate for between the client and server in this case.
    return send().ignoreResult();
  }

246
  const void* getBrand() override {
247 248 249
    return nullptr;
  }

250
  kj::Own<MallocMessageBuilder> message;
251 252 253 254

private:
  uint64_t interfaceId;
  uint16_t methodId;
255
  kj::Own<ClientHook> client;
256 257 258 259 260 261 262 263 264 265 266 267 268
};

// =======================================================================================
// Call queues
//
// These classes handle pipelining in the case where calls need to be queued in-memory until some
// local operation completes.

class QueuedPipeline final: public PipelineHook, public kj::Refcounted {
  // A PipelineHook which simply queues calls while waiting for a PipelineHook to which to forward
  // them.

public:
269 270
  QueuedPipeline(kj::Promise<kj::Own<PipelineHook>>&& promiseParam)
      : promise(promiseParam.fork()),
271 272 273 274 275
        selfResolutionOp(promise.addBranch().then([this](kj::Own<PipelineHook>&& inner) {
          redirect = kj::mv(inner);
        }, [this](kj::Exception&& exception) {
          redirect = newBrokenPipeline(kj::mv(exception));
        }).eagerlyEvaluate(nullptr)) {}
276

277
  kj::Own<PipelineHook> addRef() override {
278 279 280
    return kj::addRef(*this);
  }

281
  kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) override {
282 283 284 285
    auto copy = kj::heapArrayBuilder<PipelineOp>(ops.size());
    for (auto& op: ops) {
      copy.add(op);
    }
286
    return getPipelinedCap(copy.finish());
287 288
  }

289
  kj::Own<ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) override;
290 291

private:
292
  kj::ForkedPromise<kj::Own<PipelineHook>> promise;
Kenton Varda's avatar
Kenton Varda committed
293

Kenton Varda's avatar
Kenton Varda committed
294
  kj::Maybe<kj::Own<PipelineHook>> redirect;
Kenton Varda's avatar
Kenton Varda committed
295 296 297 298
  // Once the promise resolves, this will become non-null and point to the underlying object.

  kj::Promise<void> selfResolutionOp;
  // Represents the operation which will set `redirect` when possible.
299 300 301 302 303 304 305
};

class QueuedClient final: public ClientHook, public kj::Refcounted {
  // A ClientHook which simply queues calls while waiting for a ClientHook to which to forward
  // them.

public:
306 307
  QueuedClient(kj::Promise<kj::Own<ClientHook>>&& promiseParam)
      : promise(promiseParam.fork()),
308 309 310 311 312
        selfResolutionOp(promise.addBranch().then([this](kj::Own<ClientHook>&& inner) {
          redirect = kj::mv(inner);
        }, [this](kj::Exception&& exception) {
          redirect = newBrokenCap(kj::mv(exception));
        }).eagerlyEvaluate(nullptr)),
Kenton Varda's avatar
Kenton Varda committed
313
        promiseForCallForwarding(promise.addBranch().fork()),
314
        promiseForClientResolution(promise.addBranch().fork()) {}
Kenton Varda's avatar
Kenton Varda committed
315

316
  Request<AnyPointer, AnyPointer> newCall(
317
      uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
318
    auto hook = kj::heap<LocalRequest>(
319
        interfaceId, methodId, sizeHint, kj::addRef(*this));
320
    auto root = hook->message->getRoot<AnyPointer>();
321
    return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
Kenton Varda's avatar
Kenton Varda committed
322 323 324
  }

  VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
325
                              kj::Own<CallContextHook>&& context) override {
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
    // This is a bit complicated.  We need to initiate this call later on.  When we initiate the
    // call, we'll get a void promise for its completion and a pipeline object.  Right now, we have
    // to produce a similar void promise and pipeline that will eventually be chained to those.
    // The problem is, these are two independent objects, but they both depend on the result of
    // one future call.
    //
    // So, we need to set up a continuation that will initiate the call later, then we need to
    // fork the promise for that continuation in order to send the completion promise and the
    // pipeline to their respective places.
    //
    // TODO(perf):  Too much reference counting?  Can we do better?  Maybe a way to fork
    //   Promise<Tuple<T, U>> into Tuple<Promise<T>, Promise<U>>?

    struct CallResultHolder: public kj::Refcounted {
      // Essentially acts as a refcounted \VoidPromiseAndPipeline, so that we can create a promise
      // for it and fork that promise.

343
      VoidPromiseAndPipeline content;
344
      // One branch of the fork will use content.promise, the other branch will use
345
      // content.pipeline.  Neither branch will touch the other's piece.
346 347 348

      inline CallResultHolder(VoidPromiseAndPipeline&& content): content(kj::mv(content)) {}

349
      kj::Own<CallResultHolder> addRef() { return kj::addRef(*this); }
350 351 352
    };

    // Create a promise for the call initiation.
353
    kj::ForkedPromise<kj::Own<CallResultHolder>> callResultPromise =
Kenton Varda's avatar
Kenton Varda committed
354
        promiseForCallForwarding.addBranch().then(kj::mvCapture(context,
355
        [=](kj::Own<CallContextHook>&& context, kj::Own<ClientHook>&& client){
356 357
          return kj::refcounted<CallResultHolder>(
              client->call(interfaceId, methodId, kj::mv(context)));
358
        })).fork();
359 360 361

    // Create a promise that extracts the pipeline from the call initiation, and construct our
    // QueuedPipeline to chain to it.
362 363
    auto pipelinePromise = callResultPromise.addBranch().then(
        [](kj::Own<CallResultHolder>&& callResult){
364 365
          return kj::mv(callResult->content.pipeline);
        });
366
    auto pipeline = kj::refcounted<QueuedPipeline>(kj::mv(pipelinePromise));
367 368

    // Create a promise that simply chains to the void promise produced by the call initiation.
369 370
    auto completionPromise = callResultPromise.addBranch().then(
        [](kj::Own<CallResultHolder>&& callResult){
371 372 373 374 375
          return kj::mv(callResult->content.promise);
        });

    // OK, now we can actually return our thing.
    return VoidPromiseAndPipeline { kj::mv(completionPromise), kj::mv(pipeline) };
Kenton Varda's avatar
Kenton Varda committed
376 377
  }

378
  kj::Maybe<ClientHook&> getResolved() override {
Kenton Varda's avatar
Kenton Varda committed
379
    KJ_IF_MAYBE(inner, redirect) {
380 381 382 383 384 385
      return **inner;
    } else {
      return nullptr;
    }
  }

386
  kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
Kenton Varda's avatar
Kenton Varda committed
387
    return promiseForClientResolution.addBranch();
Kenton Varda's avatar
Kenton Varda committed
388 389
  }

390
  kj::Own<ClientHook> addRef() override {
Kenton Varda's avatar
Kenton Varda committed
391 392 393
    return kj::addRef(*this);
  }

394
  const void* getBrand() override {
Kenton Varda's avatar
Kenton Varda committed
395 396 397
    return nullptr;
  }

398 399 400 401 402 403 404 405
  kj::Maybe<int> getFd() override {
    KJ_IF_MAYBE(r, redirect) {
      return r->get()->getFd();
    } else {
      return nullptr;
    }
  }

Kenton Varda's avatar
Kenton Varda committed
406
private:
407
  typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork;
408

Kenton Varda's avatar
Kenton Varda committed
409 410 411
  kj::Maybe<kj::Own<ClientHook>> redirect;
  // Once the promise resolves, this will become non-null and point to the underlying object.

412 413 414
  ClientHookPromiseFork promise;
  // Promise that resolves when we have a new ClientHook to forward to.
  //
Kenton Varda's avatar
Kenton Varda committed
415
  // This fork shall only have three branches:  `selfResolutionOp`, `promiseForCallForwarding`, and
416 417
  // `promiseForClientResolution`, in that order.

Kenton Varda's avatar
Kenton Varda committed
418 419 420 421
  kj::Promise<void> selfResolutionOp;
  // Represents the operation which will set `redirect` when possible.

  ClientHookPromiseFork promiseForCallForwarding;
422 423 424 425
  // When this promise resolves, each queued call will be forwarded to the real client.  This needs
  // to occur *before* any 'whenMoreResolved()' promises resolve, because we want to make sure
  // previously-queued calls are delivered before any new calls made in response to the resolution.

Kenton Varda's avatar
Kenton Varda committed
426
  ClientHookPromiseFork promiseForClientResolution;
427 428 429 430 431 432
  // whenMoreResolved() returns forks of this promise.  These must resolve *after* queued calls
  // have been initiated (so that any calls made in the whenMoreResolved() handler are correctly
  // delivered after calls made earlier), but *before* any queued calls return (because it might
  // confuse the application if a queued call returns before the capability on which it was made
  // resolves).  Luckily, we know that queued calls will involve, at the very least, an
  // eventLoop.evalLater.
Kenton Varda's avatar
Kenton Varda committed
433 434
};

435
kj::Own<ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) {
Kenton Varda's avatar
Kenton Varda committed
436 437
  KJ_IF_MAYBE(r, redirect) {
    return r->get()->getPipelinedCap(kj::mv(ops));
Kenton Varda's avatar
Kenton Varda committed
438
  } else {
439 440
    auto clientPromise = promise.addBranch().then(kj::mvCapture(ops,
        [](kj::Array<PipelineOp>&& ops, kj::Own<PipelineHook> pipeline) {
Kenton Varda's avatar
Kenton Varda committed
441 442 443
          return pipeline->getPipelinedCap(kj::mv(ops));
        }));

444
    return kj::refcounted<QueuedClient>(kj::mv(clientPromise));
Kenton Varda's avatar
Kenton Varda committed
445
  }
446
}
Kenton Varda's avatar
Kenton Varda committed
447

448
// =======================================================================================
Kenton Varda's avatar
Kenton Varda committed
449

450
class LocalPipeline final: public PipelineHook, public kj::Refcounted {
Kenton Varda's avatar
Kenton Varda committed
451
public:
452 453
  inline LocalPipeline(kj::Own<CallContextHook>&& contextParam)
      : context(kj::mv(contextParam)),
454
        results(context->getResults(MessageSize { 0, 0 })) {}
Kenton Varda's avatar
Kenton Varda committed
455

456
  kj::Own<PipelineHook> addRef() {
457
    return kj::addRef(*this);
Kenton Varda's avatar
Kenton Varda committed
458 459
  }

460
  kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) {
461
    return results.getPipelinedCap(ops);
462
  }
Kenton Varda's avatar
Kenton Varda committed
463 464

private:
465
  kj::Own<CallContextHook> context;
466
  AnyPointer::Reader results;
Kenton Varda's avatar
Kenton Varda committed
467 468 469 470
};

class LocalClient final: public ClientHook, public kj::Refcounted {
public:
471 472 473 474 475
  LocalClient(kj::Own<Capability::Server>&& serverParam)
      : server(kj::mv(serverParam)) {
    server->thisHook = this;
  }
  LocalClient(kj::Own<Capability::Server>&& serverParam,
476
              _::CapabilityServerSetBase& capServerSet, void* ptr)
477 478
      : server(kj::mv(serverParam)), capServerSet(&capServerSet), ptr(ptr) {
    server->thisHook = this;
479 480
  }

481 482
  ~LocalClient() noexcept(false) {
    server->thisHook = nullptr;
483
  }
Kenton Varda's avatar
Kenton Varda committed
484

485
  Request<AnyPointer, AnyPointer> newCall(
486
      uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
Kenton Varda's avatar
Kenton Varda committed
487
    auto hook = kj::heap<LocalRequest>(
488
        interfaceId, methodId, sizeHint, kj::addRef(*this));
489
    auto root = hook->message->getRoot<AnyPointer>();
490
    return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
Kenton Varda's avatar
Kenton Varda committed
491 492 493
  }

  VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
494
                              kj::Own<CallContextHook>&& context) override {
495 496
    auto contextPtr = context.get();

497 498 499
    // We don't want to actually dispatch the call synchronously, because we don't want the callee
    // to have any side effects before the promise is returned to the caller.  This helps avoid
    // race conditions.
500 501 502 503 504
    //
    // So, we do an evalLater() here.
    //
    // Note also that QueuedClient depends on this evalLater() to ensure that pipelined calls don't
    // complete before 'whenMoreResolved()' promises resolve.
505
    auto promise = kj::evalLater([this,interfaceId,methodId,contextPtr]() {
506 507 508 509 510 511
      if (blocked) {
        return kj::newAdaptedPromise<kj::Promise<void>, BlockedCall>(
            *this, interfaceId, methodId, *contextPtr);
      } else {
        return callInternal(interfaceId, methodId, *contextPtr);
      }
512
    }).attach(kj::addRef(*this));
513

514
    // We have to fork this promise for the pipeline to receive a copy of the answer.
515
    auto forked = promise.fork();
516

517 518
    auto pipelinePromise = forked.addBranch().then(kj::mvCapture(context->addRef(),
        [=](kj::Own<CallContextHook>&& context) -> kj::Own<PipelineHook> {
519 520 521
          context->releaseParams();
          return kj::refcounted<LocalPipeline>(kj::mv(context));
        }));
522

523
    auto tailPipelinePromise = context->onTailCall().then([](AnyPointer::Pipeline&& pipeline) {
524 525 526
      return kj::mv(pipeline.hook);
    });

527
    pipelinePromise = pipelinePromise.exclusiveJoin(kj::mv(tailPipelinePromise));
528

529
    auto completionPromise = forked.addBranch().attach(kj::mv(context));
Kenton Varda's avatar
Kenton Varda committed
530

531
    return VoidPromiseAndPipeline { kj::mv(completionPromise),
532
        kj::refcounted<QueuedPipeline>(kj::mv(pipelinePromise)) };
Kenton Varda's avatar
Kenton Varda committed
533 534
  }

535
  kj::Maybe<ClientHook&> getResolved() override {
536 537 538
    return nullptr;
  }

539
  kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
Kenton Varda's avatar
Kenton Varda committed
540 541 542
    return nullptr;
  }

543
  kj::Own<ClientHook> addRef() override {
Kenton Varda's avatar
Kenton Varda committed
544 545 546
    return kj::addRef(*this);
  }

547 548 549
  static const uint BRAND;
  // Value is irrelevant; used for pointer.

550
  const void* getBrand() override {
551
    return &BRAND;
Kenton Varda's avatar
Kenton Varda committed
552 553
  }

554 555 556 557 558
  kj::Promise<void*> getLocalServer(_::CapabilityServerSetBase& capServerSet) {
    // If this is a local capability created through `capServerSet`, return the underlying Server.
    // Otherwise, return nullptr. Default implementation (which everyone except LocalClient should
    // use) always returns nullptr.

559
    if (this->capServerSet == &capServerSet) {
560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
      if (blocked) {
        // If streaming calls are in-flight, it could be the case that they were originally sent
        // over RPC and reflected back, before the capability had resolved to a local object. In
        // that case, the client may already perceive these calls as "done" because the RPC
        // implementation caused the client promise to resolve early. However, the capability is
        // now local, and the app is trying to break through the LocalClient wrapper and access
        // the server directly, bypassing the stream queue. Since the app thinks that all
        // previous calls already completed, it may then try to queue a new call directly on the
        // server, jumping the queue.
        //
        // We can solve this by delaying getLocalServer() until all current streaming calls have
        // finished. Note that if a new streaming call is started *after* this point, we need not
        // worry about that, because in this case it is presumably a local call and the caller
        // won't be informed of completion until the call actually does complete. Thus the caller
        // is well-aware that this call is still in-flight.
        //
        // However, the app still cannot assume that there aren't multiple clients, perhaps even
        // a malicious client that tries to send stream requests that overlap with the app's
        // direct use of the server... so it's up to the app to check for and guard against
        // concurrent calls after using getLocalServer().
        return kj::newAdaptedPromise<kj::Promise<void>, BlockedCall>(*this)
            .then([this]() { return ptr; });
      } else {
        return ptr;
      }
585
    } else {
586
      return (void*)nullptr;
587 588 589
    }
  }

590 591 592 593
  kj::Maybe<int> getFd() override {
    return server->getFd();
  }

Kenton Varda's avatar
Kenton Varda committed
594
private:
595
  kj::Own<Capability::Server> server;
596 597
  _::CapabilityServerSetBase* capServerSet = nullptr;
  void* ptr = nullptr;
598 599 600 601 602 603 604 605 606 607 608 609

  class BlockedCall {
  public:
    BlockedCall(kj::PromiseFulfiller<kj::Promise<void>>& fulfiller, LocalClient& client,
                uint64_t interfaceId, uint16_t methodId, CallContextHook& context)
        : fulfiller(fulfiller), client(client),
          interfaceId(interfaceId), methodId(methodId), context(context),
          prev(client.blockedCallsEnd) {
      *prev = *this;
      client.blockedCallsEnd = &next;
    }

610 611 612 613 614 615
    BlockedCall(kj::PromiseFulfiller<kj::Promise<void>>& fulfiller, LocalClient& client)
        : fulfiller(fulfiller), client(client), prev(client.blockedCallsEnd) {
      *prev = *this;
      client.blockedCallsEnd = &next;
    }

616 617 618 619 620 621
    ~BlockedCall() noexcept(false) {
      unlink();
    }

    void unblock() {
      unlink();
622 623 624 625 626 627 628 629
      KJ_IF_MAYBE(c, context) {
        fulfiller.fulfill(kj::evalNow([&]() {
          return client.callInternal(interfaceId, methodId, *c);
        }));
      } else {
        // This is just a barrier.
        fulfiller.fulfill(kj::READY_NOW);
      }
630 631 632 633 634 635 636
    }

  private:
    kj::PromiseFulfiller<kj::Promise<void>>& fulfiller;
    LocalClient& client;
    uint64_t interfaceId;
    uint16_t methodId;
637
    kj::Maybe<CallContextHook&> context;
638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708

    kj::Maybe<BlockedCall&> next;
    kj::Maybe<BlockedCall&>* prev;

    void unlink() {
      if (prev != nullptr) {
        *prev = next;
        KJ_IF_MAYBE(n, next) {
          n->prev = prev;
        } else {
          client.blockedCallsEnd = prev;
        }
        prev = nullptr;
      }
    }
  };

  class BlockingScope {
  public:
    BlockingScope(LocalClient& client): client(client) { client.blocked = true; }
    BlockingScope(): client(nullptr) {}
    BlockingScope(BlockingScope&& other): client(other.client) { other.client = nullptr; }
    KJ_DISALLOW_COPY(BlockingScope);

    ~BlockingScope() noexcept(false) {
      KJ_IF_MAYBE(c, client) {
        c->unblock();
      }
    }

  private:
    kj::Maybe<LocalClient&> client;
  };

  bool blocked = false;
  kj::Maybe<kj::Exception> brokenException;
  kj::Maybe<BlockedCall&> blockedCalls;
  kj::Maybe<BlockedCall&>* blockedCallsEnd = &blockedCalls;

  void unblock() {
    blocked = false;
    while (!blocked) {
      KJ_IF_MAYBE(t, blockedCalls) {
        t->unblock();
      } else {
        break;
      }
    }
  }

  kj::Promise<void> callInternal(uint64_t interfaceId, uint16_t methodId,
                                 CallContextHook& context) {
    KJ_ASSERT(!blocked);

    KJ_IF_MAYBE(e, brokenException) {
      // Previous streaming call threw, so everything fails from now on.
      return kj::cp(*e);
    }

    auto result = server->dispatchCall(interfaceId, methodId,
                                       CallContext<AnyPointer, AnyPointer>(context));
    if (result.isStreaming) {
      return result.promise
          .catch_([this](kj::Exception&& e) {
        brokenException = kj::cp(e);
        kj::throwRecoverableException(kj::mv(e));
      }).attach(BlockingScope(*this));
    } else {
      return kj::mv(result.promise);
    }
  }
Kenton Varda's avatar
Kenton Varda committed
709 710
};

711 712
const uint LocalClient::BRAND = 0;

713 714
kj::Own<ClientHook> Capability::Client::makeLocalClient(kj::Own<Capability::Server>&& server) {
  return kj::refcounted<LocalClient>(kj::mv(server));
Kenton Varda's avatar
Kenton Varda committed
715 716
}

717 718
kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& promise) {
  return kj::refcounted<QueuedClient>(kj::mv(promise));
Kenton Varda's avatar
Kenton Varda committed
719 720
}

721 722 723 724
kj::Own<PipelineHook> newLocalPromisePipeline(kj::Promise<kj::Own<PipelineHook>>&& promise) {
  return kj::refcounted<QueuedPipeline>(kj::mv(promise));
}

725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752
// =======================================================================================

namespace {

class BrokenPipeline final: public PipelineHook, public kj::Refcounted {
public:
  BrokenPipeline(const kj::Exception& exception): exception(exception) {}

  kj::Own<PipelineHook> addRef() override {
    return kj::addRef(*this);
  }

  kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) override;

private:
  kj::Exception exception;
};

class BrokenRequest final: public RequestHook {
public:
  BrokenRequest(const kj::Exception& exception, kj::Maybe<MessageSize> sizeHint)
      : exception(exception), message(firstSegmentSize(sizeHint)) {}

  RemotePromise<AnyPointer> send() override {
    return RemotePromise<AnyPointer>(kj::cp(exception),
        AnyPointer::Pipeline(kj::refcounted<BrokenPipeline>(exception)));
  }

753 754 755 756
  kj::Promise<void> sendStreaming() override {
    return kj::cp(exception);
  }

757
  const void* getBrand() override {
758 759 760 761 762 763 764 765 766
    return nullptr;
  }

  kj::Exception exception;
  MallocMessageBuilder message;
};

class BrokenClient final: public ClientHook, public kj::Refcounted {
public:
767 768 769 770 771
  BrokenClient(const kj::Exception& exception, bool resolved, const void* brand = nullptr)
      : exception(exception), resolved(resolved), brand(brand) {}
  BrokenClient(const kj::StringPtr description, bool resolved, const void* brand = nullptr)
      : exception(kj::Exception::Type::FAILED, "", 0, kj::str(description)),
        resolved(resolved), brand(brand) {}
772 773 774

  Request<AnyPointer, AnyPointer> newCall(
      uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
775
    return newBrokenRequest(kj::cp(exception), sizeHint);
776 777 778 779
  }

  VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
                              kj::Own<CallContextHook>&& context) override {
780
    return VoidPromiseAndPipeline { kj::cp(exception), kj::refcounted<BrokenPipeline>(exception) };
781 782
  }

783
  kj::Maybe<ClientHook&> getResolved() override {
784 785 786 787
    return nullptr;
  }

  kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
788 789 790 791 792
    if (resolved) {
      return nullptr;
    } else {
      return kj::Promise<kj::Own<ClientHook>>(kj::cp(exception));
    }
793 794 795 796 797 798 799
  }

  kj::Own<ClientHook> addRef() override {
    return kj::addRef(*this);
  }

  const void* getBrand() override {
800
    return brand;
801 802
  }

803 804 805 806
  kj::Maybe<int> getFd() override {
    return nullptr;
  }

807 808
private:
  kj::Exception exception;
809
  bool resolved;
810
  const void* brand;
811 812 813
};

kj::Own<ClientHook> BrokenPipeline::getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) {
814 815 816 817 818
  return kj::refcounted<BrokenClient>(exception, false);
}

kj::Own<ClientHook> newNullCap() {
  // A null capability, unlike other broken capabilities, is considered resolved.
819 820
  return kj::refcounted<BrokenClient>("Called null capability.", true,
                                      &ClientHook::NULL_CAPABILITY_BRAND);
821 822 823 824 825
}

}  // namespace

kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason) {
826
  return kj::refcounted<BrokenClient>(reason, false);
827 828 829
}

kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason) {
830
  return kj::refcounted<BrokenClient>(kj::mv(reason), false);
831 832 833 834 835 836
}

kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
  return kj::refcounted<BrokenPipeline>(kj::mv(reason));
}

837 838 839 840 841 842 843
Request<AnyPointer, AnyPointer> newBrokenRequest(
    kj::Exception&& reason, kj::Maybe<MessageSize> sizeHint) {
  auto hook = kj::heap<BrokenRequest>(kj::mv(reason), sizeHint);
  auto root = hook->message.getRoot<AnyPointer>();
  return Request<AnyPointer, AnyPointer>(root, kj::mv(hook));
}

844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884
// =======================================================================================

ReaderCapabilityTable::ReaderCapabilityTable(
    kj::Array<kj::Maybe<kj::Own<ClientHook>>> table)
    : table(kj::mv(table)) {
  setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
}

kj::Maybe<kj::Own<ClientHook>> ReaderCapabilityTable::extractCap(uint index) {
  if (index < table.size()) {
    return table[index].map([](kj::Own<ClientHook>& cap) { return cap->addRef(); });
  } else {
    return nullptr;
  }
}

BuilderCapabilityTable::BuilderCapabilityTable() {
  setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory);
}

kj::Maybe<kj::Own<ClientHook>> BuilderCapabilityTable::extractCap(uint index) {
  if (index < table.size()) {
    return table[index].map([](kj::Own<ClientHook>& cap) { return cap->addRef(); });
  } else {
    return nullptr;
  }
}

uint BuilderCapabilityTable::injectCap(kj::Own<ClientHook>&& cap) {
  uint result = table.size();
  table.add(kj::mv(cap));
  return result;
}

void BuilderCapabilityTable::dropCap(uint index) {
  KJ_ASSERT(index < table.size(), "Invalid capability descriptor in message.") {
    return;
  }
  table[index] = nullptr;
}

885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909
// =======================================================================================
// CapabilityServerSet

namespace _ {  // private

Capability::Client CapabilityServerSetBase::addInternal(
    kj::Own<Capability::Server>&& server, void* ptr) {
  return Capability::Client(kj::refcounted<LocalClient>(kj::mv(server), *this, ptr));
}

kj::Promise<void*> CapabilityServerSetBase::getLocalServerInternal(Capability::Client& client) {
  ClientHook* hook = client.hook.get();

  // Get the most-resolved-so-far version of the hook.
  KJ_IF_MAYBE(h, hook->getResolved()) {
    hook = h;
  };

  KJ_IF_MAYBE(p, hook->whenMoreResolved()) {
    // This hook is an unresolved promise. We need to wait for it.
    return p->attach(hook->addRef())
        .then([this](kj::Own<ClientHook>&& resolved) {
      Capability::Client client(kj::mv(resolved));
      return getLocalServerInternal(client);
    });
910 911
  } else if (hook->getBrand() == &LocalClient::BRAND) {
    return kj::downcast<LocalClient>(*hook).getLocalServer(*this);
912
  } else {
913
    return (void*)nullptr;
914 915 916 917 918
  }
}

}  // namespace _ (private)

919
}  // namespace capnp