Commit bbe8da24 authored by Jayaram Bobba's avatar Jayaram Bobba

Added optimal layouts for MKLDNN relu fprop and bprop

parent 6ef2d5a0
......@@ -2540,9 +2540,8 @@ namespace ngraph
<< ");\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape)
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format)
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format)
<< ");\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, "
......@@ -3059,12 +3058,27 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::ReluBackprop)
{
const auto& arg_shape = args[0].get_shape();
const size_t arg_rank = arg_shape.size();
const auto& result_shape = out[0].get_shape();
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[0].get_element_type());
if (arg_rank == 4 && args[0].get_element_type() == element::f32)
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[0].get_element_type());
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto delta_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
if (!runtime::cpu::mkldnn_utils::compare_mkldnn_formats(input_format,
delta_format))
{
throw ngraph_error(
"mkldnn emitter: Relu backprop fprop input and delta layouts should be "
"the same");
}
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
writer << "{\n";
writer.indent++;
......@@ -3072,12 +3086,17 @@ namespace ngraph
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(arg_shape)
<< "}, " << et << ", memory::format::nchw);\n";
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format)
<< ");\n";
writer << "memory::desc delta_data_desc = memory::desc({"
<< join(args[1].get_shape()) << "}, " << et
<< ", memory::format::nchw);\n";
<< join(args[1].get_shape()) << "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(delta_format)
<< ");\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape)
<< "}, " << et << ", memory::format::nchw);\n";
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format)
<< ");\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
......@@ -3125,12 +3144,18 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Relu)
{
const auto& arg_shape = args[0].get_shape();
const size_t arg_rank = arg_shape.size();
const auto& result_shape = out[0].get_shape();
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[0].get_element_type());
if (arg_rank == 4 && args[0].get_element_type() == element::f32)
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[0].get_element_type());
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
writer << "{\n";
writer.indent++;
......@@ -3138,9 +3163,13 @@ namespace ngraph
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(arg_shape)
<< "}, " << et << ", memory::format::nchw);\n";
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format)
<< ");\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape)
<< "}, " << et << ", memory::format::nchw);\n";
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format)
<< ");\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
......
......@@ -27,6 +27,7 @@
#include "ngraph/descriptor/output.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/relu.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
......@@ -159,6 +160,42 @@ namespace ngraph
avg_pool->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu)
{
auto avg_pool = static_cast<op::Relu*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
avg_pool->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ReluBackprop)
{
auto avg_pool = static_cast<op::ReluBackprop*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
avg_pool->set_op_annotations(op_annotations);
}
}
}
}
}
......@@ -176,6 +213,9 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Relu>},
{TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReluBackprop>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
......@@ -28,6 +28,7 @@
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/relu.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
......@@ -625,6 +626,46 @@ namespace ngraph
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Relu)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ReluBackprop)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(input_layout);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
}
}
}
......@@ -641,6 +682,9 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
{TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ReluBackprop>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment