Added support for MKLDNN Max pooling in the CPU emitter code

parent 2fe7f0f3
...@@ -2438,20 +2438,69 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitMaxPoolBackprop) ...@@ -2438,20 +2438,69 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitMaxPoolBackprop)
{ {
auto mpb = static_cast<const op::MaxPoolBackprop*>(node); auto mpb = static_cast<const op::MaxPoolBackprop*>(node);
auto delta_shape = args[1].get_shape(); auto delta_shape = args[0].get_shape();
auto delta_rank = delta_shape.size();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
writer << "kernel::max_pool_backprop<" << out[0].get_type() << ">(" << args[0].get_name() if (delta_rank == 4 && apb->get_window_shape().size() == 2 &&
<< ",\n"; args[0].get_element_type() == element::f32)
writer << " " << args[1].get_name() << ",\n"; {
writer << " " << out[0].get_name() << ",\n"; const string& et = get_mkldnn_data_type(args[0].get_element_type().c_type_string());
writer << " {" << join(delta_shape) << "},\n";
writer << " {" << join(out_shape) << "},\n"; writer << "{\n";
writer << " {" << join(mpb->get_window_shape()) << "},\n"; writer.indent++;
writer << " {" << join(mpb->get_window_movement_strides()) << "},\n";
writer << " {" << join(mpb->get_padding_below()) << "},\n"; writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << " {" << join(mpb->get_padding_above()) << "}\n"; writer << "memory::desc input_data_desc = memory::desc({" << join(delta_shape) << "}, "
writer << " );\n"; << et << ", memory::format::nchw);\n";
writer << "memory::desc result_desc = memory::desc({" << join(out_shape) << "}, " << et
<< ", memory::format::nchw);\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " << args[0].get_name()
<< ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name()
<< ");\n";
// Dummy forward primitive descriptor to keep MKLDNN happy, use this to query the workspace
// TODO: we need to develop global context to keep the mapping of fprop annd bprop corrosponding
// mkldnn kernels and use it to query the workspace requirement during bprop
writer << "pooling_forward::primitive_desc pool_fwd_pd = pooling_forward::primitive_desc("
<< "{prop_kind::forward, algorithm::pooling_max, "
<< "result_desc, input_data_desc, {" << join(mpb->get_window_movement_strides())
<< "}, {" << join(mpb->get_window_shape()) << "}, "
<< "{" << join(mpb->get_padding_below()) << "}, "
<< "{" << join(mpb->get_padding_above()) << "}, "
<< "padding_kind::zero}, cpu_engine);\n";
// query the workspace from the forward primitive desc
writer << "memory max_pool_workspace_memory = memory(pool_fwd_pd.workspace_primitive_desc());\n";
writer << "auto avg_pooling = pooling_backward(pooling_backward::primitive_desc("
<< "pooling_backward::desc(algorithm::pooling_max, "
<< "result_desc, input_data_desc, {" << join(mpb->get_window_movement_strides())
<< "}, {" << join(mpb->get_window_shape()) << "}, "
<< "{" << join(mpb->get_padding_below()) << "}, "
<< "{" << join(mpb->get_padding_above()) << "}, "
<< "padding_kind::zero), cpu_engine, pool_fwd_pd), "
<< "input_data, max_pool_workspace_memory, result);\n";
writer << "auto s = stream(stream::kind::eager);\n"
<< "s.submit({avg_pooling}).wait();\n";
writer.indent--;
writer << "}\n";
}
else
{
writer << "kernel::max_pool_backprop<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(delta_shape) << "},\n";
writer << " {" << join(out_shape) << "},\n";
writer << " {" << join(mpb->get_window_shape()) << "},\n";
writer << " {" << join(mpb->get_window_movement_strides()) << "},\n";
writer << " {" << join(mpb->get_padding_below()) << "},\n";
writer << " {" << join(mpb->get_padding_above()) << "}\n";
writer << " );\n";
}
} }
//------------------------------------------------------------------------------------------------ //------------------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment