Unverified Commit bb665f19 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Avoid size_t issue in Gather (#4329)

parent a05b4823
...@@ -133,3 +133,29 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens ...@@ -133,3 +133,29 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
} }
} }
} }
uint64_t Dimension::get_length() const
{
if (is_dynamic())
{
throw std::invalid_argument("Cannot get length of dynamic dimension");
}
if (m_dimension < 0)
{
throw std::invalid_argument("Cannot get_length of negative dimension");
}
return m_dimension;
}
Dimension::operator size_t() const
{
if (is_dynamic())
{
throw std::invalid_argument("Cannot convert dynamic dimension to size_t");
}
if (m_dimension < 0)
{
throw std::invalid_argument("Cannot convert negative dimension to size_t");
}
return m_dimension;
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <stddef.h> #include <stddef.h>
#include <stdexcept> #include <stdexcept>
#include "ngraph/deprecated.hpp"
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
namespace ngraph namespace ngraph
...@@ -61,18 +62,12 @@ namespace ngraph ...@@ -61,18 +62,12 @@ namespace ngraph
/// \brief Convert this dimension to `size_t`. This dimension must be static and /// \brief Convert this dimension to `size_t`. This dimension must be static and
/// non-negative. /// non-negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative. /// \throws std::invalid_argument If this dimension is dynamic or negative.
explicit operator size_t() const explicit operator size_t() const NGRAPH_DEPRECATED("use get_length() instead");
{
if (is_dynamic()) /// \brief Convert this dimension to `uint64_t`. This dimension must be static and
{ /// non-negative.
throw std::invalid_argument("Cannot convert dynamic dimension to size_t"); /// \throws std::invalid_argument If this dimension is dynamic or negative.
} uint64_t get_length() const;
if (m_dimension < 0)
{
throw std::invalid_argument("Cannot convert negative dimension to size_t");
}
return m_dimension;
}
/// \brief Check whether this dimension represents the same scheme as the argument (both /// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal). /// dynamic, or equal).
......
...@@ -120,9 +120,9 @@ void op::v1::Gather::validate_and_infer_types() ...@@ -120,9 +120,9 @@ void op::v1::Gather::validate_and_infer_types()
if (axis_rank.is_static() && axis_shape.is_static()) if (axis_rank.is_static() && axis_shape.is_static())
{ {
const auto axis_is_scalar = static_cast<size_t>(axis_rank) == 0; const auto axis_is_scalar = axis_rank.get_length() == 0;
const auto axis_has_one_elem = const auto axis_has_one_elem =
static_cast<size_t>(axis_rank) == 1 && static_cast<size_t>(axis_shape[0]) == 1; axis_rank.get_length() == 1 && axis_shape[0].get_length() == 1;
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axis_is_scalar || axis_has_one_elem, axis_is_scalar || axis_has_one_elem,
"Axes input must be scalar or have 1 element (shape: ", "Axes input must be scalar or have 1 element (shape: ",
...@@ -134,7 +134,7 @@ void op::v1::Gather::validate_and_infer_types() ...@@ -134,7 +134,7 @@ void op::v1::Gather::validate_and_infer_types()
if (input_rank.is_static() && axis != AXIS_NOT_SET_VALUE) if (input_rank.is_static() && axis != AXIS_NOT_SET_VALUE)
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axis < static_cast<size_t>(input_rank), axis < input_rank.get_length(),
"The axis must => 0 and <= input_rank (axis: ", "The axis must => 0 and <= input_rank (axis: ",
axis, axis,
")."); ").");
...@@ -150,19 +150,18 @@ void op::v1::Gather::validate_and_infer_types() ...@@ -150,19 +150,18 @@ void op::v1::Gather::validate_and_infer_types()
if (params_shape.rank().is_static() && indices_shape.rank().is_static() && if (params_shape.rank().is_static() && indices_shape.rank().is_static() &&
axis != AXIS_NOT_SET_VALUE) axis != AXIS_NOT_SET_VALUE)
{ {
std::vector<Dimension> result_dims(static_cast<size_t>(params_shape.rank()) + std::vector<Dimension> result_dims(params_shape.rank().get_length() +
static_cast<size_t>(indices_shape.rank()) - 1); indices_shape.rank().get_length() - 1);
size_t i = 0; uint64_t i = 0;
for (; i < static_cast<size_t>(axis); i++) for (; i < axis; i++)
{ {
result_dims[i] = params_shape[i]; result_dims[i] = params_shape[i];
} }
for (size_t j = 0; j < static_cast<size_t>(indices_shape.rank()); i++, j++) for (uint64_t j = 0; j < indices_shape.rank().get_length(); i++, j++)
{ {
result_dims[i] = indices_shape[j]; result_dims[i] = indices_shape[j];
} }
for (size_t j = static_cast<size_t>(axis) + 1; j < static_cast<size_t>(params_shape.rank()); for (uint64_t j = axis + 1; j < params_shape.rank().get_length(); i++, j++)
i++, j++)
{ {
result_dims[i] = params_shape[j]; result_dims[i] = params_shape[j];
} }
...@@ -177,7 +176,7 @@ void op::v1::Gather::validate_and_infer_types() ...@@ -177,7 +176,7 @@ void op::v1::Gather::validate_and_infer_types()
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
size_t op::v1::Gather::get_axis() const int64_t op::v1::Gather::get_axis() const
{ {
int64_t axis = AXIS_NOT_SET_VALUE; int64_t axis = AXIS_NOT_SET_VALUE;
auto axes_input_node = input_value(AXIS).get_node_shared_ptr(); auto axes_input_node = input_value(AXIS).get_node_shared_ptr();
...@@ -190,10 +189,10 @@ size_t op::v1::Gather::get_axis() const ...@@ -190,10 +189,10 @@ size_t op::v1::Gather::get_axis() const
const auto& input_rank = get_input_partial_shape(PARAMS).rank(); const auto& input_rank = get_input_partial_shape(PARAMS).rank();
if (input_rank.is_static()) if (input_rank.is_static())
{ {
axis += static_cast<size_t>(input_rank); axis += input_rank.get_length();
} }
} }
return static_cast<size_t>(axis); return axis;
} }
void op::v1::Gather::generate_adjoints(autodiff::Adjoints& /* adjoints */, void op::v1::Gather::generate_adjoints(autodiff::Adjoints& /* adjoints */,
......
...@@ -67,8 +67,7 @@ namespace ngraph ...@@ -67,8 +67,7 @@ namespace ngraph
const Output<Node>& indices, const Output<Node>& indices,
const Output<Node>& axis); const Output<Node>& axis);
size_t get_version() const override { return 1; } int64_t get_axis() const;
size_t get_axis() const;
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment