Commit 1fdf2d98 authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

Added DEX support for (MaxPool + AvgPool) Backprop op for CPU backend (#1302)

* - Added DEX support for MaxPoolBackprop op for CPU backend

* Added DEX execution support for AvgPoolBackprop
parent 2a0e43ef
......@@ -108,7 +108,85 @@ namespace ngraph
}
}
template <>
void Builder::BUILDER_DECL(ngraph::op::AvgPoolBackprop)
{
auto apb = static_cast<const ngraph::op::AvgPoolBackprop*>(node);
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto delta_shape = args[0].get_shape();
auto out_shape = out[0].get_shape();
auto& delta_tensor = tensor_data[args[0].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto window_shape = apb->get_window_shape();
auto window_movement_strides = apb->get_window_movement_strides();
auto padding_below = apb->get_padding_below();
auto padding_above = apb->get_padding_above();
auto include_padding_in_avg_computation =
apb->get_include_padding_in_avg_computation();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto diff_dst_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto diff_src_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t avg_pool_index = mkldnn_emitter->build_pooling_backward(
(apb->get_include_padding_in_avg_computation()
? mkldnn::algorithm::pooling_avg_include_padding
: mkldnn::algorithm::pooling_avg_exclude_padding),
diff_dst_desc,
diff_src_desc,
apb->get_window_movement_strides(),
apb->get_window_shape(),
apb->get_padding_below(),
apb->get_padding_above());
auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index);
auto functor = [&, avg_pool_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], delta_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, avg_pool_index);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::avg_pool_backprop<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::avg_pool_backprop);
auto functor = [&,
kernel,
delta_shape,
out_shape,
window_shape,
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation](CPURuntimeContext* ctx) {
kernel(delta_tensor,
out_tensor,
delta_shape,
out_shape,
window_shape,
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation);
};
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(AvgPool);
REGISTER_OP_BUILDER(AvgPoolBackprop);
}
}
}
......@@ -102,6 +102,95 @@ namespace ngraph
functors.emplace_back(functor);
}
}
template <>
void Builder::BUILDER_DECL(ngraph::op::MaxPoolBackprop)
{
auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node);
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg_fwd_shape = args[0].get_shape();
auto delta_shape = args[1].get_shape();
auto out_shape = out[0].get_shape();
auto& arg_fwd_tensor = tensor_data[args[0].get_name()];
auto& delta_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto window_shape = mpb->get_window_shape();
auto window_movement_strides = mpb->get_window_movement_strides();
auto padding_below = mpb->get_padding_below();
auto padding_above = mpb->get_padding_above();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto fprop_src_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto diff_dst_desc = mkldnn_emitter->build_memory_descriptor(
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto diff_src_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pooling_backward(
mkldnn::algorithm::pooling_max,
fprop_src_desc,
diff_dst_desc,
diff_src_desc,
mpb->get_window_movement_strides(),
mpb->get_window_shape(),
mpb->get_padding_below(),
mpb->get_padding_above());
auto& fdeps = mkldnn_emitter->get_primitive_deps(max_pool_index - 1);
auto functor_fprop = [&, max_pool_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, fdeps[0], arg_fwd_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, fdeps[1], out_tensor);
cpu::mkldnn_utils::set_memory_ptr(
ctx, fdeps[2], ctx->mkldnn_workspaces[fdeps[3]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, max_pool_index - 1);
};
functors.emplace_back(functor_fprop);
auto& bdeps = mkldnn_emitter->get_primitive_deps(max_pool_index);
auto functor_bprop = [&, max_pool_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, bdeps[0], delta_tensor);
cpu::mkldnn_utils::set_memory_ptr(
ctx, bdeps[1], ctx->mkldnn_workspaces[bdeps[3]]);
cpu::mkldnn_utils::set_memory_ptr(ctx, bdeps[2], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, max_pool_index);
};
functors.emplace_back(functor_bprop);
}
else
{
std::function<decltype(runtime::cpu::kernel::max_pool_backprop<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::max_pool_backprop);
auto functor = [&,
kernel,
arg_fwd_shape,
delta_shape,
out_shape,
window_shape,
window_movement_strides,
padding_below,
padding_above](CPURuntimeContext* ctx) {
kernel(arg_fwd_tensor,
delta_tensor,
out_tensor,
delta_shape,
arg_fwd_shape,
window_shape,
window_movement_strides,
padding_below,
padding_above);
};
functors.emplace_back(functor);
}
}
template <>
void Builder::BUILDER_DECL(ngraph::op::MaxPoolWithIndices)
......@@ -190,6 +279,7 @@ namespace ngraph
}
REGISTER_OP_BUILDER(MaxPool);
REGISTER_OP_BUILDER(MaxPoolBackprop);
REGISTER_OP_BUILDER(MaxPoolWithIndices);
REGISTER_OP_BUILDER(MaxPoolWithIndicesBackprop);
}
......
......@@ -48,6 +48,29 @@ namespace ngraph
padding_above,
include_padding_in_avg_computation);
}
template <typename ElementType>
void avg_pool_backprop(void* delta,
void* out,
const Shape& delta_shape,
const Shape& out_shape,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above,
bool include_padding_in_avg_computation)
{
reference::avg_pool_backprop<ElementType>(
static_cast<const ElementType*>(delta),
static_cast<ElementType*>(out),
delta_shape,
out_shape,
window_shape,
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation);
}
}
}
}
......
......@@ -46,6 +46,29 @@ namespace ngraph
padding_below,
padding_above);
}
template <typename ElementType>
void max_pool_backprop(void* arg_forward,
void* delta,
void* out,
const Shape& delta_shape,
const Shape& out_shape, // same as arg_forward_shape
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above)
{
reference::max_pool_backprop<ElementType>(
static_cast<const ElementType*>(arg_forward),
static_cast<const ElementType*>(delta),
static_cast<ElementType*>(out),
delta_shape,
out_shape,
window_shape,
window_movement_strides,
padding_below,
padding_above);
}
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment