Commit a1880375 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: MKLDNN Average Pooling

parent e8138511
......@@ -2145,16 +2145,59 @@ void runtime::cpu::CPU_Emitter::EmitAvgPool(codegen::CodeWriter& writer,
auto avg_pool = static_cast<const op::AvgPool*>(n);
auto arg_shape = args[0].get_shape();
auto arg_rank = arg_shape.size();
auto result_shape = out[0].get_shape();
writer << "kernel::avg_pool<" << out[0].get_type() << ">(" << args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(arg_shape) << "},\n";
writer << " {" << join(result_shape) << "},\n";
writer << " {" << join(avg_pool->get_window_shape()) << "},\n";
writer << " {" << join(avg_pool->get_window_movement_strides()) << "},\n";
writer << " {" << join(avg_pool->get_padding_below()) << "},\n";
writer << " {" << join(avg_pool->get_padding_above()) << "});\n";
// TODO(jmenon): Refactor into an MKLDNN Pooling emitter that handles
// all pooling variants
// TODO(jmenon): Optimize for 1D
// TODO(jmenon): Remove element type restriction
if (arg_rank == 4 && avg_pool->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32)
{
const string& et = get_mkldnn_data_type(args[0].get_element_type().c_type_string());
writer << "{\n";
writer.indent++;
writer << "auto input_data_desc = memory::desc({" << join(arg_shape) << "}, " << et
<< ", memory::format::nchw);\n";
writer << "auto result_desc = memory::desc({" << join(result_shape) << "}, " << et
<< ", memory::format::nchw);\n";
writer << "auto input_data = memory({input_data_desc, cpu_engine}, " << args[0].get_name()
<< ");\n";
writer << "auto result = memory({result_desc, cpu_engine}, " << out[0].get_name() << ");\n";
// TODO(jmenon): Use a workspace
writer << "auto avg_pooling = pooling_forward({"
<< "{prop_kind::forward_inference, algorithm::pooling_avg, "
<< "input_data_desc, result_desc, {" << join(avg_pool->get_window_movement_strides())
<< "}, {" << join(avg_pool->get_window_shape()) << "}, "
<< "{" << join(avg_pool->get_padding_below()) << "}, "
<< "{" << join(avg_pool->get_padding_above()) << "}, "
<< "padding_kind::zero}, cpu_engine}, "
<< "input_data, result);\n";
writer << "auto s = stream(stream::kind::eager);\n"
<< "s.submit({avg_pooling}).wait();\n";
writer.indent--;
writer << "}\n";
}
else
{
writer << "kernel::avg_pool<" << out[0].get_type() << ">(" << args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(arg_shape) << "},\n";
writer << " {" << join(result_shape) << "},\n";
writer << " {" << join(avg_pool->get_window_shape()) << "},\n";
writer << " {" << join(avg_pool->get_window_movement_strides()) << "},\n";
writer << " {" << join(avg_pool->get_padding_below()) << "},\n";
writer << " {" << join(avg_pool->get_padding_above()) << "});\n";
}
}
void runtime::cpu::CPU_Emitter::EmitPad(codegen::CodeWriter& writer,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment