Commit e327fe57 authored by Sergey Shalnov's avatar Sergey Shalnov Committed by Robert Kimball

IntelGPU backend: Min and Max operations fix (#2357)

* IntelGPU backend: Min and Max operations fix

* IntelGPU backend: PR2357 style changed if/else to switch/case

* IntelGPU backend: PR2357 passed tests adjusted

* IntelGPU backend: PR2357 failed test removed
parent c19e48a6
......@@ -191,13 +191,16 @@ void runtime::intelgpu::do_max_min_operation(cldnn::topology& topology,
{
const string function_name = "min_max_" + output_name;
const size_t input_size = shape_size<Shape>(input_shape);
const string& init_value = is_min ? "INFINITY" : "-INFINITY";
const string& init_value = get_opencl_type_min_max_value(output_type, !is_min);
const string& operation = is_min ? " < " : " > ";
codegen::CodeWriter writer;
writer << "__kernel void " << function_name << "(const __global float input"
<< array_dims(input_shape) << ", __global float output" << array_dims(output_shape)
<< ")\n";
runtime::intelgpu::gen_func_def(writer,
function_name,
{get_opencl_type_name(output_type)},
{input_shape},
get_opencl_type_name(output_type),
output_shape);
writer.block_begin();
{
......@@ -231,11 +234,11 @@ void runtime::intelgpu::do_max_min_operation(cldnn::topology& topology,
++var_idx;
}
writer << "if (input" << access_dims(input_shape) << operation << "output"
writer << "if (input0" << access_dims(input_shape) << operation << "output"
<< access_dims(input_shape, "i", axis) << ")\n";
writer.block_begin();
{
writer << "output" << access_dims(input_shape, "i", axis) << " = input"
writer << "output" << access_dims(input_shape, "i", axis) << " = input0"
<< access_dims(input_shape) << ";\n";
}
writer.block_end();
......
......@@ -28,34 +28,38 @@ using namespace ngraph;
string runtime::intelgpu::get_opencl_type_name(const element::Type& ngraph_type)
{
if (ngraph_type == ngraph::element::i64)
switch (ngraph_type.get_type_enum())
{
return "long";
case element::Type_t::i64: return "long";
case element::Type_t::i32: return "int";
case element::Type_t::i16: return "short";
case element::Type_t::u16: return "ushort";
case element::Type_t::i8: return "char";
case element::Type_t::u8: return "uchar";
}
else if (ngraph_type == ngraph::element::i32)
{
return "int";
}
else if (ngraph_type == ngraph::element::i16)
{
return "short";
}
else if (ngraph_type == ngraph::element::u16)
{
return "ushort";
}
else if (ngraph_type == ngraph::element::i8)
{
return "char";
}
else if (ngraph_type == ngraph::element::u8)
{
return "uchar";
}
else
return ngraph_type.c_type_string();
}
string runtime::intelgpu::get_opencl_type_min_max_value(const element::Type& ngraph_type,
bool is_min)
{
switch (ngraph_type.get_type_enum())
{
return ngraph_type.c_type_string();
case element::Type_t::f32: return is_min ? "-INFINITY" : "INFINITY";
case element::Type_t::f64: return is_min ? "-INFINITY" : "INFINITY";
case element::Type_t::i64: return is_min ? "LONG_MIN" : "LONG_MAX";
case element::Type_t::u64: return is_min ? "0" : "ULONG_MAX";
case element::Type_t::i32: return is_min ? "INT_MIN" : "INT_MAX";
case element::Type_t::u32: return is_min ? "0" : "UINT_MAX";
case element::Type_t::i16: return is_min ? "SHRT_MIN" : "SHRT_MAX";
case element::Type_t::u16: return is_min ? "0" : "USHRT_MAX";
case element::Type_t::i8: return is_min ? "CHAR_MIN" : "CHAR_MAX";
case element::Type_t::u8: return is_min ? "0" : "UCHAR_MAX";
}
throw ngraph_error("Unsupported type '" + ngraph_type.c_type_string() +
"' in runtime::intelgpu::get_opencl_type_min_max_value()");
}
vector<cldnn_arg> runtime::intelgpu::get_kernel_args(size_t input, size_t output)
......
......@@ -197,6 +197,8 @@ namespace ngraph
// Helper functions used in cldnn::custom_gpu_primitive kernels
std::string get_opencl_type_name(const element::Type& ngraph_type);
std::string get_opencl_type_min_max_value(const element::Type& ngraph_type,
bool is_min);
std::vector<cldnn_arg> get_kernel_args(size_t input, size_t output);
std::string array_dims(const Shape& dimentions, const AxisSet& axis = {});
std::string access_dims(const Shape& dimentions,
......
all_2x2x3_eliminate_dims_0_1
argmax_4D_axis_3_i64_in_i32
argmin_trivial_in_double
argmin_trivial_in_i32
avg_pool_2d_2channel_2image_padded_only_above_do_not_include_in_computation
avg_pool_2d_2channel_2image_padded_only_above_include_in_computation
avg_pool_3d_uneven_strided_padded
backwards_batch_norm_three_outputs
backwards_batch_norm_training
backwards_dot_scalar_tensor
backwards_dot_tensor3_tensor3
backwards_dot_tensor_scalar
backwards_dot_tensor_vector
backwards_exp
backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_maxpool_n4_c1_hw4_2x2_max
backwards_replace_slice
backwards_reverse_sequence_n3_c2_h3
backwards_reverse_sequence_n4d2c3h2w2
backwards_slice
backwards_tanh
batch_norm_bprop_n4c3h2w2
batch_norm_inference_0eps_f64
batch_norm_inference_f64
batch_norm_training_0eps_f64
batch_norm_one_output
batch_norm_three_outputs
batch_norm_bprop_n4c3h2w2
dequantize
dequantize_axes
dequantize_dynamic_offset
......@@ -30,10 +28,11 @@ dequantize_int8
dequantize_int8_zero_offset
dequantize_zero_offset
divide_by_zero_int32
embedding_lookup_4x5_reverse
embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_4x5_reverse
generate_mask
max_3d_to_scalar_double
max_pool_3d
numeric_double_inf
numeric_double_nan
......@@ -74,6 +73,9 @@ shape_of_scalar
shape_of_vector
softmax_axis_3d_double
sum_stable_acc
sum_stable_acc_double
sum_stable_simple_double
sum_trivial_in_double
topk_1d_max_all
topk_1d_max_one
topk_1d_max_partial
......@@ -127,21 +129,3 @@ zero_sized_sqrt
zero_sized_subtract
zero_sized_tan
zero_sized_tanh
shape_of_scalar
shape_of_vector
shape_of_matrix
shape_of_5d
sum_stable_acc
sum_stable_acc_double
sum_stable_simple_double
sum_trivial_in_double
max_matrix_rows_zero_int32
max_to_scalar_int8
min_to_scalar_int8
max_3d_to_scalar_double
argmin_trivial_in_i32
argmax_4D_axis_3_i64_in_i32
argmin_trivial_in_double
all_2x2x3_eliminate_dim_1
all_2x2x3_eliminate_dim_2
all_2x2x3_eliminate_dims_0_1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment