Commit 5761f145 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by arogowie-intel

Update LSTM ONNX operator to use LSTMCell fused op.

parent ef1c5347
......@@ -14,46 +14,32 @@
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "core/null_node.hpp"
#include "exceptions.hpp"
#include "lstm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#include "utils/reshape.hpp"
#include "utils/rnn/activation_functions.hpp"
namespace ngraph
{
......@@ -63,61 +49,6 @@ namespace ngraph
{
namespace
{
std::shared_ptr<ngraph::Node> add(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> sub(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Subtract>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> clip(const std::shared_ptr<ngraph::Node>& data,
float threshold)
{
if (threshold == 0.f)
{
return data;
}
float min_val = -threshold;
float max_val = threshold;
std::size_t size = ngraph::shape_size(data->get_shape());
const std::shared_ptr<ngraph::Node> min_val_node =
ngraph::op::Constant::create(data->get_element_type(),
data->get_shape(),
std::vector<float>(size, min_val));
const std::shared_ptr<ngraph::Node> max_val_node =
ngraph::op::Constant::create(data->get_element_type(),
data->get_shape(),
std::vector<float>(size, max_val));
return std::make_shared<ngraph::op::Minimum>(
max_val_node, std::make_shared<ngraph::op::Maximum>(data, min_val_node));
}
// Modify input vector in-place and return reference to modified vector.
std::vector<std::string>& to_lower_case(std::vector<std::string>&& vs)
{
std::transform(std::begin(vs),
std::end(vs),
std::begin(vs),
[](std::string& s) { return ngraph::to_lower(s); });
return vs;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
enum class LSTMInput
......@@ -257,9 +188,8 @@ namespace ngraph
explicit LSTMAttributes(const Node& node)
: m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")}
, m_clip_threshold{node.get_attribute_value<float>("clip", 0.f)}
, m_activations{to_lower_case(
node.get_attribute_value<std::vector<std::string>>(
"activations", {"sigmoid", "tanh", "tanh"}))}
, m_activations{node.get_attribute_value<std::vector<std::string>>(
"activations", {"sigmoid", "tanh", "tanh"})}
// Default values for activation functions are same as for corresponding
// ONNX operator.
, m_activation_alpha{node.get_attribute_value<std::vector<float>>(
......@@ -293,20 +223,17 @@ namespace ngraph
{
public:
explicit LSTMForward(std::shared_ptr<ngraph::Node> X,
std::shared_ptr<ngraph::Node> W,
std::shared_ptr<ngraph::Node> R,
std::shared_ptr<ngraph::Node> B,
std::shared_ptr<ngraph::Node> P,
std::shared_ptr<ngraph::Node> initial_h,
std::shared_ptr<ngraph::Node> initial_c,
std::shared_ptr<ngraph::Node> seq_lengths,
rnn::ActivationFunction activation_f,
rnn::ActivationFunction activation_g,
rnn::ActivationFunction activation_h,
bool input_forget = false,
float clip_threshold = 0.f)
const std::shared_ptr<ngraph::Node>& W,
const std::shared_ptr<ngraph::Node>& R,
const std::shared_ptr<ngraph::Node>& B,
const std::shared_ptr<ngraph::Node>& P,
const std::shared_ptr<ngraph::Node>& initial_h,
const std::shared_ptr<ngraph::Node>& initial_c,
const std::shared_ptr<ngraph::Node>& seq_lengths,
const LSTMAttributes& attributes)
: m_X{X}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
,
, // Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
, m_W{reshape::squeeze(W)}
, m_R{reshape::squeeze(R)}
, m_B{reshape::squeeze(B)}
......@@ -314,11 +241,7 @@ namespace ngraph
, m_initial_h{reshape::squeeze(initial_h)}
, m_initial_c{reshape::squeeze(initial_c)}
, m_seq_lengths{seq_lengths}
, m_activation_f{activation_f}
, m_activation_g{activation_g}
, m_activation_h{activation_h}
, m_input_forget{input_forget}
, m_clip_threshold{clip_threshold}
, m_attributes{attributes}
{
}
......@@ -332,7 +255,7 @@ namespace ngraph
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor forr peepholes. [num_directions, 3*hidde_size]
// P - The weight tensor for peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
......@@ -340,32 +263,11 @@ namespace ngraph
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// b_W_R - Bias vectors for input, output, forget, and cell gates.
// Concatenation of `[Wb, Rb]`.
// p_[iof] - P peephole weight vector for respectively: input, output,
// and forget gates.
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
//
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector p_iof = ngraph::builder::split(m_P, 3);
const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1);
const auto& p_f = p_iof.at(2);
NodeVector h_list;
NodeVector b_W_R = ngraph::builder::split(m_B, 2);
std::shared_ptr<ngraph::Node> bias = b_W_R.at(0) + b_W_R.at(1);
NodeVector h_list;
std::shared_ptr<ngraph::Node> H_t = m_initial_h;
std::shared_ptr<ngraph::Node> C_t = m_initial_c;
......@@ -393,47 +295,26 @@ namespace ngraph
std::int32_t time_step{1};
for (const auto& in_x : in_seqs)
{
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// Xt*(W^T) -- for [iofc] gates.
auto Xt_W = std::make_shared<ngraph::op::Dot>(
in_x, ngraph::op::util::transpose(m_W));
// Ht-1*(R^T) -- for [iofc] gates.
auto Ht_R = std::make_shared<ngraph::op::Dot>(
H_t, ngraph::op::util::transpose(m_R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, bias));
NodeVector split_gates = ngraph::builder::split(gates, 4, -1);
auto i = split_gates.at(0);
auto o = split_gates.at(1);
auto f = split_gates.at(2);
auto c = split_gates.at(3);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i = m_activation_f(clip(add(i, mul(p_i, C_t)), m_clip_threshold));
if (m_input_forget)
{
// Couple input with forget gate: 1 - i
f = sub(ngraph::op::Constant::create(
i->get_element_type(),
i->get_shape(),
std::vector<float>(shape_size(i->get_shape()), 1.f)),
i);
}
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f = m_activation_f(clip(add(f, mul(p_f, C_t)), m_clip_threshold));
}
// ft (.) Ct-1 + it (.) ct
auto C =
add(mul(f, C_t), mul(i, m_activation_g(clip(c, m_clip_threshold))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o = m_activation_f(clip(add(o, mul(p_o, C)), m_clip_threshold));
// ot (.) h(Ct)
auto H = mul(o, m_activation_h(C));
const std::shared_ptr<ngraph::Node>& lstm_cell =
std::make_shared<ngraph::op::LSTMCell>(
in_x,
m_W,
m_R,
H_t,
C_t,
m_attributes.m_hidden_size,
m_B,
m_P,
m_attributes.m_activations,
m_attributes.m_activation_alpha,
m_attributes.m_activation_beta,
m_attributes.m_clip_threshold,
m_attributes.m_input_forget);
const std::shared_ptr<ngraph::Node>& H =
get_output_element(lstm_cell, 0);
const std::shared_ptr<ngraph::Node>& C =
get_output_element(lstm_cell, 1);
// Expand tensors with empty outermost dim, so we can later concatenate
// them.
......@@ -528,41 +409,16 @@ namespace ngraph
}
std::shared_ptr<ngraph::Node> m_X;
std::shared_ptr<ngraph::Node> m_W;
std::shared_ptr<ngraph::Node> m_R;
std::shared_ptr<ngraph::Node> m_B;
std::shared_ptr<ngraph::Node> m_P;
std::shared_ptr<ngraph::Node> m_initial_h;
std::shared_ptr<ngraph::Node> m_initial_c;
std::shared_ptr<ngraph::Node> m_seq_lengths;
rnn::ActivationFunction m_activation_f;
rnn::ActivationFunction m_activation_g;
rnn::ActivationFunction m_activation_h;
// For coupling input and forget gates.
bool m_input_forget;
// For clipping cell input in the range [-clip_threshold, clip_threshold].
float m_clip_threshold;
const std::shared_ptr<ngraph::Node>& m_W;
const std::shared_ptr<ngraph::Node>& m_R;
const std::shared_ptr<ngraph::Node>& m_B;
const std::shared_ptr<ngraph::Node>& m_P;
const std::shared_ptr<ngraph::Node>& m_initial_h;
const std::shared_ptr<ngraph::Node>& m_initial_c;
const std::shared_ptr<ngraph::Node>& m_seq_lengths;
const LSTMAttributes& m_attributes;
};
rnn::ActivationFunction get_activation_function(const LSTMAttributes& attributes,
std::size_t idx)
{
rnn::ActivationFunction afunc =
rnn::get_activation_func_by_name(attributes.m_activations.at(idx));
// Set activation functions parameters (if any)
if (attributes.m_activation_alpha.size() > idx)
{
afunc.set_alpha(attributes.m_activation_alpha.at(idx));
}
if (attributes.m_activation_beta.size() > idx)
{
afunc.set_beta(attributes.m_activation_beta.at(idx));
}
return afunc;
}
} // anonymous namespace
namespace set_1
......@@ -572,14 +428,6 @@ namespace ngraph
LSTMNgInputMap input_map{node};
LSTMAttributes attributes{node};
// Get activation functions.
const rnn::ActivationFunction& activation_f =
get_activation_function(attributes, 0);
const rnn::ActivationFunction& activation_g =
get_activation_function(attributes, 1);
const rnn::ActivationFunction& activation_h =
get_activation_function(attributes, 2);
NodeVector results;
if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_FORWARD ||
......@@ -593,11 +441,7 @@ namespace ngraph
input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f,
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
attributes);
results = lstm_fwd.run(
(attributes.m_direction == LSTMDirection::LSTM_DIRECTION_REVERSE));
}
......@@ -625,11 +469,7 @@ namespace ngraph
H.at(0),
C.at(0),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f,
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
attributes);
LSTMForward lstm_reversed(input_map.at(LSTMInput::LSTM_INPUT_X),
W.at(1),
R.at(1),
......@@ -638,11 +478,7 @@ namespace ngraph
H.at(1),
C.at(1),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f,
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
attributes);
NodeVector fwd_results{lstm_fwd.run()};
NodeVector rev_results{lstm_fwd.run(true)};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment