Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
6e6c8af4
Commit
6e6c8af4
authored
Feb 26, 2019
by
Adam Rogowiec
Committed by
Michał Karzyński
Feb 26, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] Enhance LSTM support. (#2408)
parent
25c9152f
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
610 additions
and
179 deletions
+610
-179
CMakeLists.txt
src/ngraph/frontend/onnx_import/CMakeLists.txt
+2
-0
node.cpp
src/ngraph/frontend/onnx_import/core/node.cpp
+9
-0
lstm.cpp
src/ngraph/frontend/onnx_import/op/lstm.cpp
+360
-157
matmul.cpp
src/ngraph/frontend/onnx_import/op/matmul.cpp
+1
-1
supported_ops.md
src/ngraph/frontend/onnx_import/op/supported_ops.md
+1
-1
reshape.cpp
src/ngraph/frontend/onnx_import/utils/reshape.cpp
+7
-10
reshape.hpp
src/ngraph/frontend/onnx_import/utils/reshape.hpp
+8
-10
activation_functions.cpp
...h/frontend/onnx_import/utils/rnn/activation_functions.cpp
+71
-0
activation_functions.hpp
...h/frontend/onnx_import/utils/rnn/activation_functions.hpp
+66
-0
lstm_fwd_with_clip.onnx
test/models/onnx/lstm_fwd_with_clip.onnx
+0
-0
onnx_import.in.cpp
test/onnx_import.in.cpp
+85
-0
No files found.
src/ngraph/frontend/onnx_import/CMakeLists.txt
View file @
6e6c8af4
...
@@ -177,6 +177,8 @@ add_library(onnx_import STATIC
...
@@ -177,6 +177,8 @@ add_library(onnx_import STATIC
utils/reduction.hpp
utils/reduction.hpp
utils/reshape.cpp
utils/reshape.cpp
utils/reshape.hpp
utils/reshape.hpp
utils/rnn/activation_functions.cpp
utils/rnn/activation_functions.hpp
utils/variadic.hpp
)
utils/variadic.hpp
)
set
(
ONNX_IMPORT_INCLUDE_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
CACHE INTERNAL
""
)
set
(
ONNX_IMPORT_INCLUDE_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
CACHE INTERNAL
""
)
...
...
src/ngraph/frontend/onnx_import/core/node.cpp
View file @
6e6c8af4
...
@@ -258,6 +258,15 @@ namespace ngraph
...
@@ -258,6 +258,15 @@ namespace ngraph
name
,
std
::
move
(
default_value
));
name
,
std
::
move
(
default_value
));
}
}
template
<>
std
::
vector
<
std
::
string
>
Node
::
get_attribute_value
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
string
>
default_value
)
const
{
return
m_pimpl
->
template
get_attribute_value
<
std
::
vector
<
std
::
string
>>
(
name
,
std
::
move
(
default_value
));
}
template
<>
template
<>
std
::
vector
<
Tensor
>
Node
::
get_attribute_value
(
const
std
::
string
&
name
,
std
::
vector
<
Tensor
>
Node
::
get_attribute_value
(
const
std
::
string
&
name
,
std
::
vector
<
Tensor
>
default_value
)
const
std
::
vector
<
Tensor
>
default_value
)
const
...
...
src/ngraph/frontend/onnx_import/op/lstm.cpp
View file @
6e6c8af4
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
// limitations under the License.
// limitations under the License.
//*****************************************************************************
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstddef>
#include <cstdint>
#include <cstdint>
#include <functional>
#include <functional>
...
@@ -24,21 +26,30 @@
...
@@ -24,21 +26,30 @@
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "core/null_node.hpp"
#include "exceptions.hpp"
#include "exceptions.hpp"
#include "lstm.hpp"
#include "lstm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#include "utils/broadcasting.hpp"
#include "utils/broadcasting.hpp"
#include "utils/common.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
#include "utils/reshape.hpp"
#include "utils/rnn/activation_functions.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -55,6 +66,13 @@ namespace ngraph
...
@@ -55,6 +66,13 @@ namespace ngraph
return
{
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
return
{
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
sub
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
lhs
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
rhs
)
{
auto
args
=
numpy_style_broadcast_for_binary_operation
(
lhs
,
rhs
);
return
{
std
::
make_shared
<
ngraph
::
op
::
Subtract
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
std
::
shared_ptr
<
ngraph
::
Node
>
mul
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
lhs
,
std
::
shared_ptr
<
ngraph
::
Node
>
mul
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
lhs
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
rhs
)
const
std
::
shared_ptr
<
ngraph
::
Node
>&
rhs
)
{
{
...
@@ -62,16 +80,38 @@ namespace ngraph
...
@@ -62,16 +80,38 @@ namespace ngraph
return
{
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
return
{
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ACTIVATION FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
std
::
shared_ptr
<
ngraph
::
Node
>
clip
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
data
,
float
threshold
)
std
::
shared_ptr
<
ngraph
::
Node
>
sigmoid
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
)
{
{
return
std
::
make_shared
<
ngraph
::
op
::
Sigmoid
>
(
arg
);
if
(
threshold
==
0.
f
)
{
return
data
;
}
float
min_val
=
-
threshold
;
float
max_val
=
threshold
;
std
::
size_t
size
=
ngraph
::
shape_size
(
data
->
get_shape
());
const
std
::
shared_ptr
<
ngraph
::
Node
>
min_val_node
=
ngraph
::
op
::
Constant
::
create
(
data
->
get_element_type
(),
data
->
get_shape
(),
std
::
vector
<
float
>
(
size
,
min_val
));
const
std
::
shared_ptr
<
ngraph
::
Node
>
max_val_node
=
ngraph
::
op
::
Constant
::
create
(
data
->
get_element_type
(),
data
->
get_shape
(),
std
::
vector
<
float
>
(
size
,
max_val
));
return
std
::
make_shared
<
ngraph
::
op
::
Minimum
>
(
max_val_node
,
std
::
make_shared
<
ngraph
::
op
::
Maximum
>
(
data
,
min_val_node
));
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
tanh
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
)
// Modify input vector in-place and return reference to modified vector.
std
::
vector
<
std
::
string
>&
to_lower_case
(
std
::
vector
<
std
::
string
>&&
vs
)
{
{
return
std
::
make_shared
<
ngraph
::
op
::
Tanh
>
(
arg
);
std
::
transform
(
std
::
begin
(
vs
),
std
::
end
(
vs
),
std
::
begin
(
vs
),
[](
std
::
string
&
s
)
{
return
ngraph
::
to_lower
(
s
);
});
return
vs
;
}
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
@@ -88,22 +128,6 @@ namespace ngraph
...
@@ -88,22 +128,6 @@ namespace ngraph
LSTM_INPUT_P
LSTM_INPUT_P
};
};
std
::
string
to_str
(
const
LSTMInput
&
in
)
{
switch
(
in
)
{
case
LSTMInput
:
:
LSTM_INPUT_X
:
return
"X"
;
case
LSTMInput
:
:
LSTM_INPUT_W
:
return
"W"
;
case
LSTMInput
:
:
LSTM_INPUT_R
:
return
"R"
;
case
LSTMInput
:
:
LSTM_INPUT_B
:
return
"B"
;
case
LSTMInput
:
:
LSTM_INPUT_SEQ_LENGTHS
:
return
"sequence_lens"
;
case
LSTMInput
:
:
LSTM_INPUT_INIT_H
:
return
"initial_h"
;
case
LSTMInput
:
:
LSTM_INPUT_INIT_C
:
return
"initial_c"
;
case
LSTMInput
:
:
LSTM_INPUT_P
:
return
"P"
;
default
:
return
"Unrecognized input value!"
;
}
}
struct
LSTMNgInputMap
struct
LSTMNgInputMap
{
{
using
container_type
=
std
::
map
<
LSTMInput
,
std
::
shared_ptr
<
ngraph
::
Node
>>
;
using
container_type
=
std
::
map
<
LSTMInput
,
std
::
shared_ptr
<
ngraph
::
Node
>>
;
...
@@ -134,7 +158,7 @@ namespace ngraph
...
@@ -134,7 +158,7 @@ namespace ngraph
// ------ Optional inputs ------
// ------ Optional inputs ------
// The bias tensor for input gate. Shape [num_directions, 8*hidden_size]
// The bias tensor for input gate. Shape [num_directions, 8*hidden_size]
if
(
ng_inputs
.
size
()
>
=
4
)
if
(
ng_inputs
.
size
()
>
3
&&
!
ng_inputs
.
at
(
3
)
->
is_null
()
)
{
{
m_map
[
LSTMInput
::
LSTM_INPUT_B
]
=
ng_inputs
.
at
(
3
);
m_map
[
LSTMInput
::
LSTM_INPUT_B
]
=
ng_inputs
.
at
(
3
);
}
}
...
@@ -146,21 +170,20 @@ namespace ngraph
...
@@ -146,21 +170,20 @@ namespace ngraph
{
0.
f
});
{
0.
f
});
}
}
// The lengths of the sequences in a batch. Shape [batch_size]
// The lengths of the sequences in a batch. Shape [batch_size]
if
(
ng_inputs
.
size
()
>
=
5
)
if
(
ng_inputs
.
size
()
>
4
&&
!
ng_inputs
.
at
(
4
)
->
is_null
()
)
{
{
m_map
[
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
]
=
ng_inputs
.
at
(
4
);
m_map
[
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
]
=
ng_inputs
.
at
(
4
);
}
}
else
else
{
{
m_map
[
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
]
=
m_map
[
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
]
=
ngraph
::
op
::
Constant
::
create
(
common
::
make_constant_node
<
std
::
int32_t
>
(
element
::
i32
,
element
::
i32
,
Shape
{
batch_size
},
{
batch_size
},
std
::
vector
<
std
::
int32_t
>
(
{
static_cast
<
std
::
int32_t
>
(
batch_size
,
m_map
[
LSTMInput
::
LSTM_INPUT_X
]
->
get_shape
().
at
(
0
)));
m_map
[
LSTMInput
::
LSTM_INPUT_X
]
->
get_shape
().
at
(
0
))});
}
}
// The initial value of the hidden. Shape [num_directions, batch_size, hidden_size]
// The initial value of the hidden. Shape [num_directions, batch_size, hidden_size]
if
(
ng_inputs
.
size
()
>
=
6
)
if
(
ng_inputs
.
size
()
>
5
&&
!
ng_inputs
.
at
(
5
)
->
is_null
()
)
{
{
m_map
[
LSTMInput
::
LSTM_INPUT_INIT_H
]
=
ng_inputs
.
at
(
5
);
m_map
[
LSTMInput
::
LSTM_INPUT_INIT_H
]
=
ng_inputs
.
at
(
5
);
}
}
...
@@ -170,7 +193,7 @@ namespace ngraph
...
@@ -170,7 +193,7 @@ namespace ngraph
element
::
f32
,
{
num_directions
,
batch_size
,
hidden_size
},
{
0.
f
});
element
::
f32
,
{
num_directions
,
batch_size
,
hidden_size
},
{
0.
f
});
}
}
// The initial value of the cell. Shape [num_directions, batch_size, hidden_size]
// The initial value of the cell. Shape [num_directions, batch_size, hidden_size]
if
(
ng_inputs
.
size
()
>
=
7
)
if
(
ng_inputs
.
size
()
>
6
&&
!
ng_inputs
.
at
(
6
)
->
is_null
()
)
{
{
m_map
[
LSTMInput
::
LSTM_INPUT_INIT_C
]
=
ng_inputs
.
at
(
6
);
m_map
[
LSTMInput
::
LSTM_INPUT_INIT_C
]
=
ng_inputs
.
at
(
6
);
}
}
...
@@ -180,7 +203,7 @@ namespace ngraph
...
@@ -180,7 +203,7 @@ namespace ngraph
element
::
f32
,
{
num_directions
,
batch_size
,
hidden_size
},
{
0.
f
});
element
::
f32
,
{
num_directions
,
batch_size
,
hidden_size
},
{
0.
f
});
}
}
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
if
(
ng_inputs
.
size
()
>
=
8
)
if
(
ng_inputs
.
size
()
>
7
&&
!
ng_inputs
.
at
(
7
)
->
is_null
()
)
{
{
m_map
[
LSTMInput
::
LSTM_INPUT_P
]
=
ng_inputs
.
at
(
7
);
m_map
[
LSTMInput
::
LSTM_INPUT_P
]
=
ng_inputs
.
at
(
7
);
}
}
...
@@ -197,8 +220,6 @@ namespace ngraph
...
@@ -197,8 +220,6 @@ namespace ngraph
{
{
return
m_map
.
at
(
key
);
return
m_map
.
at
(
key
);
}
}
iterator
begin
()
{
return
m_map
.
begin
();
}
iterator
end
()
{
return
m_map
.
end
();
}
container_type
m_map
;
container_type
m_map
;
};
};
...
@@ -208,20 +229,248 @@ namespace ngraph
...
@@ -208,20 +229,248 @@ namespace ngraph
{
{
LSTM_DIRECTION_FORWARD
,
LSTM_DIRECTION_FORWARD
,
LSTM_DIRECTION_REVERSE
,
LSTM_DIRECTION_REVERSE
,
LSTM_DIRECTION_BIDIRECTIONAL
LSTM_DIRECTION_BIDIRECTIONAL
,
LSTM_DIRECTION_UNKNOWN
,
};
};
LSTMDirection
getLSTMDirection
(
const
std
::
string
&
s
)
{
if
(
s
==
"forward"
)
{
return
LSTMDirection
::
LSTM_DIRECTION_FORWARD
;
}
if
(
s
==
"reverse"
)
{
return
LSTMDirection
::
LSTM_DIRECTION_REVERSE
;
}
if
(
s
==
"bidirectional"
)
{
return
LSTMDirection
::
LSTM_DIRECTION_BIDIRECTIONAL
;
}
return
LSTMDirection
::
LSTM_DIRECTION_UNKNOWN
;
}
struct
LSTMAttributes
struct
LSTMAttributes
{
{
explicit
LSTMAttributes
(
const
Node
&
node
)
explicit
LSTMAttributes
(
const
Node
&
node
)
:
m_direction
{
LSTMDirection
::
LSTM_DIRECTION_FORWARD
}
:
m_hidden_size
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"hidden_size"
)}
,
m_hidden_size
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"hidden_size"
)}
,
m_clip_threshold
{
node
.
get_attribute_value
<
float
>
(
"clip"
,
0.
f
)}
,
m_activations
{
to_lower_case
(
node
.
get_attribute_value
<
std
::
vector
<
std
::
string
>>
(
"activations"
,
{
"sigmoid"
,
"tanh"
,
"tanh"
}))}
,
m_input_forget
{
static_cast
<
bool
>
(
node
.
get_attribute_value
<
std
::
int64_t
>
(
"input_forget"
,
0
))}
{
{
m_clip_threshold
=
std
::
abs
(
m_clip_threshold
);
std
::
string
direction
{
ngraph
::
to_lower
(
node
.
get_attribute_value
<
std
::
string
>
(
"direction"
,
{
"forward"
}))};
ASSERT_VALID_ARGUMENT
(
node
,
getLSTMDirection
(
direction
)
!=
LSTMDirection
::
LSTM_DIRECTION_UNKNOWN
)
<<
"Provided attribute
\"
direction
\"
value is incorrect: "
<<
direction
;
m_direction
=
getLSTMDirection
(
direction
);
}
}
// Currently only LSTM_DIRECTION_FORWARD is supported.
LSTMDirection
m_direction
;
LSTMDirection
m_direction
;
std
::
int64_t
m_hidden_size
;
std
::
int64_t
m_hidden_size
;
float
m_clip_threshold
;
std
::
vector
<
std
::
string
>
m_activations
;
bool
m_input_forget
;
};
class
LSTMForward
{
public
:
explicit
LSTMForward
(
std
::
shared_ptr
<
ngraph
::
Node
>
X
,
std
::
shared_ptr
<
ngraph
::
Node
>
W
,
std
::
shared_ptr
<
ngraph
::
Node
>
R
,
std
::
shared_ptr
<
ngraph
::
Node
>
B
,
std
::
shared_ptr
<
ngraph
::
Node
>
P
,
std
::
shared_ptr
<
ngraph
::
Node
>
initial_h
,
std
::
shared_ptr
<
ngraph
::
Node
>
initial_c
,
std
::
shared_ptr
<
ngraph
::
Node
>
seq_lengths
,
rnn
::
ActivationFunction
activation_f
,
rnn
::
ActivationFunction
activation_g
,
rnn
::
ActivationFunction
activation_h
,
bool
input_forget
=
false
,
float
clip_threshold
=
0.
f
)
:
m_X
{
X
}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
,
m_W
{
reshape
::
squeeze
(
W
)}
,
m_R
{
reshape
::
squeeze
(
R
)}
,
m_B
{
reshape
::
squeeze
(
B
)}
,
m_P
{
reshape
::
squeeze
(
P
)}
,
m_initial_h
{
reshape
::
squeeze
(
initial_h
)}
,
m_initial_c
{
reshape
::
squeeze
(
initial_c
)}
,
m_seq_lengths
{
seq_lengths
}
,
m_activation_f
{
activation_f
}
,
m_activation_g
{
activation_g
}
,
m_activation_h
{
activation_h
}
,
m_input_forget
{
input_forget
}
,
m_clip_threshold
{
clip_threshold
}
{
}
NodeVector
run
(
bool
reverse
=
false
)
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ INPUTS ------
// X - The input tensor. [seq_length, batch_size, input_size]
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor forr peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// b_W_R - Bias vectors for input, output, forget, and cell gates.
// Concatenation of `[Wb, Rb]`.
// p_[iof] - P peephole weight vector for respectively: input, output,
// and forget gates.
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
//
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector
p_iof
=
reshape
::
split
(
m_P
,
3
);
const
auto
&
p_i
=
p_iof
.
at
(
0
);
const
auto
&
p_o
=
p_iof
.
at
(
1
);
const
auto
&
p_f
=
p_iof
.
at
(
2
);
NodeVector
h_list
;
NodeVector
b_W_R
=
reshape
::
split
(
m_B
,
2
);
std
::
shared_ptr
<
ngraph
::
Node
>
bias
=
b_W_R
.
at
(
0
)
+
b_W_R
.
at
(
1
);
std
::
shared_ptr
<
ngraph
::
Node
>
H_t
=
m_initial_h
;
std
::
shared_ptr
<
ngraph
::
Node
>
C_t
=
m_initial_c
;
if
(
reverse
)
{
m_X
=
std
::
make_shared
<
ngraph
::
op
::
Reverse
>
(
m_X
,
AxisSet
{
0
});
}
NodeVector
in_seqs
{};
if
(
m_X
->
get_shape
().
at
(
0
)
!=
1
)
{
in_seqs
=
reshape
::
split
(
m_X
,
m_X
->
get_shape
().
at
(
0
));
}
else
{
in_seqs
=
NodeVector
{
m_X
};
}
for
(
auto
&
in_x
:
in_seqs
)
{
// remove first empty dim, after above split.
in_x
=
reshape
::
squeeze
(
in_x
);
}
for
(
const
auto
&
in_x
:
in_seqs
)
{
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// Xt*(W^T) -- for [iofc] gates.
auto
Xt_W
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
in_x
,
reshape
::
transpose
(
m_W
));
// Ht-1*(R^T) -- for [iofc] gates.
auto
Ht_R
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
H_t
,
reshape
::
transpose
(
m_R
));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto
gates
=
add
(
Xt_W
,
add
(
Ht_R
,
bias
));
NodeVector
split_gates
=
reshape
::
split
(
gates
,
4
,
-
1
);
auto
i
=
split_gates
.
at
(
0
);
auto
o
=
split_gates
.
at
(
1
);
auto
f
=
split_gates
.
at
(
2
);
auto
c
=
split_gates
.
at
(
3
);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i
=
m_activation_f
(
clip
(
add
(
i
,
mul
(
p_i
,
C_t
)),
m_clip_threshold
));
if
(
m_input_forget
)
{
// Couple input with forget gate: 1 - i
f
=
sub
(
ngraph
::
op
::
Constant
::
create
(
i
->
get_element_type
(),
i
->
get_shape
(),
std
::
vector
<
float
>
(
shape_size
(
i
->
get_shape
()),
1.
f
)),
i
);
}
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f
=
m_activation_f
(
clip
(
add
(
f
,
mul
(
p_f
,
C_t
)),
m_clip_threshold
));
}
// ft (.) Ct-1 + it (.) ct
auto
C
=
add
(
mul
(
f
,
C_t
),
mul
(
i
,
m_activation_g
(
clip
(
c
,
m_clip_threshold
))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o
=
m_activation_f
(
clip
(
add
(
o
,
mul
(
p_o
,
C
)),
m_clip_threshold
));
// ot (.) h(Ct)
auto
H
=
mul
(
o
,
m_activation_h
(
C
));
h_list
.
push_back
(
H
);
H_t
=
H
;
C_t
=
C
;
}
// The tensor that concats all the intermediate output values of the hidden.
// It has shape [seq_length, batch_size, hidden_size]
NodeVector
exp_h_list
;
for
(
const
auto
&
ht
:
h_list
)
{
// Expand tensors with empty outermost dim, so we can later concatenate them.
exp_h_list
.
push_back
(
reshape
::
expand_dims
(
ht
));
}
std
::
shared_ptr
<
ngraph
::
Node
>
Y
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
exp_h_list
,
0
)};
// Get back the original order of the output data.
if
(
reverse
)
{
Y
=
std
::
make_shared
<
ngraph
::
op
::
Reverse
>
(
Y
,
AxisSet
{
0
});
}
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
Y
=
reshape
::
expand_dims
(
Y
,
1
);
// expand C_t so that it has expected shape:
// [num_directions, batch_size, hidden_size]
auto
Y_c
=
reshape
::
expand_dims
(
C_t
);
return
{
Y
,
exp_h_list
.
back
(),
Y_c
};
}
private
:
std
::
shared_ptr
<
ngraph
::
Node
>
m_X
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_W
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_R
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_B
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_P
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_initial_h
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_initial_c
;
std
::
shared_ptr
<
ngraph
::
Node
>
m_seq_lengths
;
rnn
::
ActivationFunction
m_activation_f
;
rnn
::
ActivationFunction
m_activation_g
;
rnn
::
ActivationFunction
m_activation_h
;
// For coupling input and forget gates.
bool
m_input_forget
;
// For clipping cell input in the range [-clip_threshold, clip_threshold].
float
m_clip_threshold
;
};
};
}
// anonymous namespace
}
// anonymous namespace
...
@@ -233,131 +482,85 @@ namespace ngraph
...
@@ -233,131 +482,85 @@ namespace ngraph
LSTMNgInputMap
input_map
{
node
};
LSTMNgInputMap
input_map
{
node
};
LSTMAttributes
attributes
{
node
};
LSTMAttributes
attributes
{
node
};
if
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_FORWARD
)
rnn
::
ActivationFunction
activation_f
=
{
rnn
::
get_activation_func_by_name
(
attributes
.
m_activations
.
at
(
0
));
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
rnn
::
ActivationFunction
activation_g
=
for
(
auto
&
ng_in
:
input_map
)
rnn
::
get_activation_func_by_name
(
attributes
.
m_activations
.
at
(
1
));
{
rnn
::
ActivationFunction
activation_h
=
if
(
ng_in
.
first
!=
LSTMInput
::
LSTM_INPUT_X
&&
rnn
::
get_activation_func_by_name
(
attributes
.
m_activations
.
at
(
2
));
ng_in
.
first
!=
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
)
{
ASSERT_VALID_ARGUMENT
(
node
,
ng_in
.
second
->
get_shape
().
at
(
0
)
==
1
)
<<
"Input: { "
<<
to_str
(
ng_in
.
first
)
<<
" } first axis has size different "
"from 1, while direction attribute set to 'forward'."
;
ng_in
.
second
=
reshape
::
squeeze
(
ng_in
.
second
);
}
}
}
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
NodeVector
results
;
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ INPUTS ------
// X - The input tensor. [seq_length, batch_size, input_size]
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor forr peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// b_W_R - Bias vectors for input, output, forget, and cell gates.
// Concatenation of `[Wb, Rb]`.
// p_[iof] - P peephole weight vector for respectively: input, output,
// and forget gates.
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
//
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector
p_iof
=
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_P
),
3
);
const
auto
&
p_i
=
p_iof
.
at
(
0
);
const
auto
&
p_o
=
p_iof
.
at
(
1
);
const
auto
&
p_f
=
p_iof
.
at
(
2
);
std
::
shared_ptr
<
ngraph
::
Node
>
H_t
{
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_H
)};
std
::
shared_ptr
<
ngraph
::
Node
>
C_t
{
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_C
)};
NodeVector
h_list
;
NodeVector
b_W_R
=
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_B
),
2
);
std
::
shared_ptr
<
ngraph
::
Node
>
bias
=
b_W_R
.
at
(
0
)
+
b_W_R
.
at
(
1
);
NodeVector
in_seqs
=
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
)
->
get_shape
().
at
(
0
));
for
(
auto
&
in_x
:
in_seqs
)
{
// remove first empty dim, after above split.
in_x
=
reshape
::
squeeze
(
in_x
);
}
for
(
const
auto
&
in_x
:
in_seqs
)
if
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_FORWARD
||
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_REVERSE
)
{
{
// (.) - Denotes element-wise multiplication.
LSTMForward
lstm_fwd
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
),
// * - Denotes dot product.
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_W
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_R
),
// Xt*(W^T) -- for [iofc] gates.
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_B
),
auto
Xt_W
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_P
),
in_x
,
reshape
::
transpose
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_W
)));
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_H
),
// Ht-1*(R^T) -- for [iofc] gates.
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_C
),
auto
Ht_R
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
),
H_t
,
reshape
::
transpose
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_R
)));
activation_f
,
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
activation_g
,
auto
gates
=
add
(
Xt_W
,
add
(
Ht_R
,
bias
));
activation_h
,
attributes
.
m_input_forget
,
NodeVector
split_gates
=
reshape
::
split
(
gates
,
4
,
-
1
);
attributes
.
m_clip_threshold
);
auto
i
=
split_gates
.
at
(
0
);
results
=
lstm_fwd
.
run
(
auto
o
=
split_gates
.
at
(
1
);
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_REVERSE
));
auto
f
=
split_gates
.
at
(
2
);
auto
c
=
split_gates
.
at
(
3
);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i
=
sigmoid
(
add
(
i
,
mul
(
p_i
,
C_t
)));
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f
=
sigmoid
(
add
(
f
,
mul
(
p_f
,
C_t
)));
// ft (.) Ct-1 + it (.) ct
auto
C
=
add
(
mul
(
f
,
C_t
),
mul
(
i
,
tanh
(
c
)));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o
=
sigmoid
(
add
(
o
,
mul
(
p_o
,
C
)));
// ot (.) h(Ct)
auto
H
=
mul
(
o
,
tanh
(
C
));
h_list
.
push_back
(
H
);
H_t
=
H
;
C_t
=
C
;
}
}
// The tensor that concats all the intermediate output values of the hidden.
if
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_BIDIRECTIONAL
)
// It has shape [seq_length, batch_size, hidden_size]
NodeVector
exp_h_list
;
for
(
const
auto
&
ht
:
h_list
)
{
{
// Expand tensors with empty outermost dim, so we can later concatenate them.
// In bidirectional mode weights are stacked together, so we must split them.
exp_h_list
.
push_back
(
reshape
::
add_empty_axes
(
ht
));
NodeVector
W
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_W
),
2
)};
NodeVector
R
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_R
),
2
)};
NodeVector
B
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_B
),
2
)};
NodeVector
P
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_P
),
2
)};
NodeVector
H
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_H
),
2
)};
NodeVector
C
{
reshape
::
split
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_INIT_C
),
2
)};
LSTMForward
lstm_fwd
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
),
W
.
at
(
0
),
R
.
at
(
0
),
B
.
at
(
0
),
P
.
at
(
0
),
H
.
at
(
0
),
C
.
at
(
0
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
),
activation_f
,
activation_g
,
activation_h
,
attributes
.
m_input_forget
,
attributes
.
m_clip_threshold
);
LSTMForward
lstm_reversed
(
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_X
),
W
.
at
(
1
),
R
.
at
(
1
),
B
.
at
(
1
),
P
.
at
(
1
),
H
.
at
(
1
),
C
.
at
(
1
),
input_map
.
at
(
LSTMInput
::
LSTM_INPUT_SEQ_LENGTHS
),
activation_f
,
activation_g
,
activation_h
,
attributes
.
m_input_forget
,
attributes
.
m_clip_threshold
);
NodeVector
fwd_results
{
lstm_fwd
.
run
()};
NodeVector
rev_results
{
lstm_fwd
.
run
(
true
)};
// Stack together respective outputs from both forward and reverse passess.
std
::
shared_ptr
<
ngraph
::
Node
>
Y
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
fwd_results
.
at
(
0
),
rev_results
.
at
(
0
)},
1
)};
std
::
shared_ptr
<
ngraph
::
Node
>
Y_h
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
fwd_results
.
at
(
1
),
rev_results
.
at
(
1
)},
0
)};
std
::
shared_ptr
<
ngraph
::
Node
>
Y_c
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
NodeVector
{
fwd_results
.
at
(
2
),
rev_results
.
at
(
2
)},
0
)};
results
=
NodeVector
{
Y
,
Y_h
,
Y_c
};
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
Y
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
exp_h_list
,
0
)};
// Expand Y so that it has expected shape:
return
results
;
// [seq_length, num_directions, batch_size, hidden_size]
if
(
attributes
.
m_direction
==
LSTMDirection
::
LSTM_DIRECTION_FORWARD
)
{
Shape
shape
{
Y
->
get_shape
()};
shape
.
insert
(
std
::
next
(
std
::
begin
(
shape
)),
1
);
Y
=
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
Y
,
reshape
::
get_default_axis_vector
(
Y
->
get_shape
().
size
()),
shape
);
}
return
{
Y
,
exp_h_list
.
back
()};
}
}
}
// namespace set_1
}
// namespace set_1
...
...
src/ngraph/frontend/onnx_import/op/matmul.cpp
View file @
6e6c8af4
...
@@ -138,7 +138,7 @@ namespace ngraph
...
@@ -138,7 +138,7 @@ namespace ngraph
// Expand sub_dot result with single empty outermost axis, in order to
// Expand sub_dot result with single empty outermost axis, in order to
// later concatenate sub_dots at this axis.
// later concatenate sub_dots at this axis.
small_dots
.
at
(
g
)
=
reshape
::
add_empty_axe
s
(
sub_dot
);
small_dots
.
at
(
g
)
=
reshape
::
expand_dim
s
(
sub_dot
);
}
}
// Concatenate sub_dots on groups axis.
// Concatenate sub_dots on groups axis.
...
...
src/ngraph/frontend/onnx_import/op/supported_ops.md
View file @
6e6c8af4
...
@@ -112,7 +112,7 @@ opset versions starting from `1` to `6` and to the latest opset version.
...
@@ -112,7 +112,7 @@ opset versions starting from `1` to `6` and to the latest opset version.
|------|-----------------|--------|--------|---------|
|------|-----------------|--------|--------|---------|
| Erf | (9) | 284 | 442 | Need separate kernel for this in nGraph core. |
| Erf | (9) | 284 | 442 | Need separate kernel for this in nGraph core. |
| Pad | 1-2- | 273 | 416 | Not fully supported. |
| Pad | 1-2- | 273 | 416 | Not fully supported. |
| LSTM | 1-7- | | 4
30 | Not fully supported
. |
| LSTM | 1-7- | | 4
76 | Mixed sequences length not supported yet
. |
| MaxUnpool | (9) | 286, 289 | 447 | |
| MaxUnpool | (9) | 286, 289 | 447 | |
| LpPool | - | 291 | 437 | Unsupported by nGraph - only max/avg pooling ops. Need separate kernel. |
| LpPool | - | 291 | 437 | Unsupported by nGraph - only max/avg pooling ops. Need separate kernel. |
| Multinomial | - | 199 | 435 | Lack of PRNG in nGraph. |
| Multinomial | - | 199 | 435 | Lack of PRNG in nGraph. |
...
...
src/ngraph/frontend/onnx_import/utils/reshape.cpp
View file @
6e6c8af4
...
@@ -221,17 +221,14 @@ namespace ngraph
...
@@ -221,17 +221,14 @@ namespace ngraph
node
,
get_default_axis_vector
(
node
->
get_shape
().
size
()),
shape
);
node
,
get_default_axis_vector
(
node
->
get_shape
().
size
()),
shape
);
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
add_empty_axes
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
shared_ptr
<
ngraph
::
Node
>
expand_dims
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
size_t
outermost_axes_count
,
std
::
size_t
axis
)
std
::
size_t
innermost_axes_count
)
{
{
// Add outermost empty dimensions.
Shape
output_shape
(
node
->
get_shape
());
Shape
output_shape
(
outermost_axes_count
,
1
);
// Add empty axis at specified position.
output_shape
.
insert
(
std
::
end
(
output_shape
),
auto
empty_axis_it
=
std
::
begin
(
output_shape
);
std
::
begin
(
node
->
get_shape
()),
std
::
advance
(
empty_axis_it
,
axis
);
std
::
end
(
node
->
get_shape
()));
output_shape
.
insert
(
empty_axis_it
,
1
);
// Add innermost empty dimensions.
output_shape
.
insert
(
std
::
end
(
output_shape
),
innermost_axes_count
,
1
);
return
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
return
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
node
,
reshape
::
get_default_axis_vector
(
node
->
get_shape
().
size
()),
output_shape
);
node
,
reshape
::
get_default_axis_vector
(
node
->
get_shape
().
size
()),
output_shape
);
}
}
...
...
src/ngraph/frontend/onnx_import/utils/reshape.hpp
View file @
6e6c8af4
...
@@ -127,19 +127,17 @@ namespace ngraph
...
@@ -127,19 +127,17 @@ namespace ngraph
return
reshape
(
node
,
get_default_axis_vector
(
node
->
get_shape
().
size
()),
shape
);
return
reshape
(
node
,
get_default_axis_vector
(
node
->
get_shape
().
size
()),
shape
);
}
}
/// \brief Expands node tensor shape with empty axes.
/// \brief Expands node tensor shape with empty axis at
/// specified position.
///
///
/// \param[in] node The node to be expanded.
/// \param[in] node The node to be expanded.
/// \param[in] outermost_axes_count The number of added outermost axes.
/// \param[in] axis The position in the expanded axes where the
/// At the front of the shape.
/// new axis is placed.
/// \param[in] innermost_axes_count The number of added innermost axes.
/// At the end of the shape.
///
///
/// \return The node with added empty ax
e
s.
/// \return The node with added empty ax
i
s.
///
///
std
::
shared_ptr
<
ngraph
::
Node
>
add_empty_axes
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
shared_ptr
<
ngraph
::
Node
>
expand_dims
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
size_t
outermost_axes_count
=
1
,
std
::
size_t
axis
=
0
);
std
::
size_t
innermost_axes_count
=
0
);
/// \brief Split node on specified axis into multiple parts.
/// \brief Split node on specified axis into multiple parts.
///
///
...
...
src/ngraph/frontend/onnx_import/utils/rnn/activation_functions.cpp
0 → 100644
View file @
6e6c8af4
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <iterator>
#include <unordered_map>
#include "activation_functions.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
rnn
{
namespace
detail
{
std
::
shared_ptr
<
ngraph
::
Node
>
sigmoid
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Sigmoid
>
(
arg
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
tanh
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Tanh
>
(
arg
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
relu
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Relu
>
(
arg
);
}
}
// namespace detail
ActivationFunction
get_activation_func_by_name
(
const
std
::
string
&
func_name
)
{
using
ActivationFunctionMap
=
std
::
unordered_map
<
std
::
string
,
ActivationFunction
>
;
static
ActivationFunctionMap
func_map
{
{
"sigmoid"
,
std
::
bind
(
detail
::
sigmoid
,
std
::
placeholders
::
_1
)},
{
"tanh"
,
std
::
bind
(
detail
::
tanh
,
std
::
placeholders
::
_1
)},
{
"relu"
,
std
::
bind
(
detail
::
relu
,
std
::
placeholders
::
_1
)}};
auto
func_it
=
func_map
.
find
(
func_name
);
if
(
func_it
==
std
::
end
(
func_map
))
{
throw
error
::
UnknownActivationFunction
(
func_name
);
}
return
func_it
->
second
;
}
}
//namespace rnn
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/utils/rnn/activation_functions.hpp
0 → 100644
View file @
6e6c8af4
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include <string>
#include "ngraph/except.hpp"
#include "ngraph/node.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
rnn
{
namespace
error
{
struct
UnknownActivationFunction
:
ngraph_error
{
UnknownActivationFunction
(
const
std
::
string
&
func_name
)
:
ngraph_error
{
"Unknown activation function: "
+
func_name
}
{
}
};
}
namespace
detail
{
std
::
shared_ptr
<
ngraph
::
Node
>
sigmoid
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
);
std
::
shared_ptr
<
ngraph
::
Node
>
tanh
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
);
std
::
shared_ptr
<
ngraph
::
Node
>
relu
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg
);
}
using
ActivationFunction
=
std
::
function
<
std
::
shared_ptr
<
ngraph
::
Node
>
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
)
>
;
/// \brief Gets the activation function by name.
///
/// \param[in] func_name The function name
///
/// \throws UnknownActivationFunction When provided func_name is unknown.
///
/// \return The activation function object.
///
ActivationFunction
get_activation_func_by_name
(
const
std
::
string
&
func_name
);
}
//namespace rnn
}
// namespace onnx_import
}
// namespace ngraph
test/models/onnx/lstm_fwd_with_clip.onnx
0 → 100644
View file @
6e6c8af4
File added
test/onnx_import.in.cpp
View file @
6e6c8af4
...
@@ -1864,6 +1864,91 @@ TEST(onnx_${BACKEND_NAME}, model_top_k)
...
@@ -1864,6 +1864,91 @@ TEST(onnx_${BACKEND_NAME}, model_top_k)
EXPECT_TRUE
(
test
::
all_close
(
expected_indices_output
,
indices_output
));
EXPECT_TRUE
(
test
::
all_close
(
expected_indices_output
,
indices_output
));
}
}
TEST
(
onnx_
$
{
BACKEND_NAME
},
model_lstm_fwd_with_clip
)
{
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/lstm_fwd_with_clip.onnx"
));
Inputs
inputs
{};
// X
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
-
0.455351
,
-
0.276391
,
-
0.185934
,
-
0.269585
});
// W
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
-
0.494659
f
,
0.0453352
f
,
-
0.487793
f
,
0.417264
f
,
-
0.0175329
f
,
0.489074
f
,
-
0.446013
f
,
0.414029
f
,
-
0.0091708
f
,
-
0.255364
f
,
-
0.106952
f
,
-
0.266717
f
,
-
0.0888852
f
,
-
0.428709
f
,
-
0.283349
f
,
0.208792
f
});
// R
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
0.146626
f
,
-
0.0620289
f
,
-
0.0815302
f
,
0.100482
f
,
-
0.219535
f
,
-
0.306635
f
,
-
0.28515
f
,
-
0.314112
f
,
-
0.228172
f
,
0.405972
f
,
0.31576
f
,
0.281487
f
,
-
0.394864
f
,
0.42111
f
,
-
0.386624
f
,
-
0.390225
f
});
// B
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
0.381619
f
,
0.0323954
f
,
-
0.14449
f
,
0.420804
f
,
-
0.258721
f
,
0.45056
f
,
-
0.250755
f
,
0.0967895
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
});
// P
inputs
.
emplace_back
(
std
::
vector
<
float
>
{
0.2345
f
,
0.5235
f
,
0.4378
f
,
0.3475
f
,
0.8927
f
,
0.3456
f
});
Outputs
expected_output
{};
// Y_data
expected_output
.
emplace_back
(
std
::
vector
<
float
>
{
-
0.02280854
f
,
0.02744377
f
,
-
0.03516197
f
,
0.03875681
f
});
// Y_h_data
expected_output
.
emplace_back
(
std
::
vector
<
float
>
{
-
0.03516197
f
,
0.03875681
f
});
// Y_c_data
expected_output
.
emplace_back
(
std
::
vector
<
float
>
{
-
0.07415761
f
,
0.07395997
f
});
Outputs
outputs
{
execute
(
function
,
inputs
,
"${BACKEND_NAME}"
)};
EXPECT_TRUE
(
outputs
.
size
()
==
expected_output
.
size
());
for
(
std
::
size_t
i
{
0
};
i
<
expected_output
.
size
();
++
i
)
{
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
EXPECT_TRUE
(
test
::
all_close_f
(
expected_output
.
at
(
i
),
outputs
.
at
(
i
),
3
));
}
}
TEST
(
onnx_
$
{
BACKEND_NAME
},
model_missing_input
)
TEST
(
onnx_
$
{
BACKEND_NAME
},
model_missing_input
)
{
{
onnx_import
::
register_operator
(
onnx_import
::
register_operator
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment