Commit 3762538c authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Use convolution_direct algo (#3028)

* Failing conv test case

* Opt for mkldnn::algorithm::convolution_direct if input channel is less than 8

* Comment
parent d96482a5
...@@ -937,8 +937,9 @@ namespace ngraph ...@@ -937,8 +937,9 @@ namespace ngraph
mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo(); mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo();
if (node->get_input_element_type(0) != element::f32 && if ((node->get_input_element_type(0) != element::f32 &&
convolution_algo != mkldnn::algorithm::convolution_direct) convolution_algo != mkldnn::algorithm::convolution_direct) ||
convolution->get_argument(0)->get_shape()[1] <= 8)
{ {
convolution_algo = mkldnn::algorithm::convolution_direct; convolution_algo = mkldnn::algorithm::convolution_direct;
} }
......
...@@ -395,8 +395,13 @@ namespace ngraph ...@@ -395,8 +395,13 @@ namespace ngraph
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr}; std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
auto convolution_algo = mkldnn_utils::get_conv_algo(); auto convolution_algo = mkldnn_utils::get_conv_algo();
if (node->get_input_element_type(0) != element::f32 && // I/p channels less than 8 & convolution_algo = convolution_auto
convolution_algo != mkldnn::algorithm::convolution_direct) // forces src format to be nChw16c & the weight format to be
// OIhw16i16o which invokes mkldnn reference implementation of conv
// which crashes as it has no support for post ops
if ((node->get_input_element_type(0) != element::f32 &&
convolution_algo != mkldnn::algorithm::convolution_direct) ||
arg0_shape[1] <= 8)
{ {
convolution_algo = mkldnn::algorithm::convolution_direct; convolution_algo = mkldnn::algorithm::convolution_direct;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment