Commit e327fe57 authored by Sergey Shalnov's avatar Sergey Shalnov Committed by Robert Kimball

IntelGPU backend: Min and Max operations fix (#2357)

* IntelGPU backend: Min and Max operations fix

* IntelGPU backend: PR2357 style changed if/else to switch/case

* IntelGPU backend: PR2357 passed tests adjusted

* IntelGPU backend: PR2357 failed test removed
parent c19e48a6
...@@ -191,13 +191,16 @@ void runtime::intelgpu::do_max_min_operation(cldnn::topology& topology, ...@@ -191,13 +191,16 @@ void runtime::intelgpu::do_max_min_operation(cldnn::topology& topology,
{ {
const string function_name = "min_max_" + output_name; const string function_name = "min_max_" + output_name;
const size_t input_size = shape_size<Shape>(input_shape); const size_t input_size = shape_size<Shape>(input_shape);
const string& init_value = is_min ? "INFINITY" : "-INFINITY"; const string& init_value = get_opencl_type_min_max_value(output_type, !is_min);
const string& operation = is_min ? " < " : " > "; const string& operation = is_min ? " < " : " > ";
codegen::CodeWriter writer; codegen::CodeWriter writer;
writer << "__kernel void " << function_name << "(const __global float input" runtime::intelgpu::gen_func_def(writer,
<< array_dims(input_shape) << ", __global float output" << array_dims(output_shape) function_name,
<< ")\n"; {get_opencl_type_name(output_type)},
{input_shape},
get_opencl_type_name(output_type),
output_shape);
writer.block_begin(); writer.block_begin();
{ {
...@@ -231,11 +234,11 @@ void runtime::intelgpu::do_max_min_operation(cldnn::topology& topology, ...@@ -231,11 +234,11 @@ void runtime::intelgpu::do_max_min_operation(cldnn::topology& topology,
++var_idx; ++var_idx;
} }
writer << "if (input" << access_dims(input_shape) << operation << "output" writer << "if (input0" << access_dims(input_shape) << operation << "output"
<< access_dims(input_shape, "i", axis) << ")\n"; << access_dims(input_shape, "i", axis) << ")\n";
writer.block_begin(); writer.block_begin();
{ {
writer << "output" << access_dims(input_shape, "i", axis) << " = input" writer << "output" << access_dims(input_shape, "i", axis) << " = input0"
<< access_dims(input_shape) << ";\n"; << access_dims(input_shape) << ";\n";
} }
writer.block_end(); writer.block_end();
......
...@@ -28,34 +28,38 @@ using namespace ngraph; ...@@ -28,34 +28,38 @@ using namespace ngraph;
string runtime::intelgpu::get_opencl_type_name(const element::Type& ngraph_type) string runtime::intelgpu::get_opencl_type_name(const element::Type& ngraph_type)
{ {
if (ngraph_type == ngraph::element::i64) switch (ngraph_type.get_type_enum())
{ {
return "long"; case element::Type_t::i64: return "long";
case element::Type_t::i32: return "int";
case element::Type_t::i16: return "short";
case element::Type_t::u16: return "ushort";
case element::Type_t::i8: return "char";
case element::Type_t::u8: return "uchar";
} }
else if (ngraph_type == ngraph::element::i32)
{ return ngraph_type.c_type_string();
return "int"; }
}
else if (ngraph_type == ngraph::element::i16) string runtime::intelgpu::get_opencl_type_min_max_value(const element::Type& ngraph_type,
{ bool is_min)
return "short"; {
} switch (ngraph_type.get_type_enum())
else if (ngraph_type == ngraph::element::u16)
{
return "ushort";
}
else if (ngraph_type == ngraph::element::i8)
{
return "char";
}
else if (ngraph_type == ngraph::element::u8)
{
return "uchar";
}
else
{ {
return ngraph_type.c_type_string(); case element::Type_t::f32: return is_min ? "-INFINITY" : "INFINITY";
case element::Type_t::f64: return is_min ? "-INFINITY" : "INFINITY";
case element::Type_t::i64: return is_min ? "LONG_MIN" : "LONG_MAX";
case element::Type_t::u64: return is_min ? "0" : "ULONG_MAX";
case element::Type_t::i32: return is_min ? "INT_MIN" : "INT_MAX";
case element::Type_t::u32: return is_min ? "0" : "UINT_MAX";
case element::Type_t::i16: return is_min ? "SHRT_MIN" : "SHRT_MAX";
case element::Type_t::u16: return is_min ? "0" : "USHRT_MAX";
case element::Type_t::i8: return is_min ? "CHAR_MIN" : "CHAR_MAX";
case element::Type_t::u8: return is_min ? "0" : "UCHAR_MAX";
} }
throw ngraph_error("Unsupported type '" + ngraph_type.c_type_string() +
"' in runtime::intelgpu::get_opencl_type_min_max_value()");
} }
vector<cldnn_arg> runtime::intelgpu::get_kernel_args(size_t input, size_t output) vector<cldnn_arg> runtime::intelgpu::get_kernel_args(size_t input, size_t output)
......
...@@ -197,6 +197,8 @@ namespace ngraph ...@@ -197,6 +197,8 @@ namespace ngraph
// Helper functions used in cldnn::custom_gpu_primitive kernels // Helper functions used in cldnn::custom_gpu_primitive kernels
std::string get_opencl_type_name(const element::Type& ngraph_type); std::string get_opencl_type_name(const element::Type& ngraph_type);
std::string get_opencl_type_min_max_value(const element::Type& ngraph_type,
bool is_min);
std::vector<cldnn_arg> get_kernel_args(size_t input, size_t output); std::vector<cldnn_arg> get_kernel_args(size_t input, size_t output);
std::string array_dims(const Shape& dimentions, const AxisSet& axis = {}); std::string array_dims(const Shape& dimentions, const AxisSet& axis = {});
std::string access_dims(const Shape& dimentions, std::string access_dims(const Shape& dimentions,
......
all_2x2x3_eliminate_dims_0_1
argmax_4D_axis_3_i64_in_i32
argmin_trivial_in_double
argmin_trivial_in_i32
avg_pool_2d_2channel_2image_padded_only_above_do_not_include_in_computation avg_pool_2d_2channel_2image_padded_only_above_do_not_include_in_computation
avg_pool_2d_2channel_2image_padded_only_above_include_in_computation avg_pool_2d_2channel_2image_padded_only_above_include_in_computation
avg_pool_3d_uneven_strided_padded avg_pool_3d_uneven_strided_padded
backwards_batch_norm_three_outputs
backwards_batch_norm_training backwards_batch_norm_training
backwards_dot_scalar_tensor backwards_dot_scalar_tensor
backwards_dot_tensor3_tensor3
backwards_dot_tensor_scalar backwards_dot_tensor_scalar
backwards_dot_tensor_vector backwards_dot_tensor_vector
backwards_exp
backwards_maxpool_n2_c1_hw5_3x3_str2_max backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_maxpool_n4_c1_hw4_2x2_max backwards_maxpool_n4_c1_hw4_2x2_max
backwards_replace_slice backwards_replace_slice
backwards_reverse_sequence_n3_c2_h3 backwards_reverse_sequence_n3_c2_h3
backwards_reverse_sequence_n4d2c3h2w2 backwards_reverse_sequence_n4d2c3h2w2
backwards_slice backwards_slice
backwards_tanh batch_norm_bprop_n4c3h2w2
batch_norm_inference_0eps_f64 batch_norm_inference_0eps_f64
batch_norm_inference_f64 batch_norm_inference_f64
batch_norm_training_0eps_f64 batch_norm_training_0eps_f64
batch_norm_one_output
batch_norm_three_outputs
batch_norm_bprop_n4c3h2w2
dequantize dequantize
dequantize_axes dequantize_axes
dequantize_dynamic_offset dequantize_dynamic_offset
...@@ -30,10 +28,11 @@ dequantize_int8 ...@@ -30,10 +28,11 @@ dequantize_int8
dequantize_int8_zero_offset dequantize_int8_zero_offset
dequantize_zero_offset dequantize_zero_offset
divide_by_zero_int32 divide_by_zero_int32
embedding_lookup_4x5_reverse
embedding_lookup_10x1_arbitrary embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_4x5_reverse
generate_mask generate_mask
max_3d_to_scalar_double
max_pool_3d max_pool_3d
numeric_double_inf numeric_double_inf
numeric_double_nan numeric_double_nan
...@@ -74,6 +73,9 @@ shape_of_scalar ...@@ -74,6 +73,9 @@ shape_of_scalar
shape_of_vector shape_of_vector
softmax_axis_3d_double softmax_axis_3d_double
sum_stable_acc sum_stable_acc
sum_stable_acc_double
sum_stable_simple_double
sum_trivial_in_double
topk_1d_max_all topk_1d_max_all
topk_1d_max_one topk_1d_max_one
topk_1d_max_partial topk_1d_max_partial
...@@ -127,21 +129,3 @@ zero_sized_sqrt ...@@ -127,21 +129,3 @@ zero_sized_sqrt
zero_sized_subtract zero_sized_subtract
zero_sized_tan zero_sized_tan
zero_sized_tanh zero_sized_tanh
shape_of_scalar
shape_of_vector
shape_of_matrix
shape_of_5d
sum_stable_acc
sum_stable_acc_double
sum_stable_simple_double
sum_trivial_in_double
max_matrix_rows_zero_int32
max_to_scalar_int8
min_to_scalar_int8
max_3d_to_scalar_double
argmin_trivial_in_i32
argmax_4D_axis_3_i64_in_i32
argmin_trivial_in_double
all_2x2x3_eliminate_dim_1
all_2x2x3_eliminate_dim_2
all_2x2x3_eliminate_dims_0_1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment