Commit 280d97d5 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Sang Ik Lee

[ONNX] Use v1 ops in Softplus (#4092)

* Use v1 ops in ONNX Softplus

* Some extra comments
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent caa11583
...@@ -18,10 +18,6 @@ ...@@ -18,10 +18,6 @@
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "softplus.hpp" #include "softplus.hpp"
namespace ngraph namespace ngraph
...@@ -34,34 +30,38 @@ namespace ngraph ...@@ -34,34 +30,38 @@ namespace ngraph
{ {
NodeVector softplus(const Node& node) NodeVector softplus(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); const auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> zero_node = const std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<default_opset::Constant>( std::make_shared<default_opset::Constant>(
data->get_element_type(), data->get_shape(), std::vector<float>{0.f}); data->get_element_type(), data->get_shape(), std::vector<float>{0.f});
std::shared_ptr<ngraph::Node> one_node = const std::shared_ptr<ngraph::Node> one_node =
std::make_shared<default_opset::Constant>( std::make_shared<default_opset::Constant>(
data->get_element_type(), data->get_shape(), std::vector<float>{1.f}); data->get_element_type(), data->get_shape(), std::vector<float>{1.f});
std::shared_ptr<ngraph::Node> positive_val_node = // data + log(exp(-data) + 1)
data + std::make_shared<default_opset::Log>( const std::shared_ptr<ngraph::Node> positive_val_node =
std::make_shared<default_opset::Add>(
data,
std::make_shared<default_opset::Log>(
std::make_shared<default_opset::Add>(
std::make_shared<default_opset::Exp>( std::make_shared<default_opset::Exp>(
std::make_shared<default_opset::Negative>(data)) + std::make_shared<default_opset::Negative>(data)),
one_node); one_node)));
std::shared_ptr<ngraph::Node> negative_val_node = // log(exp(data) + 1)
std::make_shared<default_opset::Log>( const std::shared_ptr<ngraph::Node> negative_val_node =
std::make_shared<default_opset::Exp>(data) + one_node); std::make_shared<default_opset::Log>(std::make_shared<default_opset::Add>(
std::make_shared<default_opset::Exp>(data), one_node));
std::shared_ptr<ngraph::Node> condition_node = const std::shared_ptr<ngraph::Node> condition_node =
std::make_shared<ngraph::opset0::Greater>(data, zero_node); std::make_shared<default_opset::Greater>(data, zero_node);
//
// This equation represents: // This equation represents:
// x + log(exp(-x) + 1) - for x > 0; to manage exponent overflow, // x + log(exp(-x) + 1) - for x > 0; to manage exponent overflow,
// log(exp(x) + 1) - elsewhere. // log(exp(x) + 1) - elsewhere.
// //
return {std::make_shared<ngraph::opset0::Select>( return {std::make_shared<default_opset::Select>(
condition_node, positive_val_node, negative_val_node)}; condition_node, positive_val_node, negative_val_node)};
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment