Commit 49a8f3b2 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Use only convolution direct algo for non float convs (#2752)

parent c5d52f14
...@@ -332,10 +332,9 @@ size_t ...@@ -332,10 +332,9 @@ size_t
conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest); conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
/* Specify the scales array and corresponding mask */ /* Specify the scales array and corresponding mask */
conv_attr.set_output_scales(0, output_scale); conv_attr.set_output_scales(0, output_scale);
mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo();
size_t conv_index = insert_primitive(new mkldnn::convolution_forward( size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
{{mkldnn::prop_kind::forward, {{mkldnn::prop_kind::forward,
convolution_algo, mkldnn::algorithm::convolution_direct,
input_data_desc, input_data_desc,
weights_desc, weights_desc,
result_desc, result_desc,
...@@ -377,10 +376,9 @@ size_t ...@@ -377,10 +376,9 @@ size_t
conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest); conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
/* Specify the scales array and corresponding mask */ /* Specify the scales array and corresponding mask */
conv_attr.set_output_scales(0, output_scale); conv_attr.set_output_scales(0, output_scale);
mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo();
size_t conv_index = insert_primitive(new mkldnn::convolution_forward( size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
{{mkldnn::prop_kind::forward, {{mkldnn::prop_kind::forward,
convolution_algo, mkldnn::algorithm::convolution_direct,
input_data_desc, input_data_desc,
weights_desc, weights_desc,
bias_desc, bias_desc,
......
...@@ -1250,6 +1250,13 @@ namespace ngraph ...@@ -1250,6 +1250,13 @@ namespace ngraph
Strides window_dilation_strides_adjusted; Strides window_dilation_strides_adjusted;
mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo(); mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo();
if (node->get_input_element_type(0) != element::f32 &&
convolution_algo != mkldnn::algorithm::convolution_direct)
{
convolution_algo = mkldnn::algorithm::convolution_direct;
}
for (size_t s : convolution->get_window_dilation_strides()) for (size_t s : convolution->get_window_dilation_strides())
{ {
window_dilation_strides_adjusted.push_back(s - 1); window_dilation_strides_adjusted.push_back(s - 1);
......
...@@ -393,6 +393,13 @@ namespace ngraph ...@@ -393,6 +393,13 @@ namespace ngraph
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr}; std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
auto convolution_algo = mkldnn_utils::get_conv_algo(); auto convolution_algo = mkldnn_utils::get_conv_algo();
if (node->get_input_element_type(0) != element::f32 &&
convolution_algo != mkldnn::algorithm::convolution_direct)
{
convolution_algo = mkldnn::algorithm::convolution_direct;
}
if (use_bias) if (use_bias)
{ {
memory::data_type et_bias = memory::data_type et_bias =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment