Commit f566cd4d authored by Kenton Varda's avatar Kenton Varda

Capability accessor support code and initial tests of local calls.

parent bdd06585
...@@ -25,12 +25,12 @@ ...@@ -25,12 +25,12 @@
#include "arena.h" #include "arena.h"
#include "message.h" #include "message.h"
#include "capability.h" #include "capability.h"
#include "capability-context.h"
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/refcount.h>
#include <vector> #include <vector>
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
#include "capability.h"
#include "capability-context.h"
namespace capnp { namespace capnp {
namespace _ { // private namespace _ { // private
...@@ -38,6 +38,82 @@ namespace _ { // private ...@@ -38,6 +38,82 @@ namespace _ { // private
Arena::~Arena() noexcept(false) {} Arena::~Arena() noexcept(false) {}
BuilderArena::~BuilderArena() noexcept(false) {} BuilderArena::~BuilderArena() noexcept(false) {}
namespace {
class BrokenPipeline final: public PipelineHook, public kj::Refcounted {
public:
BrokenPipeline(const kj::Exception& exception): exception(exception) {}
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
}
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const override;
private:
kj::Exception exception;
};
class BrokenRequest final: public RequestHook {
public:
BrokenRequest(const kj::Exception& exception, uint firstSegmentWordSize)
: exception(exception), message(firstSegmentWordSize) {}
RemotePromise<TypelessResults> send() override {
return RemotePromise<TypelessResults>(kj::cp(exception),
TypelessResults::Pipeline(kj::refcounted<BrokenPipeline>(exception)));
}
kj::Exception exception;
MallocMessageBuilder message;
};
class BrokenClient final: public ClientHook, public kj::Refcounted {
public:
BrokenClient(const kj::Exception& exception): exception(exception) {}
BrokenClient(const char* description)
: exception(kj::Exception::Nature::PRECONDITION, kj::Exception::Durability::PERMANENT,
"", 0, kj::str(description)) {}
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto hook = kj::heap<BrokenRequest>(exception, firstSegmentWordSize);
return Request<ObjectPointer, TypelessResults>(
hook->message.getRoot<ObjectPointer>(), kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) const override {
return VoidPromiseAndPipeline { kj::cp(exception), kj::heap<BrokenPipeline>(exception) };
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return kj::Promise<kj::Own<const ClientHook>>(kj::cp(exception));
}
kj::Own<const ClientHook> addRef() const override {
return kj::addRef(*this);
}
void* getBrand() const override {
return nullptr;
}
private:
kj::Exception exception;
};
kj::Own<const ClientHook> BrokenPipeline::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) const {
return kj::heap<BrokenClient>(exception);
}
} // namespace
kj::Own<const ClientHook> Arena::extractNullCap() {
return kj::refcounted<BrokenClient>("Calling null capability pointer.");
}
void ReadLimiter::unread(WordCount64 amount) { void ReadLimiter::unread(WordCount64 amount) {
// Be careful not to overflow here. Since ReadLimiter has no thread-safety, it's possible that // Be careful not to overflow here. Since ReadLimiter has no thread-safety, it's possible that
// the limit value was not updated correctly for one or more reads, and therefore unread() could // the limit value was not updated correctly for one or more reads, and therefore unread() could
...@@ -102,40 +178,11 @@ void BasicReaderArena::reportReadLimitReached() { ...@@ -102,40 +178,11 @@ void BasicReaderArena::reportReadLimitReached() {
} }
} }
namespace { kj::Own<const ClientHook> BasicReaderArena::extractCap(const _::StructReader& capDescriptor) {
class DummyClientHook final: public ClientHook {
public:
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
KJ_FAIL_REQUIRE("Calling capability that was extracted from a message that had no "
"capability context.");
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) const override {
KJ_FAIL_REQUIRE("Calling capability that was extracted from a message that had no "
"capability context.");
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return nullptr;
}
kj::Own<const ClientHook> addRef() const override {
return kj::heap<DummyClientHook>();
}
void* getBrand() const override {
return nullptr;
}
};
} // namespace
kj::Own<ClientHook> BasicReaderArena::extractCap(const _::StructReader& capDescriptor) {
KJ_FAIL_REQUIRE("Message contained a capability but is not imbued with a capability context.") { KJ_FAIL_REQUIRE("Message contained a capability but is not imbued with a capability context.") {
return kj::heap<DummyClientHook>(); return kj::heap<BrokenClient>(
"Calling capability extracted from message that was not imbued with a capability "
"context.");
} }
} }
...@@ -188,7 +235,7 @@ void ImbuedReaderArena::reportReadLimitReached() { ...@@ -188,7 +235,7 @@ void ImbuedReaderArena::reportReadLimitReached() {
return base->reportReadLimitReached(); return base->reportReadLimitReached();
} }
kj::Own<ClientHook> ImbuedReaderArena::extractCap(const _::StructReader& capDescriptor) { kj::Own<const ClientHook> ImbuedReaderArena::extractCap(const _::StructReader& capDescriptor) {
return capExtractor->extractCapInternal(capDescriptor); return capExtractor->extractCapInternal(capDescriptor);
} }
...@@ -331,13 +378,19 @@ void BasicBuilderArena::reportReadLimitReached() { ...@@ -331,13 +378,19 @@ void BasicBuilderArena::reportReadLimitReached() {
} }
} }
kj::Own<ClientHook> BasicBuilderArena::extractCap(const _::StructReader& capDescriptor) { kj::Own<const ClientHook> BasicBuilderArena::extractCap(const _::StructReader& capDescriptor) {
KJ_FAIL_REQUIRE("Message contains no capabilities."); KJ_FAIL_REQUIRE("Message contains no capabilities.");
} }
void BasicBuilderArena::injectCap(_::PointerBuilder pointer, kj::Own<ClientHook>&& cap) { OrphanBuilder BasicBuilderArena::injectCap(kj::Own<const ClientHook>&& cap) {
KJ_FAIL_REQUIRE("Cannot inject capability into a builder that has not been imbued with a " KJ_FAIL_REQUIRE("Cannot inject capability into a builder that has not been imbued with a "
"capability context."); "capability context.") {
return OrphanBuilder();
}
}
void BasicBuilderArena::dropCap(const _::StructReader& capDescriptor) {
// They only way we could have a cap in the first place is if the error was already reported...
} }
// ======================================================================================= // =======================================================================================
...@@ -386,7 +439,7 @@ void ImbuedBuilderArena::reportReadLimitReached() { ...@@ -386,7 +439,7 @@ void ImbuedBuilderArena::reportReadLimitReached() {
base->reportReadLimitReached(); base->reportReadLimitReached();
} }
kj::Own<ClientHook> ImbuedBuilderArena::extractCap(const _::StructReader& capDescriptor) { kj::Own<const ClientHook> ImbuedBuilderArena::extractCap(const _::StructReader& capDescriptor) {
return capInjector->getInjectedCapInternal(capDescriptor); return capInjector->getInjectedCapInternal(capDescriptor);
} }
...@@ -400,8 +453,12 @@ BuilderArena::AllocateResult ImbuedBuilderArena::allocate(WordCount amount) { ...@@ -400,8 +453,12 @@ BuilderArena::AllocateResult ImbuedBuilderArena::allocate(WordCount amount) {
return result; return result;
} }
void ImbuedBuilderArena::injectCap(_::PointerBuilder pointer, kj::Own<ClientHook>&& cap) { OrphanBuilder ImbuedBuilderArena::injectCap(kj::Own<const ClientHook>&& cap) {
return capInjector->injectCapInternal(pointer, kj::mv(cap)); return capInjector->injectCapInternal(this, kj::mv(cap));
}
void ImbuedBuilderArena::dropCap(const StructReader& capDescriptor) {
capInjector->dropCapInternal(capDescriptor);
} }
} // namespace _ (private) } // namespace _ (private)
......
...@@ -185,9 +185,13 @@ public: ...@@ -185,9 +185,13 @@ public:
// the VALIDATE_INPUT() macro which may throw an exception; if it return normally, the caller // the VALIDATE_INPUT() macro which may throw an exception; if it return normally, the caller
// will need to continue with default values. // will need to continue with default values.
virtual kj::Own<ClientHook> extractCap(const _::StructReader& capDescriptor) = 0; virtual kj::Own<const ClientHook> extractCap(const _::StructReader& capDescriptor) = 0;
// Given a StructReader for a capability descriptor embedded in the message, return the // Given a StructReader for a capability descriptor embedded in the message, return the
// corresponding capability. // corresponding capability.
kj::Own<const ClientHook> extractNullCap();
// Like extractCap() but called when the pointer was null. This just returns a dummy capability
// that throws exceptions on any call.
}; };
class BasicReaderArena final: public Arena { class BasicReaderArena final: public Arena {
...@@ -199,7 +203,7 @@ public: ...@@ -199,7 +203,7 @@ public:
// implements Arena ------------------------------------------------ // implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override; SegmentReader* tryGetSegment(SegmentId id) override;
void reportReadLimitReached() override; void reportReadLimitReached() override;
kj::Own<ClientHook> extractCap(const _::StructReader& capDescriptor); kj::Own<const ClientHook> extractCap(const _::StructReader& capDescriptor);
private: private:
MessageReader* message; MessageReader* message;
...@@ -223,7 +227,7 @@ public: ...@@ -223,7 +227,7 @@ public:
// implements Arena ------------------------------------------------ // implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override; SegmentReader* tryGetSegment(SegmentId id) override;
void reportReadLimitReached() override; void reportReadLimitReached() override;
kj::Own<ClientHook> extractCap(const _::StructReader& capDescriptor); kj::Own<const ClientHook> extractCap(const _::StructReader& capDescriptor);
private: private:
Arena* base; Arena* base;
...@@ -254,9 +258,12 @@ public: ...@@ -254,9 +258,12 @@ public:
// the arena is guaranteed to succeed. Therefore callers should try to allocate from a specific // the arena is guaranteed to succeed. Therefore callers should try to allocate from a specific
// segment first if there is one, then fall back to the arena. // segment first if there is one, then fall back to the arena.
virtual void injectCap(_::PointerBuilder pointer, kj::Own<ClientHook>&& cap) = 0; virtual OrphanBuilder injectCap(kj::Own<const ClientHook>&& cap) = 0;
// Add the capability to the message and initialize the given pointer as an interface pointer // Add the capability to the message and initialize the given pointer as an interface pointer
// pointing to this cap. // pointing to this cap.
virtual void dropCap(const StructReader& capDescriptor) = 0;
// Remove a capability injected earlier. Called when the pointer is overwritten or zero'd out.
}; };
class BasicBuilderArena final: public BuilderArena { class BasicBuilderArena final: public BuilderArena {
...@@ -277,12 +284,13 @@ public: ...@@ -277,12 +284,13 @@ public:
// implements Arena ------------------------------------------------ // implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override; SegmentReader* tryGetSegment(SegmentId id) override;
void reportReadLimitReached() override; void reportReadLimitReached() override;
kj::Own<ClientHook> extractCap(const _::StructReader& capDescriptor); kj::Own<const ClientHook> extractCap(const _::StructReader& capDescriptor);
// implements BuilderArena ----------------------------------------- // implements BuilderArena -----------------------------------------
SegmentBuilder* getSegment(SegmentId id) override; SegmentBuilder* getSegment(SegmentId id) override;
AllocateResult allocate(WordCount amount) override; AllocateResult allocate(WordCount amount) override;
void injectCap(_::PointerBuilder pointer, kj::Own<ClientHook>&& cap); OrphanBuilder injectCap(kj::Own<const ClientHook>&& cap);
void dropCap(const StructReader& capDescriptor);
private: private:
MessageBuilder* message; MessageBuilder* message;
...@@ -312,12 +320,13 @@ public: ...@@ -312,12 +320,13 @@ public:
// implements Arena ------------------------------------------------ // implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override; SegmentReader* tryGetSegment(SegmentId id) override;
void reportReadLimitReached() override; void reportReadLimitReached() override;
kj::Own<ClientHook> extractCap(const _::StructReader& capDescriptor); kj::Own<const ClientHook> extractCap(const _::StructReader& capDescriptor);
// implements BuilderArena ----------------------------------------- // implements BuilderArena -----------------------------------------
SegmentBuilder* getSegment(SegmentId id) override; SegmentBuilder* getSegment(SegmentId id) override;
AllocateResult allocate(WordCount amount) override; AllocateResult allocate(WordCount amount) override;
void injectCap(_::PointerBuilder pointer, kj::Own<ClientHook>&& cap); OrphanBuilder injectCap(kj::Own<const ClientHook>&& cap);
void dropCap(const StructReader& capDescriptor);
private: private:
BuilderArena* base; BuilderArena* base;
......
...@@ -54,7 +54,8 @@ class CapExtractorBase { ...@@ -54,7 +54,8 @@ class CapExtractorBase {
// Non-template base class for CapExtractor<T>. // Non-template base class for CapExtractor<T>.
private: private:
virtual kj::Own<ClientHook> extractCapInternal(const _::StructReader& capDescriptor) = 0; virtual kj::Own<const ClientHook> extractCapInternal(
const _::StructReader& capDescriptor) const = 0;
friend class _::ImbuedReaderArena; friend class _::ImbuedReaderArena;
}; };
...@@ -62,8 +63,11 @@ class CapInjectorBase { ...@@ -62,8 +63,11 @@ class CapInjectorBase {
// Non-template base class for CapInjector<T>. // Non-template base class for CapInjector<T>.
private: private:
virtual void injectCapInternal(_::PointerBuilder builder, kj::Own<ClientHook>&& cap) = 0; virtual _::OrphanBuilder injectCapInternal(
virtual kj::Own<ClientHook> getInjectedCapInternal(const _::StructReader& capDescriptor) = 0; _::BuilderArena* arena, kj::Own<const ClientHook>&& cap) const = 0;
virtual void dropCapInternal(const _::StructReader& capDescriptor) const = 0;
virtual kj::Own<const ClientHook> getInjectedCapInternal(
const _::StructReader& capDescriptor) const = 0;
friend class _::ImbuedBuilderArena; friend class _::ImbuedBuilderArena;
}; };
...@@ -74,11 +78,12 @@ class CapExtractor: public CapExtractorBase { ...@@ -74,11 +78,12 @@ class CapExtractor: public CapExtractorBase {
// capabilities. (On the wire, an interface pointer actually points to a struct of this type.) // capabilities. (On the wire, an interface pointer actually points to a struct of this type.)
public: public:
virtual kj::Own<ClientHook> extractCap(typename CapDescriptor::Reader descriptor) = 0; virtual kj::Own<const ClientHook> extractCap(typename CapDescriptor::Reader descriptor) const = 0;
// Given the descriptor read off the wire, construct a live capability. // Given the descriptor read off the wire, construct a live capability.
private: private:
kj::Own<ClientHook> extractCapInternal(const _::StructReader& capDescriptor) override final { kj::Own<const ClientHook> extractCapInternal(
const _::StructReader& capDescriptor) const override final {
return extractCap(typename CapDescriptor::Reader(capDescriptor)); return extractCap(typename CapDescriptor::Reader(capDescriptor));
} }
}; };
...@@ -90,21 +95,34 @@ class CapInjector: public CapInjectorBase { ...@@ -90,21 +95,34 @@ class CapInjector: public CapInjectorBase {
// capabilities. (On the wire, an interface pointer actually points to a struct of this type.) // capabilities. (On the wire, an interface pointer actually points to a struct of this type.)
public: public:
virtual void injectCap(typename CapDescriptor::Builder descriptor, kj::Own<ClientHook>&& cap) = 0; virtual void injectCap(typename CapDescriptor::Builder descriptor,
kj::Own<const ClientHook>&& cap) const = 0;
// Fill in the given descriptor so that it describes the given capability. // Fill in the given descriptor so that it describes the given capability.
virtual kj::Own<ClientHook> getInjectedCap(typename CapDescriptor::Reader descriptor) = 0; virtual kj::Own<const ClientHook> getInjectedCap(
typename CapDescriptor::Reader descriptor) const = 0;
// Read back a cap that was previously injected with `injectCap`. This should return a new
// reference.
virtual void dropCap(typename CapDescriptor::Reader descriptor) const = 0;
// Read back a cap that was previously injected with `injectCap`. This should return a new // Read back a cap that was previously injected with `injectCap`. This should return a new
// reference. // reference.
private: private:
void injectCapInternal(_::PointerBuilder builder, kj::Own<ClientHook>&& cap) override final { _::OrphanBuilder injectCapInternal(_::BuilderArena* arena,
injectCap( kj::Own<const ClientHook>&& cap) const override final {
typename CapDescriptor::Builder(builder.initCapDescriptor(_::structSize<CapDescriptor>())), auto result = _::OrphanBuilder::initStruct(arena, _::structSize<CapDescriptor>());
injectCap(typename CapDescriptor::Builder(result.asStruct(_::structSize<CapDescriptor>())),
kj::mv(cap)); kj::mv(cap));
return kj::mv(result);
}
void dropCapInternal(const _::StructReader& capDescriptor) const override final {
dropCap(typename CapDescriptor::Reader(capDescriptor));
} }
kj::Own<ClientHook> getInjectedCapInternal(const _::StructReader& capDescriptor) { kj::Own<const ClientHook> getInjectedCapInternal(
const _::StructReader& capDescriptor) const override final {
return getInjectedCap(typename CapDescriptor::Reader(capDescriptor)); return getInjectedCap(typename CapDescriptor::Reader(capDescriptor));
} }
}; };
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "capability.h"
#include "test-util.h"
#include <kj/debug.h>
#include <gtest/gtest.h>
namespace capnp {
namespace _ {
namespace {
class TestInterfaceImpl final: public test::TestInterface::Server {
public:
TestInterfaceImpl(int& callCount): callCount(callCount) {}
int& callCount;
virtual ::kj::Promise<void> foo(
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) {
++callCount;
EXPECT_EQ(123, params.getI());
EXPECT_TRUE(params.getJ());
result.setX("foo");
return kj::READY_NOW;
}
virtual ::kj::Promise<void> bazAdvanced(
::capnp::CallContext<test::TestInterface::BazParams,
test::TestInterface::BazResults> context) {
++callCount;
auto params = context.getParams();
checkTestMessage(params.getS());
context.releaseParams();
#if !KJ_NO_EXCEPTIONS
EXPECT_ANY_THROW(context.getParams());
#endif
return kj::READY_NOW;
}
};
TEST(Capability, Basic) {
kj::SimpleEventLoop loop;
int callCount;
test::TestInterface::Client client(makeLocalClient(kj::heap<TestInterfaceImpl>(callCount), loop));
auto request1 = client.fooRequest();
request1.setI(123);
request1.setJ(true);
auto promise1 = request1.send();
auto request2 = client.bazRequest();
initTestMessage(request2.initS());
auto promise2 = request2.send();
auto request3 = client.barRequest();
auto promise3 = loop.there(request3.send(),
[](Response<test::TestInterface::BarResults>&& response) {
ADD_FAILURE() << "Expected bar() call to fail.";
}, [](kj::Exception&& e) {
// success
});
EXPECT_EQ(0, callCount);
auto response1 = loop.wait(kj::mv(promise1));
EXPECT_EQ("foo", response1.getX());
auto response2 = loop.wait(kj::mv(promise2));
loop.wait(kj::mv(promise3));
EXPECT_EQ(2, callCount);
}
class TestExtendsImpl final: public test::TestExtends::Server {
public:
TestExtendsImpl(int& callCount): callCount(callCount) {}
int& callCount;
virtual ::kj::Promise<void> foo(
test::TestInterface::FooParams::Reader params,
test::TestInterface::FooResults::Builder result) {
++callCount;
EXPECT_EQ(321, params.getI());
EXPECT_FALSE(params.getJ());
result.setX("bar");
return kj::READY_NOW;
}
virtual ::kj::Promise<void> graultAdvanced(
::capnp::CallContext<test::TestExtends::GraultParams, test::TestAllTypes> context) {
++callCount;
context.releaseParams();
initTestMessage(context.getResults());
return kj::READY_NOW;
}
};
TEST(Capability, Inheritance) {
kj::SimpleEventLoop loop;
int callCount;
test::TestExtends::Client client(makeLocalClient(kj::heap<TestExtendsImpl>(callCount), loop));
auto request1 = client.fooRequest();
request1.setI(321);
auto promise1 = request1.send();
auto request2 = client.graultRequest();
auto promise2 = request2.send();
EXPECT_EQ(0, callCount);
auto response2 = loop.wait(kj::mv(promise2));
checkTestMessage(response2);
auto response1 = loop.wait(kj::mv(promise1));
EXPECT_EQ("bar", response1.getX());
EXPECT_EQ(2, callCount);
}
} // namespace
} // namespace _
} // namespace capnp
...@@ -97,7 +97,11 @@ public: ...@@ -97,7 +97,11 @@ public:
: request(kj::mv(request)), clientRef(kj::mv(clientRef)) {} : request(kj::mv(request)), clientRef(kj::mv(clientRef)) {}
ObjectPointer::Reader getParams() override { ObjectPointer::Reader getParams() override {
return request->getRoot<ObjectPointer>(); KJ_IF_MAYBE(r, request) {
return r->get()->getRoot<ObjectPointer>();
} else {
KJ_FAIL_REQUIRE("Can't call getParams() after releaseParams().");
}
} }
void releaseParams() override { void releaseParams() override {
request = nullptr; request = nullptr;
...@@ -118,7 +122,7 @@ public: ...@@ -118,7 +122,7 @@ public:
return kj::addRef(*this); return kj::addRef(*this);
} }
kj::Own<MallocMessageBuilder> request; kj::Maybe<kj::Own<MallocMessageBuilder>> request;
kj::Own<LocalResponse> response; kj::Own<LocalResponse> response;
kj::Own<const ClientHook> clientRef; kj::Own<const ClientHook> clientRef;
}; };
...@@ -137,7 +141,7 @@ public: ...@@ -137,7 +141,7 @@ public:
uint64_t interfaceId = this->interfaceId; uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId; uint16_t methodId = this->methodId;
auto context = kj::refcounted<LocalCallContext>(kj::mv(message), kj::mv(client)); auto context = kj::refcounted<LocalCallContext>(kj::mv(message), client->addRef());
auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context)); auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context));
auto promise = loop.there(kj::mv(promiseAndPipeline.promise), auto promise = loop.there(kj::mv(promiseAndPipeline.promise),
...@@ -328,8 +332,8 @@ kj::Own<const ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>& ...@@ -328,8 +332,8 @@ kj::Own<const ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&
class LocalPipeline final: public PipelineHook, public kj::Refcounted { class LocalPipeline final: public PipelineHook, public kj::Refcounted {
public: public:
inline LocalPipeline(kj::Own<CallContextHook>&& context) inline LocalPipeline(kj::Own<CallContextHook>&& contextParam)
: context(kj::mv(context)), : context(kj::mv(contextParam)),
results(context->getResults(1)) {} results(context->getResults(1)) {}
kj::Own<const PipelineHook> addRef() const { kj::Own<const PipelineHook> addRef() const {
...@@ -376,7 +380,7 @@ public: ...@@ -376,7 +380,7 @@ public:
// Note also that QueuedClient depends on this evalLater() to ensure that pipelined calls don't // Note also that QueuedClient depends on this evalLater() to ensure that pipelined calls don't
// complete before 'whenMoreResolved()' promises resolve. // complete before 'whenMoreResolved()' promises resolve.
auto promise = eventLoop.evalLater( auto promise = eventLoop.evalLater(
[=]() mutable { [=]() {
return server->dispatchCall(interfaceId, methodId, return server->dispatchCall(interfaceId, methodId,
CallContext<ObjectPointer, ObjectPointer>(*contextPtr)); CallContext<ObjectPointer, ObjectPointer>(*contextPtr));
}); });
...@@ -394,6 +398,9 @@ public: ...@@ -394,6 +398,9 @@ public:
[=](kj::Own<CallContextHook>&& context) { [=](kj::Own<CallContextHook>&& context) {
// Nothing to do here. We just wanted to make sure to hold on to a reference to the // Nothing to do here. We just wanted to make sure to hold on to a reference to the
// context even if the pipeline was discarded. // context even if the pipeline was discarded.
//
// TODO(someday): We could probably make this less ugly if we had the ability to
// convert Promise<Tuple<T, U>> -> Tuple<Promise<T>, Promise<U>>...
})); }));
return VoidPromiseAndPipeline { kj::mv(completionPromise), return VoidPromiseAndPipeline { kj::mv(completionPromise),
......
...@@ -243,6 +243,8 @@ kj::Own<const ClientHook> makeLocalClient(kj::Own<Capability::Server>&& server, ...@@ -243,6 +243,8 @@ kj::Own<const ClientHook> makeLocalClient(kj::Own<Capability::Server>&& server,
kj::EventLoop& eventLoop = kj::EventLoop::current()); kj::EventLoop& eventLoop = kj::EventLoop::current());
// Make a client capability that wraps the given server capability. The server's methods will // Make a client capability that wraps the given server capability. The server's methods will
// only be executed in the given EventLoop, regardless of what thread calls the client's methods. // only be executed in the given EventLoop, regardless of what thread calls the client's methods.
//
// TODO(now): Templated version or something.
// ======================================================================================= // =======================================================================================
...@@ -395,7 +397,7 @@ RemotePromise<Results> Request<Params, Results>::send() { ...@@ -395,7 +397,7 @@ RemotePromise<Results> Request<Params, Results>::send() {
// Explicitly upcast to kj::Promise to make clear that calling .then() doesn't invalidate the // Explicitly upcast to kj::Promise to make clear that calling .then() doesn't invalidate the
// Pipeline part of the RemotePromise. // Pipeline part of the RemotePromise.
auto typedPromise = kj::implicitCast<kj::Promise<Response<TypelessResults>>&>(typelessPromise) auto typedPromise = kj::implicitCast<kj::Promise<Response<TypelessResults>>&>(typelessPromise)
.then([](Response<TypelessResults>&& response) -> Response<Results> { .thenInAnyThread([](Response<TypelessResults>&& response) -> Response<Results> {
return Response<Results>(response.getAs<Results>(), kj::mv(response.hook)); return Response<Results>(response.getAs<Results>(), kj::mv(response.hook));
}); });
......
...@@ -527,6 +527,7 @@ private: ...@@ -527,6 +527,7 @@ private:
struct FieldText { struct FieldText {
kj::StringTree readerMethodDecls; kj::StringTree readerMethodDecls;
kj::StringTree builderMethodDecls; kj::StringTree builderMethodDecls;
kj::StringTree pipelineMethodDecls;
kj::StringTree inlineMethodDefs; kj::StringTree inlineMethodDefs;
}; };
...@@ -570,6 +571,8 @@ private: ...@@ -570,6 +571,8 @@ private:
" inline ", titleCase, "::Builder init", titleCase, "();\n" " inline ", titleCase, "::Builder init", titleCase, "();\n"
"\n"), "\n"),
kj::strTree(),
kj::strTree( kj::strTree(
kj::mv(unionDiscrim.isDefs), kj::mv(unionDiscrim.isDefs),
"inline bool ", scope, "Reader::has", titleCase, "() const {\n", "inline bool ", scope, "Reader::has", titleCase, "() const {\n",
...@@ -774,6 +777,8 @@ private: ...@@ -774,6 +777,8 @@ private:
" inline void set", titleCase, "(", type, " value", setterDefault, ");\n" " inline void set", titleCase, "(", type, " value", setterDefault, ");\n"
"\n"), "\n"),
kj::strTree(),
kj::strTree( kj::strTree(
kj::mv(unionDiscrim.isDefs), kj::mv(unionDiscrim.isDefs),
"inline bool ", scope, "Reader::has", titleCase, "() const {\n", "inline bool ", scope, "Reader::has", titleCase, "() const {\n",
...@@ -806,7 +811,7 @@ private: ...@@ -806,7 +811,7 @@ private:
} else if (kind == FieldKind::INTERFACE) { } else if (kind == FieldKind::INTERFACE) {
// Not implemented. // Not implemented.
return FieldText { kj::strTree(), kj::strTree(), kj::strTree() }; return FieldText { kj::strTree(), kj::strTree(), kj::strTree(), kj::strTree() };
} else if (kind == FieldKind::OBJECT) { } else if (kind == FieldKind::OBJECT) {
return FieldText { return FieldText {
...@@ -823,6 +828,8 @@ private: ...@@ -823,6 +828,8 @@ private:
" inline ::capnp::ObjectPointer::Builder init", titleCase, "();\n" " inline ::capnp::ObjectPointer::Builder init", titleCase, "();\n"
"\n"), "\n"),
kj::strTree(),
kj::strTree( kj::strTree(
kj::mv(unionDiscrim.isDefs), kj::mv(unionDiscrim.isDefs),
"inline bool ", scope, "Reader::has", titleCase, "() const {\n", "inline bool ", scope, "Reader::has", titleCase, "() const {\n",
...@@ -925,6 +932,8 @@ private: ...@@ -925,6 +932,8 @@ private:
" inline ::capnp::Orphan<", type, "> disown", titleCase, "();\n" " inline ::capnp::Orphan<", type, "> disown", titleCase, "();\n"
"\n"), "\n"),
kj::strTree(),
kj::strTree( kj::strTree(
kj::mv(unionDiscrim.isDefs), kj::mv(unionDiscrim.isDefs),
"inline bool ", scope, "Reader::has", titleCase, "() const {\n", "inline bool ", scope, "Reader::has", titleCase, "() const {\n",
...@@ -1063,6 +1072,25 @@ private: ...@@ -1063,6 +1072,25 @@ private:
"\n"); "\n");
} }
kj::StringTree makePipelineDef(kj::StringPtr fullName, kj::StringPtr unqualifiedParentType,
bool isUnion, kj::Array<kj::StringTree>&& methodDecls) {
return kj::strTree(
"class ", fullName, "::Pipeline {\n"
"public:\n"
" typedef ", unqualifiedParentType, " Pipelines;\n"
"\n"
" inline explicit Pipeline(::capnp::TypelessResults::Pipeline&& typeless)\n"
" : _typeless(kj::mv(typeless)) {}\n"
"\n",
kj::mv(methodDecls),
"private:\n"
" ::capnp::TypelessResults::Pipeline _typeless;\n"
" template <typename T, ::capnp::Kind k>\n"
" friend struct ::capnp::ToDynamic_;\n"
"};\n"
"\n");
}
StructText makeStructText(kj::StringPtr scope, kj::StringPtr name, StructSchema schema, StructText makeStructText(kj::StringPtr scope, kj::StringPtr name, StructSchema schema,
kj::Array<kj::StringTree> nestedTypeDecls) { kj::Array<kj::StringTree> nestedTypeDecls) {
auto proto = schema.getProto(); auto proto = schema.getProto();
...@@ -1082,7 +1110,8 @@ private: ...@@ -1082,7 +1110,8 @@ private:
" ", name, "() = delete;\n" " ", name, "() = delete;\n"
"\n" "\n"
" class Reader;\n" " class Reader;\n"
" class Builder;\n", " class Builder;\n"
" class Pipeline;\n",
structNode.getDiscriminantCount() == 0 ? kj::strTree() : kj::strTree( structNode.getDiscriminantCount() == 0 ? kj::strTree() : kj::strTree(
" enum Which: uint16_t {\n", " enum Which: uint16_t {\n",
KJ_MAP(f, structNode.getFields()) { KJ_MAP(f, structNode.getFields()) {
...@@ -1101,7 +1130,9 @@ private: ...@@ -1101,7 +1130,9 @@ private:
makeReaderDef(fullName, name, structNode.getDiscriminantCount() != 0, makeReaderDef(fullName, name, structNode.getDiscriminantCount() != 0,
KJ_MAP(f, fieldTexts) { return kj::mv(f.readerMethodDecls); }), KJ_MAP(f, fieldTexts) { return kj::mv(f.readerMethodDecls); }),
makeBuilderDef(fullName, name, structNode.getDiscriminantCount() != 0, makeBuilderDef(fullName, name, structNode.getDiscriminantCount() != 0,
KJ_MAP(f, fieldTexts) { return kj::mv(f.builderMethodDecls); })), KJ_MAP(f, fieldTexts) { return kj::mv(f.builderMethodDecls); }),
makePipelineDef(fullName, name, structNode.getDiscriminantCount() != 0,
KJ_MAP(f, fieldTexts) { return kj::mv(f.pipelineMethodDecls); })),
kj::strTree( kj::strTree(
structNode.getDiscriminantCount() == 0 ? kj::strTree() : kj::strTree( structNode.getDiscriminantCount() == 0 ? kj::strTree() : kj::strTree(
...@@ -1267,7 +1298,9 @@ private: ...@@ -1267,7 +1298,9 @@ private:
"::kj::Promise<void> ", fullName, "::Server::dispatchCall(\n" "::kj::Promise<void> ", fullName, "::Server::dispatchCall(\n"
" uint64_t interfaceId, uint16_t methodId,\n" " uint64_t interfaceId, uint16_t methodId,\n"
" ::capnp::CallContext< ::capnp::ObjectPointer, ::capnp::ObjectPointer> context) {\n" " ::capnp::CallContext< ::capnp::ObjectPointer, ::capnp::ObjectPointer> context) {\n"
" switch (interfaceId) {\n", " switch (interfaceId) {\n"
" case 0x", kj::hex(proto.getId()), "ull:\n"
" return dispatchCallInternal(methodId, context);\n",
KJ_MAP(e, extends) { KJ_MAP(e, extends) {
return kj::strTree( return kj::strTree(
" case 0x", kj::hex(e.id), "ull:\n" " case 0x", kj::hex(e.id), "ull:\n"
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "layout.h" #include "layout.h"
#include <kj/debug.h> #include <kj/debug.h>
#include "arena.h" #include "arena.h"
#include "capability.h"
#include <string.h> #include <string.h>
#include <limits> #include <limits>
#include <stdlib.h> #include <stdlib.h>
...@@ -69,8 +70,9 @@ struct WirePointer { ...@@ -69,8 +70,9 @@ struct WirePointer {
// Reference is a "far pointer", which points at data located in a different segment. The // Reference is a "far pointer", which points at data located in a different segment. The
// eventual target is one of the other kinds. // eventual target is one of the other kinds.
RESERVED_3 = 3 CAPABILITY = 3
// Reserved for future use. // Reference points at a capability descriptor struct. Other than the kind, the pointer has
// the same format as a struct.
}; };
WireValue<uint32_t> offsetAndKind; WireValue<uint32_t> offsetAndKind;
...@@ -430,9 +432,8 @@ struct WireHelpers { ...@@ -430,9 +432,8 @@ struct WireHelpers {
switch (ref->kind()) { switch (ref->kind()) {
case WirePointer::STRUCT: case WirePointer::STRUCT:
zeroObject(segment, ref, ref->target());
break;
case WirePointer::LIST: case WirePointer::LIST:
case WirePointer::CAPABILITY:
zeroObject(segment, ref, ref->target()); zeroObject(segment, ref, ref->target());
break; break;
case WirePointer::FAR: { case WirePointer::FAR: {
...@@ -450,16 +451,20 @@ struct WireHelpers { ...@@ -450,16 +451,20 @@ struct WireHelpers {
} }
break; break;
} }
case WirePointer::RESERVED_3:
KJ_FAIL_ASSERT("Don't know how to handle RESERVED_3.") {
break;
}
break;
} }
} }
static void zeroObject(SegmentBuilder* segment, WirePointer* tag, word* ptr) { static void zeroObject(SegmentBuilder* segment, WirePointer* tag, word* ptr) {
switch (tag->kind()) { switch (tag->kind()) {
case WirePointer::CAPABILITY:
segment->getArena()->dropCap(StructReader(
segment, ptr,
reinterpret_cast<const WirePointer*>(ptr + tag->structRef.dataSize.get()),
tag->structRef.dataSize.get() * BITS_PER_WORD,
tag->structRef.ptrCount.get(),
0 * BITS, std::numeric_limits<int>::max()));
// no break: treat like struct pointer
case WirePointer::STRUCT: { case WirePointer::STRUCT: {
WirePointer* pointerSection = WirePointer* pointerSection =
reinterpret_cast<WirePointer*>(ptr + tag->structRef.dataSize.get()); reinterpret_cast<WirePointer*>(ptr + tag->structRef.dataSize.get());
...@@ -524,11 +529,6 @@ struct WireHelpers { ...@@ -524,11 +529,6 @@ struct WireHelpers {
break; break;
} }
break; break;
case WirePointer::RESERVED_3:
KJ_FAIL_ASSERT("Don't know how to handle RESERVED_3.") {
break;
}
break;
} }
} }
...@@ -565,7 +565,8 @@ struct WireHelpers { ...@@ -565,7 +565,8 @@ struct WireHelpers {
WordCount64 result = 0 * WORDS; WordCount64 result = 0 * WORDS;
switch (ref->kind()) { switch (ref->kind()) {
case WirePointer::STRUCT: { case WirePointer::STRUCT:
case WirePointer::CAPABILITY: {
KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
return result; return result;
...@@ -661,11 +662,6 @@ struct WireHelpers { ...@@ -661,11 +662,6 @@ struct WireHelpers {
break; break;
} }
break; break;
case WirePointer::RESERVED_3:
KJ_FAIL_REQUIRE("Don't know how to handle RESERVED_3.") {
return result;
}
break;
} }
return result; return result;
...@@ -776,9 +772,11 @@ struct WireHelpers { ...@@ -776,9 +772,11 @@ struct WireHelpers {
} }
break; break;
} }
case WirePointer::RESERVED_3: case WirePointer::CAPABILITY:
default: KJ_FAIL_REQUIRE("Unchecked messages cannot contain capabilities.");
KJ_FAIL_REQUIRE("Copy source message contained unexpected kind."); break;
case WirePointer::FAR:
KJ_FAIL_REQUIRE("Unchecked messages cannot contain far pointers.");
break; break;
} }
...@@ -1497,6 +1495,42 @@ struct WireHelpers { ...@@ -1497,6 +1495,42 @@ struct WireHelpers {
return { segment, ptr }; return { segment, ptr };
} }
static SegmentAnd<word*> setCapabilityPointer(
SegmentBuilder* segment, WirePointer* ref, kj::Own<const ClientHook>&& cap,
BuilderArena* orphanArena = nullptr) {
if (orphanArena == nullptr) {
auto orphan = segment->getArena()->injectCap(kj::mv(cap));
SegmentAnd<word*> result = { orphan.segment, orphan.location };
adopt(segment, ref, kj::mv(orphan));
if (ref->kind() == WirePointer::STRUCT) {
ref->setKindAndTarget(WirePointer::CAPABILITY, ref->target(), segment);
}
return result;
} else {
// We're actually writing into another OrphanBuilder. If we had direct access to it, we
// could just use the move constructor, but we don't quite...
auto orphan = orphanArena->injectCap(kj::mv(cap));
SegmentAnd<word*> result = { orphan.segment, orphan.location };
memcpy(ref, orphan.tagAsPtr(), sizeof(*ref));
if (ref->kind() == WirePointer::STRUCT) {
ref->setKindForOrphan(WirePointer::CAPABILITY);
}
// Zero out the orphan because we have transferred ownership manually.
memset(orphan.tagAsPtr(), 0, sizeof(WirePointer));
orphan.location = nullptr;
orphan.segment = nullptr;
return result;
}
}
static SegmentAnd<word*> setListPointer( static SegmentAnd<word*> setListPointer(
SegmentBuilder* segment, WirePointer* ref, ListReader value, SegmentBuilder* segment, WirePointer* ref, ListReader value,
BuilderArena* orphanArena = nullptr) { BuilderArena* orphanArena = nullptr) {
...@@ -1667,9 +1701,29 @@ struct WireHelpers { ...@@ -1667,9 +1701,29 @@ struct WireHelpers {
} }
} }
case WirePointer::RESERVED_3: case WirePointer::CAPABILITY: {
default: KJ_REQUIRE(nestingLimit > 0,
KJ_FAIL_REQUIRE("Message contained invalid pointer.") { "Message is too deeply-nested or contains cycles. See capnp::ReadOptions.") {
goto useDefault;
}
KJ_REQUIRE(boundsCheck(srcSegment, ptr, ptr + src->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") {
goto useDefault;
}
setCapabilityPointer(dstSegment, dst,
srcSegment->getArena()->extractCap(StructReader(
srcSegment, ptr,
reinterpret_cast<const WirePointer*>(ptr + src->structRef.dataSize.get()),
src->structRef.dataSize.get() * BITS_PER_WORD,
src->structRef.ptrCount.get(),
0 * BITS, nestingLimit - 1)),
orphanArena);
}
case WirePointer::FAR:
KJ_FAIL_ASSERT("Far pointer should have been handled above.") {
goto useDefault; goto useDefault;
} }
} }
...@@ -1732,6 +1786,28 @@ struct WireHelpers { ...@@ -1732,6 +1786,28 @@ struct WireHelpers {
static KJ_ALWAYS_INLINE(StructReader readStructPointer( static KJ_ALWAYS_INLINE(StructReader readStructPointer(
SegmentReader* segment, const WirePointer* ref, const word* refTarget, SegmentReader* segment, const WirePointer* ref, const word* refTarget,
const word* defaultValue, int nestingLimit)) { const word* defaultValue, int nestingLimit)) {
return readStructOrCapDescPointer(WirePointer::STRUCT, segment, ref, refTarget, defaultValue,
nestingLimit);
}
static KJ_ALWAYS_INLINE(kj::Own<const ClientHook> readCapabilityPointer(
SegmentReader* segment, const WirePointer* ref, int nestingLimit)) {
return readCapabilityPointer(segment, ref, ref->target(), nestingLimit);
}
static KJ_ALWAYS_INLINE(kj::Own<const ClientHook> readCapabilityPointer(
SegmentReader* segment, const WirePointer* ref, const word* refTarget, int nestingLimit)) {
if (ref->isNull()) {
return segment->getArena()->extractNullCap();
} else {
return segment->getArena()->extractCap(readStructOrCapDescPointer(
WirePointer::CAPABILITY, segment, ref, refTarget, nullptr, nestingLimit));
}
}
static KJ_ALWAYS_INLINE(StructReader readStructOrCapDescPointer(WirePointer::Kind kind,
SegmentReader* segment, const WirePointer* ref, const word* refTarget,
const word* defaultValue, int nestingLimit)) {
if (ref == nullptr || ref->isNull()) { if (ref == nullptr || ref->isNull()) {
useDefault: useDefault:
if (defaultValue == nullptr || if (defaultValue == nullptr ||
...@@ -1755,8 +1831,10 @@ struct WireHelpers { ...@@ -1755,8 +1831,10 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
KJ_REQUIRE(ref->kind() == WirePointer::STRUCT, KJ_REQUIRE(ref->kind() == kind,
"Message contains non-struct pointer where struct pointer was expected.") { kind == WirePointer::CAPABILITY
? "Message contains non-capability pointer where capability pointer was expected."
: "Message contains non-struct pointer where struct pointer was expected.") {
goto useDefault; goto useDefault;
} }
...@@ -2075,6 +2153,15 @@ void PointerBuilder::setList(const ListReader& value) { ...@@ -2075,6 +2153,15 @@ void PointerBuilder::setList(const ListReader& value) {
WireHelpers::setListPointer(segment, pointer, value); WireHelpers::setListPointer(segment, pointer, value);
} }
kj::Own<const ClientHook> PointerBuilder::getCapability() {
return WireHelpers::readCapabilityPointer(
segment, pointer, std::numeric_limits<int>::max());
}
void PointerBuilder::setCapability(kj::Own<const ClientHook>&& cap) {
WireHelpers::setCapabilityPointer(segment, pointer, kj::mv(cap));
}
void PointerBuilder::adopt(OrphanBuilder&& value) { void PointerBuilder::adopt(OrphanBuilder&& value) {
WireHelpers::adopt(segment, pointer, kj::mv(value)); WireHelpers::adopt(segment, pointer, kj::mv(value));
} }
...@@ -2148,6 +2235,11 @@ Data::Reader PointerReader::getBlob<Data>(const void* defaultValue, ByteCount de ...@@ -2148,6 +2235,11 @@ Data::Reader PointerReader::getBlob<Data>(const void* defaultValue, ByteCount de
return WireHelpers::readDataPointer(segment, ref, defaultValue, defaultSize); return WireHelpers::readDataPointer(segment, ref, defaultValue, defaultSize);
} }
kj::Own<const ClientHook> PointerReader::getCapability() const {
const WirePointer* ref = pointer == nullptr ? &zero.pointer : pointer;
return WireHelpers::readCapabilityPointer(segment, ref, nestingLimit);
}
const word* PointerReader::getUnchecked() const { const word* PointerReader::getUnchecked() const {
KJ_REQUIRE(segment == nullptr, "getUncheckedPointer() only allowed on unchecked messages."); KJ_REQUIRE(segment == nullptr, "getUncheckedPointer() only allowed on unchecked messages.");
return reinterpret_cast<const word*>(pointer); return reinterpret_cast<const word*>(pointer);
......
...@@ -300,7 +300,6 @@ public: ...@@ -300,7 +300,6 @@ public:
ListBuilder initList(FieldSize elementSize, ElementCount elementCount); ListBuilder initList(FieldSize elementSize, ElementCount elementCount);
ListBuilder initStructList(ElementCount elementCount, StructSize size); ListBuilder initStructList(ElementCount elementCount, StructSize size);
template <typename T> typename T::Builder initBlob(ByteCount size); template <typename T> typename T::Builder initBlob(ByteCount size);
StructBuilder initCapDescriptor(StructSize size);
// Init methods: Initialize the pointer to a newly-allocated object, discarding the existing // Init methods: Initialize the pointer to a newly-allocated object, discarding the existing
// object. // object.
...@@ -362,7 +361,7 @@ public: ...@@ -362,7 +361,7 @@ public:
ListReader getList(FieldSize expectedElementSize, const word* defaultValue) const; ListReader getList(FieldSize expectedElementSize, const word* defaultValue) const;
template <typename T> template <typename T>
typename T::Reader getBlob(const void* defaultValue, ByteCount defaultSize) const; typename T::Reader getBlob(const void* defaultValue, ByteCount defaultSize) const;
kj::Own<const ClientHook> getCapability(); kj::Own<const ClientHook> getCapability() const;
// Get methods: Get the value. If it is null, return the default value instead. // Get methods: Get the value. If it is null, return the default value instead.
// The default value is encoded as an "unchecked message" for structs, lists, and objects, or a // The default value is encoded as an "unchecked message" for structs, lists, and objects, or a
// simple byte array for blobs. // simple byte array for blobs.
......
...@@ -77,6 +77,8 @@ private: ...@@ -77,6 +77,8 @@ private:
friend class Orphan; friend class Orphan;
friend class Orphanage; friend class Orphanage;
friend class MessageBuilder; friend class MessageBuilder;
template <typename>
friend class CapInjector;
}; };
class Orphanage: private kj::DisallowConstCopy { class Orphanage: private kj::DisallowConstCopy {
...@@ -131,6 +133,8 @@ private: ...@@ -131,6 +133,8 @@ private:
struct NewOrphanListImpl; struct NewOrphanListImpl;
friend class MessageBuilder; friend class MessageBuilder;
template <typename>
friend class CapInjector;
}; };
// ======================================================================================= // =======================================================================================
......
...@@ -156,6 +156,18 @@ TEST(Async, Then) { ...@@ -156,6 +156,18 @@ TEST(Async, Then) {
EXPECT_TRUE(innerDone); EXPECT_TRUE(innerDone);
} }
TEST(Async, ThenInAnyThread) {
SimpleEventLoop loop;
Promise<int> a = 123;
bool done = false;
Promise<int> promise = a.thenInAnyThread([&](int ai) { done = true; return ai + 321; });
EXPECT_FALSE(done);
EXPECT_EQ(444, loop.wait(kj::mv(promise)));
EXPECT_TRUE(done);
}
TEST(Async, Chain) { TEST(Async, Chain) {
SimpleEventLoop loop; SimpleEventLoop loop;
......
...@@ -311,7 +311,7 @@ void ImmediateBrokenPromiseNode::get(ExceptionOrValue& output) noexcept { ...@@ -311,7 +311,7 @@ void ImmediateBrokenPromiseNode::get(ExceptionOrValue& output) noexcept {
// ------------------------------------------------------------------- // -------------------------------------------------------------------
TransformPromiseNodeBase::TransformPromiseNodeBase( TransformPromiseNodeBase::TransformPromiseNodeBase(
const EventLoop& loop, Own<PromiseNode>&& dependency) Maybe<const EventLoop&> loop, Own<PromiseNode>&& dependency)
: loop(loop), dependency(kj::mv(dependency)) {} : loop(loop), dependency(kj::mv(dependency)) {}
bool TransformPromiseNodeBase::onReady(EventLoop::Event& event) noexcept { bool TransformPromiseNodeBase::onReady(EventLoop::Event& event) noexcept {
...@@ -327,7 +327,7 @@ void TransformPromiseNodeBase::get(ExceptionOrValue& output) noexcept { ...@@ -327,7 +327,7 @@ void TransformPromiseNodeBase::get(ExceptionOrValue& output) noexcept {
} }
Maybe<const EventLoop&> TransformPromiseNodeBase::getSafeEventLoop() noexcept { Maybe<const EventLoop&> TransformPromiseNodeBase::getSafeEventLoop() noexcept {
return loop; return loop == nullptr ? dependency->getSafeEventLoop() : loop;
} }
void TransformPromiseNodeBase::dropDependency() { void TransformPromiseNodeBase::dropDependency() {
......
...@@ -58,6 +58,15 @@ template <typename T> ...@@ -58,6 +58,15 @@ template <typename T>
using JoinPromises = typename JoinPromises_<T>::Type; using JoinPromises = typename JoinPromises_<T>::Type;
// If T is Promise<U>, resolves to U, otherwise resolves to T. // If T is Promise<U>, resolves to U, otherwise resolves to T.
template <typename T> struct DisallowChain_ { typedef T Type; };
template <typename T> struct DisallowChain_<Promise<T>> {
static_assert(sizeof(T) < 0, "Continuation passed to thenInAnyThread() cannot return a promise.");
};
template <typename T>
using DisallowChain = typename DisallowChain_<T>::Type;
// If T is Promise<U>, error, otherwise resolves to T.
class PropagateException { class PropagateException {
// A functor which accepts a kj::Exception as a parameter and returns a broken promise of // A functor which accepts a kj::Exception as a parameter and returns a broken promise of
// arbitrary type which simply propagates the exception. // arbitrary type which simply propagates the exception.
...@@ -66,11 +75,6 @@ public: ...@@ -66,11 +75,6 @@ public:
public: public:
Bottom(Exception&& exception): exception(kj::mv(exception)) {} Bottom(Exception&& exception): exception(kj::mv(exception)) {}
template <typename T>
operator T() {
throwFatalException(kj::mv(exception));
}
Exception asException() { return kj::mv(exception); } Exception asException() { return kj::mv(exception); }
private: private:
...@@ -190,6 +194,10 @@ using PromiseForResult = Promise<_::JoinPromises<_::ReturnType<Func, T>>>; ...@@ -190,6 +194,10 @@ using PromiseForResult = Promise<_::JoinPromises<_::ReturnType<Func, T>>>;
// T. If T is void, then the promise is for the result of calling Func with no arguments. If // T. If T is void, then the promise is for the result of calling Func with no arguments. If
// Func itself returns a promise, the promises are joined, so you never get Promise<Promise<T>>. // Func itself returns a promise, the promises are joined, so you never get Promise<Promise<T>>.
template <typename Func, typename T>
using PromiseForResultNoChaining = Promise<_::DisallowChain<_::ReturnType<Func, T>>>;
// Like PromiseForResult but chaining (continuations that return another promise) is now allowed.
class EventLoop { class EventLoop {
// Represents a queue of events being executed in a loop. Most code won't interact with // Represents a queue of events being executed in a loop. Most code won't interact with
// EventLoop directly, but instead use `Promise`s to interact with it indirectly. See the // EventLoop directly, but instead use `Promise`s to interact with it indirectly. See the
...@@ -582,6 +590,14 @@ public: ...@@ -582,6 +590,14 @@ public:
// to yield control; this way, all other events in the queue will get a chance to run before your // to yield control; this way, all other events in the queue will get a chance to run before your
// callback is executed. // callback is executed.
template <typename Func, typename ErrorFunc = _::PropagateException>
PromiseForResultNoChaining<Func, T> thenInAnyThread(
Func&& func, ErrorFunc&& errorHandler = _::PropagateException()) KJ_WARN_UNUSED_RESULT;
// Like then(), but the continuation will be executed in an arbitrary thread, not the calling
// thread. The continuation MUST NOT return another promise. It's suggested that you use a
// lambda with an empty capture as the continuation. In the vast majority of cases, this ends
// up doing the same thing as then(); don't use this unless you really know you need it.
T wait(); T wait();
// Equivalent to `EventLoop::current().wait(kj::mv(*this))`. WARNING: Although `wait()` // Equivalent to `EventLoop::current().wait(kj::mv(*this))`. WARNING: Although `wait()`
// advances the event loop, calls to `wait()` obviously can only return in the reverse of the // advances the event loop, calls to `wait()` obviously can only return in the reverse of the
...@@ -892,14 +908,14 @@ private: ...@@ -892,14 +908,14 @@ private:
class TransformPromiseNodeBase: public PromiseNode { class TransformPromiseNodeBase: public PromiseNode {
public: public:
TransformPromiseNodeBase(const EventLoop& loop, Own<PromiseNode>&& dependency); TransformPromiseNodeBase(Maybe<const EventLoop&> loop, Own<PromiseNode>&& dependency);
bool onReady(EventLoop::Event& event) noexcept override; bool onReady(EventLoop::Event& event) noexcept override;
void get(ExceptionOrValue& output) noexcept override; void get(ExceptionOrValue& output) noexcept override;
Maybe<const EventLoop&> getSafeEventLoop() noexcept override; Maybe<const EventLoop&> getSafeEventLoop() noexcept override;
private: private:
const EventLoop& loop; Maybe<const EventLoop&> loop;
Own<PromiseNode> dependency; Own<PromiseNode> dependency;
void dropDependency(); void dropDependency();
...@@ -916,7 +932,7 @@ class TransformPromiseNode final: public TransformPromiseNodeBase { ...@@ -916,7 +932,7 @@ class TransformPromiseNode final: public TransformPromiseNodeBase {
// function (implements `then()`). // function (implements `then()`).
public: public:
TransformPromiseNode(const EventLoop& loop, Own<PromiseNode>&& dependency, TransformPromiseNode(Maybe<const EventLoop&> loop, Own<PromiseNode>&& dependency,
Func&& func, ErrorFunc&& errorHandler) Func&& func, ErrorFunc&& errorHandler)
: TransformPromiseNodeBase(loop, kj::mv(dependency)), : TransformPromiseNodeBase(loop, kj::mv(dependency)),
func(kj::fwd<Func>(func)), errorHandler(kj::fwd<ErrorFunc>(errorHandler)) {} func(kj::fwd<Func>(func)), errorHandler(kj::fwd<ErrorFunc>(errorHandler)) {}
...@@ -936,7 +952,8 @@ private: ...@@ -936,7 +952,8 @@ private:
ExceptionOr<DepT> depResult; ExceptionOr<DepT> depResult;
dependency->get(depResult); dependency->get(depResult);
KJ_IF_MAYBE(depException, depResult.exception) { KJ_IF_MAYBE(depException, depResult.exception) {
output.as<T>() = handle(MaybeVoidCaller<Exception&&, T>::apply( output.as<T>() = handle(
MaybeVoidCaller<Exception, FixVoid<ReturnType<ErrorFunc, Exception>>>::apply(
errorHandler, kj::mv(*depException))); errorHandler, kj::mv(*depException)));
} else KJ_IF_MAYBE(depValue, depResult.value) { } else KJ_IF_MAYBE(depValue, depResult.value) {
output.as<T>() = handle(MaybeVoidCaller<DepT, T>::apply(func, kj::mv(*depValue))); output.as<T>() = handle(MaybeVoidCaller<DepT, T>::apply(func, kj::mv(*depValue)));
...@@ -1285,6 +1302,17 @@ PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler ...@@ -1285,6 +1302,17 @@ PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler
EventLoop::Event::PREEMPT)); EventLoop::Event::PREEMPT));
} }
template <typename T>
template <typename Func, typename ErrorFunc>
PromiseForResultNoChaining<Func, T> Promise<T>::thenInAnyThread(
Func&& func, ErrorFunc&& errorHandler) {
typedef _::FixVoid<_::ReturnType<Func, T>> ResultT;
return PromiseForResultNoChaining<Func, T>(false,
heap<_::TransformPromiseNode<ResultT, _::FixVoid<T>, Func, ErrorFunc>>(
nullptr, kj::mv(node), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler)));
}
template <typename T> template <typename T>
T Promise<T>::wait() { T Promise<T>::wait() {
return EventLoop::current().wait(kj::mv(*this)); return EventLoop::current().wait(kj::mv(*this));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment