Commit 83206a0a authored by Nishant Patel's avatar Nishant Patel Committed by Adam Procter

Add checks for conv + bias bprop fusion (#1063)

parent ba2cbdd6
......@@ -775,6 +775,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop()
auto conv_bprop =
std::dynamic_pointer_cast<op::ConvolutionBackpropFilters>(m.get_match_root());
if (conv_bprop->get_input_shape(0).size() == 4 &&
conv_bprop->get_input_shape(1).size() == 4 &&
conv_bprop->get_input_element_type(0) == element::f32)
{
for (auto delta_user : pattern_map[delta]->get_users())
{
if (std::dynamic_pointer_cast<op::Sum>(delta_user))
......@@ -801,6 +805,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop()
return true;
}
}
}
return false;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment