Commit 1dcb66b1 authored by Kenton Varda's avatar Kenton Varda

Continuing schema rewrite WIP.

parent 863afbe2
...@@ -903,6 +903,26 @@ Compiler::Impl::Impl(AnnotationFlag annotationFlag) ...@@ -903,6 +903,26 @@ Compiler::Impl::Impl(AnnotationFlag annotationFlag)
: annotationFlag(annotationFlag), finalLoader(*this), workspace(*this) { : annotationFlag(annotationFlag), finalLoader(*this), workspace(*this) {
// Reflectively interpret the members of Declaration.body. Any member prefixed by "builtin" // Reflectively interpret the members of Declaration.body. Any member prefixed by "builtin"
// defines a builtin declaration visible in the global scope. // defines a builtin declaration visible in the global scope.
#warning "temporary hack for schema transition"
builtinDecls["Void"] = nodeArena.allocateOwn<Node>("Void", Declaration::Body::BUILTIN_VOID);
builtinDecls["Bool"] = nodeArena.allocateOwn<Node>("Bool", Declaration::Body::BUILTIN_BOOL);
builtinDecls["Int8"] = nodeArena.allocateOwn<Node>("Int8", Declaration::Body::BUILTIN_INT8);
builtinDecls["Int16"] = nodeArena.allocateOwn<Node>("Int16", Declaration::Body::BUILTIN_INT16);
builtinDecls["Int32"] = nodeArena.allocateOwn<Node>("Int32", Declaration::Body::BUILTIN_INT32);
builtinDecls["Int64"] = nodeArena.allocateOwn<Node>("Int64", Declaration::Body::BUILTIN_INT64);
builtinDecls["UInt8"] = nodeArena.allocateOwn<Node>("UInt8", Declaration::Body::BUILTIN_U_INT8);
builtinDecls["UInt16"] = nodeArena.allocateOwn<Node>("UInt16", Declaration::Body::BUILTIN_U_INT16);
builtinDecls["UInt32"] = nodeArena.allocateOwn<Node>("UInt32", Declaration::Body::BUILTIN_U_INT32);
builtinDecls["UInt64"] = nodeArena.allocateOwn<Node>("UInt64", Declaration::Body::BUILTIN_U_INT64);
builtinDecls["Float32"] = nodeArena.allocateOwn<Node>("Float32", Declaration::Body::BUILTIN_FLOAT32);
builtinDecls["Float64"] = nodeArena.allocateOwn<Node>("Float64", Declaration::Body::BUILTIN_FLOAT64);
builtinDecls["Text"] = nodeArena.allocateOwn<Node>("Text", Declaration::Body::BUILTIN_TEXT);
builtinDecls["Data"] = nodeArena.allocateOwn<Node>("Data", Declaration::Body::BUILTIN_DATA);
builtinDecls["List"] = nodeArena.allocateOwn<Node>("List", Declaration::Body::BUILTIN_LIST);
builtinDecls["Object"] = nodeArena.allocateOwn<Node>("Object", Declaration::Body::BUILTIN_OBJECT);
#if 0
StructSchema::Union declBodySchema = StructSchema::Union declBodySchema =
Schema::from<Declaration>().getMemberByName("body").asUnion(); Schema::from<Declaration>().getMemberByName("body").asUnion();
for (auto member: declBodySchema.getMembers()) { for (auto member: declBodySchema.getMembers()) {
...@@ -913,6 +933,7 @@ Compiler::Impl::Impl(AnnotationFlag annotationFlag) ...@@ -913,6 +933,7 @@ Compiler::Impl::Impl(AnnotationFlag annotationFlag)
symbolName, static_cast<Declaration::Body::Which>(member.getIndex())); symbolName, static_cast<Declaration::Body::Which>(member.getIndex()));
} }
} }
#endif
} }
Compiler::Impl::~Impl() noexcept(false) {} Compiler::Impl::~Impl() noexcept(false) {}
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "node-translator.h" #include "node-translator.h"
#include "parser.h" // only for generateChildId() #include "parser.h" // only for generateGroupId()
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/arena.h> #include <kj/arena.h>
#include <set> #include <set>
...@@ -906,6 +906,7 @@ public: ...@@ -906,6 +906,7 @@ public:
case Declaration::Body::GROUP_DECL: case Declaration::Body::GROUP_DECL:
member->setDiscriminantOffsetInSchema(); // in case it contains an unnamed union member->setDiscriminantOffsetInSchema(); // in case it contains an unnamed union
member->node.setId(generateGroupId(member->parent->node.getId(), member->index));
targetsFlagName = "targetsGroup"; targetsFlagName = "targetsGroup";
break; break;
...@@ -964,6 +965,9 @@ private: ...@@ -964,6 +965,9 @@ private:
uint codeOrder; uint codeOrder;
// Code order within the parent. // Code order within the parent.
uint index = 0;
// Index within the parent.
uint childCount = 0; uint childCount = 0;
// Number of children this member has. // Number of children this member has.
...@@ -1017,6 +1021,7 @@ private: ...@@ -1017,6 +1021,7 @@ private:
KJ_IF_MAYBE(result, schema) { KJ_IF_MAYBE(result, schema) {
return *result; return *result;
} else { } else {
index = parent->childInitializedCount;
auto builder = parent->addMemberSchema(); auto builder = parent->addMemberSchema();
if (isInUnion) { if (isInUnion) {
builder.setDiscriminantValue(parent->unionDiscriminantCount++); builder.setDiscriminantValue(parent->unionDiscriminantCount++);
...@@ -1189,7 +1194,7 @@ private: ...@@ -1189,7 +1194,7 @@ private:
.newOrphan<schema2::Node>(); .newOrphan<schema2::Node>();
auto node = orphan.get(); auto node = orphan.get();
node.setId(generateChildId(parent.getId(), name)); // We'll set the ID later.
node.setDisplayName(kj::str(parent.getDisplayName(), '.', name)); node.setDisplayName(kj::str(parent.getDisplayName(), '.', name));
node.setDisplayNamePrefixLength(node.getDisplayName().size() - name.size()); node.setDisplayNamePrefixLength(node.getDisplayName().size() - name.size());
node.setScopeId(parent.getId()); node.setScopeId(parent.getId());
......
...@@ -70,6 +70,31 @@ uint64_t generateChildId(uint64_t parentId, kj::StringPtr childName) { ...@@ -70,6 +70,31 @@ uint64_t generateChildId(uint64_t parentId, kj::StringPtr childName) {
return result | (1ull << 63); return result | (1ull << 63);
} }
uint64_t generateGroupId(uint64_t parentId, uint16_t groupIndex) {
// Compute ID by MD5 hashing the concatenation of the parent ID and the group index, and
// then taking the first 8 bytes.
kj::byte bytes[sizeof(uint64_t) + sizeof(uint16_t)];
for (uint i = 0; i < sizeof(uint64_t); i++) {
bytes[i] = (parentId >> (i * 8)) & 0xff;
}
for (uint i = 0; i < sizeof(uint16_t); i++) {
bytes[sizeof(uint64_t) + i] = (groupIndex >> (i * 8)) & 0xff;
}
Md5 md5;
md5.update(bytes);
kj::ArrayPtr<const kj::byte> resultBytes = md5.finish();
uint64_t result = 0;
for (uint i = 0; i < sizeof(uint64_t); i++) {
result = (result << 8) | resultBytes[i];
}
return result | (1ull << 63);
}
void parseFile(List<Statement>::Reader statements, ParsedFile::Builder result, void parseFile(List<Statement>::Reader statements, ParsedFile::Builder result,
const ErrorReporter& errorReporter) { const ErrorReporter& errorReporter) {
CapnpParser parser(Orphanage::getForMessageContaining(result), errorReporter); CapnpParser parser(Orphanage::getForMessageContaining(result), errorReporter);
...@@ -858,6 +883,8 @@ CapnpParser::CapnpParser(Orphanage orphanageParam, const ErrorReporter& errorRep ...@@ -858,6 +883,8 @@ CapnpParser::CapnpParser(Orphanage orphanageParam, const ErrorReporter& errorRep
DynamicStruct::Builder dynamicBuilder = builder; DynamicStruct::Builder dynamicBuilder = builder;
for (auto& maybeTarget: targets.value) { for (auto& maybeTarget: targets.value) {
KJ_IF_MAYBE(target, maybeTarget) { KJ_IF_MAYBE(target, maybeTarget) {
#warning "temporary hack for schema transition"
#if 0
if (target->value == "*") { if (target->value == "*") {
// Set all. // Set all.
if (targets.value.size() > 1) { if (targets.value.size() > 1) {
...@@ -892,6 +919,7 @@ CapnpParser::CapnpParser(Orphanage orphanageParam, const ErrorReporter& errorRep ...@@ -892,6 +919,7 @@ CapnpParser::CapnpParser(Orphanage orphanageParam, const ErrorReporter& errorRep
} }
} }
} }
#endif
} }
} }
return DeclParserResult(kj::mv(decl)); return DeclParserResult(kj::mv(decl));
......
...@@ -45,8 +45,11 @@ uint64_t generateRandomId(); ...@@ -45,8 +45,11 @@ uint64_t generateRandomId();
uint64_t generateChildId(uint64_t parentId, kj::StringPtr childName); uint64_t generateChildId(uint64_t parentId, kj::StringPtr childName);
// Generate the ID for a child node given its parent ID and name. // Generate the ID for a child node given its parent ID and name.
uint64_t generateGroupId(uint64_t parentId, uint16_t groupIndex);
// Generate the ID for a group within a struct.
// //
// TODO(cleanup): Move generateRandomId() and generateChildId() somewhere more sensible. // TODO(cleanup): Move generate*Id() somewhere more sensible.
class CapnpParser { class CapnpParser {
// Advanced parser interface. This interface exposes the inner parsers so that you can embed // Advanced parser interface. This interface exposes the inner parsers so that you can embed
......
...@@ -858,7 +858,7 @@ DynamicValue::Builder DynamicStruct::Builder::init(kj::StringPtr name, uint size ...@@ -858,7 +858,7 @@ DynamicValue::Builder DynamicStruct::Builder::init(kj::StringPtr name, uint size
return init(schema.getFieldByName(name), size); return init(schema.getFieldByName(name), size);
} }
void DynamicStruct::Builder::clear(kj::StringPtr name) { void DynamicStruct::Builder::clear(kj::StringPtr name) {
return clear(schema.getFieldByName(name)); clear(schema.getFieldByName(name));
} }
DynamicStruct::Builder DynamicStruct::Builder::getObject( DynamicStruct::Builder DynamicStruct::Builder::getObject(
kj::StringPtr name, StructSchema type) { kj::StringPtr name, StructSchema type) {
...@@ -1332,16 +1332,22 @@ Void DynamicValue::Builder::AsImpl<Void>::apply(Builder& builder) { ...@@ -1332,16 +1332,22 @@ Void DynamicValue::Builder::AsImpl<Void>::apply(Builder& builder) {
template <> template <>
DynamicStruct::Reader MessageReader::getRoot<DynamicStruct>(StructSchema schema) { DynamicStruct::Reader MessageReader::getRoot<DynamicStruct>(StructSchema schema) {
KJ_REQUIRE(!schema.getProto().getStruct().getIsGroup(),
"Can't use group type as the root of a message.");
return DynamicStruct::Reader(schema, getRootInternal()); return DynamicStruct::Reader(schema, getRootInternal());
} }
template <> template <>
DynamicStruct::Builder MessageBuilder::initRoot<DynamicStruct>(StructSchema schema) { DynamicStruct::Builder MessageBuilder::initRoot<DynamicStruct>(StructSchema schema) {
KJ_REQUIRE(!schema.getProto().getStruct().getIsGroup(),
"Can't use group type as the root of a message.");
return DynamicStruct::Builder(schema, initRoot(structSizeFromSchema(schema))); return DynamicStruct::Builder(schema, initRoot(structSizeFromSchema(schema)));
} }
template <> template <>
DynamicStruct::Builder MessageBuilder::getRoot<DynamicStruct>(StructSchema schema) { DynamicStruct::Builder MessageBuilder::getRoot<DynamicStruct>(StructSchema schema) {
KJ_REQUIRE(!schema.getProto().getStruct().getIsGroup(),
"Can't use group type as the root of a message.");
return DynamicStruct::Builder(schema, getRoot(structSizeFromSchema(schema))); return DynamicStruct::Builder(schema, getRoot(structSizeFromSchema(schema)));
} }
...@@ -1349,19 +1355,27 @@ namespace _ { // private ...@@ -1349,19 +1355,27 @@ namespace _ { // private
DynamicStruct::Reader PointerHelpers<DynamicStruct, Kind::UNKNOWN>::getDynamic( DynamicStruct::Reader PointerHelpers<DynamicStruct, Kind::UNKNOWN>::getDynamic(
StructReader reader, WirePointerCount index, StructSchema schema) { StructReader reader, WirePointerCount index, StructSchema schema) {
KJ_REQUIRE(!schema.getProto().getStruct().getIsGroup(),
"Cannot form pointer to group type.");
return DynamicStruct::Reader(schema, reader.getStructField(index, nullptr)); return DynamicStruct::Reader(schema, reader.getStructField(index, nullptr));
} }
DynamicStruct::Builder PointerHelpers<DynamicStruct, Kind::UNKNOWN>::getDynamic( DynamicStruct::Builder PointerHelpers<DynamicStruct, Kind::UNKNOWN>::getDynamic(
StructBuilder builder, WirePointerCount index, StructSchema schema) { StructBuilder builder, WirePointerCount index, StructSchema schema) {
KJ_REQUIRE(!schema.getProto().getStruct().getIsGroup(),
"Cannot form pointer to group type.");
return DynamicStruct::Builder(schema, builder.getStructField( return DynamicStruct::Builder(schema, builder.getStructField(
index, structSizeFromSchema(schema), nullptr)); index, structSizeFromSchema(schema), nullptr));
} }
void PointerHelpers<DynamicStruct, Kind::UNKNOWN>::set( void PointerHelpers<DynamicStruct, Kind::UNKNOWN>::set(
StructBuilder builder, WirePointerCount index, const DynamicStruct::Reader& value) { StructBuilder builder, WirePointerCount index, const DynamicStruct::Reader& value) {
KJ_REQUIRE(!value.schema.getProto().getStruct().getIsGroup(),
"Cannot form pointer to group type.");
builder.setStructField(index, value.reader); builder.setStructField(index, value.reader);
} }
DynamicStruct::Builder PointerHelpers<DynamicStruct, Kind::UNKNOWN>::init( DynamicStruct::Builder PointerHelpers<DynamicStruct, Kind::UNKNOWN>::init(
StructBuilder builder, WirePointerCount index, StructSchema schema) { StructBuilder builder, WirePointerCount index, StructSchema schema) {
KJ_REQUIRE(!schema.getProto().getStruct().getIsGroup(),
"Cannot form pointer to group type.");
return DynamicStruct::Builder(schema, return DynamicStruct::Builder(schema,
builder.initStructField(index, structSizeFromSchema(schema))); builder.initStructField(index, structSizeFromSchema(schema)));
} }
......
...@@ -167,6 +167,7 @@ struct RawSchema { ...@@ -167,6 +167,7 @@ struct RawSchema {
uint16_t value; uint16_t value;
inline operator uint16_t() const { return value; } inline operator uint16_t() const { return value; }
MemberInfo() = default;
constexpr MemberInfo(uint16_t value): value(value) {} constexpr MemberInfo(uint16_t value): value(value) {}
constexpr MemberInfo(uint16_t value, uint16_t dummy): value(value) {} constexpr MemberInfo(uint16_t value, uint16_t dummy): value(value) {}
}; };
......
...@@ -234,7 +234,7 @@ template <typename Reader> ...@@ -234,7 +234,7 @@ template <typename Reader>
void copyToUnchecked(Reader&& reader, kj::ArrayPtr<word> uncheckedBuffer); void copyToUnchecked(Reader&& reader, kj::ArrayPtr<word> uncheckedBuffer);
// Copy the content of the given reader into the given buffer, such that it can safely be passed to // Copy the content of the given reader into the given buffer, such that it can safely be passed to
// readMessageUnchecked(). The buffer's size must be exactly reader.totalSizeInWords() + 1, // readMessageUnchecked(). The buffer's size must be exactly reader.totalSizeInWords() + 1,
// otherwise an exception will be thrown. // otherwise an exception will be thrown. The buffer must be zero'd before calling.
template <typename Type> template <typename Type>
static typename Type::Reader defaultValue(); static typename Type::Reader defaultValue();
...@@ -332,7 +332,7 @@ private: ...@@ -332,7 +332,7 @@ private:
class FlatMessageBuilder: public MessageBuilder { class FlatMessageBuilder: public MessageBuilder {
// A message builder implementation which allocates from a single flat array, throwing an // A message builder implementation which allocates from a single flat array, throwing an
// exception if it runs out of space. // exception if it runs out of space. The array must be zero'd before use.
public: public:
explicit FlatMessageBuilder(kj::ArrayPtr<word> array); explicit FlatMessageBuilder(kj::ArrayPtr<word> array);
......
...@@ -56,11 +56,11 @@ public: ...@@ -56,11 +56,11 @@ public:
inline Impl(const SchemaLoader& loader, const LazyLoadCallback& callback) inline Impl(const SchemaLoader& loader, const LazyLoadCallback& callback)
: initializer(loader, callback) {} : initializer(loader, callback) {}
_::RawSchema* load(const schema::Node::Reader& reader, bool isPlaceholder); _::RawSchema* load(const schema2::Node::Reader& reader, bool isPlaceholder);
_::RawSchema* loadNative(const _::RawSchema* nativeSchema); _::RawSchema* loadNative(const _::RawSchema* nativeSchema);
_::RawSchema* loadEmpty(uint64_t id, kj::StringPtr name, schema::Node::Body::Which kind); _::RawSchema* loadEmpty(uint64_t id, kj::StringPtr name, schema2::Node::Which kind);
// Create a dummy empty schema of the given kind for the given id and load it. // Create a dummy empty schema of the given kind for the given id and load it.
struct TryGetResult { struct TryGetResult {
...@@ -71,45 +71,82 @@ public: ...@@ -71,45 +71,82 @@ public:
TryGetResult tryGet(uint64_t typeId) const; TryGetResult tryGet(uint64_t typeId) const;
kj::Array<Schema> getAllLoaded() const; kj::Array<Schema> getAllLoaded() const;
void requireStructSize(uint64_t id, uint dataWordCount, uint pointerCount);
// Require any struct nodes loaded with this ID -- in the past and in the future -- to have at
// least the given sizes. Struct nodes that don't comply will simply be rewritten to comply.
// This is used to ensure that parents of group nodes have at least the size of the group node,
// so that allocating a struct that contains a group then getting the group node and setting
// its fields can't possibly write outside of the allocated space.
kj::Arena arena; kj::Arena arena;
private: private:
std::unordered_map<uint64_t, _::RawSchema*> schemas; std::unordered_map<uint64_t, _::RawSchema*> schemas;
struct RequiredSize {
uint16_t dataWordCount;
uint16_t pointerCount;
};
std::unordered_map<uint64_t, RequiredSize> structSizeRequirements;
InitializerImpl initializer; InitializerImpl initializer;
kj::ArrayPtr<word> makeUncheckedNode(schema2::Node::Reader node);
// Construct a copy of the given schema node, allocated as a single-segment ("unchecked") node
// within the loader's arena.
kj::ArrayPtr<word> makeUncheckedNodeEnforcingSizeRequirements(schema2::Node::Reader node);
// Like makeUncheckedNode() but if structSizeRequirements has a requirement for this node which
// is larger than the node claims to be, the size will be edited to comply. This should be rare.
// If the incoming node is not a struct, any struct size requirements will be ignored, but if
// such requirements exist, this indicates an inconsistency that could cause exceptions later on
// (but at least can't cause memory corruption).
kj::ArrayPtr<word> rewriteStructNodeWithSizes(
schema2::Node::Reader node, uint dataWordCount, uint pointerCount);
// Make a copy of the given node (which must be a struct node) and set its sizes to be the max
// of what it said already and the given sizes.
// If the encoded node does not meet the given struct size requirements, make a new copy that
// does.
void applyStructSizeRequirement(_::RawSchema* raw, uint dataWordCount, uint pointerCount);
}; };
// ======================================================================================= // =======================================================================================
inline static void verifyVoid(Void value) {}
// Calls to this will break if the parameter type changes to non-void. We use this to detect
// when the code needs updating.
class SchemaLoader::Validator { class SchemaLoader::Validator {
public: public:
Validator(SchemaLoader::Impl& loader): loader(loader) {} Validator(SchemaLoader::Impl& loader): loader(loader) {}
bool validate(const schema::Node::Reader& node) { bool validate(const schema2::Node::Reader& node) {
isValid = true; isValid = true;
nodeName = node.getDisplayName(); nodeName = node.getDisplayName();
dependencies.clear(); dependencies.clear();
KJ_CONTEXT("validating schema node", nodeName, (uint)node.getBody().which()); KJ_CONTEXT("validating schema node", nodeName, (uint)node.which());
switch (node.getBody().which()) { switch (node.which()) {
case schema::Node::Body::FILE_NODE: case schema2::Node::FILE:
validate(node.getBody().getFileNode()); verifyVoid(node.getFile());
break; break;
case schema::Node::Body::STRUCT_NODE: case schema2::Node::STRUCT:
validate(node.getBody().getStructNode()); validate(node.getStruct(), node.getScopeId());
break; break;
case schema::Node::Body::ENUM_NODE: case schema2::Node::ENUM:
validate(node.getBody().getEnumNode()); validate(node.getEnum());
break; break;
case schema::Node::Body::INTERFACE_NODE: case schema2::Node::INTERFACE:
validate(node.getBody().getInterfaceNode()); validate(node.getInterface());
break; break;
case schema::Node::Body::CONST_NODE: case schema2::Node::CONST:
validate(node.getBody().getConstNode()); validate(node.getConst());
break; break;
case schema::Node::Body::ANNOTATION_NODE: case schema2::Node::ANNOTATION:
validate(node.getBody().getAnnotationNode()); validate(node.getAnnotation());
break; break;
} }
...@@ -135,96 +172,76 @@ public: ...@@ -135,96 +172,76 @@ public:
loader.arena.allocateArray<_::RawSchema::MemberInfo>(*count); loader.arena.allocateArray<_::RawSchema::MemberInfo>(*count);
uint pos = 0; uint pos = 0;
for (auto& member: members) { for (auto& member: members) {
result[pos++] = {kj::implicitCast<uint16_t>(member.first.first), result[pos++] = member.second;
kj::implicitCast<uint16_t>(member.second)};
} }
KJ_DASSERT(pos == *count); KJ_DASSERT(pos == *count);
return result.begin(); return result.begin();
} }
const uint16_t* makeMembersByDiscriminantArray() {
return membersByDiscriminant.begin();
}
private: private:
SchemaLoader::Impl& loader; SchemaLoader::Impl& loader;
Text::Reader nodeName; Text::Reader nodeName;
bool isValid; bool isValid;
std::map<uint64_t, _::RawSchema*> dependencies; std::map<uint64_t, _::RawSchema*> dependencies;
// Maps (scopeOrdinal, name) -> index for each member. // Maps name -> index for each member.
std::map<std::pair<uint, Text::Reader>, uint> members; std::map<Text::Reader, uint> members;
kj::ArrayPtr<uint16_t> membersByDiscriminant;
#define VALIDATE_SCHEMA(condition, ...) \ #define VALIDATE_SCHEMA(condition, ...) \
KJ_REQUIRE(condition, ##__VA_ARGS__) { isValid = false; return; } KJ_REQUIRE(condition, ##__VA_ARGS__) { isValid = false; return; }
#define FAIL_VALIDATE_SCHEMA(...) \ #define FAIL_VALIDATE_SCHEMA(...) \
KJ_FAIL_REQUIRE(__VA_ARGS__) { isValid = false; return; } KJ_FAIL_REQUIRE(__VA_ARGS__) { isValid = false; return; }
void validate(const schema::FileNode::Reader& fileNode) { void validateMemberName(kj::StringPtr name, uint index) {
// Nothing needs validation. bool isNewName = members.insert(std::make_pair(name, index)).second;
} VALIDATE_SCHEMA(isNewName, "duplicate name", name);
uint countOrdinals(const List<schema::StructNode::Member>::Reader& members) {
uint result = 0;
for (auto member: members) {
switch (member.getBody().which()) {
case schema::StructNode::Member::Body::FIELD_MEMBER:
++result;
break;
case schema::StructNode::Member::Body::UNION_MEMBER: {
auto uMembers = member.getBody().getUnionMember().getMembers();
if (uMembers.size() == 0 || member.getOrdinal() != uMembers[0].getOrdinal()) {
// Union has explicit ordinal.
++result;
}
result += countOrdinals(uMembers);
break;
}
case schema::StructNode::Member::Body::GROUP_MEMBER:
result += countOrdinals(member.getBody().getGroupMember().getMembers());
break;
}
}
return result;
} }
void validate(const schema::StructNode::Reader& structNode) { void validate(const schema2::Node::Struct::Reader& structNode, uint64_t scopeId) {
uint dataSizeInBits; uint dataSizeInBits;
uint pointerCount; uint pointerCount;
switch (structNode.getPreferredListEncoding()) { switch (structNode.getPreferredListEncoding()) {
case schema::ElementSize::EMPTY: case schema2::ElementSize::EMPTY:
dataSizeInBits = 0; dataSizeInBits = 0;
pointerCount = 0; pointerCount = 0;
break; break;
case schema::ElementSize::BIT: case schema2::ElementSize::BIT:
dataSizeInBits = 1; dataSizeInBits = 1;
pointerCount = 0; pointerCount = 0;
break; break;
case schema::ElementSize::BYTE: case schema2::ElementSize::BYTE:
dataSizeInBits = 8; dataSizeInBits = 8;
pointerCount = 0; pointerCount = 0;
break; break;
case schema::ElementSize::TWO_BYTES: case schema2::ElementSize::TWO_BYTES:
dataSizeInBits = 16; dataSizeInBits = 16;
pointerCount = 0; pointerCount = 0;
break; break;
case schema::ElementSize::FOUR_BYTES: case schema2::ElementSize::FOUR_BYTES:
dataSizeInBits = 32; dataSizeInBits = 32;
pointerCount = 0; pointerCount = 0;
break; break;
case schema::ElementSize::EIGHT_BYTES: case schema2::ElementSize::EIGHT_BYTES:
dataSizeInBits = 64; dataSizeInBits = 64;
pointerCount = 0; pointerCount = 0;
break; break;
case schema::ElementSize::POINTER: case schema2::ElementSize::POINTER:
dataSizeInBits = 0; dataSizeInBits = 0;
pointerCount = 1; pointerCount = 1;
break; break;
case schema::ElementSize::INLINE_COMPOSITE: case schema2::ElementSize::INLINE_COMPOSITE:
dataSizeInBits = structNode.getDataSectionWordSize() * 64; dataSizeInBits = structNode.getDataSectionWordSize() * 64;
pointerCount = structNode.getPointerSectionSize(); pointerCount = structNode.getPointerSectionSize();
break; break;
default: default:
FAIL_VALIDATE_SCHEMA("Invalid preferredListEncoding."); FAIL_VALIDATE_SCHEMA("invalid preferredListEncoding");
dataSizeInBits = 0; dataSizeInBits = 0;
pointerCount = 0; pointerCount = 0;
break; break;
...@@ -232,139 +249,111 @@ private: ...@@ -232,139 +249,111 @@ private:
VALIDATE_SCHEMA(structNode.getDataSectionWordSize() == (dataSizeInBits + 63) / 64 && VALIDATE_SCHEMA(structNode.getDataSectionWordSize() == (dataSizeInBits + 63) / 64 &&
structNode.getPointerSectionSize() == pointerCount, structNode.getPointerSectionSize() == pointerCount,
"Struct size does not match preferredListEncoding."); "struct size does not match preferredListEncoding");
auto members = structNode.getMembers(); auto fields = structNode.getFields();
uint ordinalCount = countOrdinals(members);
KJ_STACK_ARRAY(bool, sawCodeOrder, members.size(), 32, 256); KJ_STACK_ARRAY(bool, sawCodeOrder, fields.size(), 32, 256);
memset(sawCodeOrder.begin(), 0, sawCodeOrder.size() * sizeof(sawCodeOrder[0])); memset(sawCodeOrder.begin(), 0, sawCodeOrder.size() * sizeof(sawCodeOrder[0]));
KJ_STACK_ARRAY(bool, sawOrdinal, ordinalCount, 32, 256);
memset(sawOrdinal.begin(), 0, sawOrdinal.size() * sizeof(sawOrdinal[0]));
uint index = 0; KJ_STACK_ARRAY(bool, sawDiscriminantValue, structNode.getDiscriminantCount(), 32, 256);
for (auto member: members) { memset(sawDiscriminantValue.begin(), 0,
KJ_CONTEXT("validating struct member", member.getName()); sawDiscriminantValue.size() * sizeof(sawDiscriminantValue[0]));
validate(member, sawCodeOrder, sawOrdinal, dataSizeInBits, pointerCount, 0, members.size(),
index++); if (structNode.getDiscriminantCount() > 0) {
VALIDATE_SCHEMA(structNode.getDiscriminantCount() != 1,
"union must have at least two members");
VALIDATE_SCHEMA(structNode.getDiscriminantCount() <= fields.size(),
"struct can't have more union fields than total fields");
VALIDATE_SCHEMA((structNode.getDiscriminantOffset() + 1) * 16 <= dataSizeInBits,
"union discriminant is out-of-bounds");
} }
}
void validateMemberName(kj::StringPtr name, uint scopeOrdinal, uint adjustedIndex) { membersByDiscriminant = loader.arena.allocateArray<uint16_t>(fields.size());
bool isNewName = members.insert(std::make_pair( uint discriminantPos = 0;
std::pair<uint, Text::Reader>(scopeOrdinal, name), adjustedIndex)).second; uint nonDiscriminantPos = structNode.getDiscriminantCount();
VALIDATE_SCHEMA(isNewName, "duplicate name", name);
}
void validate(const schema::StructNode::Member::Reader& member, uint index = 0;
kj::ArrayPtr<bool> sawCodeOrder, kj::ArrayPtr<bool> sawOrdinal, uint nextOrdinal = 0;
uint dataSizeInBits, uint pointerCount, for (auto field: fields) {
uint scopeOrdinal, uint scopeMemberCount, uint adjustedIndex) { KJ_CONTEXT("validating struct field", field.getName());
validateMemberName(member.getName(), scopeOrdinal, adjustedIndex);
VALIDATE_SCHEMA(member.getCodeOrder() < sawCodeOrder.size() &&
!sawCodeOrder[member.getCodeOrder()],
"Invalid codeOrder.");
sawCodeOrder[member.getCodeOrder()] = true;
switch (member.getBody().which()) {
case schema::StructNode::Member::Body::FIELD_MEMBER: {
VALIDATE_SCHEMA(member.getOrdinal() < sawOrdinal.size() &&
!sawOrdinal[member.getOrdinal()],
"Invalid ordinal.", member.getOrdinal());
sawOrdinal[member.getOrdinal()] = true;
auto field = member.getBody().getFieldMember();
uint fieldBits;
bool fieldIsPointer;
validate(field.getType(), field.getDefaultValue(), &fieldBits, &fieldIsPointer);
VALIDATE_SCHEMA(fieldBits * (field.getOffset() + 1) <= dataSizeInBits &&
fieldIsPointer * (field.getOffset() + 1) <= pointerCount,
"field offset out-of-bounds",
field.getOffset(), dataSizeInBits, pointerCount);
break;
}
case schema::StructNode::Member::Body::UNION_MEMBER: { validateMemberName(field.getName(), index);
auto u = member.getBody().getUnionMember(); VALIDATE_SCHEMA(field.getCodeOrder() < sawCodeOrder.size() &&
!sawCodeOrder[field.getCodeOrder()],
VALIDATE_SCHEMA((u.getDiscriminantOffset() + 1) * 16 <= dataSizeInBits, "invalid codeOrder");
"Schema invalid: Union discriminant out-of-bounds."); sawCodeOrder[field.getCodeOrder()] = true;
auto uMembers = u.getMembers();
VALIDATE_SCHEMA(uMembers.size() >= 2, "Union must have at least two members.");
KJ_STACK_ARRAY(bool, uSawCodeOrder, uMembers.size(), 32, 256);
memset(uSawCodeOrder.begin(), 0, uSawCodeOrder.size() * sizeof(uSawCodeOrder[0]));
uint subIndex = 0;
for (auto uMember: uMembers) {
KJ_CONTEXT("validating union member", uMember.getName());
VALIDATE_SCHEMA(
uMember.getBody().which() != schema::StructNode::Member::Body::UNION_MEMBER,
"Union members must be fields or groups.");
uint subScopeOrdinal;
uint indexAdjustment;
if (member.getName().size() == 0) {
subScopeOrdinal = scopeOrdinal;
indexAdjustment = scopeMemberCount;
} else {
subScopeOrdinal = member.getOrdinal() + 1;
indexAdjustment = 0;
}
validate(uMember, uSawCodeOrder, sawOrdinal, dataSizeInBits, pointerCount,
subScopeOrdinal, uMembers.size(), subIndex++ + indexAdjustment);
}
// Union ordinal may match the ordinal of its first member, meaning it was unspecified in auto ordinal = field.getOrdinal();
// the schema file. Otherwise, it must be unique. if (ordinal.which() == schema2::Field::Ordinal::EXPLICIT) {
if (member.getOrdinal() != uMembers[0].getOrdinal()) { VALIDATE_SCHEMA(ordinal.getExplicit() >= nextOrdinal,
VALIDATE_SCHEMA(member.getOrdinal() < sawOrdinal.size() && "fields were not ordered by ordinal");
!sawOrdinal[member.getOrdinal()], nextOrdinal = ordinal.getExplicit() + 1;
"Invalid ordinal.", member.getOrdinal());
sawOrdinal[member.getOrdinal()] = true;
}
break;
} }
case schema::StructNode::Member::Body::GROUP_MEMBER: { if (field.hasDiscriminantValue()) {
auto g = member.getBody().getGroupMember(); VALIDATE_SCHEMA(field.getDiscriminantValue() < sawDiscriminantValue.size() &&
!sawDiscriminantValue[field.getDiscriminantValue()],
auto gMembers = g.getMembers(); "invalid discriminantValue");
VALIDATE_SCHEMA(gMembers.size() >= 2, "Group must have at least two members."); sawDiscriminantValue[field.getDiscriminantValue()] = true;
membersByDiscriminant[discriminantPos++] = index;
} else {
VALIDATE_SCHEMA(nonDiscriminantPos <= fields.size(),
"discriminantCount did not match fields");
membersByDiscriminant[nonDiscriminantPos++] = index;
}
KJ_STACK_ARRAY(bool, uSawCodeOrder, gMembers.size(), 32, 256); switch (field.which()) {
memset(uSawCodeOrder.begin(), 0, uSawCodeOrder.size() * sizeof(uSawCodeOrder[0])); case schema2::Field::REGULAR: {
auto regularField = field.getRegular();
uint subIndex = 0; uint fieldBits;
for (auto gMember: gMembers) { bool fieldIsPointer;
KJ_CONTEXT("validating group member", gMember.getName()); validate(regularField.getType(), regularField.getDefaultValue(),
VALIDATE_SCHEMA( &fieldBits, &fieldIsPointer);
gMember.getBody().which() != schema::StructNode::Member::Body::GROUP_MEMBER, VALIDATE_SCHEMA(fieldBits * (regularField.getOffset() + 1) <= dataSizeInBits &&
"Group members must be fields or unions."); fieldIsPointer * (regularField.getOffset() + 1) <= pointerCount,
"field offset out-of-bounds",
regularField.getOffset(), dataSizeInBits, pointerCount);
validate(gMember, uSawCodeOrder, sawOrdinal, dataSizeInBits, pointerCount, break;
member.getOrdinal() + 1, gMembers.size(), subIndex++);
} }
// Group ordinal must match the ordinal of its first member. case schema2::Field::GROUP:
VALIDATE_SCHEMA(member.getOrdinal() == gMembers[0].getOrdinal(), // Require that the group is a struct node.
"Invalid ordinal.", member.getOrdinal()); validateTypeId(field.getGroup(), schema2::Node::STRUCT);
break; break;
} }
++index;
} }
}
void validate(const schema::EnumNode::Reader& enumNode) { // If the above code is correct, these should pass.
auto enumerants = enumNode.getEnumerants(); KJ_ASSERT(discriminantPos == structNode.getDiscriminantCount());
KJ_ASSERT(nonDiscriminantPos == fields.size());
if (structNode.getIsGroup()) {
VALIDATE_SCHEMA(scopeId != 0, "group node missing scopeId");
// Require that the group's scope has at least the same size as the group, so that anyone
// constructing an instance of the outer scope can safely read/write the group.
loader.requireStructSize(scopeId, structNode.getDataSectionWordSize(),
structNode.getPointerSectionSize());
// Require that the parent type is a struct.
validateTypeId(scopeId, schema2::Node::STRUCT);
}
}
void validate(const List<schema2::Enumerant>::Reader& enumerants) {
KJ_STACK_ARRAY(bool, sawCodeOrder, enumerants.size(), 32, 256); KJ_STACK_ARRAY(bool, sawCodeOrder, enumerants.size(), 32, 256);
memset(sawCodeOrder.begin(), 0, sawCodeOrder.size() * sizeof(sawCodeOrder[0])); memset(sawCodeOrder.begin(), 0, sawCodeOrder.size() * sizeof(sawCodeOrder[0]));
uint index = 0; uint index = 0;
for (auto enumerant: enumerants) { for (auto enumerant: enumerants) {
validateMemberName(enumerant.getName(), 0, index++); validateMemberName(enumerant.getName(), index++);
VALIDATE_SCHEMA(enumerant.getCodeOrder() < enumerants.size() && VALIDATE_SCHEMA(enumerant.getCodeOrder() < enumerants.size() &&
!sawCodeOrder[enumerant.getCodeOrder()], !sawCodeOrder[enumerant.getCodeOrder()],
...@@ -373,16 +362,14 @@ private: ...@@ -373,16 +362,14 @@ private:
} }
} }
void validate(const schema::InterfaceNode::Reader& interfaceNode) { void validate(const List<schema2::Method>::Reader& methods) {
auto methods = interfaceNode.getMethods();
KJ_STACK_ARRAY(bool, sawCodeOrder, methods.size(), 32, 256); KJ_STACK_ARRAY(bool, sawCodeOrder, methods.size(), 32, 256);
memset(sawCodeOrder.begin(), 0, sawCodeOrder.size() * sizeof(sawCodeOrder[0])); memset(sawCodeOrder.begin(), 0, sawCodeOrder.size() * sizeof(sawCodeOrder[0]));
uint index = 0; uint index = 0;
for (auto method: methods) { for (auto method: methods) {
KJ_CONTEXT("validating method", method.getName()); KJ_CONTEXT("validating method", method.getName());
validateMemberName(method.getName(), 0, index++); validateMemberName(method.getName(), index++);
VALIDATE_SCHEMA(method.getCodeOrder() < methods.size() && VALIDATE_SCHEMA(method.getCodeOrder() < methods.size() &&
!sawCodeOrder[method.getCodeOrder()], !sawCodeOrder[method.getCodeOrder()],
...@@ -403,26 +390,26 @@ private: ...@@ -403,26 +390,26 @@ private:
} }
} }
void validate(const schema::ConstNode::Reader& constNode) { void validate(const schema2::Node::Const::Reader& constNode) {
uint dummy1; uint dummy1;
bool dummy2; bool dummy2;
validate(constNode.getType(), constNode.getValue(), &dummy1, &dummy2); validate(constNode.getType(), constNode.getValue(), &dummy1, &dummy2);
} }
void validate(const schema::AnnotationNode::Reader& annotationNode) { void validate(const schema2::Node::Annotation::Reader& annotationNode) {
validate(annotationNode.getType()); validate(annotationNode.getType());
} }
void validate(const schema::Type::Reader& type, const schema::Value::Reader& value, void validate(const schema2::Type::Reader& type, const schema2::Value::Reader& value,
uint* dataSizeInBits, bool* isPointer) { uint* dataSizeInBits, bool* isPointer) {
validate(type); validate(type);
schema::Value::Body::Which expectedValueType = schema::Value::Body::VOID_VALUE; schema2::Value::Which expectedValueType = schema2::Value::VOID;
bool hadCase = false; bool hadCase = false;
switch (type.getBody().which()) { switch (type.which()) {
#define HANDLE_TYPE(name, bits, ptr) \ #define HANDLE_TYPE(name, bits, ptr) \
case schema::Type::Body::name##_TYPE: \ case schema2::Type::name: \
expectedValueType = schema::Value::Body::name##_VALUE; \ expectedValueType = schema2::Value::name; \
*dataSizeInBits = bits; *isPointer = ptr; \ *dataSizeInBits = bits; *isPointer = ptr; \
hadCase = true; \ hadCase = true; \
break; break;
...@@ -449,54 +436,54 @@ private: ...@@ -449,54 +436,54 @@ private:
} }
if (hadCase) { if (hadCase) {
VALIDATE_SCHEMA(value.getBody().which() == expectedValueType, "Value did not match type."); VALIDATE_SCHEMA(value.which() == expectedValueType, "Value did not match type.");
} }
} }
void validate(const schema::Type::Reader& type) { void validate(const schema2::Type::Reader& type) {
switch (type.getBody().which()) { switch (type.which()) {
case schema::Type::Body::VOID_TYPE: case schema2::Type::VOID:
case schema::Type::Body::BOOL_TYPE: case schema2::Type::BOOL:
case schema::Type::Body::INT8_TYPE: case schema2::Type::INT8:
case schema::Type::Body::INT16_TYPE: case schema2::Type::INT16:
case schema::Type::Body::INT32_TYPE: case schema2::Type::INT32:
case schema::Type::Body::INT64_TYPE: case schema2::Type::INT64:
case schema::Type::Body::UINT8_TYPE: case schema2::Type::UINT8:
case schema::Type::Body::UINT16_TYPE: case schema2::Type::UINT16:
case schema::Type::Body::UINT32_TYPE: case schema2::Type::UINT32:
case schema::Type::Body::UINT64_TYPE: case schema2::Type::UINT64:
case schema::Type::Body::FLOAT32_TYPE: case schema2::Type::FLOAT32:
case schema::Type::Body::FLOAT64_TYPE: case schema2::Type::FLOAT64:
case schema::Type::Body::TEXT_TYPE: case schema2::Type::TEXT:
case schema::Type::Body::DATA_TYPE: case schema2::Type::DATA:
case schema::Type::Body::OBJECT_TYPE: case schema2::Type::OBJECT:
break; break;
case schema::Type::Body::STRUCT_TYPE: case schema2::Type::STRUCT:
validateTypeId(type.getBody().getStructType(), schema::Node::Body::STRUCT_NODE); validateTypeId(type.getStruct(), schema2::Node::STRUCT);
break; break;
case schema::Type::Body::ENUM_TYPE: case schema2::Type::ENUM:
validateTypeId(type.getBody().getEnumType(), schema::Node::Body::ENUM_NODE); validateTypeId(type.getEnum(), schema2::Node::ENUM);
break; break;
case schema::Type::Body::INTERFACE_TYPE: case schema2::Type::INTERFACE:
validateTypeId(type.getBody().getInterfaceType(), schema::Node::Body::INTERFACE_NODE); validateTypeId(type.getInterface(), schema2::Node::INTERFACE);
break; break;
case schema::Type::Body::LIST_TYPE: case schema2::Type::LIST:
validate(type.getBody().getListType()); validate(type.getList());
break; break;
} }
// We intentionally allow unknown types. // We intentionally allow unknown types.
} }
void validateTypeId(uint64_t id, schema::Node::Body::Which expectedKind) { void validateTypeId(uint64_t id, schema2::Node::Which expectedKind) {
_::RawSchema* existing = loader.tryGet(id).schema; _::RawSchema* existing = loader.tryGet(id).schema;
if (existing != nullptr) { if (existing != nullptr) {
auto node = readMessageUnchecked<schema::Node>(existing->encodedNode); auto node = readMessageUnchecked<schema2::Node>(existing->encodedNode);
VALIDATE_SCHEMA(node.getBody().which() == expectedKind, VALIDATE_SCHEMA(node.which() == expectedKind,
"expected a different kind of node for this ID", "expected a different kind of node for this ID",
id, (uint)expectedKind, (uint)node.getBody().which(), node.getDisplayName()); id, (uint)expectedKind, (uint)node.which(), node.getDisplayName());
dependencies.insert(std::make_pair(id, existing)); dependencies.insert(std::make_pair(id, existing));
return; return;
} }
...@@ -515,8 +502,8 @@ class SchemaLoader::CompatibilityChecker { ...@@ -515,8 +502,8 @@ class SchemaLoader::CompatibilityChecker {
public: public:
CompatibilityChecker(SchemaLoader::Impl& loader): loader(loader) {} CompatibilityChecker(SchemaLoader::Impl& loader): loader(loader) {}
bool shouldReplace(const schema::Node::Reader& existingNode, bool shouldReplace(const schema2::Node::Reader& existingNode,
const schema::Node::Reader& replacement, const schema2::Node::Reader& replacement,
bool preferReplacementIfEquivalent) { bool preferReplacementIfEquivalent) {
KJ_CONTEXT("checking compatibility with previously-loaded node of the same id", KJ_CONTEXT("checking compatibility with previously-loaded node of the same id",
existingNode.getDisplayName()); existingNode.getDisplayName());
...@@ -581,53 +568,44 @@ private: ...@@ -581,53 +568,44 @@ private:
} }
} }
void checkCompatibility(const schema::Node::Reader& node, void checkCompatibility(const schema2::Node::Reader& node,
const schema::Node::Reader& replacement) { const schema2::Node::Reader& replacement) {
// Returns whether `replacement` is equivalent, older than, newer than, or incompatible with // Returns whether `replacement` is equivalent, older than, newer than, or incompatible with
// `node`. If exceptions are enabled, this will throw an exception on INCOMPATIBLE. // `node`. If exceptions are enabled, this will throw an exception on INCOMPATIBLE.
VALIDATE_SCHEMA(node.getBody().which() == replacement.getBody().which(), VALIDATE_SCHEMA(node.which() == replacement.which(),
"kind of declaration changed"); "kind of declaration changed");
// No need to check compatibility of the non-body parts of the node: // No need to check compatibility of the non-body parts of the node:
// - Arbitrary renaming and moving between scopes is allowed. // - Arbitrary renaming and moving between scopes is allowed.
// - Annotations are ignored for compatibility purposes. // - Annotations are ignored for compatibility purposes.
switch (node.getBody().which()) { switch (node.which()) {
case schema::Node::Body::FILE_NODE: case schema2::Node::FILE:
checkCompatibility(node.getBody().getFileNode(), verifyVoid(node.getFile());
replacement.getBody().getFileNode());
break; break;
case schema::Node::Body::STRUCT_NODE: case schema2::Node::STRUCT:
checkCompatibility(node.getBody().getStructNode(), checkCompatibility(node.getStruct(), replacement.getStruct(),
replacement.getBody().getStructNode()); node.getScopeId(), replacement.getScopeId());
break; break;
case schema::Node::Body::ENUM_NODE: case schema2::Node::ENUM:
checkCompatibility(node.getBody().getEnumNode(), checkCompatibility(node.getEnum(), replacement.getEnum());
replacement.getBody().getEnumNode());
break; break;
case schema::Node::Body::INTERFACE_NODE: case schema2::Node::INTERFACE:
checkCompatibility(node.getBody().getInterfaceNode(), checkCompatibility(node.getInterface(), replacement.getInterface());
replacement.getBody().getInterfaceNode());
break; break;
case schema::Node::Body::CONST_NODE: case schema2::Node::CONST:
checkCompatibility(node.getBody().getConstNode(), checkCompatibility(node.getConst(), replacement.getConst());
replacement.getBody().getConstNode());
break; break;
case schema::Node::Body::ANNOTATION_NODE: case schema2::Node::ANNOTATION:
checkCompatibility(node.getBody().getAnnotationNode(), checkCompatibility(node.getAnnotation(), replacement.getAnnotation());
replacement.getBody().getAnnotationNode());
break; break;
} }
} }
void checkCompatibility(const schema::FileNode::Reader& file, void checkCompatibility(const schema2::Node::Struct::Reader& structNode,
const schema::FileNode::Reader& replacement) { const schema2::Node::Struct::Reader& replacement,
// Nothing to compare. uint64_t scopeId, uint64_t replacementScopeId) {
}
void checkCompatibility(const schema::StructNode::Reader& structNode,
const schema::StructNode::Reader& replacement) {
if (replacement.getDataSectionWordSize() > structNode.getDataSectionWordSize()) { if (replacement.getDataSectionWordSize() > structNode.getDataSectionWordSize()) {
replacementIsNewer(); replacementIsNewer();
} else if (replacement.getDataSectionWordSize() < structNode.getDataSectionWordSize()) { } else if (replacement.getDataSectionWordSize() < structNode.getDataSectionWordSize()) {
...@@ -651,87 +629,71 @@ private: ...@@ -651,87 +629,71 @@ private:
// The shared members should occupy corresponding positions in the member lists, since the // The shared members should occupy corresponding positions in the member lists, since the
// lists are sorted by ordinal. // lists are sorted by ordinal.
auto members = structNode.getMembers(); auto fields = structNode.getFields();
auto replacementMembers = replacement.getMembers(); auto replacementFields = replacement.getFields();
uint count = std::min(members.size(), replacementMembers.size()); uint count = std::min(fields.size(), replacementFields.size());
if (replacementMembers.size() > members.size()) { if (replacementFields.size() > fields.size()) {
replacementIsNewer(); replacementIsNewer();
} else if (replacementMembers.size() < members.size()) { } else if (replacementFields.size() < fields.size()) {
replacementIsOlder(); replacementIsOlder();
} }
for (uint i = 0; i < count; i++) { for (uint i = 0; i < count; i++) {
checkCompatibility(members[i], replacementMembers[i]); checkCompatibility(fields[i], replacementFields[i]);
}
// For the moment, we allow "upgrading" from non-group to group, mainly so that the
// placeholders we generate for group parents (which in the absence of more info, we assume to
// be non-groups) can be replaced with groups.
//
// TODO(cleanup): The placeholder approach is really breaking down. Maybe we need to maintain
// a list of expectations for nodes we haven't loaded yet.
if (structNode.getIsGroup()) {
if (replacement.getIsGroup()) {
VALIDATE_SCHEMA(replacementScopeId == scopeId, "group node's scope changed");
} else {
replacementIsOlder();
}
} else {
if (replacement.getIsGroup()) {
replacementIsNewer();
}
} }
} }
void checkCompatibility(const schema::StructNode::Member::Reader& member, void checkCompatibility(const schema2::Field::Reader& field,
const schema::StructNode::Member::Reader& replacement) { const schema2::Field::Reader& replacement) {
KJ_CONTEXT("comparing struct member", member.getName()); KJ_CONTEXT("comparing struct field", field.getName());
switch (member.getBody().which()) { VALIDATE_SCHEMA(field.which() == replacement.which(),
case schema::StructNode::Member::Body::FIELD_MEMBER: { "group field replaced with non-group or vice versa");
auto field = member.getBody().getFieldMember();
auto replacementField = replacement.getBody().getFieldMember();
checkCompatibility(field.getType(), replacementField.getType(), switch (field.which()) {
case schema2::Field::REGULAR: {
auto regularField = field.getRegular();
auto replacementRegularField = replacement.getRegular();
checkCompatibility(regularField.getType(), replacementRegularField.getType(),
NO_UPGRADE_TO_STRUCT); NO_UPGRADE_TO_STRUCT);
checkDefaultCompatibility(field.getDefaultValue(), replacementField.getDefaultValue()); checkDefaultCompatibility(regularField.getDefaultValue(),
replacementRegularField.getDefaultValue());
VALIDATE_SCHEMA(field.getOffset() == replacementField.getOffset(), VALIDATE_SCHEMA(regularField.getOffset() == replacementRegularField.getOffset(),
"field position changed"); "field position changed");
break; break;
} }
case schema::StructNode::Member::Body::UNION_MEMBER: {
auto existingUnion = member.getBody().getUnionMember();
auto replacementUnion = replacement.getBody().getUnionMember();
VALIDATE_SCHEMA(
existingUnion.getDiscriminantOffset() == replacementUnion.getDiscriminantOffset(),
"union discriminant position changed");
auto members = existingUnion.getMembers();
auto replacementMembers = replacementUnion.getMembers();
uint count = std::min(members.size(), replacementMembers.size());
if (replacementMembers.size() > members.size()) {
replacementIsNewer();
} else if (replacementMembers.size() < members.size()) {
replacementIsOlder();
}
for (uint i = 0; i < count; i++) {
checkCompatibility(members[i], replacementMembers[i]);
}
break;
}
case schema::StructNode::Member::Body::GROUP_MEMBER: {
auto existingGroup = member.getBody().getGroupMember();
auto replacementGroup = replacement.getBody().getGroupMember();
auto members = existingGroup.getMembers();
auto replacementMembers = replacementGroup.getMembers();
uint count = std::min(members.size(), replacementMembers.size());
if (replacementMembers.size() > members.size()) {
replacementIsNewer();
} else if (replacementMembers.size() < members.size()) {
replacementIsOlder();
}
for (uint i = 0; i < count; i++) { case schema2::Field::GROUP:
checkCompatibility(members[i], replacementMembers[i]); VALIDATE_SCHEMA(field.getGroup() == replacement.getGroup(), "group id changed");
}
break; break;
}
} }
} }
void checkCompatibility(const schema::EnumNode::Reader& enumNode, void checkCompatibility(const List<schema2::Enumerant>::Reader& enumerants,
const schema::EnumNode::Reader& replacement) { const List<schema2::Enumerant>::Reader& replacementEnumerants) {
uint size = enumNode.getEnumerants().size(); uint size = enumerants.size();
uint replacementSize = replacement.getEnumerants().size(); uint replacementSize = replacementEnumerants.size();
if (replacementSize > size) { if (replacementSize > size) {
replacementIsNewer(); replacementIsNewer();
} else if (replacementSize < size) { } else if (replacementSize < size) {
...@@ -739,11 +701,8 @@ private: ...@@ -739,11 +701,8 @@ private:
} }
} }
void checkCompatibility(const schema::InterfaceNode::Reader& interfaceNode, void checkCompatibility(const List<schema2::Method>::Reader& methods,
const schema::InterfaceNode::Reader& replacement) { const List<schema2::Method>::Reader& replacementMethods) {
auto methods = interfaceNode.getMethods();
auto replacementMethods = replacement.getMethods();
if (replacementMethods.size() > methods.size()) { if (replacementMethods.size() > methods.size()) {
replacementIsNewer(); replacementIsNewer();
} else if (replacementMethods.size() < methods.size()) { } else if (replacementMethods.size() < methods.size()) {
...@@ -757,8 +716,8 @@ private: ...@@ -757,8 +716,8 @@ private:
} }
} }
void checkCompatibility(const schema::InterfaceNode::Method::Reader& method, void checkCompatibility(const schema2::Method::Reader& method,
const schema::InterfaceNode::Method::Reader& replacement) { const schema2::Method::Reader& replacement) {
KJ_CONTEXT("comparing method", method.getName()); KJ_CONTEXT("comparing method", method.getName());
auto params = method.getParams(); auto params = method.getParams();
...@@ -797,13 +756,13 @@ private: ...@@ -797,13 +756,13 @@ private:
ALLOW_UPGRADE_TO_STRUCT); ALLOW_UPGRADE_TO_STRUCT);
} }
void checkCompatibility(const schema::ConstNode::Reader& constNode, void checkCompatibility(const schema2::Node::Const::Reader& constNode,
const schema::ConstNode::Reader& replacement) { const schema2::Node::Const::Reader& replacement) {
// Who cares? These don't appear on the wire. // Who cares? These don't appear on the wire.
} }
void checkCompatibility(const schema::AnnotationNode::Reader& annotationNode, void checkCompatibility(const schema2::Node::Annotation::Reader& annotationNode,
const schema::AnnotationNode::Reader& replacement) { const schema2::Node::Annotation::Reader& replacement) {
// Who cares? These don't appear on the wire. // Who cares? These don't appear on the wire.
} }
...@@ -812,35 +771,31 @@ private: ...@@ -812,35 +771,31 @@ private:
NO_UPGRADE_TO_STRUCT NO_UPGRADE_TO_STRUCT
}; };
void checkCompatibility(const schema::Type::Reader& type, void checkCompatibility(const schema2::Type::Reader& type,
const schema::Type::Reader& replacement, const schema2::Type::Reader& replacement,
UpgradeToStructMode upgradeToStructMode) { UpgradeToStructMode upgradeToStructMode) {
if (replacement.getBody().which() != type.getBody().which()) { if (replacement.which() != type.which()) {
// Check for allowed "upgrade" to Data or Object. // Check for allowed "upgrade" to Data or Object.
if (replacement.getBody().which() == schema::Type::Body::DATA_TYPE && if (replacement.which() == schema2::Type::DATA && canUpgradeToData(type)) {
canUpgradeToData(type)) {
replacementIsNewer(); replacementIsNewer();
return; return;
} else if (type.getBody().which() == schema::Type::Body::DATA_TYPE && } else if (type.which() == schema2::Type::DATA && canUpgradeToData(replacement)) {
canUpgradeToData(replacement)) {
replacementIsOlder(); replacementIsOlder();
return; return;
} else if (replacement.getBody().which() == schema::Type::Body::OBJECT_TYPE && } else if (replacement.which() == schema2::Type::OBJECT && canUpgradeToObject(type)) {
canUpgradeToObject(type)) {
replacementIsNewer(); replacementIsNewer();
return; return;
} else if (type.getBody().which() == schema::Type::Body::OBJECT_TYPE && } else if (type.which() == schema2::Type::OBJECT && canUpgradeToObject(replacement)) {
canUpgradeToObject(replacement)) {
replacementIsOlder(); replacementIsOlder();
return; return;
} }
if (upgradeToStructMode == ALLOW_UPGRADE_TO_STRUCT) { if (upgradeToStructMode == ALLOW_UPGRADE_TO_STRUCT) {
if (type.getBody().which() == schema::Type::Body::STRUCT_TYPE) { if (type.which() == schema2::Type::STRUCT) {
checkUpgradeToStruct(replacement, type.getBody().getStructType()); checkUpgradeToStruct(replacement, type.getStruct());
return; return;
} else if (replacement.getBody().which() == schema::Type::Body::STRUCT_TYPE) { } else if (replacement.which() == schema2::Type::STRUCT) {
checkUpgradeToStruct(type, replacement.getBody().getStructType()); checkUpgradeToStruct(type, replacement.getStruct());
return; return;
} }
} }
...@@ -848,36 +803,33 @@ private: ...@@ -848,36 +803,33 @@ private:
FAIL_VALIDATE_SCHEMA("a type was changed"); FAIL_VALIDATE_SCHEMA("a type was changed");
} }
switch (type.getBody().which()) { switch (type.which()) {
case schema::Type::Body::VOID_TYPE: case schema2::Type::VOID:
case schema::Type::Body::BOOL_TYPE: case schema2::Type::BOOL:
case schema::Type::Body::INT8_TYPE: case schema2::Type::INT8:
case schema::Type::Body::INT16_TYPE: case schema2::Type::INT16:
case schema::Type::Body::INT32_TYPE: case schema2::Type::INT32:
case schema::Type::Body::INT64_TYPE: case schema2::Type::INT64:
case schema::Type::Body::UINT8_TYPE: case schema2::Type::UINT8:
case schema::Type::Body::UINT16_TYPE: case schema2::Type::UINT16:
case schema::Type::Body::UINT32_TYPE: case schema2::Type::UINT32:
case schema::Type::Body::UINT64_TYPE: case schema2::Type::UINT64:
case schema::Type::Body::FLOAT32_TYPE: case schema2::Type::FLOAT32:
case schema::Type::Body::FLOAT64_TYPE: case schema2::Type::FLOAT64:
case schema::Type::Body::TEXT_TYPE: case schema2::Type::TEXT:
case schema::Type::Body::DATA_TYPE: case schema2::Type::DATA:
case schema::Type::Body::OBJECT_TYPE: case schema2::Type::OBJECT:
return; return;
case schema::Type::Body::LIST_TYPE: case schema2::Type::LIST:
checkCompatibility(type.getBody().getListType(), checkCompatibility(type.getList(), replacement.getList(), ALLOW_UPGRADE_TO_STRUCT);
replacement.getBody().getListType(),
ALLOW_UPGRADE_TO_STRUCT);
return; return;
case schema::Type::Body::ENUM_TYPE: case schema2::Type::ENUM:
VALIDATE_SCHEMA(replacement.getBody().getEnumType() == type.getBody().getEnumType(), VALIDATE_SCHEMA(replacement.getEnum() == type.getEnum(), "type changed enum type");
"type changed enum type");
return; return;
case schema::Type::Body::STRUCT_TYPE: case schema2::Type::STRUCT:
// TODO(someday): If the IDs don't match, we should compare the two structs for // TODO(someday): If the IDs don't match, we should compare the two structs for
// compatibility. This is tricky, though, because the new type's target may not yet be // compatibility. This is tricky, though, because the new type's target may not yet be
// loaded. In that case we could take the old type, make a copy of it, assign the new // loaded. In that case we could take the old type, make a copy of it, assign the new
...@@ -885,21 +837,20 @@ private: ...@@ -885,21 +837,20 @@ private:
// be compatible. However, that has another problem, which is that it could be that the // be compatible. However, that has another problem, which is that it could be that the
// whole reason the type was replaced was to fork that type, and so an incompatibility // whole reason the type was replaced was to fork that type, and so an incompatibility
// could be very much expected. This could be a rat hole... // could be very much expected. This could be a rat hole...
VALIDATE_SCHEMA(replacement.getBody().getStructType() == type.getBody().getStructType(), VALIDATE_SCHEMA(replacement.getStruct() == type.getStruct(),
"type changed to incompatible struct type"); "type changed to incompatible struct type");
return; return;
case schema::Type::Body::INTERFACE_TYPE: case schema2::Type::INTERFACE:
VALIDATE_SCHEMA( VALIDATE_SCHEMA(replacement.getInterface() == type.getInterface(),
replacement.getBody().getInterfaceType() == type.getBody().getInterfaceType(), "type changed to incompatible interface type");
"type changed to incompatible interface type");
return; return;
} }
// We assume unknown types (from newer versions of Cap'n Proto?) are equivalent. // We assume unknown types (from newer versions of Cap'n Proto?) are equivalent.
} }
void checkUpgradeToStruct(const schema::Type::Reader& type, uint64_t structTypeId) { void checkUpgradeToStruct(const schema2::Type::Reader& type, uint64_t structTypeId) {
// We can't just look up the target struct and check it because it may not have been loaded // We can't just look up the target struct and check it because it may not have been loaded
// yet. Instead, we contrive a struct that looks like what we want and load() that, which // yet. Instead, we contrive a struct that looks like what we want and load() that, which
// guarantees that any incompatibility will be caught either now or when the real version of // guarantees that any incompatibility will be caught either now or when the real version of
...@@ -907,84 +858,84 @@ private: ...@@ -907,84 +858,84 @@ private:
word scratch[32]; word scratch[32];
memset(scratch, 0, sizeof(scratch)); memset(scratch, 0, sizeof(scratch));
MallocMessageBuilder builder(kj::arrayPtr(scratch, sizeof(scratch))); MallocMessageBuilder builder(scratch);
auto node = builder.initRoot<schema::Node>(); auto node = builder.initRoot<schema2::Node>();
node.setId(structTypeId); node.setId(structTypeId);
node.setDisplayName(kj::str("(unknown type used in ", nodeName, ")")); node.setDisplayName(kj::str("(unknown type used in ", nodeName, ")"));
auto structNode = node.getBody().initStructNode(); auto structNode = node.initStruct();
switch (type.getBody().which()) { switch (type.which()) {
case schema::Type::Body::VOID_TYPE: case schema2::Type::VOID:
structNode.setDataSectionWordSize(0); structNode.setDataSectionWordSize(0);
structNode.setPointerSectionSize(0); structNode.setPointerSectionSize(0);
structNode.setPreferredListEncoding(schema::ElementSize::EMPTY); structNode.setPreferredListEncoding(schema2::ElementSize::EMPTY);
break; break;
case schema::Type::Body::BOOL_TYPE: case schema2::Type::BOOL:
structNode.setDataSectionWordSize(1); structNode.setDataSectionWordSize(1);
structNode.setPointerSectionSize(0); structNode.setPointerSectionSize(0);
structNode.setPreferredListEncoding(schema::ElementSize::BIT); structNode.setPreferredListEncoding(schema2::ElementSize::BIT);
break; break;
case schema::Type::Body::INT8_TYPE: case schema2::Type::INT8:
case schema::Type::Body::UINT8_TYPE: case schema2::Type::UINT8:
structNode.setDataSectionWordSize(1); structNode.setDataSectionWordSize(1);
structNode.setPointerSectionSize(0); structNode.setPointerSectionSize(0);
structNode.setPreferredListEncoding(schema::ElementSize::BYTE); structNode.setPreferredListEncoding(schema2::ElementSize::BYTE);
break; break;
case schema::Type::Body::INT16_TYPE: case schema2::Type::INT16:
case schema::Type::Body::UINT16_TYPE: case schema2::Type::UINT16:
case schema::Type::Body::ENUM_TYPE: case schema2::Type::ENUM:
structNode.setDataSectionWordSize(1); structNode.setDataSectionWordSize(1);
structNode.setPointerSectionSize(0); structNode.setPointerSectionSize(0);
structNode.setPreferredListEncoding(schema::ElementSize::TWO_BYTES); structNode.setPreferredListEncoding(schema2::ElementSize::TWO_BYTES);
break; break;
case schema::Type::Body::INT32_TYPE: case schema2::Type::INT32:
case schema::Type::Body::UINT32_TYPE: case schema2::Type::UINT32:
case schema::Type::Body::FLOAT32_TYPE: case schema2::Type::FLOAT32:
structNode.setDataSectionWordSize(1); structNode.setDataSectionWordSize(1);
structNode.setPointerSectionSize(0); structNode.setPointerSectionSize(0);
structNode.setPreferredListEncoding(schema::ElementSize::FOUR_BYTES); structNode.setPreferredListEncoding(schema2::ElementSize::FOUR_BYTES);
break; break;
case schema::Type::Body::INT64_TYPE: case schema2::Type::INT64:
case schema::Type::Body::UINT64_TYPE: case schema2::Type::UINT64:
case schema::Type::Body::FLOAT64_TYPE: case schema2::Type::FLOAT64:
structNode.setDataSectionWordSize(1); structNode.setDataSectionWordSize(1);
structNode.setPointerSectionSize(0); structNode.setPointerSectionSize(0);
structNode.setPreferredListEncoding(schema::ElementSize::EIGHT_BYTES); structNode.setPreferredListEncoding(schema2::ElementSize::EIGHT_BYTES);
break; break;
case schema::Type::Body::TEXT_TYPE: case schema2::Type::TEXT:
case schema::Type::Body::DATA_TYPE: case schema2::Type::DATA:
case schema::Type::Body::LIST_TYPE: case schema2::Type::LIST:
case schema::Type::Body::STRUCT_TYPE: case schema2::Type::STRUCT:
case schema::Type::Body::INTERFACE_TYPE: case schema2::Type::INTERFACE:
case schema::Type::Body::OBJECT_TYPE: case schema2::Type::OBJECT:
structNode.setDataSectionWordSize(0); structNode.setDataSectionWordSize(0);
structNode.setPointerSectionSize(1); structNode.setPointerSectionSize(1);
structNode.setPreferredListEncoding(schema::ElementSize::POINTER); structNode.setPreferredListEncoding(schema2::ElementSize::POINTER);
break; break;
} }
auto member = structNode.initMembers(1)[0]; auto field = structNode.initFields(1)[0];
member.setName("member0"); field.setName("member0");
member.setOrdinal(0); field.getOrdinal().setExplicit(0);
member.setCodeOrder(0); field.setCodeOrder(0);
member.getBody().initFieldMember().setType(type); field.initRegular().setType(type);
loader.load(node, true); loader.load(node, true);
} }
bool canUpgradeToData(const schema::Type::Reader& type) { bool canUpgradeToData(const schema2::Type::Reader& type) {
if (type.getBody().which() == schema::Type::Body::TEXT_TYPE) { if (type.which() == schema2::Type::TEXT) {
return true; return true;
} else if (type.getBody().which() == schema::Type::Body::LIST_TYPE) { } else if (type.which() == schema2::Type::LIST) {
switch (type.getBody().getListType().getBody().which()) { switch (type.getList().which()) {
case schema::Type::Body::INT8_TYPE: case schema2::Type::INT8:
case schema::Type::Body::UINT8_TYPE: case schema2::Type::UINT8:
return true; return true;
default: default:
return false; return false;
...@@ -994,29 +945,29 @@ private: ...@@ -994,29 +945,29 @@ private:
} }
} }
bool canUpgradeToObject(const schema::Type::Reader& type) { bool canUpgradeToObject(const schema2::Type::Reader& type) {
switch (type.getBody().which()) { switch (type.which()) {
case schema::Type::Body::VOID_TYPE: case schema2::Type::VOID:
case schema::Type::Body::BOOL_TYPE: case schema2::Type::BOOL:
case schema::Type::Body::INT8_TYPE: case schema2::Type::INT8:
case schema::Type::Body::INT16_TYPE: case schema2::Type::INT16:
case schema::Type::Body::INT32_TYPE: case schema2::Type::INT32:
case schema::Type::Body::INT64_TYPE: case schema2::Type::INT64:
case schema::Type::Body::UINT8_TYPE: case schema2::Type::UINT8:
case schema::Type::Body::UINT16_TYPE: case schema2::Type::UINT16:
case schema::Type::Body::UINT32_TYPE: case schema2::Type::UINT32:
case schema::Type::Body::UINT64_TYPE: case schema2::Type::UINT64:
case schema::Type::Body::FLOAT32_TYPE: case schema2::Type::FLOAT32:
case schema::Type::Body::FLOAT64_TYPE: case schema2::Type::FLOAT64:
case schema::Type::Body::ENUM_TYPE: case schema2::Type::ENUM:
return false; return false;
case schema::Type::Body::TEXT_TYPE: case schema2::Type::TEXT:
case schema::Type::Body::DATA_TYPE: case schema2::Type::DATA:
case schema::Type::Body::LIST_TYPE: case schema2::Type::LIST:
case schema::Type::Body::STRUCT_TYPE: case schema2::Type::STRUCT:
case schema::Type::Body::INTERFACE_TYPE: case schema2::Type::INTERFACE:
case schema::Type::Body::OBJECT_TYPE: case schema2::Type::OBJECT:
return true; return true;
} }
...@@ -1024,21 +975,19 @@ private: ...@@ -1024,21 +975,19 @@ private:
return true; return true;
} }
void checkDefaultCompatibility(const schema::Value::Reader& value, void checkDefaultCompatibility(const schema2::Value::Reader& value,
const schema::Value::Reader& replacement) { const schema2::Value::Reader& replacement) {
// Note that we test default compatibility only after testing type compatibility, and default // Note that we test default compatibility only after testing type compatibility, and default
// values have already been validated as matching their types, so this should pass. // values have already been validated as matching their types, so this should pass.
KJ_ASSERT(value.getBody().which() == replacement.getBody().which()) { KJ_ASSERT(value.which() == replacement.which()) {
compatibility = INCOMPATIBLE; compatibility = INCOMPATIBLE;
return; return;
} }
switch (value.getBody().which()) { switch (value.which()) {
#define HANDLE_TYPE(discrim, name) \ #define HANDLE_TYPE(discrim, name) \
case schema::Value::Body::discrim##_VALUE: \ case schema2::Value::discrim: \
VALIDATE_SCHEMA(value.getBody().get##name##Value() == \ VALIDATE_SCHEMA(value.get##name() == replacement.get##name(), "default value changed"); \
replacement.getBody().get##name##Value(), \
"default value changed"); \
break; break;
HANDLE_TYPE(VOID, Void); HANDLE_TYPE(VOID, Void);
HANDLE_TYPE(BOOL, Bool); HANDLE_TYPE(BOOL, Bool);
...@@ -1055,12 +1004,12 @@ private: ...@@ -1055,12 +1004,12 @@ private:
HANDLE_TYPE(ENUM, Enum); HANDLE_TYPE(ENUM, Enum);
#undef HANDLE_TYPE #undef HANDLE_TYPE
case schema::Value::Body::TEXT_VALUE: case schema2::Value::TEXT:
case schema::Value::Body::DATA_VALUE: case schema2::Value::DATA:
case schema::Value::Body::LIST_VALUE: case schema2::Value::LIST:
case schema::Value::Body::STRUCT_VALUE: case schema2::Value::STRUCT:
case schema::Value::Body::INTERFACE_VALUE: case schema2::Value::INTERFACE:
case schema::Value::Body::OBJECT_VALUE: case schema2::Value::OBJECT:
// It's not a big deal if default values for pointers change, and it would be difficult for // It's not a big deal if default values for pointers change, and it would be difficult for
// us to compare these defaults here, so just let it slide. // us to compare these defaults here, so just let it slide.
break; break;
...@@ -1070,22 +1019,19 @@ private: ...@@ -1070,22 +1019,19 @@ private:
// ======================================================================================= // =======================================================================================
_::RawSchema* SchemaLoader::Impl::load(const schema::Node::Reader& reader, bool isPlaceholder) { _::RawSchema* SchemaLoader::Impl::load(const schema2::Node::Reader& reader, bool isPlaceholder) {
// Make a copy of the node which can be used unchecked. // Make a copy of the node which can be used unchecked.
size_t size = reader.totalSizeInWords() + 1; kj::ArrayPtr<word> validated = makeUncheckedNodeEnforcingSizeRequirements(reader);
kj::ArrayPtr<word> validated = arena.allocateArray<word>(size);
memset(validated.begin(), 0, size * sizeof(word));
copyToUnchecked(reader, validated);
// Validate the copy. // Validate the copy.
Validator validator(*this); Validator validator(*this);
auto validatedReader = readMessageUnchecked<schema::Node>(validated.begin()); auto validatedReader = readMessageUnchecked<schema2::Node>(validated.begin());
if (!validator.validate(validatedReader)) { if (!validator.validate(validatedReader)) {
// Not valid. Construct an empty schema of the same type and return that. // Not valid. Construct an empty schema of the same type and return that.
return loadEmpty(validatedReader.getId(), return loadEmpty(validatedReader.getId(),
validatedReader.getDisplayName(), validatedReader.getDisplayName(),
validatedReader.getBody().which()); validatedReader.which());
} }
// Check if we already have a schema for this ID. // Check if we already have a schema for this ID.
...@@ -1106,7 +1052,7 @@ _::RawSchema* SchemaLoader::Impl::load(const schema::Node::Reader& reader, bool ...@@ -1106,7 +1052,7 @@ _::RawSchema* SchemaLoader::Impl::load(const schema::Node::Reader& reader, bool
isPlaceholder = false; isPlaceholder = false;
} }
auto existing = readMessageUnchecked<schema::Node>(slot->encodedNode); auto existing = readMessageUnchecked<schema2::Node>(slot->encodedNode);
CompatibilityChecker checker(*this); CompatibilityChecker checker(*this);
// Prefer to replace the existing schema if the existing schema is a placeholder. Otherwise, // Prefer to replace the existing schema if the existing schema is a placeholder. Otherwise,
...@@ -1121,6 +1067,7 @@ _::RawSchema* SchemaLoader::Impl::load(const schema::Node::Reader& reader, bool ...@@ -1121,6 +1067,7 @@ _::RawSchema* SchemaLoader::Impl::load(const schema::Node::Reader& reader, bool
slot->encodedSize = validated.size(); slot->encodedSize = validated.size();
slot->dependencies = validator.makeDependencyArray(&slot->dependencyCount); slot->dependencies = validator.makeDependencyArray(&slot->dependencyCount);
slot->membersByName = validator.makeMemberInfoArray(&slot->memberCount); slot->membersByName = validator.makeMemberInfoArray(&slot->memberCount);
slot->membersByDiscriminant = validator.makeMembersByDiscriminantArray();
} }
if (isPlaceholder) { if (isPlaceholder) {
...@@ -1147,12 +1094,12 @@ _::RawSchema* SchemaLoader::Impl::loadNative(const _::RawSchema* nativeSchema) { ...@@ -1147,12 +1094,12 @@ _::RawSchema* SchemaLoader::Impl::loadNative(const _::RawSchema* nativeSchema) {
KJ_REQUIRE(slot->canCastTo == nativeSchema, KJ_REQUIRE(slot->canCastTo == nativeSchema,
"two different compiled-in type have the same type ID", "two different compiled-in type have the same type ID",
nativeSchema->id, nativeSchema->id,
readMessageUnchecked<schema::Node>(nativeSchema->encodedNode).getDisplayName(), readMessageUnchecked<schema2::Node>(nativeSchema->encodedNode).getDisplayName(),
readMessageUnchecked<schema::Node>(slot->canCastTo->encodedNode).getDisplayName()); readMessageUnchecked<schema2::Node>(slot->canCastTo->encodedNode).getDisplayName());
return slot; return slot;
} else { } else {
auto existing = readMessageUnchecked<schema::Node>(slot->encodedNode); auto existing = readMessageUnchecked<schema2::Node>(slot->encodedNode);
auto native = readMessageUnchecked<schema::Node>(nativeSchema->encodedNode); auto native = readMessageUnchecked<schema2::Node>(nativeSchema->encodedNode);
CompatibilityChecker checker(*this); CompatibilityChecker checker(*this);
shouldReplace = checker.shouldReplace(existing, native, true); shouldReplace = checker.shouldReplace(existing, native, true);
} }
...@@ -1162,20 +1109,30 @@ _::RawSchema* SchemaLoader::Impl::loadNative(const _::RawSchema* nativeSchema) { ...@@ -1162,20 +1109,30 @@ _::RawSchema* SchemaLoader::Impl::loadNative(const _::RawSchema* nativeSchema) {
_::RawSchema* result = slot; _::RawSchema* result = slot;
if (shouldReplace) { if (shouldReplace) {
// Set the schema to a copy of the native schema. // Set the schema to a copy of the native schema, but make sure not to null out lazyInitializer
*kj::implicitCast<_::RawSchema*>(result) = *nativeSchema; // yet.
_::RawSchema temp = *nativeSchema;
temp.lazyInitializer = result->lazyInitializer;
*result = temp;
// Indicate that casting is safe. Note that it's important to set this before recursively // Indicate that casting is safe. Note that it's important to set this before recursively
// loading dependencies, so that cycles don't cause infinite loops! // loading dependencies, so that cycles don't cause infinite loops!
result->canCastTo = nativeSchema; result->canCastTo = nativeSchema;
// Except that we need to set the dependency list to point at other loader-owned RawSchemas. // We need to set the dependency list to point at other loader-owned RawSchemas.
kj::ArrayPtr<const _::RawSchema*> dependencies = kj::ArrayPtr<const _::RawSchema*> dependencies =
arena.allocateArray<const _::RawSchema*>(result->dependencyCount); arena.allocateArray<const _::RawSchema*>(result->dependencyCount);
for (uint i = 0; i < nativeSchema->dependencyCount; i++) { for (uint i = 0; i < nativeSchema->dependencyCount; i++) {
dependencies[i] = loadNative(nativeSchema->dependencies[i]); dependencies[i] = loadNative(nativeSchema->dependencies[i]);
} }
result->dependencies = dependencies.begin(); result->dependencies = dependencies.begin();
// If there is a struct size requirement, we need to make sure that it is satisfied.
auto reqIter = structSizeRequirements.find(nativeSchema->id);
if (reqIter != structSizeRequirements.end()) {
applyStructSizeRequirement(result, reqIter->second.dataWordCount,
reqIter->second.pointerCount);
}
} else { } else {
// The existing schema is newer. // The existing schema is newer.
...@@ -1198,21 +1155,21 @@ _::RawSchema* SchemaLoader::Impl::loadNative(const _::RawSchema* nativeSchema) { ...@@ -1198,21 +1155,21 @@ _::RawSchema* SchemaLoader::Impl::loadNative(const _::RawSchema* nativeSchema) {
} }
_::RawSchema* SchemaLoader::Impl::loadEmpty( _::RawSchema* SchemaLoader::Impl::loadEmpty(
uint64_t id, kj::StringPtr name, schema::Node::Body::Which kind) { uint64_t id, kj::StringPtr name, schema2::Node::Which kind) {
word scratch[32]; word scratch[32];
memset(scratch, 0, sizeof(scratch)); memset(scratch, 0, sizeof(scratch));
MallocMessageBuilder builder(kj::arrayPtr(scratch, sizeof(scratch))); MallocMessageBuilder builder(scratch);
auto node = builder.initRoot<schema::Node>(); auto node = builder.initRoot<schema2::Node>();
node.setId(id); node.setId(id);
node.setDisplayName(name); node.setDisplayName(name);
switch (kind) { switch (kind) {
case schema::Node::Body::STRUCT_NODE: node.getBody().initStructNode(); break; case schema2::Node::STRUCT: node.initStruct(); break;
case schema::Node::Body::ENUM_NODE: node.getBody().initEnumNode(); break; case schema2::Node::ENUM: node.initEnum(0); break;
case schema::Node::Body::INTERFACE_NODE: node.getBody().initInterfaceNode(); break; case schema2::Node::INTERFACE: node.initInterface(0); break;
case schema::Node::Body::FILE_NODE: case schema2::Node::FILE:
case schema::Node::Body::CONST_NODE: case schema2::Node::CONST:
case schema::Node::Body::ANNOTATION_NODE: case schema2::Node::ANNOTATION:
KJ_FAIL_REQUIRE("Not a type."); KJ_FAIL_REQUIRE("Not a type.");
break; break;
} }
...@@ -1243,6 +1200,73 @@ kj::Array<Schema> SchemaLoader::Impl::getAllLoaded() const { ...@@ -1243,6 +1200,73 @@ kj::Array<Schema> SchemaLoader::Impl::getAllLoaded() const {
return result; return result;
} }
void SchemaLoader::Impl::requireStructSize(uint64_t id, uint dataWordCount, uint pointerCount) {
auto& slot = structSizeRequirements[id];
slot.dataWordCount = kj::max(slot.dataWordCount, dataWordCount);
slot.pointerCount = kj::max(slot.pointerCount, pointerCount);
auto iter = schemas.find(id);
if (iter != schemas.end()) {
applyStructSizeRequirement(iter->second, dataWordCount, pointerCount);
}
}
kj::ArrayPtr<word> SchemaLoader::Impl::makeUncheckedNode(schema2::Node::Reader node) {
size_t size = node.totalSizeInWords() + 1;
kj::ArrayPtr<word> result = arena.allocateArray<word>(size);
memset(result.begin(), 0, size * sizeof(word));
copyToUnchecked(node, result);
return result;
}
kj::ArrayPtr<word> SchemaLoader::Impl::makeUncheckedNodeEnforcingSizeRequirements(
schema2::Node::Reader node) {
if (node.which() == schema2::Node::STRUCT) {
auto iter = structSizeRequirements.find(node.getId());
if (iter != structSizeRequirements.end()) {
auto requirement = iter->second;
auto structNode = node.getStruct();
if (structNode.getDataSectionWordSize() < requirement.dataWordCount ||
structNode.getPointerSectionSize() < requirement.pointerCount) {
return rewriteStructNodeWithSizes(node, requirement.dataWordCount,
requirement.pointerCount);
}
}
}
return makeUncheckedNode(node);
}
kj::ArrayPtr<word> SchemaLoader::Impl::rewriteStructNodeWithSizes(
schema2::Node::Reader node, uint dataWordCount, uint pointerCount) {
MallocMessageBuilder builder;
builder.setRoot(node);
auto root = builder.getRoot<schema2::Node>();
auto newStruct = root.getStruct();
newStruct.setDataSectionWordSize(kj::max(newStruct.getDataSectionWordSize(), dataWordCount));
newStruct.setPointerSectionSize(kj::max(newStruct.getPointerSectionSize(), pointerCount));
return makeUncheckedNode(root);
}
void SchemaLoader::Impl::applyStructSizeRequirement(
_::RawSchema* raw, uint dataWordCount, uint pointerCount) {
auto node = readMessageUnchecked<schema2::Node>(raw->encodedNode);
auto structNode = node.getStruct();
if (structNode.getDataSectionWordSize() < dataWordCount ||
structNode.getPointerSectionSize() < pointerCount) {
// Sizes need to be increased. Must rewrite.
kj::ArrayPtr<word> words = rewriteStructNodeWithSizes(node, dataWordCount, pointerCount);
// We don't need to re-validate the node because we know this change could not possibly have
// invalidated it. Just remake the unchecked message.
raw->encodedNode = words.begin();
raw->encodedSize = words.size();
}
}
void SchemaLoader::InitializerImpl::init(const _::RawSchema* schema) const { void SchemaLoader::InitializerImpl::init(const _::RawSchema* schema) const {
KJ_IF_MAYBE(c, callback) { KJ_IF_MAYBE(c, callback) {
c->load(loader, schema->id); c->load(loader, schema->id);
...@@ -1296,11 +1320,11 @@ kj::Maybe<Schema> SchemaLoader::tryGet(uint64_t id) const { ...@@ -1296,11 +1320,11 @@ kj::Maybe<Schema> SchemaLoader::tryGet(uint64_t id) const {
} }
} }
Schema SchemaLoader::load(const schema::Node::Reader& reader) { Schema SchemaLoader::load(const schema2::Node::Reader& reader) {
return Schema(impl.lockExclusive()->get()->load(reader, false)); return Schema(impl.lockExclusive()->get()->load(reader, false));
} }
Schema SchemaLoader::loadOnce(const schema::Node::Reader& reader) const { Schema SchemaLoader::loadOnce(const schema2::Node::Reader& reader) const {
auto locked = impl.lockExclusive(); auto locked = impl.lockExclusive();
auto getResult = locked->get()->tryGet(reader.getId()); auto getResult = locked->get()->tryGet(reader.getId());
if (getResult.schema == nullptr || getResult.schema->lazyInitializer != nullptr) { if (getResult.schema == nullptr || getResult.schema->lazyInitializer != nullptr) {
......
...@@ -31,6 +31,16 @@ ...@@ -31,6 +31,16 @@
namespace capnp { namespace capnp {
class SchemaLoader { class SchemaLoader {
// Class which can be used to construct Schema objects from schema::Nodes as defined in
// schema.capnp.
//
// It is a bad idea to use this class on untrusted input with exceptions disabled -- you may
// be exposing yourself to denial-of-service attacks, as attackers can easily construct schemas
// that are subtly inconsistent in a way that causes exceptions to be thrown either by
// SchemaLoader or by the dynamic API when the schemas are subsequently used. If you enable and
// properly catch exceptions, you should be OK -- assuming no bugs in the Cap'n Proto
// implementation, of course.
public: public:
class LazyLoadCallback { class LazyLoadCallback {
public: public:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment