Commit 698aeb2f authored by Tomasz Socha's avatar Tomasz Socha Committed by Scott Cyphers

[FUSED] Add lstm sequence operator (#3595)

* unfold attributes

* Remove unnecesary if

* Rename run() -> lstm_pass()

* Unify usage of LSTMForward for one and bi directional LSTM

* Unify LSTMForward return values for one and bi directional LSTM

* Dirty moving LSTMForward into fused directory

* Accept lstm direction as string instead of enum

* Fused op which uses decompose_op in onnx_importer

* Rename LSTMForward -> LSTMSequence

* Split LSTMSequence to cpp and hpp

* Remove LSTMDirection enum

* Add getters for class fields

* Adjust constructors

* Add direction validation.

* Add support of LSTMSequence op in serializer

* Reorder fused op input order

* Style fix

* Fix for reorder of inputs

* Use NodeTypeInfo instead of static string

* Node -> value in doc

* Add doc for prepare_input method

* Fix shape inference

* Use enum instead of string for direction

* Add Type prop unit test

* Fix style
parent d92ef6b6
......@@ -342,6 +342,8 @@ set (SRC
op/fused/layer_norm.hpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/lstm_sequence.cpp
op/fused/lstm_sequence.hpp
op/fused/matmul.cpp
op/fused/matmul.hpp
op/fused/mvn.cpp
......
......@@ -14,7 +14,6 @@
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <map>
......@@ -22,23 +21,13 @@
#include <string>
#include <vector>
#include "core/null_node.hpp"
#include "exceptions.hpp"
#include "lstm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
......@@ -168,32 +157,6 @@ namespace ngraph
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ATTRIBUTES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
enum class LSTMDirection
{
LSTM_DIRECTION_FORWARD,
LSTM_DIRECTION_REVERSE,
LSTM_DIRECTION_BIDIRECTIONAL,
LSTM_DIRECTION_UNKNOWN,
};
LSTMDirection getLSTMDirection(const std::string& s)
{
if (s == "forward")
{
return LSTMDirection::LSTM_DIRECTION_FORWARD;
}
if (s == "reverse")
{
return LSTMDirection::LSTM_DIRECTION_REVERSE;
}
if (s == "bidirectional")
{
return LSTMDirection::LSTM_DIRECTION_BIDIRECTIONAL;
}
return LSTMDirection::LSTM_DIRECTION_UNKNOWN;
}
struct LSTMAttributes
{
explicit LSTMAttributes(const Node& node)
......@@ -211,223 +174,34 @@ namespace ngraph
node.get_attribute_value<std::int64_t>("input_forget", 0))}
{
m_clip_threshold = std::abs(m_clip_threshold);
std::string direction{ngraph::to_lower(
node.get_attribute_value<std::string>("direction", {"forward"}))};
ASSERT_VALID_ARGUMENT(node,
getLSTMDirection(direction) !=
LSTMDirection::LSTM_DIRECTION_UNKNOWN)
<< "Provided attribute \"direction\" value is incorrect: " << direction;
m_direction = getLSTMDirection(direction);
}
LSTMDirection m_direction;
std::int64_t m_hidden_size;
float m_clip_threshold;
std::vector<std::string> m_activations;
std::vector<float> m_activation_alpha;
std::vector<float> m_activation_beta;
bool m_input_forget;
};
class LSTMForward
{
public:
explicit LSTMForward(const std::shared_ptr<ngraph::Node>& X,
const std::shared_ptr<ngraph::Node>& W,
const std::shared_ptr<ngraph::Node>& R,
const std::shared_ptr<ngraph::Node>& B,
const std::shared_ptr<ngraph::Node>& P,
const std::shared_ptr<ngraph::Node>& initial_h,
const std::shared_ptr<ngraph::Node>& initial_c,
const std::shared_ptr<ngraph::Node>& seq_lengths,
const LSTMAttributes& attributes)
: m_X{X} // Since we have forward LSTM we can squeeze `num_directions` axis
// from inputs.
, m_W(builder::squeeze(W))
, m_R(builder::squeeze(R))
, m_B(builder::squeeze(B))
, m_P(builder::squeeze(P))
, m_initial_h(builder::squeeze(initial_h))
, m_initial_c(builder::squeeze(initial_c))
, m_seq_lengths(seq_lengths)
, m_attributes(attributes)
{
}
NodeVector run(bool reverse = false)
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ INPUTS ------
// X - The input tensor. [seq_length, batch_size, input_size]
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size,
// hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor for peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
NodeVector h_list;
std::shared_ptr<ngraph::Node> H_t = m_initial_h;
std::shared_ptr<ngraph::Node> C_t = m_initial_c;
if (reverse)
std::string direction = ngraph::to_lower(
node.get_attribute_value<std::string>("direction", "forward"));
NGRAPH_CHECK(direction == "bidirectional" || direction == "forward" ||
direction == "reverse",
"Provided direction: ",
direction,
" is invalid");
if (direction == "forward")
{
m_X = std::make_shared<ngraph::op::ReverseSequence>(
m_X, m_seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
m_direction = ngraph::op::LSTMSequence::direction::FORWARD;
}
NodeVector in_seqs{};
if (m_X->get_shape().at(0) != 1)
{
in_seqs = ngraph::builder::split(m_X, m_X->get_shape().at(0));
}
else
else if (direction == "reverse")
{
in_seqs = NodeVector{m_X};
m_direction = ngraph::op::LSTMSequence::direction::REVERSE;
}
for (auto& in_x : in_seqs)
else // (direction == "bidirectional")
{
// remove first empty dim, after above split.
in_x = builder::squeeze(in_x);
m_direction = ngraph::op::LSTMSequence::direction::BIDIRECTIONAL;
}
std::int32_t time_step{1};
for (const auto& in_x : in_seqs)
{
std::shared_ptr<ngraph::Node> lstm_cell =
std::make_shared<ngraph::op::LSTMCell>(
in_x,
m_W,
m_R,
H_t,
C_t,
m_attributes.m_hidden_size,
m_B,
m_P,
m_attributes.m_activations,
m_attributes.m_activation_alpha,
m_attributes.m_activation_beta,
m_attributes.m_clip_threshold,
m_attributes.m_input_forget);
std::shared_ptr<ngraph::Node> H = get_output_element(lstm_cell, 0);
std::shared_ptr<ngraph::Node> C = get_output_element(lstm_cell, 1);
// Expand tensors with empty outermost dim, so we can later concatenate
// them.
// Mask hidden state tensor in order to handle mixed sequence lengths.
// This results in zeroing out values in batches with sequence shorter
// than current time_step.
h_list.push_back(
get_masked_node(builder::expand_dims(H), time_step, 1));
// Reference implementation in ONNX Runtime doesn't mask values of Y_h
// and Y_c outputs, thus here we make sure that only appropriate batches
// (in respect to its sequence length) are updated. Those batches which
// has shorter sequences preserve the last value.
H_t = get_masked_node(H, time_step, 0, H_t);
C_t = get_masked_node(C, time_step, 0, C_t);
time_step++;
}
// The tensor that concats all the intermediate output values of the hidden.
// It has shape [seq_length, batch_size, hidden_size]
std::shared_ptr<ngraph::Node> Y{
std::make_shared<ngraph::op::Concat>(h_list, 0)};
// Get back the original order of the output data.
if (reverse)
{
Y = std::make_shared<ngraph::op::ReverseSequence>(
Y, m_seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
}
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
Y = builder::expand_dims(Y, 1);
// expand H_t and C_t so that it has expected shape:
// [num_directions, batch_size, hidden_size]
auto Y_h = builder::expand_dims(H_t);
auto Y_c = builder::expand_dims(C_t);
return {Y, Y_h, Y_c};
}
private:
///
/// \brief Gets the masked node according to sequence lenght in a batch.
///
/// \note Zeros out values or sets them to default value for inputs with
/// sequence lenght shorter than currently procssed time step.
///
/// \param[in] data The input node.
/// \param[in] time_step The current time step denoting sequence lenght.
/// \param[in] batch_axis The batch axis index of data tensor.
/// \param[in] default_value The default value for masked elements.
///
/// \return The masked node.
///
std::shared_ptr<ngraph::Node> get_masked_node(
const std::shared_ptr<ngraph::Node>& data,
std::int32_t time_step,
std::size_t batch_axis = 0,
const std::shared_ptr<ngraph::Node>& default_value = {nullptr})
{
std::shared_ptr<ngraph::Node> mask_value = default_value;
// Create zero mask value node.
if (!mask_value)
{
mask_value = ngraph::op::Constant::create(
data->get_element_type(),
data->get_shape(),
std::vector<float>(shape_size(data->get_shape()), 0.f));
}
// Create predicate nodes. The condition is whether current time step value
// is greater than sequence length for respective batch inputs.
std::shared_ptr<ngraph::Node> curr_time_step_node =
ngraph::op::Constant::create(
element::i32,
data->get_shape(),
std::vector<std::int32_t>(shape_size(data->get_shape()),
time_step));
std::shared_ptr<ngraph::Node> batch_seq_length =
ngraph::op::legacy_style_broadcast_for_binary_operation(
curr_time_step_node, m_seq_lengths, batch_axis)
.at(1);
// Create mask node deciding whether or not to mask batch data.
std::shared_ptr<ngraph::Node> mask_condition =
std::make_shared<ngraph::op::Greater>(curr_time_step_node,
batch_seq_length);
// Select values depnding on mask_condition.
// Select(<condition>, <true_value>, <false_value>)
return std::make_shared<ngraph::op::Select>(
mask_condition, mask_value, data);
}
std::shared_ptr<ngraph::Node> m_X;
std::shared_ptr<ngraph::Node> m_W;
std::shared_ptr<ngraph::Node> m_R;
std::shared_ptr<ngraph::Node> m_B;
std::shared_ptr<ngraph::Node> m_P;
std::shared_ptr<ngraph::Node> m_initial_h;
std::shared_ptr<ngraph::Node> m_initial_c;
std::shared_ptr<ngraph::Node> m_seq_lengths;
const LSTMAttributes& m_attributes;
ngraph::op::LSTMSequence::direction m_direction;
std::int64_t m_hidden_size;
float m_clip_threshold;
std::vector<std::string> m_activations;
std::vector<float> m_activation_alpha;
std::vector<float> m_activation_beta;
bool m_input_forget;
};
} // anonymous namespace
......@@ -439,73 +213,25 @@ namespace ngraph
LSTMNgInputMap input_map{node};
LSTMAttributes attributes{node};
NodeVector results;
if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_FORWARD ||
attributes.m_direction == LSTMDirection::LSTM_DIRECTION_REVERSE)
{
LSTMForward lstm_fwd(input_map.at(LSTMInput::LSTM_INPUT_X),
auto lstmSequence = std::make_shared<ngraph::op::LSTMSequence>(
input_map.at(LSTMInput::LSTM_INPUT_X),
input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
input_map.at(LSTMInput::LSTM_INPUT_W),
input_map.at(LSTMInput::LSTM_INPUT_R),
input_map.at(LSTMInput::LSTM_INPUT_B),
input_map.at(LSTMInput::LSTM_INPUT_P),
input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
attributes);
results = lstm_fwd.run(
(attributes.m_direction == LSTMDirection::LSTM_DIRECTION_REVERSE));
}
if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_BIDIRECTIONAL)
{
// In bidirectional mode weights are stacked together, so we must split
// them.
NodeVector W{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_W), 2)};
NodeVector R{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_R), 2)};
NodeVector B{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_B), 2)};
NodeVector P{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_P), 2)};
NodeVector H{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_INIT_H), 2)};
NodeVector C{
ngraph::builder::split(input_map.at(LSTMInput::LSTM_INPUT_INIT_C), 2)};
LSTMForward lstm_fwd(input_map.at(LSTMInput::LSTM_INPUT_X),
W.at(0),
R.at(0),
B.at(0),
P.at(0),
H.at(0),
C.at(0),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
attributes);
LSTMForward lstm_reversed(input_map.at(LSTMInput::LSTM_INPUT_X),
W.at(1),
R.at(1),
B.at(1),
P.at(1),
H.at(1),
C.at(1),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
attributes);
NodeVector fwd_results{lstm_fwd.run()};
NodeVector rev_results{lstm_reversed.run(true)};
// Stack together respective outputs from both forward and reverse passess.
std::shared_ptr<ngraph::Node> Y{std::make_shared<ngraph::op::Concat>(
NodeVector{fwd_results.at(0), rev_results.at(0)}, 1)};
std::shared_ptr<ngraph::Node> Y_h{std::make_shared<ngraph::op::Concat>(
NodeVector{fwd_results.at(1), rev_results.at(1)}, 0)};
std::shared_ptr<ngraph::Node> Y_c{std::make_shared<ngraph::op::Concat>(
NodeVector{fwd_results.at(2), rev_results.at(2)}, 0)};
results = NodeVector{Y, Y_h, Y_c};
}
return results;
attributes.m_hidden_size,
attributes.m_direction,
attributes.m_activation_alpha,
attributes.m_activation_beta,
attributes.m_activations,
attributes.m_clip_threshold,
attributes.m_input_forget);
return {std::make_shared<ngraph::op::GetOutputElement>(lstmSequence, 0),
std::make_shared<ngraph::op::GetOutputElement>(lstmSequence, 1),
std::make_shared<ngraph::op::GetOutputElement>(lstmSequence, 2)};
}
} // namespace set_1
......
......@@ -138,6 +138,7 @@ namespace ngraph
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/frontend/onnx_import/utils/reshape.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace ngraph;
using namespace std;
constexpr NodeTypeInfo op::LSTMSequence::type_info;
NodeVector op::LSTMSequence::decompose_op() const
{
NodeVector results;
if (m_direction == direction::FORWARD || m_direction == direction::REVERSE)
{
results = lstm_pass(m_direction == direction::REVERSE);
}
if (m_direction == direction::BIDIRECTIONAL)
{
NodeVector fwd_results{lstm_pass()};
NodeVector rev_results{lstm_pass(true)};
// Stack together respective outputs from both forward and reverse passess.
shared_ptr<Node> Y{
make_shared<op::Concat>(NodeVector{fwd_results.at(0), rev_results.at(0)}, 1)};
shared_ptr<Node> Y_h{
make_shared<op::Concat>(NodeVector{fwd_results.at(1), rev_results.at(1)}, 0)};
shared_ptr<Node> Y_c{
make_shared<op::Concat>(NodeVector{fwd_results.at(2), rev_results.at(2)}, 0)};
results = NodeVector{Y, Y_h, Y_c};
}
return results;
}
shared_ptr<Node> op::LSTMSequence::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<LSTMSequence>(new_args.at(0), // X
new_args.at(1), // initial_hidden_state
new_args.at(2), // initial_cell_state
new_args.at(3), // sequence_lengths
new_args.at(4), // W
new_args.at(5), // R
new_args.at(6), // B
new_args.at(7), // P
m_hidden_size,
m_direction,
m_activations_alpha,
m_activations_beta,
m_activations,
m_clip_threshold,
m_input_forget);
}
shared_ptr<Node> op::LSTMSequence::get_masked_node(const shared_ptr<Node>& data,
int32_t time_step,
size_t batch_axis,
const shared_ptr<Node>& default_value) const
{
shared_ptr<Node> mask_value = default_value;
// Create zero mask value node.
if (!mask_value)
{
mask_value = op::Constant::create(data->get_element_type(),
data->get_shape(),
vector<float>(shape_size(data->get_shape()), 0.f));
}
// Create predicate nodes. The condition is whether current time step value
// is greater than sequence length for respective batch inputs.
shared_ptr<Node> curr_time_step_node = op::Constant::create(
element::i32, data->get_shape(), vector<int32_t>(shape_size(data->get_shape()), time_step));
shared_ptr<Node> batch_seq_length =
op::legacy_style_broadcast_for_binary_operation(
curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis)
.at(1);
// Create mask node deciding whether or not to mask batch data.
shared_ptr<Node> mask_condition =
make_shared<op::Greater>(curr_time_step_node, batch_seq_length);
// Select values depnding on mask_condition.
// Select(<condition>, <true_value>, <false_value>)
return make_shared<op::Select>(mask_condition, mask_value, data);
}
NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ INPUTS ------
// X - The input tensor. [seq_length, batch_size, input_size]
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor for peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
NodeVector h_list;
shared_ptr<Node> X = input_value(0).get_node_shared_ptr();
shared_ptr<Node> H_t = prepare_input(input_value(1), is_reverse);
shared_ptr<Node> C_t = prepare_input(input_value(2), is_reverse);
shared_ptr<Node> seq_lengths = input_value(3).get_node_shared_ptr();
shared_ptr<Node> W = prepare_input(input_value(4), is_reverse);
shared_ptr<Node> R = prepare_input(input_value(5), is_reverse);
shared_ptr<Node> B = prepare_input(input_value(6), is_reverse);
shared_ptr<Node> P = prepare_input(input_value(7), is_reverse);
if (is_reverse)
{
X = make_shared<op::ReverseSequence>(X, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
}
NodeVector in_seqs = builder::split(X, X->get_shape().at(0));
for (auto& in_x : in_seqs)
{
// remove first empty dim, after above split.
in_x = builder::squeeze(in_x);
}
int32_t time_step{1};
for (const auto& in_x : in_seqs)
{
shared_ptr<Node> lstm_cell = make_shared<op::LSTMCell>(in_x,
W,
R,
H_t,
C_t,
m_hidden_size,
B,
P,
m_activations,
m_activations_alpha,
m_activations_beta,
m_clip_threshold,
m_input_forget);
shared_ptr<Node> H = get_output_element(lstm_cell, 0);
shared_ptr<Node> C = get_output_element(lstm_cell, 1);
// Expand tensors with empty outermost dim, so we can later concatenate
// them.
// Mask hidden state tensor in order to handle mixed sequence lengths.
// This results in zeroing out values in batches with sequence shorter
// than current time_step.
h_list.push_back(get_masked_node(builder::expand_dims(H), time_step, 1));
// Reference implementation in ONNX Runtime doesn't mask values of Y_h
// and Y_c outputs, thus here we make sure that only appropriate batches
// (in respect to its sequence length) are updated. Those batches which
// has shorter sequences preserve the last value.
H_t = get_masked_node(H, time_step, 0, H_t);
C_t = get_masked_node(C, time_step, 0, C_t);
time_step++;
}
// The tensor that concats all the intermediate output values of the hidden.
// It has shape [seq_length, batch_size, hidden_size]
shared_ptr<Node> Y{make_shared<op::Concat>(h_list, 0)};
// Get back the original order of the output data.
if (is_reverse)
{
Y = make_shared<op::ReverseSequence>(Y, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
}
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
Y = builder::expand_dims(Y, 1);
// expand H_t and C_t so that it has expected shape:
// [num_directions, batch_size, hidden_size]
auto Y_h = builder::expand_dims(H_t);
auto Y_c = builder::expand_dims(C_t);
return {Y, Y_h, Y_c};
}
shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reverse) const
{
// In bidirectional mode inputs are stacked together, so we must split them.
shared_ptr<Node> tmp = node.get_node_shared_ptr();
if (m_direction == direction::BIDIRECTIONAL)
{
tmp = builder::split(node, 2).at(is_reverse ? 1 : 0);
}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
return builder::squeeze(tmp);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class for lstm sequence node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
///
class LSTMSequence : public util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LSTMSequence", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
LSTMSequence() = default;
enum class direction
{
FORWARD,
REVERSE,
BIDIRECTIONAL
};
explicit LSTMSequence(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& sequence_lengths,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
const Output<Node>& P,
const std::int64_t hidden_size,
const direction lstm_direction,
const std::vector<float> activations_alpha = {},
const std::vector<float> activations_beta = {},
const std::vector<std::string> activations = {"sigmoid",
"tanh",
"tanh"},
const float clip_threshold = 0,
const bool input_forget = false)
: FusedOp(
{X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B, P})
, m_activations_alpha(activations_alpha)
, m_activations_beta(activations_beta)
, m_activations(activations)
, m_clip_threshold(clip_threshold)
, m_direction(lstm_direction)
, m_hidden_size(hidden_size)
, m_input_forget(input_forget)
{
constructor_validate_and_infer_types();
}
explicit LSTMSequence(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& sequence_lengths,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
const std::int64_t hidden_size,
const direction lstm_direction,
const std::vector<float> activations_alpha = {},
const std::vector<float> activations_beta = {},
const std::vector<std::string> activations = {"sigmoid",
"tanh",
"tanh"},
const float clip_threshold = 0,
const bool input_forget = false)
: LSTMSequence(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
Constant::create(
element::f32,
Shape{(lstm_direction == direction::BIDIRECTIONAL ? 2UL : 1UL),
3UL * static_cast<size_t>(hidden_size)},
std::vector<float>{0.f}),
hidden_size,
lstm_direction,
activations_alpha,
activations_beta,
activations,
clip_threshold,
input_forget)
{
}
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::vector<float> get_activations_alpha() const { return m_activations_alpha; }
std::vector<float> get_activations_beta() const { return m_activations_beta; }
std::vector<std::string> get_activations() const { return m_activations; }
float get_clip_threshold() const { return m_clip_threshold; }
direction get_direction() const { return m_direction; }
std::int64_t get_hidden_size() const { return m_hidden_size; }
bool get_input_forget() const { return m_input_forget; }
private:
///
/// \brief Gets the masked value according to sequence lenght in a batch.
///
/// \note Zeros out values or sets them to default value for inputs with
/// sequence lenght shorter than currently procssed time step.
///
/// \param[in] data The input value.
/// \param[in] time_step The current time step denoting sequence lenght.
/// \param[in] batch_axis The batch axis index of data tensor.
/// \param[in] default_value The default value for masked elements.
///
/// \return The masked value.
///
std::shared_ptr<Node> get_masked_node(const std::shared_ptr<Node>& data,
std::int32_t time_step,
std::size_t batch_axis = 0,
const std::shared_ptr<Node>& default_value = {
nullptr}) const;
NodeVector lstm_pass(bool is_reverse = false) const;
// Split(bi-directional) and squeeze input data to remove 'num_direction' dimension.
std::shared_ptr<Node> prepare_input(Output<Node> node, bool is_reverse) const;
const std::vector<float> m_activations_alpha;
const std::vector<float> m_activations_beta;
const std::vector<std::string> m_activations;
const float m_clip_threshold;
const direction m_direction;
const std::int64_t m_hidden_size;
const bool m_input_forget;
};
} // namespace op
} // namespace ngraph
......@@ -40,6 +40,7 @@ NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
......
......@@ -82,6 +82,7 @@
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
......@@ -1556,6 +1557,52 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
input_forget);
break;
}
case OP_TYPEID::LSTMSequence:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activations_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activations_beta = node_js.at("activations_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>();
auto direction = node_js.at("direction").get<op::LSTMSequence::direction>();
if (args.size() == 8)
{
node = make_shared<op::LSTMSequence>(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7],
hidden_size,
direction,
activations_alpha,
activations_beta,
activations,
clip,
input_forget);
}
else
{
node = make_shared<op::LSTMSequence>(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
hidden_size,
direction,
activations_alpha,
activations_beta,
activations,
clip,
input_forget);
}
break;
}
case OP_TYPEID::MatMul:
{
bool transpose_a = node_js.at("transpose_a").get<bool>();
......@@ -2953,6 +3000,18 @@ json JSONSerializer::serialize_node(const Node& n)
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::LSTMSequence:
{
auto tmp = dynamic_cast<const op::LSTMSequence*>(&n);
node["direction"] = tmp->get_direction();
node["hidden_size"] = tmp->get_hidden_size();
node["clip_threshold"] = tmp->get_clip_threshold();
node["activations"] = tmp->get_activations();
node["activations_alpha"] = tmp->get_activations_alpha();
node["activations_beta"] = tmp->get_activations_beta();
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::MatMul:
{
auto tmp = static_cast<const op::MatMul*>(&n);
......
......@@ -131,6 +131,7 @@ set(SRC
type_prop/layer_norm.cpp
type_prop/lrn.cpp
type_prop/lstm_cell.cpp
type_prop/lstm_sequence.cpp
type_prop/matmul.cpp
type_prop/max_pool.cpp
type_prop/mvn.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, lstm_sequence)
{
const auto X = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
const auto W = make_shared<op::Parameter>(element::f32, Shape{1, 12, 4});
const auto R = make_shared<op::Parameter>(element::f32, Shape{1, 12, 3});
const auto initial_hidden_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto B = make_shared<op::Parameter>(element::f32, Shape{1, 24});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{2});
const auto hidden_size = 3;
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
hidden_size,
op::LSTMSequence::direction::FORWARD);
EXPECT_EQ(lstm_sequence->output(0).get_element_type(), element::f32);
EXPECT_EQ(lstm_sequence->output(0).get_shape(), (Shape{1, 1, 2, 3}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment