Commit 8a5b29a1 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Changes to layout selection heuristic for cases like XLA frontend where fprop…

Changes to layout selection heuristic for cases like XLA frontend where fprop layouts are not propagated to bprop graph (#769)
parent 27029eae
......@@ -221,3 +221,12 @@ bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(mkldnn::memory::format
}
return false;
}
bool runtime::cpu::mkldnn_utils::is_mkldnn_blocked_data_format(mkldnn::memory::format fmt)
{
if (fmt == memory::format::nChw8c || fmt == memory::format::nChw16c)
{
return true;
}
return false;
}
......@@ -45,6 +45,7 @@ namespace ngraph
bool compare_mkldnn_formats(mkldnn::memory::format fmt1,
mkldnn::memory::format fmt2);
bool is_mkldnn_filter_format(mkldnn::memory::format fmt);
bool is_mkldnn_blocked_data_format(mkldnn::memory::format fmt);
}
}
}
......
......@@ -988,14 +988,21 @@ namespace ngraph
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
auto kernel_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
if (!runtime::cpu::mkldnn_utils::is_mkldnn_blocked_data_format(
kernel_layout))
{
// Propagate delta layout
kernel_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 1);
}
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(input_layout);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_input_formats.push_back(kernel_layout);
prim_input_formats.push_back(kernel_layout);
prim_output_formats.push_back(kernel_layout);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
......@@ -1051,19 +1058,27 @@ namespace ngraph
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
auto kernel_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 2);
if (!runtime::cpu::mkldnn_utils::is_mkldnn_blocked_data_format(
kernel_layout))
{
// Propagate delta layout
kernel_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 5);
}
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x); // gamma
prim_input_formats.push_back(memory::format::x); // beta
prim_input_formats.push_back(input_layout); // input
prim_input_formats.push_back(kernel_layout); // input
prim_input_formats.push_back(memory::format::x); // mean
prim_input_formats.push_back(memory::format::x); // variance
prim_input_formats.push_back(input_layout); // delta
prim_output_formats.push_back(input_layout); // dinput
prim_input_formats.push_back(kernel_layout); // delta
prim_output_formats.push_back(kernel_layout); // dinput
prim_output_formats.push_back(memory::format::x); // dgamma
prim_output_formats.push_back(memory::format::x); // dbeta
node =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment