Commit 2b0cd2fd authored by Jayaram Bobba's avatar Jayaram Bobba

Added MKLDNN avg pool backprop kernel

parent bce27b39
......@@ -91,7 +91,7 @@ static const string& get_mkldnn_data_type(const string& type)
void runtime::cpu::CPU_Emitter::EmitMKLDNNPreamble(codegen::CodeWriter& writer)
{
writer << "// MKLDNN Preamble\n";
writer << "#include <mkldnn.hpp>;\n";
writer << "#include <mkldnn.hpp>\n";
writer << "using namespace mkldnn;\n\n";
}
......@@ -2370,18 +2370,57 @@ void runtime::cpu::CPU_Emitter::EmitAvgPoolBackprop(
auto apb = static_cast<const op::AvgPoolBackprop*>(n);
auto delta_shape = args[0].get_shape();
auto delta_rank = delta_shape.size();
auto out_shape = out[0].get_shape();
if (delta_rank == 4 && apb->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32)
{
const string& et = get_mkldnn_data_type(args[0].get_element_type().c_type_string());
writer << "kernel::avg_pool_backprop<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(delta_shape) << "},\n";
writer << " {" << join(out_shape) << "},\n";
writer << " {" << join(apb->get_window_shape()) << "},\n";
writer << " {" << join(apb->get_window_movement_strides()) << "},\n";
writer << " {" << join(apb->get_padding_below()) << "},\n";
writer << " {" << join(apb->get_padding_above()) << "}\n";
writer << " );\n";
writer << "{\n";
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(delta_shape) << "}, " << et
<< ", memory::format::nchw);\n";
writer << "memory::desc result_desc = memory::desc({" << join(out_shape) << "}, " << et
<< ", memory::format::nchw);\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " << args[0].get_name()
<< ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name() << ");\n";
// Dummy forward primitive descriptor to keep MKLDNN happy
writer << "pooling_forward::primitive_desc fwd_pd = pooling_forward::primitive_desc("
<< "{prop_kind::forward, algorithm::pooling_avg_exclude_padding, "
<< "result_desc, input_data_desc, {" << join(apb->get_window_movement_strides())
<< "}, {" << join(apb->get_window_shape()) << "}, "
<< "{" << join(apb->get_padding_below()) << "}, "
<< "{" << join(apb->get_padding_above()) << "}, "
<< "padding_kind::zero}, cpu_engine);\n";
writer << "auto avg_pooling = pooling_backward(pooling_backward::primitive_desc("
<< "pooling_backward::desc(algorithm::pooling_avg_exclude_padding, "
<< "result_desc, input_data_desc, {" << join(apb->get_window_movement_strides())
<< "}, {" << join(apb->get_window_shape()) << "}, "
<< "{" << join(apb->get_padding_below()) << "}, "
<< "{" << join(apb->get_padding_above()) << "}, "
<< "padding_kind::zero), cpu_engine, fwd_pd), "
<< "input_data, result);\n";
writer << "auto s = stream(stream::kind::eager);\n"
<< "s.submit({avg_pooling}).wait();\n";
writer.indent--;
writer << "}\n";
} else {
writer << "kernel::avg_pool_backprop<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(delta_shape) << "},\n";
writer << " {" << join(out_shape) << "},\n";
writer << " {" << join(apb->get_window_shape()) << "},\n";
writer << " {" << join(apb->get_window_movement_strides()) << "},\n";
writer << " {" << join(apb->get_padding_below()) << "},\n";
writer << " {" << join(apb->get_padding_above()) << "}\n";
writer << " );\n";
}
}
void runtime::cpu::CPU_Emitter::EmitMaxPoolBackprop(
......
......@@ -239,7 +239,8 @@ void runtime::cpu::CPU_ExternalFunction::compile()
for (shared_ptr<Node> node : current_function->get_ordered_ops())
{
if (dynamic_cast<op::Convolution*>(node.get()) ||
dynamic_cast<op::AvgPool*>(node.get()) || dynamic_cast<op::MaxPool*>(node.get()))
dynamic_cast<op::AvgPool*>(node.get()) || dynamic_cast<op::MaxPool*>(node.get()) ||
dynamic_cast<op::AvgPoolBackprop*>(node.get()))
{
include_mkldnn_headers = true;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment