Commit 8e798add authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

allow multiple thread safe calls to compiled function in DEX mode (#2717)

* Implement thread safe calls.

Create MKLDNN primitives on first iteratin for quantized_concat.

* Add buffer_data to CPURuntimeContext.

* Fix bugs.

* Modify unit test.

* Swap vectors of mkldnn primitive pointers for CODEGEN.

* Fix a bug.

* Address PR feedback.

* Rename variables.

* Update Gather, GatherND, and DeconvolutionBias.

* Fix style error.

Disable cpu thread_safe_calls test on Windows.
parent 79f27b2e
...@@ -44,19 +44,30 @@ namespace ngraph ...@@ -44,19 +44,30 @@ namespace ngraph
size_t add_index = mkldnn_emitter->reserve_primitive_space(4); size_t add_index = mkldnn_emitter->reserve_primitive_space(4);
auto& deps = mkldnn_emitter->get_primitive_deps(add_index); auto& deps = mkldnn_emitter->get_primitive_deps(add_index);
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg0_buffer_index =
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name()); external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto arg1_buffer_index =
external_function->get_buffer_index(args[1].get_name());
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto functor = [&, sum_pd, add_index](CPURuntimeContext* ctx, auto functor = [&,
sum_pd,
add_index,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_elementwise_add(sum_pd, add_index); mkldnn_emitter->build_elementwise_add(
ctx->mkldnn_primitives, sum_pd, deps, add_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], arg1_tensor); ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out_tensor); cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[arg1_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, add_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, add_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -33,8 +33,8 @@ namespace ngraph ...@@ -33,8 +33,8 @@ namespace ngraph
static int call_seq = 0; static int call_seq = 0;
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto count = static_cast<int>(out[0].get_size()); auto count = static_cast<int>(out[0].get_size());
auto data_type = args[0].get_element_type().get_type_enum(); auto data_type = args[0].get_element_type().get_type_enum();
...@@ -48,10 +48,12 @@ namespace ngraph ...@@ -48,10 +48,12 @@ namespace ngraph
node->get_friendly_name().c_str(), node->get_friendly_name().c_str(),
count); count);
auto functor = [&, count, data_type](CPURuntimeContext* ctx, auto functor = [&, count, data_type, arg_buffer_index, out_buffer_index](
CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
get_distributed_interface()->all_reduce( get_distributed_interface()->all_reduce(ctx->buffer_data[arg_buffer_index],
arg_tensor, out_tensor, data_type, count); ctx->buffer_data[out_buffer_index],
data_type,
count);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -37,8 +37,8 @@ namespace ngraph ...@@ -37,8 +37,8 @@ namespace ngraph
const ngraph::op::ArgMax* argmax = static_cast<const ngraph::op::ArgMax*>(node); const ngraph::op::ArgMax* argmax = static_cast<const ngraph::op::ArgMax*>(node);
CPUKernelFunctor functor; CPUKernelFunctor functor;
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (out[0].get_element_type() != element::i64 && if (out[0].get_element_type() != element::i64 &&
out[0].get_element_type() != element::i32) out[0].get_element_type() != element::i32)
{ {
...@@ -60,9 +60,20 @@ namespace ngraph ...@@ -60,9 +60,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmax); kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
else else
...@@ -72,9 +83,20 @@ namespace ngraph ...@@ -72,9 +83,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmax); kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
} }
...@@ -88,9 +110,20 @@ namespace ngraph ...@@ -88,9 +110,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmax); kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
else else
...@@ -101,9 +134,20 @@ namespace ngraph ...@@ -101,9 +134,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmax); kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
} }
...@@ -117,9 +161,20 @@ namespace ngraph ...@@ -117,9 +161,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmax); kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
else else
...@@ -129,9 +184,20 @@ namespace ngraph ...@@ -129,9 +184,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmax); kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
} }
......
...@@ -37,8 +37,8 @@ namespace ngraph ...@@ -37,8 +37,8 @@ namespace ngraph
const ngraph::op::ArgMin* argmin = static_cast<const ngraph::op::ArgMin*>(node); const ngraph::op::ArgMin* argmin = static_cast<const ngraph::op::ArgMin*>(node);
CPUKernelFunctor functor; CPUKernelFunctor functor;
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (out[0].get_element_type() != element::i64 && if (out[0].get_element_type() != element::i64 &&
out[0].get_element_type() != element::i32) out[0].get_element_type() != element::i32)
{ {
...@@ -60,9 +60,20 @@ namespace ngraph ...@@ -60,9 +60,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmin); kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
else else
...@@ -72,9 +83,20 @@ namespace ngraph ...@@ -72,9 +83,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmin); kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
} }
...@@ -88,9 +110,20 @@ namespace ngraph ...@@ -88,9 +110,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmin); kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
else else
...@@ -101,9 +134,20 @@ namespace ngraph ...@@ -101,9 +134,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmin); kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
} }
...@@ -117,9 +161,20 @@ namespace ngraph ...@@ -117,9 +161,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmin); kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
else else
...@@ -129,9 +184,20 @@ namespace ngraph ...@@ -129,9 +184,20 @@ namespace ngraph
SELECT_RANK2( SELECT_RANK2(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmin); kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena); in_shape,
out_shape,
axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
in_shape,
out_shape,
axis,
ectx->arena);
}; };
} }
} }
......
...@@ -39,8 +39,8 @@ namespace ngraph ...@@ -39,8 +39,8 @@ namespace ngraph
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto window_shape = avg_pool->get_window_shape(); auto window_shape = avg_pool->get_window_shape();
auto window_movement_strides = avg_pool->get_window_movement_strides(); auto window_movement_strides = avg_pool->get_window_movement_strides();
...@@ -59,14 +59,18 @@ namespace ngraph ...@@ -59,14 +59,18 @@ namespace ngraph
size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3); size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index); auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index);
auto functor = [&, avg_pool_desc, avg_pool_index](CPURuntimeContext* ctx, auto functor =
CPUExecutionContext* ectx) { [&, avg_pool_desc, avg_pool_index, arg0_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_pooling_forward(avg_pool_desc, avg_pool_index); mkldnn_emitter->build_pooling_forward(
ctx->mkldnn_primitives, avg_pool_desc, deps, avg_pool_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor); ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, avg_pool_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, avg_pool_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
...@@ -86,10 +90,12 @@ namespace ngraph ...@@ -86,10 +90,12 @@ namespace ngraph
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above, padding_above,
include_padding_in_avg_computation](CPURuntimeContext* ctx, include_padding_in_avg_computation,
arg0_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
kernel(arg0_tensor, kernel(ctx->buffer_data[arg0_buffer_index],
out_tensor, ctx->buffer_data[out_buffer_index],
arg0_shape, arg0_shape,
out_shape, out_shape,
window_shape, window_shape,
...@@ -112,8 +118,8 @@ namespace ngraph ...@@ -112,8 +118,8 @@ namespace ngraph
auto delta_shape = args[0].get_shape(); auto delta_shape = args[0].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto& delta_tensor = external_function->get_tensor_data(args[0].get_name()); auto delta_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto window_shape = apb->get_window_shape(); auto window_shape = apb->get_window_shape();
auto window_movement_strides = apb->get_window_movement_strides(); auto window_movement_strides = apb->get_window_movement_strides();
...@@ -135,15 +141,25 @@ namespace ngraph ...@@ -135,15 +141,25 @@ namespace ngraph
size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3); size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index); auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index);
auto functor = [&, avg_pool_desc, avg_pool_fwd_desc, avg_pool_index]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { avg_pool_desc,
avg_pool_fwd_desc,
avg_pool_index,
delta_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_pooling_backward( mkldnn_emitter->build_pooling_backward(ctx->mkldnn_primitives,
avg_pool_desc, avg_pool_fwd_desc, avg_pool_index); avg_pool_desc,
avg_pool_fwd_desc,
deps,
avg_pool_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], delta_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor); ctx, deps[0], ctx->buffer_data[delta_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, avg_pool_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, avg_pool_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
...@@ -162,10 +178,12 @@ namespace ngraph ...@@ -162,10 +178,12 @@ namespace ngraph
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above, padding_above,
include_padding_in_avg_computation](CPURuntimeContext* ctx, include_padding_in_avg_computation,
delta_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
kernel(delta_tensor, kernel(ctx->buffer_data[delta_buffer_index],
out_tensor, ctx->buffer_data[out_buffer_index],
delta_shape, delta_shape,
out_shape, out_shape,
window_shape, window_shape,
......
...@@ -34,8 +34,8 @@ namespace ngraph ...@@ -34,8 +34,8 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& input_tensor = external_function->get_tensor_data(args[0].get_name()); auto input_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
size_t count = out[0].get_size(); size_t count = out[0].get_size();
auto alpha = static_cast<const ngraph::op::BoundedRelu*>(node)->get_alpha(); auto alpha = static_cast<const ngraph::op::BoundedRelu*>(node)->get_alpha();
...@@ -48,15 +48,23 @@ namespace ngraph ...@@ -48,15 +48,23 @@ namespace ngraph
auto bounded_relu_index = mkldnn_emitter->reserve_primitive_space(3); auto bounded_relu_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(bounded_relu_index); auto& deps = mkldnn_emitter->get_primitive_deps(bounded_relu_index);
auto functor = [&, bounded_relu_desc, bounded_relu_index]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { bounded_relu_desc,
bounded_relu_index,
input_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_bounded_relu(bounded_relu_desc, mkldnn_emitter->build_bounded_relu(ctx->mkldnn_primitives,
bounded_relu_desc,
deps,
bounded_relu_index); bounded_relu_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], input_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor); ctx, deps[0], ctx->buffer_data[input_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, bounded_relu_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, bounded_relu_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
...@@ -68,9 +76,13 @@ namespace ngraph ...@@ -68,9 +76,13 @@ namespace ngraph
SELECT_KERNEL( SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::bounded_relu); kernel, out[0].get_element_type(), runtime::cpu::kernel::bounded_relu);
auto functor = [&, kernel, alpha, count](CPURuntimeContext* ctx, auto functor = [&, kernel, alpha, count, input_buffer_index, out_buffer_index](
CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(input_tensor, out_tensor, alpha, count, ectx->arena); kernel(ctx->buffer_data[input_buffer_index],
ctx->buffer_data[out_buffer_index],
alpha,
count,
ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -42,6 +42,9 @@ namespace ngraph ...@@ -42,6 +42,9 @@ namespace ngraph
auto arg_shape = broadcast->get_argument(0)->get_shape(); auto arg_shape = broadcast->get_argument(0)->get_shape();
out_shape = broadcast->get_shape(); out_shape = broadcast->get_shape();
// TODO(jmenon): Shape transformations, rank reduction etc. needs to be general
// and not in any one builder. Move this to the Halide analysis phase.
// Transform output shape - ex. [4, 1, 2, 2] -> [4, 1, 4] // Transform output shape - ex. [4, 1, 2, 2] -> [4, 1, 4]
// if we're not broadcasting along axes 2 and 3 // if we're not broadcasting along axes 2 and 3
...@@ -92,7 +95,9 @@ namespace ngraph ...@@ -92,7 +95,9 @@ namespace ngraph
else else
{ {
broadcast_axes.erase(i); broadcast_axes.erase(i);
// TODO(jmenon): This needs to be rewritten
// when it gets moved to the analysis pass
// that doesn't use AxisSet
auto new_bcast_axes = AxisSet{}; auto new_bcast_axes = AxisSet{};
for (auto axis : broadcast_axes) for (auto axis : broadcast_axes)
{ {
...@@ -188,8 +193,8 @@ namespace ngraph ...@@ -188,8 +193,8 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
std::function<decltype(runtime::cpu::kernel::broadcast<float, 2>)> kernel; std::function<decltype(runtime::cpu::kernel::broadcast<float, 2>)> kernel;
Shape expanded_input_shape, out_shape; Shape expanded_input_shape, out_shape;
...@@ -199,17 +204,28 @@ namespace ngraph ...@@ -199,17 +204,28 @@ namespace ngraph
CPUKernelFunctor functor; CPUKernelFunctor functor;
if (kernel) if (kernel)
{ {
functor = [&, kernel, expanded_input_shape, out_shape]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel( expanded_input_shape,
arg_tensor, out_tensor, expanded_input_shape, out_shape, ectx->arena); out_shape,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
expanded_input_shape,
out_shape,
ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
else else
{ {
functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) { functor = [&, size, arg_buffer_index, out_buffer_index](
memcpy(out_tensor, arg_tensor, size); CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(ctx->buffer_data[out_buffer_index],
ctx->buffer_data[arg_buffer_index],
size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -31,12 +31,13 @@ namespace ngraph ...@@ -31,12 +31,13 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto count = static_cast<int>(args[0].get_size()); auto count = static_cast<int>(args[0].get_size());
auto data_type = args[0].get_element_type().get_type_enum(); auto data_type = args[0].get_element_type().get_type_enum();
auto functor = [&, count, data_type](CPURuntimeContext* ctx, auto functor = [&, count, data_type, arg_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
get_distributed_interface()->broadcast(arg_tensor, data_type, count); get_distributed_interface()->broadcast(
ctx->buffer_data[arg_buffer_index], data_type, count);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -37,7 +37,7 @@ namespace ngraph ...@@ -37,7 +37,7 @@ namespace ngraph
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
vector<reference_wrapper<void*>> arg_tensors; vector<size_t> arg_buffer_indices;
vector<Shape> arg_shapes; vector<Shape> arg_shapes;
vector<size_t> arg_sizes; vector<size_t> arg_sizes;
auto element_size = concat->get_input_element_type(0).size(); auto element_size = concat->get_input_element_type(0).size();
...@@ -45,15 +45,15 @@ namespace ngraph ...@@ -45,15 +45,15 @@ namespace ngraph
{ {
if (shape_size(arg.get_shape())) if (shape_size(arg.get_shape()))
{ {
arg_tensors.emplace_back( arg_buffer_indices.emplace_back(
external_function->get_tensor_data(arg.get_name())); external_function->get_buffer_index(arg.get_name()));
arg_shapes.emplace_back(arg.get_shape()); arg_shapes.emplace_back(arg.get_shape());
arg_sizes.emplace_back(shape_size(arg.get_shape()) * element_size); arg_sizes.emplace_back(shape_size(arg.get_shape()) * element_size);
} }
} }
auto nargs = args.size(); auto nargs = args.size();
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
if (auto op_annotations = concat->get_op_annotations()) if (auto op_annotations = concat->get_op_annotations())
...@@ -63,7 +63,8 @@ namespace ngraph ...@@ -63,7 +63,8 @@ namespace ngraph
{ {
auto out_size = shape_size(out_shape) * element_size; auto out_size = shape_size(out_shape) * element_size;
auto functor = [&, arg_tensors, nargs, out_size, arg_sizes]( auto functor =
[&, arg_buffer_indices, nargs, out_size, arg_sizes, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
auto offset = 0; auto offset = 0;
for (size_t i = 0; i < nargs; i++) for (size_t i = 0; i < nargs; i++)
...@@ -71,12 +72,17 @@ namespace ngraph ...@@ -71,12 +72,17 @@ namespace ngraph
// if the argument pointer does not fall within the concat output buffer // if the argument pointer does not fall within the concat output buffer
// (caused by propagate_in_place_output or propagate_in_place_input), we need to copy the data; // (caused by propagate_in_place_output or propagate_in_place_input), we need to copy the data;
// otherwise, we can skip the copy. // otherwise, we can skip the copy.
if (arg_tensors[i] < out_tensor || if (ctx->buffer_data[arg_buffer_indices[i]] <
arg_tensors[i] >= ctx->buffer_data[out_buffer_index] ||
reinterpret_cast<char*>(out_tensor) + out_size) ctx->buffer_data[arg_buffer_indices[i]] >=
reinterpret_cast<char*>(
ctx->buffer_data[out_buffer_index]) +
out_size)
{ {
memcpy(reinterpret_cast<char*>(out_tensor) + offset, memcpy(reinterpret_cast<char*>(
arg_tensors[i], ctx->buffer_data[out_buffer_index]) +
offset,
ctx->buffer_data[arg_buffer_indices[i]],
arg_sizes[i]); arg_sizes[i]);
} }
offset += arg_sizes[i]; offset += arg_sizes[i];
...@@ -92,7 +98,8 @@ namespace ngraph ...@@ -92,7 +98,8 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto concat_pd = mkldnn_emitter->get_concat_desc(node, nargs); auto concat_pd =
mkldnn_emitter->get_concat_desc<ngraph::op::Concat>(node, nargs);
std::vector<mkldnn::memory::desc> inputs_data_desc; std::vector<mkldnn::memory::desc> inputs_data_desc;
for (size_t i = 0; i < nargs; i++) for (size_t i = 0; i < nargs; i++)
{ {
...@@ -102,19 +109,29 @@ namespace ngraph ...@@ -102,19 +109,29 @@ namespace ngraph
auto concat_index = mkldnn_emitter->reserve_primitive_space(nargs + 2); auto concat_index = mkldnn_emitter->reserve_primitive_space(nargs + 2);
auto& deps = mkldnn_emitter->get_primitive_deps(concat_index); auto& deps = mkldnn_emitter->get_primitive_deps(concat_index);
auto functor = auto functor = [&,
[&, concat_pd, inputs_data_desc, arg_tensors, nargs, concat_index]( concat_pd,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { inputs_data_desc,
arg_buffer_indices,
nargs,
concat_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_concat( mkldnn_emitter->build_concat(ctx->mkldnn_primitives,
concat_pd, inputs_data_desc, concat_index); concat_pd,
inputs_data_desc,
deps,
concat_index);
} }
for (size_t i = 0; i < nargs; i++) for (size_t i = 0; i < nargs; i++)
{ {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[i], arg_tensors[i]); cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[i], ctx->buffer_data[arg_buffer_indices[i]]);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[nargs], out_tensor); cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[nargs], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, concat_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, concat_index);
}; };
...@@ -129,9 +146,24 @@ namespace ngraph ...@@ -129,9 +146,24 @@ namespace ngraph
out[0].get_shape().size(), out[0].get_shape().size(),
runtime::cpu::kernel::concat); runtime::cpu::kernel::concat);
auto functor = [&, kernel, arg_tensors, arg_shapes, out_shape, axis]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensors, arg_shapes, out_tensor, out_shape, axis); arg_buffer_indices,
arg_shapes,
out_shape,
axis,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
std::vector<void*> arg_tensors;
for (auto& arg_buffer_index : arg_buffer_indices)
{
arg_tensors.push_back(ctx->buffer_data[arg_buffer_index]);
}
kernel(arg_tensors,
arg_shapes,
ctx->buffer_data[out_buffer_index],
out_shape,
axis);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -32,8 +32,8 @@ namespace ngraph ...@@ -32,8 +32,8 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto element_count = out[0].get_size(); auto element_count = out[0].get_size();
...@@ -101,11 +101,14 @@ namespace ngraph ...@@ -101,11 +101,14 @@ namespace ngraph
throw ngraph_error("Cannot convert from an invalid input element type"); throw ngraph_error("Cannot convert from an invalid input element type");
} }
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx, auto functor = [&, kernel, element_count, arg_buffer_index, out_buffer_index](
CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (arg_tensor != out_tensor) if (ctx->buffer_data[arg_buffer_index] != ctx->buffer_data[out_buffer_index])
{ {
kernel(arg_tensor, out_tensor, element_count, ectx->arena); kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
element_count,
ectx->arena);
} }
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -35,8 +35,8 @@ namespace ngraph ...@@ -35,8 +35,8 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
...@@ -84,14 +84,21 @@ namespace ngraph ...@@ -84,14 +84,21 @@ namespace ngraph
// ConvertLayout needs 3 primitives: input, result, and reorder. // ConvertLayout needs 3 primitives: input, result, and reorder.
size_t reorder_index = mkldnn_emitter->reserve_primitive_space(3); size_t reorder_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(reorder_index); auto& deps = mkldnn_emitter->get_primitive_deps(reorder_index);
auto functor = [&, input_desc, result_desc, reorder_index]( auto functor =
[&, input_desc, result_desc, reorder_index, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_reorder(input_desc, result_desc, reorder_index); mkldnn_emitter->build_reorder(ctx->mkldnn_primitives,
input_desc,
result_desc,
deps,
reorder_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor); ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, reorder_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, reorder_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
This diff is collapsed.
...@@ -36,9 +36,9 @@ namespace ngraph ...@@ -36,9 +36,9 @@ namespace ngraph
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
CPUKernelFunctor functor; CPUKernelFunctor functor;
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name()); auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (out[0].get_element_type() != element::f32 && if (out[0].get_element_type() != element::f32 &&
out[0].get_element_type() != element::f64) out[0].get_element_type() != element::f64)
...@@ -54,39 +54,54 @@ namespace ngraph ...@@ -54,39 +54,54 @@ namespace ngraph
{ {
if (index_element_type == element::f32) if (index_element_type == element::f32)
{ {
functor = [&, in_shape, element_count](CPURuntimeContext* ctx, functor = [&,
in_shape,
element_count,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<float, float>( ngraph::runtime::reference::embedding<float, float>(
static_cast<float*>(arg0_tensor), static_cast<float*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<float*>(arg1_tensor), static_cast<float*>(ctx->buffer_data[arg1_buffer_index]),
static_cast<float*>(out_tensor), static_cast<float*>(ctx->buffer_data[out_buffer_index]),
element_count, element_count,
in_shape); in_shape);
}; };
} }
else if (index_element_type == element::i32) else if (index_element_type == element::i32)
{ {
functor = [&, in_shape, element_count](CPURuntimeContext* ctx, functor = [&,
in_shape,
element_count,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<float, int>( ngraph::runtime::reference::embedding<float, int>(
static_cast<int*>(arg0_tensor), static_cast<int*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<float*>(arg1_tensor), static_cast<float*>(ctx->buffer_data[arg1_buffer_index]),
static_cast<float*>(out_tensor), static_cast<float*>(ctx->buffer_data[out_buffer_index]),
element_count, element_count,
in_shape); in_shape);
}; };
} }
else if (index_element_type == element::i64) else if (index_element_type == element::i64)
{ {
functor = [&, in_shape, element_count](CPURuntimeContext* ctx, functor = [&,
in_shape,
element_count,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<float, int64_t>( ngraph::runtime::reference::embedding<float, int64_t>(
static_cast<int64_t*>(arg0_tensor), static_cast<int64_t*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<float*>(arg1_tensor), static_cast<float*>(ctx->buffer_data[arg1_buffer_index]),
static_cast<float*>(out_tensor), static_cast<float*>(ctx->buffer_data[out_buffer_index]),
element_count, element_count,
in_shape); in_shape);
}; };
...@@ -101,39 +116,54 @@ namespace ngraph ...@@ -101,39 +116,54 @@ namespace ngraph
{ {
if (index_element_type == element::f32) if (index_element_type == element::f32)
{ {
functor = [&, in_shape, element_count](CPURuntimeContext* ctx, functor = [&,
in_shape,
element_count,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<int, float>( ngraph::runtime::reference::embedding<int, float>(
static_cast<float*>(arg0_tensor), static_cast<float*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<int*>(arg1_tensor), static_cast<int*>(ctx->buffer_data[arg1_buffer_index]),
static_cast<int*>(out_tensor), static_cast<int*>(ctx->buffer_data[out_buffer_index]),
element_count, element_count,
in_shape); in_shape);
}; };
} }
else if (index_element_type == element::i32) else if (index_element_type == element::i32)
{ {
functor = [&, in_shape, element_count](CPURuntimeContext* ctx, functor = [&,
in_shape,
element_count,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<int, int>( ngraph::runtime::reference::embedding<int, int>(
static_cast<int*>(arg0_tensor), static_cast<int*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<int*>(arg1_tensor), static_cast<int*>(ctx->buffer_data[arg1_buffer_index]),
static_cast<int*>(out_tensor), static_cast<int*>(ctx->buffer_data[out_buffer_index]),
element_count, element_count,
in_shape); in_shape);
}; };
} }
else if (index_element_type == element::i64) else if (index_element_type == element::i64)
{ {
functor = [&, in_shape, element_count](CPURuntimeContext* ctx, functor = [&,
in_shape,
element_count,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<int, int64_t>( ngraph::runtime::reference::embedding<int, int64_t>(
static_cast<int64_t*>(arg0_tensor), static_cast<int64_t*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<int*>(arg1_tensor), static_cast<int*>(ctx->buffer_data[arg1_buffer_index]),
static_cast<int*>(out_tensor), static_cast<int*>(ctx->buffer_data[out_buffer_index]),
element_count, element_count,
in_shape); in_shape);
}; };
......
...@@ -32,8 +32,8 @@ namespace ngraph ...@@ -32,8 +32,8 @@ namespace ngraph
{ {
auto element_type = args[0].get_element_type(); auto element_type = args[0].get_element_type();
auto element_count = out[0].get_size(); auto element_count = out[0].get_size();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out0_tensor = external_function->get_tensor_data(out[0].get_name()); auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
if (element_type == element::f32 || element_type == element::f64) if (element_type == element::f32 || element_type == element::f64)
...@@ -47,9 +47,12 @@ namespace ngraph ...@@ -47,9 +47,12 @@ namespace ngraph
{ {
kernel = runtime::cpu::kernel::erf<double>; kernel = runtime::cpu::kernel::erf<double>;
} }
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx, auto functor = [&, kernel, element_count, arg0_buffer_index, out0_buffer_index](
CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg0_tensor, out0_tensor, element_count, ectx->arena); kernel(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[out0_buffer_index],
element_count,
ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
...@@ -58,9 +61,11 @@ namespace ngraph ...@@ -58,9 +61,11 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::reference_erf<float>)> kernel; std::function<decltype(runtime::cpu::kernel::reference_erf<float>)> kernel;
SELECT_KERNEL( SELECT_KERNEL(
kernel, args[0].get_element_type(), runtime::cpu::kernel::reference_erf); kernel, args[0].get_element_type(), runtime::cpu::kernel::reference_erf);
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx, auto functor = [&, kernel, element_count, arg0_buffer_index, out0_buffer_index](
CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg0_tensor, out0_tensor, element_count); kernel(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[out0_buffer_index],
element_count);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -36,9 +36,9 @@ namespace ngraph ...@@ -36,9 +36,9 @@ namespace ngraph
const ngraph::op::Gather* gather = static_cast<const ngraph::op::Gather*>(node); const ngraph::op::Gather* gather = static_cast<const ngraph::op::Gather*>(node);
CPUKernelFunctor functor; CPUKernelFunctor functor;
auto& params_tensor = external_function->get_tensor_data(args[0].get_name()); auto params_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& indices_tensor = external_function->get_tensor_data(args[1].get_name()); auto indices_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (args[1].get_element_type() != element::i64 && if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32) args[1].get_element_type() != element::i32)
{ {
...@@ -54,12 +54,19 @@ namespace ngraph ...@@ -54,12 +54,19 @@ namespace ngraph
{ {
if (is_int64) if (is_int64)
{ {
functor = [&, params_shape, indices_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<float, int64_t>( ngraph::runtime::reference::gather<float, int64_t>(
static_cast<float*>(params_tensor), static_cast<float*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(indices_tensor), static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(out_tensor), static_cast<float*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
...@@ -68,12 +75,19 @@ namespace ngraph ...@@ -68,12 +75,19 @@ namespace ngraph
} }
else else
{ {
functor = [&, params_shape, indices_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<float, int32_t>( ngraph::runtime::reference::gather<float, int32_t>(
static_cast<float*>(params_tensor), static_cast<float*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(indices_tensor), static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(out_tensor), static_cast<float*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
...@@ -85,12 +99,19 @@ namespace ngraph ...@@ -85,12 +99,19 @@ namespace ngraph
{ {
if (is_int64) if (is_int64)
{ {
functor = [&, params_shape, indices_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<double, int64_t>( ngraph::runtime::reference::gather<double, int64_t>(
static_cast<double*>(params_tensor), static_cast<double*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(indices_tensor), static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(out_tensor), static_cast<double*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
...@@ -99,12 +120,19 @@ namespace ngraph ...@@ -99,12 +120,19 @@ namespace ngraph
} }
else else
{ {
functor = [&, params_shape, indices_shape, out_shape, axis]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<double, int32_t>( ngraph::runtime::reference::gather<double, int32_t>(
static_cast<double*>(params_tensor), static_cast<double*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(indices_tensor), static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(out_tensor), static_cast<double*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
......
...@@ -35,9 +35,9 @@ namespace ngraph ...@@ -35,9 +35,9 @@ namespace ngraph
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
CPUKernelFunctor functor; CPUKernelFunctor functor;
auto& params_tensor = external_function->get_tensor_data(args[0].get_name()); auto params_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& indices_tensor = external_function->get_tensor_data(args[1].get_name()); auto indices_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (args[1].get_element_type() != element::i64 && if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32) args[1].get_element_type() != element::i32)
{ {
...@@ -52,12 +52,18 @@ namespace ngraph ...@@ -52,12 +52,18 @@ namespace ngraph
{ {
if (is_int64) if (is_int64)
{ {
functor = [&, params_shape, indices_shape, out_shape]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { params_shape,
indices_shape,
out_shape,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather_nd<float, int64_t>( ngraph::runtime::reference::gather_nd<float, int64_t>(
static_cast<float*>(params_tensor), static_cast<float*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(indices_tensor), static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(out_tensor), static_cast<float*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape); out_shape);
...@@ -65,12 +71,18 @@ namespace ngraph ...@@ -65,12 +71,18 @@ namespace ngraph
} }
else else
{ {
functor = [&, params_shape, indices_shape, out_shape]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { params_shape,
indices_shape,
out_shape,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather_nd<float, int32_t>( ngraph::runtime::reference::gather_nd<float, int32_t>(
static_cast<float*>(params_tensor), static_cast<float*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(indices_tensor), static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(out_tensor), static_cast<float*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape); out_shape);
...@@ -81,12 +93,18 @@ namespace ngraph ...@@ -81,12 +93,18 @@ namespace ngraph
{ {
if (is_int64) if (is_int64)
{ {
functor = [&, params_shape, indices_shape, out_shape]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { params_shape,
indices_shape,
out_shape,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather_nd<double, int64_t>( ngraph::runtime::reference::gather_nd<double, int64_t>(
static_cast<double*>(params_tensor), static_cast<double*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(indices_tensor), static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(out_tensor), static_cast<double*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape); out_shape);
...@@ -94,12 +112,18 @@ namespace ngraph ...@@ -94,12 +112,18 @@ namespace ngraph
} }
else else
{ {
functor = [&, params_shape, indices_shape, out_shape]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { params_shape,
indices_shape,
out_shape,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather_nd<double, int32_t>( ngraph::runtime::reference::gather_nd<double, int32_t>(
static_cast<double*>(params_tensor), static_cast<double*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(indices_tensor), static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(out_tensor), static_cast<double*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape); out_shape);
......
...@@ -36,10 +36,11 @@ namespace ngraph ...@@ -36,10 +36,11 @@ namespace ngraph
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto goe = static_cast<const ngraph::op::GetOutputElement*>(node); auto goe = static_cast<const ngraph::op::GetOutputElement*>(node);
size_t n = goe->get_n(); size_t n = goe->get_n();
auto& arg_tensor = external_function->get_tensor_data(args[n].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[n].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto functor = [&, n](CPURuntimeContext* ctx, CPUExecutionContext* ectx) { auto functor = [&, n, arg_buffer_index, out_buffer_index](
if (arg_tensor != out_tensor) CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (ctx->buffer_data[arg_buffer_index] != ctx->buffer_data[out_buffer_index])
{ {
throw ngraph_error("GOE's input and out must be equal"); throw ngraph_error("GOE's input and out must be equal");
} }
......
...@@ -52,7 +52,7 @@ namespace ngraph ...@@ -52,7 +52,7 @@ namespace ngraph
auto& halide_functions = external_function->get_halide_functions(); auto& halide_functions = external_function->get_halide_functions();
auto& subgraph_params = external_function->get_subgraph_params(); auto& subgraph_params = external_function->get_subgraph_params();
auto& subgraph_param_sizes = external_function->get_subgraph_param_sizes(); auto& subgraph_param_sizes = external_function->get_subgraph_param_sizes();
auto& subgraph_param_ptrs = external_function->get_subgraph_param_ptrs(); auto& subgraph_param_indices = external_function->get_subgraph_param_indices();
for (const auto& op : hs->get_ops()) for (const auto& op : hs->get_ops())
{ {
...@@ -73,8 +73,8 @@ namespace ngraph ...@@ -73,8 +73,8 @@ namespace ngraph
subgraph_params[tensor_name] = Halide::ImageParam(Halide::Float(32), 1); subgraph_params[tensor_name] = Halide::ImageParam(Halide::Float(32), 1);
subgraph_param_sizes[tensor_name] = subgraph_param_sizes[tensor_name] =
shape_size(input.get_output().get_tensor_ptr()->get_shape()); shape_size(input.get_output().get_tensor_ptr()->get_shape());
subgraph_param_ptrs.emplace( subgraph_param_indices.emplace(
tensor_name, external_function->get_tensor_data(tensor_name)); tensor_name, external_function->get_buffer_index(tensor_name));
inputs.emplace_back(subgraph_params[tensor_name]); inputs.emplace_back(subgraph_params[tensor_name]);
} }
} }
...@@ -84,19 +84,22 @@ namespace ngraph ...@@ -84,19 +84,22 @@ namespace ngraph
auto out_tensor_name = hs->get_ops().back()->get_output_tensor_ptr()->get_name(); auto out_tensor_name = hs->get_ops().back()->get_output_tensor_ptr()->get_name();
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& terminal_func = halide_functions[out_tensor_name]; auto& terminal_func = halide_functions[out_tensor_name];
auto out_size = out[0].get_size(); auto out_size = out[0].get_size();
auto functor = [&, out_size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) { auto functor = [&, out_size, out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
for (auto& param : subgraph_params) for (auto& param : subgraph_params)
{ {
Halide::Buffer<float> param_buffer( Halide::Buffer<float> param_buffer(
static_cast<float*>(subgraph_param_ptrs.at(param.first).get()), static_cast<float*>(
ctx->buffer_data[subgraph_param_indices.at(param.first)]),
subgraph_param_sizes.at(param.first)); subgraph_param_sizes.at(param.first));
param.second.set(param_buffer); param.second.set(param_buffer);
} }
Halide::Buffer<float> out_buffer(static_cast<float*>(out_tensor), out_size); Halide::Buffer<float> out_buffer(
static_cast<float*>(ctx->buffer_data[out_buffer_index]), out_size);
terminal_func.realize(out_buffer); terminal_func.realize(out_buffer);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -34,8 +34,8 @@ namespace ngraph ...@@ -34,8 +34,8 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& input_tensor = external_function->get_tensor_data(args[0].get_name()); auto input_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
size_t count = out[0].get_size(); size_t count = out[0].get_size();
auto alpha = static_cast<const ngraph::op::LeakyRelu*>(node)->get_alpha(); auto alpha = static_cast<const ngraph::op::LeakyRelu*>(node)->get_alpha();
...@@ -48,14 +48,21 @@ namespace ngraph ...@@ -48,14 +48,21 @@ namespace ngraph
auto leaky_relu_index = mkldnn_emitter->reserve_primitive_space(3); auto leaky_relu_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(leaky_relu_index); auto& deps = mkldnn_emitter->get_primitive_deps(leaky_relu_index);
auto functor = [&, leaky_relu_desc, leaky_relu_index]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { leaky_relu_desc,
leaky_relu_index,
input_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_leaky_relu(leaky_relu_desc, leaky_relu_index); mkldnn_emitter->build_leaky_relu(
ctx->mkldnn_primitives, leaky_relu_desc, deps, leaky_relu_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], input_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor); ctx, deps[0], ctx->buffer_data[input_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, leaky_relu_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, leaky_relu_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
...@@ -67,9 +74,13 @@ namespace ngraph ...@@ -67,9 +74,13 @@ namespace ngraph
SELECT_KERNEL( SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::leaky_relu); kernel, out[0].get_element_type(), runtime::cpu::kernel::leaky_relu);
auto functor = [&, kernel, alpha, count](CPURuntimeContext* ctx, auto functor = [&, kernel, alpha, count, input_buffer_index, out_buffer_index](
CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(input_tensor, out_tensor, alpha, count, ectx->arena); kernel(ctx->buffer_data[input_buffer_index],
ctx->buffer_data[out_buffer_index],
alpha,
count,
ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -59,7 +59,7 @@ namespace ngraph ...@@ -59,7 +59,7 @@ namespace ngraph
auto& halide_functions = external_function->get_halide_functions(); auto& halide_functions = external_function->get_halide_functions();
auto& subgraph_params = external_function->get_subgraph_params(); auto& subgraph_params = external_function->get_subgraph_params();
auto& subgraph_param_sizes = external_function->get_subgraph_param_sizes(); auto& subgraph_param_sizes = external_function->get_subgraph_param_sizes();
auto& subgraph_param_ptrs = external_function->get_subgraph_param_ptrs(); auto& subgraph_param_indices = external_function->get_subgraph_param_indices();
std::set<std::string> param_names; std::set<std::string> param_names;
for (const auto& op : hs->get_node_list()) for (const auto& op : hs->get_node_list())
...@@ -85,8 +85,8 @@ namespace ngraph ...@@ -85,8 +85,8 @@ namespace ngraph
Halide::ImageParam(Halide::Float(32), 1, tensor_name); Halide::ImageParam(Halide::Float(32), 1, tensor_name);
subgraph_param_sizes[tensor_name] = subgraph_param_sizes[tensor_name] =
shape_size(input.get_output().get_tensor_ptr()->get_shape()); shape_size(input.get_output().get_tensor_ptr()->get_shape());
subgraph_param_ptrs.emplace( subgraph_param_indices.emplace(
tensor_name, external_function->get_tensor_data(tensor_name)); tensor_name, external_function->get_buffer_index(tensor_name));
inputs.emplace_back(subgraph_params[tensor_name]); inputs.emplace_back(subgraph_params[tensor_name]);
} }
else else
...@@ -107,7 +107,7 @@ namespace ngraph ...@@ -107,7 +107,7 @@ namespace ngraph
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
std::vector<std::tuple<void*&, size_t>> buffers_data; std::vector<std::tuple<size_t, size_t>> buffers_data;
std::vector<Halide::Expr> results; std::vector<Halide::Expr> results;
auto output_nodes = hs->get_kernel_outputs(); auto output_nodes = hs->get_kernel_outputs();
...@@ -117,9 +117,9 @@ namespace ngraph ...@@ -117,9 +117,9 @@ namespace ngraph
auto result_func = auto result_func =
halide_functions[output_nodes.at(i)->get_output_tensor_ptr()->get_name()]; halide_functions[output_nodes.at(i)->get_output_tensor_ptr()->get_name()];
results.push_back((result_func(x) + 0)); results.push_back((result_func(x) + 0));
auto& out_tensor = external_function->get_tensor_data(out[i].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[i].get_name());
buffers_data.push_back( buffers_data.push_back(
std::tuple<void*&, size_t>(out_tensor, out[i].get_size())); std::tuple<size_t, size_t>(out_buffer_index, out[i].get_size()));
} }
Halide::Func terminal_func; Halide::Func terminal_func;
...@@ -131,7 +131,7 @@ namespace ngraph ...@@ -131,7 +131,7 @@ namespace ngraph
for (auto& param : param_names) for (auto& param : param_names)
{ {
Halide::Buffer<float> param_buffer( Halide::Buffer<float> param_buffer(
static_cast<float*>(subgraph_param_ptrs.at(param).get()), static_cast<float*>(ctx->buffer_data[subgraph_param_indices.at(param)]),
subgraph_param_sizes.at(param)); subgraph_param_sizes.at(param));
subgraph_params[param].set(param_buffer); subgraph_params[param].set(param_buffer);
} }
...@@ -139,7 +139,8 @@ namespace ngraph ...@@ -139,7 +139,8 @@ namespace ngraph
for (auto tuple : buffers_data) for (auto tuple : buffers_data)
{ {
buffers.push_back(Halide::Buffer<float>( buffers.push_back(Halide::Buffer<float>(
static_cast<float*>(std::get<0>(tuple)), std::get<1>(tuple))); static_cast<float*>(ctx->buffer_data[std::get<0>(tuple)]),
std::get<1>(tuple)));
} }
Halide::Realization r(buffers); Halide::Realization r(buffers);
terminal_func.realize(r); terminal_func.realize(r);
......
...@@ -37,8 +37,8 @@ namespace ngraph ...@@ -37,8 +37,8 @@ namespace ngraph
const ngraph::op::LRN* lrn = static_cast<const ngraph::op::LRN*>(node); const ngraph::op::LRN* lrn = static_cast<const ngraph::op::LRN*>(node);
CPUKernelFunctor functor; CPUKernelFunctor functor;
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
...@@ -48,14 +48,17 @@ namespace ngraph ...@@ -48,14 +48,17 @@ namespace ngraph
auto lrn_index = mkldnn_emitter->reserve_primitive_space(3); auto lrn_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(lrn_index); auto& deps = mkldnn_emitter->get_primitive_deps(lrn_index);
functor = [&, lrn_desc, lrn_index](CPURuntimeContext* ctx, functor = [&, lrn_desc, lrn_index, arg_buffer_index, out_buffer_index](
CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_lrn_forward(lrn_desc, lrn_index); mkldnn_emitter->build_lrn_forward(
ctx->mkldnn_primitives, lrn_desc, deps, lrn_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor); ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, lrn_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, lrn_index);
}; };
} }
...@@ -70,10 +73,18 @@ namespace ngraph ...@@ -70,10 +73,18 @@ namespace ngraph
auto element_type = lrn->get_element_type(); auto element_type = lrn->get_element_type();
if (element_type == element::f32) if (element_type == element::f32)
{ {
functor = [&, alpha, beta, bias, arg_shape, nsize]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { alpha,
ngraph::runtime::reference::lrn<float>(static_cast<float*>(arg_tensor), beta,
static_cast<float*>(out_tensor), bias,
arg_shape,
nsize,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::lrn<float>(
static_cast<float*>(ctx->buffer_data[arg_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
arg_shape, arg_shape,
alpha, alpha,
beta, beta,
...@@ -83,11 +94,18 @@ namespace ngraph ...@@ -83,11 +94,18 @@ namespace ngraph
} }
else if (element_type == element::f64) else if (element_type == element::f64)
{ {
functor = [&, alpha, beta, bias, arg_shape, nsize]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { alpha,
beta,
bias,
arg_shape,
nsize,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::lrn<double>( ngraph::runtime::reference::lrn<double>(
static_cast<double*>(arg_tensor), static_cast<double*>(ctx->buffer_data[arg_buffer_index]),
static_cast<double*>(out_tensor), static_cast<double*>(ctx->buffer_data[out_buffer_index]),
arg_shape, arg_shape,
alpha, alpha,
beta, beta,
......
...@@ -45,13 +45,18 @@ namespace ngraph ...@@ -45,13 +45,18 @@ namespace ngraph
} }
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& src_layer_tensor = external_function->get_tensor_data(args[0].get_name()); auto src_layer_buffer_index =
auto& src_iter_tensor = external_function->get_tensor_data(args[1].get_name()); external_function->get_buffer_index(args[0].get_name());
auto& weights_layer_tensor = external_function->get_tensor_data(args[2].get_name()); auto src_iter_buffer_index =
auto& weights_iter_tensor = external_function->get_tensor_data(args[3].get_name()); external_function->get_buffer_index(args[1].get_name());
auto& bias_tensor = external_function->get_tensor_data(args[4].get_name()); auto weights_layer_buffer_index =
auto& dst_layer_tensor = external_function->get_tensor_data(out[0].get_name()); external_function->get_buffer_index(args[2].get_name());
auto& dst_iter_tensor = external_function->get_tensor_data(out[1].get_name()); auto weights_iter_buffer_index =
external_function->get_buffer_index(args[3].get_name());
auto bias_buffer_index = external_function->get_buffer_index(args[4].get_name());
auto dst_layer_buffer_index =
external_function->get_buffer_index(out[0].get_name());
auto dst_iter_buffer_index = external_function->get_buffer_index(out[1].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto lstm_desc = auto lstm_desc =
...@@ -63,20 +68,39 @@ namespace ngraph ...@@ -63,20 +68,39 @@ namespace ngraph
mkldnn_emitter->reserve_primitive_space(9, true /* new workspace */); mkldnn_emitter->reserve_primitive_space(9, true /* new workspace */);
auto& deps = mkldnn_emitter->get_primitive_deps(lstm_index); auto& deps = mkldnn_emitter->get_primitive_deps(lstm_index);
auto functor = [&, lstm_desc, lstm_index](CPURuntimeContext* ctx, auto functor = [&,
lstm_desc,
lstm_index,
src_layer_buffer_index,
src_iter_buffer_index,
weights_layer_buffer_index,
weights_iter_buffer_index,
bias_buffer_index,
dst_layer_buffer_index,
dst_iter_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_rnn_forward(lstm_desc, lstm_index); mkldnn_emitter->build_rnn_forward(ctx->mkldnn_primitives,
ctx->mkldnn_workspaces = mkldnn_emitter->get_mkldnn_workspaces().data(); ctx->mkldnn_workspaces,
lstm_desc,
deps,
lstm_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], src_layer_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], src_iter_tensor); ctx, deps[0], ctx->buffer_data[src_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], weights_layer_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[3], weights_iter_tensor); ctx, deps[1], ctx->buffer_data[src_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[4], bias_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[5], dst_layer_tensor); ctx, deps[2], ctx->buffer_data[weights_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[6], dst_iter_tensor); cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[3], ctx->buffer_data[weights_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[4], ctx->buffer_data[bias_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[5], ctx->buffer_data[dst_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[6], ctx->buffer_data[dst_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr( cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[7], ctx->mkldnn_workspaces[deps[8]]); ctx, deps[7], ctx->mkldnn_workspaces[deps[8]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, lstm_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, lstm_index);
......
This diff is collapsed.
...@@ -39,17 +39,22 @@ namespace ngraph ...@@ -39,17 +39,22 @@ namespace ngraph
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (arg_rank == 0) if (arg_rank == 0)
{ {
std::function<decltype(runtime::cpu::kernel::one_hot_rank_0<float>)> kernel; std::function<decltype(runtime::cpu::kernel::one_hot_rank_0<float>)> kernel;
SELECT_KERNEL( SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::one_hot_rank_0); kernel, out[0].get_element_type(), runtime::cpu::kernel::one_hot_rank_0);
auto functor = [&, kernel, out_shape, one_hot_axis](CPURuntimeContext* ctx, auto functor =
CPUExecutionContext* ectx) { [&, kernel, out_shape, one_hot_axis, arg_buffer_index, out_buffer_index](
kernel(arg_tensor, out_tensor, out_shape, one_hot_axis, ectx->arena); CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
out_shape,
one_hot_axis,
ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
...@@ -59,10 +64,16 @@ namespace ngraph ...@@ -59,10 +64,16 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::one_hot_rank_1<float>)> kernel; std::function<decltype(runtime::cpu::kernel::one_hot_rank_1<float>)> kernel;
SELECT_KERNEL( SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::one_hot_rank_1); kernel, out[0].get_element_type(), runtime::cpu::kernel::one_hot_rank_1);
auto functor = [&, kernel, arg_shape, out_shape, one_hot_axis]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, arg_shape,
out_tensor, out_shape,
one_hot_axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape, arg_shape,
out_shape, out_shape,
one_hot_axis, one_hot_axis,
...@@ -78,9 +89,19 @@ namespace ngraph ...@@ -78,9 +89,19 @@ namespace ngraph
SELECT_KERNEL(kernel, SELECT_KERNEL(kernel,
out[0].get_element_type(), out[0].get_element_type(),
runtime::cpu::kernel::one_hot_rank_2_or_more); runtime::cpu::kernel::one_hot_rank_2_or_more);
auto functor = [&, kernel, arg_shape, out_shape, one_hot_axis]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, arg_shape, out_shape, one_hot_axis); arg_shape,
out_shape,
one_hot_axis,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape,
out_shape,
one_hot_axis);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -38,9 +38,9 @@ namespace ngraph ...@@ -38,9 +38,9 @@ namespace ngraph
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& padding_value = external_function->get_tensor_data(args[1].get_name()); auto padding_value_index = external_function->get_buffer_index(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto pad = static_cast<const ngraph::op::Pad*>(node); auto pad = static_cast<const ngraph::op::Pad*>(node);
...@@ -59,11 +59,19 @@ namespace ngraph ...@@ -59,11 +59,19 @@ namespace ngraph
arg_shape.size(), arg_shape.size(),
runtime::cpu::kernel::pad_and_slice); runtime::cpu::kernel::pad_and_slice);
auto functor = [&, kernel, arg_shape, out_shape, padding_below, padding_above]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, arg_shape,
out_tensor, out_shape,
padding_value, padding_below,
padding_above,
arg_buffer_index,
padding_value_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
ctx->buffer_data[padding_value_index],
arg_shape, arg_shape,
out_shape, out_shape,
CoordinateDiff(padding_below.begin(), padding_below.end()), CoordinateDiff(padding_below.begin(), padding_below.end()),
...@@ -79,12 +87,20 @@ namespace ngraph ...@@ -79,12 +87,20 @@ namespace ngraph
SELECT_KERNEL( SELECT_KERNEL(
kernel, args[0].get_element_type(), runtime::cpu::kernel::pad_ref); kernel, args[0].get_element_type(), runtime::cpu::kernel::pad_ref);
auto functor = auto functor = [&,
[&, kernel, arg_shape, out_shape, padding_below, padding_above, pad_mode]( kernel,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { arg_shape,
kernel(arg_tensor, out_shape,
padding_value, padding_below,
out_tensor, padding_above,
pad_mode,
arg_buffer_index,
padding_value_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[padding_value_index],
ctx->buffer_data[out_buffer_index],
arg_shape, arg_shape,
out_shape, out_shape,
padding_below, padding_below,
......
...@@ -36,8 +36,8 @@ namespace ngraph ...@@ -36,8 +36,8 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto qavg_pool_desc = auto qavg_pool_desc =
...@@ -47,14 +47,18 @@ namespace ngraph ...@@ -47,14 +47,18 @@ namespace ngraph
size_t qavg_pool_index = mkldnn_emitter->reserve_primitive_space(3); size_t qavg_pool_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(qavg_pool_index); auto& deps = mkldnn_emitter->get_primitive_deps(qavg_pool_index);
auto functor = [&, qavg_pool_desc, qavg_pool_index](CPURuntimeContext* ctx, auto functor =
CPUExecutionContext* ectx) { [&, qavg_pool_desc, qavg_pool_index, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_pooling_forward(qavg_pool_desc, qavg_pool_index); mkldnn_emitter->build_pooling_forward(
ctx->mkldnn_primitives, qavg_pool_desc, deps, qavg_pool_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor); ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, qavg_pool_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, qavg_pool_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -35,41 +35,55 @@ namespace ngraph ...@@ -35,41 +35,55 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
vector<reference_wrapper<void*>> arg_tensors; vector<size_t> arg_buffer_indices;
for (auto& arg : args) for (auto& arg : args)
{ {
if (shape_size(arg.get_shape())) if (shape_size(arg.get_shape()))
{ {
arg_tensors.emplace_back( arg_buffer_indices.emplace_back(
external_function->get_tensor_data(arg.get_name())); external_function->get_buffer_index(arg.get_name()));
} }
} }
auto nargs = args.size(); auto nargs = args.size();
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto concat_pd =
mkldnn_emitter->get_concat_desc<ngraph::op::QuantizedConcat>(node, nargs);
std::vector<mkldnn::memory::desc> inputs_data_desc; std::vector<mkldnn::memory::desc> inputs_data_desc;
for (size_t i = 0; i < args.size(); i++) for (size_t i = 0; i < args.size(); i++)
{ {
inputs_data_desc.push_back(mkldnn_utils::get_input_mkldnn_md(node, i)); inputs_data_desc.push_back(mkldnn_utils::get_input_mkldnn_md(node, i));
} }
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0); // Concat needs number of inputs plus 2 primitives; those two are for result and concat.
auto concat_index = mkldnn_emitter->reserve_primitive_space(nargs + 2);
size_t concat_dim = (static_cast<const ngraph::op::QuantizedConcat*>(node))
->get_concatenation_axis();
auto concat_index =
mkldnn_emitter->build_concat(inputs_data_desc, result_desc, concat_dim);
auto& deps = mkldnn_emitter->get_primitive_deps(concat_index); auto& deps = mkldnn_emitter->get_primitive_deps(concat_index);
auto functor = [&, arg_tensors, nargs, concat_index]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { concat_pd,
inputs_data_desc,
arg_buffer_indices,
nargs,
concat_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_concat(ctx->mkldnn_primitives,
concat_pd,
inputs_data_desc,
deps,
concat_index);
}
for (size_t i = 0; i < nargs; i++) for (size_t i = 0; i < nargs; i++)
{ {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[i], arg_tensors[i]); cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[i], ctx->buffer_data[arg_buffer_indices[i]]);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[nargs], out_tensor); cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[nargs], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, concat_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, concat_index);
}; };
......
...@@ -43,11 +43,15 @@ namespace ngraph ...@@ -43,11 +43,15 @@ namespace ngraph
"Unsupported data types for QuantizedDot MKLDNN kernel."); "Unsupported data types for QuantizedDot MKLDNN kernel.");
} }
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg0_buffer_index =
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name()); external_function->get_buffer_index(args[0].get_name());
auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name()); auto arg1_buffer_index =
auto& arg3_tensor = external_function->get_tensor_data(args[3].get_name()); external_function->get_buffer_index(args[1].get_name());
auto& out0_tensor = external_function->get_tensor_data(out[0].get_name()); auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name());
auto arg3_buffer_index =
external_function->get_buffer_index(args[3].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto scales_size = shape_size(args[3].get_shape()); auto scales_size = shape_size(args[3].get_shape());
...@@ -61,21 +65,42 @@ namespace ngraph ...@@ -61,21 +65,42 @@ namespace ngraph
size_t ip_index = mkldnn_emitter->inner_product_forward_init(true); size_t ip_index = mkldnn_emitter->inner_product_forward_init(true);
auto& deps = mkldnn_emitter->get_primitive_deps(ip_index); auto& deps = mkldnn_emitter->get_primitive_deps(ip_index);
auto functor = [&, scales_size, ip_desc, ip_attr, deps, ip_index]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) mutable { scales_size,
ip_desc,
ip_attr,
deps,
ip_index,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
arg3_buffer_index,
out0_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) mutable {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
vector<float> dyn_scales; vector<float> dyn_scales;
dyn_scales.assign(static_cast<float*>(arg3_tensor), dyn_scales.assign(
static_cast<float*>(arg3_tensor) + scales_size); static_cast<float*>(ctx->buffer_data[arg3_buffer_index]),
static_cast<float*>(ctx->buffer_data[arg3_buffer_index]) +
scales_size);
ip_attr.set_output_scales(0, dyn_scales); ip_attr.set_output_scales(0, dyn_scales);
mkldnn_emitter->build_inner_product_forward<true>( mkldnn_emitter->build_inner_product_forward<true>(
ip_desc, ip_attr, executor::global_cpu_engine, ip_index); ctx->mkldnn_primitives,
ip_desc,
ip_attr,
executor::global_cpu_engine,
deps,
ip_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], arg1_tensor); ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], arg2_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[3], out0_tensor); ctx, deps[1], ctx->buffer_data[arg1_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[arg2_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[3], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, ip_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, ip_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
...@@ -92,10 +117,13 @@ namespace ngraph ...@@ -92,10 +117,13 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg0_buffer_index =
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name()); external_function->get_buffer_index(args[0].get_name());
auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name()); auto arg1_buffer_index =
auto& out0_tensor = external_function->get_tensor_data(out[0].get_name()); external_function->get_buffer_index(args[1].get_name());
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto scales_size = shape_size(args[2].get_shape()); auto scales_size = shape_size(args[2].get_shape());
...@@ -109,20 +137,39 @@ namespace ngraph ...@@ -109,20 +137,39 @@ namespace ngraph
size_t ip_index = mkldnn_emitter->inner_product_forward_init(false); size_t ip_index = mkldnn_emitter->inner_product_forward_init(false);
auto& deps = mkldnn_emitter->get_primitive_deps(ip_index); auto& deps = mkldnn_emitter->get_primitive_deps(ip_index);
auto functor = [&, scales_size, ip_desc, ip_attr, deps, ip_index]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) mutable { scales_size,
ip_desc,
ip_attr,
deps,
ip_index,
arg0_buffer_index,
arg1_buffer_index,
arg2_buffer_index,
out0_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) mutable {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
vector<float> dyn_scales; vector<float> dyn_scales;
dyn_scales.assign(static_cast<float*>(arg2_tensor), dyn_scales.assign(
static_cast<float*>(arg2_tensor) + scales_size); static_cast<float*>(ctx->buffer_data[arg2_buffer_index]),
static_cast<float*>(ctx->buffer_data[arg2_buffer_index]) +
scales_size);
ip_attr.set_output_scales(0, dyn_scales); ip_attr.set_output_scales(0, dyn_scales);
mkldnn_emitter->build_inner_product_forward<false>( mkldnn_emitter->build_inner_product_forward<false>(
ip_desc, ip_attr, executor::global_cpu_engine, ip_index); ctx->mkldnn_primitives,
ip_desc,
ip_attr,
executor::global_cpu_engine,
deps,
ip_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg0_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], arg1_tensor); ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out0_tensor); cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[arg1_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[out0_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, ip_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, ip_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -36,8 +36,8 @@ namespace ngraph ...@@ -36,8 +36,8 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto qmax_pool_desc = auto qmax_pool_desc =
...@@ -47,14 +47,18 @@ namespace ngraph ...@@ -47,14 +47,18 @@ namespace ngraph
size_t qmax_pool_index = mkldnn_emitter->reserve_primitive_space(3); size_t qmax_pool_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(qmax_pool_index); auto& deps = mkldnn_emitter->get_primitive_deps(qmax_pool_index);
auto functor = [&, qmax_pool_desc, qmax_pool_index](CPURuntimeContext* ctx, auto functor =
CPUExecutionContext* ectx) { [&, qmax_pool_desc, qmax_pool_index, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_pooling_forward(qmax_pool_desc, qmax_pool_index); mkldnn_emitter->build_pooling_forward(
ctx->mkldnn_primitives, qmax_pool_desc, deps, qmax_pool_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor); ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, qmax_pool_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, qmax_pool_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -37,17 +37,19 @@ namespace ngraph ...@@ -37,17 +37,19 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto reduce = static_cast<const ngraph::op::Any*>(node); auto reduce = static_cast<const ngraph::op::Any*>(node);
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto reduction_axes = reduce->get_reduction_axes(); auto reduction_axes = reduce->get_reduction_axes();
auto functor = [&, arg0_shape, out_shape, reduction_axes]( auto functor =
[&, arg0_shape, out_shape, reduction_axes, arg0_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
runtime::reference::any(static_cast<char*>(arg0_tensor), runtime::reference::any(
static_cast<char*>(out_tensor), static_cast<char*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<char*>(ctx->buffer_data[out_buffer_index]),
arg0_shape, arg0_shape,
out_shape, out_shape,
reduction_axes); reduction_axes);
...@@ -60,17 +62,19 @@ namespace ngraph ...@@ -60,17 +62,19 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto reduce = static_cast<const ngraph::op::All*>(node); auto reduce = static_cast<const ngraph::op::All*>(node);
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto reduction_axes = reduce->get_reduction_axes(); auto reduction_axes = reduce->get_reduction_axes();
auto functor = [&, arg0_shape, out_shape, reduction_axes]( auto functor =
[&, arg0_shape, out_shape, reduction_axes, arg0_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
runtime::reference::all(static_cast<char*>(arg0_tensor), runtime::reference::all(
static_cast<char*>(out_tensor), static_cast<char*>(ctx->buffer_data[arg0_buffer_index]),
static_cast<char*>(ctx->buffer_data[out_buffer_index]),
arg0_shape, arg0_shape,
out_shape, out_shape,
reduction_axes); reduction_axes);
......
...@@ -36,8 +36,8 @@ namespace ngraph ...@@ -36,8 +36,8 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto relu_desc = mkldnn_emitter->get_relu_forward_desc(node); auto relu_desc = mkldnn_emitter->get_relu_forward_desc(node);
...@@ -45,14 +45,17 @@ namespace ngraph ...@@ -45,14 +45,17 @@ namespace ngraph
size_t relu_index = mkldnn_emitter->reserve_primitive_space(3); size_t relu_index = mkldnn_emitter->reserve_primitive_space(3);
auto& deps = mkldnn_emitter->get_primitive_deps(relu_index); auto& deps = mkldnn_emitter->get_primitive_deps(relu_index);
auto functor = [&, relu_desc, relu_index](CPURuntimeContext* ctx, auto functor = [&, relu_desc, relu_index, arg_buffer_index, out_buffer_index](
CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_relu_forward(relu_desc, relu_index); mkldnn_emitter->build_relu_forward(
ctx->mkldnn_primitives, relu_desc, deps, relu_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor); ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, relu_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, relu_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
...@@ -68,9 +71,9 @@ namespace ngraph ...@@ -68,9 +71,9 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_fwd_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_fwd_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& delta_tensor = external_function->get_tensor_data(args[1].get_name()); auto delta_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
size_t count = out[0].get_size(); size_t count = out[0].get_size();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
...@@ -82,15 +85,25 @@ namespace ngraph ...@@ -82,15 +85,25 @@ namespace ngraph
size_t relu_index = mkldnn_emitter->reserve_primitive_space(4); size_t relu_index = mkldnn_emitter->reserve_primitive_space(4);
auto& deps = mkldnn_emitter->get_primitive_deps(relu_index); auto& deps = mkldnn_emitter->get_primitive_deps(relu_index);
auto functor = [&, bwd_desc, fwd_desc, relu_index](CPURuntimeContext* ctx, auto functor = [&,
bwd_desc,
fwd_desc,
relu_index,
arg_fwd_buffer_index,
delta_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
if (ctx->first_iteration) if (ctx->first_iteration)
{ {
mkldnn_emitter->build_relu_backward(bwd_desc, fwd_desc, relu_index); mkldnn_emitter->build_relu_backward(
ctx->mkldnn_primitives, bwd_desc, fwd_desc, deps, relu_index);
} }
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_fwd_tensor); cpu::mkldnn_utils::set_memory_ptr(
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], delta_tensor); ctx, deps[0], ctx->buffer_data[arg_fwd_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out_tensor); cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[delta_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[out_buffer_index]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, relu_index); cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, relu_index);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
...@@ -102,9 +115,18 @@ namespace ngraph ...@@ -102,9 +115,18 @@ namespace ngraph
SELECT_KERNEL( SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::relu_backprop); kernel, out[0].get_element_type(), runtime::cpu::kernel::relu_backprop);
auto functor = [&, kernel, count](CPURuntimeContext* ctx, auto functor = [&,
kernel,
count,
arg_fwd_buffer_index,
delta_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) { CPUExecutionContext* ectx) {
kernel(arg_fwd_tensor, delta_tensor, out_tensor, count, ectx->arena); kernel(ctx->buffer_data[arg_fwd_buffer_index],
ctx->buffer_data[delta_buffer_index],
ctx->buffer_data[out_buffer_index],
count,
ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -34,10 +34,10 @@ namespace ngraph ...@@ -34,10 +34,10 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name()); auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto replace_slice = static_cast<const ngraph::op::ReplaceSlice*>(node); auto replace_slice = static_cast<const ngraph::op::ReplaceSlice*>(node);
...@@ -61,8 +61,11 @@ namespace ngraph ...@@ -61,8 +61,11 @@ namespace ngraph
if (!arg0_shape.size()) if (!arg0_shape.size())
{ {
size_t size = args[0].get_element_type().size(); size_t size = args[0].get_element_type().size();
auto functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) { auto functor = [&, size, arg1_buffer_index, out_buffer_index](
memcpy(out_tensor, arg1_tensor, size); CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(ctx->buffer_data[out_buffer_index],
ctx->buffer_data[arg1_buffer_index],
size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
return; return;
...@@ -78,12 +81,20 @@ namespace ngraph ...@@ -78,12 +81,20 @@ namespace ngraph
arg0_shape.size(), arg0_shape.size(),
runtime::cpu::kernel::strided_replace_slice); runtime::cpu::kernel::strided_replace_slice);
auto functor = auto functor = [&,
[&, kernel, arg0_shape, arg1_shape, lower_bounds, upper_bounds, strides]( kernel,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { arg0_shape,
kernel(arg0_tensor, arg1_shape,
arg1_tensor, lower_bounds,
out_tensor, upper_bounds,
strides,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[out_buffer_index],
arg0_shape, arg0_shape,
arg1_shape, arg1_shape,
lower_bounds, lower_bounds,
...@@ -102,11 +113,18 @@ namespace ngraph ...@@ -102,11 +113,18 @@ namespace ngraph
arg0_shape.size(), arg0_shape.size(),
runtime::cpu::kernel::replace_slice); runtime::cpu::kernel::replace_slice);
auto functor = [&, kernel, arg0_shape, arg1_shape, lower_bounds]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg0_tensor, arg0_shape,
arg1_tensor, arg1_shape,
out_tensor, lower_bounds,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[out_buffer_index],
arg0_shape, arg0_shape,
arg1_shape, arg1_shape,
lower_bounds, lower_bounds,
......
...@@ -167,8 +167,8 @@ namespace ngraph ...@@ -167,8 +167,8 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
std::function<decltype(runtime::cpu::kernel::reshape_1d<float, 2>)> kernel; std::function<decltype(runtime::cpu::kernel::reshape_1d<float, 2>)> kernel;
std::function<decltype(runtime::cpu::kernel::reshape_ref<float>)> ref_kernel; std::function<decltype(runtime::cpu::kernel::reshape_ref<float>)> ref_kernel;
...@@ -188,10 +188,16 @@ namespace ngraph ...@@ -188,10 +188,16 @@ namespace ngraph
CPUKernelFunctor functor; CPUKernelFunctor functor;
if (kernel) if (kernel)
{ {
functor = [&, kernel, arg_shape, input_order, result_shape]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, arg_shape,
out_tensor, input_order,
result_shape,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape, arg_shape,
input_order, input_order,
result_shape, result_shape,
...@@ -200,10 +206,16 @@ namespace ngraph ...@@ -200,10 +206,16 @@ namespace ngraph
} }
else if (ref_kernel) else if (ref_kernel)
{ {
functor = [&, ref_kernel, arg_shape, input_order, result_shape]( functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { ref_kernel,
ref_kernel(arg_tensor, arg_shape,
out_tensor, input_order,
result_shape,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ref_kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape, arg_shape,
input_order, input_order,
result_shape, result_shape,
...@@ -212,17 +224,24 @@ namespace ngraph ...@@ -212,17 +224,24 @@ namespace ngraph
} }
else if (skip_reshape) else if (skip_reshape)
{ {
functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) { functor = [&, size, arg_buffer_index, out_buffer_index](
if (out_tensor != arg_tensor) CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (ctx->buffer_data[out_buffer_index] !=
ctx->buffer_data[arg_buffer_index])
{ {
memcpy(out_tensor, arg_tensor, size); memcpy(ctx->buffer_data[out_buffer_index],
ctx->buffer_data[arg_buffer_index],
size);
} }
}; };
} }
else else
{ {
functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) { functor = [&, size, arg_buffer_index, out_buffer_index](
memcpy(out_tensor, arg_tensor, size); CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(ctx->buffer_data[out_buffer_index],
ctx->buffer_data[arg_buffer_index],
size);
}; };
} }
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -34,8 +34,8 @@ namespace ngraph ...@@ -34,8 +34,8 @@ namespace ngraph
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto arg_shape = args[0].get_shape(); auto arg_shape = args[0].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
...@@ -45,9 +45,19 @@ namespace ngraph ...@@ -45,9 +45,19 @@ namespace ngraph
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::reverse); SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::reverse);
auto functor = [&, kernel, arg_shape, result_shape, reversed_axes]( auto functor = [&,
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { kernel,
kernel(arg_tensor, out_tensor, arg_shape, result_shape, reversed_axes); arg_shape,
result_shape,
reversed_axes,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape,
result_shape,
reversed_axes);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment