Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
13fc556e
Commit
13fc556e
authored
Mar 04, 2019
by
Robert Kimball
Committed by
Scott Cyphers
Mar 04, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add ngraph and std namespaces to c++ files (#2549)
* add ngraph and std namespaces to c++ files * style
parent
ef3378c1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
394 additions
and
407 deletions
+394
-407
algebraic_simplification.cpp
src/ngraph/pass/algebraic_simplification.cpp
+73
-89
constant_folding.cpp
src/ngraph/pass/constant_folding.cpp
+7
-7
core_fusion.cpp
src/ngraph/pass/core_fusion.cpp
+22
-25
cse.cpp
src/ngraph/pass/cse.cpp
+31
-36
graph_rewrite.cpp
src/ngraph/pass/graph_rewrite.cpp
+19
-16
like_replacement.cpp
src/ngraph/pass/like_replacement.cpp
+15
-14
liveness.cpp
src/ngraph/pass/liveness.cpp
+1
-1
manager.cpp
src/ngraph/pass/manager.cpp
+5
-5
manager_state.cpp
src/ngraph/pass/manager_state.cpp
+1
-1
memory_layout.cpp
src/ngraph/pass/memory_layout.cpp
+1
-1
memory_visualize.cpp
src/ngraph/pass/memory_visualize.cpp
+1
-1
nop_elimination.cpp
src/ngraph/pass/nop_elimination.cpp
+38
-39
pass.cpp
src/ngraph/pass/pass.cpp
+5
-2
pass_config.cpp
src/ngraph/pass/pass_config.cpp
+14
-13
prefix_reshape_elimination.cpp
src/ngraph/pass/prefix_reshape_elimination.cpp
+14
-11
propagate_cacheability.cpp
src/ngraph/pass/propagate_cacheability.cpp
+5
-4
reshape_elimination.cpp
src/ngraph/pass/reshape_elimination.cpp
+37
-35
reshape_sinking.cpp
src/ngraph/pass/reshape_sinking.cpp
+82
-84
visualize_tree.cpp
src/ngraph/pass/visualize_tree.cpp
+13
-13
zero_dim_tensor_elimination.cpp
src/ngraph/pass/zero_dim_tensor_elimination.cpp
+10
-10
No files found.
src/ngraph/pass/algebraic_simplification.cpp
View file @
13fc556e
...
...
@@ -39,29 +39,26 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
#define TI(x)
std::
type_index(typeid(x))
#define TI(x) type_index(typeid(x))
extern
template
ngraph
::
Shape
ngraph
::
apply_permutation
<
ngraph
::
Shape
>
(
ngraph
::
Shape
input
,
ngraph
::
AxisVector
order
);
extern
template
Shape
ngraph
::
apply_permutation
<
Shape
>
(
Shape
input
,
AxisVector
order
);
template
<
typename
T
>
static
s
td
::
s
hared_ptr
<
pattern
::
Matcher
>
create_binary_matcher
(
s
td
::
s
hared_ptr
<
pattern
::
op
::
Label
>
label
,
s
td
::
s
hared_ptr
<
pattern
::
op
::
Label
>
const_label
)
static
shared_ptr
<
pattern
::
Matcher
>
create_binary_matcher
(
shared_ptr
<
pattern
::
op
::
Label
>
label
,
shared_ptr
<
pattern
::
op
::
Label
>
const_label
)
{
auto
bcst
=
std
::
make_shared
<
pattern
::
op
::
Skip
>
(
const_label
,
pattern
::
has_class
<
op
::
Broadcast
>
());
auto
bcst_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
bcst
,
nullptr
,
NodeVector
{
bcst
});
auto
matcher
=
std
::
make_shared
<
pattern
::
Matcher
>
(
std
::
make_shared
<
T
>
(
label
,
bcst_label
),
nullptr
);
auto
bcst
=
make_shared
<
pattern
::
op
::
Skip
>
(
const_label
,
pattern
::
has_class
<
op
::
Broadcast
>
());
auto
bcst_label
=
make_shared
<
pattern
::
op
::
Label
>
(
bcst
,
nullptr
,
NodeVector
{
bcst
});
auto
matcher
=
make_shared
<
pattern
::
Matcher
>
(
make_shared
<
T
>
(
label
,
bcst_label
),
nullptr
);
return
matcher
;
}
static
std
::
shared_ptr
<
pattern
::
op
::
Label
>
get_broadcast_label
(
std
::
shared_ptr
<
pattern
::
Matcher
>
matcher
)
static
shared_ptr
<
pattern
::
op
::
Label
>
get_broadcast_label
(
shared_ptr
<
pattern
::
Matcher
>
matcher
)
{
return
std
::
dynamic_pointer_cast
<
pattern
::
op
::
Label
>
(
matcher
->
get_pattern
()
->
get_argument
(
1
));
return
dynamic_pointer_cast
<
pattern
::
op
::
Label
>
(
matcher
->
get_pattern
()
->
get_argument
(
1
));
}
//`simplify_concat` identifies slices-concat sequences
...
...
@@ -75,23 +72,21 @@ static std::shared_ptr<pattern::op::Label>
// +-------+ | +----------+ | +-----------+
// +----+slice(0..n/2)---+
// +----------+
static
bool
simplify_concat
(
s
td
::
s
hared_ptr
<
Node
>
n
)
static
bool
simplify_concat
(
shared_ptr
<
Node
>
n
)
{
NGRAPH_DEBUG
<<
"In simplify_concat for "
<<
n
->
get_name
();
s
td
::
s
hared_ptr
<
Node
>
branch_tip
;
shared_ptr
<
Node
>
branch_tip
;
auto
ltip
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i32
,
Shape
{
2
,
1
});
auto
ltip
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i32
,
Shape
{
2
,
1
});
auto
pslice
=
std
::
make_shared
<
op
::
Slice
>
(
ltip
,
Coordinate
{
0
,
0
},
Coordinate
{
2
,
1
},
Strides
{
1
,
1
});
auto
pslice
=
make_shared
<
op
::
Slice
>
(
ltip
,
Coordinate
{
0
,
0
},
Coordinate
{
2
,
1
},
Strides
{
1
,
1
});
auto
lslice
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
pslice
,
nullptr
,
NodeVector
{
pslice
});
auto
lslice
=
make_shared
<
pattern
::
op
::
Label
>
(
pslice
,
nullptr
,
NodeVector
{
pslice
});
auto
skip_reshape
=
std
::
make_shared
<
pattern
::
op
::
Skip
>
(
lslice
,
pattern
::
has_class
<
op
::
Reshape
>
());
auto
skip_reshape
=
make_shared
<
pattern
::
op
::
Skip
>
(
lslice
,
pattern
::
has_class
<
op
::
Reshape
>
());
auto
matcher
=
std
::
make_shared
<
pattern
::
Matcher
>
(
skip_reshape
,
nullptr
);
auto
matcher
=
make_shared
<
pattern
::
Matcher
>
(
skip_reshape
,
nullptr
);
Coordinate
prev_lower_bounds
;
Shape
prev_slice_shape
;
...
...
@@ -104,7 +99,7 @@ static bool simplify_concat(std::shared_ptr<Node> n)
return
false
;
}
auto
slice
=
st
d
::
st
atic_pointer_cast
<
op
::
Slice
>
(
matcher
->
get_pattern_map
()[
lslice
]);
auto
slice
=
static_pointer_cast
<
op
::
Slice
>
(
matcher
->
get_pattern_map
()[
lslice
]);
if
(
branch_tip
)
{
if
(
branch_tip
!=
matcher
->
get_pattern_map
()[
ltip
])
...
...
@@ -153,9 +148,9 @@ static bool simplify_concat(std::shared_ptr<Node> n)
}
//check that no other node uses slices and reshapes
if
(
auto
rcarg
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
carg
))
if
(
auto
rcarg
=
dynamic_pointer_cast
<
op
::
Reshape
>
(
carg
))
{
auto
default_shape
=
ngraph
::
get_default_order
(
rcarg
->
get_argument
(
0
)
->
get_shape
());
auto
default_shape
=
get_default_order
(
rcarg
->
get_argument
(
0
)
->
get_shape
());
if
(
default_shape
!=
rcarg
->
get_input_order
())
{
NGRAPH_DEBUG
<<
carg
->
get_name
()
<<
" reshape also does transposes"
;
...
...
@@ -170,11 +165,11 @@ static bool simplify_concat(std::shared_ptr<Node> n)
}
}
auto
concat
=
st
d
::
st
atic_pointer_cast
<
op
::
Concat
>
(
n
);
auto
concat
=
static_pointer_cast
<
op
::
Concat
>
(
n
);
size_t
concat_axis
=
concat
->
get_concatenation_axis
();
auto
slice_shape
=
branch_tip
->
get_users
(
true
).
at
(
0
)
->
get_shape
();
size_t
slice_axis
=
std
::
numeric_limits
<
size_t
>::
max
();
size_t
slice_axis
=
numeric_limits
<
size_t
>::
max
();
auto
btip_shape
=
branch_tip
->
get_shape
();
...
...
@@ -191,7 +186,7 @@ static bool simplify_concat(std::shared_ptr<Node> n)
{
if
(
btip_shape
[
i
]
!=
slice_shape
[
i
])
{
if
(
slice_axis
!=
std
::
numeric_limits
<
size_t
>::
max
())
if
(
slice_axis
!=
numeric_limits
<
size_t
>::
max
())
{
// multi-axis slice + concat do not cancel
return
false
;
...
...
@@ -200,19 +195,18 @@ static bool simplify_concat(std::shared_ptr<Node> n)
}
}
if
(
slice_axis
==
std
::
numeric_limits
<
size_t
>::
max
())
if
(
slice_axis
==
numeric_limits
<
size_t
>::
max
())
{
return
false
;
}
auto
replacement
=
branch_tip
;
if
(
btip_shape
!=
n
->
get_shape
())
{
auto
default_order
=
ngraph
::
get_default_order
(
btip_shape
);
auto
default_order
=
get_default_order
(
btip_shape
);
if
(
concat_axis
==
slice_axis
)
{
// logical reshape only
replacement
=
std
::
make_shared
<
op
::
Reshape
>
(
branch_tip
,
default_order
,
concat
->
get_shape
());
replacement
=
make_shared
<
op
::
Reshape
>
(
branch_tip
,
default_order
,
concat
->
get_shape
());
}
else
{
...
...
@@ -221,30 +215,29 @@ static bool simplify_concat(std::shared_ptr<Node> n)
if
(
btip_shape
.
size
()
>=
transposed_shape
.
size
())
{
AxisVector
order
=
ngraph
::
get_default_order
(
btip_shape
);
AxisVector
order
=
get_default_order
(
btip_shape
);
auto
ax
=
order
[
slice_axis
];
order
[
slice_axis
]
=
order
[
concat_axis
];
order
[
concat_axis
]
=
ax
;
replacement
=
std
::
make_shared
<
op
::
Reshape
>
(
branch_tip
,
order
,
transposed_shape
);
replacement
=
make_shared
<
op
::
Reshape
>
(
branch_tip
,
order
,
transposed_shape
);
}
else
if
(
btip_shape
.
size
()
<
transposed_shape
.
size
())
{
// intermediate logical reshape
AxisVector
order
=
ngraph
::
get_default_order
(
transposed_shape
);
AxisVector
order
=
get_default_order
(
transposed_shape
);
auto
ax
=
order
[
slice_axis
];
order
[
slice_axis
]
=
order
[
concat_axis
];
order
[
concat_axis
]
=
ax
;
auto
output_shape
=
ngraph
::
apply_permutation
(
transposed_shape
,
order
);
auto
output_shape
=
apply_permutation
(
transposed_shape
,
order
);
auto
logical_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
branch_tip
,
default_order
,
output_shape
);
make_shared
<
op
::
Reshape
>
(
branch_tip
,
default_order
,
output_shape
);
// transpose to final concatenated shape
replacement
=
std
::
make_shared
<
op
::
Reshape
>
(
logical_reshape
,
order
,
transposed_shape
);
replacement
=
make_shared
<
op
::
Reshape
>
(
logical_reshape
,
order
,
transposed_shape
);
}
}
}
ngraph
::
replace_node
(
n
,
replacement
);
replace_node
(
n
,
replacement
);
return
true
;
}
...
...
@@ -255,15 +248,13 @@ static bool simplify_concat(std::shared_ptr<Node> n)
//a * broadcast(0) -> broadcast(0)
//a * 1 -> a
//a * broadcast(1) -> a
static
bool
simplify_multiply
(
s
td
::
s
hared_ptr
<
Node
>
n
)
static
bool
simplify_multiply
(
shared_ptr
<
Node
>
n
)
{
NGRAPH_DEBUG
<<
"In simplify_multiply for "
<<
n
->
get_name
();
auto
iconst
=
ngraph
::
make_zero
(
element
::
i32
,
Shape
{});
auto
label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
iconst
);
auto
const_label_zero
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
iconst
,
ngraph
::
is_zero
,
NodeVector
{
iconst
});
auto
const_label_one
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
iconst
,
ngraph
::
is_one
,
NodeVector
{
iconst
});
auto
iconst
=
make_zero
(
element
::
i32
,
Shape
{});
auto
label
=
make_shared
<
pattern
::
op
::
Label
>
(
iconst
);
auto
const_label_zero
=
make_shared
<
pattern
::
op
::
Label
>
(
iconst
,
is_zero
,
NodeVector
{
iconst
});
auto
const_label_one
=
make_shared
<
pattern
::
op
::
Label
>
(
iconst
,
is_one
,
NodeVector
{
iconst
});
auto
matcher_const_zero
=
create_binary_matcher
<
op
::
Multiply
>
(
label
,
const_label_zero
);
auto
matcher_const_one
=
create_binary_matcher
<
op
::
Multiply
>
(
label
,
const_label_one
);
...
...
@@ -273,7 +264,7 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
auto
bcst_label
=
get_broadcast_label
(
matcher_const_zero
);
auto
bcst_or_cnst
=
matcher_const_zero
->
get_pattern_map
()[
bcst_label
];
NGRAPH_DEBUG
<<
" Replacing "
<<
n
->
get_name
()
<<
" with "
<<
bcst_or_cnst
->
get_name
();
ngraph
::
replace_node
(
n
,
bcst_or_cnst
);
replace_node
(
n
,
bcst_or_cnst
);
return
true
;
}
...
...
@@ -281,7 +272,7 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
{
auto
x
=
matcher_const_one
->
get_pattern_map
()[
label
];
NGRAPH_DEBUG
<<
" Replacing "
<<
n
->
get_name
()
<<
" with "
<<
x
->
get_name
();
ngraph
::
replace_node
(
n
,
x
);
replace_node
(
n
,
x
);
return
true
;
}
...
...
@@ -293,12 +284,12 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
//
//a + 0 -> a
//a + broadcast(0) -> a
static
bool
simplify_add
(
s
td
::
s
hared_ptr
<
Node
>
n
)
static
bool
simplify_add
(
shared_ptr
<
Node
>
n
)
{
NGRAPH_DEBUG
<<
"In simplify_add for "
<<
n
->
get_name
();
auto
iconst
=
ngraph
::
make_zero
(
element
::
i32
,
Shape
{});
auto
label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
iconst
);
auto
const_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
iconst
,
nullptr
,
NodeVector
{
iconst
});
auto
iconst
=
make_zero
(
element
::
i32
,
Shape
{});
auto
label
=
make_shared
<
pattern
::
op
::
Label
>
(
iconst
);
auto
const_label
=
make_shared
<
pattern
::
op
::
Label
>
(
iconst
,
nullptr
,
NodeVector
{
iconst
});
auto
matcher
=
create_binary_matcher
<
op
::
Add
>
(
label
,
const_label
);
if
(
matcher
->
match
(
n
))
...
...
@@ -309,10 +300,10 @@ static bool simplify_add(std::shared_ptr<Node> n)
NGRAPH_DEBUG
<<
"Node "
<<
n
->
get_name
()
<<
" matched
\"
arg + 0
\"
\n
"
<<
" arg : "
<<
x
->
get_name
()
<<
" , const : "
<<
cnst
->
get_name
();
if
(
ngraph
::
is_zero
(
cnst
))
if
(
is_zero
(
cnst
))
{
NGRAPH_DEBUG
<<
" Replacing "
<<
n
->
get_name
()
<<
" with "
<<
x
->
get_name
();
ngraph
::
replace_node
(
n
,
x
);
replace_node
(
n
,
x
);
return
true
;
}
else
...
...
@@ -324,16 +315,16 @@ static bool simplify_add(std::shared_ptr<Node> n)
}
//`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)`
static
bool
simplify_log
(
s
td
::
s
hared_ptr
<
Node
>
n
)
static
bool
simplify_log
(
shared_ptr
<
Node
>
n
)
{
if
(
auto
div
=
std
::
dynamic_pointer_cast
<
op
::
Divide
>
(
n
->
get_argument
(
0
)))
if
(
auto
div
=
dynamic_pointer_cast
<
op
::
Divide
>
(
n
->
get_argument
(
0
)))
{
if
(
auto
exp
=
std
::
dynamic_pointer_cast
<
op
::
Exp
>
(
div
->
get_argument
(
0
)))
if
(
auto
exp
=
dynamic_pointer_cast
<
op
::
Exp
>
(
div
->
get_argument
(
0
)))
{
auto
denom
=
div
->
get_argument
(
1
);
auto
diff
=
std
::
make_shared
<
op
::
Subtract
>
(
exp
->
get_argument
(
0
),
std
::
make_shared
<
op
::
Log
>
(
denom
));
ngraph
::
replace_node
(
n
,
diff
);
auto
diff
=
make_shared
<
op
::
Subtract
>
(
exp
->
get_argument
(
0
),
make_shared
<
op
::
Log
>
(
denom
));
replace_node
(
n
,
diff
);
return
true
;
}
}
...
...
@@ -353,16 +344,15 @@ static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape)
}
template
<
typename
T
>
static
s
td
::
s
hared_ptr
<
Node
>
multiply_by
(
element
::
Type
type
,
size_t
multiplier
,
s
td
::
s
hared_ptr
<
op
::
Constant
>
cnst
)
static
shared_ptr
<
Node
>
multiply_by
(
element
::
Type
type
,
size_t
multiplier
,
shared_ptr
<
op
::
Constant
>
cnst
)
{
T
sum_cnst
=
static_cast
<
T
>
(
cnst
->
get_vector
<
T
>
().
at
(
0
)
*
multiplier
);
return
op
::
Constant
::
create
<
T
>
(
type
,
Shape
{},
{
sum_cnst
});
}
template
<
typename
T
>
static
std
::
shared_ptr
<
Node
>
pow_by
(
element
::
Type
type
,
size_t
multiplier
,
std
::
shared_ptr
<
op
::
Constant
>
cnst
)
static
shared_ptr
<
Node
>
pow_by
(
element
::
Type
type
,
size_t
multiplier
,
shared_ptr
<
op
::
Constant
>
cnst
)
{
T
prod
=
static_cast
<
T
>
(
1
);
T
val
=
cnst
->
get_vector
<
T
>
().
at
(
0
);
...
...
@@ -373,7 +363,7 @@ static std::shared_ptr<Node>
return
op
::
Constant
::
create
<
T
>
(
type
,
Shape
{},
{
prod
});
}
static
s
td
::
shared_ptr
<
Node
>
get_sum_constant
(
std
::
shared_ptr
<
op
::
Constant
>
cnst
,
size_t
multiplier
)
static
s
hared_ptr
<
Node
>
get_sum_constant
(
shared_ptr
<
op
::
Constant
>
cnst
,
size_t
multiplier
)
{
if
(
cnst
->
get_element_type
()
==
element
::
i32
)
{
...
...
@@ -395,8 +385,7 @@ static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst
return
nullptr
;
}
static
std
::
shared_ptr
<
Node
>
get_prod_constant
(
std
::
shared_ptr
<
op
::
Constant
>
cnst
,
size_t
multiplier
)
static
shared_ptr
<
Node
>
get_prod_constant
(
shared_ptr
<
op
::
Constant
>
cnst
,
size_t
multiplier
)
{
if
(
cnst
->
get_element_type
()
==
element
::
i32
)
{
...
...
@@ -423,21 +412,20 @@ static std::shared_ptr<Node> get_prod_constant(std::shared_ptr<op::Constant> cns
//where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
//product(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant ^ shape_size(reduction_axes)
template
<
typename
T
,
std
::
shared_ptr
<
Node
>
(
*
F
)(
std
::
shared_ptr
<
op
::
Constant
>
cnst
,
size_t
multiplier
)
>
static
bool
simplify_reduction
(
std
::
shared_ptr
<
Node
>
n
)
template
<
typename
T
,
shared_ptr
<
Node
>
(
*
F
)(
shared_ptr
<
op
::
Constant
>
cnst
,
size_t
multiplier
)
>
static
bool
simplify_reduction
(
shared_ptr
<
Node
>
n
)
{
NGRAPH_DEBUG
<<
"In simplify_reduction for "
<<
n
->
get_name
();
auto
reduction
=
st
d
::
st
atic_pointer_cast
<
T
>
(
n
);
auto
reduction
=
static_pointer_cast
<
T
>
(
n
);
auto
broadcast
=
std
::
dynamic_pointer_cast
<
op
::
Broadcast
>
(
n
->
get_argument
(
0
));
auto
broadcast
=
dynamic_pointer_cast
<
op
::
Broadcast
>
(
n
->
get_argument
(
0
));
if
(
!
broadcast
)
{
NGRAPH_DEBUG
<<
n
->
get_name
()
<<
" isn't Broadcast"
;
return
false
;
}
auto
cnst
=
std
::
dynamic_pointer_cast
<
op
::
Constant
>
(
broadcast
->
get_argument
(
0
));
auto
cnst
=
dynamic_pointer_cast
<
op
::
Constant
>
(
broadcast
->
get_argument
(
0
));
if
(
!
cnst
||
cnst
->
get_shape
().
size
()
>
0
/*not a scalar*/
)
{
NGRAPH_DEBUG
<<
broadcast
->
get_argument
(
0
)
->
get_name
()
<<
" isn't a scalar constant"
;
...
...
@@ -456,39 +444,35 @@ static bool simplify_reduction(std::shared_ptr<Node> n)
if
(
reduction
->
get_shape
().
size
()
>
0
)
{
ngraph
::
AxisSet
axes
{};
AxisSet
axes
{};
for
(
size_t
i
=
0
;
i
<
reduction
->
get_shape
().
size
();
i
++
)
{
axes
.
insert
(
i
);
}
reduction_cnst
=
std
::
make_shared
<
op
::
Broadcast
>
(
reduction_cnst
,
reduction
->
get_shape
(),
axes
);
reduction_cnst
=
make_shared
<
op
::
Broadcast
>
(
reduction_cnst
,
reduction
->
get_shape
(),
axes
);
}
ngraph
::
replace_node
(
n
,
reduction_cnst
);
replace_node
(
n
,
reduction_cnst
);
return
true
;
}
static
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
)
>>
initialize_ops_to_simplifiers
()
static
unordered_map
<
type_index
,
function
<
bool
(
shared_ptr
<
Node
>
)
>>
initialize_ops_to_simplifiers
()
{
return
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
)
>>
(
return
unordered_map
<
type_index
,
function
<
bool
(
shared_ptr
<
Node
>
)
>>
(
{{
TI
(
op
::
Add
),
simplify_add
},
{
TI
(
op
::
Multiply
),
simplify_multiply
},
{
TI
(
op
::
Concat
),
simplify_concat
},
{
TI
(
op
::
Sum
),
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
)
>
{
simplify_reduction
<
op
::
Sum
,
get_sum_constant
>
}},
function
<
bool
(
shared_ptr
<
Node
>
)
>
{
simplify_reduction
<
op
::
Sum
,
get_sum_constant
>
}},
{
TI
(
op
::
Product
),
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
)
>
{
simplify_reduction
<
op
::
Product
,
get_prod_constant
>
}},
function
<
bool
(
shared_ptr
<
Node
>
)
>
{
simplify_reduction
<
op
::
Product
,
get_prod_constant
>
}},
{
TI
(
op
::
Log
),
simplify_log
}});
}
static
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
)
>>
ops_to_simplifiers
=
initialize_ops_to_simplifiers
();
static
unordered_map
<
type_index
,
function
<
bool
(
shared_ptr
<
Node
>
)
>>
ops_to_simplifiers
=
initialize_ops_to_simplifiers
();
bool
ngraph
::
pass
::
AlgebraicSimplification
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
bool
pass
::
AlgebraicSimplification
::
run_on_function
(
shared_ptr
<
Function
>
f
)
{
bool
replaced
=
false
;
for
(
auto
n
:
f
->
get_ordered_ops
())
...
...
src/ngraph/pass/constant_folding.cpp
View file @
13fc556e
...
...
@@ -89,7 +89,7 @@ shared_ptr<op::Constant> make_constant_pad(shared_ptr<op::Constant> constant,
return
make_shared
<
op
::
Constant
>
(
constant
->
get_element_type
(),
out_shape
,
out_vec
);
}
void
ngraph
::
pass
::
ConstantFolding
::
construct_constant_pad
()
void
pass
::
ConstantFolding
::
construct_constant_pad
()
{
auto
is_constant
=
pattern
::
has_class
<
op
::
Constant
>
();
auto
constant_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
6
},
is_constant
);
...
...
@@ -142,7 +142,7 @@ void ngraph::pass::ConstantFolding::construct_constant_pad()
this
->
add_matcher
(
pad_matcher
);
}
void
ngraph
::
pass
::
ConstantFolding
::
construct_constant_reshape
()
void
pass
::
ConstantFolding
::
construct_constant_reshape
()
{
auto
constant_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
2
,
4
},
pattern
::
has_class
<
op
::
Constant
>
());
...
...
@@ -207,7 +207,7 @@ shared_ptr<op::Constant> make_constant_broadcast(shared_ptr<op::Constant> consta
return
make_shared
<
op
::
Constant
>
(
constant
->
get_element_type
(),
out_shape
,
out_vec
);
}
void
ngraph
::
pass
::
ConstantFolding
::
construct_constant_broadcast
()
void
pass
::
ConstantFolding
::
construct_constant_broadcast
()
{
auto
constant_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
2
},
pattern
::
has_class
<
op
::
Constant
>
());
...
...
@@ -324,7 +324,7 @@ bool is_supported_binary_op(std::shared_ptr<Node> n)
std
::
dynamic_pointer_cast
<
op
::
Minimum
>
(
n
));
}
void
ngraph
::
pass
::
ConstantFolding
::
construct_constant_binary
()
void
pass
::
ConstantFolding
::
construct_constant_binary
()
{
auto
a
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
2
,
4
},
pattern
::
has_class
<
op
::
Constant
>
());
...
...
@@ -418,7 +418,7 @@ shared_ptr<op::Constant> make_constant_unary(shared_ptr<op::Constant> constant,
return
make_shared
<
op
::
Constant
>
(
constant
->
get_element_type
(),
out_shape
,
out_vec
);
}
void
ngraph
::
pass
::
ConstantFolding
::
construct_constant_unary
()
void
pass
::
ConstantFolding
::
construct_constant_unary
()
{
auto
constant_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
2
,
4
},
pattern
::
has_class
<
op
::
Constant
>
());
...
...
@@ -493,7 +493,7 @@ shared_ptr<op::Constant> make_constant_dequantize(shared_ptr<op::Constant> const
return
make_shared
<
op
::
Constant
>
(
dequant
->
get_element_type
(),
out_shape
,
out_vec
);
}
void
ngraph
::
pass
::
ConstantFolding
::
construct_constant_dequantize
()
void
pass
::
ConstantFolding
::
construct_constant_dequantize
()
{
auto
constant_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
u8
,
Shape
{
2
},
pattern
::
has_class
<
op
::
Constant
>
());
...
...
@@ -567,7 +567,7 @@ shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constan
return
make_shared
<
op
::
Constant
>
(
quant
->
get_element_type
(),
out_shape
,
out_vec
);
}
void
ngraph
::
pass
::
ConstantFolding
::
construct_constant_quantize
()
void
pass
::
ConstantFolding
::
construct_constant_quantize
()
{
auto
constant_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
2
},
pattern
::
has_class
<
op
::
Constant
>
());
...
...
src/ngraph/pass/core_fusion.cpp
View file @
13fc556e
...
...
@@ -69,7 +69,7 @@ void pass::CoreFusion::construct_relu()
auto
pattern_map
=
m
.
get_pattern_map
();
auto
mzero
=
m
.
get_pattern_map
()[
zero
];
if
(
!
ngraph
::
is_zero
(
mzero
))
if
(
!
is_zero
(
mzero
))
{
NGRAPH_DEBUG
<<
"zero constant = "
<<
mzero
->
get_name
()
<<
" not equal to 0
\n
"
;
return
false
;
...
...
@@ -77,7 +77,7 @@ void pass::CoreFusion::construct_relu()
auto
mpattern
=
m
.
get_match_root
();
auto
cg
=
shared_ptr
<
Node
>
(
new
op
::
Relu
(
pattern_map
[
val
]));
ngraph
::
replace_node
(
m
.
get_match_root
(),
cg
);
replace_node
(
m
.
get_match_root
(),
cg
);
return
true
;
};
...
...
@@ -100,7 +100,7 @@ void pass::CoreFusion::construct_sigmoid()
auto
divide_1_over_exp
=
std
::
make_shared
<
op
::
Divide
>
(
skip_broadcast
,
add_exp
);
// Define a call back that needs to called once the DFG matches the pattern
ngraph
::
pattern
::
graph_rewrite_callback
callback
=
[
input
,
constant
](
pattern
::
Matcher
&
m
)
{
pattern
::
graph_rewrite_callback
callback
=
[
input
,
constant
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_sigmoid pattern against "
<<
m
.
get_match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
...
...
@@ -125,12 +125,11 @@ void pass::CoreFusion::construct_sigmoid()
return
false
;
}
auto
sigmoid_node
=
std
::
make_shared
<
op
::
Sigmoid
>
(
pattern_map
[
input
]);
ngraph
::
replace_node
(
m
.
get_match_root
(),
sigmoid_node
);
replace_node
(
m
.
get_match_root
(),
sigmoid_node
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
divide_1_over_exp
,
callback
,
"CoreFusion.Sigmoid"
);
auto
m
=
std
::
make_shared
<
pattern
::
Matcher
>
(
divide_1_over_exp
,
callback
,
"CoreFusion.Sigmoid"
);
this
->
add_matcher
(
m
);
}
...
...
@@ -159,7 +158,7 @@ void pass::CoreFusion::construct_sigmoid_bprop()
auto
negative_2
=
std
::
make_shared
<
op
::
Negative
>
(
multiply_2
);
// Define a call back that needs to called once the DFG matches the pattern
ngraph
::
pattern
::
graph_rewrite_callback
callback
=
[
input
,
delta
](
pattern
::
Matcher
&
m
)
{
pattern
::
graph_rewrite_callback
callback
=
[
input
,
delta
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In a callback for construct_bprop_sigmoid pattern against "
<<
m
.
get_match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
...
...
@@ -178,12 +177,11 @@ void pass::CoreFusion::construct_sigmoid_bprop()
}
auto
dsigmoid
=
std
::
make_shared
<
op
::
SigmoidBackprop
>
(
pattern_map
[
input
],
pattern_map
[
delta
]);
ngraph
::
replace_node
(
m
.
get_match_root
(),
dsigmoid
);
replace_node
(
m
.
get_match_root
(),
dsigmoid
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
negative_2
,
callback
,
"CoreFusion.SigmoidBprop"
);
auto
m
=
std
::
make_shared
<
pattern
::
Matcher
>
(
negative_2
,
callback
,
"CoreFusion.SigmoidBprop"
);
this
->
add_matcher
(
m
);
}
...
...
@@ -212,7 +210,7 @@ void pass::CoreFusion::construct_folded_batch_norm()
auto
shape_r
=
Shape
{
1
,
2
,
2
,
2
};
auto
bn
=
std
::
make_shared
<
op
::
BatchNormInference
>
(
eps
,
gamma
,
beta
,
pconv
,
mean
,
var
);
ngraph
::
pattern
::
graph_rewrite_callback
callback
=
[
input
,
filters
,
mean
,
var
,
gamma
,
beta
](
pattern
::
graph_rewrite_callback
callback
=
[
input
,
filters
,
mean
,
var
,
gamma
,
beta
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In callback for folded batch norm against node = "
<<
m
.
get_match_root
()
->
get_name
();
...
...
@@ -258,13 +256,13 @@ void pass::CoreFusion::construct_folded_batch_norm()
m_conv
->
get_data_dilation_strides
());
auto
conv_bias
=
conv
+
std
::
make_shared
<
op
::
Broadcast
>
(
new_biases
,
conv
->
get_shape
(),
AxisSet
{
0
,
2
,
3
});
ngraph
::
replace_node
(
m
.
get_match_root
(),
conv_bias
);
replace_node
(
m
.
get_match_root
(),
conv_bias
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
bn
,
callback
,
"CoreFusion.FoldedBatchNorm"
);
auto
m
=
std
::
make_shared
<
pattern
::
Matcher
>
(
bn
,
callback
,
"CoreFusion.FoldedBatchNorm"
);
this
->
add_matcher
(
m
);
}
...
...
@@ -293,7 +291,7 @@ void pass::CoreFusion::construct_conv_affine_folding()
auto
multiply
=
std
::
make_shared
<
op
::
Multiply
>
(
conv_label
,
A_label
);
auto
add
=
std
::
make_shared
<
op
::
Add
>
(
multiply
,
B_label
);
ngraph
::
pattern
::
graph_rewrite_callback
callback
=
pattern
::
graph_rewrite_callback
callback
=
[
input
,
filters
,
conv_label
,
A_label
,
B_label
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In callback for conv affine folding against node = "
<<
m
.
get_match_root
()
->
get_name
();
...
...
@@ -345,7 +343,7 @@ void pass::CoreFusion::construct_conv_affine_folding()
if
(
bcast
->
get_argument
(
0
)
->
get_shape
().
size
()
==
2
)
{
Shape
bshape
{
bcast
->
get_argument
(
0
)
->
get_shape
()[
1
]};
return
static_pointer_cast
<
ngraph
::
Node
>
(
std
::
make_shared
<
op
::
Reshape
>
(
return
static_pointer_cast
<
Node
>
(
std
::
make_shared
<
op
::
Reshape
>
(
bcast
->
get_argument
(
0
),
AxisVector
{
0
,
1
},
bshape
));
}
throw
ngraph_error
(
"Unexpected shape for bcast input"
);
...
...
@@ -369,14 +367,13 @@ void pass::CoreFusion::construct_conv_affine_folding()
conv_m
->
get_padding_above
(),
conv_m
->
get_data_dilation_strides
());
auto
convbias_n
=
conv_n
+
B_m
;
ngraph
::
replace_node
(
m
.
get_match_root
(),
convbias_n
);
replace_node
(
m
.
get_match_root
(),
convbias_n
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
add
,
callback
,
"CoreFusion.ConvAffineFolding"
);
auto
m
=
std
::
make_shared
<
pattern
::
Matcher
>
(
add
,
callback
,
"CoreFusion.ConvAffineFolding"
);
this
->
add_matcher
(
m
);
}
...
...
@@ -440,7 +437,7 @@ static size_t shape_to_index(Shape shape)
}
}
void
ngraph
::
pass
::
CoreFusion
::
construct_reshape_broadcast
()
void
pass
::
CoreFusion
::
construct_reshape_broadcast
()
{
Shape
input_shape
{
10
};
auto
input
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
input_shape
);
...
...
@@ -473,7 +470,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast()
if
(
d
!=
1
&&
d
!=
dim
)
{
NGRAPH_DEBUG
<<
"Input is reshaped in a way we can't directly broadcast ( shape = "
<<
ngraph
::
vector_to_string
(
reshape1_m
->
get_shape
())
<<
")"
;
<<
vector_to_string
(
reshape1_m
->
get_shape
())
<<
")"
;
return
false
;
}
...
...
@@ -502,7 +499,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast()
auto
new_broadcast
=
make_shared
<
op
::
Broadcast
>
(
input_m
,
broadcast_m
->
get_shape
(),
new_axes
);
ngraph
::
replace_node
(
m
.
get_match_root
(),
new_broadcast
);
replace_node
(
m
.
get_match_root
(),
new_broadcast
);
return
true
;
};
...
...
@@ -520,7 +517,7 @@ void ngraph::pass::CoreFusion::construct_reshape_broadcast()
void
pass
::
CoreFusion
::
construct_optimized_strided_conv
()
{
Shape
win_size_1
{
1
,
1
,
1
,
1
};
auto
is_bc
=
ngraph
::
pattern
::
has_class
<
op
::
Broadcast
>
();
auto
is_bc
=
pattern
::
has_class
<
op
::
Broadcast
>
();
auto
data_stride3
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
1
,
1
,
128
,
128
});
auto
weights_stride3
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
win_size_1
);
...
...
@@ -689,7 +686,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
new_relu_two_convs
,
sconv
->
get_argument
(
1
),
stride_1
,
stride_1
);
NGRAPH_DEBUG
<<
"Replacing "
<<
sconv
->
get_name
()
<<
" with "
<<
sconv_28w1s1
->
get_name
();
ngraph
::
replace_node
(
sconv
,
sconv_28w1s1
);
replace_node
(
sconv
,
sconv_28w1s1
);
}
return
true
;
};
...
...
@@ -699,7 +696,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
this
->
add_matcher
(
m
);
}
void
ngraph
::
pass
::
CoreFusion
::
construct_reshape_softmax_reshape
()
void
pass
::
CoreFusion
::
construct_reshape_softmax_reshape
()
{
Shape
input_shape
{
10
,
20
};
AxisVector
io
{
1
,
0
};
...
...
@@ -738,7 +735,7 @@ void ngraph::pass::CoreFusion::construct_reshape_softmax_reshape()
}
auto
new_softmax
=
make_shared
<
op
::
Softmax
>
(
input_m
,
new_axes
);
ngraph
::
replace_node
(
m
.
get_match_root
(),
new_softmax
);
replace_node
(
m
.
get_match_root
(),
new_softmax
);
return
true
;
};
...
...
src/ngraph/pass/cse.cpp
View file @
13fc556e
...
...
@@ -59,11 +59,12 @@
#include "ngraph/op/tanh.hpp"
#include "ngraph/pattern/matcher.hpp"
using
namespace
std
;
using
namespace
ngraph
;
#define TI(x)
std::
type_index(typeid(x))
#define TI(x) type_index(typeid(x))
static
bool
cse_constant
(
s
td
::
shared_ptr
<
Node
>
a
,
std
::
shared_ptr
<
Node
>
b
)
static
bool
cse_constant
(
s
hared_ptr
<
Node
>
a
,
shared_ptr
<
Node
>
b
)
{
NGRAPH_DEBUG
<<
"In cse_constant for "
<<
a
->
get_name
()
<<
" and "
<<
b
->
get_name
();
...
...
@@ -72,44 +73,44 @@ static bool cse_constant(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
return
false
;
}
auto
ca
=
st
d
::
st
atic_pointer_cast
<
op
::
Constant
>
(
a
);
auto
cb
=
st
d
::
st
atic_pointer_cast
<
op
::
Constant
>
(
b
);
auto
ca
=
static_pointer_cast
<
op
::
Constant
>
(
a
);
auto
cb
=
static_pointer_cast
<
op
::
Constant
>
(
b
);
size_t
size
=
shape_size
(
a
->
get_shape
())
*
a
->
get_element_type
().
size
();
return
!
memcmp
(
ca
->
get_data_ptr
(),
cb
->
get_data_ptr
(),
size
);
}
static
bool
cse_reshape
(
s
td
::
shared_ptr
<
Node
>
a
,
std
::
shared_ptr
<
Node
>
b
)
static
bool
cse_reshape
(
s
hared_ptr
<
Node
>
a
,
shared_ptr
<
Node
>
b
)
{
NGRAPH_DEBUG
<<
"In cse_reshape for "
<<
a
->
get_name
()
<<
" and "
<<
b
->
get_name
();
auto
reshape_a
=
st
d
::
st
atic_pointer_cast
<
ngraph
::
op
::
Reshape
>
(
a
);
auto
reshape_b
=
st
d
::
st
atic_pointer_cast
<
ngraph
::
op
::
Reshape
>
(
b
);
auto
reshape_a
=
static_pointer_cast
<
ngraph
::
op
::
Reshape
>
(
a
);
auto
reshape_b
=
static_pointer_cast
<
ngraph
::
op
::
Reshape
>
(
b
);
return
(
a
->
get_argument
(
0
)
==
b
->
get_argument
(
0
))
&&
(
reshape_a
->
get_input_order
()
==
reshape_b
->
get_input_order
())
&&
(
reshape_a
->
get_output_shape
()
==
reshape_b
->
get_output_shape
());
}
static
bool
cse_broadcast
(
s
td
::
shared_ptr
<
Node
>
a
,
std
::
shared_ptr
<
Node
>
b
)
static
bool
cse_broadcast
(
s
hared_ptr
<
Node
>
a
,
shared_ptr
<
Node
>
b
)
{
NGRAPH_DEBUG
<<
"In cse_broadcast for "
<<
a
->
get_name
()
<<
" and "
<<
b
->
get_name
();
auto
broadcast_a
=
st
d
::
st
atic_pointer_cast
<
ngraph
::
op
::
Broadcast
>
(
a
);
auto
broadcast_b
=
st
d
::
st
atic_pointer_cast
<
ngraph
::
op
::
Broadcast
>
(
b
);
auto
broadcast_a
=
static_pointer_cast
<
ngraph
::
op
::
Broadcast
>
(
a
);
auto
broadcast_b
=
static_pointer_cast
<
ngraph
::
op
::
Broadcast
>
(
b
);
return
(
a
->
get_argument
(
0
)
==
b
->
get_argument
(
0
))
&&
(
broadcast_a
->
get_broadcast_axes
()
==
broadcast_b
->
get_broadcast_axes
())
&&
(
broadcast_a
->
get_broadcast_shape
()
==
broadcast_b
->
get_broadcast_shape
());
}
static
bool
cse_unarywise
(
s
td
::
shared_ptr
<
Node
>
a
,
std
::
shared_ptr
<
Node
>
b
)
static
bool
cse_unarywise
(
s
hared_ptr
<
Node
>
a
,
shared_ptr
<
Node
>
b
)
{
NGRAPH_DEBUG
<<
"In cse_unarywise for "
<<
a
->
get_name
()
<<
" and "
<<
b
->
get_name
();
return
a
->
get_argument
(
0
)
==
b
->
get_argument
(
0
);
}
static
bool
cse_binarywise
(
s
td
::
shared_ptr
<
Node
>
a
,
std
::
shared_ptr
<
Node
>
b
)
static
bool
cse_binarywise
(
s
hared_ptr
<
Node
>
a
,
shared_ptr
<
Node
>
b
)
{
NGRAPH_DEBUG
<<
"In cse_binary for "
<<
a
->
get_name
()
<<
" and "
<<
b
->
get_name
();
...
...
@@ -117,23 +118,21 @@ static bool cse_binarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
(
a
->
get_argument
(
1
)
==
b
->
get_argument
(
0
)
&&
a
->
get_argument
(
0
)
==
b
->
get_argument
(
1
));
}
static
bool
cse_reduction
(
s
td
::
shared_ptr
<
Node
>
a
,
std
::
shared_ptr
<
Node
>
b
)
static
bool
cse_reduction
(
s
hared_ptr
<
Node
>
a
,
shared_ptr
<
Node
>
b
)
{
NGRAPH_DEBUG
<<
"In cse_reduction for "
<<
a
->
get_name
()
<<
" and "
<<
b
->
get_name
();
auto
ar_a
=
st
d
::
st
atic_pointer_cast
<
op
::
util
::
ArithmeticReduction
>
(
a
);
auto
ar_b
=
st
d
::
st
atic_pointer_cast
<
op
::
util
::
ArithmeticReduction
>
(
b
);
auto
ar_a
=
static_pointer_cast
<
op
::
util
::
ArithmeticReduction
>
(
a
);
auto
ar_b
=
static_pointer_cast
<
op
::
util
::
ArithmeticReduction
>
(
b
);
return
ar_a
->
get_argument
(
0
)
==
ar_b
->
get_argument
(
0
)
&&
ar_a
->
get_reduction_axes
()
==
ar_b
->
get_reduction_axes
();
}
static
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>
)
>>
static
unordered_map
<
type_index
,
function
<
bool
(
shared_ptr
<
Node
>
,
shared_ptr
<
Node
>
)
>>
initialize_ops_to_cse_handlers
()
{
return
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>
)
>>
(
return
unordered_map
<
type_index
,
function
<
bool
(
shared_ptr
<
Node
>
,
shared_ptr
<
Node
>
)
>>
(
{{
TI
(
op
::
Abs
),
cse_unarywise
},
{
TI
(
op
::
Acos
),
cse_unarywise
},
{
TI
(
op
::
Asin
),
cse_unarywise
},
...
...
@@ -168,23 +167,21 @@ static std::unordered_map<std::type_index,
{
TI
(
op
::
Broadcast
),
cse_broadcast
}});
}
static
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>
)
>>
static
unordered_map
<
type_index
,
function
<
bool
(
shared_ptr
<
Node
>
,
shared_ptr
<
Node
>
)
>>
ops_to_cse_handlers
=
initialize_ops_to_cse_handlers
();
class
NodeKey
{
public
:
NodeKey
(
std
::
shared_ptr
<
Node
>
n
,
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>
)
>>&
NodeKey
(
shared_ptr
<
Node
>
n
,
unordered_map
<
type_index
,
function
<
bool
(
shared_ptr
<
Node
>
,
shared_ptr
<
Node
>
)
>>&
backend_handlers
)
:
m_node
(
n
)
,
m_backend_handlers
(
backend_handlers
)
{
}
s
td
::
s
hared_ptr
<
Node
>
get_node
()
const
{
return
m_node
;
}
shared_ptr
<
Node
>
get_node
()
const
{
return
m_node
;
}
bool
operator
==
(
const
NodeKey
&
other
)
const
{
Node
&
p_this
=
*
m_node
.
get
();
...
...
@@ -215,9 +212,8 @@ public:
}
private
:
std
::
shared_ptr
<
Node
>
m_node
;
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>
)
>>&
shared_ptr
<
Node
>
m_node
;
unordered_map
<
type_index
,
function
<
bool
(
shared_ptr
<
Node
>
,
shared_ptr
<
Node
>
)
>>&
m_backend_handlers
;
};
...
...
@@ -226,15 +222,15 @@ namespace std
template
<>
struct
hash
<
NodeKey
>
{
s
td
::
s
ize_t
operator
()(
const
NodeKey
&
k
)
const
size_t
operator
()(
const
NodeKey
&
k
)
const
{
Node
&
p_this
=
*
k
.
get_node
().
get
();
auto
ti
=
TI
(
p_this
);
std
::
hash
<
std
::
type_index
>
type_hash_compute
{};
hash
<
type_index
>
type_hash_compute
{};
auto
type_hash
=
type_hash_compute
(
ti
);
std
::
vector
<
size_t
>
arg_ids
;
vector
<
size_t
>
arg_ids
;
arg_ids
.
push_back
(
type_hash
);
...
...
@@ -244,7 +240,7 @@ namespace std
// specify how to compute hash for each op?
if
(
p_this
.
is_commutative
())
{
s
td
::
s
ort
(
begin
(
cargs
),
end
(
cargs
));
sort
(
begin
(
cargs
),
end
(
cargs
));
}
for
(
auto
arg
:
cargs
)
...
...
@@ -258,11 +254,10 @@ namespace std
};
}
bool
ngraph
::
pass
::
CommonSubexpressionElimination
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
bool
ngraph
::
pass
::
CommonSubexpressionElimination
::
run_on_function
(
shared_ptr
<
ngraph
::
Function
>
f
)
{
bool
replaced
=
false
;
std
::
unordered_map
<
NodeKey
,
std
::
shared_ptr
<
Node
>>
expressions
{};
unordered_map
<
NodeKey
,
shared_ptr
<
Node
>>
expressions
{};
for
(
auto
n
:
f
->
get_ordered_ops
())
{
...
...
@@ -279,7 +274,7 @@ bool ngraph::pass::CommonSubexpressionElimination::run_on_function(
}
else
{
expressions
.
insert
(
std
::
make_pair
(
n_key
,
n
));
expressions
.
insert
(
make_pair
(
n_key
,
n
));
}
}
...
...
src/ngraph/pass/graph_rewrite.cpp
View file @
13fc556e
...
...
@@ -24,6 +24,9 @@
#include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp"
using
namespace
std
;
using
namespace
ngraph
;
// GraphRewrite algorithm:
// GraphRewrite processes an input graph in an topological order(i.e. args before users)
// Given the following graph: Abs2
...
...
@@ -56,16 +59,16 @@
// c) there's no linear order of fusions which will give
// the correct final fusion. i.e. the same fusion needs to occur before and after some other fusion
bool
ngraph
::
pass
::
GraphRewrite
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
bool
pass
::
GraphRewrite
::
run_on_function
(
shared_ptr
<
Function
>
f
)
{
bool
rewritten
=
false
;
const
size_t
NUM_TRIES
=
10
;
size_t
tries
=
NUM_TRIES
;
std
::
vector
<
std
::
shared_ptr
<
pattern
::
Matcher
>>
original_matchers
{
m_matchers
};
vector
<
shared_ptr
<
pattern
::
Matcher
>>
original_matchers
{
m_matchers
};
do
{
rewritten
=
false
;
std
::
vector
<
std
::
shared_ptr
<
pattern
::
Matcher
>>
matchers
{
m_matchers
};
vector
<
shared_ptr
<
pattern
::
Matcher
>>
matchers
{
m_matchers
};
m_matchers
.
clear
();
for
(
auto
node
:
f
->
get_ordered_ops
())
{
...
...
@@ -92,31 +95,31 @@ bool ngraph::pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Functio
return
(
NUM_TRIES
-
tries
)
>
1
;
//this means a graph was transformed
}
static
const
std
::
vector
<
std
::
regex
>
initialize_fusion_regexes
()
static
const
vector
<
regex
>
initialize_fusion_regexes
()
{
const
char
*
cnsf
=
std
::
getenv
(
"NGRAPH_DISABLED_FUSIONS"
);
std
::
vector
<
std
::
regex
>
regexes
;
const
char
*
cnsf
=
getenv
(
"NGRAPH_DISABLED_FUSIONS"
);
vector
<
regex
>
regexes
;
if
(
cnsf
)
{
const
st
d
::
st
ring
nsf
=
cnsf
;
const
auto
sregexes
=
ngraph
::
split
(
nsf
,
';'
);
const
string
nsf
=
cnsf
;
const
auto
sregexes
=
split
(
nsf
,
';'
);
std
::
transform
(
sregexes
.
begin
(),
sregexes
.
end
(),
std
::
back_inserter
(
regexes
),
[](
const
std
::
string
&
c
)
->
std
::
regex
{
return
std
::
regex
(
c
);
});
transform
(
sregexes
.
begin
(),
sregexes
.
end
(),
back_inserter
(
regexes
),
[](
const
string
&
c
)
->
regex
{
return
regex
(
c
);
});
}
return
regexes
;
}
bool
ngraph
::
pass
::
GraphRewrite
::
is_enabled
(
std
::
shared_ptr
<
pattern
::
Matcher
>
m
)
bool
pass
::
GraphRewrite
::
is_enabled
(
shared_ptr
<
pattern
::
Matcher
>
m
)
{
//note, regexes are static to avoid re-initialization
static
const
auto
regexes
=
initialize_fusion_regexes
();
for
(
const
auto
&
regex
:
regexes
)
{
if
(
std
::
regex_match
(
m
->
get_name
(),
regex
))
if
(
regex_match
(
m
->
get_name
(),
regex
))
{
NGRAPH_DEBUG
<<
"Disabling matcher "
<<
m
->
get_name
();
return
false
;
...
...
@@ -126,7 +129,7 @@ bool ngraph::pass::GraphRewrite::is_enabled(std::shared_ptr<pattern::Matcher> m)
return
true
;
}
void
ngraph
::
pass
::
GraphRewrite
::
add_matcher
(
std
::
shared_ptr
<
pattern
::
Matcher
>
m
)
void
pass
::
GraphRewrite
::
add_matcher
(
shared_ptr
<
pattern
::
Matcher
>
m
)
{
if
(
is_enabled
(
m
))
{
...
...
@@ -134,7 +137,7 @@ void ngraph::pass::GraphRewrite::add_matcher(std::shared_ptr<pattern::Matcher> m
}
}
bool
ngraph
::
pass
::
RecurrentGraphRewrite
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
bool
pass
::
RecurrentGraphRewrite
::
run_on_function
(
shared_ptr
<
Function
>
f
)
{
bool
changed
=
false
;
size_t
i
=
0
;
...
...
src/ngraph/pass/like_replacement.cpp
View file @
13fc556e
...
...
@@ -30,27 +30,28 @@
#include "ngraph/op/sum.hpp"
#include "ngraph/util.hpp"
#define TI(x) std::type_index(typeid(x))
using
namespace
std
;
using
namespace
ngraph
;
#define HANDLER_DECL(x) static bool x(const std::shared_ptr<ngraph::Node>& node)
#define TI(x) type_index(typeid(x))
#define HANDLER_DECL(x) static bool x(const shared_ptr<Node>& node)
HANDLER_DECL
(
replace_broadcast_like
)
{
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument
auto
broadcast_like
=
std
::
static_pointer_cast
<
ngraph
::
op
::
BroadcastLike
>
(
node
);
ngraph
::
replace_node
(
node
,
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
broadcast_like
->
get_argument
(
0
),
broadcast_like
->
get_broadcast_shape
(),
broadcast_like
->
get_broadcast_axes
()));
auto
broadcast_like
=
static_pointer_cast
<
op
::
BroadcastLike
>
(
node
);
replace_node
(
node
,
make_shared
<
op
::
Broadcast
>
(
broadcast_like
->
get_argument
(
0
),
broadcast_like
->
get_broadcast_shape
(),
broadcast_like
->
get_broadcast_axes
()));
return
true
;
}
static
const
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
)
>>
dispatcher
{{
TI
(
ngraph
::
op
::
BroadcastLike
),
&
replace_broadcast_like
}};
static
const
unordered_map
<
type_index
,
function
<
bool
(
const
shared_ptr
<
Node
>&
)
>>
dispatcher
{
{
TI
(
op
::
BroadcastLike
),
&
replace_broadcast_like
}};
bool
ngraph
::
pass
::
LikeReplacement
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
function
)
bool
pass
::
LikeReplacement
::
run_on_function
(
shared_ptr
<
Function
>
function
)
{
bool
clobbered
=
false
;
...
...
@@ -66,10 +67,10 @@ bool ngraph::pass::LikeReplacement::run_on_function(std::shared_ptr<ngraph::Func
// Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle.
auto
sclb
=
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
ScalarConstantLikeBase
>
(
n
);
auto
sclb
=
dynamic_pointer_cast
<
op
::
ScalarConstantLikeBase
>
(
n
);
if
(
sclb
!=
nullptr
)
{
ngraph
::
replace_node
(
sclb
,
sclb
->
as_constant
());
replace_node
(
sclb
,
sclb
->
as_constant
());
clobbered
=
true
;
}
}
...
...
src/ngraph/pass/liveness.cpp
View file @
13fc556e
...
...
@@ -33,7 +33,7 @@
using
namespace
std
;
using
namespace
ngraph
;
bool
pass
::
Liveness
::
run_on_function
(
shared_ptr
<
ngraph
::
Function
>
function
)
bool
pass
::
Liveness
::
run_on_function
(
shared_ptr
<
Function
>
function
)
{
list
<
shared_ptr
<
Node
>>
ops
=
function
->
get_ordered_ops
();
...
...
src/ngraph/pass/manager.cpp
View file @
13fc556e
...
...
@@ -35,7 +35,7 @@
using
namespace
std
;
using
namespace
ngraph
;
ngraph
::
pass
::
Manager
::
Manager
()
pass
::
Manager
::
Manager
()
{
static
const
auto
nevt
=
std
::
getenv
(
"NGRAPH_ENABLE_VISUALIZE_TRACING"
);
if
(
nevt
)
...
...
@@ -49,15 +49,15 @@ ngraph::pass::Manager::Manager()
}
}
ngraph
::
pass
::
Manager
::~
Manager
()
pass
::
Manager
::~
Manager
()
{
}
void
ngraph
::
pass
::
Manager
::
initialize_default_passes
()
void
pass
::
Manager
::
initialize_default_passes
()
{
}
void
ngraph
::
pass
::
Manager
::
run_passes
(
shared_ptr
<
Function
>
func
,
bool
transitive
)
void
pass
::
Manager
::
run_passes
(
shared_ptr
<
Function
>
func
,
bool
transitive
)
{
bool
profile_enabled
=
getenv
(
"NGRAPH_PROFILE_PASS_ENABLE"
)
!=
nullptr
;
...
...
@@ -167,7 +167,7 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func, bool transitiv
}
}
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
Manager
::
get_state
()
pass
::
ManagerState
&
pass
::
Manager
::
get_state
()
{
return
m_state
;
}
src/ngraph/pass/manager_state.cpp
View file @
13fc556e
...
...
@@ -25,7 +25,7 @@
using
namespace
std
;
using
namespace
ngraph
;
const
vector
<
shared_ptr
<
Function
>>&
ngraph
::
pass
::
ManagerState
::
get_functions
()
const
vector
<
shared_ptr
<
Function
>>&
pass
::
ManagerState
::
get_functions
()
{
return
m_function_list
;
}
src/ngraph/pass/memory_layout.cpp
View file @
13fc556e
...
...
@@ -40,7 +40,7 @@ pass::MemoryLayout::MemoryLayout(size_t alignment, bool disable_memory_sharing)
}
}
bool
pass
::
MemoryLayout
::
run_on_function
(
shared_ptr
<
ngraph
::
Function
>
function
)
bool
pass
::
MemoryLayout
::
run_on_function
(
shared_ptr
<
Function
>
function
)
{
MemoryManager
mm
(
m_alignment
,
m_disable_memory_sharing
);
for
(
shared_ptr
<
Node
>
node
:
function
->
get_ordered_ops
())
...
...
src/ngraph/pass/memory_visualize.cpp
View file @
13fc556e
...
...
@@ -34,7 +34,7 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{
}
bool
pass
::
MemoryVisualize
::
run_on_module
(
vector
<
shared_ptr
<
ngraph
::
Function
>>&
functions
)
bool
pass
::
MemoryVisualize
::
run_on_module
(
vector
<
shared_ptr
<
Function
>>&
functions
)
{
ofstream
file
(
m_filename
);
{
...
...
src/ngraph/pass/nop_elimination.cpp
View file @
13fc556e
...
...
@@ -30,94 +30,93 @@
#include "ngraph/util.hpp"
#include "nop_elimination.hpp"
#define TI(x) std::type_index(typeid(x))
using
namespace
std
;
using
namespace
ngraph
;
#define
HANDLER_DECL(x) static bool x(const std::shared_ptr<ngraph::Node>& node
)
#define
TI(x) std::type_index(typeid(x)
)
HANDLER_DECL
(
eliminate_pad
)
static
bool
eliminate_pad
(
const
std
::
shared_ptr
<
Node
>&
node
)
{
auto
pad
=
std
::
static_pointer_cast
<
ngraph
::
op
::
Pad
>
(
node
);
auto
pad
=
std
::
static_pointer_cast
<
op
::
Pad
>
(
node
);
if
(
pad
->
get_input_shape
(
0
)
==
pad
->
get_output_shape
(
0
))
{
ngraph
::
replace_node
(
node
,
node
->
get_argument
(
0
));
replace_node
(
node
,
node
->
get_argument
(
0
));
return
true
;
}
return
false
;
}
HANDLER_DECL
(
eliminate_sum
)
static
bool
eliminate_sum
(
const
std
::
shared_ptr
<
Node
>&
node
)
{
auto
sum
=
std
::
static_pointer_cast
<
ngraph
::
op
::
Sum
>
(
node
);
auto
sum
=
std
::
static_pointer_cast
<
op
::
Sum
>
(
node
);
if
(
sum
->
get_reduction_axes
().
empty
())
{
ngraph
::
replace_node
(
node
,
node
->
get_argument
(
0
));
replace_node
(
node
,
node
->
get_argument
(
0
));
return
true
;
}
return
false
;
}
HANDLER_DECL
(
eliminate_convert
)
static
bool
eliminate_convert
(
const
std
::
shared_ptr
<
Node
>&
node
)
{
auto
convert
=
std
::
static_pointer_cast
<
ngraph
::
op
::
Convert
>
(
node
);
auto
convert
=
std
::
static_pointer_cast
<
op
::
Convert
>
(
node
);
if
(
convert
->
get_convert_element_type
()
==
convert
->
get_argument
(
0
)
->
get_element_type
())
{
ngraph
::
replace_node
(
node
,
node
->
get_argument
(
0
));
replace_node
(
node
,
node
->
get_argument
(
0
));
return
true
;
}
return
false
;
}
HANDLER_DECL
(
eliminate_slic
e
)
static
bool
eliminate_slice
(
const
std
::
shared_ptr
<
Node
>&
nod
e
)
{
auto
slice
=
std
::
static_pointer_cast
<
ngraph
::
op
::
Slice
>
(
node
);
auto
slice
=
std
::
static_pointer_cast
<
op
::
Slice
>
(
node
);
if
(
slice
->
get_input_shape
(
0
)
==
slice
->
get_output_shape
(
0
))
{
ngraph
::
replace_node
(
node
,
node
->
get_argument
(
0
));
replace_node
(
node
,
node
->
get_argument
(
0
));
return
true
;
}
return
false
;
}
HANDLER_DECL
(
replace_broadcast_lik
e
)
static
bool
replace_broadcast_like
(
const
std
::
shared_ptr
<
Node
>&
nod
e
)
{
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument
auto
broadcast_like
=
std
::
static_pointer_cast
<
ngraph
::
op
::
BroadcastLike
>
(
node
);
ngraph
::
replace_node
(
node
,
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
broadcast_like
->
get_argument
(
0
),
broadcast_like
->
get_broadcast_shape
(),
broadcast_like
->
get_broadcast_axes
()));
auto
broadcast_like
=
std
::
static_pointer_cast
<
op
::
BroadcastLike
>
(
node
);
replace_node
(
node
,
std
::
make_shared
<
op
::
Broadcast
>
(
broadcast_like
->
get_argument
(
0
),
broadcast_like
->
get_broadcast_shape
(),
broadcast_like
->
get_broadcast_axes
()));
return
true
;
}
HANDLER_DECL
(
eliminate_broadcast
)
static
bool
eliminate_broadcast
(
const
std
::
shared_ptr
<
Node
>&
node
)
{
auto
broadcast
=
std
::
static_pointer_cast
<
ngraph
::
op
::
Broadcast
>
(
node
);
auto
broadcast
=
std
::
static_pointer_cast
<
op
::
Broadcast
>
(
node
);
if
(
broadcast
->
get_input_shape
(
0
)
==
broadcast
->
get_output_shape
(
0
))
{
ngraph
::
replace_node
(
node
,
node
->
get_argument
(
0
));
replace_node
(
node
,
node
->
get_argument
(
0
));
return
true
;
}
return
false
;
}
HANDLER_DECL
(
eliminate_stop_gradient
)
static
bool
eliminate_stop_gradient
(
const
std
::
shared_ptr
<
Node
>&
node
)
{
ngraph
::
replace_node
(
node
,
node
->
get_argument
(
0
));
replace_node
(
node
,
node
->
get_argument
(
0
));
return
true
;
}
static
const
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
)
>>
dispatcher
{{
TI
(
ngraph
::
op
::
Pad
),
&
eliminate_pad
},
{
TI
(
ngraph
::
op
::
Sum
),
&
eliminate_sum
},
{
TI
(
ngraph
::
op
::
Convert
),
&
eliminate_convert
},
{
TI
(
ngraph
::
op
::
Slice
),
&
eliminate_slice
},
{
TI
(
ngraph
::
op
::
StopGradient
),
&
eliminate_stop_gradient
},
{
TI
(
ngraph
::
op
::
BroadcastLike
),
&
replace_broadcast_like
},
{
TI
(
ngraph
::
op
::
Broadcast
),
&
eliminate_broadcast
}};
bool
ngraph
::
pass
::
NopElimination
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
function
)
static
const
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
const
std
::
shared_ptr
<
Node
>&
)
>>
dispatcher
{{
TI
(
op
::
Pad
),
&
eliminate_pad
},
{
TI
(
op
::
Sum
),
&
eliminate_sum
},
{
TI
(
op
::
Convert
),
&
eliminate_convert
},
{
TI
(
op
::
Slice
),
&
eliminate_slice
},
{
TI
(
op
::
StopGradient
),
&
eliminate_stop_gradient
},
{
TI
(
op
::
BroadcastLike
),
&
replace_broadcast_like
},
{
TI
(
op
::
Broadcast
),
&
eliminate_broadcast
}};
bool
pass
::
NopElimination
::
run_on_function
(
std
::
shared_ptr
<
Function
>
function
)
{
bool
clobbered
=
false
;
...
...
@@ -133,10 +132,10 @@ bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Funct
// Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle.
auto
sclb
=
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
ScalarConstantLikeBase
>
(
n
);
auto
sclb
=
std
::
dynamic_pointer_cast
<
op
::
ScalarConstantLikeBase
>
(
n
);
if
(
sclb
!=
nullptr
)
{
ngraph
::
replace_node
(
sclb
,
sclb
->
as_constant
());
replace_node
(
sclb
,
sclb
->
as_constant
());
clobbered
=
true
;
}
}
...
...
src/ngraph/pass/pass.cpp
View file @
13fc556e
...
...
@@ -17,12 +17,15 @@
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/manager.hpp"
ngraph
::
pass
::
ManagerState
&
ngraph
::
pass
::
PassBase
::
get_state
()
using
namespace
std
;
using
namespace
ngraph
;
pass
::
ManagerState
&
pass
::
PassBase
::
get_state
()
{
return
*
m_state
;
}
void
ngraph
::
pass
::
PassBase
::
set_state
(
ManagerState
&
state
)
void
pass
::
PassBase
::
set_state
(
ManagerState
&
state
)
{
m_state
=
&
state
;
}
src/ngraph/pass/pass_config.cpp
View file @
13fc556e
...
...
@@ -19,10 +19,11 @@
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
// TODO: Add file-based configuration support
ngraph
::
pass
::
PassConfig
::
PassConfig
(
ngraph
::
pass
::
CompilationMode
mode
)
pass
::
PassConfig
::
PassConfig
(
pass
::
CompilationMode
mode
)
:
m_compilation_mode
(
mode
)
{
/**
...
...
@@ -32,15 +33,15 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
* E.g., NGRAPH_PASS_ENABLES="CoreFusion:0;LikeReplacement:1;CPUCollapseDims" would
* set disables on CoreFusion and enables on LikeReplacement and CPUCollapseDims
**/
const
char
*
env_str
=
std
::
getenv
(
"NGRAPH_PASS_ENABLES"
);
const
char
*
env_str
=
getenv
(
"NGRAPH_PASS_ENABLES"
);
if
(
env_str
)
{
st
d
::
st
ringstream
ss
;
stringstream
ss
;
ss
<<
env_str
;
while
(
ss
.
good
())
{
st
d
::
st
ring
substr
;
std
::
getline
(
ss
,
substr
,
';'
);
string
substr
;
getline
(
ss
,
substr
,
';'
);
auto
split_str
=
split
(
substr
,
':'
,
false
);
switch
(
split_str
.
size
())
{
...
...
@@ -58,15 +59,15 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
* would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on
* "UseDefaultLayouts"
**/
env_str
=
std
::
getenv
(
"NGRAPH_PASS_ATTRIBUTES"
);
env_str
=
getenv
(
"NGRAPH_PASS_ATTRIBUTES"
);
if
(
env_str
)
{
st
d
::
st
ringstream
ss
;
stringstream
ss
;
ss
<<
env_str
;
while
(
ss
.
good
())
{
st
d
::
st
ring
substr
;
std
::
getline
(
ss
,
substr
,
';'
);
string
substr
;
getline
(
ss
,
substr
,
';'
);
auto
split_str
=
split
(
substr
,
'='
,
false
);
switch
(
split_str
.
size
())
{
...
...
@@ -80,12 +81,12 @@ ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
}
}
void
ngraph
::
pass
::
PassConfig
::
set_pass_enable
(
std
::
string
name
,
bool
enable
)
void
pass
::
PassConfig
::
set_pass_enable
(
string
name
,
bool
enable
)
{
m_pass_enables
[
name
]
=
enable
;
}
bool
ngraph
::
pass
::
PassConfig
::
get_pass_enable
(
std
::
string
name
)
bool
pass
::
PassConfig
::
get_pass_enable
(
string
name
)
{
if
(
m_pass_enables
.
find
(
name
)
==
m_pass_enables
.
end
())
{
...
...
@@ -94,12 +95,12 @@ bool ngraph::pass::PassConfig::get_pass_enable(std::string name)
return
m_pass_enables
[
name
];
}
void
ngraph
::
pass
::
PassConfig
::
set_pass_attribute
(
std
::
string
name
,
bool
enable
)
void
pass
::
PassConfig
::
set_pass_attribute
(
string
name
,
bool
enable
)
{
m_pass_attributes
[
name
]
=
enable
;
}
bool
ngraph
::
pass
::
PassConfig
::
get_pass_attribute
(
std
::
string
name
)
bool
pass
::
PassConfig
::
get_pass_attribute
(
string
name
)
{
if
(
m_pass_attributes
.
find
(
name
)
==
m_pass_attributes
.
end
())
{
...
...
src/ngraph/pass/prefix_reshape_elimination.cpp
View file @
13fc556e
...
...
@@ -24,14 +24,17 @@
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
ngraph
::
pass
::
PrefixReshapeElimination
::
PrefixReshapeElimination
()
using
namespace
std
;
using
namespace
ngraph
;
pass
::
PrefixReshapeElimination
::
PrefixReshapeElimination
()
{
auto
src_op
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i8
,
Shape
{},
[](
s
td
::
s
hared_ptr
<
Node
>
)
{
return
true
;
});
auto
reshape_op
=
std
::
make_shared
<
pattern
::
op
::
Any
>
(
auto
src_op
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i8
,
Shape
{},
[](
shared_ptr
<
Node
>
)
{
return
true
;
});
auto
reshape_op
=
make_shared
<
pattern
::
op
::
Any
>
(
element
::
i8
,
Shape
{},
[](
s
td
::
s
hared_ptr
<
Node
>
node
)
{
[](
shared_ptr
<
Node
>
node
)
{
op
::
Reshape
*
reshape
=
dynamic_cast
<
op
::
Reshape
*>
(
node
.
get
());
if
(
!
reshape
)
{
...
...
@@ -46,14 +49,14 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
// Make sure that logical dimension sizes match.
const
Shape
&
src_shape
=
reshape
->
get_input_shape
(
0
);
for
(
s
td
::
s
ize_t
idx
=
0
;
idx
<
reshape
->
get_output_shape
().
size
();
++
idx
)
for
(
size_t
idx
=
0
;
idx
<
reshape
->
get_output_shape
().
size
();
++
idx
)
{
s
td
::
s
ize_t
src_size
=
1
;
size_t
src_size
=
1
;
if
(
idx
<
src_shape
.
size
())
{
src_size
=
src_shape
.
at
(
src_shape
.
size
()
-
1
-
idx
);
}
s
td
::
s
ize_t
dest_size
=
size_t
dest_size
=
reshape
->
get_output_shape
().
at
(
reshape
->
get_output_shape
().
size
()
-
1
-
idx
);
if
(
dest_size
!=
src_size
)
{
...
...
@@ -64,10 +67,10 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
return
true
;
},
NodeVector
{
src_op
});
auto
target_op
=
std
::
make_shared
<
pattern
::
op
::
AnyOf
>
(
auto
target_op
=
make_shared
<
pattern
::
op
::
AnyOf
>
(
element
::
i8
,
Shape
{},
[](
s
td
::
s
hared_ptr
<
Node
>
node
)
{
[](
shared_ptr
<
Node
>
node
)
{
return
pattern
::
has_class
<
op
::
Reshape
>
()(
node
)
||
pattern
::
has_class
<
op
::
util
::
UnaryElementwiseArithmetic
>
()(
node
)
||
pattern
::
has_class
<
op
::
util
::
BinaryElementwiseArithmetic
>
()(
node
);
...
...
@@ -78,5 +81,5 @@ ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
replace_node
(
m
.
get_matched_nodes
().
at
(
1
),
m
.
get_matched_nodes
().
at
(
2
));
return
true
;
};
add_matcher
(
std
::
make_shared
<
pattern
::
Matcher
>
(
target_op
,
callback
));
add_matcher
(
make_shared
<
pattern
::
Matcher
>
(
target_op
,
callback
));
}
src/ngraph/pass/propagate_cacheability.cpp
View file @
13fc556e
...
...
@@ -22,15 +22,16 @@
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
using
namespace
std
;
using
namespace
ngraph
;
bool
ngraph
::
pass
::
PropagateCacheability
::
run_on_function
(
std
::
shared_ptr
<
Function
>
function
)
bool
pass
::
PropagateCacheability
::
run_on_function
(
shared_ptr
<
Function
>
function
)
{
for
(
auto
&
node
:
function
->
get_ordered_ops
())
{
if
(
node
->
is_op
())
{
auto
op
=
st
d
::
st
atic_pointer_cast
<
op
::
Op
>
(
node
);
auto
op
=
static_pointer_cast
<
op
::
Op
>
(
node
);
NGRAPH_DEBUG
<<
"propagate cacheability: node is "
<<
node
->
get_name
();
auto
op_annotations
=
op
->
get_op_annotations
();
if
(
!
op_annotations
)
...
...
@@ -41,7 +42,7 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi
}
if
(
node
->
is_parameter
())
{
auto
parameter
=
st
d
::
st
atic_pointer_cast
<
op
::
Parameter
>
(
node
);
auto
parameter
=
static_pointer_cast
<
op
::
Parameter
>
(
node
);
op_annotations
->
set_cacheable
(
parameter
->
get_cacheable
());
NGRAPH_DEBUG
<<
"propagate cacheability: cacheability is "
<<
parameter
->
get_cacheable
();
...
...
@@ -54,7 +55,7 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi
NGRAPH_DEBUG
<<
"propagate cacheability: arg is "
<<
arg
->
get_name
();
if
(
arg
->
is_op
())
{
auto
arg_op
=
st
d
::
st
atic_pointer_cast
<
op
::
Op
>
(
arg
);
auto
arg_op
=
static_pointer_cast
<
op
::
Op
>
(
arg
);
auto
arg_op_annotations
=
arg_op
->
get_op_annotations
();
NGRAPH_ASSERT
(
arg_op_annotations
);
if
(
!
arg_op_annotations
->
is_cacheable
())
...
...
src/ngraph/pass/reshape_elimination.cpp
View file @
13fc556e
...
...
@@ -33,17 +33,19 @@
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/util.hpp"
extern
template
ngraph
::
AxisVector
ngraph
::
apply_permutation
<
ngraph
::
AxisVector
>
(
ngraph
::
AxisVector
input
,
ngraph
::
AxisVector
order
);
using
namespace
std
;
using
namespace
ngraph
;
void
ngraph
::
pass
::
ReshapeElimination
::
construct_identity_reshape_pattern
()
extern
template
AxisVector
ngraph
::
apply_permutation
<
AxisVector
>
(
AxisVector
input
,
AxisVector
order
);
void
pass
::
ReshapeElimination
::
construct_identity_reshape_pattern
()
{
Shape
shape_op
{
3
};
Shape
shape_r1
{
1
,
3
};
auto
op
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
shape_op
);
auto
reshape1
=
std
::
make_shared
<
op
::
Reshape
>
(
op
,
AxisVector
{
0
},
shape_r1
);
auto
op
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
shape_op
);
auto
reshape1
=
make_shared
<
op
::
Reshape
>
(
op
,
AxisVector
{
0
},
shape_r1
);
auto
callback
=
[
op
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In callback for construct_identity_reshape_pattern against node = "
...
...
@@ -51,7 +53,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
auto
pattern_map
=
m
.
get_pattern_map
();
auto
gop
=
pattern_map
[
op
];
auto
r1
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
get_match_root
());
auto
r1
=
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
get_match_root
());
if
(
r1
->
get_shape
()
!=
gop
->
get_shape
())
{
...
...
@@ -59,7 +61,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
return
false
;
}
auto
do_r1
=
ngraph
::
get_default_order
(
r1
->
get_shape
());
auto
do_r1
=
get_default_order
(
r1
->
get_shape
());
if
(
do_r1
!=
r1
->
get_input_order
())
{
...
...
@@ -67,22 +69,22 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
return
false
;
}
ngraph
::
replace_node
(
m
.
get_match_root
(),
gop
);
replace_node
(
m
.
get_match_root
(),
gop
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape1
,
callback
);
auto
m
=
make_shared
<
pattern
::
Matcher
>
(
reshape1
,
callback
);
this
->
add_matcher
(
m
);
}
void
ngraph
::
pass
::
ReshapeElimination
::
construct_reshapex2_pattern
()
void
pass
::
ReshapeElimination
::
construct_reshapex2_pattern
()
{
Shape
shape_op
{
3
};
Shape
shape_r1
{
1
,
3
};
auto
op
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
shape_op
);
auto
reshape1
=
std
::
make_shared
<
op
::
Reshape
>
(
op
,
AxisVector
{
0
},
shape_r1
);
auto
reshape2
=
std
::
make_shared
<
op
::
Reshape
>
(
reshape1
,
AxisVector
{
0
,
1
},
shape_op
);
auto
op
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
shape_op
);
auto
reshape1
=
make_shared
<
op
::
Reshape
>
(
op
,
AxisVector
{
0
},
shape_r1
);
auto
reshape2
=
make_shared
<
op
::
Reshape
>
(
reshape1
,
AxisVector
{
0
,
1
},
shape_op
);
auto
callback
=
[
op
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In callback for construct_reshapex2_pattern against node = "
...
...
@@ -101,11 +103,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
return
false
;
}
auto
r2
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
get_match_root
());
auto
r1
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
r2
->
get_argument
(
0
));
auto
r2
=
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
get_match_root
());
auto
r1
=
dynamic_pointer_cast
<
op
::
Reshape
>
(
r2
->
get_argument
(
0
));
auto
do_r2
=
ngraph
::
get_default_order
(
r1
->
get_shape
());
auto
do_r1
=
ngraph
::
get_default_order
(
gop
->
get_shape
());
auto
do_r2
=
get_default_order
(
r1
->
get_shape
());
auto
do_r1
=
get_default_order
(
gop
->
get_shape
());
NGRAPH_DEBUG
<<
"r1's i/o = "
<<
vector_to_string
(
r1
->
get_input_order
())
<<
"do_r1 = "
<<
vector_to_string
(
do_r1
);
...
...
@@ -115,40 +117,40 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if
(
r1
->
get_input_order
()
==
do_r1
&&
r2
->
get_input_order
()
==
do_r2
)
{
NGRAPH_DEBUG
<<
"Two reshapes were removed!"
;
ngraph
::
replace_node
(
m
.
get_match_root
(),
gop
);
replace_node
(
m
.
get_match_root
(),
gop
);
return
true
;
}
auto
perm1
=
ngraph
::
apply_permutation
(
do_r1
,
r1
->
get_input_order
());
auto
perm2
=
ngraph
::
apply_permutation
(
perm1
,
r2
->
get_input_order
());
auto
perm1
=
apply_permutation
(
do_r1
,
r1
->
get_input_order
());
auto
perm2
=
apply_permutation
(
perm1
,
r2
->
get_input_order
());
if
(
perm2
==
do_r1
)
{
NGRAPH_DEBUG
<<
"Two transposes were removed!"
;
ngraph
::
replace_node
(
m
.
get_match_root
(),
gop
);
replace_node
(
m
.
get_match_root
(),
gop
);
return
true
;
}
return
false
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape2
,
callback
);
auto
m
=
make_shared
<
pattern
::
Matcher
>
(
reshape2
,
callback
);
this
->
add_matcher
(
m
);
}
void
ngraph
::
pass
::
ReshapeElimination
::
construct_dot_transpose_pattern
()
void
pass
::
ReshapeElimination
::
construct_dot_transpose_pattern
()
{
// dot(A,B).T = dot (B.T, A.T)
auto
dot_pred
=
[](
s
td
::
s
hared_ptr
<
Node
>
n
)
{
return
static_cast
<
bool
>
(
std
::
dynamic_pointer_cast
<
op
::
Dot
>
(
n
));
auto
dot_pred
=
[](
shared_ptr
<
Node
>
n
)
{
return
static_cast
<
bool
>
(
dynamic_pointer_cast
<
op
::
Dot
>
(
n
));
};
auto
pdot
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
2
,
1
},
dot_pred
);
auto
preshape
=
std
::
make_shared
<
op
::
Reshape
>
(
pdot
,
AxisVector
{
1
,
0
},
Shape
{
1
,
2
});
auto
pdot
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
2
,
1
},
dot_pred
);
auto
preshape
=
make_shared
<
op
::
Reshape
>
(
pdot
,
AxisVector
{
1
,
0
},
Shape
{
1
,
2
});
ngraph
::
pattern
::
graph_rewrite_callback
callback
=
[](
pattern
::
Matcher
&
m
)
{
pattern
::
graph_rewrite_callback
callback
=
[](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In callback for construct_dot_transpose_pattern against node = "
<<
m
.
get_match_root
()
->
get_name
();
auto
mtranspose
=
st
d
::
st
atic_pointer_cast
<
op
::
Reshape
>
(
m
.
get_match_root
());
auto
mtranspose
=
static_pointer_cast
<
op
::
Reshape
>
(
m
.
get_match_root
());
// this also checks the rank
if
(
mtranspose
->
get_input_order
()
!=
AxisVector
{
1
,
0
})
{
...
...
@@ -171,7 +173,7 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
return
false
;
}
auto
reshape0_shape
=
Shape
{
arg0
->
get_shape
().
at
(
1
),
arg0
->
get_shape
().
at
(
0
)};
auto
reshape0
=
std
::
make_shared
<
op
::
Reshape
>
(
arg0
,
AxisVector
{
1
,
0
},
reshape0_shape
);
auto
reshape0
=
make_shared
<
op
::
Reshape
>
(
arg0
,
AxisVector
{
1
,
0
},
reshape0_shape
);
auto
arg1
=
mdot
->
get_argument
(
1
);
if
(
arg1
->
get_shape
().
size
()
!=
2
)
...
...
@@ -180,13 +182,13 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
return
false
;
}
auto
reshape1_shape
=
Shape
{
arg1
->
get_shape
().
at
(
1
),
arg1
->
get_shape
().
at
(
0
)};
auto
reshape1
=
std
::
make_shared
<
op
::
Reshape
>
(
arg1
,
AxisVector
{
1
,
0
},
reshape1_shape
);
auto
reshape1
=
make_shared
<
op
::
Reshape
>
(
arg1
,
AxisVector
{
1
,
0
},
reshape1_shape
);
auto
tdot
=
s
td
::
s
hared_ptr
<
Node
>
(
new
op
::
Dot
(
reshape1
,
reshape0
));
ngraph
::
replace_node
(
m
.
get_match_root
(),
tdot
);
auto
tdot
=
shared_ptr
<
Node
>
(
new
op
::
Dot
(
reshape1
,
reshape0
));
replace_node
(
m
.
get_match_root
(),
tdot
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
preshape
,
callback
);
auto
m
=
make_shared
<
pattern
::
Matcher
>
(
preshape
,
callback
);
this
->
add_matcher
(
m
);
}
src/ngraph/pass/reshape_sinking.cpp
View file @
13fc556e
...
...
@@ -39,14 +39,15 @@
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
ReshapeMap
=
std
::
unordered_map
<
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
op
::
Reshape
>>
;
using
ReshapeMap
=
unordered_map
<
shared_ptr
<
Node
>
,
shared_ptr
<
op
::
Reshape
>>
;
static
st
d
::
string
describe_reshape
(
std
::
shared_ptr
<
Node
>
node
)
static
st
ring
describe_reshape
(
shared_ptr
<
Node
>
node
)
{
st
d
::
st
ringstream
ss
;
auto
reshape
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
node
);
stringstream
ss
;
auto
reshape
=
dynamic_pointer_cast
<
op
::
Reshape
>
(
node
);
ss
<<
reshape
->
get_name
()
<<
" ( axis order = "
<<
ngraph
::
vector_to_string
(
reshape
->
get_input_order
())
<<
" , shape = "
<<
vector_to_string
(
reshape
->
get_shape
())
<<
" ) "
...
...
@@ -55,25 +56,24 @@ static std::string describe_reshape(std::shared_ptr<Node> node)
return
ss
.
str
();
}
static
s
td
::
shared_ptr
<
op
::
Reshape
>
combine_reshapes
(
std
::
shared_ptr
<
op
::
Reshape
>
r1
,
std
::
shared_ptr
<
op
::
Reshape
>
r2
)
static
s
hared_ptr
<
op
::
Reshape
>
combine_reshapes
(
shared_ptr
<
op
::
Reshape
>
r1
,
shared_ptr
<
op
::
Reshape
>
r2
)
{
auto
default_order
=
ngraph
::
get_default_order
(
r1
->
get_shape
());
auto
perm_r1
=
apply_permutation
(
default_order
,
r1
->
get_input_order
());
auto
perm_r2
=
apply_permutation
(
perm_r1
,
r2
->
get_input_order
());
auto
rreshape
=
std
::
make_shared
<
op
::
Reshape
>
(
r2
->
get_argument
(
0
),
perm_r2
,
r2
->
get_shape
());
auto
rreshape
=
make_shared
<
op
::
Reshape
>
(
r2
->
get_argument
(
0
),
perm_r2
,
r2
->
get_shape
());
return
rreshape
;
}
static
void
insert_reshape
(
std
::
shared_ptr
<
Node
>
target
,
std
::
shared_ptr
<
Node
>
reshape
,
size_t
input_index
)
static
void
insert_reshape
(
shared_ptr
<
Node
>
target
,
shared_ptr
<
Node
>
reshape
,
size_t
input_index
)
{
auto
arg
=
target
->
get_inputs
().
at
(
input_index
).
get_output
().
get_node
();
auto
new_reshape
=
reshape
->
copy_with_new_args
({
arg
});
target
->
get_inputs
().
at
(
input_index
).
replace_output
(
new_reshape
->
get_outputs
().
at
(
0
));
}
static
void
delete_reshape
(
s
td
::
s
hared_ptr
<
Node
>
reshape
)
static
void
delete_reshape
(
shared_ptr
<
Node
>
reshape
)
{
NGRAPH_DEBUG
<<
"Removing reshape "
<<
reshape
->
get_name
();
if
(
!
reshape
->
get_users
().
empty
())
...
...
@@ -82,22 +82,22 @@ static void delete_reshape(std::shared_ptr<Node> reshape)
}
}
static
void
mark_reshape_for_deletion
(
s
td
::
s
hared_ptr
<
Node
>
reshape
,
s
td
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
static
void
mark_reshape_for_deletion
(
shared_ptr
<
Node
>
reshape
,
s
et
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
NGRAPH_DEBUG
<<
"Marking reshape "
<<
reshape
->
get_name
()
<<
" for deletion"
;
reshapes_to_delete
.
insert
(
reshape
);
}
static
s
td
::
shared_ptr
<
op
::
Reshape
>
create_default_reshape
(
std
::
shared_ptr
<
Node
>
n
)
static
s
hared_ptr
<
op
::
Reshape
>
create_default_reshape
(
shared_ptr
<
Node
>
n
)
{
auto
default_order
=
ngraph
::
get_default_order
(
n
->
get_shape
());
auto
default_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
n
,
default_order
,
n
->
get_shape
());
auto
default_reshape
=
make_shared
<
op
::
Reshape
>
(
n
,
default_order
,
n
->
get_shape
());
return
default_reshape
;
}
//compute an axis order that converts the given axis order to default
static
AxisSet
get_quantization_axes_in_default_order
(
s
td
::
s
hared_ptr
<
op
::
Reshape
>
arg_reshape
,
static
AxisSet
get_quantization_axes_in_default_order
(
shared_ptr
<
op
::
Reshape
>
arg_reshape
,
const
AxisSet
&
old_axis_set
)
{
auto
perm_to_def
=
ngraph
::
get_permutation_to_default_order
(
arg_reshape
->
get_input_order
());
...
...
@@ -112,7 +112,7 @@ static AxisSet get_quantization_axes_in_default_order(std::shared_ptr<op::Reshap
struct
Swimmer
{
descriptor
::
Input
*
input
;
s
td
::
s
hared_ptr
<
op
::
Reshape
>
reshape
;
shared_ptr
<
op
::
Reshape
>
reshape
;
};
//Swim is used to push/"swim" reshapes towards paramaters.
...
...
@@ -121,10 +121,10 @@ struct Swimmer
//we prefer nchw since a lot of ngraph ops require this format,
//so keeping things in nchw allows us to eliminate as many reshapes
//as possible
void
swim
(
descriptor
::
Input
*
input
,
s
td
::
s
hared_ptr
<
op
::
Reshape
>
reshape
)
void
swim
(
descriptor
::
Input
*
input
,
shared_ptr
<
op
::
Reshape
>
reshape
)
{
Swimmer
sw
{
input
,
reshape
};
std
::
list
<
Swimmer
>
work_queue
;
list
<
Swimmer
>
work_queue
;
work_queue
.
push_back
(
sw
);
//TODO: if we support more ops (especially, with >1 args)
...
...
@@ -135,21 +135,21 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
work_queue
.
pop_front
();
auto
n
=
csw
.
input
->
get_output
().
get_node
();
NGRAPH_DEBUG
<<
"Processing (swimming) "
<<
n
->
get_name
();
if
(
auto
unary
=
std
::
dynamic_pointer_cast
<
op
::
util
::
UnaryElementwiseArithmetic
>
(
n
))
if
(
auto
unary
=
dynamic_pointer_cast
<
op
::
util
::
UnaryElementwiseArithmetic
>
(
n
))
{
Swimmer
nsw
{
&
unary
->
get_inputs
().
at
(
0
),
csw
.
reshape
};
work_queue
.
push_back
(
nsw
);
NGRAPH_DEBUG
<<
"Propagating reshape "
<<
describe_reshape
(
csw
.
reshape
)
<<
" for "
<<
n
->
get_name
()
<<
" to "
<<
unary
->
get_argument
(
0
);
}
else
if
(
std
::
dynamic_pointer_cast
<
op
::
Broadcast
>
(
n
))
else
if
(
dynamic_pointer_cast
<
op
::
Broadcast
>
(
n
))
{
auto
old_broadcast
=
st
d
::
st
atic_pointer_cast
<
op
::
Broadcast
>
(
n
);
auto
old_broadcast
=
static_pointer_cast
<
op
::
Broadcast
>
(
n
);
auto
broadcast_axes
=
old_broadcast
->
get_broadcast_axes
();
auto
broadcast_reshape
=
csw
.
reshape
;
bool
in_order
=
true
;
AxisSet
new_broadcast_axes
;
std
::
vector
<
size_t
>
new_source_axes
;
vector
<
size_t
>
new_source_axes
;
auto
input_order
=
broadcast_reshape
->
get_input_order
();
for
(
size_t
i
=
0
;
i
<
input_order
.
size
();
i
++
)
{
...
...
@@ -171,8 +171,8 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
if
(
!
in_order
)
{
AxisVector
new_source_axes_sorted
{
new_source_axes
};
s
td
::
s
ort
(
new_source_axes_sorted
.
begin
(),
new_source_axes_sorted
.
end
());
std
::
map
<
size_t
,
size_t
>
old_new_source_axes
;
sort
(
new_source_axes_sorted
.
begin
(),
new_source_axes_sorted
.
end
());
map
<
size_t
,
size_t
>
old_new_source_axes
;
for
(
size_t
i
=
0
;
new_source_axes_sorted
.
size
();
i
++
)
{
old_new_source_axes
.
insert
({
new_source_axes
.
at
(
i
),
i
});
...
...
@@ -186,11 +186,11 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
auto
new_arg_shape
=
ngraph
::
apply_permutation
(
broadcast_input
->
get_shape
(),
new_source_axis_order
);
broadcast_input
=
std
::
make_shared
<
op
::
Reshape
>
(
broadcast_input
,
new_source_axis_order
,
new_arg_shape
);
broadcast_input
=
make_shared
<
op
::
Reshape
>
(
broadcast_input
,
new_source_axis_order
,
new_arg_shape
);
}
auto
new_broadcast
=
std
::
make_shared
<
op
::
Broadcast
>
(
auto
new_broadcast
=
make_shared
<
op
::
Broadcast
>
(
broadcast_input
,
broadcast_reshape
->
get_shape
(),
new_broadcast_axes
);
csw
.
input
->
replace_output
(
new_broadcast
->
get_outputs
().
at
(
0
));
}
...
...
@@ -210,11 +210,11 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
//We have to normalize this other argument to nchw by swimming nchw towards parameters
//as far as we can
static
void
convert_binary_to_default_order
(
s
td
::
s
hared_ptr
<
Node
>
binary
,
shared_ptr
<
Node
>
binary
,
descriptor
::
Input
&
input
,
s
td
::
s
hared_ptr
<
Node
>
right
,
std
::
unordered_map
<
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
op
::
Reshape
>>&
reorders
,
s
td
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
shared_ptr
<
Node
>
right
,
unordered_map
<
shared_ptr
<
Node
>
,
shared_ptr
<
op
::
Reshape
>>&
reorders
,
s
et
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
left
=
input
.
get_output
().
get_node
();
auto
perm_to_def
=
...
...
@@ -222,7 +222,7 @@ static void convert_binary_to_default_order(
auto
new_shape
=
apply_permutation
(
left
->
get_shape
(),
perm_to_def
);
NGRAPH_DEBUG
<<
"right = "
<<
ngraph
::
vector_to_string
(
right
->
get_shape
())
<<
", "
<<
right
->
get_name
();
auto
new_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
left
,
perm_to_def
,
new_shape
);
auto
new_reshape
=
make_shared
<
op
::
Reshape
>
(
left
,
perm_to_def
,
new_shape
);
NGRAPH_DEBUG
<<
"left : About to swim "
<<
describe_reshape
(
new_reshape
)
<<
" up to "
<<
left
->
get_name
();
//this should now insert and swim reshape on right
...
...
@@ -231,9 +231,9 @@ static void convert_binary_to_default_order(
reorders
[
binary
]
=
reorders
.
at
(
right
);
}
static
void
materialize_shapes
(
s
td
::
s
hared_ptr
<
Node
>
n
,
static
void
materialize_shapes
(
shared_ptr
<
Node
>
n
,
ReshapeMap
&
reorders
,
s
td
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
s
et
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
//skip multiple output nodes and deal with GOEs exclusively
if
(
n
->
get_outputs
().
size
()
>
1
)
...
...
@@ -257,9 +257,9 @@ static void materialize_shapes(std::shared_ptr<Node> n,
reorders
[
n
]
=
create_default_reshape
(
n
);
}
static
void
sink_reshape
(
s
td
::
s
hared_ptr
<
op
::
Reshape
>
reshape
,
static
void
sink_reshape
(
shared_ptr
<
op
::
Reshape
>
reshape
,
ReshapeMap
&
reorders
,
s
td
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
s
et
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
orig_reshape
=
reorders
.
at
(
reshape
->
get_argument
(
0
));
if
(
!
reshape
->
get_is_transpose
())
...
...
@@ -286,18 +286,18 @@ static void sink_reshape(std::shared_ptr<op::Reshape> reshape,
}
}
static
void
sink_unary
(
s
td
::
s
hared_ptr
<
op
::
util
::
UnaryElementwiseArithmetic
>
n
,
static
void
sink_unary
(
shared_ptr
<
op
::
util
::
UnaryElementwiseArithmetic
>
n
,
ReshapeMap
&
reorders
,
s
td
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
s
et
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
arg_reshape
=
reorders
.
at
(
n
->
get_argument
(
0
));
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
arg_reshape
)
<<
" for "
<<
n
->
get_name
();
reorders
[
n
]
=
reorders
[
n
->
get_argument
(
0
)];
}
static
void
sink_binary
(
s
td
::
s
hared_ptr
<
op
::
util
::
BinaryElementwiseArithmetic
>
binary
,
static
void
sink_binary
(
shared_ptr
<
op
::
util
::
BinaryElementwiseArithmetic
>
binary
,
ReshapeMap
&
reorders
,
s
td
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
s
et
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
left
=
binary
->
get_argument
(
0
);
auto
right
=
binary
->
get_argument
(
1
);
...
...
@@ -333,9 +333,9 @@ static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> b
}
}
static
void
sink_slice
(
s
td
::
s
hared_ptr
<
op
::
Slice
>
n
,
static
void
sink_slice
(
shared_ptr
<
op
::
Slice
>
n
,
ReshapeMap
&
reorders
,
s
td
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
s
et
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
arg_reshape
=
reorders
.
at
(
n
->
get_argument
(
0
));
auto
order
=
arg_reshape
->
get_input_order
();
...
...
@@ -346,25 +346,23 @@ static void sink_slice(std::shared_ptr<op::Slice> n,
auto
def_order
=
ngraph
::
get_permutation_to_default_order
(
order
);
auto
input_shape
=
ngraph
::
apply_permutation
(
arg_reshape
->
get_shape
(),
def_order
);
auto
dummy_correct_shape
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
arg_reshape
->
get_element_type
(),
input_shape
);
make_shared
<
pattern
::
op
::
Label
>
(
arg_reshape
->
get_element_type
(),
input_shape
);
auto
new_lower
=
ngraph
::
apply_permutation
(
n
->
get_lower_bounds
(),
def_order
);
auto
new_upper
=
ngraph
::
apply_permutation
(
n
->
get_upper_bounds
(),
def_order
);
auto
new_strides
=
ngraph
::
apply_permutation
(
n
->
get_strides
(),
def_order
);
auto
new_slice
=
std
::
make_shared
<
op
::
Slice
>
(
dummy_correct_shape
,
new_lower
,
new_upper
,
new_strides
);
auto
new_slice
=
make_shared
<
op
::
Slice
>
(
dummy_correct_shape
,
new_lower
,
new_upper
,
new_strides
);
ngraph
::
replace_node
(
dummy_correct_shape
,
n
->
get_argument
(
0
));
NGRAPH_DEBUG
<<
"Replacing "
<<
n
->
get_name
()
<<
" with "
<<
new_slice
->
get_name
();
ngraph
::
replace_node
(
n
,
new_slice
);
auto
new_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
new_slice
,
order
,
n
->
get_shape
());
auto
new_reshape
=
make_shared
<
op
::
Reshape
>
(
new_slice
,
order
,
n
->
get_shape
());
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
new_reshape
)
<<
" for "
<<
n
->
get_name
();
reorders
[
new_slice
]
=
new_reshape
;
}
static
void
sink_pad
(
std
::
shared_ptr
<
op
::
Pad
>
n
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
static
void
sink_pad
(
shared_ptr
<
op
::
Pad
>
n
,
ReshapeMap
&
reorders
,
set
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
arg_reshape
=
reorders
.
at
(
n
->
get_argument
(
0
));
auto
order
=
arg_reshape
->
get_input_order
();
...
...
@@ -374,41 +372,41 @@ static void sink_pad(std::shared_ptr<op::Pad> n,
auto
def_order
=
ngraph
::
get_permutation_to_default_order
(
order
);
auto
input_shape
=
ngraph
::
apply_permutation
(
arg_reshape
->
get_shape
(),
def_order
);
auto
dummy_correct_shape
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
arg_reshape
->
get_element_type
(),
input_shape
);
make_shared
<
pattern
::
op
::
Label
>
(
arg_reshape
->
get_element_type
(),
input_shape
);
auto
new_lower
=
ngraph
::
apply_permutation
(
n
->
get_padding_below
(),
def_order
);
auto
new_upper
=
ngraph
::
apply_permutation
(
n
->
get_padding_above
(),
def_order
);
auto
new_interior
=
ngraph
::
apply_permutation
(
n
->
get_padding_interior
(),
def_order
);
auto
new_pad
=
std
::
make_shared
<
op
::
Pad
>
(
auto
new_pad
=
make_shared
<
op
::
Pad
>
(
dummy_correct_shape
,
n
->
get_argument
(
1
),
new_lower
,
new_upper
,
new_interior
);
ngraph
::
replace_node
(
dummy_correct_shape
,
n
->
get_argument
(
0
));
NGRAPH_DEBUG
<<
"Replacing "
<<
n
->
get_name
()
<<
" with "
<<
new_pad
->
get_name
();
ngraph
::
replace_node
(
n
,
new_pad
);
auto
new_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
new_pad
,
order
,
n
->
get_shape
());
auto
new_reshape
=
make_shared
<
op
::
Reshape
>
(
new_pad
,
order
,
n
->
get_shape
());
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
new_reshape
)
<<
" for "
<<
n
->
get_name
();
reorders
[
new_pad
]
=
new_reshape
;
}
static
void
sink_quantize
(
s
td
::
s
hared_ptr
<
op
::
Quantize
>
quantize
,
static
void
sink_quantize
(
shared_ptr
<
op
::
Quantize
>
quantize
,
ReshapeMap
&
reorders
,
s
td
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
s
et
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
arg_reshape
=
reorders
.
at
(
quantize
->
get_argument
(
0
));
AxisSet
axes_in_def_order
=
get_quantization_axes_in_default_order
(
arg_reshape
,
quantize
->
get_axes
());
auto
new_quantize
=
std
::
make_shared
<
op
::
Quantize
>
(
quantize
->
get_argument
(
0
),
quantize
->
get_argument
(
1
),
quantize
->
get_argument
(
2
),
quantize
->
get_element_type
(),
axes_in_def_order
,
quantize
->
get_round_mode
());
auto
new_quantize
=
make_shared
<
op
::
Quantize
>
(
quantize
->
get_argument
(
0
),
quantize
->
get_argument
(
1
),
quantize
->
get_argument
(
2
),
quantize
->
get_element_type
(),
axes_in_def_order
,
quantize
->
get_round_mode
());
ngraph
::
replace_node
(
quantize
,
new_quantize
);
reorders
[
new_quantize
]
=
arg_reshape
;
}
static
void
sink_concat
(
s
td
::
s
hared_ptr
<
op
::
Concat
>
n
,
static
void
sink_concat
(
shared_ptr
<
op
::
Concat
>
n
,
ReshapeMap
&
reorders
,
s
td
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
s
et
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
arg_reshape
=
reorders
.
at
(
n
->
get_argument
(
0
));
auto
order
=
arg_reshape
->
get_input_order
();
...
...
@@ -418,7 +416,7 @@ static void sink_concat(std::shared_ptr<op::Concat> n,
auto
def_order
=
ngraph
::
get_permutation_to_default_order
(
order
);
auto
input_shape
=
ngraph
::
apply_permutation
(
arg_reshape
->
get_shape
(),
def_order
);
auto
dummy_correct_shape
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
arg_reshape
->
get_element_type
(),
input_shape
);
make_shared
<
pattern
::
op
::
Label
>
(
arg_reshape
->
get_element_type
(),
input_shape
);
NodeVector
new_args
;
new_args
.
push_back
(
dummy_correct_shape
);
...
...
@@ -436,12 +434,12 @@ static void sink_concat(std::shared_ptr<op::Concat> n,
auto
iinput_shape
=
ngraph
::
apply_permutation
(
iarg_reshape
->
get_shape
(),
def_order
);
auto
idummy_correct_shape
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
iarg_reshape
->
get_element_type
(),
iinput_shape
);
make_shared
<
pattern
::
op
::
Label
>
(
iarg_reshape
->
get_element_type
(),
iinput_shape
);
new_args
.
push_back
(
idummy_correct_shape
);
}
auto
new_axis
=
order
.
at
(
n
->
get_concatenation_axis
());
auto
new_concat
=
std
::
make_shared
<
op
::
Concat
>
(
new_args
,
new_axis
);
auto
new_concat
=
make_shared
<
op
::
Concat
>
(
new_args
,
new_axis
);
//put back the original arguments
for
(
size_t
i
=
0
;
i
<
new_concat
->
get_input_size
();
i
++
)
{
...
...
@@ -450,23 +448,23 @@ static void sink_concat(std::shared_ptr<op::Concat> n,
NGRAPH_DEBUG
<<
"Replacing "
<<
n
->
get_name
()
<<
" with "
<<
new_concat
->
get_name
();
ngraph
::
replace_node
(
n
,
new_concat
);
auto
new_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
new_concat
,
order
,
n
->
get_shape
());
auto
new_reshape
=
make_shared
<
op
::
Reshape
>
(
new_concat
,
order
,
n
->
get_shape
());
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
new_reshape
)
<<
" for "
<<
n
->
get_name
();
reorders
[
new_concat
]
=
new_reshape
;
}
static
void
sink_dequantize
(
s
td
::
s
hared_ptr
<
op
::
Dequantize
>
dequantize
,
static
void
sink_dequantize
(
shared_ptr
<
op
::
Dequantize
>
dequantize
,
ReshapeMap
&
reorders
,
s
td
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
s
et
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
arg_reshape
=
reorders
.
at
(
dequantize
->
get_argument
(
0
));
AxisSet
axes_in_def_order
=
get_quantization_axes_in_default_order
(
arg_reshape
,
dequantize
->
get_axes
());
auto
new_dequantize
=
std
::
make_shared
<
op
::
Dequantize
>
(
dequantize
->
get_argument
(
0
),
dequantize
->
get_argument
(
1
),
dequantize
->
get_argument
(
2
),
dequantize
->
get_element_type
(),
axes_in_def_order
);
auto
new_dequantize
=
make_shared
<
op
::
Dequantize
>
(
dequantize
->
get_argument
(
0
),
dequantize
->
get_argument
(
1
),
dequantize
->
get_argument
(
2
),
dequantize
->
get_element_type
(),
axes_in_def_order
);
ngraph
::
replace_node
(
dequantize
,
new_dequantize
);
reorders
[
new_dequantize
]
=
arg_reshape
;
...
...
@@ -481,11 +479,11 @@ static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
//For each op type we support we can either combine
//two reshapes by replacing the existing Reshape,
//materialize pending reshapes if they can't be propagated through op
bool
ngraph
::
pass
::
ReshapeSinking
::
run_on_function
(
s
td
::
s
hared_ptr
<
ngraph
::
Function
>
f
)
bool
ngraph
::
pass
::
ReshapeSinking
::
run_on_function
(
shared_ptr
<
ngraph
::
Function
>
f
)
{
ReshapeMap
reorders
;
NodeVector
results
;
s
td
::
set
<
std
::
shared_ptr
<
Node
>>
reshapes_to_delete
;
s
et
<
shared_ptr
<
Node
>>
reshapes_to_delete
;
//STEP 1 : Sink or Swim reshapes away for op clusters
for
(
auto
n
:
f
->
get_ordered_ops
())
...
...
@@ -497,31 +495,31 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
results
.
push_back
(
n
);
}
if
(
auto
reshape
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
n
))
if
(
auto
reshape
=
dynamic_pointer_cast
<
op
::
Reshape
>
(
n
))
{
sink_reshape
(
reshape
,
reorders
,
reshapes_to_delete
);
}
else
if
(
auto
unary
=
std
::
dynamic_pointer_cast
<
op
::
util
::
UnaryElementwiseArithmetic
>
(
n
))
else
if
(
auto
unary
=
dynamic_pointer_cast
<
op
::
util
::
UnaryElementwiseArithmetic
>
(
n
))
{
sink_unary
(
unary
,
reorders
,
reshapes_to_delete
);
}
else
if
(
auto
binary
=
std
::
dynamic_pointer_cast
<
op
::
util
::
BinaryElementwiseArithmetic
>
(
n
))
else
if
(
auto
binary
=
dynamic_pointer_cast
<
op
::
util
::
BinaryElementwiseArithmetic
>
(
n
))
{
sink_binary
(
binary
,
reorders
,
reshapes_to_delete
);
}
else
if
(
auto
goe
=
std
::
dynamic_pointer_cast
<
op
::
GetOutputElement
>
(
n
))
else
if
(
auto
goe
=
dynamic_pointer_cast
<
op
::
GetOutputElement
>
(
n
))
{
reorders
[
goe
]
=
create_default_reshape
(
goe
);
}
else
if
(
auto
quantize
=
std
::
dynamic_pointer_cast
<
op
::
Quantize
>
(
n
))
else
if
(
auto
quantize
=
dynamic_pointer_cast
<
op
::
Quantize
>
(
n
))
{
sink_quantize
(
quantize
,
reorders
,
reshapes_to_delete
);
}
else
if
(
auto
dequantize
=
std
::
dynamic_pointer_cast
<
op
::
Dequantize
>
(
n
))
else
if
(
auto
dequantize
=
dynamic_pointer_cast
<
op
::
Dequantize
>
(
n
))
{
sink_dequantize
(
dequantize
,
reorders
,
reshapes_to_delete
);
}
else
if
(
auto
slice
=
std
::
dynamic_pointer_cast
<
op
::
Slice
>
(
n
))
else
if
(
auto
slice
=
dynamic_pointer_cast
<
op
::
Slice
>
(
n
))
{
// A heuristic. If Reshape has multiple slice users, if sunk
// it will be replicated by the number of its users
...
...
@@ -542,11 +540,11 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
materialize_shapes
(
n
,
reorders
,
reshapes_to_delete
);
}
}
else
if
(
auto
pad
=
std
::
dynamic_pointer_cast
<
op
::
Pad
>
(
n
))
else
if
(
auto
pad
=
dynamic_pointer_cast
<
op
::
Pad
>
(
n
))
{
sink_pad
(
pad
,
reorders
,
reshapes_to_delete
);
}
else
if
(
auto
concat
=
std
::
dynamic_pointer_cast
<
op
::
Concat
>
(
n
))
else
if
(
auto
concat
=
dynamic_pointer_cast
<
op
::
Concat
>
(
n
))
{
sink_concat
(
concat
,
reorders
,
reshapes_to_delete
);
}
...
...
src/ngraph/pass/visualize_tree.cpp
View file @
13fc556e
...
...
@@ -27,9 +27,9 @@
using
namespace
ngraph
;
using
namespace
std
;
#define TI(x)
std::
type_index(typeid(x))
#define TI(x) type_index(typeid(x))
bool
pass
::
VisualizeTree
::
run_on_module
(
vector
<
shared_ptr
<
ngraph
::
Function
>>&
functions
)
bool
pass
::
VisualizeTree
::
run_on_module
(
vector
<
shared_ptr
<
Function
>>&
functions
)
{
for
(
shared_ptr
<
Function
>
f
:
functions
)
{
...
...
@@ -42,10 +42,10 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu
m_ss
<<
add_attributes
(
node
);
m_ss
<<
" "
<<
arg
->
get_name
()
<<
" -> "
<<
node
->
get_name
();
if
(
std
::
getenv
(
"NGRAPH_VISUALIZE_EDGE_LABELS"
)
!=
nullptr
)
if
(
getenv
(
"NGRAPH_VISUALIZE_EDGE_LABELS"
)
!=
nullptr
)
{
size_t
output
=
0
;
if
(
auto
goe
=
std
::
dynamic_pointer_cast
<
op
::
GetOutputElement
>
(
node
))
if
(
auto
goe
=
dynamic_pointer_cast
<
op
::
GetOutputElement
>
(
node
))
{
output
=
goe
->
get_n
();
}
...
...
@@ -71,7 +71,7 @@ pass::VisualizeTree::VisualizeTree(const string& file_name, node_modifiers_t nm)
{
}
st
d
::
st
ring
pass
::
VisualizeTree
::
add_attributes
(
shared_ptr
<
Node
>
node
)
string
pass
::
VisualizeTree
::
add_attributes
(
shared_ptr
<
Node
>
node
)
{
string
rc
;
if
(
m_nodes_with_attributes
.
find
(
node
)
==
m_nodes_with_attributes
.
end
())
...
...
@@ -82,7 +82,7 @@ std::string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
return
rc
;
}
st
d
::
st
ring
pass
::
VisualizeTree
::
get_attributes
(
shared_ptr
<
Node
>
node
)
string
pass
::
VisualizeTree
::
get_attributes
(
shared_ptr
<
Node
>
node
)
{
vector
<
string
>
attributes
;
if
(
node
->
is_parameter
()
||
node
->
is_output
())
...
...
@@ -110,22 +110,22 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
stringstream
label
;
label
<<
"label=
\"
"
<<
node
->
get_friendly_name
();
static
const
char
*
nvtos
=
std
::
getenv
(
"NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES"
);
static
const
char
*
nvtos
=
getenv
(
"NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES"
);
if
(
nvtos
!=
nullptr
)
{
// The shapes of the Outputs of a multi-output op
// will be printed for its corresponding `GetOutputElement`s
label
<<
" "
<<
(
node
->
get_outputs
().
size
()
!=
1
?
st
d
::
st
ring
(
"[skipped]"
)
label
<<
" "
<<
(
node
->
get_outputs
().
size
()
!=
1
?
string
(
"[skipped]"
)
:
vector_to_string
(
node
->
get_shape
()));
}
static
const
char
*
nvtot
=
std
::
getenv
(
"NGRAPH_VISUALIZE_TREE_OUTPUT_TYPES"
);
static
const
char
*
nvtot
=
getenv
(
"NGRAPH_VISUALIZE_TREE_OUTPUT_TYPES"
);
if
(
nvtot
!=
nullptr
)
{
// The types of the Outputs of a multi-output op
// will be printed for its corresponding `GetOutputElement`s
label
<<
" "
<<
((
node
->
get_outputs
().
size
()
!=
1
)
?
st
d
::
st
ring
(
"[skipped]"
)
<<
((
node
->
get_outputs
().
size
()
!=
1
)
?
string
(
"[skipped]"
)
:
node
->
get_element_type
().
c_type_string
());
}
...
...
@@ -150,9 +150,9 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
return
ss
.
str
();
}
st
d
::
st
ring
pass
::
VisualizeTree
::
get_file_ext
()
string
pass
::
VisualizeTree
::
get_file_ext
()
{
const
char
*
format
=
std
::
getenv
(
"NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT"
);
const
char
*
format
=
getenv
(
"NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT"
);
if
(
!
format
)
{
format
=
"png"
;
...
...
@@ -163,7 +163,7 @@ std::string pass::VisualizeTree::get_file_ext()
format
+=
1
;
}
return
st
d
::
st
ring
(
format
);
return
string
(
format
);
}
void
pass
::
VisualizeTree
::
render
()
const
...
...
src/ngraph/pass/zero_dim_tensor_elimination.cpp
View file @
13fc556e
...
...
@@ -30,9 +30,10 @@
#include "ngraph/op/sum.hpp"
#include "zero_dim_tensor_elimination.hpp"
using
namespace
std
;
using
namespace
ngraph
;
static
bool
has_zero_dim
(
s
td
::
s
hared_ptr
<
Node
>
node
)
static
bool
has_zero_dim
(
shared_ptr
<
Node
>
node
)
{
if
(
node
->
get_output_size
()
!=
1
)
{
...
...
@@ -40,12 +41,12 @@ static bool has_zero_dim(std::shared_ptr<Node> node)
}
const
auto
&
shape
=
node
->
get_shape
();
return
std
::
find
(
shape
.
begin
(),
shape
.
end
(),
0
)
!=
shape
.
end
();
return
find
(
shape
.
begin
(),
shape
.
end
(),
0
)
!=
shape
.
end
();
}
static
bool
verify_no_internal_zero_length_ops
(
s
td
::
shared_ptr
<
ngraph
::
Function
>
f
)
static
bool
verify_no_internal_zero_length_ops
(
s
hared_ptr
<
Function
>
f
)
{
s
td
::
set
<
std
::
shared_ptr
<
Node
>>
zero_length_nodes
;
s
et
<
shared_ptr
<
Node
>>
zero_length_nodes
;
for
(
auto
n
:
f
->
get_ordered_ops
())
{
if
(
n
->
is_output
()
||
n
->
is_parameter
()
||
n
->
get_outputs
().
size
()
>
1
)
...
...
@@ -76,10 +77,10 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function>
return
zero_length_nodes
.
size
()
>
0
;
}
bool
ngraph
::
pass
::
ZeroDimTensorElimination
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
bool
pass
::
ZeroDimTensorElimination
::
run_on_function
(
shared_ptr
<
Function
>
f
)
{
bool
replaced
=
false
;
auto
cvals
=
std
::
vector
<
std
::
string
>
(
0
);
auto
cvals
=
vector
<
string
>
(
0
);
// we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op
// as an internal node (i.e. a node that isn't an argument to `op::Result`)
for
(
auto
n
:
f
->
get_ordered_ops
())
...
...
@@ -98,8 +99,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
{
// we don't have to create constants every time but this is the easiest
// and it's CSE's job to eliminate the same ones
auto
constant
=
std
::
make_shared
<
op
::
Constant
>
(
n
->
get_element_type
(),
n
->
get_shape
(),
cvals
);
auto
constant
=
make_shared
<
op
::
Constant
>
(
n
->
get_element_type
(),
n
->
get_shape
(),
cvals
);
replace_node
(
n
,
constant
);
NGRAPH_DEBUG
<<
" Replacing "
<<
n
->
get_name
()
<<
" with "
<<
constant
->
get_name
();
replaced
=
true
;
...
...
@@ -111,7 +111,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
continue
;
}
if
(
auto
concat
=
std
::
dynamic_pointer_cast
<
op
::
Concat
>
(
n
))
if
(
auto
concat
=
dynamic_pointer_cast
<
op
::
Concat
>
(
n
))
{
NodeVector
non_zero_dim_args
;
for
(
auto
arg
:
concat
->
get_arguments
())
...
...
@@ -127,7 +127,7 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
auto
new_concat
=
concat
->
copy_with_new_args
(
non_zero_dim_args
);
NGRAPH_DEBUG
<<
" Replacing "
<<
n
->
get_name
()
<<
" with "
<<
new_concat
->
get_name
();
ngraph
::
replace_node
(
concat
,
new_concat
);
replace_node
(
concat
,
new_concat
);
continue
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment