Commit 3762538c authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Use convolution_direct algo (#3028)

* Failing conv test case

* Opt for mkldnn::algorithm::convolution_direct if input channel is less than 8

* Comment
parent d96482a5
......@@ -937,8 +937,9 @@ namespace ngraph
mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo();
if (node->get_input_element_type(0) != element::f32 &&
convolution_algo != mkldnn::algorithm::convolution_direct)
if ((node->get_input_element_type(0) != element::f32 &&
convolution_algo != mkldnn::algorithm::convolution_direct) ||
convolution->get_argument(0)->get_shape()[1] <= 8)
{
convolution_algo = mkldnn::algorithm::convolution_direct;
}
......
......@@ -395,8 +395,13 @@ namespace ngraph
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
auto convolution_algo = mkldnn_utils::get_conv_algo();
if (node->get_input_element_type(0) != element::f32 &&
convolution_algo != mkldnn::algorithm::convolution_direct)
// I/p channels less than 8 & convolution_algo = convolution_auto
// forces src format to be nChw16c & the weight format to be
// OIhw16i16o which invokes mkldnn reference implementation of conv
// which crashes as it has no support for post ops
if ((node->get_input_element_type(0) != element::f32 &&
convolution_algo != mkldnn::algorithm::convolution_direct) ||
arg0_shape[1] <= 8)
{
convolution_algo = mkldnn::algorithm::convolution_direct;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment