Implement equals / operator ==/!= for AnyPointer

parent 09759e9f
......@@ -261,6 +261,56 @@ TEST(Any, AnyStructListCapInSchema) {
}
}
TEST(Any, Equals) {
MallocMessageBuilder builderA;
auto rootA = builderA.getRoot<test::TestAllTypes>();
initTestMessage(rootA);
MallocMessageBuilder builderB;
auto rootB = builderB.getRoot<test::TestAllTypes>();
initTestMessage(rootB);
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.setBoolField(false);
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.setBoolField(false);
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.setEnumField(test::TestEnum::GARPLY);
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.setEnumField(test::TestEnum::GARPLY);
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.getStructField().setTextField("buzz");
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.getStructField().setTextField("buzz");
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.initVoidList(3);
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.initVoidList(3);
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.getBoolList().set(2, true);
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.getBoolList().set(2, true);
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootB.getStructList()[1].setTextField("my NEW structlist 2");
EXPECT_EQ(StructEqualityResult::NOT_EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
rootA.getStructList()[1].setTextField("my NEW structlist 2");
EXPECT_EQ(StructEqualityResult::EQUAL, equal(builderA.getRoot<AnyPointer>(), builderB.getRoot<AnyPointer>()));
}
} // namespace
} // namespace _ (private)
} // namespace capnp
......@@ -21,6 +21,8 @@
#include "any.h"
#include <kj/debug.h>
#if !CAPNP_LITE
#include "capability.h"
#endif // !CAPNP_LITE
......@@ -77,4 +79,127 @@ kj::Own<ClientHook> AnyPointer::Pipeline::asCap() {
#endif // !CAPNP_LITE
StructEqualityResult equal(AnyStruct::Reader left, AnyStruct::Reader right) {
auto dataL = left.getDataSection();
size_t dataSizeL = dataL.size();
while(dataSizeL > 0 && dataL[dataSizeL - 1] == 0) {
-- dataSizeL;
}
auto dataR = right.getDataSection();
size_t dataSizeR = dataR.size();
while(dataSizeR > 0 && dataR[dataSizeR - 1] == 0) {
-- dataSizeR;
}
if(dataSizeL != dataSizeR) {
return StructEqualityResult::NOT_EQUAL;
}
if(0 != memcmp(dataL.begin(), dataR.begin(), dataSizeL * sizeof(word))) {
return StructEqualityResult::NOT_EQUAL;
}
auto ptrsL = left.getPointerSection();
auto ptrsR = right.getPointerSection();
size_t i = 0;
for(; i < kj::min(ptrsL.size(), ptrsR.size()); i++) {
auto l = ptrsL[i];
auto r = ptrsR[i];
switch(equal(l, r)) {
case StructEqualityResult::EQUAL:
break;
case StructEqualityResult::NOT_EQUAL:
return StructEqualityResult::NOT_EQUAL;
case StructEqualityResult::UNKNOWN_CONTAINS_CAPS:
return StructEqualityResult::UNKNOWN_CONTAINS_CAPS;
}
}
return StructEqualityResult::EQUAL;
}
StructEqualityResult equal(AnyList::Reader left, AnyList::Reader right) {
if(left.size() != right.size()) {
return StructEqualityResult::NOT_EQUAL;
}
switch(left.getElementSize()) {
case ElementSize::VOID:
case ElementSize::BIT:
case ElementSize::BYTE:
case ElementSize::TWO_BYTES:
case ElementSize::FOUR_BYTES:
case ElementSize::EIGHT_BYTES:
if(left.getElementSize() == right.getElementSize()) {
if(memcmp(left.getData().begin(), right.getData().begin(), left.getData().size()) == 0) {
return StructEqualityResult::EQUAL;
} else {
return StructEqualityResult::NOT_EQUAL;
}
} else {
return StructEqualityResult::NOT_EQUAL;
}
case ElementSize::POINTER:
case ElementSize::INLINE_COMPOSITE: {
auto llist = left.as<List<AnyStruct>>();
auto rlist = right.as<List<AnyStruct>>();
for(size_t i = 0; i < left.size(); i++) {
switch(equal(llist[i], rlist[i])) {
case StructEqualityResult::EQUAL:
break;
case StructEqualityResult::NOT_EQUAL:
return StructEqualityResult::NOT_EQUAL;
case StructEqualityResult::UNKNOWN_CONTAINS_CAPS:
return StructEqualityResult::UNKNOWN_CONTAINS_CAPS;
}
}
return StructEqualityResult::EQUAL;
}
}
}
StructEqualityResult equal(AnyPointer::Reader left, AnyPointer::Reader right) {
if(right.isCapability()) {
return StructEqualityResult::UNKNOWN_CONTAINS_CAPS;
}
if(left.isNull()) {
if(right.isNull()) {
return StructEqualityResult::EQUAL;
} else {
return StructEqualityResult::NOT_EQUAL;
}
} else if(left.isStruct()) {
if(right.isStruct()) {
return equal(left.getAs<AnyStruct>(), right.getAs<AnyStruct>());
} else {
return StructEqualityResult::NOT_EQUAL;
}
} else if(left.isList()) {
if(right.isList()) {
return equal(left.getAs<AnyList>(), right.getAs<AnyList>());
} else {
return StructEqualityResult::NOT_EQUAL;
}
} else if(left.isCapability()) {
return StructEqualityResult::UNKNOWN_CONTAINS_CAPS;
} else {
// There aren't currently any other types of pointers
KJ_FAIL_REQUIRE();
}
}
bool operator ==(AnyPointer::Reader left, AnyPointer::Reader right) {
switch(equal(left, right)) {
case StructEqualityResult::EQUAL:
return true;
case StructEqualityResult::NOT_EQUAL:
return false;
case StructEqualityResult::UNKNOWN_CONTAINS_CAPS:
KJ_FAIL_REQUIRE();
}
}
} // namespace capnp
......@@ -90,13 +90,12 @@ struct AnyPointer {
inline MessageSize targetSize() const;
// Get the total size of the target object and all its children.
inline bool isNull() const;
inline bool isStruct() {
return reader.isStruct();
}
inline bool isList() {
return reader.isList();
}
inline PointerType getPointerType() const;
inline bool isNull() const { return getPointerType() == PointerType::NULL_; }
inline bool isStruct() const { return getPointerType() == PointerType::STRUCT; }
inline bool isList() const { return getPointerType() == PointerType::LIST; }
inline bool isCapability() const { return getPointerType() == PointerType::CAPABILITY; }
template <typename T>
inline ReaderFor<T> getAs() const;
......@@ -138,13 +137,12 @@ struct AnyPointer {
inline MessageSize targetSize() const;
// Get the total size of the target object and all its children.
inline bool isNull();
inline bool isStruct() {
return builder.isStruct();
}
inline bool isList() {
return builder.isList();
}
inline PointerType getPointerType();
inline bool isNull() { return getPointerType() == PointerType::NULL_; }
inline bool isStruct() { return getPointerType() == PointerType::STRUCT; }
inline bool isList() { return getPointerType() == PointerType::LIST; }
inline bool isCapability() { return getPointerType() == PointerType::CAPABILITY; }
inline void clear();
// Set to null.
......@@ -573,6 +571,8 @@ public:
inline ElementSize getElementSize() { return _reader.getElementSize(); }
inline uint size() { return _reader.size() / ELEMENTS; }
inline Data::Reader getData() { return _reader.asDataOfAnySize(); }
template <typename T> ReaderFor<T> as() {
// T must be List<U>.
return ReaderFor<T>(_reader);
......@@ -665,8 +665,8 @@ inline MessageSize AnyPointer::Reader::targetSize() const {
return reader.targetSize().asPublic();
}
inline bool AnyPointer::Reader::isNull() const {
return reader.isNull();
inline PointerType AnyPointer::Reader::getPointerType() const {
return reader.getPointerType();
}
template <typename T>
......@@ -678,8 +678,8 @@ inline MessageSize AnyPointer::Builder::targetSize() const {
return asReader().targetSize();
}
inline bool AnyPointer::Builder::isNull() {
return builder.isNull();
inline PointerType AnyPointer::Builder::getPointerType() {
return builder.getPointerType();
}
inline void AnyPointer::Builder::clear() {
......@@ -806,6 +806,31 @@ inline Orphan<AnyPointer> Orphan<AnyPointer>::releaseAs() {
return kj::mv(*this);
}
enum class StructEqualityResult {
NOT_EQUAL,
EQUAL,
UNKNOWN_CONTAINS_CAPS
};
inline kj::StringPtr KJ_STRINGIFY(StructEqualityResult res) {
switch(res) {
case StructEqualityResult::NOT_EQUAL:
return "NOT_EQUAL";
case StructEqualityResult::EQUAL:
return "EQUAL";
case StructEqualityResult::UNKNOWN_CONTAINS_CAPS:
return "UNKNOWN_CONTAINS_CAPS";
}
}
StructEqualityResult equal(AnyStruct::Reader left, AnyStruct::Reader right);
StructEqualityResult equal(List<AnyStruct>::Reader left, List<AnyStruct>::Reader right);
StructEqualityResult equal(AnyPointer::Reader left, AnyPointer::Reader right);
bool operator ==(AnyPointer::Reader left, AnyPointer::Reader right);
inline bool operator !=(AnyPointer::Reader left, AnyPointer::Reader right) {
return !(left == right);
}
namespace _ { // private
// Specialize PointerHelpers for AnyPointer.
......
......@@ -99,6 +99,20 @@ enum class ElementSize: uint8_t {
INLINE_COMPOSITE = 7
};
enum class PointerType {
// Various wire types a pointer field can take
NULL_,
// Should be NULL, but that's #defined in stddef.h
STRUCT,
LIST,
CAPABILITY,
OTHER
// currently unused
};
namespace schemas {
template <typename T>
......
......@@ -2267,20 +2267,24 @@ void PointerBuilder::clear() {
memset(pointer, 0, sizeof(WirePointer));
}
bool PointerBuilder::isNull() {
return pointer->isNull();
}
bool PointerBuilder::isStruct() {
WirePointer* ptr = pointer;
WireHelpers::followFars(ptr, ptr->target(), segment);
return ptr->kind() == WirePointer::Kind::STRUCT;
}
bool PointerBuilder::isList() {
WirePointer* ptr = pointer;
WireHelpers::followFars(ptr, ptr->target(), segment);
return ptr->kind() == WirePointer::Kind::LIST;
PointerType PointerBuilder::getPointerType() {
if(pointer->isNull()) {
return PointerType::NULL_;
} else {
WirePointer* ptr = pointer;
WireHelpers::followFars(ptr, ptr->target(), segment);
switch(ptr->kind()) {
case WirePointer::Kind::FAR:
KJ_FAIL_REQUIRE();
case WirePointer::Kind::STRUCT:
return PointerType::STRUCT;
case WirePointer::Kind::LIST:
return PointerType::LIST;
case WirePointer::Kind::OTHER:
// TODO: make sure we're only looking at capability pointers
return PointerType::CAPABILITY;
}
}
}
void PointerBuilder::transferFrom(PointerBuilder other) {
......@@ -2368,24 +2372,26 @@ MessageSizeCounts PointerReader::targetSize() const {
: WireHelpers::totalSize(segment, pointer, nestingLimit);
}
bool PointerReader::isNull() const {
return pointer == nullptr || pointer->isNull();
}
bool PointerReader::isStruct() const {
word* refTarget = nullptr;
const WirePointer* ptr = pointer;
SegmentReader* sgmt = segment;
WireHelpers::followFars(ptr, refTarget, sgmt);
return ptr->kind() == WirePointer::Kind::STRUCT;
}
bool PointerReader::isList() const {
word* refTarget = nullptr;
const WirePointer* ptr = pointer;
SegmentReader* sgmt = segment;
WireHelpers::followFars(ptr, refTarget, sgmt);
return ptr->kind() == WirePointer::Kind::LIST;
PointerType PointerReader::getPointerType() const {
if(pointer->isNull()) {
return PointerType::NULL_;
} else {
word* refTarget = nullptr;
const WirePointer* ptr = pointer;
SegmentReader* sgmt = segment;
WireHelpers::followFars(ptr, refTarget, sgmt);
switch(ptr->kind()) {
case WirePointer::Kind::FAR:
KJ_FAIL_REQUIRE();
case WirePointer::Kind::STRUCT:
return PointerType::STRUCT;
case WirePointer::Kind::LIST:
return PointerType::LIST;
case WirePointer::Kind::OTHER:
// TODO: make sure we're only looking at capability pointers
return PointerType::CAPABILITY;
}
}
}
kj::Maybe<Arena&> PointerReader::getArena() const {
......@@ -2599,6 +2605,16 @@ Data::Reader ListReader::asData() {
return Data::Reader(reinterpret_cast<const byte*>(ptr), elementCount / ELEMENTS);
}
Data::Reader ListReader::asDataOfAnySize() {
KJ_REQUIRE(structPointerCount == 0 * POINTERS,
"Expected data only, got pointers.") {
return Data::Reader();
}
return Data::Reader(reinterpret_cast<const byte*>(ptr), structDataSize * elementCount / ELEMENTS);
}
StructReader ListReader::getStructElement(ElementCount index) const {
KJ_REQUIRE(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnp::ReaderOptions.") {
......
......@@ -274,9 +274,8 @@ public:
// Get a PointerBuilder representing a message root located in the given segment at the given
// location.
bool isNull();
bool isStruct();
bool isList();
inline bool isNull() { return getPointerType() == PointerType::NULL_; }
PointerType getPointerType();
StructBuilder getStruct(StructSize size, const word* defaultValue);
ListBuilder getList(ElementSize elementSize, const word* defaultValue);
......@@ -355,9 +354,8 @@ public:
// use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns.
bool isNull() const;
bool isStruct() const;
bool isList() const;
inline bool isNull() const { return getPointerType() == PointerType::NULL_; }
PointerType getPointerType() const;
StructReader getStruct(const word* defaultValue) const;
ListReader getList(ElementSize expectedElementSize, const word* defaultValue) const;
......@@ -656,6 +654,8 @@ public:
Data::Reader asData();
// Reinterpret the list as a blob. Throws an exception if the elements are not byte-sized.
Data::Reader asDataOfAnySize();
template <typename T>
KJ_ALWAYS_INLINE(T getDataElement(ElementCount index) const);
// Get the element of the given type at the given index.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment