Commit f8632ea0 authored by Sergey Shalnov's avatar Sergey Shalnov Committed by Sang Ik Lee

IntelGPU backend: Max and Avg pool fix (#2482)

parent fd0ed37c
......@@ -765,14 +765,40 @@ shared_ptr<runtime::Executable>
}
case OP_TYPEID::MaxPool:
{
arguments_check(op, 1, 1);
const shared_ptr<op::MaxPool> max_pool = static_pointer_cast<op::MaxPool>(op);
if ((get_input_shape(op).size() > 4) || (get_output_type(op) != element::f32) ||
!max_pool->get_padding_below().empty() || !max_pool->get_padding_above().empty())
{
const shared_ptr<Node> def_val = max_pool->get_default_value();
const shared_ptr<op::Constant> def_const =
static_pointer_cast<op::Constant>(def_val);
const vector<std::string>& values = def_const->get_value_strings();
do_max_avg_pool_operation(topology,
get_input_name(op),
get_input_shape(op),
get_output_name(op),
get_output_shape(op),
get_output_type(op),
max_pool->get_window_shape(),
max_pool->get_window_movement_strides(),
max_pool->get_padding_below(),
false,
values.at(0),
true);
}
else
{
do_pooling_operation(topology,
op,
max_pool->get_window_shape(),
max_pool->get_window_movement_strides(),
max_pool->get_padding_below(),
cldnn::pooling_mode::max);
}
break;
}
case OP_TYPEID::MaxPoolBackprop:
......@@ -804,7 +830,34 @@ shared_ptr<runtime::Executable>
}
case OP_TYPEID::AvgPool:
{
arguments_check(op, 1, 1);
const shared_ptr<op::AvgPool> avg_pool = static_pointer_cast<op::AvgPool>(op);
if ((get_input_shape(op).size() > 4) || (get_output_type(op) != element::f32) ||
avg_pool->get_include_padding_in_avg_computation() ||
!avg_pool->get_padding_below().empty() || !avg_pool->get_padding_above().empty())
{
const shared_ptr<Node> def_val = avg_pool->get_default_value();
const shared_ptr<op::Constant> def_const =
static_pointer_cast<op::Constant>(def_val);
const vector<std::string>& values = def_const->get_value_strings();
do_max_avg_pool_operation(topology,
get_input_name(op),
get_input_shape(op),
get_output_name(op),
get_output_shape(op),
get_output_type(op),
avg_pool->get_window_shape(),
avg_pool->get_window_movement_strides(),
avg_pool->get_padding_below(),
avg_pool->get_include_padding_in_avg_computation(),
values.at(0),
false);
}
else
{
const cldnn::pooling_mode mode = avg_pool->get_include_padding_in_avg_computation()
? cldnn::pooling_mode::average
: cldnn::pooling_mode::average_no_padding;
......@@ -815,6 +868,7 @@ shared_ptr<runtime::Executable>
avg_pool->get_window_movement_strides(),
avg_pool->get_padding_below(),
mode);
}
break;
}
case OP_TYPEID::AvgPoolBackprop:
......@@ -825,8 +879,8 @@ shared_ptr<runtime::Executable>
static_pointer_cast<op::AvgPoolBackprop>(op);
do_avg_pool_backprop_operation(topology,
get_input_name(op, 0),
get_input_shape(op, 0),
get_input_name(op),
get_input_shape(op),
get_output_name(op),
get_output_shape(op),
get_output_type(op),
......
......@@ -58,6 +58,19 @@ namespace ngraph
const Shape& win_stride,
const Shape& pad_below);
void do_max_avg_pool_operation(cldnn::topology& topology,
const std::string& input_name,
const Shape& input_shape,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const Shape& win_shape,
const Shape& win_stride,
const Shape& pad_below,
bool include_padding,
const std::string& def_val,
bool is_max_pool);
void do_avg_pool_backprop_operation(cldnn::topology& topology,
const std::string& delta_name,
const Shape& delta_shape,
......
all_2x2x3_eliminate_dims_0_1
avg_pool_2d_2channel_2image_padded_only_above_do_not_include_in_computation
avg_pool_2d_2channel_2image_padded_only_above_include_in_computation
avg_pool_3d_uneven_strided_padded
backwards_batch_norm_training
backwards_dot_scalar_tensor
backwards_dot_tensor_scalar
backwards_dot_tensor_vector
backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_maxpool_n4_c1_hw4_2x2_max
backwards_replace_slice
backwards_reverse_sequence_n3_c2_h3
backwards_reverse_sequence_n4d2c3h2w2
......@@ -18,7 +13,6 @@ embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_4x5_reverse
generate_mask
max_pool_3d
replace_slice_3d
replace_slice_3d_strided
replace_slice_3d_strided_different_strides
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment