Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
fc9018dc
Commit
fc9018dc
authored
Mar 20, 2018
by
Nick Korovaiko
Committed by
Scott Cyphers
Mar 20, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update GraphRewrite API (#686)
parent
9d84c439
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
34 additions
and
36 deletions
+34
-36
graph_rewrite.cpp
src/ngraph/pass/graph_rewrite.cpp
+1
-4
reshape_elimination.cpp
src/ngraph/pass/reshape_elimination.cpp
+14
-14
core_fusion.cpp
src/ngraph/pattern/core_fusion.cpp
+4
-4
matcher.cpp
src/ngraph/pattern/matcher.cpp
+1
-1
matcher.hpp
src/ngraph/pattern/matcher.hpp
+2
-2
cpu_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
+0
-0
pattern.cpp
test/pattern.cpp
+12
-11
No files found.
src/ngraph/pass/graph_rewrite.cpp
View file @
fc9018dc
...
@@ -39,11 +39,8 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
...
@@ -39,11 +39,8 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
NGRAPH_DEBUG
<<
"Matcher "
<<
matcher
<<
" matched "
<<
node
<<
" , "
NGRAPH_DEBUG
<<
"Matcher "
<<
matcher
<<
" matched "
<<
node
<<
" , "
<<
node
->
get_name
();
<<
node
->
get_name
();
rewritten
=
true
;
rewritten
=
true
;
auto
result
=
matcher
->
process_match
();
if
(
matcher
->
process_match
())
if
(
result
)
{
{
f
->
replace_node
(
node
,
result
);
//move onto the next node
break
;
break
;
}
}
}
}
...
...
src/ngraph/pass/reshape_elimination.cpp
View file @
fc9018dc
...
@@ -63,8 +63,6 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
...
@@ -63,8 +63,6 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
NGRAPH_DEBUG
<<
"In callback for construct_identity_reshape_pattern against node = "
NGRAPH_DEBUG
<<
"In callback for construct_identity_reshape_pattern against node = "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
std
::
shared_ptr
<
ngraph
::
Node
>
nn
;
auto
gop
=
pattern_map
[
op
];
auto
gop
=
pattern_map
[
op
];
auto
r1
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
auto
r1
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
...
@@ -72,7 +70,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
...
@@ -72,7 +70,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
if
(
r1
->
get_shape
()
!=
gop
->
get_shape
())
if
(
r1
->
get_shape
()
!=
gop
->
get_shape
())
{
{
NGRAPH_DEBUG
<<
"Not a no-op; Shapes are different!"
;
NGRAPH_DEBUG
<<
"Not a no-op; Shapes are different!"
;
return
nn
;
return
false
;
}
}
Shape
do_r1
(
r1
->
get_shape
().
size
());
Shape
do_r1
(
r1
->
get_shape
().
size
());
...
@@ -81,10 +79,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
...
@@ -81,10 +79,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
if
(
do_r1
!=
r1
->
get_input_order
())
if
(
do_r1
!=
r1
->
get_input_order
())
{
{
NGRAPH_DEBUG
<<
"Not a no-op; Not in default input order!"
;
NGRAPH_DEBUG
<<
"Not a no-op; Not in default input order!"
;
return
nn
;
return
false
;
}
}
return
gop
;
ngraph
::
replace_node
(
m
.
match_root
(),
gop
);
return
true
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape1
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape1
,
callback
);
...
@@ -105,7 +104,6 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
...
@@ -105,7 +104,6 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
pattern_map
=
m
.
get_pattern_map
();
std
::
shared_ptr
<
ngraph
::
Node
>
nn
;
auto
gop
=
pattern_map
[
op
];
auto
gop
=
pattern_map
[
op
];
if
(
gop
->
get_shape
()
!=
m
.
match_root
()
->
get_shape
())
if
(
gop
->
get_shape
()
!=
m
.
match_root
()
->
get_shape
())
...
@@ -115,7 +113,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
...
@@ -115,7 +113,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
<<
"shape = "
<<
vector_to_string
(
gop
->
get_shape
());
<<
"shape = "
<<
vector_to_string
(
gop
->
get_shape
());
NGRAPH_DEBUG
<<
"match_root "
<<
m
.
match_root
()
->
get_name
()
NGRAPH_DEBUG
<<
"match_root "
<<
m
.
match_root
()
->
get_name
()
<<
"shape = "
<<
vector_to_string
(
m
.
match_root
()
->
get_shape
());
<<
"shape = "
<<
vector_to_string
(
m
.
match_root
()
->
get_shape
());
return
nn
;
return
false
;
}
}
auto
r2
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
auto
r2
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
...
@@ -134,7 +132,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
...
@@ -134,7 +132,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if
(
r1
->
get_input_order
()
==
do_r1
&&
r2
->
get_input_order
()
==
do_r2
)
if
(
r1
->
get_input_order
()
==
do_r1
&&
r2
->
get_input_order
()
==
do_r2
)
{
{
NGRAPH_DEBUG
<<
"Two reshapes were removed!"
;
NGRAPH_DEBUG
<<
"Two reshapes were removed!"
;
return
gop
;
ngraph
::
replace_node
(
m
.
match_root
(),
gop
);
return
true
;
}
}
auto
perm1
=
apply_permutation
(
do_r1
,
r1
->
get_input_order
());
auto
perm1
=
apply_permutation
(
do_r1
,
r1
->
get_input_order
());
...
@@ -142,10 +141,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
...
@@ -142,10 +141,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if
(
perm2
==
do_r1
)
if
(
perm2
==
do_r1
)
{
{
NGRAPH_DEBUG
<<
"Two transposes were removed!"
;
NGRAPH_DEBUG
<<
"Two transposes were removed!"
;
return
gop
;
ngraph
::
replace_node
(
m
.
match_root
(),
gop
);
return
true
;
}
}
return
nn
;
return
false
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape2
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
reshape2
,
callback
);
this
->
add_matcher
(
m
);
this
->
add_matcher
(
m
);
...
@@ -165,21 +165,20 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
...
@@ -165,21 +165,20 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
NGRAPH_DEBUG
<<
"In callback for construct_dot_transpose_pattern against node = "
NGRAPH_DEBUG
<<
"In callback for construct_dot_transpose_pattern against node = "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
std
::
shared_ptr
<
Node
>
nn
;
auto
mtranspose
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
auto
mtranspose
=
std
::
dynamic_pointer_cast
<
op
::
Reshape
>
(
m
.
match_root
());
//this also checks the rank
//this also checks the rank
if
(
mtranspose
->
get_input_order
()
!=
AxisVector
{
1
,
0
})
if
(
mtranspose
->
get_input_order
()
!=
AxisVector
{
1
,
0
})
{
{
NGRAPH_DEBUG
<<
"Reshape isn't transpose. "
NGRAPH_DEBUG
<<
"Reshape isn't transpose. "
<<
vector_to_string
(
mtranspose
->
get_input_order
());
<<
vector_to_string
(
mtranspose
->
get_input_order
());
return
nn
;
return
false
;
}
}
auto
mdot
=
mtranspose
->
get_input_op
(
0
);
auto
mdot
=
mtranspose
->
get_input_op
(
0
);
if
(
mdot
->
get_shape
().
size
()
!=
2
)
if
(
mdot
->
get_shape
().
size
()
!=
2
)
{
{
NGRAPH_DEBUG
<<
"Dot has the wrong shape. "
<<
vector_to_string
(
mdot
->
get_shape
());
NGRAPH_DEBUG
<<
"Dot has the wrong shape. "
<<
vector_to_string
(
mdot
->
get_shape
());
return
nn
;
return
false
;
}
}
auto
arg0
=
mdot
->
get_input_op
(
0
);
auto
arg0
=
mdot
->
get_input_op
(
0
);
...
@@ -191,7 +190,8 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
...
@@ -191,7 +190,8 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
auto
reshape1
=
std
::
make_shared
<
op
::
Reshape
>
(
arg1
,
AxisVector
{
1
,
0
},
reshape1_shape
);
auto
reshape1
=
std
::
make_shared
<
op
::
Reshape
>
(
arg1
,
AxisVector
{
1
,
0
},
reshape1_shape
);
auto
tdot
=
std
::
shared_ptr
<
Node
>
(
new
op
::
Dot
(
reshape1
,
reshape0
));
auto
tdot
=
std
::
shared_ptr
<
Node
>
(
new
op
::
Dot
(
reshape1
,
reshape0
));
return
tdot
;
ngraph
::
replace_node
(
m
.
match_root
(),
tdot
);
return
true
;
};
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
preshape
,
callback
);
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
preshape
,
callback
);
...
...
src/ngraph/pattern/core_fusion.cpp
View file @
fc9018dc
...
@@ -61,19 +61,19 @@ void pass::CoreFusion::construct_relu_pattern()
...
@@ -61,19 +61,19 @@ void pass::CoreFusion::construct_relu_pattern()
pattern
::
gr_callback_fn
callback
=
[
val
,
zero
](
pattern
::
Matcher
&
m
)
{
pattern
::
gr_callback_fn
callback
=
[
val
,
zero
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In a callback for construct_relu_pattern against "
NGRAPH_DEBUG
<<
"In a callback for construct_relu_pattern against "
<<
m
.
match_root
()
->
get_name
();
<<
m
.
match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
shared_ptr
<
Node
>
nn
;
auto
pattern_map
=
m
.
get_pattern_map
();
auto
mzero
=
m
.
get_pattern_map
()[
zero
];
auto
mzero
=
m
.
get_pattern_map
()[
zero
];
if
(
!
is_zero
(
mzero
))
if
(
!
is_zero
(
mzero
))
{
{
NGRAPH_DEBUG
<<
"zero constant = "
<<
mzero
->
get_name
()
<<
" not equal to 0
\n
"
;
NGRAPH_DEBUG
<<
"zero constant = "
<<
mzero
->
get_name
()
<<
" not equal to 0
\n
"
;
return
nn
;
return
false
;
}
}
auto
mpattern
=
m
.
match_root
();
auto
mpattern
=
m
.
match_root
();
auto
cg
=
shared_ptr
<
Node
>
(
new
op
::
Relu
(
pattern_map
[
val
]));
auto
cg
=
shared_ptr
<
Node
>
(
new
op
::
Relu
(
pattern_map
[
val
]));
return
cg
;
ngraph
::
replace_node
(
m
.
match_root
(),
cg
);
return
true
;
};
};
auto
m
=
make_shared
<
pattern
::
Matcher
>
(
max
,
callback
);
auto
m
=
make_shared
<
pattern
::
Matcher
>
(
max
,
callback
);
...
...
src/ngraph/pattern/matcher.cpp
100755 → 100644
View file @
fc9018dc
...
@@ -202,7 +202,7 @@ namespace ngraph
...
@@ -202,7 +202,7 @@ namespace ngraph
return
false
;
return
false
;
}
}
std
::
shared_ptr
<
Node
>
Matcher
::
process_match
(
::
ngraph
::
pattern
::
gr_callback_fn
callback
)
bool
Matcher
::
process_match
(
::
ngraph
::
pattern
::
gr_callback_fn
callback
)
{
{
gr_callback_fn
cb
=
m_callback
;
gr_callback_fn
cb
=
m_callback
;
if
(
callback
)
if
(
callback
)
...
...
src/ngraph/pattern/matcher.hpp
View file @
fc9018dc
...
@@ -32,7 +32,7 @@ namespace ngraph
...
@@ -32,7 +32,7 @@ namespace ngraph
namespace
pattern
namespace
pattern
{
{
using
gr_callback_fn
=
std
::
function
<
std
::
shared_ptr
<
Node
>
(
class
Matcher
&
m
)
>
;
using
gr_callback_fn
=
std
::
function
<
bool
(
class
Matcher
&
m
)
>
;
namespace
op
namespace
op
{
{
...
@@ -63,7 +63,7 @@ namespace ngraph
...
@@ -63,7 +63,7 @@ namespace ngraph
/// \param graph_node is an input graph to be matched against
/// \param graph_node is an input graph to be matched against
bool
match
(
const
std
::
shared_ptr
<
Node
>&
graph_node
);
bool
match
(
const
std
::
shared_ptr
<
Node
>&
graph_node
);
std
::
shared_ptr
<
Node
>
process_match
(
gr_callback_fn
callback
=
nullptr
);
bool
process_match
(
gr_callback_fn
callback
=
nullptr
);
void
reset
()
{}
void
reset
()
{}
std
::
shared_ptr
<
Node
>
pattern_node
()
{
return
m_pattern_node
;
}
std
::
shared_ptr
<
Node
>
pattern_node
()
{
return
m_pattern_node
;
}
...
...
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
View file @
fc9018dc
This diff is collapsed.
Click to expand it.
test/pattern.cpp
View file @
fc9018dc
...
@@ -169,12 +169,11 @@ public:
...
@@ -169,12 +169,11 @@ public:
NGRAPH_DEBUG
<<
"second_node = "
<<
second_node
->
get_name
()
NGRAPH_DEBUG
<<
"second_node = "
<<
second_node
->
get_name
()
<<
" , pattern = "
<<
pattern_map
[
pattern
]
->
get_name
();
<<
" , pattern = "
<<
pattern_map
[
pattern
]
->
get_name
();
std
::
shared_ptr
<
ngraph
::
Node
>
nn
=
nullptr
;
if
(
pattern_map
[
pattern
]
->
get_element_type
()
!=
const_node
->
get_element_type
()
||
if
(
pattern_map
[
pattern
]
->
get_element_type
()
!=
const_node
->
get_element_type
()
||
pattern_map
[
pattern
]
->
get_shape
()
!=
const_node
->
get_shape
())
pattern_map
[
pattern
]
->
get_shape
()
!=
const_node
->
get_shape
())
{
{
NGRAPH_DEBUG
<<
"Operands' types and/or shape don't match"
;
NGRAPH_DEBUG
<<
"Operands' types and/or shape don't match"
;
return
nn
;
return
false
;
}
}
auto
const_values
=
const_node
->
get_vector
<
int32_t
>
();
auto
const_values
=
const_node
->
get_vector
<
int32_t
>
();
...
@@ -184,9 +183,11 @@ public:
...
@@ -184,9 +183,11 @@ public:
if
(
!
all_ones
)
if
(
!
all_ones
)
{
{
NGRAPH_DEBUG
<<
"Constant vector's values aren't equal to 1"
;
NGRAPH_DEBUG
<<
"Constant vector's values aren't equal to 1"
;
return
nn
;
return
false
;
}
}
return
pattern_map
[
pattern
];
ngraph
::
replace_node
(
m
.
match_root
(),
pattern_map
[
pattern
]);
return
true
;
};
};
auto
m
=
make_shared
<
TestMatcher
>
(
pattern
*
iconst1
,
callback
);
auto
m
=
make_shared
<
TestMatcher
>
(
pattern
*
iconst1
,
callback
);
...
@@ -213,14 +214,11 @@ public:
...
@@ -213,14 +214,11 @@ public:
NGRAPH_DEBUG
<<
"second_node = "
<<
second_node
->
get_name
()
NGRAPH_DEBUG
<<
"second_node = "
<<
second_node
->
get_name
()
<<
" , pattern = "
<<
pattern_map
[
pattern
]
->
get_name
();
<<
" , pattern = "
<<
pattern_map
[
pattern
]
->
get_name
();
//ASSERT_NE(nullptr, const_node);
std
::
shared_ptr
<
ngraph
::
Node
>
nn
=
nullptr
;
if
(
pattern_map
[
pattern
]
->
get_element_type
()
!=
const_node
->
get_element_type
()
||
if
(
pattern_map
[
pattern
]
->
get_element_type
()
!=
const_node
->
get_element_type
()
||
pattern_map
[
pattern
]
->
get_shape
()
!=
const_node
->
get_shape
())
pattern_map
[
pattern
]
->
get_shape
()
!=
const_node
->
get_shape
())
{
{
NGRAPH_DEBUG
<<
"Operands' types and/or shape don't match"
;
NGRAPH_DEBUG
<<
"Operands' types and/or shape don't match"
;
return
nn
;
return
false
;
}
}
auto
const_values
=
const_node
->
get_vector
<
int
>
();
auto
const_values
=
const_node
->
get_vector
<
int
>
();
...
@@ -230,10 +228,11 @@ public:
...
@@ -230,10 +228,11 @@ public:
if
(
!
all_zeros
)
if
(
!
all_zeros
)
{
{
NGRAPH_DEBUG
<<
"Constant vector's values aren't equal to 0"
;
NGRAPH_DEBUG
<<
"Constant vector's values aren't equal to 0"
;
return
nn
;
return
false
;
}
}
return
pattern_map
[
pattern
];
ngraph
::
replace_node
(
m
.
match_root
(),
pattern_map
[
pattern
]);
return
true
;
};
};
auto
m
=
make_shared
<
TestMatcher
>
(
pattern
+
iconst0
,
callback
);
auto
m
=
make_shared
<
TestMatcher
>
(
pattern
+
iconst0
,
callback
);
...
@@ -252,7 +251,9 @@ public:
...
@@ -252,7 +251,9 @@ public:
NGRAPH_DEBUG
<<
"reducee = "
<<
reducee
->
get_name
();
NGRAPH_DEBUG
<<
"reducee = "
<<
reducee
->
get_name
();
auto
sum
=
auto
sum
=
std
::
shared_ptr
<
ngraph
::
Node
>
(
new
op
::
Sum
(
reducee
,
reduce
->
get_reduction_axes
()));
std
::
shared_ptr
<
ngraph
::
Node
>
(
new
op
::
Sum
(
reducee
,
reduce
->
get_reduction_axes
()));
return
sum
;
ngraph
::
replace_node
(
m
.
match_root
(),
sum
);
return
true
;
};
};
auto
m
=
make_shared
<
TestMatcher
>
(
sum_pattern
,
callback
);
auto
m
=
make_shared
<
TestMatcher
>
(
sum_pattern
,
callback
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment