Commit c4db6126 authored by pthoreho's avatar pthoreho

added bprop maxpool mkldnn implmenetation in the cpu emitted code

 - workaround to find the workspace requirement for maxpool bprop based on the forward_op reference
parent 7fe7bf50
...@@ -2523,15 +2523,16 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitAvgPoolBackprop) ...@@ -2523,15 +2523,16 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitAvgPoolBackprop)
void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitMaxPoolBackprop) void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitMaxPoolBackprop)
{ {
auto mpb = static_cast<const op::MaxPoolBackprop*>(node); auto mpb = static_cast<const op::MaxPoolBackprop*>(node);
auto max_pool_fprop_op = mpb->get_forward_op();
auto delta_shape = args[0].get_shape(); auto delta_shape = args[1].get_shape();
auto delta_rank = delta_shape.size(); auto delta_rank = delta_shape.size();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
if (delta_rank == 4 && apb->get_window_shape().size() == 2 && if (delta_rank == 4 && mpb->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32) args[0].get_element_type() == element::f32)
{ {
const string& et = get_mkldnn_data_type(args[0].get_element_type().c_type_string()); const string& et = get_mkldnn_data_type(args[1].get_element_type().c_type_string());
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
...@@ -2541,22 +2542,26 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitMaxPoolBackprop) ...@@ -2541,22 +2542,26 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitMaxPoolBackprop)
<< et << ", memory::format::nchw);\n"; << et << ", memory::format::nchw);\n";
writer << "memory::desc result_desc = memory::desc({" << join(out_shape) << "}, " << et writer << "memory::desc result_desc = memory::desc({" << join(out_shape) << "}, " << et
<< ", memory::format::nchw);\n"; << ", memory::format::nchw);\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " << args[0].get_name() writer << "memory input_data = memory({input_data_desc, cpu_engine}, " << args[1].get_name()
<< ");\n"; << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name() writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name()
<< ");\n"; << ");\n";
// Dummy forward primitive descriptor to keep MKLDNN happy, use this to query the workspace // create a forward primitive_desc, use this to query the workspace
// TODO: we need to develop global context to keep the mapping of fprop annd bprop corrosponding // TODO: we need to develop global context to keep the mapping of fprop annd bprop corrosponding
// mkldnn kernels and use it to query the workspace requirement during bprop // mkldnn kernels and use it to query the workspace requirement during bprop
// Note:
// input_data_desc of MaxpoolBackprop : will be same as maxpool(fprop) result_desc
// result_desc of MaxpoolBackprop : will be same as maxpool(fprop) input_data_desc
writer << "pooling_forward::primitive_desc pool_fwd_pd = pooling_forward::primitive_desc(" writer << "pooling_forward::primitive_desc pool_fwd_pd = pooling_forward::primitive_desc("
<< "{prop_kind::forward, algorithm::pooling_max, " << "{prop_kind::forward, algorithm::pooling_max, "
<< "result_desc, input_data_desc, {" << join(mpb->get_window_movement_strides()) << "result_desc, input_data_desc, {" << join(max_pool_fprop_op->get_window_movement_strides())
<< "}, {" << join(mpb->get_window_shape()) << "}, " << "}, {" << join(max_pool_fprop_op->get_window_shape()) << "}, "
<< "{" << join(mpb->get_padding_below()) << "}, " << "{" << join(max_pool_fprop_op->get_padding_below()) << "}, "
<< "{" << join(mpb->get_padding_above()) << "}, " << "{" << join(max_pool_fprop_op->get_padding_above()) << "}, "
<< "padding_kind::zero}, cpu_engine);\n"; << "padding_kind::zero}, cpu_engine);\n";
// query the workspace from the forward primitive desc // query the workspace from the forward primitive desc
writer << "memory max_pool_workspace_memory = " writer << "memory max_pool_workspace_memory = "
"memory(pool_fwd_pd.workspace_primitive_desc());\n"; "memory(pool_fwd_pd.workspace_primitive_desc());\n";
......
...@@ -40,12 +40,12 @@ namespace ngraph ...@@ -40,12 +40,12 @@ namespace ngraph
const std::unordered_set<std::type_index> s_op_registry{ const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::AvgPool), TI(ngraph::op::AvgPool),
TI(ngraph::op::AvgPoolBackprop), TI(ngraph::op::AvgPoolBackprop),
TI(ngraph::op::BatchNorm)};
TI(ngraph::op::Convolution), TI(ngraph::op::Convolution),
TI(ngraph::op::ConvolutionBackpropData), TI(ngraph::op::ConvolutionBackpropData),
TI(ngraph::op::ConvolutionBackpropFilters), TI(ngraph::op::ConvolutionBackpropFilters),
TI(ngraph::op::MaxPool), TI(ngraph::op::MaxPool),
T1(ngraph::op::MaxPoolBackprop)}; TI(ngraph::op::MaxPoolBackprop),
TI(ngraph::op::BatchNorm)};
bool IsMKLDNNOp(ngraph::Node& op) bool IsMKLDNNOp(ngraph::Node& op)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment