Commit 6ef2d5a0 authored by Jayaram Bobba's avatar Jayaram Bobba

Added MKLDNN optimal layouts to avg_pool fprop and bprop

parent 2522ae5e
......@@ -2507,8 +2507,6 @@ namespace ngraph
auto avg_pool = static_cast<const ngraph::op::AvgPool*>(node);
auto arg_shape = args[0].get_shape();
auto arg_rank = arg_shape.size();
auto result_shape = out[0].get_shape();
// TODO(jmenon): Refactor into an MKLDNN Pooling emitter that handles
......@@ -2517,8 +2515,7 @@ namespace ngraph
// TODO(jmenon): Optimize for 1D
// TODO(jmenon): Remove element type restriction
if (arg_rank == 4 && avg_pool->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32)
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[0].get_element_type());
......@@ -2528,14 +2525,23 @@ namespace ngraph
? "algorithm::pooling_avg_include_padding"
: "algorithm::pooling_avg_exclude_padding";
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
writer << "{\n";
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(arg_shape)
<< "}, " << et << ", memory::format::nchw);\n";
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format)
<< ");\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape)
<< "}, " << et << ", memory::format::nchw);\n";
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format)
<< ");\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
......@@ -2603,23 +2609,30 @@ namespace ngraph
auto apb = static_cast<const ngraph::op::AvgPoolBackprop*>(node);
auto delta_shape = args[0].get_shape();
auto delta_rank = delta_shape.size();
auto out_shape = out[0].get_shape();
if (delta_rank == 4 && apb->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32)
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[0].get_element_type());
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
writer << "{\n";
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(delta_shape)
<< "}, " << et << ", memory::format::nchw);\n";
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format)
<< ");\n";
writer << "memory::desc result_desc = memory::desc({" << join(out_shape)
<< "}, " << et << ", memory::format::nchw);\n";
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format)
<< ");\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, "
......
......@@ -25,6 +25,7 @@
#include <mkldnn.hpp>
#include "ngraph/descriptor/output.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
......@@ -120,6 +121,44 @@ namespace ngraph
convolution->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::AvgPool)
{
auto avg_pool = static_cast<op::AvgPool*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && avg_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
avg_pool->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::AvgPoolBackprop)
{
auto avg_pool = static_cast<op::AvgPoolBackprop*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && avg_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
avg_pool->set_op_annotations(op_annotations);
}
}
}
}
}
......@@ -134,6 +173,9 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
......@@ -25,6 +25,7 @@
#include "cpu_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
......@@ -461,6 +462,169 @@ namespace ngraph
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::AvgPool)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto avg_pool = static_cast<const ngraph::op::AvgPool*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto result_shape = node->get_output_shape(0);
auto filter_shape = avg_pool->get_window_shape();
auto filter_strides = avg_pool->get_window_movement_strides();
auto padding_below = avg_pool->get_padding_below();
auto padding_above = avg_pool->get_padding_above();
memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(0));
algorithm algorithm_enumerator =
avg_pool->get_include_padding_in_avg_computation()
? algorithm::pooling_avg_include_padding
: algorithm::pooling_avg_exclude_padding;
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_filter_shape(filter_shape.begin(), filter_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
auto input_desc = memory::desc(mkldnn_arg0_shape, et, input_layout);
auto result_desc =
memory::desc(mkldnn_result_shape, et, memory::format::any);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
try
{
auto prim_desc = pooling_forward::primitive_desc(
{prop_kind::forward_inference,
algorithm_enumerator,
input_desc,
result_desc,
mkldnn_filter_strides,
mkldnn_filter_shape,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero},
runtime::cpu::mkldnn_utils::global_cpu_engine);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.dst_primitive_desc().desc().data.format));
}
catch (const mkldnn::error& e)
{
// TODO (jbobba): Check with MKLDNN folks if this is necessary
throw ngraph_error("MKLDNN Unsupported pooling layout" +
to_string(input_layout));
prim_input_formats.push_back(memory::format::nchw);
prim_output_formats.push_back(memory::format::nchw);
}
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::AvgPoolBackprop)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto avg_pool = static_cast<const ngraph::op::AvgPoolBackprop*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto result_shape = node->get_output_shape(0);
auto filter_shape = avg_pool->get_window_shape();
auto filter_strides = avg_pool->get_window_movement_strides();
auto padding_below = avg_pool->get_padding_below();
auto padding_above = avg_pool->get_padding_above();
memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(0));
algorithm algorithm_enumerator =
avg_pool->get_include_padding_in_avg_computation()
? algorithm::pooling_avg_include_padding
: algorithm::pooling_avg_exclude_padding;
memory::dims mkldnn_arg0_shape(arg0_shape.begin(), arg0_shape.end());
memory::dims mkldnn_result_shape(result_shape.begin(), result_shape.end());
memory::dims mkldnn_filter_shape(filter_shape.begin(), filter_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
auto input_desc = memory::desc(mkldnn_arg0_shape, et, input_layout);
auto result_desc =
memory::desc(mkldnn_result_shape, et, memory::format::any);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
try
{
auto fwd_prim_desc = pooling_forward::primitive_desc(
{prop_kind::forward_inference,
algorithm_enumerator,
result_desc,
input_desc,
mkldnn_filter_strides,
mkldnn_filter_shape,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero},
runtime::cpu::mkldnn_utils::global_cpu_engine);
auto prim_desc = pooling_backward::primitive_desc(
{algorithm_enumerator,
result_desc,
input_desc,
mkldnn_filter_strides,
mkldnn_filter_shape,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero},
runtime::cpu::mkldnn_utils::global_cpu_engine,
fwd_prim_desc);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_src_primitive_desc().desc().data.format));
}
catch (const mkldnn::error& e)
{
// TODO (jbobba): Check with MKLDNN folks if this is necessary
throw ngraph_error("MKLDNN Unsupported pooling layout" +
to_string(input_layout));
prim_input_formats.push_back(memory::format::nchw);
prim_output_formats.push_back(memory::format::nchw);
}
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
}
}
}
......@@ -474,6 +638,9 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPoolBackprop>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
......@@ -6100,8 +6100,9 @@ TEST(${BACKEND_NAME}, convolution_outlining)
EXPECT_EQ(vector<float>{expected_result}, read_vector<float>(result));
}
TEST(${BACKEND_NAME}, convolution_layout)
TEST(${BACKEND_NAME}, mkldnn_layouts)
{
SKIP_TEST_FOR("INTERPRETER", "${BACKEND_NAME}");
Shape shape_a{1, 16, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{32, 16, 1, 1};
......@@ -6114,7 +6115,9 @@ TEST(${BACKEND_NAME}, convolution_layout)
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto f = make_shared<Function>(conv1, op::Parameters{A, B});
Shape pool_shape{1, 1};
auto pool1 = make_shared<op::AvgPool>(conv1, pool_shape);
auto f = make_shared<Function>(pool1, op::Parameters{A, B});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment