Commit 93560cdf authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Generalize pooling

parent 2941cd42
...@@ -2281,7 +2281,8 @@ namespace ngraph ...@@ -2281,7 +2281,8 @@ namespace ngraph
auto result_desc = mkldnn_emitter->build_memory_descriptor( auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0)); out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pool_forward( size_t max_pool_index = mkldnn_emitter->build_pooling_forward(
mkldnn::algorithm::pooling_max,
input_desc, input_desc,
result_desc, result_desc,
max_pool->get_window_movement_strides(), max_pool->get_window_movement_strides(),
......
...@@ -222,19 +222,20 @@ size_t MKLDNNEmitter::build_convolution_backward_data(const mkldnn::memory::desc ...@@ -222,19 +222,20 @@ size_t MKLDNNEmitter::build_convolution_backward_data(const mkldnn::memory::desc
return primitive_index; return primitive_index;
} }
size_t MKLDNNEmitter::build_max_pool_forward(const mkldnn::memory::desc& input_desc, size_t MKLDNNEmitter::build_pooling_forward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& input_desc,
const ngraph::Strides& window_strides, const mkldnn::memory::desc& result_desc,
const ngraph::Shape& window_shape, const ngraph::Strides& window_strides,
const ngraph::Shape& padding_below, const ngraph::Shape& window_shape,
const ngraph::Shape& padding_above) const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{ {
size_t input_index = build_memory_primitive(input_desc); size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc); size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::pooling_forward( size_t primitive_index = insert_primitive(new mkldnn::pooling_forward(
{{mkldnn::prop_kind::forward_inference, {{mkldnn::prop_kind::forward_inference,
mkldnn::algorithm::pooling_max, pooling_algorithm,
input_desc, input_desc,
result_desc, result_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()), mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
......
...@@ -87,12 +87,13 @@ namespace ngraph ...@@ -87,12 +87,13 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above); const ngraph::CoordinateDiff& padding_above);
size_t build_max_pool_forward(const mkldnn::memory::desc& input_desc, size_t build_pooling_forward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& input_desc,
const ngraph::Strides& window_strides, const mkldnn::memory::desc& result_desc,
const ngraph::Shape& window_shape, const ngraph::Strides& window_strides,
const ngraph::Shape& padding_below, const ngraph::Shape& window_shape,
const ngraph::Shape& padding_above); const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_elementwise_add( size_t build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc, const mkldnn::memory::desc& input0_data_desc,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment