Commit 606f3f93 authored by Fenglei's avatar Fenglei Committed by Robert Kimball

nvgpu cuda reduce with stable sum (#2076)

* add some helper function

* update with new helper function

* update reduce to nd with new helper function

* update float sum to stable sum

* fix bug

* update all reduce to stable sum for float

* fix bug and pass the sum stable test

* remove debug info

* style

* update with shape

* fix bug

* add host parameters to cuda_emitter

* clang format

* fix bugs

* add element::type support

* format

* add a cached value with datatype name

* add init_reduce_value

* unroll loop

* optimization

* remove the need for init_value

* add memset kernel

* add memcpy

* working version

* remove debug info

* add comments, clean up code.

* change in_idx to input_idx

* fix bug

* change args name for memset in emitter

* pass element::Type instead of string

* the op::reduce come with init value, add support

* resolve codacy-bot comment

* fix bug

* resove codacy-bot comment

* remove unused comments, resolve comments

* cuda reduce for max, min, mul, reduce op init value, format

* use type::info

* use type info for numeric_limits

* remove code from gpu_host_parameters

* header

* remvoe outdated comments

* add helper to check if stable sum is needed

* add stable sum test for double

* remove extra line

* consolidate helper functions

* no need list now.

* remove extra ;

* clang format

* style

* add skip test for cpu and intelGPU side

* add line between groups of headers

* add two simple stable sum test for float and double

* skip test for intelGPU
parent 4b0445d1
...@@ -12,6 +12,7 @@ shape_of_scalar ...@@ -12,6 +12,7 @@ shape_of_scalar
shape_of_vector shape_of_vector
shape_of_matrix shape_of_matrix
shape_of_5d shape_of_5d
sum_stable_acc_double
quantize_clamp_int32 quantize_clamp_int32
# failing in CI build but passing on local machine # failing in CI build but passing on local machine
......
...@@ -69,8 +69,10 @@ std::ostream& operator<<(std::ostream& os, pooling_op_shape& shape) ...@@ -69,8 +69,10 @@ std::ostream& operator<<(std::ostream& os, pooling_op_shape& shape)
} }
runtime::gpu::CUDAEmitter::CUDAEmitter(runtime::gpu::GPUPrimitiveEmitter* emitter, runtime::gpu::CUDAEmitter::CUDAEmitter(runtime::gpu::GPUPrimitiveEmitter* emitter,
runtime::gpu::GPURuntimeContext* ctx) runtime::gpu::GPURuntimeContext* ctx,
: m_primitive_emitter(emitter) std::shared_ptr<GPUHostParameters> params)
: m_host_parameters(params)
, m_primitive_emitter(emitter)
{ {
m_ctx = ctx; m_ctx = ctx;
} }
...@@ -227,11 +229,7 @@ size_t runtime::gpu::CUDAEmitter::build_topk(const std::vector<element::Type>& d ...@@ -227,11 +229,7 @@ size_t runtime::gpu::CUDAEmitter::build_topk(const std::vector<element::Type>& d
<< " The axis along which topk is computed should be the last axis"; << " The axis along which topk is computed should be the last axis";
size_t num_cols = input_shape[rank - 1]; size_t num_cols = input_shape[rank - 1];
size_t num_rows = ((rank == 2) ? input_shape[0] : 1); size_t num_rows = ((rank == 2) ? input_shape[0] : 1);
std::vector<std::string> dtypes_string; std::vector<std::string> dtypes_string = get_string_vector(dtypes);
for (auto& dtype : dtypes)
{
dtypes_string.push_back(dtype.c_type_string());
}
/* The struct 'Entry' used in the kernel looks like this: /* The struct 'Entry' used in the kernel looks like this:
struct Entry struct Entry
...@@ -1404,6 +1402,68 @@ size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const std::vector<std ...@@ -1404,6 +1402,68 @@ size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const std::vector<std
return this->m_primitive_emitter->register_primitive(ew, hash); return this->m_primitive_emitter->register_primitive(ew, hash);
} }
size_t runtime::gpu::CUDAEmitter::build_memset(const std::string& dtype, uint32_t tensor_size)
{
// kernel_name is used to check if the cuda kernel has been previously compiled
std::stringstream kernel_name;
kernel_name << "memset_" << dtype;
// hash is used to check if the emitted primitive already exists
std::stringstream ss;
ss << kernel_name.str() << "_s_" << tensor_size;
auto hash = ss.str();
// if the primitive exists, we are done
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtype, "in").add_placeholder(dtype, "out").add("nthreads", tensor_size);
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_memset_op(writer, kernel_name.str(), dtype, args);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
// TODO: currently we set it to 512, will add tuning method later
uint32_t block_size_x = 512;
int num_SMs;
CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute(&num_SMs, cudaDevAttrMultiProcessorCount, 0));
uint32_t aligned_grid_size_x =
fmin(num_SMs * 32, align_to_block_size(tensor_size, block_size_x));
// create the launch primitive
std::unique_ptr<gpu::primitive> memset(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1, // grid dim
block_size_x,
1,
1, // block dim
0,
nullptr, // shared mem and stream
args_list,
nullptr)); // arguments
debug_sync();
}});
return this->m_primitive_emitter->register_primitive(memset, hash);
}
size_t runtime::gpu::CUDAEmitter::build_cudnn_bn_inv_var(const std::vector<std::string>& dtypes, size_t runtime::gpu::CUDAEmitter::build_cudnn_bn_inv_var(const std::vector<std::string>& dtypes,
NVShape tensor_shape, NVShape tensor_shape,
const double& eps) const double& eps)
...@@ -1574,7 +1634,7 @@ size_t runtime::gpu::CUDAEmitter::build_softmax(const std::vector<std::string>& ...@@ -1574,7 +1634,7 @@ size_t runtime::gpu::CUDAEmitter::build_softmax(const std::vector<std::string>&
{ {
size_t rank = input_shape.size(); size_t rank = input_shape.size();
size_t reduce_rank = reduce_axis.size(); size_t reduce_rank = reduce_axis.size();
size_t out_rank = rank - reduce_rank; size_t non_reduce_rank = rank - reduce_rank;
// assumes NC{d1,...,dn} format // assumes NC{d1,...,dn} format
std::string kernel_name = "softmax_" + join(dtypes, "_"); std::string kernel_name = "softmax_" + join(dtypes, "_");
kernel_name += kernel_name +=
...@@ -1639,7 +1699,7 @@ size_t runtime::gpu::CUDAEmitter::build_softmax(const std::vector<std::string>& ...@@ -1639,7 +1699,7 @@ size_t runtime::gpu::CUDAEmitter::build_softmax(const std::vector<std::string>&
codegen::CodeWriter writer; codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer); CudaKernelBuilder::add_pod_typedefs(writer);
runtime::gpu::CudaKernelBuilder::get_softmax_op( runtime::gpu::CudaKernelBuilder::get_softmax_op(
writer, kernel_name, args, dtypes, out_rank, reduce_rank); writer, kernel_name, args, dtypes, non_reduce_rank, reduce_rank);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code()); compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
} }
...@@ -1666,23 +1726,30 @@ size_t runtime::gpu::CUDAEmitter::build_softmax(const std::vector<std::string>& ...@@ -1666,23 +1726,30 @@ size_t runtime::gpu::CUDAEmitter::build_softmax(const std::vector<std::string>&
return this->m_primitive_emitter->register_primitive(softmax, hash); return this->m_primitive_emitter->register_primitive(softmax, hash);
} }
size_t runtime::gpu::CUDAEmitter::build_reduce_to_nd(const std::vector<std::string>& dtypes, size_t runtime::gpu::CUDAEmitter::build_reduce_to_nd(const std::vector<element::Type>& dtypes,
NVShape input_shape, NVShape input_shape,
NVShape reduce_axis, NVShape reduce_axis,
const char* op, const char* op,
const char* kernel) const char* kernel)
{ {
size_t rank = input_shape.size(); std::vector<std::string> dtypes_str = get_string_vector(dtypes);
size_t reduce_rank = reduce_axis.size(); //if call from reduce, this is duplicated
size_t out_rank = rank - reduce_rank; NVShape simplified_reduce_axis;
NVShape simplified_input_shape;
// simplified_reduce_axis will not be empty, since we checked if input size is same as output size in gpu_emitter
simplify_reduce_shape(input_shape, reduce_axis, simplified_input_shape, simplified_reduce_axis);
size_t rank = simplified_input_shape.size();
size_t reduce_rank = simplified_reduce_axis.size();
size_t non_reduce_rank = rank - reduce_rank;
// assumes NC{d1,...,dn} format // assumes NC{d1,...,dn} format
std::string kernel_name = "reduce_nd_" + join(dtypes, "_") + "_" + op; std::string kernel_name = "reduce_nd_" + join(dtypes_str, "_") + "_" + op;
kernel_name += kernel_name += "_ri_" + std::to_string(simplified_input_shape.size()) + "_rr_" +
"_ri_" + std::to_string(input_shape.size()) + "_rr_" + std::to_string(reduce_axis.size()); std::to_string(simplified_reduce_axis.size());
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_'); std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
std::stringstream ss; std::stringstream ss;
ss << kernel_name << "_s_" << join(input_shape, "_") << "_axis_" << join(reduce_axis, "_"); ss << kernel_name << "_s_" << join(simplified_input_shape, "_") << "_axis_"
<< join(simplified_reduce_axis, "_");
auto hash = ss.str(); auto hash = ss.str();
// check if the requested kernel is already an inserted primitive // check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash); size_t primitive_index = m_primitive_emitter->lookup(hash);
...@@ -1691,41 +1758,41 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_nd(const std::vector<std::stri ...@@ -1691,41 +1758,41 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_nd(const std::vector<std::stri
return primitive_index; return primitive_index;
} }
NVShape reduce_flag(rank, 0); NVShape non_reduce_shape;
for (auto a : reduce_axis)
{
reduce_flag[a] = 1;
}
NVShape output_shape;
NVShape non_reduce_strides; NVShape non_reduce_strides;
NVShape non_reduce_strides_in_input;
NVShape reduce_shape; NVShape reduce_shape;
NVShape reduce_strides; NVShape reduce_strides;
NVShape input_strides = row_major_strides(input_shape); NVShape reduce_strides_in_input;
for (int i = 0; i < rank; i++) get_reduce_strides(simplified_input_shape,
{ simplified_reduce_axis,
if (reduce_flag[i] != 0) non_reduce_shape,
{ non_reduce_strides,
reduce_shape.push_back(input_shape[i]); non_reduce_strides_in_input,
reduce_strides.push_back(input_strides[i]); reduce_shape,
} reduce_strides,
else reduce_strides_in_input);
{
non_reduce_strides.push_back(input_strides[i]); std::vector<int> non_reduce_strides_magic;
output_shape.push_back(input_shape[i]); std::vector<int> non_reduce_strides_shift;
}
} div_to_mul(non_reduce_strides, non_reduce_strides_magic, non_reduce_strides_shift);
NVShape output_strides = row_major_strides(output_shape);
uint32_t nthreads = static_cast<uint32_t>(shape_size(output_shape)); uint32_t reduce_count = static_cast<uint32_t>(shape_size(reduce_shape));
uint32_t nthreads = static_cast<uint32_t>(shape_size(non_reduce_shape));
// TODO: currently we set it to 64, will add tuning method later // TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64; uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x); uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
auto args = m_primitive_emitter->add_kernel_args(); auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in") args.add_placeholder(dtypes_str[0], "in0")
.add_placeholder(dtypes[1], "out") .add_placeholder(dtypes_str[1], "out")
.add("out_strides", output_strides)
.add("non_reduce_strides", non_reduce_strides) .add("non_reduce_strides", non_reduce_strides)
.add("non_reduce_strides_magic", non_reduce_strides_magic)
.add("non_reduce_strides_shift", non_reduce_strides_shift)
.add("non_reduce_strides_in_input", non_reduce_strides_in_input)
.add("reduce_shape", reduce_shape) .add("reduce_shape", reduce_shape)
.add("reduce_strides", reduce_strides) .add("reduce_strides_in_input", reduce_strides_in_input)
.add("reduce_count", reduce_count)
.add("nthreads", nthreads); .add("nthreads", nthreads);
// if the kernel has not been compiled, build it // if the kernel has not been compiled, build it
...@@ -1737,10 +1804,10 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_nd(const std::vector<std::stri ...@@ -1737,10 +1804,10 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_nd(const std::vector<std::stri
if (kernel) if (kernel)
{ {
CudaKernelBuilder::get_device_helper( CudaKernelBuilder::get_device_helper(
writer, op, kernel, {{dtypes[0], dtypes[0], dtypes[1]}}); writer, op, kernel, {{dtypes_str[0], dtypes_str[0], dtypes_str[1]}});
} }
runtime::gpu::CudaKernelBuilder::get_reduce_to_nd_op( runtime::gpu::CudaKernelBuilder::get_reduce_to_nd_op(
writer, kernel_name, args, dtypes, op, out_rank, reduce_rank); writer, kernel_name, args, dtypes_str, op, non_reduce_rank, reduce_rank);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code()); compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
} }
...@@ -1767,14 +1834,15 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_nd(const std::vector<std::stri ...@@ -1767,14 +1834,15 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_nd(const std::vector<std::stri
return this->m_primitive_emitter->register_primitive(reduce, hash); return this->m_primitive_emitter->register_primitive(reduce, hash);
} }
size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar(const std::vector<std::string>& dtypes, size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar(const std::vector<element::Type>& dtypes,
const size_t data_bytes,
NVShape input_shape, NVShape input_shape,
const char* op, const char* op,
const char* kernel) const char* kernel)
{ {
std::vector<std::string> dtypes_str = get_string_vector(dtypes);
uint32_t data_bytes = dtypes[0].size();
// assumes NC{d1,...,dn} format // assumes NC{d1,...,dn} format
std::string kernel_name = "reduce_scalar_" + join(dtypes, "_") + "_" + op; std::string kernel_name = "reduce_scalar_" + join(dtypes_str, "_") + "_" + op;
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_'); std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
std::stringstream ss; std::stringstream ss;
...@@ -1790,17 +1858,15 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar(const std::vector<std:: ...@@ -1790,17 +1858,15 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar(const std::vector<std::
uint32_t nthreads = static_cast<uint32_t>(shape_size(input_shape)); uint32_t nthreads = static_cast<uint32_t>(shape_size(input_shape));
uint32_t n = nthreads; uint32_t n = nthreads;
uint32_t block_size_x = 1; uint32_t block_size_x = 1;
while (n > 1) while ((block_size_x << 1) <= fmin(512, n))
{ {
block_size_x <<= 1; block_size_x <<= 1;
n >>= 1;
} }
block_size_x = fmin(512, block_size_x);
uint32_t shared_data_bytes = block_size_x * static_cast<uint32_t>(data_bytes); uint32_t shared_data_bytes = block_size_x * static_cast<uint32_t>(data_bytes);
kernel_name += "_b_" + std::to_string(block_size_x); kernel_name += "_b_" + std::to_string(block_size_x);
auto args = m_primitive_emitter->add_kernel_args(); auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in") args.add_placeholder(dtypes_str[0], "in")
.add_placeholder(dtypes[1], "out") .add_placeholder(dtypes_str[1], "out")
.add("nthreads", nthreads); .add("nthreads", nthreads);
// if the kernel has not been compiled, build it // if the kernel has not been compiled, build it
...@@ -1812,10 +1878,10 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar(const std::vector<std:: ...@@ -1812,10 +1878,10 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar(const std::vector<std::
if (kernel) if (kernel)
{ {
CudaKernelBuilder::get_device_helper( CudaKernelBuilder::get_device_helper(
writer, op, kernel, {{dtypes[0], dtypes[0], dtypes[1]}}); writer, op, kernel, {{dtypes_str[0], dtypes_str[0], dtypes_str[1]}});
} }
runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_op( runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_op(
writer, kernel_name, args, dtypes, op, block_size_x); writer, kernel_name, args, dtypes_str, op, block_size_x);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code()); compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
} }
...@@ -1842,15 +1908,17 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar(const std::vector<std:: ...@@ -1842,15 +1908,17 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar(const std::vector<std::
return this->m_primitive_emitter->register_primitive(reduce, hash); return this->m_primitive_emitter->register_primitive(reduce, hash);
} }
size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar_acc(const std::vector<std::string>& dtypes, size_t
runtime::gpu::CUDAEmitter::build_reduce_to_scalar_acc(const std::vector<element::Type>& dtypes,
NVShape input_shape, NVShape input_shape,
NVShape output_shape, NVShape output_shape,
uint32_t block_size_x, uint32_t block_size_x,
const char* op, const char* op,
const char* kernel) const char* kernel)
{ {
std::vector<std::string> dtypes_str = get_string_vector(dtypes);
// assumes NC{d1,...,dn} format // assumes NC{d1,...,dn} format
std::string kernel_name = "reduce_acc_" + join(dtypes, "_") + "_" + op; std::string kernel_name = "reduce_acc_" + join(dtypes_str, "_") + "_" + op;
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_'); std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
std::stringstream ss; std::stringstream ss;
...@@ -1865,8 +1933,8 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar_acc(const std::vector<s ...@@ -1865,8 +1933,8 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar_acc(const std::vector<s
uint32_t nthreads = static_cast<uint32_t>(shape_size(input_shape)); uint32_t nthreads = static_cast<uint32_t>(shape_size(input_shape));
auto args = m_primitive_emitter->add_kernel_args(); auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in") args.add_placeholder(dtypes_str[0], "in")
.add_placeholder(dtypes[1], "out") .add_placeholder(dtypes_str[1], "out")
.add("nthreads", nthreads); .add("nthreads", nthreads);
uint32_t aligned_grid_size_x = static_cast<uint32_t>(shape_size(output_shape)) / block_size_x; uint32_t aligned_grid_size_x = static_cast<uint32_t>(shape_size(output_shape)) / block_size_x;
...@@ -1880,10 +1948,10 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar_acc(const std::vector<s ...@@ -1880,10 +1948,10 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar_acc(const std::vector<s
if (kernel) if (kernel)
{ {
CudaKernelBuilder::get_device_helper( CudaKernelBuilder::get_device_helper(
writer, op, kernel, {{dtypes[0], dtypes[0], dtypes[1]}}); writer, op, kernel, {{dtypes_str[0], dtypes_str[0], dtypes_str[1]}});
} }
runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_acc_op( runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_acc_op(
writer, kernel_name, args, dtypes, op); writer, kernel_name, args, dtypes_str, op);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code()); compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
} }
...@@ -1908,27 +1976,37 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar_acc(const std::vector<s ...@@ -1908,27 +1976,37 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_to_scalar_acc(const std::vector<s
return this->m_primitive_emitter->register_primitive(reduce_acc, hash); return this->m_primitive_emitter->register_primitive(reduce_acc, hash);
} }
size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<std::string>& dtypes, size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<element::Type>& dtypes,
const size_t data_bytes,
const NVShape& input_shape, const NVShape& input_shape,
const NVShape& output_shape,
const NVShape& reduce_axis, const NVShape& reduce_axis,
const char* op, const char* op,
const char* kernel) const char* kernel,
const bool with_init_value)
{ {
size_t rank = input_shape.size(); NVShape simplified_reduce_axis;
size_t reduce_rank = reduce_axis.size(); NVShape simplified_input_shape;
size_t out_rank = rank - reduce_rank; // simplified_reduce_axis will not be empty, since we checked if input size is same as output size in gpu_emitter
simplify_reduce_shape(input_shape, reduce_axis, simplified_input_shape, simplified_reduce_axis);
size_t rank = simplified_input_shape.size();
size_t reduce_rank = simplified_reduce_axis.size();
size_t non_reduce_rank = rank - reduce_rank;
uint32_t nthreads = static_cast<uint32_t>(shape_size(input_shape));
uint32_t data_bytes = dtypes[0].size();
std::vector<std::string> dtypes_str = get_string_vector(dtypes);
// assumes NC{d1,...,dn} format // assumes NC{d1,...,dn} format
std::string kernel_name = "reduce_" + join(dtypes, "_") + "_" + op; std::string kernel_name = "reduce_" + join(dtypes_str, "_") + "_" + op;
if (out_rank != 0) if (non_reduce_rank != 0)
{ {
kernel_name += "_ri_" + std::to_string(input_shape.size()) + "_rr_" + kernel_name += "_ri_" + std::to_string(simplified_input_shape.size()) + "_rr_" +
std::to_string(reduce_axis.size()); std::to_string(simplified_reduce_axis.size());
} }
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_'); std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
std::stringstream ss; std::stringstream ss;
ss << kernel_name << "_s_" << join(input_shape, "_") << "_axis_" << join(reduce_axis, "_"); ss << kernel_name << "_s_" << join(simplified_input_shape, "_") << "_axis_"
<< join(simplified_reduce_axis, "_");
auto hash = ss.str(); auto hash = ss.str();
// check if the requested kernel is already an inserted primitive // check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash); size_t primitive_index = m_primitive_emitter->lookup(hash);
...@@ -1941,10 +2019,57 @@ size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<std::string>& d ...@@ -1941,10 +2019,57 @@ size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<std::string>& d
CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute(&num_SMs, cudaDevAttrMultiProcessorCount, 0)); CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute(&num_SMs, cudaDevAttrMultiProcessorCount, 0));
uint32_t block_size_x_acc = 256; uint32_t block_size_x_acc = 256;
uint32_t nthreads_acc = num_SMs * block_size_x_acc; uint32_t nthreads_acc = num_SMs * block_size_x_acc;
//call reduce_to_nd
if (out_rank != 0) // if input size is 0, memset output to inital value
if (nthreads == 0)
{
size_t memset_idx =
build_memset(dtypes_str[0], static_cast<uint32_t>(shape_size(output_shape)));
if (with_init_value)
{ {
size_t reduce_idx = build_reduce_to_nd(dtypes, input_shape, reduce_axis, op, kernel); std::unique_ptr<gpu::primitive> memset(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
gpu::invoke_primitive(m_ctx,
memset_idx,
std::vector<void*>{inputs[1]}.data(),
std::vector<void*>{outputs[0]}.data());
}});
primitive_index = this->m_primitive_emitter->insert(std::move(memset));
}
else
{
void* init_value = get_init_reduce_val(op, dtypes_str[0]);
// get an allocator for transient per kernel gpu memory
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
// (lazy) allocation for kernel arguments
size_t idx_init_value = allocator.reserve_argspace(init_value, data_bytes);
std::unique_ptr<gpu::primitive> memset(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void* init_value_buff =
runtime::gpu::invoke_memory_primitive(m_ctx, idx_init_value);
gpu::invoke_primitive(m_ctx,
memset_idx,
std::vector<void*>{init_value_buff}.data(),
std::vector<void*>{outputs[0]}.data());
}});
primitive_index = this->m_primitive_emitter->insert(std::move(memset));
}
}
// if input size is same as output size, do a copy
else if (nthreads == static_cast<uint32_t>(shape_size(output_shape)))
{
size_t size = nthreads * data_bytes;
std::unique_ptr<gpu::primitive> memcopy(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
runtime::gpu::cuda_memcpyDtD(outputs[0], inputs[0], size);
}});
primitive_index = this->m_primitive_emitter->insert(std::move(memcopy));
}
// if output is not scalar, do reduce_to_nd
else if (non_reduce_rank != 0)
{
size_t reduce_idx =
build_reduce_to_nd(dtypes, simplified_input_shape, simplified_reduce_axis, op, kernel);
std::unique_ptr<gpu::primitive> reduce( std::unique_ptr<gpu::primitive> reduce(
new gpu::primitive{[=](void** inputs, void** outputs) mutable { new gpu::primitive{[=](void** inputs, void** outputs) mutable {
...@@ -1957,7 +2082,6 @@ size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<std::string>& d ...@@ -1957,7 +2082,6 @@ size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<std::string>& d
} }
else else
{ {
uint32_t nthreads = static_cast<uint32_t>(shape_size(input_shape));
//if the data size is large, call reduce_to_scalar_acc first and then reduce_to_scalar. //if the data size is large, call reduce_to_scalar_acc first and then reduce_to_scalar.
//other wise, call reduce to scalar directly. //other wise, call reduce to scalar directly.
const uint32_t unroll_size = 8; const uint32_t unroll_size = 8;
...@@ -1965,9 +2089,8 @@ size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<std::string>& d ...@@ -1965,9 +2089,8 @@ size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<std::string>& d
{ {
NVShape acc_output_shape{nthreads_acc}; NVShape acc_output_shape{nthreads_acc};
size_t reduce_scalar_acc_idx = build_reduce_to_scalar_acc( size_t reduce_scalar_acc_idx = build_reduce_to_scalar_acc(
dtypes, input_shape, acc_output_shape, block_size_x_acc, op, kernel); dtypes, simplified_input_shape, acc_output_shape, block_size_x_acc, op, kernel);
size_t reduce_scalar_idx = size_t reduce_scalar_idx = build_reduce_to_scalar(dtypes, acc_output_shape, op, kernel);
build_reduce_to_scalar(dtypes, data_bytes, acc_output_shape, op, kernel);
// get an allocator for transient per kernel gpu memory // get an allocator for transient per kernel gpu memory
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator(); GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_workspace = allocator.reserve_workspace(nthreads_acc * data_bytes); size_t idx_workspace = allocator.reserve_workspace(nthreads_acc * data_bytes);
...@@ -1988,7 +2111,7 @@ size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<std::string>& d ...@@ -1988,7 +2111,7 @@ size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<std::string>& d
else else
{ {
size_t reduce_scalar_idx = size_t reduce_scalar_idx =
build_reduce_to_scalar(dtypes, data_bytes, input_shape, op, kernel); build_reduce_to_scalar(dtypes, simplified_input_shape, op, kernel);
std::unique_ptr<gpu::primitive> reduce_scalar( std::unique_ptr<gpu::primitive> reduce_scalar(
new gpu::primitive{[=](void** inputs, void** outputs) mutable { new gpu::primitive{[=](void** inputs, void** outputs) mutable {
gpu::invoke_primitive(m_ctx, gpu::invoke_primitive(m_ctx,
...@@ -2805,3 +2928,178 @@ void runtime::gpu::CUDAEmitter::debug_sync() ...@@ -2805,3 +2928,178 @@ void runtime::gpu::CUDAEmitter::debug_sync()
#endif #endif
return; return;
} }
void runtime::gpu::CUDAEmitter::simplify_reduce_shape(NVShape in,
NVShape reduce_axis,
NVShape& simplified_shape,
NVShape& simplified_reduce_axis)
{
int32_t rank = in.size();
// Sort the axis incase it's not sorted.
std::sort(reduce_axis.begin(), reduce_axis.end());
// Clear simplified_shape and axis
simplified_shape.clear();
simplified_reduce_axis.clear();
// Combine axis if there is two or more adjeciant reuce_axis
// combine axis if there is two or more adjeciant non_reuce_axis
// update combined shape and axis
NVShape combined_reduce_axis;
NVShape adj_map(rank, 0);
size_t combined_axis_count = 0;
if (reduce_axis.empty())
{
simplified_shape = in;
simplified_reduce_axis = reduce_axis;
return;
}
for (int32_t i = 0; i < static_cast<int32_t>(reduce_axis[0]) - 1; i++)
{
adj_map[i] = 1;
combined_axis_count++;
}
for (int32_t i = 0; i < reduce_axis.size() - 1; i++)
{
if (static_cast<int32_t>(reduce_axis[i + 1]) - static_cast<int32_t>(reduce_axis[i]) == 1)
{
adj_map[reduce_axis[i]] = 1;
combined_axis_count++;
}
else
{
combined_reduce_axis.push_back(reduce_axis[i] - combined_axis_count);
for (int32_t j = static_cast<int32_t>(reduce_axis[i]) + 1;
j < static_cast<int32_t>(reduce_axis[i + 1]) - 1;
j++)
{
adj_map[j] = 1;
combined_axis_count++;
}
}
}
combined_reduce_axis.push_back(reduce_axis.back() - combined_axis_count);
for (int32_t i = static_cast<int32_t>(reduce_axis.back()) + 1; i < rank - 1; i++)
{
adj_map[i] = 1;
}
NVShape combined_shape;
size_t shape_i = 1;
for (int i = 0; i < rank; i++)
{
if (adj_map[i] == 1)
{
shape_i *= in[i];
}
else
{
combined_shape.push_back(shape_i * in[i]);
shape_i = 1;
}
}
// eleminate dimensons when dimension size = 1, update shape and reduce axis
size_t reduce_idx = 0;
size_t eliminated_axis_count = 0;
for (int32_t i = 0; i < combined_shape.size(); i++)
{
if (combined_shape[i] == 1)
{
eliminated_axis_count++;
}
else
{
simplified_shape.push_back(combined_shape[i]);
if (i == combined_reduce_axis[reduce_idx])
{
simplified_reduce_axis.push_back(i - eliminated_axis_count);
}
}
if (reduce_idx < combined_reduce_axis.size() - 1)
{
reduce_idx = (i == combined_reduce_axis[reduce_idx]) ? reduce_idx + 1 : reduce_idx;
}
}
}
void runtime::gpu::CUDAEmitter::get_reduce_strides(NVShape input_shape,
NVShape reduce_axis,
NVShape& non_reduce_shape,
NVShape& non_reduce_strides,
NVShape& non_reduce_strides_in_input,
NVShape& reduce_shape,
NVShape& reduce_strides,
NVShape& reduce_strides_in_input)
{
size_t rank = input_shape.size();
NVShape reduce_flag(rank, 0);
for (auto a : reduce_axis)
{
reduce_flag[a] = 1;
}
NVShape input_strides = row_major_strides(input_shape);
for (int i = 0; i < rank; i++)
{
if (reduce_flag[i] != 0)
{
reduce_shape.push_back(input_shape[i]);
reduce_strides_in_input.push_back(input_strides[i]);
}
else
{
non_reduce_shape.push_back(input_shape[i]);
non_reduce_strides_in_input.push_back(input_strides[i]);
}
}
reduce_strides = row_major_strides(reduce_shape);
non_reduce_strides = row_major_strides(non_reduce_shape);
}
void runtime::gpu::CUDAEmitter::div_to_mul(const NVShape& shape,
std::vector<int>& magic,
std::vector<int>& shift)
{
for (int i = 0; i < shape.size(); i++)
{
int _magic;
int _shift;
std::tie(_magic, _shift) = idiv_magic_u64(shape[i]);
magic.push_back(_magic);
shift.push_back(_shift);
}
}
void* runtime::gpu::CUDAEmitter::get_init_reduce_val(std::string reduce_op, std::string data_type)
{
if (reduce_op == "fmaxf" || reduce_op == "max")
{
return TypeInfo::Get(data_type)->lowest_ptr();
}
else if (reduce_op == "fminf" || reduce_op == "min")
{
return TypeInfo::Get(data_type)->max_ptr();
}
else if (reduce_op == "mul" || reduce_op == "logical_and")
{
return m_host_parameters->val_by_datatype(data_type, static_cast<int64_t>(1));
}
else if (reduce_op == "add" || reduce_op == "logical_or")
{
return m_host_parameters->val_by_datatype(data_type, static_cast<int64_t>(0));
}
else
{
//not defined.
throw std::runtime_error(data_type + "currently not supportted with init value.");
}
}
std::vector<std::string>
runtime::gpu::CUDAEmitter::get_string_vector(const std::vector<element::Type>& dtypes)
{
std::vector<std::string> str;
for (auto const& a : dtypes)
{
str.push_back(a.c_type_string());
}
return str;
}
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <array> #include <array>
#include "ngraph/codegen/code_writer.hpp" #include "ngraph/codegen/code_writer.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp" #include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
#include "ngraph/runtime/gpu/gpu_host_parameters.hpp"
#include "ngraph/runtime/gpu/nvdiff.hpp" #include "ngraph/runtime/gpu/nvdiff.hpp"
#include "ngraph/runtime/gpu/nvshape.hpp" #include "ngraph/runtime/gpu/nvshape.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
...@@ -49,6 +50,8 @@ namespace ngraph ...@@ -49,6 +50,8 @@ namespace ngraph
size_t build_primitive(const op::ReplaceSlice* node, bool in_place_op); size_t build_primitive(const op::ReplaceSlice* node, bool in_place_op);
public: public:
size_t build_memset(const std::string& dtype, uint32_t tensor_size);
size_t build_topk(const std::vector<element::Type>& dtypes, size_t build_topk(const std::vector<element::Type>& dtypes,
const NVShape& input_shape, const NVShape& input_shape,
const size_t topk_axis, const size_t topk_axis,
...@@ -125,17 +128,19 @@ namespace ngraph ...@@ -125,17 +128,19 @@ namespace ngraph
const double& eps); const double& eps);
template <typename T> template <typename T>
size_t build_reduce(const std::vector<std::string>& dtypes, size_t build_reduce(const std::vector<element::Type>& dtypes,
const size_t data_bytes, NVShape input_shape,
const NVShape& input_shape, NVShape output_shape,
const NVShape& reduce_axis) NVShape reduce_axis,
const bool with_init_value = false)
{ {
return build_reduce(dtypes, return build_reduce(dtypes,
data_bytes,
input_shape, input_shape,
output_shape,
reduce_axis, reduce_axis,
CudaOpMap<T>::op, CudaOpMap<T>::op,
CudaOpMap<T>::math_kernel); CudaOpMap<T>::math_kernel,
with_init_value);
} }
template <typename ELEMENTWISE_OP_TYPE, typename REDUCE_OP_TYPE = ngraph::op::Nop> template <typename ELEMENTWISE_OP_TYPE, typename REDUCE_OP_TYPE = ngraph::op::Nop>
...@@ -193,7 +198,9 @@ namespace ngraph ...@@ -193,7 +198,9 @@ namespace ngraph
void sync(); void sync();
private: private:
CUDAEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx); CUDAEmitter(GPUPrimitiveEmitter* emitter,
GPURuntimeContext* ctx,
std::shared_ptr<GPUHostParameters> params);
uint32_t align_to_block_size(uint32_t threads, uint32_t block_size); uint32_t align_to_block_size(uint32_t threads, uint32_t block_size);
void print_tensor_from_gpu(codegen::CodeWriter& writer, void print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name, const std::string& tensor_name,
...@@ -211,32 +218,71 @@ namespace ngraph ...@@ -211,32 +218,71 @@ namespace ngraph
const char* kernel, const char* kernel,
const char* reduce_op, const char* reduce_op,
bool save_elementwise); bool save_elementwise);
size_t build_reduce(const std::vector<std::string>& dtypes,
const size_t data_bytes, size_t build_reduce(const std::vector<element::Type>& dtypes,
const NVShape& input_shape, const NVShape& input_shape,
const NVShape& output_shape,
const NVShape& reduce_axis, const NVShape& reduce_axis,
const char* op, const char* op,
const char* kernel); const char* kernel,
size_t build_reduce_to_nd(const std::vector<std::string>& dtypes, const bool with_init_value);
size_t build_reduce_to_nd(const std::vector<element::Type>& dtypes,
NVShape input_shape, NVShape input_shape,
NVShape reduce_axis, NVShape reduce_axis,
const char* op, const char* op,
const char* kernel); const char* kernel);
size_t build_reduce_to_scalar(const std::vector<std::string>& dtypes, size_t build_reduce_to_scalar(const std::vector<element::Type>& dtypes,
const size_t data_bytes,
NVShape input_shape, NVShape input_shape,
const char* op, const char* op,
const char* kernel); const char* kernel);
/// \brief This is the preprocess for reduce to scalar if the data size is large than a number.
//This is the preprocess for reduce to scalar if the data size is large than a number. /// The number can be tuned based on hardware.
//The number can be tuned based on hardware. /// This cuda kernel will accumulate reduction to a certain number of bins depends on hardware.
//This cuda kernel will accumulate reduction to a certain number of bins depends on hardware. size_t build_reduce_to_scalar_acc(const std::vector<element::Type>& dtypes,
size_t build_reduce_to_scalar_acc(const std::vector<std::string>& dtypes,
NVShape input_shape, NVShape input_shape,
NVShape output_shape, NVShape output_shape,
uint32_t block_size_x, uint32_t block_size_x,
const char* op, const char* op,
const char* kernel); const char* kernel);
/// \brief Simplifed reduce shape and reduce axis, remove dimsion size 1,
/// combine two or more adjacent reduce/nonreduce axis.
/// the simplified reduce shape and reduce axis will make index caculation simplier in cuda kernel.
/// example:
/// {1 1 2 2} with reduce axis {3} simplifiy to: {2 2} with reduce_axis {1};
/// {2 3 4} with reduce axis {0 1} simplify to {6 4} with reduce_axis {0};
/// {2 3 4} with reduce axis {0} simplify to {2 12} with reduce_axis {0};
void simplify_reduce_shape(NVShape in,
NVShape reduce_axis,
NVShape& simplified_shape,
NVShape& simplified_reduce_axis);
/// \brief Seperate input_shape to reduced_shape and non_reduce_shape, and calcuate strides for them
/// and strides in input. This help caculate input index and output index for cuda kernel.
/// example:
/// input_shape {2 3 4 5} with reduce_axis {0 2}:
/// input_strides: {60, 20, 5, 1}
/// reduce_shape {2 4}, reduce_strides {4 1}, reduce_strides_in_input {60 5}
/// non_reduce_shape {3 5}, non_reduce_strides {5 1}, non_reduce_strides_in_input {20 1}
void get_reduce_strides(NVShape input_shape,
NVShape reduce_axis,
NVShape& non_reduce_shape,
NVShape& non_reduce_strides,
NVShape& non_reduce_strides_in_input,
NVShape& reduce_shape,
NVShape& reduce_strides,
NVShape& reduce_strides_in_input);
/// \brief Calculate magic and shift part of an shape vector (denomitor), change divide to multiply
/// in cuda kernel.
void div_to_mul(const NVShape& shape,
std::vector<int>& magic,
std::vector<int>& shift);
/// \brief Get initial value for reduce op
void* get_init_reduce_val(std::string reduce_op, std::string data_type);
/// \brief Get vector<string> of datatype from vector<element::Type>
std::vector<std::string>
get_string_vector(const std::vector<element::Type>& dtypes);
std::shared_ptr<GPUHostParameters> m_host_parameters;
GPUPrimitiveEmitter* m_primitive_emitter; GPUPrimitiveEmitter* m_primitive_emitter;
GPURuntimeContext* m_ctx; GPURuntimeContext* m_ctx;
}; };
......
...@@ -57,6 +57,28 @@ void runtime::gpu::CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& wr ...@@ -57,6 +57,28 @@ void runtime::gpu::CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& wr
return; return;
} }
void runtime::gpu::CudaKernelBuilder::get_memset_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& data_type,
runtime::gpu::GPUKernelArgs& args)
{
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin();
{
writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; \n";
writer << "uint32_t step = gridDim.x * blockDim.x; \n";
writer << "for ( ;tid < nthreads; tid += step)\n";
writer.block_begin();
{
writer << "out[tid] = in[0];\n";
}
writer.block_end();
}
writer.block_end();
return;
}
void runtime::gpu::CudaKernelBuilder::get_cudnn_bn_inv_var_op(codegen::CodeWriter& writer, void runtime::gpu::CudaKernelBuilder::get_cudnn_bn_inv_var_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
runtime::gpu::GPUKernelArgs& args) runtime::gpu::GPUKernelArgs& args)
...@@ -483,9 +505,24 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_nd_op( ...@@ -483,9 +505,24 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_nd_op(
runtime::gpu::GPUKernelArgs& args, runtime::gpu::GPUKernelArgs& args,
const std::vector<std::string>& data_types, const std::vector<std::string>& data_types,
const std::string& reduce_op, const std::string& reduce_op,
size_t out_rank, size_t non_reduce_rank,
size_t reduce_rank) size_t reduce_rank)
{ {
bool stable_sum = stable_sum_check_helper(reduce_op, data_types[1]);
auto stable_sum_lambda = [&]() {
writer << "input_i = in0[input_idx];\n";
if (stable_sum)
{
writer << "y = input_i - c;\n";
writer << "t = r + y;\n";
writer << "c = (t - r) - y;\n";
writer << "r = t;\n";
}
else
{
writer << "r = " << reduce_op << "(r , input_i);\n";
}
};
writer << runtime::gpu::nvrtc::helpers(); writer << runtime::gpu::nvrtc::helpers();
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature(); writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin(); writer.block_begin();
...@@ -494,68 +531,74 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_nd_op( ...@@ -494,68 +531,74 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_nd_op(
writer << "if (tid < nthreads)\n"; writer << "if (tid < nthreads)\n";
writer.block_begin(); writer.block_begin();
{ {
if (out_rank > 0) collective_coordinate_transform_helper(writer,
{ "tid",
writer << "uint32_t dim_idx_generator = tid;\n"; "non_reduce_strides",
} "non_reduce_strides_magic",
writer << "uint32_t in_idx = 0;\n"; "non_reduce_strides_shift",
"non_reduce_strides_in_input",
// loop through all reduction axis "non_reduce_coordinate",
for (int64_t i = 0; i < static_cast<int64_t>(out_rank); i++) non_reduce_rank,
true,
"non_reduce_input_index");
writer << "uint32_t input_idx = non_reduce_input_index;\n";
writer << "uint32_t step = reduce_strides_in_input" << reduce_rank - 1 << ";\n";
writer << data_types[1] << " r = in0[non_reduce_input_index];\n";
if (stable_sum)
{ {
writer << "in_idx += (dim_idx_generator / out_strides" << i writer << data_types[1] << " c = 0;\n";
<< ") * non_reduce_strides" << i << ";\n"; writer << data_types[1] << " y;\n";
writer << "dim_idx_generator %= out_strides" << i << ";\n"; writer << data_types[1] << " t;\n";
} }
writer << "uint32_t init_in_idx = in_idx;\n"; writer << data_types[1] << " input_i;\n";
writer << data_types[1] << " r = in[init_in_idx];\n"; for (uint32_t i = 0; i < reduce_rank - 1; i++)
int64_t last_r_idx = static_cast<int64_t>(reduce_rank) - 1;
for (int64_t j = 0; j < last_r_idx; j++)
{ {
writer << "for(int idx" << j << " = 0; idx" << j << "< reduce_shape" << j << "; idx" writer << "for (uint32_t reduce_coordinate_" << i << " = 0; reduce_coordinate_" << i
<< j << "++)\n"; << " < reduce_shape" << i << "; reduce_coordinate_" << i << "++)\n";
writer.block_begin(); writer.block_begin();
} }
{ {
writer << "uint32_t reduce_idx = in_idx;\n"; uint32_t loop_unroll = 8;
for (int64_t j = 0; j < last_r_idx; j++) uint32_t i = reduce_rank - 1;
{ writer << "if (input_idx != non_reduce_input_index)\n";
writer << "reduce_idx += idx" << j << " * reduce_strides" << j << ";\n";
}
writer << "uint32_t step = reduce_strides" << last_r_idx << ";\n";
writer << "if(reduce_idx != init_in_idx)\n";
writer.block_begin(); writer.block_begin();
{ {
writer << "r = " << reduce_op << "(r , in[reduce_idx]);\n"; stable_sum_lambda();
} }
writer.block_end(); writer.block_end();
writer << "reduce_idx += step;\n"; writer << "input_idx += step;\n";
writer << "int idx" << last_r_idx << " = 1;\n";
// unroll last reduction axis writer << "uint32_t reduce_coordinate_" << i << " = 1;\n";
uint32_t unroll_num = 8; writer << "for (; reduce_coordinate_" << i << " + " << loop_unroll - 1
writer << "for(; idx" << last_r_idx << " + " << unroll_num << " - 1 < reduce_shape" << " < reduce_shape" << i << "; reduce_coordinate_" << i
<< last_r_idx << "; idx" << last_r_idx << " += " << unroll_num << ")\n"; << " += " << loop_unroll << ")\n";
writer.block_begin(); writer.block_begin();
{ {
for (int k = 0; k < unroll_num; k++) for (uint32_t j = 0; j < 8; j++)
{ {
writer << "r = " << reduce_op << "(r , in[reduce_idx]);\n"; stable_sum_lambda();
writer << "reduce_idx += step;\n"; writer << "input_idx += step;\n";
} }
} }
writer.block_end(); writer.block_end();
writer << "for(; idx" << last_r_idx << " < reduce_shape" << last_r_idx << "; idx" writer << "for (; reduce_coordinate_" << i << " < reduce_shape" << i
<< last_r_idx << "++)\n"; << "; reduce_coordinate_" << i << "++)\n";
writer.block_begin(); writer.block_begin();
{ {
writer << "r = " << reduce_op << "(r , in[reduce_idx]);\n"; stable_sum_lambda();
writer << "reduce_idx += step;\n"; writer << "input_idx += step;\n";
} }
writer.block_end(); writer.block_end();
writer << "input_idx -= "
<< "step * reduce_shape" << i << ";\n";
} }
for (int64_t j = 0; j < last_r_idx; j++) for (int32_t i = static_cast<int32_t>(reduce_rank - 2); i >= 0; i--)
{ {
writer << "input_idx += "
<< "reduce_strides_in_input" << i << ";\n";
writer.block_end(); writer.block_end();
writer << "input_idx -= "
<< "reduce_strides_in_input" << i << " * reduce_shape" << i << ";\n";
} }
writer << "out[tid] = r;\n"; writer << "out[tid] = r;\n";
} }
...@@ -573,6 +616,21 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_op( ...@@ -573,6 +616,21 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_op(
const std::string& reduce_op, const std::string& reduce_op,
uint32_t block_size_x) uint32_t block_size_x)
{ {
bool stable_sum = stable_sum_check_helper(reduce_op, data_types[1]);
auto stable_sum_lambda = [&]() {
writer << "input_i = in[input_idx];\n";
if (stable_sum)
{
writer << "y = input_i - c;\n";
writer << "t = r + y;\n";
writer << "c = (t - r) - y;\n";
writer << "r = t;\n";
}
else
{
writer << "r = " << reduce_op << "(r , input_i);\n";
}
};
writer << runtime::gpu::nvrtc::helpers(); writer << runtime::gpu::nvrtc::helpers();
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature(); writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin(); writer.block_begin();
...@@ -581,30 +639,37 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_op( ...@@ -581,30 +639,37 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_op(
writer << "uint32_t tid = threadIdx.x; \n"; writer << "uint32_t tid = threadIdx.x; \n";
writer << "uint32_t step = blockDim.x; \n"; writer << "uint32_t step = blockDim.x; \n";
writer << "sdata[tid] = 0;\n"; writer << "sdata[tid] = 0;\n";
writer << "uint32_t in_idx = tid;\n"; writer << "uint32_t input_idx = tid;\n";
writer << data_types[1] << " r = 0;\n"; writer << data_types[1] << " r = 0;\n";
writer << "if(in_idx < nthreads)\n"; writer << data_types[1] << " input_i;\n";
writer << "if(input_idx < nthreads)\n";
writer.block_begin(); writer.block_begin();
writer << "r = in[in_idx];\n"; writer << "r = in[input_idx];\n";
writer << "in_idx += step;\n"; writer << "input_idx += step;\n";
writer.block_end(); writer.block_end();
//accumulate reduction to blockDim.x threads //accumulate reduction to blockDim.x threads
if (stable_sum)
{
writer << data_types[1] << " c = 0;\n";
writer << data_types[1] << " y;\n";
writer << data_types[1] << " t;\n";
}
uint32_t unroll_num = 8; uint32_t unroll_num = 8;
writer << "while(in_idx + (step * " << unroll_num - 1 << ") < nthreads)\n"; writer << "while(input_idx + (step * " << unroll_num - 1 << ") < nthreads)\n";
writer.block_begin(); writer.block_begin();
{ {
for (int i = 0; i < unroll_num; i++) for (int i = 0; i < unroll_num; i++)
{ {
writer << "r = " << reduce_op << "(r , in[in_idx]);\n"; stable_sum_lambda();
writer << "in_idx += step;\n"; writer << "input_idx += step;\n";
} }
} }
writer.block_end(); writer.block_end();
writer << "while(in_idx < nthreads)\n"; writer << "while(input_idx < nthreads)\n";
writer.block_begin(); writer.block_begin();
{ {
writer << "r = " << reduce_op << "(r , in[in_idx]);\n"; stable_sum_lambda();
writer << "in_idx += step;\n"; writer << "input_idx += step;\n";
} }
writer.block_end(); writer.block_end();
...@@ -667,36 +732,58 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_acc_op( ...@@ -667,36 +732,58 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_acc_op(
const std::vector<std::string>& data_types, const std::vector<std::string>& data_types,
const std::string& reduce_op) const std::string& reduce_op)
{ {
bool stable_sum = stable_sum_check_helper(reduce_op, data_types[1]);
auto stable_sum_lambda = [&]() {
writer << "input_i = in[input_idx];\n";
if (stable_sum)
{
writer << "y = input_i - c;\n";
writer << "t = r + y;\n";
writer << "c = (t - r) - y;\n";
writer << "r = t;\n";
}
else
{
writer << "r = " << reduce_op << "(r , input_i);\n";
}
};
writer << runtime::gpu::nvrtc::helpers(); writer << runtime::gpu::nvrtc::helpers();
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature(); writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin(); writer.block_begin();
{ {
writer << "uint32_t tid = blockDim.x*blockIdx.x + threadIdx.x;\n"; writer << "uint32_t tid = blockDim.x*blockIdx.x + threadIdx.x;\n";
writer << "uint32_t step = gridDim.x * blockDim.x; \n"; writer << "uint32_t step = gridDim.x * blockDim.x; \n";
writer << "uint32_t in_idx = tid;\n"; writer << "uint32_t input_idx = tid;\n";
writer << data_types[1] << " r = 0;\n"; writer << data_types[1] << " r = 0;\n";
writer << "if(in_idx < nthreads)\n"; writer << data_types[1] << " input_i;\n";
writer << "if(input_idx < nthreads)\n";
writer.block_begin(); writer.block_begin();
writer << "r = in[in_idx];\n"; writer << "r = in[input_idx];\n";
writer << "in_idx += step;\n"; writer << "input_idx += step;\n";
writer.block_end(); writer.block_end();
//accumulate reduction to step threads //accumulate reduction to step threads
if (stable_sum)
{
writer << data_types[1] << " c = 0;\n";
writer << data_types[1] << " y;\n";
writer << data_types[1] << " t;\n";
}
uint32_t unroll_num = 8; uint32_t unroll_num = 8;
writer << "while(in_idx + (step * " << unroll_num - 1 << ") < nthreads)\n"; writer << "while(input_idx + (step * " << unroll_num - 1 << ") < nthreads)\n";
writer.block_begin(); writer.block_begin();
{ {
for (int i = 0; i < unroll_num; i++) for (int i = 0; i < unroll_num; i++)
{ {
writer << "r = " << reduce_op << "(r , in[in_idx]);\n"; stable_sum_lambda();
writer << "in_idx += step;\n"; writer << "input_idx += step;\n";
} }
} }
writer.block_end(); writer.block_end();
writer << "while(in_idx < nthreads)\n"; writer << "while(input_idx < nthreads)\n";
writer.block_begin(); writer.block_begin();
{ {
writer << "r = " << reduce_op << "(r , in[in_idx]);\n"; stable_sum_lambda();
writer << "in_idx += step;\n"; writer << "input_idx += step;\n";
} }
writer.block_end(); writer.block_end();
writer << "out[tid] = r;\n"; writer << "out[tid] = r;\n";
...@@ -1856,19 +1943,20 @@ void runtime::gpu::CudaKernelBuilder::coordinate_transform_to_multi_d(codegen::C ...@@ -1856,19 +1943,20 @@ void runtime::gpu::CudaKernelBuilder::coordinate_transform_to_multi_d(codegen::C
// product = product % stride[0] // product = product % stride[0]
// d1 = product/stride[1] // d1 = product/stride[1]
// ... // ...
writer << "int coordinate_product = " << i_coord_product << ";\n"; writer << "int " << o_coordinates << "product = " << i_coord_product << ";\n";
for (size_t i = 0; i < rank; i++) for (size_t i = 0; i < rank; i++)
{ {
if (i != 0) if (i != 0)
{ {
writer << "coordinate_product -= (" << o_coordinates << i - 1 << " * " << i_strides writer << o_coordinates << "product -= (" << o_coordinates << i - 1 << " * "
<< brace_open << i - 1 << brace_close << ");\n"; << i_strides << brace_open << i - 1 << brace_close << ");\n";
} }
writer << "int " << o_coordinates << i << " = division_by_invariant_multiplication(" writer << "int " << o_coordinates << i << " = division_by_invariant_multiplication("
<< "coordinate_product, " << i_stride_magic << brace_open << i << brace_close << ", " << o_coordinates << "product, " << i_stride_magic << brace_open << i << brace_close
<< i_stride_shift << brace_open << i << brace_close << ");\n"; << ", " << i_stride_shift << brace_open << i << brace_close << ");\n";
} }
} }
std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_helper( std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_helper(
codegen::CodeWriter& writer, codegen::CodeWriter& writer,
std::string i_thread_index, std::string i_thread_index,
...@@ -1878,7 +1966,8 @@ std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_hel ...@@ -1878,7 +1966,8 @@ std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_hel
std::string i_reduced_strides, std::string i_reduced_strides,
std::string o_coordinates, std::string o_coordinates,
size_t rank, size_t rank,
bool register_arguments) bool register_arguments,
std::string reduced_idx)
{ {
coordinate_transform_to_multi_d(writer, coordinate_transform_to_multi_d(writer,
i_strides, i_strides,
...@@ -1893,14 +1982,12 @@ std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_hel ...@@ -1893,14 +1982,12 @@ std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_hel
std::string brace_close = (register_arguments) ? "" : "]"; std::string brace_close = (register_arguments) ? "" : "]";
// index into reduced tensor from coordinates of non-reduced tensor // index into reduced tensor from coordinates of non-reduced tensor
std::string reduced_idx = "reduced_idx"; writer << "uint32_t " << reduced_idx << " = 0;\n";
writer << "int " << reduced_idx << " = 0;\n";
for (size_t i = 0; i < rank; i++) for (size_t i = 0; i < rank; i++)
{ {
writer << "reduced_idx += " << o_coordinates << i << " * " << i_reduced_strides writer << reduced_idx << " += " << o_coordinates << i << " * " << i_reduced_strides
<< brace_open << i << brace_close << ";\n"; << brace_open << i << brace_close << ";\n";
} }
return reduced_idx; return reduced_idx;
} }
...@@ -1942,3 +2029,9 @@ void runtime::gpu::CudaKernelBuilder::add_pod_typedefs(codegen::CodeWriter& writ ...@@ -1942,3 +2029,9 @@ void runtime::gpu::CudaKernelBuilder::add_pod_typedefs(codegen::CodeWriter& writ
writer << "typedef unsigned long int uint64_t;\n"; writer << "typedef unsigned long int uint64_t;\n";
writer << "\n"; writer << "\n";
} }
bool runtime::gpu::CudaKernelBuilder::stable_sum_check_helper(const std::string& op,
const std::string& data_type)
{
return ((op == "add") && (data_type == "float" || data_type == "double"));
}
...@@ -40,6 +40,11 @@ namespace ngraph ...@@ -40,6 +40,11 @@ namespace ngraph
const std::string& op, const std::string& op,
const std::vector<std::string>& data_types); const std::vector<std::string>& data_types);
static void get_memset_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& data_type,
runtime::gpu::GPUKernelArgs& args);
static void get_cudnn_bn_inv_var_op(codegen::CodeWriter& writer, static void get_cudnn_bn_inv_var_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
runtime::gpu::GPUKernelArgs& args); runtime::gpu::GPUKernelArgs& args);
...@@ -78,22 +83,34 @@ namespace ngraph ...@@ -78,22 +83,34 @@ namespace ngraph
const std::string& data_type, const std::string& data_type,
uint32_t block_size); uint32_t block_size);
/// \brief reduce op for output that is not scalar
/// stable kahan sum is been used for float point sum.
/// no initial value needed since we load one input value as initial
/// not support 0 sized input
static void get_reduce_to_nd_op(codegen::CodeWriter& writer, static void get_reduce_to_nd_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
runtime::gpu::GPUKernelArgs& args, runtime::gpu::GPUKernelArgs& args,
const std::vector<std::string>& data_types, const std::vector<std::string>& data_types,
const std::string& reduce_op, const std::string& reduce_op,
size_t out_rank, size_t non_reduce_rank,
size_t reduce_rank); size_t reduce_rank);
static void get_topk(codegen::CodeWriter& writer, /// \brief This is the preprocess to reduce to scalar if the input data size is large than a number.
/// The number can be tuned based on hardware.
/// This cuda kernel will accumulate reduction to a certain number of bins depends on hardware.
/// stable kahan sum is been used for float point sum.
/// no initial value needed since we load one input value as initial
/// not support 0 sized input
static void get_reduce_to_scalar_acc_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
const std::vector<std::string>& dtypes,
bool compute_max,
runtime::gpu::GPUKernelArgs& args, runtime::gpu::GPUKernelArgs& args,
bool use_malloc); const std::vector<std::string>& data_types,
const std::string& reduce_op);
//using one block with at most 512 threads to reduce to scalar. /// \brief This op using one block with at most 512 threads to reduce to scalar.
/// stable kahan sum is been used for float point sum.
/// no initial value needed since we load one input value as initial
/// not support 0 sized input
static void get_reduce_to_scalar_op(codegen::CodeWriter& writer, static void get_reduce_to_scalar_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
runtime::gpu::GPUKernelArgs& args, runtime::gpu::GPUKernelArgs& args,
...@@ -101,14 +118,12 @@ namespace ngraph ...@@ -101,14 +118,12 @@ namespace ngraph
const std::string& reduce_op, const std::string& reduce_op,
uint32_t block_size_x); uint32_t block_size_x);
//This is the preprocess to reduce to scalar if the data size is large than a number. static void get_topk(codegen::CodeWriter& writer,
//The number can be tuned based on hardware.
//This cuda kernel will accumulate reduction to a certain number of bins depends on hardware.
static void get_reduce_to_scalar_acc_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
const std::vector<std::string>& dtypes,
bool compute_max,
runtime::gpu::GPUKernelArgs& args, runtime::gpu::GPUKernelArgs& args,
const std::vector<std::string>& data_types, bool use_malloc);
const std::string& reduce_op);
static void get_slice_op(codegen::CodeWriter& writer, static void get_slice_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
...@@ -195,26 +210,31 @@ namespace ngraph ...@@ -195,26 +210,31 @@ namespace ngraph
static void add_pod_typedefs(codegen::CodeWriter& writer); static void add_pod_typedefs(codegen::CodeWriter& writer);
/// \brief Given kernel input variables i_* produce register variables o_coordinates{i} static void coordinate_transform_to_multi_d(codegen::CodeWriter& writer,
/// of the non-reduced tensor and return the string name of integer index into reduced tensor
static std::string
collective_coordinate_transform_helper(codegen::CodeWriter& writer,
std::string i_thread_index,
std::string i_strides, std::string i_strides,
std::string i_stride_magic, std::string i_stride_magic,
std::string i_stride_shift, std::string i_stride_shift,
std::string i_reduced_strides, std::string i_coord_product,
std::string o_coordinates, std::string o_coordinates,
size_t rank, size_t rank,
bool register_arguments = false); bool register_arguments = false);
static void coordinate_transform_to_multi_d(codegen::CodeWriter& writer,
/// \brief Given kernel input variables i_* produce register variables o_coordinates{i}
/// of the non-reduced tensor and return the string name of integer index into reduced tensor
static std::string
collective_coordinate_transform_helper(codegen::CodeWriter& writer,
std::string i_thread_index,
std::string i_strides, std::string i_strides,
std::string i_stride_magic, std::string i_stride_magic,
std::string i_stride_shift, std::string i_stride_shift,
std::string i_coord_product, std::string i_reduced_strides,
std::string o_coordinates, std::string o_coordinates,
size_t rank, size_t rank,
bool register_arguments = false); bool register_arguments = true,
std::string reduced_idx = "reduced_idx");
static bool stable_sum_check_helper(const std::string& op,
const std::string& data_type);
}; };
} }
} }
......
...@@ -732,45 +732,12 @@ void runtime::gpu::GPU_Emitter::emit_Max(EMIT_ARGS) ...@@ -732,45 +732,12 @@ void runtime::gpu::GPU_Emitter::emit_Max(EMIT_ARGS)
} }
const ngraph::op::Max* max = static_cast<const ngraph::op::Max*>(node); const ngraph::op::Max* max = static_cast<const ngraph::op::Max*>(node);
vector<element::Type> dtypes;
size_t index; dtypes.push_back(args[0].get_element_type());
if ((args[0].get_element_type() == element::i32) || (args[0].get_element_type() == element::i8)) dtypes.push_back(out[0].get_element_type());
{
// one of args0 axes has zero size, zero output, use args1 value
if (args[0].get_size() == 0)
{
writer << out[0].get_type()
<< " init_value = " << TypeInfo::Get(args[0].get_type())->min() << ";\n";
writer << "vector<" << out[0].get_type() << "> temp(" << out[0].get_size()
<< ", init_value);\n";
writer << "runtime::gpu::cuda_memcpyHtD(" << out[0].get_name()
<< ", (void*)temp.data(), " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
return;
}
else if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
return;
}
else
{
vector<string> dtypes;
dtypes.push_back(args[0].get_type());
dtypes.push_back(out[0].get_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter(); auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
index = cuda_emitter->build_reduce<ngraph::op::Max>(dtypes, size_t index = cuda_emitter->build_reduce<ngraph::op::Max>(
out[0].get_element_type().size(), dtypes, args[0].get_shape(), out[0].get_shape(), max->get_reduction_axes());
args[0].get_shape(),
max->get_reduction_axes());
}
}
else
{
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
index = cudnn_emitter->build_primitive(max);
}
writer.block_begin(); writer.block_begin();
writer << "void* input[] = {" << node_names(args) << "};\n"; writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n"; writer << "void* output[] = {" << node_names(out) << "};\n";
...@@ -867,43 +834,12 @@ void runtime::gpu::GPU_Emitter::emit_Min(EMIT_ARGS) ...@@ -867,43 +834,12 @@ void runtime::gpu::GPU_Emitter::emit_Min(EMIT_ARGS)
const ngraph::op::Min* min = static_cast<const ngraph::op::Min*>(node); const ngraph::op::Min* min = static_cast<const ngraph::op::Min*>(node);
size_t index; size_t index;
if ((args[0].get_element_type() == element::i32) || (args[0].get_element_type() == element::i8)) vector<element::Type> dtypes;
{ dtypes.push_back(args[0].get_element_type());
// one of args0 axes has zero size, zero output, use args1 value dtypes.push_back(out[0].get_element_type());
if (args[0].get_size() == 0)
{
writer << out[0].get_type()
<< " init_value = " << TypeInfo::Get(args[0].get_type())->max() << ";\n";
writer << "vector<" << out[0].get_type() << "> temp(" << out[0].get_size()
<< ", init_value);\n";
writer << "runtime::gpu::cuda_memcpyHtD(" << out[0].get_name()
<< ", (void*)temp.data(), " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
return;
}
else if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
return;
}
else
{
vector<string> dtypes;
dtypes.push_back(args[0].get_type());
dtypes.push_back(out[0].get_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter(); auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
index = cuda_emitter->build_reduce<ngraph::op::Min>(dtypes, index = cuda_emitter->build_reduce<ngraph::op::Min>(
out[0].get_element_type().size(), dtypes, args[0].get_shape(), out[0].get_shape(), min->get_reduction_axes());
args[0].get_shape(),
min->get_reduction_axes());
}
}
else
{
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
index = cudnn_emitter->build_primitive(min);
}
writer.block_begin(); writer.block_begin();
writer << "void* input[] = {" << node_names(args) << "};\n"; writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n"; writer << "void* output[] = {" << node_names(out) << "};\n";
...@@ -1013,59 +949,20 @@ void runtime::gpu::GPU_Emitter::emit_Product(EMIT_ARGS) ...@@ -1013,59 +949,20 @@ void runtime::gpu::GPU_Emitter::emit_Product(EMIT_ARGS)
writer.block_begin(); writer.block_begin();
{ {
if (out[0].get_size() != 0) if (out[0].get_size() != 0)
{
// one of args[] axes has zero size, fill output with 1
if (args[0].get_size() == 0)
{
writer << out[0].get_type() << " init_value = 1;\n";
writer << "vector<" << out[0].get_type() << "> temp(" << out[0].get_size()
<< ", init_value);\n";
writer << "runtime::gpu::cuda_memcpyHtD(" << out[0].get_name()
<< ", (void*)temp.data(), " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
}
else if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
// descriptors for tensors with <= 4 dimensions
else
{ {
size_t prod_index; size_t prod_index;
if ((args[0].get_element_type() == element::i32) || vector<element::Type> dtypes;
(args[0].get_element_type() == element::i8)) dtypes.push_back(args[0].get_element_type());
{ dtypes.push_back(out[0].get_element_type());
vector<string> dtypes; auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
dtypes.push_back(args[0].get_type());
dtypes.push_back(out[0].get_type());
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
prod_index = cuda_emitter->build_reduce<ngraph::op::Multiply>( prod_index = cuda_emitter->build_reduce<ngraph::op::Multiply>(
dtypes, dtypes, args[0].get_shape(), out[0].get_shape(), prod->get_reduction_axes());
out[0].get_element_type().size(),
args[0].get_shape(),
prod->get_reduction_axes());
}
else
{
std::vector<element::Type> dtypes{args[0].get_element_type(),
out[0].get_element_type()};
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
prod_index =
cudnn_emitter->build_reduce_forward(CUDNN_REDUCE_TENSOR_MUL,
dtypes,
args[0].get_shape(),
prod->get_reduction_axes(),
CUDNNEmitter::ReductionMode::Reduce);
}
writer << "void* input[] = {" << node_names(args) << "};\n"; writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n"; writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << prod_index << ", input, output);\n"; writer << "gpu::invoke_primitive(ctx, " << prod_index << ", input, output);\n";
} }
} }
}
writer.block_end(); writer.block_end();
} }
...@@ -1076,54 +973,15 @@ void runtime::gpu::GPU_Emitter::emit_Quantize(EMIT_ARGS) ...@@ -1076,54 +973,15 @@ void runtime::gpu::GPU_Emitter::emit_Quantize(EMIT_ARGS)
void runtime::gpu::GPU_Emitter::emit_Reduce(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_Reduce(EMIT_ARGS)
{ {
// reduction function supported by GPU
// CUDNN_REDUCE_TENSOR_ADD
// CUDNN_REDUCE_TENSOR_MUL
// CUDNN_REDUCE_TENSOR_MIN
// CUDNN_REDUCE_TENSOR_MAX
// CUDNN_REDUCE_TENSOR_AMAX
// CUDNN_REDUCE_TENSOR_AVG
// CUDNN_REDUCE_TENSOR_NORM1
// CUDNN_REDUCE_TENSOR_NORM2
// CUDNN_REDUCE_TENSOR_MUL_NO_ZEROS
static const unordered_map<type_index, cudnnReduceTensorOp_t> reduce_map{
{TI(ngraph::op::Add), CUDNN_REDUCE_TENSOR_ADD},
{TI(ngraph::op::Multiply), CUDNN_REDUCE_TENSOR_MUL},
{TI(ngraph::op::Maximum), CUDNN_REDUCE_TENSOR_MAX},
{TI(ngraph::op::Minimum), CUDNN_REDUCE_TENSOR_MIN}};
const ngraph::op::Reduce* reduce_op = static_cast<const ngraph::op::Reduce*>(node); const ngraph::op::Reduce* reduce_op = static_cast<const ngraph::op::Reduce*>(node);
writer.block_begin(); writer.block_begin();
{ {
if (out[0].get_size() != 0) if (out[0].get_size() != 0)
{
// one of args0 axes has zero size, zero output, use args1 value
if (args[0].get_size() == 0)
{
writer << out[0].get_type() << " init_value;\n";
writer << "runtime::gpu::cuda_memcpyDtH(&init_value, " << args[1].get_name() << " ,"
<< args[1].get_element_type().size() << ");\n";
writer << "vector<" << out[0].get_type() << "> temp(" << out[0].get_size()
<< ", init_value);\n";
writer << "runtime::gpu::cuda_memcpyHtD(" << out[0].get_name()
<< ", (void*)temp.data(), " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
}
else if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
else
{ {
auto axes_set = reduce_op->get_reduction_axes(); auto axes_set = reduce_op->get_reduction_axes();
ngraph::AxisVector axes_vec; std::vector<element::Type> dtypes;
for (auto a : axes_set) dtypes.push_back(args[0].get_element_type());
{ dtypes.push_back(out[0].get_element_type());
axes_vec.push_back(a);
}
std::vector<string> dtypes;
dtypes.push_back(args[0].get_type());
dtypes.push_back(out[0].get_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter(); auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto reduction_function_ops = reduce_op->get_functions()[0]->get_ops(); auto reduction_function_ops = reduce_op->get_functions()[0]->get_ops();
...@@ -1150,44 +1008,42 @@ void runtime::gpu::GPU_Emitter::emit_Reduce(EMIT_ARGS) ...@@ -1150,44 +1008,42 @@ void runtime::gpu::GPU_Emitter::emit_Reduce(EMIT_ARGS)
if (dynamic_pointer_cast<ngraph::op::Add>(reduce_func)) if (dynamic_pointer_cast<ngraph::op::Add>(reduce_func))
{ {
emitter_index = cuda_emitter->build_reduce<ngraph::op::Add>( emitter_index = cuda_emitter->build_reduce<ngraph::op::Add>(
dtypes, out[0].get_element_type().size(), args[0].get_shape(), axes_vec); dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
} }
else if (dynamic_pointer_cast<ngraph::op::Multiply>(reduce_func)) else if (dynamic_pointer_cast<ngraph::op::Multiply>(reduce_func))
{ {
emitter_index = cuda_emitter->build_reduce<ngraph::op::Multiply>( emitter_index = cuda_emitter->build_reduce<ngraph::op::Multiply>(
dtypes, out[0].get_element_type().size(), args[0].get_shape(), axes_vec); dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
} }
else if (dynamic_pointer_cast<ngraph::op::Maximum>(reduce_func)) else if (dynamic_pointer_cast<ngraph::op::Maximum>(reduce_func))
{ {
emitter_index = cuda_emitter->build_reduce<ngraph::op::Maximum>( emitter_index = cuda_emitter->build_reduce<ngraph::op::Maximum>(
dtypes, out[0].get_element_type().size(), args[0].get_shape(), axes_vec); dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
} }
else if (dynamic_pointer_cast<ngraph::op::Minimum>(reduce_func)) else if (dynamic_pointer_cast<ngraph::op::Minimum>(reduce_func))
{ {
emitter_index = cuda_emitter->build_reduce<ngraph::op::Minimum>( emitter_index = cuda_emitter->build_reduce<ngraph::op::Minimum>(
dtypes, out[0].get_element_type().size(), args[0].get_shape(), axes_vec); dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
} }
else if (dynamic_pointer_cast<ngraph::op::And>(reduce_func)) else if (dynamic_pointer_cast<ngraph::op::And>(reduce_func))
{ {
emitter_index = cuda_emitter->build_reduce<ngraph::op::And>( emitter_index = cuda_emitter->build_reduce<ngraph::op::And>(
dtypes, out[0].get_element_type().size(), args[0].get_shape(), axes_vec); dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
} }
else if (dynamic_pointer_cast<ngraph::op::Or>(reduce_func)) else if (dynamic_pointer_cast<ngraph::op::Or>(reduce_func))
{ {
emitter_index = cuda_emitter->build_reduce<ngraph::op::Or>( emitter_index = cuda_emitter->build_reduce<ngraph::op::Or>(
dtypes, out[0].get_element_type().size(), args[0].get_shape(), axes_vec); dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
} }
else else
{ {
throw runtime_error("reduce with function " + op_name + throw runtime_error("reduce with function " + op_name + " is not implement yet.");
" is not implement yet.");
} }
writer << "void* input[] = {" << node_names(args) << "};\n"; writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n"; writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << emitter_index << ", input, output);\n"; writer << "gpu::invoke_primitive(ctx, " << emitter_index << ", input, output);\n";
} }
} }
}
writer.block_end(); writer.block_end();
} }
...@@ -1621,16 +1477,11 @@ void runtime::gpu::GPU_Emitter::emit_Softmax(EMIT_ARGS) ...@@ -1621,16 +1477,11 @@ void runtime::gpu::GPU_Emitter::emit_Softmax(EMIT_ARGS)
writer.block_begin(); writer.block_begin();
{ {
auto axes_set = softmax->get_axes(); auto axes_set = softmax->get_axes();
ngraph::AxisVector axes_vec;
for (auto a : axes_set)
{
axes_vec.push_back(a);
}
std::vector<string> dtypes; std::vector<string> dtypes;
dtypes.push_back(args[0].get_type()); dtypes.push_back(args[0].get_type());
dtypes.push_back(out[0].get_type()); dtypes.push_back(out[0].get_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter(); auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
size_t index = cuda_emitter->build_softmax(dtypes, args[0].get_shape(), axes_vec); size_t index = cuda_emitter->build_softmax(dtypes, args[0].get_shape(), axes_set);
writer << "void* input[] = {" << node_names(args) << "};\n"; writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n"; writer << "void* output[] = {" << node_names(out) << "};\n";
...@@ -1656,14 +1507,7 @@ void runtime::gpu::GPU_Emitter::emit_Subtract(EMIT_ARGS) ...@@ -1656,14 +1507,7 @@ void runtime::gpu::GPU_Emitter::emit_Subtract(EMIT_ARGS)
void runtime::gpu::GPU_Emitter::emit_Sum(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_Sum(EMIT_ARGS)
{ {
if ((args[0].get_element_type() == element::i32) || (args[0].get_element_type() == element::i8))
{
runtime::gpu::GPU_Emitter::emit_Sum_0(external_function, writer, node, args, out); runtime::gpu::GPU_Emitter::emit_Sum_0(external_function, writer, node, args, out);
}
else
{
runtime::gpu::GPU_Emitter::emit_Sum_1(external_function, writer, node, args, out);
}
} }
void runtime::gpu::GPU_Emitter::emit_Sum_0(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_Sum_0(EMIT_ARGS)
...@@ -1677,33 +1521,19 @@ to fail */ ...@@ -1677,33 +1521,19 @@ to fail */
{ {
if (out[0].get_size() != 0) if (out[0].get_size() != 0)
{ {
// one of args[] axes has zero size, zero output auto axes_set = sum->get_reduction_axes();
if (args[0].get_size() == 0) vector<element::Type> dtypes;
{ dtypes.push_back(args[0].get_element_type());
kernel::emit_memset(writer, out[0], 0); dtypes.push_back(out[0].get_element_type());
}
else if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
else
{
vector<string> dtypes;
dtypes.push_back(args[0].get_type());
dtypes.push_back(out[0].get_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter(); auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto sum_index = auto sum_index = cuda_emitter->build_reduce<ngraph::op::Add>(
cuda_emitter->build_reduce<ngraph::op::Add>(dtypes, dtypes, args[0].get_shape(), out[0].get_shape(), axes_set);
out[0].get_element_type().size(),
args[0].get_shape(),
sum->get_reduction_axes());
writer << "void* input[] = {" << node_names(args) << "};\n"; writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n"; writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << sum_index << ", input, output);\n"; writer << "gpu::invoke_primitive(ctx, " << sum_index << ", input, output);\n";
} }
} }
}
writer.block_end(); writer.block_end();
} }
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include <cinttypes> #include <cinttypes>
#include <list> #include <list>
#include "ngraph/except.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -86,6 +88,110 @@ namespace ngraph ...@@ -86,6 +88,110 @@ namespace ngraph
return &m_uint64_t_params.back(); return &m_uint64_t_params.back();
} }
template <typename T1, typename T2>
void* getVal(T2 val)
{
return cache(static_cast<T1>(val));
}
void* val_by_datatype(const std::string& type, double val)
{
if (type == "char")
{
return getVal<char>(val);
}
else if (type == "float")
{
return getVal<float>(val);
}
else if (type == "double")
{
return getVal<double>(val);
}
else if (type == "int8_t")
{
return getVal<int8_t>(val);
}
else if (type == "int16_t")
{
return getVal<int16_t>(val);
}
else if (type == "int32_t")
{
return getVal<int32_t>(val);
}
else if (type == "int64_t")
{
return getVal<int64_t>(val);
}
else if (type == "uint8_t")
{
return getVal<uint8_t>(val);
}
else if (type == "uint16_t")
{
return getVal<uint16_t>(val);
}
else if (type == "uint32_t")
{
return getVal<uint32_t>(val);
}
else if (type == "uint64_t")
{
return getVal<uint64_t>(val);
}
throw ngraph_error("Cast requested for invalid dtype");
}
void* val_by_datatype(const std::string& type, int64_t val)
{
if (type == "char")
{
return getVal<char>(val);
}
else if (type == "float")
{
return getVal<float>(val);
}
else if (type == "double")
{
return getVal<double>(val);
}
else if (type == "int8_t")
{
return getVal<int8_t>(val);
}
else if (type == "int16_t")
{
return getVal<int16_t>(val);
}
else if (type == "int32_t")
{
return getVal<int32_t>(val);
}
else if (type == "int64_t")
{
return getVal<int64_t>(val);
}
else if (type == "uint8_t")
{
return getVal<uint8_t>(val);
}
else if (type == "uint16_t")
{
return getVal<uint16_t>(val);
}
else if (type == "uint32_t")
{
return getVal<uint32_t>(val);
}
else if (type == "uint64_t")
{
return getVal<uint64_t>(val);
}
throw ngraph_error("Cast requested for invalid dtype");
}
private: private:
std::list<char> m_char_params; std::list<char> m_char_params;
std::list<float> m_float_params; std::list<float> m_float_params;
......
...@@ -24,7 +24,7 @@ using namespace ngraph::runtime::gpu; ...@@ -24,7 +24,7 @@ using namespace ngraph::runtime::gpu;
GPUPrimitiveEmitter::GPUPrimitiveEmitter() GPUPrimitiveEmitter::GPUPrimitiveEmitter()
: m_memory_manager(this) : m_memory_manager(this)
, m_host_parameters(new GPUHostParameters) , m_host_parameters(new GPUHostParameters)
, m_cuda_emitter(new CUDAEmitter(this, nullptr)) , m_cuda_emitter(new CUDAEmitter(this, nullptr, nullptr))
, m_cudnn_emitter(new CUDNNEmitter(this, nullptr, nullptr)) , m_cudnn_emitter(new CUDNNEmitter(this, nullptr, nullptr))
, m_cublas_emitter(new CUBLASEmitter(this, nullptr)) , m_cublas_emitter(new CUBLASEmitter(this, nullptr))
{ {
...@@ -33,7 +33,7 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter() ...@@ -33,7 +33,7 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter()
GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx) GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx)
: m_memory_manager(this) : m_memory_manager(this)
, m_host_parameters(new GPUHostParameters) , m_host_parameters(new GPUHostParameters)
, m_cuda_emitter(new CUDAEmitter(this, ctx.get())) , m_cuda_emitter(new CUDAEmitter(this, ctx.get(), this->m_host_parameters))
, m_cudnn_emitter(new CUDNNEmitter(this, ctx.get(), this->m_host_parameters)) , m_cudnn_emitter(new CUDNNEmitter(this, ctx.get(), this->m_host_parameters))
, m_cublas_emitter(new CUBLASEmitter(this, ctx.get())) , m_cublas_emitter(new CUBLASEmitter(this, ctx.get()))
......
...@@ -39,6 +39,9 @@ namespace ngraph ...@@ -39,6 +39,9 @@ namespace ngraph
virtual std::string lowest() const = 0; virtual std::string lowest() const = 0;
virtual std::string min() const = 0; virtual std::string min() const = 0;
virtual std::string max() const = 0; virtual std::string max() const = 0;
virtual void* lowest_ptr() = 0;
virtual void* min_ptr() = 0;
virtual void* max_ptr() = 0;
using TypeDispatch = std::unordered_map<std::string, std::shared_ptr<TypeInfo>>; using TypeDispatch = std::unordered_map<std::string, std::shared_ptr<TypeInfo>>;
static const std::shared_ptr<TypeInfo>& Get(const element::Type& type) static const std::shared_ptr<TypeInfo>& Get(const element::Type& type)
...@@ -68,6 +71,17 @@ namespace ngraph ...@@ -68,6 +71,17 @@ namespace ngraph
class TypeInfo_Impl : public TypeInfo class TypeInfo_Impl : public TypeInfo
{ {
public: public:
TypeInfo_Impl()
: m_min(std::numeric_limits<T>::min())
, m_max(std::numeric_limits<T>::has_infinity
? std::numeric_limits<T>::infinity()
: std::numeric_limits<T>::max())
, m_lowest(std::numeric_limits<T>::has_infinity
? -std::numeric_limits<T>::infinity()
: std::numeric_limits<T>::lowest())
{
}
std::string lowest() const override std::string lowest() const override
{ {
return to_string<T>(std::numeric_limits<T>::lowest()); return to_string<T>(std::numeric_limits<T>::lowest());
...@@ -80,6 +94,13 @@ namespace ngraph ...@@ -80,6 +94,13 @@ namespace ngraph
{ {
return to_string<T>(std::numeric_limits<T>::max()); return to_string<T>(std::numeric_limits<T>::max());
} }
void* lowest_ptr() override { return &m_lowest; }
void* min_ptr() override { return &m_min; }
void* max_ptr() override { return &m_max; }
private:
T m_min;
T m_max;
T m_lowest;
}; };
} }
} }
......
#int64 is not supprted by cuDNN #int64 is not supprted by cuDNN
abc_int64
batch_norm_one_output batch_norm_one_output
batch_norm_three_outputs batch_norm_three_outputs
backwards_batch_norm_three_outputs backwards_batch_norm_three_outputs
#need to check #need to check
computation_reuse computation_reuse
#int64 is not supprted #cuda does not support throw
concat_matrix_int64
divide_by_zero_int32 divide_by_zero_int32
#int64 is not supprted by cuDNN #int64 is not supprted by cuDNN
dot_matrix_vector_int64 dot_matrix_vector_int64
generate_mask generate_mask
#no mkldnn on GPU
#error throw is not the same on GPU, not supported yet #error throw is not the same on GPU, not supported yet
one_hot_scalar_fp_nonint_in_3 one_hot_scalar_fp_nonint_in_3
one_hot_scalar_oob_in_3 one_hot_scalar_oob_in_3
......
...@@ -135,6 +135,8 @@ shape_of_vector ...@@ -135,6 +135,8 @@ shape_of_vector
shape_of_matrix shape_of_matrix
shape_of_5d shape_of_5d
sum_stable_acc sum_stable_acc
sum_stable_acc_double
sum_stable_simple_double
sum_trivial_in_double sum_trivial_in_double
product_2d_to_scalar_int32 product_2d_to_scalar_int32
product_to_scalar_int32 product_to_scalar_int32
......
...@@ -535,4 +535,106 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_stable_acc) ...@@ -535,4 +535,106 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_stable_acc)
EXPECT_TRUE(test::all_close_f(ref_results.at(0), bk_results.at(0), 24, 3)); EXPECT_TRUE(test::all_close_f(ref_results.at(0), bk_results.at(0), 24, 3));
} }
NGRAPH_TEST(${BACKEND_NAME}, sum_stable_acc_double)
{
std::string backend_name = "${BACKEND_NAME}";
if (backend_name == "INTERPRETER")
{
return;
}
Shape shape_a{10, 10, 20, 300};
auto A = make_shared<op::Parameter>(element::f64, shape_a);
Shape shape_rt{10};
auto f = make_shared<Function>(make_shared<op::Sum>(A, AxisSet{1, 2, 3}), ParameterVector{A});
test::Uniform<double> rng(1000000000.0L, 1000000000.001L, 2112);
vector<vector<double>> args;
for (shared_ptr<op::Parameter> param : f->get_parameters())
{
vector<double> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto ref_func = clone_function(*f);
auto bk_func = clone_function(*f);
auto ref_results = execute(ref_func, args, "INTERPRETER");
auto bk_results = execute(bk_func, args, "${BACKEND_NAME}");
EXPECT_TRUE(test::all_close(ref_results.at(0), bk_results.at(0), 0.0, 1e-5));
}
NGRAPH_TEST(${BACKEND_NAME}, sum_stable_simple_float)
{
std::string backend_name = "${BACKEND_NAME}";
if (backend_name == "INTERPRETER")
{
return;
}
Shape shape_a{20};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{};
auto f = make_shared<Function>(make_shared<op::Sum>(A, AxisSet{0}), ParameterVector{A});
vector<vector<float>> args;
args.push_back(vector<float>{10000000.0f, 0.9f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f,
0.8f, 0.1f, 0.9f, 0.5f, 0.2f, 0.3f, 0.4f,
0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 0.1f});
auto ref_func = clone_function(*f);
auto bk_func = clone_function(*f);
auto ref_results = execute(ref_func, args, "INTERPRETER");
auto bk_results = execute(bk_func, args, "${BACKEND_NAME}");
EXPECT_TRUE(test::all_close_f(ref_results.at(0), bk_results.at(0), 24, 1));
}
NGRAPH_TEST(${BACKEND_NAME}, sum_stable_simple_double)
{
std::string backend_name = "${BACKEND_NAME}";
if (backend_name == "INTERPRETER")
{
return;
}
Shape shape_a{20};
auto A = make_shared<op::Parameter>(element::f64, shape_a);
Shape shape_rt{};
auto f = make_shared<Function>(make_shared<op::Sum>(A, AxisSet{0}), ParameterVector{A});
vector<vector<double>> args;
args.push_back(vector<double>{10000000000000000.0L,
0.2L,
0.3L,
0.4L,
0.5L,
0.6L,
0.7L,
0.8L,
0.9L,
0.7L,
0.9L,
0.7L,
0.3L,
0.6L,
0.8L,
0.4L,
0.6L,
0.5L,
0.8L,
0.7L});
auto ref_func = clone_function(*f);
auto bk_func = clone_function(*f);
auto ref_results = execute(ref_func, args, "INTERPRETER");
auto bk_results = execute(bk_func, args, "${BACKEND_NAME}");
EXPECT_TRUE(test::all_close(ref_results.at(0), bk_results.at(0), 0.0, 2.0));
}
#endif #endif
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
{ {
if (count < 5) if (count < 5)
{ {
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i; NGRAPH_INFO
<< std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< a[i] << " is not close to " << b[i] << " at index " << i;
} }
count++; count++;
rc = false; rc = false;
......
...@@ -166,7 +166,8 @@ bool test::all_close_f(const vector<float>& a, ...@@ -166,7 +166,8 @@ bool test::all_close_f(const vector<float>& a,
{ {
if (diff_count < 5) if (diff_count < 5)
{ {
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i; NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< a[i] << " is not close to " << b[i] << " at index " << i;
} }
rc = false; rc = false;
...@@ -191,10 +192,12 @@ bool test::all_close_f(const vector<float>& a, ...@@ -191,10 +192,12 @@ bool test::all_close_f(const vector<float>& a,
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits (" NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)"; << mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)";
NGRAPH_INFO << "tightest match: " << matching_mantissa_bits(min_distance) NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match: " << matching_mantissa_bits(min_distance)
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index] << " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index]
<< " at [" << min_distance_index << "])"; << " at [" << min_distance_index << "])";
NGRAPH_INFO << "loosest match: " << matching_mantissa_bits(max_distance) NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match: " << matching_mantissa_bits(max_distance)
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index] << " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index]
<< " at [" << max_distance_index << "])"; << " at [" << max_distance_index << "])";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance) NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment