Commit a6be909f authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Port AvgPool

parent 93560cdf
...@@ -2438,59 +2438,36 @@ namespace ngraph ...@@ -2438,59 +2438,36 @@ namespace ngraph
auto arg_shape = args[0].get_shape(); auto arg_shape = args[0].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
// TODO(jmenon): Refactor into an MKLDNN Pooling emitter that handles
// all pooling variants
// TODO(jmenon): Optimize for 1D // TODO(jmenon): Optimize for 1D
// TODO(jmenon): Remove element type restriction // TODO(jmenon): Remove element type restriction
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string( auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
args[0].get_element_type()); auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
const char* algorithm_enumerator = auto result_desc = mkldnn_emitter->build_memory_descriptor(
avg_pool->get_include_padding_in_avg_computation() out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
? "algorithm::pooling_avg_include_padding"
: "algorithm::pooling_avg_exclude_padding";
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
writer << "{\n";
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; size_t avg_pool_index = mkldnn_emitter->build_pooling_forward(
writer << "memory::desc input_data_desc = memory::desc({" << join(arg_shape) (avg_pool->get_include_padding_in_avg_computation()
<< "}, " << et << ", " ? mkldnn::algorithm::pooling_avg_include_padding
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format) : mkldnn::algorithm::pooling_avg_exclude_padding),
<< ");\n"; input_desc,
writer << "memory::desc result_desc = memory::desc({" << join(result_shape) result_desc,
<< "}, " << et << ", " avg_pool->get_window_movement_strides(),
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format) avg_pool->get_window_shape(),
<< ");\n"; avg_pool->get_padding_below(),
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " avg_pool->get_padding_above());
<< args[0].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n";
// TODO(jmenon): Use a workspace auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index);
writer << "pooling_forward avg_pooling = pooling_forward({" writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< "{prop_kind::forward_inference, " << algorithm_enumerator << ", " << ", " << args[0].get_name() << ");\n";
<< "input_data_desc, result_desc, {" writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< join(avg_pool->get_window_movement_strides()) << "}, {" << ", " << out[0].get_name() << ");\n";
<< join(avg_pool->get_window_shape()) << "}, "
<< "{" << join(avg_pool->get_padding_below()) << "}, "
<< "{" << join(avg_pool->get_padding_above()) << "}, "
<< "padding_kind::zero}, cpu_engine}, "
<< "input_data, result);\n";
writer << "stream s = stream(stream::kind::eager);\n" writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< "s.submit({avg_pooling}).wait();\n"; << to_string(avg_pool_index) << ");\n";
writer.indent--;
writer << "}\n";
} }
else else
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment