Commit 606ad20b authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Fix Matcher's getters to adhere to ngraph coding guidelines (#916)

* rename getters to adhere to ngraph coding guidelines

* fix renaminb

* fix build errors
parent 4f2316e8
...@@ -52,7 +52,7 @@ static std::shared_ptr<pattern::Matcher> ...@@ -52,7 +52,7 @@ static std::shared_ptr<pattern::Matcher>
static std::shared_ptr<pattern::op::Label> static std::shared_ptr<pattern::op::Label>
get_broadcast_label(std::shared_ptr<pattern::Matcher> matcher) get_broadcast_label(std::shared_ptr<pattern::Matcher> matcher)
{ {
return std::dynamic_pointer_cast<pattern::op::Label>(matcher->pattern_node()->get_argument(1)); return std::dynamic_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1));
} }
//`simplify_multiply` optimizes the following 4 *base* cases //`simplify_multiply` optimizes the following 4 *base* cases
......
...@@ -53,7 +53,8 @@ void pass::CoreFusion::construct_relu() ...@@ -53,7 +53,8 @@ void pass::CoreFusion::construct_relu()
auto max = make_shared<op::Maximum>(skip_broadcast, val); auto max = make_shared<op::Maximum>(skip_broadcast, val);
pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_relu against " << m.match_root()->get_name(); NGRAPH_DEBUG << "In a callback for construct_relu against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto mzero = m.get_pattern_map()[zero]; auto mzero = m.get_pattern_map()[zero];
...@@ -62,10 +63,10 @@ void pass::CoreFusion::construct_relu() ...@@ -62,10 +63,10 @@ void pass::CoreFusion::construct_relu()
NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n"; NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n";
return false; return false;
} }
auto mpattern = m.match_root(); auto mpattern = m.get_match_root();
auto cg = shared_ptr<Node>(new op::Relu(pattern_map[val])); auto cg = shared_ptr<Node>(new op::Relu(pattern_map[val]));
ngraph::replace_node(m.match_root(), cg); ngraph::replace_node(m.get_match_root(), cg);
return true; return true;
}; };
......
...@@ -61,11 +61,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -61,11 +61,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
auto callback = [op](pattern::Matcher& m) { auto callback = [op](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = " NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto gop = pattern_map[op]; auto gop = pattern_map[op];
auto r1 = std::dynamic_pointer_cast<op::Reshape>(m.match_root()); auto r1 = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root());
if (r1->get_shape() != gop->get_shape()) if (r1->get_shape() != gop->get_shape())
{ {
...@@ -82,7 +82,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -82,7 +82,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
return false; return false;
} }
ngraph::replace_node(m.match_root(), gop); ngraph::replace_node(m.get_match_root(), gop);
return true; return true;
}; };
...@@ -101,22 +101,22 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -101,22 +101,22 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
auto callback = [op](pattern::Matcher& m) { auto callback = [op](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_reshapex2_pattern against node = " NGRAPH_DEBUG << "In callback for construct_reshapex2_pattern against node = "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto gop = pattern_map[op]; auto gop = pattern_map[op];
if (gop->get_shape() != m.match_root()->get_shape()) if (gop->get_shape() != m.get_match_root()->get_shape())
{ {
NGRAPH_DEBUG << "Operand shape doesn't match the shape of the second reshape!"; NGRAPH_DEBUG << "Operand shape doesn't match the shape of the second reshape!";
NGRAPH_DEBUG << "gop " << gop->get_name() NGRAPH_DEBUG << "gop " << gop->get_name()
<< "shape = " << vector_to_string(gop->get_shape()); << "shape = " << vector_to_string(gop->get_shape());
NGRAPH_DEBUG << "match_root " << m.match_root()->get_name() NGRAPH_DEBUG << "match_root " << m.get_match_root()->get_name()
<< "shape = " << vector_to_string(m.match_root()->get_shape()); << "shape = " << vector_to_string(m.get_match_root()->get_shape());
return false; return false;
} }
auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.match_root()); auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root());
auto r1 = std::dynamic_pointer_cast<op::Reshape>(r2->get_argument(0)); auto r1 = std::dynamic_pointer_cast<op::Reshape>(r2->get_argument(0));
Shape do_r2(r1->get_shape().size()); Shape do_r2(r1->get_shape().size());
...@@ -132,7 +132,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -132,7 +132,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2) if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2)
{ {
NGRAPH_DEBUG << "Two reshapes were removed!"; NGRAPH_DEBUG << "Two reshapes were removed!";
ngraph::replace_node(m.match_root(), gop); ngraph::replace_node(m.get_match_root(), gop);
return true; return true;
} }
...@@ -141,7 +141,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -141,7 +141,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if (perm2 == do_r1) if (perm2 == do_r1)
{ {
NGRAPH_DEBUG << "Two transposes were removed!"; NGRAPH_DEBUG << "Two transposes were removed!";
ngraph::replace_node(m.match_root(), gop); ngraph::replace_node(m.get_match_root(), gop);
return true; return true;
} }
...@@ -163,9 +163,9 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -163,9 +163,9 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = " NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto mtranspose = std::dynamic_pointer_cast<op::Reshape>(m.match_root()); auto mtranspose = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root());
//this also checks the rank //this also checks the rank
if (mtranspose->get_input_order() != AxisVector{1, 0}) if (mtranspose->get_input_order() != AxisVector{1, 0})
{ {
...@@ -190,7 +190,7 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -190,7 +190,7 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape); auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
auto tdot = std::shared_ptr<Node>(new op::Dot(reshape1, reshape0)); auto tdot = std::shared_ptr<Node>(new op::Dot(reshape1, reshape0));
ngraph::replace_node(m.match_root(), tdot); ngraph::replace_node(m.get_match_root(), tdot);
return true; return true;
}; };
......
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
namespace pattern namespace pattern
{ {
std::shared_ptr<Node> Matcher::match_root() { return m_match_root; } std::shared_ptr<Node> Matcher::get_match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label, bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node, const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map) PatternMap& pattern_map)
......
...@@ -95,8 +95,8 @@ namespace ngraph ...@@ -95,8 +95,8 @@ namespace ngraph
bool process_match(graph_rewrite_callback callback = nullptr); bool process_match(graph_rewrite_callback callback = nullptr);
void reset() {} void reset() {}
std::shared_ptr<Node> pattern_node() { return m_pattern_node; } std::shared_ptr<Node> get_pattern() { return m_pattern_node; }
std::shared_ptr<Node> match_root(); std::shared_ptr<Node> get_match_root();
PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; } PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; }
/// \brief Low-level helper to match recurring patterns /// \brief Low-level helper to match recurring patterns
/// ///
......
...@@ -138,9 +138,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias() ...@@ -138,9 +138,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
ngraph::pattern::graph_rewrite_callback callback = [W, x](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmulbias_pattern against node = " NGRAPH_DEBUG << "In callback for construct_matmulbias_pattern against node = "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto mpattern = m.match_root(); //add auto mpattern = m.get_match_root(); //add
auto m_matmul = ngraph::pattern::Matcher::unique_match<op::MatmulBias>(mpattern); auto m_matmul = ngraph::pattern::Matcher::unique_match<op::MatmulBias>(mpattern);
auto m_broadcast = ngraph::pattern::Matcher::unique_match<op::Broadcast>(mpattern); auto m_broadcast = ngraph::pattern::Matcher::unique_match<op::Broadcast>(mpattern);
auto m_bias = m_broadcast->get_argument(0); auto m_bias = m_broadcast->get_argument(0);
...@@ -155,7 +155,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias() ...@@ -155,7 +155,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
m_matmul->get_is_arg1_transposed(), m_matmul->get_is_arg1_transposed(),
m_broadcast->get_broadcast_axes()); m_broadcast->get_broadcast_axes());
ngraph::replace_node(m.match_root(), mmb); ngraph::replace_node(m.get_match_root(), mmb);
return true; return true;
}; };
...@@ -184,11 +184,11 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul() ...@@ -184,11 +184,11 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul()
ngraph::pattern::graph_rewrite_callback callback = [W, x](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = " NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto mpattern = m.match_root(); auto mpattern = m.get_match_root();
auto dot = m.match_root(); auto dot = m.get_match_root();
if (mpattern->get_element_type() != element::f32) if (mpattern->get_element_type() != element::f32)
{ {
...@@ -289,7 +289,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn() ...@@ -289,7 +289,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
[variance_label, mean_label, input, eps_label, gamma_label, beta_label]( [variance_label, mean_label, input, eps_label, gamma_label, beta_label](
pattern::Matcher& m) { pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
//TODO - add assert's based on the matched node //TODO - add assert's based on the matched node
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -312,7 +312,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn() ...@@ -312,7 +312,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing"; NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing";
return false; return false;
} }
Shape bn_output_shape{m.match_root()->get_shape()}; Shape bn_output_shape{m.get_match_root()->get_shape()};
Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()}; Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()};
Shape m_bn_variance_shape{pattern_map[variance_label]->get_shape()}; Shape m_bn_variance_shape{pattern_map[variance_label]->get_shape()};
...@@ -324,7 +324,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn() ...@@ -324,7 +324,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
auto normalized_output = std::shared_ptr<Node>(new op::GetOutputElement(bn_node, 0)); auto normalized_output = std::shared_ptr<Node>(new op::GetOutputElement(bn_node, 0));
ngraph::replace_node(m.match_root(), normalized_output); ngraph::replace_node(m.get_match_root(), normalized_output);
return true; return true;
}; };
...@@ -433,7 +433,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv( ...@@ -433,7 +433,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
input_order, input_order,
Shape(hoisted_reshape_output_shape.begin(), hoisted_reshape_output_shape.end())); Shape(hoisted_reshape_output_shape.begin(), hoisted_reshape_output_shape.end()));
if (!zero_padded_conv_consistency_check(m.match_root(), if (!zero_padded_conv_consistency_check(m.get_match_root(),
pad_value_op, pad_value_op,
pattern_map[pad_input], pattern_map[pad_input],
matched_pad, matched_pad,
...@@ -463,7 +463,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv( ...@@ -463,7 +463,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
padding_above, padding_above,
matched_conv->get_data_dilation_strides()); matched_conv->get_data_dilation_strides());
ngraph::replace_node(m.match_root(), zero_padded_conv); ngraph::replace_node(m.get_match_root(), zero_padded_conv);
return true; return true;
}; };
...@@ -499,7 +499,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv() ...@@ -499,7 +499,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
std::dynamic_pointer_cast<op::Convolution>(pattern_map[conv_label]); std::dynamic_pointer_cast<op::Convolution>(pattern_map[conv_label]);
const auto& matched_pad = std::dynamic_pointer_cast<op::Pad>(pattern_map[pad_label]); const auto& matched_pad = std::dynamic_pointer_cast<op::Pad>(pattern_map[pad_label]);
if (!zero_padded_conv_consistency_check(m.match_root(), if (!zero_padded_conv_consistency_check(m.get_match_root(),
pad_value_op, pad_value_op,
pattern_map[pad_input], pattern_map[pad_input],
matched_pad, matched_pad,
...@@ -527,7 +527,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv() ...@@ -527,7 +527,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
padding_above, padding_above,
matched_conv->get_data_dilation_strides()); matched_conv->get_data_dilation_strides());
ngraph::replace_node(m.match_root(), zero_padded_conv); ngraph::replace_node(m.get_match_root(), zero_padded_conv);
return true; return true;
}; };
...@@ -564,7 +564,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv_backprop_ ...@@ -564,7 +564,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv_backprop_
std::dynamic_pointer_cast<op::ConvolutionBackpropFilters>(pattern_map[conv_label]); std::dynamic_pointer_cast<op::ConvolutionBackpropFilters>(pattern_map[conv_label]);
const auto& matched_pad = std::dynamic_pointer_cast<op::Pad>(pattern_map[pad_label]); const auto& matched_pad = std::dynamic_pointer_cast<op::Pad>(pattern_map[pad_label]);
if (!zero_padded_conv_consistency_check(m.match_root(), if (!zero_padded_conv_consistency_check(m.get_match_root(),
pad_value_op, pad_value_op,
pattern_map[pad_input], pattern_map[pad_input],
matched_pad, matched_pad,
...@@ -594,7 +594,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv_backprop_ ...@@ -594,7 +594,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv_backprop_
padding_above, padding_above,
matched_conv->get_data_dilation_strides_forward()); matched_conv->get_data_dilation_strides_forward());
ngraph::replace_node(m.match_root(), zero_padded_conv_backprop_filters); ngraph::replace_node(m.get_match_root(), zero_padded_conv_backprop_filters);
return true; return true;
}; };
...@@ -618,24 +618,25 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid() ...@@ -618,24 +618,25 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
//Define a call back that needs to called once the DFG matches the pattern //Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
if (m.match_root()->get_element_type() != element::f32) if (m.get_match_root()->get_element_type() != element::f32)
{ {
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() << " type is not float!"; NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false; return false;
} }
if (m.match_root()->get_outputs().size() != pattern_map[input]->get_outputs().size()) if (m.get_match_root()->get_outputs().size() != pattern_map[input]->get_outputs().size())
{ {
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!"; << "input= " << pattern_map[input]->get_name() << "size dont match!";
return false; return false;
} }
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]); auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
ngraph::replace_node(m.match_root(), sigmoid_node); ngraph::replace_node(m.get_match_root(), sigmoid_node);
return true; return true;
}; };
...@@ -670,23 +671,24 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop() ...@@ -670,23 +671,24 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop()
//Define a call back that needs to called once the DFG matches the pattern //Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
if (m.match_root()->get_element_type() != element::f32) if (m.get_match_root()->get_element_type() != element::f32)
{ {
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() << " type is not float!"; NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false; return false;
} }
if (m.match_root()->get_shape().size() != pattern_map[input]->get_shape().size()) if (m.get_match_root()->get_shape().size() != pattern_map[input]->get_shape().size())
{ {
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!"; << "input= " << pattern_map[input]->get_name() << "size dont match!";
return false; return false;
} }
auto dsigmoid = auto dsigmoid =
std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]); std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]);
ngraph::replace_node(m.match_root(), dsigmoid); ngraph::replace_node(m.get_match_root(), dsigmoid);
return true; return true;
}; };
...@@ -714,35 +716,35 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias() ...@@ -714,35 +716,35 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_conv_bias against node = " NGRAPH_DEBUG << "In callback for construct_conv_bias against node = "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_argument(0)); auto conv = std::dynamic_pointer_cast<op::Convolution>(m.get_match_root()->get_argument(0));
if (conv->get_input_shape(0).size() == 4) if (conv->get_input_shape(0).size() == 4)
{ {
auto bias = m.match_root()->get_argument(1)->get_argument(0); auto bias = m.get_match_root()->get_argument(1)->get_argument(0);
auto bias_shape = bias->get_shape(); auto bias_shape = bias->get_shape();
if (bias_shape.size() > 1) if (bias_shape.size() > 1)
{ {
NGRAPH_DEBUG NGRAPH_DEBUG
<< "mpattern = " << m.match_root()->get_name() << "mpattern = " << m.get_match_root()->get_name()
<< "conv_bias bias shape != 1, requires reshape to match filter count."; << "conv_bias bias shape != 1, requires reshape to match filter count.";
ngraph::AxisVector order(bias_shape.size()); ngraph::AxisVector order(bias_shape.size());
std::iota(begin(order), end(order), 0); std::iota(begin(order), end(order), 0);
auto bias_reshape = auto bias_reshape =
std::make_shared<op::Reshape>(bias, order, Shape{conv->get_input_shape(1)[0]}); std::make_shared<op::Reshape>(bias, order, Shape{conv->get_input_shape(1)[0]});
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias_reshape)); auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias_reshape));
ngraph::replace_node(m.match_root(), conv_bias); ngraph::replace_node(m.get_match_root(), conv_bias);
return true; return true;
} }
else else
{ {
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias)); auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
ngraph::replace_node(m.match_root(), conv_bias); ngraph::replace_node(m.get_match_root(), conv_bias);
return true; return true;
} }
} }
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "conv_bias fusion skipped due to input rank size != 4."; << "conv_bias fusion skipped due to input rank size != 4.";
return false; return false;
}; };
...@@ -769,11 +771,11 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu() ...@@ -769,11 +771,11 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
ngraph::pattern::graph_rewrite_callback callback = [input, gamma, beta](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [input, gamma, beta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_batch_norm_relu against node = " NGRAPH_DEBUG << "In callback for construct_batch_norm_relu against node = "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>( auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>(
m.match_root()->get_argument(0)->get_inputs().at(0).get_output().get_node()); m.get_match_root()->get_argument(0)->get_inputs().at(0).get_output().get_node());
//as of now, only MKLDNN supports this fusion //as of now, only MKLDNN supports this fusion
//and it requires input data's rank to be equal to 4 //and it requires input data's rank to be equal to 4
...@@ -797,7 +799,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu() ...@@ -797,7 +799,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
return false; return false;
} }
mgoes[0] = m.match_root(); //replace relu instead of its GetOutputElement mgoes[0] = m.get_match_root(); //replace relu instead of its GetOutputElement
auto bn_relu = std::make_shared<op::BatchNormRelu>( auto bn_relu = std::make_shared<op::BatchNormRelu>(
m_bn->get_eps_value(), pattern_map[gamma], pattern_map[beta], pattern_map[input]); m_bn->get_eps_value(), pattern_map[gamma], pattern_map[beta], pattern_map[input]);
...@@ -839,11 +841,11 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta ...@@ -839,11 +841,11 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
ngraph::pattern::graph_rewrite_callback callback = ngraph::pattern::graph_rewrite_callback callback =
[input, mean, var, gamma, beta](pattern::Matcher& m) { [input, mean, var, gamma, beta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_batch_norm_relu against node = " NGRAPH_DEBUG << "In callback for construct_batch_norm_relu against node = "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>( auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>(
m.match_root()->get_inputs().at(0).get_output().get_node()); m.get_match_root()->get_inputs().at(0).get_output().get_node());
//as of now, only MKLDNN supports this fusion //as of now, only MKLDNN supports this fusion
//and it requires input data's rank to be equal to 4 //and it requires input data's rank to be equal to 4
...@@ -868,7 +870,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta ...@@ -868,7 +870,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
pattern_map[var], pattern_map[var],
m_bn->get_training_flag()); m_bn->get_training_flag());
ngraph::replace_node(m.match_root(), bn_relu); ngraph::replace_node(m.get_match_root(), bn_relu);
return true; return true;
}; };
...@@ -895,9 +897,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu() ...@@ -895,9 +897,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_conv_relu against " NGRAPH_DEBUG << "In a callback for construct_conv_relu against "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_argument(0)); auto conv = std::dynamic_pointer_cast<op::Convolution>(m.get_match_root()->get_argument(0));
//These checks are to make sure a MKLDNN Convolution kernel can be used. //These checks are to make sure a MKLDNN Convolution kernel can be used.
bool data_dilated = false; bool data_dilated = false;
...@@ -934,7 +936,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu() ...@@ -934,7 +936,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
} }
auto conv_relu = std::shared_ptr<Node>(new op::ConvolutionRelu(conv)); auto conv_relu = std::shared_ptr<Node>(new op::ConvolutionRelu(conv));
ngraph::replace_node(m.match_root(), conv_relu); ngraph::replace_node(m.get_match_root(), conv_relu);
return true; return true;
}; };
......
...@@ -51,9 +51,10 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu ...@@ -51,9 +51,10 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu
data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1}); data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1});
pattern::graph_rewrite_callback callback = [param](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [param](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_weight against " << m.match_root()->get_name(); NGRAPH_DEBUG << "In a callback for construct_weight against "
<< m.get_match_root()->get_name();
auto m_cvt_lt = m.match_root()->get_argument(1); auto m_cvt_lt = m.get_match_root()->get_argument(1);
auto m_reshape_conv = m_cvt_lt->get_argument(0); auto m_reshape_conv = m_cvt_lt->get_argument(0);
std::shared_ptr<Node> m_conv_bprop; std::shared_ptr<Node> m_conv_bprop;
......
...@@ -152,15 +152,16 @@ public: ...@@ -152,15 +152,16 @@ public:
ngraph::pattern::graph_rewrite_callback callback = [pattern](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against " NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
assert(m.match_root()->get_arguments().size() == 2); assert(m.get_match_root()->get_arguments().size() == 2);
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
size_t const_node_index = m.match_root()->get_arguments().at(0) == pattern_map[pattern]; size_t const_node_index =
m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::Constant>( auto const_node = dynamic_pointer_cast<op::Constant>(
m.match_root()->get_arguments().at(const_node_index)); m.get_match_root()->get_arguments().at(const_node_index));
auto second_node = m.match_root()->get_arguments().at(const_node_index); auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
NGRAPH_DEBUG << "second_node = " << second_node->get_name() NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name(); << " , pattern = " << pattern_map[pattern]->get_name();
...@@ -181,7 +182,7 @@ public: ...@@ -181,7 +182,7 @@ public:
return false; return false;
} }
ngraph::replace_node(m.match_root(), pattern_map[pattern]); ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
return true; return true;
}; };
...@@ -197,15 +198,16 @@ public: ...@@ -197,15 +198,16 @@ public:
auto callback = [pattern](pattern::Matcher& m) { auto callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_add_zero against " NGRAPH_DEBUG << "In a callback for construct_add_zero against "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
assert(m.match_root()->get_arguments().size() == 2); assert(m.get_match_root()->get_arguments().size() == 2);
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
size_t const_node_index = m.match_root()->get_arguments().at(0) == pattern_map[pattern]; size_t const_node_index =
m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::Constant>( auto const_node = dynamic_pointer_cast<op::Constant>(
m.match_root()->get_arguments().at(const_node_index)); m.get_match_root()->get_arguments().at(const_node_index));
auto second_node = m.match_root()->get_arguments().at(const_node_index); auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
NGRAPH_DEBUG << "second_node = " << second_node->get_name() NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name(); << " , pattern = " << pattern_map[pattern]->get_name();
...@@ -226,7 +228,7 @@ public: ...@@ -226,7 +228,7 @@ public:
return false; return false;
} }
ngraph::replace_node(m.match_root(), pattern_map[pattern]); ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
return true; return true;
}; };
...@@ -240,14 +242,14 @@ public: ...@@ -240,14 +242,14 @@ public:
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_sum_pattern against " NGRAPH_DEBUG << "In a callback for construct_sum_pattern against "
<< m.match_root()->get_name(); << m.get_match_root()->get_name();
auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root()); auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.get_match_root());
auto reducee = reduce->get_inputs().at(0).get_output().get_node(); auto reducee = reduce->get_inputs().at(0).get_output().get_node();
NGRAPH_DEBUG << "reducee = " << reducee->get_name(); NGRAPH_DEBUG << "reducee = " << reducee->get_name();
auto sum = auto sum =
std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes())); std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes()));
ngraph::replace_node(m.match_root(), sum); ngraph::replace_node(m.get_match_root(), sum);
return true; return true;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment