Commit 53a6af8d authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

[SPEC] LSTMCell, RNNCell updates. (#3733)

parent bc0fd13f
......@@ -242,11 +242,11 @@ def group_convolution(data_batch, # type: Node
@nameable_op
def rnn_cell(X, # type: Node
H_t, # type: Node
W, # type: Node
R, # type: Node
H_t, # type: Node
hidden_size, # type: int
B, # type: Node
hidden_size, # type: int
activations, # type: List[str]
activation_alpha, # type: List[float]
activation_beta, # type: List[float]
......@@ -261,29 +261,30 @@ def rnn_cell(X, # type: Node
Note this class represents only single *cell* and not whole RNN *layer*.
:param X: The input tensor with shape: [batch_size, input_size].
:param W: The weight tensor with shape: [hidden_size, input_size].
:param R: The recurrence weight tensor with shape: [hidden_size, hidden_size].
:param H_t: The hidden state tensor at current time step with
shape: [batch_size, hidden_size].
:param hidden_size: The number of hidden units for recurrent cell.
:param B: The bias tensor for input gate with shape: [2*hidden_size].
:param activations: The vector of activation functions used inside recurrent cell.
:param activation_alpha: The vector of alpha parameters for activation
functions in order respective to activation list.
:param activation_beta: The vector of beta parameters for activation functions
in order respective to activation list.
:param clip: The value defining clipping range [-clip, clip] on
input of activation functions.
:param name: Optional output node name.
:return: The new node performing a RNNCell operation on tensor from input node.
:param X: The input tensor with shape: [batch_size, input_size].
:param H_t: The hidden state tensor at current time step with shape:
[batch_size, hidden_size].
:param W: The weight tensor with shape: [hidden_size, input_size].
:param R: The recurrence weight tensor with shape: [hidden_size,
hidden_size].
:param B: The bias tensor for input gate with shape: [2*hidden_size].
:param hidden_size: The number of hidden units for recurrent cell.
:param activations: The vector of activation functions used inside recurrent cell.
:param activation_alpha: The vector of alpha parameters for activation functions in
order respective to activation list.
:param activation_beta: The vector of beta parameters for activation functions in order
respective to activation list.
:param clip: The value defining clipping range [-clip, clip] on input of
activation functions.
:param name: Optional output node name.
:returns: The new node performing a RNNCell operation on tensor from input node.
"""
return RNNCell(X,
H_t,
W,
R,
H_t,
hidden_size,
B,
hidden_size,
activations,
activation_alpha,
activation_beta,
......
......@@ -31,8 +31,8 @@ void regclass_pyngraph_op_RNNCell(py::module m)
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
int&,
const std::shared_ptr<ngraph::Node>&,
int&,
const std::vector<std::string>&,
const std::vector<float>&,
const std::vector<float>&,
......
......@@ -455,17 +455,20 @@ def test_rnn_cell_operator():
W_shape = [hidden_size, input_size]
R_shape = [hidden_size, hidden_size]
H_t_shape = [batch_size, hidden_size]
B_shape = [2 * hidden_size]
B_shape = [hidden_size]
parameter_X = ng.parameter(X_shape, name='X', dtype=np.float32)
parameter_H_t = ng.parameter(H_t_shape, name='H_t', dtype=np.float32)
parameter_W = ng.parameter(W_shape, name='W', dtype=np.float32)
parameter_R = ng.parameter(R_shape, name='R', dtype=np.float32)
parameter_H_t = ng.parameter(H_t_shape, name='H_t', dtype=np.float32)
parameter_B = ng.parameter(B_shape, name='B', dtype=np.float32)
X_value = np.array([0.3432185, 0.612268, 0.20272376,
0.9513413, 0.30585995, 0.7265472],
dtype=np.float32).reshape(X_shape)
H_t_value = np.array([0.12444675, 0.52055854, 0.46489045,
0.4983964, 0.7730452, 0.28439692],
dtype=np.float32).reshape(H_t_shape)
W_value = np.array([0.41930267, 0.7872176, 0.89940447,
0.23659843, 0.24676207, 0.17101714,
0.3147149, 0.6555601, 0.4559603],
......@@ -474,11 +477,7 @@ def test_rnn_cell_operator():
0.71549815, 0.18775631, 0.3182116,
0.25392973, 0.38301638, 0.85531586],
dtype=np.float32).reshape(R_shape)
H_t_value = np.array([0.12444675, 0.52055854, 0.46489045,
0.4983964, 0.7730452, 0.28439692],
dtype=np.float32).reshape(H_t_shape)
B_value = np.array([0.45513555, 0.96227735, 0.24737759,
0.57380486, 0.67398053, 0.18968852],
B_value = np.array([1.0289404, 1.6362579, 0.4370661],
dtype=np.float32).reshape(B_shape)
activations = ['sigmoid']
activation_alpha = []
......@@ -486,23 +485,23 @@ def test_rnn_cell_operator():
clip = 2.88
model = ng.rnn_cell(parameter_X,
parameter_H_t,
parameter_W,
parameter_R,
parameter_H_t,
hidden_size,
parameter_B,
hidden_size,
activations,
activation_alpha,
activation_beta,
clip)
computation = runtime.computation(model,
parameter_X,
parameter_H_t,
parameter_W,
parameter_R,
parameter_H_t,
parameter_B)
result = computation(X_value, W_value, R_value, H_t_value, B_value)
result = computation(X_value, H_t_value, W_value, R_value, B_value)
expected = np.array([0.94126844, 0.9036043, 0.841243,
0.9468489, 0.934215, 0.873708],
dtype=np.float32).reshape(batch_size, hidden_size)
......
......@@ -22,7 +22,9 @@
#include <vector>
#include "exceptions.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -82,17 +84,19 @@ namespace ngraph
m_map[LSTMInput::LSTM_INPUT_W]->get_shape().front();
// ------ Optional inputs ------
// The bias tensor for input gate. Shape [num_directions, 8*hidden_size]
// The bias tensor for input gate. Shape [num_directions, 4*hidden_size]
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
{
m_map[LSTMInput::LSTM_INPUT_B] = ng_inputs.at(3);
auto bias = ng_inputs.at(3);
auto split_bias = builder::split(bias, 2, 1);
m_map[LSTMInput::LSTM_INPUT_B] = split_bias.at(0) + split_bias.at(1);
}
else
{
m_map[LSTMInput::LSTM_INPUT_B] = ngraph::op::Constant::create(
element::f32,
Shape{num_directions, 2 * gates_count * hidden_size},
std::vector<float>(num_directions * 2 * gates_count * hidden_size,
Shape{num_directions, gates_count * hidden_size},
std::vector<float>(num_directions * gates_count * hidden_size,
0.f));
}
// The lengths of the sequences in a batch. Shape [batch_size]
......@@ -224,6 +228,7 @@ namespace ngraph
input_map.at(LSTMInput::LSTM_INPUT_P),
attributes.m_hidden_size,
attributes.m_direction,
ngraph::op::LSTMWeightsFormat::IOFC,
attributes.m_activation_alpha,
attributes.m_activation_beta,
attributes.m_activations,
......
......@@ -33,12 +33,12 @@ constexpr NodeTypeInfo op::GRUCell::type_info;
op::GRUCell::GRUCell(const Output<Node>& X,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& H_t,
const Output<Node>& initial_hidden_state,
size_t hidden_size)
: GRUCell(X,
W,
R,
H_t,
initial_hidden_state,
hidden_size,
vector<string>{"sigmoid", "tanh"},
vector<float>{},
......@@ -51,15 +51,15 @@ op::GRUCell::GRUCell(const Output<Node>& X,
op::GRUCell::GRUCell(const Output<Node>& X,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& H_t,
const Output<Node>& initial_hidden_state,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip,
bool linear_before_reset)
: FusedOp({X, W, R, H_t})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
: FusedOp({X, W, R, initial_hidden_state})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_linear_before_reset{linear_before_reset}
......@@ -71,16 +71,16 @@ op::GRUCell::GRUCell(const Output<Node>& X,
op::GRUCell::GRUCell(const Output<Node>& X,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& H_t,
const Output<Node>& initial_hidden_state,
size_t hidden_size,
const Output<Node>& B,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip,
bool linear_before_reset)
: FusedOp({X, W, R, H_t, B})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
: FusedOp({X, W, R, initial_hidden_state, B})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_linear_before_reset{linear_before_reset}
......@@ -129,7 +129,7 @@ void op::GRUCell::pre_validate_and_infer_types()
".");
NODE_VALIDATION_CHECK(this,
(ht_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor H_t must have shape (",
"Input tensor initial_hidden_state must have shape (",
batch_size,
", ",
get_hidden_size(),
......@@ -290,8 +290,8 @@ shared_ptr<Node> op::GRUCell::copy_with_new_args(const NodeVector& new_args) con
new_args.at(3),
get_hidden_size(),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_activations_alpha(),
get_activations_beta(),
get_clip(),
m_linear_before_reset);
}
......@@ -304,8 +304,8 @@ shared_ptr<Node> op::GRUCell::copy_with_new_args(const NodeVector& new_args) con
get_hidden_size(),
new_args.at(4),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_activations_alpha(),
get_activations_beta(),
get_clip(),
m_linear_before_reset);
}
......
......@@ -47,84 +47,90 @@ namespace ngraph
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] X The input tensor with shape: [batch_size,
/// input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
GRUCell(const Output<Node>& X,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& H_t,
const Output<Node>& initial_hidden_state,
std::size_t hidden_size);
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] X The input tensor with shape: [batch_size,
/// input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activations_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activations_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
GRUCell(const Output<Node>& X,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& H_t,
const Output<Node>& initial_hidden_state,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta,
const std::vector<float>& activations_alpha,
const std::vector<float>& activations_beta,
float clip,
bool linear_before_reset);
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape:
/// [2 * gates_count * hidden_size].
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] X The input tensor with shape: [batch_size,
/// input_size].
/// \param[in] W The weight tensor with shape: [gates_count *
/// hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape:
/// [2 * gates_count * hidden_size].
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activations_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activations_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] linear_before_reset Whether or not to apply the linear transformation
/// before multiplying by the output of the reset
/// gate.
///
GRUCell(const Output<Node>& X,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& H_t,
const Output<Node>& initial_hidden_state,
std::size_t hidden_size,
const Output<Node>& B,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
const std::vector<float>& activations_alpha = {},
const std::vector<float>& activations_beta = {},
float clip = 0.f,
bool linear_before_reset = false);
......
This diff is collapsed.
This diff is collapsed.
......@@ -58,21 +58,47 @@ NodeVector op::LSTMSequence::decompose_op() const
shared_ptr<Node> op::LSTMSequence::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<LSTMSequence>(new_args.at(0), // X
new_args.at(1), // initial_hidden_state
new_args.at(2), // initial_cell_state
new_args.at(3), // sequence_lengths
new_args.at(4), // W
new_args.at(5), // R
new_args.at(6), // B
new_args.at(7), // P
m_hidden_size,
m_direction,
m_activations_alpha,
m_activations_beta,
m_activations,
m_clip_threshold,
m_input_forget);
if (new_args.size() == 8)
{
return make_shared<LSTMSequence>(new_args.at(0), // X
new_args.at(1), // initial_hidden_state
new_args.at(2), // initial_cell_state
new_args.at(3), // sequence_lengths
new_args.at(4), // W
new_args.at(5), // R
new_args.at(6), // B
new_args.at(7), // P
m_hidden_size,
m_direction,
m_weights_format,
m_activations_alpha,
m_activations_beta,
m_activations,
m_clip_threshold,
m_input_forget);
}
else if (new_args.size() == 7)
{
return make_shared<LSTMSequence>(new_args.at(0), // X
new_args.at(1), // initial_hidden_state
new_args.at(2), // initial_cell_state
new_args.at(3), // sequence_lengths
new_args.at(4), // W
new_args.at(5), // R
new_args.at(6), // B
m_hidden_size,
m_direction,
m_weights_format,
m_activations_alpha,
m_activations_beta,
m_activations,
m_clip_threshold,
m_input_forget);
}
else
{
throw ngraph_error("Incorrect number of new arguments");
}
}
shared_ptr<Node> op::LSTMSequence::get_masked_node(const shared_ptr<Node>& data,
......@@ -157,13 +183,14 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
for (const auto& in_x : in_seqs)
{
shared_ptr<Node> lstm_cell = make_shared<op::LSTMCell>(in_x,
W,
R,
H_t,
C_t,
m_hidden_size,
W,
R,
B,
P,
m_hidden_size,
m_weights_format,
m_activations,
m_activations_alpha,
m_activations_beta,
......
......@@ -24,6 +24,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
......@@ -36,6 +37,9 @@ namespace ngraph
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
///
/// \sa LSTMCell, RNNCell, GRUCell
///
///
class LSTMSequence : public util::FusedOp
{
public:
......@@ -61,6 +65,7 @@ namespace ngraph
const Output<Node>& P,
const std::int64_t hidden_size,
const direction lstm_direction,
LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
const std::vector<float> activations_alpha = {},
const std::vector<float> activations_beta = {},
const std::vector<std::string> activations = {"sigmoid",
......@@ -77,6 +82,7 @@ namespace ngraph
, m_direction(lstm_direction)
, m_hidden_size(hidden_size)
, m_input_forget(input_forget)
, m_weights_format(weights_format)
{
constructor_validate_and_infer_types();
}
......@@ -90,6 +96,7 @@ namespace ngraph
const Output<Node>& B,
const std::int64_t hidden_size,
const direction lstm_direction,
LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
const std::vector<float> activations_alpha = {},
const std::vector<float> activations_beta = {},
const std::vector<std::string> activations = {"sigmoid",
......@@ -111,6 +118,7 @@ namespace ngraph
std::vector<float>{0.f}),
hidden_size,
lstm_direction,
weights_format,
activations_alpha,
activations_beta,
activations,
......@@ -131,6 +139,7 @@ namespace ngraph
direction get_direction() const { return m_direction; }
std::int64_t get_hidden_size() const { return m_hidden_size; }
bool get_input_forget() const { return m_input_forget; }
LSTMWeightsFormat get_weights_format() const { return m_weights_format; }
private:
///
/// \brief Gets the masked value according to sequence lenght in a batch.
......@@ -163,6 +172,7 @@ namespace ngraph
direction m_direction;
std::int64_t m_hidden_size;
bool m_input_forget;
LSTMWeightsFormat m_weights_format;
};
} // namespace op
} // namespace ngraph
......@@ -32,44 +32,34 @@ using namespace ngraph;
constexpr NodeTypeInfo op::RNNCell::type_info;
op::RNNCell::RNNCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& H_t,
size_t hidden_size)
: RNNCell(
X, W, R, H_t, hidden_size, vector<string>{"tanh"}, vector<float>{}, vector<float>{}, 0.f)
{
}
op::RNNCell::RNNCell(const Output<Node>& X,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& H_t,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip)
: FusedOp({X, W, R, H_t})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
: FusedOp({X, initial_hidden_state, W, R})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
, m_activation_f{get_activation_function(0)}
{
add_default_bias_input();
set_argument(4, get_default_bias_input());
constructor_validate_and_infer_types();
}
op::RNNCell::RNNCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& H_t,
size_t hidden_size,
const Output<Node>& B,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip)
: FusedOp({X, W, R, H_t, B})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
: FusedOp({X, initial_hidden_state, W, R, B})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
, m_activation_f{get_activation_function(0)}
{
constructor_validate_and_infer_types();
......@@ -83,9 +73,9 @@ void op::RNNCell::pre_validate_and_infer_types()
}
const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2);
const auto& ht_pshape = get_input_partial_shape(3);
const auto& ht_pshape = get_input_partial_shape(1);
const auto& w_pshape = get_input_partial_shape(2);
const auto& r_pshape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
......@@ -121,7 +111,7 @@ void op::RNNCell::pre_validate_and_infer_types()
".");
NODE_VALIDATION_CHECK(this,
(ht_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor H_t must have shape (",
"Input tensor initial_hidden_state must have shape (",
batch_size,
", ",
get_hidden_size(),
......@@ -137,9 +127,9 @@ void op::RNNCell::pre_validate_and_infer_types()
const Shape& b_shape{b_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(b_shape == Shape{2 * get_hidden_size()}),
(b_shape == Shape{get_hidden_size()}),
"Input tensor B must have shape (",
2 * get_hidden_size(),
get_hidden_size(),
"). Actual shape is:",
b_shape,
".");
......@@ -157,8 +147,7 @@ NodeVector op::RNNCell::decompose_op() const
// W - The weight tensor for input gate. Shape: [hidden_size, input_size].
// R - The recurrence weight tensor for input gate. Shape: [hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
// B - The bias tensor for the input gate. Shape: [2 * hidden_size].
// Concatenation of `[Wb, Rb]`.
// B - The bias tensor for the input gate. Shape: [hidden_size].
// Wb - W bias vectors for input gate.
// Rb - R bias vectors for input gate.
// ------ VARIABLE NAMES ------
......@@ -174,10 +163,10 @@ NodeVector op::RNNCell::decompose_op() const
// --------------------
Output<Node> X = input_value(0);
Output<Node> W = input_value(1);
Output<Node> R = input_value(2);
Output<Node> H_t = input_value(3);
Output<Node> bias = get_bias();
Output<Node> H_t = input_value(1);
Output<Node> W = input_value(2);
Output<Node> R = input_value(3);
Output<Node> bias = input_value(4);
// Xt*(W^T)
auto Xt_W = std::make_shared<op::Dot>(X, builder::transpose(W));
......@@ -192,22 +181,12 @@ NodeVector op::RNNCell::decompose_op() const
return {i_t};
}
Output<Node> op::RNNCell::get_bias() const
Output<Node> op::RNNCell::get_default_bias_input() const
{
Output<Node> bias;
// Split B onto Wb an Rb and add them.
NodeVector b_W_R = builder::split(input_value(4), 2);
bias = b_W_R.at(0) + b_W_R.at(1);
return bias;
}
void op::RNNCell::add_default_bias_input()
{
Output<Node> B =
return Output<Node>{
op::Constant::create(input(0).get_element_type(),
Shape{2 * s_gates_count * get_hidden_size()},
vector<float>(2 * s_gates_count * get_hidden_size(), 0.f));
set_argument(4, B);
Shape{s_gates_count * get_hidden_size()},
vector<float>(s_gates_count * get_hidden_size(), 0.f))};
}
shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) const
......@@ -221,8 +200,8 @@ shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) con
new_args.at(3),
get_hidden_size(),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_activations_alpha(),
get_activations_beta(),
get_clip());
}
else if (new_args.size() == 5)
......@@ -231,11 +210,11 @@ shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) con
new_args.at(1),
new_args.at(2),
new_args.at(3),
get_hidden_size(),
new_args.at(4),
get_hidden_size(),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_activations_alpha(),
get_activations_beta(),
get_clip());
}
else
......
This diff is collapsed.
......@@ -39,13 +39,13 @@ static vector<string> to_lower_case(const vector<string>& vs)
op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
float clip,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta)
const vector<float>& activations_alpha,
const vector<float>& activations_beta)
: m_hidden_size(hidden_size)
, m_clip(clip)
, m_activations(to_lower_case(activations))
, m_activation_alpha(activation_alpha)
, m_activation_beta(activation_beta)
, m_activations_alpha(activations_alpha)
, m_activations_beta(activations_beta)
{
}
......@@ -54,13 +54,13 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
op::util::ActivationFunction afunc = get_activation_func_by_name(m_activations.at(idx));
// Set activation functions parameters (if any)
if (m_activation_alpha.size() > idx)
if (m_activations_alpha.size() > idx)
{
afunc.set_alpha(m_activation_alpha.at(idx));
afunc.set_alpha(m_activations_alpha.at(idx));
}
if (m_activation_beta.size() > idx)
if (m_activations_beta.size() > idx)
{
afunc.set_beta(m_activation_beta.at(idx));
afunc.set_beta(m_activations_beta.at(idx));
}
return afunc;
......
......@@ -40,30 +40,34 @@ namespace ngraph
///
/// \brief Constructs a RNNCellBase class.
///
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] clip The value defining clipping range [-clip, clip]
/// on input of activation functions.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activations_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activations_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
///
RNNCellBase(std::size_t hidden_size,
float clip,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta);
const std::vector<float>& activations_alpha,
const std::vector<float>& activations_beta);
std::size_t get_hidden_size() const { return m_hidden_size; }
float get_clip() const { return m_clip; }
const std::vector<std::string>& get_activations() const { return m_activations; }
const std::vector<float>& get_activation_alpha() const
const std::vector<float>& get_activations_alpha() const
{
return m_activation_alpha;
return m_activations_alpha;
}
const std::vector<float>& get_activation_beta() const { return m_activation_beta; }
const std::vector<float>& get_activations_beta() const
{
return m_activations_beta;
}
protected:
///
/// \brief Constructs activation function object.
......@@ -117,9 +121,9 @@ namespace ngraph
const std::size_t m_hidden_size;
const float m_clip;
const std::vector<std::string> m_activations;
const std::vector<float> m_activation_alpha;
const std::vector<float> m_activation_beta;
const std::vector<float> m_activations_alpha;
const std::vector<float> m_activations_beta;
};
}
}
}
} // namespace util
} // namespace op
} // namespace ngraph
......@@ -77,8 +77,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
element::f32, Shape{ref_gates_count * ref_hidden_size, ref_input_size});
auto R = std::make_shared<pattern::op::Label>(
element::f32, Shape{ref_gates_count * ref_hidden_size, ref_hidden_size});
auto bias_ref = std::make_shared<pattern::op::Label>(
element::f32, Shape{2 * ref_gates_count * ref_hidden_size});
auto B = std::make_shared<pattern::op::Label>(element::f32,
Shape{ref_gates_count * ref_hidden_size});
auto peep_hole = std::make_shared<pattern::op::Label>(element::f32, Shape{3 * ref_hidden_size});
auto H_t =
std::make_shared<pattern::op::Label>(element::f32, Shape{ref_batch_size, ref_hidden_size});
......@@ -87,13 +87,14 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
auto ref_lstm_cell =
std::make_shared<op::LSTMCell>(X,
W,
R,
H_t,
C_t,
ref_hidden_size,
bias_ref,
W,
R,
B,
peep_hole,
ref_hidden_size,
op::LSTMWeightsFormat::IOFC,
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
std::vector<float>{},
std::vector<float>{},
......@@ -101,72 +102,27 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
false);
auto callback = [X, W, R, H_t, C_t](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto target_lstm_node = m.get_match_root();
auto lstmcell_op = as_type_ptr<op::LSTMCell>(m.get_match_root());
auto src_iter =
std::make_shared<ngraph::op::Concat>(NodeVector{pattern_map[H_t], pattern_map[C_t]}, 0);
auto bias_iofc = target_lstm_node->get_argument(5);
// we need to reorder W, R and bias from IOFC to IFCO gate order
// Note: ONNX runtime provides W, R and bias in the gate order [IOFC] but
// MKLDNN computes LSTM kernel in the [IFCO] order.
auto get_weights_ifco_gate_order =
[&](std::shared_ptr<Node> weights_graph_node) -> std::shared_ptr<Node> {
// slices will be in ICFO order
std::vector<std::shared_ptr<Node>> gate_slices;
size_t dim0 = weights_graph_node->get_shape()[0] / 4;
size_t dim1 = weights_graph_node->get_shape()[1];
for (size_t i = 0; i < 4; i++)
{
auto slice = std::make_shared<ngraph::op::Slice>(
weights_graph_node, Coordinate{i * dim0, 0}, Coordinate{(i + 1) * dim0, dim1});
gate_slices.push_back(slice);
}
auto weights_ifco = std::make_shared<ngraph::op::Concat>(
NodeVector{gate_slices[0], gate_slices[2], gate_slices[3], gate_slices[1]}, 0);
return std::move(weights_ifco);
};
auto get_bias_ifco_gate_order =
[&](std::shared_ptr<Node> bias_graph_node) -> std::shared_ptr<Node> {
size_t hidden_size = lstmcell_op->get_hidden_size();
auto Wb_bias = std::make_shared<ngraph::op::Slice>(
bias_graph_node, Coordinate{0}, Coordinate{4 * hidden_size});
auto Rb_bias = std::make_shared<ngraph::op::Slice>(
bias_graph_node, Coordinate{4 * hidden_size}, Coordinate{2 * 4 * hidden_size});
auto bias = std::make_shared<op::Add>(Wb_bias, Rb_bias);
auto W_ifco = lstmcell_op->get_argument(3);
auto R_ifco = lstmcell_op->get_argument(4);
auto bias_ifco = lstmcell_op->get_argument(5);
// slices will be in ICFO order
std::vector<std::shared_ptr<Node>> gate_slices;
for (size_t i = 0; i < 4; i++)
{
auto slice = std::make_shared<ngraph::op::Slice>(
bias, Coordinate{i * hidden_size}, Coordinate{(i + 1) * hidden_size});
gate_slices.push_back(slice);
}
auto new_bias = std::make_shared<ngraph::op::Concat>(
NodeVector{gate_slices[0], gate_slices[2], gate_slices[3], gate_slices[1]}, 0);
return std::move(new_bias);
};
auto W_iofc = pattern_map[W];
auto R_iofc = pattern_map[R];
auto W_ifco = get_weights_ifco_gate_order(W_iofc);
auto R_ifco = get_weights_ifco_gate_order(R_iofc);
// here onnx bias will be of shape (2 * gates_count * hidden_size) bias of Wb and Rb are
// concatenated, we will split the bias, add and rearrange in order IFCO
auto bias_ifco = get_bias_ifco_gate_order(bias_iofc);
// We need to reorder W, R and bias to IFCO gate order.
// Note: ie.: ONNX runtime provides W, R and bias in the gate order [IOFC] but
// MKLDNN computes LSTM kernel in the [IFCO] order.
if (lstmcell_op->get_weights_format() != op::LSTMWeightsFormat::IFCO)
{
W_ifco = lstmcell_op->convert_node_format(W_ifco);
R_ifco = lstmcell_op->convert_node_format(R_ifco);
bias_ifco = lstmcell_op->convert_node_format(bias_ifco);
}
auto W_reshape = std::make_shared<op::Reshape>(
W_ifco, AxisVector{1, 0}, Shape{W_ifco->get_shape()[1], W_ifco->get_shape()[0]});
......@@ -595,7 +551,6 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
lstm_weights_layer_label,
lstm_weights_iter_label,
lstm_bias_label](pattern::RecurrentMatcher& m) {
NGRAPH_DEBUG << " In recurrent RNN fusion callback";
auto concat_rnn_inputs_across_timestep =
......@@ -800,7 +755,6 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
lstm_weights_layer_label,
lstm_weights_iter_label,
lstm_bias_label](pattern::RecurrentMatcher& m) {
NGRAPH_DEBUG << " In recurrent RNN fusion callback";
auto concat_rnn_inputs_across_timestep =
......@@ -1161,7 +1115,6 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
// Replace all the users of RNN cell state {ct} across different user.
auto replace_rnn_output_cellstate = [&](std::shared_ptr<Node> rnn_ct_goe1, size_t layer) {
// multi layerd fused rnn second output {GOE1} holds the recurrent output state tensors
// for the last cell of all the layers, {{ht_1 | ct_1} || {ht2 |ct2} || ....{htn | ctn}}
// we will slice the cell state output tensor {ct_*} from the fused RNN kerenel output
......@@ -1211,7 +1164,6 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
// Replace all the users of RNN cell state {ct} across different user.
auto replace_rnn_output_cellstate = [&](std::shared_ptr<Node> rnn_ct_goe2, size_t layer) {
// multi layerd fused rnn second output {GOE2} holds the recurrent output state tensors
// for the last cell
// of all the layers, { ct_1 || ct2 || ....|| ctn}
......@@ -1302,7 +1254,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
// Define a call back that needs to called once the DFG matches the pattern
auto callback = [rnn_left_to_right, rnn_right_to_left](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto rnn_ltor_node =
std::static_pointer_cast<ngraph::op::Rnn>(pattern_map[rnn_left_to_right]);
......@@ -1351,7 +1302,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto construct_birnn_inputs = [&](int index) {
auto nodes =
NodeVector{rnn_ltor_node->get_argument(index), rnn_rtol_node->get_argument(index)};
return std::make_shared<ngraph::op::Concat>(nodes, 0);
......
......@@ -433,6 +433,13 @@ static element::Type read_element_type(json j)
return element::Type(bitwidth, is_real, is_signed, is_quantized, c_type_string);
}
static op::LSTMWeightsFormat read_lstm_weights_format(const json& js)
{
return has_key(js, "weights_format")
? static_cast<op::LSTMWeightsFormat>(js.at("weights_format"))
: op::LSTMWeightsFormat::IFCO;
}
void ngraph::serialize(const string& path, shared_ptr<ngraph::Function> func, size_t indent)
{
ofstream out(path);
......@@ -1828,24 +1835,60 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::LSTMCell:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto weights_format = read_lstm_weights_format(node_js);
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>();
auto activations_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activations_beta = node_js.at("activations_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>();
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
args[3],
args[4],
hidden_size,
args[5],
args[6],
activations,
activation_alpha,
activation_beta,
clip,
input_forget);
if (args.size() == 7)
{
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
hidden_size,
weights_format,
activations,
activations_alpha,
activations_beta,
clip,
input_forget);
}
if (args.size() == 6)
{
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
hidden_size,
weights_format,
activations,
activations_alpha,
activations_beta,
clip,
input_forget);
}
else
{
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
args[3],
args[4],
hidden_size,
weights_format,
activations,
activations_alpha,
activations_beta,
clip,
input_forget);
}
break;
}
case OP_TYPEID::LSTMSequence:
......@@ -1857,6 +1900,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto activations_beta = node_js.at("activations_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>();
auto direction = node_js.at("direction").get<op::LSTMSequence::direction>();
auto weights_format = read_lstm_weights_format(node_js);
if (args.size() == 8)
{
node = make_shared<op::LSTMSequence>(args[0],
......@@ -1869,6 +1913,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[7],
hidden_size,
direction,
weights_format,
activations_alpha,
activations_beta,
activations,
......@@ -1886,6 +1931,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[6],
hidden_size,
direction,
weights_format,
activations_alpha,
activations_beta,
activations,
......@@ -2393,8 +2439,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[1],
args[2],
args[3],
hidden_size,
args[4],
hidden_size,
activations,
activation_alpha,
activation_beta,
......@@ -3418,8 +3464,8 @@ json JSONSerializer::serialize_node(const Node& n)
node["hidden_size"] = tmp->get_hidden_size();
node["clip"] = tmp->get_clip();
node["activations"] = tmp->get_activations();
node["activation_alpha"] = tmp->get_activation_alpha();
node["activation_beta"] = tmp->get_activation_beta();
node["activations_alpha"] = tmp->get_activations_alpha();
node["activations_beta"] = tmp->get_activations_beta();
node["linear_before_reset"] = tmp->get_linear_before_reset();
break;
}
......@@ -3552,10 +3598,11 @@ json JSONSerializer::serialize_node(const Node& n)
{
auto tmp = static_cast<const op::LSTMCell*>(&n);
node["hidden_size"] = tmp->get_hidden_size();
node["weights_format"] = tmp->get_weights_format();
node["clip"] = tmp->get_clip();
node["activations"] = tmp->get_activations();
node["activation_alpha"] = tmp->get_activation_alpha();
node["activation_beta"] = tmp->get_activation_beta();
node["activations_alpha"] = tmp->get_activations_alpha();
node["activations_beta"] = tmp->get_activations_beta();
node["input_forget"] = tmp->get_input_forget();
break;
}
......@@ -3564,6 +3611,7 @@ json JSONSerializer::serialize_node(const Node& n)
auto tmp = dynamic_cast<const op::LSTMSequence*>(&n);
node["direction"] = tmp->get_direction();
node["hidden_size"] = tmp->get_hidden_size();
node["weights_format"] = tmp->get_weights_format();
node["clip_threshold"] = tmp->get_clip_threshold();
node["activations"] = tmp->get_activations();
node["activations_alpha"] = tmp->get_activations_alpha();
......@@ -3936,8 +3984,8 @@ json JSONSerializer::serialize_node(const Node& n)
node["hidden_size"] = tmp->get_hidden_size();
node["clip"] = tmp->get_clip();
node["activations"] = tmp->get_activations();
node["activation_alpha"] = tmp->get_activation_alpha();
node["activation_beta"] = tmp->get_activation_beta();
node["activations_alpha"] = tmp->get_activations_alpha();
node["activations_beta"] = tmp->get_activations_beta();
break;
}
case OP_TYPEID::ScalarConstantLike:
......
This diff is collapsed.
......@@ -3967,12 +3967,14 @@ TEST(cpu_fusion, lstm_cell)
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
auto ht = make_shared<op::GetOutputElement>(lstm_cell, 0);
auto ct = make_shared<op::GetOutputElement>(lstm_cell, 1);
auto lstm_function =
make_shared<Function>(NodeVector{ht, ct}, ParameterVector{X, W, R, H_t, C_t});
auto lstm_function = make_shared<Function>(NodeVector{ht, ct},
ParameterVector{
X, H_t, C_t, W, R,
});
return lstm_function;
};
auto lstm_function_cpu = make_function();
......
......@@ -531,10 +531,10 @@ TEST(serialize, tensor_iterator_lstm)
auto R_body = make_shared<op::Parameter>(element::f32, Shape{4 * H, H});
auto LSTM_cell =
make_shared<op::LSTMCell>(make_shared<op::Reshape>(X, AxisVector{0, 1, 2}, Shape{N, I}),
W_body,
R_body,
make_shared<op::Reshape>(H_t, AxisVector{0, 1, 2}, Shape{N, H}),
make_shared<op::Reshape>(C_t, AxisVector{0, 1, 2}, Shape{N, H}),
W_body,
R_body,
H);
auto H_o = make_shared<op::Reshape>(LSTM_cell->output(0), AxisVector{0, 1}, Shape{N, 1, H});
auto C_o = make_shared<op::Reshape>(LSTM_cell->output(1), AxisVector{0, 1}, Shape{N, 1, H});
......
......@@ -87,7 +87,8 @@ TEST(type_prop, gru_cell_invalid_input)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor H_t must have shape"));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input tensor initial_hidden_state must have shape"));
}
// Invalid B tensor shape.
......
......@@ -36,7 +36,16 @@ TEST(type_prop, lstm_cell)
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
EXPECT_EQ(lstm_cell->get_hidden_size(), hidden_size);
EXPECT_EQ(lstm_cell->get_clip(), 0.f);
EXPECT_TRUE(lstm_cell->get_activations_alpha().empty());
EXPECT_TRUE(lstm_cell->get_activations_beta().empty());
EXPECT_EQ(lstm_cell->get_activations()[0], "sigmoid");
EXPECT_EQ(lstm_cell->get_activations()[1], "tanh");
EXPECT_EQ(lstm_cell->get_activations()[2], "tanh");
EXPECT_EQ(lstm_cell->get_weights_format(), op::LSTMWeightsFormat::IFCO);
EXPECT_FALSE(lstm_cell->get_input_forget());
EXPECT_EQ(lstm_cell->output(0).get_element_type(), element::f32);
EXPECT_EQ(lstm_cell->output(0).get_shape(), (Shape{batch_size, hidden_size}));
EXPECT_EQ(lstm_cell->output(1).get_element_type(), element::f32);
......@@ -60,7 +69,7 @@ TEST(type_prop, lstm_cell_invalid_input)
auto W = make_shared<op::Parameter>(element::f32, Shape{1 * hidden_size, input_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
......@@ -73,7 +82,7 @@ TEST(type_prop, lstm_cell_invalid_input)
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, 1});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
......@@ -86,12 +95,13 @@ TEST(type_prop, lstm_cell_invalid_input)
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor H_t must have shape"));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input tensor initial_hidden_state must have shape"));
}
// Invalid C_t tensor shape.
......@@ -99,21 +109,22 @@ TEST(type_prop, lstm_cell_invalid_input)
C_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor C_t must have shape"));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input tensor initial_cell_state must have shape"));
}
// Invalid B tensor shape.
C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size, B, P);
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
......@@ -122,11 +133,11 @@ TEST(type_prop, lstm_cell_invalid_input)
}
// Invalid P tensor shape.
B = make_shared<op::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
P = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size, B, P);
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
......
......@@ -28,7 +28,7 @@ TEST(type_prop, lstm_sequence)
const auto R = make_shared<op::Parameter>(element::f32, Shape{1, 12, 3});
const auto initial_hidden_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto B = make_shared<op::Parameter>(element::f32, Shape{1, 24});
const auto B = make_shared<op::Parameter>(element::f32, Shape{1, 12});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{2});
const auto hidden_size = 3;
......@@ -41,6 +41,20 @@ TEST(type_prop, lstm_sequence)
B,
hidden_size,
op::LSTMSequence::direction::FORWARD);
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
EXPECT_EQ(lstm_sequence->get_direction(), op::LSTMSequence::direction::FORWARD);
EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::IFCO);
EXPECT_TRUE(lstm_sequence->get_activations_alpha().empty());
EXPECT_TRUE(lstm_sequence->get_activations_beta().empty());
EXPECT_EQ(lstm_sequence->get_activations()[0], "sigmoid");
EXPECT_EQ(lstm_sequence->get_activations()[1], "tanh");
EXPECT_EQ(lstm_sequence->get_activations()[2], "tanh");
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f);
EXPECT_FALSE(lstm_sequence->get_input_forget());
EXPECT_EQ(lstm_sequence->output(0).get_element_type(), element::f32);
EXPECT_EQ(lstm_sequence->output(0).get_shape(), (Shape{1, 1, 2, 3}));
EXPECT_EQ(lstm_sequence->output(1).get_element_type(), element::f32);
EXPECT_EQ(lstm_sequence->output(1).get_shape(), (Shape{1, 2, 3}));
EXPECT_EQ(lstm_sequence->output(2).get_element_type(), element::f32);
EXPECT_EQ(lstm_sequence->output(2).get_shape(), (Shape{1, 2, 3}));
}
......@@ -28,11 +28,11 @@ TEST(type_prop, rnn_cell)
const size_t hidden_size = 3;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
EXPECT_EQ(rnn_cell->output(0).get_element_type(), element::f32);
EXPECT_EQ(rnn_cell->output(0).get_shape(), (Shape{batch_size, hidden_size}));
}
......@@ -51,7 +51,7 @@ TEST(type_prop, rnn_cell_invalid_input)
auto W = make_shared<op::Parameter>(element::f32, Shape{2 * hidden_size, input_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
......@@ -64,7 +64,7 @@ TEST(type_prop, rnn_cell_invalid_input)
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
......@@ -77,20 +77,21 @@ TEST(type_prop, rnn_cell_invalid_input)
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor H_t must have shape"));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input tensor initial_hidden_state must have shape"));
}
// Invalid B tensor shape.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{2 * hidden_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size, B);
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, B, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment