Commit 7693365a authored by Kenton Varda's avatar Kenton Varda

Add test verifying that backwards compatibility is maintained through various changes.

parent be6027ee
...@@ -258,7 +258,7 @@ $(test_capnpc_outputs): test_capnpc_middleman ...@@ -258,7 +258,7 @@ $(test_capnpc_outputs): test_capnpc_middleman
BUILT_SOURCES = $(test_capnpc_outputs) BUILT_SOURCES = $(test_capnpc_outputs)
check_PROGRAMS = capnp-test check_PROGRAMS = capnp-test capnp-evolution-test
capnp_test_LDADD = gtest/lib/libgtest.la gtest/lib/libgtest_main.la libcapnpc.la libcapnp.la libkj.la capnp_test_LDADD = gtest/lib/libgtest.la gtest/lib/libgtest_main.la libcapnpc.la libcapnp.la libkj.la
capnp_test_CPPFLAGS = -Igtest/include -I$(srcdir)/gtest/include capnp_test_CPPFLAGS = -Igtest/include -I$(srcdir)/gtest/include
capnp_test_SOURCES = \ capnp_test_SOURCES = \
...@@ -297,4 +297,7 @@ capnp_test_SOURCES = \ ...@@ -297,4 +297,7 @@ capnp_test_SOURCES = \
src/capnp/compiler/md5-test.c++ src/capnp/compiler/md5-test.c++
nodist_capnp_test_SOURCES = $(test_capnpc_outputs) nodist_capnp_test_SOURCES = $(test_capnpc_outputs)
TESTS = capnp-test capnp_evolution_test_LDADD = libcapnpc.la libcapnp.la libkj.la
capnp_evolution_test_SOURCES = src/capnp/compiler/evolution-test.c++
TESTS = capnp-test capnp-evolution-test
This diff is collapsed.
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
// already allocated and therefore cannot be a hole. // already allocated and therefore cannot be a hole.
kj::Maybe<UIntType> tryAllocate(UIntType lgSize) { kj::Maybe<UIntType> tryAllocate(UIntType lgSize) {
// Try to find space for a field of size lgSize^2 within the set of holes. If found, // Try to find space for a field of size 2^lgSize within the set of holes. If found,
// remove it from the holes, and return its offset (as a multiple of its size). If there // remove it from the holes, and return its offset (as a multiple of its size). If there
// is no such space, returns zero (no hole can be at offset zero, as explained above). // is no such space, returns zero (no hole can be at offset zero, as explained above).
...@@ -1126,7 +1126,28 @@ private: ...@@ -1126,7 +1126,28 @@ private:
} }
case Declaration::UNION: case Declaration::UNION:
errorReporter.addErrorOn(member, "Unions cannot contain unions."); if (member.getName().getValue() == "") {
errorReporter.addErrorOn(member, "Unions cannot contain unnamed unions.");
} else {
parent.childCount++;
// For layout purposes, pretend this union is enclosed in a one-member group.
StructLayout::Group& singletonGroup =
arena.allocate<StructLayout::Group>(layout);
StructLayout::Union& unionLayout = arena.allocate<StructLayout::Union>(singletonGroup);
memberInfo = &arena.allocate<MemberInfo>(
parent, codeOrder++, member,
newGroupNode(parent.node, member.getName().getValue()),
true);
allMembers.add(memberInfo);
memberInfo->unionScope = &unionLayout;
uint subCodeOrder = 0;
traverseUnion(member.getNestedDecls(), *memberInfo, unionLayout, subCodeOrder);
if (member.getId().isOrdinal()) {
ordinal = member.getId().getOrdinal().getValue();
}
}
break; break;
case Declaration::GROUP: { case Declaration::GROUP: {
......
...@@ -764,6 +764,43 @@ TEST(Encoding, BitListDowngrade) { ...@@ -764,6 +764,43 @@ TEST(Encoding, BitListDowngrade) {
checkList(reader.getObjectField<List<uint16_t>>(), {0x1201u, 0x3400u, 0x5601u, 0x7801u}); checkList(reader.getObjectField<List<uint16_t>>(), {0x1201u, 0x3400u, 0x5601u, 0x7801u});
} }
TEST(Encoding, BitListDowngradeFromStruct) {
MallocMessageBuilder builder;
auto root = builder.initRoot<test::TestObject>();
{
auto list = root.initObjectField<List<test::TestLists::Struct1c>>(4);
list[0].setF(true);
list[1].setF(false);
list[2].setF(true);
list[3].setF(true);
}
checkList(root.getObjectField<List<bool>>(), {true, false, true, true});
{
auto l = root.getObjectField<List<test::TestLists::Struct1>>();
ASSERT_EQ(4u, l.size());
EXPECT_TRUE(l[0].getF());
EXPECT_FALSE(l[1].getF());
EXPECT_TRUE(l[2].getF());
EXPECT_TRUE(l[3].getF());
}
auto reader = root.asReader();
checkList(reader.getObjectField<List<bool>>(), {true, false, true, true});
{
auto l = reader.getObjectField<List<test::TestLists::Struct1>>();
ASSERT_EQ(4u, l.size());
EXPECT_TRUE(l[0].getF());
EXPECT_FALSE(l[1].getF());
EXPECT_TRUE(l[2].getF());
EXPECT_TRUE(l[3].getF());
}
}
TEST(Encoding, BitListUpgrade) { TEST(Encoding, BitListUpgrade) {
MallocMessageBuilder builder; MallocMessageBuilder builder;
auto root = builder.initRoot<test::TestObject>(); auto root = builder.initRoot<test::TestObject>();
......
...@@ -1757,11 +1757,6 @@ struct WireHelpers { ...@@ -1757,11 +1757,6 @@ struct WireHelpers {
break; break;
case FieldSize::BIT: case FieldSize::BIT:
KJ_FAIL_REQUIRE("Expected a bit list, but got a list of structs.") {
goto useDefault;
}
break;
case FieldSize::BYTE: case FieldSize::BYTE:
case FieldSize::TWO_BYTES: case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES: case FieldSize::FOUR_BYTES:
......
...@@ -111,6 +111,7 @@ struct List<T, Kind::PRIMITIVE> { ...@@ -111,6 +111,7 @@ struct List<T, Kind::PRIMITIVE> {
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return reader.size() / ELEMENTS; }
inline T operator[](uint index) const { inline T operator[](uint index) const {
KJ_IREQUIRE(index < size());
return reader.template getDataElement<T>(index * ELEMENTS); return reader.template getDataElement<T>(index * ELEMENTS);
} }
...@@ -141,6 +142,7 @@ struct List<T, Kind::PRIMITIVE> { ...@@ -141,6 +142,7 @@ struct List<T, Kind::PRIMITIVE> {
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return builder.size() / ELEMENTS; }
inline T operator[](uint index) { inline T operator[](uint index) {
KJ_IREQUIRE(index < size());
return builder.template getDataElement<T>(index * ELEMENTS); return builder.template getDataElement<T>(index * ELEMENTS);
} }
inline void set(uint index, T value) { inline void set(uint index, T value) {
...@@ -216,6 +218,7 @@ struct List<T, Kind::STRUCT> { ...@@ -216,6 +218,7 @@ struct List<T, Kind::STRUCT> {
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return reader.size() / ELEMENTS; }
inline typename T::Reader operator[](uint index) const { inline typename T::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size());
return typename T::Reader(reader.getStructElement(index * ELEMENTS)); return typename T::Reader(reader.getStructElement(index * ELEMENTS));
} }
...@@ -246,6 +249,7 @@ struct List<T, Kind::STRUCT> { ...@@ -246,6 +249,7 @@ struct List<T, Kind::STRUCT> {
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return builder.size() / ELEMENTS; }
inline typename T::Builder operator[](uint index) { inline typename T::Builder operator[](uint index) {
KJ_IREQUIRE(index < size());
return typename T::Builder(builder.getStructElement(index * ELEMENTS)); return typename T::Builder(builder.getStructElement(index * ELEMENTS));
} }
...@@ -259,6 +263,8 @@ struct List<T, Kind::STRUCT> { ...@@ -259,6 +263,8 @@ struct List<T, Kind::STRUCT> {
// using a newer version of the schema that has additional fields -- it will be truncated, // using a newer version of the schema that has additional fields -- it will be truncated,
// losing data. // losing data.
KJ_IREQUIRE(index < size());
// We pass a zero-valued StructSize to asStruct() because we do not want the struct to be // We pass a zero-valued StructSize to asStruct() because we do not want the struct to be
// expanded under any circumstances. We're just going to throw it away anyway, and // expanded under any circumstances. We're just going to throw it away anyway, and
// transferContentFrom() already carefully compares the struct sizes before transferring. // transferContentFrom() already carefully compares the struct sizes before transferring.
...@@ -273,6 +279,7 @@ struct List<T, Kind::STRUCT> { ...@@ -273,6 +279,7 @@ struct List<T, Kind::STRUCT> {
// using a newer version of the schema that has additional fields -- it will be truncated, // using a newer version of the schema that has additional fields -- it will be truncated,
// losing data. // losing data.
KJ_IREQUIRE(index < size());
builder.getStructElement(index * ELEMENTS).copyContentFrom(reader._reader); builder.getStructElement(index * ELEMENTS).copyContentFrom(reader._reader);
} }
...@@ -341,6 +348,7 @@ struct List<List<T>, Kind::LIST> { ...@@ -341,6 +348,7 @@ struct List<List<T>, Kind::LIST> {
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return reader.size() / ELEMENTS; }
inline typename List<T>::Reader operator[](uint index) const { inline typename List<T>::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size());
return typename List<T>::Reader(List<T>::getAsElementOf(reader, index)); return typename List<T>::Reader(List<T>::getAsElementOf(reader, index));
} }
...@@ -371,15 +379,19 @@ struct List<List<T>, Kind::LIST> { ...@@ -371,15 +379,19 @@ struct List<List<T>, Kind::LIST> {
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return builder.size() / ELEMENTS; }
inline typename List<T>::Builder operator[](uint index) { inline typename List<T>::Builder operator[](uint index) {
KJ_IREQUIRE(index < size());
return typename List<T>::Builder(List<T>::getAsElementOf(builder, index)); return typename List<T>::Builder(List<T>::getAsElementOf(builder, index));
} }
inline typename List<T>::Builder init(uint index, uint size) { inline typename List<T>::Builder init(uint index, uint size) {
KJ_IREQUIRE(index < this->size());
return typename List<T>::Builder(List<T>::initAsElementOf(builder, index, size)); return typename List<T>::Builder(List<T>::initAsElementOf(builder, index, size));
} }
inline void set(uint index, typename List<T>::Reader value) { inline void set(uint index, typename List<T>::Reader value) {
KJ_IREQUIRE(index < size());
builder.setListElement(index * ELEMENTS, value.reader); builder.setListElement(index * ELEMENTS, value.reader);
} }
void set(uint index, std::initializer_list<ReaderFor<T>> value) { void set(uint index, std::initializer_list<ReaderFor<T>> value) {
KJ_IREQUIRE(index < size());
auto l = init(index, value.size()); auto l = init(index, value.size());
uint i = 0; uint i = 0;
for (auto& element: value) { for (auto& element: value) {
...@@ -387,9 +399,11 @@ struct List<List<T>, Kind::LIST> { ...@@ -387,9 +399,11 @@ struct List<List<T>, Kind::LIST> {
} }
} }
inline void adopt(uint index, Orphan<T>&& value) { inline void adopt(uint index, Orphan<T>&& value) {
KJ_IREQUIRE(index < size());
builder.adopt(index * ELEMENTS, kj::mv(value)); builder.adopt(index * ELEMENTS, kj::mv(value));
} }
inline Orphan<T> disown(uint index) { inline Orphan<T> disown(uint index) {
KJ_IREQUIRE(index < size());
return Orphan<T>(builder.disown(index * ELEMENTS)); return Orphan<T>(builder.disown(index * ELEMENTS));
} }
...@@ -451,6 +465,7 @@ struct List<T, Kind::BLOB> { ...@@ -451,6 +465,7 @@ struct List<T, Kind::BLOB> {
inline uint size() const { return reader.size() / ELEMENTS; } inline uint size() const { return reader.size() / ELEMENTS; }
inline typename T::Reader operator[](uint index) const { inline typename T::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size());
return reader.getBlobElement<T>(index * ELEMENTS); return reader.getBlobElement<T>(index * ELEMENTS);
} }
...@@ -481,18 +496,23 @@ struct List<T, Kind::BLOB> { ...@@ -481,18 +496,23 @@ struct List<T, Kind::BLOB> {
inline uint size() const { return builder.size() / ELEMENTS; } inline uint size() const { return builder.size() / ELEMENTS; }
inline typename T::Builder operator[](uint index) { inline typename T::Builder operator[](uint index) {
KJ_IREQUIRE(index < size());
return builder.getBlobElement<T>(index * ELEMENTS); return builder.getBlobElement<T>(index * ELEMENTS);
} }
inline void set(uint index, typename T::Reader value) { inline void set(uint index, typename T::Reader value) {
KJ_IREQUIRE(index < size());
builder.setBlobElement<T>(index * ELEMENTS, value); builder.setBlobElement<T>(index * ELEMENTS, value);
} }
inline typename T::Builder init(uint index, uint size) { inline typename T::Builder init(uint index, uint size) {
KJ_IREQUIRE(index < this->size());
return builder.initBlobElement<T>(index * ELEMENTS, size * BYTES); return builder.initBlobElement<T>(index * ELEMENTS, size * BYTES);
} }
inline void adopt(uint index, Orphan<T>&& value) { inline void adopt(uint index, Orphan<T>&& value) {
KJ_IREQUIRE(index < size());
builder.adopt(index * ELEMENTS, kj::mv(value)); builder.adopt(index * ELEMENTS, kj::mv(value));
} }
inline Orphan<T> disown(uint index) { inline Orphan<T> disown(uint index) {
KJ_IREQUIRE(index < size());
return Orphan<T>(builder.disown(index * ELEMENTS)); return Orphan<T>(builder.disown(index * ELEMENTS));
} }
......
...@@ -377,7 +377,8 @@ inline typename RootType::Builder MessageBuilder::initRoot() { ...@@ -377,7 +377,8 @@ inline typename RootType::Builder MessageBuilder::initRoot() {
template <typename Reader> template <typename Reader>
inline void MessageBuilder::setRoot(Reader&& value) { inline void MessageBuilder::setRoot(Reader&& value) {
typedef FromReader<Reader> RootType; typedef FromReader<Reader> RootType;
static_assert(kind<RootType>() == Kind::STRUCT, "Root type must be a Cap'n Proto struct type."); static_assert(kind<RootType>() == Kind::STRUCT,
"Parameter must be a Reader for a Cap'n Proto struct type.");
setRootInternal(value._reader); setRootInternal(value._reader);
} }
......
...@@ -442,7 +442,8 @@ private: ...@@ -442,7 +442,8 @@ private:
} }
if (hadCase) { if (hadCase) {
VALIDATE_SCHEMA(value.which() == expectedValueType, "Value did not match type."); VALIDATE_SCHEMA(value.which() == expectedValueType, "Value did not match type.",
(uint)value.which(), (uint)expectedValueType);
} }
} }
...@@ -511,6 +512,9 @@ public: ...@@ -511,6 +512,9 @@ public:
bool shouldReplace(const schema::Node::Reader& existingNode, bool shouldReplace(const schema::Node::Reader& existingNode,
const schema::Node::Reader& replacement, const schema::Node::Reader& replacement,
bool preferReplacementIfEquivalent) { bool preferReplacementIfEquivalent) {
this->existingNode = existingNode;
this->replacementNode = replacement;
KJ_CONTEXT("checking compatibility with previously-loaded node of the same id", KJ_CONTEXT("checking compatibility with previously-loaded node of the same id",
existingNode.getDisplayName()); existingNode.getDisplayName());
...@@ -528,6 +532,8 @@ public: ...@@ -528,6 +532,8 @@ public:
private: private:
SchemaLoader::Impl& loader; SchemaLoader::Impl& loader;
Text::Reader nodeName; Text::Reader nodeName;
schema::Node::Reader existingNode;
schema::Node::Reader replacementNode;
enum Compatibility { enum Compatibility {
EQUIVALENT, EQUIVALENT,
...@@ -633,6 +639,17 @@ private: ...@@ -633,6 +639,17 @@ private:
replacementIsOlder(); replacementIsOlder();
} }
if (replacement.getDiscriminantCount() > structNode.getDiscriminantCount()) {
replacementIsNewer();
} else if (replacement.getDiscriminantCount() < structNode.getDiscriminantCount()) {
replacementIsOlder();
}
if (replacement.getDiscriminantCount() > 0 && structNode.getDiscriminantCount() > 0) {
VALIDATE_SCHEMA(replacement.getDiscriminantOffset() == structNode.getDiscriminantOffset(),
"union discriminant position changed");
}
// The shared members should occupy corresponding positions in the member lists, since the // The shared members should occupy corresponding positions in the member lists, since the
// lists are sorted by ordinal. // lists are sorted by ordinal.
auto fields = structNode.getFields(); auto fields = structNode.getFields();
...@@ -672,12 +689,19 @@ private: ...@@ -672,12 +689,19 @@ private:
const schema::Field::Reader& replacement) { const schema::Field::Reader& replacement) {
KJ_CONTEXT("comparing struct field", field.getName()); KJ_CONTEXT("comparing struct field", field.getName());
VALIDATE_SCHEMA(field.which() == replacement.which(), // A field that is initially not in a union can be upgraded to be in one, as long as it has
"group field replaced with non-group or vice versa"); // discriminant 0.
uint discriminant = field.hasDiscriminantValue() ? field.getDiscriminantValue() : 0;
uint replacementDiscriminant =
replacement.hasDiscriminantValue() ? replacement.getDiscriminantValue() : 0;
VALIDATE_SCHEMA(discriminant == replacementDiscriminant, "Field discriminant changed.");
switch (field.which()) { switch (field.which()) {
case schema::Field::NON_GROUP: { case schema::Field::NON_GROUP: {
auto nonGroup = field.getNonGroup(); auto nonGroup = field.getNonGroup();
switch (replacement.which()) {
case schema::Field::NON_GROUP: {
auto replacementNonGroup = replacement.getNonGroup(); auto replacementNonGroup = replacement.getNonGroup();
checkCompatibility(nonGroup.getType(), replacementNonGroup.getType(), checkCompatibility(nonGroup.getType(), replacementNonGroup.getType(),
...@@ -689,11 +713,26 @@ private: ...@@ -689,11 +713,26 @@ private:
"field position changed"); "field position changed");
break; break;
} }
case schema::Field::GROUP:
checkUpgradeToStruct(nonGroup.getType(), replacement.getGroup(), existingNode, field);
break;
}
break;
}
case schema::Field::GROUP:
switch (replacement.which()) {
case schema::Field::NON_GROUP:
checkUpgradeToStruct(replacement.getNonGroup().getType(), field.getGroup(),
replacementNode, replacement);
break;
case schema::Field::GROUP: case schema::Field::GROUP:
VALIDATE_SCHEMA(field.getGroup() == replacement.getGroup(), "group id changed"); VALIDATE_SCHEMA(field.getGroup() == replacement.getGroup(), "group id changed");
break; break;
} }
break;
}
} }
void checkCompatibility(const schema::Node::Enum::Reader& enumNode, void checkCompatibility(const schema::Node::Enum::Reader& enumNode,
...@@ -859,7 +898,9 @@ private: ...@@ -859,7 +898,9 @@ private:
// We assume unknown types (from newer versions of Cap'n Proto?) are equivalent. // We assume unknown types (from newer versions of Cap'n Proto?) are equivalent.
} }
void checkUpgradeToStruct(const schema::Type::Reader& type, uint64_t structTypeId) { void checkUpgradeToStruct(const schema::Type::Reader& type, uint64_t structTypeId,
kj::Maybe<schema::Node::Reader> matchSize = nullptr,
kj::Maybe<schema::Field::Reader> matchPosition = nullptr) {
// We can't just look up the target struct and check it because it may not have been loaded // We can't just look up the target struct and check it because it may not have been loaded
// yet. Instead, we contrive a struct that looks like what we want and load() that, which // yet. Instead, we contrive a struct that looks like what we want and load() that, which
// guarantees that any incompatibility will be caught either now or when the real version of // guarantees that any incompatibility will be caught either now or when the real version of
...@@ -929,11 +970,55 @@ private: ...@@ -929,11 +970,55 @@ private:
break; break;
} }
KJ_IF_MAYBE(s, matchSize) {
auto match = s->getStruct();
structNode.setDataWordCount(match.getDataWordCount());
structNode.setPointerCount(match.getPointerCount());
structNode.setPreferredListEncoding(match.getPreferredListEncoding());
}
auto field = structNode.initFields(1)[0]; auto field = structNode.initFields(1)[0];
field.setName("member0"); field.setName("member0");
field.getOrdinal().setExplicit(0);
field.setCodeOrder(0); field.setCodeOrder(0);
field.initNonGroup().setType(type); auto nongroup = field.initNonGroup();
nongroup.setType(type);
KJ_IF_MAYBE(p, matchPosition) {
if (p->getOrdinal().isExplicit()) {
field.getOrdinal().setExplicit(p->getOrdinal().getExplicit());
} else {
field.getOrdinal().setImplicit();
}
auto matchNongroup = p->getNonGroup();
nongroup.setOffset(matchNongroup.getOffset());
nongroup.setDefaultValue(matchNongroup.getDefaultValue());
} else {
field.getOrdinal().setExplicit(0);
nongroup.setOffset(0);
schema::Value::Builder value = nongroup.initDefaultValue();
switch (type.which()) {
case schema::Type::VOID: value.setVoid(); break;
case schema::Type::BOOL: value.setBool(false); break;
case schema::Type::INT8: value.setInt8(0); break;
case schema::Type::INT16: value.setInt16(0); break;
case schema::Type::INT32: value.setInt32(0); break;
case schema::Type::INT64: value.setInt64(0); break;
case schema::Type::UINT8: value.setUint8(0); break;
case schema::Type::UINT16: value.setUint16(0); break;
case schema::Type::UINT32: value.setUint32(0); break;
case schema::Type::UINT64: value.setUint64(0); break;
case schema::Type::FLOAT32: value.setFloat32(0); break;
case schema::Type::FLOAT64: value.setFloat64(0); break;
case schema::Type::ENUM: value.setEnum(0); break;
case schema::Type::TEXT: value.adoptText(Orphan<Text>()); break;
case schema::Type::DATA: value.adoptData(Orphan<Data>()); break;
case schema::Type::LIST: value.adoptList(Orphan<Data>()); break;
case schema::Type::STRUCT: value.adoptStruct(Orphan<Data>()); break;
case schema::Type::INTERFACE: value.setInterface(); break;
case schema::Type::OBJECT: value.adoptObject(Orphan<Data>()); break;
}
}
loader.load(node, true); loader.load(node, true);
} }
......
...@@ -269,6 +269,17 @@ struct TestUnnamedUnion { ...@@ -269,6 +269,17 @@ struct TestUnnamedUnion {
after @4 :Text; after @4 :Text;
} }
struct TestUnionInUnion {
# There is no reason to ever do this.
outer :union {
inner :union {
foo @0 :Int32;
bar @1 :Int32;
}
baz @2 :Int32;
}
}
struct TestGroups { struct TestGroups {
groups :union { groups :union {
foo :group { foo :group {
...@@ -381,6 +392,15 @@ struct TestLists { ...@@ -381,6 +392,15 @@ struct TestLists {
struct Struct64 { f @0 :UInt64; } struct Struct64 { f @0 :UInt64; }
struct StructP { f @0 :Text; } struct StructP { f @0 :Text; }
# Versions of the above which cannot be encoded as primitive lists.
struct Struct0c { f @0 :Void; pad @1 :Text; }
struct Struct1c { f @0 :Bool; pad @1 :Text; }
struct Struct8c { f @0 :UInt8; pad @1 :Text; }
struct Struct16c { f @0 :UInt16; pad @1 :Text; }
struct Struct32c { f @0 :UInt32; pad @1 :Text; }
struct Struct64c { f @0 :UInt64; pad @1 :Text; }
struct StructPc { f @0 :Text; pad @1 :UInt64; }
list0 @0 :List(Struct0); list0 @0 :List(Struct0);
list1 @1 :List(Struct1); list1 @1 :List(Struct1);
list8 @2 :List(Struct8); list8 @2 :List(Struct8);
......
...@@ -79,5 +79,45 @@ TEST(Function, Method) { ...@@ -79,5 +79,45 @@ TEST(Function, Method) {
EXPECT_EQ(9 + 2 + 5, f(2, 9)); EXPECT_EQ(9 + 2 + 5, f(2, 9));
} }
struct TestConstType {
mutable int callCount;
TestConstType(int callCount = 0): callCount(callCount) {}
~TestConstType() { callCount = 1234; }
// Make sure we catch invalid post-destruction uses.
int foo(int a, int b) const {
return a + b + callCount++;
}
};
TEST(ConstFunction, Method) {
TestConstType obj;
ConstFunction<int(int, int)> f = KJ_BIND_METHOD(obj, foo);
ConstFunction<uint(uint, uint)> f2 = KJ_BIND_METHOD(obj, foo);
EXPECT_EQ(123 + 456, f(123, 456));
EXPECT_EQ(7 + 8 + 1, f(7, 8));
EXPECT_EQ(9u + 2u + 2u, f2(2, 9));
EXPECT_EQ(3, obj.callCount);
// Bind to a temporary.
f = KJ_BIND_METHOD(TestConstType(10), foo);
EXPECT_EQ(123 + 456 + 10, f(123, 456));
EXPECT_EQ(7 + 8 + 11, f(7, 8));
EXPECT_EQ(9 + 2 + 12, f(2, 9));
// Bind to a move.
f = KJ_BIND_METHOD(kj::mv(obj), foo);
obj.callCount = 1234;
EXPECT_EQ(123 + 456 + 3, f(123, 456));
EXPECT_EQ(7 + 8 + 4, f(7, 8));
EXPECT_EQ(9 + 2 + 5, f(2, 9));
}
} // namespace } // namespace
} // namespace kj } // namespace kj
...@@ -84,6 +84,10 @@ class Function; ...@@ -84,6 +84,10 @@ class Function;
// Notice how KJ_BIND_METHOD is able to figure out which overload to use depending on the kind of // Notice how KJ_BIND_METHOD is able to figure out which overload to use depending on the kind of
// Function it is binding to. // Function it is binding to.
template <typename Signature>
class ConstFunction;
// Like Function, but wraps a "const" (i.e. thread-safe) call.
template <typename Return, typename... Params> template <typename Return, typename... Params>
class Function<Return(Params...)> { class Function<Return(Params...)> {
public: public:
...@@ -91,10 +95,30 @@ public: ...@@ -91,10 +95,30 @@ public:
inline Function(F&& f): impl(heap<Impl<F>>(kj::fwd<F>(f))) {} inline Function(F&& f): impl(heap<Impl<F>>(kj::fwd<F>(f))) {}
Function() = default; Function() = default;
// Make sure people don't accidentally end up wrapping a reference when they meant to return
// a function.
KJ_DISALLOW_COPY(Function);
Function(Function&) = delete;
Function& operator=(Function&) = delete;
template <typename T> Function(const Function<T>&) = delete;
template <typename T> Function& operator=(const Function<T>&) = delete;
template <typename T> Function(const ConstFunction<T>&) = delete;
template <typename T> Function& operator=(const ConstFunction<T>&) = delete;
Function(Function&&) = default;
Function& operator=(Function&&) = default;
inline Return operator()(Params... params) { inline Return operator()(Params... params) {
return (*impl)(kj::fwd<Params>(params)...); return (*impl)(kj::fwd<Params>(params)...);
} }
Function reference() {
// Forms a new Function of the same type that delegates to this Function by reference.
// Therefore, this Function must outlive the returned Function, but otherwise they behave
// exactly the same.
return *impl;
}
private: private:
class Iface { class Iface {
public: public:
...@@ -117,6 +141,59 @@ private: ...@@ -117,6 +141,59 @@ private:
Own<Iface> impl; Own<Iface> impl;
}; };
template <typename Return, typename... Params>
class ConstFunction<Return(Params...)> {
public:
template <typename F>
inline ConstFunction(F&& f): impl(heap<Impl<F>>(kj::fwd<F>(f))) {}
ConstFunction() = default;
// Make sure people don't accidentally end up wrapping a reference when they meant to return
// a function.
KJ_DISALLOW_COPY(ConstFunction);
ConstFunction(ConstFunction&) = delete;
ConstFunction& operator=(ConstFunction&) = delete;
template <typename T> ConstFunction(const ConstFunction<T>&) = delete;
template <typename T> ConstFunction& operator=(const ConstFunction<T>&) = delete;
template <typename T> ConstFunction(const Function<T>&) = delete;
template <typename T> ConstFunction& operator=(const Function<T>&) = delete;
ConstFunction(ConstFunction&&) = default;
ConstFunction& operator=(ConstFunction&&) = default;
inline Return operator()(Params... params) const {
return (*impl)(kj::fwd<Params>(params)...);
}
ConstFunction reference() const {
// Forms a new ConstFunction of the same type that delegates to this ConstFunction by reference.
// Therefore, this ConstFunction must outlive the returned ConstFunction, but otherwise they
// behave exactly the same.
return *impl;
}
private:
class Iface {
public:
virtual Return operator()(Params... params) const = 0;
};
template <typename F>
class Impl final: public Iface {
public:
explicit Impl(F&& f): f(kj::fwd<F>(f)) {}
Return operator()(Params... params) const override {
return f(kj::fwd<Params>(params)...);
}
private:
F f;
};
Own<Iface> impl;
};
#if 1 #if 1
namespace _ { // private namespace _ { // private
...@@ -137,6 +214,20 @@ private: ...@@ -137,6 +214,20 @@ private:
T t; T t;
}; };
template <typename T, typename Return, typename... Params,
Return (Decay<T>::*method)(Params...) const>
class BoundMethod<T, Return (Decay<T>::*)(Params...) const, method> {
public:
BoundMethod(T&& t): t(kj::fwd<T>(t)) {}
Return operator()(Params&&... params) const {
return (t.*method)(kj::fwd<Params>(params)...);
}
private:
T t;
};
} // namespace _ (private) } // namespace _ (private)
#define KJ_BIND_METHOD(obj, method) \ #define KJ_BIND_METHOD(obj, method) \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment