Commit e8d256ab authored by Kenton Varda's avatar Kenton Varda

More capabilities work. I think I need to implement Promise forking, though.

parent 7ceed92d
...@@ -71,11 +71,11 @@ SegmentReader* BasicReaderArena::tryGetSegment(SegmentId id) { ...@@ -71,11 +71,11 @@ SegmentReader* BasicReaderArena::tryGetSegment(SegmentId id) {
SegmentMap* segments = nullptr; SegmentMap* segments = nullptr;
KJ_IF_MAYBE(s, *lock) { KJ_IF_MAYBE(s, *lock) {
auto iter = s->find(id.value); auto iter = s->get()->find(id.value);
if (iter != s->end()) { if (iter != s->get()->end()) {
return iter->second; return iter->second;
} }
segments = s; segments = *s;
} }
kj::ArrayPtr<const word> newSegment = message->getSegment(id.value); kj::ArrayPtr<const word> newSegment = message->getSegment(id.value);
...@@ -113,7 +113,7 @@ public: ...@@ -113,7 +113,7 @@ public:
} }
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const override { CallContextHook& context) const override {
KJ_FAIL_REQUIRE("Calling capability that was extracted from a message that had no " KJ_FAIL_REQUIRE("Calling capability that was extracted from a message that had no "
"capability context."); "capability context.");
} }
...@@ -162,12 +162,12 @@ SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) { ...@@ -162,12 +162,12 @@ SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) {
SegmentMap* segments = nullptr; SegmentMap* segments = nullptr;
KJ_IF_MAYBE(s, *lock) { KJ_IF_MAYBE(s, *lock) {
auto iter = s->find(baseSegment); auto iter = s->get()->find(baseSegment);
if (iter != s->end()) { if (iter != s->get()->end()) {
KJ_DASSERT(iter->second->getArray().begin() == baseSegment->getArray().begin()); KJ_DASSERT(iter->second->getArray().begin() == baseSegment->getArray().begin());
return iter->second; return iter->second;
} }
segments = s; segments = *s;
} else { } else {
auto newMap = kj::heap<SegmentMap>(); auto newMap = kj::heap<SegmentMap>();
segments = newMap; segments = newMap;
...@@ -205,10 +205,10 @@ SegmentBuilder* BasicBuilderArena::getSegment(SegmentId id) { ...@@ -205,10 +205,10 @@ SegmentBuilder* BasicBuilderArena::getSegment(SegmentId id) {
} else { } else {
auto lock = moreSegments.lockShared(); auto lock = moreSegments.lockShared();
KJ_IF_MAYBE(s, *lock) { KJ_IF_MAYBE(s, *lock) {
KJ_REQUIRE(id.value - 1 < s->builders.size(), "invalid segment id", id.value); KJ_REQUIRE(id.value - 1 < s->get()->builders.size(), "invalid segment id", id.value);
// TODO(cleanup): Return a const SegmentBuilder and tediously constify all SegmentBuilder // TODO(cleanup): Return a const SegmentBuilder and tediously constify all SegmentBuilder
// pointers throughout the codebase. // pointers throughout the codebase.
return const_cast<BasicSegmentBuilder*>(s->builders[id.value - 1].get()); return const_cast<BasicSegmentBuilder*>(s->get()->builders[id.value - 1].get());
} else { } else {
KJ_FAIL_REQUIRE("invalid segment id", id.value); KJ_FAIL_REQUIRE("invalid segment id", id.value);
} }
...@@ -245,11 +245,11 @@ BasicBuilderArena::AllocateResult BasicBuilderArena::allocate(WordCount amount) ...@@ -245,11 +245,11 @@ BasicBuilderArena::AllocateResult BasicBuilderArena::allocate(WordCount amount)
// on the last-known available size, and then re-check the size when we pop segments off it // on the last-known available size, and then re-check the size when we pop segments off it
// and shove them to the back of the queue if they have become too small. // and shove them to the back of the queue if they have become too small.
attempt = s->builders.back()->allocate(amount); attempt = s->get()->builders.back()->allocate(amount);
if (attempt != nullptr) { if (attempt != nullptr) {
return AllocateResult { s->builders.back().get(), attempt }; return AllocateResult { s->get()->builders.back().get(), attempt };
} }
segmentState = s; segmentState = *s;
} else { } else {
auto newSegmentState = kj::heap<MultiSegmentState>(); auto newSegmentState = kj::heap<MultiSegmentState>();
segmentState = newSegmentState; segmentState = newSegmentState;
...@@ -279,15 +279,15 @@ kj::ArrayPtr<const kj::ArrayPtr<const word>> BasicBuilderArena::getSegmentsForOu ...@@ -279,15 +279,15 @@ kj::ArrayPtr<const kj::ArrayPtr<const word>> BasicBuilderArena::getSegmentsForOu
// problem regardless of locking here. // problem regardless of locking here.
KJ_IF_MAYBE(segmentState, moreSegments.getWithoutLock()) { KJ_IF_MAYBE(segmentState, moreSegments.getWithoutLock()) {
KJ_DASSERT(segmentState->forOutput.size() == segmentState->builders.size() + 1, KJ_DASSERT(segmentState->get()->forOutput.size() == segmentState->get()->builders.size() + 1,
"segmentState->forOutput wasn't resized correctly when the last builder was added.", "segmentState->forOutput wasn't resized correctly when the last builder was added.",
segmentState->forOutput.size(), segmentState->builders.size()); segmentState->get()->forOutput.size(), segmentState->get()->builders.size());
kj::ArrayPtr<kj::ArrayPtr<const word>> result( kj::ArrayPtr<kj::ArrayPtr<const word>> result(
&segmentState->forOutput[0], segmentState->forOutput.size()); &segmentState->get()->forOutput[0], segmentState->get()->forOutput.size());
uint i = 0; uint i = 0;
result[i++] = segment0.currentlyAllocated(); result[i++] = segment0.currentlyAllocated();
for (auto& builder: segmentState->builders) { for (auto& builder: segmentState->get()->builders) {
result[i++] = builder->currentlyAllocated(); result[i++] = builder->currentlyAllocated();
} }
return result; return result;
...@@ -314,11 +314,11 @@ SegmentReader* BasicBuilderArena::tryGetSegment(SegmentId id) { ...@@ -314,11 +314,11 @@ SegmentReader* BasicBuilderArena::tryGetSegment(SegmentId id) {
} else { } else {
auto lock = moreSegments.lockShared(); auto lock = moreSegments.lockShared();
KJ_IF_MAYBE(segmentState, *lock) { KJ_IF_MAYBE(segmentState, *lock) {
if (id.value <= segmentState->builders.size()) { if (id.value <= segmentState->get()->builders.size()) {
// TODO(cleanup): Return a const SegmentReader and tediously constify all SegmentBuilder // TODO(cleanup): Return a const SegmentReader and tediously constify all SegmentBuilder
// pointers throughout the codebase. // pointers throughout the codebase.
return const_cast<SegmentReader*>(kj::implicitCast<const SegmentReader*>( return const_cast<SegmentReader*>(kj::implicitCast<const SegmentReader*>(
segmentState->builders[id.value - 1].get())); segmentState->get()->builders[id.value - 1].get()));
} }
} }
return nullptr; return nullptr;
...@@ -360,15 +360,15 @@ SegmentBuilder* ImbuedBuilderArena::imbue(SegmentBuilder* baseSegment) { ...@@ -360,15 +360,15 @@ SegmentBuilder* ImbuedBuilderArena::imbue(SegmentBuilder* baseSegment) {
auto lock = moreSegments.lockExclusive(); auto lock = moreSegments.lockExclusive();
KJ_IF_MAYBE(segmentState, *lock) { KJ_IF_MAYBE(segmentState, *lock) {
auto id = baseSegment->getSegmentId().value; auto id = baseSegment->getSegmentId().value;
if (id >= segmentState->builders.size()) { if (id >= segmentState->get()->builders.size()) {
segmentState->builders.resize(id + 1); segmentState->get()->builders.resize(id + 1);
} }
KJ_IF_MAYBE(segment, segmentState->builders[id]) { KJ_IF_MAYBE(segment, segmentState->get()->builders[id]) {
result = segment; result = *segment;
} else { } else {
auto newBuilder = kj::heap<ImbuedSegmentBuilder>(baseSegment); auto newBuilder = kj::heap<ImbuedSegmentBuilder>(baseSegment);
result = newBuilder; result = newBuilder;
segmentState->builders[id] = kj::mv(newBuilder); segmentState->get()->builders[id] = kj::mv(newBuilder);
} }
} }
return nullptr; return nullptr;
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <kj/refcount.h> #include <kj/refcount.h>
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/vector.h> #include <kj/vector.h>
#include <kj/one-of.h>
namespace capnp { namespace capnp {
...@@ -68,11 +69,19 @@ TypelessResults::Pipeline TypelessResults::Pipeline::getPointerField( ...@@ -68,11 +69,19 @@ TypelessResults::Pipeline TypelessResults::Pipeline::getPointerField(
ResponseHook::~ResponseHook() noexcept(false) {} ResponseHook::~ResponseHook() noexcept(false) {}
// ======================================================================================= kj::Promise<void> ClientHook::whenResolved() const {
KJ_IF_MAYBE(promise, whenMoreResolved()) {
return promise->then([](kj::Own<const ClientHook>&& resolution) {
return resolution->whenResolved();
});
} else {
return kj::READY_NOW;
}
}
namespace { // =======================================================================================
class LocalResponse final: public ResponseHook { class LocalResponse final: public ResponseHook, public kj::Refcounted {
public: public:
LocalResponse(uint sizeHint) LocalResponse(uint sizeHint)
: message(sizeHint == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : sizeHint) {} : message(sizeHint == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : sizeHint) {}
...@@ -93,7 +102,7 @@ public: ...@@ -93,7 +102,7 @@ public:
} }
ObjectPointer::Builder getResults(uint firstSegmentWordSize) override { ObjectPointer::Builder getResults(uint firstSegmentWordSize) override {
if (!response) { if (!response) {
response = kj::heap<LocalResponse>(firstSegmentWordSize); response = kj::refcounted<LocalResponse>(firstSegmentWordSize);
} }
return response->message.getRoot<ObjectPointer>(); return response->message.getRoot<ObjectPointer>();
} }
...@@ -103,70 +112,275 @@ public: ...@@ -103,70 +112,275 @@ public:
bool isCanceled() override { bool isCanceled() override {
return false; return false;
} }
Response<ObjectPointer> getResponseForPipeline() override {
auto reader = getResults(1); // Needs to be a separate line since it may allocate the response.
return Response<ObjectPointer>(reader, kj::addRef(*response));
}
kj::Own<MallocMessageBuilder> request; kj::Own<MallocMessageBuilder> request;
kj::Own<LocalResponse> response; kj::Own<LocalResponse> response;
kj::Own<const ClientHook> clientRef; kj::Own<const ClientHook> clientRef;
}; };
class LocalPipelinedClient final: public ClientHook, public kj::Refcounted { class LocalRequest final: public RequestHook {
public: public:
LocalPipelinedClient(kj::Promise<kj::Own<const ClientHook>> promise) inline LocalRequest(uint64_t interfaceId, uint16_t methodId,
: innerPromise(promise.then([this](kj::Own<const ClientHook>&& resolution) { uint firstSegmentWordSize, kj::Own<const ClientHook> client)
: message(kj::heap<MallocMessageBuilder>(
firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize)),
interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {}
RemotePromise<TypelessResults> send() override {
// For the lambda capture.
uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId;
auto context = kj::heap<LocalCallContext>(kj::mv(message), kj::mv(client));
auto promiseAndPipeline = client->call(interfaceId, methodId, *context);
auto promise = promiseAndPipeline.promise.then(
kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
return Response<TypelessResults>(context->getResults(1).asReader(),
kj::mv(context->response));
}));
return RemotePromise<TypelessResults>(
kj::mv(promise), TypelessResults::Pipeline(kj::mv(promiseAndPipeline.pipeline)));
}
kj::Own<MallocMessageBuilder> message;
private:
uint64_t interfaceId;
uint16_t methodId;
kj::Own<const ClientHook> client;
};
// =======================================================================================
namespace {
class BrokenPipeline final: public PipelineHook, public kj::Refcounted {
public:
BrokenPipeline(const kj::Exception& exception): exception(exception) {}
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
}
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const override;
private:
kj::Exception exception;
};
class BrokenClient final: public ClientHook, public kj::Refcounted {
public:
BrokenClient(const kj::Exception& exception): exception(exception) {}
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
return Request<ObjectPointer, TypelessResults>(
hook->message->getRoot<ObjectPointer>(), kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContextHook& context) const override {
return VoidPromiseAndPipeline { kj::cp(exception), kj::heap<BrokenPipeline>(exception) };
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return kj::Promise<kj::Own<const ClientHook>>(kj::cp(exception));
}
kj::Own<const ClientHook> addRef() const override {
return kj::addRef(*this);
}
void* getBrand() const override {
return nullptr;
}
private:
kj::Exception exception;
};
kj::Own<const ClientHook> BrokenPipeline::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) const {
return kj::heap<BrokenClient>(exception);
}
// =======================================================================================
// Call queues
//
// These classes handle pipelining in the case where calls need to be queued in-memory until some
// local operation completes.
class QueuedPipeline final: public PipelineHook, public kj::Refcounted {
// A PipelineHook which simply queues calls while waiting for a PipelineHook to which to forward
// them.
public:
QueuedPipeline(kj::EventLoop& loop, kj::Promise<kj::Own<const PipelineHook>>&& promise)
: loop(loop),
innerPromise(loop.there(kj::mv(promise), [this](kj::Own<const PipelineHook>&& resolution) {
auto lock = state.lockExclusive(); auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
auto oldState = kj::mv(*lock); for (auto& waiter: oldState) {
for (auto& call: oldState.pending) { waiter.fulfiller->fulfill(resolution->getPipelinedCap(kj::mv(waiter.ops)));
call.fulfiller->fulfill(resolution->call(
call.interfaceId, call.methodId, call.context).promise);
} }
lock->init<Resolved>(kj::mv(resolution));
}, [this](kj::Exception&& exception) {
auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
for (auto& waiter: oldState) {
waiter.fulfiller->reject(kj::cp(exception));
}
lock->init<kj::Exception>(kj::mv(exception));
})) {
state.getWithoutLock().init<Waiting>();
}
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
}
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const override {
auto copy = kj::heapArrayBuilder<PipelineOp>(ops.size());
for (auto& op: ops) {
copy.add(op);
}
return getPipelinedCap(kj::mv(copy));
}
kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const override;
private:
struct Waiter {
kj::Array<PipelineOp> ops;
kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>> fulfiller;
};
typedef kj::Vector<Waiter> Waiting;
typedef kj::Own<const PipelineHook> Resolved;
kj::EventLoop& loop;
kj::Promise<void> innerPromise;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, kj::Exception>> state;
};
class QueuedClient final: public ClientHook, public kj::Refcounted {
// A ClientHook which simply queues calls while waiting for a ClientHook to which to forward
// them.
public:
QueuedClient(kj::EventLoop& loop, kj::Promise<kj::Own<const ClientHook>>&& promise)
: loop(loop),
innerPromise(loop.there(kj::mv(promise), [this](kj::Own<const ClientHook>&& resolution) {
// The promised capability has resolved. Forward all queued calls to it.
auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
// First we want to initiate all the queued calls, and notify the QueuedPipelines to
// transfer their queues to the new call's own pipeline. It's important that this all
// happen before the application receives any notification that the promise resolved,
// so that any new calls it makes in response to the resolution don't end up being
// delivered before the previously-queued calls.
auto realCallPromises = kj::heapArrayBuilder<kj::Promise<void>>(oldState.pending.size());
for (auto& pendingCall: oldState.pending) {
auto realCall = resolution->call(
pendingCall.interfaceId, pendingCall.methodId, *pendingCall.context);
pendingCall.pipelineFulfiller->fulfill(kj::mv(realCall.pipeline));
realCallPromises.add(kj::mv(realCall.promise));
}
// Fire the "whenMoreResolved" callbacks.
for (auto& notify: oldState.notifyOnResolution) { for (auto& notify: oldState.notifyOnResolution) {
notify->fulfill(resolution->addRef()); notify->fulfill(resolution->addRef());
} }
lock->resolution = kj::mv(resolution); // For each queued call, chain the pipelined promise to the real promise. It's important
// that this happens after the "whenMoreResolved" callbacks because applications may get
// confused if a pipelined call completes before the promise on which it was made
// resolves.
for (uint i: kj::indices(realCallPromises)) {
oldState.pending[i].fulfiller->fulfill(kj::mv(realCallPromises[i]));
}
lock->init<Resolved>(kj::mv(resolution));
}, [this](kj::Exception&& exception) { }, [this](kj::Exception&& exception) {
auto lock = state.lockExclusive(); auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
auto oldState = kj::mv(*lock); // Reject outer promises before dependent promises.
for (auto& call: oldState.pending) {
call.fulfiller->reject(kj::Exception(exception));
}
for (auto& notify: oldState.notifyOnResolution) { for (auto& notify: oldState.notifyOnResolution) {
notify->reject(kj::Exception(exception)); notify->reject(kj::cp(exception));
}
for (auto& call: oldState.pending) {
call.fulfiller->reject(kj::cp(exception));
call.pipelineFulfiller->reject(kj::cp(exception));
} }
lock->exception = kj::mv(exception); lock->init<kj::Exception>(kj::mv(exception));
})) {} })) {
state.getWithoutLock().init<Waiting>();
}
Request<ObjectPointer, TypelessResults> newCall( Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override { uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto lock = state.lockExclusive(); auto lock = state.lockExclusive();
KJ_IF_MAYBE(r, lock->resolution) { if (lock->is<Resolved>()) {
return r->newCall(interfaceId, methodId, firstSegmentWordSize); return lock->get<Resolved>()->newCall(interfaceId, methodId, firstSegmentWordSize);
} else { } else {
// TODO(now) auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
return Request<ObjectPointer, TypelessResults>(
hook->message->getRoot<ObjectPointer>(), kj::mv(hook));
} }
} }
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const override { CallContextHook& context) const override {
auto lock = state.lockExclusive(); auto lock = state.lockExclusive();
KJ_IF_MAYBE(r, lock->resolution) { if (lock->is<Resolved>()) {
return r->call(interfaceId, methodId, context); return lock->get<Resolved>()->call(interfaceId, methodId, context);
} else if (lock->is<kj::Exception>()) {
return VoidPromiseAndPipeline { kj::cp(lock->get<kj::Exception>()),
kj::heap<BrokenPipeline>(lock->get<kj::Exception>()) };
} else { } else {
lock->pending.add(PendingCall { interfaceId, methodId, context }); auto pair = kj::newPromiseAndFulfiller<kj::Promise<void>>(loop);
auto pipelinePromise = kj::newPromiseAndFulfiller<kj::Own<const PipelineHook>>(loop);
auto pipeline = kj::heap<QueuedPipeline>(loop, kj::mv(pipelinePromise.promise));
lock->get<Waiting>().pending.add(PendingCall {
interfaceId, methodId, &context, kj::mv(pair.fulfiller),
kj::mv(pipelinePromise.fulfiller) });
// TODO(now): returned promise must hold a reference to this.
return VoidPromiseAndPipeline { kj::mv(pair.promise), kj::mv(pipeline) };
} }
} }
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override { kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
auto lock = state.lockExclusive(); auto lock = state.lockExclusive();
KJ_IF_MAYBE(r, lock->resolution) { if (lock->is<Resolved>()) {
// Already resolved. // Already resolved.
return kj::Promise<kj::Own<const ClientHook>>(r->addRef()); return kj::Promise<kj::Own<const ClientHook>>(lock->get<Resolved>()->addRef());
} else if (lock->is<kj::Exception>()) {
// Already broken.
return kj::Promise<kj::Own<const ClientHook>>(kj::Own<const ClientHook>(
kj::heap<BrokenClient>(lock->get<kj::Exception>())));
} else { } else {
// Waiting.
auto pair = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>(); auto pair = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>();
lock->notifyOnResolution.add(kj::mv(pair.fulfiller)); lock->get<Waiting>().notifyOnResolution.add(kj::mv(pair.fulfiller));
// TODO(now): returned promise must hold a reference to this.
return kj::mv(pair.promise); return kj::mv(pair.promise);
} }
} }
...@@ -183,81 +397,52 @@ private: ...@@ -183,81 +397,52 @@ private:
struct PendingCall { struct PendingCall {
uint64_t interfaceId; uint64_t interfaceId;
uint16_t methodId; uint16_t methodId;
CallContext<ObjectPointer, ObjectPointer> context; CallContextHook* context;
kj::Own<kj::PromiseFulfiller<void>> fulfiller; kj::Own<kj::PromiseFulfiller<kj::Promise<void>>> fulfiller;
kj::Own<kj::PromiseFulfiller<kj::Own<const PipelineHook>>> pipelineFulfiller;
}; };
struct State { struct Waiting {
kj::Maybe<kj::Own<const ClientHook>> resolution;
kj::Vector<PendingCall> pending; kj::Vector<PendingCall> pending;
kj::Vector<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> notifyOnResolution; kj::Vector<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> notifyOnResolution;
}; };
kj::MutexGuarded<State> state;
typedef kj::Own<const ClientHook> Resolved;
kj::EventLoop& loop;
kj::Promise<void> innerPromise; kj::Promise<void> innerPromise;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, kj::Exception>> state;
}; };
class LocalPipeline final: public PipelineHook, public kj::Refcounted { kj::Own<const ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) const {
public: auto lock = state.lockExclusive();
kj::Own<const PipelineHook> addRef() const override { if (lock->is<Resolved>()) {
return kj::addRef(*this); return lock->get<Resolved>()->getPipelinedCap(ops);
} } else if (lock->is<kj::Exception>()) {
return kj::heap<BrokenClient>(lock->get<kj::Exception>());
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const override { } else {
auto pair = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>();
lock->get<Waiting>().add(Waiter { kj::mv(ops), kj::mv(pair.fulfiller) });
return kj::heap<QueuedClient>(loop, kj::mv(pair.promise));
} }
}
private: // =======================================================================================
struct Waiter {
};
struct State {
kj::Vector<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> notifyOnResolution;
};
kj::MutexGuarded<State> state;
};
class LocalRequest final: public RequestHook { class LocalPipeline final: public PipelineHook, public kj::Refcounted {
public: public:
inline LocalRequest(kj::EventLoop& eventLoop, const Capability::Server* server, inline LocalPipeline(Response<ObjectPointer> response): response(kj::mv(response)) {}
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize,
kj::Own<const ClientHook> clientRef)
: message(kj::heap<MallocMessageBuilder>(
firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize)),
eventLoop(eventLoop), server(server), interfaceId(interfaceId), methodId(methodId),
clientRef(kj::mv(clientRef)) {}
RemotePromise<TypelessResults> send() override {
// For the lambda capture.
// We can const-cast the server pointer because we are synchronizing to its event loop here.
Capability::Server* server = const_cast<Capability::Server*>(this->server);
uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId;
auto context = kj::heap<LocalCallContext>(kj::mv(message), kj::mv(clientRef)); kj::Own<const PipelineHook> addRef() const {
auto promise = eventLoop.evalLater( return kj::addRef(*this);
kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
return server->dispatchCall(interfaceId, methodId,
CallContext<ObjectPointer, ObjectPointer>(*context))
.then(kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
return Response<TypelessResults>(context->getResults(1).asReader(),
kj::mv(context->response));
}));
}));
return RemotePromise<TypelessResults>(
kj::mv(promise),
TypelessResults::Pipeline(kj::heap<LocalPipeline>()));
} }
kj::Own<MallocMessageBuilder> message; kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const {
return response.getPipelinedCap(ops);
}
private: private:
kj::EventLoop& eventLoop; Response<ObjectPointer> response;
const Capability::Server* server;
uint64_t interfaceId;
uint16_t methodId;
kj::Own<const ClientHook> clientRef;
}; };
class LocalClient final: public ClientHook, public kj::Refcounted { class LocalClient final: public ClientHook, public kj::Refcounted {
...@@ -268,21 +453,30 @@ public: ...@@ -268,21 +453,30 @@ public:
Request<ObjectPointer, TypelessResults> newCall( Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override { uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto hook = kj::heap<LocalRequest>( auto hook = kj::heap<LocalRequest>(
eventLoop, server, interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this)); interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
return Request<ObjectPointer, TypelessResults>( return Request<ObjectPointer, TypelessResults>(
hook->message->getRoot<ObjectPointer>(), kj::mv(hook)); hook->message->getRoot<ObjectPointer>(), kj::mv(hook));
} }
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const override { CallContextHook& context) const override {
// We can const-cast the server because we're synchronizing on the event loop. // We can const-cast the server because we're synchronizing on the event loop.
auto server = const_cast<Capability::Server*>(this->server.get()); auto server = const_cast<Capability::Server*>(this->server.get());
auto promise = eventLoop.evalLater([=]() mutable {
return server->dispatchCall(interfaceId, methodId, context); auto pipelineFulfiller = kj::newPromiseAndFulfiller<kj::Own<const PipelineHook>>();
});
auto promise = eventLoop.evalLater(kj::mvCapture(pipelineFulfiller.fulfiller,
[=,&context](kj::Own<kj::PromiseFulfiller<kj::Own<const PipelineHook>>>&& fulfiller) mutable {
return server->dispatchCall(interfaceId, methodId,
CallContext<ObjectPointer, ObjectPointer>(context))
.then(kj::mvCapture(fulfiller,
[=,&context](kj::Own<kj::PromiseFulfiller<kj::Own<const PipelineHook>>>&& fulfiller) {
fulfiller->fulfill(kj::heap<LocalPipeline>(context.getResponseForPipeline()));
}));
}));
return VoidPromiseAndPipeline { kj::mv(promise), return VoidPromiseAndPipeline { kj::mv(promise),
TypelessResults::Pipeline(kj::heap<LocalPipeline>()) }; kj::heap<QueuedPipeline>(eventLoop, kj::mv(pipelineFulfiller.promise)) };
} }
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override { kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
......
...@@ -321,6 +321,12 @@ public: ...@@ -321,6 +321,12 @@ public:
virtual kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const = 0; virtual kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const = 0;
// Extract a promised Capability from the results. // Extract a promised Capability from the results.
virtual kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const {
// Version of getPipelinedCap() passing the array by move. May avoid a copy in some cases.
// Default implementation just calls the other version.
return getPipelinedCap(ops.asPtr());
}
}; };
class ClientHook { class ClientHook {
...@@ -332,11 +338,11 @@ public: ...@@ -332,11 +338,11 @@ public:
struct VoidPromiseAndPipeline { struct VoidPromiseAndPipeline {
kj::Promise<void> promise; kj::Promise<void> promise;
TypelessResults::Pipeline pipeline; kj::Own<const PipelineHook> pipeline;
}; };
virtual VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, virtual VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const = 0; CallContextHook& context) const = 0;
// Call the object, but the caller controls allocation of the request/response objects. If the // Call the object, but the caller controls allocation of the request/response objects. If the
// callee insists on allocating this objects itself, it must make a copy. This version is used // callee insists on allocating this objects itself, it must make a copy. This version is used
// when calls come in over the network via an RPC system. During the call, the context object // when calls come in over the network via an RPC system. During the call, the context object
...@@ -351,6 +357,9 @@ public: ...@@ -351,6 +357,9 @@ public:
// promise that eventually resolves to a new client that is closer to being the final, settled // promise that eventually resolves to a new client that is closer to being the final, settled
// client. Calling this repeatedly should eventually produce a settled client. // client. Calling this repeatedly should eventually produce a settled client.
kj::Promise<void> whenResolved() const;
// Repeatedly calls whenMoreResolved() until it returns nullptr.
virtual kj::Own<const ClientHook> addRef() const = 0; virtual kj::Own<const ClientHook> addRef() const = 0;
// Return a new reference to the same capability. // Return a new reference to the same capability.
...@@ -370,6 +379,11 @@ public: ...@@ -370,6 +379,11 @@ public:
virtual ObjectPointer::Builder getResults(uint firstSegmentWordSize) = 0; virtual ObjectPointer::Builder getResults(uint firstSegmentWordSize) = 0;
virtual void allowAsyncCancellation(bool allow) = 0; virtual void allowAsyncCancellation(bool allow) = 0;
virtual bool isCanceled() = 0; virtual bool isCanceled() = 0;
virtual Response<ObjectPointer> getResponseForPipeline() = 0;
// Get a copy or reference to the response which will be used to execute pipelined calls. This
// will be called no more than once, just after the server implementation successfully returns
// from the call.
}; };
// ======================================================================================= // =======================================================================================
......
...@@ -1149,7 +1149,7 @@ private: ...@@ -1149,7 +1149,7 @@ private:
return MethodText { return MethodText {
kj::strTree( kj::strTree(
" ::capnp::Request<", paramType, ", ", resultType, "> ", name, "Request(\n" " ::capnp::Request<", paramType, ", ", resultType, "> ", name, "Request(\n"
" uint firstSegmentWordSize = 0);\n"), " unsigned int firstSegmentWordSize = 0);\n"),
kj::strTree( kj::strTree(
" virtual ::kj::Promise<void> ", name, "(\n" " virtual ::kj::Promise<void> ", name, "(\n"
...@@ -1162,7 +1162,7 @@ private: ...@@ -1162,7 +1162,7 @@ private:
kj::strTree( kj::strTree(
"::capnp::Request<", paramType, ", ", resultType, ">\n", "::capnp::Request<", paramType, ", ", resultType, ">\n",
interfaceName, "::Client::", name, "Request(uint firstSegmentWordSize) {\n" interfaceName, "::Client::", name, "Request(unsigned int firstSegmentWordSize) {\n"
" return newCall<", paramType, ", ", resultType, ">(\n" " return newCall<", paramType, ", ", resultType, ">(\n"
" 0x", interfaceIdHex, "ull, ", methodId, ", firstSegmentWordSize);\n" " 0x", interfaceIdHex, "ull, ", methodId, ", firstSegmentWordSize);\n"
"}\n" "}\n"
......
...@@ -676,8 +676,8 @@ static kj::Maybe<kj::Exception> loadFile( ...@@ -676,8 +676,8 @@ static kj::Maybe<kj::Exception> loadFile(
KJ_IF_MAYBE(m, messageBuilder) { KJ_IF_MAYBE(m, messageBuilder) {
// Build an example struct using the compiled schema. // Build an example struct using the compiled schema.
m->adoptRoot(makeExampleStruct( m->get()->adoptRoot(makeExampleStruct(
m->getOrphanage(), compiler.getLoader().get(0x823456789abcdef1llu).asStruct(), m->get()->getOrphanage(), compiler.getLoader().get(0x823456789abcdef1llu).asStruct(),
sharedOrdinalCount)); sharedOrdinalCount));
} }
...@@ -692,7 +692,7 @@ static kj::Maybe<kj::Exception> loadFile( ...@@ -692,7 +692,7 @@ static kj::Maybe<kj::Exception> loadFile(
KJ_IF_MAYBE(m, messageBuilder) { KJ_IF_MAYBE(m, messageBuilder) {
// Check that the example struct matches the compiled schema. // Check that the example struct matches the compiled schema.
auto root = m->getRoot<DynamicStruct>( auto root = m->get()->getRoot<DynamicStruct>(
compiler.getLoader().get(0x823456789abcdef1llu).asStruct()).asReader(); compiler.getLoader().get(0x823456789abcdef1llu).asStruct()).asReader();
KJ_CONTEXT(root); KJ_CONTEXT(root);
checkExampleStruct(root, sharedOrdinalCount); checkExampleStruct(root, sharedOrdinalCount);
......
...@@ -810,11 +810,11 @@ struct PointerHelpers<DynamicList, Kind::UNKNOWN> { ...@@ -810,11 +810,11 @@ struct PointerHelpers<DynamicList, Kind::UNKNOWN> {
} // namespace _ (private) } // namespace _ (private)
template <typename T> template <typename T>
inline typename T::Reader ObjectPointer::Reader::getAs(StructSchema schema) { inline typename T::Reader ObjectPointer::Reader::getAs(StructSchema schema) const {
return _::PointerHelpers<T>::getDynamic(reader, schema); return _::PointerHelpers<T>::getDynamic(reader, schema);
} }
template <typename T> template <typename T>
inline typename T::Reader ObjectPointer::Reader::getAs(ListSchema schema) { inline typename T::Reader ObjectPointer::Reader::getAs(ListSchema schema) const {
return _::PointerHelpers<T>::getDynamic(reader, schema); return _::PointerHelpers<T>::getDynamic(reader, schema);
} }
template <typename T> template <typename T>
......
...@@ -165,7 +165,7 @@ MallocMessageBuilder::~MallocMessageBuilder() noexcept(false) { ...@@ -165,7 +165,7 @@ MallocMessageBuilder::~MallocMessageBuilder() noexcept(false) {
} }
KJ_IF_MAYBE(s, moreSegments) { KJ_IF_MAYBE(s, moreSegments) {
for (void* ptr: s->segments) { for (void* ptr: s->get()->segments) {
free(ptr); free(ptr);
} }
} }
...@@ -201,7 +201,7 @@ kj::ArrayPtr<word> MallocMessageBuilder::allocateSegment(uint minimumSize) { ...@@ -201,7 +201,7 @@ kj::ArrayPtr<word> MallocMessageBuilder::allocateSegment(uint minimumSize) {
} else { } else {
MoreSegments* segments; MoreSegments* segments;
KJ_IF_MAYBE(s, moreSegments) { KJ_IF_MAYBE(s, moreSegments) {
segments = s; segments = *s;
} else { } else {
auto newSegments = kj::heap<MoreSegments>(); auto newSegments = kj::heap<MoreSegments>();
segments = newSegments; segments = newSegments;
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
namespace capnp { namespace capnp {
kj::Own<const ClientHook> ObjectPointer::Reader::getPipelinedCap( kj::Own<const ClientHook> ObjectPointer::Reader::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) { kj::ArrayPtr<const PipelineOp> ops) const {
_::PointerReader pointer = reader; _::PointerReader pointer = reader;
for (auto& op: ops) { for (auto& op: ops) {
......
...@@ -47,21 +47,21 @@ struct ObjectPointer { ...@@ -47,21 +47,21 @@ struct ObjectPointer {
Reader() = default; Reader() = default;
inline Reader(_::PointerReader reader): reader(reader) {} inline Reader(_::PointerReader reader): reader(reader) {}
inline bool isNull(); inline bool isNull() const;
template <typename T> template <typename T>
inline typename T::Reader getAs(); inline typename T::Reader getAs() const;
// Valid for T = any generated struct type, List<U>, Text, or Data. // Valid for T = any generated struct type, List<U>, Text, or Data.
template <typename T> template <typename T>
inline typename T::Reader getAs(StructSchema schema); inline typename T::Reader getAs(StructSchema schema) const;
// Only valid for T = DynamicStruct. Requires `#include <capnp/dynamic.h>`. // Only valid for T = DynamicStruct. Requires `#include <capnp/dynamic.h>`.
template <typename T> template <typename T>
inline typename T::Reader getAs(ListSchema schema); inline typename T::Reader getAs(ListSchema schema) const;
// Only valid for T = DynamicList. Requires `#include <capnp/dynamic.h>`. // Only valid for T = DynamicList. Requires `#include <capnp/dynamic.h>`.
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops); kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const;
// Used by RPC system to implement pipelining. Applications generally shouldn't use this // Used by RPC system to implement pipelining. Applications generally shouldn't use this
// directly. // directly.
...@@ -209,12 +209,12 @@ private: ...@@ -209,12 +209,12 @@ private:
// ======================================================================================= // =======================================================================================
// Inline implementation details // Inline implementation details
inline bool ObjectPointer::Reader::isNull() { inline bool ObjectPointer::Reader::isNull() const {
return reader.isNull(); return reader.isNull();
} }
template <typename T> template <typename T>
inline typename T::Reader ObjectPointer::Reader::getAs() { inline typename T::Reader ObjectPointer::Reader::getAs() const {
return _::PointerHelpers<T>::get(reader); return _::PointerHelpers<T>::get(reader);
} }
...@@ -292,7 +292,7 @@ inline Orphan<T> Orphan<ObjectPointer>::releaseAs() { ...@@ -292,7 +292,7 @@ inline Orphan<T> Orphan<ObjectPointer>::releaseAs() {
// Using ObjectPointer as the template type should work... // Using ObjectPointer as the template type should work...
template <> template <>
inline typename ObjectPointer::Reader ObjectPointer::Reader::getAs<ObjectPointer>() { inline typename ObjectPointer::Reader ObjectPointer::Reader::getAs<ObjectPointer>() const {
return *this; return *this;
} }
template <> template <>
......
...@@ -101,7 +101,7 @@ public: ...@@ -101,7 +101,7 @@ public:
kj::Maybe<const Module&> importRelative(kj::StringPtr importPath) const override { kj::Maybe<const Module&> importRelative(kj::StringPtr importPath) const override {
KJ_IF_MAYBE(importedFile, file->import(importPath)) { KJ_IF_MAYBE(importedFile, file->import(importPath)) {
return parser.getModuleImpl(kj::mv(importedFile)); return parser.getModuleImpl(kj::mv(*importedFile));
} else { } else {
return nullptr; return nullptr;
} }
......
...@@ -206,6 +206,21 @@ TEST(Async, SeparateFulfillerCanceled) { ...@@ -206,6 +206,21 @@ TEST(Async, SeparateFulfillerCanceled) {
EXPECT_FALSE(pair.fulfiller->isWaiting()); EXPECT_FALSE(pair.fulfiller->isWaiting());
} }
TEST(Async, SeparateFulfillerChained) {
SimpleEventLoop loop;
auto pair = newPromiseAndFulfiller<Promise<int>>(loop);
auto inner = newPromiseAndFulfiller<int>();
EXPECT_TRUE(pair.fulfiller->isWaiting());
pair.fulfiller->fulfill(kj::mv(inner.promise));
EXPECT_FALSE(pair.fulfiller->isWaiting());
inner.fulfiller->fulfill(123);
EXPECT_EQ(123, loop.wait(kj::mv(pair.promise)));
}
#if KJ_NO_EXCEPTIONS #if KJ_NO_EXCEPTIONS
#undef EXPECT_ANY_THROW #undef EXPECT_ANY_THROW
#define EXPECT_ANY_THROW(code) EXPECT_DEATH(code, ".") #define EXPECT_ANY_THROW(code) EXPECT_DEATH(code, ".")
......
...@@ -36,6 +36,8 @@ template <typename T> ...@@ -36,6 +36,8 @@ template <typename T>
class Promise; class Promise;
template <typename T> template <typename T>
class PromiseFulfiller; class PromiseFulfiller;
template <typename T>
struct PromiseFulfillerPair;
// ======================================================================================= // =======================================================================================
// *************************************************************************************** // ***************************************************************************************
...@@ -468,11 +470,14 @@ public: ...@@ -468,11 +470,14 @@ public:
// //
// For void promises, use `kj::READY_NOW` as the value, e.g. `return kj::READY_NOW`. // For void promises, use `kj::READY_NOW` as the value, e.g. `return kj::READY_NOW`.
Promise(kj::Exception&& e);
// Construct an already-broken Promise.
inline Promise(decltype(nullptr)) {} inline Promise(decltype(nullptr)) {}
template <typename Func, typename ErrorFunc = _::PropagateException> template <typename Func, typename ErrorFunc = _::PropagateException>
auto then(Func&& func, ErrorFunc&& errorHandler = _::PropagateException()) PromiseForResult<Func, T> then(Func&& func, ErrorFunc&& errorHandler = _::PropagateException())
-> PromiseForResult<Func, T>; KJ_WARN_UNUSED_RESULT;
// Register a continuation function to be executed when the promise completes. The continuation // Register a continuation function to be executed when the promise completes. The continuation
// (`func`) takes the promised value (an rvalue of type `T`) as its parameter. The continuation // (`func`) takes the promised value (an rvalue of type `T`) as its parameter. The continuation
// may return a new value; `then()` itself returns a promise for the continuation's eventual // may return a new value; `then()` itself returns a promise for the continuation's eventual
...@@ -561,6 +566,10 @@ private: ...@@ -561,6 +566,10 @@ private:
friend class EventLoop; friend class EventLoop;
template <typename U, typename Adapter, typename... Params> template <typename U, typename Adapter, typename... Params>
friend Promise<U> newAdaptedPromise(Params&&... adapterConstructorParams); friend Promise<U> newAdaptedPromise(Params&&... adapterConstructorParams);
template <typename U>
friend PromiseFulfillerPair<U> newPromiseAndFulfiller();
template <typename U>
friend PromiseFulfillerPair<U> newPromiseAndFulfiller(const EventLoop& loop);
}; };
constexpr _::Void READY_NOW = _::Void(); constexpr _::Void READY_NOW = _::Void();
...@@ -672,18 +681,28 @@ Promise<T> newAdaptedPromise(Params&&... adapterConstructorParams); ...@@ -672,18 +681,28 @@ Promise<T> newAdaptedPromise(Params&&... adapterConstructorParams);
template <typename T> template <typename T>
struct PromiseFulfillerPair { struct PromiseFulfillerPair {
Promise<T> promise; Promise<_::JoinPromises<T>> promise;
Own<PromiseFulfiller<T>> fulfiller; Own<PromiseFulfiller<T>> fulfiller;
}; };
template <typename T> template <typename T>
PromiseFulfillerPair<T> newPromiseAndFulfiller(); PromiseFulfillerPair<T> newPromiseAndFulfiller();
template <typename T>
PromiseFulfillerPair<T> newPromiseAndFulfiller(const EventLoop& loop);
// Construct a Promise and a separate PromiseFulfiller which can be used to fulfill the promise. // Construct a Promise and a separate PromiseFulfiller which can be used to fulfill the promise.
// If the PromiseFulfiller is destroyed before either of its methods are called, the Promise is // If the PromiseFulfiller is destroyed before either of its methods are called, the Promise is
// implicitly rejected. // implicitly rejected.
// //
// Although this function is easier to use than `newAdaptedPromise()`, it has the serious drawback // Although this function is easier to use than `newAdaptedPromise()`, it has the serious drawback
// that there is no way to handle cancellation (i.e. detect when the Promise is discarded). // that there is no way to handle cancellation (i.e. detect when the Promise is discarded).
//
// You can arrange to fulfill a promise with another promise by using a promise type for T. E.g.
// if `T` is `Promise<U>`, then the returned promise will be of type `Promise<U>` but the fulfiller
// will be of type `PromiseFulfiller<Promise<U>>`. Thus you pass a `Promise<U>` to the `fulfill()`
// callback, and the promises are chained. In this case, an `EventLoop` is needed in order to wait
// on the chained promise; you can specify one as a parameter, otherwise the current loop in the
// thread that called `newPromiseAndFulfiller` will be used. If `T` is *not* a promise type, then
// no `EventLoop` is needed and the `loop` parameter, if specified, is ignored.
// ======================================================================================= // =======================================================================================
// internal implementation details follow // internal implementation details follow
...@@ -891,6 +910,16 @@ Own<PromiseNode>&& maybeChain(Own<PromiseNode>&& node, const EventLoop& loop, ...@@ -891,6 +910,16 @@ Own<PromiseNode>&& maybeChain(Own<PromiseNode>&& node, const EventLoop& loop,
return kj::mv(node); return kj::mv(node);
} }
template <typename T>
Own<PromiseNode> maybeChain(Own<PromiseNode>&& node, Promise<T>*) {
return heap<ChainPromiseNode>(EventLoop::current(), kj::mv(node), EventLoop::Event::PREEMPT);
}
template <typename T>
Own<PromiseNode>&& maybeChain(Own<PromiseNode>&& node, T*) {
return kj::mv(node);
}
class CrossThreadPromiseNodeBase: public PromiseNode, private EventLoop::Event { class CrossThreadPromiseNodeBase: public PromiseNode, private EventLoop::Event {
// A PromiseNode that safely imports a promised value from one EventLoop to another (which // A PromiseNode that safely imports a promised value from one EventLoop to another (which
// implies crossing threads). // implies crossing threads).
...@@ -1048,9 +1077,13 @@ template <typename T> ...@@ -1048,9 +1077,13 @@ template <typename T>
Promise<T>::Promise(_::FixVoid<T> value) Promise<T>::Promise(_::FixVoid<T> value)
: PromiseBase(heap<_::ImmediatePromiseNode<_::FixVoid<T>>>(kj::mv(value))) {} : PromiseBase(heap<_::ImmediatePromiseNode<_::FixVoid<T>>>(kj::mv(value))) {}
template <typename T>
Promise<T>::Promise(kj::Exception&& exception)
: PromiseBase(heap<_::ImmediateBrokenPromiseNode>(kj::mv(exception))) {}
template <typename T> template <typename T>
template <typename Func, typename ErrorFunc> template <typename Func, typename ErrorFunc>
auto Promise<T>::then(Func&& func, ErrorFunc&& errorHandler) -> PromiseForResult<Func, T> { PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler) {
return EventLoop::current().thereImpl( return EventLoop::current().thereImpl(
kj::mv(*this), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler), kj::mv(*this), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler),
EventLoop::Event::PREEMPT); EventLoop::Event::PREEMPT);
...@@ -1130,7 +1163,25 @@ Promise<T> newAdaptedPromise(Params&&... adapterConstructorParams) { ...@@ -1130,7 +1163,25 @@ Promise<T> newAdaptedPromise(Params&&... adapterConstructorParams) {
template <typename T> template <typename T>
PromiseFulfillerPair<T> newPromiseAndFulfiller() { PromiseFulfillerPair<T> newPromiseAndFulfiller() {
auto wrapper = heap<_::WeakFulfiller<T>>(); auto wrapper = heap<_::WeakFulfiller<T>>();
Promise<T> promise = newAdaptedPromise<T, _::PromiseAndFulfillerAdapter<T>>(*wrapper);
Own<_::PromiseNode> intermediate(
heap<_::AdapterPromiseNode<_::FixVoid<T>, _::PromiseAndFulfillerAdapter<T>>>(*wrapper));
Promise<_::JoinPromises<T>> promise(
_::maybeChain(kj::mv(intermediate), implicitCast<T*>(nullptr)));
return PromiseFulfillerPair<T> { kj::mv(promise), kj::mv(wrapper) };
}
template <typename T>
PromiseFulfillerPair<T> newPromiseAndFulfiller(const EventLoop& loop) {
auto wrapper = heap<_::WeakFulfiller<T>>();
Own<_::PromiseNode> intermediate(
heap<_::AdapterPromiseNode<_::FixVoid<T>, _::PromiseAndFulfillerAdapter<T>>>(*wrapper));
Promise<_::JoinPromises<T>> promise(
_::maybeChain(kj::mv(intermediate), loop, EventLoop::Event::YIELD,
implicitCast<T*>(nullptr)));
return PromiseFulfillerPair<T> { kj::mv(promise), kj::mv(wrapper) }; return PromiseFulfillerPair<T> { kj::mv(promise), kj::mv(wrapper) };
} }
......
...@@ -342,6 +342,10 @@ T refIfLvalue(T&&); ...@@ -342,6 +342,10 @@ T refIfLvalue(T&&);
template<typename T> constexpr T&& mv(T& t) noexcept { return static_cast<T&&>(t); } template<typename T> constexpr T&& mv(T& t) noexcept { return static_cast<T&&>(t); }
template<typename T> constexpr T&& fwd(NoInfer<T>& t) noexcept { return static_cast<T&&>(t); } template<typename T> constexpr T&& fwd(NoInfer<T>& t) noexcept { return static_cast<T&&>(t); }
template<typename T> constexpr T cp(T& t) noexcept { return t; }
template<typename T> constexpr T cp(const T& t) noexcept { return t; }
// Useful to force a copy, particularly to pass into a function that expects T&&.
template <typename T, typename U> template <typename T, typename U>
inline constexpr auto min(T&& a, U&& b) -> decltype(a < b ? a : b) { return a < b ? a : b; } inline constexpr auto min(T&& a, U&& b) -> decltype(a < b ? a : b) { return a < b ? a : b; }
template <typename T, typename U> template <typename T, typename U>
......
...@@ -207,7 +207,7 @@ Exception::Exception(const Exception& other) noexcept ...@@ -207,7 +207,7 @@ Exception::Exception(const Exception& other) noexcept
memcpy(trace, other.trace, sizeof(trace[0]) * traceCount); memcpy(trace, other.trace, sizeof(trace[0]) * traceCount);
KJ_IF_MAYBE(c, other.context) { KJ_IF_MAYBE(c, other.context) {
context = heap(*c); context = heap(**c);
} }
} }
...@@ -216,7 +216,7 @@ Exception::~Exception() noexcept {} ...@@ -216,7 +216,7 @@ Exception::~Exception() noexcept {}
Exception::Context::Context(const Context& other) noexcept Exception::Context::Context(const Context& other) noexcept
: file(other.file), line(other.line), description(str(other.description)) { : file(other.file), line(other.line), description(str(other.description)) {
KJ_IF_MAYBE(n, other.next) { KJ_IF_MAYBE(n, other.next) {
next = heap(*n); next = heap(**n);
} }
} }
......
...@@ -97,7 +97,7 @@ public: ...@@ -97,7 +97,7 @@ public:
inline Maybe<const Context&> getContext() const { inline Maybe<const Context&> getContext() const {
KJ_IF_MAYBE(c, context) { KJ_IF_MAYBE(c, context) {
return *c; return **c;
} else { } else {
return nullptr; return nullptr;
} }
......
...@@ -168,11 +168,27 @@ private: ...@@ -168,11 +168,27 @@ private:
namespace _ { // private namespace _ { // private
template <typename T> template <typename T>
Own<T>&& readMaybe(Maybe<Own<T>>&& maybe) { return kj::mv(maybe.ptr); } class OwnOwn {
public:
inline OwnOwn(Own<T>&& value) noexcept: value(kj::mv(value)) {}
inline Own<T>& operator*() { return value; }
inline const Own<T>& operator*() const { return value; }
inline Own<T>* operator->() { return &value; }
inline const Own<T>* operator->() const { return &value; }
inline operator Own<T>*() { return value ? &value : nullptr; }
inline operator const Own<T>*() const { return value ? &value : nullptr; }
private:
Own<T> value;
};
template <typename T>
OwnOwn<T> readMaybe(Maybe<Own<T>>&& maybe) { return OwnOwn<T>(kj::mv(maybe.ptr)); }
template <typename T> template <typename T>
T* readMaybe(Maybe<Own<T>>& maybe) { return maybe.ptr; } Own<T>* readMaybe(Maybe<Own<T>>& maybe) { return maybe.ptr ? &maybe.ptr : nullptr; }
template <typename T> template <typename T>
const T* readMaybe(const Maybe<Own<T>>& maybe) { return maybe.ptr; } const Own<T>* readMaybe(const Maybe<Own<T>>& maybe) { return maybe.ptr ? &maybe.ptr : nullptr; }
} // namespace _ (private) } // namespace _ (private)
...@@ -223,11 +239,11 @@ private: ...@@ -223,11 +239,11 @@ private:
template <typename U> template <typename U>
friend class Maybe; friend class Maybe;
template <typename U> template <typename U>
friend Own<U>&& _::readMaybe(Maybe<Own<U>>&& maybe); friend _::OwnOwn<U> _::readMaybe(Maybe<Own<U>>&& maybe);
template <typename U> template <typename U>
friend U* _::readMaybe(Maybe<Own<U>>& maybe); friend Own<U>* _::readMaybe(Maybe<Own<U>>& maybe);
template <typename U> template <typename U>
friend const U* _::readMaybe(const Maybe<Own<U>>& maybe); friend const Own<U>* _::readMaybe(const Maybe<Own<U>>& maybe);
}; };
namespace _ { // private namespace _ { // private
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "one-of.h"
#include "string.h"
#include <gtest/gtest.h>
namespace kj {
TEST(OneOf, Basic) {
OneOf<int, float, String> var;
EXPECT_FALSE(var.is<int>());
EXPECT_FALSE(var.is<float>());
EXPECT_FALSE(var.is<String>());
var.init<int>(123);
EXPECT_TRUE(var.is<int>());
EXPECT_FALSE(var.is<float>());
EXPECT_FALSE(var.is<String>());
EXPECT_EQ(123, var.get<int>());
#if !KJ_NO_EXCEPTIONS
EXPECT_ANY_THROW(var.get<float>());
EXPECT_ANY_THROW(var.get<String>());
#endif
var.init<String>(kj::str("foo"));
EXPECT_FALSE(var.is<int>());
EXPECT_FALSE(var.is<float>());
EXPECT_TRUE(var.is<String>());
EXPECT_EQ("foo", var.get<String>());
OneOf<int, float, String> var2 = kj::mv(var);
EXPECT_EQ("", var.get<String>());
EXPECT_EQ("foo", var2.get<String>());
var = kj::mv(var2);
EXPECT_EQ("foo", var.get<String>());
EXPECT_EQ("", var2.get<String>());
}
TEST(OneOf, Copy) {
OneOf<int, float, const char*> var;
OneOf<int, float, const char*> var2 = var;
EXPECT_FALSE(var2.is<int>());
EXPECT_FALSE(var2.is<float>());
EXPECT_FALSE(var2.is<const char*>());
var.init<int>(123);
var2 = var;
EXPECT_TRUE(var2.is<int>());
EXPECT_EQ(123, var2.get<int>());
var.init<const char*>("foo");
var2 = var;
EXPECT_TRUE(var2.is<const char*>());
EXPECT_STREQ("foo", var2.get<const char*>());
}
} // namespace kj
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef KJ_ONE_OF_H_
#define KJ_ONE_OF_H_
#include "common.h"
namespace kj {
namespace _ { // private
template <uint i, typename Key, typename First, typename... Rest>
struct TypeIndex_ { static constexpr uint value = TypeIndex_<i + 1, Key, Rest...>::value; };
template <uint i, typename Key, typename... Rest>
struct TypeIndex_<i, Key, Key, Rest...> { static constexpr uint value = i; };
} // namespace _ (private)
template <typename... Variants>
class OneOf {
template <typename Key>
static inline constexpr uint typeIndex() { return _::TypeIndex_<1, Key, Variants...>::value; }
// Get the 1-based index of Key within the type list Types.
public:
inline OneOf(): tag(0) {}
OneOf(const OneOf& other) { copyFrom(other); }
OneOf(OneOf&& other) { moveFrom(other); }
~OneOf() { destroy(); }
OneOf& operator=(const OneOf& other) { if (tag != 0) destroy(); copyFrom(other); return *this; }
OneOf& operator=(OneOf&& other) { if (tag != 0) destroy(); moveFrom(other); return *this; }
template <typename T>
bool is() const {
return tag == typeIndex<T>();
}
template <typename T>
T& get() {
KJ_IREQUIRE(is<T>(), "Must check OneOf::is<T>() before calling get<T>().");
return *reinterpret_cast<T*>(space);
}
template <typename T>
const T& get() const {
KJ_IREQUIRE(is<T>(), "Must check OneOf::is<T>() before calling get<T>().");
return *reinterpret_cast<const T*>(space);
}
template <typename T, typename... Params>
void init(Params&&... params) {
if (tag != 0) destroy();
ctor(*reinterpret_cast<T*>(space), kj::fwd<Params>(params)...);
tag = typeIndex<T>();
}
private:
uint tag;
static inline constexpr size_t maxSize(size_t a) {
return a;
}
template <typename... Rest>
static inline constexpr size_t maxSize(size_t a, size_t b, Rest... rest) {
return maxSize(kj::max(a, b), rest...);
}
// Returns the maximum of all the parameters.
// TODO(someday): Generalize the above template and make it common. I tried, but C++ decided to
// be difficult so I cut my losses.
union {
byte space[maxSize(sizeof(Variants)...)];
void* forceAligned;
// TODO(someday): Use C++11 alignas() once we require GCC 4.8 / Clang 3.3.
};
template <typename... T>
inline void doAll(T... t) {}
template <typename T>
KJ_ALWAYS_INLINE(bool destroyVariant()) {
if (tag == typeIndex<T>()) {
tag = 0;
dtor(*reinterpret_cast<T*>(space));
}
return false;
}
void destroy() {
doAll(destroyVariant<Variants>()...);
}
template <typename T>
KJ_ALWAYS_INLINE(bool copyVariantFrom(const OneOf& other)) {
if (other.is<T>()) {
ctor(*reinterpret_cast<T*>(space), other.get<T>());
tag = typeIndex<T>();
}
return false;
}
void copyFrom(const OneOf& other) {
// Initialize as a copy of `other`. Expects that `this` starts out uninitialized, so the tag
// is invalid.
doAll(copyVariantFrom<Variants>(other)...);
}
template <typename T>
KJ_ALWAYS_INLINE(bool moveVariantFrom(OneOf& other)) {
if (other.is<T>()) {
ctor(*reinterpret_cast<T*>(space), kj::mv(other.get<T>()));
tag = typeIndex<T>();
}
return false;
}
void moveFrom(OneOf& other) {
// Initialize as a copy of `other`. Expects that `this` starts out uninitialized, so the tag
// is invalid.
doAll(moveVariantFrom<Variants>(other)...);
}
};
} // namespace kj
#endif // KJ_ONE_OF_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment