Commit c73c33f2 authored by Adam Procter's avatar Adam Procter

Change Dimension to be able to represent signed values

parent ec2f69bf
......@@ -46,6 +46,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const PartialShape& pshape)
{
NGRAPH_CHECK(pshape.all_non_negative(),
"set_tensor_type called on a PartialShape containing non-negative dimensions: ",
pshape);
if (pshape.is_static())
{
m_shape = pshape.to_shape();
......
......@@ -23,7 +23,7 @@
using namespace ngraph;
Dimension::Dimension(size_t dimension)
Dimension::Dimension(int64_t dimension)
: m_dimension(dimension)
{
if (dimension == s_dynamic_val)
......@@ -40,7 +40,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
{
if (dimension.is_static())
{
return (str << size_t(dimension));
return (str << int64_t(dimension));
}
else
{
......@@ -50,36 +50,36 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
Dimension Dimension::operator+(const Dimension& dim) const
{
return (is_static() && dim.is_static() ? m_dimension + size_t(dim) : Dimension::dynamic());
return (is_static() && dim.is_static() ? m_dimension + int64_t(dim) : Dimension::dynamic());
}
Dimension Dimension::operator-(const Dimension& dim) const
{
return (is_static() && dim.is_static() ? m_dimension - size_t(dim) : Dimension::dynamic());
return (is_static() && dim.is_static() ? m_dimension - int64_t(dim) : Dimension::dynamic());
}
Dimension Dimension::operator*(const Dimension& dim) const
{
return ((is_static() && dim.is_static())
? m_dimension * size_t(dim)
? m_dimension * int64_t(dim)
: (is_static() && m_dimension == 0)
? 0
: (dim.is_static() && size_t(dim) == 0) ? 0 : Dimension::dynamic());
: (dim.is_static() && int64_t(dim) == 0) ? 0 : Dimension::dynamic());
}
bool Dimension::compatible(const Dimension& d) const
{
return (is_dynamic() || d.is_dynamic() || m_dimension == size_t(d));
return (is_dynamic() || d.is_dynamic() || m_dimension == int64_t(d));
}
bool Dimension::relaxes(const Dimension& d) const
{
return (is_dynamic() || (d.is_static() && size_t(*this) == size_t(d)));
return (is_dynamic() || (d.is_static() && int64_t(*this) == int64_t(d)));
}
bool Dimension::refines(const Dimension& d) const
{
return (d.is_dynamic() || (is_static() && size_t(d) == size_t(*this)));
return (d.is_dynamic() || (is_static() && int64_t(d) == int64_t(*this)));
}
bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
......@@ -94,7 +94,7 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
dst = d1;
return true;
}
else if (size_t(d1) != size_t(d2))
else if (int64_t(d1) != int64_t(d2))
{
return false;
}
......@@ -115,16 +115,16 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
else if (d1.is_dynamic() || d2.is_dynamic())
{
// One static. Set dst to static size if >1
auto ds = d1.is_dynamic() ? size_t(d2) : size_t(d1);
auto ds = d1.is_dynamic() ? int64_t(d2) : int64_t(d1);
dst = (ds > 1) ? ds : Dimension::dynamic();
return true;
}
else
{
// Static sizes. Both match or one of them is 1.
if (size_t(d1) == size_t(d2) || size_t(d1) == 1 || size_t(d2) == 1)
if (int64_t(d1) == int64_t(d2) || int64_t(d1) == 1 || int64_t(d2) == 1)
{
dst = std::max(size_t(d1), size_t(d2));
dst = std::max(int64_t(d1), int64_t(d2));
return true;
}
else
......
......@@ -25,7 +25,7 @@ namespace ngraph
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object.
///
/// Static dimensions may be implicitly converted from size_t. A dynamic dimension is
/// Static dimensions may be implicitly converted from int64_t. A dynamic dimension is
/// constructed with Dimension() or Dimension::dynamic().
///
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
......@@ -36,7 +36,7 @@ namespace ngraph
/// \param dimension Value of the dimension. Must not be equal to
/// Dimension::s_dynamic_val.
/// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val.
Dimension(size_t dimension);
Dimension(int64_t dimension);
/// \brief Construct a dynamic dimension.
Dimension() { m_dimension = s_dynamic_val; }
......@@ -46,14 +46,29 @@ namespace ngraph
/// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const { return !is_static(); }
/// \brief Convert this dimension to `size_t`. This dimension must be static.
/// \brief Convert this dimension to `int64_t`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic.
explicit operator int64_t() const
{
if (is_dynamic())
{
throw std::invalid_argument("Cannot convert dynamic dimension to int64_t");
}
return m_dimension;
}
/// \brief Convert this dimension to `size_t`. This dimension must be static and
/// non-negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative.
explicit operator size_t() const
{
if (is_dynamic())
{
throw std::invalid_argument("Cannot convert dynamic dimension to size_t");
}
if (m_dimension < 0)
{
throw std::invalid_argument("Cannot convert negative dimension to size_t");
}
return m_dimension;
}
/// \brief Convert this dimension to `ptrdiff_t`. This dimension must be static.
......@@ -75,7 +90,7 @@ namespace ngraph
bool same_scheme(const Dimension& dim) const
{
return (is_dynamic() && dim.is_dynamic()) ||
(is_static() && dim.is_static() && m_dimension == size_t(dim));
(is_static() && dim.is_static() && m_dimension == int64_t(dim));
}
/// \brief Try to merge two Dimension objects together.
......@@ -128,25 +143,25 @@ namespace ngraph
/// \return A dynamic dimension.
static Dimension dynamic() { return Dimension(); }
/// \brief Constant for the value used internally to represent a dynamic dimension.
static const size_t s_dynamic_val{(std::numeric_limits<size_t>::max())};
static const int64_t s_dynamic_val{(std::numeric_limits<int64_t>::max())};
/// \brief Addition operator for Dimension.
/// \param dim Right operand for addition.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `size_t(*this)+size_t(dim)`.
/// dimension with value `int64_t(*this)+in64_t(dim)`.
Dimension operator+(const Dimension& dim) const;
/// \brief Subtraction operator for Dimension.
/// \param dim Right operand for subtraction.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `size_t(*this)-size_t(dim)`.
/// dimension with value `int64_t(*this)-int64_t(dim)`.
Dimension operator-(const Dimension& dim) const;
/// \brief Multiplication operator for Dimension.
/// \param dim Right operand for multiplicaiton.
/// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if
/// either of `*this` or `dim` is dynamic; else, a static dimension with value
/// `size_t(*this)*size_t(dim)`.
/// `int64_t(*this)*int64_t(dim)`.
Dimension operator*(const Dimension& dim) const;
/// \brief Add-into operator for Dimension.
......@@ -160,7 +175,7 @@ namespace ngraph
private:
// The actual numerical value of the dimension. s_dynamic_val is a special case,
// representing a dynamic dimension.
size_t m_dimension;
int64_t m_dimension;
};
/// \brief Insert a human-readable representation of a dimension into an output stream.
......@@ -168,6 +183,6 @@ namespace ngraph
/// \param dimension The dimension to be inserted into `str`.
/// \return A reference to `str` after insertion.
///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `size_t(dimension)`.
/// Inserts the string `?` if `dimension` is dynamic; else inserts `int64_t(dimension)`.
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
}
......@@ -275,3 +275,16 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
return success;
}
}
bool PartialShape::all_non_negative() const
{
for (auto& d : m_dimensions)
{
if (d.is_static() && int64_t(d) < 0)
{
return false;
}
}
return true;
}
......@@ -164,6 +164,10 @@ namespace ngraph
/// \throws std::invalid_argument If this PartialShape is dynamic.
Shape to_shape() const;
/// \brief Returns `true` if all static dimensions of the tensor are non-negative, else
/// `false`.
bool all_non_negative() const;
/// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected.
/// \return A reference to the `i`th Dimension of this shape.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment