Commit 2941cd42 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Port MaxPool

parent e9467f4b
...@@ -2275,38 +2275,28 @@ namespace ngraph ...@@ -2275,38 +2275,28 @@ namespace ngraph
if (arg_rank == 4 && max_pool->get_window_shape().size() == 2 && if (arg_rank == 4 && max_pool->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32) args[0].get_element_type() == element::f32)
{ {
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string( auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
args[0].get_element_type()); auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
writer << "{\n"; auto result_desc = mkldnn_emitter->build_memory_descriptor(
writer.indent++; out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(arg_shape)
<< "}, " << et << ", memory::format::nchw);\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape)
<< "}, " << et << ", memory::format::nchw);\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " size_t max_pool_index = mkldnn_emitter->build_max_pool_forward(
<< args[0].get_name() << ");\n"; input_desc,
writer << "memory result = memory({result_desc, cpu_engine}, " result_desc,
<< out[0].get_name() << ");\n"; max_pool->get_window_movement_strides(),
max_pool->get_window_shape(),
max_pool->get_padding_below(),
max_pool->get_padding_above());
// TODO(jmenon): Use a workspace auto& deps = mkldnn_emitter->get_primitive_deps(max_pool_index);
writer << "pooling_forward max_pooling = pooling_forward({" writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< "{prop_kind::forward_inference, algorithm::pooling_max, " << ", " << args[0].get_name() << ");\n";
<< "input_data_desc, result_desc, {" writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< join(max_pool->get_window_movement_strides()) << "}, {" << ", " << out[0].get_name() << ");\n";
<< join(max_pool->get_window_shape()) << "}, {"
<< join(max_pool->get_padding_below()) << "}, "
<< "{" << join(max_pool->get_padding_above())
<< "}, padding_kind::zero}, cpu_engine}, "
<< "input_data, result);\n";
writer << "stream s = stream(stream::kind::eager);\n" writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< "s.submit({max_pooling}).wait();\n"; << to_string(max_pool_index) << ");\n";
writer.indent--;
writer << "}\n";
} }
else else
{ {
......
...@@ -222,6 +222,34 @@ size_t MKLDNNEmitter::build_convolution_backward_data(const mkldnn::memory::desc ...@@ -222,6 +222,34 @@ size_t MKLDNNEmitter::build_convolution_backward_data(const mkldnn::memory::desc
return primitive_index; return primitive_index;
} }
size_t MKLDNNEmitter::build_max_pool_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::pooling_forward(
{{mkldnn::prop_kind::forward_inference,
mkldnn::algorithm::pooling_max,
input_desc,
result_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine},
*mkldnn_primitives[input_index],
*mkldnn_primitives[result_index]));
primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_elementwise_add( size_t MKLDNNEmitter::build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc, const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc, const mkldnn::memory::desc& input1_data_desc,
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <mkldnn.hpp> #include <mkldnn.hpp>
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
namespace ngraph namespace ngraph
...@@ -86,6 +87,13 @@ namespace ngraph ...@@ -86,6 +87,13 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above); const ngraph::CoordinateDiff& padding_above);
size_t build_max_pool_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_elementwise_add( size_t build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc, const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc, const mkldnn::memory::desc& input1_data_desc,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment