Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
53a6af8d
Commit
53a6af8d
authored
Nov 08, 2019
by
Adam Rogowiec
Committed by
Michał Karzyński
Nov 08, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SPEC] LSTMCell, RNNCell updates. (#3733)
parent
bc0fd13f
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
23 changed files
with
375 additions
and
317 deletions
+375
-317
ops.py
python/ngraph/ops.py
+21
-20
rnn_cell.cpp
python/pyngraph/ops/fused/rnn_cell.cpp
+1
-1
test_ops_fused.py
python/test/ngraph/test_ops_fused.py
+10
-11
lstm.cpp
src/ngraph/frontend/onnx_import/op/lstm.cpp
+9
-4
gru_cell.cpp
src/ngraph/op/fused/gru_cell.cpp
+17
-17
gru_cell.hpp
src/ngraph/op/fused/gru_cell.hpp
+55
-49
lstm_cell.cpp
src/ngraph/op/fused/lstm_cell.cpp
+0
-0
lstm_cell.hpp
src/ngraph/op/fused/lstm_cell.hpp
+0
-0
lstm_sequence.cpp
src/ngraph/op/fused/lstm_sequence.cpp
+45
-18
lstm_sequence.hpp
src/ngraph/op/fused/lstm_sequence.hpp
+10
-0
rnn_cell.cpp
src/ngraph/op/fused/rnn_cell.cpp
+32
-53
rnn_cell.hpp
src/ngraph/op/fused/rnn_cell.hpp
+0
-0
rnn_cell_base.cpp
src/ngraph/op/util/rnn_cell_base.cpp
+8
-8
rnn_cell_base.hpp
src/ngraph/op/util/rnn_cell_base.hpp
+23
-19
cpu_rnn_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_rnn_fusion.cpp
+19
-69
serializer.cpp
src/ngraph/serializer.cpp
+70
-22
fused_op.in.cpp
test/backend/fused_op.in.cpp
+0
-0
cpu_fusion.cpp
test/cpu_fusion.cpp
+5
-3
serialize.cpp
test/serialize.cpp
+2
-2
gru_cell.cpp
test/type_prop/gru_cell.cpp
+2
-1
lstm_cell.cpp
test/type_prop/lstm_cell.cpp
+22
-11
lstm_sequence.cpp
test/type_prop/lstm_sequence.cpp
+15
-1
rnn_cell.cpp
test/type_prop/rnn_cell.cpp
+9
-8
No files found.
python/ngraph/ops.py
View file @
53a6af8d
...
...
@@ -242,11 +242,11 @@ def group_convolution(data_batch, # type: Node
@nameable_op
def
rnn_cell
(
X
,
# type: Node
H_t
,
# type: Node
W
,
# type: Node
R
,
# type: Node
H_t
,
# type: Node
hidden_size
,
# type: int
B
,
# type: Node
hidden_size
,
# type: int
activations
,
# type: List[str]
activation_alpha
,
# type: List[float]
activation_beta
,
# type: List[float]
...
...
@@ -261,29 +261,30 @@ def rnn_cell(X, # type: Node
Note this class represents only single *cell* and not whole RNN *layer*.
:param X: The input tensor with shape: [batch_size, input_size].
:param W: The weight tensor with shape: [hidden_size, input_size].
:param R: The recurrence weight tensor with shape: [hidden_size, hidden_size].
:param H_t: The hidden state tensor at current time step with
shape: [batch_size, hidden_size].
:param hidden_size: The number of hidden units for recurrent cell.
:param B: The bias tensor for input gate with shape: [2*hidden_size].
:param activations: The vector of activation functions used inside recurrent cell.
:param activation_alpha: The vector of alpha parameters for activation
functions in order respective to activation list.
:param activation_beta: The vector of beta parameters for activation functions
in order respective to activation list.
:param clip: The value defining clipping range [-clip, clip] on
input of activation functions.
:param name: Optional output node name.
:return: The new node performing a RNNCell operation on tensor from input node.
:param X: The input tensor with shape: [batch_size, input_size].
:param H_t: The hidden state tensor at current time step with shape:
[batch_size, hidden_size].
:param W: The weight tensor with shape: [hidden_size, input_size].
:param R: The recurrence weight tensor with shape: [hidden_size,
hidden_size].
:param B: The bias tensor for input gate with shape: [2*hidden_size].
:param hidden_size: The number of hidden units for recurrent cell.
:param activations: The vector of activation functions used inside recurrent cell.
:param activation_alpha: The vector of alpha parameters for activation functions in
order respective to activation list.
:param activation_beta: The vector of beta parameters for activation functions in order
respective to activation list.
:param clip: The value defining clipping range [-clip, clip] on input of
activation functions.
:param name: Optional output node name.
:returns: The new node performing a RNNCell operation on tensor from input node.
"""
return
RNNCell
(
X
,
H_t
,
W
,
R
,
H_t
,
hidden_size
,
B
,
hidden_size
,
activations
,
activation_alpha
,
activation_beta
,
...
...
python/pyngraph/ops/fused/rnn_cell.cpp
View file @
53a6af8d
...
...
@@ -31,8 +31,8 @@ void regclass_pyngraph_op_RNNCell(py::module m)
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
int
&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
int
&
,
const
std
::
vector
<
std
::
string
>&
,
const
std
::
vector
<
float
>&
,
const
std
::
vector
<
float
>&
,
...
...
python/test/ngraph/test_ops_fused.py
View file @
53a6af8d
...
...
@@ -455,17 +455,20 @@ def test_rnn_cell_operator():
W_shape
=
[
hidden_size
,
input_size
]
R_shape
=
[
hidden_size
,
hidden_size
]
H_t_shape
=
[
batch_size
,
hidden_size
]
B_shape
=
[
2
*
hidden_size
]
B_shape
=
[
hidden_size
]
parameter_X
=
ng
.
parameter
(
X_shape
,
name
=
'X'
,
dtype
=
np
.
float32
)
parameter_H_t
=
ng
.
parameter
(
H_t_shape
,
name
=
'H_t'
,
dtype
=
np
.
float32
)
parameter_W
=
ng
.
parameter
(
W_shape
,
name
=
'W'
,
dtype
=
np
.
float32
)
parameter_R
=
ng
.
parameter
(
R_shape
,
name
=
'R'
,
dtype
=
np
.
float32
)
parameter_H_t
=
ng
.
parameter
(
H_t_shape
,
name
=
'H_t'
,
dtype
=
np
.
float32
)
parameter_B
=
ng
.
parameter
(
B_shape
,
name
=
'B'
,
dtype
=
np
.
float32
)
X_value
=
np
.
array
([
0.3432185
,
0.612268
,
0.20272376
,
0.9513413
,
0.30585995
,
0.7265472
],
dtype
=
np
.
float32
)
.
reshape
(
X_shape
)
H_t_value
=
np
.
array
([
0.12444675
,
0.52055854
,
0.46489045
,
0.4983964
,
0.7730452
,
0.28439692
],
dtype
=
np
.
float32
)
.
reshape
(
H_t_shape
)
W_value
=
np
.
array
([
0.41930267
,
0.7872176
,
0.89940447
,
0.23659843
,
0.24676207
,
0.17101714
,
0.3147149
,
0.6555601
,
0.4559603
],
...
...
@@ -474,11 +477,7 @@ def test_rnn_cell_operator():
0.71549815
,
0.18775631
,
0.3182116
,
0.25392973
,
0.38301638
,
0.85531586
],
dtype
=
np
.
float32
)
.
reshape
(
R_shape
)
H_t_value
=
np
.
array
([
0.12444675
,
0.52055854
,
0.46489045
,
0.4983964
,
0.7730452
,
0.28439692
],
dtype
=
np
.
float32
)
.
reshape
(
H_t_shape
)
B_value
=
np
.
array
([
0.45513555
,
0.96227735
,
0.24737759
,
0.57380486
,
0.67398053
,
0.18968852
],
B_value
=
np
.
array
([
1.0289404
,
1.6362579
,
0.4370661
],
dtype
=
np
.
float32
)
.
reshape
(
B_shape
)
activations
=
[
'sigmoid'
]
activation_alpha
=
[]
...
...
@@ -486,23 +485,23 @@ def test_rnn_cell_operator():
clip
=
2.88
model
=
ng
.
rnn_cell
(
parameter_X
,
parameter_H_t
,
parameter_W
,
parameter_R
,
parameter_H_t
,
hidden_size
,
parameter_B
,
hidden_size
,
activations
,
activation_alpha
,
activation_beta
,
clip
)
computation
=
runtime
.
computation
(
model
,
parameter_X
,
parameter_H_t
,
parameter_W
,
parameter_R
,
parameter_H_t
,
parameter_B
)
result
=
computation
(
X_value
,
W_value
,
R_value
,
H_t
_value
,
B_value
)
result
=
computation
(
X_value
,
H_t_value
,
W_value
,
R
_value
,
B_value
)
expected
=
np
.
array
([
0.94126844
,
0.9036043
,
0.841243
,
0.9468489
,
0.934215
,
0.873708
],
dtype
=
np
.
float32
)
.
reshape
(
batch_size
,
hidden_size
)
...
...
src/ngraph/frontend/onnx_import/op/lstm.cpp
View file @
53a6af8d
...
...
@@ -22,7 +22,9 @@
#include <vector>
#include "exceptions.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/get_output_element.hpp"
...
...
@@ -82,17 +84,19 @@ namespace ngraph
m_map
[
LSTMInput
::
LSTM_INPUT_W
]
->
get_shape
().
front
();
// ------ Optional inputs ------
// The bias tensor for input gate. Shape [num_directions,
8
*hidden_size]
// The bias tensor for input gate. Shape [num_directions,
4
*hidden_size]
if
(
ng_inputs
.
size
()
>
3
&&
!
ng_inputs
.
at
(
3
)
->
is_null
())
{
m_map
[
LSTMInput
::
LSTM_INPUT_B
]
=
ng_inputs
.
at
(
3
);
auto
bias
=
ng_inputs
.
at
(
3
);
auto
split_bias
=
builder
::
split
(
bias
,
2
,
1
);
m_map
[
LSTMInput
::
LSTM_INPUT_B
]
=
split_bias
.
at
(
0
)
+
split_bias
.
at
(
1
);
}
else
{
m_map
[
LSTMInput
::
LSTM_INPUT_B
]
=
ngraph
::
op
::
Constant
::
create
(
element
::
f32
,
Shape
{
num_directions
,
2
*
gates_count
*
hidden_size
},
std
::
vector
<
float
>
(
num_directions
*
2
*
gates_count
*
hidden_size
,
Shape
{
num_directions
,
gates_count
*
hidden_size
},
std
::
vector
<
float
>
(
num_directions
*
gates_count
*
hidden_size
,
0.
f
));
}
// The lengths of the sequences in a batch. Shape [batch_size]
...
...
@@ -224,6 +228,7 @@ namespace ngraph
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_P
),
attributes
.
m_hidden_size
,
attributes
.
m_direction
,
ngraph
::
op
::
LSTMWeightsFormat
::
IOFC
,
attributes
.
m_activation_alpha
,
attributes
.
m_activation_beta
,
attributes
.
m_activations
,
...
...
src/ngraph/op/fused/gru_cell.cpp
View file @
53a6af8d
...
...
@@ -33,12 +33,12 @@ constexpr NodeTypeInfo op::GRUCell::type_info;
op
::
GRUCell
::
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
size_t
hidden_size
)
:
GRUCell
(
X
,
W
,
R
,
H_t
,
initial_hidden_state
,
hidden_size
,
vector
<
string
>
{
"sigmoid"
,
"tanh"
},
vector
<
float
>
{},
...
...
@@ -51,15 +51,15 @@ op::GRUCell::GRUCell(const Output<Node>& X,
op
::
GRUCell
::
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
size_t
hidden_size
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation_beta
,
const
vector
<
float
>&
activation
s
_alpha
,
const
vector
<
float
>&
activation
s
_beta
,
float
clip
,
bool
linear_before_reset
)
:
FusedOp
({
X
,
W
,
R
,
H_t
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
_alpha
,
activation
_beta
)
:
FusedOp
({
X
,
W
,
R
,
initial_hidden_state
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
s_alpha
,
activations
_beta
)
,
m_activation_f
{
get_activation_function
(
0
)}
,
m_activation_g
{
get_activation_function
(
1
)}
,
m_linear_before_reset
{
linear_before_reset
}
...
...
@@ -71,16 +71,16 @@ op::GRUCell::GRUCell(const Output<Node>& X,
op
::
GRUCell
::
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
size_t
hidden_size
,
const
Output
<
Node
>&
B
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation_beta
,
const
vector
<
float
>&
activation
s
_alpha
,
const
vector
<
float
>&
activation
s
_beta
,
float
clip
,
bool
linear_before_reset
)
:
FusedOp
({
X
,
W
,
R
,
H_t
,
B
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
_alpha
,
activation
_beta
)
:
FusedOp
({
X
,
W
,
R
,
initial_hidden_state
,
B
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
s_alpha
,
activations
_beta
)
,
m_activation_f
{
get_activation_function
(
0
)}
,
m_activation_g
{
get_activation_function
(
1
)}
,
m_linear_before_reset
{
linear_before_reset
}
...
...
@@ -129,7 +129,7 @@ void op::GRUCell::pre_validate_and_infer_types()
"."
);
NODE_VALIDATION_CHECK
(
this
,
(
ht_shape
==
Shape
{
batch_size
,
get_hidden_size
()}),
"Input tensor
H_t
must have shape ("
,
"Input tensor
initial_hidden_state
must have shape ("
,
batch_size
,
", "
,
get_hidden_size
(),
...
...
@@ -290,8 +290,8 @@ shared_ptr<Node> op::GRUCell::copy_with_new_args(const NodeVector& new_args) con
new_args
.
at
(
3
),
get_hidden_size
(),
get_activations
(),
get_activation_alpha
(),
get_activation_beta
(),
get_activation
s
_alpha
(),
get_activation
s
_beta
(),
get_clip
(),
m_linear_before_reset
);
}
...
...
@@ -304,8 +304,8 @@ shared_ptr<Node> op::GRUCell::copy_with_new_args(const NodeVector& new_args) con
get_hidden_size
(),
new_args
.
at
(
4
),
get_activations
(),
get_activation_alpha
(),
get_activation_beta
(),
get_activation
s
_alpha
(),
get_activation
s
_beta
(),
get_clip
(),
m_linear_before_reset
);
}
...
...
src/ngraph/op/fused/gru_cell.hpp
View file @
53a6af8d
...
...
@@ -47,84 +47,90 @@ namespace ngraph
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] X The input tensor with shape: [batch_size,
/// input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
std
::
size_t
hidden_size
);
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] X The input tensor with shape: [batch_size,
/// input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activations_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activations_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
std
::
size_t
hidden_size
,
const
std
::
vector
<
std
::
string
>&
activations
,
const
std
::
vector
<
float
>&
activation_alpha
,
const
std
::
vector
<
float
>&
activation_beta
,
const
std
::
vector
<
float
>&
activation
s
_alpha
,
const
std
::
vector
<
float
>&
activation
s
_beta
,
float
clip
,
bool
linear_before_reset
);
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape:
/// [2 * gates_count * hidden_size].
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] X The input tensor with shape: [batch_size,
/// input_size].
/// \param[in] W The weight tensor with shape: [gates_count *
/// hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape:
/// [2 * gates_count * hidden_size].
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activations_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activations_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] linear_before_reset Whether or not to apply the linear transformation
/// before multiplying by the output of the reset
/// gate.
///
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
std
::
size_t
hidden_size
,
const
Output
<
Node
>&
B
,
const
std
::
vector
<
std
::
string
>&
activations
=
std
::
vector
<
std
::
string
>
{
"sigmoid"
,
"tanh"
},
const
std
::
vector
<
float
>&
activation_alpha
=
{},
const
std
::
vector
<
float
>&
activation_beta
=
{},
const
std
::
vector
<
float
>&
activation
s
_alpha
=
{},
const
std
::
vector
<
float
>&
activation
s
_beta
=
{},
float
clip
=
0.
f
,
bool
linear_before_reset
=
false
);
...
...
src/ngraph/op/fused/lstm_cell.cpp
View file @
53a6af8d
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/lstm_cell.hpp
View file @
53a6af8d
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/lstm_sequence.cpp
View file @
53a6af8d
...
...
@@ -58,21 +58,47 @@ NodeVector op::LSTMSequence::decompose_op() const
shared_ptr
<
Node
>
op
::
LSTMSequence
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
LSTMSequence
>
(
new_args
.
at
(
0
),
// X
new_args
.
at
(
1
),
// initial_hidden_state
new_args
.
at
(
2
),
// initial_cell_state
new_args
.
at
(
3
),
// sequence_lengths
new_args
.
at
(
4
),
// W
new_args
.
at
(
5
),
// R
new_args
.
at
(
6
),
// B
new_args
.
at
(
7
),
// P
m_hidden_size
,
m_direction
,
m_activations_alpha
,
m_activations_beta
,
m_activations
,
m_clip_threshold
,
m_input_forget
);
if
(
new_args
.
size
()
==
8
)
{
return
make_shared
<
LSTMSequence
>
(
new_args
.
at
(
0
),
// X
new_args
.
at
(
1
),
// initial_hidden_state
new_args
.
at
(
2
),
// initial_cell_state
new_args
.
at
(
3
),
// sequence_lengths
new_args
.
at
(
4
),
// W
new_args
.
at
(
5
),
// R
new_args
.
at
(
6
),
// B
new_args
.
at
(
7
),
// P
m_hidden_size
,
m_direction
,
m_weights_format
,
m_activations_alpha
,
m_activations_beta
,
m_activations
,
m_clip_threshold
,
m_input_forget
);
}
else
if
(
new_args
.
size
()
==
7
)
{
return
make_shared
<
LSTMSequence
>
(
new_args
.
at
(
0
),
// X
new_args
.
at
(
1
),
// initial_hidden_state
new_args
.
at
(
2
),
// initial_cell_state
new_args
.
at
(
3
),
// sequence_lengths
new_args
.
at
(
4
),
// W
new_args
.
at
(
5
),
// R
new_args
.
at
(
6
),
// B
m_hidden_size
,
m_direction
,
m_weights_format
,
m_activations_alpha
,
m_activations_beta
,
m_activations
,
m_clip_threshold
,
m_input_forget
);
}
else
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
}
shared_ptr
<
Node
>
op
::
LSTMSequence
::
get_masked_node
(
const
shared_ptr
<
Node
>&
data
,
...
...
@@ -157,13 +183,14 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
for
(
const
auto
&
in_x
:
in_seqs
)
{
shared_ptr
<
Node
>
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
in_x
,
W
,
R
,
H_t
,
C_t
,
m_hidden_size
,
W
,
R
,
B
,
P
,
m_hidden_size
,
m_weights_format
,
m_activations
,
m_activations_alpha
,
m_activations_beta
,
...
...
src/ngraph/op/fused/lstm_sequence.hpp
View file @
53a6af8d
...
...
@@ -24,6 +24,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace
ngraph
...
...
@@ -36,6 +37,9 @@ namespace ngraph
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
///
/// \sa LSTMCell, RNNCell, GRUCell
///
///
class
LSTMSequence
:
public
util
::
FusedOp
{
public
:
...
...
@@ -61,6 +65,7 @@ namespace ngraph
const
Output
<
Node
>&
P
,
const
std
::
int64_t
hidden_size
,
const
direction
lstm_direction
,
LSTMWeightsFormat
weights_format
=
LSTMWeightsFormat
::
IFCO
,
const
std
::
vector
<
float
>
activations_alpha
=
{},
const
std
::
vector
<
float
>
activations_beta
=
{},
const
std
::
vector
<
std
::
string
>
activations
=
{
"sigmoid"
,
...
...
@@ -77,6 +82,7 @@ namespace ngraph
,
m_direction
(
lstm_direction
)
,
m_hidden_size
(
hidden_size
)
,
m_input_forget
(
input_forget
)
,
m_weights_format
(
weights_format
)
{
constructor_validate_and_infer_types
();
}
...
...
@@ -90,6 +96,7 @@ namespace ngraph
const
Output
<
Node
>&
B
,
const
std
::
int64_t
hidden_size
,
const
direction
lstm_direction
,
LSTMWeightsFormat
weights_format
=
LSTMWeightsFormat
::
IFCO
,
const
std
::
vector
<
float
>
activations_alpha
=
{},
const
std
::
vector
<
float
>
activations_beta
=
{},
const
std
::
vector
<
std
::
string
>
activations
=
{
"sigmoid"
,
...
...
@@ -111,6 +118,7 @@ namespace ngraph
std
::
vector
<
float
>
{
0.
f
}),
hidden_size
,
lstm_direction
,
weights_format
,
activations_alpha
,
activations_beta
,
activations
,
...
...
@@ -131,6 +139,7 @@ namespace ngraph
direction
get_direction
()
const
{
return
m_direction
;
}
std
::
int64_t
get_hidden_size
()
const
{
return
m_hidden_size
;
}
bool
get_input_forget
()
const
{
return
m_input_forget
;
}
LSTMWeightsFormat
get_weights_format
()
const
{
return
m_weights_format
;
}
private
:
///
/// \brief Gets the masked value according to sequence lenght in a batch.
...
...
@@ -163,6 +172,7 @@ namespace ngraph
direction
m_direction
;
std
::
int64_t
m_hidden_size
;
bool
m_input_forget
;
LSTMWeightsFormat
m_weights_format
;
};
}
// namespace op
}
// namespace ngraph
src/ngraph/op/fused/rnn_cell.cpp
View file @
53a6af8d
...
...
@@ -32,44 +32,34 @@ using namespace ngraph;
constexpr
NodeTypeInfo
op
::
RNNCell
::
type_info
;
op
::
RNNCell
::
RNNCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
initial_hidden_state
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
size_t
hidden_size
)
:
RNNCell
(
X
,
W
,
R
,
H_t
,
hidden_size
,
vector
<
string
>
{
"tanh"
},
vector
<
float
>
{},
vector
<
float
>
{},
0.
f
)
{
}
op
::
RNNCell
::
RNNCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
size_t
hidden_size
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation_beta
,
const
vector
<
float
>&
activation
s
_alpha
,
const
vector
<
float
>&
activation
s
_beta
,
float
clip
)
:
FusedOp
({
X
,
W
,
R
,
H_t
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
_alpha
,
activation
_beta
)
:
FusedOp
({
X
,
initial_hidden_state
,
W
,
R
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
s_alpha
,
activations
_beta
)
,
m_activation_f
{
get_activation_function
(
0
)}
{
add_default_bias_input
(
);
set_argument
(
4
,
get_default_bias_input
()
);
constructor_validate_and_infer_types
();
}
op
::
RNNCell
::
RNNCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
initial_hidden_state
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
size_t
hidden_size
,
const
Output
<
Node
>&
B
,
size_t
hidden_size
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation_beta
,
const
vector
<
float
>&
activation
s
_alpha
,
const
vector
<
float
>&
activation
s
_beta
,
float
clip
)
:
FusedOp
({
X
,
W
,
R
,
H_t
,
B
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
_alpha
,
activation
_beta
)
:
FusedOp
({
X
,
initial_hidden_state
,
W
,
R
,
B
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
s_alpha
,
activations
_beta
)
,
m_activation_f
{
get_activation_function
(
0
)}
{
constructor_validate_and_infer_types
();
...
...
@@ -83,9 +73,9 @@ void op::RNNCell::pre_validate_and_infer_types()
}
const
auto
&
x_pshape
=
get_input_partial_shape
(
0
);
const
auto
&
w
_pshape
=
get_input_partial_shape
(
1
);
const
auto
&
r
_pshape
=
get_input_partial_shape
(
2
);
const
auto
&
ht
_pshape
=
get_input_partial_shape
(
3
);
const
auto
&
ht
_pshape
=
get_input_partial_shape
(
1
);
const
auto
&
w
_pshape
=
get_input_partial_shape
(
2
);
const
auto
&
r
_pshape
=
get_input_partial_shape
(
3
);
NODE_VALIDATION_CHECK
(
this
,
(
x_pshape
.
is_static
()
||
w_pshape
.
is_static
()
||
r_pshape
.
is_static
()
||
...
...
@@ -121,7 +111,7 @@ void op::RNNCell::pre_validate_and_infer_types()
"."
);
NODE_VALIDATION_CHECK
(
this
,
(
ht_shape
==
Shape
{
batch_size
,
get_hidden_size
()}),
"Input tensor
H_t
must have shape ("
,
"Input tensor
initial_hidden_state
must have shape ("
,
batch_size
,
", "
,
get_hidden_size
(),
...
...
@@ -137,9 +127,9 @@ void op::RNNCell::pre_validate_and_infer_types()
const
Shape
&
b_shape
{
b_pshape
.
to_shape
()};
NODE_VALIDATION_CHECK
(
this
,
(
b_shape
==
Shape
{
2
*
get_hidden_size
()}),
(
b_shape
==
Shape
{
get_hidden_size
()}),
"Input tensor B must have shape ("
,
2
*
get_hidden_size
(),
get_hidden_size
(),
"). Actual shape is:"
,
b_shape
,
"."
);
...
...
@@ -157,8 +147,7 @@ NodeVector op::RNNCell::decompose_op() const
// W - The weight tensor for input gate. Shape: [hidden_size, input_size].
// R - The recurrence weight tensor for input gate. Shape: [hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
// B - The bias tensor for the input gate. Shape: [2 * hidden_size].
// Concatenation of `[Wb, Rb]`.
// B - The bias tensor for the input gate. Shape: [hidden_size].
// Wb - W bias vectors for input gate.
// Rb - R bias vectors for input gate.
// ------ VARIABLE NAMES ------
...
...
@@ -174,10 +163,10 @@ NodeVector op::RNNCell::decompose_op() const
// --------------------
Output
<
Node
>
X
=
input_value
(
0
);
Output
<
Node
>
W
=
input_value
(
1
);
Output
<
Node
>
R
=
input_value
(
2
);
Output
<
Node
>
H_t
=
input_value
(
3
);
Output
<
Node
>
bias
=
get_bias
(
);
Output
<
Node
>
H_t
=
input_value
(
1
);
Output
<
Node
>
W
=
input_value
(
2
);
Output
<
Node
>
R
=
input_value
(
3
);
Output
<
Node
>
bias
=
input_value
(
4
);
// Xt*(W^T)
auto
Xt_W
=
std
::
make_shared
<
op
::
Dot
>
(
X
,
builder
::
transpose
(
W
));
...
...
@@ -192,22 +181,12 @@ NodeVector op::RNNCell::decompose_op() const
return
{
i_t
};
}
Output
<
Node
>
op
::
RNNCell
::
get_
bias
()
const
Output
<
Node
>
op
::
RNNCell
::
get_
default_bias_input
()
const
{
Output
<
Node
>
bias
;
// Split B onto Wb an Rb and add them.
NodeVector
b_W_R
=
builder
::
split
(
input_value
(
4
),
2
);
bias
=
b_W_R
.
at
(
0
)
+
b_W_R
.
at
(
1
);
return
bias
;
}
void
op
::
RNNCell
::
add_default_bias_input
()
{
Output
<
Node
>
B
=
return
Output
<
Node
>
{
op
::
Constant
::
create
(
input
(
0
).
get_element_type
(),
Shape
{
2
*
s_gates_count
*
get_hidden_size
()},
vector
<
float
>
(
2
*
s_gates_count
*
get_hidden_size
(),
0.
f
));
set_argument
(
4
,
B
);
Shape
{
s_gates_count
*
get_hidden_size
()},
vector
<
float
>
(
s_gates_count
*
get_hidden_size
(),
0.
f
))};
}
shared_ptr
<
Node
>
op
::
RNNCell
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
@@ -221,8 +200,8 @@ shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) con
new_args
.
at
(
3
),
get_hidden_size
(),
get_activations
(),
get_activation_alpha
(),
get_activation_beta
(),
get_activation
s
_alpha
(),
get_activation
s
_beta
(),
get_clip
());
}
else
if
(
new_args
.
size
()
==
5
)
...
...
@@ -231,11 +210,11 @@ shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) con
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
),
get_hidden_size
(),
new_args
.
at
(
4
),
get_hidden_size
(),
get_activations
(),
get_activation_alpha
(),
get_activation_beta
(),
get_activation
s
_alpha
(),
get_activation
s
_beta
(),
get_clip
());
}
else
...
...
src/ngraph/op/fused/rnn_cell.hpp
View file @
53a6af8d
This diff is collapsed.
Click to expand it.
src/ngraph/op/util/rnn_cell_base.cpp
View file @
53a6af8d
...
...
@@ -39,13 +39,13 @@ static vector<string> to_lower_case(const vector<string>& vs)
op
::
util
::
RNNCellBase
::
RNNCellBase
(
size_t
hidden_size
,
float
clip
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation_beta
)
const
vector
<
float
>&
activation
s
_alpha
,
const
vector
<
float
>&
activation
s
_beta
)
:
m_hidden_size
(
hidden_size
)
,
m_clip
(
clip
)
,
m_activations
(
to_lower_case
(
activations
))
,
m_activation
_alpha
(
activation
_alpha
)
,
m_activation
_beta
(
activation
_beta
)
,
m_activation
s_alpha
(
activations
_alpha
)
,
m_activation
s_beta
(
activations
_beta
)
{
}
...
...
@@ -54,13 +54,13 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
op
::
util
::
ActivationFunction
afunc
=
get_activation_func_by_name
(
m_activations
.
at
(
idx
));
// Set activation functions parameters (if any)
if
(
m_activation_alpha
.
size
()
>
idx
)
if
(
m_activation
s
_alpha
.
size
()
>
idx
)
{
afunc
.
set_alpha
(
m_activation_alpha
.
at
(
idx
));
afunc
.
set_alpha
(
m_activation
s
_alpha
.
at
(
idx
));
}
if
(
m_activation_beta
.
size
()
>
idx
)
if
(
m_activation
s
_beta
.
size
()
>
idx
)
{
afunc
.
set_beta
(
m_activation_beta
.
at
(
idx
));
afunc
.
set_beta
(
m_activation
s
_beta
.
at
(
idx
));
}
return
afunc
;
...
...
src/ngraph/op/util/rnn_cell_base.hpp
View file @
53a6af8d
...
...
@@ -40,30 +40,34 @@ namespace ngraph
///
/// \brief Constructs a RNNCellBase class.
///
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] clip
The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
/// \param[in] hidden_size
The number of hidden units for recurrent cell.
/// \param[in] clip
The value defining clipping range [-clip, clip]
///
on
input of activation functions.
/// \param[in] activations
The vector of activation functions used inside
///
recurrent cell.
/// \param[in] activation
s
_alpha The vector of alpha parameters for activation
///
functions in order respective to activation list.
/// \param[in] activation
s
_beta The vector of beta parameters for activation
///
functions in order respective to activation list.
///
RNNCellBase
(
std
::
size_t
hidden_size
,
float
clip
,
const
std
::
vector
<
std
::
string
>&
activations
,
const
std
::
vector
<
float
>&
activation_alpha
,
const
std
::
vector
<
float
>&
activation_beta
);
const
std
::
vector
<
float
>&
activation
s
_alpha
,
const
std
::
vector
<
float
>&
activation
s
_beta
);
std
::
size_t
get_hidden_size
()
const
{
return
m_hidden_size
;
}
float
get_clip
()
const
{
return
m_clip
;
}
const
std
::
vector
<
std
::
string
>&
get_activations
()
const
{
return
m_activations
;
}
const
std
::
vector
<
float
>&
get_activation_alpha
()
const
const
std
::
vector
<
float
>&
get_activation
s
_alpha
()
const
{
return
m_activation_alpha
;
return
m_activation
s
_alpha
;
}
const
std
::
vector
<
float
>&
get_activation_beta
()
const
{
return
m_activation_beta
;
}
const
std
::
vector
<
float
>&
get_activations_beta
()
const
{
return
m_activations_beta
;
}
protected
:
///
/// \brief Constructs activation function object.
...
...
@@ -117,9 +121,9 @@ namespace ngraph
const
std
::
size_t
m_hidden_size
;
const
float
m_clip
;
const
std
::
vector
<
std
::
string
>
m_activations
;
const
std
::
vector
<
float
>
m_activation_alpha
;
const
std
::
vector
<
float
>
m_activation_beta
;
const
std
::
vector
<
float
>
m_activation
s
_alpha
;
const
std
::
vector
<
float
>
m_activation
s
_beta
;
};
}
}
}
}
// namespace util
}
// namespace op
}
// namespace ngraph
src/ngraph/runtime/cpu/pass/cpu_rnn_fusion.cpp
View file @
53a6af8d
...
...
@@ -77,8 +77,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
element
::
f32
,
Shape
{
ref_gates_count
*
ref_hidden_size
,
ref_input_size
});
auto
R
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
ref_gates_count
*
ref_hidden_size
,
ref_hidden_size
});
auto
bias_ref
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
2
*
ref_gates_count
*
ref_hidden_size
});
auto
B
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
ref_gates_count
*
ref_hidden_size
});
auto
peep_hole
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
3
*
ref_hidden_size
});
auto
H_t
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
ref_batch_size
,
ref_hidden_size
});
...
...
@@ -87,13 +87,14 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
auto
ref_lstm_cell
=
std
::
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
ref_hidden_size
,
bias_ref
,
W
,
R
,
B
,
peep_hole
,
ref_hidden_size
,
op
::
LSTMWeightsFormat
::
IOFC
,
std
::
vector
<
std
::
string
>
{
"sigmoid"
,
"tanh"
,
"tanh"
},
std
::
vector
<
float
>
{},
std
::
vector
<
float
>
{},
...
...
@@ -101,72 +102,27 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
false
);
auto
callback
=
[
X
,
W
,
R
,
H_t
,
C_t
](
pattern
::
Matcher
&
m
)
{
auto
pattern_map
=
m
.
get_pattern_map
();
ngraph
::
runtime
::
cpu
::
rnn_utils
::
rnntype
rnn_type
=
ngraph
::
runtime
::
cpu
::
rnn_utils
::
rnntype
::
vanilla_lstm
;
auto
target_lstm_node
=
m
.
get_match_root
();
auto
lstmcell_op
=
as_type_ptr
<
op
::
LSTMCell
>
(
m
.
get_match_root
());
auto
src_iter
=
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
pattern_map
[
H_t
],
pattern_map
[
C_t
]},
0
);
auto
bias_iofc
=
target_lstm_node
->
get_argument
(
5
);
// we need to reorder W, R and bias from IOFC to IFCO gate order
// Note: ONNX runtime provides W, R and bias in the gate order [IOFC] but
// MKLDNN computes LSTM kernel in the [IFCO] order.
auto
get_weights_ifco_gate_order
=
[
&
](
std
::
shared_ptr
<
Node
>
weights_graph_node
)
->
std
::
shared_ptr
<
Node
>
{
// slices will be in ICFO order
std
::
vector
<
std
::
shared_ptr
<
Node
>>
gate_slices
;
size_t
dim0
=
weights_graph_node
->
get_shape
()[
0
]
/
4
;
size_t
dim1
=
weights_graph_node
->
get_shape
()[
1
];
for
(
size_t
i
=
0
;
i
<
4
;
i
++
)
{
auto
slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
weights_graph_node
,
Coordinate
{
i
*
dim0
,
0
},
Coordinate
{(
i
+
1
)
*
dim0
,
dim1
});
gate_slices
.
push_back
(
slice
);
}
auto
weights_ifco
=
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
gate_slices
[
0
],
gate_slices
[
2
],
gate_slices
[
3
],
gate_slices
[
1
]},
0
);
return
std
::
move
(
weights_ifco
);
};
auto
get_bias_ifco_gate_order
=
[
&
](
std
::
shared_ptr
<
Node
>
bias_graph_node
)
->
std
::
shared_ptr
<
Node
>
{
size_t
hidden_size
=
lstmcell_op
->
get_hidden_size
();
auto
Wb_bias
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
bias_graph_node
,
Coordinate
{
0
},
Coordinate
{
4
*
hidden_size
});
auto
Rb_bias
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
bias_graph_node
,
Coordinate
{
4
*
hidden_size
},
Coordinate
{
2
*
4
*
hidden_size
});
auto
bias
=
std
::
make_shared
<
op
::
Add
>
(
Wb_bias
,
Rb_bias
);
auto
W_ifco
=
lstmcell_op
->
get_argument
(
3
);
auto
R_ifco
=
lstmcell_op
->
get_argument
(
4
);
auto
bias_ifco
=
lstmcell_op
->
get_argument
(
5
);
// slices will be in ICFO order
std
::
vector
<
std
::
shared_ptr
<
Node
>>
gate_slices
;
for
(
size_t
i
=
0
;
i
<
4
;
i
++
)
{
auto
slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
bias
,
Coordinate
{
i
*
hidden_size
},
Coordinate
{(
i
+
1
)
*
hidden_size
});
gate_slices
.
push_back
(
slice
);
}
auto
new_bias
=
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
gate_slices
[
0
],
gate_slices
[
2
],
gate_slices
[
3
],
gate_slices
[
1
]},
0
);
return
std
::
move
(
new_bias
);
};
auto
W_iofc
=
pattern_map
[
W
];
auto
R_iofc
=
pattern_map
[
R
];
auto
W_ifco
=
get_weights_ifco_gate_order
(
W_iofc
);
auto
R_ifco
=
get_weights_ifco_gate_order
(
R_iofc
);
// here onnx bias will be of shape (2 * gates_count * hidden_size) bias of Wb and Rb are
// concatenated, we will split the bias, add and rearrange in order IFCO
auto
bias_ifco
=
get_bias_ifco_gate_order
(
bias_iofc
);
// We need to reorder W, R and bias to IFCO gate order.
// Note: ie.: ONNX runtime provides W, R and bias in the gate order [IOFC] but
// MKLDNN computes LSTM kernel in the [IFCO] order.
if
(
lstmcell_op
->
get_weights_format
()
!=
op
::
LSTMWeightsFormat
::
IFCO
)
{
W_ifco
=
lstmcell_op
->
convert_node_format
(
W_ifco
);
R_ifco
=
lstmcell_op
->
convert_node_format
(
R_ifco
);
bias_ifco
=
lstmcell_op
->
convert_node_format
(
bias_ifco
);
}
auto
W_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
W_ifco
,
AxisVector
{
1
,
0
},
Shape
{
W_ifco
->
get_shape
()[
1
],
W_ifco
->
get_shape
()[
0
]});
...
...
@@ -595,7 +551,6 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
lstm_weights_layer_label
,
lstm_weights_iter_label
,
lstm_bias_label
](
pattern
::
RecurrentMatcher
&
m
)
{
NGRAPH_DEBUG
<<
" In recurrent RNN fusion callback"
;
auto
concat_rnn_inputs_across_timestep
=
...
...
@@ -800,7 +755,6 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
lstm_weights_layer_label
,
lstm_weights_iter_label
,
lstm_bias_label
](
pattern
::
RecurrentMatcher
&
m
)
{
NGRAPH_DEBUG
<<
" In recurrent RNN fusion callback"
;
auto
concat_rnn_inputs_across_timestep
=
...
...
@@ -1161,7 +1115,6 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
// Replace all the users of RNN cell state {ct} across different user.
auto
replace_rnn_output_cellstate
=
[
&
](
std
::
shared_ptr
<
Node
>
rnn_ct_goe1
,
size_t
layer
)
{
// multi layerd fused rnn second output {GOE1} holds the recurrent output state tensors
// for the last cell of all the layers, {{ht_1 | ct_1} || {ht2 |ct2} || ....{htn | ctn}}
// we will slice the cell state output tensor {ct_*} from the fused RNN kerenel output
...
...
@@ -1211,7 +1164,6 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
// Replace all the users of RNN cell state {ct} across different user.
auto
replace_rnn_output_cellstate
=
[
&
](
std
::
shared_ptr
<
Node
>
rnn_ct_goe2
,
size_t
layer
)
{
// multi layerd fused rnn second output {GOE2} holds the recurrent output state tensors
// for the last cell
// of all the layers, { ct_1 || ct2 || ....|| ctn}
...
...
@@ -1302,7 +1254,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
// Define a call back that needs to called once the DFG matches the pattern
auto
callback
=
[
rnn_left_to_right
,
rnn_right_to_left
](
pattern
::
Matcher
&
m
)
{
auto
pattern_map
=
m
.
get_pattern_map
();
auto
rnn_ltor_node
=
std
::
static_pointer_cast
<
ngraph
::
op
::
Rnn
>
(
pattern_map
[
rnn_left_to_right
]);
...
...
@@ -1351,7 +1302,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
ngraph
::
runtime
::
cpu
::
rnn_utils
::
rnntype
::
vanilla_lstm
;
auto
construct_birnn_inputs
=
[
&
](
int
index
)
{
auto
nodes
=
NodeVector
{
rnn_ltor_node
->
get_argument
(
index
),
rnn_rtol_node
->
get_argument
(
index
)};
return
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
nodes
,
0
);
...
...
src/ngraph/serializer.cpp
View file @
53a6af8d
...
...
@@ -433,6 +433,13 @@ static element::Type read_element_type(json j)
return
element
::
Type
(
bitwidth
,
is_real
,
is_signed
,
is_quantized
,
c_type_string
);
}
static
op
::
LSTMWeightsFormat
read_lstm_weights_format
(
const
json
&
js
)
{
return
has_key
(
js
,
"weights_format"
)
?
static_cast
<
op
::
LSTMWeightsFormat
>
(
js
.
at
(
"weights_format"
))
:
op
::
LSTMWeightsFormat
::
IFCO
;
}
void
ngraph
::
serialize
(
const
string
&
path
,
shared_ptr
<
ngraph
::
Function
>
func
,
size_t
indent
)
{
ofstream
out
(
path
);
...
...
@@ -1828,24 +1835,60 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case
OP_TYPEID
:
:
LSTMCell
:
{
auto
hidden_size
=
node_js
.
at
(
"hidden_size"
).
get
<
size_t
>
();
auto
weights_format
=
read_lstm_weights_format
(
node_js
);
auto
clip
=
node_js
.
at
(
"clip"
).
get
<
float
>
();
auto
activations
=
node_js
.
at
(
"activations"
).
get
<
vector
<
string
>>
();
auto
activation
_alpha
=
node_js
.
at
(
"activation
_alpha"
).
get
<
vector
<
float
>>
();
auto
activation
_beta
=
node_js
.
at
(
"activation
_beta"
).
get
<
vector
<
float
>>
();
auto
activation
s_alpha
=
node_js
.
at
(
"activations
_alpha"
).
get
<
vector
<
float
>>
();
auto
activation
s_beta
=
node_js
.
at
(
"activations
_beta"
).
get
<
vector
<
float
>>
();
auto
input_forget
=
node_js
.
at
(
"input_forget"
).
get
<
bool
>
();
node
=
make_shared
<
op
::
LSTMCell
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
hidden_size
,
args
[
5
],
args
[
6
],
activations
,
activation_alpha
,
activation_beta
,
clip
,
input_forget
);
if
(
args
.
size
()
==
7
)
{
node
=
make_shared
<
op
::
LSTMCell
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
args
[
5
],
args
[
6
],
hidden_size
,
weights_format
,
activations
,
activations_alpha
,
activations_beta
,
clip
,
input_forget
);
}
if
(
args
.
size
()
==
6
)
{
node
=
make_shared
<
op
::
LSTMCell
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
args
[
5
],
hidden_size
,
weights_format
,
activations
,
activations_alpha
,
activations_beta
,
clip
,
input_forget
);
}
else
{
node
=
make_shared
<
op
::
LSTMCell
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
hidden_size
,
weights_format
,
activations
,
activations_alpha
,
activations_beta
,
clip
,
input_forget
);
}
break
;
}
case
OP_TYPEID
:
:
LSTMSequence
:
...
...
@@ -1857,6 +1900,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto
activations_beta
=
node_js
.
at
(
"activations_beta"
).
get
<
vector
<
float
>>
();
auto
input_forget
=
node_js
.
at
(
"input_forget"
).
get
<
bool
>
();
auto
direction
=
node_js
.
at
(
"direction"
).
get
<
op
::
LSTMSequence
::
direction
>
();
auto
weights_format
=
read_lstm_weights_format
(
node_js
);
if
(
args
.
size
()
==
8
)
{
node
=
make_shared
<
op
::
LSTMSequence
>
(
args
[
0
],
...
...
@@ -1869,6 +1913,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args
[
7
],
hidden_size
,
direction
,
weights_format
,
activations_alpha
,
activations_beta
,
activations
,
...
...
@@ -1886,6 +1931,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args
[
6
],
hidden_size
,
direction
,
weights_format
,
activations_alpha
,
activations_beta
,
activations
,
...
...
@@ -2393,8 +2439,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args
[
1
],
args
[
2
],
args
[
3
],
hidden_size
,
args
[
4
],
hidden_size
,
activations
,
activation_alpha
,
activation_beta
,
...
...
@@ -3418,8 +3464,8 @@ json JSONSerializer::serialize_node(const Node& n)
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"clip"
]
=
tmp
->
get_clip
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activation
_alpha"
]
=
tmp
->
get_activation
_alpha
();
node
[
"activation
_beta"
]
=
tmp
->
get_activation
_beta
();
node
[
"activation
s_alpha"
]
=
tmp
->
get_activations
_alpha
();
node
[
"activation
s_beta"
]
=
tmp
->
get_activations
_beta
();
node
[
"linear_before_reset"
]
=
tmp
->
get_linear_before_reset
();
break
;
}
...
...
@@ -3552,10 +3598,11 @@ json JSONSerializer::serialize_node(const Node& n)
{
auto
tmp
=
static_cast
<
const
op
::
LSTMCell
*>
(
&
n
);
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"weights_format"
]
=
tmp
->
get_weights_format
();
node
[
"clip"
]
=
tmp
->
get_clip
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activation
_alpha"
]
=
tmp
->
get_activation
_alpha
();
node
[
"activation
_beta"
]
=
tmp
->
get_activation
_beta
();
node
[
"activation
s_alpha"
]
=
tmp
->
get_activations
_alpha
();
node
[
"activation
s_beta"
]
=
tmp
->
get_activations
_beta
();
node
[
"input_forget"
]
=
tmp
->
get_input_forget
();
break
;
}
...
...
@@ -3564,6 +3611,7 @@ json JSONSerializer::serialize_node(const Node& n)
auto
tmp
=
dynamic_cast
<
const
op
::
LSTMSequence
*>
(
&
n
);
node
[
"direction"
]
=
tmp
->
get_direction
();
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"weights_format"
]
=
tmp
->
get_weights_format
();
node
[
"clip_threshold"
]
=
tmp
->
get_clip_threshold
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activations_alpha"
]
=
tmp
->
get_activations_alpha
();
...
...
@@ -3936,8 +3984,8 @@ json JSONSerializer::serialize_node(const Node& n)
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"clip"
]
=
tmp
->
get_clip
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activation
_alpha"
]
=
tmp
->
get_activation
_alpha
();
node
[
"activation
_beta"
]
=
tmp
->
get_activation
_beta
();
node
[
"activation
s_alpha"
]
=
tmp
->
get_activations
_alpha
();
node
[
"activation
s_beta"
]
=
tmp
->
get_activations
_beta
();
break
;
}
case
OP_TYPEID
:
:
ScalarConstantLike
:
...
...
test/backend/fused_op.in.cpp
View file @
53a6af8d
This diff is collapsed.
Click to expand it.
test/cpu_fusion.cpp
View file @
53a6af8d
...
...
@@ -3967,12 +3967,14 @@ TEST(cpu_fusion, lstm_cell)
const
auto
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
auto
ht
=
make_shared
<
op
::
GetOutputElement
>
(
lstm_cell
,
0
);
auto
ct
=
make_shared
<
op
::
GetOutputElement
>
(
lstm_cell
,
1
);
auto
lstm_function
=
make_shared
<
Function
>
(
NodeVector
{
ht
,
ct
},
ParameterVector
{
X
,
W
,
R
,
H_t
,
C_t
});
auto
lstm_function
=
make_shared
<
Function
>
(
NodeVector
{
ht
,
ct
},
ParameterVector
{
X
,
H_t
,
C_t
,
W
,
R
,
});
return
lstm_function
;
};
auto
lstm_function_cpu
=
make_function
();
...
...
test/serialize.cpp
View file @
53a6af8d
...
...
@@ -531,10 +531,10 @@ TEST(serialize, tensor_iterator_lstm)
auto
R_body
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
*
H
,
H
});
auto
LSTM_cell
=
make_shared
<
op
::
LSTMCell
>
(
make_shared
<
op
::
Reshape
>
(
X
,
AxisVector
{
0
,
1
,
2
},
Shape
{
N
,
I
}),
W_body
,
R_body
,
make_shared
<
op
::
Reshape
>
(
H_t
,
AxisVector
{
0
,
1
,
2
},
Shape
{
N
,
H
}),
make_shared
<
op
::
Reshape
>
(
C_t
,
AxisVector
{
0
,
1
,
2
},
Shape
{
N
,
H
}),
W_body
,
R_body
,
H
);
auto
H_o
=
make_shared
<
op
::
Reshape
>
(
LSTM_cell
->
output
(
0
),
AxisVector
{
0
,
1
},
Shape
{
N
,
1
,
H
});
auto
C_o
=
make_shared
<
op
::
Reshape
>
(
LSTM_cell
->
output
(
1
),
AxisVector
{
0
,
1
},
Shape
{
N
,
1
,
H
});
...
...
test/type_prop/gru_cell.cpp
View file @
53a6af8d
...
...
@@ -87,7 +87,8 @@ TEST(type_prop, gru_cell_invalid_input)
}
catch
(
const
NodeValidationFailure
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor H_t must have shape"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor initial_hidden_state must have shape"
));
}
// Invalid B tensor shape.
...
...
test/type_prop/lstm_cell.cpp
View file @
53a6af8d
...
...
@@ -36,7 +36,16 @@ TEST(type_prop, lstm_cell)
const
auto
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
EXPECT_EQ
(
lstm_cell
->
get_hidden_size
(),
hidden_size
);
EXPECT_EQ
(
lstm_cell
->
get_clip
(),
0.
f
);
EXPECT_TRUE
(
lstm_cell
->
get_activations_alpha
().
empty
());
EXPECT_TRUE
(
lstm_cell
->
get_activations_beta
().
empty
());
EXPECT_EQ
(
lstm_cell
->
get_activations
()[
0
],
"sigmoid"
);
EXPECT_EQ
(
lstm_cell
->
get_activations
()[
1
],
"tanh"
);
EXPECT_EQ
(
lstm_cell
->
get_activations
()[
2
],
"tanh"
);
EXPECT_EQ
(
lstm_cell
->
get_weights_format
(),
op
::
LSTMWeightsFormat
::
IFCO
);
EXPECT_FALSE
(
lstm_cell
->
get_input_forget
());
EXPECT_EQ
(
lstm_cell
->
output
(
0
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_cell
->
output
(
0
).
get_shape
(),
(
Shape
{
batch_size
,
hidden_size
}));
EXPECT_EQ
(
lstm_cell
->
output
(
1
).
get_element_type
(),
element
::
f32
);
...
...
@@ -60,7 +69,7 @@ TEST(type_prop, lstm_cell_invalid_input)
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
*
hidden_size
,
input_size
});
try
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
catch
(
const
NodeValidationFailure
&
error
)
...
...
@@ -73,7 +82,7 @@ TEST(type_prop, lstm_cell_invalid_input)
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
gates_count
*
hidden_size
,
1
});
try
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
catch
(
const
NodeValidationFailure
&
error
)
...
...
@@ -86,12 +95,13 @@ TEST(type_prop, lstm_cell_invalid_input)
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
hidden_size
});
try
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
catch
(
const
NodeValidationFailure
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor H_t must have shape"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor initial_hidden_state must have shape"
));
}
// Invalid C_t tensor shape.
...
...
@@ -99,21 +109,22 @@ TEST(type_prop, lstm_cell_invalid_input)
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
hidden_size
});
try
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
catch
(
const
NodeValidationFailure
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor C_t must have shape"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor initial_cell_state must have shape"
));
}
// Invalid B tensor shape.
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
gates_count
*
hidden_size
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
*
gates_count
*
hidden_size
});
auto
P
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
*
hidden_size
});
try
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
,
B
,
P
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
B
,
P
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
catch
(
const
NodeValidationFailure
&
error
)
...
...
@@ -122,11 +133,11 @@ TEST(type_prop, lstm_cell_invalid_input)
}
// Invalid P tensor shape.
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
*
gates_count
*
hidden_size
});
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
gates_count
*
hidden_size
});
P
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
});
try
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
,
B
,
P
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
B
,
P
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
catch
(
const
NodeValidationFailure
&
error
)
...
...
test/type_prop/lstm_sequence.cpp
View file @
53a6af8d
...
...
@@ -28,7 +28,7 @@ TEST(type_prop, lstm_sequence)
const
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
,
3
});
const
auto
initial_hidden_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
const
auto
initial_cell_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
const
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
24
});
const
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
});
const
auto
sequence_lengths
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
2
});
const
auto
hidden_size
=
3
;
...
...
@@ -41,6 +41,20 @@ TEST(type_prop, lstm_sequence)
B
,
hidden_size
,
op
::
LSTMSequence
::
direction
::
FORWARD
);
EXPECT_EQ
(
lstm_sequence
->
get_hidden_size
(),
hidden_size
);
EXPECT_EQ
(
lstm_sequence
->
get_direction
(),
op
::
LSTMSequence
::
direction
::
FORWARD
);
EXPECT_EQ
(
lstm_sequence
->
get_weights_format
(),
op
::
LSTMWeightsFormat
::
IFCO
);
EXPECT_TRUE
(
lstm_sequence
->
get_activations_alpha
().
empty
());
EXPECT_TRUE
(
lstm_sequence
->
get_activations_beta
().
empty
());
EXPECT_EQ
(
lstm_sequence
->
get_activations
()[
0
],
"sigmoid"
);
EXPECT_EQ
(
lstm_sequence
->
get_activations
()[
1
],
"tanh"
);
EXPECT_EQ
(
lstm_sequence
->
get_activations
()[
2
],
"tanh"
);
EXPECT_EQ
(
lstm_sequence
->
get_clip_threshold
(),
0.
f
);
EXPECT_FALSE
(
lstm_sequence
->
get_input_forget
());
EXPECT_EQ
(
lstm_sequence
->
output
(
0
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_sequence
->
output
(
0
).
get_shape
(),
(
Shape
{
1
,
1
,
2
,
3
}));
EXPECT_EQ
(
lstm_sequence
->
output
(
1
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_sequence
->
output
(
1
).
get_shape
(),
(
Shape
{
1
,
2
,
3
}));
EXPECT_EQ
(
lstm_sequence
->
output
(
2
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_sequence
->
output
(
2
).
get_shape
(),
(
Shape
{
1
,
2
,
3
}));
}
test/type_prop/rnn_cell.cpp
View file @
53a6af8d
...
...
@@ -28,11 +28,11 @@ TEST(type_prop, rnn_cell)
const
size_t
hidden_size
=
3
;
const
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
input_size
});
const
auto
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
,
input_size
});
const
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
,
hidden_size
});
const
auto
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
W
,
R
,
H_t
,
hidden_size
);
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
H_t
,
W
,
R
,
hidden_size
);
EXPECT_EQ
(
rnn_cell
->
output
(
0
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
rnn_cell
->
output
(
0
).
get_shape
(),
(
Shape
{
batch_size
,
hidden_size
}));
}
...
...
@@ -51,7 +51,7 @@ TEST(type_prop, rnn_cell_invalid_input)
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
*
hidden_size
,
input_size
});
try
{
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
W
,
R
,
H_t
,
hidden_size
);
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
H_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"RNNCell node was created with invalid data."
;
}
catch
(
const
NodeValidationFailure
&
error
)
...
...
@@ -64,7 +64,7 @@ TEST(type_prop, rnn_cell_invalid_input)
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
,
1
});
try
{
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
W
,
R
,
H_t
,
hidden_size
);
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
H_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"RNNCell node was created with invalid data."
;
}
catch
(
const
NodeValidationFailure
&
error
)
...
...
@@ -77,20 +77,21 @@ TEST(type_prop, rnn_cell_invalid_input)
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
hidden_size
});
try
{
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
W
,
R
,
H_t
,
hidden_size
);
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
H_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"RNNCell node was created with invalid data."
;
}
catch
(
const
NodeValidationFailure
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor H_t must have shape"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor initial_hidden_state must have shape"
));
}
// Invalid B tensor shape.
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
*
hidden_size
});
try
{
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
W
,
R
,
H_t
,
hidden_size
,
B
);
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
H_t
,
W
,
R
,
B
,
hidden_size
);
FAIL
()
<<
"RNNCell node was created with invalid data."
;
}
catch
(
const
NodeValidationFailure
&
error
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment