Commit e1b5cfbc authored by Adam Procter's avatar Adam Procter

Tests, and 'possibly_eq/possibly_neq' methods

parent 5f621476
......@@ -42,13 +42,20 @@ namespace ngraph
{
}
/// \brief Returns true if this dimension is known.
/// \brief Returns true if this dimension is determined.
bool is_determined() const { return m_dimension != s_undetermined_val; }
/// \brief Converts this dimension to size_t. If the dimension is unknown, behavior is
/// undefined.
/// \brief Converts this dimension to size_t. If the dimension is undetermined, return
/// value is implementation-dependent.
explicit operator size_t() const { return m_dimension; }
/// \brief Tests whether "this" is possibly equal to s.
bool possibly_eq(const Dimension& d) const { return !(*this != d); }
/// \brief Tests whether "this" is possibly not equal to s.
bool possibly_neq(const Dimension& d) const { return !(*this == d); }
/// \brief Constructs an unknown dimension.
static Dimension undetermined() { return s_undetermined_val; }
friend bool operator==(const Dimension& d1, const Dimension& d2);
friend bool operator!=(const Dimension& d1, const Dimension& d2);
private:
// The actual numerical value of the dimension. s_undetermined_val is a special case,
// representing an unknown dimension.
......
......@@ -53,6 +53,7 @@
#include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/dimension.hpp"
#include "ngraph/except.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
......@@ -126,6 +127,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/shape.hpp"
......
......@@ -73,3 +73,60 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
return (str << "?");
}
}
bool ngraph::operator==(const PartialShape& s1, const PartialShape& s2)
{
// If we don't know that the ranks are equal, we don't know that s1 and s2 are equal.
if (s1.rank().possibly_neq(s2.rank()))
{
return false;
}
// If we do know that the ranks are equal, we check each component elementwise.
else
{
for (size_t i = 0; i < size_t(s1.rank()); i++)
{
// If we don't know that these two corresponding elements are equal, we don't know
// that s1 and s2 are equal.
if (s1.m_dimensions[i].possibly_neq(s2.m_dimensions[i]))
{
return false;
}
}
// If we are still here, we know that s1 and s2 have the same rank and are elementwise
// necessarily equal everywhere.
return true;
}
}
bool ngraph::operator!=(const PartialShape& s1, const PartialShape& s2)
{
// If we know that the ranks are unequal, we know s1 and s2 are unequal.
if (s1.rank() != s2.rank())
{
return true;
}
// If we do not know that the ranks are unequal, and we do not know that they are equal,
// then one of s1 or s2 has undetermined rank, and we do not know that s1 and s2 are unequal.
else if (s1.rank().possibly_neq(s2.rank()))
{
return false;
}
// If we do know that the ranks are equal, we check each component elementwise.
else
{
for (size_t i = 0; i < size_t(s1.rank()); i++)
{
// If we know that these two corresponding elemenats are not equal, we know that s1
// and s2 are not equal.
if (s1.m_dimensions[i] != s2.m_dimensions[i])
{
return true;
}
}
// If we are still here, then we know that s1 and s2 have the same rank, but there is
// nowhere that we know that s1 and s2 are elementwise unequal. Therefore we do not know
// that s1 and s2 are unequal.
return false;
}
}
......@@ -69,10 +69,16 @@ namespace ngraph
/// PartialShape::undetermined().
PartialShape append(const PartialShape& other);
/// \brief Tests whether "this" is possibly equal to s.
bool possibly_eq(const PartialShape& s) const { return !(*this != s); }
/// \brief Tests whether "this" is possibly not equal to s.
bool possibly_neq(const PartialShape& s) const { return !(*this == s); }
/// \brief Returns the undetermined shape.
static PartialShape undetermined() { return PartialShape(false, {}); }
friend std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
friend bool operator==(const PartialShape& s1, const PartialShape& s2);
friend bool operator!=(const PartialShape& s1, const PartialShape& s2);
private:
// Private constructor so PartialShape::undetermined() can construct an undetermined shape.
......@@ -98,6 +104,19 @@ namespace ngraph
/// returns a new shape whose ith dimension is s1[i] + s2[i].
PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
/// \brief Tests whether two partial shapes are necessarily equal.
///
/// Returns true if s1 and s2's ranks are determined and equal, AND s1 and s2's
/// dimensions are elementwise determined and equal everywhere.
bool operator==(const PartialShape& s1, const PartialShape& s2);
/// \brief Tests whether two partial shapes are necessarily unequal.
///
/// Returns true if s1 and s2's ranks are determined and not equal; else, false if s1
/// and s2's ranks are not both determined; else, true if s1 and s2's dimensions are
/// elementwise determined and not equal somewhere.
bool operator!=(const PartialShape& s1, const PartialShape& s2);
/// \brief Inserts a human-readable representation of "shape" into "str".
std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
}
......@@ -34,6 +34,7 @@ set(SRC
main.cpp
nop_elimination.cpp
op.cpp
partial_shape.cpp
pass_liveness.cpp
pass_manager.cpp
pass_memory_layout.cpp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment