Commit 7b9f581a authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Scott Cyphers

Fixed Selu op decomposition (#3976)

* Fixed Selu op decomposition

* Updated Selu op
parent 420c8c0f
...@@ -39,12 +39,18 @@ NodeVector op::v0::Selu::decompose_op() const ...@@ -39,12 +39,18 @@ NodeVector op::v0::Selu::decompose_op() const
const auto data = input_value(0); const auto data = input_value(0);
const auto alpha = input_value(1); const auto alpha = input_value(1);
const auto lambda = input_value(2); const auto lambda = input_value(2);
const auto zero_node = std::make_shared<ngraph::op::Constant>( const auto zero_node = op::Constant::create(data.get_element_type(), Shape{1}, {0});
data.get_element_type(), data.get_shape(), std::vector<double>{0});
return {lambda * // lambda * ((max(data, 0) + (alpha * exp(min(data, 0)) - alpha))
(std::make_shared<op::Maximum>(data, zero_node) + return {std::make_shared<op::v1::Multiply>(
alpha * std::make_shared<op::Exp>(std::make_shared<op::Minimum>(data, zero_node)) - lambda,
alpha)}; std::make_shared<op::v1::Add>(
std::make_shared<op::v1::Maximum>(data, zero_node),
std::make_shared<op::v1::Subtract>(
std::make_shared<op::v1::Multiply>(
alpha,
std::make_shared<op::Exp>(std::make_shared<op::v1::Minimum>(data, zero_node))),
alpha)))};
} }
shared_ptr<Node> op::v0::Selu::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Selu::copy_with_new_args(const NodeVector& new_args) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment