Commit 606f3f93 authored by Fenglei's avatar Fenglei Committed by Robert Kimball

nvgpu cuda reduce with stable sum (#2076)

* add some helper function

* update with new helper function

* update reduce to nd with new helper function

* update float sum to stable sum

* fix bug

* update all reduce to stable sum for float

* fix bug and pass the sum stable test

* remove debug info

* style

* update with shape

* fix bug

* add host parameters to cuda_emitter

* clang format

* fix bugs

* add element::type support

* format

* add a cached value with datatype name

* add init_reduce_value

* unroll loop

* optimization

* remove the need for init_value

* add memset kernel

* add memcpy

* working version

* remove debug info

* add comments, clean up code.

* change in_idx to input_idx

* fix bug

* change args name for memset in emitter

* pass element::Type instead of string

* the op::reduce come with init value, add support

* resolve codacy-bot comment

* fix bug

* resove codacy-bot comment

* remove unused comments, resolve comments

* cuda reduce for max, min, mul, reduce op init value, format

* use type::info

* use type info for numeric_limits

* remove code from gpu_host_parameters

* header

* remvoe outdated comments

* add helper to check if stable sum is needed

* add stable sum test for double

* remove extra line

* consolidate helper functions

* no need list now.

* remove extra ;

* clang format

* style

* add skip test for cpu and intelGPU side

* add line between groups of headers

* add two simple stable sum test for float and double

* skip test for intelGPU
parent 4b0445d1
......@@ -12,6 +12,7 @@ shape_of_scalar
shape_of_vector
shape_of_matrix
shape_of_5d
sum_stable_acc_double
quantize_clamp_int32
# failing in CI build but passing on local machine
......
This diff is collapsed.
......@@ -19,6 +19,7 @@
#include <array>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
#include "ngraph/runtime/gpu/gpu_host_parameters.hpp"
#include "ngraph/runtime/gpu/nvdiff.hpp"
#include "ngraph/runtime/gpu/nvshape.hpp"
#include "ngraph/strides.hpp"
......@@ -49,6 +50,8 @@ namespace ngraph
size_t build_primitive(const op::ReplaceSlice* node, bool in_place_op);
public:
size_t build_memset(const std::string& dtype, uint32_t tensor_size);
size_t build_topk(const std::vector<element::Type>& dtypes,
const NVShape& input_shape,
const size_t topk_axis,
......@@ -125,17 +128,19 @@ namespace ngraph
const double& eps);
template <typename T>
size_t build_reduce(const std::vector<std::string>& dtypes,
const size_t data_bytes,
const NVShape& input_shape,
const NVShape& reduce_axis)
size_t build_reduce(const std::vector<element::Type>& dtypes,
NVShape input_shape,
NVShape output_shape,
NVShape reduce_axis,
const bool with_init_value = false)
{
return build_reduce(dtypes,
data_bytes,
input_shape,
output_shape,
reduce_axis,
CudaOpMap<T>::op,
CudaOpMap<T>::math_kernel);
CudaOpMap<T>::math_kernel,
with_init_value);
}
template <typename ELEMENTWISE_OP_TYPE, typename REDUCE_OP_TYPE = ngraph::op::Nop>
......@@ -193,7 +198,9 @@ namespace ngraph
void sync();
private:
CUDAEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx);
CUDAEmitter(GPUPrimitiveEmitter* emitter,
GPURuntimeContext* ctx,
std::shared_ptr<GPUHostParameters> params);
uint32_t align_to_block_size(uint32_t threads, uint32_t block_size);
void print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name,
......@@ -211,32 +218,71 @@ namespace ngraph
const char* kernel,
const char* reduce_op,
bool save_elementwise);
size_t build_reduce(const std::vector<std::string>& dtypes,
const size_t data_bytes,
size_t build_reduce(const std::vector<element::Type>& dtypes,
const NVShape& input_shape,
const NVShape& output_shape,
const NVShape& reduce_axis,
const char* op,
const char* kernel);
size_t build_reduce_to_nd(const std::vector<std::string>& dtypes,
const char* kernel,
const bool with_init_value);
size_t build_reduce_to_nd(const std::vector<element::Type>& dtypes,
NVShape input_shape,
NVShape reduce_axis,
const char* op,
const char* kernel);
size_t build_reduce_to_scalar(const std::vector<std::string>& dtypes,
const size_t data_bytes,
size_t build_reduce_to_scalar(const std::vector<element::Type>& dtypes,
NVShape input_shape,
const char* op,
const char* kernel);
//This is the preprocess for reduce to scalar if the data size is large than a number.
//The number can be tuned based on hardware.
//This cuda kernel will accumulate reduction to a certain number of bins depends on hardware.
size_t build_reduce_to_scalar_acc(const std::vector<std::string>& dtypes,
/// \brief This is the preprocess for reduce to scalar if the data size is large than a number.
/// The number can be tuned based on hardware.
/// This cuda kernel will accumulate reduction to a certain number of bins depends on hardware.
size_t build_reduce_to_scalar_acc(const std::vector<element::Type>& dtypes,
NVShape input_shape,
NVShape output_shape,
uint32_t block_size_x,
const char* op,
const char* kernel);
/// \brief Simplifed reduce shape and reduce axis, remove dimsion size 1,
/// combine two or more adjacent reduce/nonreduce axis.
/// the simplified reduce shape and reduce axis will make index caculation simplier in cuda kernel.
/// example:
/// {1 1 2 2} with reduce axis {3} simplifiy to: {2 2} with reduce_axis {1};
/// {2 3 4} with reduce axis {0 1} simplify to {6 4} with reduce_axis {0};
/// {2 3 4} with reduce axis {0} simplify to {2 12} with reduce_axis {0};
void simplify_reduce_shape(NVShape in,
NVShape reduce_axis,
NVShape& simplified_shape,
NVShape& simplified_reduce_axis);
/// \brief Seperate input_shape to reduced_shape and non_reduce_shape, and calcuate strides for them
/// and strides in input. This help caculate input index and output index for cuda kernel.
/// example:
/// input_shape {2 3 4 5} with reduce_axis {0 2}:
/// input_strides: {60, 20, 5, 1}
/// reduce_shape {2 4}, reduce_strides {4 1}, reduce_strides_in_input {60 5}
/// non_reduce_shape {3 5}, non_reduce_strides {5 1}, non_reduce_strides_in_input {20 1}
void get_reduce_strides(NVShape input_shape,
NVShape reduce_axis,
NVShape& non_reduce_shape,
NVShape& non_reduce_strides,
NVShape& non_reduce_strides_in_input,
NVShape& reduce_shape,
NVShape& reduce_strides,
NVShape& reduce_strides_in_input);
/// \brief Calculate magic and shift part of an shape vector (denomitor), change divide to multiply
/// in cuda kernel.
void div_to_mul(const NVShape& shape,
std::vector<int>& magic,
std::vector<int>& shift);
/// \brief Get initial value for reduce op
void* get_init_reduce_val(std::string reduce_op, std::string data_type);
/// \brief Get vector<string> of datatype from vector<element::Type>
std::vector<std::string>
get_string_vector(const std::vector<element::Type>& dtypes);
std::shared_ptr<GPUHostParameters> m_host_parameters;
GPUPrimitiveEmitter* m_primitive_emitter;
GPURuntimeContext* m_ctx;
};
......
......@@ -40,6 +40,11 @@ namespace ngraph
const std::string& op,
const std::vector<std::string>& data_types);
static void get_memset_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& data_type,
runtime::gpu::GPUKernelArgs& args);
static void get_cudnn_bn_inv_var_op(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args);
......@@ -78,22 +83,34 @@ namespace ngraph
const std::string& data_type,
uint32_t block_size);
/// \brief reduce op for output that is not scalar
/// stable kahan sum is been used for float point sum.
/// no initial value needed since we load one input value as initial
/// not support 0 sized input
static void get_reduce_to_nd_op(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
const std::vector<std::string>& data_types,
const std::string& reduce_op,
size_t out_rank,
size_t non_reduce_rank,
size_t reduce_rank);
static void get_topk(codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& dtypes,
bool compute_max,
runtime::gpu::GPUKernelArgs& args,
bool use_malloc);
/// \brief This is the preprocess to reduce to scalar if the input data size is large than a number.
/// The number can be tuned based on hardware.
/// This cuda kernel will accumulate reduction to a certain number of bins depends on hardware.
/// stable kahan sum is been used for float point sum.
/// no initial value needed since we load one input value as initial
/// not support 0 sized input
static void get_reduce_to_scalar_acc_op(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
const std::vector<std::string>& data_types,
const std::string& reduce_op);
//using one block with at most 512 threads to reduce to scalar.
/// \brief This op using one block with at most 512 threads to reduce to scalar.
/// stable kahan sum is been used for float point sum.
/// no initial value needed since we load one input value as initial
/// not support 0 sized input
static void get_reduce_to_scalar_op(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
......@@ -101,14 +118,12 @@ namespace ngraph
const std::string& reduce_op,
uint32_t block_size_x);
//This is the preprocess to reduce to scalar if the data size is large than a number.
//The number can be tuned based on hardware.
//This cuda kernel will accumulate reduction to a certain number of bins depends on hardware.
static void get_reduce_to_scalar_acc_op(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
const std::vector<std::string>& data_types,
const std::string& reduce_op);
static void get_topk(codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& dtypes,
bool compute_max,
runtime::gpu::GPUKernelArgs& args,
bool use_malloc);
static void get_slice_op(codegen::CodeWriter& writer,
const std::string& name,
......@@ -195,6 +210,15 @@ namespace ngraph
static void add_pod_typedefs(codegen::CodeWriter& writer);
static void coordinate_transform_to_multi_d(codegen::CodeWriter& writer,
std::string i_strides,
std::string i_stride_magic,
std::string i_stride_shift,
std::string i_coord_product,
std::string o_coordinates,
size_t rank,
bool register_arguments = false);
/// \brief Given kernel input variables i_* produce register variables o_coordinates{i}
/// of the non-reduced tensor and return the string name of integer index into reduced tensor
static std::string
......@@ -206,15 +230,11 @@ namespace ngraph
std::string i_reduced_strides,
std::string o_coordinates,
size_t rank,
bool register_arguments = false);
static void coordinate_transform_to_multi_d(codegen::CodeWriter& writer,
std::string i_strides,
std::string i_stride_magic,
std::string i_stride_shift,
std::string i_coord_product,
std::string o_coordinates,
size_t rank,
bool register_arguments = false);
bool register_arguments = true,
std::string reduced_idx = "reduced_idx");
static bool stable_sum_check_helper(const std::string& op,
const std::string& data_type);
};
}
}
......
This diff is collapsed.
......@@ -19,6 +19,8 @@
#include <cinttypes>
#include <list>
#include "ngraph/except.hpp"
namespace ngraph
{
namespace runtime
......@@ -86,6 +88,110 @@ namespace ngraph
return &m_uint64_t_params.back();
}
template <typename T1, typename T2>
void* getVal(T2 val)
{
return cache(static_cast<T1>(val));
}
void* val_by_datatype(const std::string& type, double val)
{
if (type == "char")
{
return getVal<char>(val);
}
else if (type == "float")
{
return getVal<float>(val);
}
else if (type == "double")
{
return getVal<double>(val);
}
else if (type == "int8_t")
{
return getVal<int8_t>(val);
}
else if (type == "int16_t")
{
return getVal<int16_t>(val);
}
else if (type == "int32_t")
{
return getVal<int32_t>(val);
}
else if (type == "int64_t")
{
return getVal<int64_t>(val);
}
else if (type == "uint8_t")
{
return getVal<uint8_t>(val);
}
else if (type == "uint16_t")
{
return getVal<uint16_t>(val);
}
else if (type == "uint32_t")
{
return getVal<uint32_t>(val);
}
else if (type == "uint64_t")
{
return getVal<uint64_t>(val);
}
throw ngraph_error("Cast requested for invalid dtype");
}
void* val_by_datatype(const std::string& type, int64_t val)
{
if (type == "char")
{
return getVal<char>(val);
}
else if (type == "float")
{
return getVal<float>(val);
}
else if (type == "double")
{
return getVal<double>(val);
}
else if (type == "int8_t")
{
return getVal<int8_t>(val);
}
else if (type == "int16_t")
{
return getVal<int16_t>(val);
}
else if (type == "int32_t")
{
return getVal<int32_t>(val);
}
else if (type == "int64_t")
{
return getVal<int64_t>(val);
}
else if (type == "uint8_t")
{
return getVal<uint8_t>(val);
}
else if (type == "uint16_t")
{
return getVal<uint16_t>(val);
}
else if (type == "uint32_t")
{
return getVal<uint32_t>(val);
}
else if (type == "uint64_t")
{
return getVal<uint64_t>(val);
}
throw ngraph_error("Cast requested for invalid dtype");
}
private:
std::list<char> m_char_params;
std::list<float> m_float_params;
......
......@@ -24,7 +24,7 @@ using namespace ngraph::runtime::gpu;
GPUPrimitiveEmitter::GPUPrimitiveEmitter()
: m_memory_manager(this)
, m_host_parameters(new GPUHostParameters)
, m_cuda_emitter(new CUDAEmitter(this, nullptr))
, m_cuda_emitter(new CUDAEmitter(this, nullptr, nullptr))
, m_cudnn_emitter(new CUDNNEmitter(this, nullptr, nullptr))
, m_cublas_emitter(new CUBLASEmitter(this, nullptr))
{
......@@ -33,7 +33,7 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter()
GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx)
: m_memory_manager(this)
, m_host_parameters(new GPUHostParameters)
, m_cuda_emitter(new CUDAEmitter(this, ctx.get()))
, m_cuda_emitter(new CUDAEmitter(this, ctx.get(), this->m_host_parameters))
, m_cudnn_emitter(new CUDNNEmitter(this, ctx.get(), this->m_host_parameters))
, m_cublas_emitter(new CUBLASEmitter(this, ctx.get()))
......
......@@ -39,6 +39,9 @@ namespace ngraph
virtual std::string lowest() const = 0;
virtual std::string min() const = 0;
virtual std::string max() const = 0;
virtual void* lowest_ptr() = 0;
virtual void* min_ptr() = 0;
virtual void* max_ptr() = 0;
using TypeDispatch = std::unordered_map<std::string, std::shared_ptr<TypeInfo>>;
static const std::shared_ptr<TypeInfo>& Get(const element::Type& type)
......@@ -68,6 +71,17 @@ namespace ngraph
class TypeInfo_Impl : public TypeInfo
{
public:
TypeInfo_Impl()
: m_min(std::numeric_limits<T>::min())
, m_max(std::numeric_limits<T>::has_infinity
? std::numeric_limits<T>::infinity()
: std::numeric_limits<T>::max())
, m_lowest(std::numeric_limits<T>::has_infinity
? -std::numeric_limits<T>::infinity()
: std::numeric_limits<T>::lowest())
{
}
std::string lowest() const override
{
return to_string<T>(std::numeric_limits<T>::lowest());
......@@ -80,6 +94,13 @@ namespace ngraph
{
return to_string<T>(std::numeric_limits<T>::max());
}
void* lowest_ptr() override { return &m_lowest; }
void* min_ptr() override { return &m_min; }
void* max_ptr() override { return &m_max; }
private:
T m_min;
T m_max;
T m_lowest;
};
}
}
......
#int64 is not supprted by cuDNN
abc_int64
batch_norm_one_output
batch_norm_three_outputs
backwards_batch_norm_three_outputs
#need to check
computation_reuse
#int64 is not supprted
concat_matrix_int64
#cuda does not support throw
divide_by_zero_int32
#int64 is not supprted by cuDNN
dot_matrix_vector_int64
generate_mask
#no mkldnn on GPU
#error throw is not the same on GPU, not supported yet
one_hot_scalar_fp_nonint_in_3
one_hot_scalar_oob_in_3
......
......@@ -135,6 +135,8 @@ shape_of_vector
shape_of_matrix
shape_of_5d
sum_stable_acc
sum_stable_acc_double
sum_stable_simple_double
sum_trivial_in_double
product_2d_to_scalar_int32
product_to_scalar_int32
......
......@@ -535,4 +535,106 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_stable_acc)
EXPECT_TRUE(test::all_close_f(ref_results.at(0), bk_results.at(0), 24, 3));
}
NGRAPH_TEST(${BACKEND_NAME}, sum_stable_acc_double)
{
std::string backend_name = "${BACKEND_NAME}";
if (backend_name == "INTERPRETER")
{
return;
}
Shape shape_a{10, 10, 20, 300};
auto A = make_shared<op::Parameter>(element::f64, shape_a);
Shape shape_rt{10};
auto f = make_shared<Function>(make_shared<op::Sum>(A, AxisSet{1, 2, 3}), ParameterVector{A});
test::Uniform<double> rng(1000000000.0L, 1000000000.001L, 2112);
vector<vector<double>> args;
for (shared_ptr<op::Parameter> param : f->get_parameters())
{
vector<double> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto ref_func = clone_function(*f);
auto bk_func = clone_function(*f);
auto ref_results = execute(ref_func, args, "INTERPRETER");
auto bk_results = execute(bk_func, args, "${BACKEND_NAME}");
EXPECT_TRUE(test::all_close(ref_results.at(0), bk_results.at(0), 0.0, 1e-5));
}
NGRAPH_TEST(${BACKEND_NAME}, sum_stable_simple_float)
{
std::string backend_name = "${BACKEND_NAME}";
if (backend_name == "INTERPRETER")
{
return;
}
Shape shape_a{20};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_rt{};
auto f = make_shared<Function>(make_shared<op::Sum>(A, AxisSet{0}), ParameterVector{A});
vector<vector<float>> args;
args.push_back(vector<float>{10000000.0f, 0.9f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f,
0.8f, 0.1f, 0.9f, 0.5f, 0.2f, 0.3f, 0.4f,
0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 0.1f});
auto ref_func = clone_function(*f);
auto bk_func = clone_function(*f);
auto ref_results = execute(ref_func, args, "INTERPRETER");
auto bk_results = execute(bk_func, args, "${BACKEND_NAME}");
EXPECT_TRUE(test::all_close_f(ref_results.at(0), bk_results.at(0), 24, 1));
}
NGRAPH_TEST(${BACKEND_NAME}, sum_stable_simple_double)
{
std::string backend_name = "${BACKEND_NAME}";
if (backend_name == "INTERPRETER")
{
return;
}
Shape shape_a{20};
auto A = make_shared<op::Parameter>(element::f64, shape_a);
Shape shape_rt{};
auto f = make_shared<Function>(make_shared<op::Sum>(A, AxisSet{0}), ParameterVector{A});
vector<vector<double>> args;
args.push_back(vector<double>{10000000000000000.0L,
0.2L,
0.3L,
0.4L,
0.5L,
0.6L,
0.7L,
0.8L,
0.9L,
0.7L,
0.9L,
0.7L,
0.3L,
0.6L,
0.8L,
0.4L,
0.6L,
0.5L,
0.8L,
0.7L});
auto ref_func = clone_function(*f);
auto bk_func = clone_function(*f);
auto ref_results = execute(ref_func, args, "INTERPRETER");
auto bk_results = execute(bk_func, args, "${BACKEND_NAME}");
EXPECT_TRUE(test::all_close(ref_results.at(0), bk_results.at(0), 0.0, 2.0));
}
#endif
......@@ -50,7 +50,9 @@ namespace ngraph
{
if (count < 5)
{
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
NGRAPH_INFO
<< std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< a[i] << " is not close to " << b[i] << " at index " << i;
}
count++;
rc = false;
......
......@@ -166,7 +166,8 @@ bool test::all_close_f(const vector<float>& a,
{
if (diff_count < 5)
{
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< a[i] << " is not close to " << b[i] << " at index " << i;
}
rc = false;
......@@ -191,10 +192,12 @@ bool test::all_close_f(const vector<float>& a,
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)";
NGRAPH_INFO << "tightest match: " << matching_mantissa_bits(min_distance)
NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match: " << matching_mantissa_bits(min_distance)
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index]
<< " at [" << min_distance_index << "])";
NGRAPH_INFO << "loosest match: " << matching_mantissa_bits(max_distance)
NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match: " << matching_mantissa_bits(max_distance)
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index]
<< " at [" << max_distance_index << "])";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment