Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
fcf59b2a
Unverified
Commit
fcf59b2a
authored
Sep 05, 2019
by
Scott Cyphers
Committed by
GitHub
Sep 05, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into cyphers/typename
parents
adf849e5
8e7d10df
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
43 changed files
with
1168 additions
and
378 deletions
+1168
-378
Jenkinsfile
.ci/onnx/jenkins/Jenkinsfile
+2
-2
softmax.cpp
python/pyngraph/ops/softmax.cpp
+2
-4
mlir_subgraph_extraction.cpp
src/contrib/mlir/compiler/pass/mlir_subgraph_extraction.cpp
+0
-44
CMakeLists.txt
src/ngraph/CMakeLists.txt
+6
-0
matmul_factory.cpp
src/ngraph/builder/matmul_factory.cpp
+48
-62
matmul_factory.hpp
src/ngraph/builder/matmul_factory.hpp
+88
-0
reshape.cpp
src/ngraph/builder/reshape.cpp
+48
-0
reshape.hpp
src/ngraph/builder/reshape.hpp
+36
-1
CMakeLists.txt
src/ngraph/frontend/onnx_import/CMakeLists.txt
+0
-2
lstm.cpp
src/ngraph/frontend/onnx_import/op/lstm.cpp
+12
-11
matmul.cpp
src/ngraph/frontend/onnx_import/op/matmul.cpp
+15
-2
matmul_integer.cpp
src/ngraph/frontend/onnx_import/op/matmul_integer.cpp
+15
-2
qlinear_matmul.cpp
src/ngraph/frontend/onnx_import/op/qlinear_matmul.cpp
+15
-2
matmul_factory.hpp
src/ngraph/frontend/onnx_import/utils/matmul_factory.hpp
+0
-102
reshape.cpp
src/ngraph/frontend/onnx_import/utils/reshape.cpp
+0
-54
reshape.hpp
src/ngraph/frontend/onnx_import/utils/reshape.hpp
+0
-36
ngraph.hpp
src/ngraph/ngraph.hpp
+1
-0
node.hpp
src/ngraph/node.hpp
+2
-0
matmul.cpp
src/ngraph/op/fused/matmul.cpp
+89
-0
matmul.hpp
src/ngraph/op/fused/matmul.hpp
+58
-0
fused_op_tbl.hpp
src/ngraph/op/fused_op_tbl.hpp
+1
-0
softmax.cpp
src/ngraph/op/softmax.cpp
+92
-6
softmax.hpp
src/ngraph/op/softmax.hpp
+72
-25
topk.hpp
src/ngraph/op/topk.hpp
+4
-4
opset1_upgrade.cpp
src/ngraph/pass/opset1_upgrade.cpp
+101
-0
opset1_upgrade.hpp
src/ngraph/pass/opset1_upgrade.hpp
+38
-0
topk.cpp
src/ngraph/runtime/cpu/builder/topk.cpp
+19
-6
cpu_emitter.hpp
src/ngraph/runtime/cpu/cpu_emitter.hpp
+0
-1
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+1
-0
cpu_post_layout_optimizations.cpp
...ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.cpp
+0
-0
cpu_post_layout_optimizations.hpp
...ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp
+9
-0
intelgpu_backend.cpp
src/ngraph/runtime/intelgpu/intelgpu_backend.cpp
+3
-0
int_executable.hpp
src/ngraph/runtime/interpreter/int_executable.hpp
+14
-2
unit_test.manifest
src/ngraph/runtime/plaidml/unit_test.manifest
+11
-0
topk.hpp
src/ngraph/runtime/reference/topk.hpp
+55
-4
serializer.cpp
src/ngraph/serializer.cpp
+42
-5
CMakeLists.txt
test/CMakeLists.txt
+2
-0
topk.in.cpp
test/backend/topk.in.cpp
+0
-0
cpu_test.cpp
test/cpu_test.cpp
+23
-0
softmax_opset_pass.cpp
test/opset_pass/softmax_opset_pass.cpp
+117
-0
serialize.cpp
test/serialize.cpp
+16
-0
matmul.cpp
test/type_prop/matmul.cpp
+110
-0
top_k.cpp
test/type_prop/top_k.cpp
+1
-1
No files found.
.ci/onnx/jenkins/Jenkinsfile
View file @
fcf59b2a
...
@@ -164,8 +164,8 @@ def notifyByEmail() {
...
@@ -164,8 +164,8 @@ def notifyByEmail() {
<tr><td>Pull Request Title:</td> <td>$CHANGE_TITLE</td></tr>
<tr><td>Pull Request Title:</td> <td>$CHANGE_TITLE</td></tr>
<tr><td>Pull Request:</td> <td><a href=$CHANGE_URL>$CHANGE_ID</a> </td></tr>
<tr><td>Pull Request:</td> <td><a href=$CHANGE_URL>$CHANGE_ID</a> </td></tr>
<tr><td>Branch:</td> <td>$CHANGE_BRANCH</td></tr>
<tr><td>Branch:</td> <td>$CHANGE_BRANCH</td></tr>
<tr><td>Commit Hash:</td> <td>$GIT_COMMIT_
SUBJECT
</td></tr>
<tr><td>Commit Hash:</td> <td>$GIT_COMMIT_
HASH
</td></tr>
<tr><td>Commit Subject:</td> <td>$GIT_COMMIT_
HASH
</td></tr>
<tr><td>Commit Subject:</td> <td>$GIT_COMMIT_
SUBJECT
</td></tr>
<tr><td>Jenkins Build:</td> <td> <a href=$RUN_DISPLAY_URL> ${BUILD_NUMBER} </a> </td></tr>
<tr><td>Jenkins Build:</td> <td> <a href=$RUN_DISPLAY_URL> ${BUILD_NUMBER} </a> </td></tr>
<tr><td>nGraph-ONNX Branch:</td> <td>${ONNX_BRANCH}</td></tr>
<tr><td>nGraph-ONNX Branch:</td> <td>${ONNX_BRANCH}</td></tr>
</table>
</table>
...
...
python/pyngraph/ops/softmax.cpp
View file @
fcf59b2a
...
@@ -24,10 +24,8 @@ namespace py = pybind11;
...
@@ -24,10 +24,8 @@ namespace py = pybind11;
void
regclass_pyngraph_op_Softmax
(
py
::
module
m
)
void
regclass_pyngraph_op_Softmax
(
py
::
module
m
)
{
{
py
::
class_
<
ngraph
::
op
::
Softmax
,
py
::
class_
<
ngraph
::
op
::
Softmax
,
std
::
shared_ptr
<
ngraph
::
op
::
Softmax
>
,
ngraph
::
op
::
Op
>
softmax
(
std
::
shared_ptr
<
ngraph
::
op
::
Softmax
>
,
m
,
"Softmax"
);
ngraph
::
op
::
util
::
UnaryElementwiseArithmetic
>
softmax
(
m
,
"Softmax"
);
softmax
.
doc
()
=
"ngraph.impl.op.Softmax wraps ngraph::op::Softmax"
;
softmax
.
doc
()
=
"ngraph.impl.op.Softmax wraps ngraph::op::Softmax"
;
softmax
.
def
(
py
::
init
<
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
ngraph
::
AxisSet
&>
());
softmax
.
def
(
py
::
init
<
const
std
::
shared_ptr
<
ngraph
::
Node
>&
,
const
ngraph
::
AxisSet
&>
());
}
}
src/contrib/mlir/compiler/pass/mlir_subgraph_extraction.cpp
View file @
fcf59b2a
...
@@ -437,50 +437,6 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
...
@@ -437,50 +437,6 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
}
}
}
}
if
(
TI
(
ngraph
::
op
::
ArgMin
)
==
TI
(
*
node
)
||
TI
(
ngraph
::
op
::
ArgMax
)
==
TI
(
*
node
))
{
// TODO: Remove this when MLIR has float point cmp support
if
(
!
node
->
input
(
0
).
get_element_type
().
is_integral
())
{
return
false
;
}
else
{
return
true
;
}
}
if
(
TI
(
ngraph
::
op
::
Maximum
)
==
TI
(
*
node
)
||
TI
(
ngraph
::
op
::
Minimum
)
==
TI
(
*
node
))
{
// TODO: Remove this when MLIR has float point cmp support
if
(
!
node
->
input
(
0
).
get_element_type
().
is_integral
())
{
return
false
;
}
else
{
return
true
;
}
}
if
(
TI
(
ngraph
::
op
::
Greater
)
==
TI
(
*
node
)
||
TI
(
ngraph
::
op
::
Less
)
==
TI
(
*
node
))
{
// TODO: Remove this when MLIR has float point cmp support
if
(
!
node
->
input
(
0
).
get_element_type
().
is_integral
())
{
return
false
;
}
else
{
return
true
;
}
}
if
(
TI
(
ngraph
::
op
::
Negative
)
==
TI
(
*
node
))
{
return
true
;
}
if
(
TI
(
ngraph
::
op
::
Convolution
)
==
TI
(
*
node
))
if
(
TI
(
ngraph
::
op
::
Convolution
)
==
TI
(
*
node
))
{
{
// No padding for now
// No padding for now
...
...
src/ngraph/CMakeLists.txt
View file @
fcf59b2a
...
@@ -27,6 +27,8 @@ set (SRC
...
@@ -27,6 +27,8 @@ set (SRC
builder/dequantize_builder.cpp
builder/dequantize_builder.cpp
builder/dequantize_builder.hpp
builder/dequantize_builder.hpp
builder/make_constant.hpp
builder/make_constant.hpp
builder/matmul_factory.cpp
builder/matmul_factory.hpp
builder/norm.cpp
builder/norm.cpp
builder/norm.hpp
builder/norm.hpp
builder/numpy_transpose.cpp
builder/numpy_transpose.cpp
...
@@ -329,6 +331,8 @@ set (SRC
...
@@ -329,6 +331,8 @@ set (SRC
op/fused/gru_cell.hpp
op/fused/gru_cell.hpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/lstm_cell.hpp
op/fused/matmul.cpp
op/fused/matmul.hpp
op/fused/mvn.cpp
op/fused/mvn.cpp
op/fused/mvn.hpp
op/fused/mvn.hpp
op/fused/normalize_l2.cpp
op/fused/normalize_l2.cpp
...
@@ -437,6 +441,8 @@ set (SRC
...
@@ -437,6 +441,8 @@ set (SRC
pass/nop_elimination.hpp
pass/nop_elimination.hpp
pass/pass.cpp
pass/pass.cpp
pass/pass.hpp
pass/pass.hpp
pass/opset1_upgrade.cpp
pass/opset1_upgrade.hpp
pass/pass_config.cpp
pass/pass_config.cpp
pass/pass_config.hpp
pass/pass_config.hpp
pass/propagate_cacheability.cpp
pass/propagate_cacheability.cpp
...
...
src/ngraph/
frontend/onnx_import/utils
/matmul_factory.cpp
→
src/ngraph/
builder
/matmul_factory.cpp
View file @
fcf59b2a
...
@@ -18,17 +18,18 @@
...
@@ -18,17 +18,18 @@
#include <iterator>
#include <iterator>
#include <memory>
#include <memory>
#include "matmul_factory.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "utils/reshape.hpp"
using
namespace
ngraph
::
onnx_import
::
matmul
;
using
namespace
ngraph
;
using
namespace
std
;
/// \brief Slice the sub matrix from the input tensor.
/// \brief Slice the sub matrix from the input tensor.
///
///
...
@@ -37,59 +38,49 @@ using namespace ngraph::onnx_import::matmul;
...
@@ -37,59 +38,49 @@ using namespace ngraph::onnx_import::matmul;
///
///
/// \return The node representing sub matrix.
/// \return The node representing sub matrix.
///
///
static
std
::
shared_ptr
<
ngraph
::
Node
>
get_sub_matrix
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
static
Output
<
Node
>
get_sub_matrix
(
const
Output
<
Node
>&
node
,
size_t
idx
)
std
::
size_t
idx
)
{
{
const
ngraph
::
Shape
&
shape
{
node
->
get_shape
()};
const
Shape
&
shape
{
node
.
get_shape
()};
if
(
shape
.
size
()
<
3
)
if
(
shape
.
size
()
<
3
)
{
{
return
node
;
return
node
.
get_node_shared_ptr
()
;
}
}
// Below bounds defines the sub_matrix through ranges for each input node axis.
// Below bounds defines the sub_matrix through ranges for each input node axis.
ngraph
::
Coordinate
lower_bounds
(
shape
.
size
());
Coordinate
lower_bounds
(
shape
.
size
());
ngraph
::
Coordinate
upper_bounds
=
shape
;
Coordinate
upper_bounds
=
shape
;
// We assume `node` tensor is of rank equal 3, thus we slice the sub-matrix lying in the last
// We assume `node` tensor is of rank equal 3, thus we slice the sub-matrix lying in the last
// two dimensions at index `idx` of first axis.
// two dimensions at index `idx` of first axis.
lower_bounds
.
at
(
0
)
=
idx
;
lower_bounds
.
at
(
0
)
=
idx
;
upper_bounds
.
at
(
0
)
=
idx
+
1
;
upper_bounds
.
at
(
0
)
=
idx
+
1
;
auto
sub_matrix
=
std
::
shared_ptr
<
ngraph
::
Node
>
{
auto
sub_matrix
=
Output
<
Node
>
{
make_shared
<
op
::
Slice
>
(
node
,
lower_bounds
,
upper_bounds
)};
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
node
,
lower_bounds
,
upper_bounds
)};
// Remove first single entry dim.
// Remove first single entry dim.
return
ngraph
::
onnx_import
::
reshape
::
squeeze
(
sub_matrix
);
return
builder
::
squeeze
(
sub_matrix
);
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
MatmulFactory
::
get_left
()
Output
<
Node
>
builder
::
MatmulFactory
::
get_left
()
{
{
return
m_inputs
.
at
(
0
);
return
m_inputs
.
at
(
0
);
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
MatmulFactory
::
get_right
()
Output
<
Node
>
builder
::
MatmulFactory
::
get_right
()
{
{
return
m_inputs
.
at
(
1
);
return
m_inputs
.
at
(
1
);
}
}
ngraph
::
NodeVector
MatmulFactory
::
make_matmul_op
()
NodeVector
builder
::
MatmulFactory
::
make_matmul_op
()
{
{
auto
left
=
get_left
();
auto
left
=
get_left
();
auto
right
=
get_right
();
auto
right
=
get_right
();
std
::
size_t
left_rank
{
left
->
get_shape
().
size
()};
size_t
left_rank
{
left
.
get_shape
().
size
()};
std
::
size_t
right_rank
{
right
->
get_shape
().
size
()};
size_t
right_rank
{
right
.
get_shape
().
size
()};
if
(
left_rank
==
0
||
right_rank
==
0
)
{
NGRAPH_WARN
<<
(
m_onnx_node
)
<<
" "
<<
"ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard."
;
}
// First (easy) case that is already internally handled by Ngraph Dot operator.
// First (easy) case that is already internally handled by Ngraph Dot operator.
// Multiply two tensors where both of them has rank lower equal 2.
// Multiply two tensors where both of them has rank lower equal 2.
if
(
left_rank
<=
2
&&
right_rank
<=
2
)
if
(
left_rank
<=
2
&&
right_rank
<=
2
)
{
{
return
NodeVector
{
make_dot
(
left
,
right
)};
return
{
make_dot
(
left
,
right
).
get_node_shared_ptr
(
)};
}
}
// Second case:
// Second case:
...
@@ -98,37 +89,37 @@ ngraph::NodeVector MatmulFactory::make_matmul_op()
...
@@ -98,37 +89,37 @@ ngraph::NodeVector MatmulFactory::make_matmul_op()
// Broadcast input arguments only if both of them are not vectors.
// Broadcast input arguments only if both of them are not vectors.
if
(
left_rank
>
1
&&
right_rank
>
1
)
if
(
left_rank
>
1
&&
right_rank
>
1
)
{
{
const
NodeVector
&
broadcasted_nodes
=
const
NodeVector
&
broadcasted_nodes
=
op
::
numpy_style_broadcast_for_matmul_operation
(
ngraph
::
op
::
numpy_style_broadcast_for_matmul_operation
(
left
,
right
);
left
.
get_node_shared_ptr
(),
right
.
get_node_shared_ptr
()
);
left
=
broadcasted_nodes
.
at
(
0
);
left
=
broadcasted_nodes
.
at
(
0
);
right
=
broadcasted_nodes
.
at
(
1
);
right
=
broadcasted_nodes
.
at
(
1
);
}
}
const
auto
&
left_shape
=
left
->
get_shape
();
const
auto
&
left_shape
=
left
.
get_shape
();
const
auto
&
right_shape
=
right
->
get_shape
();
const
auto
&
right_shape
=
right
.
get_shape
();
// Collapse both tensors _stack of matrices_ axes (all except the last two).
// Collapse both tensors _stack of matrices_ axes (all except the last two).
// This will make easier further dot product calculations.
// This will make easier further dot product calculations.
if
(
left_shape
.
size
()
>
3
)
if
(
left_shape
.
size
()
>
3
)
{
{
left
=
onnx_import
::
reshape
::
collapse
(
left
,
0
,
left_shape
.
size
()
-
3
);
left
=
builder
::
collapse
(
left
,
0
,
left_shape
.
size
()
-
3
);
}
}
if
(
right_shape
.
size
()
>
3
)
if
(
right_shape
.
size
()
>
3
)
{
{
right
=
onnx_import
::
reshape
::
collapse
(
right
,
0
,
right_shape
.
size
()
-
3
);
right
=
builder
::
collapse
(
right
,
0
,
right_shape
.
size
()
-
3
);
}
}
// Perform multiple small dot products
// Perform multiple small dot products
s
td
::
size_t
groups
=
left
->
get_shape
().
at
(
0
);
s
ize_t
groups
=
left
.
get_shape
().
at
(
0
);
// If we haven't broadcast earlier this means that one of the inputs is a vector,
// If we haven't broadcast earlier this means that one of the inputs is a vector,
// thus the number of groups is defined by the shape of the bigger tensor.
// thus the number of groups is defined by the shape of the bigger tensor.
if
(
right
->
get_shape
().
size
()
>
left
->
get_shape
().
size
())
if
(
right
.
get_shape
().
size
()
>
left
.
get_shape
().
size
())
{
{
groups
=
right
->
get_shape
().
at
(
0
);
groups
=
right
.
get_shape
().
at
(
0
);
}
}
NodeVector
small_dots
(
groups
);
NodeVector
small_dots
(
groups
);
for
(
s
td
::
s
ize_t
g
=
0
;
g
<
groups
;
++
g
)
for
(
size_t
g
=
0
;
g
<
groups
;
++
g
)
{
{
const
auto
sliced_left
=
get_sub_matrix
(
left
,
g
);
const
auto
sliced_left
=
get_sub_matrix
(
left
,
g
);
const
auto
sliced_right
=
get_sub_matrix
(
right
,
g
);
const
auto
sliced_right
=
get_sub_matrix
(
right
,
g
);
...
@@ -136,11 +127,11 @@ ngraph::NodeVector MatmulFactory::make_matmul_op()
...
@@ -136,11 +127,11 @@ ngraph::NodeVector MatmulFactory::make_matmul_op()
// Expand sub_dot result with single empty outermost axis, in order to
// Expand sub_dot result with single empty outermost axis, in order to
// later concatenate sub_dots at this axis.
// later concatenate sub_dots at this axis.
small_dots
.
at
(
g
)
=
onnx_import
::
reshape
::
expand_dims
(
sub_dot
);
small_dots
.
at
(
g
)
=
builder
::
expand_dims
(
sub_dot
);
}
}
// Concatenate sub_dots on groups axis.
// Concatenate sub_dots on groups axis.
auto
result
=
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
small_dots
,
0
);
auto
result
=
make_shared
<
op
::
Concat
>
(
small_dots
,
0
);
if
(
left_shape
.
size
()
<=
3
&&
right_shape
.
size
()
<=
3
)
if
(
left_shape
.
size
()
<=
3
&&
right_shape
.
size
()
<=
3
)
{
{
...
@@ -150,39 +141,35 @@ ngraph::NodeVector MatmulFactory::make_matmul_op()
...
@@ -150,39 +141,35 @@ ngraph::NodeVector MatmulFactory::make_matmul_op()
else
else
{
{
const
Shape
&
shape
{
result
->
get_shape
()};
const
Shape
&
shape
{
result
->
get_shape
()};
Shape
result_shape
(
std
::
next
(
std
::
begin
(
shape
)),
std
::
end
(
shape
));
Shape
result_shape
(
next
(
begin
(
shape
)),
end
(
shape
));
result_shape
.
insert
(
std
::
begin
(
result_shape
),
result_shape
.
insert
(
std
::
begin
(
left_shape
),
begin
(
result_shape
),
begin
(
left_shape
),
next
(
begin
(
left_shape
),
left_shape
.
size
()
-
2
));
std
::
next
(
std
::
begin
(
left_shape
),
left_shape
.
size
()
-
2
));
return
{
make_shared
<
op
::
Reshape
>
(
result
,
get_default_order
(
shape
.
size
()),
result_shape
)};
return
{
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
result
,
ngraph
::
get_default_order
(
shape
.
size
()),
result_shape
)};
}
}
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
MatmulFactory
::
make_dot
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
left
,
Output
<
Node
>
builder
::
MatmulFactory
::
make_dot
(
const
Output
<
Node
>&
left
,
const
Output
<
Node
>&
right
)
const
std
::
shared_ptr
<
ngraph
::
Node
>&
right
)
{
{
return
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
left
,
right
);
return
make_shared
<
op
::
Dot
>
(
left
,
right
);
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
QLinearMatmulFactory
::
get_right
()
Output
<
Node
>
builder
::
QLinearMatmulFactory
::
get_right
()
{
{
return
m_inputs
.
at
(
3
);
return
m_inputs
.
at
(
3
);
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
Output
<
Node
>
builder
::
QLinearMatmulFactory
::
make_dot
(
const
Output
<
Node
>&
left
,
QLinearMatmulFactory
::
make_dot
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
left
,
const
Output
<
Node
>&
right
)
const
std
::
shared_ptr
<
ngraph
::
Node
>&
right
)
{
{
ngraph
::
element
::
Type
output_type
;
ngraph
::
element
::
Type
output_type
;
if
(
left
->
get_element_type
()
==
ngraph
::
element
::
u8
&&
if
(
left
.
get_element_type
()
==
ngraph
::
element
::
u8
&&
right
->
get_element_type
()
==
ngraph
::
element
::
i8
)
right
.
get_element_type
()
==
ngraph
::
element
::
i8
)
{
{
output_type
=
ngraph
::
element
::
i8
;
output_type
=
ngraph
::
element
::
i8
;
}
}
else
if
(
left
->
get_element_type
()
==
ngraph
::
element
::
u8
&&
else
if
(
left
.
get_element_type
()
==
ngraph
::
element
::
u8
&&
right
->
get_element_type
()
==
ngraph
::
element
::
u8
)
right
.
get_element_type
()
==
ngraph
::
element
::
u8
)
{
{
output_type
=
ngraph
::
element
::
u8
;
output_type
=
ngraph
::
element
::
u8
;
}
}
...
@@ -202,15 +189,14 @@ std::shared_ptr<ngraph::Node>
...
@@ -202,15 +189,14 @@ std::shared_ptr<ngraph::Node>
ngraph
::
AxisSet
{});
ngraph
::
AxisSet
{});
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
Output
<
Node
>
builder
::
MatmulIntegerFactory
::
make_dot
(
const
Output
<
Node
>&
left
,
MatmulIntegerFactory
::
make_dot
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
left
,
const
Output
<
Node
>&
right
)
const
std
::
shared_ptr
<
ngraph
::
Node
>&
right
)
{
{
auto
num_inputs
=
m_inputs
.
size
();
auto
num_inputs
=
m_inputs
.
size
();
auto
scale_one
=
ngraph
::
builder
::
make_constant
(
ngraph
::
element
::
f32
,
Shape
{},
1
);
auto
scale_one
=
ngraph
::
builder
::
make_constant
(
ngraph
::
element
::
f32
,
Shape
{},
1
);
auto
output_zero_point
=
ngraph
::
builder
::
make_constant
(
ngraph
::
element
::
i32
,
Shape
{},
0
);
auto
output_zero_point
=
ngraph
::
builder
::
make_constant
(
ngraph
::
element
::
i32
,
Shape
{},
0
);
auto
left_zero_point
=
ngraph
::
builder
::
make_constant
(
left
->
get_element_type
(),
Shape
{},
0
);
auto
left_zero_point
=
ngraph
::
builder
::
make_constant
(
left
.
get_element_type
(),
Shape
{},
0
);
auto
right_zero_point
=
ngraph
::
builder
::
make_constant
(
right
->
get_element_type
(),
Shape
{},
0
);
auto
right_zero_point
=
ngraph
::
builder
::
make_constant
(
right
.
get_element_type
(),
Shape
{},
0
);
if
(
num_inputs
==
2
)
if
(
num_inputs
==
2
)
{
{
return
std
::
make_shared
<
ngraph
::
op
::
QuantizedDot
>
(
left
,
return
std
::
make_shared
<
ngraph
::
op
::
QuantizedDot
>
(
left
,
...
@@ -228,10 +214,10 @@ std::shared_ptr<ngraph::Node>
...
@@ -228,10 +214,10 @@ std::shared_ptr<ngraph::Node>
ngraph
::
AxisSet
{});
ngraph
::
AxisSet
{});
}
}
left_zero_point
=
m_inputs
.
at
(
2
);
left_zero_point
=
m_inputs
.
at
(
2
)
.
get_node_shared_ptr
()
;
if
(
num_inputs
==
4
)
if
(
num_inputs
==
4
)
{
{
right_zero_point
=
m_inputs
.
at
(
3
);
right_zero_point
=
m_inputs
.
at
(
3
)
.
get_node_shared_ptr
()
;
}
}
return
std
::
make_shared
<
ngraph
::
op
::
QuantizedDot
>
(
left
,
return
std
::
make_shared
<
ngraph
::
op
::
QuantizedDot
>
(
left
,
...
...
src/ngraph/builder/matmul_factory.hpp
0 → 100644
View file @
fcf59b2a
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
namespace
ngraph
{
namespace
builder
{
/// \brief Factory class which generates an nGraph sub-graph performing MatMul operation.
///
/// This default implementation `MatmulFactory` creates a `MatMul` operation for
/// floating-point data.
/// Subclasses: `QLinearMatmulFactory` and `MatmulIntegerFactory` implement quantized
/// versions.
class
MatmulFactory
{
public
:
explicit
MatmulFactory
(
const
OutputVector
&
inputs
)
:
m_inputs
(
inputs
)
{
}
virtual
~
MatmulFactory
()
=
default
;
/// \brief Create a sub-graph representing an ONNX MatMul operation.
///
/// \return NodeVector containing the sub-graph output node.
virtual
NodeVector
make_matmul_op
();
protected
:
/// \return Output representing the left operand.
virtual
Output
<
Node
>
get_left
();
/// \return Output representing the right operand.
virtual
Output
<
Node
>
get_right
();
/// \return Output representing the nGraph Dot operation used to construct MatMul.
virtual
Output
<
Node
>
make_dot
(
const
Output
<
Node
>&
left
,
const
Output
<
Node
>&
right
);
const
OutputVector
m_inputs
;
};
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX QLinearMatMul
/// operation.
class
QLinearMatmulFactory
:
public
MatmulFactory
{
public
:
explicit
QLinearMatmulFactory
(
const
OutputVector
&
inputs
)
:
MatmulFactory
(
inputs
)
{
}
protected
:
Output
<
Node
>
get_right
()
override
;
Output
<
Node
>
make_dot
(
const
Output
<
Node
>&
left
,
const
Output
<
Node
>&
right
)
override
;
};
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX MatMulInteger
/// operation.
class
MatmulIntegerFactory
:
public
MatmulFactory
{
public
:
explicit
MatmulIntegerFactory
(
const
OutputVector
&
inputs
)
:
MatmulFactory
(
inputs
)
{
}
protected
:
Output
<
Node
>
make_dot
(
const
Output
<
Node
>&
left
,
const
Output
<
Node
>&
right
)
override
;
};
}
// namespace builder
}
// namespace ngraph
src/ngraph/builder/reshape.cpp
View file @
fcf59b2a
...
@@ -120,3 +120,51 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, const Output<Node>&
...
@@ -120,3 +120,51 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, const Output<Node>&
// result := DynReshape(value, flattened_dims)
// result := DynReshape(value, flattened_dims)
return
make_shared
<
op
::
DynReshape
>
(
value
,
flattened_dims
);
return
make_shared
<
op
::
DynReshape
>
(
value
,
flattened_dims
);
}
}
shared_ptr
<
Node
>
builder
::
squeeze
(
const
Output
<
Node
>&
value
,
vector
<
size_t
>
axes
)
{
if
(
axes
.
empty
())
{
return
value
.
get_node_shared_ptr
();
}
Shape
in_shape
{
value
.
get_shape
()};
for
(
size_t
idx
=
0
;
idx
<
axes
.
size
();
++
idx
)
{
in_shape
.
at
(
idx
)
=
0
;
}
Shape
output_shape
;
for
(
auto
axis
:
in_shape
)
{
if
(
axis
!=
0
)
{
output_shape
.
push_back
(
axis
);
}
}
return
builder
::
reshape
(
value
,
output_shape
);
}
shared_ptr
<
Node
>
builder
::
collapse
(
const
Output
<
Node
>&
value
,
const
size_t
start_axis
,
const
size_t
end_axis
)
{
auto
shape
=
value
.
get_shape
();
size_t
collapsed_axis_size
=
accumulate
(
next
(
begin
(
shape
),
start_axis
),
next
(
begin
(
shape
),
end_axis
+
1
),
1UL
,
multiplies
<
size_t
>
());
Shape
output_shape
{
collapsed_axis_size
};
output_shape
.
insert
(
end
(
output_shape
),
next
(
begin
(
shape
),
end_axis
+
1
),
end
(
shape
));
return
builder
::
reshape
(
value
,
output_shape
);
}
shared_ptr
<
Node
>
builder
::
expand_dims
(
const
Output
<
Node
>&
value
,
size_t
axis
)
{
Shape
output_shape
(
value
.
get_shape
());
// Add empty axis at specified position.
auto
empty_axis_it
=
begin
(
output_shape
);
advance
(
empty_axis_it
,
axis
);
output_shape
.
insert
(
empty_axis_it
,
1
);
return
make_shared
<
op
::
Reshape
>
(
value
,
get_default_order
(
value
.
get_shape
().
size
()),
output_shape
);
}
src/ngraph/builder/reshape.hpp
View file @
fcf59b2a
...
@@ -55,7 +55,7 @@ namespace ngraph
...
@@ -55,7 +55,7 @@ namespace ngraph
/// \brief Flatten a value into a 2D matrix, with a static dividing axis.
/// \brief Flatten a value into a 2D matrix, with a static dividing axis.
///
///
/// \param value The tensor to be flattened.
/// \param value The tensor to be flattened.
/// \param axis The axis dividing shape.
/// \param axis
The axis dividing shape.
///
///
/// \return The new value will be a 2D matrix representing the flattened input node.
/// \return The new value will be a 2D matrix representing the flattened input node.
std
::
shared_ptr
<
Node
>
flatten
(
const
Output
<
Node
>&
value
,
int
axis
);
std
::
shared_ptr
<
Node
>
flatten
(
const
Output
<
Node
>&
value
,
int
axis
);
...
@@ -68,5 +68,40 @@ namespace ngraph
...
@@ -68,5 +68,40 @@ namespace ngraph
///
///
/// \return The new value will be a 2D matrix representing the flattened input node.
/// \return The new value will be a 2D matrix representing the flattened input node.
std
::
shared_ptr
<
Node
>
flatten
(
const
Output
<
Node
>&
value
,
const
Output
<
Node
>&
axis
);
std
::
shared_ptr
<
Node
>
flatten
(
const
Output
<
Node
>&
value
,
const
Output
<
Node
>&
axis
);
/// \brief Remove empty axes from input tensor.
///
/// \param[in] value The value to be squeezed.
/// \param[in] axes The vector defining indexes of axes to be removed.
///
/// \return The squeezed node.
///
std
::
shared_ptr
<
Node
>
squeeze
(
const
Output
<
Node
>&
value
,
std
::
vector
<
std
::
size_t
>
axes
=
{
0
});
/// \brief Collapse specified axes into single one.
///
/// \note Collapsed axes create a continuous range starting from outermost axis.
///
/// \param[in] value The value to be reshaped.
/// \param[in] start_axis The start axis index.
/// \param[in] end_axis The end axis (inclusive) index.
///
/// \return The node with collapsed specified axes.
///
std
::
shared_ptr
<
Node
>
collapse
(
const
Output
<
Node
>&
value
,
const
std
::
size_t
start_axis
,
const
std
::
size_t
end_axis
);
/// \brief Expands node tensor shape with empty axis at
/// specified position.
///
/// \param[in] value The value to be expanded.
/// \param[in] axis The position in the expanded axes where the
/// new axis is placed.
///
/// \return The node with added empty axis.
///
std
::
shared_ptr
<
Node
>
expand_dims
(
const
Output
<
Node
>&
value
,
std
::
size_t
axis
=
0
);
}
// namespace builder
}
// namespace builder
}
// namespace ngraph
}
// namespace ngraph
src/ngraph/frontend/onnx_import/CMakeLists.txt
View file @
fcf59b2a
...
@@ -204,8 +204,6 @@ add_library(onnx_import STATIC
...
@@ -204,8 +204,6 @@ add_library(onnx_import STATIC
utils/common.hpp
utils/common.hpp
utils/convpool.cpp
utils/convpool.cpp
utils/convpool.hpp
utils/convpool.hpp
utils/matmul_factory.cpp
utils/matmul_factory.hpp
utils/pooling_factory.cpp
utils/pooling_factory.cpp
utils/pooling_factory.hpp
utils/pooling_factory.hpp
utils/reduction.cpp
utils/reduction.cpp
...
...
src/ngraph/frontend/onnx_import/op/lstm.cpp
View file @
fcf59b2a
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include "exceptions.hpp"
#include "exceptions.hpp"
#include "lstm.hpp"
#include "lstm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/constant.hpp"
...
@@ -243,12 +244,12 @@ namespace ngraph
...
@@ -243,12 +244,12 @@ namespace ngraph
const
LSTMAttributes
&
attributes
)
const
LSTMAttributes
&
attributes
)
:
m_X
{
X
}
// Since we have forward LSTM we can squeeze `num_directions` axis
:
m_X
{
X
}
// Since we have forward LSTM we can squeeze `num_directions` axis
// from inputs.
// from inputs.
,
m_W
(
reshape
::
squeeze
(
W
))
,
m_W
(
builder
::
squeeze
(
W
))
,
m_R
(
reshape
::
squeeze
(
R
))
,
m_R
(
builder
::
squeeze
(
R
))
,
m_B
(
reshape
::
squeeze
(
B
))
,
m_B
(
builder
::
squeeze
(
B
))
,
m_P
(
reshape
::
squeeze
(
P
))
,
m_P
(
builder
::
squeeze
(
P
))
,
m_initial_h
(
reshape
::
squeeze
(
initial_h
))
,
m_initial_h
(
builder
::
squeeze
(
initial_h
))
,
m_initial_c
(
reshape
::
squeeze
(
initial_c
))
,
m_initial_c
(
builder
::
squeeze
(
initial_c
))
,
m_seq_lengths
(
seq_lengths
)
,
m_seq_lengths
(
seq_lengths
)
,
m_attributes
(
attributes
)
,
m_attributes
(
attributes
)
{
{
...
@@ -300,7 +301,7 @@ namespace ngraph
...
@@ -300,7 +301,7 @@ namespace ngraph
for
(
auto
&
in_x
:
in_seqs
)
for
(
auto
&
in_x
:
in_seqs
)
{
{
// remove first empty dim, after above split.
// remove first empty dim, after above split.
in_x
=
reshape
::
squeeze
(
in_x
);
in_x
=
builder
::
squeeze
(
in_x
);
}
}
std
::
int32_t
time_step
{
1
};
std
::
int32_t
time_step
{
1
};
...
@@ -331,7 +332,7 @@ namespace ngraph
...
@@ -331,7 +332,7 @@ namespace ngraph
// This results in zeroing out values in batches with sequence shorter
// This results in zeroing out values in batches with sequence shorter
// than current time_step.
// than current time_step.
h_list
.
push_back
(
h_list
.
push_back
(
get_masked_node
(
reshape
::
expand_dims
(
H
),
time_step
,
1
));
get_masked_node
(
builder
::
expand_dims
(
H
),
time_step
,
1
));
// Reference implementation in ONNX Runtime doesn't mask values of Y_h
// Reference implementation in ONNX Runtime doesn't mask values of Y_h
// and Y_c outputs, thus here we make sure that only appropriate batches
// and Y_c outputs, thus here we make sure that only appropriate batches
// (in respect to its sequence length) are updated. Those batches which
// (in respect to its sequence length) are updated. Those batches which
...
@@ -354,12 +355,12 @@ namespace ngraph
...
@@ -354,12 +355,12 @@ namespace ngraph
// Expand Y so that it has expected shape:
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
// [seq_length, num_directions, batch_size, hidden_size]
Y
=
reshape
::
expand_dims
(
Y
,
1
);
Y
=
builder
::
expand_dims
(
Y
,
1
);
// expand H_t and C_t so that it has expected shape:
// expand H_t and C_t so that it has expected shape:
// [num_directions, batch_size, hidden_size]
// [num_directions, batch_size, hidden_size]
auto
Y_h
=
reshape
::
expand_dims
(
H_t
);
auto
Y_h
=
builder
::
expand_dims
(
H_t
);
auto
Y_c
=
reshape
::
expand_dims
(
C_t
);
auto
Y_c
=
builder
::
expand_dims
(
C_t
);
return
{
Y
,
Y_h
,
Y_c
};
return
{
Y
,
Y_h
,
Y_c
};
}
}
...
...
src/ngraph/frontend/onnx_import/op/matmul.cpp
View file @
fcf59b2a
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
//*****************************************************************************
//*****************************************************************************
#include "matmul.hpp"
#include "matmul.hpp"
#include "
frontend/onnx_import/utils
/matmul_factory.hpp"
#include "
ngraph/builder
/matmul_factory.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -27,7 +27,20 @@ namespace ngraph
...
@@ -27,7 +27,20 @@ namespace ngraph
{
{
NodeVector
matmul
(
const
Node
&
node
)
NodeVector
matmul
(
const
Node
&
node
)
{
{
auto
factory
=
matmul
::
MatmulFactory
(
node
);
auto
ng_inputs
=
node
.
get_ng_inputs
();
auto
factory
=
builder
::
MatmulFactory
(
(
OutputVector
(
std
::
begin
(
ng_inputs
),
std
::
end
(
ng_inputs
))));
std
::
size_t
left_rank
{
ng_inputs
.
at
(
0
)
->
get_shape
().
size
()};
std
::
size_t
right_rank
{
ng_inputs
.
at
(
1
)
->
get_shape
().
size
()};
if
(
left_rank
==
0
||
right_rank
==
0
)
{
NGRAPH_WARN
<<
(
node
)
<<
" "
<<
"ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard."
;
}
return
factory
.
make_matmul_op
();
return
factory
.
make_matmul_op
();
}
}
}
// namespace set_1
}
// namespace set_1
...
...
src/ngraph/frontend/onnx_import/op/matmul_integer.cpp
View file @
fcf59b2a
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
//*****************************************************************************
//*****************************************************************************
#include "matmul_integer.hpp"
#include "matmul_integer.hpp"
#include "
frontend/onnx_import/utils
/matmul_factory.hpp"
#include "
ngraph/builder
/matmul_factory.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -27,7 +27,20 @@ namespace ngraph
...
@@ -27,7 +27,20 @@ namespace ngraph
{
{
NodeVector
matmul_integer
(
const
Node
&
node
)
NodeVector
matmul_integer
(
const
Node
&
node
)
{
{
auto
factory
=
matmul
::
MatmulIntegerFactory
(
node
);
auto
ng_inputs
=
node
.
get_ng_inputs
();
auto
factory
=
builder
::
MatmulIntegerFactory
(
OutputVector
(
std
::
begin
(
ng_inputs
),
std
::
end
(
ng_inputs
)));
std
::
size_t
left_rank
{
ng_inputs
.
at
(
0
)
->
get_shape
().
size
()};
std
::
size_t
right_rank
{
ng_inputs
.
at
(
1
)
->
get_shape
().
size
()};
if
(
left_rank
==
0
||
right_rank
==
0
)
{
NGRAPH_WARN
<<
(
node
)
<<
" "
<<
"ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard."
;
}
return
factory
.
make_matmul_op
();
return
factory
.
make_matmul_op
();
}
}
}
// namespace set_1
}
// namespace set_1
...
...
src/ngraph/frontend/onnx_import/op/qlinear_matmul.cpp
View file @
fcf59b2a
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
//*****************************************************************************
//*****************************************************************************
#include "qlinear_matmul.hpp"
#include "qlinear_matmul.hpp"
#include "
frontend/onnx_import/utils
/matmul_factory.hpp"
#include "
ngraph/builder
/matmul_factory.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -27,7 +27,20 @@ namespace ngraph
...
@@ -27,7 +27,20 @@ namespace ngraph
{
{
NodeVector
qlinear_matmul
(
const
Node
&
node
)
NodeVector
qlinear_matmul
(
const
Node
&
node
)
{
{
auto
factory
=
matmul
::
QLinearMatmulFactory
(
node
);
auto
ng_inputs
=
node
.
get_ng_inputs
();
auto
factory
=
builder
::
QLinearMatmulFactory
(
(
OutputVector
(
std
::
begin
(
ng_inputs
),
std
::
end
(
ng_inputs
))));
std
::
size_t
left_rank
{
ng_inputs
.
at
(
0
)
->
get_shape
().
size
()};
std
::
size_t
right_rank
{
ng_inputs
.
at
(
1
)
->
get_shape
().
size
()};
if
(
left_rank
==
0
||
right_rank
==
0
)
{
NGRAPH_WARN
<<
(
node
)
<<
" "
<<
"ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard."
;
}
return
factory
.
make_matmul_op
();
return
factory
.
make_matmul_op
();
}
}
}
// namespace set_1
}
// namespace set_1
...
...
src/ngraph/frontend/onnx_import/utils/matmul_factory.hpp
deleted
100644 → 0
View file @
adf849e5
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
matmul
{
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX MatMul
/// operation.
///
/// \note
/// The sub-graph is needed to adjust nGraph's Dot operation semantics to semantics
/// expected by ONNX, which are modeled on NumPy's "stacks of arrays" approach.
/// Differences are apparent with matrices of rank > 2.
///
/// This default implementation `MatmulFactory` creates a `MatMul` operation for
/// floating-point data. Subclasses: `QLinearMatmulFactory` and `MatmulIntegerFactory`
/// implement quantized versions.
class
MatmulFactory
{
public
:
explicit
MatmulFactory
(
const
Node
&
node
)
:
m_onnx_node
(
node
)
,
m_inputs
(
node
.
get_ng_inputs
())
{
}
virtual
~
MatmulFactory
()
=
default
;
/// \brief Create a sub-graph representing an ONNX MatMul operation.
///
/// \return NodeVector containing the sub-graph output node.
virtual
NodeVector
make_matmul_op
();
/// \return Node representing the left operand.
virtual
std
::
shared_ptr
<
ngraph
::
Node
>
get_left
();
/// \return Node representing the right operand.
virtual
std
::
shared_ptr
<
ngraph
::
Node
>
get_right
();
/// \return Node representing the nGraph Dot operation used to construct MatMul.
virtual
std
::
shared_ptr
<
ngraph
::
Node
>
make_dot
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
left
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
right
);
protected
:
const
Node
&
m_onnx_node
;
const
NodeVector
m_inputs
;
};
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX
/// QLinearMatMul operation.
class
QLinearMatmulFactory
:
public
MatmulFactory
{
public
:
explicit
QLinearMatmulFactory
(
const
Node
&
node
)
:
MatmulFactory
(
node
)
{
}
std
::
shared_ptr
<
ngraph
::
Node
>
get_right
()
override
;
std
::
shared_ptr
<
ngraph
::
Node
>
make_dot
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
left
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
right
)
override
;
};
/// \brief Factory class which generates an nGraph sub-graph based on an ONNX
/// MatMulInteger operation.
class
MatmulIntegerFactory
:
public
MatmulFactory
{
public
:
explicit
MatmulIntegerFactory
(
const
Node
&
node
)
:
MatmulFactory
(
node
)
{
}
std
::
shared_ptr
<
ngraph
::
Node
>
make_dot
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
left
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
right
)
override
;
};
}
}
}
src/ngraph/frontend/onnx_import/utils/reshape.cpp
View file @
fcf59b2a
...
@@ -85,60 +85,6 @@ namespace ngraph
...
@@ -85,60 +85,6 @@ namespace ngraph
return
inferred_dims
;
return
inferred_dims
;
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
squeeze
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
vector
<
std
::
size_t
>
axes
)
{
if
(
axes
.
empty
())
{
return
node
;
}
Shape
in_shape
{
node
->
get_shape
()};
for
(
std
::
size_t
idx
=
0
;
idx
<
axes
.
size
();
++
idx
)
{
in_shape
.
at
(
idx
)
=
0
;
}
Shape
output_shape
;
for
(
auto
axis
:
in_shape
)
{
if
(
axis
!=
0
)
{
output_shape
.
push_back
(
axis
);
}
}
return
ngraph
::
builder
::
reshape
(
node
,
output_shape
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
collapse
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
const
std
::
size_t
start_axis
,
const
std
::
size_t
end_axis
)
{
auto
shape
=
node
->
get_shape
();
std
::
size_t
collapsed_axis_size
=
std
::
accumulate
(
std
::
next
(
std
::
begin
(
shape
),
start_axis
),
std
::
next
(
std
::
begin
(
shape
),
end_axis
+
1
),
1UL
,
std
::
multiplies
<
std
::
size_t
>
());
Shape
output_shape
{
collapsed_axis_size
};
output_shape
.
insert
(
std
::
end
(
output_shape
),
std
::
next
(
std
::
begin
(
shape
),
end_axis
+
1
),
std
::
end
(
shape
));
return
ngraph
::
builder
::
reshape
(
node
,
output_shape
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
expand_dims
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
size_t
axis
)
{
Shape
output_shape
(
node
->
get_shape
());
// Add empty axis at specified position.
auto
empty_axis_it
=
std
::
begin
(
output_shape
);
std
::
advance
(
empty_axis_it
,
axis
);
output_shape
.
insert
(
empty_axis_it
,
1
);
return
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
node
,
ngraph
::
get_default_order
(
node
->
get_shape
().
size
()),
output_shape
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
std
::
shared_ptr
<
ngraph
::
Node
>
interpret_as_scalar
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
)
interpret_as_scalar
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
)
{
{
...
...
src/ngraph/frontend/onnx_import/utils/reshape.hpp
View file @
fcf59b2a
...
@@ -50,42 +50,6 @@ namespace ngraph
...
@@ -50,42 +50,6 @@ namespace ngraph
const
std
::
vector
<
std
::
size_t
>&
input_shape
,
const
std
::
vector
<
std
::
size_t
>&
input_shape
,
const
std
::
vector
<
std
::
size_t
>&
output_shape
);
const
std
::
vector
<
std
::
size_t
>&
output_shape
);
/// \brief Remove empty axes from input tensor.
///
/// \param[in] node The node to be squeezed.
/// \param[in] axes The vector defining indexes of axes to be removed.
///
/// \return The squeezed node.
///
std
::
shared_ptr
<
ngraph
::
Node
>
squeeze
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
vector
<
std
::
size_t
>
axes
=
{
0
});
/// \brief Collapse specified axes into single one.
///
/// \note Collapsed axes create a continuous range starting from outermost axis.
///
/// \param[in] node The node to be reshaped.
/// \param[in] start_axis The start axis index.
/// \param[in] end_axis The end axis (inclusive) index.
///
/// \return The node with collapsed specified axes.
///
std
::
shared_ptr
<
ngraph
::
Node
>
collapse
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
const
std
::
size_t
start_axis
,
const
std
::
size_t
end_axis
);
/// \brief Expands node tensor shape with empty axis at
/// specified position.
///
/// \param[in] node The node to be expanded.
/// \param[in] axis The position in the expanded axes where the
/// new axis is placed.
///
/// \return The node with added empty axis.
///
std
::
shared_ptr
<
ngraph
::
Node
>
expand_dims
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
size_t
axis
=
0
);
/// \brief Handle a node which represents a scalar value.
/// \brief Handle a node which represents a scalar value.
///
///
/// \note Some ONNX nodes, which should provide scalar values are given as
/// \note Some ONNX nodes, which should provide scalar values are given as
...
...
src/ngraph/ngraph.hpp
View file @
fcf59b2a
...
@@ -135,6 +135,7 @@ namespace ngraph
...
@@ -135,6 +135,7 @@ namespace ngraph
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/prelu.hpp"
...
...
src/ngraph/node.hpp
View file @
fcf59b2a
...
@@ -439,6 +439,8 @@ namespace ngraph
...
@@ -439,6 +439,8 @@ namespace ngraph
/// Get all the nodes that uses the current node
/// Get all the nodes that uses the current node
NodeVector
get_users
(
bool
check_is_used
=
false
)
const
;
NodeVector
get_users
(
bool
check_is_used
=
false
)
const
;
/// \return Version of this node
virtual
size_t
get_version
()
const
{
return
0
;
}
virtual
std
::
shared_ptr
<
Node
>
get_default_value
()
const
{
return
nullptr
;
}
virtual
std
::
shared_ptr
<
Node
>
get_default_value
()
const
{
return
nullptr
;
}
/// Use instance ids for comparison instead of memory addresses to improve determinism
/// Use instance ids for comparison instead of memory addresses to improve determinism
bool
operator
<
(
const
Node
&
other
)
const
{
return
m_instance_id
<
other
.
m_instance_id
;
}
bool
operator
<
(
const
Node
&
other
)
const
{
return
m_instance_id
<
other
.
m_instance_id
;
}
...
...
src/ngraph/op/fused/matmul.cpp
0 → 100644
View file @
fcf59b2a
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <numeric>
#include "matmul.hpp"
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp"
using
namespace
std
;
using
namespace
ngraph
;
const
string
op
::
MatMul
::
type_name
{
"MatMul"
};
op
::
MatMul
::
MatMul
(
const
Output
<
Node
>&
A
,
const
Output
<
Node
>&
B
,
const
bool
&
transpose_a
,
const
bool
&
transpose_b
)
:
FusedOp
(
OutputVector
{
A
,
B
})
,
m_transpose_a
{
transpose_a
}
,
m_transpose_b
{
transpose_b
}
{
constructor_validate_and_infer_types
();
}
NodeVector
op
::
MatMul
::
decompose_op
()
const
{
auto
A
=
input_value
(
0
);
auto
B
=
input_value
(
1
);
// Specification is expecting that A & B have at least 2 dimenstions.
// Missing dimensions are padded with 1.
int
a_rank
=
A
.
get_shape
().
size
();
if
(
a_rank
<
2
)
{
A
=
a_rank
==
0
?
make_shared
<
op
::
Reshape
>
(
A
,
AxisVector
{},
Shape
{
1
,
1
})
:
make_shared
<
op
::
Reshape
>
(
A
,
AxisVector
{
1
},
Shape
{
1
,
A
.
get_shape
()[
0
]});
a_rank
=
2
;
}
int
b_rank
=
B
.
get_shape
().
size
();
if
(
b_rank
<
2
)
{
B
=
b_rank
==
0
?
make_shared
<
op
::
Reshape
>
(
B
,
AxisVector
{},
Shape
{
1
,
1
})
:
make_shared
<
op
::
Reshape
>
(
B
,
AxisVector
{
1
},
Shape
{
1
,
B
.
get_shape
()[
0
]});
b_rank
=
2
;
}
if
(
m_transpose_a
)
{
vector
<
size_t
>
axes_order
(
a_rank
);
// generate default axes_order.
iota
(
axes_order
.
begin
(),
axes_order
.
end
(),
0
);
// transpose the last 2 spatial dims
swap
(
axes_order
[
a_rank
-
1
],
axes_order
[
a_rank
-
2
]);
A
=
builder
::
reorder_axes
(
A
,
axes_order
);
}
if
(
m_transpose_b
)
{
vector
<
size_t
>
axes_order
(
b_rank
);
iota
(
axes_order
.
begin
(),
axes_order
.
end
(),
0
);
swap
(
axes_order
[
b_rank
-
1
],
axes_order
[
b_rank
-
2
]);
B
=
builder
::
reorder_axes
(
B
,
axes_order
);
}
builder
::
MatmulFactory
factory
({
A
,
B
});
return
factory
.
make_matmul_op
();
}
shared_ptr
<
Node
>
op
::
MatMul
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
MatMul
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_transpose_a
,
m_transpose_b
);
}
src/ngraph/op/fused/matmul.hpp
0 → 100644
View file @
fcf59b2a
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace
ngraph
{
namespace
op
{
/// \brief Operator performing Matrix Multiplication.
class
MatMul
:
public
ngraph
::
op
::
util
::
FusedOp
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
MatMul
()
=
default
;
/// \brief Constructs an ScaleShift operation.
///
/// \param A Matrix A
/// \param B Matrix B
/// \param transpose_a If matrix A should be transposed.
/// \param transpose_b If matrix B should be transposed.
MatMul
(
const
Output
<
Node
>&
A
,
const
Output
<
Node
>&
B
,
const
bool
&
transpose_a
=
0
,
const
bool
&
transpose_b
=
0
);
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
bool
get_transpose_a
()
const
{
return
m_transpose_a
;
}
bool
get_transpose_b
()
const
{
return
m_transpose_b
;
}
private
:
const
bool
m_transpose_a
;
const
bool
m_transpose_b
;
};
}
// namespace op
}
// namespace ngraph
src/ngraph/op/fused_op_tbl.hpp
View file @
fcf59b2a
...
@@ -38,6 +38,7 @@ NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
...
@@ -38,6 +38,7 @@ NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP
(
GRUCell
,
ngraph
::
op
)
NGRAPH_OP
(
GRUCell
,
ngraph
::
op
)
NGRAPH_OP
(
HardSigmoid
,
ngraph
::
op
)
NGRAPH_OP
(
HardSigmoid
,
ngraph
::
op
)
NGRAPH_OP
(
LSTMCell
,
ngraph
::
op
)
NGRAPH_OP
(
LSTMCell
,
ngraph
::
op
)
NGRAPH_OP
(
MatMul
,
ngraph
::
op
)
NGRAPH_OP
(
MVN
,
ngraph
::
op
)
NGRAPH_OP
(
MVN
,
ngraph
::
op
)
NGRAPH_OP
(
NormalizeL2
,
ngraph
::
op
)
NGRAPH_OP
(
NormalizeL2
,
ngraph
::
op
)
NGRAPH_OP
(
PRelu
,
ngraph
::
op
)
NGRAPH_OP
(
PRelu
,
ngraph
::
op
)
...
...
src/ngraph/op/softmax.cpp
View file @
fcf59b2a
...
@@ -30,22 +30,36 @@ using namespace ngraph;
...
@@ -30,22 +30,36 @@ using namespace ngraph;
constexpr
NodeTypeInfo
op
::
Softmax
::
type_info
;
constexpr
NodeTypeInfo
op
::
Softmax
::
type_info
;
op
::
Softmax
::
Softmax
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
axes
)
op
::
v0
::
Softmax
::
Softmax
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
axes
)
:
UnaryElementwiseArithmetic
(
arg
)
:
Op
({
arg
}
)
,
m_axes
(
axes
)
,
m_axes
(
axes
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
NODE_VALIDATION_CHECK
(
this
,
input_shape
.
rank
().
is_static
(),
"Input node rank must be static (input_shape="
,
input_shape
,
")."
);
for
(
auto
axis
:
m_axes
)
for
(
auto
axis
:
m_axes
)
{
{
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
axis
<
get_shape
().
size
(
),
axis
>=
0
&&
axis
<
static_cast
<
size_t
>
(
input_shape
.
rank
()
),
"Reduction axis ("
,
"Reduction axis ("
,
axis
,
axis
,
") is out of bounds (argument shape: "
,
") is out of bounds (argument shape: "
,
get_shape
()
,
input_shape
,
")."
);
")."
);
}
}
if
(
input_shape
.
is_static
())
{
set_output_type
(
0
,
get_input_element_type
(
0
),
input_shape
.
to_shape
());
}
else
{
set_output_type
(
0
,
get_input_element_type
(
0
),
PartialShape
::
dynamic
());
}
// empty axes == all axes
// empty axes == all axes
if
(
m_axes
.
size
()
==
0
)
if
(
m_axes
.
size
()
==
0
)
...
@@ -57,13 +71,13 @@ op::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
...
@@ -57,13 +71,13 @@ op::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
}
}
}
}
shared_ptr
<
Node
>
op
::
Softmax
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
v0
::
Softmax
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
Softmax
>
(
new_args
.
at
(
0
),
m_axes
);
return
make_shared
<
Softmax
>
(
new_args
.
at
(
0
),
m_axes
);
}
}
void
op
::
Softmax
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
void
op
::
v0
::
Softmax
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
{
{
auto
delta
=
deltas
.
at
(
0
);
auto
delta
=
deltas
.
at
(
0
);
...
@@ -90,3 +104,75 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
...
@@ -90,3 +104,75 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto
x
=
input_value
(
0
);
auto
x
=
input_value
(
0
);
adjoints
.
add_delta
(
x
,
adjoint
);
adjoints
.
add_delta
(
x
,
adjoint
);
}
}
// *** SOFTMAX OP SET V1 ***
const
string
op
::
v1
::
Softmax
::
type_name
{
"Softmax"
};
op
::
v1
::
Softmax
::
Softmax
(
const
Output
<
Node
>&
arg
,
const
size_t
axis
)
:
Op
({
arg
})
,
m_axis
(
axis
)
{
constructor_validate_and_infer_types
();
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
NODE_VALIDATION_CHECK
(
this
,
input_shape
.
rank
().
is_static
(),
"Input node rank must be static (input_shape="
,
input_shape
,
")."
);
NODE_VALIDATION_CHECK
(
this
,
axis
>=
0
&&
axis
<
static_cast
<
size_t
>
(
input_shape
.
rank
()),
"Reduction axis ("
,
axis
,
") is out of bounds (argument shape: "
,
input_shape
,
")."
);
if
(
input_shape
.
is_static
())
set_output_type
(
0
,
get_input_element_type
(
0
),
input_shape
.
to_shape
());
else
set_output_type
(
0
,
get_input_element_type
(
0
),
PartialShape
::
dynamic
());
}
shared_ptr
<
Node
>
op
::
v1
::
Softmax
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
op
::
v1
::
Softmax
>
(
new_args
.
at
(
0
),
m_axis
);
}
void
op
::
v1
::
Softmax
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
{
throw
ngraph_error
(
"op::v1::Softmax::generate_adjoints function is not implemented yet"
);
/* This might work, but as of this writing we have no way to test it, so we are being careful
auto delta = deltas.at(0);
auto z = delta * shared_from_this();
std::vector<size_t> axes(get_shape().size() - m_axis);
std::iota(std::begin(axes), std::end(axes), m_axis);
AxisSet axes_set{axes};
auto zsum = make_shared<op::Sum>(z, axes_set);
Shape shape;
for (size_t i = 0; i < get_shape().size(); ++i)
{
if (axes_set.find(i) == axes_set.end())
{
shape.push_back(get_shape()[i]);
}
else
{
shape.push_back(1);
}
}
auto order = ngraph::get_default_order(zsum->get_shape());
auto zreshape = make_shared<op::Reshape>(zsum, order, shape);
auto adjoint = z - builder::make_with_numpy_broadcast<op::Multiply>(output(0), zreshape);
auto x = input(0).get_source_output();
adjoints.add_delta(x, adjoint);
*/
}
src/ngraph/op/softmax.hpp
View file @
fcf59b2a
...
@@ -16,15 +16,13 @@
...
@@ -16,15 +16,13 @@
#pragma once
#pragma once
#include "ngraph/op/
util/unary_elementwise_arithmetic
.hpp"
#include "ngraph/op/
op
.hpp"
namespace
ngraph
namespace
ngraph
{
{
namespace
op
namespace
op
{
{
/// \brief Softmax operation.
namespace
v0
///
class
Softmax
:
public
util
::
UnaryElementwiseArithmetic
{
{
public
:
public
:
NGRAPH_API
NGRAPH_API
...
@@ -32,26 +30,75 @@ namespace ngraph
...
@@ -32,26 +30,75 @@ namespace ngraph
const
NodeTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
const
NodeTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
Softmax
()
=
default
;
Softmax
()
=
default
;
/// \brief Constructs a softmax operation.
/// \brief Constructs a softmax operation.
///
class
Softmax
:
public
Op
/// \param arg Node that produces the first input tensor.<br>
{
/// `[d0, ...]`
public
:
/// \param axes The axis positions (0-based) on which to calculate the softmax.
NGRAPH_API
///
static
const
std
::
string
type_name
;
/// Output `[d0, ...]`
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
///
Softmax
()
=
default
;
Softmax
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
axes
);
/// \brief Constructs a softmax operation.
///
virtual
std
::
shared_ptr
<
Node
>
/// \param arg Node that produces the first input tensor.<br>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
/// `[d0, ...]`
/// \param axes The axis positions (0-based) on which to calculate the softmax.
const
AxisSet
&
get_axes
()
const
{
return
m_axes
;
}
///
void
set_axes
(
const
AxisSet
&
axes
)
{
m_axes
=
axes
;
}
/// Output `[d0, ...]`
protected
:
///
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
Softmax
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
axes
);
const
NodeVector
&
deltas
)
override
;
virtual
std
::
shared_ptr
<
Node
>
private
:
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
AxisSet
m_axes
;
};
const
AxisSet
&
get_axes
()
const
{
return
m_axes
;
}
void
set_axes
(
const
AxisSet
&
axes
)
{
m_axes
=
axes
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
private
:
AxisSet
m_axes
;
};
}
namespace
v1
{
class
Softmax
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Softmax
()
:
m_axis
(
0
)
{
}
/// \brief Constructs a softmax operation.
///
/// \param arg Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param axis The axis position (0-based) on which to calculate the softmax.
///
/// Output `[d0, ...]`
///
Softmax
(
const
Output
<
Node
>&
arg
,
const
size_t
axis
);
size_t
get_version
()
const
override
{
return
1
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
size_t
get_axis
()
const
{
return
m_axis
;
}
void
set_axis
(
const
size_t
axis
)
{
m_axis
=
axis
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
private
:
size_t
m_axis
;
};
}
// default opset version
using
v0
::
Softmax
;
}
}
}
}
src/ngraph/op/topk.hpp
View file @
fcf59b2a
...
@@ -54,13 +54,13 @@ namespace ngraph
...
@@ -54,13 +54,13 @@ namespace ngraph
/// supported
/// supported
/// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param compute_max Compute top k max or top k min?
/// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default -
NONE
/// \param sort SortType for sorting results, default -
SORT_VALUES
TopK
(
const
Output
<
Node
>&
arg
,
TopK
(
const
Output
<
Node
>&
arg
,
size_t
top_k_axis
,
size_t
top_k_axis
,
const
element
::
Type
&
index_element_type
,
const
element
::
Type
&
index_element_type
,
size_t
k
=
0
,
size_t
k
=
0
,
bool
compute_max
=
true
,
bool
compute_max
=
true
,
SortType
sort
=
SortType
::
NONE
);
SortType
sort
=
SortType
::
SORT_VALUES
);
/// \brief Constructs a TopK operation.
/// \brief Constructs a TopK operation.
///
///
/// \param arg The input tensor
/// \param arg The input tensor
...
@@ -69,13 +69,13 @@ namespace ngraph
...
@@ -69,13 +69,13 @@ namespace ngraph
/// \param index_element_type produce indices. Currently, only int64 or int32 are
/// \param index_element_type produce indices. Currently, only int64 or int32 are
/// supported
/// supported
/// \param compute_max Compute top k max or top k min?
/// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default -
NONE
/// \param sort SortType for sorting results, default -
SORT_VALUES
TopK
(
const
Output
<
Node
>&
arg
,
TopK
(
const
Output
<
Node
>&
arg
,
const
Output
<
Node
>&
k
,
const
Output
<
Node
>&
k
,
size_t
top_k_axis
,
size_t
top_k_axis
,
const
element
::
Type
&
index_element_type
,
const
element
::
Type
&
index_element_type
,
bool
compute_max
=
true
,
bool
compute_max
=
true
,
SortType
sort
=
SortType
::
NONE
);
SortType
sort
=
SortType
::
SORT_VALUES
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/pass/opset1_upgrade.cpp
0 → 100644
View file @
fcf59b2a
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/softmax.hpp"
using
namespace
std
;
using
namespace
ngraph
;
#define NGRAPH_OP(a, b) a,
enum
class
OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
#define NGRAPH_OP(a, b) {#a, OP_TYPEID::a},
static
unordered_map
<
string
,
OP_TYPEID
>
typeid_map
{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
static
OP_TYPEID
get_typeid
(
shared_ptr
<
Node
>
node
)
{
OP_TYPEID
type_id
;
auto
it
=
typeid_map
.
find
(
node
->
description
());
if
(
it
!=
typeid_map
.
end
())
{
type_id
=
it
->
second
;
}
else
{
throw
unsupported_op
(
"Unsupported op '"
+
node
->
description
()
+
"'"
);
}
return
type_id
;
}
// END mapping to OP_TYPEID
bool
pass
::
Opset1Upgrade
::
run_on_node
(
shared_ptr
<
Node
>
node
)
{
bool
modified
=
false
;
size_t
op_version
=
node
->
get_version
();
if
(
op_version
==
1
)
{
return
modified
;
}
NGRAPH_CHECK
(
op_version
==
0
,
"Op version 1 transformation pass failed for "
,
*
node
,
", only op version 0 operations expected. Op version "
,
op_version
,
" found."
);
// Not all enumeration values explicitly handled in switch
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch
(
get_typeid
(
node
))
{
case
OP_TYPEID
:
:
Softmax
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
v0
::
Softmax
*>
(
node
.
get
());
AxisSet
axes
=
tmp
->
get_axes
();
NGRAPH_CHECK
(
axes
.
size
()
==
1
,
"Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: "
,
*
node
);
auto
replacement_node
=
make_shared
<
op
::
v1
::
Softmax
>
(
node
->
input
(
0
).
get_source_output
(),
axes
.
to_vector
()[
0
]);
replace_node
(
node
,
replacement_node
);
modified
=
true
;
break
;
}
default
:
break
;
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
return
modified
;
}
src/ngraph/pass/opset1_upgrade.hpp
0 → 100644
View file @
fcf59b2a
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
namespace
pass
{
class
Opset1Upgrade
:
public
NodePass
{
public
:
///
/// \brief Constructor for the Opset 1 transformation pass.
///
/// \details This transformation pass iterates over all nodes in a graph
/// and updates opset version 0 ops to their opset version 1 equivalents.
/// All ops in the final graph have opset version 1.
Opset1Upgrade
()
=
default
;
bool
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
override
;
};
}
}
src/ngraph/runtime/cpu/builder/topk.cpp
View file @
fcf59b2a
...
@@ -52,6 +52,7 @@ namespace ngraph
...
@@ -52,6 +52,7 @@ namespace ngraph
auto
out_shape
=
out
[
0
].
get_shape
();
auto
out_shape
=
out
[
0
].
get_shape
();
auto
k
=
topk
->
get_k
();
auto
k
=
topk
->
get_k
();
auto
compute_max
=
topk
->
get_compute_max
();
auto
compute_max
=
topk
->
get_compute_max
();
auto
sort
=
topk
->
get_sort
();
auto
element_type
=
args
[
0
].
get_element_type
();
auto
element_type
=
args
[
0
].
get_element_type
();
if
(
element_type
==
element
::
f32
)
if
(
element_type
==
element
::
f32
)
...
@@ -64,6 +65,7 @@ namespace ngraph
...
@@ -64,6 +65,7 @@ namespace ngraph
axis
,
axis
,
k
,
k
,
compute_max
,
compute_max
,
sort
,
arg_buffer_index
,
arg_buffer_index
,
out_indices_buffer_index
,
out_indices_buffer_index
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
...
@@ -76,7 +78,8 @@ namespace ngraph
...
@@ -76,7 +78,8 @@ namespace ngraph
out_shape
,
out_shape
,
axis
,
axis
,
k
,
k
,
compute_max
);
compute_max
,
sort
);
};
};
}
}
else
else
...
@@ -87,6 +90,7 @@ namespace ngraph
...
@@ -87,6 +90,7 @@ namespace ngraph
axis
,
axis
,
k
,
k
,
compute_max
,
compute_max
,
sort
,
arg_buffer_index
,
arg_buffer_index
,
out_indices_buffer_index
,
out_indices_buffer_index
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
...
@@ -99,7 +103,8 @@ namespace ngraph
...
@@ -99,7 +103,8 @@ namespace ngraph
out_shape
,
out_shape
,
axis
,
axis
,
k
,
k
,
compute_max
);
compute_max
,
sort
);
};
};
}
}
}
}
...
@@ -113,6 +118,7 @@ namespace ngraph
...
@@ -113,6 +118,7 @@ namespace ngraph
axis
,
axis
,
k
,
k
,
compute_max
,
compute_max
,
sort
,
arg_buffer_index
,
arg_buffer_index
,
out_indices_buffer_index
,
out_indices_buffer_index
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
...
@@ -125,7 +131,8 @@ namespace ngraph
...
@@ -125,7 +131,8 @@ namespace ngraph
out_shape
,
out_shape
,
axis
,
axis
,
k
,
k
,
compute_max
);
compute_max
,
sort
);
};
};
}
}
else
else
...
@@ -136,6 +143,7 @@ namespace ngraph
...
@@ -136,6 +143,7 @@ namespace ngraph
axis
,
axis
,
k
,
k
,
compute_max
,
compute_max
,
sort
,
arg_buffer_index
,
arg_buffer_index
,
out_indices_buffer_index
,
out_indices_buffer_index
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
...
@@ -148,7 +156,8 @@ namespace ngraph
...
@@ -148,7 +156,8 @@ namespace ngraph
out_shape
,
out_shape
,
axis
,
axis
,
k
,
k
,
compute_max
);
compute_max
,
sort
);
};
};
}
}
}
}
...
@@ -162,6 +171,7 @@ namespace ngraph
...
@@ -162,6 +171,7 @@ namespace ngraph
axis
,
axis
,
k
,
k
,
compute_max
,
compute_max
,
sort
,
arg_buffer_index
,
arg_buffer_index
,
out_indices_buffer_index
,
out_indices_buffer_index
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
...
@@ -174,7 +184,8 @@ namespace ngraph
...
@@ -174,7 +184,8 @@ namespace ngraph
out_shape
,
out_shape
,
axis
,
axis
,
k
,
k
,
compute_max
);
compute_max
,
sort
);
};
};
}
}
else
else
...
@@ -185,6 +196,7 @@ namespace ngraph
...
@@ -185,6 +196,7 @@ namespace ngraph
axis
,
axis
,
k
,
k
,
compute_max
,
compute_max
,
sort
,
arg_buffer_index
,
arg_buffer_index
,
out_indices_buffer_index
,
out_indices_buffer_index
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
out_values_buffer_index
](
CPURuntimeContext
*
ctx
,
...
@@ -197,7 +209,8 @@ namespace ngraph
...
@@ -197,7 +209,8 @@ namespace ngraph
out_shape
,
out_shape
,
axis
,
axis
,
k
,
k
,
compute_max
);
compute_max
,
sort
);
};
};
}
}
}
}
...
...
src/ngraph/runtime/cpu/cpu_emitter.hpp
View file @
fcf59b2a
...
@@ -142,7 +142,6 @@ namespace ngraph
...
@@ -142,7 +142,6 @@ namespace ngraph
class
SigmoidBackprop
;
class
SigmoidBackprop
;
class
SigmoidMultiply
;
class
SigmoidMultiply
;
class
SigmoidMultiplyBackprop
;
class
SigmoidMultiplyBackprop
;
class
Softmax
;
class
Result
;
class
Result
;
class
And
;
class
And
;
class
Or
;
class
Or
;
...
...
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
fcf59b2a
...
@@ -1244,6 +1244,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
...
@@ -1244,6 +1244,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS_WITH_ARGS
(
REGISTER_KNOBBED_PASS_WITH_ARGS
(
CommonSubexpressionElimination
,
true
,
ngraph
::
pass
,
runtime
::
cpu
::
get_cse_handlers_map
());
CommonSubexpressionElimination
,
true
,
ngraph
::
pass
,
runtime
::
cpu
::
get_cse_handlers_map
());
REGISTER_KNOBBED_PASS
(
CPUPostLayoutOptimizations
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUPostLayoutOptimizations
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUConvertLayoutConstantFolding
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUMemoryOptimization
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUMemoryOptimization
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
GetOutputElementElimination
,
false
,
ngraph
::
pass
);
REGISTER_KNOBBED_PASS
(
GetOutputElementElimination
,
false
,
ngraph
::
pass
);
REGISTER_KNOBBED_PASS_WITH_ARGS
(
REGISTER_KNOBBED_PASS_WITH_ARGS
(
...
...
src/ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.cpp
View file @
fcf59b2a
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp
View file @
fcf59b2a
...
@@ -27,6 +27,7 @@ namespace ngraph
...
@@ -27,6 +27,7 @@ namespace ngraph
namespace
pass
namespace
pass
{
{
class
CPUPostLayoutOptimizations
;
class
CPUPostLayoutOptimizations
;
class
CPUConvertLayoutConstantFolding
;
}
}
}
}
}
}
...
@@ -47,3 +48,11 @@ public:
...
@@ -47,3 +48,11 @@ public:
void
construct_slice_convertLayout_fusion
();
void
construct_slice_convertLayout_fusion
();
void
construct_reshape_convertLayout_fusion
();
void
construct_reshape_convertLayout_fusion
();
};
};
class
CPU_BACKEND_API
ngraph
::
runtime
::
cpu
::
pass
::
CPUConvertLayoutConstantFolding
:
public
ngraph
::
pass
::
FunctionPass
{
public
:
CPUConvertLayoutConstantFolding
()
{}
virtual
bool
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
function
)
override
;
};
src/ngraph/runtime/intelgpu/intelgpu_backend.cpp
View file @
fcf59b2a
...
@@ -89,6 +89,7 @@
...
@@ -89,6 +89,7 @@
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
...
@@ -2072,6 +2073,7 @@ shared_ptr<runtime::Executable>
...
@@ -2072,6 +2073,7 @@ shared_ptr<runtime::Executable>
case
OP_TYPEID
:
:
GRUCell
:
case
OP_TYPEID
:
:
GRUCell
:
case
OP_TYPEID
:
:
HardSigmoid
:
case
OP_TYPEID
:
:
HardSigmoid
:
case
OP_TYPEID
:
:
LSTMCell
:
case
OP_TYPEID
:
:
LSTMCell
:
case
OP_TYPEID
:
:
MatMul
:
case
OP_TYPEID
:
:
MVN
:
case
OP_TYPEID
:
:
MVN
:
case
OP_TYPEID
:
:
NormalizeL2
:
case
OP_TYPEID
:
:
NormalizeL2
:
case
OP_TYPEID
:
:
PRelu
:
case
OP_TYPEID
:
:
PRelu
:
...
@@ -2197,6 +2199,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
...
@@ -2197,6 +2199,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case
OP_TYPEID
:
:
GroupConvolutionTranspose
:
case
OP_TYPEID
:
:
GroupConvolutionTranspose
:
case
OP_TYPEID
:
:
GRUCell
:
case
OP_TYPEID
:
:
GRUCell
:
case
OP_TYPEID
:
:
LSTMCell
:
case
OP_TYPEID
:
:
LSTMCell
:
case
OP_TYPEID
:
:
MatMul
:
case
OP_TYPEID
:
:
MVN
:
case
OP_TYPEID
:
:
MVN
:
case
OP_TYPEID
:
:
NormalizeL2
:
case
OP_TYPEID
:
:
NormalizeL2
:
case
OP_TYPEID
:
:
PRelu
:
case
OP_TYPEID
:
:
PRelu
:
...
...
src/ngraph/runtime/interpreter/int_executable.hpp
View file @
fcf59b2a
...
@@ -243,6 +243,16 @@ private:
...
@@ -243,6 +243,16 @@ private:
{
{
const
Node
&
node
=
*
node_wrapper
.
get_node
();
const
Node
&
node
=
*
node_wrapper
.
get_node
();
size_t
op_version
=
node
.
get_version
();
bool
is_op_version_supported
=
op_version
==
0
;
NGRAPH_CHECK
(
is_op_version_supported
,
"Unsupported operator version "
,
op_version
,
" in "
,
node
,
".
\n
"
,
"INTERPRETER backend currently only supports op in version 0."
);
// We want to check that every OP_TYPEID enumeration is included in the list.
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
// These GCC flags enable compile-time checking so that if an enumeration
// is not in the list an error is generated.
// is not in the list an error is generated.
...
@@ -1724,7 +1734,8 @@ private:
...
@@ -1724,7 +1734,8 @@ private:
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
topk
->
get_top_k_axis
(),
topk
->
get_k
(),
topk
->
get_k
(),
topk
->
get_compute_max
());
topk
->
get_compute_max
(),
topk
->
get_sort
());
}
}
else
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i32
)
else
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i32
)
{
{
...
@@ -1735,7 +1746,8 @@ private:
...
@@ -1735,7 +1746,8 @@ private:
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
topk
->
get_top_k_axis
(),
topk
->
get_k
(),
topk
->
get_k
(),
topk
->
get_compute_max
());
topk
->
get_compute_max
(),
topk
->
get_sort
());
}
}
else
else
{
{
...
...
src/ngraph/runtime/plaidml/unit_test.manifest
View file @
fcf59b2a
...
@@ -39,6 +39,13 @@ topk_2d_min_one # No plans to implement TopK
...
@@ -39,6 +39,13 @@ topk_2d_min_one # No plans to implement TopK
topk_int64 # No plans to implement TopK
topk_int64 # No plans to implement TopK
topk_5d_max_partial # No plans to implement TopK
topk_5d_max_partial # No plans to implement TopK
topk_1d_i32_max_all # No plans to implement TopK
topk_1d_i32_max_all # No plans to implement TopK
topk_resnet50 # No plans to implement TopK
topk_max_sort_none # No plans to implement TopK
topk_min_sort_none # No plans to implement TopK
topk_max_sort_value # No plans to implement TopK
topk_min_sort_value # No plans to implement TopK
topk_max_sort_index # No plans to implement TopK
topk_min_sort_index # No plans to implement TopK
topk_2d_max_one_with_equal_values # No plans to implement TopK
topk_2d_max_one_with_equal_values # No plans to implement TopK
model_top_k # No plans to implement TopK
model_top_k # No plans to implement TopK
...
@@ -254,6 +261,10 @@ dot_2x0_0
...
@@ -254,6 +261,10 @@ dot_2x0_0
auto_bcast_binary_elementwise
auto_bcast_binary_elementwise
max_pool_2d_1channel_1image_overpadded
max_pool_2d_1channel_1image_overpadded
# passes locally, fails in CI
numeric_float_nan
fake_quantize_with_clip_across_channels
# axes input param not supported
# axes input param not supported
lrn_across_h
lrn_across_h
lrn_across_hw
lrn_across_hw
...
...
src/ngraph/runtime/reference/topk.hpp
View file @
fcf59b2a
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include <numeric>
#include <numeric>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/op/topk.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -46,15 +47,28 @@ namespace ngraph
...
@@ -46,15 +47,28 @@ namespace ngraph
#if defined(__GNUC__)
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#pragma GCC diagnostic pop
#endif
#endif
return
a
>
b
;
return
a
>
b
;
}
}
template
<
typename
T
,
typename
U
>
template
<
typename
T
,
typename
U
>
inline
bool
compare_min
(
const
std
::
tuple
<
T
,
U
>&
a
,
const
std
::
tuple
<
T
,
U
>&
b
)
inline
bool
compare_min
(
const
std
::
tuple
<
T
,
U
>&
a
,
const
std
::
tuple
<
T
,
U
>&
b
)
{
{
return
a
<
b
;
return
a
<
b
;
}
}
template
<
typename
T
,
typename
U
>
inline
bool
sort_indices_descending
(
const
std
::
tuple
<
T
,
U
>&
a
,
const
std
::
tuple
<
T
,
U
>&
b
)
{
return
std
::
get
<
1
>
(
a
)
<
std
::
get
<
1
>
(
b
);
}
template
<
typename
T
,
typename
U
>
inline
bool
sort_indices_ascending
(
const
std
::
tuple
<
T
,
U
>&
a
,
const
std
::
tuple
<
T
,
U
>&
b
)
{
return
std
::
get
<
1
>
(
a
)
>
std
::
get
<
1
>
(
b
);
}
template
<
typename
T
,
typename
U
>
template
<
typename
T
,
typename
U
>
void
topk
(
const
T
*
arg
,
void
topk
(
const
T
*
arg
,
U
*
out_indices
,
U
*
out_indices
,
...
@@ -63,7 +77,8 @@ namespace ngraph
...
@@ -63,7 +77,8 @@ namespace ngraph
const
Shape
&
out_shape
,
const
Shape
&
out_shape
,
size_t
axis
,
size_t
axis
,
size_t
k
,
size_t
k
,
bool
compute_max
)
bool
compute_max
,
op
::
TopK
::
SortType
sort
=
op
::
TopK
::
SortType
::
NONE
)
{
{
using
namespace
std
;
using
namespace
std
;
// reorder source axis visit order and make "axis" inner most
// reorder source axis visit order and make "axis" inner most
...
@@ -103,13 +118,49 @@ namespace ngraph
...
@@ -103,13 +118,49 @@ namespace ngraph
// Sort the temp vector
// Sort the temp vector
if
(
compute_max
)
if
(
compute_max
)
{
{
sort
(
workspace
.
begin
(),
workspace
.
end
(),
compare_max
<
T
,
U
>
);
nth_element
(
workspace
.
begin
(),
workspace
.
begin
()
+
k
,
workspace
.
end
(),
compare_max
<
T
,
U
>
);
}
}
else
else
{
{
sort
(
workspace
.
begin
(),
workspace
.
end
(),
compare_min
<
T
,
U
>
);
nth_element
(
workspace
.
begin
(),
workspace
.
begin
()
+
k
,
workspace
.
end
(),
compare_min
<
T
,
U
>
);
}
}
// Write temp vector to output
// Write temp vector to output
if
(
compute_max
)
{
switch
(
sort
)
{
case
op
:
:
TopK
::
SortType
::
NONE
:
break
;
case
op
:
:
TopK
::
SortType
::
SORT_INDICES
:
std
::
sort
(
workspace
.
begin
(),
workspace
.
begin
()
+
k
,
sort_indices_descending
<
T
,
U
>
);
break
;
case
op
:
:
TopK
::
SortType
::
SORT_VALUES
:
std
::
sort
(
workspace
.
begin
(),
workspace
.
begin
()
+
k
,
compare_max
<
T
,
U
>
);
break
;
}
}
else
{
switch
(
sort
)
{
case
op
:
:
TopK
::
SortType
::
NONE
:
break
;
case
op
:
:
TopK
::
SortType
::
SORT_INDICES
:
std
::
sort
(
workspace
.
begin
(),
workspace
.
begin
()
+
k
,
sort_indices_ascending
<
T
,
U
>
);
break
;
case
op
:
:
TopK
::
SortType
::
SORT_VALUES
:
std
::
sort
(
workspace
.
begin
(),
workspace
.
begin
()
+
k
,
compare_min
<
T
,
U
>
);
break
;
}
}
for
(
size_t
j
=
0
;
j
<
k
;
j
++
)
for
(
size_t
j
=
0
;
j
<
k
;
j
++
)
{
{
tuple
<
T
,
U
>
entry
=
workspace
[
j
];
tuple
<
T
,
U
>
entry
=
workspace
[
j
];
...
...
src/ngraph/serializer.cpp
View file @
fcf59b2a
...
@@ -80,6 +80,7 @@
...
@@ -80,6 +80,7 @@
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/prelu.hpp"
...
@@ -554,7 +555,7 @@ json JSONSerializer::serialize_function(const Function& f)
...
@@ -554,7 +555,7 @@ json JSONSerializer::serialize_function(const Function& f)
template
<
typename
T
>
template
<
typename
T
>
T
get_value
(
json
js
,
const
string
&
key
)
T
get_value
(
json
js
,
const
string
&
key
)
{
{
T
rc
;
T
rc
=
{}
;
auto
it
=
js
.
find
(
key
);
auto
it
=
js
.
find
(
key
);
if
(
it
!=
js
.
end
())
if
(
it
!=
js
.
end
())
{
{
...
@@ -719,15 +720,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -719,15 +720,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
string
node_name
=
node_js
.
at
(
"name"
).
get
<
string
>
();
string
node_name
=
node_js
.
at
(
"name"
).
get
<
string
>
();
string
node_op
=
node_js
.
at
(
"op"
).
get
<
string
>
();
string
node_op
=
node_js
.
at
(
"op"
).
get
<
string
>
();
string
friendly_name
=
get_value
<
string
>
(
node_js
,
"friendly_name"
);
string
friendly_name
=
get_value
<
string
>
(
node_js
,
"friendly_name"
);
size_t
op_version
=
get_value
<
size_t
>
(
node_js
,
"op_version"
);
vector
<
json
>
control_deps_inputs
=
get_value
<
vector
<
json
>>
(
node_js
,
"control_deps"
);
vector
<
json
>
control_deps_inputs
=
get_value
<
vector
<
json
>>
(
node_js
,
"control_deps"
);
vector
<
string
>
node_outputs
=
get_value
<
vector
<
string
>>
(
node_js
,
"outputs"
);
vector
<
string
>
node_outputs
=
get_value
<
vector
<
string
>>
(
node_js
,
"outputs"
);
OutputVectorHelper
args
(
deserialize_output_vector
(
node_js
[
"inputs"
]));
OutputVectorHelper
args
(
deserialize_output_vector
(
node_js
[
"inputs"
]));
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wimplicit-fallthrough"
// #pragma GCC diagnostic error "-Wimplicit-fallthrough"
#endif
#endif
switch
(
get_typeid
(
node_op
))
switch
(
get_typeid
(
node_op
))
{
{
case
OP_TYPEID
:
:
Abs
:
case
OP_TYPEID
:
:
Abs
:
...
@@ -1399,6 +1403,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -1399,6 +1403,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
input_forget
);
input_forget
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
MatMul
:
{
bool
transpose_a
=
node_js
.
at
(
"transpose_a"
).
get
<
bool
>
();
bool
transpose_b
=
node_js
.
at
(
"transpose_b"
).
get
<
bool
>
();
node
=
make_shared
<
op
::
MatMul
>
(
args
[
0
],
args
[
1
],
transpose_a
,
transpose_b
);
break
;
}
case
OP_TYPEID
:
:
Max
:
case
OP_TYPEID
:
:
Max
:
{
{
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
));
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
));
...
@@ -1831,8 +1842,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -1831,8 +1842,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
}
case
OP_TYPEID
:
:
Softmax
:
case
OP_TYPEID
:
:
Softmax
:
{
{
auto
softmax_axes
=
deserialize_axis_set
(
node_js
.
at
(
"softmax_axes"
));
if
(
op_version
==
0
)
node
=
make_shared
<
op
::
Softmax
>
(
args
[
0
],
softmax_axes
);
{
auto
softmax_axes
=
deserialize_axis_set
(
node_js
.
at
(
"softmax_axes"
));
node
=
make_shared
<
op
::
Softmax
>
(
args
[
0
],
softmax_axes
);
}
if
(
op_version
==
1
)
{
size_t
softmax_axis
=
node_js
.
at
(
"softmax_axis"
);
node
=
make_shared
<
op
::
v1
::
Softmax
>
(
args
[
0
],
softmax_axis
);
}
break
;
break
;
}
}
case
OP_TYPEID
:
:
SpaceToDepth
:
case
OP_TYPEID
:
:
SpaceToDepth
:
...
@@ -2028,6 +2047,9 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2028,6 +2047,9 @@ json JSONSerializer::serialize_node(const Node& n)
m_nodes_serialized
.
insert
(
&
n
);
m_nodes_serialized
.
insert
(
&
n
);
json
node
;
json
node
;
node
[
"name"
]
=
n
.
get_name
();
node
[
"name"
]
=
n
.
get_name
();
auto
op_version
=
n
.
get_version
();
node
[
"op_version"
]
=
op_version
;
if
(
n
.
get_name
()
!=
n
.
get_friendly_name
())
if
(
n
.
get_name
()
!=
n
.
get_friendly_name
())
{
{
node
[
"friendly_name"
]
=
n
.
get_friendly_name
();
node
[
"friendly_name"
]
=
n
.
get_friendly_name
();
...
@@ -2543,6 +2565,13 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2543,6 +2565,13 @@ json JSONSerializer::serialize_node(const Node& n)
node
[
"input_forget"
]
=
tmp
->
get_input_forget
();
node
[
"input_forget"
]
=
tmp
->
get_input_forget
();
break
;
break
;
}
}
case
OP_TYPEID
:
:
MatMul
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
MatMul
*>
(
&
n
);
node
[
"transpose_a"
]
=
tmp
->
get_transpose_a
();
node
[
"transpose_b"
]
=
tmp
->
get_transpose_b
();
break
;
}
case
OP_TYPEID
:
:
Max
:
case
OP_TYPEID
:
:
Max
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Max
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Max
*>
(
&
n
);
...
@@ -2881,8 +2910,16 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2881,8 +2910,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
}
case
OP_TYPEID
:
:
Softmax
:
case
OP_TYPEID
:
:
Softmax
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Softmax
*>
(
&
n
);
if
(
op_version
==
0
)
node
[
"softmax_axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
());
{
auto
tmp
=
dynamic_cast
<
const
op
::
v0
::
Softmax
*>
(
&
n
);
node
[
"softmax_axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
());
}
if
(
op_version
==
1
)
{
auto
tmp
=
dynamic_cast
<
const
op
::
v1
::
Softmax
*>
(
&
n
);
node
[
"softmax_axis"
]
=
tmp
->
get_axis
();
}
break
;
break
;
}
}
case
OP_TYPEID
:
:
Tan
:
{
break
;
case
OP_TYPEID
:
:
Tan
:
{
break
;
...
...
test/CMakeLists.txt
View file @
fcf59b2a
...
@@ -64,6 +64,7 @@ set(SRC
...
@@ -64,6 +64,7 @@ set(SRC
node_input_output.cpp
node_input_output.cpp
nop_elimination.cpp
nop_elimination.cpp
op.cpp
op.cpp
opset_pass/softmax_opset_pass.cpp
partial_shape.cpp
partial_shape.cpp
pass.cpp
pass.cpp
pass_liveness.cpp
pass_liveness.cpp
...
@@ -113,6 +114,7 @@ set(SRC
...
@@ -113,6 +114,7 @@ set(SRC
type_prop/index_reduction.cpp
type_prop/index_reduction.cpp
type_prop/lrn.cpp
type_prop/lrn.cpp
type_prop/lstm_cell.cpp
type_prop/lstm_cell.cpp
type_prop/matmul.cpp
type_prop/max_pool.cpp
type_prop/max_pool.cpp
type_prop/mvn.cpp
type_prop/mvn.cpp
type_prop/normalize.cpp
type_prop/normalize.cpp
...
...
test/backend/topk.in.cpp
View file @
fcf59b2a
This diff is collapsed.
Click to expand it.
test/cpu_test.cpp
View file @
fcf59b2a
...
@@ -1059,6 +1059,29 @@ TEST(cpu_test, thread_safe_calls_convolution_2d_2items)
...
@@ -1059,6 +1059,29 @@ TEST(cpu_test, thread_safe_calls_convolution_2d_2items)
unset_environment
(
"NGRAPH_CPU_CONCURRENCY"
);
unset_environment
(
"NGRAPH_CPU_CONCURRENCY"
);
}
}
TEST
(
cpu_test
,
constant_convertlayout
)
{
Shape
data_shape
{
1
,
64
,
56
,
56
};
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
data_shape
);
Shape
weights_shape
{
64
,
64
,
3
,
3
};
test
::
Uniform
<
float
>
rng
(
-
100.0
f
,
100.0
f
);
vector
<
float
>
values_in
(
shape_size
(
weights_shape
));
rng
.
initialize
(
values_in
);
auto
weights
=
make_shared
<
op
::
Constant
>
(
element
::
f32
,
weights_shape
,
values_in
);
Shape
bias_shape
{
64
};
auto
bias
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
bias_shape
);
auto
conv
=
std
::
make_shared
<
op
::
Convolution
>
(
data
,
weights
,
Strides
{
1
,
1
},
Strides
{
1
,
1
});
auto
convbias
=
make_shared
<
op
::
ConvolutionBias
>
(
conv
,
bias
);
auto
f
=
make_shared
<
Function
>
(
convbias
,
ParameterVector
{
data
,
bias
});
auto
backend
=
runtime
::
Backend
::
create
(
"CPU"
);
auto
handle
=
backend
->
compile
(
f
);
size_t
convert_layout
=
count_ops_of_type
<
runtime
::
cpu
::
op
::
ConvertLayout
>
(
f
);
ASSERT_EQ
(
convert_layout
,
1
);
}
TEST
(
cpu_test
,
constant_reshape
)
TEST
(
cpu_test
,
constant_reshape
)
{
{
Shape
shape_in
{
2
,
4
};
Shape
shape_in
{
2
,
4
};
...
...
test/opset_pass/softmax_opset_pass.cpp
0 → 100644
View file @
fcf59b2a
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
serialize
,
opset1_softmax_pass_axis
)
{
const
size_t
axis
=
2
;
const
AxisSet
axes
{
axis
};
auto
arg
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
,
4
});
auto
softmax_s0
=
make_shared
<
op
::
v0
::
Softmax
>
(
arg
,
axes
);
auto
result
=
make_shared
<
op
::
Result
>
(
softmax_s0
);
auto
f
=
make_shared
<
Function
>
(
ResultVector
{
result
},
ParameterVector
{
arg
});
ngraph
::
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
Opset1Upgrade
>
();
pass_manager
.
run_passes
(
f
);
auto
softmax_s1_result
=
f
->
get_results
().
at
(
0
);
auto
node
=
softmax_s1_result
->
input
(
0
).
get_source_output
().
get_node_shared_ptr
();
auto
softmax_s1_node
=
static_pointer_cast
<
op
::
v1
::
Softmax
>
(
node
);
EXPECT_EQ
(
softmax_s1_node
->
get_axis
(),
axis
);
EXPECT_EQ
(
softmax_s1_node
->
description
(),
"Softmax"
);
EXPECT_EQ
(
softmax_s1_node
->
get_version
(),
1
);
}
TEST
(
serialize
,
opset1_softmax_pass_axis_exception
)
{
const
AxisSet
axes
{
1
,
2
};
auto
arg
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
,
4
});
auto
softmax_s0
=
make_shared
<
op
::
v0
::
Softmax
>
(
arg
,
axes
);
auto
result
=
make_shared
<
op
::
Result
>
(
softmax_s0
);
auto
f
=
make_shared
<
Function
>
(
ResultVector
{
result
},
ParameterVector
{
arg
});
ngraph
::
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
Opset1Upgrade
>
();
try
{
pass_manager
.
run_passes
(
f
);
FAIL
()
<<
"Exception after Opset1Upgrade pass was not thrown."
;
}
catch
(
const
ngraph_error
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis."
));
}
catch
(...)
{
FAIL
()
<<
"Softmax pass failed for unexpected reason"
;
}
}
namespace
fake_v2
{
class
FakeSoftmax
:
public
op
::
v0
::
Softmax
{
public
:
FakeSoftmax
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
axes
)
:
Softmax
{
arg
,
axes
}
{
}
size_t
get_version
()
const
override
{
return
2
;
}
};
}
TEST
(
serialize
,
opset1_softmax_pass_incorrect_op_version
)
{
const
AxisSet
axes
{
2
};
auto
arg
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
,
4
});
auto
softmax_s2
=
make_shared
<
fake_v2
::
FakeSoftmax
>
(
arg
,
axes
);
auto
result
=
make_shared
<
op
::
Result
>
(
softmax_s2
);
auto
f
=
make_shared
<
Function
>
(
ResultVector
{
result
},
ParameterVector
{
arg
});
ngraph
::
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
Opset1Upgrade
>
();
try
{
pass_manager
.
run_passes
(
f
);
FAIL
()
<<
"Opset 1 transformation pass failed for"
;
}
catch
(
const
ngraph_error
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Op version 1 transformation pass failed for"
));
}
catch
(...)
{
FAIL
()
<<
"Softmax pass failed for unexpected reason"
;
}
}
test/serialize.cpp
View file @
fcf59b2a
...
@@ -340,3 +340,19 @@ TEST(serialize, non_zero_node_output)
...
@@ -340,3 +340,19 @@ TEST(serialize, non_zero_node_output)
EXPECT_EQ
(
topk_out
.
get_index
(),
1
);
EXPECT_EQ
(
topk_out
.
get_index
(),
1
);
EXPECT_EQ
(
topk_out
.
get_node
()
->
description
(),
"TopK"
);
EXPECT_EQ
(
topk_out
.
get_node
()
->
description
(),
"TopK"
);
}
}
TEST
(
serialize
,
opset1_softmax
)
{
auto
arg
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
10
});
auto
softmax
=
make_shared
<
op
::
v1
::
Softmax
>
(
arg
,
0
);
auto
result
=
make_shared
<
op
::
Result
>
(
softmax
);
auto
f
=
make_shared
<
Function
>
(
ResultVector
{
result
},
ParameterVector
{
arg
});
string
s
=
serialize
(
f
);
shared_ptr
<
Function
>
g
=
deserialize
(
s
);
auto
g_result
=
g
->
get_results
().
at
(
0
);
auto
g_softmax
=
g_result
->
input
(
0
).
get_source_output
().
get_node_shared_ptr
();
EXPECT_EQ
(
g_softmax
->
description
(),
"Softmax"
);
EXPECT_EQ
(
g_softmax
->
get_version
(),
1
);
}
test/type_prop/matmul.cpp
0 → 100644
View file @
fcf59b2a
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
type_prop
,
matmul_2D_same
)
{
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
2
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
2
});
auto
matmul
=
make_shared
<
op
::
MatMul
>
(
A
,
B
);
ASSERT_EQ
(
matmul
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
matmul
->
get_shape
(),
(
Shape
{
2
,
2
}));
}
TEST
(
type_prop
,
matmul_4D_same
)
{
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
2
,
3
,
3
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
2
,
3
,
3
});
auto
matmul
=
make_shared
<
op
::
MatMul
>
(
A
,
B
);
ASSERT_EQ
(
matmul
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
matmul
->
get_shape
(),
(
Shape
{
2
,
2
,
3
,
3
}));
}
TEST
(
type_prop
,
matmul_2D
)
{
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
6
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
6
,
4
});
auto
matmul
=
make_shared
<
op
::
MatMul
>
(
A
,
B
);
ASSERT_EQ
(
matmul
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
matmul
->
get_shape
(),
(
Shape
{
3
,
4
}));
}
TEST
(
type_prop
,
matmul_4D
)
{
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
2
,
3
,
6
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
2
,
6
,
4
});
auto
matmul
=
make_shared
<
op
::
MatMul
>
(
A
,
B
);
ASSERT_EQ
(
matmul
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
matmul
->
get_shape
(),
(
Shape
{
2
,
2
,
3
,
4
}));
}
TEST
(
type_prop
,
matmul_2D_transpose_a
)
{
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
6
,
3
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
6
,
4
});
auto
matmul
=
make_shared
<
op
::
MatMul
>
(
A
,
B
,
1
);
ASSERT_EQ
(
matmul
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
matmul
->
get_shape
(),
(
Shape
{
3
,
4
}));
}
TEST
(
type_prop
,
matmul_4D_transpose_a
)
{
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
2
,
6
,
3
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
2
,
6
,
4
});
auto
matmul
=
make_shared
<
op
::
MatMul
>
(
A
,
B
,
1
);
ASSERT_EQ
(
matmul
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
matmul
->
get_shape
(),
(
Shape
{
2
,
2
,
3
,
4
}));
}
TEST
(
type_prop
,
matmul_2D_transpose_b
)
{
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
6
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
6
});
auto
matmul
=
make_shared
<
op
::
MatMul
>
(
A
,
B
,
0
,
1
);
ASSERT_EQ
(
matmul
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
matmul
->
get_shape
(),
(
Shape
{
3
,
4
}));
}
TEST
(
type_prop
,
matmul_4D_transpose_b
)
{
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
2
,
3
,
6
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
2
,
4
,
6
});
auto
matmul
=
make_shared
<
op
::
MatMul
>
(
A
,
B
,
0
,
1
);
ASSERT_EQ
(
matmul
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
matmul
->
get_shape
(),
(
Shape
{
2
,
2
,
3
,
4
}));
}
test/type_prop/top_k.cpp
View file @
fcf59b2a
...
@@ -117,7 +117,7 @@ TEST(type_prop, topk_rank_dynamic_ok)
...
@@ -117,7 +117,7 @@ TEST(type_prop, topk_rank_dynamic_ok)
ASSERT_TRUE
(
topk
->
get_output_element_type
(
1
)
==
element
::
f32
);
ASSERT_TRUE
(
topk
->
get_output_element_type
(
1
)
==
element
::
f32
);
ASSERT_TRUE
(
topk
->
get_output_partial_shape
(
0
).
rank
().
is_dynamic
());
ASSERT_TRUE
(
topk
->
get_output_partial_shape
(
0
).
rank
().
is_dynamic
());
ASSERT_TRUE
(
topk
->
get_output_partial_shape
(
1
).
rank
().
is_dynamic
());
ASSERT_TRUE
(
topk
->
get_output_partial_shape
(
1
).
rank
().
is_dynamic
());
ASSERT_TRUE
(
topk
->
get_sort
()
==
op
::
TopK
::
SortType
::
NONE
);
ASSERT_TRUE
(
topk
->
get_sort
()
==
op
::
TopK
::
SortType
::
SORT_VALUES
);
}
}
TEST
(
type_prop
,
topk_rank_dynamic_result_et_dynamic
)
TEST
(
type_prop
,
topk_rank_dynamic_result_et_dynamic
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment