Commit 1bc38b12 authored by Jayaram Bobba's avatar Jayaram Bobba

Optimize batchnorm backprop by propagating input layouts instead of delta layouts

parent d3ea93e2
...@@ -1037,19 +1037,19 @@ namespace ngraph ...@@ -1037,19 +1037,19 @@ namespace ngraph
{ {
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get())) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{ {
auto delta_layout = auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 5); runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 2);
vector<memory::format> prim_input_formats; vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats; vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x); // gamma prim_input_formats.push_back(memory::format::x); // gamma
prim_input_formats.push_back(memory::format::x); // beta prim_input_formats.push_back(memory::format::x); // beta
prim_input_formats.push_back(delta_layout); // input prim_input_formats.push_back(input_layout); // input
prim_input_formats.push_back(memory::format::x); // mean prim_input_formats.push_back(memory::format::x); // mean
prim_input_formats.push_back(memory::format::x); // variance prim_input_formats.push_back(memory::format::x); // variance
prim_input_formats.push_back(delta_layout); // delta prim_input_formats.push_back(input_layout); // delta
prim_output_formats.push_back(delta_layout); // dinput prim_output_formats.push_back(input_layout); // dinput
prim_output_formats.push_back(memory::format::x); // dgamma prim_output_formats.push_back(memory::format::x); // dgamma
prim_output_formats.push_back(memory::format::x); // dbeta prim_output_formats.push_back(memory::format::x); // dbeta
node = node =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment