Commit c38c76a7 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

DEX MaxPoolWithIndices (#1299)

* dex max_pool_with_indices

* maxpoolwithindices (#1300)
parent b1239af4
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -102,7 +103,95 @@ namespace ngraph ...@@ -102,7 +103,95 @@ namespace ngraph
} }
} }
template <>
void Builder::BUILDER_DECL(ngraph::op::MaxPoolWithIndices)
{
if (!runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error("MaxPoolWithIndices isn't supported");
}
auto max_pool = static_cast<const ngraph::op::MaxPoolWithIndices*>(node);
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto& out1_tensor = tensor_data[out[1].get_name()];
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pooling_with_indices_forward(
mkldnn::algorithm::pooling_max,
input_desc,
result_desc,
max_pool->get_window_movement_strides(),
max_pool->get_window_shape(),
max_pool->get_padding_below(),
max_pool->get_padding_above());
auto& deps = mkldnn_emitter->get_primitive_deps(max_pool_index);
auto functor = [&, max_pool_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out0_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out1_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, max_pool_index);
};
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::MaxPoolWithIndicesBackprop)
{
if (!runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error("MaxPoolWithIndicesBackprop isn't supported");
}
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& arg2_tensor = tensor_data[args[2].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto mpb = static_cast<const ngraph::op::MaxPoolWithIndicesBackprop*>(node);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto diff_dst_desc = mkldnn_emitter->build_memory_descriptor(
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto diff_src_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pooling_with_indices_backward(
mkldnn::algorithm::pooling_max,
diff_dst_desc,
diff_src_desc,
mpb->get_window_movement_strides(),
mpb->get_window_shape(),
mpb->get_padding_below(),
mpb->get_padding_above());
auto& deps = mkldnn_emitter->get_primitive_deps(max_pool_index);
auto functor = [&, max_pool_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg1_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], arg2_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, max_pool_index);
};
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(MaxPool); REGISTER_OP_BUILDER(MaxPool);
REGISTER_OP_BUILDER(MaxPoolWithIndices);
REGISTER_OP_BUILDER(MaxPoolWithIndicesBackprop);
} }
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment