Unverified Commit 28432321 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

[ONNX] Add Tile op (#2990)

* [ONNX] Add Tile op

* Fix test model output shape

* Add static test for Tile

* Test Case: use expected shape, not computed output shape

* Exclude dynamic Tile op tests

* Add tests using BackendMode::DYNAMIC

* Code review comments
Co-Authored-By: 's avatarKatarzyna Mitrus <katarzyna.mitrus@intel.com>

* Apply suggestions from code review
Co-Authored-By: 's avatarKatarzyna Mitrus <katarzyna.mitrus@intel.com>

* Code review
Co-authored-by: 's avatarKatarzyna Mitrus <katarzyna.mitrus@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent 2dc3f8b5
......@@ -212,6 +212,8 @@ add_library(onnx_import STATIC
op/tanh.hpp
op/thresholded_relu.cpp
op/thresholded_relu.hpp
op/tile.cpp
op/tile.hpp
op/topk.cpp
op/topk.hpp
op/transpose.cpp
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "core/node.hpp"
#include "default_opset.hpp"
#include "tile.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector tile(const Node& node)
{
auto input = node.get_ng_inputs().at(0);
auto repeats = node.get_ng_inputs().at(1);
// Workaround for backends which require repeats to be i64.
// Remove the following line when no longer needed.
repeats = std::make_shared<default_opset::Convert>(repeats, element::i64);
return {std::make_shared<default_opset::Tile>(input, repeats)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Performs ONNX Tile operation.
///
/// \param node The ONNX node object representing this operation.
/// \return The vector containing nGraph a node producing the output of the Tile op.
NodeVector tile(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -125,6 +125,7 @@
#include "op/tan.hpp"
#include "op/tanh.hpp"
#include "op/thresholded_relu.hpp"
#include "op/tile.hpp"
#include "op/topk.hpp"
#include "op/transpose.hpp"
#include "op/unsqueeze.hpp"
......@@ -362,6 +363,7 @@ namespace ngraph
REGISTER_OPERATOR("Tan", 1, tan);
REGISTER_OPERATOR("Tanh", 1, tanh);
REGISTER_OPERATOR("ThresholdedRelu", 1, thresholded_relu);
REGISTER_OPERATOR("Tile", 1, tile);
REGISTER_OPERATOR("TopK", 1, topk);
REGISTER_OPERATOR("TopK", 10, topk);
REGISTER_OPERATOR("TopK", 11, topk);
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "input"
input: "repeats"
output: "output"
op_type: "Tile"
}
name: "test_tile"
input {
name: "input"
type {
tensor_type {
elem_type: 5
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "repeats"
type {
tensor_type {
elem_type: 5
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 5
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 8
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "input"
input: "repeats"
output: "output"
op_type: "Tile"
}
name: "test_tile"
initializer {
dims: 2
data_type: 7
raw_data: "\002\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000"
name: "repeats"
}
input {
name: "input"
type {
tensor_type {
elem_type: 5
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "repeats"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 5
shape {
dim {
dim_value: 4
}
dim {
dim_value: 6
}
}
}
}
}
}
opset_import {
version: 8
}
......@@ -550,3 +550,29 @@ NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, expand_uint16_dyn_shape)
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_tile)
{
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/tile.prototxt"));
auto test_case =
ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}", BackendMode::DYNAMIC);
test_case.add_input<std::int16_t>({0, 1, 2, 3, 4, 5}); // input
test_case.add_input<std::int16_t>({2, 1}); // repeats
test_case.add_expected_output<std::int16_t>(Shape{4, 3}, {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_tile_static)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/tile_static.prototxt"));
auto test_case =
ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}", BackendMode::DYNAMIC);
test_case.add_input<std::int16_t>({0, 1, 2, 3, 4, 5}); // input
test_case.add_expected_output<std::int16_t>(
Shape{4, 6}, {0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment