Commit 6716068c authored by gaurides's avatar gaurides Committed by Robert Kimball

Fix accuracy issue (#2652)

* Fix accuracy issue

* Style fix
parent 29dd3e3f
...@@ -1363,54 +1363,68 @@ namespace ngraph ...@@ -1363,54 +1363,68 @@ namespace ngraph
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Quantize) void CPULayout::LAYOUT_DECL(ngraph::op::Quantize)
{ {
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0); if (mkldnn_utils::use_mkldnn_kernel(node.get()))
auto tv = node->get_output_tensor_ptr(0);
auto fmt = static_cast<mkldnn::memory::format>(input_md.data.format);
if (fmt == mkldnn_blocked || fmt == mkldnn_format_undef ||
!mkldnn_utils::can_create_mkldnn_md(tv->get_element_type()))
{ {
// Cannot pass through layout information for blocked layouts at the moment auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
set_native_layouts(external_function, node); auto tv = node->get_output_tensor_ptr(0);
auto fmt = static_cast<mkldnn::memory::format>(input_md.data.format);
if (fmt == mkldnn_blocked || fmt == mkldnn_format_undef ||
!mkldnn_utils::can_create_mkldnn_md(tv->get_element_type()))
{
// Cannot pass through layout information for blocked layouts at the moment
set_native_layouts(external_function, node);
}
else
{
// mkldnn expects nhwc for int8, avoids reorder
if (fmt == mkldnn::memory::format::nchw ||
fmt == mkldnn::memory::format::nChw8c ||
fmt == mkldnn::memory::format::nChw16c)
{
fmt = mkldnn::memory::format::nhwc;
}
vector<memory::desc> o_mds;
o_mds.push_back(mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, true, static_cast<memory::format>(fmt)));
set_output_layouts(node, o_mds);
}
} }
else else
{ {
// mkldnn expects nhwc for int8, avoids reorder set_native_layouts(external_function, node);
if (fmt == mkldnn::memory::format::nchw ||
fmt == mkldnn::memory::format::nChw8c ||
fmt == mkldnn::memory::format::nChw16c)
{
fmt = mkldnn::memory::format::nhwc;
}
vector<memory::desc> o_mds;
o_mds.push_back(mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, true, static_cast<memory::format>(fmt)));
set_output_layouts(node, o_mds);
} }
} }
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Dequantize) void CPULayout::LAYOUT_DECL(ngraph::op::Dequantize)
{ {
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0); if (mkldnn_utils::use_mkldnn_kernel(node.get()))
auto tv = node->get_output_tensor_ptr(0);
auto fmt = static_cast<mkldnn::memory::format>(input_md.data.format);
if (fmt == mkldnn_blocked || fmt == mkldnn_format_undef ||
!mkldnn_utils::can_create_mkldnn_md(tv->get_element_type()))
{ {
// Cannot pass through layout information for blocked layouts at the moment auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
set_native_layouts(external_function, node); auto tv = node->get_output_tensor_ptr(0);
auto fmt = static_cast<mkldnn::memory::format>(input_md.data.format);
if (fmt == mkldnn_blocked || fmt == mkldnn_format_undef ||
!mkldnn_utils::can_create_mkldnn_md(tv->get_element_type()))
{
// Cannot pass through layout information for blocked layouts at the moment
set_native_layouts(external_function, node);
}
else
{
// reorder as default nchw layout
if (fmt == mkldnn::memory::format::nhwc)
{
fmt = mkldnn::memory::format::nchw;
}
vector<memory::desc> o_mds;
o_mds.push_back(mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, true, static_cast<memory::format>(fmt)));
set_output_layouts(node, o_mds);
}
} }
else else
{ {
// reorder as default nchw layout set_native_layouts(external_function, node);
if (fmt == mkldnn::memory::format::nhwc)
{
fmt = mkldnn::memory::format::nchw;
}
vector<memory::desc> o_mds;
o_mds.push_back(mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, true, static_cast<memory::format>(fmt)));
set_output_layouts(node, o_mds);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment