Commit 7693365a authored by Kenton Varda's avatar Kenton Varda

Add test verifying that backwards compatibility is maintained through various changes.

parent be6027ee
......@@ -258,7 +258,7 @@ $(test_capnpc_outputs): test_capnpc_middleman
BUILT_SOURCES = $(test_capnpc_outputs)
check_PROGRAMS = capnp-test
check_PROGRAMS = capnp-test capnp-evolution-test
capnp_test_LDADD = gtest/lib/libgtest.la gtest/lib/libgtest_main.la libcapnpc.la libcapnp.la libkj.la
capnp_test_CPPFLAGS = -Igtest/include -I$(srcdir)/gtest/include
capnp_test_SOURCES = \
......@@ -297,4 +297,7 @@ capnp_test_SOURCES = \
src/capnp/compiler/md5-test.c++
nodist_capnp_test_SOURCES = $(test_capnpc_outputs)
TESTS = capnp-test
capnp_evolution_test_LDADD = libcapnpc.la libcapnp.la libkj.la
capnp_evolution_test_SOURCES = src/capnp/compiler/evolution-test.c++
TESTS = capnp-test capnp-evolution-test
// Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// This is a fuzz test which randomly generates a schema for a struct one change at a time.
// Each modification is known a priori to be compatible or incompatible. The type is compiled
// before and after the change and both versions are loaded into a SchemaLoader with the
// expectation that this will succeed if they are compatible and fail if they are not. If
// the types are expected to be compatible, the test also constructs an instance of the old
// type and reads it as the new type, and vice versa.
#include <capnp/compiler/grammar.capnp.h>
#include <capnp/schema-loader.h>
#include <capnp/message.h>
#include <capnp/pretty-print.h>
#include "compiler.h"
#include <kj/function.h>
#include <kj/debug.h>
#include <stdlib.h>
#include <time.h>
#include <kj/main.h>
#include <kj/io.h>
#include <unistd.h>
namespace capnp {
namespace compiler {
namespace {
static const kj::StringPtr RFC3092[] = {"foo", "bar", "baz", "qux"};
template <typename T, size_t size>
T& chooseFrom(T (&arr)[size]) {
return arr[rand() % size];
}
template <typename T>
auto chooseFrom(T arr) -> decltype(arr[0]) {
return arr[rand() % arr.size()];
}
static Declaration::Builder addNested(Declaration::Builder parent) {
auto oldNestedOrphan = parent.disownNestedDecls();
auto oldNested = oldNestedOrphan.get();
auto newNested = parent.initNestedDecls(oldNested.size() + 1);
uint index = rand() % (oldNested.size() + 1);
for (uint i = 0; i < index; i++) {
newNested.setWithCaveats(i, oldNested[i]);
}
for (uint i = index + 1; i < newNested.size(); i++) {
newNested.setWithCaveats(i, oldNested[i - 1]);
}
return newNested[index];
}
struct TypeOption {
kj::StringPtr name;
kj::ConstFunction<void(ValueExpression::Builder)> makeValue;
};
static const TypeOption TYPE_OPTIONS[] = {
{ "Int32",
[](ValueExpression::Builder builder) {
builder.setPositiveInt(rand() % (1 << 24));
}},
{ "Float64",
[](ValueExpression::Builder builder) {
builder.setPositiveInt(rand());
}},
{ "Int8",
[](ValueExpression::Builder builder) {
builder.setPositiveInt(rand() % 128);
}},
{ "UInt16",
[](ValueExpression::Builder builder) {
builder.setPositiveInt(rand() % (1 << 16));
}},
{ "Bool",
[](ValueExpression::Builder builder) {
builder.initName().getBase().initRelativeName().setValue("true");
}},
{ "Text",
[](ValueExpression::Builder builder) {
builder.setString(chooseFrom(RFC3092));
}},
{ "StructType",
[](ValueExpression::Builder builder) {
auto assignment = builder.initStruct(1)[0];
assignment.initFieldName().setValue("i");
assignment.initValue().setPositiveInt(rand() % (1 << 24));
}},
{ "EnumType",
[](ValueExpression::Builder builder) {
builder.initName().getBase().initRelativeName().setValue(chooseFrom(RFC3092));
}},
};
void setDeclName(DeclName::Builder decl, kj::StringPtr name) {
decl.getBase().initRelativeName().setValue(name);
}
static kj::ConstFunction<void(ValueExpression::Builder)> randomizeType(
TypeExpression::Builder type) {
auto option = &chooseFrom(TYPE_OPTIONS);
if (rand() % 4 == 0) {
setDeclName(type.initName(), "List");
setDeclName(type.initParams(1)[0].initName(), option->name);
return [option](ValueExpression::Builder builder) {
for (auto element: builder.initList(rand() % 4 + 1)) {
option->makeValue(element);
}
};
} else {
setDeclName(type.initName(), option->name);
return option->makeValue.reference();
}
}
enum ChangeKind {
NO_CHANGE,
COMPATIBLE,
INCOMPATIBLE,
SUBTLY_COMPATIBLE
// The change is technically compatible on the wire, but SchemaLoader will complain.
};
struct ChangeInfo {
ChangeKind kind;
kj::String description;
ChangeInfo(): kind(NO_CHANGE) {}
ChangeInfo(ChangeKind kind, kj::StringPtr description)
: kind(kind), description(kj::str(description)) {}
ChangeInfo(ChangeKind kind, kj::String&& description)
: kind(kind), description(kj::mv(description)) {}
};
extern kj::ArrayPtr<kj::ConstFunction<ChangeInfo(Declaration::Builder, uint&, bool)>> STRUCT_MODS;
extern kj::ArrayPtr<kj::ConstFunction<ChangeInfo(Declaration::Builder, uint&, bool)>> FIELD_MODS;
// ================================================================================
static ChangeInfo declChangeName(Declaration::Builder decl, uint& nextOrdinal,
bool scopeHasUnion) {
auto name = decl.getName();
if (name.getValue().size() == 0) {
// Naming an unnamed union.
name.setValue("unUnnamed");
return { SUBTLY_COMPATIBLE, "Assign name to unnamed union." };
} else {
name.setValue(kj::str(name.getValue(), "Xx"));
return { COMPATIBLE, "Rename declaration." };
}
}
static ChangeInfo structAddField(Declaration::Builder decl, uint& nextOrdinal, bool scopeHasUnion) {
auto fieldDecl = addNested(decl);
uint ordinal = nextOrdinal++;
fieldDecl.initName().setValue(kj::str("f", ordinal));
fieldDecl.getId().initOrdinal().setValue(ordinal);
auto field = fieldDecl.initField();
auto makeValue = randomizeType(field.initType());
if (rand() % 4 == 0) {
makeValue(field.getDefaultValue().initValue());
} else {
field.getDefaultValue().setNone();
}
return { COMPATIBLE, "Add field." };
}
static ChangeInfo structModifyField(Declaration::Builder decl, uint& nextOrdinal,
bool scopeHasUnion) {
auto nested = decl.getNestedDecls();
if (nested.size() == 0) {
return { NO_CHANGE, "Modify field, but there were none to modify." };
}
auto field = chooseFrom(nested);
bool hasUnion = false;
if (decl.isUnion()) {
hasUnion = true;
} else {
for (auto n: nested) {
if (n.isUnion() && n.getName().getValue().size() == 0) {
hasUnion = true;
break;
}
}
}
if (field.isGroup() || field.isUnion()) {
return chooseFrom(STRUCT_MODS)(field, nextOrdinal, hasUnion);
} else {
return chooseFrom(FIELD_MODS)(field, nextOrdinal, hasUnion);
}
}
static ChangeInfo structGroupifyFields(
Declaration::Builder decl, uint& nextOrdinal, bool scopeHasUnion) {
// Place a random subset of the fields into a group.
if (decl.isUnion()) {
return { NO_CHANGE,
"Randomly make a group out of some fields, but I can't do this to a union." };
}
kj::Vector<Orphan<Declaration>> groupified;
kj::Vector<Orphan<Declaration>> notGroupified;
auto orphanage = Orphanage::getForMessageContaining(decl);
for (auto nested: decl.getNestedDecls()) {
if (rand() % 2) {
groupified.add(orphanage.newOrphanCopy(nested.asReader()));
} else {
notGroupified.add(orphanage.newOrphanCopy(nested.asReader()));
}
}
if (groupified.size() == 0) {
return { NO_CHANGE,
"Randomly make a group out of some fields, but I ended up choosing none of them." };
}
auto newNested = decl.initNestedDecls(notGroupified.size() + 1);
uint index = rand() % (notGroupified.size() + 1);
for (uint i = 0; i < index; i++) {
newNested.adoptWithCaveats(i, kj::mv(notGroupified[i]));
}
for (uint i = index; i < notGroupified.size(); i++) {
newNested.adoptWithCaveats(i + 1, kj::mv(notGroupified[i]));
}
auto newGroup = newNested[index];
auto groupNested = newGroup.initNestedDecls(groupified.size());
for (uint i = 0; i < groupified.size(); i++) {
groupNested.adoptWithCaveats(i, kj::mv(groupified[i]));
}
newGroup.initName().setValue(kj::str("g", groupNested[0].getName().getValue()));
newGroup.getId().setUnspecified();
newGroup.setGroup();
return { SUBTLY_COMPATIBLE, "Randomly group some set of existing fields." };
}
static ChangeInfo structPermuteFields(
Declaration::Builder decl, uint& nextOrdinal, bool scopeHasUnion) {
if (decl.getNestedDecls().size() == 0) {
return { NO_CHANGE, "Permute field code order, but there were none." };
}
auto oldOrphan = decl.disownNestedDecls();
auto old = oldOrphan.get();
KJ_STACK_ARRAY(uint, mapping, old.size(), 16, 64);
for (uint i = 0; i < mapping.size(); i++) {
mapping[i] = i;
}
for (uint i = mapping.size() - 1; i > 0; i--) {
uint j = rand() % i;
uint temp = mapping[j];
mapping[j] = mapping[i];
mapping[i] = temp;
}
auto newNested = decl.initNestedDecls(old.size());
for (uint i = 0; i < old.size(); i++) {
newNested.setWithCaveats(i, old[mapping[i]]);
}
return { COMPATIBLE, "Permute field code order." };
}
kj::ConstFunction<ChangeInfo(Declaration::Builder, uint&, bool)> STRUCT_MODS_[] = {
structAddField,
structAddField,
structAddField,
structModifyField,
structModifyField,
structModifyField,
structPermuteFields,
declChangeName,
structGroupifyFields // do more rarely because it creates slowness
};
kj::ArrayPtr<kj::ConstFunction<ChangeInfo(Declaration::Builder, uint&, bool)>>
STRUCT_MODS = STRUCT_MODS_;
// ================================================================================
static ChangeInfo fieldUpgradeList(Declaration::Builder decl, uint& nextOrdinal,
bool scopeHasUnion) {
// Upgrades a non-struct list to a struct list.
auto field = decl.getField();
if (field.getDefaultValue().isValue()) {
return { NO_CHANGE, "Upgrade primitive list to struct list, but it had a default value." };
}
auto typeParams = field.getType().getParams();
if (typeParams.size() != 1) {
return { NO_CHANGE, "Upgrade primitive list to struct list, but it wasn't a list." };
}
auto elementType = typeParams[0];
auto relativeName = elementType.getName().getBase().getRelativeName();
auto nameText = relativeName.asReader().getValue();
if (nameText == "StructType" || nameText.endsWith("Struct")) {
return { NO_CHANGE, "Upgrade primitive list to struct list, but it was already a struct list."};
}
relativeName.setValue(kj::str(nameText, "Struct"));
return { COMPATIBLE, "Upgrade primitive list to struct list" };
}
static ChangeInfo fieldExpandGroup(Declaration::Builder decl, uint& nextOrdinal,
bool scopeHasUnion) {
Declaration::Builder newDecl = decl.initNestedDecls(1)[0];
newDecl.adoptName(decl.disownName());
newDecl.getId().adoptOrdinal(decl.getId().disownOrdinal());
auto field = decl.getField();
auto newField = newDecl.initField();
newField.adoptType(field.disownType());
if (field.getDefaultValue().isValue()) {
newField.getDefaultValue().adoptValue(field.getDefaultValue().disownValue());
} else {
newField.getDefaultValue().setNone();
}
decl.initName().setValue(kj::str("g", newDecl.getName().getValue()));
decl.getId().setUnspecified();
if (rand() % 2 == 0) {
decl.setGroup();
} else {
decl.setUnion();
if (!scopeHasUnion && rand() % 2 == 0) {
// Make it an unnamed union.
decl.getName().setValue("");
}
structAddField(decl, nextOrdinal, scopeHasUnion); // union must have two members
}
return { COMPATIBLE, "Wrap a field in a singleton group." };
}
static ChangeInfo fieldChangeType(Declaration::Builder decl, uint& nextOrdinal,
bool scopeHasUnion) {
auto field = decl.getField();
if (field.getDefaultValue().isNone()) {
// Change the type.
auto type = field.getType();
while (type.getParams().size() > 0) {
// Either change the list parameter, or revert to a non-list.
if (rand() % 2) {
type = type.getParams()[0];
} else {
type.disownParams();
}
}
auto typeName = type.getName().getBase().getRelativeName();
if (typeName.asReader().getValue().startsWith("Text")) {
typeName.setValue("Int32");
} else {
typeName.setValue("Text");
}
return { INCOMPATIBLE, "Change the type of a field." };
} else {
// Change the default value.
auto dval = field.getDefaultValue().getValue();
switch (dval.which()) {
case ValueExpression::UNKNOWN: KJ_FAIL_ASSERT("unknown value expression?");
case ValueExpression::POSITIVE_INT: dval.setPositiveInt(dval.getPositiveInt() ^ 1); break;
case ValueExpression::NEGATIVE_INT: dval.setNegativeInt(dval.getNegativeInt() ^ 1); break;
case ValueExpression::FLOAT: dval.setFloat(-dval.getFloat()); break;
case ValueExpression::NAME: {
auto name = dval.getName().getBase().getRelativeName();
auto nameText = name.asReader().getValue();
if (nameText == "true") {
name.setValue("false");
} else if (nameText == "false") {
name.setValue("true");
} else if (nameText == "foo") {
name.setValue("bar");
} else {
name.setValue("foo");
}
break;
}
case ValueExpression::STRING:
case ValueExpression::LIST:
case ValueExpression::STRUCT:
return { NO_CHANGE, "Change the default value of a field, but it's a pointer field." };
}
return { INCOMPATIBLE, "Change the default value of a pritimive field." };
}
}
kj::ConstFunction<ChangeInfo(Declaration::Builder, uint&, bool)> FIELD_MODS_[] = {
fieldUpgradeList,
fieldExpandGroup,
fieldChangeType,
declChangeName
};
kj::ArrayPtr<kj::ConstFunction<ChangeInfo(Declaration::Builder, uint&, bool)>>
FIELD_MODS = FIELD_MODS_;
// ================================================================================
uint getOrdinal(StructSchema::Field field) {
auto proto = field.getProto();
if (proto.getOrdinal().isExplicit()) {
return proto.getOrdinal().getExplicit();
}
KJ_ASSERT(proto.isGroup());
auto group = field.getContainingStruct().getDependency(proto.getGroup()).asStruct();
return getOrdinal(group.getFields()[0]);
}
Orphan<DynamicStruct> makeExampleStruct(
Orphanage orphanage, StructSchema schema, uint sharedOrdinalCount);
void checkExampleStruct(DynamicStruct::Reader reader, uint sharedOrdinalCount);
Orphan<DynamicValue> makeExampleValue(
Orphanage orphanage, Schema scope, uint ordinal, schema::Type::Reader type,
uint sharedOrdinalCount) {
switch (type.which()) {
case schema::Type::INT32: return ordinal * 47327;
case schema::Type::FLOAT64: return ordinal * 313.25;
case schema::Type::INT8: return int(ordinal % 256) - 128;
case schema::Type::UINT16: return ordinal * 13;
case schema::Type::BOOL: return ordinal % 2 == 0;
case schema::Type::TEXT: return orphanage.newOrphanCopy(Text::Reader(kj::str(ordinal)));
case schema::Type::STRUCT: {
auto structType = scope.getDependency(type.getStruct()).asStruct();
auto result = orphanage.newOrphan(structType);
auto builder = result.get();
KJ_IF_MAYBE(fieldI, structType.findFieldByName("i")) {
// Type is "StructType"
builder.set(*fieldI, ordinal);
} else {
// Type is "Int32Struct" or the like.
auto field = structType.getFieldByName("f0");
builder.adopt(field, makeExampleValue(orphanage, structType, ordinal,
field.getProto().getNonGroup().getType(),
sharedOrdinalCount));
}
return kj::mv(result);
}
case schema::Type::ENUM: {
auto enumerants = scope.getDependency(type.getEnum()).asEnum().getEnumerants();
return DynamicEnum(enumerants[ordinal %enumerants.size()]);
}
case schema::Type::LIST: {
auto elementType = type.getList();
auto listType = ListSchema::of(elementType, scope);
auto result = orphanage.newOrphan(listType, 1);
result.get().adopt(0, makeExampleValue(
orphanage, scope, ordinal, elementType, sharedOrdinalCount));
return kj::mv(result);
}
default:
KJ_FAIL_ASSERT("You added a new possible field type!");
}
}
void checkExampleValue(DynamicValue::Reader value, uint ordinal, schema::Type::Reader type,
uint sharedOrdinalCount) {
switch (type.which()) {
case schema::Type::INT32: KJ_ASSERT(value.as<int32_t>() == ordinal * 47327); break;
case schema::Type::FLOAT64: KJ_ASSERT(value.as<double>() == ordinal * 313.25); break;
case schema::Type::INT8: KJ_ASSERT(value.as<int8_t>() == int(ordinal % 256) - 128); break;
case schema::Type::UINT16: KJ_ASSERT(value.as<uint16_t>() == ordinal * 13); break;
case schema::Type::BOOL: KJ_ASSERT(value.as<bool>() == (ordinal % 2 == 0)); break;
case schema::Type::TEXT: KJ_ASSERT(value.as<Text>() == kj::str(ordinal)); break;
case schema::Type::STRUCT: {
auto structValue = value.as<DynamicStruct>();
auto structType = structValue.getSchema();
KJ_IF_MAYBE(fieldI, structType.findFieldByName("i")) {
// Type is "StructType"
KJ_ASSERT(structValue.get(*fieldI).as<uint32_t>() == ordinal);
} else {
// Type is "Int32Struct" or the like.
auto field = structType.getFieldByName("f0");
checkExampleValue(structValue.get(field), ordinal,
field.getProto().getNonGroup().getType(), sharedOrdinalCount);
}
break;
}
case schema::Type::ENUM: {
auto enumerant = KJ_ASSERT_NONNULL(value.as<DynamicEnum>().getEnumerant());
KJ_ASSERT(enumerant.getIndex() ==
ordinal % enumerant.getContainingEnum().getEnumerants().size());
break;
}
case schema::Type::LIST:
checkExampleValue(value.as<DynamicList>()[0], ordinal, type.getList(), sharedOrdinalCount);
break;
default:
KJ_FAIL_ASSERT("You added a new possible field type!");
}
}
void setExampleField(DynamicStruct::Builder builder, StructSchema::Field field,
uint sharedOrdinalCount) {
auto fieldProto = field.getProto();
switch (fieldProto.which()) {
case schema::Field::NON_GROUP:
builder.adopt(field, makeExampleValue(
Orphanage::getForMessageContaining(builder), field.getContainingStruct(),
getOrdinal(field), fieldProto.getNonGroup().getType(), sharedOrdinalCount));
break;
case schema::Field::GROUP:
builder.adopt(field, makeExampleStruct(
Orphanage::getForMessageContaining(builder),
field.getContainingStruct().getDependency(fieldProto.getGroup()).asStruct(),
sharedOrdinalCount));
break;
}
}
void checkExampleField(DynamicStruct::Reader reader, StructSchema::Field field,
uint sharedOrdinalCount) {
auto fieldProto = field.getProto();
switch (fieldProto.which()) {
case schema::Field::NON_GROUP: {
uint ordinal = getOrdinal(field);
if (ordinal < sharedOrdinalCount) {
checkExampleValue(reader.get(field), ordinal,
fieldProto.getNonGroup().getType(), sharedOrdinalCount);
}
break;
}
case schema::Field::GROUP:
checkExampleStruct(reader.get(field).as<DynamicStruct>(), sharedOrdinalCount);
break;
}
}
Orphan<DynamicStruct> makeExampleStruct(
Orphanage orphanage, StructSchema schema, uint sharedOrdinalCount) {
// Initialize all fields of the struct via reflection, such that they can be verified using
// a different version of the struct. sharedOrdinalCount is the number of ordinals shared by
// the two versions. This is used mainly to avoid setting union members that the other version
// doesn't have.
Orphan<DynamicStruct> result = orphanage.newOrphan(schema);
auto builder = result.get();
for (auto field: schema.getNonUnionFields()) {
setExampleField(builder, field, sharedOrdinalCount);
}
auto unionFields = schema.getUnionFields();
// Pretend the union doesn't have any fields that aren't in the shared ordinal range.
uint range = unionFields.size();
while (range > 0 && getOrdinal(unionFields[range - 1]) >= sharedOrdinalCount) {
--range;
}
if (range > 0) {
auto field = unionFields[getOrdinal(unionFields[0]) % range];
setExampleField(builder, field, sharedOrdinalCount);
}
return kj::mv(result);
}
void checkExampleStruct(DynamicStruct::Reader reader, uint sharedOrdinalCount) {
auto schema = reader.getSchema();
for (auto field: schema.getNonUnionFields()) {
checkExampleField(reader, field, sharedOrdinalCount);
}
auto unionFields = schema.getUnionFields();
// Pretend the union doesn't have any fields that aren't in the shared ordinal range.
uint range = unionFields.size();
while (range > 0 && getOrdinal(unionFields[range - 1]) >= sharedOrdinalCount) {
--range;
}
if (range > 0) {
auto field = unionFields[getOrdinal(unionFields[0]) % range];
checkExampleField(reader, field, sharedOrdinalCount);
}
}
// ================================================================================
class ModuleImpl final: public Module {
public:
explicit ModuleImpl(ParsedFile::Reader content): content(content) {}
kj::StringPtr getSourceName() const override { return "evolving-schema.capnp"; }
Orphan<ParsedFile> loadContent(Orphanage orphanage) const override {
return orphanage.newOrphanCopy(content);
}
kj::Maybe<const Module&> importRelative(kj::StringPtr importPath) const override {
return nullptr;
}
void addError(uint32_t startByte, uint32_t endByte, kj::StringPtr message) const override {
KJ_FAIL_ASSERT("Unexpected parse error.", startByte, endByte, message);
}
bool hadErrors() const override {
return false;
}
private:
ParsedFile::Reader content;
};
static void loadStructAndGroups(const SchemaLoader& src, SchemaLoader& dst, uint64_t id) {
auto proto = src.get(id).getProto();
dst.load(proto);
for (auto field: proto.getStruct().getFields()) {
if (field.isGroup()) {
loadStructAndGroups(src, dst, field.getGroup());
}
}
}
static kj::Maybe<kj::Exception> loadFile(
ParsedFile::Reader file, SchemaLoader& loader, bool allNodes,
kj::Maybe<kj::Own<MallocMessageBuilder>>& messageBuilder,
uint sharedOrdinalCount) {
Compiler compiler;
ModuleImpl module(file);
KJ_ASSERT(compiler.add(module) == 0x8123456789abcdefllu);
if (allNodes) {
// Eagerly compile and load the whole thing.
compiler.eagerlyCompile(0x8123456789abcdefllu, Compiler::ALL_RELATED_NODES);
KJ_IF_MAYBE(m, messageBuilder) {
// Build an example struct using the compiled schema.
m->adoptRoot(makeExampleStruct(
m->getOrphanage(), compiler.getLoader().get(0x823456789abcdef1llu).asStruct(),
sharedOrdinalCount));
}
for (auto schema: compiler.getLoader().getAllLoaded()) {
loader.load(schema.getProto());
}
return nullptr;
} else {
// Compile the file root so that the children are findable, then load the specific child
// we want.
compiler.eagerlyCompile(0x8123456789abcdefllu, Compiler::NODE);
KJ_IF_MAYBE(m, messageBuilder) {
// Check that the example struct matches the compiled schema.
auto root = m->getRoot<DynamicStruct>(
compiler.getLoader().get(0x823456789abcdef1llu).asStruct()).asReader();
KJ_CONTEXT(root);
checkExampleStruct(root, sharedOrdinalCount);
}
return kj::runCatchingExceptions([&]() {
loadStructAndGroups(compiler.getLoader(), loader, 0x823456789abcdef1llu);
});
}
}
bool checkChange(ParsedFile::Reader file1, ParsedFile::Reader file2, ChangeKind changeKind,
uint sharedOrdinalCount) {
// Try loading file1 followed by file2 into the same SchemaLoader, expecting it to behave
// according to changeKind. Returns true if the files are both expected to be compatible and
// actually are -- the main loop uses this to decide which version to keep
kj::Maybe<kj::Own<MallocMessageBuilder>> exampleBuilder;
if (changeKind != INCOMPATIBLE) {
// For COMPATIBLE and SUBTLY_COMPATIBLE changes, build an example message with one schema
// and check it with the other.
exampleBuilder = kj::heap<MallocMessageBuilder>();
}
SchemaLoader loader;
loadFile(file1, loader, true, exampleBuilder, sharedOrdinalCount);
auto exception = loadFile(file2, loader, false, exampleBuilder, sharedOrdinalCount);
if (changeKind == COMPATIBLE) {
KJ_IF_MAYBE(e, exception) {
kj::getExceptionCallback().onFatalException(kj::mv(*e));
return false;
} else {
return true;
}
} else if (changeKind == INCOMPATIBLE) {
KJ_ASSERT(exception != nullptr, file1, file2);
return false;
} else {
KJ_ASSERT(changeKind == SUBTLY_COMPATIBLE);
// SchemaLoader is allowed to throw an exception in this case, but we ignore it.
return true;
}
}
kj::MainBuilder::Validity doTest(uint seed) {
srand(seed);
{
kj::String text = kj::str(
"Randomly testing backwards-compatibility scenarios...\n"
"seed = ", seed, " <- PLEASE RECORD THIS NUMBER IF THE TEST FAILS\n");
kj::FdOutputStream(STDOUT_FILENO).write(text.begin(), text.size());
}
KJ_CONTEXT(seed, "PLEASE REPORT THIS FAILURE AND INCLUDE THE SEED");
auto builder = kj::heap<MallocMessageBuilder>();
{
// Set up the basic file decl.
auto parsedFile = builder->initRoot<ParsedFile>();
auto file = parsedFile.initRoot();
file.setFile();
file.initId().initUid().setValue(0x8123456789abcdefllu);
auto decls = file.initNestedDecls(3 + KJ_ARRAY_SIZE(TYPE_OPTIONS));
{
auto decl = decls[0];
decl.initName().setValue("EvolvingStruct");
decl.initId().initUid().setValue(0x823456789abcdef1llu);
decl.setStruct();
}
{
auto decl = decls[1];
decl.initName().setValue("StructType");
decl.setStruct();
auto fieldDecl = decl.initNestedDecls(1)[0];
fieldDecl.initName().setValue("i");
fieldDecl.getId().initOrdinal().setValue(0);
auto field = fieldDecl.initField();
setDeclName(field.initType().initName(), "UInt32");
}
{
auto decl = decls[2];
decl.initName().setValue("EnumType");
decl.setEnum();
auto enumerants = decl.initNestedDecls(4);
for (uint i = 0; i < KJ_ARRAY_SIZE(RFC3092); i++) {
auto enumerantDecl = enumerants[i];
enumerantDecl.initName().setValue(RFC3092[i]);
enumerantDecl.getId().initOrdinal().setValue(i);
enumerantDecl.setEnumerant();
}
}
// For each of TYPE_OPTIONS, declare a struct type that contains that type as its @0 field.
for (uint i = 0; i < KJ_ARRAY_SIZE(TYPE_OPTIONS); i++) {
auto decl = decls[3 + i];
auto& option = TYPE_OPTIONS[i];
decl.initName().setValue(kj::str(option.name, "Struct"));
decl.setStruct();
auto fieldDecl = decl.initNestedDecls(1)[0];
fieldDecl.initName().setValue("f0");
fieldDecl.getId().initOrdinal().setValue(0);
auto field = fieldDecl.initField();
setDeclName(field.initType().initName(), option.name);
uint ordinal = 1;
for (auto j: kj::range(0, rand() % 4)) {
(void)j;
structAddField(decl, ordinal, false);
}
}
}
uint nextOrdinal = 0;
for (uint i = 0; i < 128; i++) {
uint oldOrdinalCount = nextOrdinal;
auto newBuilder = kj::heap<MallocMessageBuilder>();
newBuilder->setRoot(builder->getRoot<ParsedFile>().asReader());
auto parsedFile = newBuilder->getRoot<ParsedFile>();
Declaration::Builder decl = parsedFile.getRoot().getNestedDecls()[0];
// Apply a random modification.
ChangeInfo changeInfo;
while (changeInfo.kind == NO_CHANGE) {
auto& mod = chooseFrom(STRUCT_MODS);
changeInfo = mod(decl, nextOrdinal, false);
}
KJ_CONTEXT(changeInfo.description);
if (checkChange(builder->getRoot<ParsedFile>(), parsedFile, changeInfo.kind, oldOrdinalCount) &&
checkChange(parsedFile, builder->getRoot<ParsedFile>(), changeInfo.kind, oldOrdinalCount)) {
builder = kj::mv(newBuilder);
}
}
return true;
}
class EvolutionTestMain {
public:
explicit EvolutionTestMain(kj::ProcessContext& context)
: context(context) {}
kj::MainFunc getMain() {
return kj::MainBuilder(context, "(unknown version)",
"Integration test / fuzzer which randomly modifies schemas is backwards-compatible ways "
"and verifies that they do actually remain compatible.")
.addOptionWithArg({"seed"}, KJ_BIND_METHOD(*this, setSeed), "<num>",
"Set random number seed to <num>. By default, time() is used.")
.callAfterParsing([this]() { return doTest(seed); })
.build();
}
kj::MainBuilder::Validity setSeed(kj::StringPtr value) {
char* end;
seed = strtol(value.cStr(), &end, 0);
if (value.size() == 0 || *end != '\0') {
return "not an integer";
} else {
return true;
}
}
private:
kj::ProcessContext& context;
uint seed = time(nullptr);
};
} // namespace
} // namespace compiler
} // namespace capnp
KJ_MAIN(capnp::compiler::EvolutionTestMain);
......@@ -71,7 +71,7 @@ public:
// already allocated and therefore cannot be a hole.
kj::Maybe<UIntType> tryAllocate(UIntType lgSize) {
// Try to find space for a field of size lgSize^2 within the set of holes. If found,
// Try to find space for a field of size 2^lgSize within the set of holes. If found,
// remove it from the holes, and return its offset (as a multiple of its size). If there
// is no such space, returns zero (no hole can be at offset zero, as explained above).
......@@ -1126,7 +1126,28 @@ private:
}
case Declaration::UNION:
errorReporter.addErrorOn(member, "Unions cannot contain unions.");
if (member.getName().getValue() == "") {
errorReporter.addErrorOn(member, "Unions cannot contain unnamed unions.");
} else {
parent.childCount++;
// For layout purposes, pretend this union is enclosed in a one-member group.
StructLayout::Group& singletonGroup =
arena.allocate<StructLayout::Group>(layout);
StructLayout::Union& unionLayout = arena.allocate<StructLayout::Union>(singletonGroup);
memberInfo = &arena.allocate<MemberInfo>(
parent, codeOrder++, member,
newGroupNode(parent.node, member.getName().getValue()),
true);
allMembers.add(memberInfo);
memberInfo->unionScope = &unionLayout;
uint subCodeOrder = 0;
traverseUnion(member.getNestedDecls(), *memberInfo, unionLayout, subCodeOrder);
if (member.getId().isOrdinal()) {
ordinal = member.getId().getOrdinal().getValue();
}
}
break;
case Declaration::GROUP: {
......
......@@ -764,6 +764,43 @@ TEST(Encoding, BitListDowngrade) {
checkList(reader.getObjectField<List<uint16_t>>(), {0x1201u, 0x3400u, 0x5601u, 0x7801u});
}
TEST(Encoding, BitListDowngradeFromStruct) {
MallocMessageBuilder builder;
auto root = builder.initRoot<test::TestObject>();
{
auto list = root.initObjectField<List<test::TestLists::Struct1c>>(4);
list[0].setF(true);
list[1].setF(false);
list[2].setF(true);
list[3].setF(true);
}
checkList(root.getObjectField<List<bool>>(), {true, false, true, true});
{
auto l = root.getObjectField<List<test::TestLists::Struct1>>();
ASSERT_EQ(4u, l.size());
EXPECT_TRUE(l[0].getF());
EXPECT_FALSE(l[1].getF());
EXPECT_TRUE(l[2].getF());
EXPECT_TRUE(l[3].getF());
}
auto reader = root.asReader();
checkList(reader.getObjectField<List<bool>>(), {true, false, true, true});
{
auto l = reader.getObjectField<List<test::TestLists::Struct1>>();
ASSERT_EQ(4u, l.size());
EXPECT_TRUE(l[0].getF());
EXPECT_FALSE(l[1].getF());
EXPECT_TRUE(l[2].getF());
EXPECT_TRUE(l[3].getF());
}
}
TEST(Encoding, BitListUpgrade) {
MallocMessageBuilder builder;
auto root = builder.initRoot<test::TestObject>();
......
......@@ -1757,11 +1757,6 @@ struct WireHelpers {
break;
case FieldSize::BIT:
KJ_FAIL_REQUIRE("Expected a bit list, but got a list of structs.") {
goto useDefault;
}
break;
case FieldSize::BYTE:
case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES:
......
......@@ -111,6 +111,7 @@ struct List<T, Kind::PRIMITIVE> {
inline uint size() const { return reader.size() / ELEMENTS; }
inline T operator[](uint index) const {
KJ_IREQUIRE(index < size());
return reader.template getDataElement<T>(index * ELEMENTS);
}
......@@ -141,6 +142,7 @@ struct List<T, Kind::PRIMITIVE> {
inline uint size() const { return builder.size() / ELEMENTS; }
inline T operator[](uint index) {
KJ_IREQUIRE(index < size());
return builder.template getDataElement<T>(index * ELEMENTS);
}
inline void set(uint index, T value) {
......@@ -216,6 +218,7 @@ struct List<T, Kind::STRUCT> {
inline uint size() const { return reader.size() / ELEMENTS; }
inline typename T::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size());
return typename T::Reader(reader.getStructElement(index * ELEMENTS));
}
......@@ -246,6 +249,7 @@ struct List<T, Kind::STRUCT> {
inline uint size() const { return builder.size() / ELEMENTS; }
inline typename T::Builder operator[](uint index) {
KJ_IREQUIRE(index < size());
return typename T::Builder(builder.getStructElement(index * ELEMENTS));
}
......@@ -259,6 +263,8 @@ struct List<T, Kind::STRUCT> {
// using a newer version of the schema that has additional fields -- it will be truncated,
// losing data.
KJ_IREQUIRE(index < size());
// We pass a zero-valued StructSize to asStruct() because we do not want the struct to be
// expanded under any circumstances. We're just going to throw it away anyway, and
// transferContentFrom() already carefully compares the struct sizes before transferring.
......@@ -273,6 +279,7 @@ struct List<T, Kind::STRUCT> {
// using a newer version of the schema that has additional fields -- it will be truncated,
// losing data.
KJ_IREQUIRE(index < size());
builder.getStructElement(index * ELEMENTS).copyContentFrom(reader._reader);
}
......@@ -341,6 +348,7 @@ struct List<List<T>, Kind::LIST> {
inline uint size() const { return reader.size() / ELEMENTS; }
inline typename List<T>::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size());
return typename List<T>::Reader(List<T>::getAsElementOf(reader, index));
}
......@@ -371,15 +379,19 @@ struct List<List<T>, Kind::LIST> {
inline uint size() const { return builder.size() / ELEMENTS; }
inline typename List<T>::Builder operator[](uint index) {
KJ_IREQUIRE(index < size());
return typename List<T>::Builder(List<T>::getAsElementOf(builder, index));
}
inline typename List<T>::Builder init(uint index, uint size) {
KJ_IREQUIRE(index < this->size());
return typename List<T>::Builder(List<T>::initAsElementOf(builder, index, size));
}
inline void set(uint index, typename List<T>::Reader value) {
KJ_IREQUIRE(index < size());
builder.setListElement(index * ELEMENTS, value.reader);
}
void set(uint index, std::initializer_list<ReaderFor<T>> value) {
KJ_IREQUIRE(index < size());
auto l = init(index, value.size());
uint i = 0;
for (auto& element: value) {
......@@ -387,9 +399,11 @@ struct List<List<T>, Kind::LIST> {
}
}
inline void adopt(uint index, Orphan<T>&& value) {
KJ_IREQUIRE(index < size());
builder.adopt(index * ELEMENTS, kj::mv(value));
}
inline Orphan<T> disown(uint index) {
KJ_IREQUIRE(index < size());
return Orphan<T>(builder.disown(index * ELEMENTS));
}
......@@ -451,6 +465,7 @@ struct List<T, Kind::BLOB> {
inline uint size() const { return reader.size() / ELEMENTS; }
inline typename T::Reader operator[](uint index) const {
KJ_IREQUIRE(index < size());
return reader.getBlobElement<T>(index * ELEMENTS);
}
......@@ -481,18 +496,23 @@ struct List<T, Kind::BLOB> {
inline uint size() const { return builder.size() / ELEMENTS; }
inline typename T::Builder operator[](uint index) {
KJ_IREQUIRE(index < size());
return builder.getBlobElement<T>(index * ELEMENTS);
}
inline void set(uint index, typename T::Reader value) {
KJ_IREQUIRE(index < size());
builder.setBlobElement<T>(index * ELEMENTS, value);
}
inline typename T::Builder init(uint index, uint size) {
KJ_IREQUIRE(index < this->size());
return builder.initBlobElement<T>(index * ELEMENTS, size * BYTES);
}
inline void adopt(uint index, Orphan<T>&& value) {
KJ_IREQUIRE(index < size());
builder.adopt(index * ELEMENTS, kj::mv(value));
}
inline Orphan<T> disown(uint index) {
KJ_IREQUIRE(index < size());
return Orphan<T>(builder.disown(index * ELEMENTS));
}
......
......@@ -377,7 +377,8 @@ inline typename RootType::Builder MessageBuilder::initRoot() {
template <typename Reader>
inline void MessageBuilder::setRoot(Reader&& value) {
typedef FromReader<Reader> RootType;
static_assert(kind<RootType>() == Kind::STRUCT, "Root type must be a Cap'n Proto struct type.");
static_assert(kind<RootType>() == Kind::STRUCT,
"Parameter must be a Reader for a Cap'n Proto struct type.");
setRootInternal(value._reader);
}
......
......@@ -442,7 +442,8 @@ private:
}
if (hadCase) {
VALIDATE_SCHEMA(value.which() == expectedValueType, "Value did not match type.");
VALIDATE_SCHEMA(value.which() == expectedValueType, "Value did not match type.",
(uint)value.which(), (uint)expectedValueType);
}
}
......@@ -511,6 +512,9 @@ public:
bool shouldReplace(const schema::Node::Reader& existingNode,
const schema::Node::Reader& replacement,
bool preferReplacementIfEquivalent) {
this->existingNode = existingNode;
this->replacementNode = replacement;
KJ_CONTEXT("checking compatibility with previously-loaded node of the same id",
existingNode.getDisplayName());
......@@ -528,6 +532,8 @@ public:
private:
SchemaLoader::Impl& loader;
Text::Reader nodeName;
schema::Node::Reader existingNode;
schema::Node::Reader replacementNode;
enum Compatibility {
EQUIVALENT,
......@@ -633,6 +639,17 @@ private:
replacementIsOlder();
}
if (replacement.getDiscriminantCount() > structNode.getDiscriminantCount()) {
replacementIsNewer();
} else if (replacement.getDiscriminantCount() < structNode.getDiscriminantCount()) {
replacementIsOlder();
}
if (replacement.getDiscriminantCount() > 0 && structNode.getDiscriminantCount() > 0) {
VALIDATE_SCHEMA(replacement.getDiscriminantOffset() == structNode.getDiscriminantOffset(),
"union discriminant position changed");
}
// The shared members should occupy corresponding positions in the member lists, since the
// lists are sorted by ordinal.
auto fields = structNode.getFields();
......@@ -672,26 +689,48 @@ private:
const schema::Field::Reader& replacement) {
KJ_CONTEXT("comparing struct field", field.getName());
VALIDATE_SCHEMA(field.which() == replacement.which(),
"group field replaced with non-group or vice versa");
// A field that is initially not in a union can be upgraded to be in one, as long as it has
// discriminant 0.
uint discriminant = field.hasDiscriminantValue() ? field.getDiscriminantValue() : 0;
uint replacementDiscriminant =
replacement.hasDiscriminantValue() ? replacement.getDiscriminantValue() : 0;
VALIDATE_SCHEMA(discriminant == replacementDiscriminant, "Field discriminant changed.");
switch (field.which()) {
case schema::Field::NON_GROUP: {
auto nonGroup = field.getNonGroup();
auto replacementNonGroup = replacement.getNonGroup();
checkCompatibility(nonGroup.getType(), replacementNonGroup.getType(),
NO_UPGRADE_TO_STRUCT);
checkDefaultCompatibility(nonGroup.getDefaultValue(),
replacementNonGroup.getDefaultValue());
switch (replacement.which()) {
case schema::Field::NON_GROUP: {
auto replacementNonGroup = replacement.getNonGroup();
checkCompatibility(nonGroup.getType(), replacementNonGroup.getType(),
NO_UPGRADE_TO_STRUCT);
checkDefaultCompatibility(nonGroup.getDefaultValue(),
replacementNonGroup.getDefaultValue());
VALIDATE_SCHEMA(nonGroup.getOffset() == replacementNonGroup.getOffset(),
"field position changed");
break;
}
case schema::Field::GROUP:
checkUpgradeToStruct(nonGroup.getType(), replacement.getGroup(), existingNode, field);
break;
}
VALIDATE_SCHEMA(nonGroup.getOffset() == replacementNonGroup.getOffset(),
"field position changed");
break;
}
case schema::Field::GROUP:
VALIDATE_SCHEMA(field.getGroup() == replacement.getGroup(), "group id changed");
switch (replacement.which()) {
case schema::Field::NON_GROUP:
checkUpgradeToStruct(replacement.getNonGroup().getType(), field.getGroup(),
replacementNode, replacement);
break;
case schema::Field::GROUP:
VALIDATE_SCHEMA(field.getGroup() == replacement.getGroup(), "group id changed");
break;
}
break;
}
}
......@@ -859,7 +898,9 @@ private:
// We assume unknown types (from newer versions of Cap'n Proto?) are equivalent.
}
void checkUpgradeToStruct(const schema::Type::Reader& type, uint64_t structTypeId) {
void checkUpgradeToStruct(const schema::Type::Reader& type, uint64_t structTypeId,
kj::Maybe<schema::Node::Reader> matchSize = nullptr,
kj::Maybe<schema::Field::Reader> matchPosition = nullptr) {
// We can't just look up the target struct and check it because it may not have been loaded
// yet. Instead, we contrive a struct that looks like what we want and load() that, which
// guarantees that any incompatibility will be caught either now or when the real version of
......@@ -929,11 +970,55 @@ private:
break;
}
KJ_IF_MAYBE(s, matchSize) {
auto match = s->getStruct();
structNode.setDataWordCount(match.getDataWordCount());
structNode.setPointerCount(match.getPointerCount());
structNode.setPreferredListEncoding(match.getPreferredListEncoding());
}
auto field = structNode.initFields(1)[0];
field.setName("member0");
field.getOrdinal().setExplicit(0);
field.setCodeOrder(0);
field.initNonGroup().setType(type);
auto nongroup = field.initNonGroup();
nongroup.setType(type);
KJ_IF_MAYBE(p, matchPosition) {
if (p->getOrdinal().isExplicit()) {
field.getOrdinal().setExplicit(p->getOrdinal().getExplicit());
} else {
field.getOrdinal().setImplicit();
}
auto matchNongroup = p->getNonGroup();
nongroup.setOffset(matchNongroup.getOffset());
nongroup.setDefaultValue(matchNongroup.getDefaultValue());
} else {
field.getOrdinal().setExplicit(0);
nongroup.setOffset(0);
schema::Value::Builder value = nongroup.initDefaultValue();
switch (type.which()) {
case schema::Type::VOID: value.setVoid(); break;
case schema::Type::BOOL: value.setBool(false); break;
case schema::Type::INT8: value.setInt8(0); break;
case schema::Type::INT16: value.setInt16(0); break;
case schema::Type::INT32: value.setInt32(0); break;
case schema::Type::INT64: value.setInt64(0); break;
case schema::Type::UINT8: value.setUint8(0); break;
case schema::Type::UINT16: value.setUint16(0); break;
case schema::Type::UINT32: value.setUint32(0); break;
case schema::Type::UINT64: value.setUint64(0); break;
case schema::Type::FLOAT32: value.setFloat32(0); break;
case schema::Type::FLOAT64: value.setFloat64(0); break;
case schema::Type::ENUM: value.setEnum(0); break;
case schema::Type::TEXT: value.adoptText(Orphan<Text>()); break;
case schema::Type::DATA: value.adoptData(Orphan<Data>()); break;
case schema::Type::LIST: value.adoptList(Orphan<Data>()); break;
case schema::Type::STRUCT: value.adoptStruct(Orphan<Data>()); break;
case schema::Type::INTERFACE: value.setInterface(); break;
case schema::Type::OBJECT: value.adoptObject(Orphan<Data>()); break;
}
}
loader.load(node, true);
}
......
......@@ -269,6 +269,17 @@ struct TestUnnamedUnion {
after @4 :Text;
}
struct TestUnionInUnion {
# There is no reason to ever do this.
outer :union {
inner :union {
foo @0 :Int32;
bar @1 :Int32;
}
baz @2 :Int32;
}
}
struct TestGroups {
groups :union {
foo :group {
......@@ -381,6 +392,15 @@ struct TestLists {
struct Struct64 { f @0 :UInt64; }
struct StructP { f @0 :Text; }
# Versions of the above which cannot be encoded as primitive lists.
struct Struct0c { f @0 :Void; pad @1 :Text; }
struct Struct1c { f @0 :Bool; pad @1 :Text; }
struct Struct8c { f @0 :UInt8; pad @1 :Text; }
struct Struct16c { f @0 :UInt16; pad @1 :Text; }
struct Struct32c { f @0 :UInt32; pad @1 :Text; }
struct Struct64c { f @0 :UInt64; pad @1 :Text; }
struct StructPc { f @0 :Text; pad @1 :UInt64; }
list0 @0 :List(Struct0);
list1 @1 :List(Struct1);
list8 @2 :List(Struct8);
......
......@@ -79,5 +79,45 @@ TEST(Function, Method) {
EXPECT_EQ(9 + 2 + 5, f(2, 9));
}
struct TestConstType {
mutable int callCount;
TestConstType(int callCount = 0): callCount(callCount) {}
~TestConstType() { callCount = 1234; }
// Make sure we catch invalid post-destruction uses.
int foo(int a, int b) const {
return a + b + callCount++;
}
};
TEST(ConstFunction, Method) {
TestConstType obj;
ConstFunction<int(int, int)> f = KJ_BIND_METHOD(obj, foo);
ConstFunction<uint(uint, uint)> f2 = KJ_BIND_METHOD(obj, foo);
EXPECT_EQ(123 + 456, f(123, 456));
EXPECT_EQ(7 + 8 + 1, f(7, 8));
EXPECT_EQ(9u + 2u + 2u, f2(2, 9));
EXPECT_EQ(3, obj.callCount);
// Bind to a temporary.
f = KJ_BIND_METHOD(TestConstType(10), foo);
EXPECT_EQ(123 + 456 + 10, f(123, 456));
EXPECT_EQ(7 + 8 + 11, f(7, 8));
EXPECT_EQ(9 + 2 + 12, f(2, 9));
// Bind to a move.
f = KJ_BIND_METHOD(kj::mv(obj), foo);
obj.callCount = 1234;
EXPECT_EQ(123 + 456 + 3, f(123, 456));
EXPECT_EQ(7 + 8 + 4, f(7, 8));
EXPECT_EQ(9 + 2 + 5, f(2, 9));
}
} // namespace
} // namespace kj
......@@ -84,6 +84,10 @@ class Function;
// Notice how KJ_BIND_METHOD is able to figure out which overload to use depending on the kind of
// Function it is binding to.
template <typename Signature>
class ConstFunction;
// Like Function, but wraps a "const" (i.e. thread-safe) call.
template <typename Return, typename... Params>
class Function<Return(Params...)> {
public:
......@@ -91,10 +95,30 @@ public:
inline Function(F&& f): impl(heap<Impl<F>>(kj::fwd<F>(f))) {}
Function() = default;
// Make sure people don't accidentally end up wrapping a reference when they meant to return
// a function.
KJ_DISALLOW_COPY(Function);
Function(Function&) = delete;
Function& operator=(Function&) = delete;
template <typename T> Function(const Function<T>&) = delete;
template <typename T> Function& operator=(const Function<T>&) = delete;
template <typename T> Function(const ConstFunction<T>&) = delete;
template <typename T> Function& operator=(const ConstFunction<T>&) = delete;
Function(Function&&) = default;
Function& operator=(Function&&) = default;
inline Return operator()(Params... params) {
return (*impl)(kj::fwd<Params>(params)...);
}
Function reference() {
// Forms a new Function of the same type that delegates to this Function by reference.
// Therefore, this Function must outlive the returned Function, but otherwise they behave
// exactly the same.
return *impl;
}
private:
class Iface {
public:
......@@ -117,6 +141,59 @@ private:
Own<Iface> impl;
};
template <typename Return, typename... Params>
class ConstFunction<Return(Params...)> {
public:
template <typename F>
inline ConstFunction(F&& f): impl(heap<Impl<F>>(kj::fwd<F>(f))) {}
ConstFunction() = default;
// Make sure people don't accidentally end up wrapping a reference when they meant to return
// a function.
KJ_DISALLOW_COPY(ConstFunction);
ConstFunction(ConstFunction&) = delete;
ConstFunction& operator=(ConstFunction&) = delete;
template <typename T> ConstFunction(const ConstFunction<T>&) = delete;
template <typename T> ConstFunction& operator=(const ConstFunction<T>&) = delete;
template <typename T> ConstFunction(const Function<T>&) = delete;
template <typename T> ConstFunction& operator=(const Function<T>&) = delete;
ConstFunction(ConstFunction&&) = default;
ConstFunction& operator=(ConstFunction&&) = default;
inline Return operator()(Params... params) const {
return (*impl)(kj::fwd<Params>(params)...);
}
ConstFunction reference() const {
// Forms a new ConstFunction of the same type that delegates to this ConstFunction by reference.
// Therefore, this ConstFunction must outlive the returned ConstFunction, but otherwise they
// behave exactly the same.
return *impl;
}
private:
class Iface {
public:
virtual Return operator()(Params... params) const = 0;
};
template <typename F>
class Impl final: public Iface {
public:
explicit Impl(F&& f): f(kj::fwd<F>(f)) {}
Return operator()(Params... params) const override {
return f(kj::fwd<Params>(params)...);
}
private:
F f;
};
Own<Iface> impl;
};
#if 1
namespace _ { // private
......@@ -137,6 +214,20 @@ private:
T t;
};
template <typename T, typename Return, typename... Params,
Return (Decay<T>::*method)(Params...) const>
class BoundMethod<T, Return (Decay<T>::*)(Params...) const, method> {
public:
BoundMethod(T&& t): t(kj::fwd<T>(t)) {}
Return operator()(Params&&... params) const {
return (t.*method)(kj::fwd<Params>(params)...);
}
private:
T t;
};
} // namespace _ (private)
#define KJ_BIND_METHOD(obj, method) \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment