Commit 280d97d5 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Sang Ik Lee

[ONNX] Use v1 ops in Softplus (#4092)

* Use v1 ops in ONNX Softplus

* Some extra comments
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent caa11583
......@@ -18,10 +18,6 @@
#include "default_opset.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "softplus.hpp"
namespace ngraph
......@@ -34,34 +30,38 @@ namespace ngraph
{
NodeVector softplus(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
const auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> zero_node =
const std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<default_opset::Constant>(
data->get_element_type(), data->get_shape(), std::vector<float>{0.f});
std::shared_ptr<ngraph::Node> one_node =
const std::shared_ptr<ngraph::Node> one_node =
std::make_shared<default_opset::Constant>(
data->get_element_type(), data->get_shape(), std::vector<float>{1.f});
std::shared_ptr<ngraph::Node> positive_val_node =
data + std::make_shared<default_opset::Log>(
// data + log(exp(-data) + 1)
const std::shared_ptr<ngraph::Node> positive_val_node =
std::make_shared<default_opset::Add>(
data,
std::make_shared<default_opset::Log>(
std::make_shared<default_opset::Add>(
std::make_shared<default_opset::Exp>(
std::make_shared<default_opset::Negative>(data)) +
one_node);
std::make_shared<default_opset::Negative>(data)),
one_node)));
std::shared_ptr<ngraph::Node> negative_val_node =
std::make_shared<default_opset::Log>(
std::make_shared<default_opset::Exp>(data) + one_node);
// log(exp(data) + 1)
const std::shared_ptr<ngraph::Node> negative_val_node =
std::make_shared<default_opset::Log>(std::make_shared<default_opset::Add>(
std::make_shared<default_opset::Exp>(data), one_node));
std::shared_ptr<ngraph::Node> condition_node =
std::make_shared<ngraph::opset0::Greater>(data, zero_node);
const std::shared_ptr<ngraph::Node> condition_node =
std::make_shared<default_opset::Greater>(data, zero_node);
//
// This equation represents:
// x + log(exp(-x) + 1) - for x > 0; to manage exponent overflow,
// log(exp(x) + 1) - elsewhere.
//
return {std::make_shared<ngraph::opset0::Select>(
return {std::make_shared<default_opset::Select>(
condition_node, positive_val_node, negative_val_node)};
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment