Implement equals / operator ==/!= for AnyPointer

parent 09759e9f
...@@ -261,6 +261,56 @@ TEST(Any, AnyStructListCapInSchema) { ...@@ -261,6 +261,56 @@ TEST(Any, AnyStructListCapInSchema) {
} }
} }
TEST(Any, Equals) {
MallocMessageBuilder builderA;
auto rootA = builderA.getRoot<test::TestAllTypes>();
initTestMessage(rootA);
MallocMessageBuilder builderB;
auto rootB = builderB.getRoot<test::TestAllTypes>();
initTestMessage(rootB);
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.setBoolField(false);
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.setBoolField(false);
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.setEnumField(test::TestEnum::GARPLY);
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.setEnumField(test::TestEnum::GARPLY);
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.getStructField().setTextField("buzz");
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.getStructField().setTextField("buzz");
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.initVoidList(3);
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.initVoidList(3);
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.getBoolList().set(2, true);
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.getBoolList().set(2, true);
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.getStructList()[1].setTextField("my NEW structlist 2");
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.getStructList()[1].setTextField("my NEW structlist 2");
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
}
} // namespace } // namespace
} // namespace _ (private) } // namespace _ (private)
} // namespace capnp } // namespace capnp
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "any.h" #include "any.h"
#include <kj/debug.h>
#if !CAPNP_LITE #if !CAPNP_LITE
#include "capability.h" #include "capability.h"
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
...@@ -77,4 +79,127 @@ kj::Own<ClientHook> AnyPointer::Pipeline::asCap() { ...@@ -77,4 +79,127 @@ kj::Own<ClientHook> AnyPointer::Pipeline::asCap() {
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
StructEqualityResult equal(AnyStruct::Reader left, AnyStruct::Reader right) {
auto dataL = left.getDataSection();
size_t dataSizeL = dataL.size();
while(dataSizeL > 0 && dataL[dataSizeL - 1] == 0) {
-- dataSizeL;
}
auto dataR = right.getDataSection();
size_t dataSizeR = dataR.size();
while(dataSizeR > 0 && dataR[dataSizeR - 1] == 0) {
-- dataSizeR;
}
if(dataSizeL != dataSizeR) {
return StructEqualityResult::NOT_EQUAL;
}
if(0 != memcmp(dataL.begin(), dataR.begin(), dataSizeL * sizeof(word))) {
return StructEqualityResult::NOT_EQUAL;
}
auto ptrsL = left.getPointerSection();
auto ptrsR = right.getPointerSection();
size_t i = 0;
for(; i < kj::min(ptrsL.size(), ptrsR.size()); i++) {
auto l = ptrsL[i];
auto r = ptrsR[i];
switch(equal(l, r)) {
case StructEqualityResult::EQUAL:
break;
case StructEqualityResult::NOT_EQUAL:
return StructEqualityResult::NOT_EQUAL;
case StructEqualityResult::UNKNOWN_CONTAINS_CAPS:
return StructEqualityResult::UNKNOWN_CONTAINS_CAPS;
}
}
return StructEqualityResult::EQUAL;
}
StructEqualityResult equal(AnyList::Reader left, AnyList::Reader right) {
if(left.size() != right.size()) {
return StructEqualityResult::NOT_EQUAL;
}
switch(left.getElementSize()) {
case ElementSize::VOID:
case ElementSize::BIT:
case ElementSize::BYTE:
case ElementSize::TWO_BYTES:
case ElementSize::FOUR_BYTES:
case ElementSize::EIGHT_BYTES:
if(left.getElementSize() == right.getElementSize()) {
if(memcmp(left.getData().begin(), right.getData().begin(), left.getData().size()) == 0) {
return StructEqualityResult::EQUAL;
} else {
return StructEqualityResult::NOT_EQUAL;
}
} else {
return StructEqualityResult::NOT_EQUAL;
}
case ElementSize::POINTER:
case ElementSize::INLINE_COMPOSITE: {
auto llist = left.as<List<AnyStruct>>();
auto rlist = right.as<List<AnyStruct>>();
for(size_t i = 0; i < left.size(); i++) {
switch(equal(llist[i], rlist[i])) {
case StructEqualityResult::EQUAL:
break;
case StructEqualityResult::NOT_EQUAL:
return StructEqualityResult::NOT_EQUAL;
case StructEqualityResult::UNKNOWN_CONTAINS_CAPS:
return StructEqualityResult::UNKNOWN_CONTAINS_CAPS;
}
}
return StructEqualityResult::EQUAL;
}
}
}
StructEqualityResult equal(AnyPointer::Reader left, AnyPointer::Reader right) {
if(right.isCapability()) {
return StructEqualityResult::UNKNOWN_CONTAINS_CAPS;
}
if(left.isNull()) {
if(right.isNull()) {
return StructEqualityResult::EQUAL;
} else {
return StructEqualityResult::NOT_EQUAL;
}
} else if(left.isStruct()) {
if(right.isStruct()) {
return equal(left.getAs<AnyStruct>(), right.getAs<AnyStruct>());
} else {
return StructEqualityResult::NOT_EQUAL;
}
} else if(left.isList()) {
if(right.isList()) {
return equal(left.getAs<AnyList>(), right.getAs<AnyList>());
} else {
return StructEqualityResult::NOT_EQUAL;
}
} else if(left.isCapability()) {
return StructEqualityResult::UNKNOWN_CONTAINS_CAPS;
} else {
// There aren't currently any other types of pointers
KJ_FAIL_REQUIRE();
}
}
bool operator ==(AnyPointer::Reader left, AnyPointer::Reader right) {
switch(equal(left, right)) {
case StructEqualityResult::EQUAL:
return true;
case StructEqualityResult::NOT_EQUAL:
return false;
case StructEqualityResult::UNKNOWN_CONTAINS_CAPS:
KJ_FAIL_REQUIRE();
}
}
} // namespace capnp } // namespace capnp
...@@ -90,13 +90,12 @@ struct AnyPointer { ...@@ -90,13 +90,12 @@ struct AnyPointer {
inline MessageSize targetSize() const; inline MessageSize targetSize() const;
// Get the total size of the target object and all its children. // Get the total size of the target object and all its children.
inline bool isNull() const; inline PointerType getPointerType() const;
inline bool isStruct() {
return reader.isStruct(); inline bool isNull() const { return getPointerType() == PointerType::NULL_; }
} inline bool isStruct() const { return getPointerType() == PointerType::STRUCT; }
inline bool isList() { inline bool isList() const { return getPointerType() == PointerType::LIST; }
return reader.isList(); inline bool isCapability() const { return getPointerType() == PointerType::CAPABILITY; }
}
template <typename T> template <typename T>
inline ReaderFor<T> getAs() const; inline ReaderFor<T> getAs() const;
...@@ -138,13 +137,12 @@ struct AnyPointer { ...@@ -138,13 +137,12 @@ struct AnyPointer {
inline MessageSize targetSize() const; inline MessageSize targetSize() const;
// Get the total size of the target object and all its children. // Get the total size of the target object and all its children.
inline bool isNull(); inline PointerType getPointerType();
inline bool isStruct() {
return builder.isStruct(); inline bool isNull() { return getPointerType() == PointerType::NULL_; }
} inline bool isStruct() { return getPointerType() == PointerType::STRUCT; }
inline bool isList() { inline bool isList() { return getPointerType() == PointerType::LIST; }
return builder.isList(); inline bool isCapability() { return getPointerType() == PointerType::CAPABILITY; }
}
inline void clear(); inline void clear();
// Set to null. // Set to null.
...@@ -573,6 +571,8 @@ public: ...@@ -573,6 +571,8 @@ public:
inline ElementSize getElementSize() { return _reader.getElementSize(); } inline ElementSize getElementSize() { return _reader.getElementSize(); }
inline uint size() { return _reader.size() / ELEMENTS; } inline uint size() { return _reader.size() / ELEMENTS; }
inline Data::Reader getData() { return _reader.asDataOfAnySize(); }
template <typename T> ReaderFor<T> as() { template <typename T> ReaderFor<T> as() {
// T must be List<U>. // T must be List<U>.
return ReaderFor<T>(_reader); return ReaderFor<T>(_reader);
...@@ -665,8 +665,8 @@ inline MessageSize AnyPointer::Reader::targetSize() const { ...@@ -665,8 +665,8 @@ inline MessageSize AnyPointer::Reader::targetSize() const {
return reader.targetSize().asPublic(); return reader.targetSize().asPublic();
} }
inline bool AnyPointer::Reader::isNull() const { inline PointerType AnyPointer::Reader::getPointerType() const {
return reader.isNull(); return reader.getPointerType();
} }
template <typename T> template <typename T>
...@@ -678,8 +678,8 @@ inline MessageSize AnyPointer::Builder::targetSize() const { ...@@ -678,8 +678,8 @@ inline MessageSize AnyPointer::Builder::targetSize() const {
return asReader().targetSize(); return asReader().targetSize();
} }
inline bool AnyPointer::Builder::isNull() { inline PointerType AnyPointer::Builder::getPointerType() {
return builder.isNull(); return builder.getPointerType();
} }
inline void AnyPointer::Builder::clear() { inline void AnyPointer::Builder::clear() {
...@@ -806,6 +806,31 @@ inline Orphan<AnyPointer> Orphan<AnyPointer>::releaseAs() { ...@@ -806,6 +806,31 @@ inline Orphan<AnyPointer> Orphan<AnyPointer>::releaseAs() {
return kj::mv(*this); return kj::mv(*this);
} }
enum class StructEqualityResult {
NOT_EQUAL,
EQUAL,
UNKNOWN_CONTAINS_CAPS
};
inline kj::StringPtr KJ_STRINGIFY(StructEqualityResult res) {
switch(res) {
case StructEqualityResult::NOT_EQUAL:
return "NOT_EQUAL";
case StructEqualityResult::EQUAL:
return "EQUAL";
case StructEqualityResult::UNKNOWN_CONTAINS_CAPS:
return "UNKNOWN_CONTAINS_CAPS";
}
}
StructEqualityResult equal(AnyStruct::Reader left, AnyStruct::Reader right);
StructEqualityResult equal(List<AnyStruct>::Reader left, List<AnyStruct>::Reader right);
StructEqualityResult equal(AnyPointer::Reader left, AnyPointer::Reader right);
bool operator ==(AnyPointer::Reader left, AnyPointer::Reader right);
inline bool operator !=(AnyPointer::Reader left, AnyPointer::Reader right) {
return !(left == right);
}
namespace _ { // private namespace _ { // private
// Specialize PointerHelpers for AnyPointer. // Specialize PointerHelpers for AnyPointer.
......
...@@ -99,6 +99,20 @@ enum class ElementSize: uint8_t { ...@@ -99,6 +99,20 @@ enum class ElementSize: uint8_t {
INLINE_COMPOSITE = 7 INLINE_COMPOSITE = 7
}; };
enum class PointerType {
// Various wire types a pointer field can take
NULL_,
// Should be NULL, but that's #defined in stddef.h
STRUCT,
LIST,
CAPABILITY,
OTHER
// currently unused
};
namespace schemas { namespace schemas {
template <typename T> template <typename T>
......
...@@ -2267,20 +2267,24 @@ void PointerBuilder::clear() { ...@@ -2267,20 +2267,24 @@ void PointerBuilder::clear() {
memset(pointer, 0, sizeof(WirePointer)); memset(pointer, 0, sizeof(WirePointer));
} }
bool PointerBuilder::isNull() { PointerType PointerBuilder::getPointerType() {
return pointer->isNull(); if(pointer->isNull()) {
} return PointerType::NULL_;
} else {
bool PointerBuilder::isStruct() { WirePointer* ptr = pointer;
WirePointer* ptr = pointer; WireHelpers::followFars(ptr, ptr->target(), segment);
WireHelpers::followFars(ptr, ptr->target(), segment); switch(ptr->kind()) {
return ptr->kind() == WirePointer::Kind::STRUCT; case WirePointer::Kind::FAR:
} KJ_FAIL_REQUIRE();
case WirePointer::Kind::STRUCT:
bool PointerBuilder::isList() { return PointerType::STRUCT;
WirePointer* ptr = pointer; case WirePointer::Kind::LIST:
WireHelpers::followFars(ptr, ptr->target(), segment); return PointerType::LIST;
return ptr->kind() == WirePointer::Kind::LIST; case WirePointer::Kind::OTHER:
// TODO: make sure we're only looking at capability pointers
return PointerType::CAPABILITY;
}
}
} }
void PointerBuilder::transferFrom(PointerBuilder other) { void PointerBuilder::transferFrom(PointerBuilder other) {
...@@ -2368,24 +2372,26 @@ MessageSizeCounts PointerReader::targetSize() const { ...@@ -2368,24 +2372,26 @@ MessageSizeCounts PointerReader::targetSize() const {
: WireHelpers::totalSize(segment, pointer, nestingLimit); : WireHelpers::totalSize(segment, pointer, nestingLimit);
} }
bool PointerReader::isNull() const { PointerType PointerReader::getPointerType() const {
return pointer == nullptr || pointer->isNull(); if(pointer->isNull()) {
} return PointerType::NULL_;
} else {
bool PointerReader::isStruct() const { word* refTarget = nullptr;
word* refTarget = nullptr; const WirePointer* ptr = pointer;
const WirePointer* ptr = pointer; SegmentReader* sgmt = segment;
SegmentReader* sgmt = segment; WireHelpers::followFars(ptr, refTarget, sgmt);
WireHelpers::followFars(ptr, refTarget, sgmt); switch(ptr->kind()) {
return ptr->kind() == WirePointer::Kind::STRUCT; case WirePointer::Kind::FAR:
} KJ_FAIL_REQUIRE();
case WirePointer::Kind::STRUCT:
bool PointerReader::isList() const { return PointerType::STRUCT;
word* refTarget = nullptr; case WirePointer::Kind::LIST:
const WirePointer* ptr = pointer; return PointerType::LIST;
SegmentReader* sgmt = segment; case WirePointer::Kind::OTHER:
WireHelpers::followFars(ptr, refTarget, sgmt); // TODO: make sure we're only looking at capability pointers
return ptr->kind() == WirePointer::Kind::LIST; return PointerType::CAPABILITY;
}
}
} }
kj::Maybe<Arena&> PointerReader::getArena() const { kj::Maybe<Arena&> PointerReader::getArena() const {
...@@ -2599,6 +2605,16 @@ Data::Reader ListReader::asData() { ...@@ -2599,6 +2605,16 @@ Data::Reader ListReader::asData() {
return Data::Reader(reinterpret_cast<const byte*>(ptr), elementCount / ELEMENTS); return Data::Reader(reinterpret_cast<const byte*>(ptr), elementCount / ELEMENTS);
} }
Data::Reader ListReader::asDataOfAnySize() {
KJ_REQUIRE(structPointerCount == 0 * POINTERS,
"Expected data only, got pointers.") {
return Data::Reader();
}
return Data::Reader(reinterpret_cast<const byte*>(ptr), structDataSize * elementCount / ELEMENTS);
}
StructReader ListReader::getStructElement(ElementCount index) const { StructReader ListReader::getStructElement(ElementCount index) const {
KJ_REQUIRE(nestingLimit > 0, KJ_REQUIRE(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnp::ReaderOptions.") { "Message is too deeply-nested or contains cycles. See capnp::ReaderOptions.") {
......
...@@ -274,9 +274,8 @@ public: ...@@ -274,9 +274,8 @@ public:
// Get a PointerBuilder representing a message root located in the given segment at the given // Get a PointerBuilder representing a message root located in the given segment at the given
// location. // location.
bool isNull(); inline bool isNull() { return getPointerType() == PointerType::NULL_; }
bool isStruct(); PointerType getPointerType();
bool isList();
StructBuilder getStruct(StructSize size, const word* defaultValue); StructBuilder getStruct(StructSize size, const word* defaultValue);
ListBuilder getList(ElementSize elementSize, const word* defaultValue); ListBuilder getList(ElementSize elementSize, const word* defaultValue);
...@@ -355,9 +354,8 @@ public: ...@@ -355,9 +354,8 @@ public:
// use the result as a hint for allocating the first segment, do the copy, and then throw an // use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns. // exception if it overruns.
bool isNull() const; inline bool isNull() const { return getPointerType() == PointerType::NULL_; }
bool isStruct() const; PointerType getPointerType() const;
bool isList() const;
StructReader getStruct(const word* defaultValue) const; StructReader getStruct(const word* defaultValue) const;
ListReader getList(ElementSize expectedElementSize, const word* defaultValue) const; ListReader getList(ElementSize expectedElementSize, const word* defaultValue) const;
...@@ -656,6 +654,8 @@ public: ...@@ -656,6 +654,8 @@ public:
Data::Reader asData(); Data::Reader asData();
// Reinterpret the list as a blob. Throws an exception if the elements are not byte-sized. // Reinterpret the list as a blob. Throws an exception if the elements are not byte-sized.
Data::Reader asDataOfAnySize();
template <typename T> template <typename T>
KJ_ALWAYS_INLINE(T getDataElement(ElementCount index) const); KJ_ALWAYS_INLINE(T getDataElement(ElementCount index) const);
// Get the element of the given type at the given index. // Get the element of the given type at the given index.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment