Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
5761f145
Commit
5761f145
authored
May 20, 2019
by
Adam Rogowiec
Committed by
arogowie-intel
May 22, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update LSTM ONNX operator to use LSTMCell fused op.
parent
ef1c5347
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
214 deletions
+50
-214
lstm.cpp
src/ngraph/frontend/onnx_import/op/lstm.cpp
+50
-214
No files found.
src/ngraph/frontend/onnx_import/op/lstm.cpp
View file @
5761f145
...
...
@@ -14,46 +14,32 @@
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "core/null_node.hpp"
#include "exceptions.hpp"
#include "lstm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#include "utils/reshape.hpp"
#include "utils/rnn/activation_functions.hpp"
namespace
ngraph
{
...
...
@@ -63,61 +49,6 @@ namespace ngraph
{
namespace
{
std
::
shared_ptr
<
ngraph
::
Node
>
add
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
lhs
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
rhs
)
{
auto
args
=
ngraph
::
op
::
numpy_style_broadcast
({
lhs
,
rhs
});
return
{
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
std
::
shared_ptr
<
ngraph
::
Node
>
sub
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
lhs
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
rhs
)
{
auto
args
=
ngraph
::
op
::
numpy_style_broadcast
({
lhs
,
rhs
});
return
{
std
::
make_shared
<
ngraph
::
op
::
Subtract
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
std
::
shared_ptr
<
ngraph
::
Node
>
mul
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
lhs
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
rhs
)
{
auto
args
=
ngraph
::
op
::
numpy_style_broadcast
({
lhs
,
rhs
});
return
{
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
std
::
shared_ptr
<
ngraph
::
Node
>
clip
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
data
,
float
threshold
)
{
if
(
threshold
==
0.
f
)
{
return
data
;
}
float
min_val
=
-
threshold
;
float
max_val
=
threshold
;
std
::
size_t
size
=
ngraph
::
shape_size
(
data
->
get_shape
());
const
std
::
shared_ptr
<
ngraph
::
Node
>
min_val_node
=
ngraph
::
op
::
Constant
::
create
(
data
->
get_element_type
(),
data
->
get_shape
(),
std
::
vector
<
float
>
(
size
,
min_val
));
const
std
::
shared_ptr
<
ngraph
::
Node
>
max_val_node
=
ngraph
::
op
::
Constant
::
create
(
data
->
get_element_type
(),
data
->
get_shape
(),
std
::
vector
<
float
>
(
size
,
max_val
));
return
std
::
make_shared
<
ngraph
::
op
::
Minimum
>
(
max_val_node
,
std
::
make_shared
<
ngraph
::
op
::
Maximum
>
(
data
,
min_val_node
));
}
// Modify input vector in-place and return reference to modified vector.
std
::
vector
<
std
::
string
>&
to_lower_case
(
std
::
vector
<
std
::
string
>&&
vs
)
{
std
::
transform
(
std
::
begin
(
vs
),
std
::
end
(
vs
),
std
::
begin
(
vs
),
[](
std
::
string
&
s
)
{
return
ngraph
::
to_lower
(
s
);
});
return
vs
;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
enum
class
LSTMInput
...
...
@@ -257,9 +188,8 @@ namespace ngraph
explicit
LSTMAttributes
(
const
Node
&
node
)
:
m_hidden_size
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"hidden_size"
)}
,
m_clip_threshold
{
node
.
get_attribute_value
<
float
>
(
"clip"
,
0.
f
)}
,
m_activations
{
to_lower_case
(
node
.
get_attribute_value
<
std
::
vector
<
std
::
string
>>
(
"activations"
,
{
"sigmoid"
,
"tanh"
,
"tanh"
}))}
,
m_activations
{
node
.
get_attribute_value
<
std
::
vector
<
std
::
string
>>
(
"activations"
,
{
"sigmoid"
,
"tanh"
,
"tanh"
})}
// Default values for activation functions are same as for corresponding
// ONNX operator.
,
m_activation_alpha
{
node
.
get_attribute_value
<
std
::
vector
<
float
>>
(
...
...
@@ -293,20 +223,17 @@ namespace ngraph
{
public
:
explicit
LSTMForward
(
std
::
shared_ptr
<
ngraph
::
Node
>
X
,
std
::
shared_ptr
<
ngraph
::
Node
>
W
,
std
::
shared_ptr
<
ngraph
::
Node
>
R
,
std
::
shared_ptr
<
ngraph
::
Node
>
B
,
std
::
shared_ptr
<
ngraph
::
Node
>
P
,
std
::
shared_ptr
<
ngraph
::
Node
>
initial_h
,
std
::
shared_ptr
<
ngraph
::
Node
>
initial_c
,
std
::
shared_ptr
<
ngraph
::
Node
>
seq_lengths
,
rnn
::
ActivationFunction
activation_f
,
rnn
::
ActivationFunction
activation_g
,
rnn
::
ActivationFunction
activation_h
,
bool
input_forget
=
false
,
float
clip_threshold
=
0.
f
)
const
std
::
shared_ptr
<
ngraph
::
Node
>&
W
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
R
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
B
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
P
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
initial_h
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
initial_c
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
seq_lengths
,
const
LSTMAttributes
&
attributes
)
:
m_X
{
X
}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
,
,
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
,
m_W
{
reshape
::
squeeze
(
W
)}
,
m_R
{
reshape
::
squeeze
(
R
)}
,
m_B
{
reshape
::
squeeze
(
B
)}
...
...
@@ -314,11 +241,7 @@ namespace ngraph
,
m_initial_h
{
reshape
::
squeeze
(
initial_h
)}
,
m_initial_c
{
reshape
::
squeeze
(
initial_c
)}
,
m_seq_lengths
{
seq_lengths
}
,
m_activation_f
{
activation_f
}
,
m_activation_g
{
activation_g
}
,
m_activation_h
{
activation_h
}
,
m_input_forget
{
input_forget
}
,
m_clip_threshold
{
clip_threshold
}
,
m_attributes
{
attributes
}
{
}
...
...
@@ -332,7 +255,7 @@ namespace ngraph
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor for
r
peepholes. [num_directions, 3*hidde_size]
// P - The weight tensor for peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
...
...
@@ -340,32 +263,11 @@ namespace ngraph
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// b_W_R - Bias vectors for input, output, forget, and cell gates.
// Concatenation of `[Wb, Rb]`.
// p_[iof] - P peephole weight vector for respectively: input, output,
// and forget gates.
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
//
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector
p_iof
=
ngraph
::
builder
::
split
(
m_P
,
3
);
const
auto
&
p_i
=
p_iof
.
at
(
0
);
const
auto
&
p_o
=
p_iof
.
at
(
1
);
const
auto
&
p_f
=
p_iof
.
at
(
2
);
NodeVector
h_list
;
NodeVector
b_W_R
=
ngraph
::
builder
::
split
(
m_B
,
2
);
std
::
shared_ptr
<
ngraph
::
Node
>
bias
=
b_W_R
.
at
(
0
)
+
b_W_R
.
at
(
1
);
NodeVector
h_list
;
std
::
shared_ptr
<
ngraph
::
Node
>
H_t
=
m_initial_h
;
std
::
shared_ptr
<
ngraph
::
Node
>
C_t
=
m_initial_c
;
...
...
@@ -393,47 +295,26 @@ namespace ngraph
std
::
int32_t
time_step
{
1
};
for
(
const
auto
&
in_x
:
in_seqs
)
{
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// Xt*(W^T) -- for [iofc] gates.
auto
Xt_W
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
in_x
,
ngraph
::
op
::
util
::
transpose
(
m_W
));
// Ht-1*(R^T) -- for [iofc] gates.
auto
Ht_R
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
H_t
,
ngraph
::
op
::
util
::
transpose
(
m_R
));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto
gates
=
add
(
Xt_W
,
add
(
Ht_R
,
bias
));
NodeVector
split_gates
=
ngraph
::
builder
::
split
(
gates
,
4
,
-
1
);
auto
i
=
split_gates
.
at
(
0
);
auto
o
=
split_gates
.
at
(
1
);
auto
f
=
split_gates
.
at
(
2
);
auto
c
=
split_gates
.
at
(
3
);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i
=
m_activation_f
(
clip
(
add
(
i
,
mul
(
p_i
,
C_t
)),
m_clip_threshold
));
if
(
m_input_forget
)
{
// Couple input with forget gate: 1 - i
f
=
sub
(
ngraph
::
op
::
Constant
::
create
(
i
->
get_element_type
(),
i
->
get_shape
(),
std
::
vector
<
float
>
(
shape_size
(
i
->
get_shape
()),
1.
f
)),
i
);
}
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f
=
m_activation_f
(
clip
(
add
(
f
,
mul
(
p_f
,
C_t
)),
m_clip_threshold
));
}
// ft (.) Ct-1 + it (.) ct
auto
C
=
add
(
mul
(
f
,
C_t
),
mul
(
i
,
m_activation_g
(
clip
(
c
,
m_clip_threshold
))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o
=
m_activation_f
(
clip
(
add
(
o
,
mul
(
p_o
,
C
)),
m_clip_threshold
));
// ot (.) h(Ct)
auto
H
=
mul
(
o
,
m_activation_h
(
C
));
const
std
::
shared_ptr
<
ngraph
::
Node
>&
lstm_cell
=
std
::
make_shared
<
ngraph
::
op
::
LSTMCell
>
(
in_x
,
m_W
,
m_R
,
H_t
,
C_t
,
m_attributes
.
m_hidden_size
,
m_B
,
m_P
,
m_attributes
.
m_activations
,
m_attributes
.
m_activation_alpha
,
m_attributes
.
m_activation_beta
,
m_attributes
.
m_clip_threshold
,
m_attributes
.
m_input_forget
);
const
std
::
shared_ptr
<
ngraph
::
Node
>&
H
=
get_output_element
(
lstm_cell
,
0
);
const
std
::
shared_ptr
<
ngraph
::
Node
>&
C
=
get_output_element
(
lstm_cell
,
1
);
// Expand tensors with empty outermost dim, so we can later concatenate
// them.
...
...
@@ -528,41 +409,16 @@ namespace ngraph
}
std
::
shared_ptr
<
ngraph
::
Node
>
m_X
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_W
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_R
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_B
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_P
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_initial_h
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_initial_c
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_seq_lengths
;
rnn
::
ActivationFunction
m_activation_f
;
rnn
::
ActivationFunction
m_activation_g
;
rnn
::
ActivationFunction
m_activation_h
;
// For coupling input and forget gates.
bool
m_input_forget
;
// For clipping cell input in the range [-clip_threshold, clip_threshold].
float
m_clip_threshold
;
const
std
::
shared_ptr
<
ngraph
::
Node
>&
m_W
;
const
std
::
shared_ptr
<
ngraph
::
Node
>&
m_R
;
const
std
::
shared_ptr
<
ngraph
::
Node
>&
m_B
;
const
std
::
shared_ptr
<
ngraph
::
Node
>&
m_P
;
const
std
::
shared_ptr
<
ngraph
::
Node
>&
m_initial_h
;
const
std
::
shared_ptr
<
ngraph
::
Node
>&
m_initial_c
;
const
std
::
shared_ptr
<
ngraph
::
Node
>&
m_seq_lengths
;
const
LSTMAttributes
&
m_attributes
;
};
rnn
::
ActivationFunction
get_activation_function
(
const
LSTMAttributes
&
attributes
,
std
::
size_t
idx
)
{
rnn
::
ActivationFunction
afunc
=
rnn
::
get_activation_func_by_name
(
attributes
.
m_activations
.
at
(
idx
));
// Set activation functions parameters (if any)
if
(
attributes
.
m_activation_alpha
.
size
()
>
idx
)
{
afunc
.
set_alpha
(
attributes
.
m_activation_alpha
.
at
(
idx
));
}
if
(
attributes
.
m_activation_beta
.
size
()
>
idx
)
{
afunc
.
set_beta
(
attributes
.
m_activation_beta
.
at
(
idx
));
}
return
afunc
;
}
}
// anonymous namespace
namespace
set_1
...
...
@@ -572,14 +428,6 @@ namespace ngraph
LSTMNgInputMap
input_map
{
node
};
LSTMAttributes
attributes
{
node
};
// Get activation functions.
const
rnn
::
ActivationFunction
&
activation_f
=
get_activation_function
(
attributes
,
0
);
const
rnn
::
ActivationFunction
&
activation_g
=
get_activation_function
(
attributes
,
1
);
const
rnn
::
ActivationFunction
&
activation_h
=
get_activation_function
(
attributes
,
2
);
NodeVector
results
;
if
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_FORWARD
||
...
...
@@ -593,11 +441,7 @@ namespace ngraph
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_H
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_C
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
),
activation_f
,
activation_g
,
activation_h
,
attributes
.
m_input_forget
,
attributes
.
m_clip_threshold
);
attributes
);
results
=
lstm_fwd
.
run
(
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_REVERSE
));
}
...
...
@@ -625,11 +469,7 @@ namespace ngraph
H
.
at
(
0
),
C
.
at
(
0
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
),
activation_f
,
activation_g
,
activation_h
,
attributes
.
m_input_forget
,
attributes
.
m_clip_threshold
);
attributes
);
LSTMForward
lstm_reversed
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
),
W
.
at
(
1
),
R
.
at
(
1
),
...
...
@@ -638,11 +478,7 @@ namespace ngraph
H
.
at
(
1
),
C
.
at
(
1
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
),
activation_f
,
activation_g
,
activation_h
,
attributes
.
m_input_forget
,
attributes
.
m_clip_threshold
);
attributes
);
NodeVector
fwd_results
{
lstm_fwd
.
run
()};
NodeVector
rev_results
{
lstm_fwd
.
run
(
true
)};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment