Commit b1b3d4d6 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

CUDNN and CUDA kernels for AvgPool (forward/backward) (#951)

* Added op::AvgPool cudnn impl. which works for 2-3 spatial dimesions and no/symmetric padding. Enabled tests.

* Added cuda-c implementation of average pool which handles 1-3 spatial
dimensions as well as asymmetric padding. This commit also introduces
several helper functions for performing fast integer division and
fast constant memory access.

* Formatting. Removed bool that was used for testing to force the cuda impl. over cudnn.

* Added CUDNN AvgPoolBackprop implementation.

* Removed inline enum in preference of a helper struct. Removed instances of multiple declarations on a single line. Updated comments.

* Removed _prefix to helper functions in anonymous namespace.
parent 9e6d67f2
This diff is collapsed.
......@@ -51,11 +51,21 @@ namespace ngraph
size_t window_width,
size_t window_stride);
size_t build_avg_pool(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
const Shape& input_shape,
const Shape& output_shape,
const Shape& window_shape,
const Shape& window_stride,
const Shape& padding_below,
bool include_pad = false);
private:
CUDAEmitter(GPUPrimitiveEmitter* emitter);
void print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name,
const Shape& shape);
std::string include_helpers();
GPUPrimitiveEmitter* m_primitive_emitter;
};
......
......@@ -1656,6 +1656,125 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
return padded_shape;
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::AvgPool)
{
// assumes NC{d1,d2,...} format
auto avg_pool = static_cast<const ngraph::op::AvgPool*>(node);
writer.block_begin(" // " + node->get_name());
{
auto& input_shape = args[0].get_shape();
auto& result_shape = out[0].get_shape();
auto padding_below = avg_pool->get_padding_below();
auto padding_above = avg_pool->get_padding_above();
int num_nontrivial_dims = 0;
for (int64_t i = input_shape.size() - 1; i > 1; i--)
{
if (input_shape[i] > 1)
{
num_nontrivial_dims++;
}
}
size_t avg_pool_index = 0;
// if 1d or has asymmetric padding, must handle pooling manually
if (input_shape.size() == 3 || num_nontrivial_dims == 1 ||
padding_below != padding_above)
{
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
avg_pool_index =
cuda_emitter->build_avg_pool(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
input_shape,
result_shape,
avg_pool->get_window_shape(),
avg_pool->get_window_movement_strides(),
padding_below);
}
else if (input_shape.size() <= 5)
{
// 2d and 3d avg pool (NCHW) with either symetric padding or no padding
if (input_shape.size() == 4 || input_shape.size() == 5)
{
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto cudnn_avg_type = avg_pool->get_include_padding_in_avg_computation()
? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
avg_pool_index = cudnn_emitter->build_pooling(
external_function->ctx().get(),
cudnn_avg_type,
CUDNNEmitter::Prop::Forward,
input_shape,
result_shape,
avg_pool->get_window_movement_strides(),
avg_pool->get_window_shape(),
padding_below,
padding_above);
}
}
else
{
throw std::runtime_error(
"Pooling currently only supports up to 3 spatial dimensions.");
}
writer << "gpu::invoke_primitive(ctx, " << avg_pool_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::AvgPoolBackprop)
{
writer.block_begin(" // " + node->get_name());
{
auto apb = static_cast<const ngraph::op::AvgPoolBackprop*>(node);
auto output_shape = out[0].get_shape();
auto delta_shape = args[0].get_shape();
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
if (output_shape.size() >= 4)
{
auto cudnn_avg_type = apb->get_include_padding_in_avg_computation()
? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
auto avg_pool_bp_index =
cudnn_emitter->build_pooling(external_function->ctx().get(),
cudnn_avg_type,
CUDNNEmitter::Prop::Backward,
output_shape,
delta_shape,
apb->get_window_movement_strides(),
apb->get_window_shape(),
apb->get_padding_below(),
apb->get_padding_above());
writer << "gpu::invoke_primitive(ctx, " << avg_pool_bp_index << ", ";
// CUDNN backwards pooling requests input and output tensors from
// the forward pass but does not use them. It also behaves differently
// for max pool vs avg pool. The repetition of args below is to address
// this interface in a way that supports both max and avg pooling
writer << "std::vector<void*>{" << args[0].get_name() << ", "
<< args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
}
writer.block_end();
}
}
}
}
......@@ -77,3 +77,101 @@ void runtime::gpu::cuda_memset(void* dst, int value, size_t buffer_size)
{
cudaMemset(dst, value, buffer_size);
}
namespace
{
uint64_t powU64(uint64_t base, uint64_t exp)
{
uint64_t result = 1;
do
{
if (exp & 1)
{
result *= base;
}
exp >>= 1;
if (!exp)
{
break;
}
base *= base;
} while (true);
return result;
}
uint32_t msbDeBruijnU32(uint32_t v)
{
static const int multiply_de_Bruijn_bit_position[32] = {
0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30,
8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31};
v |= v >> 1; // first round down to one less than a power of 2
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
return multiply_de_Bruijn_bit_position[(uint32_t)(v * 0x07C4ACDDU) >> 27];
}
int msbU64(uint64_t val)
{
if (val > 0x00000000FFFFFFFFul)
{
return 32 + msbDeBruijnU32(static_cast<uint32_t>(val >> 32));
}
// Number is no more than 32 bits,
// so calculate number of bits in the bottom half.
return msbDeBruijnU32(static_cast<uint32_t>(val & 0xFFFFFFFF));
}
// Magic numbers and shift amounts for integer division
// Suitable for when nmax*magic fits in 32 bits
// Translated from http://www.hackersdelight.org/hdcodetxt/magicgu.py.txt
std::pair<uint64_t, uint64_t> magicU32(uint64_t nmax, uint64_t d)
{
uint64_t nc = ((nmax + 1) / d) * d - 1;
uint64_t nbits = msbU64(nmax) + 1;
for (uint64_t p = 0; p < 2 * nbits + 1; p++)
{
uint64_t pow2 = powU64(2, p);
if (pow2 > nc * (d - 1 - (pow2 - 1) % d))
{
uint64_t m = (pow2 + d - 1 - (pow2 - 1) % d) / d;
return std::pair<uint64_t, uint64_t>{m, p};
}
}
throw std::runtime_error("Magic for unsigned integer division could not be found.");
}
// Magic numbers and shift amounts for integer division
// Suitable for when nmax*magic fits in 64 bits and the shift
// lops off the lower 32 bits
std::pair<uint64_t, uint64_t> magicU64(uint64_t d)
{
// 3 is a special case that only ends up in the high bits
// if the nmax is 0xffffffff
// we can't use 0xffffffff for all cases as some return a 33 bit
// magic number
uint64_t nmax = (d == 3) ? 0xffffffff : 0x7fffffff;
uint64_t magic, shift;
std::tie(magic, shift) = magicU32(nmax, d);
if (magic != 1)
{
shift -= 32;
}
return std::pair<uint64_t, uint64_t>{magic, shift};
}
}
std::pair<uint64_t, uint64_t> runtime::gpu::idiv_magic_u32(uint64_t max_numerator, uint64_t divisor)
{
return magicU32(max_numerator, divisor);
}
std::pair<uint64_t, uint64_t> runtime::gpu::idiv_magic_u64(uint64_t divisor)
{
return magicU64(divisor);
}
......@@ -19,7 +19,9 @@
#include <memory>
#include <sstream>
#include <stdexcept>
#include <stdint.h>
#include <string>
#include <tuple>
#include <vector>
#include <cublas_v2.h>
......@@ -97,6 +99,8 @@ namespace ngraph
void cuda_memcpyHtD(void* dst, void* src, size_t buffer_size);
void cuda_memcpyDtH(void* dst, void* src, size_t buffer_size);
void cuda_memset(void* dst, int value, size_t buffer_size);
std::pair<uint64_t, uint64_t> idiv_magic_u32(uint64_t max_numerator, uint64_t divisor);
std::pair<uint64_t, uint64_t> idiv_magic_u64(uint64_t divisor);
}
}
}
......@@ -110,7 +110,6 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n2_c1_hw5_3x3_str2_max)
TEST(${BACKEND_NAME}, backwards_avgpool_n1_c1_hw2x2)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape padding{1, 1};
......@@ -147,7 +146,6 @@ TEST(${BACKEND_NAME}, backwards_avgpool_n1_c1_hw2x2)
TEST(${BACKEND_NAME}, backwards_avgpool_n1_c1_hw4x4)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape_a{1, 1, 4, 4};
......@@ -181,7 +179,6 @@ TEST(${BACKEND_NAME}, backwards_avgpool_n1_c1_hw4x4)
TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape_a{2, 2, 4, 4};
......@@ -281,7 +278,6 @@ TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4)
TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4_numeric)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape_a{2, 2, 4, 4};
test::Uniform<float> rng(1.0f, 10.0f);
......@@ -304,7 +300,6 @@ TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4_numeric)
TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4_win_2x2_str_1x1_numeric)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape_a{2, 2, 4, 4};
test::Uniform<float> rng(1.0f, 10.0f);
......@@ -327,7 +322,6 @@ TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4_win_2x2_str_1x1_numeric)
TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw2x2_win_2x2_str_1x1_padding_numeric)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape_a{2, 2, 4, 4};
test::Uniform<float> rng(1.0f, 10.0f);
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment