Unverified Commit bb665f19 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Avoid size_t issue in Gather (#4329)

parent a05b4823
......@@ -133,3 +133,29 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
}
}
}
uint64_t Dimension::get_length() const
{
if (is_dynamic())
{
throw std::invalid_argument("Cannot get length of dynamic dimension");
}
if (m_dimension < 0)
{
throw std::invalid_argument("Cannot get_length of negative dimension");
}
return m_dimension;
}
Dimension::operator size_t() const
{
if (is_dynamic())
{
throw std::invalid_argument("Cannot convert dynamic dimension to size_t");
}
if (m_dimension < 0)
{
throw std::invalid_argument("Cannot convert negative dimension to size_t");
}
return m_dimension;
}
......@@ -20,6 +20,7 @@
#include <stddef.h>
#include <stdexcept>
#include "ngraph/deprecated.hpp"
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
......@@ -61,18 +62,12 @@ namespace ngraph
/// \brief Convert this dimension to `size_t`. This dimension must be static and
/// non-negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative.
explicit operator size_t() const
{
if (is_dynamic())
{
throw std::invalid_argument("Cannot convert dynamic dimension to size_t");
}
if (m_dimension < 0)
{
throw std::invalid_argument("Cannot convert negative dimension to size_t");
}
return m_dimension;
}
explicit operator size_t() const NGRAPH_DEPRECATED("use get_length() instead");
/// \brief Convert this dimension to `uint64_t`. This dimension must be static and
/// non-negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative.
uint64_t get_length() const;
/// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal).
......
......@@ -120,9 +120,9 @@ void op::v1::Gather::validate_and_infer_types()
if (axis_rank.is_static() && axis_shape.is_static())
{
const auto axis_is_scalar = static_cast<size_t>(axis_rank) == 0;
const auto axis_is_scalar = axis_rank.get_length() == 0;
const auto axis_has_one_elem =
static_cast<size_t>(axis_rank) == 1 && static_cast<size_t>(axis_shape[0]) == 1;
axis_rank.get_length() == 1 && axis_shape[0].get_length() == 1;
NODE_VALIDATION_CHECK(this,
axis_is_scalar || axis_has_one_elem,
"Axes input must be scalar or have 1 element (shape: ",
......@@ -134,7 +134,7 @@ void op::v1::Gather::validate_and_infer_types()
if (input_rank.is_static() && axis != AXIS_NOT_SET_VALUE)
{
NODE_VALIDATION_CHECK(this,
axis < static_cast<size_t>(input_rank),
axis < input_rank.get_length(),
"The axis must => 0 and <= input_rank (axis: ",
axis,
").");
......@@ -150,19 +150,18 @@ void op::v1::Gather::validate_and_infer_types()
if (params_shape.rank().is_static() && indices_shape.rank().is_static() &&
axis != AXIS_NOT_SET_VALUE)
{
std::vector<Dimension> result_dims(static_cast<size_t>(params_shape.rank()) +
static_cast<size_t>(indices_shape.rank()) - 1);
size_t i = 0;
for (; i < static_cast<size_t>(axis); i++)
std::vector<Dimension> result_dims(params_shape.rank().get_length() +
indices_shape.rank().get_length() - 1);
uint64_t i = 0;
for (; i < axis; i++)
{
result_dims[i] = params_shape[i];
}
for (size_t j = 0; j < static_cast<size_t>(indices_shape.rank()); i++, j++)
for (uint64_t j = 0; j < indices_shape.rank().get_length(); i++, j++)
{
result_dims[i] = indices_shape[j];
}
for (size_t j = static_cast<size_t>(axis) + 1; j < static_cast<size_t>(params_shape.rank());
i++, j++)
for (uint64_t j = axis + 1; j < params_shape.rank().get_length(); i++, j++)
{
result_dims[i] = params_shape[j];
}
......@@ -177,7 +176,7 @@ void op::v1::Gather::validate_and_infer_types()
set_output_type(0, result_et, result_shape);
}
size_t op::v1::Gather::get_axis() const
int64_t op::v1::Gather::get_axis() const
{
int64_t axis = AXIS_NOT_SET_VALUE;
auto axes_input_node = input_value(AXIS).get_node_shared_ptr();
......@@ -190,10 +189,10 @@ size_t op::v1::Gather::get_axis() const
const auto& input_rank = get_input_partial_shape(PARAMS).rank();
if (input_rank.is_static())
{
axis += static_cast<size_t>(input_rank);
axis += input_rank.get_length();
}
}
return static_cast<size_t>(axis);
return axis;
}
void op::v1::Gather::generate_adjoints(autodiff::Adjoints& /* adjoints */,
......
......@@ -67,8 +67,7 @@ namespace ngraph
const Output<Node>& indices,
const Output<Node>& axis);
size_t get_version() const override { return 1; }
size_t get_axis() const;
int64_t get_axis() const;
void validate_and_infer_types() override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment