Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
fc9018dc
Commit
fc9018dc
authored
Mar 20, 2018
by
Nick Korovaiko
Committed by
Scott Cyphers
Mar 20, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update GraphRewrite API (#686)
parent
9d84c439
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
72 additions
and
70 deletions
+72
-70
graph_rewrite.cpp
src/ngraph/pass/graph_rewrite.cpp
+1
-4
reshape_elimination.cpp
src/ngraph/pass/reshape_elimination.cpp
+14
-14
core_fusion.cpp
src/ngraph/pattern/core_fusion.cpp
+4
-4
matcher.cpp
src/ngraph/pattern/matcher.cpp
+1
-1
matcher.hpp
src/ngraph/pattern/matcher.hpp
+2
-2
cpu_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
+38
-34
pattern.cpp
test/pattern.cpp
+12
-11
No files found.
src/ngraph/pass/graph_rewrite.cpp
View file @
fc9018dc
...
...
@@ -39,11 +39,8 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
NGRAPH_DEBUG
<<
"Matcher "
<<
matcher
<<
" matched "
<<
node
<<
" , "
<<
node
->
get_name
();
rewritten
=
true
;
auto
result
=
matcher
->
process_match
();
if
(
result
)
if
(
matcher
->
process_match
())
{
f
->
replace_node
(
node
,
result
);
//move onto the next node
break
;
}
}
...
...
src/ngraph/pass/reshape_elimination.cpp
View file @
fc9018dc
...
...
@@ -63,8 +63,6 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
NGRAPH_DEBUG
<<
"In callback for construct_identity_reshape_pattern against node = "
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
std
::
shared_ptr
<
ngraph
::
Node
>
nn
;
auto
gop
=
pattern_map
[
op
];
auto
r1
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
...
...
@@ -72,7 +70,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
if
(
r1
->
get_shape
()
!=
gop
->
get_shape
())
{
NGRAPH_DEBUG
<<
"Not a no-op; Shapes are different!"
;
return
nn
;
return
false
;
}
Shape
do_r1
(
r1
->
get_shape
().
size
());
...
...
@@ -81,10 +79,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
if
(
do_r1
!=
r1
->
get_input_order
())
{
NGRAPH_DEBUG
<<
"Not a no-op; Not in default input order!"
;
return
nn
;
return
false
;
}
return
gop
;
ngraph
::
replace_node
(
m
.
match_root
(),
gop
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape1
,
callback
);
...
...
@@ -105,7 +104,6 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
std
::
shared_ptr
<
ngraph
::
Node
>
nn
;
auto
gop
=
pattern_map
[
op
];
if
(
gop
->
get_shape
()
!=
m
.
match_root
()
->
get_shape
())
...
...
@@ -115,7 +113,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
<<
"shape = "
<<
vector_to_string
(
gop
->
get_shape
());
NGRAPH_DEBUG
<<
"match_root "
<<
m
.
match_root
()
->
get_name
()
<<
"shape = "
<<
vector_to_string
(
m
.
match_root
()
->
get_shape
());
return
nn
;
return
false
;
}
auto
r2
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
...
...
@@ -134,7 +132,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if
(
r1
->
get_input_order
()
==
do_r1
&&
r2
->
get_input_order
()
==
do_r2
)
{
NGRAPH_DEBUG
<<
"Two reshapes were removed!"
;
return
gop
;
ngraph
::
replace_node
(
m
.
match_root
(),
gop
);
return
true
;
}
auto
perm1
=
apply_permutation
(
do_r1
,
r1
->
get_input_order
());
...
...
@@ -142,10 +141,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if
(
perm2
==
do_r1
)
{
NGRAPH_DEBUG
<<
"Two transposes were removed!"
;
return
gop
;
ngraph
::
replace_node
(
m
.
match_root
(),
gop
);
return
true
;
}
return
nn
;
return
false
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape2
,
callback
);
this
->
add_matcher
(
m
);
...
...
@@ -165,21 +165,20 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
NGRAPH_DEBUG
<<
"In callback for construct_dot_transpose_pattern against node = "
<<
m
.
match_root
()
->
get_name
();
std
::
shared_ptr
<
Node
>
nn
;
auto
mtranspose
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
//this also checks the rank
if
(
mtranspose
->
get_input_order
()
!=
AxisVector
{
1
,
0
})
{
NGRAPH_DEBUG
<<
"Reshape isn't transpose. "
<<
vector_to_string
(
mtranspose
->
get_input_order
());
return
nn
;
return
false
;
}
auto
mdot
=
mtranspose
->
get_input_op
(
0
);
if
(
mdot
->
get_shape
().
size
()
!=
2
)
{
NGRAPH_DEBUG
<<
"Dot has the wrong shape. "
<<
vector_to_string
(
mdot
->
get_shape
());
return
nn
;
return
false
;
}
auto
arg0
=
mdot
->
get_input_op
(
0
);
...
...
@@ -191,7 +190,8 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
auto
reshape1
=
std
::
make_shared
<
op
::
Reshape
>
(
arg1
,
AxisVector
{
1
,
0
},
reshape1_shape
);
auto
tdot
=
std
::
shared_ptr
<
Node
>
(
new
op
::
Dot
(
reshape1
,
reshape0
));
return
tdot
;
ngraph
::
replace_node
(
m
.
match_root
(),
tdot
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
preshape
,
callback
);
...
...
src/ngraph/pattern/core_fusion.cpp
View file @
fc9018dc
...
...
@@ -61,19 +61,19 @@ void pass::CoreFusion::construct_relu_pattern()
pattern
::
gr_callback_fn
callback
=
[
val
,
zero
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In a callback for construct_relu_pattern against "
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
shared_ptr
<
Node
>
nn
;
auto
pattern_map
=
m
.
get_pattern_map
();
auto
mzero
=
m
.
get_pattern_map
()[
zero
];
if
(
!
is_zero
(
mzero
))
{
NGRAPH_DEBUG
<<
"zero constant = "
<<
mzero
->
get_name
()
<<
" not equal to 0
\n
"
;
return
nn
;
return
false
;
}
auto
mpattern
=
m
.
match_root
();
auto
cg
=
shared_ptr
<
Node
>
(
new
op
::
Relu
(
pattern_map
[
val
]));
return
cg
;
ngraph
::
replace_node
(
m
.
match_root
(),
cg
);
return
true
;
};
auto
m
=
make_shared
<
pattern
::
Matcher
>
(
max
,
callback
);
...
...
src/ngraph/pattern/matcher.cpp
100755 → 100644
View file @
fc9018dc
...
...
@@ -202,7 +202,7 @@ namespace ngraph
return
false
;
}
std
::
shared_ptr
<
Node
>
Matcher
::
process_match
(
::
ngraph
::
pattern
::
gr_callback_fn
callback
)
bool
Matcher
::
process_match
(
::
ngraph
::
pattern
::
gr_callback_fn
callback
)
{
gr_callback_fn
cb
=
m_callback
;
if
(
callback
)
...
...
src/ngraph/pattern/matcher.hpp
View file @
fc9018dc
...
...
@@ -32,7 +32,7 @@ namespace ngraph
namespace
pattern
{
using
gr_callback_fn
=
std
::
function
<
std
::
shared_ptr
<
Node
>
(
class
Matcher
&
m
)
>
;
using
gr_callback_fn
=
std
::
function
<
bool
(
class
Matcher
&
m
)
>
;
namespace
op
{
...
...
@@ -63,7 +63,7 @@ namespace ngraph
/// \param graph_node is an input graph to be matched against
bool
match
(
const
std
::
shared_ptr
<
Node
>&
graph_node
);
std
::
shared_ptr
<
Node
>
process_match
(
gr_callback_fn
callback
=
nullptr
);
bool
process_match
(
gr_callback_fn
callback
=
nullptr
);
void
reset
()
{}
std
::
shared_ptr
<
Node
>
pattern_node
()
{
return
m_pattern_node
;
}
...
...
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
View file @
fc9018dc
...
...
@@ -152,7 +152,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern()
m_matmul
->
get_is_arg1_transposed
(),
m_broadcast
->
get_broadcast_axes
());
return
mmb
;
ngraph
::
replace_node
(
m
.
match_root
(),
mmb
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
padd
,
callback
);
...
...
@@ -182,7 +183,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
NGRAPH_DEBUG
<<
"In callback for construct_matmul_pattern against node = "
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
std
::
shared_ptr
<
Node
>
nn
;
auto
mpattern
=
m
.
match_root
();
auto
dot
=
m
.
match_root
();
...
...
@@ -190,33 +190,33 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
if
(
mpattern
->
get_element_type
()
!=
element
::
f32
)
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
mpattern
->
get_name
()
<<
" type is not float!"
;
return
nn
;
return
false
;
}
if
(
dot
->
get_shape
().
size
()
!=
2
)
{
NGRAPH_DEBUG
<<
"dot = "
<<
dot
->
get_name
()
<<
" shape is not equal to 2!"
;
return
nn
;
return
false
;
}
if
(
shape_size
(
dot
->
get_shape
())
==
0
)
{
NGRAPH_DEBUG
<<
"dot has a zero dimension"
;
return
nn
;
return
false
;
}
bool
transpose_w
=
false
;
Shape
shape_arg0
{
pattern_map
[
W
]
->
get_shape
()};
if
(
!
init_cblas_arg
(
dot
->
get_input_op
(
0
),
pattern_map
[
W
],
transpose_w
,
shape_arg0
))
{
return
nn
;
return
false
;
}
bool
transpose_x
=
false
;
Shape
shape_arg1
{
pattern_map
[
x
]
->
get_shape
()};
if
(
!
init_cblas_arg
(
dot
->
get_input_op
(
1
),
pattern_map
[
x
],
transpose_x
,
shape_arg1
))
{
return
nn
;
return
false
;
}
auto
cg
=
std
::
shared_ptr
<
Node
>
(
new
op
::
MatmulBias
(
pattern_map
[
W
],
...
...
@@ -226,7 +226,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
shape_arg1
,
transpose_w
,
transpose_x
));
return
cg
;
ngraph
::
replace_node
(
mpattern
,
cg
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
pdot
,
callback
);
...
...
@@ -286,7 +288,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_bn pattern against "
<<
m
.
match_root
()
->
get_name
();
std
::
shared_ptr
<
Node
>
nn
=
nullptr
;
//TODO - add assert's based on the matched node
auto
pattern_map
=
m
.
get_pattern_map
();
NGRAPH_DEBUG
<<
"Input: "
<<
pattern_map
[
input
]
->
get_name
()
<<
" "
...
...
@@ -306,7 +307,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
if
(
pattern_map
[
input
]
->
get_shape
().
size
()
!=
4
)
{
NGRAPH_DEBUG
<<
"Input to bn doesnt not have a rank=4, so not fusing"
;
return
nn
;
return
false
;
}
Shape
bn_output_shape
{
m
.
match_root
()
->
get_shape
()};
Shape
m_bn_mean_shape
{
pattern_map
[
mean_label
]
->
get_shape
()};
...
...
@@ -320,7 +321,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
auto
normalized_output
=
std
::
shared_ptr
<
Node
>
(
new
op
::
GetOutputElement
(
bn_node
,
0
));
return
normalized_output
;
ngraph
::
replace_node
(
m
.
match_root
(),
normalized_output
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
add_beta
,
callback
);
...
...
@@ -408,7 +410,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
pad_input
,
pad_value
,
pad_label
,
reshape_label
,
conv_filter
,
conv_label
](
pattern
::
Matcher
&
m
)
->
std
::
shared_ptr
<
Node
>
{
pattern
::
Matcher
&
m
)
{
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pad_value_op
=
std
::
dynamic_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
pad_value
]);
...
...
@@ -420,8 +422,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
pattern_map
[
reshape_label
]);
const
auto
&
input_order
=
matched_reshape
->
get_input_order
();
auto
hoisted_reshape_output_shape
=
apply_permutation
<
Shape
::
value_type
>
(
pattern_map
[
pad_input
]
->
get_shape
(),
input_order
);
auto
hoisted_reshape_output_shape
=
apply_permutation
<
Shape
::
value_type
>
(
pattern_map
[
pad_input
]
->
get_shape
(),
input_order
);
auto
hoisted_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
pattern_map
[
pad_input
],
...
...
@@ -436,7 +438,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
input_order
[
0
],
input_order
[
1
]))
{
return
nullptr
;
return
false
;
}
CoordinateDiff
padding_below
{
static_cast
<
CoordinateDiff
::
value_type
>
(
...
...
@@ -457,7 +459,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
padding_above
,
matched_conv
->
get_data_dilation_strides
());
return
zero_padded_conv
;
ngraph
::
replace_node
(
m
.
match_root
(),
zero_padded_conv
);
return
true
;
};
this
->
add_matcher
(
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
conv_label
,
callback
));
...
...
@@ -483,8 +486,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
auto
conv_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
conv
,
nullptr
,
NodeVector
{
conv
});
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
pad_input
,
pad_value
,
pad_label
,
conv_filter
,
conv_label
](
pattern
::
Matcher
&
m
)
->
std
::
shared_ptr
<
Node
>
{
[
pad_input
,
pad_value
,
pad_label
,
conv_filter
,
conv_label
](
pattern
::
Matcher
&
m
)
{
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pad_value_op
=
std
::
dynamic_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
pad_value
]);
...
...
@@ -501,7 +503,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
0
,
1
))
{
return
nullptr
;
return
false
;
}
CoordinateDiff
padding_below
{
...
...
@@ -520,7 +522,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
padding_above
,
matched_conv
->
get_data_dilation_strides
());
return
zero_padded_conv
;
ngraph
::
replace_node
(
m
.
match_root
(),
zero_padded_conv
);
return
true
;
};
this
->
add_matcher
(
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
conv_label
,
callback
));
...
...
@@ -541,8 +544,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
auto
divide_1_over_exp
=
std
::
make_shared
<
op
::
Divide
>
(
broadcast_constant
,
add_exp
);
//Define a call back that needs to called once the DFG matches the pattern
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
input
](
pattern
::
Matcher
&
m
)
->
std
::
shared_ptr
<
Node
>
{
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
input
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_sigmoid pattern against "
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
...
...
@@ -550,18 +552,19 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
if
(
m
.
match_root
()
->
get_element_type
()
!=
element
::
f32
)
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
" type is not float!"
;
return
nullptr
;
return
false
;
}
if
(
m
.
match_root
()
->
get_outputs
().
size
()
!=
pattern_map
[
input
]
->
get_outputs
().
size
())
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
"input= "
<<
pattern_map
[
input
]
->
get_name
()
<<
"size dont match!"
;
return
nullptr
;
return
false
;
}
auto
sigmoid_node
=
std
::
make_shared
<
op
::
Sigmoid
>
(
pattern_map
[
input
]);
return
sigmoid_node
;
ngraph
::
replace_node
(
m
.
match_root
(),
sigmoid_node
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
divide_1_over_exp
,
callback
);
...
...
@@ -593,26 +596,26 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop()
auto
negtive_2
=
std
::
make_shared
<
op
::
Negative
>
(
multiply_2
);
//Define a call back that needs to called once the DFG matches the pattern
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
input
,
delta
](
pattern
::
Matcher
&
m
)
->
std
::
shared_ptr
<
Node
>
{
ngraph
::
pattern
::
gr_callback_fn
callback
=
[
input
,
delta
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In a callback for construct_fprop_sigmoid pattern against "
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
if
(
m
.
match_root
()
->
get_element_type
()
!=
element
::
f32
)
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
" type is not float!"
;
return
nullptr
;
return
false
;
}
if
(
m
.
match_root
()
->
get_shape
().
size
()
!=
pattern_map
[
input
]
->
get_shape
().
size
())
{
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
"input= "
<<
pattern_map
[
input
]
->
get_name
()
<<
"size dont match!"
;
return
nullptr
;
return
false
;
}
auto
dsigmoid
=
std
::
make_shared
<
op
::
SigmoidBackprop
>
(
pattern_map
[
input
],
pattern_map
[
delta
]);
return
dsigmoid
;
ngraph
::
replace_node
(
m
.
match_root
(),
dsigmoid
);
return
true
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
negtive_2
,
callback
);
...
...
@@ -641,7 +644,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
NGRAPH_DEBUG
<<
"In callback for construct_conv_bias against node = "
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
std
::
shared_ptr
<
Node
>
nn
;
auto
conv
=
std
::
dynamic_pointer_cast
<
op
::
Convolution
>
(
m
.
match_root
()
->
get_input_op
(
0
));
if
(
conv
->
get_input_shape
(
0
).
size
()
==
4
)
...
...
@@ -658,17 +660,19 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
auto
bias_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
bias
,
order
,
Shape
{
conv
->
get_input_shape
(
1
)[
0
]});
auto
conv_bias
=
std
::
shared_ptr
<
Node
>
(
new
op
::
ConvolutionBias
(
conv
,
bias_reshape
));
return
conv_bias
;
ngraph
::
replace_node
(
m
.
match_root
(),
conv_bias
);
return
true
;
}
else
{
auto
conv_bias
=
std
::
shared_ptr
<
Node
>
(
new
op
::
ConvolutionBias
(
conv
,
bias
));
return
conv_bias
;
ngraph
::
replace_node
(
m
.
match_root
(),
conv_bias
);
return
true
;
}
}
NGRAPH_DEBUG
<<
"mpattern = "
<<
m
.
match_root
()
->
get_name
()
<<
"conv_bias fusion skipped due to input rank size != 4."
;
return
std
::
shared_ptr
<
Node
>
(
nullptr
)
;
return
false
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
p_conv_bias
,
callback
);
...
...
test/pattern.cpp
View file @
fc9018dc
...
...
@@ -169,12 +169,11 @@ public:
NGRAPH_DEBUG
<<
"second_node = "
<<
second_node
->
get_name
()
<<
" , pattern = "
<<
pattern_map
[
pattern
]
->
get_name
();
std
::
shared_ptr
<
ngraph
::
Node
>
nn
=
nullptr
;
if
(
pattern_map
[
pattern
]
->
get_element_type
()
!=
const_node
->
get_element_type
()
||
pattern_map
[
pattern
]
->
get_shape
()
!=
const_node
->
get_shape
())
{
NGRAPH_DEBUG
<<
"Operands' types and/or shape don't match"
;
return
nn
;
return
false
;
}
auto
const_values
=
const_node
->
get_vector
<
int32_t
>
();
...
...
@@ -184,9 +183,11 @@ public:
if
(
!
all_ones
)
{
NGRAPH_DEBUG
<<
"Constant vector's values aren't equal to 1"
;
return
nn
;
return
false
;
}
return
pattern_map
[
pattern
];
ngraph
::
replace_node
(
m
.
match_root
(),
pattern_map
[
pattern
]);
return
true
;
};
auto
m
=
make_shared
<
TestMatcher
>
(
pattern
*
iconst1
,
callback
);
...
...
@@ -213,14 +214,11 @@ public:
NGRAPH_DEBUG
<<
"second_node = "
<<
second_node
->
get_name
()
<<
" , pattern = "
<<
pattern_map
[
pattern
]
->
get_name
();
//ASSERT_NE(nullptr, const_node);
std
::
shared_ptr
<
ngraph
::
Node
>
nn
=
nullptr
;
if
(
pattern_map
[
pattern
]
->
get_element_type
()
!=
const_node
->
get_element_type
()
||
pattern_map
[
pattern
]
->
get_shape
()
!=
const_node
->
get_shape
())
{
NGRAPH_DEBUG
<<
"Operands' types and/or shape don't match"
;
return
nn
;
return
false
;
}
auto
const_values
=
const_node
->
get_vector
<
int
>
();
...
...
@@ -230,10 +228,11 @@ public:
if
(
!
all_zeros
)
{
NGRAPH_DEBUG
<<
"Constant vector's values aren't equal to 0"
;
return
nn
;
return
false
;
}
return
pattern_map
[
pattern
];
ngraph
::
replace_node
(
m
.
match_root
(),
pattern_map
[
pattern
]);
return
true
;
};
auto
m
=
make_shared
<
TestMatcher
>
(
pattern
+
iconst0
,
callback
);
...
...
@@ -252,7 +251,9 @@ public:
NGRAPH_DEBUG
<<
"reducee = "
<<
reducee
->
get_name
();
auto
sum
=
std
::
shared_ptr
<
ngraph
::
Node
>
(
new
op
::
Sum
(
reducee
,
reduce
->
get_reduction_axes
()));
return
sum
;
ngraph
::
replace_node
(
m
.
match_root
(),
sum
);
return
true
;
};
auto
m
=
make_shared
<
TestMatcher
>
(
sum_pattern
,
callback
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment