Commit e8d256ab authored by Kenton Varda's avatar Kenton Varda

More capabilities work. I think I need to implement Promise forking, though.

parent 7ceed92d
......@@ -71,11 +71,11 @@ SegmentReader* BasicReaderArena::tryGetSegment(SegmentId id) {
SegmentMap* segments = nullptr;
KJ_IF_MAYBE(s, *lock) {
auto iter = s->find(id.value);
if (iter != s->end()) {
auto iter = s->get()->find(id.value);
if (iter != s->get()->end()) {
return iter->second;
}
segments = s;
segments = *s;
}
kj::ArrayPtr<const word> newSegment = message->getSegment(id.value);
......@@ -113,7 +113,7 @@ public:
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const override {
CallContextHook& context) const override {
KJ_FAIL_REQUIRE("Calling capability that was extracted from a message that had no "
"capability context.");
}
......@@ -162,12 +162,12 @@ SegmentReader* ImbuedReaderArena::imbue(SegmentReader* baseSegment) {
SegmentMap* segments = nullptr;
KJ_IF_MAYBE(s, *lock) {
auto iter = s->find(baseSegment);
if (iter != s->end()) {
auto iter = s->get()->find(baseSegment);
if (iter != s->get()->end()) {
KJ_DASSERT(iter->second->getArray().begin() == baseSegment->getArray().begin());
return iter->second;
}
segments = s;
segments = *s;
} else {
auto newMap = kj::heap<SegmentMap>();
segments = newMap;
......@@ -205,10 +205,10 @@ SegmentBuilder* BasicBuilderArena::getSegment(SegmentId id) {
} else {
auto lock = moreSegments.lockShared();
KJ_IF_MAYBE(s, *lock) {
KJ_REQUIRE(id.value - 1 < s->builders.size(), "invalid segment id", id.value);
KJ_REQUIRE(id.value - 1 < s->get()->builders.size(), "invalid segment id", id.value);
// TODO(cleanup): Return a const SegmentBuilder and tediously constify all SegmentBuilder
// pointers throughout the codebase.
return const_cast<BasicSegmentBuilder*>(s->builders[id.value - 1].get());
return const_cast<BasicSegmentBuilder*>(s->get()->builders[id.value - 1].get());
} else {
KJ_FAIL_REQUIRE("invalid segment id", id.value);
}
......@@ -245,11 +245,11 @@ BasicBuilderArena::AllocateResult BasicBuilderArena::allocate(WordCount amount)
// on the last-known available size, and then re-check the size when we pop segments off it
// and shove them to the back of the queue if they have become too small.
attempt = s->builders.back()->allocate(amount);
attempt = s->get()->builders.back()->allocate(amount);
if (attempt != nullptr) {
return AllocateResult { s->builders.back().get(), attempt };
return AllocateResult { s->get()->builders.back().get(), attempt };
}
segmentState = s;
segmentState = *s;
} else {
auto newSegmentState = kj::heap<MultiSegmentState>();
segmentState = newSegmentState;
......@@ -279,15 +279,15 @@ kj::ArrayPtr<const kj::ArrayPtr<const word>> BasicBuilderArena::getSegmentsForOu
// problem regardless of locking here.
KJ_IF_MAYBE(segmentState, moreSegments.getWithoutLock()) {
KJ_DASSERT(segmentState->forOutput.size() == segmentState->builders.size() + 1,
KJ_DASSERT(segmentState->get()->forOutput.size() == segmentState->get()->builders.size() + 1,
"segmentState->forOutput wasn't resized correctly when the last builder was added.",
segmentState->forOutput.size(), segmentState->builders.size());
segmentState->get()->forOutput.size(), segmentState->get()->builders.size());
kj::ArrayPtr<kj::ArrayPtr<const word>> result(
&segmentState->forOutput[0], segmentState->forOutput.size());
&segmentState->get()->forOutput[0], segmentState->get()->forOutput.size());
uint i = 0;
result[i++] = segment0.currentlyAllocated();
for (auto& builder: segmentState->builders) {
for (auto& builder: segmentState->get()->builders) {
result[i++] = builder->currentlyAllocated();
}
return result;
......@@ -314,11 +314,11 @@ SegmentReader* BasicBuilderArena::tryGetSegment(SegmentId id) {
} else {
auto lock = moreSegments.lockShared();
KJ_IF_MAYBE(segmentState, *lock) {
if (id.value <= segmentState->builders.size()) {
if (id.value <= segmentState->get()->builders.size()) {
// TODO(cleanup): Return a const SegmentReader and tediously constify all SegmentBuilder
// pointers throughout the codebase.
return const_cast<SegmentReader*>(kj::implicitCast<const SegmentReader*>(
segmentState->builders[id.value - 1].get()));
segmentState->get()->builders[id.value - 1].get()));
}
}
return nullptr;
......@@ -360,15 +360,15 @@ SegmentBuilder* ImbuedBuilderArena::imbue(SegmentBuilder* baseSegment) {
auto lock = moreSegments.lockExclusive();
KJ_IF_MAYBE(segmentState, *lock) {
auto id = baseSegment->getSegmentId().value;
if (id >= segmentState->builders.size()) {
segmentState->builders.resize(id + 1);
if (id >= segmentState->get()->builders.size()) {
segmentState->get()->builders.resize(id + 1);
}
KJ_IF_MAYBE(segment, segmentState->builders[id]) {
result = segment;
KJ_IF_MAYBE(segment, segmentState->get()->builders[id]) {
result = *segment;
} else {
auto newBuilder = kj::heap<ImbuedSegmentBuilder>(baseSegment);
result = newBuilder;
segmentState->builders[id] = kj::mv(newBuilder);
segmentState->get()->builders[id] = kj::mv(newBuilder);
}
}
return nullptr;
......
......@@ -26,6 +26,7 @@
#include <kj/refcount.h>
#include <kj/debug.h>
#include <kj/vector.h>
#include <kj/one-of.h>
namespace capnp {
......@@ -68,11 +69,19 @@ TypelessResults::Pipeline TypelessResults::Pipeline::getPointerField(
ResponseHook::~ResponseHook() noexcept(false) {}
// =======================================================================================
kj::Promise<void> ClientHook::whenResolved() const {
KJ_IF_MAYBE(promise, whenMoreResolved()) {
return promise->then([](kj::Own<const ClientHook>&& resolution) {
return resolution->whenResolved();
});
} else {
return kj::READY_NOW;
}
}
namespace {
// =======================================================================================
class LocalResponse final: public ResponseHook {
class LocalResponse final: public ResponseHook, public kj::Refcounted {
public:
LocalResponse(uint sizeHint)
: message(sizeHint == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : sizeHint) {}
......@@ -93,7 +102,7 @@ public:
}
ObjectPointer::Builder getResults(uint firstSegmentWordSize) override {
if (!response) {
response = kj::heap<LocalResponse>(firstSegmentWordSize);
response = kj::refcounted<LocalResponse>(firstSegmentWordSize);
}
return response->message.getRoot<ObjectPointer>();
}
......@@ -103,70 +112,275 @@ public:
bool isCanceled() override {
return false;
}
Response<ObjectPointer> getResponseForPipeline() override {
auto reader = getResults(1); // Needs to be a separate line since it may allocate the response.
return Response<ObjectPointer>(reader, kj::addRef(*response));
}
kj::Own<MallocMessageBuilder> request;
kj::Own<LocalResponse> response;
kj::Own<const ClientHook> clientRef;
};
class LocalPipelinedClient final: public ClientHook, public kj::Refcounted {
class LocalRequest final: public RequestHook {
public:
LocalPipelinedClient(kj::Promise<kj::Own<const ClientHook>> promise)
: innerPromise(promise.then([this](kj::Own<const ClientHook>&& resolution) {
inline LocalRequest(uint64_t interfaceId, uint16_t methodId,
uint firstSegmentWordSize, kj::Own<const ClientHook> client)
: message(kj::heap<MallocMessageBuilder>(
firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize)),
interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {}
RemotePromise<TypelessResults> send() override {
// For the lambda capture.
uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId;
auto context = kj::heap<LocalCallContext>(kj::mv(message), kj::mv(client));
auto promiseAndPipeline = client->call(interfaceId, methodId, *context);
auto promise = promiseAndPipeline.promise.then(
kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
return Response<TypelessResults>(context->getResults(1).asReader(),
kj::mv(context->response));
}));
return RemotePromise<TypelessResults>(
kj::mv(promise), TypelessResults::Pipeline(kj::mv(promiseAndPipeline.pipeline)));
}
kj::Own<MallocMessageBuilder> message;
private:
uint64_t interfaceId;
uint16_t methodId;
kj::Own<const ClientHook> client;
};
// =======================================================================================
namespace {
class BrokenPipeline final: public PipelineHook, public kj::Refcounted {
public:
BrokenPipeline(const kj::Exception& exception): exception(exception) {}
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
}
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const override;
private:
kj::Exception exception;
};
class BrokenClient final: public ClientHook, public kj::Refcounted {
public:
BrokenClient(const kj::Exception& exception): exception(exception) {}
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
return Request<ObjectPointer, TypelessResults>(
hook->message->getRoot<ObjectPointer>(), kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContextHook& context) const override {
return VoidPromiseAndPipeline { kj::cp(exception), kj::heap<BrokenPipeline>(exception) };
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return kj::Promise<kj::Own<const ClientHook>>(kj::cp(exception));
}
kj::Own<const ClientHook> addRef() const override {
return kj::addRef(*this);
}
void* getBrand() const override {
return nullptr;
}
private:
kj::Exception exception;
};
kj::Own<const ClientHook> BrokenPipeline::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) const {
return kj::heap<BrokenClient>(exception);
}
// =======================================================================================
// Call queues
//
// These classes handle pipelining in the case where calls need to be queued in-memory until some
// local operation completes.
class QueuedPipeline final: public PipelineHook, public kj::Refcounted {
// A PipelineHook which simply queues calls while waiting for a PipelineHook to which to forward
// them.
public:
QueuedPipeline(kj::EventLoop& loop, kj::Promise<kj::Own<const PipelineHook>>&& promise)
: loop(loop),
innerPromise(loop.there(kj::mv(promise), [this](kj::Own<const PipelineHook>&& resolution) {
auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
auto oldState = kj::mv(*lock);
for (auto& call: oldState.pending) {
call.fulfiller->fulfill(resolution->call(
call.interfaceId, call.methodId, call.context).promise);
for (auto& waiter: oldState) {
waiter.fulfiller->fulfill(resolution->getPipelinedCap(kj::mv(waiter.ops)));
}
lock->init<Resolved>(kj::mv(resolution));
}, [this](kj::Exception&& exception) {
auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
for (auto& waiter: oldState) {
waiter.fulfiller->reject(kj::cp(exception));
}
lock->init<kj::Exception>(kj::mv(exception));
})) {
state.getWithoutLock().init<Waiting>();
}
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
}
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const override {
auto copy = kj::heapArrayBuilder<PipelineOp>(ops.size());
for (auto& op: ops) {
copy.add(op);
}
return getPipelinedCap(kj::mv(copy));
}
kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const override;
private:
struct Waiter {
kj::Array<PipelineOp> ops;
kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>> fulfiller;
};
typedef kj::Vector<Waiter> Waiting;
typedef kj::Own<const PipelineHook> Resolved;
kj::EventLoop& loop;
kj::Promise<void> innerPromise;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, kj::Exception>> state;
};
class QueuedClient final: public ClientHook, public kj::Refcounted {
// A ClientHook which simply queues calls while waiting for a ClientHook to which to forward
// them.
public:
QueuedClient(kj::EventLoop& loop, kj::Promise<kj::Own<const ClientHook>>&& promise)
: loop(loop),
innerPromise(loop.there(kj::mv(promise), [this](kj::Own<const ClientHook>&& resolution) {
// The promised capability has resolved. Forward all queued calls to it.
auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
// First we want to initiate all the queued calls, and notify the QueuedPipelines to
// transfer their queues to the new call's own pipeline. It's important that this all
// happen before the application receives any notification that the promise resolved,
// so that any new calls it makes in response to the resolution don't end up being
// delivered before the previously-queued calls.
auto realCallPromises = kj::heapArrayBuilder<kj::Promise<void>>(oldState.pending.size());
for (auto& pendingCall: oldState.pending) {
auto realCall = resolution->call(
pendingCall.interfaceId, pendingCall.methodId, *pendingCall.context);
pendingCall.pipelineFulfiller->fulfill(kj::mv(realCall.pipeline));
realCallPromises.add(kj::mv(realCall.promise));
}
// Fire the "whenMoreResolved" callbacks.
for (auto& notify: oldState.notifyOnResolution) {
notify->fulfill(resolution->addRef());
}
lock->resolution = kj::mv(resolution);
// For each queued call, chain the pipelined promise to the real promise. It's important
// that this happens after the "whenMoreResolved" callbacks because applications may get
// confused if a pipelined call completes before the promise on which it was made
// resolves.
for (uint i: kj::indices(realCallPromises)) {
oldState.pending[i].fulfiller->fulfill(kj::mv(realCallPromises[i]));
}
lock->init<Resolved>(kj::mv(resolution));
}, [this](kj::Exception&& exception) {
auto lock = state.lockExclusive();
auto oldState = kj::mv(lock->get<Waiting>());
auto oldState = kj::mv(*lock);
for (auto& call: oldState.pending) {
call.fulfiller->reject(kj::Exception(exception));
}
// Reject outer promises before dependent promises.
for (auto& notify: oldState.notifyOnResolution) {
notify->reject(kj::Exception(exception));
notify->reject(kj::cp(exception));
}
for (auto& call: oldState.pending) {
call.fulfiller->reject(kj::cp(exception));
call.pipelineFulfiller->reject(kj::cp(exception));
}
lock->exception = kj::mv(exception);
})) {}
lock->init<kj::Exception>(kj::mv(exception));
})) {
state.getWithoutLock().init<Waiting>();
}
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto lock = state.lockExclusive();
KJ_IF_MAYBE(r, lock->resolution) {
return r->newCall(interfaceId, methodId, firstSegmentWordSize);
if (lock->is<Resolved>()) {
return lock->get<Resolved>()->newCall(interfaceId, methodId, firstSegmentWordSize);
} else {
// TODO(now)
auto hook = kj::heap<LocalRequest>(
interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
return Request<ObjectPointer, TypelessResults>(
hook->message->getRoot<ObjectPointer>(), kj::mv(hook));
}
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const override {
CallContextHook& context) const override {
auto lock = state.lockExclusive();
KJ_IF_MAYBE(r, lock->resolution) {
return r->call(interfaceId, methodId, context);
if (lock->is<Resolved>()) {
return lock->get<Resolved>()->call(interfaceId, methodId, context);
} else if (lock->is<kj::Exception>()) {
return VoidPromiseAndPipeline { kj::cp(lock->get<kj::Exception>()),
kj::heap<BrokenPipeline>(lock->get<kj::Exception>()) };
} else {
lock->pending.add(PendingCall { interfaceId, methodId, context });
auto pair = kj::newPromiseAndFulfiller<kj::Promise<void>>(loop);
auto pipelinePromise = kj::newPromiseAndFulfiller<kj::Own<const PipelineHook>>(loop);
auto pipeline = kj::heap<QueuedPipeline>(loop, kj::mv(pipelinePromise.promise));
lock->get<Waiting>().pending.add(PendingCall {
interfaceId, methodId, &context, kj::mv(pair.fulfiller),
kj::mv(pipelinePromise.fulfiller) });
// TODO(now): returned promise must hold a reference to this.
return VoidPromiseAndPipeline { kj::mv(pair.promise), kj::mv(pipeline) };
}
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
auto lock = state.lockExclusive();
KJ_IF_MAYBE(r, lock->resolution) {
if (lock->is<Resolved>()) {
// Already resolved.
return kj::Promise<kj::Own<const ClientHook>>(r->addRef());
return kj::Promise<kj::Own<const ClientHook>>(lock->get<Resolved>()->addRef());
} else if (lock->is<kj::Exception>()) {
// Already broken.
return kj::Promise<kj::Own<const ClientHook>>(kj::Own<const ClientHook>(
kj::heap<BrokenClient>(lock->get<kj::Exception>())));
} else {
// Waiting.
auto pair = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>();
lock->notifyOnResolution.add(kj::mv(pair.fulfiller));
lock->get<Waiting>().notifyOnResolution.add(kj::mv(pair.fulfiller));
// TODO(now): returned promise must hold a reference to this.
return kj::mv(pair.promise);
}
}
......@@ -183,81 +397,52 @@ private:
struct PendingCall {
uint64_t interfaceId;
uint16_t methodId;
CallContext<ObjectPointer, ObjectPointer> context;
kj::Own<kj::PromiseFulfiller<void>> fulfiller;
CallContextHook* context;
kj::Own<kj::PromiseFulfiller<kj::Promise<void>>> fulfiller;
kj::Own<kj::PromiseFulfiller<kj::Own<const PipelineHook>>> pipelineFulfiller;
};
struct State {
kj::Maybe<kj::Own<const ClientHook>> resolution;
struct Waiting {
kj::Vector<PendingCall> pending;
kj::Vector<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> notifyOnResolution;
};
kj::MutexGuarded<State> state;
typedef kj::Own<const ClientHook> Resolved;
kj::EventLoop& loop;
kj::Promise<void> innerPromise;
kj::MutexGuarded<kj::OneOf<Waiting, Resolved, kj::Exception>> state;
};
class LocalPipeline final: public PipelineHook, public kj::Refcounted {
public:
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
}
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const override {
kj::Own<const ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) const {
auto lock = state.lockExclusive();
if (lock->is<Resolved>()) {
return lock->get<Resolved>()->getPipelinedCap(ops);
} else if (lock->is<kj::Exception>()) {
return kj::heap<BrokenClient>(lock->get<kj::Exception>());
} else {
auto pair = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>();
lock->get<Waiting>().add(Waiter { kj::mv(ops), kj::mv(pair.fulfiller) });
return kj::heap<QueuedClient>(loop, kj::mv(pair.promise));
}
}
private:
struct Waiter {
};
struct State {
kj::Vector<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> notifyOnResolution;
};
kj::MutexGuarded<State> state;
};
// =======================================================================================
class LocalRequest final: public RequestHook {
class LocalPipeline final: public PipelineHook, public kj::Refcounted {
public:
inline LocalRequest(kj::EventLoop& eventLoop, const Capability::Server* server,
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize,
kj::Own<const ClientHook> clientRef)
: message(kj::heap<MallocMessageBuilder>(
firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize)),
eventLoop(eventLoop), server(server), interfaceId(interfaceId), methodId(methodId),
clientRef(kj::mv(clientRef)) {}
RemotePromise<TypelessResults> send() override {
// For the lambda capture.
// We can const-cast the server pointer because we are synchronizing to its event loop here.
Capability::Server* server = const_cast<Capability::Server*>(this->server);
uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId;
inline LocalPipeline(Response<ObjectPointer> response): response(kj::mv(response)) {}
auto context = kj::heap<LocalCallContext>(kj::mv(message), kj::mv(clientRef));
auto promise = eventLoop.evalLater(
kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
return server->dispatchCall(interfaceId, methodId,
CallContext<ObjectPointer, ObjectPointer>(*context))
.then(kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
return Response<TypelessResults>(context->getResults(1).asReader(),
kj::mv(context->response));
}));
}));
return RemotePromise<TypelessResults>(
kj::mv(promise),
TypelessResults::Pipeline(kj::heap<LocalPipeline>()));
kj::Own<const PipelineHook> addRef() const {
return kj::addRef(*this);
}
kj::Own<MallocMessageBuilder> message;
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const {
return response.getPipelinedCap(ops);
}
private:
kj::EventLoop& eventLoop;
const Capability::Server* server;
uint64_t interfaceId;
uint16_t methodId;
kj::Own<const ClientHook> clientRef;
Response<ObjectPointer> response;
};
class LocalClient final: public ClientHook, public kj::Refcounted {
......@@ -268,21 +453,30 @@ public:
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto hook = kj::heap<LocalRequest>(
eventLoop, server, interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
return Request<ObjectPointer, TypelessResults>(
hook->message->getRoot<ObjectPointer>(), kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const override {
CallContextHook& context) const override {
// We can const-cast the server because we're synchronizing on the event loop.
auto server = const_cast<Capability::Server*>(this->server.get());
auto promise = eventLoop.evalLater([=]() mutable {
return server->dispatchCall(interfaceId, methodId, context);
});
auto pipelineFulfiller = kj::newPromiseAndFulfiller<kj::Own<const PipelineHook>>();
auto promise = eventLoop.evalLater(kj::mvCapture(pipelineFulfiller.fulfiller,
[=,&context](kj::Own<kj::PromiseFulfiller<kj::Own<const PipelineHook>>>&& fulfiller) mutable {
return server->dispatchCall(interfaceId, methodId,
CallContext<ObjectPointer, ObjectPointer>(context))
.then(kj::mvCapture(fulfiller,
[=,&context](kj::Own<kj::PromiseFulfiller<kj::Own<const PipelineHook>>>&& fulfiller) {
fulfiller->fulfill(kj::heap<LocalPipeline>(context.getResponseForPipeline()));
}));
}));
return VoidPromiseAndPipeline { kj::mv(promise),
TypelessResults::Pipeline(kj::heap<LocalPipeline>()) };
kj::heap<QueuedPipeline>(eventLoop, kj::mv(pipelineFulfiller.promise)) };
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
......
......@@ -321,6 +321,12 @@ public:
virtual kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const = 0;
// Extract a promised Capability from the results.
virtual kj::Own<const ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) const {
// Version of getPipelinedCap() passing the array by move. May avoid a copy in some cases.
// Default implementation just calls the other version.
return getPipelinedCap(ops.asPtr());
}
};
class ClientHook {
......@@ -332,11 +338,11 @@ public:
struct VoidPromiseAndPipeline {
kj::Promise<void> promise;
TypelessResults::Pipeline pipeline;
kj::Own<const PipelineHook> pipeline;
};
virtual VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const = 0;
CallContextHook& context) const = 0;
// Call the object, but the caller controls allocation of the request/response objects. If the
// callee insists on allocating this objects itself, it must make a copy. This version is used
// when calls come in over the network via an RPC system. During the call, the context object
......@@ -351,6 +357,9 @@ public:
// promise that eventually resolves to a new client that is closer to being the final, settled
// client. Calling this repeatedly should eventually produce a settled client.
kj::Promise<void> whenResolved() const;
// Repeatedly calls whenMoreResolved() until it returns nullptr.
virtual kj::Own<const ClientHook> addRef() const = 0;
// Return a new reference to the same capability.
......@@ -370,6 +379,11 @@ public:
virtual ObjectPointer::Builder getResults(uint firstSegmentWordSize) = 0;
virtual void allowAsyncCancellation(bool allow) = 0;
virtual bool isCanceled() = 0;
virtual Response<ObjectPointer> getResponseForPipeline() = 0;
// Get a copy or reference to the response which will be used to execute pipelined calls. This
// will be called no more than once, just after the server implementation successfully returns
// from the call.
};
// =======================================================================================
......
......@@ -1149,7 +1149,7 @@ private:
return MethodText {
kj::strTree(
" ::capnp::Request<", paramType, ", ", resultType, "> ", name, "Request(\n"
" uint firstSegmentWordSize = 0);\n"),
" unsigned int firstSegmentWordSize = 0);\n"),
kj::strTree(
" virtual ::kj::Promise<void> ", name, "(\n"
......@@ -1162,7 +1162,7 @@ private:
kj::strTree(
"::capnp::Request<", paramType, ", ", resultType, ">\n",
interfaceName, "::Client::", name, "Request(uint firstSegmentWordSize) {\n"
interfaceName, "::Client::", name, "Request(unsigned int firstSegmentWordSize) {\n"
" return newCall<", paramType, ", ", resultType, ">(\n"
" 0x", interfaceIdHex, "ull, ", methodId, ", firstSegmentWordSize);\n"
"}\n"
......
......@@ -676,8 +676,8 @@ static kj::Maybe<kj::Exception> loadFile(
KJ_IF_MAYBE(m, messageBuilder) {
// Build an example struct using the compiled schema.
m->adoptRoot(makeExampleStruct(
m->getOrphanage(), compiler.getLoader().get(0x823456789abcdef1llu).asStruct(),
m->get()->adoptRoot(makeExampleStruct(
m->get()->getOrphanage(), compiler.getLoader().get(0x823456789abcdef1llu).asStruct(),
sharedOrdinalCount));
}
......@@ -692,7 +692,7 @@ static kj::Maybe<kj::Exception> loadFile(
KJ_IF_MAYBE(m, messageBuilder) {
// Check that the example struct matches the compiled schema.
auto root = m->getRoot<DynamicStruct>(
auto root = m->get()->getRoot<DynamicStruct>(
compiler.getLoader().get(0x823456789abcdef1llu).asStruct()).asReader();
KJ_CONTEXT(root);
checkExampleStruct(root, sharedOrdinalCount);
......
......@@ -810,11 +810,11 @@ struct PointerHelpers<DynamicList, Kind::UNKNOWN> {
} // namespace _ (private)
template <typename T>
inline typename T::Reader ObjectPointer::Reader::getAs(StructSchema schema) {
inline typename T::Reader ObjectPointer::Reader::getAs(StructSchema schema) const {
return _::PointerHelpers<T>::getDynamic(reader, schema);
}
template <typename T>
inline typename T::Reader ObjectPointer::Reader::getAs(ListSchema schema) {
inline typename T::Reader ObjectPointer::Reader::getAs(ListSchema schema) const {
return _::PointerHelpers<T>::getDynamic(reader, schema);
}
template <typename T>
......
......@@ -165,7 +165,7 @@ MallocMessageBuilder::~MallocMessageBuilder() noexcept(false) {
}
KJ_IF_MAYBE(s, moreSegments) {
for (void* ptr: s->segments) {
for (void* ptr: s->get()->segments) {
free(ptr);
}
}
......@@ -201,7 +201,7 @@ kj::ArrayPtr<word> MallocMessageBuilder::allocateSegment(uint minimumSize) {
} else {
MoreSegments* segments;
KJ_IF_MAYBE(s, moreSegments) {
segments = s;
segments = *s;
} else {
auto newSegments = kj::heap<MoreSegments>();
segments = newSegments;
......
......@@ -27,7 +27,7 @@
namespace capnp {
kj::Own<const ClientHook> ObjectPointer::Reader::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) {
kj::ArrayPtr<const PipelineOp> ops) const {
_::PointerReader pointer = reader;
for (auto& op: ops) {
......
......@@ -47,21 +47,21 @@ struct ObjectPointer {
Reader() = default;
inline Reader(_::PointerReader reader): reader(reader) {}
inline bool isNull();
inline bool isNull() const;
template <typename T>
inline typename T::Reader getAs();
inline typename T::Reader getAs() const;
// Valid for T = any generated struct type, List<U>, Text, or Data.
template <typename T>
inline typename T::Reader getAs(StructSchema schema);
inline typename T::Reader getAs(StructSchema schema) const;
// Only valid for T = DynamicStruct. Requires `#include <capnp/dynamic.h>`.
template <typename T>
inline typename T::Reader getAs(ListSchema schema);
inline typename T::Reader getAs(ListSchema schema) const;
// Only valid for T = DynamicList. Requires `#include <capnp/dynamic.h>`.
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops);
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const;
// Used by RPC system to implement pipelining. Applications generally shouldn't use this
// directly.
......@@ -209,12 +209,12 @@ private:
// =======================================================================================
// Inline implementation details
inline bool ObjectPointer::Reader::isNull() {
inline bool ObjectPointer::Reader::isNull() const {
return reader.isNull();
}
template <typename T>
inline typename T::Reader ObjectPointer::Reader::getAs() {
inline typename T::Reader ObjectPointer::Reader::getAs() const {
return _::PointerHelpers<T>::get(reader);
}
......@@ -292,7 +292,7 @@ inline Orphan<T> Orphan<ObjectPointer>::releaseAs() {
// Using ObjectPointer as the template type should work...
template <>
inline typename ObjectPointer::Reader ObjectPointer::Reader::getAs<ObjectPointer>() {
inline typename ObjectPointer::Reader ObjectPointer::Reader::getAs<ObjectPointer>() const {
return *this;
}
template <>
......
......@@ -101,7 +101,7 @@ public:
kj::Maybe<const Module&> importRelative(kj::StringPtr importPath) const override {
KJ_IF_MAYBE(importedFile, file->import(importPath)) {
return parser.getModuleImpl(kj::mv(importedFile));
return parser.getModuleImpl(kj::mv(*importedFile));
} else {
return nullptr;
}
......
......@@ -206,6 +206,21 @@ TEST(Async, SeparateFulfillerCanceled) {
EXPECT_FALSE(pair.fulfiller->isWaiting());
}
TEST(Async, SeparateFulfillerChained) {
SimpleEventLoop loop;
auto pair = newPromiseAndFulfiller<Promise<int>>(loop);
auto inner = newPromiseAndFulfiller<int>();
EXPECT_TRUE(pair.fulfiller->isWaiting());
pair.fulfiller->fulfill(kj::mv(inner.promise));
EXPECT_FALSE(pair.fulfiller->isWaiting());
inner.fulfiller->fulfill(123);
EXPECT_EQ(123, loop.wait(kj::mv(pair.promise)));
}
#if KJ_NO_EXCEPTIONS
#undef EXPECT_ANY_THROW
#define EXPECT_ANY_THROW(code) EXPECT_DEATH(code, ".")
......
......@@ -36,6 +36,8 @@ template <typename T>
class Promise;
template <typename T>
class PromiseFulfiller;
template <typename T>
struct PromiseFulfillerPair;
// =======================================================================================
// ***************************************************************************************
......@@ -468,11 +470,14 @@ public:
//
// For void promises, use `kj::READY_NOW` as the value, e.g. `return kj::READY_NOW`.
Promise(kj::Exception&& e);
// Construct an already-broken Promise.
inline Promise(decltype(nullptr)) {}
template <typename Func, typename ErrorFunc = _::PropagateException>
auto then(Func&& func, ErrorFunc&& errorHandler = _::PropagateException())
-> PromiseForResult<Func, T>;
PromiseForResult<Func, T> then(Func&& func, ErrorFunc&& errorHandler = _::PropagateException())
KJ_WARN_UNUSED_RESULT;
// Register a continuation function to be executed when the promise completes. The continuation
// (`func`) takes the promised value (an rvalue of type `T`) as its parameter. The continuation
// may return a new value; `then()` itself returns a promise for the continuation's eventual
......@@ -561,6 +566,10 @@ private:
friend class EventLoop;
template <typename U, typename Adapter, typename... Params>
friend Promise<U> newAdaptedPromise(Params&&... adapterConstructorParams);
template <typename U>
friend PromiseFulfillerPair<U> newPromiseAndFulfiller();
template <typename U>
friend PromiseFulfillerPair<U> newPromiseAndFulfiller(const EventLoop& loop);
};
constexpr _::Void READY_NOW = _::Void();
......@@ -672,18 +681,28 @@ Promise<T> newAdaptedPromise(Params&&... adapterConstructorParams);
template <typename T>
struct PromiseFulfillerPair {
Promise<T> promise;
Promise<_::JoinPromises<T>> promise;
Own<PromiseFulfiller<T>> fulfiller;
};
template <typename T>
PromiseFulfillerPair<T> newPromiseAndFulfiller();
template <typename T>
PromiseFulfillerPair<T> newPromiseAndFulfiller(const EventLoop& loop);
// Construct a Promise and a separate PromiseFulfiller which can be used to fulfill the promise.
// If the PromiseFulfiller is destroyed before either of its methods are called, the Promise is
// implicitly rejected.
//
// Although this function is easier to use than `newAdaptedPromise()`, it has the serious drawback
// that there is no way to handle cancellation (i.e. detect when the Promise is discarded).
//
// You can arrange to fulfill a promise with another promise by using a promise type for T. E.g.
// if `T` is `Promise<U>`, then the returned promise will be of type `Promise<U>` but the fulfiller
// will be of type `PromiseFulfiller<Promise<U>>`. Thus you pass a `Promise<U>` to the `fulfill()`
// callback, and the promises are chained. In this case, an `EventLoop` is needed in order to wait
// on the chained promise; you can specify one as a parameter, otherwise the current loop in the
// thread that called `newPromiseAndFulfiller` will be used. If `T` is *not* a promise type, then
// no `EventLoop` is needed and the `loop` parameter, if specified, is ignored.
// =======================================================================================
// internal implementation details follow
......@@ -891,6 +910,16 @@ Own<PromiseNode>&& maybeChain(Own<PromiseNode>&& node, const EventLoop& loop,
return kj::mv(node);
}
template <typename T>
Own<PromiseNode> maybeChain(Own<PromiseNode>&& node, Promise<T>*) {
return heap<ChainPromiseNode>(EventLoop::current(), kj::mv(node), EventLoop::Event::PREEMPT);
}
template <typename T>
Own<PromiseNode>&& maybeChain(Own<PromiseNode>&& node, T*) {
return kj::mv(node);
}
class CrossThreadPromiseNodeBase: public PromiseNode, private EventLoop::Event {
// A PromiseNode that safely imports a promised value from one EventLoop to another (which
// implies crossing threads).
......@@ -1048,9 +1077,13 @@ template <typename T>
Promise<T>::Promise(_::FixVoid<T> value)
: PromiseBase(heap<_::ImmediatePromiseNode<_::FixVoid<T>>>(kj::mv(value))) {}
template <typename T>
Promise<T>::Promise(kj::Exception&& exception)
: PromiseBase(heap<_::ImmediateBrokenPromiseNode>(kj::mv(exception))) {}
template <typename T>
template <typename Func, typename ErrorFunc>
auto Promise<T>::then(Func&& func, ErrorFunc&& errorHandler) -> PromiseForResult<Func, T> {
PromiseForResult<Func, T> Promise<T>::then(Func&& func, ErrorFunc&& errorHandler) {
return EventLoop::current().thereImpl(
kj::mv(*this), kj::fwd<Func>(func), kj::fwd<ErrorFunc>(errorHandler),
EventLoop::Event::PREEMPT);
......@@ -1130,7 +1163,25 @@ Promise<T> newAdaptedPromise(Params&&... adapterConstructorParams) {
template <typename T>
PromiseFulfillerPair<T> newPromiseAndFulfiller() {
auto wrapper = heap<_::WeakFulfiller<T>>();
Promise<T> promise = newAdaptedPromise<T, _::PromiseAndFulfillerAdapter<T>>(*wrapper);
Own<_::PromiseNode> intermediate(
heap<_::AdapterPromiseNode<_::FixVoid<T>, _::PromiseAndFulfillerAdapter<T>>>(*wrapper));
Promise<_::JoinPromises<T>> promise(
_::maybeChain(kj::mv(intermediate), implicitCast<T*>(nullptr)));
return PromiseFulfillerPair<T> { kj::mv(promise), kj::mv(wrapper) };
}
template <typename T>
PromiseFulfillerPair<T> newPromiseAndFulfiller(const EventLoop& loop) {
auto wrapper = heap<_::WeakFulfiller<T>>();
Own<_::PromiseNode> intermediate(
heap<_::AdapterPromiseNode<_::FixVoid<T>, _::PromiseAndFulfillerAdapter<T>>>(*wrapper));
Promise<_::JoinPromises<T>> promise(
_::maybeChain(kj::mv(intermediate), loop, EventLoop::Event::YIELD,
implicitCast<T*>(nullptr)));
return PromiseFulfillerPair<T> { kj::mv(promise), kj::mv(wrapper) };
}
......
......@@ -342,6 +342,10 @@ T refIfLvalue(T&&);
template<typename T> constexpr T&& mv(T& t) noexcept { return static_cast<T&&>(t); }
template<typename T> constexpr T&& fwd(NoInfer<T>& t) noexcept { return static_cast<T&&>(t); }
template<typename T> constexpr T cp(T& t) noexcept { return t; }
template<typename T> constexpr T cp(const T& t) noexcept { return t; }
// Useful to force a copy, particularly to pass into a function that expects T&&.
template <typename T, typename U>
inline constexpr auto min(T&& a, U&& b) -> decltype(a < b ? a : b) { return a < b ? a : b; }
template <typename T, typename U>
......
......@@ -207,7 +207,7 @@ Exception::Exception(const Exception& other) noexcept
memcpy(trace, other.trace, sizeof(trace[0]) * traceCount);
KJ_IF_MAYBE(c, other.context) {
context = heap(*c);
context = heap(**c);
}
}
......@@ -216,7 +216,7 @@ Exception::~Exception() noexcept {}
Exception::Context::Context(const Context& other) noexcept
: file(other.file), line(other.line), description(str(other.description)) {
KJ_IF_MAYBE(n, other.next) {
next = heap(*n);
next = heap(**n);
}
}
......
......@@ -97,7 +97,7 @@ public:
inline Maybe<const Context&> getContext() const {
KJ_IF_MAYBE(c, context) {
return *c;
return **c;
} else {
return nullptr;
}
......
......@@ -168,11 +168,27 @@ private:
namespace _ { // private
template <typename T>
Own<T>&& readMaybe(Maybe<Own<T>>&& maybe) { return kj::mv(maybe.ptr); }
class OwnOwn {
public:
inline OwnOwn(Own<T>&& value) noexcept: value(kj::mv(value)) {}
inline Own<T>& operator*() { return value; }
inline const Own<T>& operator*() const { return value; }
inline Own<T>* operator->() { return &value; }
inline const Own<T>* operator->() const { return &value; }
inline operator Own<T>*() { return value ? &value : nullptr; }
inline operator const Own<T>*() const { return value ? &value : nullptr; }
private:
Own<T> value;
};
template <typename T>
OwnOwn<T> readMaybe(Maybe<Own<T>>&& maybe) { return OwnOwn<T>(kj::mv(maybe.ptr)); }
template <typename T>
T* readMaybe(Maybe<Own<T>>& maybe) { return maybe.ptr; }
Own<T>* readMaybe(Maybe<Own<T>>& maybe) { return maybe.ptr ? &maybe.ptr : nullptr; }
template <typename T>
const T* readMaybe(const Maybe<Own<T>>& maybe) { return maybe.ptr; }
const Own<T>* readMaybe(const Maybe<Own<T>>& maybe) { return maybe.ptr ? &maybe.ptr : nullptr; }
} // namespace _ (private)
......@@ -223,11 +239,11 @@ private:
template <typename U>
friend class Maybe;
template <typename U>
friend Own<U>&& _::readMaybe(Maybe<Own<U>>&& maybe);
friend _::OwnOwn<U> _::readMaybe(Maybe<Own<U>>&& maybe);
template <typename U>
friend U* _::readMaybe(Maybe<Own<U>>& maybe);
friend Own<U>* _::readMaybe(Maybe<Own<U>>& maybe);
template <typename U>
friend const U* _::readMaybe(const Maybe<Own<U>>& maybe);
friend const Own<U>* _::readMaybe(const Maybe<Own<U>>& maybe);
};
namespace _ { // private
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "one-of.h"
#include "string.h"
#include <gtest/gtest.h>
namespace kj {
TEST(OneOf, Basic) {
OneOf<int, float, String> var;
EXPECT_FALSE(var.is<int>());
EXPECT_FALSE(var.is<float>());
EXPECT_FALSE(var.is<String>());
var.init<int>(123);
EXPECT_TRUE(var.is<int>());
EXPECT_FALSE(var.is<float>());
EXPECT_FALSE(var.is<String>());
EXPECT_EQ(123, var.get<int>());
#if !KJ_NO_EXCEPTIONS
EXPECT_ANY_THROW(var.get<float>());
EXPECT_ANY_THROW(var.get<String>());
#endif
var.init<String>(kj::str("foo"));
EXPECT_FALSE(var.is<int>());
EXPECT_FALSE(var.is<float>());
EXPECT_TRUE(var.is<String>());
EXPECT_EQ("foo", var.get<String>());
OneOf<int, float, String> var2 = kj::mv(var);
EXPECT_EQ("", var.get<String>());
EXPECT_EQ("foo", var2.get<String>());
var = kj::mv(var2);
EXPECT_EQ("foo", var.get<String>());
EXPECT_EQ("", var2.get<String>());
}
TEST(OneOf, Copy) {
OneOf<int, float, const char*> var;
OneOf<int, float, const char*> var2 = var;
EXPECT_FALSE(var2.is<int>());
EXPECT_FALSE(var2.is<float>());
EXPECT_FALSE(var2.is<const char*>());
var.init<int>(123);
var2 = var;
EXPECT_TRUE(var2.is<int>());
EXPECT_EQ(123, var2.get<int>());
var.init<const char*>("foo");
var2 = var;
EXPECT_TRUE(var2.is<const char*>());
EXPECT_STREQ("foo", var2.get<const char*>());
}
} // namespace kj
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef KJ_ONE_OF_H_
#define KJ_ONE_OF_H_
#include "common.h"
namespace kj {
namespace _ { // private
template <uint i, typename Key, typename First, typename... Rest>
struct TypeIndex_ { static constexpr uint value = TypeIndex_<i + 1, Key, Rest...>::value; };
template <uint i, typename Key, typename... Rest>
struct TypeIndex_<i, Key, Key, Rest...> { static constexpr uint value = i; };
} // namespace _ (private)
template <typename... Variants>
class OneOf {
template <typename Key>
static inline constexpr uint typeIndex() { return _::TypeIndex_<1, Key, Variants...>::value; }
// Get the 1-based index of Key within the type list Types.
public:
inline OneOf(): tag(0) {}
OneOf(const OneOf& other) { copyFrom(other); }
OneOf(OneOf&& other) { moveFrom(other); }
~OneOf() { destroy(); }
OneOf& operator=(const OneOf& other) { if (tag != 0) destroy(); copyFrom(other); return *this; }
OneOf& operator=(OneOf&& other) { if (tag != 0) destroy(); moveFrom(other); return *this; }
template <typename T>
bool is() const {
return tag == typeIndex<T>();
}
template <typename T>
T& get() {
KJ_IREQUIRE(is<T>(), "Must check OneOf::is<T>() before calling get<T>().");
return *reinterpret_cast<T*>(space);
}
template <typename T>
const T& get() const {
KJ_IREQUIRE(is<T>(), "Must check OneOf::is<T>() before calling get<T>().");
return *reinterpret_cast<const T*>(space);
}
template <typename T, typename... Params>
void init(Params&&... params) {
if (tag != 0) destroy();
ctor(*reinterpret_cast<T*>(space), kj::fwd<Params>(params)...);
tag = typeIndex<T>();
}
private:
uint tag;
static inline constexpr size_t maxSize(size_t a) {
return a;
}
template <typename... Rest>
static inline constexpr size_t maxSize(size_t a, size_t b, Rest... rest) {
return maxSize(kj::max(a, b), rest...);
}
// Returns the maximum of all the parameters.
// TODO(someday): Generalize the above template and make it common. I tried, but C++ decided to
// be difficult so I cut my losses.
union {
byte space[maxSize(sizeof(Variants)...)];
void* forceAligned;
// TODO(someday): Use C++11 alignas() once we require GCC 4.8 / Clang 3.3.
};
template <typename... T>
inline void doAll(T... t) {}
template <typename T>
KJ_ALWAYS_INLINE(bool destroyVariant()) {
if (tag == typeIndex<T>()) {
tag = 0;
dtor(*reinterpret_cast<T*>(space));
}
return false;
}
void destroy() {
doAll(destroyVariant<Variants>()...);
}
template <typename T>
KJ_ALWAYS_INLINE(bool copyVariantFrom(const OneOf& other)) {
if (other.is<T>()) {
ctor(*reinterpret_cast<T*>(space), other.get<T>());
tag = typeIndex<T>();
}
return false;
}
void copyFrom(const OneOf& other) {
// Initialize as a copy of `other`. Expects that `this` starts out uninitialized, so the tag
// is invalid.
doAll(copyVariantFrom<Variants>(other)...);
}
template <typename T>
KJ_ALWAYS_INLINE(bool moveVariantFrom(OneOf& other)) {
if (other.is<T>()) {
ctor(*reinterpret_cast<T*>(space), kj::mv(other.get<T>()));
tag = typeIndex<T>();
}
return false;
}
void moveFrom(OneOf& other) {
// Initialize as a copy of `other`. Expects that `this` starts out uninitialized, so the tag
// is invalid.
doAll(moveVariantFrom<Variants>(other)...);
}
};
} // namespace kj
#endif // KJ_ONE_OF_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment