Commit dddcd4a8 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[ONNX] Activation functions parameters for LSTM. (#2842)

* Move HardSigmoid to nGraph fused operators.

* UT for HardSigmoid fused operator.

* Add type_prop UT.

* Activation function parameters and hardsigmoid activation function.

* UT for lstm with hardsigmoid activation function.

* Reorder operations in implementation.

* Fix unit tests.

* Fix typo.

* Change stored activation function to pure function pointer.

* Apply style-check.

* [ONNX] Refactor LSTM tests to use NgraphTestCase

* Enable passing instance values to comparator

* Style apply.

* Fix style, syntax

* Change order of class member to fix errors.

* Switch to single-precision parameters.

* Disable unit test for IGPU.
parent 5ec02d8e
...@@ -258,6 +258,12 @@ namespace ngraph ...@@ -258,6 +258,12 @@ namespace ngraph
, m_activations{to_lower_case( , m_activations{to_lower_case(
node.get_attribute_value<std::vector<std::string>>( node.get_attribute_value<std::vector<std::string>>(
"activations", {"sigmoid", "tanh", "tanh"}))} "activations", {"sigmoid", "tanh", "tanh"}))}
// Default values for activation functions are same as for corresponding
// ONNX operator.
, m_activation_alpha{node.get_attribute_value<std::vector<float>>(
"activation_alpha", std::vector<float>{})}
, m_activation_beta{node.get_attribute_value<std::vector<float>>(
"activation_beta", std::vector<float>{})}
, m_input_forget{static_cast<bool>( , m_input_forget{static_cast<bool>(
node.get_attribute_value<std::int64_t>("input_forget", 0))} node.get_attribute_value<std::int64_t>("input_forget", 0))}
{ {
...@@ -276,6 +282,8 @@ namespace ngraph ...@@ -276,6 +282,8 @@ namespace ngraph
std::int64_t m_hidden_size; std::int64_t m_hidden_size;
float m_clip_threshold; float m_clip_threshold;
std::vector<std::string> m_activations; std::vector<std::string> m_activations;
std::vector<float> m_activation_alpha;
std::vector<float> m_activation_beta;
bool m_input_forget; bool m_input_forget;
}; };
...@@ -534,6 +542,25 @@ namespace ngraph ...@@ -534,6 +542,25 @@ namespace ngraph
float m_clip_threshold; float m_clip_threshold;
}; };
rnn::ActivationFunction get_activation_function(const LSTMAttributes& attributes,
std::size_t idx)
{
rnn::ActivationFunction afunc =
rnn::get_activation_func_by_name(attributes.m_activations.at(idx));
// Set activation functions parameters (if any)
if (attributes.m_activation_alpha.size() > idx)
{
afunc.set_alpha(attributes.m_activation_alpha.at(idx));
}
if (attributes.m_activation_beta.size() > idx)
{
afunc.set_beta(attributes.m_activation_beta.at(idx));
}
return afunc;
}
} // anonymous namespace } // anonymous namespace
namespace set_1 namespace set_1
...@@ -543,12 +570,13 @@ namespace ngraph ...@@ -543,12 +570,13 @@ namespace ngraph
LSTMNgInputMap input_map{node}; LSTMNgInputMap input_map{node};
LSTMAttributes attributes{node}; LSTMAttributes attributes{node};
rnn::ActivationFunction activation_f = // Get activation functions.
rnn::get_activation_func_by_name(attributes.m_activations.at(0)); const rnn::ActivationFunction& activation_f =
rnn::ActivationFunction activation_g = get_activation_function(attributes, 0);
rnn::get_activation_func_by_name(attributes.m_activations.at(1)); const rnn::ActivationFunction& activation_g =
rnn::ActivationFunction activation_h = get_activation_function(attributes, 1);
rnn::get_activation_func_by_name(attributes.m_activations.at(2)); const rnn::ActivationFunction& activation_h =
get_activation_function(attributes, 2);
NodeVector results; NodeVector results;
......
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <cmath>
#include <functional> #include <functional>
#include <iterator> #include <iterator>
#include <unordered_map> #include <unordered_map>
#include "activation_functions.hpp" #include "activation_functions.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
...@@ -31,30 +33,67 @@ namespace ngraph ...@@ -31,30 +33,67 @@ namespace ngraph
{ {
namespace detail namespace detail
{ {
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg) std::shared_ptr<ngraph::Node>
sigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{ {
return std::make_shared<ngraph::op::Sigmoid>(arg); return std::make_shared<ngraph::op::Sigmoid>(arg);
} }
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg) std::shared_ptr<ngraph::Node>
tanh(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{ {
return std::make_shared<ngraph::op::Tanh>(arg); return std::make_shared<ngraph::op::Tanh>(arg);
} }
std::shared_ptr<ngraph::Node> relu(const std::shared_ptr<ngraph::Node>& arg) std::shared_ptr<ngraph::Node>
relu(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{ {
return std::make_shared<ngraph::op::Relu>(arg); return std::make_shared<ngraph::op::Relu>(arg);
} }
std::shared_ptr<ngraph::Node>
hardsigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::HardSigmoid>(arg, alpha, beta);
}
} // namespace detail } // namespace detail
ActivationFunction::ActivationFunction(ActivationFunctionType f,
float alpha,
float beta)
: m_function{f}
, m_alpha{alpha}
, m_beta{beta}
{
}
ActivationFunction::ActivationFunction(ActivationFunctionType f, float alpha)
: ActivationFunction(f, alpha, std::nanf(""))
{
}
ActivationFunction::ActivationFunction(ActivationFunctionType f)
: ActivationFunction(f, std::nanf(""), std::nanf(""))
{
}
std::shared_ptr<ngraph::Node> ActivationFunction::
operator()(const std::shared_ptr<ngraph::Node>& arg) const
{
return m_function(arg, m_alpha, m_beta);
}
ActivationFunction get_activation_func_by_name(const std::string& func_name) ActivationFunction get_activation_func_by_name(const std::string& func_name)
{ {
using ActivationFunctionMap = std::unordered_map<std::string, ActivationFunction>; using ActivationFunctionMap = std::unordered_map<std::string, ActivationFunction>;
using namespace std::placeholders;
static ActivationFunctionMap func_map{ static ActivationFunctionMap func_map{
{"sigmoid", std::bind(detail::sigmoid, std::placeholders::_1)}, {"sigmoid", ActivationFunction{detail::sigmoid}},
{"tanh", std::bind(detail::tanh, std::placeholders::_1)}, {"tanh", ActivationFunction{detail::tanh}},
{"relu", std::bind(detail::relu, std::placeholders::_1)}}; {"relu", ActivationFunction{detail::relu}},
{"hardsigmoid", ActivationFunction{detail::hardsigmoid, 0.2f, 0.5f}},
};
auto func_it = func_map.find(func_name); auto func_it = func_map.find(func_name);
if (func_it == std::end(func_map)) if (func_it == std::end(func_map))
......
...@@ -22,6 +22,23 @@ ...@@ -22,6 +22,23 @@
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#ifdef _WIN32
#pragma warning(push)
#pragma warning(disable : 4100)
#endif
// Prevents the compiler from complaining about or optimizing away variables
// that appear unused on Linux
#if (defined(__GNUC__) && !defined(__clang__))
#undef ONNX_ATTRIBUTE_UNUSED
#define ONNX_ATTRIBUTE_UNUSED __attribute__((__unused__))
#else
#define ONNX_ATTRIBUTE_UNUSED
#endif
#define UNUSED_PARAMETER ONNX_ATTRIBUTE_UNUSED = 0
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
...@@ -41,13 +58,39 @@ namespace ngraph ...@@ -41,13 +58,39 @@ namespace ngraph
namespace detail namespace detail
{ {
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg); std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg,
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg); float alpha UNUSED_PARAMETER,
std::shared_ptr<ngraph::Node> relu(const std::shared_ptr<ngraph::Node>& arg); float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node> relu(const std::shared_ptr<ngraph::Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node>
hardsigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta);
} }
using ActivationFunction = using ActivationFunctionType = std::shared_ptr<ngraph::Node> (*)(
std::function<std::shared_ptr<ngraph::Node>(const std::shared_ptr<ngraph::Node>&)>; const std::shared_ptr<ngraph::Node>&, float, float);
class ActivationFunction
{
public:
ActivationFunction(ActivationFunctionType f, float alpha, float beta);
ActivationFunction(ActivationFunctionType f, float alpha);
ActivationFunction(ActivationFunctionType f);
std::shared_ptr<ngraph::Node>
operator()(const std::shared_ptr<ngraph::Node>& arg) const;
void set_alpha(float alpha) { m_alpha = alpha; }
void set_beta(float beta) { m_beta = beta; }
private:
ActivationFunctionType m_function;
float m_alpha;
float m_beta;
};
/// \brief Gets the activation function by name. /// \brief Gets the activation function by name.
/// ///
...@@ -64,3 +107,14 @@ namespace ngraph ...@@ -64,3 +107,14 @@ namespace ngraph
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
#ifdef _WIN32
#pragma warning(pop)
#endif
#ifdef UNUSED_PARAMETER
#undef UNUSED_PARAMETER
#endif
#ifdef ONNX_ATTRIBUTE_UNUSED
#undef ONNX_ATTRIBUTE_UNUSED
#endif
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
output: "Y_c"
op_type: "LSTM"
attribute {
name: "activation_alpha"
floats: 0.375
type: FLOATS
}
attribute {
name: "activation_beta"
floats: 0.748
type: FLOATS
}
attribute {
name: "activations"
strings: "HardSigmoid"
strings: "Tanh"
strings: "Tanh"
type: STRINGS
}
attribute {
name: "direction"
s: "forward"
type: STRING
}
attribute {
name: "hidden_size"
i: 2
type: INT
}
attribute {
name: "input_forget"
i: 0
type: INT
}
}
name: "compute_graph"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 8
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 8
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y_c"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 7
}
...@@ -60,8 +60,8 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip) ...@@ -60,8 +60,8 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip)
-0.0888852f, -0.0888852f,
-0.428709f, -0.428709f,
-0.283349f, -0.283349f,
0.208792f}); 0.208792f}); // W
test_case.add_input<float>({0.146626f, // R test_case.add_input<float>({0.146626f,
-0.0620289f, -0.0620289f,
-0.0815302f, -0.0815302f,
0.100482f, 0.100482f,
...@@ -76,9 +76,9 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip) ...@@ -76,9 +76,9 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip)
-0.394864f, -0.394864f,
0.42111f, 0.42111f,
-0.386624f, -0.386624f,
-0.390225f}); -0.390225f}); // R
test_case.add_input<float>({0.381619f, // B test_case.add_input<float>({0.381619f,
0.0323954f, 0.0323954f,
-0.14449f, -0.14449f,
0.420804f, 0.420804f,
...@@ -93,7 +93,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip) ...@@ -93,7 +93,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip)
0.0f, 0.0f,
0.0f, 0.0f,
0.0f, 0.0f,
0.0f}); 0.0f}); // B
test_case.add_input<float>({0.2345f, 0.5235f, 0.4378f, 0.3475f, 0.8927f, 0.3456f}); // P test_case.add_input<float>({0.2345f, 0.5235f, 0.4378f, 0.3475f, 0.8927f, 0.3456f}); // P
test_case.add_expected_output<float>( test_case.add_expected_output<float>(
...@@ -122,7 +122,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq) ...@@ -122,7 +122,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
test_case.add_input<int>({1, 2}); // seq_lengths test_case.add_input<int>({1, 2}); // seq_lengths
test_case.add_expected_output<float>(Shape{2, 1, 2, 3}, test_case.add_expected_output<float>(Shape{2, 1, 2, 3},
{0.28828835f, // Y_data {0.28828835f,
0.36581863f, 0.36581863f,
0.45679406f, 0.45679406f,
0.34526032f, 0.34526032f,
...@@ -133,7 +133,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq) ...@@ -133,7 +133,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
0.f, 0.f,
0.85882828f, 0.85882828f,
0.90703777f, 0.90703777f,
0.92382453f}); 0.92382453f}); // Y_data
test_case.add_expected_output<float>( test_case.add_expected_output<float>(
Shape{1, 2, 3}, Shape{1, 2, 3},
{0.28828835f, 0.36581863f, 0.45679406f, 0.85882828f, 0.90703777f, 0.92382453f}); // Y_h_data {0.28828835f, 0.36581863f, 0.45679406f, 0.85882828f, 0.90703777f, 0.92382453f}); // Y_h_data
...@@ -146,3 +146,60 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq) ...@@ -146,3 +146,60 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
test_case.set_tolerance(3); test_case.set_tolerance(3);
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_fwd_hardsigmoid_activation.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// X
test_case.add_input<float>({-0.455351f, -0.276391f, -0.185934f, -0.269585f});
// W
test_case.add_input<float>({-0.494659f,
0.0453352f,
-0.487793f,
0.417264f,
-0.0175329f,
0.489074f,
-0.446013f,
0.414029f,
-0.0091708f,
-0.255364f,
-0.106952f,
-0.266717f,
-0.0888852f,
-0.428709f,
-0.283349f,
0.208792f});
// R
test_case.add_input<float>({0.146626f,
-0.0620289f,
-0.0815302f,
0.100482f,
-0.219535f,
-0.306635f,
-0.28515f,
-0.314112f,
-0.228172f,
0.405972f,
0.31576f,
0.281487f,
-0.394864f,
0.42111f,
-0.386624f,
-0.390225f});
// Y
test_case.add_expected_output<float>(Shape{2, 1, 1, 2},
{0.09086666f, 0.04378549f, 0.12914555f, 0.00257774f});
// Y_h
test_case.add_expected_output<float>(Shape{1, 1, 2}, {0.12914555f, 0.00257774f});
// Y_c
test_case.add_expected_output<float>(Shape{1, 1, 2}, {0.19017234f, 0.00356848f});
// The discrepancies occur at most at 18th mantissa bit - 8th decimal position.
test_case.set_tolerance(6);
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment