Commit e2f55f83 authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Add a compile time flag to support all ET's in cpu backend (#3350)

* Refactor cpu/builder/dot.cpp

* Support only float and i64

* Refactor softmax and reduction to support only float and i64

* Add a new file for customized kernels

* Add compile time flag NGRAPH_CPU_LARGE_BINARY

* Add compilation flag checks in reduction.hpp

* change dot to only support float and i64

* Revert "Refactor cpu/builder/dot.cpp"

This reverts commit 0d53f27fde64872aff096f12ee9b79e5db7a7fee.

* style

* Consolidate macros

* Refactor slice ops, reshape & pad

* cleanup

* Gather op

* Concat, Convert and broadcast op

* Tile and reverse_sequence op

* Scatter_add op

* ET other then fp and i64 go through ref

* tests passing

* Consolidate macros

* Address feedback

* Fall back to reference for pad and reshape

* add scatter_add reference fallback

* Undo concat change

* Undo tile and update slice as they dont have reference implementation

* Remove update slice condition

* Gather op type check

* change routine name

* Build-time configurability for datatypes that requires optimized kernels on CPU backend (#3592)

* Inline function

* Change condition

* VS compiler workaround

* Add comment for VS workaround

* More fixes for VS

* More wrapping of nested macros for VS

* More wrapper

* variable name change and more wrapping of macros

* test

* Style and refactor

* Wrap macros

* Add a seperate macro for SELECT_KERNEL

* Change SELECT_KERNEL_3ARGS

* Unwrap couple of macros

* Syntax

* Add a new macro for fixed number of args for VS compiler

* Comment/Fake ectx

* Add detailed comment for the workaround for VS compiler

* Comment all unused ectx variables

* Templated calls to reference kernels

* const args

* Change softmax ref definition to take double

* Hardcode softmax kernel to take double ..testing

* Fix softmax
parent b39d0aab
......@@ -156,9 +156,42 @@ if (NGRAPH_MLIR_ENABLE)
)
endif()
set(NGRAPH_CPU_ALL_DATATYPES
boolean
f32
f64
i8
i16
i32
i64
u8
u16
u32
u64
)
set(NGRAPH_CPU_COMMON_DATATYPES
f32
i64
)
if (NGRAPH_CPU_ENABLE)
set(NGRAPH_CPU_DEBUGINFO_ENABLE 0 CACHE STRING "Enable debuginfo in the CPU backend")
set(NGRAPH_CPU_OPTIMIZED_DATATYPES "common"
CACHE STRING "Semicolon-separated list of datatypes to optimize for, or \"common\" or \"all\".")
if (NGRAPH_CPU_OPTIMIZED_DATATYPES STREQUAL "all")
set(NGRAPH_CPU_OPTIMIZED_DATATYPES ${NGRAPH_CPU_ALL_DATATYPES})
endif()
if (NGRAPH_CPU_OPTIMIZED_DATATYPES STREQUAL "common")
set(NGRAPH_CPU_OPTIMIZED_DATATYPES ${NGRAPH_CPU_COMMON_DATATYPES})
endif()
list(REMOVE_DUPLICATES NGRAPH_CPU_OPTIMIZED_DATATYPES)
add_library(cpu_backend ${LIBRARY_TYPE} ${SRC})
if (NGRAPH_CPU_STATIC_LIB_ENABLE)
target_compile_definitions(cpu_backend PRIVATE "NGRAPH_CPU_STATIC_LIB_ENABLE")
......@@ -209,6 +242,10 @@ if (NGRAPH_CPU_ENABLE)
endif()
target_compile_definitions(cpu_backend PRIVATE CPU_BACKEND_DLL_EXPORTS)
foreach(t ${NGRAPH_CPU_OPTIMIZED_DATATYPES})
target_compile_definitions(cpu_backend PRIVATE "NGRAPH_CPU_OPTIMIZE_${t}")
endforeach()
add_dependencies(cpu_backend libmkldnn ext_eigen)
target_link_libraries(cpu_backend PUBLIC ngraph libmkldnn libmkl libeigen libtbb)
......
......@@ -57,8 +57,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<float, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......@@ -80,8 +80,8 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::argmax<float, int, 1>)> kernel;
SELECT_RANK2(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......@@ -107,8 +107,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<double, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......@@ -131,8 +131,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<double, int, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......@@ -158,8 +158,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmax<int, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......@@ -181,8 +181,8 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::argmax<int, int, 1>)> kernel;
SELECT_RANK2(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmax)
SELECT_KERNEL_RANK(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&,
kernel,
......
......@@ -57,8 +57,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<float, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......@@ -80,8 +80,8 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::argmin<float, int, 1>)> kernel;
SELECT_RANK2(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......@@ -107,8 +107,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<double, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......@@ -131,8 +131,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<double, int, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......@@ -158,8 +158,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::argmin<int, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......@@ -181,8 +181,8 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::argmin<int, int, 1>)> kernel;
SELECT_RANK2(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmin)
SELECT_KERNEL_RANK(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&,
kernel,
......
......@@ -155,7 +155,7 @@ namespace ngraph
}
}
SELECT_KERNEL_BY_RANK(kernel,
SELECT_KERNEL_ET_RANK(kernel,
broadcast->get_input_element_type(0),
out_rank,
runtime::cpu::kernel::broadcast)
......
......@@ -149,7 +149,7 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::concat<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
SELECT_KERNEL_ET_RANK(kernel,
out[0].get_element_type(),
out[0].get_shape().size(),
runtime::cpu::kernel::concat)
......
......@@ -67,7 +67,9 @@ namespace ngraph
return;
}
if (arg0_shape.empty() || arg1_shape.empty())
if ((arg0_shape.empty() || arg1_shape.empty()) &&
is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{
auto first = (arg0_shape.empty() ? args[0] : args[1]);
auto second = (arg0_shape.empty() ? args[1] : args[0]);
......@@ -78,8 +80,7 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::dot_scalar<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_scalar)
SELECT_ETS(kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_scalar);
auto element_count = shape_size(second.get_shape());
......@@ -101,12 +102,13 @@ namespace ngraph
}
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1) &&
reduction_axes_count == 1)
reduction_axes_count == 1 && is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::dot_1d_1d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_1d_1rd)
SELECT_ETS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_1d_1rd);
auto functor = [&,
kernel,
......@@ -130,12 +132,13 @@ namespace ngraph
}
if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
reduction_axes_count == 1)
reduction_axes_count == 1 && is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::dot_2d_1d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_2d_1d_1rd)
SELECT_ETS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_2d_1d_1rd);
auto functor = [&,
kernel,
......@@ -159,12 +162,13 @@ namespace ngraph
}
if ((arg0_shape.size() == 1) && (arg1_shape.size() == 2) &&
reduction_axes_count == 1)
reduction_axes_count == 1 && is_optimized_et(args[0].get_element_type()) &&
is_optimized_et(args[1].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::dot_1d_2d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_2d_1rd)
SELECT_ETS(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_1d_2d_1rd);
auto functor = [&,
kernel,
......
......@@ -57,16 +57,17 @@ namespace ngraph
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 5)
params_shape.size() <= 3 && out_shape.size() <= 5 &&
is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)>
kernel;
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
params_shape.size(),
out_shape.size(),
runtime::cpu::kernel::gather_i64)
SELECT_RANK35_ET4(kernel,
args[0].get_element_type(),
params_shape.size(),
out_shape.size(),
runtime::cpu::kernel::gather_i64);
return [&,
kernel,
......@@ -117,16 +118,17 @@ namespace ngraph
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 5)
params_shape.size() <= 3 && out_shape.size() <= 5 &&
is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)>
kernel;
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
params_shape.size(),
out_shape.size(),
runtime::cpu::kernel::gather_i32)
SELECT_RANK35_ET4(kernel,
args[0].get_element_type(),
params_shape.size(),
out_shape.size(),
runtime::cpu::kernel::gather_i32);
return [&,
kernel,
......@@ -188,6 +190,10 @@ namespace ngraph
{
functor = prepare_functor<float>(node, args, out, external_function);
}
else if (element_type == element::i64)
{
functor = prepare_functor<int64_t>(node, args, out, external_function);
}
else if (element_type == element::f64)
{
functor = prepare_functor<double>(node, args, out, external_function);
......@@ -204,10 +210,6 @@ namespace ngraph
{
functor = prepare_functor<int32_t>(node, args, out, external_function);
}
else if (element_type == element::i64)
{
functor = prepare_functor<int64_t>(node, args, out, external_function);
}
else if (element_type == element::u8)
{
functor = prepare_functor<uint8_t>(node, args, out, external_function);
......
......@@ -50,14 +50,15 @@ namespace ngraph
auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT)
if (pad_mode == ngraph::op::PadMode::CONSTANT &&
is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::pad_and_slice)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::pad_and_slice);
auto functor = [&,
kernel,
......@@ -122,14 +123,15 @@ namespace ngraph
auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT)
if (pad_mode == ngraph::op::PadMode::CONSTANT &&
is_optimized_et(pad->get_input_element_type(0)))
{
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
pad->get_input_element_type(0),
arg_shape.size(),
runtime::cpu::kernel::pad_and_slice)
SELECT_ETS_AND_RANK7(kernel,
pad->get_input_element_type(0),
arg_shape.size(),
runtime::cpu::kernel::pad_and_slice);
auto functor = [kernel, arg_shape, out_shape, padding_below, padding_above](
const std::vector<void*>& inputs, std::vector<void*>& outputs) {
......
......@@ -41,10 +41,10 @@
return; \
} \
\
if (reduction_axes.size() == arg_rank) \
if (reduction_axes.size() == arg_rank && is_optimized_et(args[0].get_element_type())) \
{ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_all<float, 2>)> kernel; \
SELECT_KERNEL_BY_RANK( \
SELECT_ETS_AND_RANK7( \
kernel, result_element_type, arg_rank, runtime::cpu::kernel::reduce_##K##_all); \
auto functor = [&, kernel, arg_shape, result_shape, arg_buffer_index, out_buffer_index]( \
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { \
......@@ -58,16 +58,16 @@
return; \
} \
\
if (reduction_axes.size() == 1) \
if (reduction_axes.size() == 1 && is_optimized_et(args[0].get_element_type())) \
{ \
if (*reduction_axes.begin() == arg_rank - 1) \
{ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_innermost_1rd<float, 2>)> \
kernel; \
SELECT_KERNEL_BY_RANK(kernel, \
result_element_type, \
arg_rank, \
runtime::cpu::kernel::reduce_##K##_innermost_1rd); \
SELECT_ETS_AND_RANK7(kernel, \
result_element_type, \
arg_rank, \
runtime::cpu::kernel::reduce_##K##_innermost_1rd); \
auto functor = \
[&, kernel, arg_shape, result_shape, arg_buffer_index, out_buffer_index]( \
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { \
......@@ -82,7 +82,7 @@
} \
\
std::function<decltype(runtime::cpu::kernel::reduce_##K##_1rd<float, 2>)> kernel; \
SELECT_KERNEL_BY_RANK( \
SELECT_ETS_AND_RANK7( \
kernel, result_element_type, arg_rank, runtime::cpu::kernel::reduce_##K##_1rd); \
auto functor = [&, \
kernel, \
......@@ -102,10 +102,11 @@
return; \
} \
\
if (reduction_axes.size() == 2 && arg_rank == 3) \
if (reduction_axes.size() == 2 && arg_rank == 3 && \
is_optimized_et(args[0].get_element_type())) \
{ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_3d_2rd<float>)> kernel; \
SELECT_KERNEL(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_3d_2rd); \
SELECT_ETS(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_3d_2rd); \
auto functor = [&, \
kernel, \
arg_shape, \
......@@ -124,10 +125,11 @@
return; \
} \
\
if (reduction_axes.size() == 2 && arg_rank == 4) \
if (reduction_axes.size() == 2 && arg_rank == 4 && \
is_optimized_et(args[0].get_element_type())) \
{ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_4d_2rd<float>)> kernel; \
SELECT_KERNEL(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_4d_2rd); \
SELECT_ETS(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_4d_2rd); \
auto functor = [&, \
kernel, \
arg_shape, \
......@@ -146,10 +148,11 @@
return; \
} \
\
if (reduction_axes.size() == 2 && arg_rank == 5) \
if (reduction_axes.size() == 2 && arg_rank == 5 && \
is_optimized_et(args[0].get_element_type())) \
{ \
std::function<decltype(runtime::cpu::kernel::reduce_##K##_5d_2rd<float>)> kernel; \
SELECT_KERNEL(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_5d_2rd); \
SELECT_ETS(kernel, result_element_type, runtime::cpu::kernel::reduce_##K##_5d_2rd); \
auto functor = [&, \
kernel, \
arg_shape, \
......
......@@ -43,6 +43,7 @@ namespace ngraph
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto out_shape = out[0].get_shape();
auto strides = replace_slice->get_strides();
auto lower_bounds = replace_slice->get_lower_bounds();
......@@ -71,15 +72,15 @@ namespace ngraph
return;
}
if (strided)
if (strided && is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::strided_replace_slice<float, 2>)>
kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::strided_replace_slice)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::strided_replace_slice);
auto functor = [&,
kernel,
......@@ -104,14 +105,14 @@ namespace ngraph
};
functors.emplace_back(functor);
}
else
else if (is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::replace_slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::replace_slice)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::replace_slice);
auto functor = [&,
kernel,
......@@ -132,6 +133,34 @@ namespace ngraph
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::ref_replace_slice<float>)> kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::ref_replace_slice);
auto functor = [&,
kernel,
arg1_shape,
out_shape,
lower_bounds,
upper_bounds,
strides,
arg0_buffer_index,
arg1_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[out_buffer_index],
arg1_shape,
lower_bounds,
upper_bounds,
strides,
out_shape);
};
functors.emplace_back(functor);
}
}
void register_builders_replace_slice_cpp() { REGISTER_OP_BUILDER(ReplaceSlice); }
......
......@@ -81,25 +81,25 @@ namespace ngraph
return;
}
if (arg_rank == 1)
if (arg_rank == 1 && is_optimized_et(result_element_type))
{
SELECT_KERNEL_BY_RANK(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_1d)
SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_1d);
}
else if (arg_rank == 2)
else if (arg_rank == 2 && is_optimized_et(result_element_type))
{
SELECT_KERNEL_BY_RANK(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_2d)
SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_2d);
}
else if (arg_rank == 3)
else if (arg_rank == 3 && is_optimized_et(result_element_type))
{
SELECT_KERNEL_BY_RANK(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_3d)
SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_3d);
}
else if (arg_rank == 4)
else if (arg_rank == 4 && is_optimized_et(result_element_type))
{
SELECT_KERNEL_BY_RANK(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_4d)
SELECT_ETS_AND_RANK7(
kernel, result_element_type, result_rank, runtime::cpu::kernel::reshape_4d);
}
else
{
......
......@@ -47,7 +47,7 @@ namespace ngraph
if (args[1].get_element_type() == element::i32)
{
SELECT_KERNEL_BY_RANK(kernel,
SELECT_KERNEL_ET_RANK(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::reverse_sequence_sli32)
......
......@@ -61,18 +61,18 @@ namespace ngraph
auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (is_int64)
if (is_int64 && is_optimized_et(args[0].get_element_type()))
{
if (inputs_shape.size() <= 3 && updates_shape.size() <= 5)
{
std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)>
kernel;
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i64)
SELECT_RANK35_ET4(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i64);
auto functor = [&,
kernel,
......@@ -100,18 +100,18 @@ namespace ngraph
throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd");
}
}
else
else if (is_optimized_et(args[0].get_element_type()))
{
if (inputs_shape.size() <= 3 && updates_shape.size() <= 5)
{
std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)>
kernel;
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i32)
SELECT_RANK35_ET4(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i32);
auto functor = [&,
kernel,
......@@ -139,6 +139,67 @@ namespace ngraph
throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd");
}
}
else if (is_int64)
{
std::function<decltype(runtime::cpu::kernel::ref_scatter_add_i64<float>)>
kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::ref_scatter_add_i64);
auto functor = [&,
kernel,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[inputs_buffer_index],
ctx->buffer_data[indices_buffer_index],
ctx->buffer_data[updates_buffer_index],
ctx->buffer_data[out_buffer_index],
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::ref_scatter_add_i32<float>)>
kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::ref_scatter_add_i32);
auto functor = [&,
kernel,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[inputs_buffer_index],
ctx->buffer_data[indices_buffer_index],
ctx->buffer_data[updates_buffer_index],
ctx->buffer_data[out_buffer_index],
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
functors.emplace_back(functor);
}
}
void register_builders_scatter_add_cpp() { REGISTER_OP_BUILDER(ScatterAdd); }
......
......@@ -133,15 +133,15 @@ namespace ngraph
}
else
{
if (is_strided(strides))
if (is_strided(strides) && is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::strided_slice<float, 2>)>
kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::strided_slice)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::strided_slice);
auto functor = [&,
kernel,
......@@ -164,14 +164,14 @@ namespace ngraph
};
functors.emplace_back(functor);
}
else
else if (is_optimized_et(args[0].get_element_type()))
{
std::function<decltype(runtime::cpu::kernel::slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::slice)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::slice);
auto functor = [&,
kernel,
......@@ -190,6 +190,31 @@ namespace ngraph
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::ref_slice<float>)> kernel;
SELECT_KERNEL(
kernel, args[0].get_element_type(), runtime::cpu::kernel::ref_slice);
auto functor = [&,
kernel,
arg_shape,
out_shape,
lower_bounds,
upper_bounds,
strides,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape,
lower_bounds,
upper_bounds,
strides,
out_shape);
};
functors.emplace_back(functor);
}
}
}
......
......@@ -75,17 +75,18 @@ namespace ngraph
ctx, softmax_index, deps, cpu::mkldnn_utils::OpType::SOFTMAX);
};
functors.emplace_back(functor);
return;
}
else
else if (is_optimized_et(args[0].get_element_type()))
{
if (axes.size() == arg_shape.size())
{
std::function<decltype(runtime::cpu::kernel::softmax_all<float, 1>)> kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_all)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_all);
auto functor = [&, kernel, arg_shape, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
......@@ -95,6 +96,7 @@ namespace ngraph
ectx->arena);
};
functors.emplace_back(functor);
return;
}
else if (axes.size() == 1)
{
......@@ -104,11 +106,10 @@ namespace ngraph
runtime::cpu::kernel::softmax_innermost_1rd<float, 1>)>
kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(
kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_innermost_1rd)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_innermost_1rd);
auto functor =
[&, kernel, arg_shape, arg_buffer_index, out_buffer_index](
......@@ -119,16 +120,17 @@ namespace ngraph
ectx->arena);
};
functors.emplace_back(functor);
return;
}
else
{
std::function<decltype(runtime::cpu::kernel::softmax_1rd<float, 1>)>
kernel;
PARTIAL_SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_1rd)
SELECT_ETS_AND_RANK7(kernel,
args[0].get_element_type(),
args[0].get_shape().size(),
runtime::cpu::kernel::softmax_1rd);
auto functor =
[&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
......@@ -140,15 +142,16 @@ namespace ngraph
ectx->arena);
};
functors.emplace_back(functor);
return;
}
}
else if (arg_shape.size() == 3 && axes.size() == 2)
{
std::function<decltype(runtime::cpu::kernel::softmax_3d_2rd<float>)> kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::softmax_3d_2rd)
SELECT_ETS(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::softmax_3d_2rd);
auto functor =
[&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
......@@ -160,14 +163,15 @@ namespace ngraph
ectx->arena);
};
functors.emplace_back(functor);
return;
}
else if (arg_shape.size() == 4 && axes.size() == 3)
{
std::function<decltype(runtime::cpu::kernel::softmax_4d_3rd<float>)> kernel;
SELECT_KERNEL(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::softmax_4d_3rd)
SELECT_ETS(kernel,
args[0].get_element_type(),
runtime::cpu::kernel::softmax_4d_3rd);
auto functor =
[&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
......@@ -179,28 +183,22 @@ namespace ngraph
ectx->arena);
};
functors.emplace_back(functor);
}
else if (softmax->get_element_type() == element::f32)
{
NGRAPH_WARN << "Falling back to refernce kernel for softmax " << arg_shape
<< " over " << axes;
auto functor = [&, arg_shape, axes, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /* ectx */) {
runtime::reference::softmax<float>(
static_cast<float*>(ctx->buffer_data[arg_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
arg_shape,
axes);
};
functors.emplace_back(functor);
}
else
{
NGRAPH_ERR << "Unsupported Softmax " << arg_shape << " over " << axes
<< " in cpu buiilder";
throw ngraph_error("Unsupported Softmax");
return;
}
}
NGRAPH_WARN << "Falling back to refernce kernel for softmax " << arg_shape
<< " over " << axes;
std::function<decltype(runtime::cpu::kernel::ref_softmax<float>)> kernel;
SELECT_KERNEL(
kernel, args[0].get_element_type(), runtime::cpu::kernel::ref_softmax);
auto functor = [&, kernel, arg_shape, axes, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* /*ectx*/) {
kernel(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index],
arg_shape,
axes);
};
functors.emplace_back(functor);
}
void register_builders_softmax_cpp() { REGISTER_OP_BUILDER(Softmax); }
......
......@@ -60,8 +60,8 @@ namespace ngraph
else
{
std::function<decltype(runtime::cpu::kernel::tile<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(
kernel, out[0].get_element_type(), arg_rank, runtime::cpu::kernel::tile)
SELECT_KERNEL_ET_RANK(
kernel, out[0].get_element_type(), arg_rank, runtime::cpu::kernel::tile);
auto functor =
[&, kernel, arg_shape, out_shape, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
......
......@@ -66,7 +66,7 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::strided_update_slice<float, 2>)>
kernel;
SELECT_KERNEL_BY_RANK(kernel,
SELECT_KERNEL_ET_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::strided_update_slice)
......@@ -98,7 +98,7 @@ namespace ngraph
{
std::function<decltype(runtime::cpu::kernel::update_slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
SELECT_KERNEL_ET_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::update_slice)
......
......@@ -22,6 +22,7 @@
#include "ngraph/node.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/kernel_selectors.hpp"
#define BUILDER_DECL(op_name) \
build<op_name>(CPU_ExternalFunction * external_function, \
......@@ -29,283 +30,6 @@
const std::vector<TensorViewWrapper>& args, \
const std::vector<TensorViewWrapper>& out)
// Per-type kernel macro
#define SELECT_KERNEL(KV, ET, K) \
if (ET == element::boolean) \
{ \
KV = K<char>; \
} \
else if (ET == element::f32) \
{ \
KV = K<float>; \
} \
else if (ET == element::f64) \
{ \
KV = K<double>; \
} \
else if (ET == element::i8) \
{ \
KV = K<int8_t>; \
} \
else if (ET == element::i16) \
{ \
KV = K<int16_t>; \
} \
else if (ET == element::i32) \
{ \
KV = K<int32_t>; \
} \
else if (ET == element::i64) \
{ \
KV = K<int64_t>; \
} \
else if (ET == element::u8) \
{ \
KV = K<uint8_t>; \
} \
else if (ET == element::u16) \
{ \
KV = K<uint16_t>; \
} \
else if (ET == element::u32) \
{ \
KV = K<uint32_t>; \
} \
else if (ET == element::u64) \
{ \
KV = K<uint64_t>; \
}
#define SELECT_KERNEL_3ARGS(KV, ET, K) \
if (ET == element::boolean) \
{ \
KV = K<char, char, char>; \
} \
else if (ET == element::f32) \
{ \
KV = K<float, float, float>; \
} \
else if (ET == element::f64) \
{ \
KV = K<double, double, double>; \
} \
else if (ET == element::i8) \
{ \
KV = K<int8_t, int8_t, int8_t>; \
} \
else if (ET == element::i16) \
{ \
KV = K<int16_t, int16_t, int16_t>; \
} \
else if (ET == element::i32) \
{ \
KV = K<int32_t, int32_t, int32_t>; \
} \
else if (ET == element::i64) \
{ \
KV = K<int64_t, int64_t, int64_t>; \
} \
else if (ET == element::u8) \
{ \
KV = K<uint8_t, uint8_t, uint8_t>; \
} \
else if (ET == element::u16) \
{ \
KV = K<uint16_t, uint16_t, uint16_t>; \
} \
else if (ET == element::u32) \
{ \
KV = K<uint32_t, uint32_t, uint32_t>; \
} \
else if (ET == element::u64) \
{ \
KV = K<uint64_t, uint64_t, uint64_t>; \
}
#define SELECT_RANK(KV, ET, R, K) \
if (R == 1) \
KV = K<ET, 1>; \
else if (R == 2) \
KV = K<ET, 2>; \
else if (R == 3) \
KV = K<ET, 3>; \
else if (R == 4) \
KV = K<ET, 4>; \
else if (R == 5) \
KV = K<ET, 5>; \
else if (R == 6) \
KV = K<ET, 6>; \
else if (R == 7) \
KV = K<ET, 7>; \
else \
throw ngraph_error("Unsupported rank " + std::to_string(R) + " for kernel " #K);
#define SELECT_RANK2(KV, IT, OT, R, K) \
switch (R) \
{ \
case 1: KV = K<IT, OT, 1>; break; \
case 2: KV = K<IT, OT, 2>; break; \
case 3: KV = K<IT, OT, 3>; break; \
case 4: KV = K<IT, OT, 4>; break; \
case 5: KV = K<IT, OT, 5>; break; \
case 6: KV = K<IT, OT, 6>; break; \
case 7: KV = K<IT, OT, 7>; break; \
default: throw ngraph_error("Unsupported rank " + std::to_string(R) + " for kernel " #K); \
}
// Per-type and rank kernel macro
#define SELECT_KERNEL_BY_RANK(KV, ET, R, K) \
if (ET == element::boolean) \
{ \
SELECT_RANK(KV, char, R, K); \
} \
else if (ET == element::f32) \
{ \
SELECT_RANK(KV, float, R, K); \
} \
else if (ET == element::f64) \
{ \
SELECT_RANK(KV, double, R, K); \
} \
else if (ET == element::i8) \
{ \
SELECT_RANK(KV, int8_t, R, K); \
} \
else if (ET == element::i16) \
{ \
SELECT_RANK(KV, int16_t, R, K); \
} \
else if (ET == element::i32) \
{ \
SELECT_RANK(KV, int32_t, R, K); \
} \
else if (ET == element::i64) \
{ \
SELECT_RANK(KV, int64_t, R, K); \
} \
else if (ET == element::u8) \
{ \
SELECT_RANK(KV, uint8_t, R, K); \
} \
else if (ET == element::u16) \
{ \
SELECT_RANK(KV, uint16_t, R, K); \
} \
else if (ET == element::u32) \
{ \
SELECT_RANK(KV, uint32_t, R, K); \
} \
else if (ET == element::u64) \
{ \
SELECT_RANK(KV, uint64_t, R, K); \
} \
else \
{ \
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K); \
}
#define SELECT_RANK1(KV, ET, R1, R2, K) \
if (R1 == 1) \
KV = K<ET, 1, R2>; \
else if (R1 == 2) \
KV = K<ET, 2, R2>; \
else if (R1 == 3) \
KV = K<ET, 3, R2>; \
else \
throw ngraph_error("Unsupported first rank " + std::to_string(R1) + " for kernel " #K);
#define SELECT_2RANKS(KV, ET, R1, R2, K) \
if (R2 == 1) \
{ \
SELECT_RANK1(KV, ET, R1, 1, K); \
} \
else if (R2 == 2) \
{ \
SELECT_RANK1(KV, ET, R1, 2, K); \
} \
else if (R2 == 3) \
{ \
SELECT_RANK1(KV, ET, R1, 3, K); \
} \
else if (R2 == 4) \
{ \
SELECT_RANK1(KV, ET, R1, 4, K); \
} \
else if (R2 == 5) \
{ \
SELECT_RANK1(KV, ET, R1, 5, K); \
} \
else \
{ \
throw ngraph_error("Unsupported second rank " + std::to_string(R2) + " for kernel " #K); \
}
// Per-type and ranks kernel macro
#define SELECT_KERNEL_BY_2RANKS(KV, ET, R1, R2, K) \
if (ET == element::f32) \
{ \
SELECT_2RANKS(KV, float, R1, R2, K); \
} \
else if (ET == element::f64) \
{ \
SELECT_2RANKS(KV, double, R1, R2, K); \
} \
else if (ET == element::u8) \
{ \
SELECT_2RANKS(KV, uint8_t, R1, R2, K); \
} \
else if (ET == element::i8) \
{ \
SELECT_2RANKS(KV, int8_t, R1, R2, K); \
} \
else \
{ \
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K); \
}
// Helper macros for a partial set of element types and ranks
// Useful for keeping compilation time and memory usage reasonable
// when the computed expression is complex
#define PARTIAL_SELECT_RANK(KV, ET, R, K) \
if (R == 1) \
KV = K<ET, 1>; \
else if (R == 2) \
KV = K<ET, 2>; \
else if (R == 3) \
KV = K<ET, 3>; \
else if (R == 4) \
KV = K<ET, 4>; \
else if (R == 5) \
KV = K<ET, 5>; \
else if (R == 6) \
KV = K<ET, 6>; \
else \
throw ngraph_error("Unsupported rank " + std::to_string(R) + " for kernel " #K);
// Partial per-type and rank kernel macro
#define PARTIAL_SELECT_KERNEL_BY_RANK(KV, ET, R, K) \
if (ET == element::f32) \
{ \
PARTIAL_SELECT_RANK(KV, float, R, K); \
} \
else if (ET == element::f64) \
{ \
PARTIAL_SELECT_RANK(KV, double, R, K); \
} \
else if (ET == element::i8) \
{ \
PARTIAL_SELECT_RANK(KV, int8_t, R, K); \
} \
else if (ET == element::u8) \
{ \
PARTIAL_SELECT_RANK(KV, uint8_t, R, K); \
} \
else \
{ \
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K); \
}
#define BUILD_UNARY_ELEMWISE_FUNCTOR(OP) \
(void)node; \
auto& functors = external_function->get_functors(); \
......
......@@ -21,6 +21,7 @@
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
......@@ -100,6 +101,26 @@ namespace ngraph
.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = in1;
}
template <typename ElementType>
void ref_replace_slice(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& slice_strides,
const Shape& output_shape)
{
reference::replace_slice<ElementType>(static_cast<const ElementType*>(input0),
static_cast<const ElementType*>(input1),
static_cast<ElementType*>(output),
input0_shape,
lower_bounds,
upper_bounds,
slice_strides,
output_shape);
}
}
}
}
......
......@@ -21,6 +21,7 @@
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
......@@ -168,6 +169,46 @@ namespace ngraph
updates_shape,
arena);
}
template <typename ElementType>
void ref_scatter_add_i32(void* inputs,
void* indices,
void* updates,
void* output,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
const Shape& output_shape)
{
reference::scatter_add<ElementType, int32_t>(static_cast<ElementType*>(inputs),
static_cast<int32_t*>(indices),
static_cast<ElementType*>(updates),
static_cast<ElementType*>(output),
inputs_shape,
indices_shape,
updates_shape,
output_shape);
}
template <typename ElementType>
void ref_scatter_add_i64(void* inputs,
void* indices,
void* updates,
void* output,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
const Shape& output_shape)
{
reference::scatter_add<ElementType, int64_t>(static_cast<ElementType*>(inputs),
static_cast<int64_t*>(indices),
static_cast<ElementType*>(updates),
static_cast<ElementType*>(output),
inputs_shape,
indices_shape,
updates_shape,
output_shape);
}
}
}
}
......
......@@ -21,6 +21,7 @@
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
......@@ -88,6 +89,24 @@ namespace ngraph
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) =
in.stridedSlice(start_indices, stop_indices, strides);
}
template <typename ElementType>
void ref_slice(void* input,
void* output,
const Shape& input_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& slice_strides,
const Shape& output_shape)
{
reference::slice<ElementType>(static_cast<const ElementType*>(input),
static_cast<ElementType*>(output),
input_shape,
lower_bounds,
upper_bounds,
slice_strides,
output_shape);
}
}
}
}
......
......@@ -21,6 +21,7 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
......@@ -164,6 +165,15 @@ namespace ngraph
{
softmax<ElementType, 4, 3>(input, output, input_shape, softmax_axes, arena);
}
template <typename ElementType>
void ref_softmax(void* input, void* output, const Shape& shape, const AxisSet& axes)
{
reference::softmax<ElementType>(static_cast<const ElementType*>(input),
static_cast<ElementType*>(output),
shape,
axes);
}
}
}
}
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
// VS compiler treats __VA_ARGS__ as a single token argument rather than expanding it
// so the macro below is a workaround to that
#define EXPAND_MACRO(S) S // VS compiler workaround
// Selector Macros for builders to instantiate and pick kernels
// All element types, ranks. Use for small/simple kernels
#define SELECT_KERNEL(KV, ET, K) EXPAND_ET11_FIXED_ARGS(K, KV, ET, KERNEL_CT)
#define SELECT_KERNEL_3ARGS(KV, ET, K) EXPAND_ET11_FIXED_ARGS(K, KV, ET, KERNEL_CT_CT_CT)
#define SELECT_KERNEL_RANK(KV, CIT, COT, R, K) EXPAND_RANK7(K, KV, R, KERNEL_CIT_COT_R, CIT, COT)
#define SELECT_KERNEL_ET_RANK(KV, ET, R, K) EXPAND_ET11_AND_RANK7(K, KV, ET, R, KERNEL_CT_R)
// Subset of element types and ranks. Use for more complex/larger kernels
#define SELECT_RANK35_ET4(KV, ET, R1, R2, K) \
EXPAND_RANK35_AND_ET4(K, KV, R1, R2, ET, KERNEL_CT_R1_R2)
// Configurable at build using NGRAPH_CPU_OPTIMIZED_DATATYPES
#define SELECT_ETS(KV, ET, K) EXPAND_ETS(K, KV, ET, KERNEL_CT)
#define SELECT_ETS_AND_RANK7(KV, ET, R, K) EXPAND_ETS_AND_RANK7(K, KV, ET, R, KERNEL_CT_R)
// Macros for instantiating templated kernels
#define KERNEL_CT(K, KV, CT) KV = K<CT>
#define KERNEL_CT_CT_CT(K, KV, CT) KV = K<CT, CT, CT>
#define KERNEL_CT_R(K, KV, CT, R) KV = K<CT, R>
#define KERNEL_CIT_COT_R(K, KV, CIT, COT, R) KV = K<CIT, COT, R>
#define KERNEL_CT_R1_R2(K, KV, R1, R2, CT) KV = K<CT, R1, R2>
// Helper macros
#define EXPAND_ET11_AND_RANK7(K, KV, ET, R, S, ...) \
EXPAND_ET11(K, KV, ET, EXPAND_RANK7, R, S, ##__VA_ARGS__)
#define EXPAND_RANK5_AND_ET4(K, KV, R, ET, S, ...) \
EXPAND_RANK5(K, KV, R, EXPAND_ET4, ET, S, ##__VA_ARGS__)
#define EXPAND_RANK35_AND_ET4(K, KV, R1, R2, ET, S) \
EXPAND_RANK3(K, KV, R1, EXPAND_RANK5_AND_ET4, R2, ET, S)
#define EXPAND_ETS_AND_RANK7(K, KV, ET, R, S, ...) \
EXPAND_ETS(K, KV, ET, EXPAND_RANK7, R, S, ##__VA_ARGS__)
// Expander Macros that instantiate kernels for various element types and ranks
#define EXPAND_ET4(K, KV, ET, S, ...) \
if (ET == element::f32) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, float)); \
} \
else if (ET == element::f64) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, double)); \
} \
else if (ET == element::i8) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, int8_t)); \
} \
else if (ET == element::u8) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, uint8_t)); \
} \
else \
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K);
#define EXPAND_ET11(K, KV, ET, S, ...) \
if (ET == element::boolean) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, char)); \
} \
else if (ET == element::f32) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, float)); \
} \
else if (ET == element::f64) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, double)); \
} \
else if (ET == element::i8) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, int8_t)); \
} \
else if (ET == element::i16) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, int16_t)); \
} \
else if (ET == element::i32) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, int32_t)); \
} \
else if (ET == element::i64) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, int64_t)); \
} \
else if (ET == element::u8) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, uint8_t)); \
} \
else if (ET == element::u16) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, uint16_t)); \
} \
else if (ET == element::u32) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, uint32_t)); \
} \
else if (ET == element::u64) \
{ \
EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, uint64_t)); \
} \
else \
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K);
// Workaround since VS compiler doesn;t work well with variadic macros.
// EXPAND_ET11 Takes variable arguments and since SELECT_KERNEL & SELECT_KERNEL_3ARGS
// call into that macro without variable args, the VS compiler expands them as
// KERNEL_CT(K, KV, ,) thus giving a syntax error. The VS compiler doesn't deal well with
// igonring comma. Hence we have replicated EXPAND_ET11 and added EXPAND_ET11_FIXED_ARGS
// for calls that dont have variable args.
#define EXPAND_ET11_FIXED_ARGS(K, KV, ET, S) \
if (ET == element::boolean) \
{ \
EXPAND_MACRO(S(K, KV, char)); \
} \
else if (ET == element::f32) \
{ \
EXPAND_MACRO(S(K, KV, float)); \
} \
else if (ET == element::f64) \
{ \
EXPAND_MACRO(S(K, KV, double)); \
} \
else if (ET == element::i8) \
{ \
EXPAND_MACRO(S(K, KV, int8_t)); \
} \
else if (ET == element::i16) \
{ \
EXPAND_MACRO(S(K, KV, int16_t)); \
} \
else if (ET == element::i32) \
{ \
EXPAND_MACRO(S(K, KV, int32_t)); \
} \
else if (ET == element::i64) \
{ \
EXPAND_MACRO(S(K, KV, int64_t)); \
} \
else if (ET == element::u8) \
{ \
EXPAND_MACRO(S(K, KV, uint8_t)); \
} \
else if (ET == element::u16) \
{ \
EXPAND_MACRO(S(K, KV, uint16_t)); \
} \
else if (ET == element::u32) \
{ \
EXPAND_MACRO(S(K, KV, uint32_t)); \
} \
else if (ET == element::u64) \
{ \
EXPAND_MACRO(S(K, KV, uint64_t)); \
} \
else \
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K);
// Expand only selected datatypes. Named macros (e.g., F32_SELECT) are expanded based on build-flags
#define EXPAND_ETS(K, KV, ET, S, ...) \
if (BOOLEAN_EN && ET == element::boolean) \
{ \
BOOLEAN_SELECT(S, K, KV, ##__VA_ARGS__, char); \
} \
else if (F32_EN && ET == element::f32) \
{ \
F32_SELECT(S, K, KV, ##__VA_ARGS__, float); \
} \
else if (F64_EN && ET == element::f64) \
{ \
F64_SELECT(S, K, KV, ##__VA_ARGS__, double); \
} \
else if (I8_EN && ET == element::i8) \
{ \
I8_SELECT(S, K, KV, ##__VA_ARGS__, int8_t); \
} \
else if (I16_EN && ET == element::i16) \
{ \
I16_SELECT(S, K, KV, ##__VA_ARGS__, int16_t); \
} \
else if (I32_EN && ET == element::i32) \
{ \
I32_SELECT(S, K, KV, ##__VA_ARGS__, int32_t); \
} \
else if (I64_EN && ET == element::i64) \
{ \
I64_SELECT(S, K, KV, ##__VA_ARGS__, int64_t); \
} \
else if (U8_EN && ET == element::u8) \
{ \
U8_SELECT(S, K, KV, ##__VA_ARGS__, uint8_t); \
} \
else if (U16_EN && ET == element::u16) \
{ \
U16_SELECT(S, K, KV, ##__VA_ARGS__, uint16_t); \
} \
else if (U32_EN && ET == element::u32) \
{ \
U32_SELECT(S, K, KV, ##__VA_ARGS__, uint32_t); \
} \
else if (U64_EN && ET == element::u64) \
{ \
U64_SELECT(S, K, KV, ##__VA_ARGS__, uint64_t); \
} \
else \
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K);
#define EXPAND_RANK3(K, KV, R, S, ...) \
switch (R) \
{ \
case 1: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 1)); break; \
case 2: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 2)); break; \
case 3: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 3)); break; \
default: throw ngraph_error("Unsupported rank " + std::to_string(R) + " for kernel " #K); \
}
#define EXPAND_RANK5(K, KV, R, S, ...) \
switch (R) \
{ \
case 1: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 1)); break; \
case 2: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 2)); break; \
case 3: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 3)); break; \
case 4: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 4)); break; \
case 5: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 5)); break; \
default: throw ngraph_error("Unsupported rank " + std::to_string(R) + " for kernel " #K); \
}
#define EXPAND_RANK7(K, KV, R, S, ...) \
switch (R) \
{ \
case 1: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 1)); break; \
case 2: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 2)); break; \
case 3: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 3)); break; \
case 4: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 4)); break; \
case 5: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 5)); break; \
case 6: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 6)); break; \
case 7: EXPAND_MACRO(S(K, KV, ##__VA_ARGS__, 7)); break; \
default: throw ngraph_error("Unsupported rank " + std::to_string(R) + " for kernel " #K); \
}
#if defined(NGRAPH_CPU_OPTIMIZE_boolean)
#define BOOLEAN_EN 1
#define BOOLEAN_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define BOOLEAN_EN 0
#define BOOLEAN_SELECT(S, ...)
#endif
#if defined(NGRAPH_CPU_OPTIMIZE_f32)
#define F32_EN 1
#define F32_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define F32_EN 0
#define F32_SELECT(S, ...)
#endif
#if defined(NGRAPH_CPU_OPTIMIZE_f64)
#define F64_EN 1
#define F64_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define F64_EN 0
#define F64_SELECT(S, ...)
#endif
#if defined(NGRAPH_CPU_OPTIMIZE_i8)
#define I8_EN 1
#define I8_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define I8_EN 0
#define I8_SELECT(S, ...)
#endif
#if defined(NGRAPH_CPU_OPTIMIZE_i16)
#define I16_EN 1
#define I16_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define I16_EN 0
#define I16_SELECT(S, ...)
#endif
#if defined(NGRAPH_CPU_OPTIMIZE_i32)
#define I32_EN 1
#define I32_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define I32_EN 0
#define I32_SELECT(S, ...)
#endif
#if defined(NGRAPH_CPU_OPTIMIZE_i64)
#define I64_EN 1
#define I64_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define I64_EN 0
#define I64_SELECT(S, ...)
#endif
#if defined(NGRAPH_CPU_OPTIMIZE_u8)
#define U8_EN 1
#define U8_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define U8_EN 0
#define U8_SELECT(S, ...)
#endif
#if defined(NGRAPH_CPU_OPTIMIZE_u16)
#define U16_EN 1
#define U16_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define U16_EN 0
#define U16_SELECT(S, ...)
#endif
#if defined(NGRAPH_CPU_OPTIMIZE_u32)
#define U32_EN 1
#define U32_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define U32_EN 0
#define U32_SELECT(S, ...)
#endif
#if defined(NGRAPH_CPU_OPTIMIZE_u64)
#define U64_EN 1
#define U64_SELECT(S, ...) EXPAND_MACRO(S(__VA_ARGS__))
#else
#define U64_EN 0
#define U64_SELECT(S, ...)
#endif
static inline bool is_optimized_et(const ngraph::element::Type& et)
{
if ((et == ngraph::element::boolean && BOOLEAN_EN) || (et == ngraph::element::f32 && F32_EN) ||
(et == ngraph::element::f64 && F64_EN) || (et == ngraph::element::i8 && I8_EN) ||
(et == ngraph::element::i16 && I16_EN) || (et == ngraph::element::i32 && I32_EN) ||
(et == ngraph::element::i64 && I64_EN) || (et == ngraph::element::u8 && U8_EN) ||
(et == ngraph::element::u16 && U16_EN) || (et == ngraph::element::u32 && U32_EN) ||
(et == ngraph::element::u64 && U64_EN))
{
return true;
}
else
{
return false;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment