Commit c1f3beea authored by Amy Zhuang's avatar Amy Zhuang Committed by Sang Ik Lee

Migrate PR#3638 from master. (#3735)

* Check size requirement before creating scratchpad.

* Check max scratchpad size before allocating scratchpad_buffer.

* Add the same checks for CODEGEN.

* Fix unused-parameter warning.

* Fix a typo.

* Address PR feedback.

* Fix a bug.

* Fix a typo.
parent be738d04
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto sum_pd = mkldnn_emitter->get_elementwise_add_desc(node); auto sum_pd = mkldnn_emitter->get_elementwise_add_desc(node);
QUERY_SCRATCHPAD(sum, sum_pd); size_t scratchpad_size = QUERY_SCRATCHPAD(sum, sum_pd);
// Add needs 4 primitives: input0, input1, result, and sum. // Add needs 4 primitives: input0, input1, result, and sum.
size_t add_index = mkldnn_emitter->reserve_primitive_space(4); size_t add_index = mkldnn_emitter->reserve_primitive_space(4);
...@@ -55,6 +55,7 @@ namespace ngraph ...@@ -55,6 +55,7 @@ namespace ngraph
auto functor = [&, auto functor = [&,
sum_pd, sum_pd,
add_index, add_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
...@@ -76,7 +77,7 @@ namespace ngraph ...@@ -76,7 +77,7 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out_buffer_index]); ctx, deps[2], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, add_index, deps, cpu::mkldnn_utils::OpType::ADD); ctx, add_index, deps, cpu::mkldnn_utils::OpType::ADD, scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -55,15 +55,19 @@ namespace ngraph ...@@ -55,15 +55,19 @@ namespace ngraph
auto avg_pool_desc = auto avg_pool_desc =
mkldnn_emitter->get_avg_pooling_forward_desc<ngraph::op::AvgPool>(node, mkldnn_emitter->get_avg_pooling_forward_desc<ngraph::op::AvgPool>(node,
false); false);
QUERY_SCRATCHPAD(pooling_forward, avg_pool_desc); size_t scratchpad_size = QUERY_SCRATCHPAD(pooling_forward, avg_pool_desc);
// AvgPool needs 3 primitives: input, result, and pooling_forward. // AvgPool needs 3 primitives: input, result, and pooling_forward.
size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3); size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index); auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index);
auto functor = auto functor = [&,
[&, avg_pool_desc, avg_pool_index, arg0_buffer_index, out_buffer_index]( avg_pool_desc,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { avg_pool_index,
scratchpad_size,
arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_pooling_forward(ctx->mkldnn_memories, mkldnn_emitter->build_pooling_forward(ctx->mkldnn_memories,
...@@ -79,7 +83,11 @@ namespace ngraph ...@@ -79,7 +83,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, avg_pool_index, deps, cpu::mkldnn_utils::OpType::AVGPOOL); ctx,
avg_pool_index,
deps,
cpu::mkldnn_utils::OpType::AVGPOOL,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -145,7 +153,8 @@ namespace ngraph ...@@ -145,7 +153,8 @@ namespace ngraph
auto avg_pool_desc = auto avg_pool_desc =
mkldnn_emitter->get_avg_pooling_backward_desc<ngraph::op::AvgPoolBackprop>( mkldnn_emitter->get_avg_pooling_backward_desc<ngraph::op::AvgPoolBackprop>(
node); node);
QUERY_SCRATCHPAD_2ARGS(avg_pooling_backward, avg_pool_fwd_desc, avg_pool_desc); size_t scratchpad_size = QUERY_SCRATCHPAD_2ARGS(
avg_pooling_backward, avg_pool_fwd_desc, avg_pool_desc);
// AvgPoolBackprop needs 3 primitives: input, result, and pooling_backward. // AvgPoolBackprop needs 3 primitives: input, result, and pooling_backward.
size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3); size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3);
...@@ -155,6 +164,7 @@ namespace ngraph ...@@ -155,6 +164,7 @@ namespace ngraph
avg_pool_desc, avg_pool_desc,
avg_pool_fwd_desc, avg_pool_fwd_desc,
avg_pool_index, avg_pool_index,
scratchpad_size,
delta_buffer_index, delta_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
...@@ -174,7 +184,11 @@ namespace ngraph ...@@ -174,7 +184,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, avg_pool_index, deps, cpu::mkldnn_utils::OpType::AVGPOOLBACKPROP); ctx,
avg_pool_index,
deps,
cpu::mkldnn_utils::OpType::AVGPOOLBACKPROP,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -84,6 +84,7 @@ namespace ngraph ...@@ -84,6 +84,7 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto batchnorm_desc = auto batchnorm_desc =
mkldnn_emitter->get_batchnorm_forward_desc<OP>(node, true); mkldnn_emitter->get_batchnorm_forward_desc<OP>(node, true);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops); QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops);
auto weights_shape = Shape{2, args[0].get_size()}; auto weights_shape = Shape{2, args[0].get_size()};
...@@ -101,6 +102,7 @@ namespace ngraph ...@@ -101,6 +102,7 @@ namespace ngraph
training, training,
ops, ops,
batchnorm_index, batchnorm_index,
scratchpad_size,
stacked_weights, stacked_weights,
weight_sizes, weight_sizes,
arg0_buffer_index, arg0_buffer_index,
...@@ -140,7 +142,11 @@ namespace ngraph ...@@ -140,7 +142,11 @@ namespace ngraph
ctx, deps[4], ctx->buffer_data[out2_buffer_index]); ctx, deps[4], ctx->buffer_data[out2_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORM3ARGS); ctx,
batchnorm_index,
deps,
cpu::mkldnn_utils::OpType::BATCHNORM3ARGS,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -155,6 +161,7 @@ namespace ngraph ...@@ -155,6 +161,7 @@ namespace ngraph
auto batchnorm_desc = auto batchnorm_desc =
mkldnn_emitter->get_batchnorm_forward_desc<OP>(node, false); mkldnn_emitter->get_batchnorm_forward_desc<OP>(node, false);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops); QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops);
auto weights_shape = Shape{2, args[0].get_size()}; auto weights_shape = Shape{2, args[0].get_size()};
...@@ -172,6 +179,7 @@ namespace ngraph ...@@ -172,6 +179,7 @@ namespace ngraph
training, training,
ops, ops,
batchnorm_index, batchnorm_index,
scratchpad_size,
stacked_weights, stacked_weights,
weight_sizes, weight_sizes,
arg0_buffer_index, arg0_buffer_index,
...@@ -211,7 +219,11 @@ namespace ngraph ...@@ -211,7 +219,11 @@ namespace ngraph
ctx, deps[4], ctx->buffer_data[out0_buffer_index]); ctx, deps[4], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORM5ARGS); ctx,
batchnorm_index,
deps,
cpu::mkldnn_utils::OpType::BATCHNORM5ARGS,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -444,6 +456,7 @@ namespace ngraph ...@@ -444,6 +456,7 @@ namespace ngraph
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node); static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
auto eps = batchnorm->get_eps_value(); auto eps = batchnorm->get_eps_value();
(void)eps; // Use depends on mkl-dnn version (void)eps; // Use depends on mkl-dnn version
size_t scratchpad_size =
QUERY_SCRATCHPAD_3ARGS(batchnorm_backward, batchnorm_desc, input_desc, eps); QUERY_SCRATCHPAD_3ARGS(batchnorm_backward, batchnorm_desc, input_desc, eps);
auto functor = [&, auto functor = [&,
...@@ -452,6 +465,7 @@ namespace ngraph ...@@ -452,6 +465,7 @@ namespace ngraph
weights_desc, weights_desc,
dweights_desc, dweights_desc,
batchnorm_index, batchnorm_index,
scratchpad_size,
stacked_weights, stacked_weights,
stacked_dweights, stacked_dweights,
weight_sizes, weight_sizes,
...@@ -499,7 +513,11 @@ namespace ngraph ...@@ -499,7 +513,11 @@ namespace ngraph
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[6], stacked_dweights.get()); cpu::mkldnn_utils::set_memory_ptr(ctx, deps[6], stacked_dweights.get());
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP); ctx,
batchnorm_index,
deps,
cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP,
scratchpad_size);
memcpy(ctx->buffer_data[out1_buffer_index], memcpy(ctx->buffer_data[out1_buffer_index],
stacked_dweights.get(), stacked_dweights.get(),
......
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto bounded_relu_desc = mkldnn_emitter->get_bounded_relu_desc(node); auto bounded_relu_desc = mkldnn_emitter->get_bounded_relu_desc(node);
QUERY_SCRATCHPAD(eltwise_forward, bounded_relu_desc); size_t scratchpad_size = QUERY_SCRATCHPAD(eltwise_forward, bounded_relu_desc);
// BoundedRelu needs 3 primitives: input, result, and eltwise_forward. // BoundedRelu needs 3 primitives: input, result, and eltwise_forward.
auto bounded_relu_index = mkldnn_emitter->reserve_primitive_space(3); auto bounded_relu_index = mkldnn_emitter->reserve_primitive_space(3);
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
auto functor = [&, auto functor = [&,
bounded_relu_desc, bounded_relu_desc,
bounded_relu_index, bounded_relu_index,
scratchpad_size,
input_buffer_index, input_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
...@@ -71,7 +72,11 @@ namespace ngraph ...@@ -71,7 +72,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, bounded_relu_index, deps, cpu::mkldnn_utils::OpType::BOUNDEDRELU); ctx,
bounded_relu_index,
deps,
cpu::mkldnn_utils::OpType::BOUNDEDRELU,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -101,7 +101,7 @@ namespace ngraph ...@@ -101,7 +101,7 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto concat_pd = auto concat_pd =
mkldnn_emitter->get_concat_desc<ngraph::op::Concat>(node, nargs); mkldnn_emitter->get_concat_desc<ngraph::op::Concat>(node, nargs);
QUERY_SCRATCHPAD(concat, concat_pd); size_t scratchpad_size = QUERY_SCRATCHPAD(concat, concat_pd);
std::vector<mkldnn::memory::desc> inputs_data_desc; std::vector<mkldnn::memory::desc> inputs_data_desc;
for (size_t i = 0; i < nargs; i++) for (size_t i = 0; i < nargs; i++)
...@@ -115,6 +115,7 @@ namespace ngraph ...@@ -115,6 +115,7 @@ namespace ngraph
auto functor = [&, auto functor = [&,
concat_pd, concat_pd,
scratchpad_size,
inputs_data_desc, inputs_data_desc,
arg_buffer_indices, arg_buffer_indices,
nargs, nargs,
...@@ -140,7 +141,11 @@ namespace ngraph ...@@ -140,7 +141,11 @@ namespace ngraph
ctx, deps[nargs], ctx->buffer_data[out_buffer_index]); ctx, deps[nargs], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, concat_index, deps, cpu::mkldnn_utils::OpType::CONCAT); ctx,
concat_index,
deps,
cpu::mkldnn_utils::OpType::CONCAT,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -43,6 +43,8 @@ namespace ngraph ...@@ -43,6 +43,8 @@ namespace ngraph
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0); auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0); auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t scratchpad_size = 0;
#if MKLDNN_VERSION_MAJOR < 1 #if MKLDNN_VERSION_MAJOR < 1
if (input_desc.data.format == mkldnn_nchw && if (input_desc.data.format == mkldnn_nchw &&
result_desc.data.format == mkldnn_goihw) result_desc.data.format == mkldnn_goihw)
...@@ -131,14 +133,19 @@ namespace ngraph ...@@ -131,14 +133,19 @@ namespace ngraph
mkldnn::memory::format_tag::goihw); mkldnn::memory::format_tag::goihw);
} }
mkldnn_emitter->query_scratchpad_reorder(input_desc, result_desc); scratchpad_size = mkldnn_emitter->query_scratchpad_reorder(input_desc, result_desc);
#endif #endif
// ConvertLayout needs 3 primitives: input, result, and reorder. // ConvertLayout needs 3 primitives: input, result, and reorder.
size_t reorder_index = mkldnn_emitter->reserve_primitive_space(3); size_t reorder_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(reorder_index); auto& deps = mkldnn_emitter->get_primitive_deps(reorder_index);
auto functor = auto functor = [&,
[&, input_desc, result_desc, reorder_index, arg_buffer_index, out_buffer_index]( input_desc,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { result_desc,
reorder_index,
scratchpad_size,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_reorder(ctx->mkldnn_memories, mkldnn_emitter->build_reorder(ctx->mkldnn_memories,
...@@ -155,7 +162,11 @@ namespace ngraph ...@@ -155,7 +162,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, reorder_index, deps, cpu::mkldnn_utils::OpType::CONVERTLAYOUT); ctx,
reorder_index,
deps,
cpu::mkldnn_utils::OpType::CONVERTLAYOUT,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto leaky_relu_desc = mkldnn_emitter->get_leaky_relu_desc(node); auto leaky_relu_desc = mkldnn_emitter->get_leaky_relu_desc(node);
QUERY_SCRATCHPAD(eltwise_forward, leaky_relu_desc); size_t scratchpad_size = QUERY_SCRATCHPAD(eltwise_forward, leaky_relu_desc);
// CPULeakyRelu needs 3 primitives: input, result, and eltwise_forward. // CPULeakyRelu needs 3 primitives: input, result, and eltwise_forward.
auto leaky_relu_index = mkldnn_emitter->reserve_primitive_space(3); auto leaky_relu_index = mkldnn_emitter->reserve_primitive_space(3);
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
auto functor = [&, auto functor = [&,
leaky_relu_desc, leaky_relu_desc,
leaky_relu_index, leaky_relu_index,
scratchpad_size,
input_buffer_index, input_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
...@@ -71,7 +72,11 @@ namespace ngraph ...@@ -71,7 +72,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, leaky_relu_index, deps, cpu::mkldnn_utils::OpType::LEAKYRELU); ctx,
leaky_relu_index,
deps,
cpu::mkldnn_utils::OpType::LEAKYRELU,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -44,14 +44,19 @@ namespace ngraph ...@@ -44,14 +44,19 @@ namespace ngraph
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto lrn_desc = mkldnn_emitter->get_lrn_forward_desc(node); auto lrn_desc = mkldnn_emitter->get_lrn_forward_desc(node);
QUERY_SCRATCHPAD(lrn_forward, lrn_desc); size_t scratchpad_size = QUERY_SCRATCHPAD(lrn_forward, lrn_desc);
// LRN needs 3 primitives: input, result, and lrn_forward. // LRN needs 3 primitives: input, result, and lrn_forward.
auto lrn_index = mkldnn_emitter->reserve_primitive_space(3); auto lrn_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(lrn_index); auto& deps = mkldnn_emitter->get_primitive_deps(lrn_index);
functor = [&, lrn_desc, lrn_index, arg_buffer_index, out_buffer_index]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { lrn_desc,
lrn_index,
scratchpad_size,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_lrn_forward(ctx->mkldnn_memories, mkldnn_emitter->build_lrn_forward(ctx->mkldnn_memories,
...@@ -67,7 +72,7 @@ namespace ngraph ...@@ -67,7 +72,7 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, lrn_index, deps, cpu::mkldnn_utils::OpType::LRN); ctx, lrn_index, deps, cpu::mkldnn_utils::OpType::LRN, scratchpad_size);
}; };
} }
else else
......
...@@ -123,7 +123,7 @@ namespace ngraph ...@@ -123,7 +123,7 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
#else #else
mkldnn_emitter->query_scratchpad_rnn_forward(lstm_desc); size_t scratchpad_size = mkldnn_emitter->query_scratchpad_rnn_forward(lstm_desc);
auto src_iter_c_buffer_index = auto src_iter_c_buffer_index =
external_function->get_buffer_index(args[2].get_name()); external_function->get_buffer_index(args[2].get_name());
...@@ -146,6 +146,7 @@ namespace ngraph ...@@ -146,6 +146,7 @@ namespace ngraph
auto functor = [&, auto functor = [&,
lstm_desc, lstm_desc,
lstm_index, lstm_index,
scratchpad_size,
src_layer_buffer_index, src_layer_buffer_index,
src_iter_buffer_index, src_iter_buffer_index,
src_iter_c_buffer_index, src_iter_c_buffer_index,
...@@ -188,7 +189,7 @@ namespace ngraph ...@@ -188,7 +189,7 @@ namespace ngraph
ctx, deps[9], ctx->mkldnn_workspaces[deps[10]]); ctx, deps[9], ctx->mkldnn_workspaces[deps[10]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, lstm_index, deps, cpu::mkldnn_utils::OpType::LSTM); ctx, lstm_index, deps, cpu::mkldnn_utils::OpType::LSTM, scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
#endif #endif
......
...@@ -54,15 +54,19 @@ namespace ngraph ...@@ -54,15 +54,19 @@ namespace ngraph
auto max_pool_desc = auto max_pool_desc =
mkldnn_emitter->get_max_pooling_forward_desc<ngraph::op::MaxPool>(node, mkldnn_emitter->get_max_pooling_forward_desc<ngraph::op::MaxPool>(node,
false); false);
QUERY_SCRATCHPAD(pooling_forward, max_pool_desc); size_t scratchpad_size = QUERY_SCRATCHPAD(pooling_forward, max_pool_desc);
// MaxPool needs 3 primitives: input, result, and pooling_forward. // MaxPool needs 3 primitives: input, result, and pooling_forward.
size_t max_pool_index = mkldnn_emitter->reserve_primitive_space(3); size_t max_pool_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(max_pool_index); auto& deps = mkldnn_emitter->get_primitive_deps(max_pool_index);
auto functor = auto functor = [&,
[&, max_pool_desc, max_pool_index, arg0_buffer_index, out_buffer_index]( max_pool_desc,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { max_pool_index,
scratchpad_size,
arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_pooling_forward(ctx->mkldnn_memories, mkldnn_emitter->build_pooling_forward(ctx->mkldnn_memories,
...@@ -78,7 +82,11 @@ namespace ngraph ...@@ -78,7 +82,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, max_pool_index, deps, cpu::mkldnn_utils::OpType::MAXPOOL); ctx,
max_pool_index,
deps,
cpu::mkldnn_utils::OpType::MAXPOOL,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -142,6 +150,7 @@ namespace ngraph ...@@ -142,6 +150,7 @@ namespace ngraph
mkldnn_emitter->get_max_pooling_backward_desc<ngraph::op::MaxPoolBackprop>( mkldnn_emitter->get_max_pooling_backward_desc<ngraph::op::MaxPoolBackprop>(
node); node);
auto fprop_src_desc = mkldnn_utils::get_input_mkldnn_md(node, 0); auto fprop_src_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(max_pooling_backward, fwd_pool_desc, bwd_pool_desc); QUERY_SCRATCHPAD_2ARGS(max_pooling_backward, fwd_pool_desc, bwd_pool_desc);
// MaxPoolBackprop forward needs 4 primitives: fprop_src, diff_src, workspace, // MaxPoolBackprop forward needs 4 primitives: fprop_src, diff_src, workspace,
...@@ -151,9 +160,12 @@ namespace ngraph ...@@ -151,9 +160,12 @@ namespace ngraph
mkldnn_emitter->reserve_primitive_space(4, true /* new workspace */); mkldnn_emitter->reserve_primitive_space(4, true /* new workspace */);
auto& fdeps = mkldnn_emitter->get_primitive_deps(fwd_pool_index); auto& fdeps = mkldnn_emitter->get_primitive_deps(fwd_pool_index);
auto functor_fprop = auto functor_fprop = [&,
[&, fwd_pool_index, arg_fwd_buffer_index, out_buffer_index]( fwd_pool_index,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { arg_fwd_buffer_index,
scratchpad_size,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
cpu::mkldnn_utils::set_memory_ptr( cpu::mkldnn_utils::set_memory_ptr(
ctx, fdeps[0], ctx->buffer_data[arg_fwd_buffer_index]); ctx, fdeps[0], ctx->buffer_data[arg_fwd_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr( cpu::mkldnn_utils::set_memory_ptr(
...@@ -164,7 +176,8 @@ namespace ngraph ...@@ -164,7 +176,8 @@ namespace ngraph
ctx, ctx,
fwd_pool_index, fwd_pool_index,
fdeps, fdeps,
cpu::mkldnn_utils::OpType::MAXPOOLBACKPROPFORWARD); cpu::mkldnn_utils::OpType::MAXPOOLBACKPROPFORWARD,
scratchpad_size);
}; };
// MaxPoolBackprop backward needs 4 primitives: diff_dst, workspace, diff_src, // MaxPoolBackprop backward needs 4 primitives: diff_dst, workspace, diff_src,
...@@ -268,7 +281,7 @@ namespace ngraph ...@@ -268,7 +281,7 @@ namespace ngraph
mkldnn_emitter mkldnn_emitter
->get_max_pooling_with_indices_forward_desc<ngraph::op::MaxPoolWithIndices>( ->get_max_pooling_with_indices_forward_desc<ngraph::op::MaxPoolWithIndices>(
node); node);
QUERY_SCRATCHPAD(pooling_forward, max_pool_desc); size_t scratchpad_size = QUERY_SCRATCHPAD(pooling_forward, max_pool_desc);
// MaxPoolWithIndices needs 4 primitives: src, dst, workspace, and pooling_forward. // MaxPoolWithIndices needs 4 primitives: src, dst, workspace, and pooling_forward.
size_t max_pool_index = mkldnn_emitter->reserve_primitive_space(4); size_t max_pool_index = mkldnn_emitter->reserve_primitive_space(4);
...@@ -277,6 +290,7 @@ namespace ngraph ...@@ -277,6 +290,7 @@ namespace ngraph
auto functor = [&, auto functor = [&,
max_pool_desc, max_pool_desc,
max_pool_index, max_pool_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
out0_buffer_index, out0_buffer_index,
out1_buffer_index](CPURuntimeContext* ctx, out1_buffer_index](CPURuntimeContext* ctx,
...@@ -299,7 +313,11 @@ namespace ngraph ...@@ -299,7 +313,11 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out1_buffer_index]); ctx, deps[2], ctx->buffer_data[out1_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, max_pool_index, deps, cpu::mkldnn_utils::OpType::MAXPOOLWITHINDICES); ctx,
max_pool_index,
deps,
cpu::mkldnn_utils::OpType::MAXPOOLWITHINDICES,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -327,7 +345,7 @@ namespace ngraph ...@@ -327,7 +345,7 @@ namespace ngraph
mkldnn_emitter mkldnn_emitter
->get_max_pooling_backward_desc<ngraph::op::MaxPoolWithIndicesBackprop>( ->get_max_pooling_backward_desc<ngraph::op::MaxPoolWithIndicesBackprop>(
node); node);
QUERY_SCRATCHPAD_2ARGS( size_t scratchpad_size = QUERY_SCRATCHPAD_2ARGS(
max_pooling_with_indices_backward, fwd_pool_desc, bwd_pool_desc); max_pooling_with_indices_backward, fwd_pool_desc, bwd_pool_desc);
// MaxPoolWithIndicesBackprop needs 4 primitives: diff_dst, fprop_workspace, // MaxPoolWithIndicesBackprop needs 4 primitives: diff_dst, fprop_workspace,
...@@ -339,6 +357,7 @@ namespace ngraph ...@@ -339,6 +357,7 @@ namespace ngraph
bwd_pool_desc, bwd_pool_desc,
fwd_pool_desc, fwd_pool_desc,
max_pool_index, max_pool_index,
scratchpad_size,
arg1_buffer_index, arg1_buffer_index,
arg2_buffer_index, arg2_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
...@@ -365,7 +384,8 @@ namespace ngraph ...@@ -365,7 +384,8 @@ namespace ngraph
ctx, ctx,
max_pool_index, max_pool_index,
deps, deps,
cpu::mkldnn_utils::OpType::MAXPOOLWITHINDICESBACKPROP); cpu::mkldnn_utils::OpType::MAXPOOLWITHINDICESBACKPROP,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0); auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0); auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(reorder, input_desc, result_desc); QUERY_SCRATCHPAD_2ARGS(reorder, input_desc, result_desc);
auto scale_const_op = std::dynamic_pointer_cast<ngraph::op::Constant>( auto scale_const_op = std::dynamic_pointer_cast<ngraph::op::Constant>(
...@@ -73,6 +74,7 @@ namespace ngraph ...@@ -73,6 +74,7 @@ namespace ngraph
result_desc, result_desc,
scales_size, scales_size,
dequantize_index, dequantize_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
...@@ -101,7 +103,11 @@ namespace ngraph ...@@ -101,7 +103,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, dequantize_index, deps, cpu::mkldnn_utils::OpType::DEQUANTIZE); ctx,
dequantize_index,
deps,
cpu::mkldnn_utils::OpType::DEQUANTIZE,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -118,6 +124,7 @@ namespace ngraph ...@@ -118,6 +124,7 @@ namespace ngraph
result_desc, result_desc,
scales, scales,
dequantize_index, dequantize_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
...@@ -138,7 +145,11 @@ namespace ngraph ...@@ -138,7 +145,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, dequantize_index, deps, cpu::mkldnn_utils::OpType::DEQUANTIZE); ctx,
dequantize_index,
deps,
cpu::mkldnn_utils::OpType::DEQUANTIZE,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -325,6 +336,7 @@ namespace ngraph ...@@ -325,6 +336,7 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0); auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0); auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(reorder, input_desc, result_desc); QUERY_SCRATCHPAD_2ARGS(reorder, input_desc, result_desc);
auto scale_const_op = auto scale_const_op =
...@@ -344,6 +356,7 @@ namespace ngraph ...@@ -344,6 +356,7 @@ namespace ngraph
result_desc, result_desc,
scales_size, scales_size,
quantize_index, quantize_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
...@@ -379,7 +392,11 @@ namespace ngraph ...@@ -379,7 +392,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, quantize_index, deps, cpu::mkldnn_utils::OpType::QUANTIZE); ctx,
quantize_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZE,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -396,6 +413,7 @@ namespace ngraph ...@@ -396,6 +413,7 @@ namespace ngraph
result_desc, result_desc,
scales, scales,
quantize_index, quantize_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
...@@ -416,7 +434,11 @@ namespace ngraph ...@@ -416,7 +434,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, quantize_index, deps, cpu::mkldnn_utils::OpType::QUANTIZE); ctx,
quantize_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZE,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -66,6 +66,7 @@ namespace ngraph ...@@ -66,6 +66,7 @@ namespace ngraph
auto conv_attr = auto conv_attr =
mkldnn_emitter mkldnn_emitter
->get_convolution_forward_attr<ngraph::op::QuantizedConvolution>(node); ->get_convolution_forward_attr<ngraph::op::QuantizedConvolution>(node);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr); QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t conv_index = mkldnn_emitter->convolution_forward_init(); size_t conv_index = mkldnn_emitter->convolution_forward_init();
...@@ -76,6 +77,7 @@ namespace ngraph ...@@ -76,6 +77,7 @@ namespace ngraph
conv_attr, conv_attr,
deps, deps,
conv_index, conv_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
arg2_buffer_index, arg2_buffer_index,
...@@ -113,7 +115,11 @@ namespace ngraph ...@@ -113,7 +115,11 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out0_buffer_index]); ctx, deps[2], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, conv_index, deps, cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTION); ctx,
conv_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTION,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -345,6 +351,7 @@ namespace ngraph ...@@ -345,6 +351,7 @@ namespace ngraph
mkldnn_emitter mkldnn_emitter
->get_convolution_forward_attr<ngraph::op::QuantizedConvolutionRelu>( ->get_convolution_forward_attr<ngraph::op::QuantizedConvolutionRelu>(
node); node);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr); QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t conv_index = mkldnn_emitter->convolution_forward_init(); size_t conv_index = mkldnn_emitter->convolution_forward_init();
...@@ -356,6 +363,7 @@ namespace ngraph ...@@ -356,6 +363,7 @@ namespace ngraph
conv_attr, conv_attr,
deps, deps,
conv_index, conv_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
arg2_buffer_index, arg2_buffer_index,
...@@ -392,7 +400,8 @@ namespace ngraph ...@@ -392,7 +400,8 @@ namespace ngraph
ctx, ctx,
conv_index, conv_index,
deps, deps,
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONRELU); cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONRELU,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -430,6 +439,7 @@ namespace ngraph ...@@ -430,6 +439,7 @@ namespace ngraph
mkldnn_emitter mkldnn_emitter
->get_convolution_forward_attr<ngraph::op::QuantizedConvolutionBias>( ->get_convolution_forward_attr<ngraph::op::QuantizedConvolutionBias>(
node); node);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr); QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t conv_index = mkldnn_emitter->convolution_forward_init(true); size_t conv_index = mkldnn_emitter->convolution_forward_init(true);
...@@ -441,6 +451,7 @@ namespace ngraph ...@@ -441,6 +451,7 @@ namespace ngraph
conv_attr, conv_attr,
deps, deps,
conv_index, conv_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
arg2_buffer_index, arg2_buffer_index,
...@@ -480,7 +491,8 @@ namespace ngraph ...@@ -480,7 +491,8 @@ namespace ngraph
ctx, ctx,
conv_index, conv_index,
deps, deps,
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIAS); cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIAS,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -525,6 +537,7 @@ namespace ngraph ...@@ -525,6 +537,7 @@ namespace ngraph
mkldnn_emitter mkldnn_emitter
->get_convolution_forward_attr<ngraph::op::QuantizedConvolutionBiasAdd>( ->get_convolution_forward_attr<ngraph::op::QuantizedConvolutionBiasAdd>(
node); node);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr); QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t conv_index = mkldnn_emitter->convolution_forward_init(true); size_t conv_index = mkldnn_emitter->convolution_forward_init(true);
...@@ -537,6 +550,7 @@ namespace ngraph ...@@ -537,6 +550,7 @@ namespace ngraph
conv_attr, conv_attr,
deps, deps,
conv_index, conv_index,
scratchpad_size,
arg3_size, arg3_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
...@@ -609,7 +623,8 @@ namespace ngraph ...@@ -609,7 +623,8 @@ namespace ngraph
ctx, ctx,
conv_index, conv_index,
deps, deps,
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIASADD); cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIASADD,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -650,6 +665,7 @@ namespace ngraph ...@@ -650,6 +665,7 @@ namespace ngraph
ngraph::op::QuantizedConvolutionBiasSignedAdd>(node); ngraph::op::QuantizedConvolutionBiasSignedAdd>(node);
auto conv_attr = mkldnn_emitter->get_convolution_forward_attr< auto conv_attr = mkldnn_emitter->get_convolution_forward_attr<
ngraph::op::QuantizedConvolutionBiasSignedAdd>(node); ngraph::op::QuantizedConvolutionBiasSignedAdd>(node);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr); QUERY_SCRATCHPAD_2ARGS(convolution_forward, conv_desc, conv_attr);
size_t conv_index = mkldnn_emitter->convolution_forward_init(true); size_t conv_index = mkldnn_emitter->convolution_forward_init(true);
...@@ -662,6 +678,7 @@ namespace ngraph ...@@ -662,6 +678,7 @@ namespace ngraph
conv_attr, conv_attr,
deps, deps,
conv_index, conv_index,
scratchpad_size,
arg3_size, arg3_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
...@@ -734,7 +751,8 @@ namespace ngraph ...@@ -734,7 +751,8 @@ namespace ngraph
ctx, ctx,
conv_index, conv_index,
deps, deps,
cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIASSIGNEDADD); cpu::mkldnn_utils::OpType::QUANTIZEDCONVOLUTIONBIASSIGNEDADD,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -63,7 +63,7 @@ namespace ngraph ...@@ -63,7 +63,7 @@ namespace ngraph
auto ip_attr = auto ip_attr =
mkldnn_emitter mkldnn_emitter
->get_inner_product_forward_attr<ngraph::op::QuantizedDotBias>(node); ->get_inner_product_forward_attr<ngraph::op::QuantizedDotBias>(node);
QUERY_SCRATCHPAD_2ARGS(ip_forward, ip_desc, ip_attr); size_t scratchpad_size = QUERY_SCRATCHPAD_2ARGS(ip_forward, ip_desc, ip_attr);
size_t ip_index = mkldnn_emitter->inner_product_forward_init(true); size_t ip_index = mkldnn_emitter->inner_product_forward_init(true);
auto& deps = mkldnn_emitter->get_primitive_deps(ip_index); auto& deps = mkldnn_emitter->get_primitive_deps(ip_index);
...@@ -74,6 +74,7 @@ namespace ngraph ...@@ -74,6 +74,7 @@ namespace ngraph
ip_attr, ip_attr,
deps, deps,
ip_index, ip_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
arg2_buffer_index, arg2_buffer_index,
...@@ -108,7 +109,11 @@ namespace ngraph ...@@ -108,7 +109,11 @@ namespace ngraph
ctx, deps[3], ctx->buffer_data[out0_buffer_index]); ctx, deps[3], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, ip_index, deps, cpu::mkldnn_utils::OpType::QUANTIZEDDOTBIAS); ctx,
ip_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZEDDOTBIAS,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -56,7 +56,7 @@ namespace ngraph ...@@ -56,7 +56,7 @@ namespace ngraph
auto ip_attr = auto ip_attr =
mkldnn_emitter->get_inner_product_forward_attr<ngraph::op::QuantizedMatmul>( mkldnn_emitter->get_inner_product_forward_attr<ngraph::op::QuantizedMatmul>(
node); node);
QUERY_SCRATCHPAD_2ARGS(ip_forward, ip_desc, ip_attr); size_t scratchpad_size = QUERY_SCRATCHPAD_2ARGS(ip_forward, ip_desc, ip_attr);
size_t ip_index = mkldnn_emitter->inner_product_forward_init(false); size_t ip_index = mkldnn_emitter->inner_product_forward_init(false);
auto& deps = mkldnn_emitter->get_primitive_deps(ip_index); auto& deps = mkldnn_emitter->get_primitive_deps(ip_index);
...@@ -66,6 +66,7 @@ namespace ngraph ...@@ -66,6 +66,7 @@ namespace ngraph
ip_attr, ip_attr,
deps, deps,
ip_index, ip_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
arg2_buffer_index, arg2_buffer_index,
...@@ -95,7 +96,11 @@ namespace ngraph ...@@ -95,7 +96,11 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out0_buffer_index]); ctx, deps[2], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, ip_index, deps, cpu::mkldnn_utils::OpType::QUANTIZEDMATMUL); ctx,
ip_index,
deps,
cpu::mkldnn_utils::OpType::QUANTIZEDMATMUL,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -41,14 +41,19 @@ namespace ngraph ...@@ -41,14 +41,19 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto relu_desc = mkldnn_emitter->get_relu_forward_desc(node); auto relu_desc = mkldnn_emitter->get_relu_forward_desc(node);
QUERY_SCRATCHPAD(eltwise_forward, relu_desc); size_t scratchpad_size = QUERY_SCRATCHPAD(eltwise_forward, relu_desc);
// Relu needs 3 primitives: input, result, and eltwise_forward. // Relu needs 3 primitives: input, result, and eltwise_forward.
size_t relu_index = mkldnn_emitter->reserve_primitive_space(3); size_t relu_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(relu_index); auto& deps = mkldnn_emitter->get_primitive_deps(relu_index);
auto functor = [&, relu_desc, relu_index, arg_buffer_index, out_buffer_index]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { relu_desc,
relu_index,
scratchpad_size,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_relu_forward(ctx->mkldnn_memories, mkldnn_emitter->build_relu_forward(ctx->mkldnn_memories,
...@@ -63,8 +68,11 @@ namespace ngraph ...@@ -63,8 +68,11 @@ namespace ngraph
cpu::mkldnn_utils::set_memory_ptr( cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx,
ctx, relu_index, deps, cpu::mkldnn_utils::OpType::RELU); relu_index,
deps,
cpu::mkldnn_utils::OpType::RELU,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -89,6 +97,7 @@ namespace ngraph ...@@ -89,6 +97,7 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto bwd_desc = mkldnn_emitter->get_relu_backward_desc(node); auto bwd_desc = mkldnn_emitter->get_relu_backward_desc(node);
auto fwd_desc = mkldnn_emitter->get_relu_forward_desc(node); auto fwd_desc = mkldnn_emitter->get_relu_forward_desc(node);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(eltwise_backward, fwd_desc, bwd_desc); QUERY_SCRATCHPAD_2ARGS(eltwise_backward, fwd_desc, bwd_desc);
// ReluBackprop needs 4 primitives: input, delta, result, and eltwise_backward. // ReluBackprop needs 4 primitives: input, delta, result, and eltwise_backward.
...@@ -99,6 +108,7 @@ namespace ngraph ...@@ -99,6 +108,7 @@ namespace ngraph
bwd_desc, bwd_desc,
fwd_desc, fwd_desc,
relu_index, relu_index,
scratchpad_size,
arg_fwd_buffer_index, arg_fwd_buffer_index,
delta_buffer_index, delta_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
...@@ -121,7 +131,11 @@ namespace ngraph ...@@ -121,7 +131,11 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out_buffer_index]); ctx, deps[2], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, relu_index, deps, cpu::mkldnn_utils::OpType::RELUBACKPROP); ctx,
relu_index,
deps,
cpu::mkldnn_utils::OpType::RELUBACKPROP,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -109,7 +109,7 @@ namespace ngraph ...@@ -109,7 +109,7 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
#else #else
mkldnn_emitter->query_scratchpad_rnn_forward(rnn_desc); size_t scratchpad_size = mkldnn_emitter->query_scratchpad_rnn_forward(rnn_desc);
auto src_iter_c_buffer_index = auto src_iter_c_buffer_index =
external_function->get_buffer_index(args[2].get_name()); external_function->get_buffer_index(args[2].get_name());
...@@ -132,6 +132,7 @@ namespace ngraph ...@@ -132,6 +132,7 @@ namespace ngraph
auto functor = [&, auto functor = [&,
rnn_desc, rnn_desc,
rnn_index, rnn_index,
scratchpad_size,
src_layer_buffer_index, src_layer_buffer_index,
src_iter_buffer_index, src_iter_buffer_index,
src_iter_c_buffer_index, src_iter_c_buffer_index,
...@@ -174,7 +175,7 @@ namespace ngraph ...@@ -174,7 +175,7 @@ namespace ngraph
ctx, deps[9], ctx->mkldnn_workspaces[deps[10]]); ctx, deps[9], ctx->mkldnn_workspaces[deps[10]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, rnn_index, deps, cpu::mkldnn_utils::OpType::RNN); ctx, rnn_index, deps, cpu::mkldnn_utils::OpType::RNN, scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
#endif #endif
......
...@@ -43,15 +43,19 @@ namespace ngraph ...@@ -43,15 +43,19 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto sigmoid_desc = mkldnn_emitter->get_sigmoid_forward_desc(node, false); auto sigmoid_desc = mkldnn_emitter->get_sigmoid_forward_desc(node, false);
QUERY_SCRATCHPAD(eltwise_forward, sigmoid_desc); size_t scratchpad_size = QUERY_SCRATCHPAD(eltwise_forward, sigmoid_desc);
// Sigmoid needs 3 primitives: input, result, and eltwise_forward. // Sigmoid needs 3 primitives: input, result, and eltwise_forward.
auto sigmoid_index = mkldnn_emitter->reserve_primitive_space(3); auto sigmoid_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(sigmoid_index); auto& deps = mkldnn_emitter->get_primitive_deps(sigmoid_index);
auto functor = auto functor = [&,
[&, sigmoid_desc, sigmoid_index, arg0_buffer_index, out_buffer_index]( sigmoid_desc,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { sigmoid_index,
scratchpad_size,
arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_sigmoid_forward(ctx->mkldnn_memories, mkldnn_emitter->build_sigmoid_forward(ctx->mkldnn_memories,
...@@ -66,8 +70,11 @@ namespace ngraph ...@@ -66,8 +70,11 @@ namespace ngraph
cpu::mkldnn_utils::set_memory_ptr( cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx,
ctx, sigmoid_index, deps, cpu::mkldnn_utils::OpType::SIGMOID); sigmoid_index,
deps,
cpu::mkldnn_utils::OpType::SIGMOID,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -88,6 +95,7 @@ namespace ngraph ...@@ -88,6 +95,7 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto fwd_desc = mkldnn_emitter->get_sigmoid_forward_desc(node, true); auto fwd_desc = mkldnn_emitter->get_sigmoid_forward_desc(node, true);
auto bwd_desc = mkldnn_emitter->get_sigmoid_backward_desc(node); auto bwd_desc = mkldnn_emitter->get_sigmoid_backward_desc(node);
size_t scratchpad_size =
QUERY_SCRATCHPAD_2ARGS(eltwise_backward, fwd_desc, bwd_desc); QUERY_SCRATCHPAD_2ARGS(eltwise_backward, fwd_desc, bwd_desc);
// SigmoidBackprop needs 4 primitives: input, delta, result, and eltwise_backward. // SigmoidBackprop needs 4 primitives: input, delta, result, and eltwise_backward.
...@@ -98,6 +106,7 @@ namespace ngraph ...@@ -98,6 +106,7 @@ namespace ngraph
bwd_desc, bwd_desc,
fwd_desc, fwd_desc,
sigmoid_index, sigmoid_index,
scratchpad_size,
arg0_buffer_index, arg0_buffer_index,
arg1_buffer_index, arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
...@@ -120,7 +129,11 @@ namespace ngraph ...@@ -120,7 +129,11 @@ namespace ngraph
ctx, deps[2], ctx->buffer_data[out_buffer_index]); ctx, deps[2], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, sigmoid_index, deps, cpu::mkldnn_utils::OpType::SIGMOIDBACKPROP); ctx,
sigmoid_index,
deps,
cpu::mkldnn_utils::OpType::SIGMOIDBACKPROP,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -93,7 +93,8 @@ namespace ngraph ...@@ -93,7 +93,8 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0); auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0); auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
QUERY_SCRATCHPAD_4ARGS(slice, input_desc, result_desc, lower_bounds, out_shape); size_t scratchpad_size = QUERY_SCRATCHPAD_4ARGS(
slice, input_desc, result_desc, lower_bounds, out_shape);
// Slice needs 3 primitives: input, result, and reorder. // Slice needs 3 primitives: input, result, and reorder.
auto slice_index = mkldnn_emitter->reserve_primitive_space(3); auto slice_index = mkldnn_emitter->reserve_primitive_space(3);
...@@ -105,6 +106,7 @@ namespace ngraph ...@@ -105,6 +106,7 @@ namespace ngraph
lower_bounds, lower_bounds,
out_shape, out_shape,
slice_index, slice_index,
scratchpad_size,
arg_buffer_index, arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
...@@ -125,8 +127,11 @@ namespace ngraph ...@@ -125,8 +127,11 @@ namespace ngraph
cpu::mkldnn_utils::set_memory_ptr( cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx,
ctx, slice_index, deps, cpu::mkldnn_utils::OpType::SLICE); slice_index,
deps,
cpu::mkldnn_utils::OpType::SLICE,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -48,15 +48,19 @@ namespace ngraph ...@@ -48,15 +48,19 @@ namespace ngraph
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto softmax_desc = mkldnn_emitter->get_softmax_forward_desc(node); auto softmax_desc = mkldnn_emitter->get_softmax_forward_desc(node);
QUERY_SCRATCHPAD(softmax_forward, softmax_desc); size_t scratchpad_size = QUERY_SCRATCHPAD(softmax_forward, softmax_desc);
// Softmax needs 3 primitives: input, result, and softmax_forward. // Softmax needs 3 primitives: input, result, and softmax_forward.
size_t softmax_index = mkldnn_emitter->reserve_primitive_space(3); size_t softmax_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(softmax_index); auto& deps = mkldnn_emitter->get_primitive_deps(softmax_index);
auto functor = auto functor = [&,
[&, softmax_desc, softmax_index, arg_buffer_index, out_buffer_index]( softmax_desc,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { softmax_index,
scratchpad_size,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_softmax_forward(ctx->mkldnn_memories, mkldnn_emitter->build_softmax_forward(ctx->mkldnn_memories,
...@@ -72,7 +76,11 @@ namespace ngraph ...@@ -72,7 +76,11 @@ namespace ngraph
ctx, deps[1], ctx->buffer_data[out_buffer_index]); ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, softmax_index, deps, cpu::mkldnn_utils::OpType::SOFTMAX); ctx,
softmax_index,
deps,
cpu::mkldnn_utils::OpType::SOFTMAX,
scratchpad_size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -217,7 +217,14 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context(Allocator* allocator) ...@@ -217,7 +217,14 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context(Allocator* allocator)
std::vector<mkldnn::memory*>(mkldnn_emitter->get_mkldnn_memories().size()); std::vector<mkldnn::memory*>(mkldnn_emitter->get_mkldnn_memories().size());
ctx->mkldnn_scratchpad_mds = std::vector<mkldnn::memory::desc*>( ctx->mkldnn_scratchpad_mds = std::vector<mkldnn::memory::desc*>(
mkldnn_emitter->get_mkldnn_scratchpad_mds().size()); mkldnn_emitter->get_mkldnn_scratchpad_mds().size());
ctx->scratchpad_buffer = new AlignedBuffer(scratchpad_size, alignment); if (scratchpad_size > 0)
{
ctx->scratchpad_buffer = new AlignedBuffer(scratchpad_size, alignment, allocator);
}
else
{
ctx->scratchpad_buffer = nullptr;
}
} }
else else
{ {
......
This diff is collapsed.
...@@ -486,7 +486,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_ ...@@ -486,7 +486,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
// Build mkldnn primitives for codegen. // Build mkldnn primitives for codegen.
pass_manager.register_pass<runtime::cpu::pass::MKLDNNPrimitiveBuildPass>( pass_manager.register_pass<runtime::cpu::pass::MKLDNNPrimitiveBuildPass>(
m_desc_filename, *m_mkldnn_emitter, m_node_primitive_string_deps_index_map); m_desc_filename, *m_mkldnn_emitter, m_node_primitive_string_deps_index_size_map);
unordered_map<Node*, Node*> node_function_map; unordered_map<Node*, Node*> node_function_map;
string common_function_string; string common_function_string;
...@@ -685,9 +685,16 @@ using namespace ngraph::runtime; ...@@ -685,9 +685,16 @@ using namespace ngraph::runtime;
writer << "mkldnn_scratchpad_mds = std::vector<mkldnn::memory::desc*>(" writer << "mkldnn_scratchpad_mds = std::vector<mkldnn::memory::desc*>("
<< to_string(m_mkldnn_emitter->get_mkldnn_scratchpad_mds().size()) << ");\n"; << to_string(m_mkldnn_emitter->get_mkldnn_scratchpad_mds().size()) << ");\n";
writer << "size_t scratchpad_size = " << m_mkldnn_emitter->get_max_scratchpad_size() << ";\n"; writer << "size_t scratchpad_size = " << m_mkldnn_emitter->get_max_scratchpad_size() << ";\n";
writer << "if (scratchpad_size > 0)\n";
writer.block_begin();
writer << "size_t alignment = 4096;\n"; writer << "size_t alignment = 4096;\n";
writer << "scratchpad_buffer = new AlignedBuffer(scratchpad_size, alignment);\n"; writer << "scratchpad_buffer = new AlignedBuffer(scratchpad_size, alignment);\n";
writer.block_end(); writer.block_end();
writer << "else\n";
writer.block_begin();
writer << "scratchpad_buffer = nullptr;\n";
writer.block_end();
writer.block_end();
writer << "\n"; writer << "\n";
set<string> output_names; set<string> output_names;
......
...@@ -137,13 +137,14 @@ namespace ngraph ...@@ -137,13 +137,14 @@ namespace ngraph
return m_mkldnn_emitter; return m_mkldnn_emitter;
} }
// Return the tuple including the string to create mkldnn primitive, the deps and // Return the tuple including the string to create mkldnn primitive, the deps, the
// the index in CODEGEN // index and
const std::tuple<std::string, std::vector<size_t>, size_t>& // the scratchpad size in CODEGEN
const std::tuple<std::string, std::vector<size_t>, size_t, size_t>&
get_primitive_build_tuple(const Node* node) const get_primitive_build_tuple(const Node* node) const
{ {
auto it = m_node_primitive_string_deps_index_map.find(node); auto it = m_node_primitive_string_deps_index_size_map.find(node);
NGRAPH_CHECK(it != m_node_primitive_string_deps_index_map.end(), NGRAPH_CHECK(it != m_node_primitive_string_deps_index_size_map.end(),
"Primitive build tuple not found for node ", "Primitive build tuple not found for node ",
node->description()); node->description());
...@@ -349,9 +350,9 @@ namespace ngraph ...@@ -349,9 +350,9 @@ namespace ngraph
#endif #endif
/// Map each node with mkldnn implementation to its mkldnn primitive creating /// Map each node with mkldnn implementation to its mkldnn primitive creating
/// string, deps, and mkldnn primitive index. /// string, deps, mkldnn primitive index, and mkldnn scratchpad size.
std::map<const Node*, std::tuple<std::string, std::vector<size_t>, size_t>> std::map<const Node*, std::tuple<std::string, std::vector<size_t>, size_t, size_t>>
m_node_primitive_string_deps_index_map; m_node_primitive_string_deps_index_size_map;
/// Name of the file to store descriptors for mkldnn_primitives /// Name of the file to store descriptors for mkldnn_primitives
const std::string m_desc_filename = "desc_file"; const std::string m_desc_filename = "desc_file";
}; };
......
This diff is collapsed.
...@@ -1215,15 +1215,6 @@ namespace ngraph ...@@ -1215,15 +1215,6 @@ namespace ngraph
attr); attr);
} }
mkldnn::memory::format_tag query_convolution_forward_weight_format_tag(
const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc_any,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& filter_strides,
const ngraph::Strides& window_dilation_strides_adjusted,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
template <typename OP> template <typename OP>
mkldnn::lstm_forward::desc mkldnn::lstm_forward::desc
get_rnn_forward_desc(const ngraph::Node* node, get_rnn_forward_desc(const ngraph::Node* node,
...@@ -1408,54 +1399,50 @@ namespace ngraph ...@@ -1408,54 +1399,50 @@ namespace ngraph
mkldnn_primitives[ip_idx] = prim; mkldnn_primitives[ip_idx] = prim;
} }
void query_scratchpad_sum(const mkldnn::sum::primitive_desc); size_t query_scratchpad_sum(const mkldnn::sum::primitive_desc);
void query_scratchpad_concat(const mkldnn::concat::primitive_desc); size_t query_scratchpad_concat(const mkldnn::concat::primitive_desc);
void query_scratchpad_pooling_forward(const mkldnn::pooling_forward::desc& desc); size_t query_scratchpad_pooling_forward(const mkldnn::pooling_forward::desc& desc);
void query_scratchpad_avg_pooling_backward( size_t query_scratchpad_avg_pooling_backward(
const mkldnn::pooling_forward::desc& fwd_desc, const mkldnn::pooling_forward::desc& fwd_desc,
const mkldnn::pooling_backward::desc& bwd_desc); const mkldnn::pooling_backward::desc& bwd_desc);
void query_scratchpad_max_pooling_backward( size_t query_scratchpad_max_pooling_backward(
const mkldnn::pooling_forward::desc& fwd_desc, const mkldnn::pooling_forward::desc& fwd_desc,
const mkldnn::pooling_backward::desc& bwd_desc); const mkldnn::pooling_backward::desc& bwd_desc);
void query_scratchpad_max_pooling_with_indices_backward( size_t query_scratchpad_max_pooling_with_indices_backward(
const mkldnn::pooling_forward::desc& fwd_desc, const mkldnn::pooling_forward::desc& fwd_desc,
const mkldnn::pooling_backward::desc& bwd_desc); const mkldnn::pooling_backward::desc& bwd_desc);
void query_scratchpad_batchnorm_forward( size_t query_scratchpad_batchnorm_forward(
const mkldnn::batch_normalization_forward::desc& desc, const mkldnn::batch_normalization_forward::desc& desc,
const mkldnn::post_ops& pops); const mkldnn::post_ops& pops);
void query_scratchpad_batchnorm_backward( size_t query_scratchpad_batchnorm_backward(
const mkldnn::batch_normalization_backward::desc& desc, const mkldnn::batch_normalization_backward::desc& desc,
const mkldnn::memory::desc& input_desc, const mkldnn::memory::desc& input_desc,
float epsilon); float epsilon);
void query_scratchpad_convolution_forward( size_t query_scratchpad_convolution_forward(
const mkldnn::convolution_forward::desc& desc, mkldnn::primitive_attr& attr); const mkldnn::convolution_forward::desc& desc, mkldnn::primitive_attr& attr);
void query_scratchpad_convolution_backward_data( size_t query_scratchpad_convolution_backward_data(
const mkldnn::convolution_forward::desc& fwd_desc, const mkldnn::convolution_forward::desc& fwd_desc,
const mkldnn::convolution_backward_data::desc& bwd_desc); const mkldnn::convolution_backward_data::desc& bwd_desc);
void query_scratchpad_convolution_backward_weights( size_t query_scratchpad_convolution_backward_weights(
const mkldnn::convolution_forward::desc& fwd_desc, const mkldnn::convolution_forward::desc& fwd_desc,
const mkldnn::convolution_backward_weights::desc& bwd_desc); const mkldnn::convolution_backward_weights::desc& bwd_desc);
void query_scratchpad_deconvolution_forward( size_t query_scratchpad_deconvolution_forward(
const mkldnn::deconvolution_forward::desc& desc); const mkldnn::deconvolution_forward::desc& desc);
void query_scratchpad_eltwise_forward(const mkldnn::eltwise_forward::desc& desc); size_t query_scratchpad_eltwise_forward(const mkldnn::eltwise_forward::desc& desc);
void query_scratchpad_eltwise_backward( size_t query_scratchpad_eltwise_backward(
const mkldnn::eltwise_forward::desc& fwd_desc, const mkldnn::eltwise_forward::desc& fwd_desc,
const mkldnn::eltwise_backward::desc& bwd_desc); const mkldnn::eltwise_backward::desc& bwd_desc);
void query_scratchpad_quantize(const mkldnn::memory::desc& input_desc, size_t query_scratchpad_ip_forward(const mkldnn::inner_product_forward::desc& desc,
const mkldnn::memory::desc& output_desc);
void query_scratchpad_dequantize(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& output_desc);
void query_scratchpad_ip_forward(const mkldnn::inner_product_forward::desc& desc,
mkldnn::primitive_attr& attr); mkldnn::primitive_attr& attr);
void query_scratchpad_reorder(const mkldnn::memory::desc& input_desc, size_t query_scratchpad_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc); const mkldnn::memory::desc& result_desc);
void query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc); size_t query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc);
void query_scratchpad_rnn_forward(const mkldnn::lstm_forward::desc& desc); size_t query_scratchpad_rnn_forward(const mkldnn::lstm_forward::desc& desc);
void query_scratchpad_slice(mkldnn::memory::desc& input_desc, size_t query_scratchpad_slice(mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& output_desc, const mkldnn::memory::desc& output_desc,
const ngraph::Coordinate& lower_bounds, const ngraph::Coordinate& lower_bounds,
const ngraph::Shape& result_shape); const ngraph::Shape& result_shape);
void query_scratchpad_softmax_forward(const mkldnn::softmax_forward::desc& desc); size_t query_scratchpad_softmax_forward(const mkldnn::softmax_forward::desc& desc);
#else #else
// TODO(jmenon): Get rid of TensorViewWrappers at some point // TODO(jmenon): Get rid of TensorViewWrappers at some point
...@@ -1702,15 +1689,6 @@ namespace ngraph ...@@ -1702,15 +1689,6 @@ namespace ngraph
mkldnn_primitives[ip_idx] = prim; mkldnn_primitives[ip_idx] = prim;
} }
mkldnn::memory::format query_convolution_forward_weight_format(
const mkldnn::memory::desc& input_data_desc,
const mkldnn::memory::desc& weights_desc_any,
const mkldnn::memory::desc& result_desc,
const ngraph::Strides& filter_strides,
const ngraph::Strides& window_dilation_strides_adjusted,
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
void build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories, void build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories,
std::vector<mkldnn::primitive*>& mkldnn_primitives, std::vector<mkldnn::primitive*>& mkldnn_primitives,
std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds, std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds,
......
...@@ -33,9 +33,14 @@ extern "C" void ngraph::runtime::cpu::mkldnn_utils::set_memory_ptr(CPURuntimeCon ...@@ -33,9 +33,14 @@ extern "C" void ngraph::runtime::cpu::mkldnn_utils::set_memory_ptr(CPURuntimeCon
primitive->set_data_handle(ptr); primitive->set_data_handle(ptr);
} }
extern "C" void ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive( extern "C" void
CPURuntimeContext* ctx, size_t primitive_index, std::vector<size_t>& deps, OpType type) ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(CPURuntimeContext* ctx,
size_t primitive_index,
std::vector<size_t>& deps,
OpType type,
size_t scratchpad_size)
{ {
(void)scratchpad_size;
mkldnn::stream s(mkldnn::stream::kind::eager); mkldnn::stream s(mkldnn::stream::kind::eager);
try try
{ {
...@@ -55,8 +60,12 @@ extern "C" void ngraph::runtime::cpu::mkldnn_utils::set_memory_ptr(CPURuntimeCon ...@@ -55,8 +60,12 @@ extern "C" void ngraph::runtime::cpu::mkldnn_utils::set_memory_ptr(CPURuntimeCon
memory->set_data_handle(ptr); memory->set_data_handle(ptr);
} }
extern "C" void ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive( extern "C" void
CPURuntimeContext* ctx, size_t primitive_index, std::vector<size_t>& deps, OpType type) ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(CPURuntimeContext* ctx,
size_t primitive_index,
std::vector<size_t>& deps,
OpType type,
size_t scratchpad_size)
{ {
std::unordered_map<int, mkldnn::memory> exec_args; std::unordered_map<int, mkldnn::memory> exec_args;
size_t nargs; size_t nargs;
...@@ -198,10 +207,13 @@ extern "C" void ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive( ...@@ -198,10 +207,13 @@ extern "C" void ngraph::runtime::cpu::mkldnn_utils::mkldnn_invoke_primitive(
break; break;
} }
if (scratchpad_size)
{
mkldnn::memory scratchpad(*ctx->mkldnn_scratchpad_mds[primitive_index], mkldnn::memory scratchpad(*ctx->mkldnn_scratchpad_mds[primitive_index],
executor::global_cpu_engine, executor::global_cpu_engine,
ctx->scratchpad_buffer->get_ptr()); ctx->scratchpad_buffer->get_ptr());
exec_args.insert({MKLDNN_ARG_SCRATCHPAD, scratchpad}); exec_args.insert({MKLDNN_ARG_SCRATCHPAD, scratchpad});
}
mkldnn::stream s(executor::global_cpu_engine); mkldnn::stream s(executor::global_cpu_engine);
try try
......
...@@ -82,7 +82,8 @@ namespace ngraph ...@@ -82,7 +82,8 @@ namespace ngraph
extern "C" void mkldnn_invoke_primitive(CPURuntimeContext* ctx, extern "C" void mkldnn_invoke_primitive(CPURuntimeContext* ctx,
size_t primitive_index, size_t primitive_index,
std::vector<size_t>& deps, std::vector<size_t>& deps,
OpType type); OpType type,
size_t scratchpad_size = 0);
} }
} }
} }
......
...@@ -48,10 +48,10 @@ ...@@ -48,10 +48,10 @@
#define SET_ROUND_MODE attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest); #define SET_ROUND_MODE attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
#define QUERY_SCRATCHPAD(op_name, x) #define QUERY_SCRATCHPAD(op_name, x) 0
#define QUERY_SCRATCHPAD_2ARGS(op_name, x, y) #define QUERY_SCRATCHPAD_2ARGS(op_name, x, y) 0
#define QUERY_SCRATCHPAD_3ARGS(op_name, x, y, z) #define QUERY_SCRATCHPAD_3ARGS(op_name, x, y, z) 0
#define QUERY_SCRATCHPAD_4ARGS(op_name, x, y, z, u) #define QUERY_SCRATCHPAD_4ARGS(op_name, x, y, z, u) 0
#define MKLDNN_ERROR_MESSAGE e.message #define MKLDNN_ERROR_MESSAGE e.message
...@@ -85,7 +85,8 @@ ...@@ -85,7 +85,8 @@
#define GET_SIZE \ #define GET_SIZE \
mkldnn::memory::desc scratchpad_md = pd.scratchpad_desc(); \ mkldnn::memory::desc scratchpad_md = pd.scratchpad_desc(); \
size_t size = scratchpad_md.get_size(); \ size_t size = scratchpad_md.get_size(); \
m_max_scratchpad_size = size > m_max_scratchpad_size ? size : m_max_scratchpad_size; m_max_scratchpad_size = size > m_max_scratchpad_size ? size : m_max_scratchpad_size; \
return size;
#define MKLDNN_ERROR_MESSAGE std::string(e.message) #define MKLDNN_ERROR_MESSAGE std::string(e.message)
......
...@@ -502,7 +502,8 @@ namespace ngraph ...@@ -502,7 +502,8 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ScatterAdd) void CPUAssignment::ASSIGN_DECL(ngraph::op::ScatterAdd)
{ {
auto update_slice = static_cast<ngraph::op::ScatterAdd*>(node); (void)external_function;
auto scatter_add = static_cast<ngraph::op::ScatterAdd*>(node);
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
...@@ -511,7 +512,7 @@ namespace ngraph ...@@ -511,7 +512,7 @@ namespace ngraph
// Safe to overwrite input // Safe to overwrite input
op_annotations->add_in_place_oi_pair({0, 0, true}); op_annotations->add_in_place_oi_pair({0, 0, true});
} }
update_slice->set_op_annotations(op_annotations); scatter_add->set_op_annotations(op_annotations);
} }
template <> template <>
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
std::string & construct_string, \ std::string & construct_string, \
std::vector<size_t> & deps, \ std::vector<size_t> & deps, \
size_t & index, \ size_t & index, \
size_t & scratchpad_size, \
std::ofstream & desc_file) std::ofstream & desc_file)
namespace mkldnn namespace mkldnn
...@@ -55,6 +56,7 @@ namespace ngraph ...@@ -55,6 +56,7 @@ namespace ngraph
std::string&, std::string&,
std::vector<size_t>&, std::vector<size_t>&,
size_t&, size_t&,
size_t&,
std::ofstream&)>; std::ofstream&)>;
using PrimitiveBuildStringConstructOpMap = using PrimitiveBuildStringConstructOpMap =
std::unordered_map<std::type_index, PrimitiveBuildStringConstructFunction>; std::unordered_map<std::type_index, PrimitiveBuildStringConstructFunction>;
...@@ -69,20 +71,23 @@ namespace ngraph ...@@ -69,20 +71,23 @@ namespace ngraph
ngraph::runtime::cpu::MKLDNNEmitter& m_mkldnn_emitter; ngraph::runtime::cpu::MKLDNNEmitter& m_mkldnn_emitter;
/// External map to store each node with mkldnn implementation and its mkldnn /// External map to store each node with mkldnn implementation and its mkldnn
/// creation string, deps, and mkldnn primitive index. /// creation string, deps, mkldnn primitive index, and mkldnn primitive
std::map<const Node*, std::tuple<std::string, std::vector<size_t>, size_t>>& /// scratchpad size.
m_node_primitive_string_deps_index_map; std::map<const Node*,
std::tuple<std::string, std::vector<size_t>, size_t, size_t>>&
m_node_primitive_string_deps_index_size_map;
public: public:
MKLDNNPrimitiveBuildPass( MKLDNNPrimitiveBuildPass(
std::string filename, std::string filename,
ngraph::runtime::cpu::MKLDNNEmitter& mkldnn_emitter, ngraph::runtime::cpu::MKLDNNEmitter& mkldnn_emitter,
std::map<const Node*, std::tuple<std::string, std::vector<size_t>, size_t>>& std::map<const Node*,
node_primitive_string_deps_index_map) std::tuple<std::string, std::vector<size_t>, size_t, size_t>>&
node_primitive_string_deps_index_size_map)
: m_desc_filename(filename) : m_desc_filename(filename)
, m_mkldnn_emitter(mkldnn_emitter) , m_mkldnn_emitter(mkldnn_emitter)
, m_node_primitive_string_deps_index_map( , m_node_primitive_string_deps_index_size_map(
node_primitive_string_deps_index_map) node_primitive_string_deps_index_size_map)
{ {
} }
...@@ -95,6 +100,7 @@ namespace ngraph ...@@ -95,6 +100,7 @@ namespace ngraph
std::string& construct_string, std::string& construct_string,
std::vector<size_t>& deps, std::vector<size_t>& deps,
size_t& index, size_t& index,
size_t& scratchpad_size,
std::ofstream& desc_file) std::ofstream& desc_file)
{ {
throw std::runtime_error("Unimplemented op '" + node->description() + throw std::runtime_error("Unimplemented op '" + node->description() +
......
...@@ -97,7 +97,7 @@ struct CPURuntimeContextCG ...@@ -97,7 +97,7 @@ struct CPURuntimeContextCG
} }
void mkldnn_invoke_primitive(size_t primitive_index, std::vector<size_t>& deps, void mkldnn_invoke_primitive(size_t primitive_index, std::vector<size_t>& deps,
OpType type) OpType type, size_t scratchpad_size)
{ {
std::unordered_map<int, mkldnn::memory> exec_args; std::unordered_map<int, mkldnn::memory> exec_args;
size_t nargs; size_t nargs;
...@@ -252,10 +252,13 @@ struct CPURuntimeContextCG ...@@ -252,10 +252,13 @@ struct CPURuntimeContextCG
break; break;
} }
if (scratchpad_size)
{
mkldnn::memory scratchpad(*mkldnn_scratchpad_mds[primitive_index], mkldnn::memory scratchpad(*mkldnn_scratchpad_mds[primitive_index],
global_cpu_engine, global_cpu_engine,
scratchpad_buffer->get_ptr()); scratchpad_buffer->get_ptr());
exec_args.insert({MKLDNN_ARG_SCRATCHPAD, scratchpad}); exec_args.insert({MKLDNN_ARG_SCRATCHPAD, scratchpad});
}
mkldnn::stream s(global_cpu_engine); mkldnn::stream s(global_cpu_engine);
try try
...@@ -265,7 +268,7 @@ struct CPURuntimeContextCG ...@@ -265,7 +268,7 @@ struct CPURuntimeContextCG
} }
catch (const mkldnn::error& e) catch (const mkldnn::error& e)
{ {
throw std::runtime_error("Could not run mkdnn primitive " + *e.message); throw std::runtime_error("Could not run mkdnn primitive " + std::string(e.message));
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment