Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
56976f0c
Unverified
Commit
56976f0c
authored
Oct 15, 2019
by
Michał Karzyński
Committed by
GitHub
Oct 15, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add support for ONNX 1.5 version of TopK (#3684)
parent
d4d169f3
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
243 additions
and
15 deletions
+243
-15
topk.cpp
src/ngraph/frontend/onnx_import/op/topk.cpp
+37
-13
topk.hpp
src/ngraph/frontend/onnx_import/op/topk.hpp
+10
-2
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+1
-0
unit_test.manifest
src/ngraph/runtime/cpu/unit_test.manifest
+3
-0
unit_test.manifest
src/ngraph/runtime/interpreter/unit_test.manifest
+3
-0
unit_test.manifest
src/ngraph/runtime/plaidml/unit_test.manifest
+2
-0
top_k_opset_10.prototxt
test/models/onnx/top_k_opset_10.prototxt
+76
-0
top_k_opset_10_const_k.prototxt
test/models/onnx/top_k_opset_10_const_k.prototxt
+82
-0
onnx_import.in.cpp
test/onnx/onnx_import.in.cpp
+29
-0
No files found.
src/ngraph/frontend/onnx_import/op/topk.cpp
View file @
56976f0c
...
...
@@ -17,7 +17,6 @@
#include <cstdint>
#include <memory>
#include "exceptions.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/topk.hpp"
...
...
@@ -25,6 +24,25 @@
#include "topk.hpp"
#include "utils/common.hpp"
static
std
::
int64_t
get_axis
(
const
ngraph
::
onnx_import
::
Node
&
node
)
{
// Parse node attribute value for axis (adjust for negative value if needed).
std
::
int64_t
axis
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
-
1
)};
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
auto
data_rank
=
data
->
get_shape
().
size
();
return
ngraph
::
onnx_import
::
common
::
validate_axis
(
node
,
axis
,
data_rank
);
}
static
ngraph
::
NodeVector
get_outputs
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
top_k
)
{
std
::
shared_ptr
<
ngraph
::
Node
>
indices
=
std
::
make_shared
<
ngraph
::
op
::
GetOutputElement
>
(
top_k
,
0
);
std
::
shared_ptr
<
ngraph
::
Node
>
values
=
std
::
make_shared
<
ngraph
::
op
::
GetOutputElement
>
(
top_k
,
1
);
return
{
values
,
indices
};
}
namespace
ngraph
{
namespace
onnx_import
...
...
@@ -37,23 +55,29 @@ namespace ngraph
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
std
::
int64_t
k
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"k"
)};
auto
num_dimensions
=
data
->
get_shape
().
size
();
std
::
int64_t
axis
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
-
1
)};
std
::
int64_t
valid_axis
=
common
::
validate_axis
(
node
,
axis
,
num_dimensions
);
auto
axis
=
get_axis
(
node
);
std
::
shared_ptr
<
ngraph
::
Node
>
top_k
=
std
::
make_shared
<
ngraph
::
op
::
TopK
>
(
data
,
valid_axis
,
element
::
i64
,
k
);
std
::
shared_ptr
<
ngraph
::
Node
>
indices
=
std
::
make_shared
<
ngraph
::
op
::
GetOutputElement
>
(
top_k
,
0
);
std
::
shared_ptr
<
ngraph
::
Node
>
values
=
std
::
make_shared
<
ngraph
::
op
::
GetOutputElement
>
(
top_k
,
1
);
std
::
make_shared
<
ngraph
::
op
::
TopK
>
(
data
,
axis
,
element
::
i64
,
k
);
return
{
values
,
indices
}
;
return
get_outputs
(
top_k
)
;
}
}
}
// namespace set_1
namespace
set_10
{
NodeVector
topk
(
const
Node
&
node
)
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
auto
k
=
node
.
get_ng_inputs
().
at
(
1
);
auto
axis
=
get_axis
(
node
);
std
::
shared_ptr
<
ngraph
::
Node
>
top_k
=
std
::
make_shared
<
ngraph
::
op
::
TopK
>
(
data
,
k
,
axis
,
element
::
i64
);
return
get_outputs
(
top_k
);
}
}
}
// namespace op
...
...
src/ngraph/frontend/onnx_import/op/topk.hpp
View file @
56976f0c
...
...
@@ -31,10 +31,18 @@ namespace ngraph
///
/// \param node The ONNX node object representing this operation.
/// \return The vector containing Ngraph nodes producing output of ONNX TopK
/// operation(both values and indices).
/// operation
(both values and indices).
NodeVector
topk
(
const
Node
&
node
);
}
}
// namespace set_1
/// \brief Performs TopK operation from ONNX version 1.5
///
/// \details ONNX op set 10 added support for K as a dynamic input, not a static
/// attribute.
namespace
set_10
{
NodeVector
topk
(
const
Node
&
node
);
}
}
// namespace op
...
...
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
56976f0c
...
...
@@ -347,6 +347,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Tanh"
,
1
,
tanh
);
REGISTER_OPERATOR
(
"ThresholdedRelu"
,
1
,
thresholded_relu
);
REGISTER_OPERATOR
(
"TopK"
,
1
,
topk
);
REGISTER_OPERATOR
(
"TopK"
,
10
,
topk
);
REGISTER_OPERATOR
(
"Transpose"
,
1
,
transpose
);
REGISTER_OPERATOR
(
"Unsqueeze"
,
1
,
unsqueeze
);
REGISTER_OPERATOR
(
"Where"
,
1
,
where
);
...
...
src/ngraph/runtime/cpu/unit_test.manifest
View file @
56976f0c
...
...
@@ -21,3 +21,6 @@ lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
# ONNX TopK with dynamic K
top_k_opset_10
src/ngraph/runtime/interpreter/unit_test.manifest
View file @
56976f0c
...
...
@@ -13,3 +13,6 @@ fake_quantize_with_clip_across_channels
# casting not supported on interpreter
convert_float32_bf16
convert_bf16_float32
# ONNX TopK with dynamic K
top_k_opset_10
src/ngraph/runtime/plaidml/unit_test.manifest
View file @
56976f0c
...
...
@@ -48,6 +48,8 @@ topk_max_sort_index # No plans to implement TopK
topk_min_sort_index # No plans to implement TopK
topk_2d_max_one_with_equal_values # No plans to implement TopK
model_top_k # No plans to implement TopK
top_k_opset_10 # No plans to implement TopK
top_k_opset_10_const_k # No plans to implement TopK
# unsupported op: `Erf`
erf
...
...
test/models/onnx/top_k_opset_10.prototxt
0 → 100644
View file @
56976f0c
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "k"
output: "values"
output: "indices"
op_type: "TopK"
}
name: "test_top_k"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "k"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "values"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 10
}
test/models/onnx/top_k_opset_10_const_k.prototxt
0 → 100644
View file @
56976f0c
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "k"
output: "values"
output: "indices"
op_type: "TopK"
}
name: "test_top_k"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "k"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
initializer {
dims: 1
data_type: 7
int64_data: 3
name: "k"
}
output {
name: "values"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 10
}
test/onnx/onnx_import.in.cpp
View file @
56976f0c
...
...
@@ -1309,6 +1309,35 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_top_k)
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_
$
{
BACKEND_NAME
},
top_k_opset_10
)
{
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/top_k_opset_10.prototxt"
));
auto
test_case
=
ngraph
::
test
::
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
);
test_case
.
add_input
<
float
>
({
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
});
test_case
.
add_input
<
int64_t
>
({
3
});
test_case
.
add_expected_output
<
float
>
(
Shape
{
3
,
3
},
{
3
,
2
,
1
,
7
,
6
,
5
,
11
,
10
,
9
});
// values
test_case
.
add_expected_output
<
std
::
int64_t
>
(
Shape
{
3
,
3
},
{
3
,
2
,
1
,
3
,
2
,
1
,
3
,
2
,
1
});
// indices
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_
$
{
BACKEND_NAME
},
top_k_opset_10_const_k
)
{
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/top_k_opset_10_const_k.prototxt"
));
auto
test_case
=
ngraph
::
test
::
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
);
test_case
.
add_input
<
float
>
({
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
});
test_case
.
add_expected_output
<
float
>
(
Shape
{
3
,
3
},
{
3
,
2
,
1
,
7
,
6
,
5
,
11
,
10
,
9
});
// values
test_case
.
add_expected_output
<
std
::
int64_t
>
(
Shape
{
3
,
3
},
{
3
,
2
,
1
,
3
,
2
,
1
,
3
,
2
,
1
});
// indices
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_
$
{
BACKEND_NAME
},
model_sinh
)
{
auto
function
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment