Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
bc448701
Commit
bc448701
authored
Nov 15, 2019
by
Gleb Kazantaev
Committed by
Scott Cyphers
Nov 15, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added constant folding for binary ops (#3895)
parent
1d53977a
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
215 additions
and
38 deletions
+215
-38
constant_folding_binary.cpp
src/ngraph/pass/constant_folding_binary.cpp
+215
-38
No files found.
src/ngraph/pass/constant_folding_binary.cpp
View file @
bc448701
...
...
@@ -74,7 +74,7 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
}
else
{
if
(
auto
and_
node
=
as_type_ptr
<
op
::
And
>
(
binary
))
if
(
auto
and_
v0_node
=
as_type_ptr
<
op
::
v0
::
And
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
logical_and
<
char
>
(
a
->
get_data_ptr
<
char
>
(),
...
...
@@ -82,21 +82,21 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
and_node
->
get_autob
());
and_
v0_
node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
logical_
xor_node
=
as_type_ptr
<
op
::
v1
::
LogicalXor
>
(
binary
))
else
if
(
auto
logical_
and_node
=
as_type_ptr
<
op
::
v1
::
LogicalAnd
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
logical_
xor
<
char
>
(
a
->
get_data_ptr
<
char
>
(),
runtime
::
reference
::
logical_
and
<
char
>
(
a
->
get_data_ptr
<
char
>
(),
b
->
get_data_ptr
<
char
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
logical_
xor
_node
->
get_autob
());
logical_
and
_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
or_node
=
as_type_ptr
<
op
::
Or
>
(
binary
))
else
if
(
auto
or_node
=
as_type_ptr
<
op
::
v0
::
Or
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
logical_or
<
char
>
(
a
->
get_data_ptr
<
char
>
(),
...
...
@@ -107,6 +107,17 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
or_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
logical_or_node
=
as_type_ptr
<
op
::
v1
::
LogicalOr
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
logical_or
<
char
>
(
a
->
get_data_ptr
<
char
>
(),
b
->
get_data_ptr
<
char
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
logical_or_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
xor_node
=
as_type_ptr
<
op
::
v0
::
Xor
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
...
...
@@ -118,6 +129,17 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
xor_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
logical_xor_node
=
as_type_ptr
<
op
::
v1
::
LogicalXor
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
logical_xor
<
char
>
(
a
->
get_data_ptr
<
char
>
(),
b
->
get_data_ptr
<
char
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
logical_xor_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
{
NGRAPH_CHECK
(
...
...
@@ -151,7 +173,18 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
}
else
{
if
(
auto
equal_node
=
as_type_ptr
<
op
::
Equal
>
(
binary
))
if
(
auto
equal_v0_node
=
as_type_ptr
<
op
::
v0
::
Equal
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
equal
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
equal_v0_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
equal_v1_node
=
as_type_ptr
<
op
::
v1
::
Equal
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
equal
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
...
...
@@ -159,10 +192,10 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
equal_node
->
get_autob
());
equal_
v1_
node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
greater_
node
=
as_type_ptr
<
op
::
Greater
>
(
binary
))
else
if
(
auto
greater_
v0_node
=
as_type_ptr
<
op
::
v0
::
Greater
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
greater
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
...
...
@@ -170,10 +203,32 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
greater_node
->
get_autob
());
greater_
v0_
node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
greater_eq_node
=
as_type_ptr
<
op
::
GreaterEq
>
(
binary
))
else
if
(
auto
greater_v1_node
=
as_type_ptr
<
op
::
v1
::
Greater
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
greater
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
greater_v1_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
greater_eq_v0_node
=
as_type_ptr
<
op
::
v0
::
GreaterEq
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
greater_eq
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
greater_eq_v0_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
greater_eq_v1_node
=
as_type_ptr
<
op
::
v1
::
GreaterEq
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
greater_eq
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
...
...
@@ -181,10 +236,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
greater_eq_node
->
get_autob
());
greater_eq_v1_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
less_v0_node
=
as_type_ptr
<
op
::
v0
::
Less
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
less
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
less_v0_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
less_
node
=
as_type_ptr
<
op
::
Less
>
(
binary
))
else
if
(
auto
less_
v1_node
=
as_type_ptr
<
op
::
v1
::
Less
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
less
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
...
...
@@ -192,10 +258,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
less_node
->
get_autob
());
less_v1_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
less_eq_v0_node
=
as_type_ptr
<
op
::
v0
::
LessEq
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
less_eq
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
less_eq_v0_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
less_eq_
node
=
as_type_ptr
<
op
::
LessEq
>
(
binary
))
else
if
(
auto
less_eq_
v1_node
=
as_type_ptr
<
op
::
v1
::
LessEqual
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
less_eq
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
...
...
@@ -203,10 +280,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
less_eq_node
->
get_autob
());
less_eq_v1_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
not_equal_v0_node
=
as_type_ptr
<
op
::
v0
::
NotEqual
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
not_equal
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
not_equal_v0_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
not_equal_
node
=
as_type_ptr
<
op
::
NotEqual
>
(
binary
))
else
if
(
auto
not_equal_
v1_node
=
as_type_ptr
<
op
::
v1
::
NotEqual
>
(
binary
))
{
vector
<
char
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
not_equal
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
...
...
@@ -214,7 +302,7 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
not_equal_node
->
get_autob
());
not_equal_
v1_
node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
...
...
@@ -249,7 +337,7 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
}
else
{
if
(
auto
add_
node
=
as_type_ptr
<
op
::
Add
>
(
binary
))
if
(
auto
add_
v0_node
=
as_type_ptr
<
op
::
v0
::
Add
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
...
...
@@ -259,26 +347,55 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
add_node
->
get_autob
());
add_
v0_
node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
divide_node
=
as_type_ptr
<
op
::
Divide
>
(
binary
))
else
if
(
auto
add_v1_node
=
as_type_ptr
<
op
::
v1
::
Add
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
vector
<
Tout
>
out_vec
(
shape_size
(
out_shape
));
shared_ptr
<
op
::
Divide
>
divop
=
as_type_ptr
<
op
::
Divide
>
(
binary
);
runtime
::
reference
::
add
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
add_v1_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
divide_v0_node
=
as_type_ptr
<
op
::
v0
::
Divide
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
vector
<
Tout
>
out_vec
(
shape_size
(
out_shape
));
shared_ptr
<
op
::
v0
::
Divide
>
divop
=
as_type_ptr
<
op
::
v0
::
Divide
>
(
binary
);
bool
pythondiv
=
divop
->
is_pythondiv
();
runtime
::
reference
::
divide
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
divide_node
->
get_autob
(),
divide_
v0_
node
->
get_autob
(),
pythondiv
);
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
maximum_node
=
as_type_ptr
<
op
::
Maximum
>
(
binary
))
else
if
(
auto
divide_v1_node
=
as_type_ptr
<
op
::
v1
::
Divide
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
vector
<
Tout
>
out_vec
(
shape_size
(
out_shape
));
shared_ptr
<
op
::
v1
::
Divide
>
divop
=
as_type_ptr
<
op
::
v1
::
Divide
>
(
binary
);
bool
pythondiv
=
divop
->
is_pythondiv
();
runtime
::
reference
::
divide
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
divide_v1_node
->
get_autob
(),
pythondiv
);
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
maximum_v0_node
=
as_type_ptr
<
op
::
v0
::
Maximum
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
...
...
@@ -288,10 +405,36 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
maximum_node
->
get_autob
());
maximum_
v0_
node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
minimum_node
=
as_type_ptr
<
op
::
Minimum
>
(
binary
))
else
if
(
auto
maximum_v1_node
=
as_type_ptr
<
op
::
v1
::
Maximum
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
vector
<
Tout
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
maximum
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
maximum_v1_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
minimum_v0_node
=
as_type_ptr
<
op
::
v0
::
Minimum
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
vector
<
Tout
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
minimum
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
minimum_v0_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
minimum_v1_node
=
as_type_ptr
<
op
::
v1
::
Minimum
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
...
...
@@ -301,10 +444,23 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
minimum_node
->
get_autob
());
minimum_v1_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
multiply_v0_node
=
as_type_ptr
<
op
::
v0
::
Multiply
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
vector
<
Tout
>
out_vec
(
shape_size
(
out_shape
));
runtime
::
reference
::
multiply
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
multiply_v0_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
multiply_
node
=
as_type_ptr
<
op
::
Multiply
>
(
binary
))
else
if
(
auto
multiply_
v1_node
=
as_type_ptr
<
op
::
v1
::
Multiply
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
...
...
@@ -314,21 +470,35 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
multiply_node
->
get_autob
());
multiply_v1_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
power_v0_node
=
as_type_ptr
<
op
::
v0
::
Power
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
vector
<
Tout
>
out_vec
(
shape_size
(
out_shape
));
shared_ptr
<
op
::
v0
::
Power
>
powop
=
as_type_ptr
<
op
::
v0
::
Power
>
(
binary
);
runtime
::
reference
::
power
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
power_v0_node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
power_
node
=
as_type_ptr
<
op
::
Power
>
(
binary
))
else
if
(
auto
power_
v1_node
=
as_type_ptr
<
op
::
v1
::
Power
>
(
binary
))
{
NGRAPH_CHECK
(
element
::
from
<
Tin
>
()
==
element
::
from
<
Tout
>
(),
"Input/output types do not match"
);
vector
<
Tout
>
out_vec
(
shape_size
(
out_shape
));
shared_ptr
<
op
::
Power
>
powop
=
as_type_ptr
<
op
::
Power
>
(
binary
);
shared_ptr
<
op
::
v1
::
Power
>
powop
=
as_type_ptr
<
op
::
v1
::
Power
>
(
binary
);
runtime
::
reference
::
power
<
Tin
>
(
a
->
get_data_ptr
<
Tin
>
(),
b
->
get_data_ptr
<
Tin
>
(),
out_vec
.
data
(),
a
->
get_shape
(),
b
->
get_shape
(),
power_node
->
get_autob
());
power_
v1_
node
->
get_autob
());
return
make_shared
<
op
::
Constant
>
(
binary
->
get_element_type
(),
out_shape
,
out_vec
);
}
else
if
(
auto
subtract_node
=
as_type_ptr
<
op
::
Subtract
>
(
binary
))
...
...
@@ -375,12 +545,19 @@ shared_ptr<op::Constant> fold_constant_binary_helper(shared_ptr<op::Constant> a,
bool
is_supported_binary_op
(
std
::
shared_ptr
<
Node
>
n
)
{
return
(
is_type
<
op
::
Add
>
(
n
)
||
is_type
<
op
::
And
>
(
n
)
||
is_type
<
op
::
Divide
>
(
n
)
||
is_type
<
op
::
Equal
>
(
n
)
||
is_type
<
op
::
Greater
>
(
n
)
||
is_type
<
op
::
GreaterEq
>
(
n
)
||
is_type
<
op
::
Less
>
(
n
)
||
is_type
<
op
::
LessEq
>
(
n
)
||
is_type
<
op
::
Maximum
>
(
n
)
||
is_type
<
op
::
Minimum
>
(
n
)
||
is_type
<
op
::
Multiply
>
(
n
)
||
is_type
<
op
::
NotEqual
>
(
n
)
||
is_type
<
op
::
Or
>
(
n
)
||
is_type
<
op
::
Power
>
(
n
)
||
is_type
<
op
::
Subtract
>
(
n
)
||
is_type
<
op
::
Xor
>
(
n
));
return
(
is_type
<
op
::
v0
::
Add
>
(
n
)
||
is_type
<
op
::
v1
::
Add
>
(
n
)
||
is_type
<
op
::
v0
::
Multiply
>
(
n
)
||
is_type
<
op
::
v1
::
Multiply
>
(
n
)
||
is_type
<
op
::
v0
::
Divide
>
(
n
)
||
is_type
<
op
::
v1
::
Divide
>
(
n
)
||
is_type
<
op
::
v0
::
Power
>
(
n
)
||
is_type
<
op
::
v1
::
Power
>
(
n
)
||
is_type
<
op
::
v0
::
Equal
>
(
n
)
||
is_type
<
op
::
v1
::
Equal
>
(
n
)
||
is_type
<
op
::
v0
::
NotEqual
>
(
n
)
||
is_type
<
op
::
v1
::
NotEqual
>
(
n
)
||
is_type
<
op
::
v0
::
Greater
>
(
n
)
||
is_type
<
op
::
v1
::
Greater
>
(
n
)
||
is_type
<
op
::
v0
::
GreaterEq
>
(
n
)
||
is_type
<
op
::
v1
::
GreaterEq
>
(
n
)
||
is_type
<
op
::
v0
::
Less
>
(
n
)
||
is_type
<
op
::
v1
::
Less
>
(
n
)
||
is_type
<
op
::
v0
::
LessEq
>
(
n
)
||
is_type
<
op
::
v1
::
LessEqual
>
(
n
)
||
is_type
<
op
::
v0
::
Maximum
>
(
n
)
||
is_type
<
op
::
v1
::
Maximum
>
(
n
)
||
is_type
<
op
::
v0
::
Minimum
>
(
n
)
||
is_type
<
op
::
v1
::
Minimum
>
(
n
)
||
is_type
<
op
::
v0
::
And
>
(
n
)
||
is_type
<
op
::
v1
::
LogicalAnd
>
(
n
)
||
is_type
<
op
::
v0
::
Or
>
(
n
)
||
is_type
<
op
::
v1
::
LogicalOr
>
(
n
)
||
is_type
<
op
::
v0
::
Xor
>
(
n
)
||
is_type
<
op
::
v1
::
LogicalXor
>
(
n
)
||
is_type
<
op
::
Subtract
>
(
n
));
}
void
pass
::
ConstantFolding
::
construct_constant_binary
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment