Unverified Commit 6d299dab authored by Mateusz Bencer's avatar Mateusz Bencer Committed by GitHub

Extend normalize axes helpers to support dynamic shapes (#4299)

* Extend normalization. Part.1

* Normalize axis. Part.2

* Code review remarks introduced

* Fixed normalizes ranges type

* Trigger CI
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 6a369747
......@@ -34,8 +34,10 @@ namespace ngraph
{
NodeVector inputs{node.get_ng_inputs()};
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
const auto normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, inputs.at(0)->get_shape().size());
const auto normalized_axis =
ngraph::normalize_axis(node.get_description(),
axis,
inputs.at(0)->get_output_partial_shape(0).rank());
return {std::make_shared<default_opset::Concat>(inputs, normalized_axis)};
}
......
......@@ -25,6 +25,7 @@
#include "ngraph/op/dequantize.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
namespace ngraph
{
......@@ -57,20 +58,13 @@ namespace ngraph
int64_t axis_0{node.get_attribute_value<int64_t>("axis", 0)};
int64_t axis_1{node.get_attribute_value<int64_t>("axis", 1)};
const auto data_rank = x->get_output_partial_shape(0).rank();
AxisSet axes;
// if axis attribute is set
if (axis_0 == axis_1)
{
// positive axis
if (axis_0 >= 0)
{
axes.insert(axis_0);
}
// negative axis
else if (axis_0 < 0)
{
axes.insert(x->get_shape().size() + axis_0);
}
axes.insert(
ngraph::normalize_axis(node.get_description(), axis_0, data_rank));
}
if (x->get_element_type() != zero_point->get_element_type())
......
......@@ -34,10 +34,19 @@ namespace ngraph
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
auto data_rank = data->get_shape().size();
const auto data_rank = data->get_output_partial_shape(0).rank();
CHECK_VALID_NODE(node,
data_rank.is_static(),
"Data rank must be static in order to calculate flatten op");
const auto data_rank_value = static_cast<int64_t>(data_rank);
// Accepted range is [-r, r] where r = rank(input).
const auto normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, data_rank, -data_rank, data_rank);
const auto normalized_axis = ngraph::normalize_axis(node.get_description(),
axis,
data_rank_value,
-data_rank_value,
data_rank_value);
return {ngraph::builder::opset1::flatten(data, normalized_axis)};
}
......
......@@ -38,7 +38,7 @@ namespace ngraph
auto indices = ng_inputs.at(1);
auto axis = node.get_attribute_value<int64_t>("axis", 0);
const auto valid_axis = ngraph::normalize_axis(
node.get_description(), axis, data->get_shape().size());
node.get_description(), axis, data->get_output_partial_shape(0).rank());
return {std::make_shared<default_opset::Gather>(
data,
......
......@@ -34,11 +34,11 @@ namespace ngraph
NodeVector hardmax(const Node& node)
{
const auto input = node.get_ng_inputs().at(0);
const auto& input_shape = input->get_shape();
const auto& input_shape = input->get_output_partial_shape(0);
const auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, input_shape.size());
ngraph::normalize_axis(node.get_description(), axis, input_shape.rank());
// reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const auto coerced_tensor =
......@@ -68,7 +68,17 @@ namespace ngraph
const auto converted_results = std::make_shared<default_opset::Convert>(
results, input->get_element_type());
return {ngraph::builder::opset1::reshape(converted_results, input_shape)};
if (input_shape.is_static())
{
return {ngraph::builder::opset1::reshape(converted_results,
input_shape.to_shape())};
}
else
{
const auto output_shape = std::make_shared<default_opset::ShapeOf>(input);
return {
std::make_shared<default_opset::Reshape>(input, output_shape, false)};
}
}
} // namespace set_1
......
......@@ -32,11 +32,11 @@ namespace ngraph
{
NodeVector inputs{node.get_ng_inputs()};
const auto data = inputs.at(0);
const auto data_shape = data->get_shape();
const auto data_rank = data->get_output_partial_shape(0).rank();
const auto axis = node.get_attribute_value<int64_t>("axis", 1);
const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, data_shape.size());
ngraph::normalize_axis(node.get_description(), axis, data_rank);
const auto softmax =
std::make_shared<default_opset::Softmax>(data, normalized_axis);
......
......@@ -39,11 +39,17 @@ namespace ngraph
NodeVector lp_norm(const Node& node)
{
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
const auto data_shape = data->get_output_partial_shape(0);
const auto data_rank = data_shape.rank();
CHECK_VALID_NODE(
node, data_shape.is_static(), "Data shape must be static for lp_norm op");
const auto data_rank_value = static_cast<size_t>(data_rank);
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
const size_t normalize_axis = ngraph::normalize_axis(
node.get_description(), axis, data->get_shape().size());
const size_t normalize_axis =
ngraph::normalize_axis(node.get_description(), axis, data_rank);
ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
<< "Invalid `p` attribute value: " << p_norm
......@@ -53,13 +59,13 @@ namespace ngraph
data, AxisSet{normalize_axis}, static_cast<std::size_t>(p_norm));
const auto target_shape = default_opset::Constant::create(
element::i64, Shape{data->get_shape().size()}, data->get_shape());
element::i64, Shape{data_rank_value}, data_shape.to_shape());
// Create a default axes order matching the data tensor rank and erase the
// element at the 'normalize_axis' position. The erased element indicates the
// axis
// along which the data should be broadcasted.
std::vector<size_t> axes_values(data->get_shape().size());
std::vector<size_t> axes_values(data_rank_value);
std::iota(axes_values.begin(), axes_values.end(), 0);
axes_values.erase(axes_values.begin() + normalize_axis);
......
......@@ -50,8 +50,8 @@ namespace ngraph
{
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes", {0, 2, 3});
std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
node.get_description(), axes, data->get_shape().size());
const std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
node.get_description(), axes, data->get_output_partial_shape(0).rank());
return {std::make_shared<ngraph::opset0::MVN>(data, AxisSet(normalized_axes))};
}
......
......@@ -39,13 +39,14 @@ namespace ngraph
// nGraph supports only int32 type of sequence_lengths
const auto sequence_lengths_i32 = std::make_shared<default_opset::Convert>(
node.get_ng_inputs().at(1), element::i32);
const auto data_rank = data->get_output_partial_shape(0).rank();
const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1);
const auto normalized_batch_axis = ngraph::normalize_axis(
node.get_description(), batch_axis, data->get_shape().size());
const auto normalized_batch_axis =
ngraph::normalize_axis(node.get_description(), batch_axis, data_rank);
const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0);
const auto normalized_time_axis = ngraph::normalize_axis(
node.get_description(), time_axis, data->get_shape().size());
const auto normalized_time_axis =
ngraph::normalize_axis(node.get_description(), time_axis, data_rank);
NGRAPH_CHECK(normalized_batch_axis == 0 || normalized_batch_axis == 1,
"Allowed values of the 'batch_axis' attribute for ReverseSequence "
......
......@@ -32,11 +32,10 @@ namespace ngraph
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1);
const auto data_rank = data->get_output_partial_shape(0).rank();
const auto axis = node.get_attribute_value<int64_t>("axis", 1);
const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, data_shape.size());
ngraph::normalize_axis(node.get_description(), axis, data_rank);
return {std::make_shared<default_opset::Softmax>(data, normalized_axis)};
}
......
......@@ -36,8 +36,10 @@ namespace ngraph
auto data = node.get_ng_inputs().at(0);
std::vector<std::int64_t> axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
node.get_description(), axes, data->get_shape().size());
const auto data_rank = data->get_output_partial_shape(0).rank();
std::vector<std::size_t> normalized_axes =
ngraph::normalize_axes(node.get_description(), axes, data_rank);
auto axes_node = std::make_shared<default_opset::Constant>(
element::u64, Shape{normalized_axes.size()}, normalized_axes);
return {std::make_shared<default_opset::Squeeze>(data, axes_node)};
......
......@@ -35,8 +35,8 @@ namespace
{
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
auto data = node.get_ng_inputs().at(0);
auto data_rank = data->get_shape().size();
const auto data = node.get_ng_inputs().at(0);
const auto data_rank = data->get_output_partial_shape(0).rank();
return ngraph::normalize_axis(node.get_description(), axis, data_rank);
}
......
......@@ -17,6 +17,7 @@
#include <memory>
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
#include "unsqueeze.hpp"
......@@ -33,7 +34,12 @@ namespace ngraph
{
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
const auto expanded_rank = data->get_shape().size() + axes.size();
const auto data_rank = data->get_output_partial_shape(0).rank();
CHECK_VALID_NODE(node,
data_rank.is_static(),
"Data rank must be static for creation of ONNX Unsqueeze op");
const auto expanded_rank =
data->get_output_partial_shape(0).rank() + axes.size();
std::vector<std::size_t> normalized_axes =
ngraph::normalize_axes(node.get_description(), axes, expanded_rank);
auto axes_node = std::make_shared<default_opset::Constant>(
......
......@@ -33,8 +33,8 @@ namespace ngraph
m_input_node = node.get_ng_inputs().at(0);
const auto axis = node.get_attribute_value<std::int64_t>("axis", 0);
m_normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, m_input_node->get_shape().size());
const auto data_rank = m_input_node->get_output_partial_shape(0).rank();
m_normalized_axis = ngraph::normalize_axis(node.get_description(), axis, data_rank);
}
std::shared_ptr<ngraph::Node> ArgMinMaxFactory::make_arg_max() const
......
......@@ -34,10 +34,10 @@ namespace ngraph
{
auto reduction_axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> normalized_axes =
ngraph::normalize_axes(node.get_description(),
reduction_axes,
node.get_ng_inputs().at(0)->get_shape().size());
std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
node.get_description(),
reduction_axes,
node.get_ng_inputs().at(0)->get_output_partial_shape(0).rank());
if (reduction_axes.empty())
{
......
......@@ -67,8 +67,8 @@ void op::v0::Split::pre_validate_and_infer_types()
const auto shape = input(0).get_shape();
m_axis = ngraph::normalize_axis(this, m_axis, shape.size());
const auto data_rank = get_input_partial_shape(0).rank();
m_axis = ngraph::normalize_axis(this, m_axis, data_rank);
const auto dimension_at_axis = shape.at(m_axis);
if (m_split_evenly)
{
......@@ -140,9 +140,10 @@ void op::v1::Split::validate_and_infer_types()
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis = axis_input->cast_vector<int64_t>()[0];
const auto data_shape = data_ps.to_shape();
axis = ngraph::normalize_axis(this, axis, data_shape.size());
const auto data_rank = get_input_partial_shape(0).rank();
axis = ngraph::normalize_axis(this, axis, data_rank);
const auto data_shape = data_ps.to_shape();
const auto dimension_at_axis = data_shape.at(axis);
NODE_VALIDATION_CHECK(this,
......
......@@ -44,30 +44,8 @@ void op::ReverseSequence::validate_and_infer_types()
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
if (m_batch_axis < 0 || m_seq_axis < 0)
{
NODE_VALIDATION_CHECK(this,
input_rank.is_static(),
"In order to handle negative axes input_rank must be static (",
"batch_axis=",
m_batch_axis,
", seq_axis=",
m_seq_axis,
")");
}
else
{
m_normalized_batch_axis = m_batch_axis;
m_normalized_seq_axis = m_seq_axis;
}
if (input_rank.is_static())
{
m_normalized_batch_axis =
ngraph::normalize_axis(this, m_batch_axis, static_cast<int64_t>(input_rank));
m_normalized_seq_axis =
ngraph::normalize_axis(this, m_seq_axis, static_cast<int64_t>(input_rank));
}
m_normalized_batch_axis = ngraph::normalize_axis(this, m_batch_axis, input_rank);
m_normalized_seq_axis = ngraph::normalize_axis(this, m_seq_axis, input_rank);
auto indices_shape = get_input_partial_shape(1);
auto indices_rank = indices_shape.rank();
......
......@@ -66,7 +66,7 @@ void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_
void op::util::ArithmeticReduction::validate_and_infer_types()
{
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
const auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()};
......@@ -79,7 +79,7 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
{
try
{
axis = normalize_axis(this, axis, size_t(input_rank));
axis = normalize_axis(this, axis, input_rank);
}
catch (const ngraph_error&)
{
......
......@@ -78,7 +78,7 @@ void op::util::LogicalReduction::validate_and_infer_types()
{
try
{
axis = normalize_axis(this, axis, size_t(input_rank));
axis = normalize_axis(this, axis, input_rank);
}
catch (const ngraph_error&)
{
......
......@@ -52,7 +52,7 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types()
{
try
{
axis = normalize_axis(this, axis, size_t(input_rank));
axis = normalize_axis(this, axis, input_rank);
}
catch (const ngraph_error&)
{
......
......@@ -60,12 +60,11 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
if (data_shape.is_static() && axis_input->is_constant() &&
split_lengths_input->is_constant())
{
auto data_rank = static_cast<size_t>(data_shape.rank());
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis_val = axis_input->cast_vector<int64_t>()[0];
// Adjust split axis in case of negatives
int64_t axis = ngraph::normalize_axis(this, axis_val, data_rank);
int64_t axis = ngraph::normalize_axis(this, axis_val, data_shape.rank());
auto split_lengths =
as_type_ptr<op::Constant>(split_lengths_input)->cast_vector<int64_t>();
......
......@@ -41,8 +41,8 @@ void pass::ConstantFolding::construct_constant_split()
const auto split = static_pointer_cast<op::v1::Split>(m.get_match_root());
const auto axis_val = axis_node->cast_vector<int64_t>()[0];
const auto norm_axis_val =
ngraph::normalize_axis(split.get(), axis_val, data_node->get_shape().size());
const auto norm_axis_val = ngraph::normalize_axis(
split.get(), axis_val, data_node->get_output_partial_shape(0).rank());
const auto slices = builder::split(data_node, split->get_num_splits(), norm_axis_val);
for (size_t i = 0; i < split->get_output_size(); i++)
......
......@@ -46,8 +46,8 @@ void pass::ConstantFolding::construct_constant_variadic_split()
const auto variadic_split = static_pointer_cast<op::v1::VariadicSplit>(m.get_match_root());
const auto axis_val = axis_node->cast_vector<int64_t>()[0];
const auto norm_axis_val =
ngraph::normalize_axis(variadic_split.get(), axis_val, data_node->get_shape().size());
const auto norm_axis_val = ngraph::normalize_axis(
variadic_split.get(), axis_val, data_node->get_output_partial_shape(0).rank());
auto split_lengths = lengths_node->cast_vector<int64_t>();
// Adjust split lengths in case of negatives
......
......@@ -797,7 +797,7 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
std::vector<size_t> ngraph::normalize_axes(const std::string& node_description,
const std::vector<int64_t>& axes,
std::int64_t tensor_rank)
const Rank& tensor_rank)
{
std::vector<size_t> new_axes;
......@@ -809,21 +809,36 @@ std::vector<size_t> ngraph::normalize_axes(const std::string& node_description,
return new_axes;
}
int64_t ngraph::normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank)
int64_t ngraph::normalize_axis(const Node* node, std::int64_t axis, const Rank& tensor_rank)
{
return normalize_axis(node->description(), axis, tensor_rank);
}
int64_t ngraph::normalize_axis(const std::string& node_description,
std::int64_t axis,
std::int64_t tensor_rank)
const Rank& tensor_rank)
{
return normalize_axis(node_description, axis, tensor_rank, -tensor_rank, tensor_rank - 1);
if (axis < 0)
{
// Handling negative axis requires static tensor rank
NGRAPH_CHECK(tensor_rank.is_static(),
node_description,
" Rank must be static in order to normalize negative axis=",
axis);
}
if (tensor_rank.is_dynamic())
{
return axis;
}
const auto tensor_rank_value = static_cast<int64_t>(tensor_rank);
return normalize_axis(
node_description, axis, tensor_rank_value, -tensor_rank_value, tensor_rank_value - 1);
}
int64_t ngraph::normalize_axis(const Node* node,
std::int64_t axis,
std::int64_t tensor_rank,
std::uint64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max)
{
......@@ -833,7 +848,7 @@ int64_t ngraph::normalize_axis(const Node* node,
int64_t ngraph::normalize_axis(const std::string& node_description,
std::int64_t axis,
std::int64_t tensor_rank,
std::uint64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max)
{
......
......@@ -112,7 +112,7 @@ namespace ngraph
/// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise
/// returns error. If negative axis, it counts from the last to the first axis,
/// by adding tensor_rank to axis.
int64_t normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank);
int64_t normalize_axis(const Node* node, std::int64_t axis, const Rank& tensor_rank);
/// \brief Handle out of range axes in vector.
///
......@@ -125,7 +125,7 @@ namespace ngraph
///
std::vector<size_t> normalize_axes(const std::string& node_description,
const std::vector<int64_t>& axes,
std::int64_t tensor_rank);
const Rank& tensor_rank);
/// \brief Handle out of range axis.
///
......@@ -138,7 +138,7 @@ namespace ngraph
/// by adding tensor_rank to axis.
int64_t normalize_axis(const std::string& node_description,
std::int64_t axis,
std::int64_t tensor_rank);
const Rank& tensor_rank);
/// \brief Handle out of range axis.
///
......@@ -153,7 +153,7 @@ namespace ngraph
/// by adding tensor_rank to axis.
int64_t normalize_axis(const Node* node,
std::int64_t axis,
std::int64_t tensor_rank,
std::uint64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max);
......@@ -170,7 +170,7 @@ namespace ngraph
/// by adding tensor_rank to axis.
int64_t normalize_axis(const std::string& node_description,
std::int64_t axis,
std::int64_t tensor_rank,
std::uint64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max);
......
......@@ -301,11 +301,11 @@ TEST(type_prop, reverse_sequence_negative_axis_dynamic_input_rank)
auto rs = make_shared<op::ReverseSequence>(data, seq_lengths, batch_axis, seq_axis);
FAIL() << "Dynamic input rank for negative axis not detected";
}
catch (const NodeValidationFailure& error)
catch (const CheckFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("In order to handle negative axes input_rank must be "
"static (batch_axis=1, seq_axis=-2)"));
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Rank must be static in order to normalize negative axis=-2"));
}
catch (...)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment