Commit e2f55f83 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Add a compile time flag to support all ET's in cpu backend (#3350)

* Refactor cpu/builder/dot.cpp

* Support only float and i64

* Refactor softmax and reduction to support only float and i64

* Add a new file for customized kernels

* Add compile time flag NGRAPH_CPU_LARGE_BINARY

* Add compilation flag checks in reduction.hpp

* change dot to only support float and i64

* Revert "Refactor cpu/builder/dot.cpp"

This reverts commit 0d53f27fde64872aff096f12ee9b79e5db7a7fee.

* style

* Consolidate macros

* Refactor slice ops, reshape & pad

* cleanup

* Gather op

* Concat, Convert and broadcast op

* Tile and reverse_sequence op

* Scatter_add op

* ET other then fp and i64 go through ref

* tests passing

* Consolidate macros

* Address feedback

* Fall back to reference for pad and reshape

* add scatter_add reference fallback

* Undo concat change

* Undo tile and update slice as they dont have reference implementation

* Remove update slice condition

* Gather op type check

* change routine name

* Build-time configurability for datatypes that requires optimized kernels on CPU backend (#3592)

* Inline function

* Change condition

* VS compiler workaround

* Add comment for VS workaround

* More fixes for VS

* More wrapping of nested macros for VS

* More wrapper

* variable name change and more wrapping of macros

* test

* Style and refactor

* Wrap macros

* Add a seperate macro for SELECT_KERNEL

* Change SELECT_KERNEL_3ARGS

* Unwrap couple of macros

* Syntax

* Add a new macro for fixed number of args for VS compiler

* Comment/Fake ectx

* Add detailed comment for the workaround for VS compiler

* Comment all unused ectx variables

* Templated calls to reference kernels

* const args

* Change softmax ref definition to take double

* Hardcode softmax kernel to take double ..testing

* Fix softmax
parent b39d0aab
...@@ -156,9 +156,42 @@ if (NGRAPH_MLIR_ENABLE) ...@@ -156,9 +156,42 @@ if (NGRAPH_MLIR_ENABLE)
) )
endif() endif()
set(NGRAPH_CPU_ALL_DATATYPES
boolean
f32
f64
i8
i16
i32
i64
u8
u16
u32
u64
)
set(NGRAPH_CPU_COMMON_DATATYPES
f32
i64
)
if (NGRAPH_CPU_ENABLE) if (NGRAPH_CPU_ENABLE)
set(NGRAPH_CPU_DEBUGINFO_ENABLE 0 CACHE STRING "Enable debuginfo in the CPU backend") set(NGRAPH_CPU_DEBUGINFO_ENABLE 0 CACHE STRING "Enable debuginfo in the CPU backend")
set(NGRAPH_CPU_OPTIMIZED_DATATYPES "common"
CACHE STRING "Semicolon-separated list of datatypes to optimize for, or \"common\" or \"all\".")
if (NGRAPH_CPU_OPTIMIZED_DATATYPES STREQUAL "all")
set(NGRAPH_CPU_OPTIMIZED_DATATYPES ${NGRAPH_CPU_ALL_DATATYPES})
endif()
if (NGRAPH_CPU_OPTIMIZED_DATATYPES STREQUAL "common")
set(NGRAPH_CPU_OPTIMIZED_DATATYPES ${NGRAPH_CPU_COMMON_DATATYPES})
endif()
list(REMOVE_DUPLICATES NGRAPH_CPU_OPTIMIZED_DATATYPES)
add_library(cpu_backend ${LIBRARY_TYPE} ${SRC}) add_library(cpu_backend ${LIBRARY_TYPE} ${SRC})
if (NGRAPH_CPU_STATIC_LIB_ENABLE) if (NGRAPH_CPU_STATIC_LIB_ENABLE)
target_compile_definitions(cpu_backend PRIVATE "NGRAPH_CPU_STATIC_LIB_ENABLE") target_compile_definitions(cpu_backend PRIVATE "NGRAPH_CPU_STATIC_LIB_ENABLE")
...@@ -209,6 +242,10 @@ if (NGRAPH_CPU_ENABLE) ...@@ -209,6 +242,10 @@ if (NGRAPH_CPU_ENABLE)
endif() endif()
target_compile_definitions(cpu_backend PRIVATE CPU_BACKEND_DLL_EXPORTS) target_compile_definitions(cpu_backend PRIVATE CPU_BACKEND_DLL_EXPORTS)
foreach(t ${NGRAPH_CPU_OPTIMIZED_DATATYPES})
target_compile_definitions(cpu_backend PRIVATE "NGRAPH_CPU_OPTIMIZE_${t}")
endforeach()
add_dependencies(cpu_backend libmkldnn ext_eigen) add_dependencies(cpu_backend libmkldnn ext_eigen)
target_link_libraries(cpu_backend PUBLIC ngraph libmkldnn libmkl libeigen libtbb) target_link_libraries(cpu_backend PUBLIC ngraph libmkldnn libmkl libeigen libtbb)
......
...@@ -57,8 +57,8 @@ namespace ngraph ...@@ -57,8 +57,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<float, int64_t, 1>)> std::function<decltype(runtime::cpu::kernel::argmax<float, int64_t, 1>)>
kernel; kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmax) kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, functor = [&,
kernel, kernel,
...@@ -80,8 +80,8 @@ namespace ngraph ...@@ -80,8 +80,8 @@ namespace ngraph
{ {
std::function<decltype(runtime::cpu::kernel::argmax<float, int, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::argmax<float, int, 1>)> kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmax) kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, functor = [&,
kernel, kernel,
...@@ -107,8 +107,8 @@ namespace ngraph ...@@ -107,8 +107,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<double, int64_t, 1>)> std::function<decltype(runtime::cpu::kernel::argmax<double, int64_t, 1>)>
kernel; kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmax) kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, functor = [&,
kernel, kernel,
...@@ -131,8 +131,8 @@ namespace ngraph ...@@ -131,8 +131,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<double, int, 1>)> std::function<decltype(runtime::cpu::kernel::argmax<double, int, 1>)>
kernel; kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmax) kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, functor = [&,
kernel, kernel,
...@@ -158,8 +158,8 @@ namespace ngraph ...@@ -158,8 +158,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<int, int64_t, 1>)> std::function<decltype(runtime::cpu::kernel::argmax<int, int64_t, 1>)>
kernel; kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmax) kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, functor = [&,
kernel, kernel,
...@@ -181,8 +181,8 @@ namespace ngraph ...@@ -181,8 +181,8 @@ namespace ngraph
{ {
std::function<decltype(runtime::cpu::kernel::argmax<int, int, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::argmax<int, int, 1>)> kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmax) kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, functor = [&,
kernel, kernel,
......
...@@ -57,8 +57,8 @@ namespace ngraph ...@@ -57,8 +57,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<float, int64_t, 1>)> std::function<decltype(runtime::cpu::kernel::argmin<float, int64_t, 1>)>
kernel; kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmin) kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, functor = [&,
kernel, kernel,
...@@ -80,8 +80,8 @@ namespace ngraph ...@@ -80,8 +80,8 @@ namespace ngraph
{ {
std::function<decltype(runtime::cpu::kernel::argmin<float, int, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::argmin<float, int, 1>)> kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmin) kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, functor = [&,
kernel, kernel,
...@@ -107,8 +107,8 @@ namespace ngraph ...@@ -107,8 +107,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<double, int64_t, 1>)> std::function<decltype(runtime::cpu::kernel::argmin<double, int64_t, 1>)>
kernel; kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmin) kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, functor = [&,
kernel, kernel,
...@@ -131,8 +131,8 @@ namespace ngraph ...@@ -131,8 +131,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<double, int, 1>)> std::function<decltype(runtime::cpu::kernel::argmin<double, int, 1>)>
kernel; kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmin) kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, functor = [&,
kernel, kernel,
...@@ -158,8 +158,8 @@ namespace ngraph ...@@ -158,8 +158,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<int, int64_t, 1>)> std::function<decltype(runtime::cpu::kernel::argmin<int, int64_t, 1>)>
kernel; kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmin) kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, functor = [&,
kernel, kernel,
...@@ -181,8 +181,8 @@ namespace ngraph ...@@ -181,8 +181,8 @@ namespace ngraph
{ {
std::function<decltype(runtime::cpu::kernel::argmin<int, int, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::argmin<int, int, 1>)> kernel;
SELECT_RANK2( SELECT_KERNEL_RANK(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmin) kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, functor = [&,
kernel, kernel,
......
...@@ -155,7 +155,7 @@ namespace ngraph ...@@ -155,7 +155,7 @@ namespace ngraph
} }
} }
SELECT_KERNEL_BY_RANK(kernel, SELECT_KERNEL_ET_RANK(kernel,
broadcast->get_input_element_type(0), broadcast->get_input_element_type(0),
out_rank, out_rank,
runtime::cpu::kernel::broadcast) runtime::cpu::kernel::broadcast)
......
...@@ -149,7 +149,7 @@ namespace ngraph ...@@ -149,7 +149,7 @@ namespace ngraph
{ {
std::function<decltype(runtime::cpu::kernel::concat<float, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::concat<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel, SELECT_KERNEL_ET_RANK(kernel,
out[0].get_element_type(), out[0].get_element_type(),
out[0].get_shape().size(), out[0].get_shape().size(),
runtime::cpu::kernel::concat) runtime::cpu::kernel::concat)
......
...@@ -67,7 +67,9 @@ namespace ngraph ...@@ -67,7 +67,9 @@ namespace ngraph
return; return;
} }
if (arg0_shape.empty() || arg1_shape.empty()) if ((arg0_shape.empty() || arg1_shape.empty()) &&
is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{ {
auto first = (arg0_shape.empty() ? args[0] : args[1]); auto first = (arg0_shape.empty() ? args[0] : args[1]);
auto second = (arg0_shape.empty() ? args[1] : args[0]); auto second = (arg0_shape.empty() ? args[1] : args[0]);
...@@ -78,8 +80,7 @@ namespace ngraph ...@@ -78,8 +80,7 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::dot_scalar<float>)> kernel; std::function<decltype(runtime::cpu::kernel::dot_scalar<float>)> kernel;
SELECT_KERNEL( SELECT_ETS(kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_scalar);
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_scalar)
auto element_count = shape_size(second.get_shape()); auto element_count = shape_size(second.get_shape());
...@@ -101,12 +102,13 @@ namespace ngraph ...@@ -101,12 +102,13 @@ namespace ngraph
} }
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1) && if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1) &&
reduction_axes_count == 1) reduction_axes_count == 1 && is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::dot_1d_1d_1rd<float>)> kernel; std::function<decltype(runtime::cpu::kernel::dot_1d_1d_1rd<float>)> kernel;
SELECT_KERNEL( SELECT_ETS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_1d_1rd) kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_1d_1rd);
auto functor = [&, auto functor = [&,
kernel, kernel,
...@@ -130,12 +132,13 @@ namespace ngraph ...@@ -130,12 +132,13 @@ namespace ngraph
} }
if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
reduction_axes_count == 1) reduction_axes_count == 1 && is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::dot_2d_1d_1rd<float>)> kernel; std::function<decltype(runtime::cpu::kernel::dot_2d_1d_1rd<float>)> kernel;
SELECT_KERNEL( SELECT_ETS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_2d_1d_1rd) kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_2d_1d_1rd);
auto functor = [&, auto functor = [&,
kernel, kernel,
...@@ -159,12 +162,13 @@ namespace ngraph ...@@ -159,12 +162,13 @@ namespace ngraph
} }
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 2) && if ((arg0_shape.size() == 1) && (arg1_shape.size() == 2) &&
reduction_axes_count == 1) reduction_axes_count == 1 && is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::dot_1d_2d_1rd<float>)> kernel; std::function<decltype(runtime::cpu::kernel::dot_1d_2d_1rd<float>)> kernel;
SELECT_KERNEL( SELECT_ETS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_2d_1rd) kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_2d_1rd);
auto functor = [&, auto functor = [&,
kernel, kernel,
......
...@@ -57,16 +57,17 @@ namespace ngraph ...@@ -57,16 +57,17 @@ namespace ngraph
args[0].get_element_type() == element::f64 || args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 || args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) && args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 5) params_shape.size() <= 3 && out_shape.size() <= 5 &&
is_optimized_et(args[0].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)>
kernel; kernel;
SELECT_KERNEL_BY_2RANKS(kernel, SELECT_RANK35_ET4(kernel,
args[0].get_element_type(), args[0].get_element_type(),
params_shape.size(), params_shape.size(),
out_shape.size(), out_shape.size(),
runtime::cpu::kernel::gather_i64) runtime::cpu::kernel::gather_i64);
return [&, return [&,
kernel, kernel,
...@@ -117,16 +118,17 @@ namespace ngraph ...@@ -117,16 +118,17 @@ namespace ngraph
args[0].get_element_type() == element::f64 || args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 || args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) && args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 5) params_shape.size() <= 3 && out_shape.size() <= 5 &&
is_optimized_et(args[0].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)>
kernel; kernel;
SELECT_KERNEL_BY_2RANKS(kernel, SELECT_RANK35_ET4(kernel,
args[0].get_element_type(), args[0].get_element_type(),
params_shape.size(), params_shape.size(),
out_shape.size(), out_shape.size(),
runtime::cpu::kernel::gather_i32) runtime::cpu::kernel::gather_i32);
return [&, return [&,
kernel, kernel,
...@@ -188,6 +190,10 @@ namespace ngraph ...@@ -188,6 +190,10 @@ namespace ngraph
{ {
functor = prepare_functor<float>(node, args, out, external_function); functor = prepare_functor<float>(node, args, out, external_function);
} }
else if (element_type == element::i64)
{
functor = prepare_functor<int64_t>(node, args, out, external_function);
}
else if (element_type == element::f64) else if (element_type == element::f64)
{ {
functor = prepare_functor<double>(node, args, out, external_function); functor = prepare_functor<double>(node, args, out, external_function);
...@@ -204,10 +210,6 @@ namespace ngraph ...@@ -204,10 +210,6 @@ namespace ngraph
{ {
functor = prepare_functor<int32_t>(node, args, out, external_function); functor = prepare_functor<int32_t>(node, args, out, external_function);
} }
else if (element_type == element::i64)
{
functor = prepare_functor<int64_t>(node, args, out, external_function);
}
else if (element_type == element::u8) else if (element_type == element::u8)
{ {
functor = prepare_functor<uint8_t>(node, args, out, external_function); functor = prepare_functor<uint8_t>(node, args, out, external_function);
......
...@@ -50,14 +50,15 @@ namespace ngraph ...@@ -50,14 +50,15 @@ namespace ngraph
auto padding_above = pad->get_padding_above(); auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode(); auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT) if (pad_mode == ngraph::op::PadMode::CONSTANT &&
is_optimized_et(args[0].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel, SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(), args[0].get_element_type(),
arg_shape.size(), arg_shape.size(),
runtime::cpu::kernel::pad_and_slice) runtime::cpu::kernel::pad_and_slice);
auto functor = [&, auto functor = [&,
kernel, kernel,
...@@ -122,14 +123,15 @@ namespace ngraph ...@@ -122,14 +123,15 @@ namespace ngraph
auto padding_above = pad->get_padding_above(); auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode(); auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT) if (pad_mode == ngraph::op::PadMode::CONSTANT &&
is_optimized_et(pad->get_input_element_type(0)))
{ {
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel, SELECT_ETS_AND_RANK7(kernel,
pad->get_input_element_type(0), pad->get_input_element_type(0),
arg_shape.size(), arg_shape.size(),
runtime::cpu::kernel::pad_and_slice) runtime::cpu::kernel::pad_and_slice);
auto functor = [kernel, arg_shape, out_shape, padding_below, padding_above]( auto functor = [kernel, arg_shape, out_shape, padding_below, padding_above](
const std::vector<void*>& inputs, std::vector<void*>& outputs) { const std::vector<void*>& inputs, std::vector<void*>& outputs) {
......
...@@ -41,10 +41,10 @@ ...@@ -41,10 +41,10 @@
return; \ return; \
} \ } \
\ \
if (reduction_axes.size() == arg_rank) \ if (reduction_axes.size() == arg_rank && is_optimized_et(args[0].get_element_type())) \
{ \ { \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_all<float, 2>)> kernel; \ std::function<decltype(runtime::cpu::kernel::reduce_##K##_all<float, 2>)> kernel; \
SELECT_KERNEL_BY_RANK( \ SELECT_ETS_AND_RANK7( \
kernel, result_element_type, arg_rank, runtime::cpu::kernel::reduce_##K##_all); \ kernel, result_element_type, arg_rank, runtime::cpu::kernel::reduce_##K##_all); \
auto functor = [&, kernel, arg_shape, result_shape, arg_buffer_index, out_buffer_index]( \ auto functor = [&, kernel, arg_shape, result_shape, arg_buffer_index, out_buffer_index]( \
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { \ CPURuntimeContext* ctx, CPUExecutionContext* ectx) { \
...@@ -58,16 +58,16 @@ ...@@ -58,16 +58,16 @@
return; \ return; \
} \ } \
\ \
if (reduction_axes.size() == 1) \ if (reduction_axes.size() == 1 && is_optimized_et(args[0].get_element_type())) \
{ \ { \
if (*reduction_axes.begin() == arg_rank - 1) \ if (*reduction_axes.begin() == arg_rank - 1) \
{ \ { \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_innermost_1rd<float, 2>)> \ std::function<decltype(runtime::cpu::kernel::reduce_##K##_innermost_1rd<float, 2>)> \
kernel; \ kernel; \
SELECT_KERNEL_BY_RANK(kernel, \ SELECT_ETS_AND_RANK7(kernel, \
result_element_type, \ result_element_type, \
arg_rank, \ arg_rank, \
runtime::cpu::kernel::reduce_##K##_innermost_1rd); \ runtime::cpu::kernel::reduce_##K##_innermost_1rd); \
auto functor = \ auto functor = \
[&, kernel, arg_shape, result_shape, arg_buffer_index, out_buffer_index]( \ [&, kernel, arg_shape, result_shape, arg_buffer_index, out_buffer_index]( \
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { \ CPURuntimeContext* ctx, CPUExecutionContext* ectx) { \
...@@ -82,7 +82,7 @@ ...@@ -82,7 +82,7 @@
} \ } \
\ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_1rd<float, 2>)> kernel; \ std::function<decltype(runtime::cpu::kernel::reduce_##K##_1rd<float, 2>)> kernel; \
SELECT_KERNEL_BY_RANK( \ SELECT_ETS_AND_RANK7( \
kernel, result_element_type, arg_rank, runtime::cpu::kernel::reduce_##K##_1rd); \ kernel, result_element_type, arg_rank, runtime::cpu::kernel::reduce_##K##_1rd); \
auto functor = [&, \ auto functor = [&, \
kernel, \ kernel, \
...@@ -102,10 +102,11 @@ ...@@ -102,10 +102,11 @@
return; \ return; \
} \ } \
\ \
if (reduction_axes.size() == 2 && arg_rank == 3) \ if (reduction_axes.size() == 2 && arg_rank == 3 && \
is_optimized_et(args[0].get_element_type())) \
{ \ { \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_3d_2rd<float>)> kernel; \ std::function<decltype(runtime::cpu::kernel::reduce_##K##_3d_2rd<float>)> kernel; \
SELECT_KERNEL(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_3d_2rd); \ SELECT_ETS(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_3d_2rd); \
auto functor = [&, \ auto functor = [&, \
kernel, \ kernel, \
arg_shape, \ arg_shape, \
...@@ -124,10 +125,11 @@ ...@@ -124,10 +125,11 @@
return; \ return; \
} \ } \
\ \
if (reduction_axes.size() == 2 && arg_rank == 4) \ if (reduction_axes.size() == 2 && arg_rank == 4 && \
is_optimized_et(args[0].get_element_type())) \
{ \ { \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_4d_2rd<float>)> kernel; \ std::function<decltype(runtime::cpu::kernel::reduce_##K##_4d_2rd<float>)> kernel; \
SELECT_KERNEL(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_4d_2rd); \ SELECT_ETS(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_4d_2rd); \
auto functor = [&, \ auto functor = [&, \
kernel, \ kernel, \
arg_shape, \ arg_shape, \
...@@ -146,10 +148,11 @@ ...@@ -146,10 +148,11 @@
return; \ return; \
} \ } \
\ \
if (reduction_axes.size() == 2 && arg_rank == 5) \ if (reduction_axes.size() == 2 && arg_rank == 5 && \
is_optimized_et(args[0].get_element_type())) \
{ \ { \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_5d_2rd<float>)> kernel; \ std::function<decltype(runtime::cpu::kernel::reduce_##K##_5d_2rd<float>)> kernel; \
SELECT_KERNEL(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_5d_2rd); \ SELECT_ETS(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_5d_2rd); \
auto functor = [&, \ auto functor = [&, \
kernel, \ kernel, \
arg_shape, \ arg_shape, \
......
...@@ -43,6 +43,7 @@ namespace ngraph ...@@ -43,6 +43,7 @@ namespace ngraph
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto out_shape = out[0].get_shape();
auto strides = replace_slice->get_strides(); auto strides = replace_slice->get_strides();
auto lower_bounds = replace_slice->get_lower_bounds(); auto lower_bounds = replace_slice->get_lower_bounds();
...@@ -71,15 +72,15 @@ namespace ngraph ...@@ -71,15 +72,15 @@ namespace ngraph
return; return;
} }
if (strided) if (strided && is_optimized_et(args[0].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::strided_replace_slice<float, 2>)> std::function<decltype(runtime::cpu::kernel::strided_replace_slice<float, 2>)>
kernel; kernel;
SELECT_KERNEL_BY_RANK(kernel, SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(), args[0].get_element_type(),
arg0_shape.size(), arg0_shape.size(),
runtime::cpu::kernel::strided_replace_slice) runtime::cpu::kernel::strided_replace_slice);
auto functor = [&, auto functor = [&,
kernel, kernel,
...@@ -104,14 +105,14 @@ namespace ngraph ...@@ -104,14 +105,14 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
else else if (is_optimized_et(args[0].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::replace_slice<float, 2>)> kernel; std::function<decltype(runtime::cpu::kernel::replace_slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel, SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(), args[0].get_element_type(),
arg0_shape.size(), arg0_shape.size(),
runtime::cpu::kernel::replace_slice) runtime::cpu::kernel::replace_slice);
auto functor = [&, auto functor = [&,
kernel, kernel,
...@@ -132,6 +133,34 @@ namespace ngraph ...@@ -132,6 +133,34 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
else
{
std::function<decltype(runtime::cpu::kernel::ref_replace_slice<float>)> kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::ref_replace_slice);
auto functor = [&,
kernel,
arg1_shape,
out_shape,
lower_bounds,
upper_bounds,
strides,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[out_buffer_index],
arg1_shape,
lower_bounds,
upper_bounds,
strides,
out_shape);
};
functors.emplace_back(functor);
}
} }
void register_builders_replace_slice_cpp() { REGISTER_OP_BUILDER(ReplaceSlice); } void register_builders_replace_slice_cpp() { REGISTER_OP_BUILDER(ReplaceSlice); }
......
...@@ -81,25 +81,25 @@ namespace ngraph ...@@ -81,25 +81,25 @@ namespace ngraph
return; return;
} }
if (arg_rank == 1) if (arg_rank == 1 && is_optimized_et(result_element_type))
{ {
SELECT_KERNEL_BY_RANK( SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_1d) kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_1d);
} }
else if (arg_rank == 2) else if (arg_rank == 2 && is_optimized_et(result_element_type))
{ {
SELECT_KERNEL_BY_RANK( SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_2d) kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_2d);
} }
else if (arg_rank == 3) else if (arg_rank == 3 && is_optimized_et(result_element_type))
{ {
SELECT_KERNEL_BY_RANK( SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_3d) kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_3d);
} }
else if (arg_rank == 4) else if (arg_rank == 4 && is_optimized_et(result_element_type))
{ {
SELECT_KERNEL_BY_RANK( SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_4d) kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_4d);
} }
else else
{ {
......
...@@ -47,7 +47,7 @@ namespace ngraph ...@@ -47,7 +47,7 @@ namespace ngraph
if (args[1].get_element_type() == element::i32) if (args[1].get_element_type() == element::i32)
{ {
SELECT_KERNEL_BY_RANK(kernel, SELECT_KERNEL_ET_RANK(kernel,
args[0].get_element_type(), args[0].get_element_type(),
arg_shape.size(), arg_shape.size(),
runtime::cpu::kernel::reverse_sequence_sli32) runtime::cpu::kernel::reverse_sequence_sli32)
......
...@@ -61,18 +61,18 @@ namespace ngraph ...@@ -61,18 +61,18 @@ namespace ngraph
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type(); auto element_type = args[0].get_element_type();
if (is_int64) if (is_int64 && is_optimized_et(args[0].get_element_type()))
{ {
if (inputs_shape.size() <= 3 && updates_shape.size() <= 5) if (inputs_shape.size() <= 3 && updates_shape.size() <= 5)
{ {
std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)>
kernel; kernel;
SELECT_KERNEL_BY_2RANKS(kernel, SELECT_RANK35_ET4(kernel,
args[0].get_element_type(), args[0].get_element_type(),
inputs_shape.size(), inputs_shape.size(),
updates_shape.size(), updates_shape.size(),
runtime::cpu::kernel::scatter_add_i64) runtime::cpu::kernel::scatter_add_i64);
auto functor = [&, auto functor = [&,
kernel, kernel,
...@@ -100,18 +100,18 @@ namespace ngraph ...@@ -100,18 +100,18 @@ namespace ngraph
throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd"); throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd");
} }
} }
else else if (is_optimized_et(args[0].get_element_type()))
{ {
if (inputs_shape.size() <= 3 && updates_shape.size() <= 5) if (inputs_shape.size() <= 3 && updates_shape.size() <= 5)
{ {
std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)>
kernel; kernel;
SELECT_KERNEL_BY_2RANKS(kernel, SELECT_RANK35_ET4(kernel,
args[0].get_element_type(), args[0].get_element_type(),
inputs_shape.size(), inputs_shape.size(),
updates_shape.size(), updates_shape.size(),
runtime::cpu::kernel::scatter_add_i32) runtime::cpu::kernel::scatter_add_i32);
auto functor = [&, auto functor = [&,
kernel, kernel,
...@@ -139,6 +139,67 @@ namespace ngraph ...@@ -139,6 +139,67 @@ namespace ngraph
throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd"); throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd");
} }
} }
else if (is_int64)
{
std::function<decltype(runtime::cpu::kernel::ref_scatter_add_i64<float>)>
kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::ref_scatter_add_i64);
auto functor = [&,
kernel,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[inputs_buffer_index],
ctx->buffer_data[indices_buffer_index],
ctx->buffer_data[updates_buffer_index],
ctx->buffer_data[out_buffer_index],
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::ref_scatter_add_i32<float>)>
kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::ref_scatter_add_i32);
auto functor = [&,
kernel,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[inputs_buffer_index],
ctx->buffer_data[indices_buffer_index],
ctx->buffer_data[updates_buffer_index],
ctx->buffer_data[out_buffer_index],
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
functors.emplace_back(functor);
}
} }
void register_builders_scatter_add_cpp() { REGISTER_OP_BUILDER(ScatterAdd); } void register_builders_scatter_add_cpp() { REGISTER_OP_BUILDER(ScatterAdd); }
......
...@@ -133,15 +133,15 @@ namespace ngraph ...@@ -133,15 +133,15 @@ namespace ngraph
} }
else else
{ {
if (is_strided(strides)) if (is_strided(strides) && is_optimized_et(args[0].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::strided_slice<float, 2>)> std::function<decltype(runtime::cpu::kernel::strided_slice<float, 2>)>
kernel; kernel;
SELECT_KERNEL_BY_RANK(kernel, SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(), args[0].get_element_type(),
arg_shape.size(), arg_shape.size(),
runtime::cpu::kernel::strided_slice) runtime::cpu::kernel::strided_slice);
auto functor = [&, auto functor = [&,
kernel, kernel,
...@@ -164,14 +164,14 @@ namespace ngraph ...@@ -164,14 +164,14 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
else else if (is_optimized_et(args[0].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::slice<float, 2>)> kernel; std::function<decltype(runtime::cpu::kernel::slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel, SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(), args[0].get_element_type(),
arg_shape.size(), arg_shape.size(),
runtime::cpu::kernel::slice) runtime::cpu::kernel::slice);
auto functor = [&, auto functor = [&,
kernel, kernel,
...@@ -190,6 +190,31 @@ namespace ngraph ...@@ -190,6 +190,31 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
else
{
std::function<decltype(runtime::cpu::kernel::ref_slice<float>)> kernel;
SELECT_KERNEL(
kernel, args[0].get_element_type(), runtime::cpu::kernel::ref_slice);
auto functor = [&,
kernel,
arg_shape,
out_shape,
lower_bounds,
upper_bounds,
strides,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape,
lower_bounds,
upper_bounds,
strides,
out_shape);
};
functors.emplace_back(functor);
}
} }
} }
......
...@@ -75,17 +75,18 @@ namespace ngraph ...@@ -75,17 +75,18 @@ namespace ngraph
ctx, softmax_index, deps, cpu::mkldnn_utils::OpType::SOFTMAX); ctx, softmax_index, deps, cpu::mkldnn_utils::OpType::SOFTMAX);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
return;
} }
else else if (is_optimized_et(args[0].get_element_type()))
{ {
if (axes.size() == arg_shape.size()) if (axes.size() == arg_shape.size())
{ {
std::function<decltype(runtime::cpu::kernel::softmax_all<float, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::softmax_all<float, 1>)> kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(kernel, SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(), args[0].get_element_type(),
args[0].get_shape().size(), args[0].get_shape().size(),
runtime::cpu::kernel::softmax_all) runtime::cpu::kernel::softmax_all);
auto functor = [&, kernel, arg_shape, arg_buffer_index, out_buffer_index]( auto functor = [&, kernel, arg_shape, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
...@@ -95,6 +96,7 @@ namespace ngraph ...@@ -95,6 +96,7 @@ namespace ngraph
ectx->arena); ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
return;
} }
else if (axes.size() == 1) else if (axes.size() == 1)
{ {
...@@ -104,11 +106,10 @@ namespace ngraph ...@@ -104,11 +106,10 @@ namespace ngraph
runtime::cpu::kernel::softmax_innermost_1rd<float, 1>)> runtime::cpu::kernel::softmax_innermost_1rd<float, 1>)>
kernel; kernel;
PARTIAL_SELECT_KERNEL_BY_RANK( SELECT_ETS_AND_RANK7(kernel,
kernel, args[0].get_element_type(),
args[0].get_element_type(), args[0].get_shape().size(),
args[0].get_shape().size(), runtime::cpu::kernel::softmax_innermost_1rd);
runtime::cpu::kernel::softmax_innermost_1rd)
auto functor = auto functor =
[&, kernel, arg_shape, arg_buffer_index, out_buffer_index]( [&, kernel, arg_shape, arg_buffer_index, out_buffer_index](
...@@ -119,16 +120,17 @@ namespace ngraph ...@@ -119,16 +120,17 @@ namespace ngraph
ectx->arena); ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
return;
} }
else else
{ {
std::function<decltype(runtime::cpu::kernel::softmax_1rd<float, 1>)> std::function<decltype(runtime::cpu::kernel::softmax_1rd<float, 1>)>
kernel; kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(kernel, SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(), args[0].get_element_type(),
args[0].get_shape().size(), args[0].get_shape().size(),
runtime::cpu::kernel::softmax_1rd) runtime::cpu::kernel::softmax_1rd);
auto functor = auto functor =
[&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index]( [&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
...@@ -140,15 +142,16 @@ namespace ngraph ...@@ -140,15 +142,16 @@ namespace ngraph
ectx->arena); ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
return;
} }
} }
else if (arg_shape.size() == 3 && axes.size() == 2) else if (arg_shape.size() == 3 && axes.size() == 2)
{ {
std::function<decltype(runtime::cpu::kernel::softmax_3d_2rd<float>)> kernel; std::function<decltype(runtime::cpu::kernel::softmax_3d_2rd<float>)> kernel;
SELECT_KERNEL(kernel, SELECT_ETS(kernel,
args[0].get_element_type(), args[0].get_element_type(),
runtime::cpu::kernel::softmax_3d_2rd) runtime::cpu::kernel::softmax_3d_2rd);
auto functor = auto functor =
[&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index]( [&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
...@@ -160,14 +163,15 @@ namespace ngraph ...@@ -160,14 +163,15 @@ namespace ngraph
ectx->arena); ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
return;
} }
else if (arg_shape.size() == 4 && axes.size() == 3) else if (arg_shape.size() == 4 && axes.size() == 3)
{ {
std::function<decltype(runtime::cpu::kernel::softmax_4d_3rd<float>)> kernel; std::function<decltype(runtime::cpu::kernel::softmax_4d_3rd<float>)> kernel;
SELECT_KERNEL(kernel, SELECT_ETS(kernel,
args[0].get_element_type(), args[0].get_element_type(),
runtime::cpu::kernel::softmax_4d_3rd) runtime::cpu::kernel::softmax_4d_3rd);
auto functor = auto functor =
[&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index]( [&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
...@@ -179,28 +183,22 @@ namespace ngraph ...@@ -179,28 +183,22 @@ namespace ngraph
ectx->arena); ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} return;
else if (softmax->get_element_type() == element::f32)
{
NGRAPH_WARN << "Falling back to refernce kernel for softmax " << arg_shape
<< " over " << axes;
auto functor = [&, arg_shape, axes, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
runtime::reference::softmax<float>(
static_cast<float*>(ctx->buffer_data[arg_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
arg_shape,
axes);
};
functors.emplace_back(functor);
}
else
{
NGRAPH_ERR << "Unsupported Softmax " << arg_shape << " over " << axes
<< " in cpu buiilder";
throw ngraph_error("Unsupported Softmax");
} }
} }
NGRAPH_WARN << "Falling back to refernce kernel for softmax " << arg_shape
<< " over " << axes;
std::function<decltype(runtime::cpu::kernel::ref_softmax<float>)> kernel;
SELECT_KERNEL(
kernel, args[0].get_element_type(), runtime::cpu::kernel::ref_softmax);
auto functor = [&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape,
axes);
};
functors.emplace_back(functor);
} }
void register_builders_softmax_cpp() { REGISTER_OP_BUILDER(Softmax); } void register_builders_softmax_cpp() { REGISTER_OP_BUILDER(Softmax); }
......
...@@ -60,8 +60,8 @@ namespace ngraph ...@@ -60,8 +60,8 @@ namespace ngraph
else else
{ {
std::function<decltype(runtime::cpu::kernel::tile<float, 2>)> kernel; std::function<decltype(runtime::cpu::kernel::tile<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK( SELECT_KERNEL_ET_RANK(
kernel, out[0].get_element_type(), arg_rank, runtime::cpu::kernel::tile) kernel, out[0].get_element_type(), arg_rank, runtime::cpu::kernel::tile);
auto functor = auto functor =
[&, kernel, arg_shape, out_shape, arg_buffer_index, out_buffer_index]( [&, kernel, arg_shape, out_shape, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
......
...@@ -66,7 +66,7 @@ namespace ngraph ...@@ -66,7 +66,7 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::strided_update_slice<float, 2>)> std::function<decltype(runtime::cpu::kernel::strided_update_slice<float, 2>)>
kernel; kernel;
SELECT_KERNEL_BY_RANK(kernel, SELECT_KERNEL_ET_RANK(kernel,
args[0].get_element_type(), args[0].get_element_type(),
arg0_shape.size(), arg0_shape.size(),
runtime::cpu::kernel::strided_update_slice) runtime::cpu::kernel::strided_update_slice)
...@@ -98,7 +98,7 @@ namespace ngraph ...@@ -98,7 +98,7 @@ namespace ngraph
{ {
std::function<decltype(runtime::cpu::kernel::update_slice<float, 2>)> kernel; std::function<decltype(runtime::cpu::kernel::update_slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel, SELECT_KERNEL_ET_RANK(kernel,
args[0].get_element_type(), args[0].get_element_type(),
arg0_shape.size(), arg0_shape.size(),
runtime::cpu::kernel::update_slice) runtime::cpu::kernel::update_slice)
......
This diff is collapsed.
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/coordinate.hpp" #include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp" #include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
...@@ -100,6 +101,26 @@ namespace ngraph ...@@ -100,6 +101,26 @@ namespace ngraph
.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device( .device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = in1; arena)) = in1;
} }
template <typename ElementType>
void ref_replace_slice(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& slice_strides,
const Shape& output_shape)
{
reference::replace_slice<ElementType>(static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
static_cast<ElementType*>(output),
input0_shape,
lower_bounds,
upper_bounds,
slice_strides,
output_shape);
}
} }
} }
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/coordinate.hpp" #include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp" #include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
...@@ -168,6 +169,46 @@ namespace ngraph ...@@ -168,6 +169,46 @@ namespace ngraph
updates_shape, updates_shape,
arena); arena);
} }
template <typename ElementType>
void ref_scatter_add_i32(void* inputs,
void* indices,
void* updates,
void* output,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
const Shape& output_shape)
{
reference::scatter_add<ElementType, int32_t>(static_cast<ElementType*>(inputs),
static_cast<int32_t*>(indices),
static_cast<ElementType*>(updates),
static_cast<ElementType*>(output),
inputs_shape,
indices_shape,
updates_shape,
output_shape);
}
template <typename ElementType>
void ref_scatter_add_i64(void* inputs,
void* indices,
void* updates,
void* output,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
const Shape& output_shape)
{
reference::scatter_add<ElementType, int64_t>(static_cast<ElementType*>(inputs),
static_cast<int64_t*>(indices),
static_cast<ElementType*>(updates),
static_cast<ElementType*>(output),
inputs_shape,
indices_shape,
updates_shape,
output_shape);
}
} }
} }
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/coordinate.hpp" #include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp" #include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
...@@ -88,6 +89,24 @@ namespace ngraph ...@@ -88,6 +89,24 @@ namespace ngraph
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) =
in.stridedSlice(start_indices, stop_indices, strides); in.stridedSlice(start_indices, stop_indices, strides);
} }
template <typename ElementType>
void ref_slice(void* input,
void* output,
const Shape& input_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& slice_strides,
const Shape& output_shape)
{
reference::slice<ElementType>(static_cast<const ElementType*>(input),
static_cast<ElementType*>(output),
input_shape,
lower_bounds,
upper_bounds,
slice_strides,
output_shape);
}
} }
} }
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp" #include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
...@@ -164,6 +165,15 @@ namespace ngraph ...@@ -164,6 +165,15 @@ namespace ngraph
{ {
softmax<ElementType, 4, 3>(input, output, input_shape, softmax_axes, arena); softmax<ElementType, 4, 3>(input, output, input_shape, softmax_axes, arena);
} }
template <typename ElementType>
void ref_softmax(void* input, void* output, const Shape& shape, const AxisSet& axes)
{
reference::softmax<ElementType>(static_cast<const ElementType*>(input),
static_cast<ElementType*>(output),
shape,
axes);
}
} }
} }
} }
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment