Commit 3deed15c authored by Adam Procter's avatar Adam Procter

Ditch == and != operators, add 'compatible' function to mean 'possibly equal/not…

Ditch == and != operators, add 'compatible' function to mean 'possibly equal/not necessarily unequal'
parent 28acd2ea
......@@ -39,12 +39,7 @@ Dimension ngraph::operator+(const Dimension& d1, const Dimension& d2)
: Dimension::undetermined());
}
bool ngraph::operator==(const Dimension& d1, const Dimension& d2)
bool Dimension::compatible(const Dimension& d) const
{
return (d1.is_determined() && d2.is_determined() && size_t(d1) == size_t(d2));
}
bool ngraph::operator!=(const Dimension& d1, const Dimension& d2)
{
return (d1.is_determined() && d2.is_determined() && size_t(d1) != size_t(d2));
return (!is_determined() || !d.is_determined() || m_dimension == size_t(d));
}
......@@ -59,15 +59,11 @@ namespace ngraph
return m_dimension;
}
/// \brief Tests whether "this" is possibly equal to s.
bool possibly_eq(const Dimension& d) const { return !(*this != d); }
/// \brief Tests whether "this" is possibly not equal to s.
bool possibly_neq(const Dimension& d) const { return !(*this == d); }
/// \brief Returns true if the dimensions are compatible, i.e. if one of the dimensions
/// is undetermined, or both dimensions are determined and equal.
bool compatible(const Dimension& d) const;
/// \brief Constructs an unknown dimension.
static Dimension undetermined() { return Dimension(); }
friend bool operator==(const Dimension& d1, const Dimension& d2);
friend bool operator!=(const Dimension& d1, const Dimension& d2);
private:
// The actual numerical value of the dimension. s_undetermined_val is a special case,
// representing an unknown dimension.
......@@ -85,16 +81,4 @@ namespace ngraph
/// If d1 and d2 are both known, returns size_t(d1)+size_t(d2). Otherwise, returns
/// Dimension::undetermined().
Dimension operator+(const Dimension& d1, const Dimension& d2);
/// \brief Equality operator for dimensions.
///
/// If d1 and d2 are both known, returns size_t(d1)==size_t(d2). Otherwise, returns
/// false.
bool operator==(const Dimension& d1, const Dimension& d2);
/// \brief Inequality operator for dimensions.
///
/// If d1 and d2 are both known, returns size_t(d1)!=size_t(d2). Otherwise, returns
/// false.
bool operator!=(const Dimension& d1, const Dimension& d2);
}
......@@ -37,7 +37,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
return PartialShape::undetermined();
}
if (s1.rank() != s2.rank())
if (!s1.rank().compatible(s2.rank()))
{
throw std::invalid_argument("rank mismatch");
}
......@@ -74,59 +74,33 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
}
}
bool ngraph::operator==(const PartialShape& s1, const PartialShape& s2)
bool PartialShape::compatible(const PartialShape& s) const
{
// If we don't know that the ranks are equal, we don't know that s1 and s2 are equal.
if (s1.rank().possibly_neq(s2.rank()))
// If we don't know *this's rank, or we don't know s's rank, they are compatible.
if (!rank_is_determined() || !s.rank_is_determined())
{
return true;
}
// If we do know *this's rank and s's rank, and they are unequal, they are incompatible.
else if (size_t(rank()) != size_t(s.rank()))
{
return false;
}
// If we do know that the ranks are equal, we check each component elementwise.
// If we know both the ranks and they are equal, we check each component elementwise. We are
// compatible iff the shapes are elementwise compatible.
else
{
for (size_t i = 0; i < size_t(s1.rank()); i++)
for (size_t i = 0; i < size_t(s.rank()); i++)
{
// If we don't know that these two corresponding elements are equal, we don't know
// that s1 and s2 are equal.
if (s1.m_dimensions[i].possibly_neq(s2.m_dimensions[i]))
if (!m_dimensions[i].compatible(s.m_dimensions[i]))
{
return false;
}
}
// If we are still here, we know that s1 and s2 have the same rank and are elementwise
// necessarily equal everywhere.
return true;
}
}
bool ngraph::operator!=(const PartialShape& s1, const PartialShape& s2)
{
// If we know that the ranks are unequal, we know s1 and s2 are unequal.
if (s1.rank() != s2.rank())
{
// compatible everywhere.
return true;
}
// If we do not know that the ranks are unequal, and we do not know that they are equal,
// then one of s1 or s2 has undetermined rank, and we do not know that s1 and s2 are unequal.
else if (s1.rank().possibly_neq(s2.rank()))
{
return false;
}
// If we do know that the ranks are equal, we check each component elementwise.
else
{
for (size_t i = 0; i < size_t(s1.rank()); i++)
{
// If we know that these two corresponding elemenats are not equal, we know that s1
// and s2 are not equal.
if (s1.m_dimensions[i] != s2.m_dimensions[i])
{
return true;
}
}
// If we are still here, then we know that s1 and s2 have the same rank, but there is
// nowhere that we know that s1 and s2 are elementwise unequal. Therefore we do not know
// that s1 and s2 are unequal.
return false;
}
}
......@@ -67,16 +67,16 @@ namespace ngraph
/// PartialShape::undetermined().
PartialShape append(const PartialShape& other);
/// \brief Tests whether "this" is possibly equal to s.
bool possibly_eq(const PartialShape& s) const { return !(*this != s); }
/// \brief Tests whether "this" is possibly not equal to s.
bool possibly_neq(const PartialShape& s) const { return !(*this == s); }
/// \brief Returns the undetermined shape.
static PartialShape undetermined() { return PartialShape(false, {}); }
/// \brief Returns true if *this is compatible with s.
///
/// Two dimensions are compatible if one or both of them is undetermined, or if
/// they are both determined and equal.
bool compatible(const PartialShape& s) const;
friend std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
friend bool operator==(const PartialShape& s1, const PartialShape& s2);
friend bool operator!=(const PartialShape& s1, const PartialShape& s2);
private:
// Private constructor so PartialShape::undetermined() can construct an undetermined shape.
......@@ -102,19 +102,6 @@ namespace ngraph
/// returns a new shape whose ith dimension is s1[i] + s2[i].
PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
/// \brief Tests whether two partial shapes are necessarily equal.
///
/// Returns true if s1 and s2's ranks are determined and equal, AND s1 and s2's
/// dimensions are elementwise determined and equal everywhere.
bool operator==(const PartialShape& s1, const PartialShape& s2);
/// \brief Tests whether two partial shapes are necessarily unequal.
///
/// Returns true if s1 and s2's ranks are determined and not equal; else, false if s1
/// and s2's ranks are not both determined; else, true if s1 and s2's dimensions are
/// elementwise determined and not equal somewhere.
bool operator!=(const PartialShape& s1, const PartialShape& s2);
/// \brief Inserts a human-readable representation of "shape" into "str".
std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
}
......@@ -27,7 +27,7 @@ TEST(partial_shape, ps_construction_empty)
ASSERT_TRUE(ps.rank_is_determined());
ASSERT_TRUE(ps.rank().is_determined());
ASSERT_TRUE(ps.is_complete());
ASSERT_EQ(ps.rank(), 0);
ASSERT_EQ(size_t(ps.rank()), 0);
}
TEST(partial_shape, ps_construction_undetermined)
......@@ -44,7 +44,7 @@ TEST(partial_shape, ps_construction_incomplete)
ASSERT_TRUE(ps.rank_is_determined());
ASSERT_TRUE(ps.rank().is_determined());
ASSERT_FALSE(ps.is_complete());
ASSERT_EQ(ps.rank(), 3);
ASSERT_EQ(size_t(ps.rank()), 3);
}
TEST(partial_shape, ps_construction_complete)
......@@ -53,7 +53,7 @@ TEST(partial_shape, ps_construction_complete)
ASSERT_TRUE(ps.rank_is_determined());
ASSERT_TRUE(ps.rank().is_determined());
ASSERT_TRUE(ps.is_complete());
ASSERT_EQ(ps.rank(), 4);
ASSERT_EQ(size_t(ps.rank()), 4);
}
TEST(partial_shape, dim_construction_determined)
......@@ -103,136 +103,86 @@ TEST(partial_shape, rank_construction_undetermined)
ASSERT_FALSE(r.is_determined());
}
TEST(partial_shape, dim_equal_left_undetermined)
TEST(partial_shape, dim_compatible_left_undetermined)
{
Dimension d1{Dimension::undetermined()};
Dimension d2{3};
ASSERT_FALSE(d1 == d2);
ASSERT_TRUE(d1.possibly_eq(d2));
ASSERT_TRUE(d1.compatible(d2));
}
TEST(partial_shape, dim_not_equal_left_undetermined)
{
Dimension d1{Dimension::undetermined()};
Dimension d2{3};
ASSERT_FALSE(d1 != d2);
ASSERT_TRUE(d1.possibly_neq(d2));
}
TEST(partial_shape, dim_equal_right_undetermined)
TEST(partial_shape, dim_compatible_right_undetermined)
{
Dimension d1{3};
Dimension d2{Dimension::undetermined()};
ASSERT_FALSE(d1 == d2);
ASSERT_TRUE(d1.possibly_eq(d2));
ASSERT_TRUE(d1.compatible(d2));
}
TEST(partial_shape, dim_not_equal_right_undetermined)
{
Dimension d1{3};
Dimension d2{Dimension::undetermined()};
ASSERT_FALSE(d1 != d2);
ASSERT_TRUE(d1.possibly_neq(d2));
}
TEST(partial_shape, dim_equal_both_undetermined)
TEST(partial_shape, dim_compatible_both_undetermined)
{
Dimension d1{Dimension::undetermined()};
Dimension d2{Dimension::undetermined()};
ASSERT_FALSE(d1 == d2);
ASSERT_TRUE(d1.possibly_eq(d2));
ASSERT_TRUE(d1.compatible(d2));
}
TEST(partial_shape, dim_not_equal_both_undetermined)
{
Dimension d1{Dimension::undetermined()};
Dimension d2{Dimension::undetermined()};
ASSERT_FALSE(d1 != d2);
ASSERT_TRUE(d1.possibly_neq(d2));
}
TEST(partial_shape, dim_equal_both_determined)
TEST(partial_shape, dim_compatible_both_determined)
{
Dimension d1{3};
Dimension d2{8};
Dimension d3{3};
ASSERT_FALSE(d1 == d2);
ASSERT_FALSE(d1.possibly_eq(d2));
ASSERT_TRUE(d1 == d3);
ASSERT_TRUE(d1.possibly_eq(d3));
ASSERT_FALSE(d1.compatible(d2));
ASSERT_TRUE(d1.compatible(d3));
}
TEST(partial_shape, dim_not_equal_both_determined)
{
Dimension d1{3};
Dimension d2{8};
Dimension d3{3};
ASSERT_TRUE(d1 != d2);
ASSERT_TRUE(d1.possibly_neq(d2));
ASSERT_FALSE(d1 != d3);
ASSERT_FALSE(d1.possibly_neq(d3));
}
TEST(partial_shape, shapes_equal_both_rank_undetermined)
TEST(partial_shape, shapes_compatible_both_rank_undetermined)
{
PartialShape ps1{PartialShape::undetermined()};
PartialShape ps2{PartialShape::undetermined()};
ASSERT_FALSE(ps1 == ps2);
ASSERT_TRUE(ps1.possibly_eq(ps2));
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_left_rank_undetermined)
TEST(partial_shape, shapes_compatible_left_rank_undetermined)
{
PartialShape ps1{3};
PartialShape ps2{PartialShape::undetermined()};
ASSERT_FALSE(ps1 == ps2);
ASSERT_TRUE(ps1.possibly_eq(ps2));
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_right_rank_undetermined)
TEST(partial_shape, shapes_compatible_right_rank_undetermined)
{
PartialShape ps1{PartialShape::undetermined()};
PartialShape ps2{4};
ASSERT_FALSE(ps1 == ps2);
ASSERT_TRUE(ps1.possibly_eq(ps2));
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_both_partial_all_known_equal)
TEST(partial_shape, shapes_compatible_both_partial_all_known_equal)
{
PartialShape ps1{2, Dimension::undetermined(), 3, Dimension::undetermined(), 5};
PartialShape ps2{2, Dimension::undetermined(), Dimension::undetermined(), 4, 5};
ASSERT_FALSE(ps1 == ps2);
ASSERT_TRUE(ps1.possibly_eq(ps2));
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_both_partial_some_known_unequal)
TEST(partial_shape, shapes_compatible_both_partial_some_known_unequal)
{
PartialShape ps1{2, Dimension::undetermined(), 3, Dimension::undetermined(), 5};
PartialShape ps2{1, Dimension::undetermined(), Dimension::undetermined(), 4, 5};
ASSERT_FALSE(ps1 == ps2);
ASSERT_FALSE(ps1.possibly_eq(ps2));
ASSERT_FALSE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_both_complete_different_rank)
TEST(partial_shape, shapes_compatible_both_complete_different_rank)
{
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 6, 8, 10};
ASSERT_FALSE(ps1 == ps2);
ASSERT_FALSE(ps1.possibly_eq(ps2));
ASSERT_FALSE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_both_complete_same_rank_same_dims)
......@@ -240,8 +190,7 @@ TEST(partial_shape, shapes_equal_both_complete_same_rank_same_dims)
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 6, 8};
ASSERT_TRUE(ps1 == ps2);
ASSERT_TRUE(ps1.possibly_eq(ps2));
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_both_complete_same_rank_different_dims)
......@@ -249,78 +198,5 @@ TEST(partial_shape, shapes_equal_both_complete_same_rank_different_dims)
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 3, 8};
ASSERT_FALSE(ps1 == ps2);
ASSERT_FALSE(ps1.possibly_eq(ps2));
}
TEST(partial_shape, shapes_not_equal_both_rank_undetermined)
{
PartialShape ps1{PartialShape::undetermined()};
PartialShape ps2{PartialShape::undetermined()};
ASSERT_FALSE(ps1 != ps2);
ASSERT_TRUE(ps1.possibly_neq(ps2));
}
TEST(partial_shape, shapes_not_equal_left_rank_undetermined)
{
PartialShape ps1{3};
PartialShape ps2{PartialShape::undetermined()};
ASSERT_FALSE(ps1 != ps2);
ASSERT_TRUE(ps1.possibly_neq(ps2));
}
TEST(partial_shape, shapes_not_equal_right_rank_undetermined)
{
PartialShape ps1{PartialShape::undetermined()};
PartialShape ps2{4};
ASSERT_FALSE(ps1 != ps2);
ASSERT_TRUE(ps1.possibly_neq(ps2));
}
TEST(partial_shape, shapes_not_equal_both_partial_all_known_equal)
{
PartialShape ps1{2, Dimension::undetermined(), 3, Dimension::undetermined(), 5};
PartialShape ps2{2, Dimension::undetermined(), Dimension::undetermined(), 4, 5};
ASSERT_FALSE(ps1 != ps2);
ASSERT_TRUE(ps1.possibly_neq(ps2));
}
TEST(partial_shape, shapes_not_equal_both_partial_some_known_unequal)
{
PartialShape ps1{2, Dimension::undetermined(), 3, Dimension::undetermined(), 5};
PartialShape ps2{1, Dimension::undetermined(), Dimension::undetermined(), 4, 5};
ASSERT_TRUE(ps1 != ps2);
ASSERT_TRUE(ps1.possibly_neq(ps2));
}
TEST(partial_shape, shapes_not_equal_both_complete_different_rank)
{
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 6, 8, 10};
ASSERT_TRUE(ps1 != ps2);
ASSERT_TRUE(ps1.possibly_neq(ps2));
}
TEST(partial_shape, shapes_not_equal_both_complete_same_rank_same_dims)
{
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 6, 8};
ASSERT_FALSE(ps1 != ps2);
ASSERT_FALSE(ps1.possibly_neq(ps2));
}
TEST(partial_shape, shapes_not_equal_both_complete_same_rank_different_dims)
{
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 3, 8};
ASSERT_TRUE(ps1 != ps2);
ASSERT_TRUE(ps1.possibly_neq(ps2));
ASSERT_FALSE(ps1.compatible(ps2));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment