Commit 1574031c authored by Nishant Patel's avatar Nishant Patel Committed by Robert Kimball

Reshape bias to 1D for cpufusion of conv+bias bprop (#1151)

* Reshape bias to 1D for conv + bias bprop fusion

* Reshape goe2 back to 2D before replacing
parent cf3e2992
...@@ -782,7 +782,21 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop() ...@@ -782,7 +782,21 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop()
{ {
if (std::dynamic_pointer_cast<op::Sum>(delta_user)) if (std::dynamic_pointer_cast<op::Sum>(delta_user))
{ {
auto bias_shape = delta_user->get_output_shape(0); auto bias = std::dynamic_pointer_cast<op::Sum>(delta_user);
auto bias_shape = bias->get_shape();
bool flag = false;
if (bias_shape.size() > 1)
{
NGRAPH_DEBUG
<< "mpattern = " << m.get_match_root()->get_name()
<< "conv_bias bias shape != 1, requires reshape to match filter count.";
ngraph::AxisVector order(bias_shape.size());
std::iota(begin(order), end(order), 0);
auto bias_reshape = std::make_shared<op::Reshape>(
bias, order, Shape{conv_bprop->get_filters_shape()[0]});
bias_shape = bias_reshape->get_shape();
flag = true;
}
auto conv_bias_bprop = std::make_shared<op::ConvolutionBiasBackpropFiltersBias>( auto conv_bias_bprop = std::make_shared<op::ConvolutionBiasBackpropFiltersBias>(
pattern_map[data_batch], pattern_map[data_batch],
conv_bprop->get_filters_shape(), conv_bprop->get_filters_shape(),
...@@ -800,7 +814,16 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop() ...@@ -800,7 +814,16 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop()
ngraph::replace_node(m.get_match_root(), goe1); ngraph::replace_node(m.get_match_root(), goe1);
NGRAPH_DEBUG << "Replacing bias and adding it as a second o/p of " NGRAPH_DEBUG << "Replacing bias and adding it as a second o/p of "
"ConvolutionBiasBackpropFiltersBias"; "ConvolutionBiasBackpropFiltersBias";
ngraph::replace_node(delta_user, goe2); if (flag)
{
auto goe2_reshape = std::make_shared<op::Reshape>(
goe2, AxisVector{0}, delta_user->get_shape());
ngraph::replace_node(delta_user, goe2_reshape);
}
else
{
ngraph::replace_node(delta_user, goe2);
}
return true; return true;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment