Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
53a6af8d
Commit
53a6af8d
authored
Nov 08, 2019
by
Adam Rogowiec
Committed by
Michał Karzyński
Nov 08, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SPEC] LSTMCell, RNNCell updates. (#3733)
parent
bc0fd13f
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
23 changed files
with
375 additions
and
317 deletions
+375
-317
ops.py
python/ngraph/ops.py
+21
-20
rnn_cell.cpp
python/pyngraph/ops/fused/rnn_cell.cpp
+1
-1
test_ops_fused.py
python/test/ngraph/test_ops_fused.py
+10
-11
lstm.cpp
src/ngraph/frontend/onnx_import/op/lstm.cpp
+9
-4
gru_cell.cpp
src/ngraph/op/fused/gru_cell.cpp
+17
-17
gru_cell.hpp
src/ngraph/op/fused/gru_cell.hpp
+55
-49
lstm_cell.cpp
src/ngraph/op/fused/lstm_cell.cpp
+0
-0
lstm_cell.hpp
src/ngraph/op/fused/lstm_cell.hpp
+0
-0
lstm_sequence.cpp
src/ngraph/op/fused/lstm_sequence.cpp
+45
-18
lstm_sequence.hpp
src/ngraph/op/fused/lstm_sequence.hpp
+10
-0
rnn_cell.cpp
src/ngraph/op/fused/rnn_cell.cpp
+32
-53
rnn_cell.hpp
src/ngraph/op/fused/rnn_cell.hpp
+0
-0
rnn_cell_base.cpp
src/ngraph/op/util/rnn_cell_base.cpp
+8
-8
rnn_cell_base.hpp
src/ngraph/op/util/rnn_cell_base.hpp
+23
-19
cpu_rnn_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_rnn_fusion.cpp
+19
-69
serializer.cpp
src/ngraph/serializer.cpp
+70
-22
fused_op.in.cpp
test/backend/fused_op.in.cpp
+0
-0
cpu_fusion.cpp
test/cpu_fusion.cpp
+5
-3
serialize.cpp
test/serialize.cpp
+2
-2
gru_cell.cpp
test/type_prop/gru_cell.cpp
+2
-1
lstm_cell.cpp
test/type_prop/lstm_cell.cpp
+22
-11
lstm_sequence.cpp
test/type_prop/lstm_sequence.cpp
+15
-1
rnn_cell.cpp
test/type_prop/rnn_cell.cpp
+9
-8
No files found.
python/ngraph/ops.py
View file @
53a6af8d
...
@@ -242,11 +242,11 @@ def group_convolution(data_batch, # type: Node
...
@@ -242,11 +242,11 @@ def group_convolution(data_batch, # type: Node
@nameable_op
@nameable_op
def
rnn_cell
(
X
,
# type: Node
def
rnn_cell
(
X
,
# type: Node
H_t
,
# type: Node
W
,
# type: Node
W
,
# type: Node
R
,
# type: Node
R
,
# type: Node
H_t
,
# type: Node
hidden_size
,
# type: int
B
,
# type: Node
B
,
# type: Node
hidden_size
,
# type: int
activations
,
# type: List[str]
activations
,
# type: List[str]
activation_alpha
,
# type: List[float]
activation_alpha
,
# type: List[float]
activation_beta
,
# type: List[float]
activation_beta
,
# type: List[float]
...
@@ -261,29 +261,30 @@ def rnn_cell(X, # type: Node
...
@@ -261,29 +261,30 @@ def rnn_cell(X, # type: Node
Note this class represents only single *cell* and not whole RNN *layer*.
Note this class represents only single *cell* and not whole RNN *layer*.
:param X: The input tensor with shape: [batch_size, input_size].
:param X: The input tensor with shape: [batch_size, input_size].
:param W: The weight tensor with shape: [hidden_size, input_size].
:param H_t: The hidden state tensor at current time step with shape:
:param R: The recurrence weight tensor with shape: [hidden_size, hidden_size].
[batch_size, hidden_size].
:param H_t: The hidden state tensor at current time step with
:param W: The weight tensor with shape: [hidden_size, input_size].
shape: [batch_size, hidden_size].
:param R: The recurrence weight tensor with shape: [hidden_size,
:param hidden_size: The number of hidden units for recurrent cell.
hidden_size].
:param B: The bias tensor for input gate with shape: [2*hidden_size].
:param B: The bias tensor for input gate with shape: [2*hidden_size].
:param activations: The vector of activation functions used inside recurrent cell.
:param hidden_size: The number of hidden units for recurrent cell.
:param activation_alpha: The vector of alpha parameters for activation
:param activations: The vector of activation functions used inside recurrent cell.
functions in order respective to activation list.
:param activation_alpha: The vector of alpha parameters for activation functions in
:param activation_beta: The vector of beta parameters for activation functions
order respective to activation list.
in order respective to activation list.
:param activation_beta: The vector of beta parameters for activation functions in order
:param clip: The value defining clipping range [-clip, clip] on
respective to activation list.
input of activation functions.
:param clip: The value defining clipping range [-clip, clip] on input of
:param name: Optional output node name.
activation functions.
:return: The new node performing a RNNCell operation on tensor from input node.
:param name: Optional output node name.
:returns: The new node performing a RNNCell operation on tensor from input node.
"""
"""
return
RNNCell
(
X
,
return
RNNCell
(
X
,
H_t
,
W
,
W
,
R
,
R
,
H_t
,
hidden_size
,
B
,
B
,
hidden_size
,
activations
,
activations
,
activation_alpha
,
activation_alpha
,
activation_beta
,
activation_beta
,
...
...
python/pyngraph/ops/fused/rnn_cell.cpp
View file @
53a6af8d
...
@@ -31,8 +31,8 @@ void regclass_pyngraph_op_RNNCell(py::module m)
...
@@ -31,8 +31,8 @@ void regclass_pyngraph_op_RNNCell(py::module m)
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
int
&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
int
&
,
const
std
::
vector
<
std
::
string
>&
,
const
std
::
vector
<
std
::
string
>&
,
const
std
::
vector
<
float
>&
,
const
std
::
vector
<
float
>&
,
const
std
::
vector
<
float
>&
,
const
std
::
vector
<
float
>&
,
...
...
python/test/ngraph/test_ops_fused.py
View file @
53a6af8d
...
@@ -455,17 +455,20 @@ def test_rnn_cell_operator():
...
@@ -455,17 +455,20 @@ def test_rnn_cell_operator():
W_shape
=
[
hidden_size
,
input_size
]
W_shape
=
[
hidden_size
,
input_size
]
R_shape
=
[
hidden_size
,
hidden_size
]
R_shape
=
[
hidden_size
,
hidden_size
]
H_t_shape
=
[
batch_size
,
hidden_size
]
H_t_shape
=
[
batch_size
,
hidden_size
]
B_shape
=
[
2
*
hidden_size
]
B_shape
=
[
hidden_size
]
parameter_X
=
ng
.
parameter
(
X_shape
,
name
=
'X'
,
dtype
=
np
.
float32
)
parameter_X
=
ng
.
parameter
(
X_shape
,
name
=
'X'
,
dtype
=
np
.
float32
)
parameter_H_t
=
ng
.
parameter
(
H_t_shape
,
name
=
'H_t'
,
dtype
=
np
.
float32
)
parameter_W
=
ng
.
parameter
(
W_shape
,
name
=
'W'
,
dtype
=
np
.
float32
)
parameter_W
=
ng
.
parameter
(
W_shape
,
name
=
'W'
,
dtype
=
np
.
float32
)
parameter_R
=
ng
.
parameter
(
R_shape
,
name
=
'R'
,
dtype
=
np
.
float32
)
parameter_R
=
ng
.
parameter
(
R_shape
,
name
=
'R'
,
dtype
=
np
.
float32
)
parameter_H_t
=
ng
.
parameter
(
H_t_shape
,
name
=
'H_t'
,
dtype
=
np
.
float32
)
parameter_B
=
ng
.
parameter
(
B_shape
,
name
=
'B'
,
dtype
=
np
.
float32
)
parameter_B
=
ng
.
parameter
(
B_shape
,
name
=
'B'
,
dtype
=
np
.
float32
)
X_value
=
np
.
array
([
0.3432185
,
0.612268
,
0.20272376
,
X_value
=
np
.
array
([
0.3432185
,
0.612268
,
0.20272376
,
0.9513413
,
0.30585995
,
0.7265472
],
0.9513413
,
0.30585995
,
0.7265472
],
dtype
=
np
.
float32
)
.
reshape
(
X_shape
)
dtype
=
np
.
float32
)
.
reshape
(
X_shape
)
H_t_value
=
np
.
array
([
0.12444675
,
0.52055854
,
0.46489045
,
0.4983964
,
0.7730452
,
0.28439692
],
dtype
=
np
.
float32
)
.
reshape
(
H_t_shape
)
W_value
=
np
.
array
([
0.41930267
,
0.7872176
,
0.89940447
,
W_value
=
np
.
array
([
0.41930267
,
0.7872176
,
0.89940447
,
0.23659843
,
0.24676207
,
0.17101714
,
0.23659843
,
0.24676207
,
0.17101714
,
0.3147149
,
0.6555601
,
0.4559603
],
0.3147149
,
0.6555601
,
0.4559603
],
...
@@ -474,11 +477,7 @@ def test_rnn_cell_operator():
...
@@ -474,11 +477,7 @@ def test_rnn_cell_operator():
0.71549815
,
0.18775631
,
0.3182116
,
0.71549815
,
0.18775631
,
0.3182116
,
0.25392973
,
0.38301638
,
0.85531586
],
0.25392973
,
0.38301638
,
0.85531586
],
dtype
=
np
.
float32
)
.
reshape
(
R_shape
)
dtype
=
np
.
float32
)
.
reshape
(
R_shape
)
H_t_value
=
np
.
array
([
0.12444675
,
0.52055854
,
0.46489045
,
B_value
=
np
.
array
([
1.0289404
,
1.6362579
,
0.4370661
],
0.4983964
,
0.7730452
,
0.28439692
],
dtype
=
np
.
float32
)
.
reshape
(
H_t_shape
)
B_value
=
np
.
array
([
0.45513555
,
0.96227735
,
0.24737759
,
0.57380486
,
0.67398053
,
0.18968852
],
dtype
=
np
.
float32
)
.
reshape
(
B_shape
)
dtype
=
np
.
float32
)
.
reshape
(
B_shape
)
activations
=
[
'sigmoid'
]
activations
=
[
'sigmoid'
]
activation_alpha
=
[]
activation_alpha
=
[]
...
@@ -486,23 +485,23 @@ def test_rnn_cell_operator():
...
@@ -486,23 +485,23 @@ def test_rnn_cell_operator():
clip
=
2.88
clip
=
2.88
model
=
ng
.
rnn_cell
(
parameter_X
,
model
=
ng
.
rnn_cell
(
parameter_X
,
parameter_H_t
,
parameter_W
,
parameter_W
,
parameter_R
,
parameter_R
,
parameter_H_t
,
hidden_size
,
parameter_B
,
parameter_B
,
hidden_size
,
activations
,
activations
,
activation_alpha
,
activation_alpha
,
activation_beta
,
activation_beta
,
clip
)
clip
)
computation
=
runtime
.
computation
(
model
,
computation
=
runtime
.
computation
(
model
,
parameter_X
,
parameter_X
,
parameter_H_t
,
parameter_W
,
parameter_W
,
parameter_R
,
parameter_R
,
parameter_H_t
,
parameter_B
)
parameter_B
)
result
=
computation
(
X_value
,
W_value
,
R_value
,
H_t
_value
,
B_value
)
result
=
computation
(
X_value
,
H_t_value
,
W_value
,
R
_value
,
B_value
)
expected
=
np
.
array
([
0.94126844
,
0.9036043
,
0.841243
,
expected
=
np
.
array
([
0.94126844
,
0.9036043
,
0.841243
,
0.9468489
,
0.934215
,
0.873708
],
0.9468489
,
0.934215
,
0.873708
],
dtype
=
np
.
float32
)
.
reshape
(
batch_size
,
hidden_size
)
dtype
=
np
.
float32
)
.
reshape
(
batch_size
,
hidden_size
)
...
...
src/ngraph/frontend/onnx_import/op/lstm.cpp
View file @
53a6af8d
...
@@ -22,7 +22,9 @@
...
@@ -22,7 +22,9 @@
#include <vector>
#include <vector>
#include "exceptions.hpp"
#include "exceptions.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/get_output_element.hpp"
...
@@ -82,17 +84,19 @@ namespace ngraph
...
@@ -82,17 +84,19 @@ namespace ngraph
m_map
[
LSTMInput
::
LSTM_INPUT_W
]
->
get_shape
().
front
();
m_map
[
LSTMInput
::
LSTM_INPUT_W
]
->
get_shape
().
front
();
// ------ Optional inputs ------
// ------ Optional inputs ------
// The bias tensor for input gate. Shape [num_directions,
8
*hidden_size]
// The bias tensor for input gate. Shape [num_directions,
4
*hidden_size]
if
(
ng_inputs
.
size
()
>
3
&&
!
ng_inputs
.
at
(
3
)
->
is_null
())
if
(
ng_inputs
.
size
()
>
3
&&
!
ng_inputs
.
at
(
3
)
->
is_null
())
{
{
m_map
[
LSTMInput
::
LSTM_INPUT_B
]
=
ng_inputs
.
at
(
3
);
auto
bias
=
ng_inputs
.
at
(
3
);
auto
split_bias
=
builder
::
split
(
bias
,
2
,
1
);
m_map
[
LSTMInput
::
LSTM_INPUT_B
]
=
split_bias
.
at
(
0
)
+
split_bias
.
at
(
1
);
}
}
else
else
{
{
m_map
[
LSTMInput
::
LSTM_INPUT_B
]
=
ngraph
::
op
::
Constant
::
create
(
m_map
[
LSTMInput
::
LSTM_INPUT_B
]
=
ngraph
::
op
::
Constant
::
create
(
element
::
f32
,
element
::
f32
,
Shape
{
num_directions
,
2
*
gates_count
*
hidden_size
},
Shape
{
num_directions
,
gates_count
*
hidden_size
},
std
::
vector
<
float
>
(
num_directions
*
2
*
gates_count
*
hidden_size
,
std
::
vector
<
float
>
(
num_directions
*
gates_count
*
hidden_size
,
0.
f
));
0.
f
));
}
}
// The lengths of the sequences in a batch. Shape [batch_size]
// The lengths of the sequences in a batch. Shape [batch_size]
...
@@ -224,6 +228,7 @@ namespace ngraph
...
@@ -224,6 +228,7 @@ namespace ngraph
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_P
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_P
),
attributes
.
m_hidden_size
,
attributes
.
m_hidden_size
,
attributes
.
m_direction
,
attributes
.
m_direction
,
ngraph
::
op
::
LSTMWeightsFormat
::
IOFC
,
attributes
.
m_activation_alpha
,
attributes
.
m_activation_alpha
,
attributes
.
m_activation_beta
,
attributes
.
m_activation_beta
,
attributes
.
m_activations
,
attributes
.
m_activations
,
...
...
src/ngraph/op/fused/gru_cell.cpp
View file @
53a6af8d
...
@@ -33,12 +33,12 @@ constexpr NodeTypeInfo op::GRUCell::type_info;
...
@@ -33,12 +33,12 @@ constexpr NodeTypeInfo op::GRUCell::type_info;
op
::
GRUCell
::
GRUCell
(
const
Output
<
Node
>&
X
,
op
::
GRUCell
::
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
size_t
hidden_size
)
size_t
hidden_size
)
:
GRUCell
(
X
,
:
GRUCell
(
X
,
W
,
W
,
R
,
R
,
H_t
,
initial_hidden_state
,
hidden_size
,
hidden_size
,
vector
<
string
>
{
"sigmoid"
,
"tanh"
},
vector
<
string
>
{
"sigmoid"
,
"tanh"
},
vector
<
float
>
{},
vector
<
float
>
{},
...
@@ -51,15 +51,15 @@ op::GRUCell::GRUCell(const Output<Node>& X,
...
@@ -51,15 +51,15 @@ op::GRUCell::GRUCell(const Output<Node>& X,
op
::
GRUCell
::
GRUCell
(
const
Output
<
Node
>&
X
,
op
::
GRUCell
::
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
size_t
hidden_size
,
size_t
hidden_size
,
const
vector
<
string
>&
activations
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation
s
_alpha
,
const
vector
<
float
>&
activation_beta
,
const
vector
<
float
>&
activation
s
_beta
,
float
clip
,
float
clip
,
bool
linear_before_reset
)
bool
linear_before_reset
)
:
FusedOp
({
X
,
W
,
R
,
H_t
})
:
FusedOp
({
X
,
W
,
R
,
initial_hidden_state
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
_alpha
,
activation
_beta
)
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
s_alpha
,
activations
_beta
)
,
m_activation_f
{
get_activation_function
(
0
)}
,
m_activation_f
{
get_activation_function
(
0
)}
,
m_activation_g
{
get_activation_function
(
1
)}
,
m_activation_g
{
get_activation_function
(
1
)}
,
m_linear_before_reset
{
linear_before_reset
}
,
m_linear_before_reset
{
linear_before_reset
}
...
@@ -71,16 +71,16 @@ op::GRUCell::GRUCell(const Output<Node>& X,
...
@@ -71,16 +71,16 @@ op::GRUCell::GRUCell(const Output<Node>& X,
op
::
GRUCell
::
GRUCell
(
const
Output
<
Node
>&
X
,
op
::
GRUCell
::
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
size_t
hidden_size
,
size_t
hidden_size
,
const
Output
<
Node
>&
B
,
const
Output
<
Node
>&
B
,
const
vector
<
string
>&
activations
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation
s
_alpha
,
const
vector
<
float
>&
activation_beta
,
const
vector
<
float
>&
activation
s
_beta
,
float
clip
,
float
clip
,
bool
linear_before_reset
)
bool
linear_before_reset
)
:
FusedOp
({
X
,
W
,
R
,
H_t
,
B
})
:
FusedOp
({
X
,
W
,
R
,
initial_hidden_state
,
B
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
_alpha
,
activation
_beta
)
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
s_alpha
,
activations
_beta
)
,
m_activation_f
{
get_activation_function
(
0
)}
,
m_activation_f
{
get_activation_function
(
0
)}
,
m_activation_g
{
get_activation_function
(
1
)}
,
m_activation_g
{
get_activation_function
(
1
)}
,
m_linear_before_reset
{
linear_before_reset
}
,
m_linear_before_reset
{
linear_before_reset
}
...
@@ -129,7 +129,7 @@ void op::GRUCell::pre_validate_and_infer_types()
...
@@ -129,7 +129,7 @@ void op::GRUCell::pre_validate_and_infer_types()
"."
);
"."
);
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
(
ht_shape
==
Shape
{
batch_size
,
get_hidden_size
()}),
(
ht_shape
==
Shape
{
batch_size
,
get_hidden_size
()}),
"Input tensor
H_t
must have shape ("
,
"Input tensor
initial_hidden_state
must have shape ("
,
batch_size
,
batch_size
,
", "
,
", "
,
get_hidden_size
(),
get_hidden_size
(),
...
@@ -290,8 +290,8 @@ shared_ptr<Node> op::GRUCell::copy_with_new_args(const NodeVector& new_args) con
...
@@ -290,8 +290,8 @@ shared_ptr<Node> op::GRUCell::copy_with_new_args(const NodeVector& new_args) con
new_args
.
at
(
3
),
new_args
.
at
(
3
),
get_hidden_size
(),
get_hidden_size
(),
get_activations
(),
get_activations
(),
get_activation_alpha
(),
get_activation
s
_alpha
(),
get_activation_beta
(),
get_activation
s
_beta
(),
get_clip
(),
get_clip
(),
m_linear_before_reset
);
m_linear_before_reset
);
}
}
...
@@ -304,8 +304,8 @@ shared_ptr<Node> op::GRUCell::copy_with_new_args(const NodeVector& new_args) con
...
@@ -304,8 +304,8 @@ shared_ptr<Node> op::GRUCell::copy_with_new_args(const NodeVector& new_args) con
get_hidden_size
(),
get_hidden_size
(),
new_args
.
at
(
4
),
new_args
.
at
(
4
),
get_activations
(),
get_activations
(),
get_activation_alpha
(),
get_activation
s
_alpha
(),
get_activation_beta
(),
get_activation
s
_beta
(),
get_clip
(),
get_clip
(),
m_linear_before_reset
);
m_linear_before_reset
);
}
}
...
...
src/ngraph/op/fused/gru_cell.hpp
View file @
53a6af8d
...
@@ -47,84 +47,90 @@ namespace ngraph
...
@@ -47,84 +47,90 @@ namespace ngraph
///
///
/// \brief Constructs GRUCell node.
/// \brief Constructs GRUCell node.
///
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] X The input tensor with shape: [batch_size,
/// \param[in] W The weight tensor with shape:
/// input_size].
/// [gates_count * hidden_size, input_size].
/// \param[in] W The weight tensor with shape:
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// [gates_count * hidden_size, hidden_size].
/// \param[in] R The recurrence weight tensor with shape:
/// \param[in] H_t The hidden state tensor at current time step with
/// [gates_count * hidden_size, hidden_size].
/// shape: [batch_size, hidden_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step with
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
///
GRUCell
(
const
Output
<
Node
>&
X
,
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
std
::
size_t
hidden_size
);
std
::
size_t
hidden_size
);
///
///
/// \brief Constructs GRUCell node.
/// \brief Constructs GRUCell node.
///
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] X The input tensor with shape: [batch_size,
/// \param[in] W The weight tensor with shape:
/// input_size].
/// [gates_count * hidden_size, input_size].
/// \param[in] W The weight tensor with shape:
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// [gates_count * hidden_size, hidden_size].
/// \param[in] R The recurrence weight tensor with shape:
/// \param[in] H_t The hidden state tensor at current time step with
/// [gates_count * hidden_size, hidden_size].
/// shape: [batch_size, hidden_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step with
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// shape: [batch_size, hidden_size].
/// \param[in] activations The vector of activation functions used inside
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// recurrent cell.
/// functions in order respective to activation list.
/// \param[in] activations_alpha The vector of alpha parameters for activation
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// functions in order respective to activation list.
/// in order respective to activation list.
/// \param[in] activations_beta The vector of beta parameters for activation
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// functions in order respective to activation list.
/// input of activation functions.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
///
GRUCell
(
const
Output
<
Node
>&
X
,
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
std
::
size_t
hidden_size
,
std
::
size_t
hidden_size
,
const
std
::
vector
<
std
::
string
>&
activations
,
const
std
::
vector
<
std
::
string
>&
activations
,
const
std
::
vector
<
float
>&
activation_alpha
,
const
std
::
vector
<
float
>&
activation
s
_alpha
,
const
std
::
vector
<
float
>&
activation_beta
,
const
std
::
vector
<
float
>&
activation
s
_beta
,
float
clip
,
float
clip
,
bool
linear_before_reset
);
bool
linear_before_reset
);
///
///
/// \brief Constructs GRUCell node.
/// \brief Constructs GRUCell node.
///
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] X The input tensor with shape: [batch_size,
/// \param[in] W The weight tensor with shape:
/// input_size].
/// [gates_count * hidden_size, input_size].
/// \param[in] W The weight tensor with shape: [gates_count *
/// \param[in] R The recurrence weight tensor with shape:
/// hidden_size, input_size].
/// [gates_count * hidden_size, hidden_size].
/// \param[in] R The recurrence weight tensor with shape:
/// \param[in] H_t The hidden state tensor at current time step with
/// [gates_count * hidden_size, hidden_size].
/// shape: [batch_size, hidden_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step with
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// shape: [batch_size, hidden_size].
/// \param[in] B The bias tensor for input gate with shape:
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// [2 * gates_count * hidden_size].
/// \param[in] B The bias tensor for input gate with shape:
/// \param[in] activations The vector of activation functions used inside
/// [2 * gates_count * hidden_size].
/// recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// recurrent cell.
/// functions in order respective to activation list.
/// \param[in] activations_alpha The vector of alpha parameters for activation
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// functions in order respective to activation list.
/// in order respective to activation list.
/// \param[in] activations_beta The vector of beta parameters for activation
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// functions in order respective to activation list.
/// input of activation functions.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] linear_before_reset Whether or not to apply the linear transformation
/// before multiplying by the output of the reset
/// gate.
///
///
GRUCell
(
const
Output
<
Node
>&
X
,
GRUCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
const
Output
<
Node
>&
initial_hidden_state
,
std
::
size_t
hidden_size
,
std
::
size_t
hidden_size
,
const
Output
<
Node
>&
B
,
const
Output
<
Node
>&
B
,
const
std
::
vector
<
std
::
string
>&
activations
=
const
std
::
vector
<
std
::
string
>&
activations
=
std
::
vector
<
std
::
string
>
{
"sigmoid"
,
"tanh"
},
std
::
vector
<
std
::
string
>
{
"sigmoid"
,
"tanh"
},
const
std
::
vector
<
float
>&
activation_alpha
=
{},
const
std
::
vector
<
float
>&
activation
s
_alpha
=
{},
const
std
::
vector
<
float
>&
activation_beta
=
{},
const
std
::
vector
<
float
>&
activation
s
_beta
=
{},
float
clip
=
0.
f
,
float
clip
=
0.
f
,
bool
linear_before_reset
=
false
);
bool
linear_before_reset
=
false
);
...
...
src/ngraph/op/fused/lstm_cell.cpp
View file @
53a6af8d
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/lstm_cell.hpp
View file @
53a6af8d
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/lstm_sequence.cpp
View file @
53a6af8d
...
@@ -58,21 +58,47 @@ NodeVector op::LSTMSequence::decompose_op() const
...
@@ -58,21 +58,47 @@ NodeVector op::LSTMSequence::decompose_op() const
shared_ptr
<
Node
>
op
::
LSTMSequence
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
LSTMSequence
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
LSTMSequence
>
(
new_args
.
at
(
0
),
// X
if
(
new_args
.
size
()
==
8
)
new_args
.
at
(
1
),
// initial_hidden_state
{
new_args
.
at
(
2
),
// initial_cell_state
return
make_shared
<
LSTMSequence
>
(
new_args
.
at
(
0
),
// X
new_args
.
at
(
3
),
// sequence_lengths
new_args
.
at
(
1
),
// initial_hidden_state
new_args
.
at
(
4
),
// W
new_args
.
at
(
2
),
// initial_cell_state
new_args
.
at
(
5
),
// R
new_args
.
at
(
3
),
// sequence_lengths
new_args
.
at
(
6
),
// B
new_args
.
at
(
4
),
// W
new_args
.
at
(
7
),
// P
new_args
.
at
(
5
),
// R
m_hidden_size
,
new_args
.
at
(
6
),
// B
m_direction
,
new_args
.
at
(
7
),
// P
m_activations_alpha
,
m_hidden_size
,
m_activations_beta
,
m_direction
,
m_activations
,
m_weights_format
,
m_clip_threshold
,
m_activations_alpha
,
m_input_forget
);
m_activations_beta
,
m_activations
,
m_clip_threshold
,
m_input_forget
);
}
else
if
(
new_args
.
size
()
==
7
)
{
return
make_shared
<
LSTMSequence
>
(
new_args
.
at
(
0
),
// X
new_args
.
at
(
1
),
// initial_hidden_state
new_args
.
at
(
2
),
// initial_cell_state
new_args
.
at
(
3
),
// sequence_lengths
new_args
.
at
(
4
),
// W
new_args
.
at
(
5
),
// R
new_args
.
at
(
6
),
// B
m_hidden_size
,
m_direction
,
m_weights_format
,
m_activations_alpha
,
m_activations_beta
,
m_activations
,
m_clip_threshold
,
m_input_forget
);
}
else
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
}
}
shared_ptr
<
Node
>
op
::
LSTMSequence
::
get_masked_node
(
const
shared_ptr
<
Node
>&
data
,
shared_ptr
<
Node
>
op
::
LSTMSequence
::
get_masked_node
(
const
shared_ptr
<
Node
>&
data
,
...
@@ -157,13 +183,14 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
...
@@ -157,13 +183,14 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
for
(
const
auto
&
in_x
:
in_seqs
)
for
(
const
auto
&
in_x
:
in_seqs
)
{
{
shared_ptr
<
Node
>
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
in_x
,
shared_ptr
<
Node
>
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
in_x
,
W
,
R
,
H_t
,
H_t
,
C_t
,
C_t
,
m_hidden_size
,
W
,
R
,
B
,
B
,
P
,
P
,
m_hidden_size
,
m_weights_format
,
m_activations
,
m_activations
,
m_activations_alpha
,
m_activations_alpha
,
m_activations_beta
,
m_activations_beta
,
...
...
src/ngraph/op/fused/lstm_sequence.hpp
View file @
53a6af8d
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace
ngraph
namespace
ngraph
...
@@ -36,6 +37,9 @@ namespace ngraph
...
@@ -36,6 +37,9 @@ namespace ngraph
/// \note It follows notation and equations defined as in ONNX standard:
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
///
///
/// \sa LSTMCell, RNNCell, GRUCell
///
///
class
LSTMSequence
:
public
util
::
FusedOp
class
LSTMSequence
:
public
util
::
FusedOp
{
{
public
:
public
:
...
@@ -61,6 +65,7 @@ namespace ngraph
...
@@ -61,6 +65,7 @@ namespace ngraph
const
Output
<
Node
>&
P
,
const
Output
<
Node
>&
P
,
const
std
::
int64_t
hidden_size
,
const
std
::
int64_t
hidden_size
,
const
direction
lstm_direction
,
const
direction
lstm_direction
,
LSTMWeightsFormat
weights_format
=
LSTMWeightsFormat
::
IFCO
,
const
std
::
vector
<
float
>
activations_alpha
=
{},
const
std
::
vector
<
float
>
activations_alpha
=
{},
const
std
::
vector
<
float
>
activations_beta
=
{},
const
std
::
vector
<
float
>
activations_beta
=
{},
const
std
::
vector
<
std
::
string
>
activations
=
{
"sigmoid"
,
const
std
::
vector
<
std
::
string
>
activations
=
{
"sigmoid"
,
...
@@ -77,6 +82,7 @@ namespace ngraph
...
@@ -77,6 +82,7 @@ namespace ngraph
,
m_direction
(
lstm_direction
)
,
m_direction
(
lstm_direction
)
,
m_hidden_size
(
hidden_size
)
,
m_hidden_size
(
hidden_size
)
,
m_input_forget
(
input_forget
)
,
m_input_forget
(
input_forget
)
,
m_weights_format
(
weights_format
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
@@ -90,6 +96,7 @@ namespace ngraph
...
@@ -90,6 +96,7 @@ namespace ngraph
const
Output
<
Node
>&
B
,
const
Output
<
Node
>&
B
,
const
std
::
int64_t
hidden_size
,
const
std
::
int64_t
hidden_size
,
const
direction
lstm_direction
,
const
direction
lstm_direction
,
LSTMWeightsFormat
weights_format
=
LSTMWeightsFormat
::
IFCO
,
const
std
::
vector
<
float
>
activations_alpha
=
{},
const
std
::
vector
<
float
>
activations_alpha
=
{},
const
std
::
vector
<
float
>
activations_beta
=
{},
const
std
::
vector
<
float
>
activations_beta
=
{},
const
std
::
vector
<
std
::
string
>
activations
=
{
"sigmoid"
,
const
std
::
vector
<
std
::
string
>
activations
=
{
"sigmoid"
,
...
@@ -111,6 +118,7 @@ namespace ngraph
...
@@ -111,6 +118,7 @@ namespace ngraph
std
::
vector
<
float
>
{
0.
f
}),
std
::
vector
<
float
>
{
0.
f
}),
hidden_size
,
hidden_size
,
lstm_direction
,
lstm_direction
,
weights_format
,
activations_alpha
,
activations_alpha
,
activations_beta
,
activations_beta
,
activations
,
activations
,
...
@@ -131,6 +139,7 @@ namespace ngraph
...
@@ -131,6 +139,7 @@ namespace ngraph
direction
get_direction
()
const
{
return
m_direction
;
}
direction
get_direction
()
const
{
return
m_direction
;
}
std
::
int64_t
get_hidden_size
()
const
{
return
m_hidden_size
;
}
std
::
int64_t
get_hidden_size
()
const
{
return
m_hidden_size
;
}
bool
get_input_forget
()
const
{
return
m_input_forget
;
}
bool
get_input_forget
()
const
{
return
m_input_forget
;
}
LSTMWeightsFormat
get_weights_format
()
const
{
return
m_weights_format
;
}
private
:
private
:
///
///
/// \brief Gets the masked value according to sequence lenght in a batch.
/// \brief Gets the masked value according to sequence lenght in a batch.
...
@@ -163,6 +172,7 @@ namespace ngraph
...
@@ -163,6 +172,7 @@ namespace ngraph
direction
m_direction
;
direction
m_direction
;
std
::
int64_t
m_hidden_size
;
std
::
int64_t
m_hidden_size
;
bool
m_input_forget
;
bool
m_input_forget
;
LSTMWeightsFormat
m_weights_format
;
};
};
}
// namespace op
}
// namespace op
}
// namespace ngraph
}
// namespace ngraph
src/ngraph/op/fused/rnn_cell.cpp
View file @
53a6af8d
...
@@ -32,44 +32,34 @@ using namespace ngraph;
...
@@ -32,44 +32,34 @@ using namespace ngraph;
constexpr
NodeTypeInfo
op
::
RNNCell
::
type_info
;
constexpr
NodeTypeInfo
op
::
RNNCell
::
type_info
;
op
::
RNNCell
::
RNNCell
(
const
Output
<
Node
>&
X
,
op
::
RNNCell
::
RNNCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
initial_hidden_state
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
size_t
hidden_size
)
:
RNNCell
(
X
,
W
,
R
,
H_t
,
hidden_size
,
vector
<
string
>
{
"tanh"
},
vector
<
float
>
{},
vector
<
float
>
{},
0.
f
)
{
}
op
::
RNNCell
::
RNNCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
size_t
hidden_size
,
size_t
hidden_size
,
const
vector
<
string
>&
activations
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation
s
_alpha
,
const
vector
<
float
>&
activation_beta
,
const
vector
<
float
>&
activation
s
_beta
,
float
clip
)
float
clip
)
:
FusedOp
({
X
,
W
,
R
,
H_t
})
:
FusedOp
({
X
,
initial_hidden_state
,
W
,
R
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
_alpha
,
activation
_beta
)
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
s_alpha
,
activations
_beta
)
,
m_activation_f
{
get_activation_function
(
0
)}
,
m_activation_f
{
get_activation_function
(
0
)}
{
{
add_default_bias_input
(
);
set_argument
(
4
,
get_default_bias_input
()
);
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
op
::
RNNCell
::
RNNCell
(
const
Output
<
Node
>&
X
,
op
::
RNNCell
::
RNNCell
(
const
Output
<
Node
>&
X
,
const
Output
<
Node
>&
initial_hidden_state
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
W
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
R
,
const
Output
<
Node
>&
H_t
,
size_t
hidden_size
,
const
Output
<
Node
>&
B
,
const
Output
<
Node
>&
B
,
size_t
hidden_size
,
const
vector
<
string
>&
activations
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation
s
_alpha
,
const
vector
<
float
>&
activation_beta
,
const
vector
<
float
>&
activation
s
_beta
,
float
clip
)
float
clip
)
:
FusedOp
({
X
,
W
,
R
,
H_t
,
B
})
:
FusedOp
({
X
,
initial_hidden_state
,
W
,
R
,
B
})
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
_alpha
,
activation
_beta
)
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation
s_alpha
,
activations
_beta
)
,
m_activation_f
{
get_activation_function
(
0
)}
,
m_activation_f
{
get_activation_function
(
0
)}
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -83,9 +73,9 @@ void op::RNNCell::pre_validate_and_infer_types()
...
@@ -83,9 +73,9 @@ void op::RNNCell::pre_validate_and_infer_types()
}
}
const
auto
&
x_pshape
=
get_input_partial_shape
(
0
);
const
auto
&
x_pshape
=
get_input_partial_shape
(
0
);
const
auto
&
w
_pshape
=
get_input_partial_shape
(
1
);
const
auto
&
ht
_pshape
=
get_input_partial_shape
(
1
);
const
auto
&
r
_pshape
=
get_input_partial_shape
(
2
);
const
auto
&
w
_pshape
=
get_input_partial_shape
(
2
);
const
auto
&
ht
_pshape
=
get_input_partial_shape
(
3
);
const
auto
&
r
_pshape
=
get_input_partial_shape
(
3
);
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
(
x_pshape
.
is_static
()
||
w_pshape
.
is_static
()
||
r_pshape
.
is_static
()
||
(
x_pshape
.
is_static
()
||
w_pshape
.
is_static
()
||
r_pshape
.
is_static
()
||
...
@@ -121,7 +111,7 @@ void op::RNNCell::pre_validate_and_infer_types()
...
@@ -121,7 +111,7 @@ void op::RNNCell::pre_validate_and_infer_types()
"."
);
"."
);
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
(
ht_shape
==
Shape
{
batch_size
,
get_hidden_size
()}),
(
ht_shape
==
Shape
{
batch_size
,
get_hidden_size
()}),
"Input tensor
H_t
must have shape ("
,
"Input tensor
initial_hidden_state
must have shape ("
,
batch_size
,
batch_size
,
", "
,
", "
,
get_hidden_size
(),
get_hidden_size
(),
...
@@ -137,9 +127,9 @@ void op::RNNCell::pre_validate_and_infer_types()
...
@@ -137,9 +127,9 @@ void op::RNNCell::pre_validate_and_infer_types()
const
Shape
&
b_shape
{
b_pshape
.
to_shape
()};
const
Shape
&
b_shape
{
b_pshape
.
to_shape
()};
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
(
b_shape
==
Shape
{
2
*
get_hidden_size
()}),
(
b_shape
==
Shape
{
get_hidden_size
()}),
"Input tensor B must have shape ("
,
"Input tensor B must have shape ("
,
2
*
get_hidden_size
(),
get_hidden_size
(),
"). Actual shape is:"
,
"). Actual shape is:"
,
b_shape
,
b_shape
,
"."
);
"."
);
...
@@ -157,8 +147,7 @@ NodeVector op::RNNCell::decompose_op() const
...
@@ -157,8 +147,7 @@ NodeVector op::RNNCell::decompose_op() const
// W - The weight tensor for input gate. Shape: [hidden_size, input_size].
// W - The weight tensor for input gate. Shape: [hidden_size, input_size].
// R - The recurrence weight tensor for input gate. Shape: [hidden_size, hidden_size].
// R - The recurrence weight tensor for input gate. Shape: [hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
// B - The bias tensor for the input gate. Shape: [2 * hidden_size].
// B - The bias tensor for the input gate. Shape: [hidden_size].
// Concatenation of `[Wb, Rb]`.
// Wb - W bias vectors for input gate.
// Wb - W bias vectors for input gate.
// Rb - R bias vectors for input gate.
// Rb - R bias vectors for input gate.
// ------ VARIABLE NAMES ------
// ------ VARIABLE NAMES ------
...
@@ -174,10 +163,10 @@ NodeVector op::RNNCell::decompose_op() const
...
@@ -174,10 +163,10 @@ NodeVector op::RNNCell::decompose_op() const
// --------------------
// --------------------
Output
<
Node
>
X
=
input_value
(
0
);
Output
<
Node
>
X
=
input_value
(
0
);
Output
<
Node
>
W
=
input_value
(
1
);
Output
<
Node
>
H_t
=
input_value
(
1
);
Output
<
Node
>
R
=
input_value
(
2
);
Output
<
Node
>
W
=
input_value
(
2
);
Output
<
Node
>
H_t
=
input_value
(
3
);
Output
<
Node
>
R
=
input_value
(
3
);
Output
<
Node
>
bias
=
get_bias
(
);
Output
<
Node
>
bias
=
input_value
(
4
);
// Xt*(W^T)
// Xt*(W^T)
auto
Xt_W
=
std
::
make_shared
<
op
::
Dot
>
(
X
,
builder
::
transpose
(
W
));
auto
Xt_W
=
std
::
make_shared
<
op
::
Dot
>
(
X
,
builder
::
transpose
(
W
));
...
@@ -192,22 +181,12 @@ NodeVector op::RNNCell::decompose_op() const
...
@@ -192,22 +181,12 @@ NodeVector op::RNNCell::decompose_op() const
return
{
i_t
};
return
{
i_t
};
}
}
Output
<
Node
>
op
::
RNNCell
::
get_
bias
()
const
Output
<
Node
>
op
::
RNNCell
::
get_
default_bias_input
()
const
{
{
Output
<
Node
>
bias
;
return
Output
<
Node
>
{
// Split B onto Wb an Rb and add them.
NodeVector
b_W_R
=
builder
::
split
(
input_value
(
4
),
2
);
bias
=
b_W_R
.
at
(
0
)
+
b_W_R
.
at
(
1
);
return
bias
;
}
void
op
::
RNNCell
::
add_default_bias_input
()
{
Output
<
Node
>
B
=
op
::
Constant
::
create
(
input
(
0
).
get_element_type
(),
op
::
Constant
::
create
(
input
(
0
).
get_element_type
(),
Shape
{
2
*
s_gates_count
*
get_hidden_size
()},
Shape
{
s_gates_count
*
get_hidden_size
()},
vector
<
float
>
(
2
*
s_gates_count
*
get_hidden_size
(),
0.
f
));
vector
<
float
>
(
s_gates_count
*
get_hidden_size
(),
0.
f
))};
set_argument
(
4
,
B
);
}
}
shared_ptr
<
Node
>
op
::
RNNCell
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
RNNCell
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
@@ -221,8 +200,8 @@ shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) con
...
@@ -221,8 +200,8 @@ shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) con
new_args
.
at
(
3
),
new_args
.
at
(
3
),
get_hidden_size
(),
get_hidden_size
(),
get_activations
(),
get_activations
(),
get_activation_alpha
(),
get_activation
s
_alpha
(),
get_activation_beta
(),
get_activation
s
_beta
(),
get_clip
());
get_clip
());
}
}
else
if
(
new_args
.
size
()
==
5
)
else
if
(
new_args
.
size
()
==
5
)
...
@@ -231,11 +210,11 @@ shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) con
...
@@ -231,11 +210,11 @@ shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) con
new_args
.
at
(
1
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
2
),
new_args
.
at
(
3
),
new_args
.
at
(
3
),
get_hidden_size
(),
new_args
.
at
(
4
),
new_args
.
at
(
4
),
get_hidden_size
(),
get_activations
(),
get_activations
(),
get_activation_alpha
(),
get_activation
s
_alpha
(),
get_activation_beta
(),
get_activation
s
_beta
(),
get_clip
());
get_clip
());
}
}
else
else
...
...
src/ngraph/op/fused/rnn_cell.hpp
View file @
53a6af8d
This diff is collapsed.
Click to expand it.
src/ngraph/op/util/rnn_cell_base.cpp
View file @
53a6af8d
...
@@ -39,13 +39,13 @@ static vector<string> to_lower_case(const vector<string>& vs)
...
@@ -39,13 +39,13 @@ static vector<string> to_lower_case(const vector<string>& vs)
op
::
util
::
RNNCellBase
::
RNNCellBase
(
size_t
hidden_size
,
op
::
util
::
RNNCellBase
::
RNNCellBase
(
size_t
hidden_size
,
float
clip
,
float
clip
,
const
vector
<
string
>&
activations
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation
s
_alpha
,
const
vector
<
float
>&
activation_beta
)
const
vector
<
float
>&
activation
s
_beta
)
:
m_hidden_size
(
hidden_size
)
:
m_hidden_size
(
hidden_size
)
,
m_clip
(
clip
)
,
m_clip
(
clip
)
,
m_activations
(
to_lower_case
(
activations
))
,
m_activations
(
to_lower_case
(
activations
))
,
m_activation
_alpha
(
activation
_alpha
)
,
m_activation
s_alpha
(
activations
_alpha
)
,
m_activation
_beta
(
activation
_beta
)
,
m_activation
s_beta
(
activations
_beta
)
{
{
}
}
...
@@ -54,13 +54,13 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
...
@@ -54,13 +54,13 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
op
::
util
::
ActivationFunction
afunc
=
get_activation_func_by_name
(
m_activations
.
at
(
idx
));
op
::
util
::
ActivationFunction
afunc
=
get_activation_func_by_name
(
m_activations
.
at
(
idx
));
// Set activation functions parameters (if any)
// Set activation functions parameters (if any)
if
(
m_activation_alpha
.
size
()
>
idx
)
if
(
m_activation
s
_alpha
.
size
()
>
idx
)
{
{
afunc
.
set_alpha
(
m_activation_alpha
.
at
(
idx
));
afunc
.
set_alpha
(
m_activation
s
_alpha
.
at
(
idx
));
}
}
if
(
m_activation_beta
.
size
()
>
idx
)
if
(
m_activation
s
_beta
.
size
()
>
idx
)
{
{
afunc
.
set_beta
(
m_activation_beta
.
at
(
idx
));
afunc
.
set_beta
(
m_activation
s
_beta
.
at
(
idx
));
}
}
return
afunc
;
return
afunc
;
...
...
src/ngraph/op/util/rnn_cell_base.hpp
View file @
53a6af8d
...
@@ -40,30 +40,34 @@ namespace ngraph
...
@@ -40,30 +40,34 @@ namespace ngraph
///
///
/// \brief Constructs a RNNCellBase class.
/// \brief Constructs a RNNCellBase class.
///
///
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] hidden_size
The number of hidden units for recurrent cell.
/// \param[in] clip
The value defining clipping range [-clip, clip] on
/// \param[in] clip
The value defining clipping range [-clip, clip]
/// input of activation functions.
///
on
input of activation functions.
/// \param[in] activations The vector of activation functions used inside
/// \param[in] activations
The vector of activation functions used inside
/// recurrent cell.
///
recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// \param[in] activation
s
_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
///
functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation
/// \param[in] activation
s
_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
///
functions in order respective to activation list.
///
///
RNNCellBase
(
std
::
size_t
hidden_size
,
RNNCellBase
(
std
::
size_t
hidden_size
,
float
clip
,
float
clip
,
const
std
::
vector
<
std
::
string
>&
activations
,
const
std
::
vector
<
std
::
string
>&
activations
,
const
std
::
vector
<
float
>&
activation_alpha
,
const
std
::
vector
<
float
>&
activation
s
_alpha
,
const
std
::
vector
<
float
>&
activation_beta
);
const
std
::
vector
<
float
>&
activation
s
_beta
);
std
::
size_t
get_hidden_size
()
const
{
return
m_hidden_size
;
}
std
::
size_t
get_hidden_size
()
const
{
return
m_hidden_size
;
}
float
get_clip
()
const
{
return
m_clip
;
}
float
get_clip
()
const
{
return
m_clip
;
}
const
std
::
vector
<
std
::
string
>&
get_activations
()
const
{
return
m_activations
;
}
const
std
::
vector
<
std
::
string
>&
get_activations
()
const
{
return
m_activations
;
}
const
std
::
vector
<
float
>&
get_activation_alpha
()
const
const
std
::
vector
<
float
>&
get_activation
s
_alpha
()
const
{
{
return
m_activation_alpha
;
return
m_activation
s
_alpha
;
}
}
const
std
::
vector
<
float
>&
get_activation_beta
()
const
{
return
m_activation_beta
;
}
const
std
::
vector
<
float
>&
get_activations_beta
()
const
{
return
m_activations_beta
;
}
protected
:
protected
:
///
///
/// \brief Constructs activation function object.
/// \brief Constructs activation function object.
...
@@ -117,9 +121,9 @@ namespace ngraph
...
@@ -117,9 +121,9 @@ namespace ngraph
const
std
::
size_t
m_hidden_size
;
const
std
::
size_t
m_hidden_size
;
const
float
m_clip
;
const
float
m_clip
;
const
std
::
vector
<
std
::
string
>
m_activations
;
const
std
::
vector
<
std
::
string
>
m_activations
;
const
std
::
vector
<
float
>
m_activation_alpha
;
const
std
::
vector
<
float
>
m_activation
s
_alpha
;
const
std
::
vector
<
float
>
m_activation_beta
;
const
std
::
vector
<
float
>
m_activation
s
_beta
;
};
};
}
}
// namespace util
}
}
// namespace op
}
}
// namespace ngraph
src/ngraph/runtime/cpu/pass/cpu_rnn_fusion.cpp
View file @
53a6af8d
...
@@ -77,8 +77,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
...
@@ -77,8 +77,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
element
::
f32
,
Shape
{
ref_gates_count
*
ref_hidden_size
,
ref_input_size
});
element
::
f32
,
Shape
{
ref_gates_count
*
ref_hidden_size
,
ref_input_size
});
auto
R
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
auto
R
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
ref_gates_count
*
ref_hidden_size
,
ref_hidden_size
});
element
::
f32
,
Shape
{
ref_gates_count
*
ref_hidden_size
,
ref_hidden_size
});
auto
bias_ref
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
auto
B
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
element
::
f32
,
Shape
{
2
*
ref_gates_count
*
ref_hidden_size
});
Shape
{
ref_gates_count
*
ref_hidden_size
});
auto
peep_hole
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
3
*
ref_hidden_size
});
auto
peep_hole
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
3
*
ref_hidden_size
});
auto
H_t
=
auto
H_t
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
ref_batch_size
,
ref_hidden_size
});
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
ref_batch_size
,
ref_hidden_size
});
...
@@ -87,13 +87,14 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
...
@@ -87,13 +87,14 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
auto
ref_lstm_cell
=
auto
ref_lstm_cell
=
std
::
make_shared
<
op
::
LSTMCell
>
(
X
,
std
::
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
H_t
,
C_t
,
C_t
,
ref_hidden_size
,
W
,
bias_ref
,
R
,
B
,
peep_hole
,
peep_hole
,
ref_hidden_size
,
op
::
LSTMWeightsFormat
::
IOFC
,
std
::
vector
<
std
::
string
>
{
"sigmoid"
,
"tanh"
,
"tanh"
},
std
::
vector
<
std
::
string
>
{
"sigmoid"
,
"tanh"
,
"tanh"
},
std
::
vector
<
float
>
{},
std
::
vector
<
float
>
{},
std
::
vector
<
float
>
{},
std
::
vector
<
float
>
{},
...
@@ -101,72 +102,27 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
...
@@ -101,72 +102,27 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
false
);
false
);
auto
callback
=
[
X
,
W
,
R
,
H_t
,
C_t
](
pattern
::
Matcher
&
m
)
{
auto
callback
=
[
X
,
W
,
R
,
H_t
,
C_t
](
pattern
::
Matcher
&
m
)
{
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
ngraph
::
runtime
::
cpu
::
rnn_utils
::
rnntype
rnn_type
=
ngraph
::
runtime
::
cpu
::
rnn_utils
::
rnntype
rnn_type
=
ngraph
::
runtime
::
cpu
::
rnn_utils
::
rnntype
::
vanilla_lstm
;
ngraph
::
runtime
::
cpu
::
rnn_utils
::
rnntype
::
vanilla_lstm
;
auto
target_lstm_node
=
m
.
get_match_root
();
auto
lstmcell_op
=
as_type_ptr
<
op
::
LSTMCell
>
(
m
.
get_match_root
());
auto
lstmcell_op
=
as_type_ptr
<
op
::
LSTMCell
>
(
m
.
get_match_root
());
auto
src_iter
=
auto
src_iter
=
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
pattern_map
[
H_t
],
pattern_map
[
C_t
]},
0
);
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
pattern_map
[
H_t
],
pattern_map
[
C_t
]},
0
);
auto
bias_iofc
=
target_lstm_node
->
get_argument
(
5
);
// we need to reorder W, R and bias from IOFC to IFCO gate order
// Note: ONNX runtime provides W, R and bias in the gate order [IOFC] but
// MKLDNN computes LSTM kernel in the [IFCO] order.
auto
get_weights_ifco_gate_order
=
[
&
](
std
::
shared_ptr
<
Node
>
weights_graph_node
)
->
std
::
shared_ptr
<
Node
>
{
// slices will be in ICFO order
std
::
vector
<
std
::
shared_ptr
<
Node
>>
gate_slices
;
size_t
dim0
=
weights_graph_node
->
get_shape
()[
0
]
/
4
;
size_t
dim1
=
weights_graph_node
->
get_shape
()[
1
];
for
(
size_t
i
=
0
;
i
<
4
;
i
++
)
{
auto
slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
weights_graph_node
,
Coordinate
{
i
*
dim0
,
0
},
Coordinate
{(
i
+
1
)
*
dim0
,
dim1
});
gate_slices
.
push_back
(
slice
);
}
auto
weights_ifco
=
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
gate_slices
[
0
],
gate_slices
[
2
],
gate_slices
[
3
],
gate_slices
[
1
]},
0
);
return
std
::
move
(
weights_ifco
);
};
auto
get_bias_ifco_gate_order
=
[
&
](
std
::
shared_ptr
<
Node
>
bias_graph_node
)
->
std
::
shared_ptr
<
Node
>
{
size_t
hidden_size
=
lstmcell_op
->
get_hidden_size
();
auto
W_ifco
=
lstmcell_op
->
get_argument
(
3
);
auto
Wb_bias
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
auto
R_ifco
=
lstmcell_op
->
get_argument
(
4
);
bias_graph_node
,
Coordinate
{
0
},
Coordinate
{
4
*
hidden_size
});
auto
bias_ifco
=
lstmcell_op
->
get_argument
(
5
);
auto
Rb_bias
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
bias_graph_node
,
Coordinate
{
4
*
hidden_size
},
Coordinate
{
2
*
4
*
hidden_size
});
auto
bias
=
std
::
make_shared
<
op
::
Add
>
(
Wb_bias
,
Rb_bias
);
// slices will be in ICFO order
// We need to reorder W, R and bias to IFCO gate order.
std
::
vector
<
std
::
shared_ptr
<
Node
>>
gate_slices
;
// Note: ie.: ONNX runtime provides W, R and bias in the gate order [IOFC] but
// MKLDNN computes LSTM kernel in the [IFCO] order.
for
(
size_t
i
=
0
;
i
<
4
;
i
++
)
if
(
lstmcell_op
->
get_weights_format
()
!=
op
::
LSTMWeightsFormat
::
IFCO
)
{
{
auto
slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
W_ifco
=
lstmcell_op
->
convert_node_format
(
W_ifco
);
bias
,
Coordinate
{
i
*
hidden_size
},
Coordinate
{(
i
+
1
)
*
hidden_size
});
R_ifco
=
lstmcell_op
->
convert_node_format
(
R_ifco
);
gate_slices
.
push_back
(
slice
);
bias_ifco
=
lstmcell_op
->
convert_node_format
(
bias_ifco
);
}
}
auto
new_bias
=
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
gate_slices
[
0
],
gate_slices
[
2
],
gate_slices
[
3
],
gate_slices
[
1
]},
0
);
return
std
::
move
(
new_bias
);
};
auto
W_iofc
=
pattern_map
[
W
];
auto
R_iofc
=
pattern_map
[
R
];
auto
W_ifco
=
get_weights_ifco_gate_order
(
W_iofc
);
auto
R_ifco
=
get_weights_ifco_gate_order
(
R_iofc
);
// here onnx bias will be of shape (2 * gates_count * hidden_size) bias of Wb and Rb are
// concatenated, we will split the bias, add and rearrange in order IFCO
auto
bias_ifco
=
get_bias_ifco_gate_order
(
bias_iofc
);
auto
W_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
auto
W_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
W_ifco
,
AxisVector
{
1
,
0
},
Shape
{
W_ifco
->
get_shape
()[
1
],
W_ifco
->
get_shape
()[
0
]});
W_ifco
,
AxisVector
{
1
,
0
},
Shape
{
W_ifco
->
get_shape
()[
1
],
W_ifco
->
get_shape
()[
0
]});
...
@@ -595,7 +551,6 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
...
@@ -595,7 +551,6 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
lstm_weights_layer_label
,
lstm_weights_layer_label
,
lstm_weights_iter_label
,
lstm_weights_iter_label
,
lstm_bias_label
](
pattern
::
RecurrentMatcher
&
m
)
{
lstm_bias_label
](
pattern
::
RecurrentMatcher
&
m
)
{
NGRAPH_DEBUG
<<
" In recurrent RNN fusion callback"
;
NGRAPH_DEBUG
<<
" In recurrent RNN fusion callback"
;
auto
concat_rnn_inputs_across_timestep
=
auto
concat_rnn_inputs_across_timestep
=
...
@@ -800,7 +755,6 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
...
@@ -800,7 +755,6 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
lstm_weights_layer_label
,
lstm_weights_layer_label
,
lstm_weights_iter_label
,
lstm_weights_iter_label
,
lstm_bias_label
](
pattern
::
RecurrentMatcher
&
m
)
{
lstm_bias_label
](
pattern
::
RecurrentMatcher
&
m
)
{
NGRAPH_DEBUG
<<
" In recurrent RNN fusion callback"
;
NGRAPH_DEBUG
<<
" In recurrent RNN fusion callback"
;
auto
concat_rnn_inputs_across_timestep
=
auto
concat_rnn_inputs_across_timestep
=
...
@@ -1161,7 +1115,6 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
...
@@ -1161,7 +1115,6 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
// Replace all the users of RNN cell state {ct} across different user.
// Replace all the users of RNN cell state {ct} across different user.
auto
replace_rnn_output_cellstate
=
[
&
](
std
::
shared_ptr
<
Node
>
rnn_ct_goe1
,
size_t
layer
)
{
auto
replace_rnn_output_cellstate
=
[
&
](
std
::
shared_ptr
<
Node
>
rnn_ct_goe1
,
size_t
layer
)
{
// multi layerd fused rnn second output {GOE1} holds the recurrent output state tensors
// multi layerd fused rnn second output {GOE1} holds the recurrent output state tensors
// for the last cell of all the layers, {{ht_1 | ct_1} || {ht2 |ct2} || ....{htn | ctn}}
// for the last cell of all the layers, {{ht_1 | ct_1} || {ht2 |ct2} || ....{htn | ctn}}
// we will slice the cell state output tensor {ct_*} from the fused RNN kerenel output
// we will slice the cell state output tensor {ct_*} from the fused RNN kerenel output
...
@@ -1211,7 +1164,6 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
...
@@ -1211,7 +1164,6 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
// Replace all the users of RNN cell state {ct} across different user.
// Replace all the users of RNN cell state {ct} across different user.
auto
replace_rnn_output_cellstate
=
[
&
](
std
::
shared_ptr
<
Node
>
rnn_ct_goe2
,
size_t
layer
)
{
auto
replace_rnn_output_cellstate
=
[
&
](
std
::
shared_ptr
<
Node
>
rnn_ct_goe2
,
size_t
layer
)
{
// multi layerd fused rnn second output {GOE2} holds the recurrent output state tensors
// multi layerd fused rnn second output {GOE2} holds the recurrent output state tensors
// for the last cell
// for the last cell
// of all the layers, { ct_1 || ct2 || ....|| ctn}
// of all the layers, { ct_1 || ct2 || ....|| ctn}
...
@@ -1302,7 +1254,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
...
@@ -1302,7 +1254,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
// Define a call back that needs to called once the DFG matches the pattern
// Define a call back that needs to called once the DFG matches the pattern
auto
callback
=
[
rnn_left_to_right
,
rnn_right_to_left
](
pattern
::
Matcher
&
m
)
{
auto
callback
=
[
rnn_left_to_right
,
rnn_right_to_left
](
pattern
::
Matcher
&
m
)
{
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
rnn_ltor_node
=
auto
rnn_ltor_node
=
std
::
static_pointer_cast
<
ngraph
::
op
::
Rnn
>
(
pattern_map
[
rnn_left_to_right
]);
std
::
static_pointer_cast
<
ngraph
::
op
::
Rnn
>
(
pattern_map
[
rnn_left_to_right
]);
...
@@ -1351,7 +1302,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
...
@@ -1351,7 +1302,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
ngraph
::
runtime
::
cpu
::
rnn_utils
::
rnntype
::
vanilla_lstm
;
ngraph
::
runtime
::
cpu
::
rnn_utils
::
rnntype
::
vanilla_lstm
;
auto
construct_birnn_inputs
=
[
&
](
int
index
)
{
auto
construct_birnn_inputs
=
[
&
](
int
index
)
{
auto
nodes
=
auto
nodes
=
NodeVector
{
rnn_ltor_node
->
get_argument
(
index
),
rnn_rtol_node
->
get_argument
(
index
)};
NodeVector
{
rnn_ltor_node
->
get_argument
(
index
),
rnn_rtol_node
->
get_argument
(
index
)};
return
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
nodes
,
0
);
return
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
nodes
,
0
);
...
...
src/ngraph/serializer.cpp
View file @
53a6af8d
...
@@ -433,6 +433,13 @@ static element::Type read_element_type(json j)
...
@@ -433,6 +433,13 @@ static element::Type read_element_type(json j)
return
element
::
Type
(
bitwidth
,
is_real
,
is_signed
,
is_quantized
,
c_type_string
);
return
element
::
Type
(
bitwidth
,
is_real
,
is_signed
,
is_quantized
,
c_type_string
);
}
}
static
op
::
LSTMWeightsFormat
read_lstm_weights_format
(
const
json
&
js
)
{
return
has_key
(
js
,
"weights_format"
)
?
static_cast
<
op
::
LSTMWeightsFormat
>
(
js
.
at
(
"weights_format"
))
:
op
::
LSTMWeightsFormat
::
IFCO
;
}
void
ngraph
::
serialize
(
const
string
&
path
,
shared_ptr
<
ngraph
::
Function
>
func
,
size_t
indent
)
void
ngraph
::
serialize
(
const
string
&
path
,
shared_ptr
<
ngraph
::
Function
>
func
,
size_t
indent
)
{
{
ofstream
out
(
path
);
ofstream
out
(
path
);
...
@@ -1828,24 +1835,60 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -1828,24 +1835,60 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case
OP_TYPEID
:
:
LSTMCell
:
case
OP_TYPEID
:
:
LSTMCell
:
{
{
auto
hidden_size
=
node_js
.
at
(
"hidden_size"
).
get
<
size_t
>
();
auto
hidden_size
=
node_js
.
at
(
"hidden_size"
).
get
<
size_t
>
();
auto
weights_format
=
read_lstm_weights_format
(
node_js
);
auto
clip
=
node_js
.
at
(
"clip"
).
get
<
float
>
();
auto
clip
=
node_js
.
at
(
"clip"
).
get
<
float
>
();
auto
activations
=
node_js
.
at
(
"activations"
).
get
<
vector
<
string
>>
();
auto
activations
=
node_js
.
at
(
"activations"
).
get
<
vector
<
string
>>
();
auto
activation
_alpha
=
node_js
.
at
(
"activation
_alpha"
).
get
<
vector
<
float
>>
();
auto
activation
s_alpha
=
node_js
.
at
(
"activations
_alpha"
).
get
<
vector
<
float
>>
();
auto
activation
_beta
=
node_js
.
at
(
"activation
_beta"
).
get
<
vector
<
float
>>
();
auto
activation
s_beta
=
node_js
.
at
(
"activations
_beta"
).
get
<
vector
<
float
>>
();
auto
input_forget
=
node_js
.
at
(
"input_forget"
).
get
<
bool
>
();
auto
input_forget
=
node_js
.
at
(
"input_forget"
).
get
<
bool
>
();
node
=
make_shared
<
op
::
LSTMCell
>
(
args
[
0
],
if
(
args
.
size
()
==
7
)
args
[
1
],
{
args
[
2
],
node
=
make_shared
<
op
::
LSTMCell
>
(
args
[
0
],
args
[
3
],
args
[
1
],
args
[
4
],
args
[
2
],
hidden_size
,
args
[
3
],
args
[
5
],
args
[
4
],
args
[
6
],
args
[
5
],
activations
,
args
[
6
],
activation_alpha
,
hidden_size
,
activation_beta
,
weights_format
,
clip
,
activations
,
input_forget
);
activations_alpha
,
activations_beta
,
clip
,
input_forget
);
}
if
(
args
.
size
()
==
6
)
{
node
=
make_shared
<
op
::
LSTMCell
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
args
[
5
],
hidden_size
,
weights_format
,
activations
,
activations_alpha
,
activations_beta
,
clip
,
input_forget
);
}
else
{
node
=
make_shared
<
op
::
LSTMCell
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
hidden_size
,
weights_format
,
activations
,
activations_alpha
,
activations_beta
,
clip
,
input_forget
);
}
break
;
break
;
}
}
case
OP_TYPEID
:
:
LSTMSequence
:
case
OP_TYPEID
:
:
LSTMSequence
:
...
@@ -1857,6 +1900,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -1857,6 +1900,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto
activations_beta
=
node_js
.
at
(
"activations_beta"
).
get
<
vector
<
float
>>
();
auto
activations_beta
=
node_js
.
at
(
"activations_beta"
).
get
<
vector
<
float
>>
();
auto
input_forget
=
node_js
.
at
(
"input_forget"
).
get
<
bool
>
();
auto
input_forget
=
node_js
.
at
(
"input_forget"
).
get
<
bool
>
();
auto
direction
=
node_js
.
at
(
"direction"
).
get
<
op
::
LSTMSequence
::
direction
>
();
auto
direction
=
node_js
.
at
(
"direction"
).
get
<
op
::
LSTMSequence
::
direction
>
();
auto
weights_format
=
read_lstm_weights_format
(
node_js
);
if
(
args
.
size
()
==
8
)
if
(
args
.
size
()
==
8
)
{
{
node
=
make_shared
<
op
::
LSTMSequence
>
(
args
[
0
],
node
=
make_shared
<
op
::
LSTMSequence
>
(
args
[
0
],
...
@@ -1869,6 +1913,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -1869,6 +1913,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args
[
7
],
args
[
7
],
hidden_size
,
hidden_size
,
direction
,
direction
,
weights_format
,
activations_alpha
,
activations_alpha
,
activations_beta
,
activations_beta
,
activations
,
activations
,
...
@@ -1886,6 +1931,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -1886,6 +1931,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args
[
6
],
args
[
6
],
hidden_size
,
hidden_size
,
direction
,
direction
,
weights_format
,
activations_alpha
,
activations_alpha
,
activations_beta
,
activations_beta
,
activations
,
activations
,
...
@@ -2393,8 +2439,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -2393,8 +2439,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args
[
1
],
args
[
1
],
args
[
2
],
args
[
2
],
args
[
3
],
args
[
3
],
hidden_size
,
args
[
4
],
args
[
4
],
hidden_size
,
activations
,
activations
,
activation_alpha
,
activation_alpha
,
activation_beta
,
activation_beta
,
...
@@ -3418,8 +3464,8 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -3418,8 +3464,8 @@ json JSONSerializer::serialize_node(const Node& n)
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"clip"
]
=
tmp
->
get_clip
();
node
[
"clip"
]
=
tmp
->
get_clip
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activation
_alpha"
]
=
tmp
->
get_activation
_alpha
();
node
[
"activation
s_alpha"
]
=
tmp
->
get_activations
_alpha
();
node
[
"activation
_beta"
]
=
tmp
->
get_activation
_beta
();
node
[
"activation
s_beta"
]
=
tmp
->
get_activations
_beta
();
node
[
"linear_before_reset"
]
=
tmp
->
get_linear_before_reset
();
node
[
"linear_before_reset"
]
=
tmp
->
get_linear_before_reset
();
break
;
break
;
}
}
...
@@ -3552,10 +3598,11 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -3552,10 +3598,11 @@ json JSONSerializer::serialize_node(const Node& n)
{
{
auto
tmp
=
static_cast
<
const
op
::
LSTMCell
*>
(
&
n
);
auto
tmp
=
static_cast
<
const
op
::
LSTMCell
*>
(
&
n
);
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"weights_format"
]
=
tmp
->
get_weights_format
();
node
[
"clip"
]
=
tmp
->
get_clip
();
node
[
"clip"
]
=
tmp
->
get_clip
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activation
_alpha"
]
=
tmp
->
get_activation
_alpha
();
node
[
"activation
s_alpha"
]
=
tmp
->
get_activations
_alpha
();
node
[
"activation
_beta"
]
=
tmp
->
get_activation
_beta
();
node
[
"activation
s_beta"
]
=
tmp
->
get_activations
_beta
();
node
[
"input_forget"
]
=
tmp
->
get_input_forget
();
node
[
"input_forget"
]
=
tmp
->
get_input_forget
();
break
;
break
;
}
}
...
@@ -3564,6 +3611,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -3564,6 +3611,7 @@ json JSONSerializer::serialize_node(const Node& n)
auto
tmp
=
dynamic_cast
<
const
op
::
LSTMSequence
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
LSTMSequence
*>
(
&
n
);
node
[
"direction"
]
=
tmp
->
get_direction
();
node
[
"direction"
]
=
tmp
->
get_direction
();
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"weights_format"
]
=
tmp
->
get_weights_format
();
node
[
"clip_threshold"
]
=
tmp
->
get_clip_threshold
();
node
[
"clip_threshold"
]
=
tmp
->
get_clip_threshold
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activations_alpha"
]
=
tmp
->
get_activations_alpha
();
node
[
"activations_alpha"
]
=
tmp
->
get_activations_alpha
();
...
@@ -3936,8 +3984,8 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -3936,8 +3984,8 @@ json JSONSerializer::serialize_node(const Node& n)
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"hidden_size"
]
=
tmp
->
get_hidden_size
();
node
[
"clip"
]
=
tmp
->
get_clip
();
node
[
"clip"
]
=
tmp
->
get_clip
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activations"
]
=
tmp
->
get_activations
();
node
[
"activation
_alpha"
]
=
tmp
->
get_activation
_alpha
();
node
[
"activation
s_alpha"
]
=
tmp
->
get_activations
_alpha
();
node
[
"activation
_beta"
]
=
tmp
->
get_activation
_beta
();
node
[
"activation
s_beta"
]
=
tmp
->
get_activations
_beta
();
break
;
break
;
}
}
case
OP_TYPEID
:
:
ScalarConstantLike
:
case
OP_TYPEID
:
:
ScalarConstantLike
:
...
...
test/backend/fused_op.in.cpp
View file @
53a6af8d
This diff is collapsed.
Click to expand it.
test/cpu_fusion.cpp
View file @
53a6af8d
...
@@ -3967,12 +3967,14 @@ TEST(cpu_fusion, lstm_cell)
...
@@ -3967,12 +3967,14 @@ TEST(cpu_fusion, lstm_cell)
const
auto
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
auto
ht
=
make_shared
<
op
::
GetOutputElement
>
(
lstm_cell
,
0
);
auto
ht
=
make_shared
<
op
::
GetOutputElement
>
(
lstm_cell
,
0
);
auto
ct
=
make_shared
<
op
::
GetOutputElement
>
(
lstm_cell
,
1
);
auto
ct
=
make_shared
<
op
::
GetOutputElement
>
(
lstm_cell
,
1
);
auto
lstm_function
=
auto
lstm_function
=
make_shared
<
Function
>
(
NodeVector
{
ht
,
ct
},
make_shared
<
Function
>
(
NodeVector
{
ht
,
ct
},
ParameterVector
{
X
,
W
,
R
,
H_t
,
C_t
});
ParameterVector
{
X
,
H_t
,
C_t
,
W
,
R
,
});
return
lstm_function
;
return
lstm_function
;
};
};
auto
lstm_function_cpu
=
make_function
();
auto
lstm_function_cpu
=
make_function
();
...
...
test/serialize.cpp
View file @
53a6af8d
...
@@ -531,10 +531,10 @@ TEST(serialize, tensor_iterator_lstm)
...
@@ -531,10 +531,10 @@ TEST(serialize, tensor_iterator_lstm)
auto
R_body
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
*
H
,
H
});
auto
R_body
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
*
H
,
H
});
auto
LSTM_cell
=
auto
LSTM_cell
=
make_shared
<
op
::
LSTMCell
>
(
make_shared
<
op
::
Reshape
>
(
X
,
AxisVector
{
0
,
1
,
2
},
Shape
{
N
,
I
}),
make_shared
<
op
::
LSTMCell
>
(
make_shared
<
op
::
Reshape
>
(
X
,
AxisVector
{
0
,
1
,
2
},
Shape
{
N
,
I
}),
W_body
,
R_body
,
make_shared
<
op
::
Reshape
>
(
H_t
,
AxisVector
{
0
,
1
,
2
},
Shape
{
N
,
H
}),
make_shared
<
op
::
Reshape
>
(
H_t
,
AxisVector
{
0
,
1
,
2
},
Shape
{
N
,
H
}),
make_shared
<
op
::
Reshape
>
(
C_t
,
AxisVector
{
0
,
1
,
2
},
Shape
{
N
,
H
}),
make_shared
<
op
::
Reshape
>
(
C_t
,
AxisVector
{
0
,
1
,
2
},
Shape
{
N
,
H
}),
W_body
,
R_body
,
H
);
H
);
auto
H_o
=
make_shared
<
op
::
Reshape
>
(
LSTM_cell
->
output
(
0
),
AxisVector
{
0
,
1
},
Shape
{
N
,
1
,
H
});
auto
H_o
=
make_shared
<
op
::
Reshape
>
(
LSTM_cell
->
output
(
0
),
AxisVector
{
0
,
1
},
Shape
{
N
,
1
,
H
});
auto
C_o
=
make_shared
<
op
::
Reshape
>
(
LSTM_cell
->
output
(
1
),
AxisVector
{
0
,
1
},
Shape
{
N
,
1
,
H
});
auto
C_o
=
make_shared
<
op
::
Reshape
>
(
LSTM_cell
->
output
(
1
),
AxisVector
{
0
,
1
},
Shape
{
N
,
1
,
H
});
...
...
test/type_prop/gru_cell.cpp
View file @
53a6af8d
...
@@ -87,7 +87,8 @@ TEST(type_prop, gru_cell_invalid_input)
...
@@ -87,7 +87,8 @@ TEST(type_prop, gru_cell_invalid_input)
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor H_t must have shape"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor initial_hidden_state must have shape"
));
}
}
// Invalid B tensor shape.
// Invalid B tensor shape.
...
...
test/type_prop/lstm_cell.cpp
View file @
53a6af8d
...
@@ -36,7 +36,16 @@ TEST(type_prop, lstm_cell)
...
@@ -36,7 +36,16 @@ TEST(type_prop, lstm_cell)
const
auto
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
EXPECT_EQ
(
lstm_cell
->
get_hidden_size
(),
hidden_size
);
EXPECT_EQ
(
lstm_cell
->
get_clip
(),
0.
f
);
EXPECT_TRUE
(
lstm_cell
->
get_activations_alpha
().
empty
());
EXPECT_TRUE
(
lstm_cell
->
get_activations_beta
().
empty
());
EXPECT_EQ
(
lstm_cell
->
get_activations
()[
0
],
"sigmoid"
);
EXPECT_EQ
(
lstm_cell
->
get_activations
()[
1
],
"tanh"
);
EXPECT_EQ
(
lstm_cell
->
get_activations
()[
2
],
"tanh"
);
EXPECT_EQ
(
lstm_cell
->
get_weights_format
(),
op
::
LSTMWeightsFormat
::
IFCO
);
EXPECT_FALSE
(
lstm_cell
->
get_input_forget
());
EXPECT_EQ
(
lstm_cell
->
output
(
0
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_cell
->
output
(
0
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_cell
->
output
(
0
).
get_shape
(),
(
Shape
{
batch_size
,
hidden_size
}));
EXPECT_EQ
(
lstm_cell
->
output
(
0
).
get_shape
(),
(
Shape
{
batch_size
,
hidden_size
}));
EXPECT_EQ
(
lstm_cell
->
output
(
1
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_cell
->
output
(
1
).
get_element_type
(),
element
::
f32
);
...
@@ -60,7 +69,7 @@ TEST(type_prop, lstm_cell_invalid_input)
...
@@ -60,7 +69,7 @@ TEST(type_prop, lstm_cell_invalid_input)
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
*
hidden_size
,
input_size
});
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
*
hidden_size
,
input_size
});
try
try
{
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
...
@@ -73,7 +82,7 @@ TEST(type_prop, lstm_cell_invalid_input)
...
@@ -73,7 +82,7 @@ TEST(type_prop, lstm_cell_invalid_input)
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
gates_count
*
hidden_size
,
1
});
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
gates_count
*
hidden_size
,
1
});
try
try
{
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
...
@@ -86,12 +95,13 @@ TEST(type_prop, lstm_cell_invalid_input)
...
@@ -86,12 +95,13 @@ TEST(type_prop, lstm_cell_invalid_input)
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
hidden_size
});
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
hidden_size
});
try
try
{
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor H_t must have shape"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor initial_hidden_state must have shape"
));
}
}
// Invalid C_t tensor shape.
// Invalid C_t tensor shape.
...
@@ -99,21 +109,22 @@ TEST(type_prop, lstm_cell_invalid_input)
...
@@ -99,21 +109,22 @@ TEST(type_prop, lstm_cell_invalid_input)
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
hidden_size
});
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
hidden_size
});
try
try
{
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor C_t must have shape"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor initial_cell_state must have shape"
));
}
}
// Invalid B tensor shape.
// Invalid B tensor shape.
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
C_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
gates_count
*
hidden_size
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
*
gates_count
*
hidden_size
});
auto
P
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
*
hidden_size
});
auto
P
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
*
hidden_size
});
try
try
{
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
,
B
,
P
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
B
,
P
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
...
@@ -122,11 +133,11 @@ TEST(type_prop, lstm_cell_invalid_input)
...
@@ -122,11 +133,11 @@ TEST(type_prop, lstm_cell_invalid_input)
}
}
// Invalid P tensor shape.
// Invalid P tensor shape.
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
*
gates_count
*
hidden_size
});
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
gates_count
*
hidden_size
});
P
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
});
P
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
});
try
try
{
{
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
W
,
R
,
H_t
,
C_t
,
hidden_size
,
B
,
P
);
const
auto
lstm_cell
=
make_shared
<
op
::
LSTMCell
>
(
X
,
H_t
,
C_t
,
W
,
R
,
B
,
P
,
hidden_size
);
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
FAIL
()
<<
"LSTMCell node was created with invalid data."
;
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
...
...
test/type_prop/lstm_sequence.cpp
View file @
53a6af8d
...
@@ -28,7 +28,7 @@ TEST(type_prop, lstm_sequence)
...
@@ -28,7 +28,7 @@ TEST(type_prop, lstm_sequence)
const
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
,
3
});
const
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
,
3
});
const
auto
initial_hidden_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
const
auto
initial_hidden_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
const
auto
initial_cell_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
const
auto
initial_cell_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
const
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
24
});
const
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
});
const
auto
sequence_lengths
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
2
});
const
auto
sequence_lengths
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
2
});
const
auto
hidden_size
=
3
;
const
auto
hidden_size
=
3
;
...
@@ -41,6 +41,20 @@ TEST(type_prop, lstm_sequence)
...
@@ -41,6 +41,20 @@ TEST(type_prop, lstm_sequence)
B
,
B
,
hidden_size
,
hidden_size
,
op
::
LSTMSequence
::
direction
::
FORWARD
);
op
::
LSTMSequence
::
direction
::
FORWARD
);
EXPECT_EQ
(
lstm_sequence
->
get_hidden_size
(),
hidden_size
);
EXPECT_EQ
(
lstm_sequence
->
get_direction
(),
op
::
LSTMSequence
::
direction
::
FORWARD
);
EXPECT_EQ
(
lstm_sequence
->
get_weights_format
(),
op
::
LSTMWeightsFormat
::
IFCO
);
EXPECT_TRUE
(
lstm_sequence
->
get_activations_alpha
().
empty
());
EXPECT_TRUE
(
lstm_sequence
->
get_activations_beta
().
empty
());
EXPECT_EQ
(
lstm_sequence
->
get_activations
()[
0
],
"sigmoid"
);
EXPECT_EQ
(
lstm_sequence
->
get_activations
()[
1
],
"tanh"
);
EXPECT_EQ
(
lstm_sequence
->
get_activations
()[
2
],
"tanh"
);
EXPECT_EQ
(
lstm_sequence
->
get_clip_threshold
(),
0.
f
);
EXPECT_FALSE
(
lstm_sequence
->
get_input_forget
());
EXPECT_EQ
(
lstm_sequence
->
output
(
0
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_sequence
->
output
(
0
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_sequence
->
output
(
0
).
get_shape
(),
(
Shape
{
1
,
1
,
2
,
3
}));
EXPECT_EQ
(
lstm_sequence
->
output
(
0
).
get_shape
(),
(
Shape
{
1
,
1
,
2
,
3
}));
EXPECT_EQ
(
lstm_sequence
->
output
(
1
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_sequence
->
output
(
1
).
get_shape
(),
(
Shape
{
1
,
2
,
3
}));
EXPECT_EQ
(
lstm_sequence
->
output
(
2
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
lstm_sequence
->
output
(
2
).
get_shape
(),
(
Shape
{
1
,
2
,
3
}));
}
}
test/type_prop/rnn_cell.cpp
View file @
53a6af8d
...
@@ -28,11 +28,11 @@ TEST(type_prop, rnn_cell)
...
@@ -28,11 +28,11 @@ TEST(type_prop, rnn_cell)
const
size_t
hidden_size
=
3
;
const
size_t
hidden_size
=
3
;
const
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
input_size
});
const
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
input_size
});
const
auto
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
,
input_size
});
const
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
,
input_size
});
const
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
,
hidden_size
});
const
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
,
hidden_size
});
const
auto
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
W
,
R
,
H_t
,
hidden_size
);
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
H_t
,
W
,
R
,
hidden_size
);
EXPECT_EQ
(
rnn_cell
->
output
(
0
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
rnn_cell
->
output
(
0
).
get_element_type
(),
element
::
f32
);
EXPECT_EQ
(
rnn_cell
->
output
(
0
).
get_shape
(),
(
Shape
{
batch_size
,
hidden_size
}));
EXPECT_EQ
(
rnn_cell
->
output
(
0
).
get_shape
(),
(
Shape
{
batch_size
,
hidden_size
}));
}
}
...
@@ -51,7 +51,7 @@ TEST(type_prop, rnn_cell_invalid_input)
...
@@ -51,7 +51,7 @@ TEST(type_prop, rnn_cell_invalid_input)
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
*
hidden_size
,
input_size
});
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
*
hidden_size
,
input_size
});
try
try
{
{
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
W
,
R
,
H_t
,
hidden_size
);
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
H_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"RNNCell node was created with invalid data."
;
FAIL
()
<<
"RNNCell node was created with invalid data."
;
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
...
@@ -64,7 +64,7 @@ TEST(type_prop, rnn_cell_invalid_input)
...
@@ -64,7 +64,7 @@ TEST(type_prop, rnn_cell_invalid_input)
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
,
1
});
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
,
1
});
try
try
{
{
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
W
,
R
,
H_t
,
hidden_size
);
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
H_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"RNNCell node was created with invalid data."
;
FAIL
()
<<
"RNNCell node was created with invalid data."
;
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
...
@@ -77,20 +77,21 @@ TEST(type_prop, rnn_cell_invalid_input)
...
@@ -77,20 +77,21 @@ TEST(type_prop, rnn_cell_invalid_input)
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
hidden_size
});
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
hidden_size
});
try
try
{
{
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
W
,
R
,
H_t
,
hidden_size
);
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
H_t
,
W
,
R
,
hidden_size
);
FAIL
()
<<
"RNNCell node was created with invalid data."
;
FAIL
()
<<
"RNNCell node was created with invalid data."
;
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor H_t must have shape"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input tensor initial_hidden_state must have shape"
));
}
}
// Invalid B tensor shape.
// Invalid B tensor shape.
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
H_t
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
batch_size
,
hidden_size
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
hidden_size
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
*
hidden_size
});
try
try
{
{
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
W
,
R
,
H_t
,
hidden_size
,
B
);
const
auto
rnn_cell
=
make_shared
<
op
::
RNNCell
>
(
X
,
H_t
,
W
,
R
,
B
,
hidden_size
);
FAIL
()
<<
"RNNCell node was created with invalid data."
;
FAIL
()
<<
"RNNCell node was created with invalid data."
;
}
}
catch
(
const
NodeValidationFailure
&
error
)
catch
(
const
NodeValidationFailure
&
error
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment