Commit fc9018dc authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

update GraphRewrite API (#686)

parent 9d84c439
...@@ -39,11 +39,8 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list( ...@@ -39,11 +39,8 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
NGRAPH_DEBUG << "Matcher " << matcher << " matched " << node << " , " NGRAPH_DEBUG << "Matcher " << matcher << " matched " << node << " , "
<< node->get_name(); << node->get_name();
rewritten = true; rewritten = true;
auto result = matcher->process_match(); if (matcher->process_match())
if (result)
{ {
f->replace_node(node, result);
//move onto the next node
break; break;
} }
} }
......
...@@ -63,8 +63,6 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -63,8 +63,6 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = " NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
std::shared_ptr<ngraph::Node> nn;
auto gop = pattern_map[op]; auto gop = pattern_map[op];
auto r1 = std::dynamic_pointer_cast<op::Reshape>(m.match_root()); auto r1 = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
...@@ -72,7 +70,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -72,7 +70,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
if (r1->get_shape() != gop->get_shape()) if (r1->get_shape() != gop->get_shape())
{ {
NGRAPH_DEBUG << "Not a no-op; Shapes are different!"; NGRAPH_DEBUG << "Not a no-op; Shapes are different!";
return nn; return false;
} }
Shape do_r1(r1->get_shape().size()); Shape do_r1(r1->get_shape().size());
...@@ -81,10 +79,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -81,10 +79,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
if (do_r1 != r1->get_input_order()) if (do_r1 != r1->get_input_order())
{ {
NGRAPH_DEBUG << "Not a no-op; Not in default input order!"; NGRAPH_DEBUG << "Not a no-op; Not in default input order!";
return nn; return false;
} }
return gop; ngraph::replace_node(m.match_root(), gop);
return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape1, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(reshape1, callback);
...@@ -105,7 +104,6 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -105,7 +104,6 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
std::shared_ptr<ngraph::Node> nn;
auto gop = pattern_map[op]; auto gop = pattern_map[op];
if (gop->get_shape() != m.match_root()->get_shape()) if (gop->get_shape() != m.match_root()->get_shape())
...@@ -115,7 +113,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -115,7 +113,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
<< "shape = " << vector_to_string(gop->get_shape()); << "shape = " << vector_to_string(gop->get_shape());
NGRAPH_DEBUG << "match_root " << m.match_root()->get_name() NGRAPH_DEBUG << "match_root " << m.match_root()->get_name()
<< "shape = " << vector_to_string(m.match_root()->get_shape()); << "shape = " << vector_to_string(m.match_root()->get_shape());
return nn; return false;
} }
auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.match_root()); auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
...@@ -134,7 +132,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -134,7 +132,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2) if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2)
{ {
NGRAPH_DEBUG << "Two reshapes were removed!"; NGRAPH_DEBUG << "Two reshapes were removed!";
return gop; ngraph::replace_node(m.match_root(), gop);
return true;
} }
auto perm1 = apply_permutation(do_r1, r1->get_input_order()); auto perm1 = apply_permutation(do_r1, r1->get_input_order());
...@@ -142,10 +141,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -142,10 +141,11 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if (perm2 == do_r1) if (perm2 == do_r1)
{ {
NGRAPH_DEBUG << "Two transposes were removed!"; NGRAPH_DEBUG << "Two transposes were removed!";
return gop; ngraph::replace_node(m.match_root(), gop);
return true;
} }
return nn; return false;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape2, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(reshape2, callback);
this->add_matcher(m); this->add_matcher(m);
...@@ -165,21 +165,20 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -165,21 +165,20 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = " NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
std::shared_ptr<Node> nn;
auto mtranspose = std::dynamic_pointer_cast<op::Reshape>(m.match_root()); auto mtranspose = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
//this also checks the rank //this also checks the rank
if (mtranspose->get_input_order() != AxisVector{1, 0}) if (mtranspose->get_input_order() != AxisVector{1, 0})
{ {
NGRAPH_DEBUG << "Reshape isn't transpose. " NGRAPH_DEBUG << "Reshape isn't transpose. "
<< vector_to_string(mtranspose->get_input_order()); << vector_to_string(mtranspose->get_input_order());
return nn; return false;
} }
auto mdot = mtranspose->get_input_op(0); auto mdot = mtranspose->get_input_op(0);
if (mdot->get_shape().size() != 2) if (mdot->get_shape().size() != 2)
{ {
NGRAPH_DEBUG << "Dot has the wrong shape. " << vector_to_string(mdot->get_shape()); NGRAPH_DEBUG << "Dot has the wrong shape. " << vector_to_string(mdot->get_shape());
return nn; return false;
} }
auto arg0 = mdot->get_input_op(0); auto arg0 = mdot->get_input_op(0);
...@@ -191,7 +190,8 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -191,7 +190,8 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape); auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
auto tdot = std::shared_ptr<Node>(new op::Dot(reshape1, reshape0)); auto tdot = std::shared_ptr<Node>(new op::Dot(reshape1, reshape0));
return tdot; ngraph::replace_node(m.match_root(), tdot);
return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(preshape, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(preshape, callback);
......
...@@ -61,19 +61,19 @@ void pass::CoreFusion::construct_relu_pattern() ...@@ -61,19 +61,19 @@ void pass::CoreFusion::construct_relu_pattern()
pattern::gr_callback_fn callback = [val, zero](pattern::Matcher& m) { pattern::gr_callback_fn callback = [val, zero](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_relu_pattern against " NGRAPH_DEBUG << "In a callback for construct_relu_pattern against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
shared_ptr<Node> nn;
auto pattern_map = m.get_pattern_map();
auto mzero = m.get_pattern_map()[zero]; auto mzero = m.get_pattern_map()[zero];
if (!is_zero(mzero)) if (!is_zero(mzero))
{ {
NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n"; NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n";
return nn; return false;
} }
auto mpattern = m.match_root(); auto mpattern = m.match_root();
auto cg = shared_ptr<Node>(new op::Relu(pattern_map[val])); auto cg = shared_ptr<Node>(new op::Relu(pattern_map[val]));
return cg; ngraph::replace_node(m.match_root(), cg);
return true;
}; };
auto m = make_shared<pattern::Matcher>(max, callback); auto m = make_shared<pattern::Matcher>(max, callback);
......
...@@ -202,7 +202,7 @@ namespace ngraph ...@@ -202,7 +202,7 @@ namespace ngraph
return false; return false;
} }
std::shared_ptr<Node> Matcher::process_match(::ngraph::pattern::gr_callback_fn callback) bool Matcher::process_match(::ngraph::pattern::gr_callback_fn callback)
{ {
gr_callback_fn cb = m_callback; gr_callback_fn cb = m_callback;
if (callback) if (callback)
......
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
namespace pattern namespace pattern
{ {
using gr_callback_fn = std::function<std::shared_ptr<Node>(class Matcher& m)>; using gr_callback_fn = std::function<bool(class Matcher& m)>;
namespace op namespace op
{ {
...@@ -63,7 +63,7 @@ namespace ngraph ...@@ -63,7 +63,7 @@ namespace ngraph
/// \param graph_node is an input graph to be matched against /// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node); bool match(const std::shared_ptr<Node>& graph_node);
std::shared_ptr<Node> process_match(gr_callback_fn callback = nullptr); bool process_match(gr_callback_fn callback = nullptr);
void reset() {} void reset() {}
std::shared_ptr<Node> pattern_node() { return m_pattern_node; } std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
......
...@@ -152,7 +152,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern() ...@@ -152,7 +152,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern()
m_matmul->get_is_arg1_transposed(), m_matmul->get_is_arg1_transposed(),
m_broadcast->get_broadcast_axes()); m_broadcast->get_broadcast_axes());
return mmb; ngraph::replace_node(m.match_root(), mmb);
return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
...@@ -182,7 +183,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern() ...@@ -182,7 +183,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = " NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn;
auto mpattern = m.match_root(); auto mpattern = m.match_root();
auto dot = m.match_root(); auto dot = m.match_root();
...@@ -190,33 +190,33 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern() ...@@ -190,33 +190,33 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
if (mpattern->get_element_type() != element::f32) if (mpattern->get_element_type() != element::f32)
{ {
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!"; NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
return nn; return false;
} }
if (dot->get_shape().size() != 2) if (dot->get_shape().size() != 2)
{ {
NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!"; NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!";
return nn; return false;
} }
if (shape_size(dot->get_shape()) == 0) if (shape_size(dot->get_shape()) == 0)
{ {
NGRAPH_DEBUG << "dot has a zero dimension"; NGRAPH_DEBUG << "dot has a zero dimension";
return nn; return false;
} }
bool transpose_w = false; bool transpose_w = false;
Shape shape_arg0{pattern_map[W]->get_shape()}; Shape shape_arg0{pattern_map[W]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0)) if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0))
{ {
return nn; return false;
} }
bool transpose_x = false; bool transpose_x = false;
Shape shape_arg1{pattern_map[x]->get_shape()}; Shape shape_arg1{pattern_map[x]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1)) if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1))
{ {
return nn; return false;
} }
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W], auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W],
...@@ -226,7 +226,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern() ...@@ -226,7 +226,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
shape_arg1, shape_arg1,
transpose_w, transpose_w,
transpose_x)); transpose_x));
return cg;
ngraph::replace_node(mpattern, cg);
return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(pdot, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(pdot, callback);
...@@ -286,7 +288,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn() ...@@ -286,7 +288,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
std::shared_ptr<Node> nn = nullptr;
//TODO - add assert's based on the matched node //TODO - add assert's based on the matched node
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
NGRAPH_DEBUG << "Input: " << pattern_map[input]->get_name() << " " NGRAPH_DEBUG << "Input: " << pattern_map[input]->get_name() << " "
...@@ -306,7 +307,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn() ...@@ -306,7 +307,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
if (pattern_map[input]->get_shape().size() != 4) if (pattern_map[input]->get_shape().size() != 4)
{ {
NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing"; NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing";
return nn; return false;
} }
Shape bn_output_shape{m.match_root()->get_shape()}; Shape bn_output_shape{m.match_root()->get_shape()};
Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()}; Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()};
...@@ -320,7 +321,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn() ...@@ -320,7 +321,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
auto normalized_output = std::shared_ptr<Node>(new op::GetOutputElement(bn_node, 0)); auto normalized_output = std::shared_ptr<Node>(new op::GetOutputElement(bn_node, 0));
return normalized_output; ngraph::replace_node(m.match_root(), normalized_output);
return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback);
...@@ -408,7 +410,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv( ...@@ -408,7 +410,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::gr_callback_fn callback =
[pad_input, pad_value, pad_label, reshape_label, conv_filter, conv_label]( [pad_input, pad_value, pad_label, reshape_label, conv_filter, conv_label](
pattern::Matcher& m) -> std::shared_ptr<Node> { pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto pad_value_op = std::dynamic_pointer_cast<op::Constant>(pattern_map[pad_value]); auto pad_value_op = std::dynamic_pointer_cast<op::Constant>(pattern_map[pad_value]);
...@@ -420,8 +422,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv( ...@@ -420,8 +422,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
std::dynamic_pointer_cast<op::Reshape>(pattern_map[reshape_label]); std::dynamic_pointer_cast<op::Reshape>(pattern_map[reshape_label]);
const auto& input_order = matched_reshape->get_input_order(); const auto& input_order = matched_reshape->get_input_order();
auto hoisted_reshape_output_shape = auto hoisted_reshape_output_shape = apply_permutation<Shape::value_type>(
apply_permutation<Shape::value_type>(pattern_map[pad_input]->get_shape(), input_order); pattern_map[pad_input]->get_shape(), input_order);
auto hoisted_reshape = std::make_shared<op::Reshape>( auto hoisted_reshape = std::make_shared<op::Reshape>(
pattern_map[pad_input], pattern_map[pad_input],
...@@ -436,7 +438,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv( ...@@ -436,7 +438,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
input_order[0], input_order[0],
input_order[1])) input_order[1]))
{ {
return nullptr; return false;
} }
CoordinateDiff padding_below{static_cast<CoordinateDiff::value_type>( CoordinateDiff padding_below{static_cast<CoordinateDiff::value_type>(
...@@ -457,7 +459,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv( ...@@ -457,7 +459,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
padding_above, padding_above,
matched_conv->get_data_dilation_strides()); matched_conv->get_data_dilation_strides());
return zero_padded_conv; ngraph::replace_node(m.match_root(), zero_padded_conv);
return true;
}; };
this->add_matcher(std::make_shared<ngraph::pattern::Matcher>(conv_label, callback)); this->add_matcher(std::make_shared<ngraph::pattern::Matcher>(conv_label, callback));
...@@ -483,8 +486,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv() ...@@ -483,8 +486,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv}); auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::gr_callback_fn callback =
[pad_input, pad_value, pad_label, conv_filter, conv_label]( [pad_input, pad_value, pad_label, conv_filter, conv_label](pattern::Matcher& m) {
pattern::Matcher& m) -> std::shared_ptr<Node> {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto pad_value_op = std::dynamic_pointer_cast<op::Constant>(pattern_map[pad_value]); auto pad_value_op = std::dynamic_pointer_cast<op::Constant>(pattern_map[pad_value]);
...@@ -501,7 +503,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv() ...@@ -501,7 +503,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
0, 0,
1)) 1))
{ {
return nullptr; return false;
} }
CoordinateDiff padding_below{ CoordinateDiff padding_below{
...@@ -520,7 +522,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv() ...@@ -520,7 +522,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
padding_above, padding_above,
matched_conv->get_data_dilation_strides()); matched_conv->get_data_dilation_strides());
return zero_padded_conv; ngraph::replace_node(m.match_root(), zero_padded_conv);
return true;
}; };
this->add_matcher(std::make_shared<ngraph::pattern::Matcher>(conv_label, callback)); this->add_matcher(std::make_shared<ngraph::pattern::Matcher>(conv_label, callback));
...@@ -541,8 +544,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid() ...@@ -541,8 +544,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp); auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
//Define a call back that needs to called once the DFG matches the pattern //Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::gr_callback_fn callback = [input](pattern::Matcher& m) {
[input](pattern::Matcher& m) -> std::shared_ptr<Node> {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -550,18 +552,19 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid() ...@@ -550,18 +552,19 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
if (m.match_root()->get_element_type() != element::f32) if (m.match_root()->get_element_type() != element::f32)
{ {
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() << " type is not float!"; NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() << " type is not float!";
return nullptr; return false;
} }
if (m.match_root()->get_outputs().size() != pattern_map[input]->get_outputs().size()) if (m.match_root()->get_outputs().size() != pattern_map[input]->get_outputs().size())
{ {
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!"; << "input= " << pattern_map[input]->get_name() << "size dont match!";
return nullptr; return false;
} }
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]); auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
return sigmoid_node; ngraph::replace_node(m.match_root(), sigmoid_node);
return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, callback);
...@@ -593,26 +596,26 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop() ...@@ -593,26 +596,26 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop()
auto negtive_2 = std::make_shared<op::Negative>(multiply_2); auto negtive_2 = std::make_shared<op::Negative>(multiply_2);
//Define a call back that needs to called once the DFG matches the pattern //Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::gr_callback_fn callback = [input, delta](pattern::Matcher& m) {
[input, delta](pattern::Matcher& m) -> std::shared_ptr<Node> {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
if (m.match_root()->get_element_type() != element::f32) if (m.match_root()->get_element_type() != element::f32)
{ {
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() << " type is not float!"; NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() << " type is not float!";
return nullptr; return false;
} }
if (m.match_root()->get_shape().size() != pattern_map[input]->get_shape().size()) if (m.match_root()->get_shape().size() != pattern_map[input]->get_shape().size())
{ {
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!"; << "input= " << pattern_map[input]->get_name() << "size dont match!";
return nullptr; return false;
} }
auto dsigmoid = auto dsigmoid =
std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]); std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]);
return dsigmoid; ngraph::replace_node(m.match_root(), dsigmoid);
return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(negtive_2, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(negtive_2, callback);
...@@ -641,7 +644,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias() ...@@ -641,7 +644,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
NGRAPH_DEBUG << "In callback for construct_conv_bias against node = " NGRAPH_DEBUG << "In callback for construct_conv_bias against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn;
auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_input_op(0)); auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_input_op(0));
if (conv->get_input_shape(0).size() == 4) if (conv->get_input_shape(0).size() == 4)
...@@ -658,17 +660,19 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias() ...@@ -658,17 +660,19 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
auto bias_reshape = auto bias_reshape =
std::make_shared<op::Reshape>(bias, order, Shape{conv->get_input_shape(1)[0]}); std::make_shared<op::Reshape>(bias, order, Shape{conv->get_input_shape(1)[0]});
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias_reshape)); auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias_reshape));
return conv_bias; ngraph::replace_node(m.match_root(), conv_bias);
return true;
} }
else else
{ {
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias)); auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
return conv_bias; ngraph::replace_node(m.match_root(), conv_bias);
return true;
} }
} }
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name()
<< "conv_bias fusion skipped due to input rank size != 4."; << "conv_bias fusion skipped due to input rank size != 4.";
return std::shared_ptr<Node>(nullptr); return false;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback);
......
...@@ -169,12 +169,11 @@ public: ...@@ -169,12 +169,11 @@ public:
NGRAPH_DEBUG << "second_node = " << second_node->get_name() NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name(); << " , pattern = " << pattern_map[pattern]->get_name();
std::shared_ptr<ngraph::Node> nn = nullptr;
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() || if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape()) pattern_map[pattern]->get_shape() != const_node->get_shape())
{ {
NGRAPH_DEBUG << "Operands' types and/or shape don't match"; NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return nn; return false;
} }
auto const_values = const_node->get_vector<int32_t>(); auto const_values = const_node->get_vector<int32_t>();
...@@ -184,9 +183,11 @@ public: ...@@ -184,9 +183,11 @@ public:
if (!all_ones) if (!all_ones)
{ {
NGRAPH_DEBUG << "Constant vector's values aren't equal to 1"; NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
return nn; return false;
} }
return pattern_map[pattern];
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
return true;
}; };
auto m = make_shared<TestMatcher>(pattern * iconst1, callback); auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
...@@ -213,14 +214,11 @@ public: ...@@ -213,14 +214,11 @@ public:
NGRAPH_DEBUG << "second_node = " << second_node->get_name() NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name(); << " , pattern = " << pattern_map[pattern]->get_name();
//ASSERT_NE(nullptr, const_node);
std::shared_ptr<ngraph::Node> nn = nullptr;
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() || if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape()) pattern_map[pattern]->get_shape() != const_node->get_shape())
{ {
NGRAPH_DEBUG << "Operands' types and/or shape don't match"; NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return nn; return false;
} }
auto const_values = const_node->get_vector<int>(); auto const_values = const_node->get_vector<int>();
...@@ -230,10 +228,11 @@ public: ...@@ -230,10 +228,11 @@ public:
if (!all_zeros) if (!all_zeros)
{ {
NGRAPH_DEBUG << "Constant vector's values aren't equal to 0"; NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
return nn; return false;
} }
return pattern_map[pattern]; ngraph::replace_node(m.match_root(), pattern_map[pattern]);
return true;
}; };
auto m = make_shared<TestMatcher>(pattern + iconst0, callback); auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
...@@ -252,7 +251,9 @@ public: ...@@ -252,7 +251,9 @@ public:
NGRAPH_DEBUG << "reducee = " << reducee->get_name(); NGRAPH_DEBUG << "reducee = " << reducee->get_name();
auto sum = auto sum =
std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes())); std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes()));
return sum;
ngraph::replace_node(m.match_root(), sum);
return true;
}; };
auto m = make_shared<TestMatcher>(sum_pattern, callback); auto m = make_shared<TestMatcher>(sum_pattern, callback);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment