Commit 64356bc5 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Check size of scratchpad before allocating memory and creating scratchpad. (#3638)

* Check size requirement before creating scratchpad.

* Check max scratchpad size before allocating scratchpad_buffer.

* Add the same checks for CODEGEN.

* Fix unused-parameter warning.

* Fix a typo.

* Address PR feedback.

* Fix a bug.

* Fix style error.
parent 20b4ce27
......@@ -40,7 +40,7 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto sum_pd = mkldnn_emitter->get_elementwise_add_desc(node);
QUERY_SCRATCHPAD(sum, sum_pd);
size_t scratchpad_size = QUERY_SCRATCHPAD(sum, sum_pd);
// Add needs 4 primitives: input0, input1, result, and sum.
size_t add_index = mkldnn_emitter->reserve_primitive_space(4);
......@@ -55,6 +55,7 @@ namespace ngraph
auto functor = [&,
sum_pd,
add_index,
scratchpad_size,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
......@@ -76,7 +77,7 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, add_index, deps, cpu::mkldnn_utils::OpType::ADD);
ctx, add_index, deps, cpu::mkldnn_utils::OpType::ADD, scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -55,32 +55,40 @@ namespace ngraph
auto avg_pool_desc =
mkldnn_emitter->get_avg_pooling_forward_desc<ngraph::op::AvgPool>(node,
false);
QUERY_SCRATCHPAD(pooling_forward, avg_pool_desc);
size_t scratchpad_size = QUERY_SCRATCHPAD(pooling_forward, avg_pool_desc);
// AvgPool needs 3 primitives: input, result, and pooling_forward.
size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index);
auto functor =
[&, avg_pool_desc, avg_pool_index, arg0_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_pooling_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
avg_pool_desc,
deps,
avg_pool_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, avg_pool_index, deps, cpu::mkldnn_utils::OpType::AVGPOOL);
};
auto functor = [&,
avg_pool_desc,
avg_pool_index,
scratchpad_size,
arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_pooling_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
avg_pool_desc,
deps,
avg_pool_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx,
avg_pool_index,
deps,
cpu::mkldnn_utils::OpType::AVGPOOL,
scratchpad_size);
};
functors.emplace_back(functor);
}
else
......@@ -144,7 +152,8 @@ namespace ngraph
auto avg_pool_desc =
mkldnn_emitter->get_avg_pooling_backward_desc<ngraph::op::AvgPoolBackprop>(
node);
QUERY_SCRATCHPAD_2ARGS(avg_pooling_backward, avg_pool_fwd_desc, avg_pool_desc);
size_t scratchpad_size = QUERY_SCRATCHPAD_2ARGS(
avg_pooling_backward, avg_pool_fwd_desc, avg_pool_desc);
// AvgPoolBackprop needs 3 primitives: input, result, and pooling_backward.
size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3);
......@@ -154,6 +163,7 @@ namespace ngraph
avg_pool_desc,
avg_pool_fwd_desc,
avg_pool_index,
scratchpad_size,
delta_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
......@@ -173,7 +183,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, avg_pool_index, deps, cpu::mkldnn_utils::OpType::AVGPOOLBACKPROP);
ctx,
avg_pool_index,
deps,
cpu::mkldnn_utils::OpType::AVGPOOLBACKPROP,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -84,7 +84,8 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto batchnorm_desc =
mkldnn_emitter->get_batchnorm_forward_desc<OP>(node, true);
QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops);
auto weights_shape = Shape{2, args[0].get_size()};
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
......@@ -101,6 +102,7 @@ namespace ngraph
training,
ops,
batchnorm_index,
scratchpad_size,
stacked_weights,
weight_sizes,
arg0_buffer_index,
......@@ -140,7 +142,11 @@ namespace ngraph
ctx, deps[4], ctx->buffer_data[out2_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORM3ARGS);
ctx,
batchnorm_index,
deps,
cpu::mkldnn_utils::OpType::BATCHNORM3ARGS,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -155,7 +161,8 @@ namespace ngraph
auto batchnorm_desc =
mkldnn_emitter->get_batchnorm_forward_desc<OP>(node, false);
QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops);
auto weights_shape = Shape{2, args[0].get_size()};
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
......@@ -172,6 +179,7 @@ namespace ngraph
training,
ops,
batchnorm_index,
scratchpad_size,
stacked_weights,
weight_sizes,
arg0_buffer_index,
......@@ -211,7 +219,11 @@ namespace ngraph
ctx, deps[4], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORM5ARGS);
ctx,
batchnorm_index,
deps,
cpu::mkldnn_utils::OpType::BATCHNORM5ARGS,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -515,7 +527,8 @@ namespace ngraph
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
auto eps = batchnorm->get_eps_value();
(void)eps; // Use depends on mkl-dnn version
QUERY_SCRATCHPAD_3ARGS(batchnorm_backward, batchnorm_desc, input_desc, eps);
size_t scratchpad_size =
QUERY_SCRATCHPAD_3ARGS(batchnorm_backward, batchnorm_desc, input_desc, eps);
auto functor = [&,
batchnorm_desc,
......@@ -523,6 +536,7 @@ namespace ngraph
weights_desc,
dweights_desc,
batchnorm_index,
scratchpad_size,
stacked_weights,
stacked_dweights,
weight_sizes,
......@@ -573,7 +587,8 @@ namespace ngraph
ctx,
batchnorm_index,
deps,
cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP);
cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP,
scratchpad_size);
memcpy(ctx->buffer_data[out1_buffer_index],
stacked_dweights.get(),
......
......@@ -44,7 +44,7 @@ namespace ngraph
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto bounded_relu_desc = mkldnn_emitter->get_bounded_relu_desc(node);
QUERY_SCRATCHPAD(eltwise_forward, bounded_relu_desc);
size_t scratchpad_size = QUERY_SCRATCHPAD(eltwise_forward, bounded_relu_desc);
// BoundedRelu needs 3 primitives: input, result, and eltwise_forward.
auto bounded_relu_index = mkldnn_emitter->reserve_primitive_space(3);
......@@ -53,6 +53,7 @@ namespace ngraph
auto functor = [&,
bounded_relu_desc,
bounded_relu_index,
scratchpad_size,
input_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
......@@ -71,7 +72,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, bounded_relu_index, deps, cpu::mkldnn_utils::OpType::BOUNDEDRELU);
ctx,
bounded_relu_index,
deps,
cpu::mkldnn_utils::OpType::BOUNDEDRELU,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -101,7 +101,7 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto concat_pd =
mkldnn_emitter->get_concat_desc<ngraph::op::Concat>(node, nargs);
QUERY_SCRATCHPAD(concat, concat_pd);
size_t scratchpad_size = QUERY_SCRATCHPAD(concat, concat_pd);
std::vector<mkldnn::memory::desc> inputs_data_desc;
for (size_t i = 0; i < nargs; i++)
......@@ -115,6 +115,7 @@ namespace ngraph
auto functor = [&,
concat_pd,
scratchpad_size,
inputs_data_desc,
arg_buffer_indices,
nargs,
......@@ -140,7 +141,11 @@ namespace ngraph
ctx, deps[nargs], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, concat_index, deps, cpu::mkldnn_utils::OpType::CONCAT);
ctx,
concat_index,
deps,
cpu::mkldnn_utils::OpType::CONCAT,
scratchpad_size);
};
functors.emplace_back(functor);
......
......@@ -43,6 +43,8 @@ namespace ngraph
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t scratchpad_size = 0;
#if MKLDNN_VERSION_MAJOR < 1
if (input_desc.data.format == mkldnn_nchw &&
result_desc.data.format == mkldnn_goihw)
......@@ -129,32 +131,41 @@ namespace ngraph
mkldnn::memory::format_tag::goihw);
}
mkldnn_emitter->query_scratchpad_reorder(input_desc, result_desc);
scratchpad_size = mkldnn_emitter->query_scratchpad_reorder(input_desc, result_desc);
#endif
// ConvertLayout needs 3 primitives: input, result, and reorder.
size_t reorder_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(reorder_index);
auto functor =
[&, input_desc, result_desc, reorder_index, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_reorder(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
input_desc,
result_desc,
deps,
reorder_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
auto functor = [&,
input_desc,
result_desc,
reorder_index,
scratchpad_size,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_reorder(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
input_desc,
result_desc,
deps,
reorder_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, reorder_index, deps, cpu::mkldnn_utils::OpType::CONVERTLAYOUT);
};
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx,
reorder_index,
deps,
cpu::mkldnn_utils::OpType::CONVERTLAYOUT,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -44,7 +44,7 @@ namespace ngraph
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto leaky_relu_desc = mkldnn_emitter->get_leaky_relu_desc(node);
QUERY_SCRATCHPAD(eltwise_forward, leaky_relu_desc);
size_t scratchpad_size = QUERY_SCRATCHPAD(eltwise_forward, leaky_relu_desc);
// CPULeakyRelu needs 3 primitives: input, result, and eltwise_forward.
auto leaky_relu_index = mkldnn_emitter->reserve_primitive_space(3);
......@@ -53,6 +53,7 @@ namespace ngraph
auto functor = [&,
leaky_relu_desc,
leaky_relu_index,
scratchpad_size,
input_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
......@@ -71,7 +72,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, leaky_relu_index, deps, cpu::mkldnn_utils::OpType::LEAKYRELU);
ctx,
leaky_relu_index,
deps,
cpu::mkldnn_utils::OpType::LEAKYRELU,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -44,14 +44,19 @@ namespace ngraph
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto lrn_desc = mkldnn_emitter->get_lrn_forward_desc(node);
QUERY_SCRATCHPAD(lrn_forward, lrn_desc);
size_t scratchpad_size = QUERY_SCRATCHPAD(lrn_forward, lrn_desc);
// LRN needs 3 primitives: input, result, and lrn_forward.
auto lrn_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(lrn_index);
functor = [&, lrn_desc, lrn_index, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
functor = [&,
lrn_desc,
lrn_index,
scratchpad_size,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_lrn_forward(ctx->mkldnn_memories,
......@@ -67,7 +72,7 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, lrn_index, deps, cpu::mkldnn_utils::OpType::LRN);
ctx, lrn_index, deps, cpu::mkldnn_utils::OpType::LRN, scratchpad_size);
};
}
else
......
......@@ -123,7 +123,7 @@ namespace ngraph
};
functors.emplace_back(functor);
#else
mkldnn_emitter->query_scratchpad_rnn_forward(lstm_desc);
size_t scratchpad_size = mkldnn_emitter->query_scratchpad_rnn_forward(lstm_desc);
auto src_iter_c_buffer_index =
external_function->get_buffer_index(args[2].get_name());
......@@ -146,6 +146,7 @@ namespace ngraph
auto functor = [&,
lstm_desc,
lstm_index,
scratchpad_size,
src_layer_buffer_index,
src_iter_buffer_index,
src_iter_c_buffer_index,
......@@ -188,7 +189,7 @@ namespace ngraph
ctx, deps[9], ctx->mkldnn_workspaces[deps[10]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, lstm_index, deps, cpu::mkldnn_utils::OpType::LSTM);
ctx, lstm_index, deps, cpu::mkldnn_utils::OpType::LSTM, scratchpad_size);
};
functors.emplace_back(functor);
#endif
......
......@@ -54,32 +54,40 @@ namespace ngraph
auto max_pool_desc =
mkldnn_emitter->get_max_pooling_forward_desc<ngraph::op::MaxPool>(node,
false);
QUERY_SCRATCHPAD(pooling_forward, max_pool_desc);
size_t scratchpad_size = QUERY_SCRATCHPAD(pooling_forward, max_pool_desc);
// MaxPool needs 3 primitives: input, result, and pooling_forward.
size_t max_pool_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(max_pool_index);
auto functor =
[&, max_pool_desc, max_pool_index, arg0_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_pooling_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
max_pool_desc,
deps,
max_pool_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, max_pool_index, deps, cpu::mkldnn_utils::OpType::MAXPOOL);
};
auto functor = [&,
max_pool_desc,
max_pool_index,
scratchpad_size,
arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_pooling_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
max_pool_desc,
deps,
max_pool_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx,
max_pool_index,
deps,
cpu::mkldnn_utils::OpType::MAXPOOL,
scratchpad_size);
};
functors.emplace_back(functor);
}
else
......@@ -141,7 +149,8 @@ namespace ngraph
mkldnn_emitter->get_max_pooling_backward_desc<ngraph::op::MaxPoolBackprop>(
node);
auto fprop_src_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
QUERY_SCRATCHPAD_2ARGS(max_pooling_backward, fwd_pool_desc, bwd_pool_desc);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(max_pooling_backward, fwd_pool_desc, bwd_pool_desc);
// MaxPoolBackprop forward needs 4 primitives: fprop_src, diff_src, workspace,
// and pooling_forward.
......@@ -150,21 +159,25 @@ namespace ngraph
mkldnn_emitter->reserve_primitive_space(4, true /* new workspace */);
auto& fdeps = mkldnn_emitter->get_primitive_deps(fwd_pool_index);
auto functor_fprop =
[&, fwd_pool_index, arg_fwd_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
cpu::mkldnn_utils::set_memory_ptr(
ctx, fdeps[0], ctx->buffer_data[arg_fwd_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, fdeps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, fdeps[2], ctx->mkldnn_workspaces[fdeps[3]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx,
fwd_pool_index,
fdeps,
cpu::mkldnn_utils::OpType::MAXPOOLBACKPROPFORWARD);
};
auto functor_fprop = [&,
fwd_pool_index,
arg_fwd_buffer_index,
scratchpad_size,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
cpu::mkldnn_utils::set_memory_ptr(
ctx, fdeps[0], ctx->buffer_data[arg_fwd_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, fdeps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, fdeps[2], ctx->mkldnn_workspaces[fdeps[3]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx,
fwd_pool_index,
fdeps,
cpu::mkldnn_utils::OpType::MAXPOOLBACKPROPFORWARD,
scratchpad_size);
};
// MaxPoolBackprop backward needs 4 primitives: diff_dst, workspace, diff_src,
// and pooling_backward.
......@@ -267,7 +280,7 @@ namespace ngraph
mkldnn_emitter
->get_max_pooling_with_indices_forward_desc<ngraph::op::MaxPoolWithIndices>(
node);
QUERY_SCRATCHPAD(pooling_forward, max_pool_desc);
size_t scratchpad_size = QUERY_SCRATCHPAD(pooling_forward, max_pool_desc);
// MaxPoolWithIndices needs 4 primitives: src, dst, workspace, and pooling_forward.
size_t max_pool_index = mkldnn_emitter->reserve_primitive_space(4);
......@@ -276,6 +289,7 @@ namespace ngraph
auto functor = [&,
max_pool_desc,
max_pool_index,
scratchpad_size,
arg0_buffer_index,
out0_buffer_index,
out1_buffer_index](CPURuntimeContext* ctx,
......@@ -298,7 +312,11 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out1_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, max_pool_index, deps, cpu::mkldnn_utils::OpType::MAXPOOLWITHINDICES);
ctx,
max_pool_index,
deps,
cpu::mkldnn_utils::OpType::MAXPOOLWITHINDICES,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -326,7 +344,7 @@ namespace ngraph
mkldnn_emitter
->get_max_pooling_backward_desc<ngraph::op::MaxPoolWithIndicesBackprop>(
node);
QUERY_SCRATCHPAD_2ARGS(
size_t scratchpad_size = QUERY_SCRATCHPAD_2ARGS(
max_pooling_with_indices_backward, fwd_pool_desc, bwd_pool_desc);
// MaxPoolWithIndicesBackprop needs 4 primitives: diff_dst, fprop_workspace,
......@@ -338,6 +356,7 @@ namespace ngraph
bwd_pool_desc,
fwd_pool_desc,
max_pool_index,
scratchpad_size,
arg1_buffer_index,
arg2_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
......@@ -364,7 +383,8 @@ namespace ngraph
ctx,
max_pool_index,
deps,
cpu::mkldnn_utils::OpType::MAXPOOLWITHINDICESBACKPROP);
cpu::mkldnn_utils::OpType::MAXPOOLWITHINDICESBACKPROP,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -53,7 +53,8 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
QUERY_SCRATCHPAD_2ARGS(reorder, input_desc, result_desc);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(reorder, input_desc, result_desc);
auto scale_const_op =
as_type_ptr<ngraph::op::Constant>(dequantize->get_argument(1));
......@@ -73,6 +74,7 @@ namespace ngraph
result_desc,
scales_size,
dequantize_index,
scratchpad_size,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
......@@ -101,7 +103,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, dequantize_index, deps, cpu::mkldnn_utils::OpType::DEQUANTIZE);
ctx,
dequantize_index,
deps,
cpu::mkldnn_utils::OpType::DEQUANTIZE,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -118,6 +124,7 @@ namespace ngraph
result_desc,
scales,
dequantize_index,
scratchpad_size,
arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
......@@ -138,7 +145,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, dequantize_index, deps, cpu::mkldnn_utils::OpType::DEQUANTIZE);
ctx,
dequantize_index,
deps,
cpu::mkldnn_utils::OpType::DEQUANTIZE,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -325,7 +336,8 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
QUERY_SCRATCHPAD_2ARGS(reorder, input_desc, result_desc);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(reorder, input_desc, result_desc);
auto scale_const_op =
as_type_ptr<ngraph::op::Constant>(quantize->get_argument(1));
......@@ -344,6 +356,7 @@ namespace ngraph
result_desc,
scales_size,
quantize_index,
scratchpad_size,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
......@@ -379,7 +392,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, quantize_index, deps, cpu::mkldnn_utils::OpType::QUANTIZE);
ctx,
quantize_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZE,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -396,6 +413,7 @@ namespace ngraph
result_desc,
scales,
quantize_index,
scratchpad_size,
arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
......@@ -416,7 +434,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, quantize_index, deps, cpu::mkldnn_utils::OpType::QUANTIZE);
ctx,
quantize_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZE,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -66,7 +66,8 @@ namespace ngraph
auto conv_attr =
mkldnn_emitter
->get_convolution_forward_attr<ngraph::op::QuantizedConvolution>(node);
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t conv_index = mkldnn_emitter->convolution_forward_init();
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
......@@ -76,6 +77,7 @@ namespace ngraph
conv_attr,
deps,
conv_index,
scratchpad_size,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
......@@ -113,7 +115,11 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, conv_index, deps, cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTION);
ctx,
conv_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTION,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -345,7 +351,8 @@ namespace ngraph
mkldnn_emitter
->get_convolution_forward_attr<ngraph::op::QuantizedConvolutionRelu>(
node);
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t conv_index = mkldnn_emitter->convolution_forward_init();
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
......@@ -356,6 +363,7 @@ namespace ngraph
conv_attr,
deps,
conv_index,
scratchpad_size,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
......@@ -392,7 +400,8 @@ namespace ngraph
ctx,
conv_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONRELU);
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONRELU,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -430,7 +439,8 @@ namespace ngraph
mkldnn_emitter
->get_convolution_forward_attr<ngraph::op::QuantizedConvolutionBias>(
node);
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t conv_index = mkldnn_emitter->convolution_forward_init(true);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
......@@ -441,6 +451,7 @@ namespace ngraph
conv_attr,
deps,
conv_index,
scratchpad_size,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
......@@ -480,7 +491,8 @@ namespace ngraph
ctx,
conv_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIAS);
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIAS,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -525,7 +537,8 @@ namespace ngraph
mkldnn_emitter
->get_convolution_forward_attr<ngraph::op::QuantizedConvolutionBiasAdd>(
node);
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t conv_index = mkldnn_emitter->convolution_forward_init(true);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
......@@ -537,6 +550,7 @@ namespace ngraph
conv_attr,
deps,
conv_index,
scratchpad_size,
arg3_size,
arg0_buffer_index,
arg1_buffer_index,
......@@ -609,7 +623,8 @@ namespace ngraph
ctx,
conv_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIASADD);
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIASADD,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -650,7 +665,8 @@ namespace ngraph
ngraph::op::QuantizedConvolutionBiasSignedAdd>(node);
auto conv_attr = mkldnn_emitter->get_convolution_forward_attr<
ngraph::op::QuantizedConvolutionBiasSignedAdd>(node);
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t conv_index = mkldnn_emitter->convolution_forward_init(true);
auto& deps = mkldnn_emitter->get_primitive_deps(conv_index);
......@@ -662,6 +678,7 @@ namespace ngraph
conv_attr,
deps,
conv_index,
scratchpad_size,
arg3_size,
arg0_buffer_index,
arg1_buffer_index,
......@@ -734,7 +751,8 @@ namespace ngraph
ctx,
conv_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIASSIGNEDADD);
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIASSIGNEDADD,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -63,7 +63,7 @@ namespace ngraph
auto ip_attr =
mkldnn_emitter
->get_inner_product_forward_attr<ngraph::op::QuantizedDotBias>(node);
QUERY_SCRATCHPAD_2ARGS(ip_forward, ip_desc, ip_attr);
size_t scratchpad_size = QUERY_SCRATCHPAD_2ARGS(ip_forward, ip_desc, ip_attr);
size_t ip_index = mkldnn_emitter->inner_product_forward_init(true);
auto& deps = mkldnn_emitter->get_primitive_deps(ip_index);
......@@ -74,6 +74,7 @@ namespace ngraph
ip_attr,
deps,
ip_index,
scratchpad_size,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
......@@ -108,7 +109,11 @@ namespace ngraph
ctx, deps[3], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, ip_index, deps, cpu::mkldnn_utils::OpType::QUANTIZEDDOTBIAS);
ctx,
ip_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZEDDOTBIAS,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -56,7 +56,7 @@ namespace ngraph
auto ip_attr =
mkldnn_emitter->get_inner_product_forward_attr<ngraph::op::QuantizedMatmul>(
node);
QUERY_SCRATCHPAD_2ARGS(ip_forward, ip_desc, ip_attr);
size_t scratchpad_size = QUERY_SCRATCHPAD_2ARGS(ip_forward, ip_desc, ip_attr);
size_t ip_index = mkldnn_emitter->inner_product_forward_init(false);
auto& deps = mkldnn_emitter->get_primitive_deps(ip_index);
......@@ -66,6 +66,7 @@ namespace ngraph
ip_attr,
deps,
ip_index,
scratchpad_size,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
......@@ -95,7 +96,11 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, ip_index, deps, cpu::mkldnn_utils::OpType::QUANTIZEDMATMUL);
ctx,
ip_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZEDMATMUL,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -41,14 +41,19 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto relu_desc = mkldnn_emitter->get_relu_forward_desc(node);
QUERY_SCRATCHPAD(eltwise_forward, relu_desc);
size_t scratchpad_size = QUERY_SCRATCHPAD(eltwise_forward, relu_desc);
// Relu needs 3 primitives: input, result, and eltwise_forward.
size_t relu_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(relu_index);
auto functor = [&, relu_desc, relu_index, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
auto functor = [&,
relu_desc,
relu_index,
scratchpad_size,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_relu_forward(ctx->mkldnn_memories,
......@@ -63,8 +68,11 @@ namespace ngraph
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, relu_index, deps, cpu::mkldnn_utils::OpType::RELU);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx,
relu_index,
deps,
cpu::mkldnn_utils::OpType::RELU,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -89,7 +97,8 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto bwd_desc = mkldnn_emitter->get_relu_backward_desc(node);
auto fwd_desc = mkldnn_emitter->get_relu_forward_desc(node);
QUERY_SCRATCHPAD_2ARGS(eltwise_backward, fwd_desc, bwd_desc);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(eltwise_backward, fwd_desc, bwd_desc);
// ReluBackprop needs 4 primitives: input, delta, result, and eltwise_backward.
size_t relu_index = mkldnn_emitter->reserve_primitive_space(4);
......@@ -99,6 +108,7 @@ namespace ngraph
bwd_desc,
fwd_desc,
relu_index,
scratchpad_size,
arg_fwd_buffer_index,
delta_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
......@@ -121,7 +131,11 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, relu_index, deps, cpu::mkldnn_utils::OpType::RELUBACKPROP);
ctx,
relu_index,
deps,
cpu::mkldnn_utils::OpType::RELUBACKPROP,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -109,7 +109,7 @@ namespace ngraph
};
functors.emplace_back(functor);
#else
mkldnn_emitter->query_scratchpad_rnn_forward(rnn_desc);
size_t scratchpad_size = mkldnn_emitter->query_scratchpad_rnn_forward(rnn_desc);
auto src_iter_c_buffer_index =
external_function->get_buffer_index(args[2].get_name());
......@@ -132,6 +132,7 @@ namespace ngraph
auto functor = [&,
rnn_desc,
rnn_index,
scratchpad_size,
src_layer_buffer_index,
src_iter_buffer_index,
src_iter_c_buffer_index,
......@@ -174,7 +175,7 @@ namespace ngraph
ctx, deps[9], ctx->mkldnn_workspaces[deps[10]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, rnn_index, deps, cpu::mkldnn_utils::OpType::RNN);
ctx, rnn_index, deps, cpu::mkldnn_utils::OpType::RNN, scratchpad_size);
};
functors.emplace_back(functor);
#endif
......
......@@ -43,32 +43,39 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto sigmoid_desc = mkldnn_emitter->get_sigmoid_forward_desc(node, false);
QUERY_SCRATCHPAD(eltwise_forward, sigmoid_desc);
size_t scratchpad_size = QUERY_SCRATCHPAD(eltwise_forward, sigmoid_desc);
// Sigmoid needs 3 primitives: input, result, and eltwise_forward.
auto sigmoid_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(sigmoid_index);
auto functor =
[&, sigmoid_desc, sigmoid_index, arg0_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_sigmoid_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
sigmoid_desc,
deps,
sigmoid_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, sigmoid_index, deps, cpu::mkldnn_utils::OpType::SIGMOID);
};
auto functor = [&,
sigmoid_desc,
sigmoid_index,
scratchpad_size,
arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_sigmoid_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
sigmoid_desc,
deps,
sigmoid_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx,
sigmoid_index,
deps,
cpu::mkldnn_utils::OpType::SIGMOID,
scratchpad_size);
};
functors.emplace_back(functor);
}
......@@ -88,7 +95,8 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto fwd_desc = mkldnn_emitter->get_sigmoid_forward_desc(node, true);
auto bwd_desc = mkldnn_emitter->get_sigmoid_backward_desc(node);
QUERY_SCRATCHPAD_2ARGS(eltwise_backward, fwd_desc, bwd_desc);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(eltwise_backward, fwd_desc, bwd_desc);
// SigmoidBackprop needs 4 primitives: input, delta, result, and eltwise_backward.
size_t sigmoid_index = mkldnn_emitter->reserve_primitive_space(4);
......@@ -98,6 +106,7 @@ namespace ngraph
bwd_desc,
fwd_desc,
sigmoid_index,
scratchpad_size,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
......@@ -120,7 +129,11 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, sigmoid_index, deps, cpu::mkldnn_utils::OpType::SIGMOIDBACKPROP);
ctx,
sigmoid_index,
deps,
cpu::mkldnn_utils::OpType::SIGMOIDBACKPROP,
scratchpad_size);
};
functors.emplace_back(functor);
}
......
......@@ -93,7 +93,8 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
QUERY_SCRATCHPAD_4ARGS(slice, input_desc, result_desc, lower_bounds, out_shape);
size_t scratchpad_size = QUERY_SCRATCHPAD_4ARGS(
slice, input_desc, result_desc, lower_bounds, out_shape);
// Slice needs 3 primitives: input, result, and reorder.
auto slice_index = mkldnn_emitter->reserve_primitive_space(3);
......@@ -105,6 +106,7 @@ namespace ngraph
lower_bounds,
out_shape,
slice_index,
scratchpad_size,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
......@@ -125,8 +127,11 @@ namespace ngraph
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, slice_index, deps, cpu::mkldnn_utils::OpType::SLICE);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx,
slice_index,
deps,
cpu::mkldnn_utils::OpType::SLICE,
scratchpad_size);
};
functors.emplace_back(functor);
......
......@@ -48,32 +48,40 @@ namespace ngraph
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto softmax_desc = mkldnn_emitter->get_softmax_forward_desc(node);
QUERY_SCRATCHPAD(softmax_forward, softmax_desc);
size_t scratchpad_size = QUERY_SCRATCHPAD(softmax_forward, softmax_desc);
// Softmax needs 3 primitives: input, result, and softmax_forward.
size_t softmax_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(softmax_index);
auto functor =
[&, softmax_desc, softmax_index, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_softmax_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
softmax_desc,
deps,
softmax_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, softmax_index, deps, cpu::mkldnn_utils::OpType::SOFTMAX);
};
auto functor = [&,
softmax_desc,
softmax_index,
scratchpad_size,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_softmax_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
softmax_desc,
deps,
softmax_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx,
softmax_index,
deps,
cpu::mkldnn_utils::OpType::SOFTMAX,
scratchpad_size);
};
functors.emplace_back(functor);
return;
}
......
......@@ -217,7 +217,14 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context(Allocator* allocator)
std::vector<mkldnn::memory*>(mkldnn_emitter->get_mkldnn_memories().size());
ctx->mkldnn_scratchpad_mds = std::vector<mkldnn::memory::desc*>(
mkldnn_emitter->get_mkldnn_scratchpad_mds().size());
ctx->scratchpad_buffer = new AlignedBuffer(scratchpad_size, alignment);
if (scratchpad_size > 0)
{
ctx->scratchpad_buffer = new AlignedBuffer(scratchpad_size, alignment, allocator);
}
else
{
ctx->scratchpad_buffer = nullptr;
}
}
else
{
......
This diff is collapsed.
......@@ -492,7 +492,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
// Build mkldnn primitives for codegen.
pass_manager.register_pass<runtime::cpu::pass::MKLDNNPrimitiveBuildPass>(
m_desc_filename, *m_mkldnn_emitter, m_node_primitive_string_deps_index_map);
m_desc_filename, *m_mkldnn_emitter, m_node_primitive_string_deps_index_size_map);
unordered_map<Node*, Node*> node_function_map;
string common_function_string;
......@@ -694,9 +694,16 @@ using namespace ngraph::runtime;
writer << "mkldnn_scratchpad_mds = std::vector<mkldnn::memory::desc*>("
<< to_string(m_mkldnn_emitter->get_mkldnn_scratchpad_mds().size()) << ");\n";
writer << "size_t scratchpad_size = " << m_mkldnn_emitter->get_max_scratchpad_size() << ";\n";
writer << "if (scratchpad_size > 0)\n";
writer.block_begin();
writer << "size_t alignment = 4096;\n";
writer << "scratchpad_buffer = new AlignedBuffer(scratchpad_size, alignment);\n";
writer.block_end();
writer << "else\n";
writer.block_begin();
writer << "scratchpad_buffer = nullptr;\n";
writer.block_end();
writer.block_end();
writer << "\n";
set<string> output_names;
......
......@@ -137,13 +137,14 @@ namespace ngraph
return m_mkldnn_emitter;
}
// Return the tuple including the string to create mkldnn primitive, the deps and
// the index in CODEGEN
const std::tuple<std::string, std::vector<size_t>, size_t>&
// Return the tuple including the string to create mkldnn primitive, the deps, the
// index and
// the scratchpad size in CODEGEN
const std::tuple<std::string, std::vector<size_t>, size_t, size_t>&
get_primitive_build_tuple(const Node* node) const
{
auto it = m_node_primitive_string_deps_index_map.find(node);
NGRAPH_CHECK(it != m_node_primitive_string_deps_index_map.end(),
auto it = m_node_primitive_string_deps_index_size_map.find(node);
NGRAPH_CHECK(it != m_node_primitive_string_deps_index_size_map.end(),
"Primitive build tuple not found for node ",
node->description());
......@@ -351,9 +352,9 @@ namespace ngraph
#endif
/// Map each node with mkldnn implementation to its mkldnn primitive creating
/// string, deps, and mkldnn primitive index.
std::map<const Node*, std::tuple<std::string, std::vector<size_t>, size_t>>
m_node_primitive_string_deps_index_map;
/// string, deps, mkldnn primitive index, and mkldnn scratchpad size.
std::map<const Node*, std::tuple<std::string, std::vector<size_t>, size_t, size_t>>
m_node_primitive_string_deps_index_size_map;
/// Name of the file to store descriptors for mkldnn_primitives
const std::string m_desc_filename = "desc_file";
};
......
This diff is collapsed.
......@@ -1214,15 +1214,6 @@ namespace ngraph
attr);
}
mkldnn::memory::format_tag query_convolution_forward_weight_format_tag(
const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc_any,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& filter_strides,
const ngraph::Strides& window_dilation_strides_adjusted,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
template <typename OP>
mkldnn::lstm_forward::desc
get_rnn_forward_desc(const ngraph::Node* node,
......@@ -1407,54 +1398,50 @@ namespace ngraph
mkldnn_primitives[ip_idx] = prim;
}
void query_scratchpad_sum(const mkldnn::sum::primitive_desc);
void query_scratchpad_concat(const mkldnn::concat::primitive_desc);
void query_scratchpad_pooling_forward(const mkldnn::pooling_forward::desc& desc);
void query_scratchpad_avg_pooling_backward(
size_t query_scratchpad_sum(const mkldnn::sum::primitive_desc);
size_t query_scratchpad_concat(const mkldnn::concat::primitive_desc);
size_t query_scratchpad_pooling_forward(const mkldnn::pooling_forward::desc& desc);
size_t query_scratchpad_avg_pooling_backward(
const mkldnn::pooling_forward::desc& fwd_desc,
const mkldnn::pooling_backward::desc& bwd_desc);
void query_scratchpad_max_pooling_backward(
size_t query_scratchpad_max_pooling_backward(
const mkldnn::pooling_forward::desc& fwd_desc,
const mkldnn::pooling_backward::desc& bwd_desc);
void query_scratchpad_max_pooling_with_indices_backward(
size_t query_scratchpad_max_pooling_with_indices_backward(
const mkldnn::pooling_forward::desc& fwd_desc,
const mkldnn::pooling_backward::desc& bwd_desc);
void query_scratchpad_batchnorm_forward(
size_t query_scratchpad_batchnorm_forward(
const mkldnn::batch_normalization_forward::desc& desc,
const mkldnn::post_ops& pops);
void query_scratchpad_batchnorm_backward(
size_t query_scratchpad_batchnorm_backward(
const mkldnn::batch_normalization_backward::desc& desc,
const mkldnn::memory::desc& input_desc,
float epsilon);
void query_scratchpad_convolution_forward(
size_t query_scratchpad_convolution_forward(
const mkldnn::convolution_forward::desc& desc, mkldnn::primitive_attr& attr);
void query_scratchpad_convolution_backward_data(
size_t query_scratchpad_convolution_backward_data(
const mkldnn::convolution_forward::desc& fwd_desc,
const mkldnn::convolution_backward_data::desc& bwd_desc);
void query_scratchpad_convolution_backward_weights(
size_t query_scratchpad_convolution_backward_weights(
const mkldnn::convolution_forward::desc& fwd_desc,
const mkldnn::convolution_backward_weights::desc& bwd_desc);
void query_scratchpad_deconvolution_forward(
size_t query_scratchpad_deconvolution_forward(
const mkldnn::deconvolution_forward::desc& desc);
void query_scratchpad_eltwise_forward(const mkldnn::eltwise_forward::desc& desc);
void query_scratchpad_eltwise_backward(
size_t query_scratchpad_eltwise_forward(const mkldnn::eltwise_forward::desc& desc);
size_t query_scratchpad_eltwise_backward(
const mkldnn::eltwise_forward::desc& fwd_desc,
const mkldnn::eltwise_backward::desc& bwd_desc);
void query_scratchpad_quantize(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& output_desc);
void query_scratchpad_dequantize(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& output_desc);
void query_scratchpad_ip_forward(const mkldnn::inner_product_forward::desc& desc,
mkldnn::primitive_attr& attr);
void query_scratchpad_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
void query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc);
void query_scratchpad_rnn_forward(const mkldnn::lstm_forward::desc& desc);
void query_scratchpad_slice(mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& output_desc,
const ngraph::Coordinate& lower_bounds,
const ngraph::Shape& result_shape);
void query_scratchpad_softmax_forward(const mkldnn::softmax_forward::desc& desc);
size_t query_scratchpad_ip_forward(const mkldnn::inner_product_forward::desc& desc,
mkldnn::primitive_attr& attr);
size_t query_scratchpad_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
size_t query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc);
size_t query_scratchpad_rnn_forward(const mkldnn::lstm_forward::desc& desc);
size_t query_scratchpad_slice(mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& output_desc,
const ngraph::Coordinate& lower_bounds,
const ngraph::Shape& result_shape);
size_t query_scratchpad_softmax_forward(const mkldnn::softmax_forward::desc& desc);
#else
// TODO(jmenon): Get rid of TensorViewWrappers at some point
......@@ -1701,15 +1688,6 @@ namespace ngraph
mkldnn_primitives[ip_idx] = prim;
}
mkldnn::memory::format query_convolution_forward_weight_format(
const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc_any,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& filter_strides,
const ngraph::Strides& window_dilation_strides_adjusted,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
void build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories,
std::vector<mkldnn::primitive*>& mkldnn_primitives,
std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds,
......
......@@ -37,8 +37,10 @@ extern "C" void
ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(CPURuntimeContext* ctx,
size_t primitive_index,
std::vector<size_t>& /* deps */,
OpType /* type */)
OpType /* type */,
size_t scratchpad_size)
{
(void)scratchpad_size;
mkldnn::stream s(mkldnn::stream::kind::eager);
try
{
......@@ -58,8 +60,12 @@ extern "C" void ngraph::runtime::cpu::mkldnn_utils::set_memory_ptr(CPURuntimeCon
memory->set_data_handle(ptr);
}
extern "C" void ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(
CPURuntimeContext* ctx, size_t primitive_index, std::vector<size_t>& deps, OpType type)
extern "C" void
ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(CPURuntimeContext* ctx,
size_t primitive_index,
std::vector<size_t>& deps,
OpType type,
size_t scratchpad_size)
{
std::unordered_map<int, mkldnn::memory> exec_args;
size_t nargs;
......@@ -201,10 +207,13 @@ extern "C" void ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(
break;
}
mkldnn::memory scratchpad(*ctx->mkldnn_scratchpad_mds[primitive_index],
executor::global_cpu_engine,
ctx->scratchpad_buffer->get_ptr());
exec_args.insert({MKLDNN_ARG_SCRATCHPAD, scratchpad});
if (scratchpad_size)
{
mkldnn::memory scratchpad(*ctx->mkldnn_scratchpad_mds[primitive_index],
executor::global_cpu_engine,
ctx->scratchpad_buffer->get_ptr());
exec_args.insert({MKLDNN_ARG_SCRATCHPAD, scratchpad});
}
mkldnn::stream s(executor::global_cpu_engine);
try
......
......@@ -82,7 +82,8 @@ namespace ngraph
extern "C" void mkldnn_invoke_primitive(CPURuntimeContext* ctx,
size_t primitive_index,
std::vector<size_t>& deps,
OpType type);
OpType type,
size_t scratchpad_size = 0);
}
}
}
......
......@@ -48,10 +48,10 @@
#define SET_ROUND_MODE attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
#define QUERY_SCRATCHPAD(op_name, x)
#define QUERY_SCRATCHPAD_2ARGS(op_name, x, y)
#define QUERY_SCRATCHPAD_3ARGS(op_name, x, y, z)
#define QUERY_SCRATCHPAD_4ARGS(op_name, x, y, z, u)
#define QUERY_SCRATCHPAD(op_name, x) 0
#define QUERY_SCRATCHPAD_2ARGS(op_name, x, y) 0
#define QUERY_SCRATCHPAD_3ARGS(op_name, x, y, z) 0
#define QUERY_SCRATCHPAD_4ARGS(op_name, x, y, z, u) 0
#define MKLDNN_ERROR_MESSAGE e.message
......@@ -85,7 +85,8 @@
#define GET_SIZE \
mkldnn::memory::desc scratchpad_md = pd.scratchpad_desc(); \
size_t size = scratchpad_md.get_size(); \
m_max_scratchpad_size = size > m_max_scratchpad_size ? size : m_max_scratchpad_size;
m_max_scratchpad_size = size > m_max_scratchpad_size ? size : m_max_scratchpad_size; \
return size;
#define MKLDNN_ERROR_MESSAGE std::string(e.message)
......
......@@ -529,7 +529,7 @@ namespace ngraph
void CPUAssignment::ASSIGN_DECL(ngraph::op::ScatterAdd)
{
(void)external_function;
auto update_slice = static_cast<ngraph::op::ScatterAdd*>(node);
auto scatter_add = static_cast<ngraph::op::ScatterAdd*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......@@ -538,7 +538,7 @@ namespace ngraph
// Safe to overwrite input
op_annotations->add_in_place_oi_pair({0, 0, true});
}
update_slice->set_op_annotations(op_annotations);
scatter_add->set_op_annotations(op_annotations);
}
template <>
......
......@@ -30,6 +30,7 @@
std::string & construct_string, \
std::vector<size_t> & deps, \
size_t & index, \
size_t & scratchpad_size, \
std::ofstream & desc_file)
namespace mkldnn
......@@ -55,6 +56,7 @@ namespace ngraph
std::string&,
std::vector<size_t>&,
size_t&,
size_t&,
std::ofstream&)>;
using PrimitiveBuildStringConstructOpMap =
std::unordered_map<std::type_index, PrimitiveBuildStringConstructFunction>;
......@@ -69,20 +71,23 @@ namespace ngraph
ngraph::runtime::cpu::MKLDNNEmitter& m_mkldnn_emitter;
/// External map to store each node with mkldnn implementation and its mkldnn
/// creation string, deps, and mkldnn primitive index.
std::map<const Node*, std::tuple<std::string, std::vector<size_t>, size_t>>&
m_node_primitive_string_deps_index_map;
/// creation string, deps, mkldnn primitive index, and mkldnn primitive
/// scratchpad size.
std::map<const Node*,
std::tuple<std::string, std::vector<size_t>, size_t, size_t>>&
m_node_primitive_string_deps_index_size_map;
public:
MKLDNNPrimitiveBuildPass(
std::string filename,
ngraph::runtime::cpu::MKLDNNEmitter& mkldnn_emitter,
std::map<const Node*, std::tuple<std::string, std::vector<size_t>, size_t>>&
node_primitive_string_deps_index_map)
std::map<const Node*,
std::tuple<std::string, std::vector<size_t>, size_t, size_t>>&
node_primitive_string_deps_index_size_map)
: m_desc_filename(filename)
, m_mkldnn_emitter(mkldnn_emitter)
, m_node_primitive_string_deps_index_map(
node_primitive_string_deps_index_map)
, m_node_primitive_string_deps_index_size_map(
node_primitive_string_deps_index_size_map)
{
}
......@@ -95,6 +100,7 @@ namespace ngraph
std::string& /* construct_string */,
std::vector<size_t>& /* deps */,
size_t& /* index */,
size_t& /* scratchpad size */,
std::ofstream& /* desc_file */)
{
throw std::runtime_error("Unimplemented op '" + node->description() +
......
......@@ -102,7 +102,7 @@ struct CPURuntimeContextCG
}
void mkldnn_invoke_primitive(size_t primitive_index, std::vector<size_t>& deps,
OpType type)
OpType type, size_t scratchpad_size)
{
std::unordered_map<int, mkldnn::memory> exec_args;
size_t nargs;
......@@ -257,10 +257,13 @@ struct CPURuntimeContextCG
break;
}
mkldnn::memory scratchpad(*mkldnn_scratchpad_mds[primitive_index],
global_cpu_engine,
scratchpad_buffer->get_ptr());
exec_args.insert({MKLDNN_ARG_SCRATCHPAD, scratchpad});
if (scratchpad_size)
{
mkldnn::memory scratchpad(*mkldnn_scratchpad_mds[primitive_index],
global_cpu_engine,
scratchpad_buffer->get_ptr());
exec_args.insert({MKLDNN_ARG_SCRATCHPAD, scratchpad});
}
mkldnn::stream s(global_cpu_engine);
try
......@@ -270,7 +273,7 @@ struct CPURuntimeContextCG
}
catch (const mkldnn::error& e)
{
throw std::runtime_error("Could not run mkdnn primitive " + *e.message);
throw std::runtime_error("Could not run mkdnn primitive " + std::string(e.message));
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment