Commit 1fe02337 authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] Non-linear operators (#1580)

* [ONNX] Non-linear operators

* Review fix pt. 1

* Review fix pt. 2

* Non-linear tests

* style check

* Exception fix

* Test fix
parent 17af4266
......@@ -44,12 +44,16 @@ add_library(onnx_import STATIC
op/conv.cpp
op/conv.hpp
op/div.hpp
op/elu.cpp
op/elu.hpp
op/equal.hpp
op/flatten.cpp
op/flatten.hpp
op/gemm.cpp
op/gemm.hpp
op/greater.hpp
op/leaky_relu.cpp
op/leaky_relu.hpp
op/less.hpp
op/matmul.hpp
op/max_pool.cpp
......@@ -62,6 +66,8 @@ add_library(onnx_import STATIC
op/not.hpp
op/or.hpp
op/pow.hpp
op/prelu.cpp
op/prelu.hpp
op/reduce.cpp
op/reduce.hpp
op/relu.hpp
......@@ -69,12 +75,18 @@ add_library(onnx_import STATIC
op/reshape.hpp
op/shape.cpp
op/shape.hpp
op/selu.cpp
op/selu.hpp
op/sigmoid.hpp
op/softmax.cpp
op/softmax.hpp
op/split.cpp
op/split.hpp
op/sub.hpp
op/sum.hpp
op/tanh.hpp
op/thresholded_relu.cpp
op/thresholded_relu.hpp
op/unsqueeze.cpp
op/unsqueeze.hpp
op/xor.hpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "utils/broadcasting.hpp"
#include "elu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector elu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector elu(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/multiply.hpp"
#include "exceptions.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "leaky_relu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector leaky_relu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.01);
ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1)))
<< " alpha value should be in range (0,1)";
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector leaky_relu(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iterator>
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "prelu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector prelu(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
std::shared_ptr<ngraph::Node> slope = ng_inputs.at(1);
auto slope_shape = slope->get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it =
std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
else
{
auto params = numpy_style_broadcast_for_binary_operation(slope, data);
slope = params.at(0);
}
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector prelu(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "selu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector selu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.67326319217681884765625);
double gamma = node.get_attribute_value<double>("gamma", 1.05070102214813232421875);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> gamma_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma});
gamma_node = make_broadcast_node(gamma_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {gamma_node *
(std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector selu(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector sigmoid(const Node& node)
{
return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/tanh.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector tanh(const Node& node)
{
return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/multiply.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "thresholded_relu.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector thresholded_relu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.0);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
auto data_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, alpha_node),
data->get_element_type());
return {data * data_map};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector thresholded_relu(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -26,10 +26,12 @@
#include "op/constant.hpp"
#include "op/conv.hpp"
#include "op/div.hpp"
#include "op/elu.hpp"
#include "op/equal.hpp"
#include "op/flatten.hpp"
#include "op/gemm.hpp"
#include "op/greater.hpp"
#include "op/leaky_relu.hpp"
#include "op/less.hpp"
#include "op/matmul.hpp"
#include "op/max.hpp"
......@@ -40,14 +42,19 @@
#include "op/not.hpp"
#include "op/or.hpp"
#include "op/pow.hpp"
#include "op/prelu.hpp"
#include "op/reduce.hpp"
#include "op/relu.hpp"
#include "op/reshape.hpp"
#include "op/selu.hpp"
#include "op/shape.hpp"
#include "op/sigmoid.hpp"
#include "op/softmax.hpp"
#include "op/split.hpp"
#include "op/sub.hpp"
#include "op/sum.hpp"
#include "op/tanh.hpp"
#include "op/thresholded_relu.hpp"
#include "op/unsqueeze.hpp"
#include "op/xor.hpp"
#include "ops_bridge.hpp"
......@@ -104,10 +111,12 @@ namespace ngraph
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Div", std::bind(op::div, std::placeholders::_1));
m_map.emplace("Elu", std::bind(op::elu, std::placeholders::_1));
m_map.emplace("Equal", std::bind(op::equal, std::placeholders::_1));
m_map.emplace("Flatten", std::bind(op::flatten, std::placeholders::_1));
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1));
m_map.emplace("Greater", std::bind(op::greater, std::placeholders::_1));
m_map.emplace("LeakyRelu", std::bind(op::leaky_relu, std::placeholders::_1));
m_map.emplace("Less", std::bind(op::less, std::placeholders::_1));
m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1));
m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1));
......@@ -118,6 +127,7 @@ namespace ngraph
m_map.emplace("Not", std::bind(op::logical_not, std::placeholders::_1));
m_map.emplace("Or", std::bind(op::logical_or, std::placeholders::_1));
m_map.emplace("Pow", std::bind(op::pow, std::placeholders::_1));
m_map.emplace("PRelu", std::bind(op::prelu, std::placeholders::_1));
m_map.emplace("ReduceLogSum",
std::bind(op::reduce_log_sum, std::placeholders::_1));
m_map.emplace("ReduceLogSumExp",
......@@ -134,10 +144,15 @@ namespace ngraph
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1));
m_map.emplace("Shape", std::bind(op::shape, std::placeholders::_1));
m_map.emplace("Selu", std::bind(op::selu, std::placeholders::_1));
m_map.emplace("Sigmoid", std::bind(op::sigmoid, std::placeholders::_1));
m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1));
m_map.emplace("Sum", std::bind(op::sum, std::placeholders::_1));
m_map.emplace("Tanh", std::bind(op::tanh, std::placeholders::_1));
m_map.emplace("ThresholdedRelu",
std::bind(op::thresholded_relu, std::placeholders::_1));
m_map.emplace("Unsqueeze", std::bind(op::unsqueeze, std::placeholders::_1));
m_map.emplace("Xor", std::bind(op::logical_xor, std::placeholders::_1));
}
......
......@@ -16,6 +16,8 @@
#pragma once
#include <memory>
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
......@@ -89,6 +91,17 @@ namespace ngraph
return std::make_shared<ngraph::op::Broadcast>(
node, new_shape, calculate_broadcast_axes(new_shape, node->get_shape()));
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node,
ngraph::Shape new_shape,
std::size_t start_match_axis)
{
return std::make_shared<ngraph::op::Broadcast>(
node,
new_shape,
calculate_broadcast_axes(new_shape, node->get_shape(), start_match_axis));
}
} // namespace onnx_import
} // namespace ngraph
 backend-test:f
"
xy" LeakyRelu*
alpha=test_leakyreluZ
x



b
y



B
\ No newline at end of file
 backend-test:y

x
slopey"PRelutest_prelu_exampleZ
x



Z
slope



b
y



B
\ No newline at end of file
 backend-test:Q

xy"Sigmoid test_sigmoidZ
x



b
y



B
\ No newline at end of file
 backend-test:K
xy"Tanh test_tanhZ
x



b
y



B
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment