Unverified Commit d605e7fa authored by Ewa Tusień's avatar Ewa Tusień Committed by GitHub

Add Round op to ONNX importer (#4303)

* Added Round op to onnx importer.

* Added a new line.

* Excluded test from plaidml.

* Code formatting.

* Added header.

* Unabled test on gpu.
Co-authored-by: 's avatarChris Sullivan <chris.sullivan@intel.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent f8856e2b
......@@ -171,6 +171,8 @@ add_library(onnx_import STATIC
op/reshape.hpp
op/reverse_sequence.cpp
op/reverse_sequence.hpp
op/round.cpp
op/round.hpp
op/scatter_nd.cpp
op/scatter_nd.hpp
op/selu.cpp
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/opsets/opset0.hpp"
#include "round.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector round(const Node& node)
{
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
return {std::make_shared<ngraph::opset0::Round>(data)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector round(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -101,6 +101,7 @@
#include "op/relu.hpp"
#include "op/reshape.hpp"
#include "op/reverse_sequence.hpp"
#include "op/round.hpp"
#include "op/scatter_nd.hpp"
#include "op/selu.hpp"
#include "op/shape.hpp"
......@@ -334,6 +335,7 @@ namespace ngraph
REGISTER_OPERATOR("Relu", 1, relu);
REGISTER_OPERATOR("Reshape", 1, reshape);
REGISTER_OPERATOR("ReverseSequence", 1, reverse_sequence);
REGISTER_OPERATOR("Round", 1, round);
REGISTER_OPERATOR("ScatterND", 1, scatter_nd);
REGISTER_OPERATOR("Selu", 1, selu);
REGISTER_OPERATOR("Shape", 1, shape);
......
......@@ -453,6 +453,7 @@ model_gatherND_int32
model_gatherND_float
model_pad_constant
model_reciprocal
model_round
tile_3d_small_data_rank
tile_3d_few_repeats
select_v1
......
......@@ -282,6 +282,7 @@ model_argmax_int32
model_argmin_int32
model_lp_norm_default
model_instance_normalization
model_round
# passing locally, fails closeness checks in CI which may be too strict
elu
......
ir_version: 3
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
op_type: "Round"
}
name: "test_round"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 15
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 15
}
}
}
}
}
}
opset_import {
version: 11
}
......@@ -1963,3 +1963,30 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reciprocal)
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_round)
{
const auto round_fn =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/round.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(round_fn, "${BACKEND_NAME}");
test_case.add_input<float>({0.1f,
0.5f,
0.9f,
1.2f,
1.5f,
1.8f,
2.3f,
2.5f,
2.7f,
-1.1f,
-1.5f,
-1.9f,
-2.2f,
-2.5f,
-2.8f});
test_case.add_expected_output<float>(
{0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 2.f, 2.f, 3.f, -1.f, -2.f, -2.f, -2.f, -2.f, -3.f});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment