Commit b1b3d4d6 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

CUDNN and CUDA kernels for AvgPool (forward/backward) (#951)

* Added op::AvgPool cudnn impl. which works for 2-3 spatial dimesions and no/symmetric padding. Enabled tests.

* Added cuda-c implementation of average pool which handles 1-3 spatial
dimensions as well as asymmetric padding. This commit also introduces
several helper functions for performing fast integer division and
fast constant memory access.

* Formatting. Removed bool that was used for testing to force the cuda impl. over cudnn.

* Added CUDNN AvgPoolBackprop implementation.

* Removed inline enum in preference of a helper struct. Removed instances of multiple declarations on a single line. Updated comments.

* Removed _prefix to helper functions in anonymous namespace.
parent 9e6d67f2
......@@ -16,17 +16,51 @@
#include <algorithm>
#include <iostream>
#include <limits>
#include <ostream>
#include <vector>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/runtime/gpu/cuda_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/runtime/gpu/type_info.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
struct pooling_op_shape
{
int N;
int C;
int D;
int H;
int W;
int K;
int M;
int P;
int Q;
int J;
int T;
int R;
int S;
int STRIDE_D;
int STRIDE_H;
int STRIDE_W;
int PAD_D;
int PAD_H;
int PAD_W;
};
std::ostream& operator<<(std::ostream& os, pooling_op_shape& shape)
{
return os << shape.N << "_" << shape.C << "_" << shape.D << "_" << shape.H << "_" << shape.W
<< "_" << shape.K << "_" << shape.M << "_" << shape.P << "_" << shape.Q << "_"
<< shape.J << "_" << shape.T << "_" << shape.R << "_" << shape.S << "_"
<< shape.STRIDE_D << "_" << shape.STRIDE_H << "_" << shape.STRIDE_W << "_"
<< shape.PAD_D << "_" << shape.PAD_H << "_" << shape.PAD_W;
}
runtime::gpu::CUDAEmitter::CUDAEmitter(runtime::gpu::GPUPrimitiveEmitter* emitter)
: m_primitive_emitter(emitter)
{
......@@ -280,6 +314,295 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx
return primitive_index;
}
pooling_op_shape avgpool_shape(
const Shape& in, const Shape& out, const Shape& window, const Shape& strides, const Shape& pad)
{
pooling_op_shape shape;
shape.N = static_cast<int>(in[0]);
shape.C = static_cast<int>(in[1]);
shape.K = shape.C; // pooling feature maps is
shape.J = shape.C; // not currently supported
if (in.size() == 3)
{
shape.D = 1;
shape.H = 1;
shape.W = static_cast<int>(in[2]);
shape.M = 1;
shape.P = 1;
shape.Q = static_cast<int>(out[2]);
shape.T = 1;
shape.R = 1;
shape.S = static_cast<int>(window[0]);
shape.STRIDE_D = 0;
shape.STRIDE_H = 0;
shape.STRIDE_W = static_cast<int>(strides[0]);
shape.PAD_D = 0;
shape.PAD_H = 0;
shape.PAD_W = static_cast<int>(pad[0]);
}
else if (in.size() == 4)
{
shape.D = 1;
shape.H = static_cast<int>(in[2]);
shape.W = static_cast<int>(in[3]);
shape.M = 1;
shape.P = static_cast<int>(out[2]);
shape.Q = static_cast<int>(out[3]);
shape.T = 1;
shape.R = static_cast<int>(window[0]);
shape.S = static_cast<int>(window[1]);
shape.STRIDE_D = 0;
shape.STRIDE_H = static_cast<int>(strides[0]);
shape.STRIDE_W = static_cast<int>(strides[1]);
shape.PAD_D = 0;
shape.PAD_H = static_cast<int>(pad[0]);
shape.PAD_W = static_cast<int>(pad[1]);
}
else if (in.size() == 5)
{
shape.D = static_cast<int>(in[2]);
shape.H = static_cast<int>(in[3]);
shape.W = static_cast<int>(in[4]);
shape.M = static_cast<int>(out[2]);
shape.P = static_cast<int>(out[3]);
shape.Q = static_cast<int>(out[4]);
shape.T = static_cast<int>(window[0]);
shape.R = static_cast<int>(window[1]);
shape.S = static_cast<int>(window[2]);
shape.STRIDE_D = static_cast<int>(strides[0]);
shape.STRIDE_H = static_cast<int>(strides[1]);
shape.STRIDE_W = static_cast<int>(strides[2]);
shape.PAD_D = static_cast<int>(pad[0]);
shape.PAD_H = static_cast<int>(pad[1]);
shape.PAD_W = static_cast<int>(pad[2]);
}
else
{
throw std::runtime_error("AvgPool currently supports up to 3 spatial dimensions.");
}
return shape;
}
size_t runtime::gpu::CUDAEmitter::build_avg_pool(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
const Shape& input_shape,
const Shape& output_shape,
const Shape& window_shape,
const Shape& window_stride,
const Shape& padding_below,
bool include_pad)
{
// assumes NCDHW format
pooling_op_shape shape =
avgpool_shape(input_shape, output_shape, window_shape, window_stride, padding_below);
std::string kernel_name = "avgpool";
std::stringstream ss;
ss << kernel_name << "_s" << shape << "_st" << join(window_stride, "_") << "_ip"
<< int(include_pad);
auto hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// if the kernel has not been compiled, build it
kernel_name += "_ip" + std::to_string(int(include_pad));
auto compiled_kernel = ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
writer << include_helpers();
// In the pooling operation out = P(in) where in: NCDHW -> out: NKMPQ
// via pooling window: JTRS. Currently feature pooling
// is not supported and so K = C and J is unused
writer << "extern \"C\" __global__ void cuda_" << kernel_name << "(" << dtypes[0]
<< "* in, " << dtypes[1] << "* out, "
<< "float alpha, float beta, "
<< "int N, int C, int D, int H, int W, "
<< "int HW, int DHW, int CDHW, int magic_N, int shift_N, "
<< "int P, int Q, int magic_P, int shift_P, "
<< "int PQ, int MPQ, int KMPQ, "
<< "int S, int RS, int TRS, "
<< "int magic_S, int shift_S, int magic_RS, int shift_RS, "
<< "int str_d, int str_h, int str_w, "
<< "int pad_d, int pad_h, int pad_w"
<< ")\n";
writer.block_begin();
{
writer << "const int tid = threadIdx.x;\n";
writer << "if (tid < 32)\n";
writer.block_begin();
{
writer << "const int q = blockIdx.x;\n";
writer << "const int mp = blockIdx.y;\n";
writer << "const int nk = blockIdx.z;\n";
writer << "const int k = div64(nk, magic_N, shift_N);\n";
writer << "const int n = nk - k * N;\n";
writer << "const int m = div64(mp, magic_P, shift_P);\n";
writer << "const int p = mp - m * P;\n";
writer << "out += n*KMPQ + k*MPQ + m*PQ + mad16(p, Q, q);\n";
// coordinate transform factors from MPQ to DHW
writer << "int qs = q * str_w - pad_w;\n";
writer << "int pr = p * str_h - pad_h;\n";
writer << "int mt = m * str_d - pad_d;\n";
writer << "int pool_size = ";
auto pool_size = include_pad ? "TRS" : "0";
writer << pool_size << ";\n";
writer << "float sum = 0.0f;\n";
writer << "float rcp_pool_size = 1.0f;\n";
// each warp operates on a single pooling window and
// reduces the contents of the window within the warp
writer << "for (int trs = tid; trs < TRS; trs += 32)\n";
writer.block_begin();
{
writer << "int t = div64(trs, magic_RS, shift_RS);\n";
writer << "int rs = mod16(trs, t, RS);\n";
writer << "int r = div64(rs, magic_S, shift_S);\n";
writer << "int s = mod16(rs, r, S);\n";
// coordinate transformation from TRS to DHW
// via MPQ transform factors above
writer << "int x = qs + s;\n";
writer << "int y = pr + r;\n";
writer << "int z = mt + t;\n";
// helper to check participating threads
writer << "bool bounds_x = (x >= 0) && (x < W);\n";
writer << "bool bounds_y = (y >= 0) && (y < H);\n";
writer << "bool bounds_z = (z >= 0) && (z < D);\n";
writer << "bool within_tensor_bounds = bounds_x && bounds_y && bounds_z;\n";
if (include_pad == false)
{
// count the number of (non-padded) elements
writer << "pool_size += __popc(__ballot(within_tensor_bounds));\n";
}
// this will need to change to k->c once
// feature pooling support is added
writer << "int idx = n*CDHW + k*DHW + z*HW + y*W + x;\n";
writer << "sum += load(in,idx,within_tensor_bounds);\n";
}
writer.block_end();
writer << "rcp_pool_size = 1.0f / (float)pool_size;\n";
// reduce pooling window within warp.
// this could be improved by calculating the
// pooling windows each thread can partake in to
// reduce loads and increase coalescing. in that case,
// multiple warps per block would be required and the
// warp reduced sums would need to be accumulated in
// shared memory
writer << "for (int i = 16; i > 0; i >>= 1)\n";
writer.block_begin();
{
writer << "sum += __shfl_xor(sum,i);\n";
}
writer.block_end();
// write result to output
writer << "if (tid == 0)\n";
writer.block_begin();
{
writer << "*out = sum * rcp_pool_size;\n";
}
writer.block_end();
}
writer.block_end();
}
writer.block_end();
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
// precompute for fast constant memory access
int HW = shape.H * shape.W;
int DHW = shape.D * HW;
int CDHW = shape.C * DHW;
int PQ = shape.P * shape.Q;
int MPQ = shape.M * PQ;
int KMPQ = shape.K * MPQ;
int RS = shape.R * shape.S;
int TRS = shape.T * RS;
// precompute magic numbers and shifts for fast integer division
int magic_N;
int shift_N;
std::tie(magic_N, shift_N) = idiv_magic_u64(shape.N);
int magic_P;
int shift_P;
std::tie(magic_P, shift_P) = idiv_magic_u64(shape.P);
int magic_S;
int shift_S;
std::tie(magic_S, shift_S) = idiv_magic_u64(shape.S);
int magic_RS;
int shift_RS;
std::tie(magic_RS, shift_RS) = idiv_magic_u64(RS);
// TODO: blending factors are not currently implemented
float alpha = 1.0f;
float beta = 0.0f;
std::unique_ptr<gpu::primitive> pool(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void* args_list[] = {&inputs[0],
&outputs[0],
&alpha,
&beta,
&shape.N,
&shape.C,
&shape.D,
&shape.H,
&shape.W,
&HW,
&DHW,
&CDHW,
&magic_N,
&shift_N,
&shape.P,
&shape.Q,
&magic_P,
&shift_P,
&PQ,
&MPQ,
&KMPQ,
&shape.S,
&RS,
&TRS,
&magic_S,
&shift_S,
&magic_RS,
&shift_RS,
&shape.STRIDE_D,
&shape.STRIDE_H,
&shape.STRIDE_W,
&shape.PAD_D,
&shape.PAD_H,
&shape.PAD_W};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
shape.Q,
shape.M * shape.P,
shape.N * shape.K,
32,
1,
1,
0,
NULL,
args_list,
0));
CUDA_SAFE_CALL(cuCtxSynchronize());
}});
primitive_index = this->m_primitive_emitter->insert(std::move(pool));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
void runtime::gpu::CUDAEmitter::print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name,
const Shape& shape)
......@@ -314,3 +637,57 @@ void runtime::gpu::CUDAEmitter::print_tensor_from_gpu(codegen::CodeWriter& write
}
writer.block_end();
}
std::string runtime::gpu::CUDAEmitter::include_helpers()
{
// div64: fast integer division via magic multiplication and shifting
// if value is a power of 2, magic will be 1 and only shifting
// is required (predicate p in div64)
// load: helper to load from constant memory for fast access
std::stringstream ss;
ss << R"(
__device__ __forceinline__ int div64(int value, int magic, int shift)
{
int result;
asm("{\n\t"
".reg .pred p;\n\t"
".reg .u64 res64;\n\t"
".reg .u32 lo32, hi32;\n\t"
"setp.ne.s32 p, %2, 1;\n\t"
"mul.wide.u32 res64, %1, %2;\n\t"
"mov.b64 {lo32, hi32}, res64;\n\t"
"selp.u32 hi32, hi32, %1, p;\n\t"
"shr.u32 %0, hi32, %3;\n\t"
"}" : "=r"(result) : "r"(value), "r"(magic), "r"(shift));
return result;
}
__device__ __forceinline__ int mod16(int numerator, int div, int maxdiv)
{
int res;
asm("vmad.s32.u32.u32 %0, -%1.h0, %2.h0, %3;" : "=r"(res) : "r"(div), "r"(maxdiv), "r"(numerator));
return res;
}
__device__ __forceinline__ int mad16(int a, int b, int c)
{
int res;
asm("vmad.s32.u32.u32 %0, %1.h0, %2.h0, %3;" : "=r"(res) : "r"(a), "r"(b), "r"(c));
return res;
}
__device__ __forceinline__ int msub16(int a, int b, int c)
{
int res;
asm("vmad.s32.u32.u32 %0, %1.h0, %2.h0, -%3;" : "=r"(res) : "r"(a), "r"(b), "r"(c));
return res;
}
__device__ __forceinline__ float load(const float* __restrict__ in, int i=0, bool b=true)
{
float v = 0.0f;
if (b)
{
v = __ldg(in + i);
}
return v;
}
)";
return ss.str();
}
......@@ -51,11 +51,21 @@ namespace ngraph
size_t window_width,
size_t window_stride);
size_t build_avg_pool(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
const Shape& input_shape,
const Shape& output_shape,
const Shape& window_shape,
const Shape& window_stride,
const Shape& padding_below,
bool include_pad = false);
private:
CUDAEmitter(GPUPrimitiveEmitter* emitter);
void print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name,
const Shape& shape);
std::string include_helpers();
GPUPrimitiveEmitter* m_primitive_emitter;
};
......
......@@ -1656,6 +1656,125 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
return padded_shape;
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::AvgPool)
{
// assumes NC{d1,d2,...} format
auto avg_pool = static_cast<const ngraph::op::AvgPool*>(node);
writer.block_begin(" // " + node->get_name());
{
auto& input_shape = args[0].get_shape();
auto& result_shape = out[0].get_shape();
auto padding_below = avg_pool->get_padding_below();
auto padding_above = avg_pool->get_padding_above();
int num_nontrivial_dims = 0;
for (int64_t i = input_shape.size() - 1; i > 1; i--)
{
if (input_shape[i] > 1)
{
num_nontrivial_dims++;
}
}
size_t avg_pool_index = 0;
// if 1d or has asymmetric padding, must handle pooling manually
if (input_shape.size() == 3 || num_nontrivial_dims == 1 ||
padding_below != padding_above)
{
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
avg_pool_index =
cuda_emitter->build_avg_pool(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
input_shape,
result_shape,
avg_pool->get_window_shape(),
avg_pool->get_window_movement_strides(),
padding_below);
}
else if (input_shape.size() <= 5)
{
// 2d and 3d avg pool (NCHW) with either symetric padding or no padding
if (input_shape.size() == 4 || input_shape.size() == 5)
{
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto cudnn_avg_type = avg_pool->get_include_padding_in_avg_computation()
? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
avg_pool_index = cudnn_emitter->build_pooling(
external_function->ctx().get(),
cudnn_avg_type,
CUDNNEmitter::Prop::Forward,
input_shape,
result_shape,
avg_pool->get_window_movement_strides(),
avg_pool->get_window_shape(),
padding_below,
padding_above);
}
}
else
{
throw std::runtime_error(
"Pooling currently only supports up to 3 spatial dimensions.");
}
writer << "gpu::invoke_primitive(ctx, " << avg_pool_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::AvgPoolBackprop)
{
writer.block_begin(" // " + node->get_name());
{
auto apb = static_cast<const ngraph::op::AvgPoolBackprop*>(node);
auto output_shape = out[0].get_shape();
auto delta_shape = args[0].get_shape();
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
if (output_shape.size() >= 4)
{
auto cudnn_avg_type = apb->get_include_padding_in_avg_computation()
? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
auto avg_pool_bp_index =
cudnn_emitter->build_pooling(external_function->ctx().get(),
cudnn_avg_type,
CUDNNEmitter::Prop::Backward,
output_shape,
delta_shape,
apb->get_window_movement_strides(),
apb->get_window_shape(),
apb->get_padding_below(),
apb->get_padding_above());
writer << "gpu::invoke_primitive(ctx, " << avg_pool_bp_index << ", ";
// CUDNN backwards pooling requests input and output tensors from
// the forward pass but does not use them. It also behaves differently
// for max pool vs avg pool. The repetition of args below is to address
// this interface in a way that supports both max and avg pooling
writer << "std::vector<void*>{" << args[0].get_name() << ", "
<< args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
}
writer.block_end();
}
}
}
}
......@@ -77,3 +77,101 @@ void runtime::gpu::cuda_memset(void* dst, int value, size_t buffer_size)
{
cudaMemset(dst, value, buffer_size);
}
namespace
{
uint64_t powU64(uint64_t base, uint64_t exp)
{
uint64_t result = 1;
do
{
if (exp & 1)
{
result *= base;
}
exp >>= 1;
if (!exp)
{
break;
}
base *= base;
} while (true);
return result;
}
uint32_t msbDeBruijnU32(uint32_t v)
{
static const int multiply_de_Bruijn_bit_position[32] = {
0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30,
8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31};
v |= v >> 1; // first round down to one less than a power of 2
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
return multiply_de_Bruijn_bit_position[(uint32_t)(v * 0x07C4ACDDU) >> 27];
}
int msbU64(uint64_t val)
{
if (val > 0x00000000FFFFFFFFul)
{
return 32 + msbDeBruijnU32(static_cast<uint32_t>(val >> 32));
}
// Number is no more than 32 bits,
// so calculate number of bits in the bottom half.
return msbDeBruijnU32(static_cast<uint32_t>(val & 0xFFFFFFFF));
}
// Magic numbers and shift amounts for integer division
// Suitable for when nmax*magic fits in 32 bits
// Translated from http://www.hackersdelight.org/hdcodetxt/magicgu.py.txt
std::pair<uint64_t, uint64_t> magicU32(uint64_t nmax, uint64_t d)
{
uint64_t nc = ((nmax + 1) / d) * d - 1;
uint64_t nbits = msbU64(nmax) + 1;
for (uint64_t p = 0; p < 2 * nbits + 1; p++)
{
uint64_t pow2 = powU64(2, p);
if (pow2 > nc * (d - 1 - (pow2 - 1) % d))
{
uint64_t m = (pow2 + d - 1 - (pow2 - 1) % d) / d;
return std::pair<uint64_t, uint64_t>{m, p};
}
}
throw std::runtime_error("Magic for unsigned integer division could not be found.");
}
// Magic numbers and shift amounts for integer division
// Suitable for when nmax*magic fits in 64 bits and the shift
// lops off the lower 32 bits
std::pair<uint64_t, uint64_t> magicU64(uint64_t d)
{
// 3 is a special case that only ends up in the high bits
// if the nmax is 0xffffffff
// we can't use 0xffffffff for all cases as some return a 33 bit
// magic number
uint64_t nmax = (d == 3) ? 0xffffffff : 0x7fffffff;
uint64_t magic, shift;
std::tie(magic, shift) = magicU32(nmax, d);
if (magic != 1)
{
shift -= 32;
}
return std::pair<uint64_t, uint64_t>{magic, shift};
}
}
std::pair<uint64_t, uint64_t> runtime::gpu::idiv_magic_u32(uint64_t max_numerator, uint64_t divisor)
{
return magicU32(max_numerator, divisor);
}
std::pair<uint64_t, uint64_t> runtime::gpu::idiv_magic_u64(uint64_t divisor)
{
return magicU64(divisor);
}
......@@ -19,7 +19,9 @@
#include <memory>
#include <sstream>
#include <stdexcept>
#include <stdint.h>
#include <string>
#include <tuple>
#include <vector>
#include <cublas_v2.h>
......@@ -97,6 +99,8 @@ namespace ngraph
void cuda_memcpyHtD(void* dst, void* src, size_t buffer_size);
void cuda_memcpyDtH(void* dst, void* src, size_t buffer_size);
void cuda_memset(void* dst, int value, size_t buffer_size);
std::pair<uint64_t, uint64_t> idiv_magic_u32(uint64_t max_numerator, uint64_t divisor);
std::pair<uint64_t, uint64_t> idiv_magic_u64(uint64_t divisor);
}
}
}
......@@ -110,7 +110,6 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n2_c1_hw5_3x3_str2_max)
TEST(${BACKEND_NAME}, backwards_avgpool_n1_c1_hw2x2)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape padding{1, 1};
......@@ -147,7 +146,6 @@ TEST(${BACKEND_NAME}, backwards_avgpool_n1_c1_hw2x2)
TEST(${BACKEND_NAME}, backwards_avgpool_n1_c1_hw4x4)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape_a{1, 1, 4, 4};
......@@ -181,7 +179,6 @@ TEST(${BACKEND_NAME}, backwards_avgpool_n1_c1_hw4x4)
TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape_a{2, 2, 4, 4};
......@@ -281,7 +278,6 @@ TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4)
TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4_numeric)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape_a{2, 2, 4, 4};
test::Uniform<float> rng(1.0f, 10.0f);
......@@ -304,7 +300,6 @@ TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4_numeric)
TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4_win_2x2_str_1x1_numeric)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape_a{2, 2, 4, 4};
test::Uniform<float> rng(1.0f, 10.0f);
......@@ -327,7 +322,6 @@ TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw4x4_win_2x2_str_1x1_numeric)
TEST(${BACKEND_NAME}, backwards_avgpool_n2_c2_hw2x2_win_2x2_str_1x1_padding_numeric)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape_a{2, 2, 4, 4};
test::Uniform<float> rng(1.0f, 10.0f);
......
......@@ -5875,7 +5875,6 @@ TEST(${BACKEND_NAME}, computation_reuse)
TEST(${BACKEND_NAME}, avg_pool_1d_1channel_1image)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{1, 1, 14};
Shape window_shape{3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -5894,7 +5893,7 @@ TEST(${BACKEND_NAME}, avg_pool_1d_1channel_1image)
float denom = 3.0;
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 3>({{{1 / denom,
EXPECT_TRUE(test::all_close(test::NDArray<float, 3>({{{1 / denom,
3 / denom,
3 / denom,
3 / denom,
......@@ -5906,13 +5905,12 @@ TEST(${BACKEND_NAME}, avg_pool_1d_1channel_1image)
2 / denom,
2 / denom,
0 / denom}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_1d_1channel_2image)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 1, 14};
Shape window_shape{3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -5933,7 +5931,7 @@ TEST(${BACKEND_NAME}, avg_pool_1d_1channel_2image)
float denom = 3.0;
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 3>({{{1 / denom,
EXPECT_TRUE(test::all_close(test::NDArray<float, 3>({{{1 / denom,
3 / denom,
3 / denom,
3 / denom,
......@@ -5957,13 +5955,12 @@ TEST(${BACKEND_NAME}, avg_pool_1d_1channel_2image)
1 / denom,
1 / denom,
3 / denom}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_1d_2channel_2image)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 2, 14};
Shape window_shape{3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -5987,7 +5984,7 @@ TEST(${BACKEND_NAME}, avg_pool_1d_2channel_2image)
float denom = 3.0;
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 3>({{{1 / denom,
EXPECT_TRUE(test::all_close(test::NDArray<float, 3>({{{1 / denom,
3 / denom,
3 / denom,
3 / denom,
......@@ -6036,13 +6033,12 @@ TEST(${BACKEND_NAME}, avg_pool_1d_2channel_2image)
2 / denom,
4 / denom,
3 / denom}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 2, 5, 5};
Shape window_shape{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -6084,7 +6080,9 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image)
float denom = 2 * 3;
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 4>({{{{6 / denom, 8 / denom, 5 / denom}, // img 0 chan 0
EXPECT_TRUE(test::all_close(
test::NDArray<float, 4>({{{{6 / denom, 8 / denom, 5 / denom}, // img 0 chan 0
{7 / denom, 5 / denom, 3 / denom},
{5 / denom, 2 / denom, 5 / denom},
{6 / denom, 5 / denom, 5 / denom}},
......@@ -6103,13 +6101,12 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image)
{6 / denom, 5 / denom, 4 / denom},
{7 / denom, 5 / denom, 6 / denom},
{4 / denom, 2 / denom, 4 / denom}}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_2d_1channel_1image_strided)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{1, 1, 8, 8};
Shape window_shape{2, 3};
auto window_movement_strides = Strides{3, 2};
......@@ -6137,16 +6134,15 @@ TEST(${BACKEND_NAME}, avg_pool_2d_1channel_1image_strided)
float denom = 2 * 3;
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 4>({{{{6 / denom, 5 / denom, 4 / denom},
EXPECT_TRUE(test::all_close(test::NDArray<float, 4>({{{{6 / denom, 5 / denom, 4 / denom},
{6 / denom, 5 / denom, 8 / denom},
{6 / denom, 2 / denom, 4 / denom}}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_2d_1channel_1image_padded)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{1, 1, 3, 3};
Shape window_shape{2, 2};
auto window_movement_strides = Strides{1, 1};
......@@ -6167,17 +6163,17 @@ TEST(${BACKEND_NAME}, avg_pool_2d_1channel_1image_padded)
auto result = backend->create_tensor(element::f32, shape_r);
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 4>({{{{0.0f / 1, 1.0f / 2, 1.0f / 2, 0.0f / 1},
EXPECT_TRUE(
test::all_close(test::NDArray<float, 4>({{{{0.0f / 1, 1.0f / 2, 1.0f / 2, 0.0f / 1},
{0.0f / 2, 4.0f / 4, 6.0f / 4, 2.0f / 2},
{2.0f / 2, 5.0f / 4, 5.0f / 4, 2.0f / 2},
{2.0f / 1, 2.0f / 2, 0.0f / 2, 0.0f / 1}}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 1, 3, 3};
Shape window_shape{2, 2};
auto window_movement_strides = Strides{1, 1};
......@@ -6201,7 +6197,8 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded)
auto result = backend->create_tensor(element::f32, shape_r);
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 4>({{{{0.0f / 1, 1.0f / 2, 1.0f / 2, 0.0f / 1},
EXPECT_TRUE(
test::all_close(test::NDArray<float, 4>({{{{0.0f / 1, 1.0f / 2, 1.0f / 2, 0.0f / 1},
{0.0f / 2, 4.0f / 4, 6.0f / 4, 2.0f / 2},
{2.0f / 2, 5.0f / 4, 5.0f / 4, 2.0f / 2},
{2.0f / 1, 2.0f / 2, 0.0f / 2, 0.0f / 1}},
......@@ -6209,13 +6206,12 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded)
{5.0f / 2, 10.0f / 4, 16.0f / 4, 11.0f / 2},
{5.0f / 2, 11.0f / 4, 20.0f / 4, 14.0f / 2},
{3.0f / 1, 9.0f / 2, 11.0f / 2, 5.0f / 1}}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_only_below)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 1, 3, 3};
Shape window_shape{2, 2};
auto window_movement_strides = Strides{1, 1};
......@@ -6239,19 +6235,18 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_only_below)
auto result = backend->create_tensor(element::f32, shape_r);
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 4>({{{{0.0f / 1, 1.0f / 2, 1.0f / 2},
EXPECT_TRUE(test::all_close(test::NDArray<float, 4>({{{{0.0f / 1, 1.0f / 2, 1.0f / 2},
{0.0f / 2, 4.0f / 4, 6.0f / 4},
{2.0f / 2, 5.0f / 4, 5.0f / 4}},
{{3.0f / 1, 8.0f / 2, 7.0f / 2},
{5.0f / 2, 10.0f / 4, 16.0f / 4},
{5.0f / 2, 11.0f / 4, 20.0f / 4}}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_only_above)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 1, 3, 3};
Shape window_shape{2, 2};
auto window_movement_strides = Strides{1, 1};
......@@ -6275,19 +6270,18 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_only_above)
auto result = backend->create_tensor(element::f32, shape_r);
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 4>({{{{4.0f / 4, 6.0f / 4, 2.0f / 2},
EXPECT_TRUE(test::all_close(test::NDArray<float, 4>({{{{4.0f / 4, 6.0f / 4, 2.0f / 2},
{5.0f / 4, 5.0f / 4, 2.0f / 2},
{2.0f / 2, 0.0f / 2, 0.0f / 1}},
{{10.0f / 4, 16.0f / 4, 11.0f / 2},
{11.0f / 4, 20.0f / 4, 14.0f / 2},
{9.0f / 2, 11.0f / 2, 5.0f / 1}}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_3x3)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 1, 3, 3};
Shape window_shape{3, 3};
auto window_movement_strides = Strides{1, 1};
......@@ -6311,7 +6305,8 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_3x3)
auto result = backend->create_tensor(element::f32, shape_r);
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 4>({{{{0.0f / 1, 1.0f / 2, 1.0f / 3, 1.0f / 2, 0.0f / 1},
EXPECT_TRUE(test::all_close(
test::NDArray<float, 4>({{{{0.0f / 1, 1.0f / 2, 1.0f / 3, 1.0f / 2, 0.0f / 1},
{0.0f / 2, 4.0f / 4, 6.0f / 6, 6.0f / 4, 2.0f / 2},
{2.0f / 3, 6.0f / 6, 8.0f / 9, 6.0f / 6, 2.0f / 3},
{2.0f / 2, 5.0f / 4, 7.0f / 6, 5.0f / 4, 2.0f / 2},
......@@ -6321,13 +6316,12 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_3x3)
{8.0f / 3, 19.0f / 6, 35.0f / 9, 27.0f / 6, 16.0f / 3},
{5.0f / 2, 11.0f / 4, 25.0f / 6, 20.0f / 4, 14.0f / 2},
{3.0f / 1, 9.0f / 2, 14.0f / 3, 11.0f / 2, 5.0f / 1}}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_3x3_strided)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 1, 3, 3};
Shape window_shape{3, 3};
auto window_movement_strides = Strides{2, 2};
......@@ -6351,19 +6345,18 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_3x3_strided)
auto result = backend->create_tensor(element::f32, shape_r);
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 4>({{{{0.0f / 1, 1.0f / 3, 0.0f / 1},
EXPECT_TRUE(test::all_close(test::NDArray<float, 4>({{{{0.0f / 1, 1.0f / 3, 0.0f / 1},
{2.0f / 3, 8.0f / 9, 2.0f / 3},
{2.0f / 1, 2.0f / 3, 0.0f / 1}},
{{3.0f / 1, 10.0f / 3, 2.0f / 1},
{8.0f / 3, 35.0f / 9, 16.0f / 3},
{3.0f / 1, 14.0f / 3, 5.0f / 1}}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_3x3_strided_uneven)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 1, 3, 3};
Shape window_shape{3, 3};
auto window_movement_strides = Strides{2, 3};
......@@ -6387,11 +6380,12 @@ TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image_padded_3x3_strided_uneven)
auto result = backend->create_tensor(element::f32, shape_r);
backend->call(f, {result}, {a});
EXPECT_EQ((test::NDArray<float, 4>(
EXPECT_TRUE(test::all_close(
test::NDArray<float, 4>(
{{{{0.0f / 1, 1.0f / 2}, {2.0f / 3, 6.0f / 6}, {2.0f / 1, 0.0f / 2}},
{{3.0f / 1, 7.0f / 2}, {8.0f / 3, 27.0f / 6}, {3.0f / 1, 11.0f / 2}}}})
.get_vector()),
read_vector<float>(result));
.get_vector(),
read_vector<float>(result)));
}
TEST(${BACKEND_NAME}, pad_interior_1d)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment