Commit bcc424ab authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Sang Ik Lee

Use ngraph::normalize in onnx_importer (#4098)

* Use normalize in onnx importer

* Introduced changes after code review
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent 47abb1cb
......@@ -20,7 +20,7 @@
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/op/concat.hpp"
#include "utils/common.hpp"
#include "ngraph/validation_util.hpp"
namespace ngraph
{
......@@ -34,10 +34,10 @@ namespace ngraph
{
NodeVector inputs{node.get_ng_inputs()};
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
size_t valid_axis =
common::validate_axis(node, axis, inputs.at(0)->get_shape().size());
const auto normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, inputs.at(0)->get_shape().size());
return {std::make_shared<default_opset::Concat>(inputs, valid_axis)};
return {std::make_shared<default_opset::Concat>(inputs, normalized_axis)};
}
} // namespace set_1
......
......@@ -19,7 +19,7 @@
#include "exceptions.hpp"
#include "flatten.hpp"
#include "ngraph/builder/reshape.hpp"
#include "utils/common.hpp"
#include "ngraph/validation_util.hpp"
namespace ngraph
{
......@@ -36,10 +36,10 @@ namespace ngraph
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
auto data_rank = data->get_shape().size();
// Accepted range is [-r, r] where r = rank(input).
auto valid_axis =
common::validate_axis(node, axis, data_rank, -data_rank, data_rank);
const auto normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, data_rank, -data_rank, data_rank);
return {ngraph::builder::flatten(data, valid_axis)};
return {ngraph::builder::flatten(data, normalized_axis)};
}
} // namespace set_1
......
......@@ -21,7 +21,7 @@
#include "core/node.hpp"
#include "default_opset.hpp"
#include "ngraph/node.hpp"
#include "utils/common.hpp"
#include "ngraph/validation_util.hpp"
namespace ngraph
{
......@@ -37,7 +37,8 @@ namespace ngraph
auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1);
auto axis = node.get_attribute_value<int64_t>("axis", 0);
auto valid_axis = common::validate_axis(node, axis, data->get_shape().size());
const auto valid_axis = ngraph::normalize_axis(
node.get_description(), axis, data->get_shape().size());
return {std::make_shared<default_opset::Gather>(
data,
......
......@@ -17,8 +17,9 @@
#include "hardmax.hpp"
#include "exceptions.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/frontend/onnx_import/utils/common.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/validation_util.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -34,10 +35,11 @@ namespace ngraph
const auto& input_shape = input->get_shape();
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
auto valid_axis = common::validate_axis(node, axis, input_shape.size());
const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, input_shape.size());
// reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const auto coerced_tensor = ngraph::builder::flatten(input, valid_axis);
const auto coerced_tensor = ngraph::builder::flatten(input, normalized_axis);
const auto& coerced_shape = coerced_tensor->get_shape();
const std::shared_ptr<ngraph::Node> argmax_2d =
......
......@@ -26,7 +26,7 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/op/divide.hpp"
#include "utils/common.hpp"
#include "ngraph/validation_util.hpp"
namespace ngraph
{
......@@ -42,25 +42,26 @@ namespace ngraph
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
const std::size_t valid_axis =
common::validate_axis(node, axis, data->get_shape().size());
const size_t normalize_axis = ngraph::normalize_axis(
node.get_description(), axis, data->get_shape().size());
ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
<< "Invalid `p` attribute value: " << p_norm
<< "Only normalization of 1st or 2nd order is supported.";
std::shared_ptr<ngraph::Node> norm = ngraph::builder::lp_norm(
data, AxisSet{valid_axis}, static_cast<std::size_t>(p_norm));
data, AxisSet{normalize_axis}, static_cast<std::size_t>(p_norm));
const auto target_shape = default_opset::Constant::create(
element::i64, Shape{data->get_shape().size()}, data->get_shape());
// Create a default axes order matching the data tensor rank and erase the
// element at the 'valid_axis' position. The erased element indicates the axis
// element at the 'normalize_axis' position. The erased element indicates the
// axis
// along which the data should be broadcasted.
std::vector<size_t> axes_values(data->get_shape().size());
std::iota(axes_values.begin(), axes_values.end(), 0);
axes_values.erase(axes_values.begin() + valid_axis);
axes_values.erase(axes_values.begin() + normalize_axis);
const auto axes_mapping = default_opset::Constant::create(
element::i64, Shape{axes_values.size()}, axes_values);
......
......@@ -20,7 +20,7 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "utils/common.hpp"
#include "ngraph/validation_util.hpp"
namespace ngraph
{
......@@ -50,10 +50,10 @@ namespace ngraph
{
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes", {0, 2, 3});
std::vector<std::size_t> valid_axes =
common::validate_axes(node, axes, data->get_shape().size());
std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
node.get_description(), axes, data->get_shape().size());
return {std::make_shared<ngraph::opset0::MVN>(data, AxisSet(valid_axes))};
return {std::make_shared<ngraph::opset0::MVN>(data, AxisSet(normalized_axes))};
}
} // namespace set_9
......
......@@ -20,8 +20,8 @@
#include "default_opset.hpp"
#include "ngraph/node.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/validation_util.hpp"
#include "reverse_sequence.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -41,26 +41,26 @@ namespace ngraph
node.get_ng_inputs().at(1), element::i32);
const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1);
std::size_t valid_batch_axis =
common::validate_axis(node, batch_axis, data->get_shape().size());
const auto normalized_batch_axis = ngraph::normalize_axis(
node.get_description(), batch_axis, data->get_shape().size());
const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0);
std::size_t valid_time_axis =
common::validate_axis(node, time_axis, data->get_shape().size());
const auto normalized_time_axis = ngraph::normalize_axis(
node.get_description(), time_axis, data->get_shape().size());
NGRAPH_CHECK(valid_batch_axis == 0 || valid_batch_axis == 1,
NGRAPH_CHECK(normalized_batch_axis == 0 || normalized_batch_axis == 1,
"Allowed values of the 'batch_axis' attribute for ReverseSequence "
"operator are 0 and 1");
NGRAPH_CHECK(valid_time_axis == 0 || valid_time_axis == 1,
NGRAPH_CHECK(normalized_time_axis == 0 || normalized_time_axis == 1,
"Allowed values of the 'time_axis' attribute for ReverseSequence "
"operator are 0 and 1");
NGRAPH_CHECK(valid_batch_axis != valid_time_axis,
NGRAPH_CHECK(normalized_batch_axis != normalized_time_axis,
"'batch_axis' and 'time_axis' attributes of the ReverseSequence "
"operator can't point to the same dimension");
return {std::make_shared<default_opset::ReverseSequence>(
data, sequence_lengths_i32, valid_batch_axis, valid_time_axis)};
data, sequence_lengths_i32, normalized_batch_axis, normalized_time_axis)};
}
} // namespace set_1
......
......@@ -17,8 +17,8 @@
#include <memory>
#include "default_opset.hpp"
#include "ngraph/validation_util.hpp"
#include "softmax.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -35,9 +35,10 @@ namespace ngraph
auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1);
std::size_t valid_axis = common::validate_axis(node, axis, data_shape.size());
const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, data_shape.size());
return {std::make_shared<default_opset::Softmax>(data, valid_axis)};
return {std::make_shared<default_opset::Softmax>(data, normalized_axis)};
}
} // namespace set_1
......
......@@ -16,11 +16,12 @@
#include <vector>
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/validation_util.hpp"
#include "squeeze.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -35,10 +36,10 @@ namespace ngraph
auto data = node.get_ng_inputs().at(0);
std::vector<std::int64_t> axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> valid_axes =
common::validate_axes(node, axes, data->get_shape().size());
std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
node.get_description(), axes, data->get_shape().size());
auto axes_node = std::make_shared<default_opset::Constant>(
element::u64, Shape{valid_axes.size()}, valid_axes);
element::u64, Shape{normalized_axes.size()}, normalized_axes);
return {std::make_shared<default_opset::Squeeze>(data, axes_node)};
}
......
......@@ -24,8 +24,8 @@
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/validation_util.hpp"
#include "topk.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
namespace
......@@ -37,7 +37,7 @@ namespace
auto data = node.get_ng_inputs().at(0);
auto data_rank = data->get_shape().size();
return ngraph::onnx_import::common::validate_axis(node, axis, data_rank);
return ngraph::normalize_axis(node.get_description(), axis, data_rank);
}
/// \return Return the second input to the TopK node reshaped to a scalar.
......
......@@ -18,8 +18,8 @@
#include "default_opset.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
#include "unsqueeze.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -34,10 +34,10 @@ namespace ngraph
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
const auto expanded_rank = data->get_shape().size() + axes.size();
std::vector<std::size_t> valid_axes =
common::validate_axes(node, axes, expanded_rank);
std::vector<std::size_t> normalized_axes =
ngraph::normalize_axes(node.get_description(), axes, expanded_rank);
auto axes_node = std::make_shared<default_opset::Constant>(
element::i64, Shape{valid_axes.size()}, valid_axes);
element::i64, Shape{normalized_axes.size()}, normalized_axes);
return {std::make_shared<default_opset::Unsqueeze>(data, axes_node)};
}
......
......@@ -50,38 +50,6 @@ namespace ngraph
static_cast<onnx::TensorProto_DataType>(onnx_type)));
}
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank)
{
// Accepted range of value for axis is [-tensor_rank, tensor_rank-1].
return validate_axis(node, axis, tensor_rank, -tensor_rank, tensor_rank - 1);
}
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max)
{
return ngraph::normalize_axis(
node.get_description(), axis, tensor_rank, axis_range_min, axis_range_max);
}
std::vector<std::size_t> validate_axes(const ngraph::onnx_import::Node& node,
std::vector<std::int64_t> axes,
std::int64_t tensor_rank)
{
std::vector<std::size_t> new_axes;
for (auto a : axes)
{
new_axes.push_back(validate_axis(node, a, tensor_rank));
}
return new_axes;
}
ngraph::NodeVector get_outputs(const std::shared_ptr<ngraph::Node>& node)
{
const auto outputs_number = node->get_output_size();
......
......@@ -68,53 +68,6 @@ namespace ngraph
return range;
}
/// \brief Handle out of range axis.
///
/// \param[in] node The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise
/// returns error.
/// If negative axis, it counts from the last to the first axis, by adding
/// tensor_rank to axis.
///
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank);
/// \brief Handle out of range axis.
///
/// \param[in] node The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
/// \param[in] axis_range_min The min value of accepted range for axis.
/// \param[in] axis_range_max The max value of accepted range for axis.
///
/// \return Checking if axis is in range [axis_range_min, axis_range_max], otherwise
/// returns error.
//// If negative axis, it counts from the last to the first axis, by adding
/// tensor_rank to axis.
///
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max);
/// \brief Handle out of range axes in vector.
///
/// \param[in] node The node with requested axes.
/// \param[in] axes The requested vector of axes.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return If any negative axis in vector, it counts from the last to the first
/// axis, by adding tensor_rank to axis.
///
std::vector<std::size_t> validate_axes(const ngraph::onnx_import::Node& node,
std::vector<std::int64_t> axes,
std::int64_t tensor_rank);
/// \brief Return the outputs of the node as vector.
///
/// \param[in] node Node with multiple outputs.
......
......@@ -17,10 +17,10 @@
#include <cstddef> // std::size_t
#include <vector>
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/op/constant.hpp"
#include "reduction.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -34,16 +34,17 @@ namespace ngraph
{
auto reduction_axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> valid_reduction_axes = common::validate_axes(
node, reduction_axes, node.get_ng_inputs().at(0)->get_shape().size());
std::vector<std::size_t> normalized_axes =
ngraph::normalize_axes(node.get_description(),
reduction_axes,
node.get_ng_inputs().at(0)->get_shape().size());
if (reduction_axes.empty())
{
valid_reduction_axes =
onnx_import::common::get_monotonic_range<std::size_t>(
normalized_axes = onnx_import::common::get_monotonic_range<std::size_t>(
node.get_ng_inputs().at(0)->get_shape().size());
}
return AxisSet{valid_reduction_axes};
return AxisSet{normalized_axes};
}
} // namespace detail
......
......@@ -26,6 +26,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
......@@ -83,10 +84,11 @@ namespace ngraph
auto axis = node.get_attribute_value<std::int64_t>("axis", 0);
auto keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
auto input_node = node.get_ng_inputs().at(0);
auto valid_axis = common::validate_axis(node, axis, input_node->get_shape().size());
const auto normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, input_node->get_shape().size());
auto op_node =
std::make_shared<IndexReduction>(input_node, valid_axis, element::i64);
std::make_shared<IndexReduction>(input_node, normalized_axis, element::i64);
if (keepdims == 0)
{
......@@ -97,7 +99,7 @@ namespace ngraph
auto convert_node = std::make_shared<ngraph::op::Convert>(op_node, element::f32);
auto output_shape = input_node->get_shape();
output_shape.at(valid_axis) = 1;
output_shape.at(normalized_axis) = 1;
auto reshape_node = std::make_shared<ngraph::op::Reshape>(
convert_node,
ngraph::get_default_order(op_node->get_shape().size()),
......
......@@ -795,9 +795,30 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
return dim;
}
std::vector<size_t> ngraph::normalize_axes(const std::string& node_description,
const std::vector<int64_t>& axes,
std::int64_t tensor_rank)
{
std::vector<size_t> new_axes;
for (const auto& axis : axes)
{
new_axes.push_back(normalize_axis(node_description, axis, tensor_rank));
}
return new_axes;
}
int64_t ngraph::normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank)
{
return normalize_axis(node, axis, tensor_rank, -tensor_rank, tensor_rank - 1);
return normalize_axis(node->description(), axis, tensor_rank);
}
int64_t ngraph::normalize_axis(const std::string& node_description,
std::int64_t axis,
std::int64_t tensor_rank)
{
return normalize_axis(node_description, axis, tensor_rank, -tensor_rank, tensor_rank - 1);
}
int64_t ngraph::normalize_axis(const Node* node,
......
......@@ -114,6 +114,32 @@ namespace ngraph
/// by adding tensor_rank to axis.
int64_t normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank);
/// \brief Handle out of range axes in vector.
///
/// \param[in] node_description The name of node with requested axes.
/// \param[in] axes The requested vector of axes.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return If any negative axis in vector, it counts from the last to the first
/// axis, by adding tensor_rank to axis.
///
std::vector<size_t> normalize_axes(const std::string& node_description,
const std::vector<int64_t>& axes,
std::int64_t tensor_rank);
/// \brief Handle out of range axis.
///
/// \param[in] node_description The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise
/// returns error. If negative axis, it counts from the last to the first axis,
/// by adding tensor_rank to axis.
int64_t normalize_axis(const std::string& node_description,
std::int64_t axis,
std::int64_t tensor_rank);
/// \brief Handle out of range axis.
///
/// \param[in] node The node with requested axis.
......@@ -133,7 +159,7 @@ namespace ngraph
/// \brief Handle out of range axis.
///
/// \param[in] node The name of node with requested axis.
/// \param[in] node_description The name of node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
/// \param[in] axis_range_min The min value of accepted range for axis.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment