Unverified Commit 18561cbb authored by Ewa Tusień's avatar Ewa Tusień Committed by GitHub

[ONNX] Require the rank only if any value of axes is negative in Unsqueeze and Concat ops. (#4453)

* Added check if getting rank is needed.

* Removed normalization from importer.

* Added normalisation to Concat op.

* Added test for Concat.

* Added test for Unsqueeze

* Added missed test.
Co-authored-by: 's avatarMichał Karzyński <postrational@users.noreply.github.com>
parent 8017c094
...@@ -34,12 +34,14 @@ namespace ngraph ...@@ -34,12 +34,14 @@ namespace ngraph
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis"); std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
const auto normalized_axis = if (axis < 0)
ngraph::normalize_axis(node.get_description(), {
axis, axis = ngraph::normalize_axis(
inputs.at(0)->get_output_partial_shape(0).rank()); node.get_description(),
axis,
return {std::make_shared<default_opset::Concat>(inputs, normalized_axis)}; inputs.at(0)->get_output_partial_shape(0).rank());
}
return {std::make_shared<default_opset::Concat>(inputs, axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -34,16 +34,8 @@ namespace ngraph ...@@ -34,16 +34,8 @@ namespace ngraph
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {}); auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
const auto data_rank = data->get_output_partial_shape(0).rank();
CHECK_VALID_NODE(node,
data_rank.is_static(),
"Data rank must be static for creation of ONNX Unsqueeze op");
const auto expanded_rank =
data->get_output_partial_shape(0).rank() + axes.size();
std::vector<std::size_t> normalized_axes =
ngraph::normalize_axes(node.get_description(), axes, expanded_rank);
auto axes_node = std::make_shared<default_opset::Constant>( auto axes_node = std::make_shared<default_opset::Constant>(
element::i64, Shape{normalized_axes.size()}, normalized_axes); element::i64, Shape{axes.size()}, axes);
return {std::make_shared<default_opset::Unsqueeze>(data, axes_node)}; return {std::make_shared<default_opset::Unsqueeze>(data, axes_node)};
} }
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "value0"
input: "value1"
output: "output"
op_type: "Concat"
attribute {
name: "axis"
i: -2
type: INT
}
}
name: "test_concat_1d_axis_0"
input {
name: "value0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "value1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 7
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "Unsqueeze"
attribute {
name: "axes"
ints: -2
type: INTS
}
}
name: "test_unsqueeze"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 1
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 7
}
...@@ -364,6 +364,30 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_unsqueeze) ...@@ -364,6 +364,30 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_unsqueeze)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_unsqueeze_negative_axes)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/unsqueeze_negative_axes.prototxt"));
auto input = test::NDArray<float, 4>(
{{{{-1.8427763f, -1.0467733f, 0.50550157f, 1.4897262f, 0.33057404f}},
{{1.9244908f, -0.3804572f, 0.76275414f, -0.8183123f, 0.93889356f}},
{{-0.05270234f, 0.7113202f, -0.45783648f, -1.3378475f, 0.26926285f}}}})
.get_vector();
auto expected_output =
test::NDArray<float, 5>(
{{{{{-1.8427763f, -1.0467733f, 0.50550157f, 1.4897262f, 0.33057404f}}},
{{{1.9244908f, -0.3804572f, 0.76275414f, -0.8183123f, 0.93889356f}}},
{{{-0.05270234f, 0.7113202f, -0.45783648f, -1.3378475f, 0.26926285f}}}}})
.get_vector();
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(input);
test_case.add_expected_output(expected_output);
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_concat) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_concat)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_model(
...@@ -382,6 +406,24 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_concat) ...@@ -382,6 +406,24 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_concat)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_concat_negative_axis)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/concat_negative_axis.prototxt"));
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 2>({{1, 2}, {3, 4}}).get_vector());
inputs.emplace_back(test::NDArray<float, 2>({{5, 6}, {7, 8}}).get_vector());
auto expected_output = test::NDArray<float, 2>({{1, 2}, {3, 4}, {5, 6}, {7, 8}}).get_vector();
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_multiple_inputs(inputs);
test_case.add_expected_output(expected_output);
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_flatten) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_flatten)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_model(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment