Commit fc9018dc authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

update GraphRewrite API (#686)

parent 9d84c439
......@@ -39,11 +39,8 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
NGRAPH_DEBUG << "Matcher " << matcher << " matched " << node << " , "
<< node->get_name();
rewritten = true;
auto result = matcher->process_match();
if (result)
if (matcher->process_match())
{
f->replace_node(node, result);
//move onto the next node
break;
}
}
......
......@@ -63,8 +63,6 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<ngraph::Node> nn;
auto gop = pattern_map[op];
auto r1 = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
......@@ -72,7 +70,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
if (r1->get_shape() != gop->get_shape())
{
NGRAPH_DEBUG << "Not a no-op; Shapes are different!";
return nn;
return false;
}
Shape do_r1(r1->get_shape().size());
......@@ -81,10 +79,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
if (do_r1 != r1->get_input_order())
{
NGRAPH_DEBUG << "Not a no-op; Not in default input order!";
return nn;
return false;
}
return gop;
ngraph::replace_node(m.match_root(), gop);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape1, callback);
......@@ -105,7 +104,6 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<ngraph::Node> nn;
auto gop = pattern_map[op];
if (gop->get_shape() != m.match_root()->get_shape())
......@@ -115,7 +113,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
<< "shape = " << vector_to_string(gop->get_shape());
NGRAPH_DEBUG << "match_root " << m.match_root()->get_name()
<< "shape = " << vector_to_string(m.match_root()->get_shape());
return nn;
return false;
}
auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
......@@ -134,7 +132,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2)
{
NGRAPH_DEBUG << "Two reshapes were removed!";
return gop;
ngraph::replace_node(m.match_root(), gop);
return true;
}
auto perm1 = apply_permutation(do_r1, r1->get_input_order());
......@@ -142,10 +141,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if (perm2 == do_r1)
{
NGRAPH_DEBUG << "Two transposes were removed!";
return gop;
ngraph::replace_node(m.match_root(), gop);
return true;
}
return nn;
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape2, callback);
this->add_matcher(m);
......@@ -165,21 +165,20 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.match_root()->get_name();
std::shared_ptr<Node> nn;
auto mtranspose = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
//this also checks the rank
if (mtranspose->get_input_order() != AxisVector{1, 0})
{
NGRAPH_DEBUG << "Reshape isn't transpose. "
<< vector_to_string(mtranspose->get_input_order());
return nn;
return false;
}
auto mdot = mtranspose->get_input_op(0);
if (mdot->get_shape().size() != 2)
{
NGRAPH_DEBUG << "Dot has the wrong shape. " << vector_to_string(mdot->get_shape());
return nn;
return false;
}
auto arg0 = mdot->get_input_op(0);
......@@ -191,7 +190,8 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
auto tdot = std::shared_ptr<Node>(new op::Dot(reshape1, reshape0));
return tdot;
ngraph::replace_node(m.match_root(), tdot);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(preshape, callback);
......
......@@ -61,19 +61,19 @@ void pass::CoreFusion::construct_relu_pattern()
pattern::gr_callback_fn callback = [val, zero](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_relu_pattern against "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
shared_ptr<Node> nn;
auto pattern_map = m.get_pattern_map();
auto mzero = m.get_pattern_map()[zero];
if (!is_zero(mzero))
{
NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n";
return nn;
return false;
}
auto mpattern = m.match_root();
auto cg = shared_ptr<Node>(new op::Relu(pattern_map[val]));
return cg;
ngraph::replace_node(m.match_root(), cg);
return true;
};
auto m = make_shared<pattern::Matcher>(max, callback);
......
......@@ -202,7 +202,7 @@ namespace ngraph
return false;
}
std::shared_ptr<Node> Matcher::process_match(::ngraph::pattern::gr_callback_fn callback)
bool Matcher::process_match(::ngraph::pattern::gr_callback_fn callback)
{
gr_callback_fn cb = m_callback;
if (callback)
......
......@@ -32,7 +32,7 @@ namespace ngraph
namespace pattern
{
using gr_callback_fn = std::function<std::shared_ptr<Node>(class Matcher& m)>;
using gr_callback_fn = std::function<bool(class Matcher& m)>;
namespace op
{
......@@ -63,7 +63,7 @@ namespace ngraph
/// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node);
std::shared_ptr<Node> process_match(gr_callback_fn callback = nullptr);
bool process_match(gr_callback_fn callback = nullptr);
void reset() {}
std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
......
This diff is collapsed.
......@@ -169,12 +169,11 @@ public:
NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name();
std::shared_ptr<ngraph::Node> nn = nullptr;
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape())
{
NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return nn;
return false;
}
auto const_values = const_node->get_vector<int32_t>();
......@@ -184,9 +183,11 @@ public:
if (!all_ones)
{
NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
return nn;
return false;
}
return pattern_map[pattern];
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
return true;
};
auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
......@@ -213,14 +214,11 @@ public:
NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name();
//ASSERT_NE(nullptr, const_node);
std::shared_ptr<ngraph::Node> nn = nullptr;
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape())
{
NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return nn;
return false;
}
auto const_values = const_node->get_vector<int>();
......@@ -230,10 +228,11 @@ public:
if (!all_zeros)
{
NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
return nn;
return false;
}
return pattern_map[pattern];
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
return true;
};
auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
......@@ -252,7 +251,9 @@ public:
NGRAPH_DEBUG << "reducee = " << reducee->get_name();
auto sum =
std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes()));
return sum;
ngraph::replace_node(m.match_root(), sum);
return true;
};
auto m = make_shared<TestMatcher>(sum_pattern, callback);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment