Commit 0051f201 authored by Fenglei's avatar Fenglei Committed by Scott Cyphers

nvgpu maxpool bug fix (#1741)

* add a test failed on gpu, pass on cpu

* fixed bug

* get datatype size

* add descript for test

* update comment

* update comments and name
parent 2b289df0
......@@ -712,38 +712,37 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::MaxPool* node)
/// assymetric padding detection
bool pad_required = false;
auto shape_to_pool =
runtime::gpu::get_padded_shape(input_shape, padding_below, padding_above, {});
if (shape_to_pool != input_shape)
{
pad_required = true;
}
auto input_shape_padded = input_shape;
pad_required = pad_required && (padding_below != padding_above);
size_t padded_size;
// asymetric padding
size_t idx_workspace = std::numeric_limits<size_t>::max();
size_t pad_index = std::numeric_limits<size_t>::max();
if (pad_required)
if (padding_below != padding_above)
{
auto temp_size = shape_size(shape_to_pool) * args[0].get_element_type().size();
Shape padding_interior(padding_below.size(), 1);
input_shape_padded =
runtime::gpu::get_padded_shape(input_shape, padding_below, padding_above, {});
padded_size = shape_size(input_shape_padded);
//currntly we set this to float point only, need to add other datatype support later
float pad_value = std::numeric_limits<float>::lowest();
std::vector<float> temp(padded_size, pad_value);
GPUAllocator allocator = m_primitive_emitter->get_memory_allocator();
idx_workspace = allocator.reserve_workspace(temp_size);
auto pad_value = TypeInfo::Get(args[0].get_element_type())->lowest();
idx_workspace = allocator.reserve_argspace(temp.data(),
padded_size * args[0].get_element_type().size());
auto& cuda_emitter = m_primitive_emitter->get_cuda_emitter();
pad_index = cuda_emitter->build_pad({{input_type, output_type}},
input_shape,
shape_to_pool,
padding_below,
padding_above,
Shape{},
pad_value);
pad_index = cuda_emitter->build_pad_dynamic({{input_type, output_type}},
input_shape,
input_shape_padded,
padding_below,
padding_interior);
// asymetric padding has been applied, zero out padding vectors to
// ensure cuDNN does not assume padding during pooling
std::fill(padding_below.begin(), padding_below.end(), 0);
std::fill(padding_above.begin(), padding_above.end(), 0);
pad_required = true;
}
/// end asymmetric padding detection
......@@ -751,7 +750,7 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::MaxPool* node)
size_t max_pool_index = build_pooling(CUDNN_POOLING_MAX,
output_type,
CUDNNEmitter::Prop::Forward,
shape_to_pool,
input_shape_padded,
result_shape,
node->get_window_movement_strides(),
node->get_window_shape(),
......@@ -760,18 +759,8 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::MaxPool* node)
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
if (idx_workspace != std::numeric_limits<size_t>::max() &&
pad_index != std::numeric_limits<size_t>::max())
if (pad_required)
{
// void* pad_buffer = runtime::gpu::invoke_memory_primitive(m_ctx, idx_workspace);
// gpu::invoke_primitive(m_ctx,
// pad_dynamic_index,
// std::vector<void*>{inputs[0]}.data(),
// std::vector<void*>{pad_buffer}.data());
// gpu::invoke_primitive(
// m_ctx, conv_index, std::vector<void*>{pad_buffer, inputs[1]}.data(), outputs);
void* pad_buffer = runtime::gpu::invoke_memory_primitive(m_ctx, idx_workspace);
gpu::invoke_primitive(m_ctx,
pad_index,
......
......@@ -5207,6 +5207,65 @@ NGRAPH_TEST(${BACKEND_NAME}, max_pool_2d_2channel_2image)
read_vector<float>(result));
}
//this test cover the case with multiple image and with asymetric pad
//one bug been found on GPU side is covered by this test
NGRAPH_TEST(${BACKEND_NAME}, max_pool_2d_2channel_2image_asym_pad)
{
Shape shape_a{2, 2, 4, 4};
Shape window_shape{3, 3};
auto window_movement_strides = Strides{2, 2};
Shape padding_below{0, 0};
Shape padding_above{1, 1};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r{2, 2, 2, 2};
auto f = make_shared<Function>(
make_shared<op::MaxPool>(
A, window_shape, window_movement_strides, padding_below, padding_above),
op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a,
test::NDArray<float, 4>({{{{0, 1, 0, 2}, // img 0 chan 0
{0, 3, 2, 0},
{2, 0, 0, 0},
{0, 2, 1, 0}},
{{0, 0, 0, 2}, // img 0 chan 1
{0, 2, 3, 0},
{2, 0, 1, 0},
{2, 0, 0, 0}}},
{{{0, 2, 1, 1}, // img 1 chan 0
{0, 0, 2, 0},
{0, 0, 1, 2},
{0, 0, 0, 0}},
{{2, 1, 0, 0}, // img 1 chan 1
{0, 2, 0, 0},
{1, 1, 2, 0},
{1, 0, 0, 0}}}})
.get_vector());
auto result = backend->create_tensor(element::f32, shape_r);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 4>({{{{3, 2}, // img 0 chan 0
{2, 1}},
{{3, 3}, // img 0 chan 1
{2, 1}}},
{{{2, 2}, // img 1 chan 0
{1, 2}},
{{2, 2}, // img 1 chan 1
{2, 2}}}})
.get_vector()),
read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_overpadded)
{
Shape shape_a{1, 1, 5, 5};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment