Unverified Commit 88b058b7 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/mkldnn (#2760)

* Revert "add logic in replace_node for provenance propagation (#2703)" (#2731)

This reverts commit 6c8284a3.

* Migrate doc changes to r0.18 (#2738)

* Migrate doc changes

* Add TensorFlow version change

* Use only convolution direct algo for non float convs (#2752)
parent 5473a18a
...@@ -368,10 +368,9 @@ size_t ...@@ -368,10 +368,9 @@ size_t
conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest); conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
/* Specify the scales array and corresponding mask */ /* Specify the scales array and corresponding mask */
conv_attr.set_output_scales(0, output_scale); conv_attr.set_output_scales(0, output_scale);
mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo();
size_t conv_index = insert_primitive(new mkldnn::convolution_forward( size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
{{mkldnn::prop_kind::forward, {{mkldnn::prop_kind::forward,
convolution_algo, mkldnn::algorithm::convolution_direct,
input_data_desc, input_data_desc,
weights_desc, weights_desc,
result_desc, result_desc,
...@@ -417,10 +416,9 @@ size_t ...@@ -417,10 +416,9 @@ size_t
conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest); conv_attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
/* Specify the scales array and corresponding mask */ /* Specify the scales array and corresponding mask */
conv_attr.set_output_scales(0, output_scale); conv_attr.set_output_scales(0, output_scale);
mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo();
size_t conv_index = insert_primitive(new mkldnn::convolution_forward( size_t conv_index = insert_primitive(new mkldnn::convolution_forward(
{{mkldnn::prop_kind::forward, {{mkldnn::prop_kind::forward,
convolution_algo, mkldnn::algorithm::convolution_direct,
input_data_desc, input_data_desc,
weights_desc, weights_desc,
bias_desc, bias_desc,
......
...@@ -1255,6 +1255,13 @@ namespace ngraph ...@@ -1255,6 +1255,13 @@ namespace ngraph
Strides window_dilation_strides_adjusted; Strides window_dilation_strides_adjusted;
mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo(); mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo();
if (node->get_input_element_type(0) != element::f32 &&
convolution_algo != mkldnn::algorithm::convolution_direct)
{
convolution_algo = mkldnn::algorithm::convolution_direct;
}
for (size_t s : convolution->get_window_dilation_strides()) for (size_t s : convolution->get_window_dilation_strides())
{ {
window_dilation_strides_adjusted.push_back(s - 1); window_dilation_strides_adjusted.push_back(s - 1);
......
...@@ -393,6 +393,13 @@ namespace ngraph ...@@ -393,6 +393,13 @@ namespace ngraph
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr}; std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
auto convolution_algo = mkldnn_utils::get_conv_algo(); auto convolution_algo = mkldnn_utils::get_conv_algo();
if (node->get_input_element_type(0) != element::f32 &&
convolution_algo != mkldnn::algorithm::convolution_direct)
{
convolution_algo = mkldnn::algorithm::convolution_direct;
}
if (use_bias) if (use_bias)
{ {
memory::data_type et_bias = memory::data_type et_bias =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment