Commit 83206a0a authored by Nishant Patel's avatar Nishant Patel Committed by Adam Procter

Add checks for conv + bias bprop fusion (#1063)

parent ba2cbdd6
...@@ -775,30 +775,35 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop() ...@@ -775,30 +775,35 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop()
auto conv_bprop = auto conv_bprop =
std::dynamic_pointer_cast<op::ConvolutionBackpropFilters>(m.get_match_root()); std::dynamic_pointer_cast<op::ConvolutionBackpropFilters>(m.get_match_root());
for (auto delta_user : pattern_map[delta]->get_users()) if (conv_bprop->get_input_shape(0).size() == 4 &&
conv_bprop->get_input_shape(1).size() == 4 &&
conv_bprop->get_input_element_type(0) == element::f32)
{ {
if (std::dynamic_pointer_cast<op::Sum>(delta_user)) for (auto delta_user : pattern_map[delta]->get_users())
{ {
auto bias_shape = delta_user->get_output_shape(0); if (std::dynamic_pointer_cast<op::Sum>(delta_user))
auto conv_bias_bprop = std::make_shared<op::ConvolutionBiasBackpropFiltersBias>( {
pattern_map[data_batch], auto bias_shape = delta_user->get_output_shape(0);
conv_bprop->get_filters_shape(), auto conv_bias_bprop = std::make_shared<op::ConvolutionBiasBackpropFiltersBias>(
bias_shape, pattern_map[data_batch],
pattern_map[delta], conv_bprop->get_filters_shape(),
conv_bprop->get_window_movement_strides_forward(), bias_shape,
conv_bprop->get_window_dilation_strides_forward(), pattern_map[delta],
conv_bprop->get_padding_below_forward(), conv_bprop->get_window_movement_strides_forward(),
conv_bprop->get_padding_above_forward(), conv_bprop->get_window_dilation_strides_forward(),
conv_bprop->get_data_dilation_strides_forward()); conv_bprop->get_padding_below_forward(),
auto goe1 = std::make_shared<op::GetOutputElement>(conv_bias_bprop, 0); conv_bprop->get_padding_above_forward(),
auto goe2 = std::make_shared<op::GetOutputElement>(conv_bias_bprop, 1); conv_bprop->get_data_dilation_strides_forward());
NGRAPH_DEBUG << "Replacing " << m.get_match_root()->get_name() auto goe1 = std::make_shared<op::GetOutputElement>(conv_bias_bprop, 0);
<< "with ConvolutionBiasBackpropFiltersBias"; auto goe2 = std::make_shared<op::GetOutputElement>(conv_bias_bprop, 1);
ngraph::replace_node(m.get_match_root(), goe1); NGRAPH_DEBUG << "Replacing " << m.get_match_root()->get_name()
NGRAPH_DEBUG << "Replacing bias and adding it as a second o/p of " << "with ConvolutionBiasBackpropFiltersBias";
"ConvolutionBiasBackpropFiltersBias"; ngraph::replace_node(m.get_match_root(), goe1);
ngraph::replace_node(delta_user, goe2); NGRAPH_DEBUG << "Replacing bias and adding it as a second o/p of "
return true; "ConvolutionBiasBackpropFiltersBias";
ngraph::replace_node(delta_user, goe2);
return true;
}
} }
} }
return false; return false;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment