Commit 0a5c2b81 authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[ONNX] Support for Gather op (#2772)

* Temp.

* Put all the dummy files.

* Remove some compile errors.

* WIP: Add gather and gather_nd kernels.

* Temp save.

* Update comments for gather.

* Implement reference gather.

* Validate and infer shape.

* Style.

* Fix compile issues.

* Add serializer support.

* Fix interpreter compilation issues.

* WIP: Add UT

* WIP: Add UT

* gather_nd UT passing.

* Fix gather with no axis.

* Fix gather issue.

* Update unit_test.manifest for backends and add gather, gather_nd  support for generic cpu.

* Add type_prop tests.

* Add CPU builders.

* Fix codegen.

* Make some UT numbers more readable.

* Style.

* Update Copyright Year

* Update Copyright Year

* Fix Typo.

* Remove unused variable.

* fix nv gpu build error

* Fix intel gpu compilation.

* Add basic docstring.

* Allow 1D indices for gather_nd.

* Allow scalar indices for gather.

* Update unit_test manifest files.

* Style.

* Add indices element type check and add failing type_prop checks.

* Update docstring.

* Fix incorrect test names in unit_test.manifest

* [ONNX] Support for Gather op

* Remove unneeded broadcast

* Remove unused include

* Set correct default value for axis
parent 6cfe5b41
......@@ -86,6 +86,7 @@ add_library(onnx_import STATIC
op/flatten.cpp
op/flatten.hpp
op/floor.hpp
op/gather.hpp
op/gemm.cpp
op/gemm.hpp
op/global_average_pool.cpp
......
......@@ -51,7 +51,7 @@ namespace ngraph
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace set_1
} // namespace set_7
} //namespace op
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/gather.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector gather(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1);
auto axis = node.get_attribute_value<int64_t>("axis", 0);
return {std::make_shared<ngraph::op::Gather>(data, indices, axis)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -54,6 +54,7 @@
#include "op/exp.hpp"
#include "op/flatten.hpp"
#include "op/floor.hpp"
#include "op/gather.hpp"
#include "op/gemm.hpp"
#include "op/global_average_pool.hpp"
#include "op/global_max_pool.hpp"
......@@ -255,6 +256,7 @@ namespace ngraph
REGISTER_OPERATOR("Exp", 1, exp);
REGISTER_OPERATOR("Flatten", 1, flatten);
REGISTER_OPERATOR("Floor", 1, floor);
REGISTER_OPERATOR("Gather", 1, gather);
REGISTER_OPERATOR("Gemm", 1, gemm);
REGISTER_OPERATOR("GlobalAveragePool", 1, global_average_pool);
REGISTER_OPERATOR("GlobalLpPool", 1, global_lp_pool);
......
......@@ -153,4 +153,3 @@ gather_nd_batch_2d_from_3d
gather_scalar_indices_no_axis
gather_scalar_indices
gather_nd_single_indices
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment