Commit ef1c5347 authored by Adam Rogowiec's avatar Adam Rogowiec

LSTM cell fused operator.

parent 03dba84d
......@@ -302,12 +302,16 @@ set (SRC
op/fused/grn.hpp
op/fused/group_conv.hpp
op/fused/group_conv.cpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/mvn.cpp
op/fused/mvn.hpp
op/fused/normalize.cpp
op/fused/normalize.hpp
op/fused/prelu.cpp
op/fused/prelu.hpp
op/fused/rnn_cell_base.cpp
op/fused/rnn_cell_base.hpp
op/fused/scale_shift.cpp
op/fused/scale_shift.hpp
op/fused/space_to_depth.cpp
......@@ -320,6 +324,8 @@ set (SRC
op/fused/squeeze.hpp
op/fused/unsqueeze.cpp
op/fused/unsqueeze.hpp
op/util/activation_functions.cpp
op/util/activation_functions.hpp
op/util/arithmetic_reduction.cpp
op/util/arithmetic_reduction.hpp
op/util/binary_elementwise_arithmetic.cpp
......
......@@ -189,8 +189,6 @@ add_library(onnx_import STATIC
utils/reduction.hpp
utils/reshape.cpp
utils/reshape.hpp
utils/rnn/activation_functions.cpp
utils/rnn/activation_functions.hpp
utils/variadic.hpp)
set(ONNX_IMPORT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE INTERNAL "")
......
......@@ -103,6 +103,7 @@
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <functional>
#include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
// ------------- HELPER FUNCTIONS ---------------------------------------------
static shared_ptr<Node> add(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))};
}
static shared_ptr<Node> sub(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))};
}
static shared_ptr<Node> mul(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))};
}
static shared_ptr<Node> clip(const shared_ptr<Node>& data, float threshold)
{
if (threshold == 0.f)
{
return data;
}
float min_val = -threshold;
float max_val = threshold;
size_t size = shape_size(data->get_shape());
const shared_ptr<Node> min_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, min_val));
const shared_ptr<Node> max_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, max_val));
return make_shared<op::Minimum>(max_val_node, make_shared<op::Maximum>(data, min_val_node));
}
// ------------- LSTM_CELL ----------------------------------------------------
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
const shared_ptr<Node>& C_t,
size_t hidden_size)
: LSTMCell(X,
W,
R,
H_t,
C_t,
hidden_size,
vector<string>{"sigmoid", "tanh", "tanh"},
vector<float>{},
vector<float>{},
0.f,
false)
{
}
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
const shared_ptr<Node>& C_t,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
float clip,
bool input_forget)
: FusedOp("LSTMCell", {X, W, R, H_t, C_t})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_X{X}
, m_W{W}
, m_R{R}
, m_H_t{H_t}
, m_C_t{C_t}
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_activation_h{get_activation_function(2)}
, m_input_forget{input_forget}
{
constructor_validate_and_infer_types();
// Normally we would split B onto Wb an Rb and add them, however here they are all zeros,
// thus just initialize bias with appropriate shape and zeros.
m_bias = ngraph::op::Constant::create(element::f32,
Shape{m_gates_count * get_hidden_size()},
vector<float>(m_gates_count * get_hidden_size(), 0.f));
m_P = ngraph::op::Constant::create(element::f32,
Shape{m_peepholes_count * get_hidden_size()},
vector<float>(m_peepholes_count * get_hidden_size(), 0.f));
m_p_iof = builder::split(m_P, m_peepholes_count);
}
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
const shared_ptr<Node>& C_t,
size_t hidden_size,
const shared_ptr<Node>& B,
const shared_ptr<Node>& P,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
float clip,
bool input_forget)
: FusedOp("LSTMCell", {X, W, R, H_t, C_t, B, P})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_X{X}
, m_W{W}
, m_R{R}
, m_H_t{H_t}
, m_C_t{C_t}
, m_P{P}
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_activation_h{get_activation_function(2)}
, m_input_forget{input_forget}
{
// Normally we would split B onto Wb an Rb and add them, however here they are all zeros,
// thus just initialize bias with appropriate shape and zeros.
if (!B)
{
m_bias =
ngraph::op::Constant::create(element::f32,
Shape{m_gates_count * get_hidden_size()},
vector<float>(m_gates_count * get_hidden_size(), 0.f));
}
// Split B onto Wb an Rb and add them.
else
{
NodeVector b_W_R = builder::split(B, 2);
m_bias = b_W_R.at(0) + b_W_R.at(1);
}
if (!m_P)
{
m_P =
ngraph::op::Constant::create(element::f32,
Shape{m_peepholes_count * get_hidden_size()},
vector<float>(m_peepholes_count * get_hidden_size(), 0.f));
}
constructor_validate_and_infer_types();
m_p_iof = builder::split(m_P, m_peepholes_count);
}
void op::LSTMCell::pre_validate_and_infer_types()
{
}
NodeVector op::LSTMCell::decompose_op() const
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// ------ VARIABLE NAMES ------
// p_[iof] - P peephole weight vector for respectively: input, output,
// and forget gates.
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
const auto& p_i = m_p_iof.at(0);
const auto& p_o = m_p_iof.at(1);
const auto& p_f = m_p_iof.at(2);
// Xt*(W^T) -- for [iofc] gates.
auto Xt_W = std::make_shared<ngraph::op::Dot>(m_X, ngraph::op::util::transpose(m_W));
// Ht-1*(R^T) -- for [iofc] gates.
auto Ht_R = std::make_shared<ngraph::op::Dot>(m_H_t, ngraph::op::util::transpose(m_R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, m_bias));
NodeVector split_gates = builder::split(gates, 4, -1);
auto i = split_gates.at(0);
auto o = split_gates.at(1);
auto f = split_gates.at(2);
auto c = split_gates.at(3);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i = m_activation_f(clip(add(i, mul(p_i, m_C_t)), get_clip()));
if (m_input_forget)
{
// Couple input with forget gate: 1 - i
f = sub(ngraph::op::Constant::create(i->get_element_type(),
i->get_shape(),
std::vector<float>(shape_size(i->get_shape()), 1.f)),
i);
}
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f = m_activation_f(clip(add(f, mul(p_f, m_C_t)), get_clip()));
}
// ft (.) Ct-1 + it (.) ct
auto C = add(mul(f, m_C_t), mul(i, m_activation_g(clip(c, get_clip()))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o = m_activation_f(clip(add(o, mul(p_o, C)), get_clip()));
// ot (.) h(Ct)
auto H = mul(o, m_activation_h(C));
return {H, C};
}
shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
if (new_args.size() == 5)
{
return make_shared<LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
get_hidden_size(),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_clip(),
m_input_forget);
}
else if (new_args.size() == 7)
{
return make_shared<LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
get_hidden_size(),
new_args.at(5),
new_args.at(6),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_clip(),
m_input_forget);
}
else
{
throw ngraph_error("Incorrect number of new arguments");
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/fused/rnn_cell_base.hpp"
#include "ngraph/op/util/activation_functions.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class for lstm cell node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
///
/// Note this class represents only single *cell* and not whole LSTM *layer*.
///
class LSTMCell : public util::FusedOp, public RNNCellBase
{
public:
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [4*hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] C_t The cell state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
LSTMCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
const std::shared_ptr<Node>& C_t,
std::size_t hidden_size);
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [4*hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] C_t The cell state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] input_forget Controls coupling input and forget gates.
///
LSTMCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
const std::shared_ptr<Node>& C_t,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta,
float clip,
bool input_forget);
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [4*hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] C_t The cell state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape: [8*hidden_size].
/// \param[in] P The weight tensor for peepholes with shape:
/// [3*hidde_size] - 3 equals to only iof gates.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] input_forget Controls coupling input and forget gates.
///
LSTMCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
const std::shared_ptr<Node>& C_t,
std::size_t hidden_size,
const std::shared_ptr<Node>& B = nullptr,
const std::shared_ptr<Node>& P = nullptr,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
float clip = 0.f,
bool input_forget = false);
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_input_forget() const { return m_input_forget; };
private:
///
/// \brief The input data tensor. Shape: [batch_size, input_size].
///
const std::shared_ptr<Node>& m_X;
///
/// \brief The weight tensor. Shape: [4*hidden_size, input_size].
///
const std::shared_ptr<Node>& m_W;
///
/// \brief The recurrence weight tensor. Shape: [4*hidden_size, hidden_size].
///
const std::shared_ptr<Node>& m_R;
///
/// \brief The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
///
const std::shared_ptr<Node>& m_H_t;
///
/// \brief The cell state tensor at current time step. Shape: [batch_size, hidden_size].
///
const std::shared_ptr<Node>& m_C_t;
///
/// \brief The weight tensor for peepholes with shape: [3*hidde_size] - 3 equals to
/// only iof gates.
///
std::shared_ptr<Node> m_P;
///
/// \brief The Activation function f.
///
ActivationFunction m_activation_f;
///
/// \brief The Activation function g.
///
ActivationFunction m_activation_g;
///
/// \brief The Activation function h.
///
ActivationFunction m_activation_h;
///
/// \brief Controls whether to couple input and forget gates.
///
bool m_input_forget = false;
static constexpr std::size_t m_gates_count{4};
static constexpr std::size_t m_peepholes_count{3};
///
/// \brief Peephole weights vector for respectively: input, output, and forget gates.
///
NodeVector m_p_iof;
///
/// \brief Sum of biases (weight and recurrence) for input, output, forget, and cell gates.
///
/// Sum of `[Wb, Rb]`.
///
std::shared_ptr<Node> m_bias;
};
}
}
......@@ -27,6 +27,7 @@ NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
......
......@@ -74,6 +74,7 @@
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
......@@ -1055,6 +1056,29 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::LRN>(args[0], alpha, beta, bias, nsize);
break;
}
case OP_TYPEID::LSTMCell:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>();
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
args[3],
args[4],
hidden_size,
args[5],
args[6],
activations,
activation_alpha,
activation_beta,
clip,
input_forget);
break;
}
case OP_TYPEID::Max:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
......@@ -1979,6 +2003,17 @@ static json write(const Node& n, bool binary_constant_data)
node["nsize"] = tmp->get_nsize();
break;
}
case OP_TYPEID::LSTMCell:
{
auto tmp = dynamic_cast<const op::LSTMCell*>(&n);
node["hidden_size"] = tmp->get_hidden_size();
node["clip"] = tmp->get_clip();
node["activations"] = tmp->get_activations();
node["activation_alpha"] = tmp->get_activation_alpha();
node["activation_beta"] = tmp->get_activation_beta();
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::Max:
{
auto tmp = dynamic_cast<const op::Max*>(&n);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment