Commit 8022982f authored by tsocha's avatar tsocha Committed by Michał Karzyński

[ONNX] Enable sub and div operators. (#1510)

parent 9a22ffc8
...@@ -39,6 +39,8 @@ add_library(onnx_import STATIC ...@@ -39,6 +39,8 @@ add_library(onnx_import STATIC
op/constant.cpp op/constant.cpp
op/constant.hpp op/constant.hpp
op/conv.cpp op/conv.cpp
op/conv.hpp
op/div.hpp
op/gemm.cpp op/gemm.cpp
op/gemm.hpp op/gemm.hpp
op/matmul.hpp op/matmul.hpp
...@@ -50,6 +52,7 @@ add_library(onnx_import STATIC ...@@ -50,6 +52,7 @@ add_library(onnx_import STATIC
op/softmax.hpp op/softmax.hpp
op/split.cpp op/split.cpp
op/split.hpp op/split.hpp
op/sub.hpp
ops_bridge.cpp ops_bridge.cpp
utils/broadcasting.cpp utils/broadcasting.cpp
utils/broadcasting.hpp utils/broadcasting.hpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/divide.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector div(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/subtract.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector sub(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "op/batch_norm.hpp" #include "op/batch_norm.hpp"
#include "op/constant.hpp" #include "op/constant.hpp"
#include "op/conv.hpp" #include "op/conv.hpp"
#include "op/div.hpp"
#include "op/gemm.hpp" #include "op/gemm.hpp"
#include "op/matmul.hpp" #include "op/matmul.hpp"
#include "op/max_pool.hpp" #include "op/max_pool.hpp"
...@@ -30,6 +31,7 @@ ...@@ -30,6 +31,7 @@
#include "op/relu.hpp" #include "op/relu.hpp"
#include "op/softmax.hpp" #include "op/softmax.hpp"
#include "op/split.hpp" #include "op/split.hpp"
#include "op/sub.hpp"
#include "ops_bridge.hpp" #include "ops_bridge.hpp"
namespace ngraph namespace ngraph
...@@ -81,6 +83,7 @@ namespace ngraph ...@@ -81,6 +83,7 @@ namespace ngraph
std::bind(op::batch_norm, std::placeholders::_1)); std::bind(op::batch_norm, std::placeholders::_1));
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1)); m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1)); m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Div", std::bind(op::div, std::placeholders::_1));
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1)); m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1));
m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1)); m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1));
m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1)); m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1));
...@@ -88,6 +91,7 @@ namespace ngraph ...@@ -88,6 +91,7 @@ namespace ngraph
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1)); m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1)); m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1)); m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1));
} }
NodeVector operator()(const Node& node) const NodeVector operator()(const Node& node) const
......
...@@ -413,3 +413,35 @@ TEST(onnx, model_softmax) ...@@ -413,3 +413,35 @@ TEST(onnx, model_softmax)
auto result_vectors = execute(function, inputs, "INTERPRETER"); auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front())); EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
} }
TEST(onnx, model_sub)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/sub.onnx"));
Inputs inputs;
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector());
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{4, 5, 7}}}).get_vector());
auto expected_output = ngraph::test::NDArray<float, 3>({{{-3, -3, -4}}}).get_vector();
auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
TEST(onnx, model_div)
{
auto function = ngraph::onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/div.onnx"));
Inputs inputs;
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector());
inputs.emplace_back(ngraph::test::NDArray<float, 3>({{{1, 4, 12}}}).get_vector());
auto expected_output = ngraph::test::NDArray<float, 3>({{{1, 0.5, 0.25}}}).get_vector();
auto result_vectors = execute(function, inputs, "INTERPRETER");
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment