Commit c73c33f2 authored by Adam Procter's avatar Adam Procter

Change Dimension to be able to represent signed values

parent ec2f69bf
...@@ -46,6 +46,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type, ...@@ -46,6 +46,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
void descriptor::Tensor::set_tensor_type(const element::Type& element_type, void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const PartialShape& pshape) const PartialShape& pshape)
{ {
NGRAPH_CHECK(pshape.all_non_negative(),
"set_tensor_type called on a PartialShape containing non-negative dimensions: ",
pshape);
if (pshape.is_static()) if (pshape.is_static())
{ {
m_shape = pshape.to_shape(); m_shape = pshape.to_shape();
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
using namespace ngraph; using namespace ngraph;
Dimension::Dimension(size_t dimension) Dimension::Dimension(int64_t dimension)
: m_dimension(dimension) : m_dimension(dimension)
{ {
if (dimension == s_dynamic_val) if (dimension == s_dynamic_val)
...@@ -40,7 +40,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension) ...@@ -40,7 +40,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
{ {
if (dimension.is_static()) if (dimension.is_static())
{ {
return (str << size_t(dimension)); return (str << int64_t(dimension));
} }
else else
{ {
...@@ -50,36 +50,36 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension) ...@@ -50,36 +50,36 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
Dimension Dimension::operator+(const Dimension& dim) const Dimension Dimension::operator+(const Dimension& dim) const
{ {
return (is_static() && dim.is_static() ? m_dimension + size_t(dim) : Dimension::dynamic()); return (is_static() && dim.is_static() ? m_dimension + int64_t(dim) : Dimension::dynamic());
} }
Dimension Dimension::operator-(const Dimension& dim) const Dimension Dimension::operator-(const Dimension& dim) const
{ {
return (is_static() && dim.is_static() ? m_dimension - size_t(dim) : Dimension::dynamic()); return (is_static() && dim.is_static() ? m_dimension - int64_t(dim) : Dimension::dynamic());
} }
Dimension Dimension::operator*(const Dimension& dim) const Dimension Dimension::operator*(const Dimension& dim) const
{ {
return ((is_static() && dim.is_static()) return ((is_static() && dim.is_static())
? m_dimension * size_t(dim) ? m_dimension * int64_t(dim)
: (is_static() && m_dimension == 0) : (is_static() && m_dimension == 0)
? 0 ? 0
: (dim.is_static() && size_t(dim) == 0) ? 0 : Dimension::dynamic()); : (dim.is_static() && int64_t(dim) == 0) ? 0 : Dimension::dynamic());
} }
bool Dimension::compatible(const Dimension& d) const bool Dimension::compatible(const Dimension& d) const
{ {
return (is_dynamic() || d.is_dynamic() || m_dimension == size_t(d)); return (is_dynamic() || d.is_dynamic() || m_dimension == int64_t(d));
} }
bool Dimension::relaxes(const Dimension& d) const bool Dimension::relaxes(const Dimension& d) const
{ {
return (is_dynamic() || (d.is_static() && size_t(*this) == size_t(d))); return (is_dynamic() || (d.is_static() && int64_t(*this) == int64_t(d)));
} }
bool Dimension::refines(const Dimension& d) const bool Dimension::refines(const Dimension& d) const
{ {
return (d.is_dynamic() || (is_static() && size_t(d) == size_t(*this))); return (d.is_dynamic() || (is_static() && int64_t(d) == int64_t(*this)));
} }
bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2) bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
...@@ -94,7 +94,7 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2) ...@@ -94,7 +94,7 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
dst = d1; dst = d1;
return true; return true;
} }
else if (size_t(d1) != size_t(d2)) else if (int64_t(d1) != int64_t(d2))
{ {
return false; return false;
} }
...@@ -115,16 +115,16 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens ...@@ -115,16 +115,16 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
else if (d1.is_dynamic() || d2.is_dynamic()) else if (d1.is_dynamic() || d2.is_dynamic())
{ {
// One static. Set dst to static size if >1 // One static. Set dst to static size if >1
auto ds = d1.is_dynamic() ? size_t(d2) : size_t(d1); auto ds = d1.is_dynamic() ? int64_t(d2) : int64_t(d1);
dst = (ds > 1) ? ds : Dimension::dynamic(); dst = (ds > 1) ? ds : Dimension::dynamic();
return true; return true;
} }
else else
{ {
// Static sizes. Both match or one of them is 1. // Static sizes. Both match or one of them is 1.
if (size_t(d1) == size_t(d2) || size_t(d1) == 1 || size_t(d2) == 1) if (int64_t(d1) == int64_t(d2) || int64_t(d1) == 1 || int64_t(d2) == 1)
{ {
dst = std::max(size_t(d1), size_t(d2)); dst = std::max(int64_t(d1), int64_t(d2));
return true; return true;
} }
else else
......
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime), /// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object. /// in a shape or shape-like object.
/// ///
/// Static dimensions may be implicitly converted from size_t. A dynamic dimension is /// Static dimensions may be implicitly converted from int64_t. A dynamic dimension is
/// constructed with Dimension() or Dimension::dynamic(). /// constructed with Dimension() or Dimension::dynamic().
/// ///
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE. /// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// \param dimension Value of the dimension. Must not be equal to /// \param dimension Value of the dimension. Must not be equal to
/// Dimension::s_dynamic_val. /// Dimension::s_dynamic_val.
/// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val. /// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val.
Dimension(size_t dimension); Dimension(int64_t dimension);
/// \brief Construct a dynamic dimension. /// \brief Construct a dynamic dimension.
Dimension() { m_dimension = s_dynamic_val; } Dimension() { m_dimension = s_dynamic_val; }
...@@ -46,14 +46,29 @@ namespace ngraph ...@@ -46,14 +46,29 @@ namespace ngraph
/// \brief Check whether this dimension is dynamic. /// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`. /// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const { return !is_static(); } bool is_dynamic() const { return !is_static(); }
/// \brief Convert this dimension to `size_t`. This dimension must be static. /// \brief Convert this dimension to `int64_t`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic. /// \throws std::invalid_argument If this dimension is dynamic.
explicit operator int64_t() const
{
if (is_dynamic())
{
throw std::invalid_argument("Cannot convert dynamic dimension to int64_t");
}
return m_dimension;
}
/// \brief Convert this dimension to `size_t`. This dimension must be static and
/// non-negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative.
explicit operator size_t() const explicit operator size_t() const
{ {
if (is_dynamic()) if (is_dynamic())
{ {
throw std::invalid_argument("Cannot convert dynamic dimension to size_t"); throw std::invalid_argument("Cannot convert dynamic dimension to size_t");
} }
if (m_dimension < 0)
{
throw std::invalid_argument("Cannot convert negative dimension to size_t");
}
return m_dimension; return m_dimension;
} }
/// \brief Convert this dimension to `ptrdiff_t`. This dimension must be static. /// \brief Convert this dimension to `ptrdiff_t`. This dimension must be static.
...@@ -75,7 +90,7 @@ namespace ngraph ...@@ -75,7 +90,7 @@ namespace ngraph
bool same_scheme(const Dimension& dim) const bool same_scheme(const Dimension& dim) const
{ {
return (is_dynamic() && dim.is_dynamic()) || return (is_dynamic() && dim.is_dynamic()) ||
(is_static() && dim.is_static() && m_dimension == size_t(dim)); (is_static() && dim.is_static() && m_dimension == int64_t(dim));
} }
/// \brief Try to merge two Dimension objects together. /// \brief Try to merge two Dimension objects together.
...@@ -128,25 +143,25 @@ namespace ngraph ...@@ -128,25 +143,25 @@ namespace ngraph
/// \return A dynamic dimension. /// \return A dynamic dimension.
static Dimension dynamic() { return Dimension(); } static Dimension dynamic() { return Dimension(); }
/// \brief Constant for the value used internally to represent a dynamic dimension. /// \brief Constant for the value used internally to represent a dynamic dimension.
static const size_t s_dynamic_val{(std::numeric_limits<size_t>::max())}; static const int64_t s_dynamic_val{(std::numeric_limits<int64_t>::max())};
/// \brief Addition operator for Dimension. /// \brief Addition operator for Dimension.
/// \param dim Right operand for addition. /// \param dim Right operand for addition.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static /// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `size_t(*this)+size_t(dim)`. /// dimension with value `int64_t(*this)+in64_t(dim)`.
Dimension operator+(const Dimension& dim) const; Dimension operator+(const Dimension& dim) const;
/// \brief Subtraction operator for Dimension. /// \brief Subtraction operator for Dimension.
/// \param dim Right operand for subtraction. /// \param dim Right operand for subtraction.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static /// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `size_t(*this)-size_t(dim)`. /// dimension with value `int64_t(*this)-int64_t(dim)`.
Dimension operator-(const Dimension& dim) const; Dimension operator-(const Dimension& dim) const;
/// \brief Multiplication operator for Dimension. /// \brief Multiplication operator for Dimension.
/// \param dim Right operand for multiplicaiton. /// \param dim Right operand for multiplicaiton.
/// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if /// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if
/// either of `*this` or `dim` is dynamic; else, a static dimension with value /// either of `*this` or `dim` is dynamic; else, a static dimension with value
/// `size_t(*this)*size_t(dim)`. /// `int64_t(*this)*int64_t(dim)`.
Dimension operator*(const Dimension& dim) const; Dimension operator*(const Dimension& dim) const;
/// \brief Add-into operator for Dimension. /// \brief Add-into operator for Dimension.
...@@ -160,7 +175,7 @@ namespace ngraph ...@@ -160,7 +175,7 @@ namespace ngraph
private: private:
// The actual numerical value of the dimension. s_dynamic_val is a special case, // The actual numerical value of the dimension. s_dynamic_val is a special case,
// representing a dynamic dimension. // representing a dynamic dimension.
size_t m_dimension; int64_t m_dimension;
}; };
/// \brief Insert a human-readable representation of a dimension into an output stream. /// \brief Insert a human-readable representation of a dimension into an output stream.
...@@ -168,6 +183,6 @@ namespace ngraph ...@@ -168,6 +183,6 @@ namespace ngraph
/// \param dimension The dimension to be inserted into `str`. /// \param dimension The dimension to be inserted into `str`.
/// \return A reference to `str` after insertion. /// \return A reference to `str` after insertion.
/// ///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `size_t(dimension)`. /// Inserts the string `?` if `dimension` is dynamic; else inserts `int64_t(dimension)`.
std::ostream& operator<<(std::ostream& str, const Dimension& dimension); std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
} }
...@@ -275,3 +275,16 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst, ...@@ -275,3 +275,16 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
return success; return success;
} }
} }
bool PartialShape::all_non_negative() const
{
for (auto& d : m_dimensions)
{
if (d.is_static() && int64_t(d) < 0)
{
return false;
}
}
return true;
}
...@@ -164,6 +164,10 @@ namespace ngraph ...@@ -164,6 +164,10 @@ namespace ngraph
/// \throws std::invalid_argument If this PartialShape is dynamic. /// \throws std::invalid_argument If this PartialShape is dynamic.
Shape to_shape() const; Shape to_shape() const;
/// \brief Returns `true` if all static dimensions of the tensor are non-negative, else
/// `false`.
bool all_non_negative() const;
/// \brief Index operator for PartialShape. /// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected. /// \param i The index of the dimension being selected.
/// \return A reference to the `i`th Dimension of this shape. /// \return A reference to the `i`th Dimension of this shape.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment