Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
a6c2f23b
Commit
a6c2f23b
authored
6 years ago
by
Adam Procter
Committed by
Scott Cyphers
6 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add ConstantFolding for Gather (#3342)
* Add CF for Gather * Style
parent
6b90c1bd
master
v0.29.0-rc.0
v0.28.0-rc.1
v0.28.0-rc.0
v0.27.1-rc.3
v0.27.1-rc.2
v0.27.1-rc.1
v0.27.1-rc.0
v0.27.0-rc.1
v0.27.0-rc.0
v0.26.1-rc.0
v0.26.0
v0.26.0-rc.8
v0.26.0-rc.7
v0.26.0-rc.6
v0.26.0-rc.5
v0.26.0-rc.4
v0.26.0-rc.3
v0.26.0-rc.2
v0.26.0-rc.0
v0.25.1-rc.11
v0.25.1-rc.10
v0.25.1-rc.9
v0.25.1-rc.8
v0.25.1-rc.7
v0.25.1-rc.6
v0.25.1-rc.5
v0.25.1-rc.4
v0.25.1-rc.3
v0.25.1-rc.2
v0.25.1-rc.1
v0.25.1-rc.0
v0.25.0
v0.25.0-rc.3
v0.25.0-rc.2
v0.25.0-rc.1
v0.25.0-rc.0
No related merge requests found
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
180 additions
and
4 deletions
+180
-4
constant_folding.cpp
src/ngraph/pass/constant_folding.cpp
+142
-0
constant_folding.hpp
src/ngraph/pass/constant_folding.hpp
+4
-0
gather.hpp
src/ngraph/runtime/reference/gather.hpp
+5
-4
constant_folding.cpp
test/constant_folding.cpp
+29
-0
No files found.
src/ngraph/pass/constant_folding.cpp
View file @
a6c2f23b
...
...
@@ -35,6 +35,7 @@
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
...
...
@@ -70,6 +71,7 @@
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
...
...
@@ -1863,6 +1865,146 @@ void pass::ConstantFolding::construct_constant_concat()
this
->
add_matcher
(
concat_matcher
,
constant_concat_callback
,
all_pass_property_off
);
}
// "Inner" helper for fold_constant_gather, which has to switch on the indices
// element type.
template
<
typename
T
,
typename
U
>
static
shared_ptr
<
op
::
Constant
>
fold_constant_gather_helper
(
const
shared_ptr
<
op
::
Constant
>&
data
,
const
shared_ptr
<
op
::
Constant
>&
indices
,
const
shared_ptr
<
op
::
Gather
>&
gather
)
{
std
::
vector
<
T
>
result_vec
(
shape_size
(
gather
->
get_shape
()));
runtime
::
reference
::
gather
<
T
,
U
>
(
data
->
get_data_ptr
<
T
>
(),
indices
->
get_data_ptr
<
U
>
(),
result_vec
.
data
(),
data
->
get_shape
(),
indices
->
get_shape
(),
gather
->
get_shape
(),
gather
->
get_axis
());
return
make_shared
<
op
::
Constant
>
(
gather
->
get_output_element_type
(
0
),
gather
->
get_output_shape
(
0
),
result_vec
);
}
template
<
typename
T
>
static
shared_ptr
<
op
::
Constant
>
fold_constant_gather
(
const
shared_ptr
<
op
::
Constant
>&
data
,
const
shared_ptr
<
op
::
Constant
>&
indices
,
const
shared_ptr
<
op
::
Gather
>&
gather
)
{
auto
indices_type
=
indices
->
get_output_element_type
(
0
);
switch
(
indices_type
.
get_type_enum
())
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_gather_callback"
);
break
;
case
element
:
:
Type_t
::
dynamic
:
NGRAPH_CHECK
(
false
,
"Encountered 'dynamic' element type in constant_gather_callback"
);
break
;
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
bf16
:
case
element
:
:
Type_t
::
f16
:
case
element
:
:
Type_t
::
f32
:
case
element
:
:
Type_t
::
f64
:
case
element
:
:
Type_t
::
i8
:
case
element
:
:
Type_t
::
i16
:
case
element
:
:
Type_t
::
u8
:
case
element
:
:
Type_t
::
u16
:
case
element
:
:
Type_t
::
u32
:
case
element
:
:
Type_t
::
u64
:
NGRAPH_CHECK
(
false
,
"Encountered unsupported indices element type in constant_gather_callback: "
,
indices_type
);
break
;
case
element
:
:
Type_t
::
i32
:
return
fold_constant_gather_helper
<
T
,
int32_t
>
(
data
,
indices
,
gather
);
case
element
:
:
Type_t
::
i64
:
return
fold_constant_gather_helper
<
T
,
int64_t
>
(
data
,
indices
,
gather
);
}
NGRAPH_UNREACHABLE
(
"Unhandled switch case"
);
}
void
pass
::
ConstantFolding
::
construct_constant_gather
()
{
auto
data_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
10
,
20
,
30
},
pattern
::
has_class
<
op
::
Constant
>
());
auto
indices_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i64
,
Shape
{
5
},
pattern
::
has_class
<
op
::
Constant
>
());
size_t
gather_axis
=
1
;
auto
gather_op
=
make_shared
<
op
::
Gather
>
(
data_label
,
indices_label
,
gather_axis
);
auto
constant_gather_callback
=
[
data_label
,
indices_label
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In callback for constant_gather_callback against node = "
<<
m
.
get_match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
data
=
static_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
data_label
]);
auto
indices
=
static_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
indices_label
]);
auto
gather
=
static_pointer_cast
<
op
::
Gather
>
(
m
.
get_match_root
());
std
::
shared_ptr
<
Node
>
replacement
;
auto
data_type
=
data
->
get_output_element_type
(
0
);
auto
indices_type
=
indices
->
get_output_element_type
(
0
);
switch
(
data_type
.
get_type_enum
())
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_gather_callback"
);
break
;
case
element
:
:
Type_t
::
dynamic
:
NGRAPH_CHECK
(
false
,
"Encountered 'dynamic' element type in constant_gather_callback"
);
break
;
case
element
:
:
Type_t
::
boolean
:
replacement
=
fold_constant_gather
<
char
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
bf16
:
replacement
=
fold_constant_gather
<
bfloat16
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
f16
:
replacement
=
fold_constant_gather
<
float16
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
f32
:
replacement
=
fold_constant_gather
<
float
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
f64
:
replacement
=
fold_constant_gather
<
double
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
i8
:
replacement
=
fold_constant_gather
<
int8_t
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
i16
:
replacement
=
fold_constant_gather
<
int16_t
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
i32
:
replacement
=
fold_constant_gather
<
int32_t
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
i64
:
replacement
=
fold_constant_gather
<
int64_t
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
u8
:
replacement
=
fold_constant_gather
<
uint8_t
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
u16
:
replacement
=
fold_constant_gather
<
uint16_t
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
u32
:
replacement
=
fold_constant_gather
<
uint32_t
>
(
data
,
indices
,
gather
);
break
;
case
element
:
:
Type_t
::
u64
:
replacement
=
fold_constant_gather
<
uint64_t
>
(
data
,
indices
,
gather
);
break
;
}
replace_node
(
m
.
get_match_root
(),
replacement
);
return
true
;
};
auto
gather_matcher
=
make_shared
<
pattern
::
Matcher
>
(
gather_op
,
"ConstantFolding.ConstantGather"
);
this
->
add_matcher
(
gather_matcher
,
constant_gather_callback
,
PassProperty
::
REQUIRE_STATIC_SHAPE
);
}
template
<
class
T
>
shared_ptr
<
op
::
Constant
>
fold_constant_slice
(
shared_ptr
<
op
::
Constant
>
constant
,
shared_ptr
<
op
::
Slice
>
slice
)
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pass/constant_folding.hpp
View file @
a6c2f23b
...
...
@@ -45,6 +45,7 @@ public:
PRODUCT
,
SUM
,
CONCAT
,
GATHER
,
SLICE
,
DYN_SLICE
,
DYN_RESHAPE
,
...
...
@@ -68,6 +69,7 @@ public:
construct_constant_product
();
construct_constant_sum
();
construct_constant_concat
();
construct_constant_gather
();
construct_constant_slice
();
construct_constant_dyn_slice
();
construct_constant_dyn_reshape
();
...
...
@@ -98,6 +100,7 @@ public:
case
CFTransformations
:
:
PRODUCT
:
construct_constant_product
();
break
;
case
CFTransformations
:
:
SUM
:
construct_constant_sum
();
break
;
case
CFTransformations
:
:
CONCAT
:
construct_constant_concat
();
break
;
case
CFTransformations
:
:
GATHER
:
construct_constant_gather
();
break
;
case
CFTransformations
:
:
SLICE
:
construct_constant_slice
();
break
;
case
CFTransformations
:
:
DYN_SLICE
:
construct_constant_dyn_slice
();
break
;
case
CFTransformations
:
:
DYN_RESHAPE
:
construct_constant_dyn_reshape
();
break
;
...
...
@@ -120,6 +123,7 @@ private:
void
construct_constant_product
();
void
construct_constant_sum
();
void
construct_constant_concat
();
void
construct_constant_gather
();
void
construct_constant_slice
();
void
construct_constant_dyn_slice
();
void
construct_constant_dyn_reshape
();
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/reference/gather.hpp
View file @
a6c2f23b
...
...
@@ -43,8 +43,8 @@ namespace ngraph
// out' = out[out_index] # rank(out') == rank(params')
// gather_nd(params', indices'', out')
template
<
typename
T
,
typename
U
>
void
gather
(
T
*
params
,
U
*
indices
,
void
gather
(
const
T
*
params
,
const
U
*
indices
,
T
*
out
,
const
Shape
&
params_shape
,
const
Shape
&
indices_shape
,
...
...
@@ -148,13 +148,14 @@ namespace ngraph
auto
out_outer_coord_iter
=
out_outer_transform
.
begin
();
for
(
const
Coordinate
&
params_outer_coord
:
params_outer_transform
)
{
T
*
params_prime
=
&
params
[
params_outer_transform
.
index
(
params_outer_coord
)];
const
T
*
params_prime
=
&
params
[
params_outer_transform
.
index
(
params_outer_coord
)];
T
*
out_outer
=
&
out
[
out_outer_transform
.
index
(
*
out_outer_coord_iter
)];
auto
out_inner_coord_iter
=
out_inner_transform
.
begin
();
for
(
const
Coordinate
&
indices_outer_coord
:
indices_outer_transform
)
{
U
*
indices_prime
=
const
U
*
indices_prime
=
&
indices
[
indices_outer_transform
.
index
(
indices_outer_coord
)];
T
*
out_prime
=
&
out_outer
[
out_inner_transform
.
index
(
*
out_inner_coord_iter
)];
gather_nd
<
T
,
U
>
(
params_prime
,
...
...
This diff is collapsed.
Click to expand it.
test/constant_folding.cpp
View file @
a6c2f23b
...
...
@@ -739,6 +739,35 @@ TEST(constant_folding, const_floor)
ASSERT_TRUE
(
test
::
all_close_f
(
values_out
,
values_expected
,
MIN_FLOAT_TOLERANCE_BITS
));
}
TEST
(
constant_folding
,
const_gather
)
{
auto
constant_data
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{
2
,
5
},
vector
<
float
>
{
1.0
f
,
2.0
f
,
3.0
f
,
4.0
f
,
5.0
f
,
6.0
f
,
7.0
f
,
8.0
f
,
9.0
f
,
10.0
f
});
auto
constant_indices
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
4
},
vector
<
int64_t
>
{
0
,
3
,
2
,
2
});
size_t
gather_axis
=
1
;
auto
gather
=
make_shared
<
op
::
Gather
>
(
constant_data
,
constant_indices
,
gather_axis
);
auto
f
=
make_shared
<
Function
>
(
gather
,
ParameterVector
{});
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
ConstantFolding
>
();
pass_manager
.
run_passes
(
f
);
ASSERT_EQ
(
count_ops_of_type
<
op
::
Gather
>
(
f
),
0
);
ASSERT_EQ
(
count_ops_of_type
<
op
::
Constant
>
(
f
),
1
);
auto
new_const
=
std
::
dynamic_pointer_cast
<
op
::
Constant
>
(
f
->
get_results
().
at
(
0
)
->
get_argument
(
0
));
ASSERT_TRUE
(
new_const
);
auto
values_out
=
new_const
->
get_vector
<
float
>
();
vector
<
float
>
values_expected
{
1.0
f
,
4.0
f
,
3.0
f
,
3.0
f
,
6.0
f
,
9.0
f
,
8.0
f
,
8.0
f
};
ASSERT_TRUE
(
test
::
all_close_f
(
values_out
,
values_expected
,
MIN_FLOAT_TOLERANCE_BITS
));
}
TEST
(
constant_folding
,
const_slice
)
{
Shape
shape_in
{
16
};
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment