Commit 19899f4d authored by Louis Feng's avatar Louis Feng Committed by Matthew Brookhart

fixed conv+bias pattern match causing mxnet tests to fail. (#647)

parent 5ec9e25f
......@@ -644,9 +644,31 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
std::shared_ptr<Node> nn;
auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_input_op(0));
auto bias = m.match_root()->get_input_op(1)->get_input_op(0);
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
return conv_bias;
if (conv->get_input_shape(0).size() == 4)
{
auto bias = m.match_root()->get_input_op(1)->get_input_op(0);
auto bias_shape = bias->get_shape();
if (bias_shape.size() > 1)
{
NGRAPH_DEBUG
<< "mpattern = " << m.match_root()->get_name()
<< "conv_bias bias shape != 1, requires reshape to match filter count.";
ngraph::AxisVector order(bias_shape.size());
std::iota(begin(order), end(order), 0);
auto bias_reshape =
std::make_shared<op::Reshape>(bias, order, Shape{conv->get_input_shape(1)[0]});
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias_reshape));
return conv_bias;
}
else
{
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
return conv_bias;
}
}
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name()
<< "conv_bias fusion skipped due to input rank size != 4.";
return std::shared_ptr<Node>(nullptr);
};
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment