Commit 7d3323c9 authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[ONNX] Ops bridge refactoring (#1451)

* [ONNX] Add Relu op

* Refactoring

* Refactoring - move op::Constant implementation to .cpp file

* cmake file list order

* Code review
parent f0f50acf
...@@ -41,6 +41,7 @@ add_library(onnx_import STATIC ...@@ -41,6 +41,7 @@ add_library(onnx_import STATIC
node.cpp node.cpp
op/add.hpp op/add.hpp
op/batch_norm.hpp op/batch_norm.hpp
op/constant.cpp
op/constant.hpp op/constant.hpp
op/relu.hpp op/relu.hpp
op/split.hpp op/split.hpp
......
...@@ -30,8 +30,9 @@ namespace ngraph ...@@ -30,8 +30,9 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector batch_norm(const Node& node, const NodeVector& inputs) inline NodeVector batch_norm(const Node& node)
{ {
NodeVector inputs{node.get_ng_inputs()};
auto x = inputs.at(0); auto x = inputs.at(0);
auto scale = inputs.at(1); auto scale = inputs.at(1);
auto bias = inputs.at(2); auto bias = inputs.at(2);
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/constant.hpp"
#include "ngraph/frontend/onnx_import/node.hpp"
#include "ngraph/frontend/onnx_import/tensor.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace
{
template <typename T>
inline std::shared_ptr<ngraph::op::Constant>
__make_ng_constant(const element::Type& type, const Tensor& tensor)
{
return std::make_shared<ngraph::op::Constant>(
type, tensor.get_shape(), tensor.get_data<T>());
}
template <Tensor::Type>
inline std::shared_ptr<ngraph::op::Constant> make_ng_constant(const Tensor& tensor)
{
throw error::tensor::unsupported_data_type{tensor};
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float16>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float32>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float64>(const Tensor& tensor)
{
return __make_ng_constant<double>(element::f64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int32>(const Tensor& tensor)
{
return __make_ng_constant<int32_t>(element::i32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
{
return __make_ng_constant<uint32_t>(element::u32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor)
{
return __make_ng_constant<uint64_t>(element::u64, tensor);
}
inline std::shared_ptr<ngraph::op::Constant> make_constant(const Tensor& tensor)
{
#define MAKE_NG_CONSTANT(data_type_) \
case data_type_: return make_ng_constant<data_type_>(tensor)
switch (tensor.get_type())
{
MAKE_NG_CONSTANT(Tensor::Type::float16);
MAKE_NG_CONSTANT(Tensor::Type::float32);
MAKE_NG_CONSTANT(Tensor::Type::float64);
MAKE_NG_CONSTANT(Tensor::Type::int32);
MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor};
}
}
}
NodeVector constant(const onnx_import::Node& node)
{
return {make_constant(node.get_attribute_value<Tensor>("value"))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -25,85 +25,7 @@ namespace ngraph ...@@ -25,85 +25,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
namespace detail NodeVector constant(const Node& node);
{
namespace
{
template <typename T>
inline std::shared_ptr<ngraph::op::Constant>
__make_ng_constant(const element::Type& type, const Tensor& tensor)
{
return std::make_shared<ngraph::op::Constant>(
type, tensor.get_shape(), tensor.get_data<T>());
}
}
template <Tensor::Type>
inline std::shared_ptr<ngraph::op::Constant> make_ng_constant(const Tensor& tensor)
{
throw error::tensor::unsupported_data_type{tensor};
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float16>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float32>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float64>(const Tensor& tensor)
{
return __make_ng_constant<double>(element::f64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int32>(const Tensor& tensor)
{
return __make_ng_constant<int32_t>(element::i32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
{
return __make_ng_constant<uint32_t>(element::u32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor)
{
return __make_ng_constant<uint64_t>(element::u64, tensor);
}
} // namespace detail
inline std::shared_ptr<ngraph::op::Constant> constant(const Tensor& tensor)
{
#define _make_ng_constant(_data_type) \
case _data_type: return detail::make_ng_constant<_data_type>(tensor)
switch (tensor.get_type())
{
_make_ng_constant(Tensor::Type::float16);
_make_ng_constant(Tensor::Type::float32);
_make_ng_constant(Tensor::Type::float64);
_make_ng_constant(Tensor::Type::int32);
_make_ng_constant(Tensor::Type::uint32);
_make_ng_constant(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor};
}
}
} // namespace op } // namespace op
......
...@@ -112,8 +112,9 @@ namespace ngraph ...@@ -112,8 +112,9 @@ namespace ngraph
} // namespace detail } // namespace detail
inline NodeVector split(const Node& node, const std::shared_ptr<ngraph::Node>& input) inline NodeVector split(const Node& node)
{ {
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
std::size_t count_outputs{node.get_output_names().size()}; std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)}; int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)}; std::size_t axis_to_split{static_cast<std::size_t>(axis)};
......
...@@ -43,22 +43,6 @@ namespace ngraph ...@@ -43,22 +43,6 @@ namespace ngraph
} // namespace error } // namespace error
NodeVector add(const Node& node) { return op::add(node); }
NodeVector batch_norm(const Node& node)
{
return op::batch_norm(node, node.get_ng_inputs());
}
NodeVector constant(const Node& node)
{
return {op::constant(node.get_attribute_value<Tensor>("value"))};
}
NodeVector split(const Node& node)
{
return op::split(node, node.get_ng_inputs().at(0));
}
NodeVector relu(const Node& node) { return op::relu(node); }
class ops_bridge class ops_bridge
{ {
public: public:
...@@ -83,12 +67,12 @@ namespace ngraph ...@@ -83,12 +67,12 @@ namespace ngraph
ops_bridge() ops_bridge()
{ {
m_map.emplace("Add", std::bind(add, std::placeholders::_1)); m_map.emplace("Add", std::bind(op::add, std::placeholders::_1));
m_map.emplace("BatchNormalization", m_map.emplace("BatchNormalization",
std::bind(batch_norm, std::placeholders::_1)); std::bind(op::batch_norm, std::placeholders::_1));
m_map.emplace("Constant", std::bind(constant, std::placeholders::_1)); m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Relu", std::bind(relu, std::placeholders::_1)); m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Split", std::bind(split, std::placeholders::_1)); m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
} }
NodeVector operator()(const Node& node) const NodeVector operator()(const Node& node) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment