Commit 5761f145 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by arogowie-intel

Update LSTM ONNX operator to use LSTMCell fused op.

parent ef1c5347
...@@ -14,46 +14,32 @@ ...@@ -14,46 +14,32 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <algorithm>
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <functional>
#include <iterator>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "core/null_node.hpp" #include "core/null_node.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "lstm.hpp" #include "lstm.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
#include "utils/rnn/activation_functions.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -63,61 +49,6 @@ namespace ngraph ...@@ -63,61 +49,6 @@ namespace ngraph
{ {
namespace namespace
{ {
std::shared_ptr<ngraph::Node> add(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> sub(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Subtract>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> clip(const std::shared_ptr<ngraph::Node>& data,
float threshold)
{
if (threshold == 0.f)
{
return data;
}
float min_val = -threshold;
float max_val = threshold;
std::size_t size = ngraph::shape_size(data->get_shape());
const std::shared_ptr<ngraph::Node> min_val_node =
ngraph::op::Constant::create(data->get_element_type(),
data->get_shape(),
std::vector<float>(size, min_val));
const std::shared_ptr<ngraph::Node> max_val_node =
ngraph::op::Constant::create(data->get_element_type(),
data->get_shape(),
std::vector<float>(size, max_val));
return std::make_shared<ngraph::op::Minimum>(
max_val_node, std::make_shared<ngraph::op::Maximum>(data, min_val_node));
}
// Modify input vector in-place and return reference to modified vector.
std::vector<std::string>& to_lower_case(std::vector<std::string>&& vs)
{
std::transform(std::begin(vs),
std::end(vs),
std::begin(vs),
[](std::string& s) { return ngraph::to_lower(s); });
return vs;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
enum class LSTMInput enum class LSTMInput
...@@ -257,9 +188,8 @@ namespace ngraph ...@@ -257,9 +188,8 @@ namespace ngraph
explicit LSTMAttributes(const Node& node) explicit LSTMAttributes(const Node& node)
: m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")} : m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")}
, m_clip_threshold{node.get_attribute_value<float>("clip", 0.f)} , m_clip_threshold{node.get_attribute_value<float>("clip", 0.f)}
, m_activations{to_lower_case( , m_activations{node.get_attribute_value<std::vector<std::string>>(
node.get_attribute_value<std::vector<std::string>>( "activations", {"sigmoid", "tanh", "tanh"})}
"activations", {"sigmoid", "tanh", "tanh"}))}
// Default values for activation functions are same as for corresponding // Default values for activation functions are same as for corresponding
// ONNX operator. // ONNX operator.
, m_activation_alpha{node.get_attribute_value<std::vector<float>>( , m_activation_alpha{node.get_attribute_value<std::vector<float>>(
...@@ -293,20 +223,17 @@ namespace ngraph ...@@ -293,20 +223,17 @@ namespace ngraph
{ {
public: public:
explicit LSTMForward(std::shared_ptr<ngraph::Node> X, explicit LSTMForward(std::shared_ptr<ngraph::Node> X,
std::shared_ptr<ngraph::Node> W, const std::shared_ptr<ngraph::Node>& W,
std::shared_ptr<ngraph::Node> R, const std::shared_ptr<ngraph::Node>& R,
std::shared_ptr<ngraph::Node> B, const std::shared_ptr<ngraph::Node>& B,
std::shared_ptr<ngraph::Node> P, const std::shared_ptr<ngraph::Node>& P,
std::shared_ptr<ngraph::Node> initial_h, const std::shared_ptr<ngraph::Node>& initial_h,
std::shared_ptr<ngraph::Node> initial_c, const std::shared_ptr<ngraph::Node>& initial_c,
std::shared_ptr<ngraph::Node> seq_lengths, const std::shared_ptr<ngraph::Node>& seq_lengths,
rnn::ActivationFunction activation_f, const LSTMAttributes& attributes)
rnn::ActivationFunction activation_g,
rnn::ActivationFunction activation_h,
bool input_forget = false,
float clip_threshold = 0.f)
: m_X{X} : m_X{X}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs. ,
, // Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
, m_W{reshape::squeeze(W)} , m_W{reshape::squeeze(W)}
, m_R{reshape::squeeze(R)} , m_R{reshape::squeeze(R)}
, m_B{reshape::squeeze(B)} , m_B{reshape::squeeze(B)}
...@@ -314,11 +241,7 @@ namespace ngraph ...@@ -314,11 +241,7 @@ namespace ngraph
, m_initial_h{reshape::squeeze(initial_h)} , m_initial_h{reshape::squeeze(initial_h)}
, m_initial_c{reshape::squeeze(initial_c)} , m_initial_c{reshape::squeeze(initial_c)}
, m_seq_lengths{seq_lengths} , m_seq_lengths{seq_lengths}
, m_activation_f{activation_f} , m_attributes{attributes}
, m_activation_g{activation_g}
, m_activation_h{activation_h}
, m_input_forget{input_forget}
, m_clip_threshold{clip_threshold}
{ {
} }
...@@ -332,7 +255,7 @@ namespace ngraph ...@@ -332,7 +255,7 @@ namespace ngraph
// W - The weight tensor. [num_directions, 4*hidden_size, input_size] // W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size] // R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size] // B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor forr peepholes. [num_directions, 3*hidde_size] // P - The weight tensor for peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------ // ------ ACRONYMS ------
// i - input gate // i - input gate
// o - output gate // o - output gate
...@@ -340,32 +263,11 @@ namespace ngraph ...@@ -340,32 +263,11 @@ namespace ngraph
// c - cell gate // c - cell gate
// t - time step (t-1 means previous time step) // t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------ // ------ VARIABLE NAMES ------
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// b_W_R - Bias vectors for input, output, forget, and cell gates.
// Concatenation of `[Wb, Rb]`.
// p_[iof] - P peephole weight vector for respectively: input, output,
// and forget gates.
// H_t - Hidden state vector at current time step. // H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step. // C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps. // h_list - The list of hidden states at all processed time steps.
//
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector p_iof = ngraph::builder::split(m_P, 3);
const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1);
const auto& p_f = p_iof.at(2);
NodeVector h_list;
NodeVector b_W_R = ngraph::builder::split(m_B, 2); NodeVector h_list;
std::shared_ptr<ngraph::Node> bias = b_W_R.at(0) + b_W_R.at(1);
std::shared_ptr<ngraph::Node> H_t = m_initial_h; std::shared_ptr<ngraph::Node> H_t = m_initial_h;
std::shared_ptr<ngraph::Node> C_t = m_initial_c; std::shared_ptr<ngraph::Node> C_t = m_initial_c;
...@@ -393,47 +295,26 @@ namespace ngraph ...@@ -393,47 +295,26 @@ namespace ngraph
std::int32_t time_step{1}; std::int32_t time_step{1};
for (const auto& in_x : in_seqs) for (const auto& in_x : in_seqs)
{ {
// (.) - Denotes element-wise multiplication. const std::shared_ptr<ngraph::Node>& lstm_cell =
// * - Denotes dot product. std::make_shared<ngraph::op::LSTMCell>(
in_x,
// Xt*(W^T) -- for [iofc] gates. m_W,
auto Xt_W = std::make_shared<ngraph::op::Dot>( m_R,
in_x, ngraph::op::util::transpose(m_W)); H_t,
// Ht-1*(R^T) -- for [iofc] gates. C_t,
auto Ht_R = std::make_shared<ngraph::op::Dot>( m_attributes.m_hidden_size,
H_t, ngraph::op::util::transpose(m_R)); m_B,
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates. m_P,
auto gates = add(Xt_W, add(Ht_R, bias)); m_attributes.m_activations,
m_attributes.m_activation_alpha,
NodeVector split_gates = ngraph::builder::split(gates, 4, -1); m_attributes.m_activation_beta,
auto i = split_gates.at(0); m_attributes.m_clip_threshold,
auto o = split_gates.at(1); m_attributes.m_input_forget);
auto f = split_gates.at(2);
auto c = split_gates.at(3); const std::shared_ptr<ngraph::Node>& H =
get_output_element(lstm_cell, 0);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) const std::shared_ptr<ngraph::Node>& C =
i = m_activation_f(clip(add(i, mul(p_i, C_t)), m_clip_threshold)); get_output_element(lstm_cell, 1);
if (m_input_forget)
{
// Couple input with forget gate: 1 - i
f = sub(ngraph::op::Constant::create(
i->get_element_type(),
i->get_shape(),
std::vector<float>(shape_size(i->get_shape()), 1.f)),
i);
}
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f = m_activation_f(clip(add(f, mul(p_f, C_t)), m_clip_threshold));
}
// ft (.) Ct-1 + it (.) ct
auto C =
add(mul(f, C_t), mul(i, m_activation_g(clip(c, m_clip_threshold))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o = m_activation_f(clip(add(o, mul(p_o, C)), m_clip_threshold));
// ot (.) h(Ct)
auto H = mul(o, m_activation_h(C));
// Expand tensors with empty outermost dim, so we can later concatenate // Expand tensors with empty outermost dim, so we can later concatenate
// them. // them.
...@@ -528,41 +409,16 @@ namespace ngraph ...@@ -528,41 +409,16 @@ namespace ngraph
} }
std::shared_ptr<ngraph::Node> m_X; std::shared_ptr<ngraph::Node> m_X;
std::shared_ptr<ngraph::Node> m_W; const std::shared_ptr<ngraph::Node>& m_W;
std::shared_ptr<ngraph::Node> m_R; const std::shared_ptr<ngraph::Node>& m_R;
std::shared_ptr<ngraph::Node> m_B; const std::shared_ptr<ngraph::Node>& m_B;
std::shared_ptr<ngraph::Node> m_P; const std::shared_ptr<ngraph::Node>& m_P;
std::shared_ptr<ngraph::Node> m_initial_h; const std::shared_ptr<ngraph::Node>& m_initial_h;
std::shared_ptr<ngraph::Node> m_initial_c; const std::shared_ptr<ngraph::Node>& m_initial_c;
std::shared_ptr<ngraph::Node> m_seq_lengths; const std::shared_ptr<ngraph::Node>& m_seq_lengths;
rnn::ActivationFunction m_activation_f; const LSTMAttributes& m_attributes;
rnn::ActivationFunction m_activation_g;
rnn::ActivationFunction m_activation_h;
// For coupling input and forget gates.
bool m_input_forget;
// For clipping cell input in the range [-clip_threshold, clip_threshold].
float m_clip_threshold;
}; };
rnn::ActivationFunction get_activation_function(const LSTMAttributes& attributes,
std::size_t idx)
{
rnn::ActivationFunction afunc =
rnn::get_activation_func_by_name(attributes.m_activations.at(idx));
// Set activation functions parameters (if any)
if (attributes.m_activation_alpha.size() > idx)
{
afunc.set_alpha(attributes.m_activation_alpha.at(idx));
}
if (attributes.m_activation_beta.size() > idx)
{
afunc.set_beta(attributes.m_activation_beta.at(idx));
}
return afunc;
}
} // anonymous namespace } // anonymous namespace
namespace set_1 namespace set_1
...@@ -572,14 +428,6 @@ namespace ngraph ...@@ -572,14 +428,6 @@ namespace ngraph
LSTMNgInputMap input_map{node}; LSTMNgInputMap input_map{node};
LSTMAttributes attributes{node}; LSTMAttributes attributes{node};
// Get activation functions.
const rnn::ActivationFunction& activation_f =
get_activation_function(attributes, 0);
const rnn::ActivationFunction& activation_g =
get_activation_function(attributes, 1);
const rnn::ActivationFunction& activation_h =
get_activation_function(attributes, 2);
NodeVector results; NodeVector results;
if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_FORWARD || if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_FORWARD ||
...@@ -593,11 +441,7 @@ namespace ngraph ...@@ -593,11 +441,7 @@ namespace ngraph
input_map.at(LSTMInput::LSTM_INPUT_INIT_H), input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
input_map.at(LSTMInput::LSTM_INPUT_INIT_C), input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS), input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f, attributes);
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
results = lstm_fwd.run( results = lstm_fwd.run(
(attributes.m_direction == LSTMDirection::LSTM_DIRECTION_REVERSE)); (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_REVERSE));
} }
...@@ -625,11 +469,7 @@ namespace ngraph ...@@ -625,11 +469,7 @@ namespace ngraph
H.at(0), H.at(0),
C.at(0), C.at(0),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS), input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f, attributes);
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
LSTMForward lstm_reversed(input_map.at(LSTMInput::LSTM_INPUT_X), LSTMForward lstm_reversed(input_map.at(LSTMInput::LSTM_INPUT_X),
W.at(1), W.at(1),
R.at(1), R.at(1),
...@@ -638,11 +478,7 @@ namespace ngraph ...@@ -638,11 +478,7 @@ namespace ngraph
H.at(1), H.at(1),
C.at(1), C.at(1),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS), input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f, attributes);
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
NodeVector fwd_results{lstm_fwd.run()}; NodeVector fwd_results{lstm_fwd.run()};
NodeVector rev_results{lstm_fwd.run(true)}; NodeVector rev_results{lstm_fwd.run(true)};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment