Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
38a389d6
Commit
38a389d6
authored
5 years ago
by
Amy Zhuang
Committed by
Scott Cyphers
5 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use Eigen kernel for more cases for Gather and ScatterAdd. (#3268)
parent
0818dabc
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
15 additions
and
7 deletions
+15
-7
gather.cpp
src/ngraph/runtime/cpu/builder/gather.cpp
+2
-2
scatter_add.cpp
src/ngraph/runtime/cpu/builder/scatter_add.cpp
+2
-2
cpu_builder.hpp
src/ngraph/runtime/cpu/cpu_builder.hpp
+8
-0
cpu_emitter.cpp
src/ngraph/runtime/cpu/cpu_emitter.cpp
+2
-2
backend_scatter.in.cpp
test/backend_scatter.in.cpp
+1
-1
No files found.
src/ngraph/runtime/cpu/builder/gather.cpp
View file @
38a389d6
...
@@ -57,7 +57,7 @@ namespace ngraph
...
@@ -57,7 +57,7 @@ namespace ngraph
args
[
0
].
get_element_type
()
==
element
::
f64
||
args
[
0
].
get_element_type
()
==
element
::
f64
||
args
[
0
].
get_element_type
()
==
element
::
u8
||
args
[
0
].
get_element_type
()
==
element
::
u8
||
args
[
0
].
get_element_type
()
==
element
::
i8
)
&&
args
[
0
].
get_element_type
()
==
element
::
i8
)
&&
params_shape
.
size
()
<=
3
&&
out_shape
.
size
()
<=
3
)
params_shape
.
size
()
<=
3
&&
out_shape
.
size
()
<=
5
)
{
{
std
::
function
<
decltype
(
runtime
::
cpu
::
kernel
::
gather_i64
<
float
,
2
,
2
>
)
>
std
::
function
<
decltype
(
runtime
::
cpu
::
kernel
::
gather_i64
<
float
,
2
,
2
>
)
>
kernel
;
kernel
;
...
@@ -117,7 +117,7 @@ namespace ngraph
...
@@ -117,7 +117,7 @@ namespace ngraph
args
[
0
].
get_element_type
()
==
element
::
f64
||
args
[
0
].
get_element_type
()
==
element
::
f64
||
args
[
0
].
get_element_type
()
==
element
::
u8
||
args
[
0
].
get_element_type
()
==
element
::
u8
||
args
[
0
].
get_element_type
()
==
element
::
i8
)
&&
args
[
0
].
get_element_type
()
==
element
::
i8
)
&&
params_shape
.
size
()
<=
3
&&
out_shape
.
size
()
<=
3
)
params_shape
.
size
()
<=
3
&&
out_shape
.
size
()
<=
5
)
{
{
std
::
function
<
decltype
(
runtime
::
cpu
::
kernel
::
gather_i32
<
float
,
2
,
2
>
)
>
std
::
function
<
decltype
(
runtime
::
cpu
::
kernel
::
gather_i32
<
float
,
2
,
2
>
)
>
kernel
;
kernel
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/cpu/builder/scatter_add.cpp
View file @
38a389d6
...
@@ -62,7 +62,7 @@ namespace ngraph
...
@@ -62,7 +62,7 @@ namespace ngraph
if
(
is_int64
)
if
(
is_int64
)
{
{
if
(
inputs_shape
.
size
()
<=
3
&&
updates_shape
.
size
()
<=
3
)
if
(
inputs_shape
.
size
()
<=
3
&&
updates_shape
.
size
()
<=
5
)
{
{
std
::
function
<
decltype
(
runtime
::
cpu
::
kernel
::
scatter_add_i64
<
float
,
2
,
2
>
)
>
std
::
function
<
decltype
(
runtime
::
cpu
::
kernel
::
scatter_add_i64
<
float
,
2
,
2
>
)
>
kernel
;
kernel
;
...
@@ -101,7 +101,7 @@ namespace ngraph
...
@@ -101,7 +101,7 @@ namespace ngraph
}
}
else
else
{
{
if
(
inputs_shape
.
size
()
<=
3
&&
updates_shape
.
size
()
<=
3
)
if
(
inputs_shape
.
size
()
<=
3
&&
updates_shape
.
size
()
<=
5
)
{
{
std
::
function
<
decltype
(
runtime
::
cpu
::
kernel
::
scatter_add_i32
<
float
,
2
,
2
>
)
>
std
::
function
<
decltype
(
runtime
::
cpu
::
kernel
::
scatter_add_i32
<
float
,
2
,
2
>
)
>
kernel
;
kernel
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/cpu/cpu_builder.hpp
View file @
38a389d6
...
@@ -227,6 +227,14 @@
...
@@ -227,6 +227,14 @@
{ \
{ \
SELECT_RANK1(KV, ET, R1, 3, K); \
SELECT_RANK1(KV, ET, R1, 3, K); \
} \
} \
else if (R2 == 4) \
{ \
SELECT_RANK1(KV, ET, R1, 4, K); \
} \
else if (R2 == 5) \
{ \
SELECT_RANK1(KV, ET, R1, 5, K); \
} \
else \
else \
{ \
{ \
throw ngraph_error("Unsupported second rank " + std::to_string(R2) + " for kernel " #K); \
throw ngraph_error("Unsupported second rank " + std::to_string(R2) + " for kernel " #K); \
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/cpu/cpu_emitter.cpp
View file @
38a389d6
...
@@ -1833,7 +1833,7 @@ namespace ngraph
...
@@ -1833,7 +1833,7 @@ namespace ngraph
args
[
0
].
get_element_type
()
==
element
::
f32
||
args
[
0
].
get_element_type
()
==
element
::
f32
||
args
[
0
].
get_element_type
()
==
element
::
u8
||
args
[
0
].
get_element_type
()
==
element
::
u8
||
args
[
0
].
get_element_type
()
==
element
::
i8
)
&&
args
[
0
].
get_element_type
()
==
element
::
i8
)
&&
args
[
0
].
get_shape
().
size
()
<=
3
&&
out
[
0
].
get_shape
().
size
()
<=
3
)
args
[
0
].
get_shape
().
size
()
<=
3
&&
out
[
0
].
get_shape
().
size
()
<=
5
)
{
{
writer
<<
"cpu::kernel::gather<"
<<
args
[
0
].
get_type
()
<<
", "
writer
<<
"cpu::kernel::gather<"
<<
args
[
0
].
get_type
()
<<
", "
<<
args
[
1
].
get_element_type
().
c_type_string
()
<<
", "
<<
args
[
1
].
get_element_type
().
c_type_string
()
<<
", "
...
@@ -1897,7 +1897,7 @@ namespace ngraph
...
@@ -1897,7 +1897,7 @@ namespace ngraph
args
[
0
].
get_element_type
()
==
element
::
f32
||
args
[
0
].
get_element_type
()
==
element
::
f32
||
args
[
0
].
get_element_type
()
==
element
::
u8
||
args
[
0
].
get_element_type
()
==
element
::
u8
||
args
[
0
].
get_element_type
()
==
element
::
i8
)
&&
args
[
0
].
get_element_type
()
==
element
::
i8
)
&&
args
[
0
].
get_shape
().
size
()
<=
3
&&
args
[
2
].
get_shape
().
size
()
<=
3
)
args
[
0
].
get_shape
().
size
()
<=
3
&&
args
[
2
].
get_shape
().
size
()
<=
5
)
{
{
writer
<<
"cpu::kernel::scatter_add<"
<<
args
[
0
].
get_type
()
<<
", "
writer
<<
"cpu::kernel::scatter_add<"
<<
args
[
0
].
get_type
()
<<
", "
<<
args
[
1
].
get_element_type
().
c_type_string
()
<<
", "
<<
args
[
1
].
get_element_type
().
c_type_string
()
<<
", "
...
...
This diff is collapsed.
Click to expand it.
test/backend_scatter.in.cpp
View file @
38a389d6
...
@@ -88,6 +88,7 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_4d_indices)
...
@@ -88,6 +88,7 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_4d_indices)
read_vector<float>(result),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
MIN_FLOAT_TOLERANCE_BITS));
}
}
#endif
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
scatter_add_3d_indices
)
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
scatter_add_3d_indices
)
{
{
...
@@ -123,7 +124,6 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices)
...
@@ -123,7 +124,6 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices)
read_vector
<
float
>
(
result
),
read_vector
<
float
>
(
result
),
MIN_FLOAT_TOLERANCE_BITS
));
MIN_FLOAT_TOLERANCE_BITS
));
}
}
#endif
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
scatter_add_2d_indices
)
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
scatter_add_2d_indices
)
{
{
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment