Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
fc9018dc
Commit
fc9018dc
authored
Mar 20, 2018
by
Nick Korovaiko
Committed by
Scott Cyphers
Mar 20, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update GraphRewrite API (#686)
parent
9d84c439
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
72 additions
and
70 deletions
+72
-70
graph_rewrite.cpp
src/ngraph/pass/graph_rewrite.cpp
+1
-4
reshape_elimination.cpp
src/ngraph/pass/reshape_elimination.cpp
+14
-14
core_fusion.cpp
src/ngraph/pattern/core_fusion.cpp
+4
-4
matcher.cpp
src/ngraph/pattern/matcher.cpp
+1
-1
matcher.hpp
src/ngraph/pattern/matcher.hpp
+2
-2
cpu_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
+38
-34
pattern.cpp
test/pattern.cpp
+12
-11
No files found.
src/ngraph/pass/graph_rewrite.cpp
View file @
fc9018dc
...
@@ -39,11 +39,8 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
...
@@ -39,11 +39,8 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
NGRAPH_DEBUG
<<
"Matcher "
<<
matcher
<<
" matched "
<<
node
<<
" , "
NGRAPH_DEBUG
<<
"Matcher "
<<
matcher
<<
" matched "
<<
node
<<
" , "
<<
node
->
get_name
();
<<
node
->
get_name
();
rewritten
=
true
;
rewritten
=
true
;
auto
result
=
matcher
->
process_match
();
if
(
matcher
->
process_match
())
if
(
result
)
{
{
f
->
replace_node
(
node
,
result
);
//move onto the next node
break
;
break
;
}
}
}
}
...
...
src/ngraph/pass/reshape_elimination.cpp
View file @
fc9018dc
...
@@ -63,8 +63,6 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
...
@@ -63,8 +63,6 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
NGRAPH_DEBUG
<<
"In callback for construct_identity_reshape_pattern against node = "
NGRAPH_DEBUG
<<
"In callback for construct_identity_reshape_pattern against node = "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
std
::
shared_ptr
<
ngraph
::
Node
>
nn
;
auto
gop
=
pattern_map
[
op
];
auto
gop
=
pattern_map
[
op
];
auto
r1
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
auto
r1
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
...
@@ -72,7 +70,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
...
@@ -72,7 +70,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
if
(
r1
->
get_shape
()
!=
gop
->
get_shape
())
if
(
r1
->
get_shape
()
!=
gop
->
get_shape
())
{
{
NGRAPH_DEBUG
<<
"Not a no-op; Shapes are different!"
;
NGRAPH_DEBUG
<<
"Not a no-op; Shapes are different!"
;
return
nn
;
return
false
;
}
}
Shape
do_r1
(
r1
->
get_shape
().
size
());
Shape
do_r1
(
r1
->
get_shape
().
size
());
...
@@ -81,10 +79,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
...
@@ -81,10 +79,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
if
(
do_r1
!=
r1
->
get_input_order
())
if
(
do_r1
!=
r1
->
get_input_order
())
{
{
NGRAPH_DEBUG
<<
"Not a no-op; Not in default input order!"
;
NGRAPH_DEBUG
<<
"Not a no-op; Not in default input order!"
;
return
nn
;
return
false
;
}
}
return
gop
;
ngraph
::
replace_node
(
m
.
match_root
(),
gop
);
return
true
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape1
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape1
,
callback
);
...
@@ -105,7 +104,6 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
...
@@ -105,7 +104,6 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
std
::
shared_ptr
<
ngraph
::
Node
>
nn
;
auto
gop
=
pattern_map
[
op
];
auto
gop
=
pattern_map
[
op
];
if
(
gop
->
get_shape
()
!=
m
.
match_root
()
->
get_shape
())
if
(
gop
->
get_shape
()
!=
m
.
match_root
()
->
get_shape
())
...
@@ -115,7 +113,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
...
@@ -115,7 +113,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
<<
"shape = "
<<
vector_to_string
(
gop
->
get_shape
());
<<
"shape = "
<<
vector_to_string
(
gop
->
get_shape
());
NGRAPH_DEBUG
<<
"match_root "
<<
m
.
match_root
()
->
get_name
()
NGRAPH_DEBUG
<<
"match_root "
<<
m
.
match_root
()
->
get_name
()
<<
"shape = "
<<
vector_to_string
(
m
.
match_root
()
->
get_shape
());
<<
"shape = "
<<
vector_to_string
(
m
.
match_root
()
->
get_shape
());
return
nn
;
return
false
;
}
}
auto
r2
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
auto
r2
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
...
@@ -134,7 +132,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
...
@@ -134,7 +132,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if
(
r1
->
get_input_order
()
==
do_r1
&&
r2
->
get_input_order
()
==
do_r2
)
if
(
r1
->
get_input_order
()
==
do_r1
&&
r2
->
get_input_order
()
==
do_r2
)
{
{
NGRAPH_DEBUG
<<
"Two reshapes were removed!"
;
NGRAPH_DEBUG
<<
"Two reshapes were removed!"
;
return
gop
;
ngraph
::
replace_node
(
m
.
match_root
(),
gop
);
return
true
;
}
}
auto
perm1
=
apply_permutation
(
do_r1
,
r1
->
get_input_order
());
auto
perm1
=
apply_permutation
(
do_r1
,
r1
->
get_input_order
());
...
@@ -142,10 +141,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
...
@@ -142,10 +141,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if
(
perm2
==
do_r1
)
if
(
perm2
==
do_r1
)
{
{
NGRAPH_DEBUG
<<
"Two transposes were removed!"
;
NGRAPH_DEBUG
<<
"Two transposes were removed!"
;
return
gop
;
ngraph
::
replace_node
(
m
.
match_root
(),
gop
);
return
true
;
}
}
return
nn
;
return
false
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape2
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape2
,
callback
);
this
->
add_matcher
(
m
);
this
->
add_matcher
(
m
);
...
@@ -165,21 +165,20 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
...
@@ -165,21 +165,20 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
NGRAPH_DEBUG
<<
"In callback for construct_dot_transpose_pattern against node = "
NGRAPH_DEBUG
<<
"In callback for construct_dot_transpose_pattern against node = "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
std
::
shared_ptr
<
Node
>
nn
;
auto
mtranspose
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
auto
mtranspose
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
//this also checks the rank
//this also checks the rank
if
(
mtranspose
->
get_input_order
()
!=
AxisVector
{
1
,
0
})
if
(
mtranspose
->
get_input_order
()
!=
AxisVector
{
1
,
0
})
{
{
NGRAPH_DEBUG
<<
"Reshape isn't transpose. "
NGRAPH_DEBUG
<<
"Reshape isn't transpose. "
<<
vector_to_string
(
mtranspose
->
get_input_order
());
<<
vector_to_string
(
mtranspose
->
get_input_order
());
return
nn
;
return
false
;
}
}
auto
mdot
=
mtranspose
->
get_input_op
(
0
);
auto
mdot
=
mtranspose
->
get_input_op
(
0
);
if
(
mdot
->
get_shape
().
size
()
!=
2
)
if
(
mdot
->
get_shape
().
size
()
!=
2
)
{
{
NGRAPH_DEBUG
<<
"Dot has the wrong shape. "
<<
vector_to_string
(
mdot
->
get_shape
());
NGRAPH_DEBUG
<<
"Dot has the wrong shape. "
<<
vector_to_string
(
mdot
->
get_shape
());
return
nn
;
return
false
;
}
}
auto
arg0
=
mdot
->
get_input_op
(
0
);
auto
arg0
=
mdot
->
get_input_op
(
0
);
...
@@ -191,7 +190,8 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
...
@@ -191,7 +190,8 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
auto
reshape1
=
std
::
make_shared
<
op
::
Reshape
>
(
arg1
,
AxisVector
{
1
,
0
},
reshape1_shape
);
auto
reshape1
=
std
::
make_shared
<
op
::
Reshape
>
(
arg1
,
AxisVector
{
1
,
0
},
reshape1_shape
);
auto
tdot
=
std
::
shared_ptr
<
Node
>
(
new
op
::
Dot
(
reshape1
,
reshape0
));
auto
tdot
=
std
::
shared_ptr
<
Node
>
(
new
op
::
Dot
(
reshape1
,
reshape0
));
return
tdot
;
ngraph
::
replace_node
(
m
.
match_root
(),
tdot
);
return
true
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
preshape
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
preshape
,
callback
);
...
...
src/ngraph/pattern/core_fusion.cpp
View file @
fc9018dc
...
@@ -61,19 +61,19 @@ void pass::CoreFusion::construct_relu_pattern()
...
@@ -61,19 +61,19 @@ void pass::CoreFusion::construct_relu_pattern()
pattern
::
gr_callback_fn
callback
=
[
val
,
zero
](
pattern
::
Matcher
&
m
)
{
pattern
::
gr_callback_fn
callback
=
[
val
,
zero
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In a callback for construct_relu_pattern against "
NGRAPH_DEBUG
<<
"In a callback for construct_relu_pattern against "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
shared_ptr
<
Node
>
nn
;
auto
pattern_map
=
m
.
get_pattern_map
();
auto
mzero
=
m
.
get_pattern_map
()[
zero
];
auto
mzero
=
m
.
get_pattern_map
()[
zero
];
if
(
!
is_zero
(
mzero
))
if
(
!
is_zero
(
mzero
))
{
{
NGRAPH_DEBUG
<<
"zero constant = "
<<
mzero
->
get_name
()
<<
" not equal to 0
\n
"
;
NGRAPH_DEBUG
<<
"zero constant = "
<<
mzero
->
get_name
()
<<
" not equal to 0
\n
"
;
return
nn
;
return
false
;
}
}
auto
mpattern
=
m
.
match_root
();
auto
mpattern
=
m
.
match_root
();
auto
cg
=
shared_ptr
<
Node
>
(
new
op
::
Relu
(
pattern_map
[
val
]));
auto
cg
=
shared_ptr
<
Node
>
(
new
op
::
Relu
(
pattern_map
[
val
]));
return
cg
;
ngraph
::
replace_node
(
m
.
match_root
(),
cg
);
return
true
;
};
};
auto
m
=
make_shared
<
pattern
::
Matcher
>
(
max
,
callback
);
auto
m
=
make_shared
<
pattern
::
Matcher
>
(
max
,
callback
);
...
...
src/ngraph/pattern/matcher.cpp
100755 → 100644
View file @
fc9018dc
...
@@ -202,7 +202,7 @@ namespace ngraph
...
@@ -202,7 +202,7 @@ namespace ngraph
return
false
;
return
false
;
}
}
std
::
shared_ptr
<
Node
>
Matcher
::
process_match
(
::
ngraph
::
pattern
::
gr_callback_fn
callback
)
bool
Matcher
::
process_match
(
::
ngraph
::
pattern
::
gr_callback_fn
callback
)
{
{
gr_callback_fn
cb
=
m_callback
;
gr_callback_fn
cb
=
m_callback
;
if
(
callback
)
if
(
callback
)
...
...
src/ngraph/pattern/matcher.hpp
View file @
fc9018dc
...
@@ -32,7 +32,7 @@ namespace ngraph
...
@@ -32,7 +32,7 @@ namespace ngraph
namespace
pattern
namespace
pattern
{
{
using
gr_callback_fn
=
std
::
function
<
std
::
shared_ptr
<
Node
>
(
class
Matcher
&
m
)
>
;
using
gr_callback_fn
=
std
::
function
<
bool
(
class
Matcher
&
m
)
>
;
namespace
op
namespace
op
{
{
...
@@ -63,7 +63,7 @@ namespace ngraph
...
@@ -63,7 +63,7 @@ namespace ngraph
/// \param graph_node is an input graph to be matched against
/// \param graph_node is an input graph to be matched against
bool
match
(
const
std
::
shared_ptr
<
Node
>&
graph_node
);
bool
match
(
const
std
::
shared_ptr
<
Node
>&
graph_node
);
std
::
shared_ptr
<
Node
>
process_match
(
gr_callback_fn
callback
=
nullptr
);
bool
process_match
(
gr_callback_fn
callback
=
nullptr
);
void
reset
()
{}
void
reset
()
{}
std
::
shared_ptr
<
Node
>
pattern_node
()
{
return
m_pattern_node
;
}
std
::
shared_ptr
<
Node
>
pattern_node
()
{
return
m_pattern_node
;
}
...
...
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
View file @
fc9018dc
...
@@ -152,7 +152,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern()
...
@@ -152,7 +152,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern()
m_matmul
->
get_is_arg1_transposed
(),
m_matmul
->
get_is_arg1_transposed
(),
m_broadcast
->
get_broadcast_axes
());
m_broadcast
->
get_broadcast_axes
());
return
mmb
;
ngraph
::
replace_node
(
m
.
match_root
(),
mmb
);
return
true
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
padd
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
padd
,
callback
);
...
@@ -182,7 +183,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
...
@@ -182,7 +183,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
NGRAPH_DEBUG
<<
"In callback for construct_matmul_pattern against node = "
NGRAPH_DEBUG
<<
"In callback for construct_matmul_pattern against node = "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
std
::
shared_ptr
<
Node
>
nn
;
auto
mpattern
=
m
.
match_root
();
auto
mpattern
=
m
.
match_root
();
auto
dot
=
m
.
match_root
();
auto
dot
=
m
.
match_root
();
...
@@ -190,33 +190,33 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
...
@@ -190,33 +190,33 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
if
(
mpattern
->
get_element_type
()
!=
element
::
f32
)
if
(
mpattern
->
get_element_type
()
!=
element
::
f32
)
{
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
mpattern
->
get_name
()
<<
" type is not float!"
;
NGRAPH_DEBUG
<<
"mpattern = "
<<
mpattern
->
get_name
()
<<
" type is not float!"
;
return
nn
;
return
false
;
}
}
if
(
dot
->
get_shape
().
size
()
!=
2
)
if
(
dot
->
get_shape
().
size
()
!=
2
)
{
{
NGRAPH_DEBUG
<<
"dot = "
<<
dot
->
get_name
()
<<
" shape is not equal to 2!"
;
NGRAPH_DEBUG
<<
"dot = "
<<
dot
->
get_name
()
<<
" shape is not equal to 2!"
;
return
nn
;
return
false
;
}
}
if
(
shape_size
(
dot
->
get_shape
())
==
0
)
if
(
shape_size
(
dot
->
get_shape
())
==
0
)
{
{
NGRAPH_DEBUG
<<
"dot has a zero dimension"
;
NGRAPH_DEBUG
<<
"dot has a zero dimension"
;
return
nn
;
return
false
;
}
}
bool
transpose_w
=
false
;
bool
transpose_w
=
false
;
Shape
shape_arg0
{
pattern_map
[
W
]
->
get_shape
()};
Shape
shape_arg0
{
pattern_map
[
W
]
->
get_shape
()};
if
(
!
init_cblas_arg
(
dot
->
get_input_op
(
0
),
pattern_map
[
W
],
transpose_w
,
shape_arg0
))
if
(
!
init_cblas_arg
(
dot
->
get_input_op
(
0
),
pattern_map
[
W
],
transpose_w
,
shape_arg0
))
{
{
return
nn
;
return
false
;
}
}
bool
transpose_x
=
false
;
bool
transpose_x
=
false
;
Shape
shape_arg1
{
pattern_map
[
x
]
->
get_shape
()};
Shape
shape_arg1
{
pattern_map
[
x
]
->
get_shape
()};
if
(
!
init_cblas_arg
(
dot
->
get_input_op
(
1
),
pattern_map
[
x
],
transpose_x
,
shape_arg1
))
if
(
!
init_cblas_arg
(
dot
->
get_input_op
(
1
),
pattern_map
[
x
],
transpose_x
,
shape_arg1
))
{
{
return
nn
;
return
false
;
}
}
auto
cg
=
std
::
shared_ptr
<
Node
>
(
new
op
::
MatmulBias
(
pattern_map
[
W
],
auto
cg
=
std
::
shared_ptr
<
Node
>
(
new
op
::
MatmulBias
(
pattern_map
[
W
],
...
@@ -226,7 +226,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
...
@@ -226,7 +226,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
shape_arg1
,
shape_arg1
,
transpose_w
,
transpose_w
,
transpose_x
));
transpose_x
));
return
cg
;
ngraph
::
replace_node
(
mpattern
,
cg
);
return
true
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
pdot
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
pdot
,
callback
);
...
@@ -286,7 +288,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
...
@@ -286,7 +288,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_bn pattern against "
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_bn pattern against "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
std
::
shared_ptr
<
Node
>
nn
=
nullptr
;
//TODO - add assert's based on the matched node
//TODO - add assert's based on the matched node
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
NGRAPH_DEBUG
<<
"Input: "
<<
pattern_map
[
input
]
->
get_name
()
<<
" "
NGRAPH_DEBUG
<<
"Input: "
<<
pattern_map
[
input
]
->
get_name
()
<<
" "
...
@@ -306,7 +307,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
...
@@ -306,7 +307,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
if
(
pattern_map
[
input
]
->
get_shape
().
size
()
!=
4
)
if
(
pattern_map
[
input
]
->
get_shape
().
size
()
!=
4
)
{
{
NGRAPH_DEBUG
<<
"Input to bn doesnt not have a rank=4, so not fusing"
;
NGRAPH_DEBUG
<<
"Input to bn doesnt not have a rank=4, so not fusing"
;
return
nn
;
return
false
;
}
}
Shape
bn_output_shape
{
m
.
match_root
()
->
get_shape
()};
Shape
bn_output_shape
{
m
.
match_root
()
->
get_shape
()};
Shape
m_bn_mean_shape
{
pattern_map
[
mean_label
]
->
get_shape
()};
Shape
m_bn_mean_shape
{
pattern_map
[
mean_label
]
->
get_shape
()};
...
@@ -320,7 +321,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
...
@@ -320,7 +321,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
auto
normalized_output
=
std
::
shared_ptr
<
Node
>
(
new
op
::
GetOutputElement
(
bn_node
,
0
));
auto
normalized_output
=
std
::
shared_ptr
<
Node
>
(
new
op
::
GetOutputElement
(
bn_node
,
0
));
return
normalized_output
;
ngraph
::
replace_node
(
m
.
match_root
(),
normalized_output
);
return
true
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
add_beta
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
add_beta
,
callback
);
...
@@ -408,7 +410,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
...
@@ -408,7 +410,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
ngraph
::
pattern
::
gr_callback_fn
callback
=
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
pad_input
,
pad_value
,
pad_label
,
reshape_label
,
conv_filter
,
conv_label
](
[
pad_input
,
pad_value
,
pad_label
,
reshape_label
,
conv_filter
,
conv_label
](
pattern
::
Matcher
&
m
)
->
std
::
shared_ptr
<
Node
>
{
pattern
::
Matcher
&
m
)
{
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pad_value_op
=
std
::
dynamic_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
pad_value
]);
auto
pad_value_op
=
std
::
dynamic_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
pad_value
]);
...
@@ -420,8 +422,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
...
@@ -420,8 +422,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
pattern_map
[
reshape_label
]);
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
pattern_map
[
reshape_label
]);
const
auto
&
input_order
=
matched_reshape
->
get_input_order
();
const
auto
&
input_order
=
matched_reshape
->
get_input_order
();
auto
hoisted_reshape_output_shape
=
auto
hoisted_reshape_output_shape
=
apply_permutation
<
Shape
::
value_type
>
(
apply_permutation
<
Shape
::
value_type
>
(
pattern_map
[
pad_input
]
->
get_shape
(),
input_order
);
pattern_map
[
pad_input
]
->
get_shape
(),
input_order
);
auto
hoisted_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
auto
hoisted_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
pattern_map
[
pad_input
],
pattern_map
[
pad_input
],
...
@@ -436,7 +438,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
...
@@ -436,7 +438,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
input_order
[
0
],
input_order
[
0
],
input_order
[
1
]))
input_order
[
1
]))
{
{
return
nullptr
;
return
false
;
}
}
CoordinateDiff
padding_below
{
static_cast
<
CoordinateDiff
::
value_type
>
(
CoordinateDiff
padding_below
{
static_cast
<
CoordinateDiff
::
value_type
>
(
...
@@ -457,7 +459,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
...
@@ -457,7 +459,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
padding_above
,
padding_above
,
matched_conv
->
get_data_dilation_strides
());
matched_conv
->
get_data_dilation_strides
());
return
zero_padded_conv
;
ngraph
::
replace_node
(
m
.
match_root
(),
zero_padded_conv
);
return
true
;
};
};
this
->
add_matcher
(
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
conv_label
,
callback
));
this
->
add_matcher
(
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
conv_label
,
callback
));
...
@@ -483,8 +486,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
...
@@ -483,8 +486,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
auto
conv_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
conv
,
nullptr
,
NodeVector
{
conv
});
auto
conv_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
conv
,
nullptr
,
NodeVector
{
conv
});
ngraph
::
pattern
::
gr_callback_fn
callback
=
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
pad_input
,
pad_value
,
pad_label
,
conv_filter
,
conv_label
](
[
pad_input
,
pad_value
,
pad_label
,
conv_filter
,
conv_label
](
pattern
::
Matcher
&
m
)
{
pattern
::
Matcher
&
m
)
->
std
::
shared_ptr
<
Node
>
{
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pad_value_op
=
std
::
dynamic_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
pad_value
]);
auto
pad_value_op
=
std
::
dynamic_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
pad_value
]);
...
@@ -501,7 +503,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
...
@@ -501,7 +503,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
0
,
0
,
1
))
1
))
{
{
return
nullptr
;
return
false
;
}
}
CoordinateDiff
padding_below
{
CoordinateDiff
padding_below
{
...
@@ -520,7 +522,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
...
@@ -520,7 +522,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
padding_above
,
padding_above
,
matched_conv
->
get_data_dilation_strides
());
matched_conv
->
get_data_dilation_strides
());
return
zero_padded_conv
;
ngraph
::
replace_node
(
m
.
match_root
(),
zero_padded_conv
);
return
true
;
};
};
this
->
add_matcher
(
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
conv_label
,
callback
));
this
->
add_matcher
(
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
conv_label
,
callback
));
...
@@ -541,8 +544,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
...
@@ -541,8 +544,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
auto
divide_1_over_exp
=
std
::
make_shared
<
op
::
Divide
>
(
broadcast_constant
,
add_exp
);
auto
divide_1_over_exp
=
std
::
make_shared
<
op
::
Divide
>
(
broadcast_constant
,
add_exp
);
//Define a call back that needs to called once the DFG matches the pattern
//Define a call back that needs to called once the DFG matches the pattern
ngraph
::
pattern
::
gr_callback_fn
callback
=
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
input
](
pattern
::
Matcher
&
m
)
{
[
input
](
pattern
::
Matcher
&
m
)
->
std
::
shared_ptr
<
Node
>
{
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_sigmoid pattern against "
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_sigmoid pattern against "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
...
@@ -550,18 +552,19 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
...
@@ -550,18 +552,19 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
if
(
m
.
match_root
()
->
get_element_type
()
!=
element
::
f32
)
if
(
m
.
match_root
()
->
get_element_type
()
!=
element
::
f32
)
{
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
" type is not float!"
;
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
" type is not float!"
;
return
nullptr
;
return
false
;
}
}
if
(
m
.
match_root
()
->
get_outputs
().
size
()
!=
pattern_map
[
input
]
->
get_outputs
().
size
())
if
(
m
.
match_root
()
->
get_outputs
().
size
()
!=
pattern_map
[
input
]
->
get_outputs
().
size
())
{
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
"input= "
<<
pattern_map
[
input
]
->
get_name
()
<<
"size dont match!"
;
<<
"input= "
<<
pattern_map
[
input
]
->
get_name
()
<<
"size dont match!"
;
return
nullptr
;
return
false
;
}
}
auto
sigmoid_node
=
std
::
make_shared
<
op
::
Sigmoid
>
(
pattern_map
[
input
]);
auto
sigmoid_node
=
std
::
make_shared
<
op
::
Sigmoid
>
(
pattern_map
[
input
]);
return
sigmoid_node
;
ngraph
::
replace_node
(
m
.
match_root
(),
sigmoid_node
);
return
true
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
divide_1_over_exp
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
divide_1_over_exp
,
callback
);
...
@@ -593,26 +596,26 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop()
...
@@ -593,26 +596,26 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop()
auto
negtive_2
=
std
::
make_shared
<
op
::
Negative
>
(
multiply_2
);
auto
negtive_2
=
std
::
make_shared
<
op
::
Negative
>
(
multiply_2
);
//Define a call back that needs to called once the DFG matches the pattern
//Define a call back that needs to called once the DFG matches the pattern
ngraph
::
pattern
::
gr_callback_fn
callback
=
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
input
,
delta
](
pattern
::
Matcher
&
m
)
{
[
input
,
delta
](
pattern
::
Matcher
&
m
)
->
std
::
shared_ptr
<
Node
>
{
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_sigmoid pattern against "
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_sigmoid pattern against "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
if
(
m
.
match_root
()
->
get_element_type
()
!=
element
::
f32
)
if
(
m
.
match_root
()
->
get_element_type
()
!=
element
::
f32
)
{
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
" type is not float!"
;
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
" type is not float!"
;
return
nullptr
;
return
false
;
}
}
if
(
m
.
match_root
()
->
get_shape
().
size
()
!=
pattern_map
[
input
]
->
get_shape
().
size
())
if
(
m
.
match_root
()
->
get_shape
().
size
()
!=
pattern_map
[
input
]
->
get_shape
().
size
())
{
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
"input= "
<<
pattern_map
[
input
]
->
get_name
()
<<
"size dont match!"
;
<<
"input= "
<<
pattern_map
[
input
]
->
get_name
()
<<
"size dont match!"
;
return
nullptr
;
return
false
;
}
}
auto
dsigmoid
=
auto
dsigmoid
=
std
::
make_shared
<
op
::
SigmoidBackprop
>
(
pattern_map
[
input
],
pattern_map
[
delta
]);
std
::
make_shared
<
op
::
SigmoidBackprop
>
(
pattern_map
[
input
],
pattern_map
[
delta
]);
return
dsigmoid
;
ngraph
::
replace_node
(
m
.
match_root
(),
dsigmoid
);
return
true
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
negtive_2
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
negtive_2
,
callback
);
...
@@ -641,7 +644,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
...
@@ -641,7 +644,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
NGRAPH_DEBUG
<<
"In callback for construct_conv_bias against node = "
NGRAPH_DEBUG
<<
"In callback for construct_conv_bias against node = "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
std
::
shared_ptr
<
Node
>
nn
;
auto
conv
=
std
::
dynamic_pointer_cast
<
op
::
Convolution
>
(
m
.
match_root
()
->
get_input_op
(
0
));
auto
conv
=
std
::
dynamic_pointer_cast
<
op
::
Convolution
>
(
m
.
match_root
()
->
get_input_op
(
0
));
if
(
conv
->
get_input_shape
(
0
).
size
()
==
4
)
if
(
conv
->
get_input_shape
(
0
).
size
()
==
4
)
...
@@ -658,17 +660,19 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
...
@@ -658,17 +660,19 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
auto
bias_reshape
=
auto
bias_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
bias
,
order
,
Shape
{
conv
->
get_input_shape
(
1
)[
0
]});
std
::
make_shared
<
op
::
Reshape
>
(
bias
,
order
,
Shape
{
conv
->
get_input_shape
(
1
)[
0
]});
auto
conv_bias
=
std
::
shared_ptr
<
Node
>
(
new
op
::
ConvolutionBias
(
conv
,
bias_reshape
));
auto
conv_bias
=
std
::
shared_ptr
<
Node
>
(
new
op
::
ConvolutionBias
(
conv
,
bias_reshape
));
return
conv_bias
;
ngraph
::
replace_node
(
m
.
match_root
(),
conv_bias
);
return
true
;
}
}
else
else
{
{
auto
conv_bias
=
std
::
shared_ptr
<
Node
>
(
new
op
::
ConvolutionBias
(
conv
,
bias
));
auto
conv_bias
=
std
::
shared_ptr
<
Node
>
(
new
op
::
ConvolutionBias
(
conv
,
bias
));
return
conv_bias
;
ngraph
::
replace_node
(
m
.
match_root
(),
conv_bias
);
return
true
;
}
}
}
}
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
"conv_bias fusion skipped due to input rank size != 4."
;
<<
"conv_bias fusion skipped due to input rank size != 4."
;
return
std
::
shared_ptr
<
Node
>
(
nullptr
)
;
return
false
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
p_conv_bias
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
p_conv_bias
,
callback
);
...
...
test/pattern.cpp
View file @
fc9018dc
...
@@ -169,12 +169,11 @@ public:
...
@@ -169,12 +169,11 @@ public:
NGRAPH_DEBUG
<<
"second_node = "
<<
second_node
->
get_name
()
NGRAPH_DEBUG
<<
"second_node = "
<<
second_node
->
get_name
()
<<
" , pattern = "
<<
pattern_map
[
pattern
]
->
get_name
();
<<
" , pattern = "
<<
pattern_map
[
pattern
]
->
get_name
();
std
::
shared_ptr
<
ngraph
::
Node
>
nn
=
nullptr
;
if
(
pattern_map
[
pattern
]
->
get_element_type
()
!=
const_node
->
get_element_type
()
||
if
(
pattern_map
[
pattern
]
->
get_element_type
()
!=
const_node
->
get_element_type
()
||
pattern_map
[
pattern
]
->
get_shape
()
!=
const_node
->
get_shape
())
pattern_map
[
pattern
]
->
get_shape
()
!=
const_node
->
get_shape
())
{
{
NGRAPH_DEBUG
<<
"Operands' types and/or shape don't match"
;
NGRAPH_DEBUG
<<
"Operands' types and/or shape don't match"
;
return
nn
;
return
false
;
}
}
auto
const_values
=
const_node
->
get_vector
<
int32_t
>
();
auto
const_values
=
const_node
->
get_vector
<
int32_t
>
();
...
@@ -184,9 +183,11 @@ public:
...
@@ -184,9 +183,11 @@ public:
if
(
!
all_ones
)
if
(
!
all_ones
)
{
{
NGRAPH_DEBUG
<<
"Constant vector's values aren't equal to 1"
;
NGRAPH_DEBUG
<<
"Constant vector's values aren't equal to 1"
;
return
nn
;
return
false
;
}
}
return
pattern_map
[
pattern
];
ngraph
::
replace_node
(
m
.
match_root
(),
pattern_map
[
pattern
]);
return
true
;
};
};
auto
m
=
make_shared
<
TestMatcher
>
(
pattern
*
iconst1
,
callback
);
auto
m
=
make_shared
<
TestMatcher
>
(
pattern
*
iconst1
,
callback
);
...
@@ -213,14 +214,11 @@ public:
...
@@ -213,14 +214,11 @@ public:
NGRAPH_DEBUG
<<
"second_node = "
<<
second_node
->
get_name
()
NGRAPH_DEBUG
<<
"second_node = "
<<
second_node
->
get_name
()
<<
" , pattern = "
<<
pattern_map
[
pattern
]
->
get_name
();
<<
" , pattern = "
<<
pattern_map
[
pattern
]
->
get_name
();
//ASSERT_NE(nullptr, const_node);
std
::
shared_ptr
<
ngraph
::
Node
>
nn
=
nullptr
;
if
(
pattern_map
[
pattern
]
->
get_element_type
()
!=
const_node
->
get_element_type
()
||
if
(
pattern_map
[
pattern
]
->
get_element_type
()
!=
const_node
->
get_element_type
()
||
pattern_map
[
pattern
]
->
get_shape
()
!=
const_node
->
get_shape
())
pattern_map
[
pattern
]
->
get_shape
()
!=
const_node
->
get_shape
())
{
{
NGRAPH_DEBUG
<<
"Operands' types and/or shape don't match"
;
NGRAPH_DEBUG
<<
"Operands' types and/or shape don't match"
;
return
nn
;
return
false
;
}
}
auto
const_values
=
const_node
->
get_vector
<
int
>
();
auto
const_values
=
const_node
->
get_vector
<
int
>
();
...
@@ -230,10 +228,11 @@ public:
...
@@ -230,10 +228,11 @@ public:
if
(
!
all_zeros
)
if
(
!
all_zeros
)
{
{
NGRAPH_DEBUG
<<
"Constant vector's values aren't equal to 0"
;
NGRAPH_DEBUG
<<
"Constant vector's values aren't equal to 0"
;
return
nn
;
return
false
;
}
}
return
pattern_map
[
pattern
];
ngraph
::
replace_node
(
m
.
match_root
(),
pattern_map
[
pattern
]);
return
true
;
};
};
auto
m
=
make_shared
<
TestMatcher
>
(
pattern
+
iconst0
,
callback
);
auto
m
=
make_shared
<
TestMatcher
>
(
pattern
+
iconst0
,
callback
);
...
@@ -252,7 +251,9 @@ public:
...
@@ -252,7 +251,9 @@ public:
NGRAPH_DEBUG
<<
"reducee = "
<<
reducee
->
get_name
();
NGRAPH_DEBUG
<<
"reducee = "
<<
reducee
->
get_name
();
auto
sum
=
auto
sum
=
std
::
shared_ptr
<
ngraph
::
Node
>
(
new
op
::
Sum
(
reducee
,
reduce
->
get_reduction_axes
()));
std
::
shared_ptr
<
ngraph
::
Node
>
(
new
op
::
Sum
(
reducee
,
reduce
->
get_reduction_axes
()));
return
sum
;
ngraph
::
replace_node
(
m
.
match_root
(),
sum
);
return
true
;
};
};
auto
m
=
make_shared
<
TestMatcher
>
(
sum_pattern
,
callback
);
auto
m
=
make_shared
<
TestMatcher
>
(
sum_pattern
,
callback
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment