Commit 1574031c authored by Nishant Patel's avatar Nishant Patel Committed by Robert Kimball

Reshape bias to 1D for cpufusion of conv+bias bprop (#1151)

* Reshape bias to 1D for conv + bias bprop fusion

* Reshape goe2 back to 2D before replacing
parent cf3e2992
......@@ -782,7 +782,21 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop()
{
if (std::dynamic_pointer_cast<op::Sum>(delta_user))
{
auto bias_shape = delta_user->get_output_shape(0);
auto bias = std::dynamic_pointer_cast<op::Sum>(delta_user);
auto bias_shape = bias->get_shape();
bool flag = false;
if (bias_shape.size() > 1)
{
NGRAPH_DEBUG
<< "mpattern = " << m.get_match_root()->get_name()
<< "conv_bias bias shape != 1, requires reshape to match filter count.";
ngraph::AxisVector order(bias_shape.size());
std::iota(begin(order), end(order), 0);
auto bias_reshape = std::make_shared<op::Reshape>(
bias, order, Shape{conv_bprop->get_filters_shape()[0]});
bias_shape = bias_reshape->get_shape();
flag = true;
}
auto conv_bias_bprop = std::make_shared<op::ConvolutionBiasBackpropFiltersBias>(
pattern_map[data_batch],
conv_bprop->get_filters_shape(),
......@@ -800,7 +814,16 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop()
ngraph::replace_node(m.get_match_root(), goe1);
NGRAPH_DEBUG << "Replacing bias and adding it as a second o/p of "
"ConvolutionBiasBackpropFiltersBias";
ngraph::replace_node(delta_user, goe2);
if (flag)
{
auto goe2_reshape = std::make_shared<op::Reshape>(
goe2, AxisVector{0}, delta_user->get_shape());
ngraph::replace_node(delta_user, goe2_reshape);
}
else
{
ngraph::replace_node(delta_user, goe2);
}
return true;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment