Commit eb83f267 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Standalone codegen. Ops {Q}MaxPool and {Q}AvgPool (#2867)

parent 10b43d55
......@@ -2548,17 +2548,16 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
size_t max_pool_index = external_function->get_primitive_index(node);
auto& deps = mkldnn_emitter->get_primitive_deps(max_pool_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
size_t max_pool_index;
std::vector<std::size_t> deps;
emit_build_primitives(external_function, node, writer, max_pool_index, deps);
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(max_pool_index) << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0]) << ", "
<< args[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(max_pool_index)
<< ");\n";
}
else
{
......@@ -2581,16 +2580,16 @@ namespace ngraph
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
size_t qmax_pool_index = external_function->get_primitive_index(node);
auto& deps = mkldnn_emitter->get_primitive_deps(qmax_pool_index);
size_t max_pool_index;
std::vector<std::size_t> deps;
emit_build_primitives(external_function, node, writer, max_pool_index, deps);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(qmax_pool_index) << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0]) << ", "
<< args[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(max_pool_index)
<< ");\n";
}
else
{
......@@ -2603,15 +2602,16 @@ namespace ngraph
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
size_t qavg_pool_index = external_function->get_primitive_index(node);
auto& deps = mkldnn_emitter->get_primitive_deps(qavg_pool_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(qavg_pool_index) << ");\n";
size_t avg_pool_index;
std::vector<std::size_t> deps;
emit_build_primitives(external_function, node, writer, avg_pool_index, deps);
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0]) << ", "
<< args[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(avg_pool_index)
<< ");\n";
}
else
{
......@@ -2746,17 +2746,16 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
size_t avg_pool_index = external_function->get_primitive_index(node);
auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
size_t avg_pool_index;
std::vector<std::size_t> deps;
emit_build_primitives(external_function, node, writer, avg_pool_index, deps);
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(avg_pool_index) << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[0]) << ", "
<< args[0].get_name() << ");\n";
writer << "cg_ctx->set_memory_ptr(" << to_string(deps[1]) << ", "
<< out[0].get_name() << ");\n";
writer << "cg_ctx->mkldnn_invoke_primitive(" << to_string(avg_pool_index)
<< ");\n";
}
else
{
......
......@@ -276,39 +276,6 @@ size_t MKLDNNEmitter::build_dequantization(const ngraph::Node* node,
return dequantize_index;
}
size_t MKLDNNEmitter::build_quantized_max_pool(const ngraph::Node* node)
{
auto qmax_pool = static_cast<const ngraph::op::QuantizedMaxPool*>(node);
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t qmax_pool_index = this->build_pooling_forward(mkldnn::algorithm::pooling_max,
input_desc,
result_desc,
qmax_pool->get_window_movement_strides(),
qmax_pool->get_window_shape(),
qmax_pool->get_padding_below(),
qmax_pool->get_padding_above());
return qmax_pool_index;
}
size_t MKLDNNEmitter::build_quantized_avg_pool(const ngraph::Node* node)
{
auto qavg_pool = static_cast<const ngraph::op::QuantizedAvgPool*>(node);
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t qavg_pool_index =
this->build_pooling_forward((qavg_pool->get_include_padding_in_avg_computation()
? mkldnn::algorithm::pooling_avg_include_padding
: mkldnn::algorithm::pooling_avg_exclude_padding),
input_desc,
result_desc,
qavg_pool->get_window_movement_strides(),
qavg_pool->get_window_shape(),
qavg_pool->get_padding_below(),
qavg_pool->get_padding_above());
return qavg_pool_index;
}
mkldnn::memory::format MKLDNNEmitter::query_convolution_forward_weight_format(
const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc_any,
......@@ -659,38 +626,6 @@ void MKLDNNEmitter::build_convolution_backward_data(
*mkldnn_primitives[result_index]);
}
size_t MKLDNNEmitter::build_pooling_forward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::pooling_forward(
{{mkldnn::prop_kind::forward_inference,
pooling_algorithm,
input_desc,
result_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
executor::global_cpu_engine},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[result_index]));
NGRAPH_CHECK(m_primitive_deps.find(primitive_index) == m_primitive_deps.end(),
"Dependencies already created for node");
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
void MKLDNNEmitter::build_pooling_forward(std::vector<mkldnn::primitive*>& mkldnn_primitives,
const mkldnn::pooling_forward::desc& pool_desc,
const std::vector<size_t>& deps,
......
......@@ -421,14 +421,6 @@ namespace ngraph
}
}
size_t build_pooling_forward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
template <typename OP>
mkldnn::pooling_forward::desc get_max_pooling_forward_desc(const ngraph::Node* node,
bool training)
......@@ -822,10 +814,6 @@ namespace ngraph
const std::vector<size_t>& deps,
size_t bounded_relu_index);
size_t build_quantized_max_pool(const ngraph::Node* node);
size_t build_quantized_avg_pool(const ngraph::Node* node);
size_t build_dequantization(const ngraph::Node* node,
const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment