Commit 41773d2f authored by Kenton Varda's avatar Kenton Varda

Merge pull request #192 from joshuawarner32/equal

Add equal / operator == method
parents 9f42d66a 69dea08e
......@@ -261,6 +261,58 @@ TEST(Any, AnyStructListCapInSchema) {
}
}
TEST(Any, Equals) {
MallocMessageBuilder builderA;
auto rootA = builderA.getRoot<test::TestAllTypes>();
auto anyA = builderA.getRoot<AnyPointer>();
initTestMessage(rootA);
MallocMessageBuilder builderB;
auto rootB = builderB.getRoot<test::TestAllTypes>();
auto anyB = builderB.getRoot<AnyPointer>();
initTestMessage(rootB);
EXPECT_EQ(Equality::EQUAL, anyA.equals(anyB));
rootA.setBoolField(false);
EXPECT_EQ(Equality::NOT_EQUAL, anyA.equals(anyB));
rootB.setBoolField(false);
EXPECT_EQ(Equality::EQUAL, anyA.equals(anyB));
rootB.setEnumField(test::TestEnum::GARPLY);
EXPECT_EQ(Equality::NOT_EQUAL, anyA.equals(anyB));
rootA.setEnumField(test::TestEnum::GARPLY);
EXPECT_EQ(Equality::EQUAL, anyA.equals(anyB));
rootA.getStructField().setTextField("buzz");
EXPECT_EQ(Equality::NOT_EQUAL, anyA.equals(anyB));
rootB.getStructField().setTextField("buzz");
EXPECT_EQ(Equality::EQUAL, anyA.equals(anyB));
rootA.initVoidList(3);
EXPECT_EQ(Equality::NOT_EQUAL, anyA.equals(anyB));
rootB.initVoidList(3);
EXPECT_EQ(Equality::EQUAL, anyA.equals(anyB));
rootA.getBoolList().set(2, true);
EXPECT_EQ(Equality::NOT_EQUAL, anyA.equals(anyB));
rootB.getBoolList().set(2, true);
EXPECT_EQ(Equality::EQUAL, anyA.equals(anyB));
rootB.getStructList()[1].setTextField("my NEW structlist 2");
EXPECT_EQ(Equality::NOT_EQUAL, anyA.equals(anyB));
rootA.getStructList()[1].setTextField("my NEW structlist 2");
EXPECT_EQ(Equality::EQUAL, anyA.equals(anyB));
}
} // namespace
} // namespace _ (private)
} // namespace capnp
......@@ -21,6 +21,8 @@
#include "any.h"
#include <kj/debug.h>
#if !CAPNP_LITE
#include "capability.h"
#endif // !CAPNP_LITE
......@@ -77,4 +79,163 @@ kj::Own<ClientHook> AnyPointer::Pipeline::asCap() {
#endif // !CAPNP_LITE
Equality AnyStruct::Reader::equals(AnyStruct::Reader right) {
auto dataL = getDataSection();
size_t dataSizeL = dataL.size();
while(dataSizeL > 0 && dataL[dataSizeL - 1] == 0) {
-- dataSizeL;
}
auto dataR = right.getDataSection();
size_t dataSizeR = dataR.size();
while(dataSizeR > 0 && dataR[dataSizeR - 1] == 0) {
-- dataSizeR;
}
if(dataSizeL != dataSizeR) {
return Equality::NOT_EQUAL;
}
if(0 != memcmp(dataL.begin(), dataR.begin(), dataSizeL)) {
return Equality::NOT_EQUAL;
}
auto ptrsL = getPointerSection();
auto ptrsR = right.getPointerSection();
size_t i = 0;
auto eqResult = Equality::EQUAL;
for(; i < kj::min(ptrsL.size(), ptrsR.size()); i++) {
auto l = ptrsL[i];
auto r = ptrsR[i];
switch(l.equals(r)) {
case Equality::EQUAL:
break;
case Equality::NOT_EQUAL:
return Equality::NOT_EQUAL;
case Equality::UNKNOWN_CONTAINS_CAPS:
eqResult = Equality::UNKNOWN_CONTAINS_CAPS;
break;
default:
KJ_UNREACHABLE;
}
}
return eqResult;
}
kj::StringPtr KJ_STRINGIFY(Equality res) {
switch(res) {
case Equality::NOT_EQUAL:
return "NOT_EQUAL";
case Equality::EQUAL:
return "EQUAL";
case Equality::UNKNOWN_CONTAINS_CAPS:
return "UNKNOWN_CONTAINS_CAPS";
}
KJ_UNREACHABLE;
}
Equality AnyList::Reader::equals(AnyList::Reader right) {
if(size() != right.size()) {
return Equality::NOT_EQUAL;
}
auto eqResult = Equality::EQUAL;
switch(getElementSize()) {
case ElementSize::VOID:
case ElementSize::BIT:
case ElementSize::BYTE:
case ElementSize::TWO_BYTES:
case ElementSize::FOUR_BYTES:
case ElementSize::EIGHT_BYTES:
if(getElementSize() == right.getElementSize()) {
if(memcmp(getRawBytes().begin(), right.getRawBytes().begin(), getRawBytes().size()) == 0) {
return Equality::EQUAL;
} else {
return Equality::NOT_EQUAL;
}
} else {
return Equality::NOT_EQUAL;
}
case ElementSize::POINTER:
case ElementSize::INLINE_COMPOSITE: {
auto llist = as<List<AnyStruct>>();
auto rlist = right.as<List<AnyStruct>>();
for(size_t i = 0; i < size(); i++) {
switch(llist[i].equals(rlist[i])) {
case Equality::EQUAL:
break;
case Equality::NOT_EQUAL:
return Equality::NOT_EQUAL;
case Equality::UNKNOWN_CONTAINS_CAPS:
eqResult = Equality::UNKNOWN_CONTAINS_CAPS;
break;
default:
KJ_UNREACHABLE;
}
}
return eqResult;
}
}
KJ_UNREACHABLE;
}
Equality AnyPointer::Reader::equals(AnyPointer::Reader right) {
if(getPointerType() != right.getPointerType()) {
return Equality::NOT_EQUAL;
}
switch(getPointerType()) {
case PointerType::NULL_:
return Equality::EQUAL;
case PointerType::STRUCT:
return getAs<AnyStruct>().equals(right.getAs<AnyStruct>());
case PointerType::LIST:
return getAs<AnyList>().equals(right.getAs<AnyList>());
case PointerType::CAPABILITY:
return Equality::UNKNOWN_CONTAINS_CAPS;
}
// There aren't currently any other types of pointers
KJ_UNREACHABLE;
}
bool AnyPointer::Reader::operator ==(AnyPointer::Reader right) {
switch(equals(right)) {
case Equality::EQUAL:
return true;
case Equality::NOT_EQUAL:
return false;
case Equality::UNKNOWN_CONTAINS_CAPS:
KJ_FAIL_REQUIRE(
"operator== cannot determine equality of capabilities; use equals() instead if you need to handle this case");
}
KJ_UNREACHABLE;
}
bool AnyStruct::Reader::operator ==(AnyStruct::Reader right) {
switch(equals(right)) {
case Equality::EQUAL:
return true;
case Equality::NOT_EQUAL:
return false;
case Equality::UNKNOWN_CONTAINS_CAPS:
KJ_FAIL_REQUIRE(
"operator== cannot determine equality of capabilities; use equals() instead if you need to handle this case");
}
KJ_UNREACHABLE;
}
bool AnyList::Reader::operator ==(AnyList::Reader right) {
switch(equals(right)) {
case Equality::EQUAL:
return true;
case Equality::NOT_EQUAL:
return false;
case Equality::UNKNOWN_CONTAINS_CAPS:
KJ_FAIL_REQUIRE(
"operator== cannot determine equality of capabilities; use equals() instead if you need to handle this case");
}
KJ_UNREACHABLE;
}
} // namespace capnp
......@@ -74,6 +74,14 @@ template <> struct Kind_<AnyList> { static constexpr Kind kind = Kind::OTHER; };
// =======================================================================================
// AnyPointer!
enum class Equality {
NOT_EQUAL,
EQUAL,
UNKNOWN_CONTAINS_CAPS
};
kj::StringPtr KJ_STRINGIFY(Equality res);
struct AnyPointer {
// Reader/Builder for the `AnyPointer` field type, i.e. a pointer that can point to an arbitrary
// object.
......@@ -90,12 +98,17 @@ struct AnyPointer {
inline MessageSize targetSize() const;
// Get the total size of the target object and all its children.
inline bool isNull() const;
inline bool isStruct() {
return reader.isStruct();
}
inline bool isList() {
return reader.isList();
inline PointerType getPointerType() const;
inline bool isNull() const { return getPointerType() == PointerType::NULL_; }
inline bool isStruct() const { return getPointerType() == PointerType::STRUCT; }
inline bool isList() const { return getPointerType() == PointerType::LIST; }
inline bool isCapability() const { return getPointerType() == PointerType::CAPABILITY; }
Equality equals(AnyPointer::Reader right);
bool operator ==(AnyPointer::Reader right);
inline bool operator !=(AnyPointer::Reader right) {
return !(*this == right);
}
template <typename T>
......@@ -138,12 +151,21 @@ struct AnyPointer {
inline MessageSize targetSize() const;
// Get the total size of the target object and all its children.
inline bool isNull();
inline bool isStruct() {
return builder.isStruct();
inline PointerType getPointerType();
inline bool isNull() { return getPointerType() == PointerType::NULL_; }
inline bool isStruct() { return getPointerType() == PointerType::STRUCT; }
inline bool isList() { return getPointerType() == PointerType::LIST; }
inline bool isCapability() { return getPointerType() == PointerType::CAPABILITY; }
inline Equality equals(AnyPointer::Reader right) {
return asReader().equals(right);
}
inline bool operator ==(AnyPointer::Reader right) {
return asReader() == right;
}
inline bool isList() {
return builder.isList();
inline bool operator !=(AnyPointer::Reader right) {
return !(*this == right);
}
inline void clear();
......@@ -432,13 +454,19 @@ public:
: _reader(_::PointerHelpers<FromReader<T>>::getInternalReader(kj::fwd<T>(value))) {}
#endif
Data::Reader getDataSection() {
kj::ArrayPtr<const byte> getDataSection() {
return _reader.getDataSectionAsBlob();
}
List<AnyPointer>::Reader getPointerSection() {
return List<AnyPointer>::Reader(_reader.getPointerSectionAsList());
}
Equality equals(AnyStruct::Reader right);
bool operator ==(AnyStruct::Reader right);
inline bool operator !=(AnyStruct::Reader right) {
return !(*this == right);
}
template <typename T>
ReaderFor<T> as() const {
// T must be a struct type.
......@@ -462,13 +490,23 @@ public:
: _builder(_::PointerHelpers<FromBuilder<T>>::getInternalBuilder(kj::fwd<T>(value))) {}
#endif
inline Data::Builder getDataSection() {
inline kj::ArrayPtr<byte> getDataSection() {
return _builder.getDataSectionAsBlob();
}
List<AnyPointer>::Builder getPointerSection() {
return List<AnyPointer>::Builder(_builder.getPointerSectionAsList());
}
inline Equality equals(AnyStruct::Reader right) {
return asReader().equals(right);
}
inline bool operator ==(AnyStruct::Reader right) {
return asReader() == right;
}
inline bool operator !=(AnyStruct::Reader right) {
return !(*this == right);
}
inline operator Reader() const { return Reader(_builder.asReader()); }
inline Reader asReader() const { return Reader(_builder.asReader()); }
......@@ -576,6 +614,14 @@ public:
inline ElementSize getElementSize() { return _reader.getElementSize(); }
inline uint size() { return _reader.size() / ELEMENTS; }
inline kj::ArrayPtr<const byte> getRawBytes() { return _reader.asRawBytes(); }
Equality equals(AnyList::Reader right);
inline bool operator ==(AnyList::Reader right);
inline bool operator !=(AnyList::Reader right) {
return !(*this == right);
}
template <typename T> ReaderFor<T> as() {
// T must be List<U>.
return ReaderFor<T>(_reader);
......@@ -601,6 +647,14 @@ public:
inline ElementSize getElementSize() { return _builder.getElementSize(); }
inline uint size() { return _builder.size() / ELEMENTS; }
Equality equals(AnyList::Reader right);
inline bool operator ==(AnyList::Reader right) {
return asReader() == right;
}
inline bool operator !=(AnyList::Reader right) {
return !(*this == right);
}
template <typename T> BuilderFor<T> as() {
// T must be List<U>.
return BuilderFor<T>(_builder);
......@@ -668,8 +722,8 @@ inline MessageSize AnyPointer::Reader::targetSize() const {
return reader.targetSize().asPublic();
}
inline bool AnyPointer::Reader::isNull() const {
return reader.isNull();
inline PointerType AnyPointer::Reader::getPointerType() const {
return reader.getPointerType();
}
template <typename T>
......@@ -681,8 +735,8 @@ inline MessageSize AnyPointer::Builder::targetSize() const {
return asReader().targetSize();
}
inline bool AnyPointer::Builder::isNull() {
return builder.isNull();
inline PointerType AnyPointer::Builder::getPointerType() {
return builder.getPointerType();
}
inline void AnyPointer::Builder::clear() {
......
......@@ -99,6 +99,17 @@ enum class ElementSize: uint8_t {
INLINE_COMPOSITE = 7
};
enum class PointerType {
// Various wire types a pointer field can take
NULL_,
// Should be NULL, but that's #defined in stddef.h
STRUCT,
LIST,
CAPABILITY
};
namespace schemas {
template <typename T>
......
......@@ -2267,20 +2267,24 @@ void PointerBuilder::clear() {
memset(pointer, 0, sizeof(WirePointer));
}
bool PointerBuilder::isNull() {
return pointer->isNull();
}
bool PointerBuilder::isStruct() {
WirePointer* ptr = pointer;
WireHelpers::followFars(ptr, ptr->target(), segment);
return ptr->kind() == WirePointer::Kind::STRUCT;
}
bool PointerBuilder::isList() {
WirePointer* ptr = pointer;
WireHelpers::followFars(ptr, ptr->target(), segment);
return ptr->kind() == WirePointer::Kind::LIST;
PointerType PointerBuilder::getPointerType() {
if(pointer->isNull()) {
return PointerType::NULL_;
} else {
WirePointer* ptr = pointer;
WireHelpers::followFars(ptr, ptr->target(), segment);
switch(ptr->kind()) {
case WirePointer::Kind::FAR:
KJ_FAIL_REQUIRE();
case WirePointer::Kind::STRUCT:
return PointerType::STRUCT;
case WirePointer::Kind::LIST:
return PointerType::LIST;
case WirePointer::Kind::OTHER:
// TODO: make sure we're only looking at capability pointers
return PointerType::CAPABILITY;
}
}
}
void PointerBuilder::transferFrom(PointerBuilder other) {
......@@ -2368,24 +2372,26 @@ MessageSizeCounts PointerReader::targetSize() const {
: WireHelpers::totalSize(segment, pointer, nestingLimit);
}
bool PointerReader::isNull() const {
return pointer == nullptr || pointer->isNull();
}
bool PointerReader::isStruct() const {
word* refTarget = nullptr;
const WirePointer* ptr = pointer;
SegmentReader* sgmt = segment;
WireHelpers::followFars(ptr, refTarget, sgmt);
return ptr->kind() == WirePointer::Kind::STRUCT;
}
bool PointerReader::isList() const {
word* refTarget = nullptr;
const WirePointer* ptr = pointer;
SegmentReader* sgmt = segment;
WireHelpers::followFars(ptr, refTarget, sgmt);
return ptr->kind() == WirePointer::Kind::LIST;
PointerType PointerReader::getPointerType() const {
if(pointer == nullptr || pointer->isNull()) {
return PointerType::NULL_;
} else {
word* refTarget = nullptr;
const WirePointer* ptr = pointer;
SegmentReader* sgmt = segment;
WireHelpers::followFars(ptr, refTarget, sgmt);
switch(ptr->kind()) {
case WirePointer::Kind::FAR:
KJ_FAIL_REQUIRE();
case WirePointer::Kind::STRUCT:
return PointerType::STRUCT;
case WirePointer::Kind::LIST:
return PointerType::LIST;
case WirePointer::Kind::OTHER:
// TODO: make sure we're only looking at capability pointers
return PointerType::CAPABILITY;
}
}
}
kj::Maybe<Arena&> PointerReader::getArena() const {
......@@ -2599,6 +2605,16 @@ Data::Reader ListReader::asData() {
return Data::Reader(reinterpret_cast<const byte*>(ptr), elementCount / ELEMENTS);
}
kj::ArrayPtr<const byte> ListReader::asRawBytes() {
KJ_REQUIRE(structPointerCount == 0 * POINTERS,
"Expected data only, got pointers.") {
return kj::ArrayPtr<const byte>();
}
return kj::ArrayPtr<const byte>(reinterpret_cast<const byte*>(ptr), structDataSize * elementCount / ELEMENTS);
}
StructReader ListReader::getStructElement(ElementCount index) const {
KJ_REQUIRE(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnp::ReaderOptions.") {
......
......@@ -274,9 +274,8 @@ public:
// Get a PointerBuilder representing a message root located in the given segment at the given
// location.
bool isNull();
bool isStruct();
bool isList();
inline bool isNull() { return getPointerType() == PointerType::NULL_; }
PointerType getPointerType();
StructBuilder getStruct(StructSize size, const word* defaultValue);
ListBuilder getList(ElementSize elementSize, const word* defaultValue);
......@@ -355,9 +354,8 @@ public:
// use the result as a hint for allocating the first segment, do the copy, and then throw an
// exception if it overruns.
bool isNull() const;
bool isStruct() const;
bool isList() const;
inline bool isNull() const { return getPointerType() == PointerType::NULL_; }
PointerType getPointerType() const;
StructReader getStruct(const word* defaultValue) const;
ListReader getList(ElementSize expectedElementSize, const word* defaultValue) const;
......@@ -408,7 +406,7 @@ public:
inline BitCount getDataSectionSize() const { return dataSize; }
inline WirePointerCount getPointerSectionSize() const { return pointerCount; }
inline Data::Builder getDataSectionAsBlob();
inline kj::ArrayPtr<byte> getDataSectionAsBlob();
inline _::ListBuilder getPointerSectionAsList();
template <typename T>
......@@ -489,7 +487,7 @@ public:
inline BitCount getDataSectionSize() const { return dataSize; }
inline WirePointerCount getPointerSectionSize() const { return pointerCount; }
inline Data::Reader getDataSectionAsBlob();
inline kj::ArrayPtr<const byte> getDataSectionAsBlob();
inline _::ListReader getPointerSectionAsList();
template <typename T>
......@@ -656,6 +654,8 @@ public:
Data::Reader asData();
// Reinterpret the list as a blob. Throws an exception if the elements are not byte-sized.
kj::ArrayPtr<const byte> asRawBytes();
template <typename T>
KJ_ALWAYS_INLINE(T getDataElement(ElementCount index) const);
// Get the element of the given type at the given index.
......@@ -821,8 +821,8 @@ inline PointerReader PointerReader::getRootUnchecked(const word* location) {
// -------------------------------------------------------------------
inline Data::Builder StructBuilder::getDataSectionAsBlob() {
return Data::Builder(reinterpret_cast<byte*>(data), dataSize / BITS_PER_BYTE / BYTES);
inline kj::ArrayPtr<byte> StructBuilder::getDataSectionAsBlob() {
return kj::ArrayPtr<byte>(reinterpret_cast<byte*>(data), dataSize / BITS_PER_BYTE / BYTES);
}
inline _::ListBuilder StructBuilder::getPointerSectionAsList() {
......@@ -905,8 +905,8 @@ inline PointerBuilder StructBuilder::getPointerField(WirePointerCount ptrIndex)
// -------------------------------------------------------------------
inline Data::Reader StructReader::getDataSectionAsBlob() {
return Data::Reader(reinterpret_cast<const byte*>(data), dataSize / BITS_PER_BYTE / BYTES);
inline kj::ArrayPtr<const byte> StructReader::getDataSectionAsBlob() {
return kj::ArrayPtr<const byte>(reinterpret_cast<const byte*>(data), dataSize / BITS_PER_BYTE / BYTES);
}
inline _::ListReader StructReader::getPointerSectionAsList() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment