Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
656dfa55
Commit
656dfa55
authored
6 years ago
by
Nick Korovaiko
Committed by
Adam Procter
6 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
enable cse for reduction ops (#1030)
* enable cse for reduction ops * reduction tests
parent
7d6a0d1c
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
0 deletions
+45
-0
cse.cpp
src/ngraph/pass/cse.cpp
+13
-0
cse.cpp
test/cse.cpp
+32
-0
No files found.
src/ngraph/pass/cse.cpp
View file @
656dfa55
...
...
@@ -77,6 +77,17 @@ static bool cse_binarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
(
a
->
get_argument
(
1
)
==
b
->
get_argument
(
0
)
&&
a
->
get_argument
(
0
)
==
b
->
get_argument
(
1
));
}
static
bool
cse_reduction
(
std
::
shared_ptr
<
Node
>
a
,
std
::
shared_ptr
<
Node
>
b
)
{
NGRAPH_DEBUG
<<
"In cse_reduction for "
<<
a
->
get_name
()
<<
" and "
<<
b
->
get_name
();
auto
ar_a
=
std
::
dynamic_pointer_cast
<
op
::
util
::
ArithmeticReduction
>
(
a
);
auto
ar_b
=
std
::
dynamic_pointer_cast
<
op
::
util
::
ArithmeticReduction
>
(
b
);
return
ar_a
->
get_argument
(
0
)
==
ar_b
->
get_argument
(
0
)
&&
ar_a
->
get_reduction_axes
()
==
ar_b
->
get_reduction_axes
();
}
static
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>
)
>>
initialize_ops_to_cse_handlers
()
...
...
@@ -110,6 +121,8 @@ static std::unordered_map<std::type_index,
{
TI
(
op
::
Power
),
cse_binarywise
},
//{TI(op::Remainder), cse_binarywise},
{
TI
(
op
::
Subtract
),
cse_binarywise
},
{
TI
(
op
::
Sum
),
cse_reduction
},
{
TI
(
op
::
Product
),
cse_reduction
},
});
}
...
...
This diff is collapsed.
Click to expand it.
test/cse.cpp
View file @
656dfa55
...
...
@@ -188,3 +188,35 @@ TEST(CSE, abs_add_abs_add_negative)
ASSERT_EQ
(
oadd4
->
get_argument
(
1
),
D
);
ASSERT_EQ
(
oadd3
->
get_argument
(
0
),
oadd4
->
get_argument
(
0
));
}
template
<
typename
T
>
static
void
execute_cse_reduction_test
()
{
Shape
zero_shape
{
0
};
auto
A
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
3
,
5
});
auto
a_reduction_op
=
std
::
make_shared
<
T
>
(
A
,
AxisSet
{
0
,
1
});
auto
a_reduction_op2
=
std
::
make_shared
<
T
>
(
A
,
AxisSet
{
0
,
1
});
auto
a_reduction_op3
=
std
::
make_shared
<
T
>
(
A
,
AxisSet
{
0
});
auto
sub_aa
=
a_reduction_op
-
a_reduction_op2
;
auto
B
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
3
,
5
});
auto
b_reduction_op
=
std
::
make_shared
<
T
>
(
B
,
AxisSet
{
0
,
1
});
auto
sub_ab
=
a_reduction_op
-
b_reduction_op
;
auto
f
=
std
::
make_shared
<
Function
>
(
NodeVector
{
sub_aa
,
sub_ab
,
a_reduction_op3
},
op
::
ParameterVector
{
A
,
B
});
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
ngraph
::
pass
::
CommonSubexpressionElimination
>
();
pass_manager
.
run_passes
(
f
);
ASSERT_EQ
(
sub_aa
->
get_argument
(
0
),
sub_aa
->
get_argument
(
1
));
ASSERT_NE
(
sub_ab
->
get_argument
(
0
),
sub_ab
->
get_argument
(
1
));
ASSERT_NE
(
f
->
get_results
().
at
(
2
)
->
get_argument
(
0
),
sub_aa
->
get_argument
(
0
));
}
TEST
(
CSE
,
reduction_ops
)
{
execute_cse_reduction_test
<
op
::
Sum
>
();
execute_cse_reduction_test
<
op
::
Product
>
();
}
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment