Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
6e6c8af4
Commit
6e6c8af4
authored
Feb 26, 2019
by
Adam Rogowiec
Committed by
Michał Karzyński
Feb 26, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] Enhance LSTM support. (#2408)
parent
25c9152f
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
536 additions
and
105 deletions
+536
-105
CMakeLists.txt
src/ngraph/frontend/onnx_import/CMakeLists.txt
+2
-0
node.cpp
src/ngraph/frontend/onnx_import/core/node.cpp
+9
-0
lstm.cpp
src/ngraph/frontend/onnx_import/op/lstm.cpp
+287
-84
matmul.cpp
src/ngraph/frontend/onnx_import/op/matmul.cpp
+1
-1
supported_ops.md
src/ngraph/frontend/onnx_import/op/supported_ops.md
+1
-1
reshape.cpp
src/ngraph/frontend/onnx_import/utils/reshape.cpp
+7
-10
reshape.hpp
src/ngraph/frontend/onnx_import/utils/reshape.hpp
+7
-9
activation_functions.cpp
...h/frontend/onnx_import/utils/rnn/activation_functions.cpp
+71
-0
activation_functions.hpp
...h/frontend/onnx_import/utils/rnn/activation_functions.hpp
+66
-0
lstm_fwd_with_clip.onnx
test/models/onnx/lstm_fwd_with_clip.onnx
+0
-0
onnx_import.in.cpp
test/onnx_import.in.cpp
+85
-0
No files found.
src/ngraph/frontend/onnx_import/CMakeLists.txt
View file @
6e6c8af4
...
...
@@ -177,6 +177,8 @@ add_library(onnx_import STATIC
utils/reduction.hpp
utils/reshape.cpp
utils/reshape.hpp
utils/rnn/activation_functions.cpp
utils/rnn/activation_functions.hpp
utils/variadic.hpp
)
set
(
ONNX_IMPORT_INCLUDE_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
CACHE INTERNAL
""
)
...
...
src/ngraph/frontend/onnx_import/core/node.cpp
View file @
6e6c8af4
...
...
@@ -258,6 +258,15 @@ namespace ngraph
name
,
std
::
move
(
default_value
));
}
template
<>
std
::
vector
<
std
::
string
>
Node
::
get_attribute_value
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
string
>
default_value
)
const
{
return
m_pimpl
->
template
get_attribute_value
<
std
::
vector
<
std
::
string
>>
(
name
,
std
::
move
(
default_value
));
}
template
<>
std
::
vector
<
Tensor
>
Node
::
get_attribute_value
(
const
std
::
string
&
name
,
std
::
vector
<
Tensor
>
default_value
)
const
...
...
src/ngraph/frontend/onnx_import/op/lstm.cpp
View file @
6e6c8af4
...
...
@@ -14,6 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <functional>
...
...
@@ -24,21 +26,30 @@
#include <unordered_map>
#include <vector>
#include "core/null_node.hpp"
#include "exceptions.hpp"
#include "lstm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#include "utils/broadcasting.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
#include "utils/rnn/activation_functions.hpp"
namespace
ngraph
{
...
...
@@ -55,6 +66,13 @@ namespace ngraph
return
{
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
std
::
shared_ptr
<
ngraph
::
Node
>
sub
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
lhs
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
rhs
)
{
auto
args
=
numpy_style_broadcast_for_binary_operation
(
lhs
,
rhs
);
return
{
std
::
make_shared
<
ngraph
::
op
::
Subtract
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
std
::
shared_ptr
<
ngraph
::
Node
>
mul
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
lhs
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
rhs
)
{
...
...
@@ -62,16 +80,38 @@ namespace ngraph
return
{
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ACTIVATION FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
std
::
shared_ptr
<
ngraph
::
Node
>
sigmoid
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
)
std
::
shared_ptr
<
ngraph
::
Node
>
clip
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
data
,
float
threshold
)
{
if
(
threshold
==
0.
f
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Sigmoid
>
(
arg
);
return
data
;
}
float
min_val
=
-
threshold
;
float
max_val
=
threshold
;
std
::
size_t
size
=
ngraph
::
shape_size
(
data
->
get_shape
());
const
std
::
shared_ptr
<
ngraph
::
Node
>
min_val_node
=
ngraph
::
op
::
Constant
::
create
(
data
->
get_element_type
(),
data
->
get_shape
(),
std
::
vector
<
float
>
(
size
,
min_val
));
const
std
::
shared_ptr
<
ngraph
::
Node
>
max_val_node
=
ngraph
::
op
::
Constant
::
create
(
data
->
get_element_type
(),
data
->
get_shape
(),
std
::
vector
<
float
>
(
size
,
max_val
));
return
std
::
make_shared
<
ngraph
::
op
::
Minimum
>
(
max_val_node
,
std
::
make_shared
<
ngraph
::
op
::
Maximum
>
(
data
,
min_val_node
));
}
std
::
shared_ptr
<
ngraph
::
Node
>
tanh
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
)
// Modify input vector in-place and return reference to modified vector.
std
::
vector
<
std
::
string
>&
to_lower_case
(
std
::
vector
<
std
::
string
>&&
vs
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Tanh
>
(
arg
);
std
::
transform
(
std
::
begin
(
vs
),
std
::
end
(
vs
),
std
::
begin
(
vs
),
[](
std
::
string
&
s
)
{
return
ngraph
::
to_lower
(
s
);
});
return
vs
;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
@@ -88,22 +128,6 @@ namespace ngraph
LSTM_INPUT_P
};
std
::
string
to_str
(
const
LSTMInput
&
in
)
{
switch
(
in
)
{
case
LSTMInput
:
:
LSTM_INPUT_X
:
return
"X"
;
case
LSTMInput
:
:
LSTM_INPUT_W
:
return
"W"
;
case
LSTMInput
:
:
LSTM_INPUT_R
:
return
"R"
;
case
LSTMInput
:
:
LSTM_INPUT_B
:
return
"B"
;
case
LSTMInput
:
:
LSTM_INPUT_SEQ_LENGTHS
:
return
"sequence_lens"
;
case
LSTMInput
:
:
LSTM_INPUT_INIT_H
:
return
"initial_h"
;
case
LSTMInput
:
:
LSTM_INPUT_INIT_C
:
return
"initial_c"
;
case
LSTMInput
:
:
LSTM_INPUT_P
:
return
"P"
;
default
:
return
"Unrecognized input value!"
;
}
}
struct
LSTMNgInputMap
{
using
container_type
=
std
::
map
<
LSTMInput
,
std
::
shared_ptr
<
ngraph
::
Node
>>
;
...
...
@@ -134,7 +158,7 @@ namespace ngraph
// ------ Optional inputs ------
// The bias tensor for input gate. Shape [num_directions, 8*hidden_size]
if
(
ng_inputs
.
size
()
>
=
4
)
if
(
ng_inputs
.
size
()
>
3
&&
!
ng_inputs
.
at
(
3
)
->
is_null
()
)
{
m_map
[
LSTMInput
::
LSTM_INPUT_B
]
=
ng_inputs
.
at
(
3
);
}
...
...
@@ -146,21 +170,20 @@ namespace ngraph
{
0.
f
});
}
// The lengths of the sequences in a batch. Shape [batch_size]
if
(
ng_inputs
.
size
()
>
=
5
)
if
(
ng_inputs
.
size
()
>
4
&&
!
ng_inputs
.
at
(
4
)
->
is_null
()
)
{
m_map
[
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
]
=
ng_inputs
.
at
(
4
);
}
else
{
m_map
[
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
]
=
common
::
make_constant_node
<
std
::
int32_t
>
(
m_map
[
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
]
=
ngraph
::
op
::
Constant
::
create
(
element
::
i32
,
{
batch_size
},
{
static_cast
<
std
::
int32_t
>
(
m_map
[
LSTMInput
::
LSTM_INPUT_X
]
->
get_shape
().
at
(
0
))}
);
Shape
{
batch_size
},
std
::
vector
<
std
::
int32_t
>
(
batch_size
,
m_map
[
LSTMInput
::
LSTM_INPUT_X
]
->
get_shape
().
at
(
0
))
);
}
// The initial value of the hidden. Shape [num_directions, batch_size, hidden_size]
if
(
ng_inputs
.
size
()
>
=
6
)
if
(
ng_inputs
.
size
()
>
5
&&
!
ng_inputs
.
at
(
5
)
->
is_null
()
)
{
m_map
[
LSTMInput
::
LSTM_INPUT_INIT_H
]
=
ng_inputs
.
at
(
5
);
}
...
...
@@ -170,7 +193,7 @@ namespace ngraph
element
::
f32
,
{
num_directions
,
batch_size
,
hidden_size
},
{
0.
f
});
}
// The initial value of the cell. Shape [num_directions, batch_size, hidden_size]
if
(
ng_inputs
.
size
()
>
=
7
)
if
(
ng_inputs
.
size
()
>
6
&&
!
ng_inputs
.
at
(
6
)
->
is_null
()
)
{
m_map
[
LSTMInput
::
LSTM_INPUT_INIT_C
]
=
ng_inputs
.
at
(
6
);
}
...
...
@@ -180,7 +203,7 @@ namespace ngraph
element
::
f32
,
{
num_directions
,
batch_size
,
hidden_size
},
{
0.
f
});
}
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
if
(
ng_inputs
.
size
()
>
=
8
)
if
(
ng_inputs
.
size
()
>
7
&&
!
ng_inputs
.
at
(
7
)
->
is_null
()
)
{
m_map
[
LSTMInput
::
LSTM_INPUT_P
]
=
ng_inputs
.
at
(
7
);
}
...
...
@@ -197,8 +220,6 @@ namespace ngraph
{
return
m_map
.
at
(
key
);
}
iterator
begin
()
{
return
m_map
.
begin
();
}
iterator
end
()
{
return
m_map
.
end
();
}
container_type
m_map
;
};
...
...
@@ -208,48 +229,91 @@ namespace ngraph
{
LSTM_DIRECTION_FORWARD
,
LSTM_DIRECTION_REVERSE
,
LSTM_DIRECTION_BIDIRECTIONAL
LSTM_DIRECTION_BIDIRECTIONAL
,
LSTM_DIRECTION_UNKNOWN
,
};
LSTMDirection
getLSTMDirection
(
const
std
::
string
&
s
)
{
if
(
s
==
"forward"
)
{
return
LSTMDirection
::
LSTM_DIRECTION_FORWARD
;
}
if
(
s
==
"reverse"
)
{
return
LSTMDirection
::
LSTM_DIRECTION_REVERSE
;
}
if
(
s
==
"bidirectional"
)
{
return
LSTMDirection
::
LSTM_DIRECTION_BIDIRECTIONAL
;
}
return
LSTMDirection
::
LSTM_DIRECTION_UNKNOWN
;
}
struct
LSTMAttributes
{
explicit
LSTMAttributes
(
const
Node
&
node
)
:
m_direction
{
LSTMDirection
::
LSTM_DIRECTION_FORWARD
}
,
m_hidden_size
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"hidden_size"
)}
{
:
m_hidden_size
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"hidden_size"
)}
,
m_clip_threshold
{
node
.
get_attribute_value
<
float
>
(
"clip"
,
0.
f
)}
,
m_activations
{
to_lower_case
(
node
.
get_attribute_value
<
std
::
vector
<
std
::
string
>>
(
"activations"
,
{
"sigmoid"
,
"tanh"
,
"tanh"
}))}
,
m_input_forget
{
static_cast
<
bool
>
(
node
.
get_attribute_value
<
std
::
int64_t
>
(
"input_forget"
,
0
))}
{
m_clip_threshold
=
std
::
abs
(
m_clip_threshold
);
std
::
string
direction
{
ngraph
::
to_lower
(
node
.
get_attribute_value
<
std
::
string
>
(
"direction"
,
{
"forward"
}))};
ASSERT_VALID_ARGUMENT
(
node
,
getLSTMDirection
(
direction
)
!=
LSTMDirection
::
LSTM_DIRECTION_UNKNOWN
)
<<
"Provided attribute
\"
direction
\"
value is incorrect: "
<<
direction
;
m_direction
=
getLSTMDirection
(
direction
);
}
// Currently only LSTM_DIRECTION_FORWARD is supported.
LSTMDirection
m_direction
;
std
::
int64_t
m_hidden_size
;
float
m_clip_threshold
;
std
::
vector
<
std
::
string
>
m_activations
;
bool
m_input_forget
;
};
}
// anonymous namespace
namespace
set_1
{
NodeVector
lstm
(
const
Node
&
node
)
{
LSTMNgInputMap
input_map
{
node
};
LSTMAttributes
attributes
{
node
};
if
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_FORWARD
)
{
class
LSTMForward
{
public
:
explicit
LSTMForward
(
std
::
shared_ptr
<
ngraph
::
Node
>
X
,
std
::
shared_ptr
<
ngraph
::
Node
>
W
,
std
::
shared_ptr
<
ngraph
::
Node
>
R
,
std
::
shared_ptr
<
ngraph
::
Node
>
B
,
std
::
shared_ptr
<
ngraph
::
Node
>
P
,
std
::
shared_ptr
<
ngraph
::
Node
>
initial_h
,
std
::
shared_ptr
<
ngraph
::
Node
>
initial_c
,
std
::
shared_ptr
<
ngraph
::
Node
>
seq_lengths
,
rnn
::
ActivationFunction
activation_f
,
rnn
::
ActivationFunction
activation_g
,
rnn
::
ActivationFunction
activation_h
,
bool
input_forget
=
false
,
float
clip_threshold
=
0.
f
)
:
m_X
{
X
}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
for
(
auto
&
ng_in
:
input_map
)
,
m_W
{
reshape
::
squeeze
(
W
)}
,
m_R
{
reshape
::
squeeze
(
R
)}
,
m_B
{
reshape
::
squeeze
(
B
)}
,
m_P
{
reshape
::
squeeze
(
P
)}
,
m_initial_h
{
reshape
::
squeeze
(
initial_h
)}
,
m_initial_c
{
reshape
::
squeeze
(
initial_c
)}
,
m_seq_lengths
{
seq_lengths
}
,
m_activation_f
{
activation_f
}
,
m_activation_g
{
activation_g
}
,
m_activation_h
{
activation_h
}
,
m_input_forget
{
input_forget
}
,
m_clip_threshold
{
clip_threshold
}
{
if
(
ng_in
.
first
!=
LSTMInput
::
LSTM_INPUT_X
&&
ng_in
.
first
!=
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
)
{
ASSERT_VALID_ARGUMENT
(
node
,
ng_in
.
second
->
get_shape
().
at
(
0
)
==
1
)
<<
"Input: { "
<<
to_str
(
ng_in
.
first
)
<<
" } first axis has size different "
"from 1, while direction attribute set to 'forward'."
;
ng_in
.
second
=
reshape
::
squeeze
(
ng_in
.
second
);
}
}
}
NodeVector
run
(
bool
reverse
=
false
)
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
...
...
@@ -284,19 +348,32 @@ namespace ngraph
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector
p_iof
=
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_P
)
,
3
);
NodeVector
p_iof
=
reshape
::
split
(
m_P
,
3
);
const
auto
&
p_i
=
p_iof
.
at
(
0
);
const
auto
&
p_o
=
p_iof
.
at
(
1
);
const
auto
&
p_f
=
p_iof
.
at
(
2
);
std
::
shared_ptr
<
ngraph
::
Node
>
H_t
{
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_H
)};
std
::
shared_ptr
<
ngraph
::
Node
>
C_t
{
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_C
)};
NodeVector
h_list
;
NodeVector
b_W_R
=
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_B
)
,
2
);
NodeVector
b_W_R
=
reshape
::
split
(
m_B
,
2
);
std
::
shared_ptr
<
ngraph
::
Node
>
bias
=
b_W_R
.
at
(
0
)
+
b_W_R
.
at
(
1
);
NodeVector
in_seqs
=
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
)
->
get_shape
().
at
(
0
));
std
::
shared_ptr
<
ngraph
::
Node
>
H_t
=
m_initial_h
;
std
::
shared_ptr
<
ngraph
::
Node
>
C_t
=
m_initial_c
;
if
(
reverse
)
{
m_X
=
std
::
make_shared
<
ngraph
::
op
::
Reverse
>
(
m_X
,
AxisSet
{
0
});
}
NodeVector
in_seqs
{};
if
(
m_X
->
get_shape
().
at
(
0
)
!=
1
)
{
in_seqs
=
reshape
::
split
(
m_X
,
m_X
->
get_shape
().
at
(
0
));
}
else
{
in_seqs
=
NodeVector
{
m_X
};
}
for
(
auto
&
in_x
:
in_seqs
)
{
// remove first empty dim, after above split.
...
...
@@ -309,11 +386,11 @@ namespace ngraph
// * - Denotes dot product.
// Xt*(W^T) -- for [iofc] gates.
auto
Xt_W
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
in_x
,
reshape
::
transpose
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_W
)
));
auto
Xt_W
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
in_x
,
reshape
::
transpose
(
m_W
));
// Ht-1*(R^T) -- for [iofc] gates.
auto
Ht_R
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
H_t
,
reshape
::
transpose
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_R
)
));
auto
Ht_R
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
H_t
,
reshape
::
transpose
(
m_R
));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto
gates
=
add
(
Xt_W
,
add
(
Ht_R
,
bias
));
...
...
@@ -324,15 +401,28 @@ namespace ngraph
auto
c
=
split_gates
.
at
(
3
);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i
=
sigmoid
(
add
(
i
,
mul
(
p_i
,
C_t
)));
i
=
m_activation_f
(
clip
(
add
(
i
,
mul
(
p_i
,
C_t
)),
m_clip_threshold
));
if
(
m_input_forget
)
{
// Couple input with forget gate: 1 - i
f
=
sub
(
ngraph
::
op
::
Constant
::
create
(
i
->
get_element_type
(),
i
->
get_shape
(),
std
::
vector
<
float
>
(
shape_size
(
i
->
get_shape
()),
1.
f
)),
i
);
}
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f
=
sigmoid
(
add
(
f
,
mul
(
p_f
,
C_t
)));
f
=
m_activation_f
(
clip
(
add
(
f
,
mul
(
p_f
,
C_t
)),
m_clip_threshold
));
}
// ft (.) Ct-1 + it (.) ct
auto
C
=
add
(
mul
(
f
,
C_t
),
mul
(
i
,
tanh
(
c
)));
auto
C
=
add
(
mul
(
f
,
C_t
),
mul
(
i
,
m_activation_g
(
clip
(
c
,
m_clip_threshold
))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o
=
sigmoid
(
add
(
o
,
mul
(
p_o
,
C
)
));
o
=
m_activation_f
(
clip
(
add
(
o
,
mul
(
p_o
,
C
)),
m_clip_threshold
));
// ot (.) h(Ct)
auto
H
=
mul
(
o
,
tan
h
(
C
));
auto
H
=
mul
(
o
,
m_activation_
h
(
C
));
h_list
.
push_back
(
H
);
H_t
=
H
;
C_t
=
C
;
...
...
@@ -343,21 +433,134 @@ namespace ngraph
for
(
const
auto
&
ht
:
h_list
)
{
// Expand tensors with empty outermost dim, so we can later concatenate them.
exp_h_list
.
push_back
(
reshape
::
add_empty_axe
s
(
ht
));
exp_h_list
.
push_back
(
reshape
::
expand_dim
s
(
ht
));
}
std
::
shared_ptr
<
ngraph
::
Node
>
Y
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
exp_h_list
,
0
)};
// Get back the original order of the output data.
if
(
reverse
)
{
Y
=
std
::
make_shared
<
ngraph
::
op
::
Reverse
>
(
Y
,
AxisSet
{
0
});
}
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
if
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_FORWARD
)
Y
=
reshape
::
expand_dims
(
Y
,
1
);
// expand C_t so that it has expected shape:
// [num_directions, batch_size, hidden_size]
auto
Y_c
=
reshape
::
expand_dims
(
C_t
);
return
{
Y
,
exp_h_list
.
back
(),
Y_c
};
}
private
:
std
::
shared_ptr
<
ngraph
::
Node
>
m_X
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_W
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_R
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_B
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_P
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_initial_h
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_initial_c
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_seq_lengths
;
rnn
::
ActivationFunction
m_activation_f
;
rnn
::
ActivationFunction
m_activation_g
;
rnn
::
ActivationFunction
m_activation_h
;
// For coupling input and forget gates.
bool
m_input_forget
;
// For clipping cell input in the range [-clip_threshold, clip_threshold].
float
m_clip_threshold
;
};
}
// anonymous namespace
namespace
set_1
{
NodeVector
lstm
(
const
Node
&
node
)
{
Shape
shape
{
Y
->
get_shape
()};
shape
.
insert
(
std
::
next
(
std
::
begin
(
shape
)),
1
);
Y
=
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
Y
,
reshape
::
get_default_axis_vector
(
Y
->
get_shape
().
size
()),
shape
);
LSTMNgInputMap
input_map
{
node
};
LSTMAttributes
attributes
{
node
};
rnn
::
ActivationFunction
activation_f
=
rnn
::
get_activation_func_by_name
(
attributes
.
m_activations
.
at
(
0
));
rnn
::
ActivationFunction
activation_g
=
rnn
::
get_activation_func_by_name
(
attributes
.
m_activations
.
at
(
1
));
rnn
::
ActivationFunction
activation_h
=
rnn
::
get_activation_func_by_name
(
attributes
.
m_activations
.
at
(
2
));
NodeVector
results
;
if
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_FORWARD
||
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_REVERSE
)
{
LSTMForward
lstm_fwd
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_W
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_R
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_B
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_P
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_H
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_C
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
),
activation_f
,
activation_g
,
activation_h
,
attributes
.
m_input_forget
,
attributes
.
m_clip_threshold
);
results
=
lstm_fwd
.
run
(
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_REVERSE
));
}
return
{
Y
,
exp_h_list
.
back
()};
if
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_BIDIRECTIONAL
)
{
// In bidirectional mode weights are stacked together, so we must split them.
NodeVector
W
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_W
),
2
)};
NodeVector
R
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_R
),
2
)};
NodeVector
B
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_B
),
2
)};
NodeVector
P
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_P
),
2
)};
NodeVector
H
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_H
),
2
)};
NodeVector
C
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_C
),
2
)};
LSTMForward
lstm_fwd
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
),
W
.
at
(
0
),
R
.
at
(
0
),
B
.
at
(
0
),
P
.
at
(
0
),
H
.
at
(
0
),
C
.
at
(
0
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
),
activation_f
,
activation_g
,
activation_h
,
attributes
.
m_input_forget
,
attributes
.
m_clip_threshold
);
LSTMForward
lstm_reversed
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
),
W
.
at
(
1
),
R
.
at
(
1
),
B
.
at
(
1
),
P
.
at
(
1
),
H
.
at
(
1
),
C
.
at
(
1
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
),
activation_f
,
activation_g
,
activation_h
,
attributes
.
m_input_forget
,
attributes
.
m_clip_threshold
);
NodeVector
fwd_results
{
lstm_fwd
.
run
()};
NodeVector
rev_results
{
lstm_fwd
.
run
(
true
)};
// Stack together respective outputs from both forward and reverse passess.
std
::
shared_ptr
<
ngraph
::
Node
>
Y
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
fwd_results
.
at
(
0
),
rev_results
.
at
(
0
)},
1
)};
std
::
shared_ptr
<
ngraph
::
Node
>
Y_h
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
fwd_results
.
at
(
1
),
rev_results
.
at
(
1
)},
0
)};
std
::
shared_ptr
<
ngraph
::
Node
>
Y_c
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
fwd_results
.
at
(
2
),
rev_results
.
at
(
2
)},
0
)};
results
=
NodeVector
{
Y
,
Y_h
,
Y_c
};
}
return
results
;
}
}
// namespace set_1
...
...
src/ngraph/frontend/onnx_import/op/matmul.cpp
View file @
6e6c8af4
...
...
@@ -138,7 +138,7 @@ namespace ngraph
// Expand sub_dot result with single empty outermost axis, in order to
// later concatenate sub_dots at this axis.
small_dots
.
at
(
g
)
=
reshape
::
add_empty_axe
s
(
sub_dot
);
small_dots
.
at
(
g
)
=
reshape
::
expand_dim
s
(
sub_dot
);
}
// Concatenate sub_dots on groups axis.
...
...
src/ngraph/frontend/onnx_import/op/supported_ops.md
View file @
6e6c8af4
...
...
@@ -112,7 +112,7 @@ opset versions starting from `1` to `6` and to the latest opset version.
|------|-----------------|--------|--------|---------|
| Erf | (9) | 284 | 442 | Need separate kernel for this in nGraph core. |
| Pad | 1-2- | 273 | 416 | Not fully supported. |
| LSTM | 1-7- | | 4
30 | Not fully supported
. |
| LSTM | 1-7- | | 4
76 | Mixed sequences length not supported yet
. |
| MaxUnpool | (9) | 286, 289 | 447 | |
| LpPool | - | 291 | 437 | Unsupported by nGraph - only max/avg pooling ops. Need separate kernel. |
| Multinomial | - | 199 | 435 | Lack of PRNG in nGraph. |
...
...
src/ngraph/frontend/onnx_import/utils/reshape.cpp
View file @
6e6c8af4
...
...
@@ -221,17 +221,14 @@ namespace ngraph
node
,
get_default_axis_vector
(
node
->
get_shape
().
size
()),
shape
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
add_empty_axes
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
size_t
outermost_axes_count
,
std
::
size_t
innermost_axes_count
)
std
::
shared_ptr
<
ngraph
::
Node
>
expand_dims
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
size_t
axis
)
{
// Add outermost empty dimensions.
Shape
output_shape
(
outermost_axes_count
,
1
);
output_shape
.
insert
(
std
::
end
(
output_shape
),
std
::
begin
(
node
->
get_shape
()),
std
::
end
(
node
->
get_shape
()));
// Add innermost empty dimensions.
output_shape
.
insert
(
std
::
end
(
output_shape
),
innermost_axes_count
,
1
);
Shape
output_shape
(
node
->
get_shape
());
// Add empty axis at specified position.
auto
empty_axis_it
=
std
::
begin
(
output_shape
);
std
::
advance
(
empty_axis_it
,
axis
);
output_shape
.
insert
(
empty_axis_it
,
1
);
return
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
node
,
reshape
::
get_default_axis_vector
(
node
->
get_shape
().
size
()),
output_shape
);
}
...
...
src/ngraph/frontend/onnx_import/utils/reshape.hpp
View file @
6e6c8af4
...
...
@@ -127,19 +127,17 @@ namespace ngraph
return
reshape
(
node
,
get_default_axis_vector
(
node
->
get_shape
().
size
()),
shape
);
}
/// \brief Expands node tensor shape with empty axes.
/// \brief Expands node tensor shape with empty axis at
/// specified position.
///
/// \param[in] node The node to be expanded.
/// \param[in] outermost_axes_count The number of added outermost axes.
/// At the front of the shape.
/// \param[in] innermost_axes_count The number of added innermost axes.
/// At the end of the shape.
/// \param[in] axis The position in the expanded axes where the
/// new axis is placed.
///
/// \return The node with added empty ax
e
s.
/// \return The node with added empty ax
i
s.
///
std
::
shared_ptr
<
ngraph
::
Node
>
add_empty_axes
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
size_t
outermost_axes_count
=
1
,
std
::
size_t
innermost_axes_count
=
0
);
std
::
shared_ptr
<
ngraph
::
Node
>
expand_dims
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
size_t
axis
=
0
);
/// \brief Split node on specified axis into multiple parts.
///
...
...
src/ngraph/frontend/onnx_import/utils/rnn/activation_functions.cpp
0 → 100644
View file @
6e6c8af4
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <iterator>
#include <unordered_map>
#include "activation_functions.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
rnn
{
namespace
detail
{
std
::
shared_ptr
<
ngraph
::
Node
>
sigmoid
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Sigmoid
>
(
arg
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
tanh
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Tanh
>
(
arg
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
relu
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Relu
>
(
arg
);
}
}
// namespace detail
ActivationFunction
get_activation_func_by_name
(
const
std
::
string
&
func_name
)
{
using
ActivationFunctionMap
=
std
::
unordered_map
<
std
::
string
,
ActivationFunction
>
;
static
ActivationFunctionMap
func_map
{
{
"sigmoid"
,
std
::
bind
(
detail
::
sigmoid
,
std
::
placeholders
::
_1
)},
{
"tanh"
,
std
::
bind
(
detail
::
tanh
,
std
::
placeholders
::
_1
)},
{
"relu"
,
std
::
bind
(
detail
::
relu
,
std
::
placeholders
::
_1
)}};
auto
func_it
=
func_map
.
find
(
func_name
);
if
(
func_it
==
std
::
end
(
func_map
))
{
throw
error
::
UnknownActivationFunction
(
func_name
);
}
return
func_it
->
second
;
}
}
//namespace rnn
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/utils/rnn/activation_functions.hpp
0 → 100644
View file @
6e6c8af4
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include <string>
#include "ngraph/except.hpp"
#include "ngraph/node.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
rnn
{
namespace
error
{
struct
UnknownActivationFunction
:
ngraph_error
{
UnknownActivationFunction
(
const
std
::
string
&
func_name
)
:
ngraph_error
{
"Unknown activation function: "
+
func_name
}
{
}
};
}
namespace
detail
{
std
::
shared_ptr
<
ngraph
::
Node
>
sigmoid
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
);
std
::
shared_ptr
<
ngraph
::
Node
>
tanh
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
);
std
::
shared_ptr
<
ngraph
::
Node
>
relu
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
);
}
using
ActivationFunction
=
std
::
function
<
std
::
shared_ptr
<
ngraph
::
Node
>
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
)
>
;
/// \brief Gets the activation function by name.
///
/// \param[in] func_name The function name
///
/// \throws UnknownActivationFunction When provided func_name is unknown.
///
/// \return The activation function object.
///
ActivationFunction
get_activation_func_by_name
(
const
std
::
string
&
func_name
);
}
//namespace rnn
}
// namespace onnx_import
}
// namespace ngraph
test/models/onnx/lstm_fwd_with_clip.onnx
0 → 100644
View file @
6e6c8af4
File added
test/onnx_import.in.cpp
View file @
6e6c8af4
...
...
@@ -1864,6 +1864,91 @@ TEST(onnx_${BACKEND_NAME}, model_top_k)
EXPECT_TRUE
(
test
::
all_close
(
expected_indices_output
,
indices_output
));
}
TEST
(
onnx_
$
{
BACKEND_NAME
},
model_lstm_fwd_with_clip
)
{
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/lstm_fwd_with_clip.onnx"
));
Inputs
inputs
{};
// X
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
-
0.455351
,
-
0.276391
,
-
0.185934
,
-
0.269585
});
// W
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
-
0.494659
f
,
0.0453352
f
,
-
0.487793
f
,
0.417264
f
,
-
0.0175329
f
,
0.489074
f
,
-
0.446013
f
,
0.414029
f
,
-
0.0091708
f
,
-
0.255364
f
,
-
0.106952
f
,
-
0.266717
f
,
-
0.0888852
f
,
-
0.428709
f
,
-
0.283349
f
,
0.208792
f
});
// R
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
0.146626
f
,
-
0.0620289
f
,
-
0.0815302
f
,
0.100482
f
,
-
0.219535
f
,
-
0.306635
f
,
-
0.28515
f
,
-
0.314112
f
,
-
0.228172
f
,
0.405972
f
,
0.31576
f
,
0.281487
f
,
-
0.394864
f
,
0.42111
f
,
-
0.386624
f
,
-
0.390225
f
});
// B
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
0.381619
f
,
0.0323954
f
,
-
0.14449
f
,
0.420804
f
,
-
0.258721
f
,
0.45056
f
,
-
0.250755
f
,
0.0967895
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
});
// P
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
0.2345
f
,
0.5235
f
,
0.4378
f
,
0.3475
f
,
0.8927
f
,
0.3456
f
});
Outputs
expected_output
{};
// Y_data
expected_output
.
emplace_back
(
std
::
vector
<
float
>
{
-
0.02280854
f
,
0.02744377
f
,
-
0.03516197
f
,
0.03875681
f
});
// Y_h_data
expected_output
.
emplace_back
(
std
::
vector
<
float
>
{
-
0.03516197
f
,
0.03875681
f
});
// Y_c_data
expected_output
.
emplace_back
(
std
::
vector
<
float
>
{
-
0.07415761
f
,
0.07395997
f
});
Outputs
outputs
{
execute
(
function
,
inputs
,
"${BACKEND_NAME}"
)};
EXPECT_TRUE
(
outputs
.
size
()
==
expected_output
.
size
());
for
(
std
::
size_t
i
{
0
};
i
<
expected_output
.
size
();
++
i
)
{
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
EXPECT_TRUE
(
test
::
all_close_f
(
expected_output
.
at
(
i
),
outputs
.
at
(
i
),
3
));
}
}
TEST
(
onnx_
$
{
BACKEND_NAME
},
model_missing_input
)
{
onnx_import
::
register_operator
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment