Commit 73942928 authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[ONNX] Add Relu op (#1448)

* [ONNX] Add Relu op
parent da352aa1
...@@ -42,6 +42,7 @@ add_library(onnx_import STATIC ...@@ -42,6 +42,7 @@ add_library(onnx_import STATIC
op/add.hpp op/add.hpp
op/batch_norm.hpp op/batch_norm.hpp
op/constant.hpp op/constant.hpp
op/relu.hpp
op/split.hpp op/split.hpp
ops_bridge.cpp ops_bridge.cpp
tensor.hpp tensor.hpp
......
...@@ -36,4 +36,4 @@ namespace ngraph ...@@ -36,4 +36,4 @@ namespace ngraph
} // namespace onnx_import } // namespace onnx_import
} // namespace ngrahp } // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/frontend/onnx_import/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/relu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector relu(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Relu>(ng_inputs.at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/frontend/onnx_import/op/add.hpp" #include "ngraph/frontend/onnx_import/op/add.hpp"
#include "ngraph/frontend/onnx_import/op/batch_norm.hpp" #include "ngraph/frontend/onnx_import/op/batch_norm.hpp"
#include "ngraph/frontend/onnx_import/op/constant.hpp" #include "ngraph/frontend/onnx_import/op/constant.hpp"
#include "ngraph/frontend/onnx_import/op/relu.hpp"
#include "ngraph/frontend/onnx_import/op/split.hpp" #include "ngraph/frontend/onnx_import/op/split.hpp"
#include "ops_bridge.hpp" #include "ops_bridge.hpp"
...@@ -57,7 +58,7 @@ namespace ngraph ...@@ -57,7 +58,7 @@ namespace ngraph
{ {
return op::split(node, node.get_ng_inputs().at(0)); return op::split(node, node.get_ng_inputs().at(0));
} }
NodeVector relu(const Node& node) { return op::relu(node); }
class ops_bridge class ops_bridge
{ {
public: public:
...@@ -86,6 +87,7 @@ namespace ngraph ...@@ -86,6 +87,7 @@ namespace ngraph
m_map.emplace("BatchNormalization", m_map.emplace("BatchNormalization",
std::bind(batch_norm, std::placeholders::_1)); std::bind(batch_norm, std::placeholders::_1));
m_map.emplace("Constant", std::bind(constant, std::placeholders::_1)); m_map.emplace("Constant", std::bind(constant, std::placeholders::_1));
m_map.emplace("Relu", std::bind(relu, std::placeholders::_1));
m_map.emplace("Split", std::bind(split, std::placeholders::_1)); m_map.emplace("Split", std::bind(split, std::placeholders::_1));
} }
......
 backend-test:;
xy"Relu test_reluZ
x

b
y

B
\ No newline at end of file
...@@ -146,3 +146,16 @@ TEST(onnx, model_batchnorm_default) ...@@ -146,3 +146,16 @@ TEST(onnx, model_batchnorm_default)
auto result_vectors = execute(function, inputs, "INTERPRETER"); auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front())); EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
} }
TEST(onnx, model_relu)
{
// Simple ReLU test
auto function{ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/relu.onnx"))};
auto inputs = std::vector<std::vector<float>>{{-1, -2, 0, 1, 2, 3}};
auto expected_output = std::vector<std::vector<float>>{{0, 0, 0, 1, 2, 3}};
auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output.front(), result_vectors.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment