Commit bcc424ab authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Sang Ik Lee

Use ngraph::normalize in onnx_importer (#4098)

* Use normalize in onnx importer

* Introduced changes after code review
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent 47abb1cb
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "default_opset.hpp" #include "default_opset.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "utils/common.hpp" #include "ngraph/validation_util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,10 +34,10 @@ namespace ngraph ...@@ -34,10 +34,10 @@ namespace ngraph
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis"); std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
size_t valid_axis = const auto normalized_axis = ngraph::normalize_axis(
common::validate_axis(node, axis, inputs.at(0)->get_shape().size()); node.get_description(), axis, inputs.at(0)->get_shape().size());
return {std::make_shared<default_opset::Concat>(inputs, valid_axis)}; return {std::make_shared<default_opset::Concat>(inputs, normalized_axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "exceptions.hpp" #include "exceptions.hpp"
#include "flatten.hpp" #include "flatten.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "utils/common.hpp" #include "ngraph/validation_util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,10 +36,10 @@ namespace ngraph ...@@ -36,10 +36,10 @@ namespace ngraph
auto axis = node.get_attribute_value<std::int64_t>("axis", 1); auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
auto data_rank = data->get_shape().size(); auto data_rank = data->get_shape().size();
// Accepted range is [-r, r] where r = rank(input). // Accepted range is [-r, r] where r = rank(input).
auto valid_axis = const auto normalized_axis = ngraph::normalize_axis(
common::validate_axis(node, axis, data_rank, -data_rank, data_rank); node.get_description(), axis, data_rank, -data_rank, data_rank);
return {ngraph::builder::flatten(data, valid_axis)}; return {ngraph::builder::flatten(data, normalized_axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "utils/common.hpp" #include "ngraph/validation_util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -37,7 +37,8 @@ namespace ngraph ...@@ -37,7 +37,8 @@ namespace ngraph
auto data = ng_inputs.at(0); auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1); auto indices = ng_inputs.at(1);
auto axis = node.get_attribute_value<int64_t>("axis", 0); auto axis = node.get_attribute_value<int64_t>("axis", 0);
auto valid_axis = common::validate_axis(node, axis, data->get_shape().size()); const auto valid_axis = ngraph::normalize_axis(
node.get_description(), axis, data->get_shape().size());
return {std::make_shared<default_opset::Gather>( return {std::make_shared<default_opset::Gather>(
data, data,
......
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
#include "hardmax.hpp" #include "hardmax.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/frontend/onnx_import/utils/common.hpp"
#include "ngraph/opsets/opset0.hpp" #include "ngraph/opsets/opset0.hpp"
#include "ngraph/validation_util.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,10 +35,11 @@ namespace ngraph ...@@ -34,10 +35,11 @@ namespace ngraph
const auto& input_shape = input->get_shape(); const auto& input_shape = input->get_shape();
auto axis = node.get_attribute_value<std::int64_t>("axis", 1); auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
auto valid_axis = common::validate_axis(node, axis, input_shape.size()); const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, input_shape.size());
// reshape to 2D - "batch size" x "input feature dimensions" (NxD) // reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const auto coerced_tensor = ngraph::builder::flatten(input, valid_axis); const auto coerced_tensor = ngraph::builder::flatten(input, normalized_axis);
const auto& coerced_shape = coerced_tensor->get_shape(); const auto& coerced_shape = coerced_tensor->get_shape();
const std::shared_ptr<ngraph::Node> argmax_2d = const std::shared_ptr<ngraph::Node> argmax_2d =
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "utils/common.hpp" #include "ngraph/validation_util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -42,25 +42,26 @@ namespace ngraph ...@@ -42,25 +42,26 @@ namespace ngraph
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)}; const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)}; const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
const std::size_t valid_axis = const size_t normalize_axis = ngraph::normalize_axis(
common::validate_axis(node, axis, data->get_shape().size()); node.get_description(), axis, data->get_shape().size());
ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2) ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
<< "Invalid `p` attribute value: " << p_norm << "Invalid `p` attribute value: " << p_norm
<< "Only normalization of 1st or 2nd order is supported."; << "Only normalization of 1st or 2nd order is supported.";
std::shared_ptr<ngraph::Node> norm = ngraph::builder::lp_norm( std::shared_ptr<ngraph::Node> norm = ngraph::builder::lp_norm(
data, AxisSet{valid_axis}, static_cast<std::size_t>(p_norm)); data, AxisSet{normalize_axis}, static_cast<std::size_t>(p_norm));
const auto target_shape = default_opset::Constant::create( const auto target_shape = default_opset::Constant::create(
element::i64, Shape{data->get_shape().size()}, data->get_shape()); element::i64, Shape{data->get_shape().size()}, data->get_shape());
// Create a default axes order matching the data tensor rank and erase the // Create a default axes order matching the data tensor rank and erase the
// element at the 'valid_axis' position. The erased element indicates the axis // element at the 'normalize_axis' position. The erased element indicates the
// axis
// along which the data should be broadcasted. // along which the data should be broadcasted.
std::vector<size_t> axes_values(data->get_shape().size()); std::vector<size_t> axes_values(data->get_shape().size());
std::iota(axes_values.begin(), axes_values.end(), 0); std::iota(axes_values.begin(), axes_values.end(), 0);
axes_values.erase(axes_values.begin() + valid_axis); axes_values.erase(axes_values.begin() + normalize_axis);
const auto axes_mapping = default_opset::Constant::create( const auto axes_mapping = default_opset::Constant::create(
element::i64, Shape{axes_values.size()}, axes_values); element::i64, Shape{axes_values.size()}, axes_values);
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/op/fused/mvn.hpp" #include "ngraph/op/fused/mvn.hpp"
#include "ngraph/opsets/opset0.hpp" #include "ngraph/opsets/opset0.hpp"
#include "utils/common.hpp" #include "ngraph/validation_util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -50,10 +50,10 @@ namespace ngraph ...@@ -50,10 +50,10 @@ namespace ngraph
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes", {0, 2, 3}); auto axes = node.get_attribute_value<std::vector<int64_t>>("axes", {0, 2, 3});
std::vector<std::size_t> valid_axes = std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
common::validate_axes(node, axes, data->get_shape().size()); node.get_description(), axes, data->get_shape().size());
return {std::make_shared<ngraph::opset0::MVN>(data, AxisSet(valid_axes))}; return {std::make_shared<ngraph::opset0::MVN>(data, AxisSet(normalized_axes))};
} }
} // namespace set_9 } // namespace set_9
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/validation_util.hpp"
#include "reverse_sequence.hpp" #include "reverse_sequence.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -41,26 +41,26 @@ namespace ngraph ...@@ -41,26 +41,26 @@ namespace ngraph
node.get_ng_inputs().at(1), element::i32); node.get_ng_inputs().at(1), element::i32);
const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1); const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1);
std::size_t valid_batch_axis = const auto normalized_batch_axis = ngraph::normalize_axis(
common::validate_axis(node, batch_axis, data->get_shape().size()); node.get_description(), batch_axis, data->get_shape().size());
const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0); const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0);
std::size_t valid_time_axis = const auto normalized_time_axis = ngraph::normalize_axis(
common::validate_axis(node, time_axis, data->get_shape().size()); node.get_description(), time_axis, data->get_shape().size());
NGRAPH_CHECK(valid_batch_axis == 0 || valid_batch_axis == 1, NGRAPH_CHECK(normalized_batch_axis == 0 || normalized_batch_axis == 1,
"Allowed values of the 'batch_axis' attribute for ReverseSequence " "Allowed values of the 'batch_axis' attribute for ReverseSequence "
"operator are 0 and 1"); "operator are 0 and 1");
NGRAPH_CHECK(valid_time_axis == 0 || valid_time_axis == 1, NGRAPH_CHECK(normalized_time_axis == 0 || normalized_time_axis == 1,
"Allowed values of the 'time_axis' attribute for ReverseSequence " "Allowed values of the 'time_axis' attribute for ReverseSequence "
"operator are 0 and 1"); "operator are 0 and 1");
NGRAPH_CHECK(valid_batch_axis != valid_time_axis, NGRAPH_CHECK(normalized_batch_axis != normalized_time_axis,
"'batch_axis' and 'time_axis' attributes of the ReverseSequence " "'batch_axis' and 'time_axis' attributes of the ReverseSequence "
"operator can't point to the same dimension"); "operator can't point to the same dimension");
return {std::make_shared<default_opset::ReverseSequence>( return {std::make_shared<default_opset::ReverseSequence>(
data, sequence_lengths_i32, valid_batch_axis, valid_time_axis)}; data, sequence_lengths_i32, normalized_batch_axis, normalized_time_axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <memory> #include <memory>
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/validation_util.hpp"
#include "softmax.hpp" #include "softmax.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,9 +35,10 @@ namespace ngraph ...@@ -35,9 +35,10 @@ namespace ngraph
auto data_shape = data->get_shape(); auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1); int axis = node.get_attribute_value<int64_t>("axis", 1);
std::size_t valid_axis = common::validate_axis(node, axis, data_shape.size()); const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, data_shape.size());
return {std::make_shared<default_opset::Softmax>(data, valid_axis)}; return {std::make_shared<default_opset::Softmax>(data, normalized_axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -16,11 +16,12 @@ ...@@ -16,11 +16,12 @@
#include <vector> #include <vector>
#include "default_opset.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/squeeze.hpp" #include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/validation_util.hpp"
#include "squeeze.hpp" #include "squeeze.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,10 +36,10 @@ namespace ngraph ...@@ -35,10 +36,10 @@ namespace ngraph
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
std::vector<std::int64_t> axes = std::vector<std::int64_t> axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {}); node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> valid_axes = std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
common::validate_axes(node, axes, data->get_shape().size()); node.get_description(), axes, data->get_shape().size());
auto axes_node = std::make_shared<default_opset::Constant>( auto axes_node = std::make_shared<default_opset::Constant>(
element::u64, Shape{valid_axes.size()}, valid_axes); element::u64, Shape{normalized_axes.size()}, normalized_axes);
return {std::make_shared<default_opset::Squeeze>(data, axes_node)}; return {std::make_shared<default_opset::Squeeze>(data, axes_node)};
} }
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include "ngraph/opsets/opset0.hpp" #include "ngraph/opsets/opset0.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/validation_util.hpp"
#include "topk.hpp" #include "topk.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
namespace namespace
...@@ -37,7 +37,7 @@ namespace ...@@ -37,7 +37,7 @@ namespace
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
auto data_rank = data->get_shape().size(); auto data_rank = data->get_shape().size();
return ngraph::onnx_import::common::validate_axis(node, axis, data_rank); return ngraph::normalize_axis(node.get_description(), axis, data_rank);
} }
/// \return Return the second input to the TopK node reshaped to a scalar. /// \return Return the second input to the TopK node reshaped to a scalar.
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
#include "unsqueeze.hpp" #include "unsqueeze.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,10 +34,10 @@ namespace ngraph ...@@ -34,10 +34,10 @@ namespace ngraph
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {}); auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
const auto expanded_rank = data->get_shape().size() + axes.size(); const auto expanded_rank = data->get_shape().size() + axes.size();
std::vector<std::size_t> valid_axes = std::vector<std::size_t> normalized_axes =
common::validate_axes(node, axes, expanded_rank); ngraph::normalize_axes(node.get_description(), axes, expanded_rank);
auto axes_node = std::make_shared<default_opset::Constant>( auto axes_node = std::make_shared<default_opset::Constant>(
element::i64, Shape{valid_axes.size()}, valid_axes); element::i64, Shape{normalized_axes.size()}, normalized_axes);
return {std::make_shared<default_opset::Unsqueeze>(data, axes_node)}; return {std::make_shared<default_opset::Unsqueeze>(data, axes_node)};
} }
......
...@@ -50,38 +50,6 @@ namespace ngraph ...@@ -50,38 +50,6 @@ namespace ngraph
static_cast<onnx::TensorProto_DataType>(onnx_type))); static_cast<onnx::TensorProto_DataType>(onnx_type)));
} }
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank)
{
// Accepted range of value for axis is [-tensor_rank, tensor_rank-1].
return validate_axis(node, axis, tensor_rank, -tensor_rank, tensor_rank - 1);
}
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max)
{
return ngraph::normalize_axis(
node.get_description(), axis, tensor_rank, axis_range_min, axis_range_max);
}
std::vector<std::size_t> validate_axes(const ngraph::onnx_import::Node& node,
std::vector<std::int64_t> axes,
std::int64_t tensor_rank)
{
std::vector<std::size_t> new_axes;
for (auto a : axes)
{
new_axes.push_back(validate_axis(node, a, tensor_rank));
}
return new_axes;
}
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node) ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{ {
const auto outputs_number = node->get_output_size(); const auto outputs_number = node->get_output_size();
......
...@@ -68,53 +68,6 @@ namespace ngraph ...@@ -68,53 +68,6 @@ namespace ngraph
return range; return range;
} }
/// \brief Handle out of range axis.
///
/// \param[in] node The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise
/// returns error.
/// If negative axis, it counts from the last to the first axis, by adding
/// tensor_rank to axis.
///
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank);
/// \brief Handle out of range axis.
///
/// \param[in] node The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
/// \param[in] axis_range_min The min value of accepted range for axis.
/// \param[in] axis_range_max The max value of accepted range for axis.
///
/// \return Checking if axis is in range [axis_range_min, axis_range_max], otherwise
/// returns error.
//// If negative axis, it counts from the last to the first axis, by adding
/// tensor_rank to axis.
///
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max);
/// \brief Handle out of range axes in vector.
///
/// \param[in] node The node with requested axes.
/// \param[in] axes The requested vector of axes.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return If any negative axis in vector, it counts from the last to the first
/// axis, by adding tensor_rank to axis.
///
std::vector<std::size_t> validate_axes(const ngraph::onnx_import::Node& node,
std::vector<std::int64_t> axes,
std::int64_t tensor_rank);
/// \brief Return the outputs of the node as vector. /// \brief Return the outputs of the node as vector.
/// ///
/// \param[in] node Node with multiple outputs. /// \param[in] node Node with multiple outputs.
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include <vector> #include <vector>
#include "default_opset.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "reduction.hpp" #include "reduction.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,16 +34,17 @@ namespace ngraph ...@@ -34,16 +34,17 @@ namespace ngraph
{ {
auto reduction_axes = auto reduction_axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {}); node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> valid_reduction_axes = common::validate_axes( std::vector<std::size_t> normalized_axes =
node, reduction_axes, node.get_ng_inputs().at(0)->get_shape().size()); ngraph::normalize_axes(node.get_description(),
reduction_axes,
node.get_ng_inputs().at(0)->get_shape().size());
if (reduction_axes.empty()) if (reduction_axes.empty())
{ {
valid_reduction_axes = normalized_axes = onnx_import::common::get_monotonic_range<std::size_t>(
onnx_import::common::get_monotonic_range<std::size_t>(
node.get_ng_inputs().at(0)->get_shape().size()); node.get_ng_inputs().at(0)->get_shape().size());
} }
return AxisSet{valid_reduction_axes}; return AxisSet{normalized_axes};
} }
} // namespace detail } // namespace detail
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
...@@ -83,10 +84,11 @@ namespace ngraph ...@@ -83,10 +84,11 @@ namespace ngraph
auto axis = node.get_attribute_value<std::int64_t>("axis", 0); auto axis = node.get_attribute_value<std::int64_t>("axis", 0);
auto keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1); auto keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
auto input_node = node.get_ng_inputs().at(0); auto input_node = node.get_ng_inputs().at(0);
auto valid_axis = common::validate_axis(node, axis, input_node->get_shape().size()); const auto normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, input_node->get_shape().size());
auto op_node = auto op_node =
std::make_shared<IndexReduction>(input_node, valid_axis, element::i64); std::make_shared<IndexReduction>(input_node, normalized_axis, element::i64);
if (keepdims == 0) if (keepdims == 0)
{ {
...@@ -97,7 +99,7 @@ namespace ngraph ...@@ -97,7 +99,7 @@ namespace ngraph
auto convert_node = std::make_shared<ngraph::op::Convert>(op_node, element::f32); auto convert_node = std::make_shared<ngraph::op::Convert>(op_node, element::f32);
auto output_shape = input_node->get_shape(); auto output_shape = input_node->get_shape();
output_shape.at(valid_axis) = 1; output_shape.at(normalized_axis) = 1;
auto reshape_node = std::make_shared<ngraph::op::Reshape>( auto reshape_node = std::make_shared<ngraph::op::Reshape>(
convert_node, convert_node,
ngraph::get_default_order(op_node->get_shape().size()), ngraph::get_default_order(op_node->get_shape().size()),
......
...@@ -795,9 +795,30 @@ PartialShape ngraph::infer_slice_shape(const Node* node, ...@@ -795,9 +795,30 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
return dim; return dim;
} }
std::vector<size_t> ngraph::normalize_axes(const std::string& node_description,
const std::vector<int64_t>& axes,
std::int64_t tensor_rank)
{
std::vector<size_t> new_axes;
for (const auto& axis : axes)
{
new_axes.push_back(normalize_axis(node_description, axis, tensor_rank));
}
return new_axes;
}
int64_t ngraph::normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank) int64_t ngraph::normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank)
{ {
return normalize_axis(node, axis, tensor_rank, -tensor_rank, tensor_rank - 1); return normalize_axis(node->description(), axis, tensor_rank);
}
int64_t ngraph::normalize_axis(const std::string& node_description,
std::int64_t axis,
std::int64_t tensor_rank)
{
return normalize_axis(node_description, axis, tensor_rank, -tensor_rank, tensor_rank - 1);
} }
int64_t ngraph::normalize_axis(const Node* node, int64_t ngraph::normalize_axis(const Node* node,
......
...@@ -114,6 +114,32 @@ namespace ngraph ...@@ -114,6 +114,32 @@ namespace ngraph
/// by adding tensor_rank to axis. /// by adding tensor_rank to axis.
int64_t normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank); int64_t normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank);
/// \brief Handle out of range axes in vector.
///
/// \param[in] node_description The name of node with requested axes.
/// \param[in] axes The requested vector of axes.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return If any negative axis in vector, it counts from the last to the first
/// axis, by adding tensor_rank to axis.
///
std::vector<size_t> normalize_axes(const std::string& node_description,
const std::vector<int64_t>& axes,
std::int64_t tensor_rank);
/// \brief Handle out of range axis.
///
/// \param[in] node_description The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise
/// returns error. If negative axis, it counts from the last to the first axis,
/// by adding tensor_rank to axis.
int64_t normalize_axis(const std::string& node_description,
std::int64_t axis,
std::int64_t tensor_rank);
/// \brief Handle out of range axis. /// \brief Handle out of range axis.
/// ///
/// \param[in] node The node with requested axis. /// \param[in] node The node with requested axis.
...@@ -133,7 +159,7 @@ namespace ngraph ...@@ -133,7 +159,7 @@ namespace ngraph
/// \brief Handle out of range axis. /// \brief Handle out of range axis.
/// ///
/// \param[in] node The name of node with requested axis. /// \param[in] node_description The name of node with requested axis.
/// \param[in] axis The requested axis value. /// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank. /// \param[in] tensor_rank The corresponding tensor rank.
/// \param[in] axis_range_min The min value of accepted range for axis. /// \param[in] axis_range_min The min value of accepted range for axis.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment