Commit a804c3d7 authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] Non-linear ops (#1864)

* [ONNX] Non-linear ops

* Style check
parent fbc3a940
......@@ -20,16 +20,18 @@
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "prelu.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -54,13 +56,33 @@ namespace ngraph
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
else
else if (data_shape != slope_shape)
{
auto params = numpy_style_broadcast_for_binary_operation(slope, data);
slope = params.at(0);
}
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)};
// x < 0 => f(x) = x * slope
// x >= 0 => f(x) = x
std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
std::shared_ptr<ngraph::Node> negative_map =
std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Less>(data, zero_node),
data->get_element_type());
std::shared_ptr<ngraph::Node> positive_map =
std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, zero_node),
data->get_element_type());
slope = negative_map * slope + positive_map;
return {data * slope};
}
} // namespace set_1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment