Commit e2f55f83 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Add a compile time flag to support all ET's in cpu backend (#3350)

* Refactor cpu/builder/dot.cpp

* Support only float and i64

* Refactor softmax and reduction to support only float and i64

* Add a new file for customized kernels

* Add compile time flag NGRAPH_CPU_LARGE_BINARY

* Add compilation flag checks in reduction.hpp

* change dot to only support float and i64

* Revert "Refactor cpu/builder/dot.cpp"

This reverts commit 0d53f27fde64872aff096f12ee9b79e5db7a7fee.

* style

* Consolidate macros

* Refactor slice ops, reshape & pad

* cleanup

* Gather op

* Concat, Convert and broadcast op

* Tile and reverse_sequence op

* Scatter_add op

* ET other then fp and i64 go through ref

* tests passing

* Consolidate macros

* Address feedback

* Fall back to reference for pad and reshape

* add scatter_add reference fallback

* Undo concat change

* Undo tile and update slice as they dont have reference implementation

* Remove update slice condition

* Gather op type check

* change routine name

* Build-time configurability for datatypes that requires optimized kernels on CPU backend (#3592)

* Inline function

* Change condition

* VS compiler workaround

* Add comment for VS workaround

* More fixes for VS

* More wrapping of nested macros for VS

* More wrapper

* variable name change and more wrapping of macros

* test

* Style and refactor

* Wrap macros

* Add a seperate macro for SELECT_KERNEL

* Change SELECT_KERNEL_3ARGS

* Unwrap couple of macros

* Syntax

* Add a new macro for fixed number of args for VS compiler

* Comment/Fake ectx

* Add detailed comment for the workaround for VS compiler

* Comment all unused ectx variables

* Templated calls to reference kernels

* const args

* Change softmax ref definition to take double

* Hardcode softmax kernel to take double ..testing

* Fix softmax
parent b39d0aab
......@@ -156,9 +156,42 @@ if (NGRAPH_MLIR_ENABLE)
)
endif()
set(NGRAPH_CPU_ALL_DATATYPES
boolean
f32
f64
i8
i16
i32
i64
u8
u16
u32
u64
)
set(NGRAPH_CPU_COMMON_DATATYPES
f32
i64
)
if (NGRAPH_CPU_ENABLE)
set(NGRAPH_CPU_DEBUGINFO_ENABLE 0 CACHE STRING "Enable debuginfo in the CPU backend")
set(NGRAPH_CPU_OPTIMIZED_DATATYPES "common"
CACHE STRING "Semicolon-separated list of datatypes to optimize for, or \"common\" or \"all\".")
if (NGRAPH_CPU_OPTIMIZED_DATATYPES STREQUAL "all")
set(NGRAPH_CPU_OPTIMIZED_DATATYPES ${NGRAPH_CPU_ALL_DATATYPES})
endif()
if (NGRAPH_CPU_OPTIMIZED_DATATYPES STREQUAL "common")
set(NGRAPH_CPU_OPTIMIZED_DATATYPES ${NGRAPH_CPU_COMMON_DATATYPES})
endif()
list(REMOVE_DUPLICATES NGRAPH_CPU_OPTIMIZED_DATATYPES)
add_library(cpu_backend ${LIBRARY_TYPE} ${SRC})
if (NGRAPH_CPU_STATIC_LIB_ENABLE)
target_compile_definitions(cpu_backend PRIVATE "NGRAPH_CPU_STATIC_LIB_ENABLE")
......@@ -209,6 +242,10 @@ if (NGRAPH_CPU_ENABLE)
endif()
target_compile_definitions(cpu_backend PRIVATE CPU_BACKEND_DLL_EXPORTS)
foreach(t ${NGRAPH_CPU_OPTIMIZED_DATATYPES})
target_compile_definitions(cpu_backend PRIVATE "NGRAPH_CPU_OPTIMIZE_${t}")
endforeach()
add_dependencies(cpu_backend libmkldnn ext_eigen)
target_link_libraries(cpu_backend PUBLIC ngraph libmkldnn libmkl libeigen libtbb)
......
......@@ -57,8 +57,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<float, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......@@ -80,8 +80,8 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::argmax<float, int, 1>)> kernel;
SELECT_RANK2(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......@@ -107,8 +107,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<double, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......@@ -131,8 +131,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<double, int, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......@@ -158,8 +158,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<int, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......@@ -181,8 +181,8 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::argmax<int, int, 1>)> kernel;
SELECT_RANK2(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......
......@@ -57,8 +57,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<float, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......@@ -80,8 +80,8 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::argmin<float, int, 1>)> kernel;
SELECT_RANK2(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......@@ -107,8 +107,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<double, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......@@ -131,8 +131,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<double, int, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......@@ -158,8 +158,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<int, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......@@ -181,8 +181,8 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::argmin<int, int, 1>)> kernel;
SELECT_RANK2(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......
......@@ -155,7 +155,7 @@ namespace ngraph
}
}
SELECT_KERNEL_BY_RANK(kernel,
SELECT_KERNEL_ET_RANK(kernel,
broadcast->get_input_element_type(0),
out_rank,
runtime::cpu::kernel::broadcast)
......
......@@ -149,7 +149,7 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::concat<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
SELECT_KERNEL_ET_RANK(kernel,
out[0].get_element_type(),
out[0].get_shape().size(),
runtime::cpu::kernel::concat)
......
......@@ -67,7 +67,9 @@ namespace ngraph
return;
}
if (arg0_shape.empty() || arg1_shape.empty())
if ((arg0_shape.empty() || arg1_shape.empty()) &&
is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{
auto first = (arg0_shape.empty() ? args[0] : args[1]);
auto second = (arg0_shape.empty() ? args[1] : args[0]);
......@@ -78,8 +80,7 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::dot_scalar<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_scalar)
SELECT_ETS(kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_scalar);
auto element_count = shape_size(second.get_shape());
......@@ -101,12 +102,13 @@ namespace ngraph
}
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1) &&
reduction_axes_count == 1)
reduction_axes_count == 1 && is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::dot_1d_1d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_1d_1rd)
SELECT_ETS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_1d_1rd);
auto functor = [&,
kernel,
......@@ -130,12 +132,13 @@ namespace ngraph
}
if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
reduction_axes_count == 1)
reduction_axes_count == 1 && is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::dot_2d_1d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_2d_1d_1rd)
SELECT_ETS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_2d_1d_1rd);
auto functor = [&,
kernel,
......@@ -159,12 +162,13 @@ namespace ngraph
}
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 2) &&
reduction_axes_count == 1)
reduction_axes_count == 1 && is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::dot_1d_2d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_2d_1rd)
SELECT_ETS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_2d_1rd);
auto functor = [&,
kernel,
......
......@@ -57,16 +57,17 @@ namespace ngraph
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 5)
params_shape.size() <= 3 && out_shape.size() <= 5 &&
is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)>
kernel;
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
params_shape.size(),
out_shape.size(),
runtime::cpu::kernel::gather_i64)
SELECT_RANK35_ET4(kernel,
args[0].get_element_type(),
params_shape.size(),
out_shape.size(),
runtime::cpu::kernel::gather_i64);
return [&,
kernel,
......@@ -117,16 +118,17 @@ namespace ngraph
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 5)
params_shape.size() <= 3 && out_shape.size() <= 5 &&
is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)>
kernel;
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
params_shape.size(),
out_shape.size(),
runtime::cpu::kernel::gather_i32)
SELECT_RANK35_ET4(kernel,
args[0].get_element_type(),
params_shape.size(),
out_shape.size(),
runtime::cpu::kernel::gather_i32);
return [&,
kernel,
......@@ -188,6 +190,10 @@ namespace ngraph
{
functor = prepare_functor<float>(node, args, out, external_function);
}
else if (element_type == element::i64)
{
functor = prepare_functor<int64_t>(node, args, out, external_function);
}
else if (element_type == element::f64)
{
functor = prepare_functor<double>(node, args, out, external_function);
......@@ -204,10 +210,6 @@ namespace ngraph
{
functor = prepare_functor<int32_t>(node, args, out, external_function);
}
else if (element_type == element::i64)
{
functor = prepare_functor<int64_t>(node, args, out, external_function);
}
else if (element_type == element::u8)
{
functor = prepare_functor<uint8_t>(node, args, out, external_function);
......
......@@ -50,14 +50,15 @@ namespace ngraph
auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT)
if (pad_mode == ngraph::op::PadMode::CONSTANT &&
is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::pad_and_slice)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::pad_and_slice);
auto functor = [&,
kernel,
......@@ -122,14 +123,15 @@ namespace ngraph
auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT)
if (pad_mode == ngraph::op::PadMode::CONSTANT &&
is_optimized_et(pad->get_input_element_type(0)))
{
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
pad->get_input_element_type(0),
arg_shape.size(),
runtime::cpu::kernel::pad_and_slice)
SELECT_ETS_AND_RANK7(kernel,
pad->get_input_element_type(0),
arg_shape.size(),
runtime::cpu::kernel::pad_and_slice);
auto functor = [kernel, arg_shape, out_shape, padding_below, padding_above](
const std::vector<void*>& inputs, std::vector<void*>& outputs) {
......
......@@ -41,10 +41,10 @@
return; \
} \
\
if (reduction_axes.size() == arg_rank) \
if (reduction_axes.size() == arg_rank && is_optimized_et(args[0].get_element_type())) \
{ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_all<float, 2>)> kernel; \
SELECT_KERNEL_BY_RANK( \
SELECT_ETS_AND_RANK7( \
kernel, result_element_type, arg_rank, runtime::cpu::kernel::reduce_##K##_all); \
auto functor = [&, kernel, arg_shape, result_shape, arg_buffer_index, out_buffer_index]( \
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { \
......@@ -58,16 +58,16 @@
return; \
} \
\
if (reduction_axes.size() == 1) \
if (reduction_axes.size() == 1 && is_optimized_et(args[0].get_element_type())) \
{ \
if (*reduction_axes.begin() == arg_rank - 1) \
{ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_innermost_1rd<float, 2>)> \
kernel; \
SELECT_KERNEL_BY_RANK(kernel, \
result_element_type, \
arg_rank, \
runtime::cpu::kernel::reduce_##K##_innermost_1rd); \
SELECT_ETS_AND_RANK7(kernel, \
result_element_type, \
arg_rank, \
runtime::cpu::kernel::reduce_##K##_innermost_1rd); \
auto functor = \
[&, kernel, arg_shape, result_shape, arg_buffer_index, out_buffer_index]( \
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { \
......@@ -82,7 +82,7 @@
} \
\
std::function<decltype(runtime::cpu::kernel::reduce_##K##_1rd<float, 2>)> kernel; \
SELECT_KERNEL_BY_RANK( \
SELECT_ETS_AND_RANK7( \
kernel, result_element_type, arg_rank, runtime::cpu::kernel::reduce_##K##_1rd); \
auto functor = [&, \
kernel, \
......@@ -102,10 +102,11 @@
return; \
} \
\
if (reduction_axes.size() == 2 && arg_rank == 3) \
if (reduction_axes.size() == 2 && arg_rank == 3 && \
is_optimized_et(args[0].get_element_type())) \
{ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_3d_2rd<float>)> kernel; \
SELECT_KERNEL(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_3d_2rd); \
SELECT_ETS(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_3d_2rd); \
auto functor = [&, \
kernel, \
arg_shape, \
......@@ -124,10 +125,11 @@
return; \
} \
\
if (reduction_axes.size() == 2 && arg_rank == 4) \
if (reduction_axes.size() == 2 && arg_rank == 4 && \
is_optimized_et(args[0].get_element_type())) \
{ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_4d_2rd<float>)> kernel; \
SELECT_KERNEL(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_4d_2rd); \
SELECT_ETS(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_4d_2rd); \
auto functor = [&, \
kernel, \
arg_shape, \
......@@ -146,10 +148,11 @@
return; \
} \
\
if (reduction_axes.size() == 2 && arg_rank == 5) \
if (reduction_axes.size() == 2 && arg_rank == 5 && \
is_optimized_et(args[0].get_element_type())) \
{ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_5d_2rd<float>)> kernel; \
SELECT_KERNEL(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_5d_2rd); \
SELECT_ETS(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_5d_2rd); \
auto functor = [&, \
kernel, \
arg_shape, \
......
......@@ -43,6 +43,7 @@ namespace ngraph
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto out_shape = out[0].get_shape();
auto strides = replace_slice->get_strides();
auto lower_bounds = replace_slice->get_lower_bounds();
......@@ -71,15 +72,15 @@ namespace ngraph
return;
}
if (strided)
if (strided && is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::strided_replace_slice<float, 2>)>
kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::strided_replace_slice)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::strided_replace_slice);
auto functor = [&,
kernel,
......@@ -104,14 +105,14 @@ namespace ngraph
};
functors.emplace_back(functor);
}
else
else if (is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::replace_slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::replace_slice)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::replace_slice);
auto functor = [&,
kernel,
......@@ -132,6 +133,34 @@ namespace ngraph
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::ref_replace_slice<float>)> kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::ref_replace_slice);
auto functor = [&,
kernel,
arg1_shape,
out_shape,
lower_bounds,
upper_bounds,
strides,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[out_buffer_index],
arg1_shape,
lower_bounds,
upper_bounds,
strides,
out_shape);
};
functors.emplace_back(functor);
}
}
void register_builders_replace_slice_cpp() { REGISTER_OP_BUILDER(ReplaceSlice); }
......
......@@ -81,25 +81,25 @@ namespace ngraph
return;
}
if (arg_rank == 1)
if (arg_rank == 1 && is_optimized_et(result_element_type))
{
SELECT_KERNEL_BY_RANK(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_1d)
SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_1d);
}
else if (arg_rank == 2)
else if (arg_rank == 2 && is_optimized_et(result_element_type))
{
SELECT_KERNEL_BY_RANK(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_2d)
SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_2d);
}
else if (arg_rank == 3)
else if (arg_rank == 3 && is_optimized_et(result_element_type))
{
SELECT_KERNEL_BY_RANK(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_3d)
SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_3d);
}
else if (arg_rank == 4)
else if (arg_rank == 4 && is_optimized_et(result_element_type))
{
SELECT_KERNEL_BY_RANK(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_4d)
SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_4d);
}
else
{
......
......@@ -47,7 +47,7 @@ namespace ngraph
if (args[1].get_element_type() == element::i32)
{
SELECT_KERNEL_BY_RANK(kernel,
SELECT_KERNEL_ET_RANK(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::reverse_sequence_sli32)
......
......@@ -61,18 +61,18 @@ namespace ngraph
auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (is_int64)
if (is_int64 && is_optimized_et(args[0].get_element_type()))
{
if (inputs_shape.size() <= 3 && updates_shape.size() <= 5)
{
std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)>
kernel;
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i64)
SELECT_RANK35_ET4(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i64);
auto functor = [&,
kernel,
......@@ -100,18 +100,18 @@ namespace ngraph
throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd");
}
}
else
else if (is_optimized_et(args[0].get_element_type()))
{
if (inputs_shape.size() <= 3 && updates_shape.size() <= 5)
{
std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)>
kernel;
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i32)
SELECT_RANK35_ET4(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i32);
auto functor = [&,
kernel,
......@@ -139,6 +139,67 @@ namespace ngraph
throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd");
}
}
else if (is_int64)
{
std::function<decltype(runtime::cpu::kernel::ref_scatter_add_i64<float>)>
kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::ref_scatter_add_i64);
auto functor = [&,
kernel,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[inputs_buffer_index],
ctx->buffer_data[indices_buffer_index],
ctx->buffer_data[updates_buffer_index],
ctx->buffer_data[out_buffer_index],
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::ref_scatter_add_i32<float>)>
kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::ref_scatter_add_i32);
auto functor = [&,
kernel,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[inputs_buffer_index],
ctx->buffer_data[indices_buffer_index],
ctx->buffer_data[updates_buffer_index],
ctx->buffer_data[out_buffer_index],
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
functors.emplace_back(functor);
}
}
void register_builders_scatter_add_cpp() { REGISTER_OP_BUILDER(ScatterAdd); }
......
......@@ -133,15 +133,15 @@ namespace ngraph
}
else
{
if (is_strided(strides))
if (is_strided(strides) && is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::strided_slice<float, 2>)>
kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::strided_slice)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::strided_slice);
auto functor = [&,
kernel,
......@@ -164,14 +164,14 @@ namespace ngraph
};
functors.emplace_back(functor);
}
else
else if (is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::slice)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::slice);
auto functor = [&,
kernel,
......@@ -190,6 +190,31 @@ namespace ngraph
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::ref_slice<float>)> kernel;
SELECT_KERNEL(
kernel, args[0].get_element_type(), runtime::cpu::kernel::ref_slice);
auto functor = [&,
kernel,
arg_shape,
out_shape,
lower_bounds,
upper_bounds,
strides,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape,
lower_bounds,
upper_bounds,
strides,
out_shape);
};
functors.emplace_back(functor);
}
}
}
......
......@@ -75,17 +75,18 @@ namespace ngraph
ctx, softmax_index, deps, cpu::mkldnn_utils::OpType::SOFTMAX);
};
functors.emplace_back(functor);
return;
}
else
else if (is_optimized_et(args[0].get_element_type()))
{
if (axes.size() == arg_shape.size())
{
std::function<decltype(runtime::cpu::kernel::softmax_all<float, 1>)> kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_all)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_all);
auto functor = [&, kernel, arg_shape, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
......@@ -95,6 +96,7 @@ namespace ngraph
ectx->arena);
};
functors.emplace_back(functor);
return;
}
else if (axes.size() == 1)
{
......@@ -104,11 +106,10 @@ namespace ngraph
runtime::cpu::kernel::softmax_innermost_1rd<float, 1>)>
kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(
kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_innermost_1rd)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_innermost_1rd);
auto functor =
[&, kernel, arg_shape, arg_buffer_index, out_buffer_index](
......@@ -119,16 +120,17 @@ namespace ngraph
ectx->arena);
};
functors.emplace_back(functor);
return;
}
else
{
std::function<decltype(runtime::cpu::kernel::softmax_1rd<float, 1>)>
kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_1rd)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_1rd);
auto functor =
[&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
......@@ -140,15 +142,16 @@ namespace ngraph
ectx->arena);
};
functors.emplace_back(functor);
return;
}
}
else if (arg_shape.size() == 3 && axes.size() == 2)
{
std::function<decltype(runtime::cpu::kernel::softmax_3d_2rd<float>)> kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::softmax_3d_2rd)
SELECT_ETS(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::softmax_3d_2rd);
auto functor =
[&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
......@@ -160,14 +163,15 @@ namespace ngraph
ectx->arena);
};
functors.emplace_back(functor);
return;
}
else if (arg_shape.size() == 4 && axes.size() == 3)
{
std::function<decltype(runtime::cpu::kernel::softmax_4d_3rd<float>)> kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::softmax_4d_3rd)
SELECT_ETS(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::softmax_4d_3rd);
auto functor =
[&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
......@@ -179,28 +183,22 @@ namespace ngraph
ectx->arena);
};
functors.emplace_back(functor);
}
else if (softmax->get_element_type() == element::f32)
{
NGRAPH_WARN << "Falling back to refernce kernel for softmax " << arg_shape
<< " over " << axes;
auto functor = [&, arg_shape, axes, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
runtime::reference::softmax<float>(
static_cast<float*>(ctx->buffer_data[arg_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
arg_shape,
axes);
};
functors.emplace_back(functor);
}
else
{
NGRAPH_ERR << "Unsupported Softmax " << arg_shape << " over " << axes
<< " in cpu buiilder";
throw ngraph_error("Unsupported Softmax");
return;
}
}
NGRAPH_WARN << "Falling back to refernce kernel for softmax " << arg_shape
<< " over " << axes;
std::function<decltype(runtime::cpu::kernel::ref_softmax<float>)> kernel;
SELECT_KERNEL(
kernel, args[0].get_element_type(), runtime::cpu::kernel::ref_softmax);
auto functor = [&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape,
axes);
};
functors.emplace_back(functor);
}
void register_builders_softmax_cpp() { REGISTER_OP_BUILDER(Softmax); }
......
......@@ -60,8 +60,8 @@ namespace ngraph
else
{
std::function<decltype(runtime::cpu::kernel::tile<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(
kernel, out[0].get_element_type(), arg_rank, runtime::cpu::kernel::tile)
SELECT_KERNEL_ET_RANK(
kernel, out[0].get_element_type(), arg_rank, runtime::cpu::kernel::tile);
auto functor =
[&, kernel, arg_shape, out_shape, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
......
......@@ -66,7 +66,7 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::strided_update_slice<float, 2>)>
kernel;
SELECT_KERNEL_BY_RANK(kernel,
SELECT_KERNEL_ET_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::strided_update_slice)
......@@ -98,7 +98,7 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::update_slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
SELECT_KERNEL_ET_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::update_slice)
......
This diff is collapsed.
......@@ -21,6 +21,7 @@
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
......@@ -100,6 +101,26 @@ namespace ngraph
.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = in1;
}
template <typename ElementType>
void ref_replace_slice(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& slice_strides,
const Shape& output_shape)
{
reference::replace_slice<ElementType>(static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
static_cast<ElementType*>(output),
input0_shape,
lower_bounds,
upper_bounds,
slice_strides,
output_shape);
}
}
}
}
......
......@@ -21,6 +21,7 @@
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
......@@ -168,6 +169,46 @@ namespace ngraph
updates_shape,
arena);
}
template <typename ElementType>
void ref_scatter_add_i32(void* inputs,
void* indices,
void* updates,
void* output,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
const Shape& output_shape)
{
reference::scatter_add<ElementType, int32_t>(static_cast<ElementType*>(inputs),
static_cast<int32_t*>(indices),
static_cast<ElementType*>(updates),
static_cast<ElementType*>(output),
inputs_shape,
indices_shape,
updates_shape,
output_shape);
}
template <typename ElementType>
void ref_scatter_add_i64(void* inputs,
void* indices,
void* updates,
void* output,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
const Shape& output_shape)
{
reference::scatter_add<ElementType, int64_t>(static_cast<ElementType*>(inputs),
static_cast<int64_t*>(indices),
static_cast<ElementType*>(updates),
static_cast<ElementType*>(output),
inputs_shape,
indices_shape,
updates_shape,
output_shape);
}
}
}
}
......
......@@ -21,6 +21,7 @@
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
......@@ -88,6 +89,24 @@ namespace ngraph
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) =
in.stridedSlice(start_indices, stop_indices, strides);
}
template <typename ElementType>
void ref_slice(void* input,
void* output,
const Shape& input_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& slice_strides,
const Shape& output_shape)
{
reference::slice<ElementType>(static_cast<const ElementType*>(input),
static_cast<ElementType*>(output),
input_shape,
lower_bounds,
upper_bounds,
slice_strides,
output_shape);
}
}
}
}
......
......@@ -21,6 +21,7 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
......@@ -164,6 +165,15 @@ namespace ngraph
{
softmax<ElementType, 4, 3>(input, output, input_shape, softmax_axes, arena);
}
template <typename ElementType>
void ref_softmax(void* input, void* output, const Shape& shape, const AxisSet& axes)
{
reference::softmax<ElementType>(static_cast<const ElementType*>(input),
static_cast<ElementType*>(output),
shape,
axes);
}
}
}
}
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment