Commit 026bede0 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

Add GPURuntimeContext and GPUPrimitiveEmitter to the gpu transformer (#837)

* Begin prototype of cudnn_emitter.

* Added GPURuntimeContext to gpu_external_function for passing through to JIT functions.

* gpu_emitters now utilize gpu runtime context.

* Moved cublas and cudnn handles into GPURuntimeContext pointer and out of callframe EntryPoint.

* Added CUDNNEmitter, comparable to MKLDNNEmitter,
which allows for cudnn kernels to be defined via
lambda primitives that are emitted and
subsequently called during graph execution.
An example implementation is provided for op::Sum.

* Added GPURuntimeContext to gpu_external_function for passing through to JIT functions.

* gpu_emitters now utilize gpu runtime context.

* Moved cublas and cudnn handles into GPURuntimeContext pointer and out of callframe EntryPoint.

* GPURuntimeContext should be stored as unique_ptr in external function.

* GPURuntimeContext should be stored as unique_ptr in external function.

* Extract raw pointer from unique for cudnn_emitter.

* Removing unrelated code from PR.

* GPURuntimeContext needs to be a strict C interface in case
the native compiler and clang are utilizing different glibc ABIs.
Updated to reflect this.

* Added cudnn::primitive typedef for better readability.

* Moved allocation of CudaFunctionPool to external function
so that it is available during gpu emission.

* Fixed too-late initialization of cudart.

* Fixed too-late initialization of cudart.

* CUDNNEmitter moved into superset class GPUPrimitiveEmitter.
The GPUPrimitiveEmitter handles the emission of all gpu primitives,
including cudnn, cuda, and cublas. CUBLASEmitter support not yet included.

* Added unordered_map for cacheing primitives in the gpu_emitter.

* Added dtor to GPUPrimitiveEmitter to cleanup compiled functions.

* Adding back a serialized model graph that was accidentally rem* Added a few additional helpers to use ngraph::row_major_strides.

* added whitespace per @fengleitian's comment

* added whitespace per @fengleitian's comment

* Remove implicit type conversions from size_t to int.

* Add op::MaxPool, op::MaxPoolBackprop and op::Pad to GPU transformer (#817)

* Added pooling for 1 and 2dimensions. 1d uses a cuda kernel and 2d utilizes cudnn.
Padding is not yet supported.

* Normalized call signature on gpu emission for 1d max pool. Added a few comments.

* Max pool backprop impl. inprogress. Amend this commit.

* Max pool backprop implemented. Note that cuDNN
requests the output tensor for the maxpool operation but it is not required for computation.

* Formatting and invokation for maxpool changed.

* Fixed too-late initialization of cudart.

* Added padding kernel that is used with maxpool. Need to investigate remaining tests.

* Changed dimensionality check to correctly
determine if data is 1d or not.

* Added 3d MaxPooling (forward), verified by forcing 2d case to use Nd pooling routines.

* Added 3d MaxPooling (backward), verified by forcing 2d case to use Nd pooling routines.

* Moved cudnn prologues for maxpool into ngraph runtime and out of primitive so
that the only execution occuring on the JIT runtime is the evaluation of the op kernel.

* Refactored forward and backward pooling into single CUDNNEmitter::build_pooling interface
with a runtime switch to determine if the op is forward or backward propagation.

* Cache preconstructed cudnn kernel for maxpool if it has already been constructed.

* Forgot to add padding arrays back into cudnn kernel for MaxPool in the 2d case.

* Fixed namespace issues and use join(...,'_')

* Refactored 4d/Nd tensor descriptor builder into single function.

* Changed conditionals and comments. Now throws if MaxPool on more than 3 spatial dimensions is requested.

* Fixed forward declare for GPURuntimeContext (class -> struct).

* Clang complains about missing braces on brace-initializer. Fixed implicit conversions.

* Fixed implicit conversions (clang).

* Reverting changes on autodiff test for maxpool. @Krovatkin will update later.
parent dfae57c1
......@@ -100,3 +100,9 @@ python-wheels/
# python wheels
python/share/*
# git merge
*.orig
\#*
\.#*
......@@ -277,6 +277,11 @@ endif()
runtime/gpu/gpu_cuda_function_builder.cpp
runtime/gpu/gpu_cuda_function_pool.cpp
runtime/gpu/gpu_cuda_context_manager.cpp
runtime/gpu/gpu_runtime_context.cpp
runtime/gpu/gpu_primitive_emitter.cpp
runtime/gpu/gpu_invoke.cpp
runtime/gpu/cudnn_emitter.cpp
runtime/gpu/cuda_emitter.cpp
)
set_property(SOURCE codegen/compiler.cpp APPEND_STRING PROPERTY COMPILE_DEFINITIONS
"CUDA_HEADER_PATHS=\"${CUDA_INCLUDE_DIRS}\";")
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <iostream>
#include <limits>
#include <vector>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/runtime/gpu/cuda_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
runtime::gpu::CUDAEmitter::CUDAEmitter(runtime::gpu::GPUPrimitiveEmitter* emitter)
: m_primitive_emitter(emitter)
{
}
size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
const Shape& input_shape,
const Shape& output_shape,
const Shape& padding_below,
const Shape& padding_above,
const Shape& padding_interior)
{
if (padding_interior.size())
{
throw std::runtime_error("Interior padding is not yet supported in the GPU transformer.");
}
std::string hash = "pad_i" + join(input_shape, "_") + "_pb" + join(padding_below, "_") + "_pa" +
join(padding_above, "_") + "_pi" + join(padding_interior, "_");
// For backwards compatability we currently use two unordered maps
// 1. one looks up the compiled cuda kernel (CudaFunctionPool)
// 2. the other looks to see if this kernel is already in the primitive list
// Once all previously implemented cuda kernels are refactored to use the
// CUDAEmitter/GPUPrimittiveEmitter interface, only one map (from hash to primitive index)
// will be required.
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// if the kernel has not been compiled, build it
auto compiled_kernel = ctx->compiled_kernel_pool->get(hash);
auto nthreads = shape_size(input_shape);
if (compiled_kernel == nullptr)
{
// normalize pad dimensions to shape dimensions
Shape pad_below(input_shape.size(), 0);
Shape pad_above(input_shape.size(), 0);
if (padding_below.size() != input_shape.size())
{
for (int64_t i = padding_below.size() - 1; i >= 0; i--)
{
pad_below[i + input_shape.size() - padding_below.size()] = padding_below[i];
pad_above[i + input_shape.size() - padding_above.size()] = padding_above[i];
}
}
else
{
pad_below = padding_below;
pad_above = padding_above;
}
auto input_strides = row_major_strides(input_shape);
auto output_strides = row_major_strides(output_shape);
int offset = 0;
for (size_t i = 0; i < output_strides.size(); i++)
{
offset += (output_strides[i] * pad_below[i]);
}
codegen::CodeWriter writer;
writer << "extern \"C\" __global__ void cuda_" << hash << "(" << dtypes[0] << "* in, "
<< dtypes[1] << "* out)\n";
writer.block_begin();
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x; \n";
writer << "if (tid < " << nthreads << ")\n";
writer.block_begin();
{
writer << "size_t idx = ";
writer << offset << " + tid % " << input_shape.back();
int64_t last = input_strides.size() - 1;
for (int64_t i = last - 1; i > 0; i--)
{
writer << " + ((tid / " << input_strides[i] << ") % " << input_shape[i + 1]
<< ") * " << output_strides[i];
}
writer << ";\n";
writer << "out[idx] = in[tid];\n";
}
writer.block_end();
}
writer.block_end();
compiled_kernel = ctx->compiled_kernel_pool->set(hash, writer.get_code());
}
auto pad = new gpu::primitive{[=](void** inputs, void** outputs) {
void* args_list[] = {&inputs[0], &outputs[0]};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(nthreads),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}};
primitive_index = this->m_primitive_emitter->insert(pad);
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
void print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name,
const Shape& shape)
{
auto strides = row_major_strides(shape);
writer << "__syncthreads();\n";
writer << "if (tid==0)\n";
writer.block_begin();
{
std::string element = tensor_name + "[i]";
writer << "for (int i=0; i<" << shape_size(shape) << "; i++)\n";
writer.block_begin();
{
for (int64_t i = strides.size() - 1; i >= 0; i--)
{
writer << "if (i % " << strides[i] << " == 0)\n";
writer.block_begin();
{
writer << "printf(\"";
for (int64_t j = 0; j < strides.size() - 1 - i; j++)
{
writer << "\\n";
}
writer << "\");\n";
}
writer.block_end();
}
writer << "printf(\"%4.2f \", " << element << ");\n";
}
writer.block_end();
writer << "printf(\"\\n\");\n";
}
writer.block_end();
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <array>
namespace ngraph
{
class Shape;
namespace runtime
{
namespace gpu
{
struct GPURuntimeContext;
class GPUPrimitiveEmitter;
class CUDAEmitter
{
friend class GPUPrimitiveEmitter;
public:
size_t build_pad(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
const Shape& input_shape,
const Shape& output_shape,
const Shape& pad_below,
const Shape& pad_above,
const Shape& pad_interior);
private:
CUDAEmitter(GPUPrimitiveEmitter* emitter);
GPUPrimitiveEmitter* m_primitive_emitter;
};
}
}
}
This diff is collapsed.
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <functional>
#include <vector>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn_v7.h>
#include "ngraph/axis_set.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
namespace cudnn_util
{
std::vector<int> compute_strides(const Shape&);
std::vector<int> compute_strides(const std::vector<int>&);
std::vector<int> get_vector_int_from_size_t(const std::vector<size_t>&);
cudnnTensorDescriptor_t tensor_descriptor_from_shape(const Shape& shape);
}
class GPUPrimitiveEmitter;
class CUDNNEmitter
{
friend class GPUPrimitiveEmitter;
public:
enum class Prop
{
Forward,
Backward
};
size_t build_reduce_forward(const GPURuntimeContext* ctx,
const cudnnReduceTensorOp_t& reduce_op,
const Shape& input_shape,
const AxisSet& reduction_axes);
size_t build_pooling(const GPURuntimeContext* ctx,
const cudnnPoolingMode_t& pool_op,
const Prop& direction,
const ngraph::Shape& input_shape,
const ngraph::Shape& output_shape,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
private:
CUDNNEmitter(GPUPrimitiveEmitter* emitter);
GPUPrimitiveEmitter* m_primitive_emitter;
};
}
}
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <string>
#include "cudnn_invoke.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
extern "C" void ngraph::runtime::gpu::cudnn_utils::cudnn_invoke_primitive(GPURuntimeContext* ctx,
size_t primitive_index,
void** args,
void** result)
{
(*ctx->cudnn_primitives[primitive_index])(args, result);
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <cstddef>
namespace ngraph
{
namespace runtime
{
namespace gpu
{
struct GPURuntimeContext;
namespace cudnn_utils
{
extern "C" void cudnn_invoke_primitive(GPURuntimeContext* ctx,
size_t primitive_index,
void** args,
void** result);
}
}
}
}
......@@ -32,28 +32,14 @@ runtime::gpu::GPU_CallFrame::GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction>
: m_external_function(external_function)
, m_compiled_function(compiled_function)
{
//Create context use driver API and make it current, the runtime call will pickup the context
//http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#interoperability-between-runtime-and-driver-apis
ngraph::runtime::gpu::CudaContextManager::instance();
cublasStatus_t cublasStatus = cublasCreate(&m_cublas_handle);
if (cublasStatus != CUBLAS_STATUS_SUCCESS)
{
throw runtime_error("cuBLAS create handle failed");
}
cudnnStatus_t cudnnStatus = cudnnCreate(&m_cudnn_handle);
if (cudnnStatus != CUDNN_STATUS_SUCCESS)
{
throw runtime_error("cuDnn create handle failed");
}
// Pass scalars as reference on the Device
cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
setup_runtime_context();
}
runtime::gpu::GPU_CallFrame::~GPU_CallFrame()
{
cublasDestroy(m_cublas_handle);
cudnnDestroy(m_cudnn_handle);
cleanup_runtime_context();
}
void runtime::gpu::GPU_CallFrame::tensor_call(
......@@ -77,7 +63,7 @@ void runtime::gpu::GPU_CallFrame::tensor_call(
outputs.push_back(tv->m_allocated_buffer_pool);
}
m_compiled_function(inputs.data(), outputs.data(), m_cublas_handle, m_cudnn_handle);
m_compiled_function(inputs.data(), outputs.data(), m_external_function->m_ctx.get());
}
void runtime::gpu::GPU_CallFrame::call(
......@@ -99,3 +85,31 @@ void runtime::gpu::GPU_CallFrame::call(
tensor_call(outputs, inputs);
}
void runtime::gpu::GPU_CallFrame::setup_runtime_context()
{
cublasStatus_t cublasStatus = cublasCreate(&m_cublas_handle);
if (cublasStatus != CUBLAS_STATUS_SUCCESS)
{
throw runtime_error("cuBLAS create handle failed");
}
cudnnStatus_t cudnnStatus = cudnnCreate(&m_cudnn_handle);
if (cudnnStatus != CUDNN_STATUS_SUCCESS)
{
throw runtime_error("cuDnn create handle failed");
}
// Pass scalars as reference on the Device
cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
const auto& primitive_emitter = m_external_function->get_primitive_emitter();
m_external_function->m_ctx->gpu_primitives = primitive_emitter->get_primitives().data();
// register with c-api runtime context
m_external_function->m_ctx->cublas_handle = &m_cublas_handle;
m_external_function->m_ctx->cudnn_handle = &m_cudnn_handle;
}
void runtime::gpu::GPU_CallFrame::cleanup_runtime_context()
{
}
......@@ -25,6 +25,7 @@
#include "ngraph/function.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -38,10 +39,7 @@ namespace ngraph
class GPU_CallFrame;
class GPU_ExternalFunction;
using EntryPoint_t = void(void** inputs,
void** outputs,
cublasHandle_t& cublas_handle,
cudnnHandle_t& cudnn_handle);
using EntryPoint_t = void(void** inputs, void** outputs, GPURuntimeContext* ctx);
using EntryPoint = std::function<EntryPoint_t>;
......@@ -65,6 +63,9 @@ namespace ngraph
void tensor_call(const std::vector<std::shared_ptr<TensorView>>& outputs,
const std::vector<std::shared_ptr<TensorView>>& inputs) override;
void setup_runtime_context();
void cleanup_runtime_context();
protected:
std::shared_ptr<GPU_ExternalFunction> m_external_function;
EntryPoint m_compiled_function;
......
......@@ -14,6 +14,7 @@
* limitations under the License.
*******************************************************************************/
#include <iostream>
#include <string>
#include "ngraph/runtime/gpu/gpu_cuda_context_manager.hpp"
......
......@@ -28,13 +28,8 @@ static const std::string s_output_dir = "gpu_codegen";
using namespace ngraph;
runtime::gpu::CudaFunctionPool& runtime::gpu::CudaFunctionPool::instance()
{
static CudaFunctionPool pool;
return pool;
}
void runtime::gpu::CudaFunctionPool::set(const std::string& name, const std::string& kernel)
std::shared_ptr<CUfunction> runtime::gpu::CudaFunctionPool::set(const std::string& name,
const std::string& kernel)
{
const char* opts[] = {"--gpu-architecture=compute_35", "--relocatable-device-code=true"};
std::string filename =
......@@ -42,7 +37,9 @@ void runtime::gpu::CudaFunctionPool::set(const std::string& name, const std::str
std::ofstream out(filename);
out << kernel;
out.close();
m_function_map.insert({name, CudaFunctionBuilder::get("cuda_" + name, kernel, 2, opts)});
auto cu_compiled_function = CudaFunctionBuilder::get("cuda_" + name, kernel, 2, opts);
m_function_map.insert({name, cu_compiled_function});
return cu_compiled_function;
}
std::shared_ptr<CUfunction> runtime::gpu::CudaFunctionPool::get(const std::string& name)
......
......@@ -30,18 +30,12 @@ namespace ngraph
class CudaFunctionPool
{
public:
static CudaFunctionPool& instance();
CudaFunctionPool(CudaFunctionPool const&) = delete;
CudaFunctionPool(CudaFunctionPool&&) = delete;
CudaFunctionPool& operator=(CudaFunctionPool const&) = delete;
CudaFunctionPool& operator=(CudaFunctionPool&&) = delete;
void set(const std::string& name, const std::string& kernel);
std::shared_ptr<CUfunction> get(const std::string& name);
protected:
CudaFunctionPool() {}
~CudaFunctionPool() {}
std::shared_ptr<CUfunction> set(const std::string& name, const std::string& kernel);
std::shared_ptr<CUfunction> get(const std::string& name);
private:
std::unordered_map<std::string, std::shared_ptr<CUfunction>> m_function_map;
};
}
......
......@@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include <algorithm>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
using namespace ngraph;
......@@ -165,6 +167,48 @@ void runtime::gpu::CudaKernelBuilder::get_slice_op(codegen::CodeWriter& writer,
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_1d_max_pool(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types)
{
// assumes data is in NCW format
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, "
<< "int width, "
<< "int stride, "
<< "int input_size, "
<< "int output_size, "
<< "int n)\n";
writer.block_begin();
{
// index into output tensor
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer.block_begin();
{
// index into input tensor
writer << "size_t start = (tid / output_size) * input_size + (tid % output_size) * "
"stride;\n";
writer << data_types[0] << " max_val = 0;\n";
writer << "for (size_t i = start; i < start+width; i++)\n";
writer.block_begin();
{
writer << "const " << data_types[0] << " input = in[i];\n";
writer << "if (input > max_val)";
writer.block_begin();
{
writer << "max_val = input;\n";
}
writer.block_end();
}
writer.block_end();
writer << "out[tid] = max_val;\n";
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_device_helper(
codegen::CodeWriter& writer,
const std::string& name,
......
......@@ -55,6 +55,10 @@ namespace ngraph
const std::string& name,
const std::array<std::string, 2>& data_types);
static void get_1d_max_pool(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types);
static void get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& math_kernel,
......
......@@ -24,9 +24,10 @@ using namespace ngraph;
using namespace ngraph::runtime::gpu;
void runtime::gpu::emit_broadcast(const std::string& name,
std::array<std::string, 2> data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
size_t repeat_size,
size_t repeat_times,
size_t count)
......@@ -34,17 +35,18 @@ void runtime::gpu::emit_broadcast(const std::string& name,
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
// Create an instance of nvrtcProgram with the code string.
if (CudaFunctionPool::instance().get(name_signature) == nullptr)
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_broadcast_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
CudaFunctionPool::instance().set(name_signature, kernel);
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
void* args_list[] = {&in, &out, &repeat_size, &repeat_times, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*CudaFunctionPool::instance().get(name_signature).get(),
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
1,
1, // grid dim
......@@ -59,9 +61,10 @@ void runtime::gpu::emit_broadcast(const std::string& name,
}
void runtime::gpu::emit_onehot(const std::string& name,
std::array<std::string, 2> data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
size_t repeat_size,
size_t repeat_times,
size_t count)
......@@ -69,17 +72,18 @@ void runtime::gpu::emit_onehot(const std::string& name,
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
// Create an instance of nvrtcProgram with the code string.
if (CudaFunctionPool::instance().get(name_signature) == nullptr)
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_onehot_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
CudaFunctionPool::instance().set(name_signature, kernel);
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
void* args_list[] = {&in, &out, &repeat_size, &repeat_times, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*CudaFunctionPool::instance().get(name_signature).get(),
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
1,
1, // grid dim
......@@ -94,9 +98,10 @@ void runtime::gpu::emit_onehot(const std::string& name,
}
void runtime::gpu::emit_reshape(const std::string& name,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
CUdeviceptr input_strides,
CUdeviceptr trans_strides,
size_t rank,
......@@ -104,18 +109,18 @@ void runtime::gpu::emit_reshape(const std::string& name,
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
// Create an instance of nvrtcProgram with the code string.
if (CudaFunctionPool::instance().get(name_signature) == nullptr)
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_reshape_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
CudaFunctionPool::instance().set(name_signature, kernel);
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
void* args_list[] = {&in, &out, &input_strides, &trans_strides, &rank, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*CudaFunctionPool::instance().get(name_signature).get(),
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
1,
1, // grid dim
......@@ -133,6 +138,7 @@ void runtime::gpu::emit_slice(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr input_strides,
CUdeviceptr lower_bounds,
CUdeviceptr slice_strides,
......@@ -142,19 +148,19 @@ void runtime::gpu::emit_slice(const std::string& name,
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
// Create an instance of nvrtcProgram with the code string.
if (CudaFunctionPool::instance().get(name_signature) == nullptr)
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_slice_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
CudaFunctionPool::instance().set(name_signature, kernel);
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
void* args_list[] = {
&in, &out, &input_strides, &lower_bounds, &slice_strides, &output_strides, &rank, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*CudaFunctionPool::instance().get(name_signature).get(),
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
1,
1, // grid dim
......
......@@ -23,6 +23,7 @@
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/strides.hpp"
namespace ngraph
......@@ -35,25 +36,28 @@ namespace ngraph
struct CudaOpMap;
void emit_broadcast(const std::string& name,
std::array<std::string, 2> data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
size_t repeat_size,
size_t repeat_times,
size_t count);
void emit_onehot(const std::string& name,
std::array<std::string, 2> data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
size_t repeat_size,
size_t repeat_times,
size_t count);
void emit_reshape(const std::string& name,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
CUdeviceptr input_strides,
CUdeviceptr trans_strides,
size_t rank,
......@@ -63,6 +67,7 @@ namespace ngraph
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr input_strides,
CUdeviceptr lower_bounds,
CUdeviceptr slice_strides,
......@@ -73,13 +78,15 @@ namespace ngraph
template <typename T, typename... Inputs>
void emit_elementwise_op(const std::string& name,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
size_t count,
CUdeviceptr out,
Inputs&&... inputs)
{
std::string type_signature = "_" + data_types[0] + "_" + data_types[1];
std::replace(type_signature.begin(), type_signature.end(), ' ', '_');
if (CudaFunctionPool::instance().get(name + type_signature) == nullptr)
auto compiled_kernel = ctx->compiled_kernel_pool->get(name + type_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
......@@ -99,25 +106,60 @@ namespace ngraph
writer, name + type_signature, op_name, data_types, sizeof...(inputs));
std::string kernel = writer.get_code();
CudaFunctionPool::instance().set(name + type_signature, kernel);
compiled_kernel = ctx->compiled_kernel_pool->set(name + type_signature, kernel);
}
//convert runtime ptr to driver api ptr
void* args_list[] = {&inputs..., &out, &count};
CUDA_SAFE_CALL(
cuLaunchKernel(*CudaFunctionPool::instance().get(name + type_signature).get(),
count,
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
count,
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
template <typename... Args>
void emit_1d_max_pool(GPURuntimeContext* ctx,
const std::string& name,
const std::array<std::string, 2>& data_types,
size_t count,
Args&&... args)
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::get_1d_max_pool(writer, name_signature, data_types);
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
if (sizeof...(args))
{
std::vector<void*> args_list = {&args..., &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
&args_list[0],
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
}
}
}
}
This diff is collapsed.
......@@ -64,6 +64,10 @@ namespace ngraph
const std::vector<GPU_TensorViewWrapper>& args,
const std::vector<GPU_TensorViewWrapper>& out);
};
Shape get_padded_shape(const Shape& input_shape,
const Shape& padding_below,
const Shape& padding_above,
const Shape& padding_interior);
}
}
}
......@@ -112,6 +112,7 @@
#include "ngraph/runtime/gpu/gpu_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
using namespace std;
using namespace ngraph;
......@@ -241,7 +242,18 @@ runtime::gpu::GPU_ExternalFunction::GPU_ExternalFunction(
: ngraph::runtime::ExternalFunction(function, release_function)
, m_compiled_function(nullptr)
, m_emit_timing(std::getenv("NGRAPH_GPU_EMIT_TIMING") != nullptr)
, m_ctx(new GPURuntimeContext)
{
// Create context use driver API and make it current, the runtime call will pickup the context
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
// #interoperability-between-runtime-and-driver-apis
ngraph::runtime::gpu::CudaContextManager::instance();
m_ctx->compiled_kernel_pool = new CudaFunctionPool;
}
runtime::gpu::GPU_ExternalFunction::~GPU_ExternalFunction()
{
delete m_ctx->compiled_kernel_pool;
}
void runtime::gpu::GPU_ExternalFunction::compile()
......@@ -251,6 +263,8 @@ void runtime::gpu::GPU_ExternalFunction::compile()
return;
}
m_primitive_emitter.reset(new GPUPrimitiveEmitter());
string function_name = m_function->get_name();
string dump_filename = file_util::path_join(s_output_dir, function_name + "_ops.txt");
......@@ -288,6 +302,8 @@ void runtime::gpu::GPU_ExternalFunction::compile()
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
#include "ngraph/runtime/gpu/gpu_invoke.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/util.hpp"
)";
......@@ -296,6 +312,7 @@ void runtime::gpu::GPU_ExternalFunction::compile()
writer += R"(
using namespace ngraph;
using namespace ngraph::runtime;
using namespace std;
)";
......@@ -402,8 +419,7 @@ using namespace std;
for (shared_ptr<Function> f : pass_manager.get_state().get_functions())
{
writer << "extern \"C\" void " << f->get_name() << "(void** inputs, void** outputs, "
"cublasHandle_t& cublas_handle, "
"cudnnHandle_t& cudnn_handle);\n";
<< "gpu::GPURuntimeContext* ctx);\n";
}
writer << "\n";
......@@ -506,9 +522,8 @@ using namespace std;
}
writer << "extern \"C\" void " << current_function->get_name();
writer << "(void** inputs, void** outputs, cublasHandle_t& cublas_handle, "
"cudnnHandle_t& "
"cudnn_handle)\n";
writer << "(void** inputs, void** outputs, "
<< "gpu::GPURuntimeContext* ctx)\n";
writer << "{\n";
writer.indent++;
......@@ -809,3 +824,8 @@ void runtime::gpu::GPU_ExternalFunction::emit_debug_function_exit(
{
writer << "timer_" << node->get_name() << ".stop();\n";
}
std::unique_ptr<runtime::gpu::GPURuntimeContext>& runtime::gpu::GPU_ExternalFunction::ctx()
{
return m_ctx;
}
......@@ -28,6 +28,7 @@
#include "ngraph/function.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/gpu/gpu_call_frame.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_view_wrapper.hpp"
namespace ngraph
......@@ -39,6 +40,7 @@ namespace ngraph
class GPU_ExternalFunction;
class GPU_Emitter;
class GPU_CallFrame;
struct GPURuntimeContext;
using OpFunction =
std::function<void(GPU_ExternalFunction* external_function,
......@@ -57,7 +59,13 @@ namespace ngraph
public:
GPU_ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true);
~GPU_ExternalFunction();
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
std::unique_ptr<runtime::gpu::GPURuntimeContext>& ctx();
const std::unique_ptr<GPUPrimitiveEmitter>& get_primitive_emitter() const
{
return m_primitive_emitter;
}
protected:
void compile();
......@@ -82,6 +90,8 @@ namespace ngraph
std::unique_ptr<codegen::ExecutionEngine> m_execution_engine;
bool m_emit_timing;
std::unordered_map<std::string, std::string> m_variable_name_map;
std::unique_ptr<GPUPrimitiveEmitter> m_primitive_emitter;
std::unique_ptr<GPURuntimeContext> m_ctx;
};
}
}
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <string>
#include "ngraph/runtime/gpu/gpu_invoke.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
extern "C" void ngraph::runtime::gpu::invoke_primitive(GPURuntimeContext* ctx,
size_t primitive_index,
void** args,
void** result)
{
(*ctx->gpu_primitives[primitive_index])(args, result);
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <cstddef>
namespace ngraph
{
namespace runtime
{
namespace gpu
{
struct GPURuntimeContext;
extern "C" void invoke_primitive(GPURuntimeContext* ctx,
size_t primitive_index,
void** args,
void** result);
}
}
}
......@@ -209,7 +209,7 @@ void runtime::gpu::kernel::emit_cudnnReduceTensor(codegen::CodeWriter& writer,
writer << " CUDNN_REDUCE_TENSOR_NO_INDICES,\n";
writer << " CUDNN_32BIT_INDICES);\n";
writer << "size_t workspace_size = 0;\n";
writer << "cudnnGetReductionWorkspaceSize(cudnn_handle,\n";
writer << "cudnnGetReductionWorkspaceSize(*ctx->cudnn_handle,\n";
writer << " reduceTensorDesc,\n";
writer << " " << input_desc << ",\n";
writer << " " << output_desc << ",\n";
......@@ -217,7 +217,7 @@ void runtime::gpu::kernel::emit_cudnnReduceTensor(codegen::CodeWriter& writer,
writer << "void* workspace_ptr = "
"ngraph::runtime::gpu::create_gpu_buffer(workspace_size);\n";
writer << "float alpha = " << alpha << ", beta = " << beta << ";\n";
writer << "cudnnReduceTensor(cudnn_handle,\n";
writer << "cudnnReduceTensor(*ctx->cudnn_handle,\n";
writer << " reduceTensorDesc,\n";
writer << " nullptr,\n";
writer << " 0,\n";
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <limits>
#include "ngraph/runtime/gpu/cudnn_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
using namespace ngraph;
using namespace ngraph::runtime::gpu;
GPUPrimitiveEmitter::GPUPrimitiveEmitter()
: m_cuda_emitter(new CUDAEmitter(this))
, m_cudnn_emitter(new CUDNNEmitter(this))
{
}
GPUPrimitiveEmitter::~GPUPrimitiveEmitter()
{
for (auto& primitive : m_gpu_primitives)
{
delete primitive;
}
}
std::unique_ptr<CUDAEmitter>& GPUPrimitiveEmitter::get_cuda_emitter()
{
return m_cuda_emitter;
}
std::unique_ptr<CUDNNEmitter>& GPUPrimitiveEmitter::get_cudnn_emitter()
{
return m_cudnn_emitter;
}
size_t GPUPrimitiveEmitter::insert(gpu::primitive* f)
{
m_gpu_primitives.push_back(f);
return m_gpu_primitives.size() - 1;
}
size_t GPUPrimitiveEmitter::lookup(std::string hash)
{
if (m_primitive_map.count(hash) > 0)
{
return m_primitive_map[hash];
}
return std::numeric_limits<size_t>::max();
}
void GPUPrimitiveEmitter::cache(const std::string& hash, const size_t& index)
{
m_primitive_map.insert({hash, index});
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <functional>
#include <unordered_map>
#include "ngraph/runtime/gpu/cuda_emitter.hpp"
#include "ngraph/runtime/gpu/cudnn_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class CUDAEmitter;
class CUDNNEmitter;
class GPUPrimitiveEmitter
{
public:
GPUPrimitiveEmitter();
~GPUPrimitiveEmitter();
std::unique_ptr<CUDAEmitter>& get_cuda_emitter();
std::unique_ptr<CUDNNEmitter>& get_cudnn_emitter();
std::vector<gpu::primitive*>& get_primitives() { return m_gpu_primitives; }
size_t insert(gpu::primitive* f);
size_t lookup(std::string hash);
void cache(const std::string& hash, const size_t& index);
private:
std::unique_ptr<CUDAEmitter> m_cuda_emitter;
std::unique_ptr<CUDNNEmitter> m_cudnn_emitter;
std::vector<gpu::primitive*> m_gpu_primitives;
std::unordered_map<std::string, size_t> m_primitive_map;
};
}
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
using namespace ngraph;
using namespace ngraph::runtime::gpu;
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <string>
#include <unordered_map>
#include "ngraph/runtime/gpu/gpu_cuda_context_manager.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
typedef std::function<void(void**, void**)> primitive;
extern "C" {
struct GPURuntimeContext
{
cudnnHandle_t* cudnn_handle;
cublasHandle_t* cublas_handle;
gpu::primitive* const* gpu_primitives;
CudaFunctionPool* compiled_kernel_pool;
// Note that in it's current state, calling methods of CudaFunctionPool
// or other native compiled C++ functions in ngraph from the JIT code is
// unsafe and will fail if the GLIBCXX versions are diffent for the
// native compiler and clang. If all of the emitted CUDA ops are refactored
// to use the GPUPrimitiveEmitter, the above pointer can be removed. It is left now
// for backward compatability.
};
}
}
}
}
......@@ -52,6 +52,16 @@
} \
} while (0)
#define CUDNN_SAFE_CALL(func) \
{ \
cudnnStatus_t e = (func); \
if (e != CUDNN_STATUS_SUCCESS) \
{ \
auto msg = cudnnGetErrorString(e); \
throw std::runtime_error("\ncuDNN error: " + std::string(msg)); \
} \
}
namespace ngraph
{
namespace runtime
......
......@@ -32,7 +32,6 @@ using namespace ngraph;
TEST(${BACKEND_NAME}, backwards_maxpool_n4_c1_hw4_2x2_max)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
......@@ -79,7 +78,6 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4_c1_hw4_2x2_max)
TEST(${BACKEND_NAME}, backwards_maxpool_n2_c1_hw5_3x3_str2_max)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
......@@ -1595,7 +1593,6 @@ TEST(${BACKEND_NAME}, backwards_reverse_3d_02)
TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
Shape shape_a{4, 1, 4, 4}; //in NCHW
......@@ -1638,7 +1635,6 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1)
TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
......
......@@ -4059,7 +4059,6 @@ TEST(DISABLED_${BACKEND_NAME}, dot_4d_5d_multi_axis_big_fp64_VERY_SLOW)
TEST(${BACKEND_NAME}, max_pool_1d_1channel_1image)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{1, 1, 14};
Shape window_shape{3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -4082,7 +4081,6 @@ TEST(${BACKEND_NAME}, max_pool_1d_1channel_1image)
TEST(${BACKEND_NAME}, max_pool_1d_1channel_2image)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 1, 14};
Shape window_shape{3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -4109,7 +4107,6 @@ TEST(${BACKEND_NAME}, max_pool_1d_1channel_2image)
TEST(${BACKEND_NAME}, max_pool_1d_2channel_2image)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 2, 14};
Shape window_shape{3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -4141,7 +4138,6 @@ TEST(${BACKEND_NAME}, max_pool_1d_2channel_2image)
TEST(${BACKEND_NAME}, max_pool_2d_2channel_2image)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 2, 5, 5};
Shape window_shape{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
......@@ -4249,7 +4245,6 @@ TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_overpadded)
TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_padded)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
Shape shape_a{1, 1, 5, 5};
......@@ -4330,7 +4325,6 @@ TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_padded_negative_values)
TEST(${BACKEND_NAME}, max_pool_2d_1channel_1image_strided)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{1, 1, 8, 8};
Shape window_shape{2, 3};
auto window_movement_strides = Strides{3, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment