Commit 83206a0a authored by Nishant Patel's avatar Nishant Patel Committed by Adam Procter

Add checks for conv + bias bprop fusion (#1063)

parent ba2cbdd6
......@@ -775,30 +775,35 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop()
auto conv_bprop =
std::dynamic_pointer_cast<op::ConvolutionBackpropFilters>(m.get_match_root());
for (auto delta_user : pattern_map[delta]->get_users())
if (conv_bprop->get_input_shape(0).size() == 4 &&
conv_bprop->get_input_shape(1).size() == 4 &&
conv_bprop->get_input_element_type(0) == element::f32)
{
if (std::dynamic_pointer_cast<op::Sum>(delta_user))
for (auto delta_user : pattern_map[delta]->get_users())
{
auto bias_shape = delta_user->get_output_shape(0);
auto conv_bias_bprop = std::make_shared<op::ConvolutionBiasBackpropFiltersBias>(
pattern_map[data_batch],
conv_bprop->get_filters_shape(),
bias_shape,
pattern_map[delta],
conv_bprop->get_window_movement_strides_forward(),
conv_bprop->get_window_dilation_strides_forward(),
conv_bprop->get_padding_below_forward(),
conv_bprop->get_padding_above_forward(),
conv_bprop->get_data_dilation_strides_forward());
auto goe1 = std::make_shared<op::GetOutputElement>(conv_bias_bprop, 0);
auto goe2 = std::make_shared<op::GetOutputElement>(conv_bias_bprop, 1);
NGRAPH_DEBUG << "Replacing " << m.get_match_root()->get_name()
<< "with ConvolutionBiasBackpropFiltersBias";
ngraph::replace_node(m.get_match_root(), goe1);
NGRAPH_DEBUG << "Replacing bias and adding it as a second o/p of "
"ConvolutionBiasBackpropFiltersBias";
ngraph::replace_node(delta_user, goe2);
return true;
if (std::dynamic_pointer_cast<op::Sum>(delta_user))
{
auto bias_shape = delta_user->get_output_shape(0);
auto conv_bias_bprop = std::make_shared<op::ConvolutionBiasBackpropFiltersBias>(
pattern_map[data_batch],
conv_bprop->get_filters_shape(),
bias_shape,
pattern_map[delta],
conv_bprop->get_window_movement_strides_forward(),
conv_bprop->get_window_dilation_strides_forward(),
conv_bprop->get_padding_below_forward(),
conv_bprop->get_padding_above_forward(),
conv_bprop->get_data_dilation_strides_forward());
auto goe1 = std::make_shared<op::GetOutputElement>(conv_bias_bprop, 0);
auto goe2 = std::make_shared<op::GetOutputElement>(conv_bias_bprop, 1);
NGRAPH_DEBUG << "Replacing " << m.get_match_root()->get_name()
<< "with ConvolutionBiasBackpropFiltersBias";
ngraph::replace_node(m.get_match_root(), goe1);
NGRAPH_DEBUG << "Replacing bias and adding it as a second o/p of "
"ConvolutionBiasBackpropFiltersBias";
ngraph::replace_node(delta_user, goe2);
return true;
}
}
}
return false;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment