Commit 2c2de707 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Hack to help MKLDNN avoid ref convolution (#669)

parent 4646bcae
......@@ -2197,6 +2197,11 @@ namespace ngraph
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto weights_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
// HACK to help MKLDNN pick the right implementation
if (weights_format == mkldnn::memory::format::nchw)
{
weights_format = mkldnn::memory::format::oihw;
}
auto output_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
......@@ -2341,8 +2346,15 @@ namespace ngraph
}
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
// HACK to help MKLDNN pick the right implementation
auto weights_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
if (weights_format == mkldnn::memory::format::nchw)
{
weights_format = mkldnn::memory::format::oihw;
}
auto weights_desc =
mkldnn_emitter->build_memory_descriptor(args[0], weights_format);
auto delta_desc = mkldnn_emitter->build_memory_descriptor(
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto result_desc = mkldnn_emitter->build_memory_descriptor(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment