Commit 7ceed92d authored by Kenton Varda's avatar Kenton Varda

More WIP.

parent 2e17fd43
......@@ -107,16 +107,22 @@ namespace {
class DummyClientHook final: public ClientHook {
public:
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId) const override {
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
KJ_FAIL_REQUIRE("Calling capability that was extracted from a message that had no "
"capability context.");
}
kj::Promise<void> whenResolved() const override {
return kj::READY_NOW;
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const override {
KJ_FAIL_REQUIRE("Calling capability that was extracted from a message that had no "
"capability context.");
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return nullptr;
}
kj::Own<ClientHook> addRef() const override {
kj::Own<const ClientHook> addRef() const override {
return kj::heap<DummyClientHook>();
}
......
......@@ -22,7 +22,10 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "capability.h"
#include "message.h"
#include <kj/refcount.h>
#include <kj/debug.h>
#include <kj/vector.h>
namespace capnp {
......@@ -63,4 +66,248 @@ TypelessResults::Pipeline TypelessResults::Pipeline::getPointerField(
return Pipeline(hook->addRef(), kj::mv(newOps));
}
ResponseHook::~ResponseHook() noexcept(false) {}
// =======================================================================================
namespace {
class LocalResponse final: public ResponseHook {
public:
LocalResponse(uint sizeHint)
: message(sizeHint == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : sizeHint) {}
MallocMessageBuilder message;
};
class LocalCallContext final: public CallContextHook {
public:
LocalCallContext(kj::Own<MallocMessageBuilder>&& request, kj::Own<const ClientHook> clientRef)
: request(kj::mv(request)), clientRef(kj::mv(clientRef)) {}
ObjectPointer::Reader getParams() override {
return request->getRoot<ObjectPointer>();
}
void releaseParams() override {
request = nullptr;
}
ObjectPointer::Builder getResults(uint firstSegmentWordSize) override {
if (!response) {
response = kj::heap<LocalResponse>(firstSegmentWordSize);
}
return response->message.getRoot<ObjectPointer>();
}
void allowAsyncCancellation(bool allow) override {
// ignored for local calls
}
bool isCanceled() override {
return false;
}
kj::Own<MallocMessageBuilder> request;
kj::Own<LocalResponse> response;
kj::Own<const ClientHook> clientRef;
};
class LocalPipelinedClient final: public ClientHook, public kj::Refcounted {
public:
LocalPipelinedClient(kj::Promise<kj::Own<const ClientHook>> promise)
: innerPromise(promise.then([this](kj::Own<const ClientHook>&& resolution) {
auto lock = state.lockExclusive();
auto oldState = kj::mv(*lock);
for (auto& call: oldState.pending) {
call.fulfiller->fulfill(resolution->call(
call.interfaceId, call.methodId, call.context).promise);
}
for (auto& notify: oldState.notifyOnResolution) {
notify->fulfill(resolution->addRef());
}
lock->resolution = kj::mv(resolution);
}, [this](kj::Exception&& exception) {
auto lock = state.lockExclusive();
auto oldState = kj::mv(*lock);
for (auto& call: oldState.pending) {
call.fulfiller->reject(kj::Exception(exception));
}
for (auto& notify: oldState.notifyOnResolution) {
notify->reject(kj::Exception(exception));
}
lock->exception = kj::mv(exception);
})) {}
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto lock = state.lockExclusive();
KJ_IF_MAYBE(r, lock->resolution) {
return r->newCall(interfaceId, methodId, firstSegmentWordSize);
} else {
// TODO(now)
}
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const override {
auto lock = state.lockExclusive();
KJ_IF_MAYBE(r, lock->resolution) {
return r->call(interfaceId, methodId, context);
} else {
lock->pending.add(PendingCall { interfaceId, methodId, context });
}
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
auto lock = state.lockExclusive();
KJ_IF_MAYBE(r, lock->resolution) {
// Already resolved.
return kj::Promise<kj::Own<const ClientHook>>(r->addRef());
} else {
auto pair = kj::newPromiseAndFulfiller<kj::Own<const ClientHook>>();
lock->notifyOnResolution.add(kj::mv(pair.fulfiller));
return kj::mv(pair.promise);
}
}
kj::Own<const ClientHook> addRef() const override {
return kj::addRef(*this);
}
void* getBrand() const override {
return nullptr;
}
private:
struct PendingCall {
uint64_t interfaceId;
uint16_t methodId;
CallContext<ObjectPointer, ObjectPointer> context;
kj::Own<kj::PromiseFulfiller<void>> fulfiller;
};
struct State {
kj::Maybe<kj::Own<const ClientHook>> resolution;
kj::Vector<PendingCall> pending;
kj::Vector<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> notifyOnResolution;
};
kj::MutexGuarded<State> state;
kj::Promise<void> innerPromise;
};
class LocalPipeline final: public PipelineHook, public kj::Refcounted {
public:
kj::Own<const PipelineHook> addRef() const override {
return kj::addRef(*this);
}
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const override {
}
private:
struct Waiter {
};
struct State {
kj::Vector<kj::Own<kj::PromiseFulfiller<kj::Own<const ClientHook>>>> notifyOnResolution;
};
kj::MutexGuarded<State> state;
};
class LocalRequest final: public RequestHook {
public:
inline LocalRequest(kj::EventLoop& eventLoop, const Capability::Server* server,
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize,
kj::Own<const ClientHook> clientRef)
: message(kj::heap<MallocMessageBuilder>(
firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize)),
eventLoop(eventLoop), server(server), interfaceId(interfaceId), methodId(methodId),
clientRef(kj::mv(clientRef)) {}
RemotePromise<TypelessResults> send() override {
// For the lambda capture.
// We can const-cast the server pointer because we are synchronizing to its event loop here.
Capability::Server* server = const_cast<Capability::Server*>(this->server);
uint64_t interfaceId = this->interfaceId;
uint16_t methodId = this->methodId;
auto context = kj::heap<LocalCallContext>(kj::mv(message), kj::mv(clientRef));
auto promise = eventLoop.evalLater(
kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
return server->dispatchCall(interfaceId, methodId,
CallContext<ObjectPointer, ObjectPointer>(*context))
.then(kj::mvCapture(context, [=](kj::Own<LocalCallContext> context) {
return Response<TypelessResults>(context->getResults(1).asReader(),
kj::mv(context->response));
}));
}));
return RemotePromise<TypelessResults>(
kj::mv(promise),
TypelessResults::Pipeline(kj::heap<LocalPipeline>()));
}
kj::Own<MallocMessageBuilder> message;
private:
kj::EventLoop& eventLoop;
const Capability::Server* server;
uint64_t interfaceId;
uint16_t methodId;
kj::Own<const ClientHook> clientRef;
};
class LocalClient final: public ClientHook, public kj::Refcounted {
public:
LocalClient(kj::EventLoop& eventLoop, kj::Own<Capability::Server>&& server)
: eventLoop(eventLoop), server(kj::mv(server)) {}
Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const override {
auto hook = kj::heap<LocalRequest>(
eventLoop, server, interfaceId, methodId, firstSegmentWordSize, kj::addRef(*this));
return Request<ObjectPointer, TypelessResults>(
hook->message->getRoot<ObjectPointer>(), kj::mv(hook));
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const override {
// We can const-cast the server because we're synchronizing on the event loop.
auto server = const_cast<Capability::Server*>(this->server.get());
auto promise = eventLoop.evalLater([=]() mutable {
return server->dispatchCall(interfaceId, methodId, context);
});
return VoidPromiseAndPipeline { kj::mv(promise),
TypelessResults::Pipeline(kj::heap<LocalPipeline>()) };
}
kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const override {
return nullptr;
}
kj::Own<const ClientHook> addRef() const override {
return kj::addRef(*this);
}
void* getBrand() const override {
// We have no need to detect local objects.
return nullptr;
}
private:
kj::EventLoop& eventLoop;
kj::Own<Capability::Server> server;
};
} // namespace
kj::Own<const ClientHook> makeLocalClient(kj::Own<Capability::Server>&& server,
kj::EventLoop& eventLoop) {
return kj::refcounted<LocalClient>(eventLoop, kj::mv(server));
}
} // namespace capnp
......@@ -110,7 +110,7 @@ class Capability::Client {
// Base type for capability clients.
public:
explicit Client(kj::Own<ClientHook>&& hook);
explicit Client(kj::Own<const ClientHook>&& hook);
Client(const Client& other);
Client& operator=(const Client& other);
......@@ -131,13 +131,14 @@ public:
// TODO(soon): method(s) for Join
private:
kj::Own<ClientHook> hook;
kj::Own<const ClientHook> hook;
protected:
Client() = default;
template <typename Params, typename Results>
Request<Params, Results> newCall(uint64_t interfaceId, uint16_t methodId);
Request<Params, Results> newCall(uint64_t interfaceId, uint16_t methodId,
uint firstSegmentWordSize);
};
// =======================================================================================
......@@ -164,17 +165,19 @@ public:
// requests. Long-running asynchronous methods should try to call this as early as is
// convenient.
typename Results::Builder getResults();
typename Results::Builder initResults();
typename Results::Builder initResults(uint size);
typename Results::Builder getResults(uint firstSegmentWordSize = 0);
typename Results::Builder initResults(uint firstSegmentWordSize = 0);
void setResults(typename Results::Reader value);
void adoptResults(Orphan<Results>&& value);
Orphanage getResultsOrphanage();
Orphanage getResultsOrphanage(uint firstSegmentWordSize = 0);
// Manipulate the results payload. The "Return" message (part of the RPC protocol) will
// typically be allocated the first time one of these is called. Some RPC systems may
// allocate these messages in a limited space (such as a shared memory segment), therefore the
// application should delay calling these as long as is convenient to do so (but don't delay
// if doing so would require extra copies later).
//
// `firstSegmentWordSize` indicates the suggested size of the message's first segment. This
// is a hint only. If not specified, the system will decide on its own.
void allowAsyncCancellation(bool allow = true);
// Indicate that it is OK for the RPC system to discard its Promise for this call's result if
......@@ -189,6 +192,12 @@ public:
// Keep in mind that asynchronous cancellation cannot occur while the method is synchronously
// executing on a local thread. The method must perform an asynchronous operation or call
// `EventLoop::current().runLater()` to yield control.
//
// TODO(soon): This doesn't work for local calls, because there's no one to own the object
// in the meantime. What do we do about that? Is the security issue here actually a real
// threat? Maybe we can just always enable cancellation. After all, you need to be fault
// tolerant and exception-safe, and those are pretty similar to being cancel-tolerant, though
// with less direct control by the attacker...
bool isCanceled();
// As an alternative to `allowAsyncCancellation()`, a server can call this to check for
......@@ -230,8 +239,10 @@ protected:
uint64_t typeId, uint16_t methodId);
};
Capability::Client makeLocalClient(kj::Own<Capability::Server>&& server);
// Make a client capability that wraps the given server capability.
kj::Own<const ClientHook> makeLocalClient(kj::Own<Capability::Server>&& server,
kj::EventLoop& eventLoop = kj::EventLoop::current());
// Make a client capability that wraps the given server capability. The server's methods will
// only be executed in the given EventLoop, regardless of what thread calls the client's methods.
// =======================================================================================
......@@ -257,7 +268,7 @@ struct TypelessResults {
class Pipeline {
public:
inline explicit Pipeline(kj::Own<PipelineHook>&& hook): hook(kj::mv(hook)) {}
inline explicit Pipeline(kj::Own<const PipelineHook>&& hook): hook(kj::mv(hook)) {}
Pipeline getPointerField(uint16_t pointerIndex) const;
// Return a new Promise representing a sub-object of the result. `pointerIndex` is the index
......@@ -270,10 +281,10 @@ struct TypelessResults {
// Expect that the result is a capability and construct a pipelined version of it now.
private:
kj::Own<PipelineHook> hook;
kj::Own<const PipelineHook> hook;
kj::Array<PipelineOp> ops;
inline Pipeline(kj::Own<PipelineHook>&& hook, kj::Array<PipelineOp>&& ops)
inline Pipeline(kj::Own<const PipelineHook>&& hook, kj::Array<PipelineOp>&& ops)
: hook(kj::mv(hook)), ops(kj::mv(ops)) {}
};
};
......@@ -286,9 +297,6 @@ class RequestHook {
// Hook interface implemented by RPC system representing a request being built.
public:
virtual ObjectPointer::Builder getRequest() = 0;
// Get the request object for this call, to be filled in before sending.
virtual RemotePromise<TypelessResults> send() = 0;
// Send the call and return a promise for the result.
};
......@@ -300,26 +308,50 @@ class ResponseHook {
// ResponseHook is destroyed, the results can be freed.
public:
virtual ~ResponseHook() noexcept(false);
// Just here to make sure the type is dynamic.
};
class PipelineHook {
// Represents a currently-running call, and implements pipelined requests on its result.
public:
virtual kj::Own<PipelineHook> addRef() const = 0;
virtual kj::Own<const PipelineHook> addRef() const = 0;
// Increment this object's reference count.
virtual kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const = 0;
virtual kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) const = 0;
// Extract a promised Capability from the results.
};
class ClientHook {
public:
virtual Request<ObjectPointer, TypelessResults> newCall(
uint64_t interfaceId, uint16_t methodId) const = 0;
virtual kj::Promise<void> whenResolved() const = 0;
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) const = 0;
// Start a new call, allowing the client to allocate request/response objects as it sees fit.
// This version is used when calls are made from application code in the local process.
virtual kj::Own<ClientHook> addRef() const = 0;
struct VoidPromiseAndPipeline {
kj::Promise<void> promise;
TypelessResults::Pipeline pipeline;
};
virtual VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
CallContext<ObjectPointer, ObjectPointer> context) const = 0;
// Call the object, but the caller controls allocation of the request/response objects. If the
// callee insists on allocating this objects itself, it must make a copy. This version is used
// when calls come in over the network via an RPC system. During the call, the context object
// may be used from any thread so long as it is only used from one thread at a time. Once the
// returned promise resolves or has been canceled, the context can no longer be used. The caller
// must not allow the ClientHook to be destroyed until the call completes or is canceled.
//
// The call must not begin synchronously, as the caller may hold arbitrary mutexes.
virtual kj::Maybe<kj::Promise<kj::Own<const ClientHook>>> whenMoreResolved() const = 0;
// If this client is a settled reference (not a promise), return nullptr. Otherwise, return a
// promise that eventually resolves to a new client that is closer to being the final, settled
// client. Calling this repeatedly should eventually produce a settled client.
virtual kj::Own<const ClientHook> addRef() const = 0;
// Return a new reference to the same capability.
virtual void* getBrand() const = 0;
......@@ -335,7 +367,7 @@ class CallContextHook {
public:
virtual ObjectPointer::Reader getParams() = 0;
virtual void releaseParams() = 0;
virtual ObjectPointer::Builder getResults() = 0;
virtual ObjectPointer::Builder getResults(uint firstSegmentWordSize) = 0;
virtual void allowAsyncCancellation(bool allow) = 0;
virtual bool isCanceled() = 0;
};
......@@ -366,7 +398,7 @@ inline Capability::Client TypelessResults::Pipeline::asCap() const {
return Capability::Client(hook->getPipelinedCap(ops));
}
inline Capability::Client::Client(kj::Own<ClientHook>&& hook): hook(kj::mv(hook)) {}
inline Capability::Client::Client(kj::Own<const ClientHook>&& hook): hook(kj::mv(hook)) {}
inline Capability::Client::Client(const Client& other): hook(other.hook->addRef()) {}
inline Capability::Client& Capability::Client::operator=(const Client& other) {
hook = other.hook->addRef();
......@@ -377,8 +409,8 @@ inline kj::Promise<void> Capability::Client::whenResolved() const {
}
template <typename Params, typename Results>
inline Request<Params, Results> Capability::Client::newCall(
uint64_t interfaceId, uint16_t methodId) {
auto typeless = hook->newCall(interfaceId, methodId);
uint64_t interfaceId, uint16_t methodId, uint firstSegmentWordSize) {
auto typeless = hook->newCall(interfaceId, methodId, firstSegmentWordSize);
return Request<Params, Results>(typeless.template getAs<Params>(), kj::mv(typeless.hook));
}
......@@ -393,31 +425,28 @@ inline void CallContext<Params, Results>::releaseParams() {
hook->releaseParams();
}
template <typename Params, typename Results>
inline typename Results::Builder CallContext<Params, Results>::getResults() {
// `template` keyword needed due to: http://llvm.org/bugs/show_bug.cgi?id=17401
return hook->getResults().template getAs<Results>();
}
template <typename Params, typename Results>
inline typename Results::Builder CallContext<Params, Results>::initResults() {
inline typename Results::Builder CallContext<Params, Results>::getResults(
uint firstSegmentWordSize) {
// `template` keyword needed due to: http://llvm.org/bugs/show_bug.cgi?id=17401
return hook->getResults().template initAs<Results>();
return hook->getResults(firstSegmentWordSize).template getAs<Results>();
}
template <typename Params, typename Results>
inline typename Results::Builder CallContext<Params, Results>::initResults(uint size) {
inline typename Results::Builder CallContext<Params, Results>::initResults(
uint firstSegmentWordSize) {
// `template` keyword needed due to: http://llvm.org/bugs/show_bug.cgi?id=17401
return hook->getResults().template initAs<Results>(size);
return hook->getResults(firstSegmentWordSize).template initAs<Results>();
}
template <typename Params, typename Results>
inline void CallContext<Params, Results>::setResults(typename Results::Reader value) {
hook->getResults().set(value);
hook->getResults(value.totalSizeInWords() + 1).set(value);
}
template <typename Params, typename Results>
inline void CallContext<Params, Results>::adoptResults(Orphan<Results>&& value) {
hook->getResults().adopt(kj::mv(value));
hook->getResults(0).adopt(kj::mv(value));
}
template <typename Params, typename Results>
inline Orphanage CallContext<Params, Results>::getResultsOrphanage() {
return Orphanage::getForMessageContaining(hook->getResults());
inline Orphanage CallContext<Params, Results>::getResultsOrphanage(uint firstSegmentWordSize) {
return Orphanage::getForMessageContaining(hook->getResults(firstSegmentWordSize));
}
template <typename Params, typename Results>
inline void CallContext<Params, Results>::allowAsyncCancellation(bool allow) {
......
......@@ -1148,7 +1148,8 @@ private:
return MethodText {
kj::strTree(
" ::capnp::Request<", paramType, ", ", resultType, "> ", name, "Request();\n"),
" ::capnp::Request<", paramType, ", ", resultType, "> ", name, "Request(\n"
" uint firstSegmentWordSize = 0);\n"),
kj::strTree(
" virtual ::kj::Promise<void> ", name, "(\n"
......@@ -1161,9 +1162,9 @@ private:
kj::strTree(
"::capnp::Request<", paramType, ", ", resultType, ">\n",
interfaceName, "::Client::", name, "Request() {\n"
interfaceName, "::Client::", name, "Request(uint firstSegmentWordSize) {\n"
" return newCall<", paramType, ", ", resultType, ">(\n"
" 0x", interfaceIdHex, "ull, ", methodId, ");\n"
" 0x", interfaceIdHex, "ull, ", methodId, ", firstSegmentWordSize);\n"
"}\n"
"::kj::Promise<void> ", interfaceName, "::Server::", name, "(\n"
" ", paramType, "::Reader params,\n"
......@@ -1231,7 +1232,7 @@ private:
return kj::strTree(",\n public virtual ", e.typeName, "::Client");
}, " {\n"
"public:\n"
" inline Client(::kj::Own< ::capnp::ClientHook>&& hook)\n"
" inline Client(::kj::Own<const ::capnp::ClientHook>&& hook)\n"
" : ::capnp::Capability::Client(::kj::mv(hook)) {}\n"
"\n",
KJ_MAP(m, methods) { return kj::mv(m.clientDecls); },
......
......@@ -1529,40 +1529,6 @@ Void DynamicValue::Builder::AsImpl<Void>::apply(Builder& builder) {
// =======================================================================================
template <>
DynamicStruct::Reader MessageReader::getRoot<DynamicStruct>(StructSchema schema) {
KJ_REQUIRE(!schema.getProto().getStruct().getIsGroup(),
"Can't use group type as the root of a message.");
return DynamicStruct::Reader(schema, getRootInternal());
}
template <>
DynamicStruct::Builder MessageBuilder::initRoot<DynamicStruct>(StructSchema schema) {
KJ_REQUIRE(!schema.getProto().getStruct().getIsGroup(),
"Can't use group type as the root of a message.");
return DynamicStruct::Builder(schema, initRoot(structSizeFromSchema(schema)));
}
template <>
DynamicStruct::Builder MessageBuilder::getRoot<DynamicStruct>(StructSchema schema) {
KJ_REQUIRE(!schema.getProto().getStruct().getIsGroup(),
"Can't use group type as the root of a message.");
return DynamicStruct::Builder(schema, getRoot(structSizeFromSchema(schema)));
}
template <>
void MessageBuilder::setRoot<DynamicStruct::Reader>(DynamicStruct::Reader&& value) {
setRootInternal(value.reader);
}
template <>
void MessageBuilder::setRoot<const DynamicStruct::Reader&>(const DynamicStruct::Reader& value) {
setRootInternal(value.reader);
}
template <>
void MessageBuilder::setRoot<DynamicStruct::Reader&>(DynamicStruct::Reader& value) {
setRootInternal(value.reader);
}
namespace _ { // private
DynamicStruct::Reader PointerHelpers<DynamicStruct, Kind::UNKNOWN>::getDynamic(
......
......@@ -771,27 +771,6 @@ template <>
Orphan<DynamicValue> Orphanage::newOrphanCopy<DynamicValue::Reader>(
const DynamicValue::Reader& copyFrom) const;
// -------------------------------------------------------------------
// Inject the ability to use DynamicStruct for message roots and Dynamic{Struct,List} for
// generated Object accessors.
template <>
DynamicStruct::Reader MessageReader::getRoot<DynamicStruct>(StructSchema schema);
template <>
DynamicStruct::Builder MessageBuilder::initRoot<DynamicStruct>(StructSchema schema);
template <>
DynamicStruct::Builder MessageBuilder::getRoot<DynamicStruct>(StructSchema schema);
template <>
void MessageBuilder::setRoot<DynamicStruct::Reader>(DynamicStruct::Reader&& value);
template <>
void MessageBuilder::setRoot<const DynamicStruct::Reader&>(const DynamicStruct::Reader& value);
template <>
void MessageBuilder::setRoot<DynamicStruct::Reader&>(DynamicStruct::Reader& value);
template <>
inline void MessageBuilder::adoptRoot<DynamicStruct>(Orphan<DynamicStruct>&& orphan) {
adoptRootInternal(kj::mv(orphan.builder));
}
namespace _ { // private
template <>
......
......@@ -130,6 +130,7 @@ inline kj::StringTree structString(StructReader reader) {
return structString(reader, rawSchema<T>());
}
// TODO(soon): Unify ConstStruct and ConstList.
template <typename T>
class ConstStruct {
public:
......@@ -138,7 +139,7 @@ public:
inline explicit constexpr ConstStruct(const word* ptr): ptr(ptr) {}
inline typename T::Reader get() const {
return typename T::Reader(StructReader::readRootUnchecked(ptr));
return ObjectPointer::Reader(PointerReader::getRootUnchecked(ptr)).getAs<T>();
}
inline operator typename T::Reader() const { return get(); }
......@@ -157,8 +158,7 @@ public:
inline explicit constexpr ConstList(const word* ptr): ptr(ptr) {}
inline typename List<T>::Reader get() const {
return typename List<T>::Reader(ListReader::readRootUnchecked(
ptr, elementSizeForType<T>()));
return ObjectPointer::Reader(PointerReader::getRootUnchecked(ptr)).getAs<List<T>>();
}
inline operator typename List<T>::Reader() const { return get(); }
......
......@@ -46,7 +46,7 @@ TEST(WireFormat, SimpleRawDataStruct) {
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef
}};
StructReader reader = StructReader::readRootUnchecked(data.words);
StructReader reader = PointerReader::getRootUnchecked(data.words).getStruct(nullptr);
EXPECT_EQ(0xefcdab8967452301ull, reader.getDataField<uint64_t>(0 * ELEMENTS));
EXPECT_EQ(0u, reader.getDataField<uint64_t>(1 * ELEMENTS));
......@@ -284,8 +284,8 @@ TEST(WireFormat, StructRoundTrip_OneSegment) {
SegmentBuilder* segment = allocation.segment;
word* rootLocation = allocation.words;
StructBuilder builder = StructBuilder::initRoot(
segment, rootLocation, StructSize(2 * WORDS, 4 * POINTERS, FieldSize::INLINE_COMPOSITE));
StructBuilder builder = PointerBuilder::getRoot(segment, rootLocation)
.initStruct(StructSize(2 * WORDS, 4 * POINTERS, FieldSize::INLINE_COMPOSITE));
setupStruct(builder);
// word count:
......@@ -310,8 +310,8 @@ TEST(WireFormat, StructRoundTrip_OneSegment) {
checkStruct(builder);
checkStruct(builder.asReader());
checkStruct(StructReader::readRootUnchecked(segment->getStartPtr()));
checkStruct(StructReader::readRoot(segment->getStartPtr(), segment, 4));
checkStruct(PointerReader::getRootUnchecked(segment->getStartPtr()).getStruct(nullptr));
checkStruct(PointerReader::getRoot(segment, segment->getStartPtr(), 4).getStruct(nullptr));
}
TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
......@@ -321,8 +321,8 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
SegmentBuilder* segment = allocation.segment;
word* rootLocation = allocation.words;
StructBuilder builder = StructBuilder::initRoot(
segment, rootLocation, StructSize(2 * WORDS, 4 * POINTERS, FieldSize::INLINE_COMPOSITE));
StructBuilder builder = PointerBuilder::getRoot(segment, rootLocation)
.initStruct(StructSize(2 * WORDS, 4 * POINTERS, FieldSize::INLINE_COMPOSITE));
setupStruct(builder);
// Verify that we made 15 segments.
......@@ -349,7 +349,7 @@ TEST(WireFormat, StructRoundTrip_OneSegmentPerAllocation) {
checkStruct(builder);
checkStruct(builder.asReader());
checkStruct(StructReader::readRoot(segment->getStartPtr(), segment, 4));
checkStruct(PointerReader::getRoot(segment, segment->getStartPtr(), 4).getStruct(nullptr));
}
TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) {
......@@ -359,8 +359,8 @@ TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) {
SegmentBuilder* segment = allocation.segment;
word* rootLocation = allocation.words;
StructBuilder builder = StructBuilder::initRoot(
segment, rootLocation, StructSize(2 * WORDS, 4 * POINTERS, FieldSize::INLINE_COMPOSITE));
StructBuilder builder = PointerBuilder::getRoot(segment, rootLocation)
.initStruct(StructSize(2 * WORDS, 4 * POINTERS, FieldSize::INLINE_COMPOSITE));
setupStruct(builder);
// Verify that we made 6 segments.
......@@ -378,7 +378,7 @@ TEST(WireFormat, StructRoundTrip_MultipleSegmentsWithMultipleAllocations) {
checkStruct(builder);
checkStruct(builder.asReader());
checkStruct(StructReader::readRoot(segment->getStartPtr(), segment, 4));
checkStruct(PointerReader::getRoot(segment, segment->getStartPtr(), 4).getStruct(nullptr));
}
} // namespace
......
......@@ -2115,6 +2115,16 @@ PointerBuilder PointerBuilder::imbue(ImbuedBuilderArena& newArena) const {
// =======================================================================================
// PointerReader
PointerReader PointerReader::getRoot(SegmentReader* segment, const word* location,
int nestingLimit) {
KJ_REQUIRE(WireHelpers::boundsCheck(segment, location, location + POINTER_SIZE_IN_WORDS),
"Root location out-of-bounds.") {
location = nullptr;
}
return PointerReader(segment, reinterpret_cast<const WirePointer*>(location), nestingLimit);
}
StructReader PointerReader::getStruct(const word* defaultValue) const {
const WirePointer* ref = pointer == nullptr ? &zero.pointer : pointer;
return WireHelpers::readStructPointer(segment, ref, defaultValue, nestingLimit);
......@@ -2158,26 +2168,6 @@ PointerReader PointerReader::imbue(ImbuedReaderArena& newArena) const {
// =======================================================================================
// StructBuilder
StructBuilder StructBuilder::initRoot(
SegmentBuilder* segment, word* location, StructSize size) {
return WireHelpers::initStructPointer(
reinterpret_cast<WirePointer*>(location), segment, size);
}
void StructBuilder::setRoot(SegmentBuilder* segment, word* location, StructReader value) {
WireHelpers::setStructPointer(segment, reinterpret_cast<WirePointer*>(location), value);
}
StructBuilder StructBuilder::getRoot(
SegmentBuilder* segment, word* location, StructSize size) {
return WireHelpers::getWritableStructPointer(
reinterpret_cast<WirePointer*>(location), segment, size, nullptr);
}
void StructBuilder::adoptRoot(SegmentBuilder* segment, word* location, OrphanBuilder orphan) {
WireHelpers::adopt(segment, reinterpret_cast<WirePointer*>(location), kj::mv(orphan));
}
void StructBuilder::clearAll() {
if (dataSize == 1 * BITS) {
setDataField<bool>(1 * ELEMENTS, false);
......@@ -2279,22 +2269,6 @@ BuilderArena* StructBuilder::getArena() {
// =======================================================================================
// StructReader
StructReader StructReader::readRootUnchecked(const word* location) {
return WireHelpers::readStructPointer(nullptr, reinterpret_cast<const WirePointer*>(location),
nullptr, std::numeric_limits<int>::max());
}
StructReader StructReader::readRoot(
const word* location, SegmentReader* segment, int nestingLimit) {
KJ_REQUIRE(WireHelpers::boundsCheck(segment, location, location + POINTER_SIZE_IN_WORDS),
"Root location out-of-bounds.") {
location = nullptr;
}
return WireHelpers::readStructPointer(segment, reinterpret_cast<const WirePointer*>(location),
nullptr, nestingLimit);
}
WordCount64 StructReader::totalSize() const {
WordCount64 result = WireHelpers::roundBitsUpToWords(dataSize) + pointerCount * WORDS_PER_POINTER;
......@@ -2314,11 +2288,6 @@ WordCount64 StructReader::totalSize() const {
// =======================================================================================
// ListBuilder
ListReader ListReader::readRootUnchecked(const word* location, FieldSize elementSize) {
return WireHelpers::readListPointer(nullptr, reinterpret_cast<const WirePointer*>(location),
nullptr, elementSize, std::numeric_limits<int>::max());
}
Text::Builder ListBuilder::asText() {
KJ_REQUIRE(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS,
"Expected Text, got list of non-bytes.") {
......
......@@ -281,13 +281,17 @@ class PointerBuilder: public kj::DisallowConstCopy {
public:
inline PointerBuilder(): segment(nullptr), pointer(nullptr) {}
static inline PointerBuilder getRoot(SegmentBuilder* segment, word* location);
// Get a PointerBuilder representing a message root located in the given segment at the given
// location.
bool isNull();
StructBuilder getStruct(StructSize size, const word* defaultValue);
ListBuilder getList(FieldSize elementSize, const word* defaultValzue);
ListBuilder getStructList(StructSize elementSize, const word* defaultValue);
template <typename T> typename T::Builder getBlob(const void* defaultValue,ByteCount defaultSize);
kj::Own<ClientHook> getCapability();
kj::Own<const ClientHook> getCapability();
// Get methods: Get the value. If it is null, initialize it to a copy of the default value.
// The default value is encoded as an "unchecked message" for structs, lists, and objects, or a
// simple byte array for blobs.
......@@ -303,7 +307,7 @@ public:
void setStruct(const StructReader& value);
void setList(const ListReader& value);
template <typename T> void setBlob(typename T::Reader value);
void setCapability(kj::Own<ClientHook>&& cap);
void setCapability(kj::Own<const ClientHook>&& cap);
// Set methods: Initialize the pointer to a newly-allocated copy of the given value, discarding
// the existing object.
......@@ -345,13 +349,20 @@ class PointerReader {
public:
inline PointerReader(): segment(nullptr), pointer(nullptr), nestingLimit(0x7fffffff) {}
static PointerReader getRoot(SegmentReader* segment, const word* location, int nestingLimit);
// Get a PointerReader representing a message root located in the given segment at the given
// location.
static inline PointerReader getRootUnchecked(const word* location);
// Get a PointerReader for an unchecked message.
bool isNull() const;
StructReader getStruct(const word* defaultValue) const;
ListReader getList(FieldSize expectedElementSize, const word* defaultValue) const;
template <typename T>
typename T::Reader getBlob(const void* defaultValue, ByteCount defaultSize) const;
kj::Own<ClientHook> getCapability();
kj::Own<const ClientHook> getCapability();
// Get methods: Get the value. If it is null, return the default value instead.
// The default value is encoded as an "unchecked message" for structs, lists, and objects, or a
// simple byte array for blobs.
......@@ -390,11 +401,6 @@ class StructBuilder: public kj::DisallowConstCopy {
public:
inline StructBuilder(): segment(nullptr), data(nullptr), pointers(nullptr), bit0Offset(0) {}
static StructBuilder initRoot(SegmentBuilder* segment, word* location, StructSize size);
static void setRoot(SegmentBuilder* segment, word* location, StructReader value);
static StructBuilder getRoot(SegmentBuilder* segment, word* location, StructSize size);
static void adoptRoot(SegmentBuilder* segment, word* location, OrphanBuilder orphan);
inline word* getLocation() { return reinterpret_cast<word*>(data); }
// Get the object's location. Only valid for independently-allocated objects (i.e. not list
// elements).
......@@ -482,9 +488,6 @@ public:
: segment(nullptr), data(nullptr), pointers(nullptr), dataSize(0),
pointerCount(0), bit0Offset(0), nestingLimit(0x7fffffff) {}
static StructReader readRootUnchecked(const word* location);
static StructReader readRoot(const word* location, SegmentReader* segment, int nestingLimit);
inline BitCount getDataSectionSize() const { return dataSize; }
inline WirePointerCount getPointerSectionSize() const { return pointerCount; }
inline Data::Reader getDataSectionAsBlob();
......@@ -633,8 +636,6 @@ public:
: segment(nullptr), ptr(nullptr), elementCount(0), step(0 * BITS / ELEMENTS),
structDataSize(0), structPointerCount(0), nestingLimit(0x7fffffff) {}
static ListReader readRootUnchecked(const word* location, FieldSize elementSize);
inline ElementCount size() const;
// The number of elements in the list.
......@@ -779,6 +780,14 @@ template <> void PointerBuilder::setBlob<Data>(typename Data::Reader value);
template <> typename Data::Builder PointerBuilder::getBlob<Data>(const void* defaultValue, ByteCount defaultSize);
template <> typename Data::Reader PointerReader::getBlob<Data>(const void* defaultValue, ByteCount defaultSize) const;
inline PointerBuilder PointerBuilder::getRoot(SegmentBuilder* segment, word* location) {
return PointerBuilder(segment, reinterpret_cast<WirePointer*>(location));
}
inline PointerReader PointerReader::getRootUnchecked(const word* location) {
return PointerReader(nullptr, reinterpret_cast<const WirePointer*>(location), 0x7fffffff);
}
// -------------------------------------------------------------------
inline Data::Builder StructBuilder::getDataSectionAsBlob() {
......
......@@ -42,7 +42,7 @@ MessageReader::~MessageReader() noexcept(false) {
}
}
_::StructReader MessageReader::getRootInternal() {
ObjectPointer::Reader MessageReader::getRootInternal() {
if (!allocatedArena) {
static_assert(sizeof(_::BasicReaderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a BasicReaderArena. Please increase it. This will break "
......@@ -55,10 +55,11 @@ _::StructReader MessageReader::getRootInternal() {
KJ_REQUIRE(segment != nullptr &&
segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1),
"Message did not contain a root pointer.") {
return _::StructReader();
return ObjectPointer::Reader();
}
return _::StructReader::readRoot(segment->getStartPtr(), segment, options.nestingLimit);
return ObjectPointer::Reader(_::PointerReader::getRoot(
segment, segment->getStartPtr(), options.nestingLimit));
}
// -------------------------------------------------------------------
......@@ -89,28 +90,10 @@ _::SegmentBuilder* MessageBuilder::getRootSegment() {
}
}
_::StructBuilder MessageBuilder::initRoot(_::StructSize size) {
ObjectPointer::Builder MessageBuilder::getRootInternal() {
_::SegmentBuilder* rootSegment = getRootSegment();
return _::StructBuilder::initRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), size);
}
void MessageBuilder::setRootInternal(_::StructReader reader) {
_::SegmentBuilder* rootSegment = getRootSegment();
_::StructBuilder::setRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), reader);
}
_::StructBuilder MessageBuilder::getRoot(_::StructSize size) {
_::SegmentBuilder* rootSegment = getRootSegment();
return _::StructBuilder::getRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), size);
}
void MessageBuilder::adoptRootInternal(_::OrphanBuilder orphan) {
_::SegmentBuilder* rootSegment = getRootSegment();
_::StructBuilder::adoptRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), kj::mv(orphan));
return ObjectPointer::Builder(_::PointerBuilder::getRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS)));
}
kj::ArrayPtr<const kj::ArrayPtr<const word>> MessageBuilder::getSegmentsForOutput() {
......
......@@ -26,6 +26,7 @@
#include <kj/mutex.h>
#include "common.h"
#include "layout.h"
#include "object.h"
#ifndef CAPNP_MESSAGE_H_
#define CAPNP_MESSAGE_H_
......@@ -111,9 +112,9 @@ public:
typename RootType::Reader getRoot();
// Get the root struct of the message, interpreting it as the given struct type.
template <typename RootType>
typename RootType::Reader getRoot(StructSchema schema);
// Dynamically interpret the root struct of the message using the given schema.
template <typename RootType, typename SchemaType>
typename RootType::Reader getRoot(SchemaType schema);
// Dynamically interpret the root struct of the message using the given schema (a StructSchema).
// RootType in this case must be DynamicStruct, and you must #include <capnp/dynamic.h> to
// use this.
......@@ -128,7 +129,7 @@ private:
bool allocatedArena;
_::BasicReaderArena* arena() { return reinterpret_cast<_::BasicReaderArena*>(arenaSpace); }
_::StructReader getRootInternal();
ObjectPointer::Reader getRootInternal();
};
class MessageBuilder {
......@@ -168,15 +169,15 @@ public:
typename RootType::Builder getRoot();
// Get the root struct of the message, interpreting it as the given struct type.
template <typename RootType>
typename RootType::Builder getRoot(StructSchema schema);
// Dynamically interpret the root struct of the message using the given schema.
template <typename RootType, typename SchemaType>
typename RootType::Builder getRoot(SchemaType schema);
// Dynamically interpret the root struct of the message using the given schema (a StructSchema).
// RootType in this case must be DynamicStruct, and you must #include <capnp/dynamic.h> to
// use this.
template <typename RootType>
typename RootType::Builder initRoot(StructSchema schema);
// Dynamically init the root struct of the message using the given schema.
template <typename RootType, typename SchemaType>
typename RootType::Builder initRoot(SchemaType schema);
// Dynamically init the root struct of the message using the given schema (a StructSchema).
// RootType in this case must be DynamicStruct, and you must #include <capnp/dynamic.h> to
// use this.
......@@ -204,10 +205,7 @@ private:
_::BasicBuilderArena* arena() { return reinterpret_cast<_::BasicBuilderArena*>(arenaSpace); }
_::SegmentBuilder* getRootSegment();
_::StructBuilder initRoot(_::StructSize size);
void setRootInternal(_::StructReader reader);
_::StructBuilder getRoot(_::StructSize size);
void adoptRootInternal(_::OrphanBuilder orphan);
ObjectPointer::Builder getRootInternal();
};
template <typename RootType>
......@@ -295,7 +293,7 @@ class MallocMessageBuilder: public MessageBuilder {
// a specific location in memory.
public:
explicit MallocMessageBuilder(uint firstSegmentWords = 1024,
explicit MallocMessageBuilder(uint firstSegmentWords = SUGGESTED_FIRST_SEGMENT_WORDS,
AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY);
// Creates a BuilderContext which allocates at least the given number of words for the first
// segment, and then uses the given strategy to decide how much to allocate for subsequent
......@@ -364,39 +362,47 @@ inline const ReaderOptions& MessageReader::getOptions() {
template <typename RootType>
inline typename RootType::Reader MessageReader::getRoot() {
static_assert(kind<RootType>() == Kind::STRUCT, "Root type must be a Cap'n Proto struct type.");
return typename RootType::Reader(getRootInternal());
return getRootInternal().getAs<RootType>();
}
template <typename RootType>
inline typename RootType::Builder MessageBuilder::initRoot() {
static_assert(kind<RootType>() == Kind::STRUCT, "Root type must be a Cap'n Proto struct type.");
return typename RootType::Builder(initRoot(_::structSize<RootType>()));
return getRootInternal().initAs<RootType>();
}
template <typename Reader>
inline void MessageBuilder::setRoot(Reader&& value) {
typedef FromReader<Reader> RootType;
static_assert(kind<RootType>() == Kind::STRUCT,
"Parameter must be a Reader for a Cap'n Proto struct type.");
setRootInternal(value._reader);
getRootInternal().setAs<FromReader<Reader>>(value);
}
template <typename RootType>
inline typename RootType::Builder MessageBuilder::getRoot() {
static_assert(kind<RootType>() == Kind::STRUCT, "Root type must be a Cap'n Proto struct type.");
return typename RootType::Builder(getRoot(_::structSize<RootType>()));
return getRootInternal().getAs<RootType>();
}
template <typename T>
void MessageBuilder::adoptRoot(Orphan<T>&& orphan) {
static_assert(kind<T>() == Kind::STRUCT, "Root type must be a Cap'n Proto struct type.");
adoptRootInternal(kj::mv(orphan.builder));
return getRootInternal().adopt(kj::mv(orphan));
}
template <typename RootType, typename SchemaType>
typename RootType::Reader MessageReader::getRoot(SchemaType schema) {
return getRootInternal().getAs<RootType>(schema);
}
template <typename RootType, typename SchemaType>
typename RootType::Builder MessageBuilder::getRoot(SchemaType schema) {
return getRootInternal().getAs<RootType>(schema);
}
template <typename RootType, typename SchemaType>
typename RootType::Builder MessageBuilder::initRoot(SchemaType schema) {
return getRootInternal().initAs<RootType>(schema);
}
template <typename RootType>
typename RootType::Reader readMessageUnchecked(const word* data) {
return typename RootType::Reader(_::StructReader::readRootUnchecked(data));
return ObjectPointer::Reader(_::PointerReader::getRootUnchecked(data)).getAs<RootType>();
}
template <typename Reader>
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "object.h"
#include "capability.h"
namespace capnp {
kj::Own<const ClientHook> ObjectPointer::Reader::getPipelinedCap(
kj::ArrayPtr<const PipelineOp> ops) {
_::PointerReader pointer = reader;
for (auto& op: ops) {
switch (op.type) {
case PipelineOp::Type::GET_POINTER_FIELD:
pointer = pointer.getStruct(nullptr).getPointerField(op.pointerIndex * POINTERS);
break;
}
}
return pointer.getCapability();
}
} // namespace capnp
......@@ -33,6 +33,8 @@ namespace capnp {
class StructSchema;
class ListSchema;
class Orphanage;
struct PipelineOp;
class ClientHook;
struct ObjectPointer {
// Reader/Builder for the `Object` field type, i.e. a pointer that can point to an arbitrary
......@@ -59,6 +61,10 @@ struct ObjectPointer {
inline typename T::Reader getAs(ListSchema schema);
// Only valid for T = DynamicList. Requires `#include <capnp/dynamic.h>`.
kj::Own<const ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops);
// Used by RPC system to implement pipelining. Applications generally shouldn't use this
// directly.
private:
_::PointerReader reader;
friend struct ObjectPointer;
......
......@@ -556,6 +556,8 @@ public:
private:
Promise(Own<_::PromiseNode>&& node): PromiseBase(kj::mv(node)) {}
template <typename>
friend class Promise;
friend class EventLoop;
template <typename U, typename Adapter, typename... Params>
friend Promise<U> newAdaptedPromise(Params&&... adapterConstructorParams);
......@@ -565,6 +567,41 @@ constexpr _::Void READY_NOW = _::Void();
// Use this when you need a Promise<void> that is already fulfilled -- this value can be implicitly
// cast to `Promise<void>`.
// -------------------------------------------------------------------
// Hack for creating a lambda that holds an owned pointer.
template <typename Func, typename MovedParam>
class CaptureByMove {
public:
inline CaptureByMove(Func&& func, MovedParam&& param)
: func(kj::mv(func)), param(kj::mv(param)) {}
template <typename... Params>
inline auto operator()(Params&&... params)
-> decltype(kj::instance<Func>()(kj::instance<MovedParam&&>(), kj::fwd<Params>(params)...)) {
return func(kj::mv(param), kj::fwd<Params>(params)...);
}
private:
Func func;
MovedParam param;
};
template <typename Func, typename MovedParam>
inline CaptureByMove<Func, Decay<MovedParam>> mvCapture(MovedParam&& param, Func&& func) {
// Hack to create a "lambda" which captures a variable by moving it rather than copying or
// referencing. C++14 generalized captures should make this obsolete, but for now in C++11 this
// is commonly needed for Promise continuations that own their state. Example usage:
//
// Own<Foo> ptr = makeFoo();
// Promise<int> promise = callRpc();
// promise.then(mvCapture(ptr, [](Own<Foo>&& ptr, int result) {
// return ptr->finish(result);
// }));
return CaptureByMove<Func, Decay<MovedParam>>(kj::fwd<Func>(func), kj::mv(param));
}
// -------------------------------------------------------------------
// Advanced promise construction
......@@ -1086,8 +1123,8 @@ private:
template <typename T, typename Adapter, typename... Params>
Promise<T> newAdaptedPromise(Params&&... adapterConstructorParams) {
return Promise<T>(heap<_::AdapterPromiseNode<_::FixVoid<T>, Adapter>>(
kj::fwd<Params>(adapterConstructorParams)...));
return Promise<T>(Own<_::PromiseNode>(heap<_::AdapterPromiseNode<_::FixVoid<T>, Adapter>>(
kj::fwd<Params>(adapterConstructorParams)...)));
}
template <typename T>
......
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "refcount.h"
#include <gtest/gtest.h>
namespace kj {
struct SetTrueInDestructor: public Refcounted {
SetTrueInDestructor(bool* ptr): ptr(ptr) {}
~SetTrueInDestructor() { *ptr = true; }
bool* ptr;
};
TEST(Refcount, Basic) {
bool b = false;
Own<SetTrueInDestructor> ref1 = kj::refcounted<SetTrueInDestructor>(&b);
Own<SetTrueInDestructor> ref2 = kj::addRef(*ref1);
Own<SetTrueInDestructor> ref3 = kj::addRef(*ref2);
EXPECT_FALSE(b);
ref1 = Own<SetTrueInDestructor>();
EXPECT_FALSE(b);
ref3 = Own<SetTrueInDestructor>();
EXPECT_FALSE(b);
ref2 = Own<SetTrueInDestructor>();
EXPECT_TRUE(b);
#if defined(KJ_DEBUG) && !KJ_NO_EXCEPTIONS
b = false;
SetTrueInDestructor obj(&b);
EXPECT_ANY_THROW(addRef(obj));
#endif
}
} // namespace kj
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "refcount.h"
namespace kj {
Refcounted::~Refcounted() noexcept(false) {}
void Refcounted::disposeImpl(void* pointer) const {
if (__atomic_sub_fetch(&refcount, 1, __ATOMIC_RELAXED) == 0) {
delete this;
}
}
} // namespace kj
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "memory.h"
#ifndef KJ_REFCOUNT_H_
#define KJ_REFCOUNT_H_
namespace kj {
class Refcounted: private Disposer {
// Subclass this to create a class that contains an atomic reference count. Then, use
// `kj::refcounted<T>()` to allocate a new refcounted pointer.
//
// Do NOT use this lightly. Refcounting is a crutch. Good designs should strive to make object
// ownership clear, so that refcounting is not necessary. Keep in mind that reference counting
// must use atomic operations and therefore is surprisingly slow -- often slower than allocating
// a copy on the heap. All that said, reference counting can sometimes simplify code that would
// otherwise become convoluted with explicit ownership, even when ownership relationships are
// clear at an abstract level.
//
// In general, abstract classes should _not_ subclass this. The concrete class at the bottom
// of the heirarchy should be the one to decide how it implements refcounting. Interfaces should
// expose only an `addRef()` method that returns `Own<InterfaceType>`. There are two reasons for
// this rule:
// 1. Interfaces would need to virtually inherit Refcounted, otherwise two refcounted interfaces
// could not be inherited by the same subclass. Virtual inheritance is awkward and
// inefficient.
// 2. An implementation may decide that it would rather return a copy than a refcount, or use
// some other strategy.
public:
virtual ~Refcounted() noexcept(false);
template <typename T>
static Own<T> addRef(T& object);
private:
mutable volatile uint refcount = 0;
void disposeImpl(void* pointer) const override;
template <typename T>
static Own<T> addRefInternal(T* object);
template <typename T>
friend Own<T> addRef(T& object);
template <typename T, typename... Params>
friend Own<T> refcounted(Params&&... params);
};
template <typename T, typename... Params>
inline Own<T> refcounted(Params&&... params) {
// Allocate a new refcounted instance of T, passing `params` to its constructor. Returns an
// initial reference to the object. More references can be created with `kj::addRef()`.
return Refcounted::addRefInternal(new T(kj::fwd<Params>(params)...));
}
template <typename T>
Own<T> addRef(T& object) {
// Return a new reference to `object`, which must subclass Refcounted and have been allocated
// using `kj::refcounted<>()`. It is suggested that subclasses implement a non-static addRef()
// method which wraps this and returns the appropriate type.
KJ_IREQUIRE(object.Refcounted::refcount > 0, "Object not allocated with kj::refcounted().");
return Refcounted::addRefInternal(&object);
}
template <typename T>
Own<T> Refcounted::addRefInternal(T* object) {
const Refcounted* refcounted = object;
__atomic_add_fetch(&refcounted->refcount, 1, __ATOMIC_RELAXED);
return Own<T>(object, *refcounted);
}
} // namespace kj
#endif // KJ_REFCOUNT_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment