Unverified Commit 6d299dab authored by Mateusz Bencer's avatar Mateusz Bencer Committed by GitHub

Extend normalize axes helpers to support dynamic shapes (#4299)

* Extend normalization. Part.1

* Normalize axis. Part.2

* Code review remarks introduced

* Fixed normalizes ranges type

* Trigger CI
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 6a369747
No related merge requests found
...@@ -34,8 +34,10 @@ namespace ngraph ...@@ -34,8 +34,10 @@ namespace ngraph
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis"); std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
const auto normalized_axis = ngraph::normalize_axis( const auto normalized_axis =
node.get_description(), axis, inputs.at(0)->get_shape().size()); ngraph::normalize_axis(node.get_description(),
axis,
inputs.at(0)->get_output_partial_shape(0).rank());
return {std::make_shared<default_opset::Concat>(inputs, normalized_axis)}; return {std::make_shared<default_opset::Concat>(inputs, normalized_axis)};
} }
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "ngraph/op/dequantize.hpp" #include "ngraph/op/dequantize.hpp"
#include "ngraph/opsets/opset0.hpp" #include "ngraph/opsets/opset0.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -57,20 +58,13 @@ namespace ngraph ...@@ -57,20 +58,13 @@ namespace ngraph
int64_t axis_0{node.get_attribute_value<int64_t>("axis", 0)}; int64_t axis_0{node.get_attribute_value<int64_t>("axis", 0)};
int64_t axis_1{node.get_attribute_value<int64_t>("axis", 1)}; int64_t axis_1{node.get_attribute_value<int64_t>("axis", 1)};
const auto data_rank = x->get_output_partial_shape(0).rank();
AxisSet axes; AxisSet axes;
// if axis attribute is set // if axis attribute is set
if (axis_0 == axis_1) if (axis_0 == axis_1)
{ {
// positive axis axes.insert(
if (axis_0 >= 0) ngraph::normalize_axis(node.get_description(), axis_0, data_rank));
{
axes.insert(axis_0);
}
// negative axis
else if (axis_0 < 0)
{
axes.insert(x->get_shape().size() + axis_0);
}
} }
if (x->get_element_type() != zero_point->get_element_type()) if (x->get_element_type() != zero_point->get_element_type())
......
...@@ -34,10 +34,19 @@ namespace ngraph ...@@ -34,10 +34,19 @@ namespace ngraph
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0); auto data = inputs.at(0);
auto axis = node.get_attribute_value<std::int64_t>("axis", 1); auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
auto data_rank = data->get_shape().size(); const auto data_rank = data->get_output_partial_shape(0).rank();
CHECK_VALID_NODE(node,
data_rank.is_static(),
"Data rank must be static in order to calculate flatten op");
const auto data_rank_value = static_cast<int64_t>(data_rank);
// Accepted range is [-r, r] where r = rank(input). // Accepted range is [-r, r] where r = rank(input).
const auto normalized_axis = ngraph::normalize_axis( const auto normalized_axis = ngraph::normalize_axis(node.get_description(),
node.get_description(), axis, data_rank, -data_rank, data_rank); axis,
data_rank_value,
-data_rank_value,
data_rank_value);
return {ngraph::builder::opset1::flatten(data, normalized_axis)}; return {ngraph::builder::opset1::flatten(data, normalized_axis)};
} }
......
...@@ -38,7 +38,7 @@ namespace ngraph ...@@ -38,7 +38,7 @@ namespace ngraph
auto indices = ng_inputs.at(1); auto indices = ng_inputs.at(1);
auto axis = node.get_attribute_value<int64_t>("axis", 0); auto axis = node.get_attribute_value<int64_t>("axis", 0);
const auto valid_axis = ngraph::normalize_axis( const auto valid_axis = ngraph::normalize_axis(
node.get_description(), axis, data->get_shape().size()); node.get_description(), axis, data->get_output_partial_shape(0).rank());
return {std::make_shared<default_opset::Gather>( return {std::make_shared<default_opset::Gather>(
data, data,
......
...@@ -34,11 +34,11 @@ namespace ngraph ...@@ -34,11 +34,11 @@ namespace ngraph
NodeVector hardmax(const Node& node) NodeVector hardmax(const Node& node)
{ {
const auto input = node.get_ng_inputs().at(0); const auto input = node.get_ng_inputs().at(0);
const auto& input_shape = input->get_shape(); const auto& input_shape = input->get_output_partial_shape(0);
const auto axis = node.get_attribute_value<std::int64_t>("axis", 1); const auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
const auto normalized_axis = const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, input_shape.size()); ngraph::normalize_axis(node.get_description(), axis, input_shape.rank());
// reshape to 2D - "batch size" x "input feature dimensions" (NxD) // reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const auto coerced_tensor = const auto coerced_tensor =
...@@ -68,7 +68,17 @@ namespace ngraph ...@@ -68,7 +68,17 @@ namespace ngraph
const auto converted_results = std::make_shared<default_opset::Convert>( const auto converted_results = std::make_shared<default_opset::Convert>(
results, input->get_element_type()); results, input->get_element_type());
return {ngraph::builder::opset1::reshape(converted_results, input_shape)}; if (input_shape.is_static())
{
return {ngraph::builder::opset1::reshape(converted_results,
input_shape.to_shape())};
}
else
{
const auto output_shape = std::make_shared<default_opset::ShapeOf>(input);
return {
std::make_shared<default_opset::Reshape>(input, output_shape, false)};
}
} }
} // namespace set_1 } // namespace set_1
......
...@@ -32,11 +32,11 @@ namespace ngraph ...@@ -32,11 +32,11 @@ namespace ngraph
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
const auto data = inputs.at(0); const auto data = inputs.at(0);
const auto data_shape = data->get_shape(); const auto data_rank = data->get_output_partial_shape(0).rank();
const auto axis = node.get_attribute_value<int64_t>("axis", 1); const auto axis = node.get_attribute_value<int64_t>("axis", 1);
const auto normalized_axis = const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, data_shape.size()); ngraph::normalize_axis(node.get_description(), axis, data_rank);
const auto softmax = const auto softmax =
std::make_shared<default_opset::Softmax>(data, normalized_axis); std::make_shared<default_opset::Softmax>(data, normalized_axis);
......
...@@ -39,11 +39,17 @@ namespace ngraph ...@@ -39,11 +39,17 @@ namespace ngraph
NodeVector lp_norm(const Node& node) NodeVector lp_norm(const Node& node)
{ {
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)}; const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
const auto data_shape = data->get_output_partial_shape(0);
const auto data_rank = data_shape.rank();
CHECK_VALID_NODE(
node, data_shape.is_static(), "Data shape must be static for lp_norm op");
const auto data_rank_value = static_cast<size_t>(data_rank);
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)}; const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)}; const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
const size_t normalize_axis = ngraph::normalize_axis( const size_t normalize_axis =
node.get_description(), axis, data->get_shape().size()); ngraph::normalize_axis(node.get_description(), axis, data_rank);
ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2) ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
<< "Invalid `p` attribute value: " << p_norm << "Invalid `p` attribute value: " << p_norm
...@@ -53,13 +59,13 @@ namespace ngraph ...@@ -53,13 +59,13 @@ namespace ngraph
data, AxisSet{normalize_axis}, static_cast<std::size_t>(p_norm)); data, AxisSet{normalize_axis}, static_cast<std::size_t>(p_norm));
const auto target_shape = default_opset::Constant::create( const auto target_shape = default_opset::Constant::create(
element::i64, Shape{data->get_shape().size()}, data->get_shape()); element::i64, Shape{data_rank_value}, data_shape.to_shape());
// Create a default axes order matching the data tensor rank and erase the // Create a default axes order matching the data tensor rank and erase the
// element at the 'normalize_axis' position. The erased element indicates the // element at the 'normalize_axis' position. The erased element indicates the
// axis // axis
// along which the data should be broadcasted. // along which the data should be broadcasted.
std::vector<size_t> axes_values(data->get_shape().size()); std::vector<size_t> axes_values(data_rank_value);
std::iota(axes_values.begin(), axes_values.end(), 0); std::iota(axes_values.begin(), axes_values.end(), 0);
axes_values.erase(axes_values.begin() + normalize_axis); axes_values.erase(axes_values.begin() + normalize_axis);
......
...@@ -50,8 +50,8 @@ namespace ngraph ...@@ -50,8 +50,8 @@ namespace ngraph
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes", {0, 2, 3}); auto axes = node.get_attribute_value<std::vector<int64_t>>("axes", {0, 2, 3});
std::vector<std::size_t> normalized_axes = ngraph::normalize_axes( const std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
node.get_description(), axes, data->get_shape().size()); node.get_description(), axes, data->get_output_partial_shape(0).rank());
return {std::make_shared<ngraph::opset0::MVN>(data, AxisSet(normalized_axes))}; return {std::make_shared<ngraph::opset0::MVN>(data, AxisSet(normalized_axes))};
} }
......
...@@ -39,13 +39,14 @@ namespace ngraph ...@@ -39,13 +39,14 @@ namespace ngraph
// nGraph supports only int32 type of sequence_lengths // nGraph supports only int32 type of sequence_lengths
const auto sequence_lengths_i32 = std::make_shared<default_opset::Convert>( const auto sequence_lengths_i32 = std::make_shared<default_opset::Convert>(
node.get_ng_inputs().at(1), element::i32); node.get_ng_inputs().at(1), element::i32);
const auto data_rank = data->get_output_partial_shape(0).rank();
const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1); const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1);
const auto normalized_batch_axis = ngraph::normalize_axis( const auto normalized_batch_axis =
node.get_description(), batch_axis, data->get_shape().size()); ngraph::normalize_axis(node.get_description(), batch_axis, data_rank);
const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0); const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0);
const auto normalized_time_axis = ngraph::normalize_axis( const auto normalized_time_axis =
node.get_description(), time_axis, data->get_shape().size()); ngraph::normalize_axis(node.get_description(), time_axis, data_rank);
NGRAPH_CHECK(normalized_batch_axis == 0 || normalized_batch_axis == 1, NGRAPH_CHECK(normalized_batch_axis == 0 || normalized_batch_axis == 1,
"Allowed values of the 'batch_axis' attribute for ReverseSequence " "Allowed values of the 'batch_axis' attribute for ReverseSequence "
......
...@@ -32,11 +32,10 @@ namespace ngraph ...@@ -32,11 +32,10 @@ namespace ngraph
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0); auto data = inputs.at(0);
auto data_shape = data->get_shape(); const auto data_rank = data->get_output_partial_shape(0).rank();
const auto axis = node.get_attribute_value<int64_t>("axis", 1);
int axis = node.get_attribute_value<int64_t>("axis", 1);
const auto normalized_axis = const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, data_shape.size()); ngraph::normalize_axis(node.get_description(), axis, data_rank);
return {std::make_shared<default_opset::Softmax>(data, normalized_axis)}; return {std::make_shared<default_opset::Softmax>(data, normalized_axis)};
} }
......
...@@ -36,8 +36,10 @@ namespace ngraph ...@@ -36,8 +36,10 @@ namespace ngraph
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
std::vector<std::int64_t> axes = std::vector<std::int64_t> axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {}); node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> normalized_axes = ngraph::normalize_axes( const auto data_rank = data->get_output_partial_shape(0).rank();
node.get_description(), axes, data->get_shape().size());
std::vector<std::size_t> normalized_axes =
ngraph::normalize_axes(node.get_description(), axes, data_rank);
auto axes_node = std::make_shared<default_opset::Constant>( auto axes_node = std::make_shared<default_opset::Constant>(
element::u64, Shape{normalized_axes.size()}, normalized_axes); element::u64, Shape{normalized_axes.size()}, normalized_axes);
return {std::make_shared<default_opset::Squeeze>(data, axes_node)}; return {std::make_shared<default_opset::Squeeze>(data, axes_node)};
......
...@@ -35,8 +35,8 @@ namespace ...@@ -35,8 +35,8 @@ namespace
{ {
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)}; std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
auto data = node.get_ng_inputs().at(0); const auto data = node.get_ng_inputs().at(0);
auto data_rank = data->get_shape().size(); const auto data_rank = data->get_output_partial_shape(0).rank();
return ngraph::normalize_axis(node.get_description(), axis, data_rank); return ngraph::normalize_axis(node.get_description(), axis, data_rank);
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include "default_opset.hpp" #include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
#include "unsqueeze.hpp" #include "unsqueeze.hpp"
...@@ -33,7 +34,12 @@ namespace ngraph ...@@ -33,7 +34,12 @@ namespace ngraph
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {}); auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
const auto expanded_rank = data->get_shape().size() + axes.size(); const auto data_rank = data->get_output_partial_shape(0).rank();
CHECK_VALID_NODE(node,
data_rank.is_static(),
"Data rank must be static for creation of ONNX Unsqueeze op");
const auto expanded_rank =
data->get_output_partial_shape(0).rank() + axes.size();
std::vector<std::size_t> normalized_axes = std::vector<std::size_t> normalized_axes =
ngraph::normalize_axes(node.get_description(), axes, expanded_rank); ngraph::normalize_axes(node.get_description(), axes, expanded_rank);
auto axes_node = std::make_shared<default_opset::Constant>( auto axes_node = std::make_shared<default_opset::Constant>(
......
...@@ -33,8 +33,8 @@ namespace ngraph ...@@ -33,8 +33,8 @@ namespace ngraph
m_input_node = node.get_ng_inputs().at(0); m_input_node = node.get_ng_inputs().at(0);
const auto axis = node.get_attribute_value<std::int64_t>("axis", 0); const auto axis = node.get_attribute_value<std::int64_t>("axis", 0);
m_normalized_axis = ngraph::normalize_axis( const auto data_rank = m_input_node->get_output_partial_shape(0).rank();
node.get_description(), axis, m_input_node->get_shape().size()); m_normalized_axis = ngraph::normalize_axis(node.get_description(), axis, data_rank);
} }
std::shared_ptr<ngraph::Node> ArgMinMaxFactory::make_arg_max() const std::shared_ptr<ngraph::Node> ArgMinMaxFactory::make_arg_max() const
......
...@@ -34,10 +34,10 @@ namespace ngraph ...@@ -34,10 +34,10 @@ namespace ngraph
{ {
auto reduction_axes = auto reduction_axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {}); node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> normalized_axes = std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
ngraph::normalize_axes(node.get_description(), node.get_description(),
reduction_axes, reduction_axes,
node.get_ng_inputs().at(0)->get_shape().size()); node.get_ng_inputs().at(0)->get_output_partial_shape(0).rank());
if (reduction_axes.empty()) if (reduction_axes.empty())
{ {
......
...@@ -67,8 +67,8 @@ void op::v0::Split::pre_validate_and_infer_types() ...@@ -67,8 +67,8 @@ void op::v0::Split::pre_validate_and_infer_types()
const auto shape = input(0).get_shape(); const auto shape = input(0).get_shape();
m_axis = ngraph::normalize_axis(this, m_axis, shape.size()); const auto data_rank = get_input_partial_shape(0).rank();
m_axis = ngraph::normalize_axis(this, m_axis, data_rank);
const auto dimension_at_axis = shape.at(m_axis); const auto dimension_at_axis = shape.at(m_axis);
if (m_split_evenly) if (m_split_evenly)
{ {
...@@ -140,9 +140,10 @@ void op::v1::Split::validate_and_infer_types() ...@@ -140,9 +140,10 @@ void op::v1::Split::validate_and_infer_types()
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()); const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis = axis_input->cast_vector<int64_t>()[0]; auto axis = axis_input->cast_vector<int64_t>()[0];
const auto data_shape = data_ps.to_shape(); const auto data_rank = get_input_partial_shape(0).rank();
axis = ngraph::normalize_axis(this, axis, data_shape.size()); axis = ngraph::normalize_axis(this, axis, data_rank);
const auto data_shape = data_ps.to_shape();
const auto dimension_at_axis = data_shape.at(axis); const auto dimension_at_axis = data_shape.at(axis);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
......
...@@ -44,30 +44,8 @@ void op::ReverseSequence::validate_and_infer_types() ...@@ -44,30 +44,8 @@ void op::ReverseSequence::validate_and_infer_types()
auto input_shape = get_input_partial_shape(0); auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank(); auto input_rank = input_shape.rank();
if (m_batch_axis < 0 || m_seq_axis < 0) m_normalized_batch_axis = ngraph::normalize_axis(this, m_batch_axis, input_rank);
{ m_normalized_seq_axis = ngraph::normalize_axis(this, m_seq_axis, input_rank);
NODE_VALIDATION_CHECK(this,
input_rank.is_static(),
"In order to handle negative axes input_rank must be static (",
"batch_axis=",
m_batch_axis,
", seq_axis=",
m_seq_axis,
")");
}
else
{
m_normalized_batch_axis = m_batch_axis;
m_normalized_seq_axis = m_seq_axis;
}
if (input_rank.is_static())
{
m_normalized_batch_axis =
ngraph::normalize_axis(this, m_batch_axis, static_cast<int64_t>(input_rank));
m_normalized_seq_axis =
ngraph::normalize_axis(this, m_seq_axis, static_cast<int64_t>(input_rank));
}
auto indices_shape = get_input_partial_shape(1); auto indices_shape = get_input_partial_shape(1);
auto indices_rank = indices_shape.rank(); auto indices_rank = indices_shape.rank();
......
...@@ -66,7 +66,7 @@ void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_ ...@@ -66,7 +66,7 @@ void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_
void op::util::ArithmeticReduction::validate_and_infer_types() void op::util::ArithmeticReduction::validate_and_infer_types()
{ {
auto input_shape = get_input_partial_shape(0); auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank(); const auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()}; PartialShape result_shape{PartialShape::dynamic()};
...@@ -79,7 +79,7 @@ void op::util::ArithmeticReduction::validate_and_infer_types() ...@@ -79,7 +79,7 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
{ {
try try
{ {
axis = normalize_axis(this, axis, size_t(input_rank)); axis = normalize_axis(this, axis, input_rank);
} }
catch (const ngraph_error&) catch (const ngraph_error&)
{ {
......
...@@ -78,7 +78,7 @@ void op::util::LogicalReduction::validate_and_infer_types() ...@@ -78,7 +78,7 @@ void op::util::LogicalReduction::validate_and_infer_types()
{ {
try try
{ {
axis = normalize_axis(this, axis, size_t(input_rank)); axis = normalize_axis(this, axis, input_rank);
} }
catch (const ngraph_error&) catch (const ngraph_error&)
{ {
......
...@@ -52,7 +52,7 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types() ...@@ -52,7 +52,7 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types()
{ {
try try
{ {
axis = normalize_axis(this, axis, size_t(input_rank)); axis = normalize_axis(this, axis, input_rank);
} }
catch (const ngraph_error&) catch (const ngraph_error&)
{ {
......
...@@ -60,12 +60,11 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types() ...@@ -60,12 +60,11 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
if (data_shape.is_static() && axis_input->is_constant() && if (data_shape.is_static() && axis_input->is_constant() &&
split_lengths_input->is_constant()) split_lengths_input->is_constant())
{ {
auto data_rank = static_cast<size_t>(data_shape.rank());
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()); const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis_val = axis_input->cast_vector<int64_t>()[0]; auto axis_val = axis_input->cast_vector<int64_t>()[0];
// Adjust split axis in case of negatives // Adjust split axis in case of negatives
int64_t axis = ngraph::normalize_axis(this, axis_val, data_rank); int64_t axis = ngraph::normalize_axis(this, axis_val, data_shape.rank());
auto split_lengths = auto split_lengths =
as_type_ptr<op::Constant>(split_lengths_input)->cast_vector<int64_t>(); as_type_ptr<op::Constant>(split_lengths_input)->cast_vector<int64_t>();
......
...@@ -41,8 +41,8 @@ void pass::ConstantFolding::construct_constant_split() ...@@ -41,8 +41,8 @@ void pass::ConstantFolding::construct_constant_split()
const auto split = static_pointer_cast<op::v1::Split>(m.get_match_root()); const auto split = static_pointer_cast<op::v1::Split>(m.get_match_root());
const auto axis_val = axis_node->cast_vector<int64_t>()[0]; const auto axis_val = axis_node->cast_vector<int64_t>()[0];
const auto norm_axis_val = const auto norm_axis_val = ngraph::normalize_axis(
ngraph::normalize_axis(split.get(), axis_val, data_node->get_shape().size()); split.get(), axis_val, data_node->get_output_partial_shape(0).rank());
const auto slices = builder::split(data_node, split->get_num_splits(), norm_axis_val); const auto slices = builder::split(data_node, split->get_num_splits(), norm_axis_val);
for (size_t i = 0; i < split->get_output_size(); i++) for (size_t i = 0; i < split->get_output_size(); i++)
......
...@@ -46,8 +46,8 @@ void pass::ConstantFolding::construct_constant_variadic_split() ...@@ -46,8 +46,8 @@ void pass::ConstantFolding::construct_constant_variadic_split()
const auto variadic_split = static_pointer_cast<op::v1::VariadicSplit>(m.get_match_root()); const auto variadic_split = static_pointer_cast<op::v1::VariadicSplit>(m.get_match_root());
const auto axis_val = axis_node->cast_vector<int64_t>()[0]; const auto axis_val = axis_node->cast_vector<int64_t>()[0];
const auto norm_axis_val = const auto norm_axis_val = ngraph::normalize_axis(
ngraph::normalize_axis(variadic_split.get(), axis_val, data_node->get_shape().size()); variadic_split.get(), axis_val, data_node->get_output_partial_shape(0).rank());
auto split_lengths = lengths_node->cast_vector<int64_t>(); auto split_lengths = lengths_node->cast_vector<int64_t>();
// Adjust split lengths in case of negatives // Adjust split lengths in case of negatives
......
...@@ -797,7 +797,7 @@ PartialShape ngraph::infer_slice_shape(const Node* node, ...@@ -797,7 +797,7 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
std::vector<size_t> ngraph::normalize_axes(const std::string& node_description, std::vector<size_t> ngraph::normalize_axes(const std::string& node_description,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
std::int64_t tensor_rank) const Rank& tensor_rank)
{ {
std::vector<size_t> new_axes; std::vector<size_t> new_axes;
...@@ -809,21 +809,36 @@ std::vector<size_t> ngraph::normalize_axes(const std::string& node_description, ...@@ -809,21 +809,36 @@ std::vector<size_t> ngraph::normalize_axes(const std::string& node_description,
return new_axes; return new_axes;
} }
int64_t ngraph::normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank) int64_t ngraph::normalize_axis(const Node* node, std::int64_t axis, const Rank& tensor_rank)
{ {
return normalize_axis(node->description(), axis, tensor_rank); return normalize_axis(node->description(), axis, tensor_rank);
} }
int64_t ngraph::normalize_axis(const std::string& node_description, int64_t ngraph::normalize_axis(const std::string& node_description,
std::int64_t axis, std::int64_t axis,
std::int64_t tensor_rank) const Rank& tensor_rank)
{ {
return normalize_axis(node_description, axis, tensor_rank, -tensor_rank, tensor_rank - 1); if (axis < 0)
{
// Handling negative axis requires static tensor rank
NGRAPH_CHECK(tensor_rank.is_static(),
node_description,
" Rank must be static in order to normalize negative axis=",
axis);
}
if (tensor_rank.is_dynamic())
{
return axis;
}
const auto tensor_rank_value = static_cast<int64_t>(tensor_rank);
return normalize_axis(
node_description, axis, tensor_rank_value, -tensor_rank_value, tensor_rank_value - 1);
} }
int64_t ngraph::normalize_axis(const Node* node, int64_t ngraph::normalize_axis(const Node* node,
std::int64_t axis, std::int64_t axis,
std::int64_t tensor_rank, std::uint64_t tensor_rank,
std::int64_t axis_range_min, std::int64_t axis_range_min,
std::int64_t axis_range_max) std::int64_t axis_range_max)
{ {
...@@ -833,7 +848,7 @@ int64_t ngraph::normalize_axis(const Node* node, ...@@ -833,7 +848,7 @@ int64_t ngraph::normalize_axis(const Node* node,
int64_t ngraph::normalize_axis(const std::string& node_description, int64_t ngraph::normalize_axis(const std::string& node_description,
std::int64_t axis, std::int64_t axis,
std::int64_t tensor_rank, std::uint64_t tensor_rank,
std::int64_t axis_range_min, std::int64_t axis_range_min,
std::int64_t axis_range_max) std::int64_t axis_range_max)
{ {
......
...@@ -112,7 +112,7 @@ namespace ngraph ...@@ -112,7 +112,7 @@ namespace ngraph
/// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise /// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise
/// returns error. If negative axis, it counts from the last to the first axis, /// returns error. If negative axis, it counts from the last to the first axis,
/// by adding tensor_rank to axis. /// by adding tensor_rank to axis.
int64_t normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank); int64_t normalize_axis(const Node* node, std::int64_t axis, const Rank& tensor_rank);
/// \brief Handle out of range axes in vector. /// \brief Handle out of range axes in vector.
/// ///
...@@ -125,7 +125,7 @@ namespace ngraph ...@@ -125,7 +125,7 @@ namespace ngraph
/// ///
std::vector<size_t> normalize_axes(const std::string& node_description, std::vector<size_t> normalize_axes(const std::string& node_description,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
std::int64_t tensor_rank); const Rank& tensor_rank);
/// \brief Handle out of range axis. /// \brief Handle out of range axis.
/// ///
...@@ -138,7 +138,7 @@ namespace ngraph ...@@ -138,7 +138,7 @@ namespace ngraph
/// by adding tensor_rank to axis. /// by adding tensor_rank to axis.
int64_t normalize_axis(const std::string& node_description, int64_t normalize_axis(const std::string& node_description,
std::int64_t axis, std::int64_t axis,
std::int64_t tensor_rank); const Rank& tensor_rank);
/// \brief Handle out of range axis. /// \brief Handle out of range axis.
/// ///
...@@ -153,7 +153,7 @@ namespace ngraph ...@@ -153,7 +153,7 @@ namespace ngraph
/// by adding tensor_rank to axis. /// by adding tensor_rank to axis.
int64_t normalize_axis(const Node* node, int64_t normalize_axis(const Node* node,
std::int64_t axis, std::int64_t axis,
std::int64_t tensor_rank, std::uint64_t tensor_rank,
std::int64_t axis_range_min, std::int64_t axis_range_min,
std::int64_t axis_range_max); std::int64_t axis_range_max);
...@@ -170,7 +170,7 @@ namespace ngraph ...@@ -170,7 +170,7 @@ namespace ngraph
/// by adding tensor_rank to axis. /// by adding tensor_rank to axis.
int64_t normalize_axis(const std::string& node_description, int64_t normalize_axis(const std::string& node_description,
std::int64_t axis, std::int64_t axis,
std::int64_t tensor_rank, std::uint64_t tensor_rank,
std::int64_t axis_range_min, std::int64_t axis_range_min,
std::int64_t axis_range_max); std::int64_t axis_range_max);
......
...@@ -301,11 +301,11 @@ TEST(type_prop, reverse_sequence_negative_axis_dynamic_input_rank) ...@@ -301,11 +301,11 @@ TEST(type_prop, reverse_sequence_negative_axis_dynamic_input_rank)
auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis); auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
FAIL() << "Dynamic input rank for negative axis not detected"; FAIL() << "Dynamic input rank for negative axis not detected";
} }
catch (const NodeValidationFailure& error) catch (const CheckFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
std::string("In order to handle negative axes input_rank must be " error.what(),
"static (batch_axis=1, seq_axis=-2)")); std::string("Rank must be static in order to normalize negative axis=-2"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment