Commit 698aeb2f authored by Tomasz Socha's avatar Tomasz Socha Committed by Scott Cyphers

[FUSED] Add lstm sequence operator (#3595)

* unfold attributes

* Remove unnecesary if

* Rename run() -> lstm_pass()

* Unify usage of LSTMForward for one and bi directional LSTM

* Unify LSTMForward return values for one and bi directional LSTM

* Dirty moving LSTMForward into fused directory

* Accept lstm direction as string instead of enum

* Fused op which uses decompose_op in onnx_importer

* Rename LSTMForward -> LSTMSequence

* Split LSTMSequence to cpp and hpp

* Remove LSTMDirection enum

* Add getters for class fields

* Adjust constructors

* Add direction validation.

* Add support of LSTMSequence op in serializer

* Reorder fused op input order

* Style fix

* Fix for reorder of inputs

* Use NodeTypeInfo instead of static string

* Node -> value in doc

* Add doc for prepare_input method

* Fix shape inference

* Use enum instead of string for direction

* Add Type prop unit test

* Fix style
parent d92ef6b6
......@@ -342,6 +342,8 @@ set (SRC
op/fused/layer_norm.hpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/lstm_sequence.cpp
op/fused/lstm_sequence.hpp
op/fused/matmul.cpp
op/fused/matmul.hpp
op/fused/mvn.cpp
......
This diff is collapsed.
......@@ -138,6 +138,7 @@ namespace ngraph
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/frontend/onnx_import/utils/reshape.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace ngraph;
using namespace std;
constexpr NodeTypeInfo op::LSTMSequence::type_info;
NodeVector op::LSTMSequence::decompose_op() const
{
NodeVector results;
if (m_direction == direction::FORWARD || m_direction == direction::REVERSE)
{
results = lstm_pass(m_direction == direction::REVERSE);
}
if (m_direction == direction::BIDIRECTIONAL)
{
NodeVector fwd_results{lstm_pass()};
NodeVector rev_results{lstm_pass(true)};
// Stack together respective outputs from both forward and reverse passess.
shared_ptr<Node> Y{
make_shared<op::Concat>(NodeVector{fwd_results.at(0), rev_results.at(0)}, 1)};
shared_ptr<Node> Y_h{
make_shared<op::Concat>(NodeVector{fwd_results.at(1), rev_results.at(1)}, 0)};
shared_ptr<Node> Y_c{
make_shared<op::Concat>(NodeVector{fwd_results.at(2), rev_results.at(2)}, 0)};
results = NodeVector{Y, Y_h, Y_c};
}
return results;
}
shared_ptr<Node> op::LSTMSequence::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<LSTMSequence>(new_args.at(0), // X
new_args.at(1), // initial_hidden_state
new_args.at(2), // initial_cell_state
new_args.at(3), // sequence_lengths
new_args.at(4), // W
new_args.at(5), // R
new_args.at(6), // B
new_args.at(7), // P
m_hidden_size,
m_direction,
m_activations_alpha,
m_activations_beta,
m_activations,
m_clip_threshold,
m_input_forget);
}
shared_ptr<Node> op::LSTMSequence::get_masked_node(const shared_ptr<Node>& data,
int32_t time_step,
size_t batch_axis,
const shared_ptr<Node>& default_value) const
{
shared_ptr<Node> mask_value = default_value;
// Create zero mask value node.
if (!mask_value)
{
mask_value = op::Constant::create(data->get_element_type(),
data->get_shape(),
vector<float>(shape_size(data->get_shape()), 0.f));
}
// Create predicate nodes. The condition is whether current time step value
// is greater than sequence length for respective batch inputs.
shared_ptr<Node> curr_time_step_node = op::Constant::create(
element::i32, data->get_shape(), vector<int32_t>(shape_size(data->get_shape()), time_step));
shared_ptr<Node> batch_seq_length =
op::legacy_style_broadcast_for_binary_operation(
curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis)
.at(1);
// Create mask node deciding whether or not to mask batch data.
shared_ptr<Node> mask_condition =
make_shared<op::Greater>(curr_time_step_node, batch_seq_length);
// Select values depnding on mask_condition.
// Select(<condition>, <true_value>, <false_value>)
return make_shared<op::Select>(mask_condition, mask_value, data);
}
NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ INPUTS ------
// X - The input tensor. [seq_length, batch_size, input_size]
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor for peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
NodeVector h_list;
shared_ptr<Node> X = input_value(0).get_node_shared_ptr();
shared_ptr<Node> H_t = prepare_input(input_value(1), is_reverse);
shared_ptr<Node> C_t = prepare_input(input_value(2), is_reverse);
shared_ptr<Node> seq_lengths = input_value(3).get_node_shared_ptr();
shared_ptr<Node> W = prepare_input(input_value(4), is_reverse);
shared_ptr<Node> R = prepare_input(input_value(5), is_reverse);
shared_ptr<Node> B = prepare_input(input_value(6), is_reverse);
shared_ptr<Node> P = prepare_input(input_value(7), is_reverse);
if (is_reverse)
{
X = make_shared<op::ReverseSequence>(X, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
}
NodeVector in_seqs = builder::split(X, X->get_shape().at(0));
for (auto& in_x : in_seqs)
{
// remove first empty dim, after above split.
in_x = builder::squeeze(in_x);
}
int32_t time_step{1};
for (const auto& in_x : in_seqs)
{
shared_ptr<Node> lstm_cell = make_shared<op::LSTMCell>(in_x,
W,
R,
H_t,
C_t,
m_hidden_size,
B,
P,
m_activations,
m_activations_alpha,
m_activations_beta,
m_clip_threshold,
m_input_forget);
shared_ptr<Node> H = get_output_element(lstm_cell, 0);
shared_ptr<Node> C = get_output_element(lstm_cell, 1);
// Expand tensors with empty outermost dim, so we can later concatenate
// them.
// Mask hidden state tensor in order to handle mixed sequence lengths.
// This results in zeroing out values in batches with sequence shorter
// than current time_step.
h_list.push_back(get_masked_node(builder::expand_dims(H), time_step, 1));
// Reference implementation in ONNX Runtime doesn't mask values of Y_h
// and Y_c outputs, thus here we make sure that only appropriate batches
// (in respect to its sequence length) are updated. Those batches which
// has shorter sequences preserve the last value.
H_t = get_masked_node(H, time_step, 0, H_t);
C_t = get_masked_node(C, time_step, 0, C_t);
time_step++;
}
// The tensor that concats all the intermediate output values of the hidden.
// It has shape [seq_length, batch_size, hidden_size]
shared_ptr<Node> Y{make_shared<op::Concat>(h_list, 0)};
// Get back the original order of the output data.
if (is_reverse)
{
Y = make_shared<op::ReverseSequence>(Y, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
}
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
Y = builder::expand_dims(Y, 1);
// expand H_t and C_t so that it has expected shape:
// [num_directions, batch_size, hidden_size]
auto Y_h = builder::expand_dims(H_t);
auto Y_c = builder::expand_dims(C_t);
return {Y, Y_h, Y_c};
}
shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reverse) const
{
// In bidirectional mode inputs are stacked together, so we must split them.
shared_ptr<Node> tmp = node.get_node_shared_ptr();
if (m_direction == direction::BIDIRECTIONAL)
{
tmp = builder::split(node, 2).at(is_reverse ? 1 : 0);
}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
return builder::squeeze(tmp);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class for lstm sequence node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
///
class LSTMSequence : public util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LSTMSequence", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
LSTMSequence() = default;
enum class direction
{
FORWARD,
REVERSE,
BIDIRECTIONAL
};
explicit LSTMSequence(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& sequence_lengths,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
const Output<Node>& P,
const std::int64_t hidden_size,
const direction lstm_direction,
const std::vector<float> activations_alpha = {},
const std::vector<float> activations_beta = {},
const std::vector<std::string> activations = {"sigmoid",
"tanh",
"tanh"},
const float clip_threshold = 0,
const bool input_forget = false)
: FusedOp(
{X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B, P})
, m_activations_alpha(activations_alpha)
, m_activations_beta(activations_beta)
, m_activations(activations)
, m_clip_threshold(clip_threshold)
, m_direction(lstm_direction)
, m_hidden_size(hidden_size)
, m_input_forget(input_forget)
{
constructor_validate_and_infer_types();
}
explicit LSTMSequence(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& sequence_lengths,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
const std::int64_t hidden_size,
const direction lstm_direction,
const std::vector<float> activations_alpha = {},
const std::vector<float> activations_beta = {},
const std::vector<std::string> activations = {"sigmoid",
"tanh",
"tanh"},
const float clip_threshold = 0,
const bool input_forget = false)
: LSTMSequence(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
Constant::create(
element::f32,
Shape{(lstm_direction == direction::BIDIRECTIONAL ? 2UL : 1UL),
3UL * static_cast<size_t>(hidden_size)},
std::vector<float>{0.f}),
hidden_size,
lstm_direction,
activations_alpha,
activations_beta,
activations,
clip_threshold,
input_forget)
{
}
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::vector<float> get_activations_alpha() const { return m_activations_alpha; }
std::vector<float> get_activations_beta() const { return m_activations_beta; }
std::vector<std::string> get_activations() const { return m_activations; }
float get_clip_threshold() const { return m_clip_threshold; }
direction get_direction() const { return m_direction; }
std::int64_t get_hidden_size() const { return m_hidden_size; }
bool get_input_forget() const { return m_input_forget; }
private:
///
/// \brief Gets the masked value according to sequence lenght in a batch.
///
/// \note Zeros out values or sets them to default value for inputs with
/// sequence lenght shorter than currently procssed time step.
///
/// \param[in] data The input value.
/// \param[in] time_step The current time step denoting sequence lenght.
/// \param[in] batch_axis The batch axis index of data tensor.
/// \param[in] default_value The default value for masked elements.
///
/// \return The masked value.
///
std::shared_ptr<Node> get_masked_node(const std::shared_ptr<Node>& data,
std::int32_t time_step,
std::size_t batch_axis = 0,
const std::shared_ptr<Node>& default_value = {
nullptr}) const;
NodeVector lstm_pass(bool is_reverse = false) const;
// Split(bi-directional) and squeeze input data to remove 'num_direction' dimension.
std::shared_ptr<Node> prepare_input(Output<Node> node, bool is_reverse) const;
const std::vector<float> m_activations_alpha;
const std::vector<float> m_activations_beta;
const std::vector<std::string> m_activations;
const float m_clip_threshold;
const direction m_direction;
const std::int64_t m_hidden_size;
const bool m_input_forget;
};
} // namespace op
} // namespace ngraph
......@@ -40,6 +40,7 @@ NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
......
......@@ -82,6 +82,7 @@
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
......@@ -1556,6 +1557,52 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
input_forget);
break;
}
case OP_TYPEID::LSTMSequence:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activations_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activations_beta = node_js.at("activations_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>();
auto direction = node_js.at("direction").get<op::LSTMSequence::direction>();
if (args.size() == 8)
{
node = make_shared<op::LSTMSequence>(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7],
hidden_size,
direction,
activations_alpha,
activations_beta,
activations,
clip,
input_forget);
}
else
{
node = make_shared<op::LSTMSequence>(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
hidden_size,
direction,
activations_alpha,
activations_beta,
activations,
clip,
input_forget);
}
break;
}
case OP_TYPEID::MatMul:
{
bool transpose_a = node_js.at("transpose_a").get<bool>();
......@@ -2953,6 +3000,18 @@ json JSONSerializer::serialize_node(const Node& n)
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::LSTMSequence:
{
auto tmp = dynamic_cast<const op::LSTMSequence*>(&n);
node["direction"] = tmp->get_direction();
node["hidden_size"] = tmp->get_hidden_size();
node["clip_threshold"] = tmp->get_clip_threshold();
node["activations"] = tmp->get_activations();
node["activations_alpha"] = tmp->get_activations_alpha();
node["activations_beta"] = tmp->get_activations_beta();
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::MatMul:
{
auto tmp = static_cast<const op::MatMul*>(&n);
......
......@@ -131,6 +131,7 @@ set(SRC
type_prop/layer_norm.cpp
type_prop/lrn.cpp
type_prop/lstm_cell.cpp
type_prop/lstm_sequence.cpp
type_prop/matmul.cpp
type_prop/max_pool.cpp
type_prop/mvn.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, lstm_sequence)
{
const auto X = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
const auto W = make_shared<op::Parameter>(element::f32, Shape{1, 12, 4});
const auto R = make_shared<op::Parameter>(element::f32, Shape{1, 12, 3});
const auto initial_hidden_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto B = make_shared<op::Parameter>(element::f32, Shape{1, 24});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{2});
const auto hidden_size = 3;
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
hidden_size,
op::LSTMSequence::direction::FORWARD);
EXPECT_EQ(lstm_sequence->output(0).get_element_type(), element::f32);
EXPECT_EQ(lstm_sequence->output(0).get_shape(), (Shape{1, 1, 2, 3}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment