Commit 226b595e authored by Ewa Tusień's avatar Ewa Tusień Committed by Sang Ik Lee

[ONNX] Add GatherND op to ONNX importer (#3963)

* [ONNX] Added gatherND op to ONNX importer.

* Added tests.

* Removed new line.

* Update onnx_import.in.cpp

* Changed tests.
parent e47cebea
...@@ -96,6 +96,8 @@ add_library(onnx_import STATIC ...@@ -96,6 +96,8 @@ add_library(onnx_import STATIC
op/flatten.hpp op/flatten.hpp
op/floor.hpp op/floor.hpp
op/gather.hpp op/gather.hpp
op/gather_nd.hpp
op/gather_nd.cpp
op/gemm.cpp op/gemm.cpp
op/gemm.hpp op/gemm.hpp
op/global_average_pool.cpp op/global_average_pool.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/gather_nd.hpp"
#include "utils/common.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector gather_nd(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1);
return {std::make_shared<ngraph::op::GatherND>(data, indices)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector gather_nd(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -59,6 +59,7 @@ ...@@ -59,6 +59,7 @@
#include "op/flatten.hpp" #include "op/flatten.hpp"
#include "op/floor.hpp" #include "op/floor.hpp"
#include "op/gather.hpp" #include "op/gather.hpp"
#include "op/gather_nd.hpp"
#include "op/gemm.hpp" #include "op/gemm.hpp"
#include "op/global_average_pool.hpp" #include "op/global_average_pool.hpp"
#include "op/global_max_pool.hpp" #include "op/global_max_pool.hpp"
...@@ -275,6 +276,7 @@ namespace ngraph ...@@ -275,6 +276,7 @@ namespace ngraph
REGISTER_OPERATOR("Flatten", 1, flatten); REGISTER_OPERATOR("Flatten", 1, flatten);
REGISTER_OPERATOR("Floor", 1, floor); REGISTER_OPERATOR("Floor", 1, floor);
REGISTER_OPERATOR("Gather", 1, gather); REGISTER_OPERATOR("Gather", 1, gather);
REGISTER_OPERATOR("GatherND", 1, gather_nd);
REGISTER_OPERATOR("Gemm", 1, gemm); REGISTER_OPERATOR("Gemm", 1, gemm);
REGISTER_OPERATOR("Gemm", 6, gemm); REGISTER_OPERATOR("Gemm", 6, gemm);
REGISTER_OPERATOR("GlobalAveragePool", 1, global_average_pool); REGISTER_OPERATOR("GlobalAveragePool", 1, global_average_pool);
......
...@@ -23,3 +23,6 @@ lrn_2d_across_outermost_axis ...@@ -23,3 +23,6 @@ lrn_2d_across_outermost_axis
# ONNX TopK with dynamic K # ONNX TopK with dynamic K
top_k_opset_10 top_k_opset_10
# ONNX GatherND with int32
model_gatherND_int32
...@@ -263,6 +263,8 @@ model_lstm_fwd_hardsigmoid_activation ...@@ -263,6 +263,8 @@ model_lstm_fwd_hardsigmoid_activation
model_lstm_fwd_with_clip model_lstm_fwd_with_clip
model_lstm_fwd_mixed_seq model_lstm_fwd_mixed_seq
model_lstm_fwd_large_batch_no_clip model_lstm_fwd_large_batch_no_clip
model_gatherND_int32
model_gatherND_float
model_global_lp_pool_p3 model_global_lp_pool_p3
model_argmin_no_keepdims model_argmin_no_keepdims
model_reduce_log_sum_exp model_reduce_log_sum_exp
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "i"
output: "y"
op_type: "GatherND"
}
name: "test_gatherND_float"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "i"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 7
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "i"
output: "y"
op_type: "GatherND"
}
name: "test_gatherND_int32"
input {
name: "x"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "i"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 7
}
...@@ -1776,3 +1776,29 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_mod) ...@@ -1776,3 +1776,29 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_mod)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_gatherND_int32)
{
const auto gatherND_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/gatherND_int32.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(gatherND_fn, "${BACKEND_NAME}");
test_case.add_input<int32_t>({0, 1, 2, 3});
test_case.add_input<int64_t>({1, 0});
test_case.add_expected_output<int32_t>(Shape{2, 2}, {2, 3, 0, 1});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_gatherND_float)
{
const auto gatherND_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/gatherND_float.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(gatherND_fn, "${BACKEND_NAME}");
test_case.add_input<float>({0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f});
test_case.add_input<int64_t>({0, 1, 1, 0});
test_case.add_expected_output<float>(Shape{2, 2}, {2.f, 3.f, 4.f, 5.f});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment