Commit 1bc38b12 authored by Jayaram Bobba's avatar Jayaram Bobba

Optimize batchnorm backprop by propagating input layouts instead of delta layouts

parent d3ea93e2
......@@ -1037,19 +1037,19 @@ namespace ngraph
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto delta_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 5);
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 2);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x); // gamma
prim_input_formats.push_back(memory::format::x); // beta
prim_input_formats.push_back(delta_layout); // input
prim_input_formats.push_back(input_layout); // input
prim_input_formats.push_back(memory::format::x); // mean
prim_input_formats.push_back(memory::format::x); // variance
prim_input_formats.push_back(delta_layout); // delta
prim_output_formats.push_back(delta_layout); // dinput
prim_input_formats.push_back(input_layout); // delta
prim_output_formats.push_back(input_layout); // dinput
prim_output_formats.push_back(memory::format::x); // dgamma
prim_output_formats.push_back(memory::format::x); // dbeta
node =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment