Commit 40069d27 authored by Fenglei's avatar Fenglei Committed by Robert Kimball

gpu deconvolution (#1099)

* add pad_dilation function

* add dilation to gpu_emitter

* add CoordinateDiff constructor to GPUShape

* remove unecessary cast

* working version for forward

* forward working

* forward test all pass

* deconvolution forward

* backward data dilation

* forward test passed

* initial to 0

* fix bug for get_padded_shape and clang format

* code style, change variable names

* refactor convolution conditions

* fix bug padding_below_diff

* change pad_dilation to pad_dynamic, compare to pad

* remove passed convolution test from skip list, clang format

* change pad to use GPUShape
parent 83e6aa5f
......@@ -231,6 +231,106 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const runtime::gpu::GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
GPUShape output_shape,
GPUShape padding_below,
GPUShape padding_interior)
{
std::stringstream kernel_name;
kernel_name << "pad_dynamic_" << join(dtypes, "_");
std::string hash = kernel_name.str() + "pad_i" + join(input_shape, "_") + "pad_o" +
join(output_shape) + "_pb" + join(padding_below, "_") + "_pi" +
join(padding_interior, "_");
// For backwards compatability we currently use two unordered maps
// 1. one looks up the compiled cuda kernel (CudaFunctionPool)
// 2. the other looks to see if this kernel is already in the primitive list
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_pad_dynamic_op(writer, kernel_name.str(), dtypes);
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
unsigned int rank = static_cast<unsigned int>(input_shape.size());
unsigned int nthreads = static_cast<unsigned int>(shape_size(input_shape));
GPUShape pad_below(input_shape.size(), 0);
GPUShape pad_interior(input_shape.size(), 1);
int64_t i = padding_below.size() - 1;
int64_t j = input_shape.size() - 1;
for (; i >= 0; i--, j--)
{
pad_below[j] = padding_below[i];
pad_interior[j] = padding_interior[i];
}
GPUShape input_strides = row_major_strides(input_shape);
GPUShape output_strides = row_major_strides(output_shape);
// get an allocator for transient per kernel gpu memory
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_input_strides = allocator.reserve_argspace(
input_strides.data(), input_strides.size() * sizeof(unsigned int));
size_t idx_output_strides = allocator.reserve_argspace(
output_strides.data(), output_strides.size() * sizeof(unsigned int));
size_t idx_padding_below =
allocator.reserve_argspace(pad_below.data(), pad_below.size() * sizeof(unsigned int));
size_t idx_padding_interior =
allocator.reserve_argspace(pad_interior.data(), pad_interior.size() * sizeof(unsigned int));
// create the launch primitive
std::unique_ptr<gpu::primitive> pad_dynamic(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* param_input_strides = runtime::gpu::invoke_memory_primitive(ctx, idx_input_strides);
void* param_output_strides = runtime::gpu::invoke_memory_primitive(ctx, idx_output_strides);
void* param_padding_below = runtime::gpu::invoke_memory_primitive(ctx, idx_padding_below);
void* param_padding_interior =
runtime::gpu::invoke_memory_primitive(ctx, idx_padding_interior);
std::vector<void*> args_list{&inputs[0],
&outputs[0],
&param_input_strides,
&param_output_strides,
&param_padding_below,
&param_padding_interior,
&rank,
&nthreads};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(nthreads),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}});
primitive_index = this->m_primitive_emitter->insert(std::move(pad_dynamic));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
......
......@@ -47,6 +47,13 @@ namespace ngraph
GPUShape pad_interior,
const std::string& pad_value = "");
size_t build_pad_dynamic(const runtime::gpu::GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
GPUShape output_shape,
GPUShape padding_below,
GPUShape padding_interior);
size_t build_1d_max_pool(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
......
......@@ -279,6 +279,40 @@ void runtime::gpu::CudaKernelBuilder::get_concat_op(codegen::CodeWriter& writer,
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_pad_dynamic_op(
codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, unsigned int* input_strides, unsigned int* output_strides, "
"unsigned int* padding_below, unsigned int* "
"padding_interior, unsigned int rank, unsigned int n)\n";
writer.block_begin();
{
writer << "unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer.block_begin();
{
writer << "unsigned int output_idx = 0;\n";
writer << "unsigned int input_idx = tid;\n";
writer << "for(unsigned int i = 0; i < rank; i++)\n";
writer.block_begin();
{
writer << "output_idx += (input_idx / input_strides[i] * padding_interior[i] + "
"padding_below[i]) "
"* output_strides[i];\n";
writer << "input_idx %= input_strides[i];\n";
}
writer.block_end();
writer << "out[output_idx] = in[tid];\n";
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_slice_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types)
......
......@@ -81,6 +81,10 @@ namespace ngraph
const std::string& math_kernel,
const std::vector<std::string>& data_types);
static void get_pad_dynamic_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types);
static void get_ew_collective_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& op,
......
......@@ -163,40 +163,42 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
Strides data_dilation_strides = convolution->get_data_dilation_strides();
CoordinateDiff padding_below_diff = convolution->get_padding_below();
CoordinateDiff padding_above_diff = convolution->get_padding_above();
Shape padding_below(padding_below_diff.size(), 0);
Shape padding_above(padding_above_diff.size(), 0);
for (int i = 0; i < padding_below_diff.size(); i++)
{
padding_below[i] = static_cast<size_t>(padding_below_diff[i]);
padding_above[i] = static_cast<size_t>(padding_above_diff[i]);
}
if (padding_below.size() > 3)
if (padding_below_diff.size() > 3)
{
throw std::runtime_error(node->get_name() +
"with more than 3D is not implemented.");
}
bool is_deconvolution = false;
for (auto a : data_dilation_strides)
{
if (a != 1)
{
throw std::runtime_error(node->get_name() +
"with data dilation is not implemented.");
is_deconvolution = true;
break;
}
}
bool pad_required = false;
if (padding_below != padding_above)
bool pad_required = (padding_below_diff != padding_above_diff);
Shape padding_below(padding_below_diff.size(), 0);
Shape padding_above(padding_above_diff.size(), 0);
for (int i = 0; i < padding_below.size(); i++)
{
pad_required = true;
padding_below[i] = static_cast<size_t>(padding_below_diff[i]);
padding_above[i] = static_cast<size_t>(padding_above_diff[i]);
}
auto input_shape = args[0].get_shape();
auto input_shape_padded = input_shape;
if (pad_required)
auto input_shape = args[0].get_shape();
Shape input_shape_padded = input_shape;
Shape padding_interior(data_dilation_strides);
writer.block_begin(" // " + node->get_name());
if (pad_required || is_deconvolution)
{
input_shape_padded =
get_padded_shape(input_shape, padding_below, padding_above, {});
input_shape_padded = get_padded_shape(
input_shape, padding_below, padding_above, padding_interior);
Shape input_padded_strides = row_major_strides(input_shape_padded);
auto temp_size =
shape_size(input_shape_padded) * args[0].get_element_type().size();
GPUAllocator allocator =
......@@ -204,28 +206,28 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
size_t idx_workspace = allocator.reserve_workspace(temp_size);
writer << "void* pad_buffer = runtime::gpu::invoke_memory_primitive(ctx, "
<< idx_workspace << ");\n";
writer << "std::vector<" << args[0].get_type() << "> pad_buffer_host("
<< shape_size(input_shape_padded) << ", 0);\n";
writer << "runtime::gpu::cuda_memcpyHtD(pad_buffer, pad_buffer_host.data(), "
<< temp_size << ");\n";
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
auto pad_index =
cuda_emitter->build_pad(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
input_shape,
input_shape_padded,
padding_below,
padding_above,
Shape{},
std::string("0"));
writer << "gpu::invoke_primitive(ctx, " << pad_index << ", ";
auto pad_dynamic_index =
cuda_emitter->build_pad_dynamic(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
input_shape,
input_shape_padded,
padding_below,
padding_interior);
writer << "gpu::invoke_primitive(ctx, " << pad_dynamic_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{pad_buffer}.data()";
writer << ");\n";
// asymetric padding has been applied, zero out padding vectors to
// ensure cudnn does not assume padding
std::fill(padding_below.begin(), padding_below.end(), 0);
std::fill(padding_above.begin(), padding_above.end(), 0);
}
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
......@@ -239,10 +241,8 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
window_dilation_strides,
padding_below);
writer.block_begin(" // " + node->get_name());
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
if (pad_required)
if (pad_required || is_deconvolution)
{
writer << "std::vector<void*>{pad_buffer, " << args[1].get_name()
<< "}.data(), ";
......@@ -273,46 +273,52 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
Strides data_dilation_strides = convolution->get_data_dilation_strides_forward();
CoordinateDiff padding_below_diff = convolution->get_padding_below_forward();
CoordinateDiff padding_above_diff = convolution->get_padding_above_forward();
Shape padding_below(padding_below_diff.size(), 0);
Shape padding_above(padding_above_diff.size(), 0);
for (int i = 0; i < padding_below_diff.size(); i++)
{
padding_below[i] = static_cast<size_t>(padding_below_diff[i]);
padding_above[i] = static_cast<size_t>(padding_above_diff[i]);
}
if (padding_below.size() > 3)
if (padding_below_diff.size() > 3)
{
throw std::runtime_error(node->get_name() +
"with more than 3D is not implemented.");
}
bool is_deconvolution = false;
for (auto a : data_dilation_strides)
{
if (a != 1)
{
throw std::runtime_error(node->get_name() +
"with data dilation is not implemented.");
is_deconvolution = true;
break;
}
}
bool pad_required = false;
if (padding_below != padding_above)
bool pad_required = (padding_below_diff != padding_above_diff);
Shape padding_below(padding_below_diff.size(), 0);
Shape padding_above(padding_above_diff.size(), 0);
for (int i = 0; i < padding_below.size(); i++)
{
pad_required = true;
padding_below[i] = static_cast<size_t>(padding_below_diff[i]);
padding_above[i] = static_cast<size_t>(padding_above_diff[i]);
}
auto output_shape = out[0].get_shape();
auto output_shape_padded = output_shape;
Shape padding_below_back(output_shape.size(), 0);
Shape padding_interior_back(output_shape.size(), 1);
size_t i = padding_below_back.size() - padding_below.size();
size_t j = 0;
for (; i < padding_below_back.size(); i++)
{
padding_below_back[i] = padding_below[j++];
padding_below_back[i] = padding_below[j];
padding_interior_back[i] = data_dilation_strides[j];
j++;
}
Shape padding_interior(data_dilation_strides);
writer.block_begin(" // " + node->get_name());
if (pad_required)
if (pad_required || is_deconvolution)
{
output_shape_padded =
get_padded_shape(output_shape, padding_below, padding_above, {});
output_shape_padded = get_padded_shape(
output_shape, padding_below, padding_above, padding_interior);
auto temp_size =
shape_size(output_shape_padded) * args[0].get_element_type().size();
GPUAllocator allocator =
......@@ -320,28 +326,28 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
size_t idx_workspace = allocator.reserve_workspace(temp_size);
writer << "void* pad_buffer = runtime::gpu::invoke_memory_primitive(ctx, "
<< idx_workspace << ");\n";
writer << "std::vector<" << args[0].get_type() << "> pad_buffer_host("
<< shape_size(output_shape_padded) << ", 0);\n";
writer << "runtime::gpu::cuda_memcpyHtD(pad_buffer, pad_buffer_host.data(), "
<< temp_size << ");\n";
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
auto pad_index =
cuda_emitter->build_pad(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
output_shape,
output_shape_padded,
padding_below,
padding_above,
Shape{},
std::string("0"));
writer << "gpu::invoke_primitive(ctx, " << pad_index << ", ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data(), ";
auto pad_dynamic_index =
cuda_emitter->build_pad_dynamic(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
output_shape,
output_shape_padded,
padding_below,
padding_interior);
writer << "gpu::invoke_primitive(ctx, " << pad_dynamic_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{pad_buffer}.data()";
writer << ");\n";
// asymetric padding has been applied, zero out padding vectors to
// ensure cudnn does not assume padding
std::fill(padding_below.begin(), padding_below.end(), 0);
std::fill(padding_above.begin(), padding_above.end(), 0);
}
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
......@@ -359,8 +365,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "," << args[1].get_name()
<< "}.data(), ";
if (pad_required)
if (pad_required || is_deconvolution)
{
writer << "std::vector<void*>{pad_buffer}.data()";
}
......@@ -369,12 +374,10 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
}
writer << ");\n";
// since we padded output with temp buffer, we need to copy back to ouput
if (pad_required)
// since we padded output with temp buffer, we need to copy back to real ouput
if (pad_required || is_deconvolution)
{
const auto arg_rank = output_shape.size();
const Strides slice_strides(output_shape.size(), 1);
const auto input_strides = row_major_strides(output_shape_padded);
const auto output_strides = row_major_strides(output_shape);
GPUAllocator allocator =
......@@ -385,8 +388,9 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
output_strides.data(), output_strides.size() * sizeof(size_t));
size_t idx_lower_bounds = allocator.reserve_argspace(
padding_below_back.data(), padding_below_back.size() * sizeof(size_t));
size_t idx_slice_strides = allocator.reserve_argspace(
slice_strides.data(), slice_strides.size() * sizeof(size_t));
size_t idx_slice_strides =
allocator.reserve_argspace(padding_interior_back.data(),
padding_interior_back.size() * sizeof(size_t));
writer << "size_t rank = " << arg_rank << ";\n";
writer << "void* input_strides_d = "
......@@ -433,41 +437,41 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
Strides data_dilation_strides = convolution->get_data_dilation_strides_forward();
CoordinateDiff padding_below_diff = convolution->get_padding_below_forward();
CoordinateDiff padding_above_diff = convolution->get_padding_above_forward();
Shape padding_below(padding_below_diff.size(), 0);
Shape padding_above(padding_above_diff.size(), 0);
for (int i = 0; i < padding_below_diff.size(); i++)
{
padding_below[i] = static_cast<size_t>(padding_below_diff[i]);
padding_above[i] = static_cast<size_t>(padding_above_diff[i]);
}
if (padding_below.size() > 3)
if (padding_below_diff.size() > 3)
{
throw std::runtime_error(node->get_name() +
"with more than 3D is not implemented.");
}
bool is_deconvolution = false;
for (auto a : data_dilation_strides)
{
if (a != 1)
{
throw std::runtime_error(node->get_name() +
"with data dilation is not implemented.");
is_deconvolution = true;
break;
}
}
bool pad_required = false;
if (padding_below != padding_above)
bool pad_required = (padding_below_diff != padding_above_diff);
Shape padding_below(padding_below_diff.size(), 0);
Shape padding_above(padding_above_diff.size(), 0);
for (int i = 0; i < padding_below.size(); i++)
{
pad_required = true;
padding_below[i] = static_cast<size_t>(padding_below_diff[i]);
padding_above[i] = static_cast<size_t>(padding_above_diff[i]);
}
auto input_shape = args[0].get_shape();
auto input_shape_padded = input_shape;
Shape padding_interior(data_dilation_strides);
writer.block_begin(" // " + node->get_name());
if (pad_required)
if (pad_required || is_deconvolution)
{
input_shape_padded =
get_padded_shape(input_shape, padding_below, padding_above, {});
input_shape_padded = get_padded_shape(
input_shape, padding_below, padding_above, padding_interior);
auto temp_size =
shape_size(input_shape_padded) * args[0].get_element_type().size();
GPUAllocator allocator =
......@@ -475,26 +479,27 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
size_t idx_workspace = allocator.reserve_workspace(temp_size);
writer << "void* pad_buffer = runtime::gpu::invoke_memory_primitive(ctx, "
<< idx_workspace << ");\n";
writer << "std::vector<" << args[0].get_type() << "> pad_buffer_host("
<< shape_size(input_shape_padded) << ", 0);\n";
writer << "runtime::gpu::cuda_memcpyHtD(pad_buffer, pad_buffer_host.data(), "
<< temp_size << ");\n";
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
auto pad_index =
cuda_emitter->build_pad(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
input_shape,
input_shape_padded,
padding_below,
padding_above,
Shape{},
std::string("0"));
writer << "gpu::invoke_primitive(ctx, " << pad_index << ", ";
auto pad_dynamic_index =
cuda_emitter->build_pad_dynamic(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
input_shape,
input_shape_padded,
padding_below,
padding_interior);
writer << "gpu::invoke_primitive(ctx, " << pad_dynamic_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{pad_buffer}.data()";
writer << ");\n";
// asymetric padding has been applied, zero out padding vectors to
// ensure cudnn does not assume padding
std::fill(padding_below.begin(), padding_below.end(), 0);
std::fill(padding_above.begin(), padding_above.end(), 0);
}
auto& cudnn_emitter =
......@@ -512,7 +517,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
padding_below);
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
if (pad_required)
if (pad_required || is_deconvolution)
{
writer << "std::vector<void*>{pad_buffer, " << args[1].get_name()
<< "}.data(), ";
......@@ -1912,12 +1917,6 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
const Shape& padding_above,
const Shape& padding_interior)
{
if (padding_interior.size())
{
throw std::runtime_error(
"Interior padding support is not yet available on GPU.");
}
enum class padtype
{
None,
......@@ -1925,31 +1924,33 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
Asymmetric
};
auto type = padtype::None;
for (int i = 0; i < padding_below.size(); i++)
if (padding_below == padding_above)
{
if (padding_below[i] != 0 || padding_above[i] != 0)
{
type = padtype::Symmetric;
}
if (padding_below[i] != padding_above[i])
{
type = padtype::Asymmetric;
break;
}
type = padtype::Symmetric;
}
if (type == padtype::None)
else
{
return input_shape;
type = padtype::Asymmetric;
}
Shape padded_shape = input_shape;
for (int i = 0; i < padding_below.size(); i++)
int64_t i = input_shape.size() - 1;
int64_t j = padding_below.size() - 1;
if (padding_interior.empty())
{
padded_shape[padded_shape.size() - 1 - i] +=
(padding_below[padding_below.size() - 1 - i] +
padding_above[padding_above.size() - 1 - i]);
for (; j >= 0; j--, i--)
{
padded_shape[i] += padding_below[j] + padding_above[j];
}
}
else
{
for (; j >= 0; j--, i--)
{
padded_shape[i] = (padded_shape[i] - 1) * padding_interior[j] + 1 +
padding_below[j] + padding_above[j];
}
}
return padded_shape;
}
......
......@@ -22,6 +22,7 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
......@@ -125,5 +126,19 @@ namespace ngraph
this->push_back(static_cast<uint32_t>(size));
}
}
GPUShape(const CoordinateDiff& coord)
{
for (auto const& size : coord)
{
if (size >> 32 != 0)
{
throw std::runtime_error(
"Request for Coordinate which exceed the bitwidth available for GPUShapes "
"(32)");
}
this->push_back(static_cast<uint32_t>(size));
}
}
};
}
......@@ -94,15 +94,6 @@ namespace ngraph
namespace gpu
{
void print_gpu_f32_tensor(const void* p, size_t element_count, size_t element_size);
template <typename T>
void print_gpu_tensor(const void* p, size_t element_count)
{
std::vector<T> local(element_count);
size_t size_in_bytes = sizeof(T) * element_count;
cudaMemcpy(local.data(), p, size_in_bytes, cudaMemcpyDeviceToHost);
std::cout << "{" << ngraph::join(local) << "}" << std::endl;
}
void check_cuda_errors(CUresult err);
void* create_gpu_buffer(size_t buffer_size, const void* data = NULL);
void free_gpu_buffer(void* buffer);
......@@ -112,6 +103,15 @@ namespace ngraph
void cuda_memset(void* dst, int value, size_t buffer_size);
std::pair<uint64_t, uint64_t> idiv_magic_u32(uint64_t max_numerator, uint64_t divisor);
std::pair<uint64_t, uint64_t> idiv_magic_u64(uint64_t divisor);
template <typename T>
void print_gpu_tensor(const void* p, size_t element_count)
{
std::vector<T> local(element_count);
size_t size_in_bytes = sizeof(T) * element_count;
cuda_memcpyDtH(local.data(), p, size_in_bytes);
std::cout << "{" << ngraph::join(local) << "}" << std::endl;
}
}
}
}
......@@ -7,20 +7,6 @@ batch_norm_three_outputs
computation_reuse
concat_matrix_int64
constant_equality_bool
convolution_2d_1item_1o1i_data_dilated
convolution_2d_1item_2o1i_data_dilated
convolution_2d_1item_2o2i_data_dilated
convolution_2d_1item_5o3i_data_dilated
convolution_2d_2item_5o3i_data_dilated
convolution_2d_2items_dilated_padded
convolution_2d_2items_strided_padded
convolution_2d_8item_large_5o3i_data_dilated
convolution_2d_8item_large_5o3i_uneven_filter_data_dilated
convolution_2d_8item_large_5o3i_uneven_filter_uneven_data_dilation_data_dilated
convolution_3d_1item_large_5o3i_padded_uneven_filter_uneven_data_dilation_data_dilated
convolution_3d_2item_large_5o3i_padded_strided_uneven_filter_uneven_data_dilation_data_dilated
convolution_3d_2item_large_5o3i_padded_strided_uneven_filter_uneven_data_dilation_filter_dilated_data_dilated
convolution_3d_2item_large_5o3i_uneven_filter_uneven_data_dilation_data_dilated
convolution_4d_2items
convolution_4d_4items
convolution_4d_4items_dilated
......
......@@ -45,7 +45,7 @@ namespace ngraph
{
if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i]))
{
NGRAPH_INFO << a[i] << " is not close to " << b[i];
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
rc = false;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment