Unverified Commit 21645142 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Add value_type to Dimension, deprecate int64_t cast (#4486)

* Add value_type to Dimension, deprecate int64_t cast

* Need a conversion
parent 6f9446e0
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
using namespace ngraph; using namespace ngraph;
Dimension::Dimension(int64_t dimension) Dimension::Dimension(value_type dimension)
: m_dimension(dimension) : m_dimension(dimension)
{ {
if (dimension == s_dynamic_val) if (dimension == s_dynamic_val)
...@@ -40,7 +40,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension) ...@@ -40,7 +40,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
{ {
if (dimension.is_static()) if (dimension.is_static())
{ {
return (str << int64_t(dimension)); return (str << dimension.get_length());
} }
else else
{ {
...@@ -50,36 +50,36 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension) ...@@ -50,36 +50,36 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
Dimension Dimension::operator+(const Dimension& dim) const Dimension Dimension::operator+(const Dimension& dim) const
{ {
return (is_static() && dim.is_static() ? m_dimension + int64_t(dim) : Dimension::dynamic()); return (is_static() && dim.is_static() ? m_dimension + dim.get_length() : Dimension::dynamic());
} }
Dimension Dimension::operator-(const Dimension& dim) const Dimension Dimension::operator-(const Dimension& dim) const
{ {
return (is_static() && dim.is_static() ? m_dimension - int64_t(dim) : Dimension::dynamic()); return (is_static() && dim.is_static() ? m_dimension - dim.get_length() : Dimension::dynamic());
} }
Dimension Dimension::operator*(const Dimension& dim) const Dimension Dimension::operator*(const Dimension& dim) const
{ {
return ((is_static() && dim.is_static()) return ((is_static() && dim.is_static())
? m_dimension * int64_t(dim) ? m_dimension * dim.get_length()
: (is_static() && m_dimension == 0) : (is_static() && m_dimension == 0)
? 0 ? 0
: (dim.is_static() && int64_t(dim) == 0) ? 0 : Dimension::dynamic()); : (dim.is_static() && dim.get_length() == 0) ? 0 : Dimension::dynamic());
} }
bool Dimension::compatible(const Dimension& d) const bool Dimension::compatible(const Dimension& d) const
{ {
return (is_dynamic() || d.is_dynamic() || m_dimension == int64_t(d)); return (is_dynamic() || d.is_dynamic() || m_dimension == d.get_length());
} }
bool Dimension::relaxes(const Dimension& d) const bool Dimension::relaxes(const Dimension& d) const
{ {
return (is_dynamic() || (d.is_static() && int64_t(*this) == int64_t(d))); return (is_dynamic() || (d.is_static() && get_length() == d.get_length()));
} }
bool Dimension::refines(const Dimension& d) const bool Dimension::refines(const Dimension& d) const
{ {
return (d.is_dynamic() || (is_static() && int64_t(d) == int64_t(*this))); return (d.is_dynamic() || (is_static() && d.get_length() == get_length()));
} }
bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2) bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
...@@ -94,7 +94,7 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2) ...@@ -94,7 +94,7 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
dst = d1; dst = d1;
return true; return true;
} }
else if (int64_t(d1) != int64_t(d2)) else if (d1.get_length() != d2.get_length())
{ {
return false; return false;
} }
...@@ -115,16 +115,16 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens ...@@ -115,16 +115,16 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
else if (d1.is_dynamic() || d2.is_dynamic()) else if (d1.is_dynamic() || d2.is_dynamic())
{ {
// One static. Set dst to static size if >1 // One static. Set dst to static size if >1
auto ds = d1.is_dynamic() ? int64_t(d2) : int64_t(d1); auto ds = d1.is_dynamic() ? d2.get_length() : d1.get_length();
dst = (ds > 1) ? ds : Dimension::dynamic(); dst = (ds > 1) ? ds : Dimension::dynamic();
return true; return true;
} }
else else
{ {
// Static sizes. Both match or one of them is 1. // Static sizes. Both match or one of them is 1.
if (int64_t(d1) == int64_t(d2) || int64_t(d1) == 1 || int64_t(d2) == 1) if (d1.get_length() == d2.get_length() || d1.get_length() == 1 || d2.get_length() == 1)
{ {
dst = std::max(int64_t(d1), int64_t(d2)); dst = std::max(d1.get_length(), d2.get_length());
return true; return true;
} }
else else
...@@ -134,7 +134,7 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens ...@@ -134,7 +134,7 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
} }
} }
uint64_t Dimension::get_length() const Dimension::value_type Dimension::get_length() const
{ {
if (is_dynamic()) if (is_dynamic())
{ {
......
...@@ -28,18 +28,20 @@ namespace ngraph ...@@ -28,18 +28,20 @@ namespace ngraph
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime), /// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object. /// in a shape or shape-like object.
/// ///
/// Static dimensions may be implicitly converted from int64_t. A dynamic dimension is /// Static dimensions may be implicitly converted from value_type. A dynamic dimension is
/// constructed with Dimension() or Dimension::dynamic(). /// constructed with Dimension() or Dimension::dynamic().
/// ///
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE. /// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
class NGRAPH_API Dimension class NGRAPH_API Dimension
{ {
public: public:
using value_type = int64_t;
/// \brief Construct a static dimension. /// \brief Construct a static dimension.
/// \param dimension Value of the dimension. Must not be equal to /// \param dimension Value of the dimension. Must not be equal to
/// Dimension::s_dynamic_val. /// Dimension::s_dynamic_val.
/// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val. /// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val.
Dimension(int64_t dimension); Dimension(value_type dimension);
/// \brief Construct a dynamic dimension. /// \brief Construct a dynamic dimension.
Dimension() { m_dimension = s_dynamic_val; } Dimension() { m_dimension = s_dynamic_val; }
...@@ -49,13 +51,13 @@ namespace ngraph ...@@ -49,13 +51,13 @@ namespace ngraph
/// \brief Check whether this dimension is dynamic. /// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`. /// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const { return !is_static(); } bool is_dynamic() const { return !is_static(); }
/// \brief Convert this dimension to `int64_t`. This dimension must be static. /// \brief Convert this dimension to `value_type`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic. /// \throws std::invalid_argument If this dimension is dynamic.
explicit operator int64_t() const explicit operator value_type() const NGRAPH_DEPRECATED("use get_length() instead")
{ {
if (is_dynamic()) if (is_dynamic())
{ {
throw std::invalid_argument("Cannot convert dynamic dimension to int64_t"); throw std::invalid_argument("Cannot convert dynamic dimension to value_type");
} }
return m_dimension; return m_dimension;
} }
...@@ -65,10 +67,10 @@ namespace ngraph ...@@ -65,10 +67,10 @@ namespace ngraph
/// \throws std::invalid_argument If this dimension is dynamic or negative. /// \throws std::invalid_argument If this dimension is dynamic or negative.
explicit operator size_t() const NGRAPH_DEPRECATED("use get_length() instead"); explicit operator size_t() const NGRAPH_DEPRECATED("use get_length() instead");
/// \brief Convert this dimension to `uint64_t`. This dimension must be static and /// \brief Convert this dimension to `value_type`. This dimension must be static and
/// non-negative. /// non-negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative. /// \throws std::invalid_argument If this dimension is dynamic or negative.
uint64_t get_length() const; value_type get_length() const;
/// \brief Check whether this dimension represents the same scheme as the argument (both /// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal). /// dynamic, or equal).
...@@ -78,7 +80,7 @@ namespace ngraph ...@@ -78,7 +80,7 @@ namespace ngraph
bool same_scheme(const Dimension& dim) const bool same_scheme(const Dimension& dim) const
{ {
return (is_dynamic() && dim.is_dynamic()) || return (is_dynamic() && dim.is_dynamic()) ||
(is_static() && dim.is_static() && m_dimension == int64_t(dim)); (is_static() && dim.is_static() && get_length() == dim.get_length());
} }
/// \brief Try to merge two Dimension objects together. /// \brief Try to merge two Dimension objects together.
...@@ -131,7 +133,7 @@ namespace ngraph ...@@ -131,7 +133,7 @@ namespace ngraph
/// \return A dynamic dimension. /// \return A dynamic dimension.
static Dimension dynamic() { return Dimension(); } static Dimension dynamic() { return Dimension(); }
/// \brief Constant for the value used internally to represent a dynamic dimension. /// \brief Constant for the value used internally to represent a dynamic dimension.
static const int64_t s_dynamic_val{(std::numeric_limits<int64_t>::max())}; static const value_type s_dynamic_val{(std::numeric_limits<value_type>::max())};
/// \brief Addition operator for Dimension. /// \brief Addition operator for Dimension.
/// \param dim Right operand for addition. /// \param dim Right operand for addition.
...@@ -163,7 +165,7 @@ namespace ngraph ...@@ -163,7 +165,7 @@ namespace ngraph
private: private:
// The actual numerical value of the dimension. s_dynamic_val is a special case, // The actual numerical value of the dimension. s_dynamic_val is a special case,
// representing a dynamic dimension. // representing a dynamic dimension.
int64_t m_dimension; value_type m_dimension;
}; };
/// \brief Insert a human-readable representation of a dimension into an output stream. /// \brief Insert a human-readable representation of a dimension into an output stream.
......
...@@ -59,7 +59,7 @@ namespace ngraph ...@@ -59,7 +59,7 @@ namespace ngraph
data, AxisSet{normalize_axis}, static_cast<std::size_t>(p_norm)); data, AxisSet{normalize_axis}, static_cast<std::size_t>(p_norm));
const auto target_shape = default_opset::Constant::create( const auto target_shape = default_opset::Constant::create(
element::i64, Shape{data_rank_value}, data_shape.to_shape()); element::i64, Shape{size_t(data_rank_value)}, data_shape.to_shape());
// Create a default axes order matching the data tensor rank and erase the // Create a default axes order matching the data tensor rank and erase the
// element at the 'normalize_axis' position. The erased element indicates the // element at the 'normalize_axis' position. The erased element indicates the
......
...@@ -62,13 +62,12 @@ void op::Concat::validate_and_infer_types() ...@@ -62,13 +62,12 @@ void op::Concat::validate_and_infer_types()
{ {
if (get_concatenation_axis() < 0) if (get_concatenation_axis() < 0)
{ {
set_concatenation_axis(get_axis() < 0 set_concatenation_axis(get_axis() < 0 ? get_axis() + this_input_rank.get_length()
? get_axis() + static_cast<int64_t>(this_input_rank)
: get_axis()); : get_axis());
} }
auto concat_axis = get_concatenation_axis(); auto concat_axis = get_concatenation_axis();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
concat_axis < static_cast<int64_t>(this_input_rank), concat_axis < this_input_rank.get_length(),
"Concatenation axis (", "Concatenation axis (",
concat_axis, concat_axis,
") is out of bounds for ", ") is out of bounds for ",
......
...@@ -52,8 +52,7 @@ void op::CropAndResize::validate_and_infer_types() ...@@ -52,8 +52,7 @@ void op::CropAndResize::validate_and_infer_types()
Dimension image_depth; Dimension image_depth;
if (image_shape.is_static()) if (image_shape.is_static())
{ {
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(this, image_shape.rank().get_length() == 4, "Image must be NHWC");
this, static_cast<int64_t>(image_shape.rank()) == 4, "Image must be NHWC");
image_depth = image_shape[3]; image_depth = image_shape[3];
} }
...@@ -62,10 +61,10 @@ void op::CropAndResize::validate_and_infer_types() ...@@ -62,10 +61,10 @@ void op::CropAndResize::validate_and_infer_types()
if (boxes_shape.is_static()) if (boxes_shape.is_static())
{ {
auto boxes_rank = boxes_shape.rank(); auto boxes_rank = boxes_shape.rank();
NODE_VALIDATION_CHECK(this, static_cast<int64_t>(boxes_rank) == 2, "Boxes must be 2d"); NODE_VALIDATION_CHECK(this, boxes_rank.get_length() == 2, "Boxes must be 2d");
auto boxes_dim1 = boxes_shape[1]; auto boxes_dim1 = boxes_shape[1];
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, static_cast<int64_t>(boxes_dim1) == 4, "Second boxes dimension must be 4"); this, boxes_dim1.get_length() == 4, "Second boxes dimension must be 4");
} }
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, boxes.get_element_type().is_real(), "Boxes must be real values in [0, 1]"); this, boxes.get_element_type().is_real(), "Boxes must be real values in [0, 1]");
...@@ -75,9 +74,8 @@ void op::CropAndResize::validate_and_infer_types() ...@@ -75,9 +74,8 @@ void op::CropAndResize::validate_and_infer_types()
Dimension num_boxes; Dimension num_boxes;
if (box_indices_shape.is_static()) if (box_indices_shape.is_static())
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(
static_cast<int64_t>(box_indices_shape.rank()) == 1, this, box_indices_shape.rank().get_length() == 1, "Box indices must have rank 1");
"Box indices must have rank 1");
num_boxes = box_indices_shape[0]; num_boxes = box_indices_shape[0];
} }
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
...@@ -90,11 +88,9 @@ void op::CropAndResize::validate_and_infer_types() ...@@ -90,11 +88,9 @@ void op::CropAndResize::validate_and_infer_types()
crop_size_shape.is_static() || crop_size_rank.is_dynamic(), crop_size_shape.is_static() || crop_size_rank.is_dynamic(),
"Dynamic crop_size not supported"); "Dynamic crop_size not supported");
NODE_VALIDATION_CHECK(this, crop_size_rank.get_length() == 1, "crop_size must be a vector");
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, static_cast<int64_t>(crop_size_rank) == 1, "crop_size must be a vector"); this, crop_size_shape[0].get_length() == 2, "crop_size must be a vector of length 2");
NODE_VALIDATION_CHECK(this,
static_cast<int64_t>(crop_size_shape[0]) == 2,
"crop_size must be a vector of length 2");
auto& crop_size_et = crop_size.get_element_type(); auto& crop_size_et = crop_size.get_element_type();
NODE_VALIDATION_CHECK(this, crop_size_et.is_integral(), "crops_size must be integral"); NODE_VALIDATION_CHECK(this, crop_size_et.is_integral(), "crops_size must be integral");
auto crop_size_node = crop_size.get_node_shared_ptr(); auto crop_size_node = crop_size.get_node_shared_ptr();
......
...@@ -185,7 +185,7 @@ void op::LayerNorm::pre_validate_and_infer_types() ...@@ -185,7 +185,7 @@ void op::LayerNorm::pre_validate_and_infer_types()
int64_t n_axis = -1; int64_t n_axis = -1;
if (data_rank.is_static()) if (data_rank.is_static())
{ {
d_rank = static_cast<int64_t>(data_rank); d_rank = data_rank.get_length();
n_axis = m_begin_norm_axis >= 0 ? m_begin_norm_axis : d_rank + m_begin_norm_axis; n_axis = m_begin_norm_axis >= 0 ? m_begin_norm_axis : d_rank + m_begin_norm_axis;
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, n_axis >= 0 && n_axis < d_rank, "begin_norm_axis is out of range"); this, n_axis >= 0 && n_axis < d_rank, "begin_norm_axis is out of range");
...@@ -198,8 +198,8 @@ void op::LayerNorm::pre_validate_and_infer_types() ...@@ -198,8 +198,8 @@ void op::LayerNorm::pre_validate_and_infer_types()
Rank bias_rank = bias_shape.rank(); Rank bias_rank = bias_shape.rank();
if (scale_rank.is_static() && bias_rank.is_static()) if (scale_rank.is_static() && bias_rank.is_static())
{ {
int64_t s_rank = static_cast<int64_t>(scale_rank); int64_t s_rank = scale_rank.get_length();
int64_t b_rank = static_cast<int64_t>(bias_rank); int64_t b_rank = bias_rank.get_length();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
s_rank == b_rank && s_rank == b_rank &&
((s_rank == (d_rank - n_axis)) || s_rank == 1), ((s_rank == (d_rank - n_axis)) || s_rank == 1),
...@@ -524,7 +524,7 @@ void op::LayerNormBackprop::pre_validate_and_infer_types() ...@@ -524,7 +524,7 @@ void op::LayerNormBackprop::pre_validate_and_infer_types()
int64_t n_axis = -1; int64_t n_axis = -1;
if (data_rank.is_static()) if (data_rank.is_static())
{ {
d_rank = static_cast<int64_t>(data_rank); d_rank = data_rank.get_length();
n_axis = m_begin_norm_axis >= 0 ? m_begin_norm_axis : d_rank + m_begin_norm_axis; n_axis = m_begin_norm_axis >= 0 ? m_begin_norm_axis : d_rank + m_begin_norm_axis;
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, n_axis >= 0 && n_axis < d_rank, "begin_norm_axis is out of range"); this, n_axis >= 0 && n_axis < d_rank, "begin_norm_axis is out of range");
...@@ -532,7 +532,7 @@ void op::LayerNormBackprop::pre_validate_and_infer_types() ...@@ -532,7 +532,7 @@ void op::LayerNormBackprop::pre_validate_and_infer_types()
const PartialShape& delta_shape = get_input_partial_shape(1); const PartialShape& delta_shape = get_input_partial_shape(1);
Rank delta_rank = delta_shape.rank(); Rank delta_rank = delta_shape.rank();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
delta_rank.is_dynamic() || static_cast<int64_t>(delta_rank) == d_rank, delta_rank.is_dynamic() || delta_rank.get_length() == d_rank,
"Delta rank is incorrect"); "Delta rank is incorrect");
if (m_use_stats) if (m_use_stats)
...@@ -543,8 +543,8 @@ void op::LayerNormBackprop::pre_validate_and_infer_types() ...@@ -543,8 +543,8 @@ void op::LayerNormBackprop::pre_validate_and_infer_types()
Rank var_rank = var_shape.rank(); Rank var_rank = var_shape.rank();
if (mean_rank.is_static() && var_rank.is_static()) if (mean_rank.is_static() && var_rank.is_static())
{ {
int64_t m_rank = static_cast<int64_t>(mean_rank); int64_t m_rank = mean_rank.get_length();
int64_t v_rank = static_cast<int64_t>(var_rank); int64_t v_rank = var_rank.get_length();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
m_rank == v_rank && m_rank == n_axis, m_rank == v_rank && m_rank == n_axis,
"Mean and/or variance rank is incorrect"); "Mean and/or variance rank is incorrect");
...@@ -557,7 +557,7 @@ void op::LayerNormBackprop::pre_validate_and_infer_types() ...@@ -557,7 +557,7 @@ void op::LayerNormBackprop::pre_validate_and_infer_types()
Rank scale_rank = scale_shape.rank(); Rank scale_rank = scale_shape.rank();
if (scale_rank.is_static()) if (scale_rank.is_static())
{ {
int64_t s_rank = static_cast<int64_t>(scale_rank); int64_t s_rank = scale_rank.get_length();
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, (s_rank == (d_rank - n_axis)) || s_rank == 1, "Scale rank is incorrect"); this, (s_rank == (d_rank - n_axis)) || s_rank == 1, "Scale rank is incorrect");
} }
......
...@@ -62,7 +62,7 @@ void op::MatMul::pre_validate_and_infer_types() ...@@ -62,7 +62,7 @@ void op::MatMul::pre_validate_and_infer_types()
if (A_rank.is_static() && B_rank.is_static()) if (A_rank.is_static() && B_rank.is_static())
{ {
Rank max_rank = int64_t(A_rank) > int64_t(B_rank) ? A_rank : B_rank; Rank max_rank = A_rank.get_length() > B_rank.get_length() ? A_rank : B_rank;
set_output_type(0, result_et, PartialShape::dynamic(max_rank)); set_output_type(0, result_et, PartialShape::dynamic(max_rank));
} }
} }
......
...@@ -144,9 +144,9 @@ void op::v1::NonMaxSuppression::validate_and_infer_types() ...@@ -144,9 +144,9 @@ void op::v1::NonMaxSuppression::validate_and_infer_types()
if (num_boxes_boxes.is_static() && scores_ps[1].is_static() && if (num_boxes_boxes.is_static() && scores_ps[1].is_static() &&
max_output_boxes_per_class->is_constant()) max_output_boxes_per_class->is_constant())
{ {
const auto num_boxes = static_cast<int64_t>(num_boxes_boxes); const auto num_boxes = num_boxes_boxes.get_length();
const auto max_output_boxes_per_class = max_boxes_output_from_input(); const auto max_output_boxes_per_class = max_boxes_output_from_input();
const auto num_classes = static_cast<int64_t>(scores_ps[1]); const auto num_classes = scores_ps[1].get_length();
out_shape[0] = std::min(num_boxes, max_output_boxes_per_class * num_classes); out_shape[0] = std::min(num_boxes, max_output_boxes_per_class * num_classes);
} }
......
...@@ -87,7 +87,7 @@ void op::v0::Pad::validate_and_infer_types() ...@@ -87,7 +87,7 @@ void op::v0::Pad::validate_and_infer_types()
if (arg_shape[i].is_static()) if (arg_shape[i].is_static())
{ {
ptrdiff_t result_dim = ptrdiff_t result_dim =
m_padding_below[i] + static_cast<int64_t>(arg_shape[i]) + m_padding_above[i]; m_padding_below[i] + arg_shape[i].get_length() + m_padding_above[i];
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
result_dim >= 0, result_dim >= 0,
"Inferred result dimension at axis ", "Inferred result dimension at axis ",
...@@ -321,7 +321,7 @@ void op::v1::Pad::validate_and_infer_types() ...@@ -321,7 +321,7 @@ void op::v1::Pad::validate_and_infer_types()
if (arg_shape[i].is_static()) if (arg_shape[i].is_static())
{ {
ptrdiff_t result_dim = ptrdiff_t result_dim =
pads_begin_coord[i] + static_cast<int64_t>(arg_shape[i]) + pads_end_coord[i]; pads_begin_coord[i] + arg_shape[i].get_length() + pads_end_coord[i];
result_dims[i] = static_cast<size_t>(result_dim); result_dims[i] = static_cast<size_t>(result_dim);
if (i > 1) if (i > 1)
{ {
......
...@@ -338,7 +338,7 @@ bool PartialShape::all_non_negative() const ...@@ -338,7 +338,7 @@ bool PartialShape::all_non_negative() const
{ {
for (auto& d : m_dimensions) for (auto& d : m_dimensions)
{ {
if (d.is_static() && int64_t(d) < 0) if (d.is_static() && d.get_length() < 0)
{ {
return false; return false;
} }
......
...@@ -462,7 +462,7 @@ namespace ...@@ -462,7 +462,7 @@ namespace
NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(), NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
"OneHot:v0 one hot axis dimension must be static ", "OneHot:v0 one hot axis dimension must be static ",
*node); *node);
const auto depth = static_cast<int64_t>(output_pshape[one_hot_axis]); const auto depth = output_pshape[one_hot_axis].get_length();
const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth}); const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
const auto on_value = op::Constant::create(element::i64, Shape{}, {1}); const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
......
...@@ -142,8 +142,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -142,8 +142,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
ptrdiff_t data_padded_dilated_dim = -1; ptrdiff_t data_padded_dilated_dim = -1;
if (data_dim_static) if (data_dim_static)
{ {
data_padded_dilated_dim = (static_cast<int64_t>(data_dilation[i]) * data_padded_dilated_dim =
(static_cast<int64_t>(data_shape[i]) - 1)) + (static_cast<int64_t>(data_dilation[i]) * (data_shape[i].get_length() - 1)) +
1 + data_padding_below[i] + data_padding_above[i]; 1 + data_padding_below[i] + data_padding_above[i];
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
node, node,
...@@ -158,8 +158,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -158,8 +158,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
ptrdiff_t window_dilated_dim = -1; ptrdiff_t window_dilated_dim = -1;
if (window_dim_static) if (window_dim_static)
{ {
window_dilated_dim = static_cast<int64_t>(window_dilation[i]) * window_dilated_dim =
(static_cast<int64_t>(window_shape[i]) - 1) + static_cast<int64_t>(window_dilation[i]) * (window_shape[i].get_length() - 1) +
1; 1;
NODE_VALIDATION_CHECK(node, NODE_VALIDATION_CHECK(node,
...@@ -719,17 +719,17 @@ PartialShape ngraph::infer_slice_shape(const Node* node, ...@@ -719,17 +719,17 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
// so according to tensorflow and numpy we just get 0 // so according to tensorflow and numpy we just get 0
if (lb < 0) if (lb < 0)
{ {
lb = std::max(int64_t(input_shape[input_shape_idx]) + lb, int64_t(0)); lb = std::max(input_shape[input_shape_idx].get_length() + lb, int64_t(0));
} }
if (ub < 0) if (ub < 0)
{ {
ub = std::max(int64_t(input_shape[input_shape_idx]) + ub, int64_t(0)); ub = std::max(input_shape[input_shape_idx].get_length() + ub, int64_t(0));
} }
// apply restrictions when begin or end values more than max possible values. // apply restrictions when begin or end values more than max possible values.
lb = std::min(int64_t(input_shape[input_shape_idx]), lb); lb = std::min(input_shape[input_shape_idx].get_length(), lb);
ub = std::min(int64_t(input_shape[input_shape_idx]), ub); ub = std::min(input_shape[input_shape_idx].get_length(), ub);
// set default value for stride or use given value // set default value for stride or use given value
int64_t stride = 1; int64_t stride = 1;
...@@ -746,14 +746,14 @@ PartialShape ngraph::infer_slice_shape(const Node* node, ...@@ -746,14 +746,14 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
// apply masks // apply masks
if (begin_mask.count(axis)) if (begin_mask.count(axis))
{ {
lb = int64_t(input_shape[input_shape_idx]) - 1; lb = input_shape[input_shape_idx].get_length() - 1;
} }
if (end_mask.count(axis)) if (end_mask.count(axis))
{ {
ub = -1; ub = -1;
} }
lb = std::min(lb, int64_t(input_shape[input_shape_idx]) - 1); lb = std::min(lb, input_shape[input_shape_idx].get_length() - 1);
lb -= 1; // we always get 1st element, so we need decrease range lb -= 1; // we always get 1st element, so we need decrease range
if (ub <= lb) if (ub <= lb)
{ {
...@@ -769,7 +769,7 @@ PartialShape ngraph::infer_slice_shape(const Node* node, ...@@ -769,7 +769,7 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
} }
if (end_mask.count(axis)) if (end_mask.count(axis))
{ {
ub = int64_t(input_shape[input_shape_idx]); ub = input_shape[input_shape_idx].get_length();
} }
lb += 1; // we always get 1st element, so we need decrease range lb += 1; // we always get 1st element, so we need decrease range
...@@ -829,7 +829,7 @@ int64_t ngraph::normalize_axis(const std::string& node_description, ...@@ -829,7 +829,7 @@ int64_t ngraph::normalize_axis(const std::string& node_description,
return axis; return axis;
} }
const auto tensor_rank_value = static_cast<int64_t>(tensor_rank); const auto tensor_rank_value = tensor_rank.get_length();
return normalize_axis( return normalize_axis(
node_description, axis, tensor_rank_value, -tensor_rank_value, tensor_rank_value - 1); node_description, axis, tensor_rank_value, -tensor_rank_value, tensor_rank_value - 1);
} }
...@@ -866,7 +866,7 @@ int64_t ngraph::normalize_axis(const std::string& node_description, ...@@ -866,7 +866,7 @@ int64_t ngraph::normalize_axis(const std::string& node_description,
axis = axis + tensor_rank; axis = axis + tensor_rank;
} }
return static_cast<int64_t>(axis); return int64_t(axis);
} }
void ngraph::opset1::infer_conv_backprop_output_spatial_shape(const Shape& input_data_shape, void ngraph::opset1::infer_conv_backprop_output_spatial_shape(const Shape& input_data_shape,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment