Commit 8476dea0 authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

Auto. gen. kernel signatures and argument expansion (#1326)

* Add GPUKernelArgs for storing kernel arguments.

* Formatting.

* Resolve tensor addresses when extracting arg list via GPUKernelArgs.

* Updated arg list resolution so that placeholder arguments can be added anywhere in the argument list.

* const ref. args and changed add_args to use add_arg. also expanded type_names map.

* GPUKernelArgs bug fix for return values.

* add_placeholders expects pointers for later resolution

* Formatting.

* Add comments to GPUKernelArgs

* Changed GPUKernelArgs interface to use a runtime variable number of arguments.

* Removed/updated comment.

* Address review comments: Remove combined address resolution and argument list retrieval. Remove unecessary extra type entries in type_map.

* Add space between pragma once and includes.

* Broadcast optimization (#1322)

* Implement GPUKernelArgs with op::Broadcast.

* Removed excess type insertion in kernel signature for broadcast impl.

* Support new auto kernel signature generation for op::Broadcast. Add boolean to helpers to determine if parameters are registers or arrays.

* Removed commented code.

* Update broadcast impl. for new GPUKernelArgs interface.

* Updated based on interface change to GPUKernelArgs.

* Formatting.

* CUDNNHostParameters now implement GPUHostParameters. (#1324)

* Formatting.
parent 69c51c27
......@@ -39,6 +39,7 @@ set(SRC
gpu_util.cpp
type_info.cpp
pass/tensor_memory_reservation.cpp
gpu_kernel_args.cpp
)
if (NGRAPH_GPU_ENABLE)
......
......@@ -1719,17 +1719,6 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const std::array<std::string,
return primitive_index;
}
// if the kernel has not been compiled, build it
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
writer << include_helpers();
runtime::gpu::CudaKernelBuilder::get_broadcast_op(
writer, kernel_name, dtypes, result_shape.size());
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
// calculate strides
GPUShape strides = row_major_strides(result_shape);
// precacluate invariants for integer division via multiplication
......@@ -1755,15 +1744,6 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const std::array<std::string,
reduced_strides[axis] = 0;
}
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_strides = allocator.reserve_argspace(strides.data(), strides.size() * sizeof(int));
size_t idx_stride_magic =
allocator.reserve_argspace(stride_magic.data(), stride_magic.size() * sizeof(int));
size_t idx_stride_shift =
allocator.reserve_argspace(stride_shift.data(), stride_shift.size() * sizeof(int));
size_t idx_reduced_strides =
allocator.reserve_argspace(reduced_strides.data(), reduced_strides.size() * sizeof(int));
// TODO: blending factors are not currently implemented
float alpha = 1.0f;
float beta = 0.0f;
......@@ -1774,36 +1754,48 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const std::array<std::string,
uint32_t aligned_grid_size_x =
align_to_block_size(static_cast<uint32_t>(nthreads), block_size_x);
std::unique_ptr<gpu::primitive> broadcast(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* strides_d = runtime::gpu::invoke_memory_primitive(m_ctx, idx_strides);
void* stride_magic_d = runtime::gpu::invoke_memory_primitive(m_ctx, idx_stride_magic);
void* stride_shift_d = runtime::gpu::invoke_memory_primitive(m_ctx, idx_stride_shift);
void* reduced_strides_d = runtime::gpu::invoke_memory_primitive(m_ctx, idx_reduced_strides);
auto args = this->m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in")
.add_placeholder(dtypes[1], "out")
.add("strides", strides)
.add("stride_magic", stride_magic)
.add("stride_shift", stride_shift)
.add("reduced_strides", reduced_strides)
.add("alpha", alpha)
.add("beta", beta)
.add("nthreads", nthreads);
void* args_list[] = {&inputs[0],
&outputs[0],
&strides_d,
&stride_magic_d,
&stride_shift_d,
&reduced_strides_d,
&alpha,
&beta,
&nthreads};
// if the kernel has not been compiled, build it
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
writer << include_helpers();
runtime::gpu::CudaKernelBuilder::get_kernel_signature(
writer, kernel_name, args.get_input_signature());
runtime::gpu::CudaKernelBuilder::get_broadcast_op(writer, result_shape.size());
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1,
block_size_x,
1,
1,
0,
NULL,
args_list,
0));
debug_sync();
}});
std::unique_ptr<gpu::primitive> broadcast(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1,
block_size_x,
1,
1,
0,
NULL,
args_list,
0));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(broadcast));
m_primitive_emitter->cache(hash, primitive_index);
......
......@@ -109,8 +109,11 @@ std::vector<int>
return low_vec;
}
runtime::gpu::CUDNNEmitter::CUDNNEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx)
: m_primitive_emitter(emitter)
runtime::gpu::CUDNNEmitter::CUDNNEmitter(GPUPrimitiveEmitter* emitter,
GPURuntimeContext* ctx,
std::shared_ptr<GPUHostParameters> params)
: m_host_parameters(params)
, m_primitive_emitter(emitter)
{
m_ctx = ctx;
}
......
......@@ -17,6 +17,7 @@
#pragma once
#include <functional>
#include <memory>
#include <vector>
#include <cublas_v2.h>
......@@ -132,7 +133,9 @@ namespace ngraph
void sync();
private:
CUDNNEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx);
CUDNNEmitter(GPUPrimitiveEmitter* emitter,
GPURuntimeContext* ctx,
std::shared_ptr<GPUHostParameters> params);
void* get_data_by_type(cudnnDataType_t data_type, double value);
......
......@@ -21,6 +21,7 @@
#include <cudnn.h>
#include "ngraph/log.hpp"
#include "ngraph/runtime/gpu/gpu_host_parameters.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
namespace ngraph
......@@ -34,7 +35,10 @@ namespace ngraph
class CUDNNHostParameters
{
public:
CUDNNHostParameters() = default;
CUDNNHostParameters(const std::shared_ptr<GPUHostParameters> params)
: m_host_parameters(params)
{
}
~CUDNNHostParameters() = default;
void* allocate_by_datatype(const cudnnDataType_t data_type, const double value)
......@@ -43,20 +47,16 @@ namespace ngraph
switch (data_type)
{
case CUDNN_DATA_FLOAT:
m_host_parameters_float.push_back(static_cast<float>(value));
r = static_cast<void*>(&m_host_parameters_float.back());
r = m_host_parameters->cache(static_cast<float>(value));
break;
case CUDNN_DATA_DOUBLE:
m_host_parameters_double.push_back(value);
r = static_cast<void*>(&m_host_parameters_double.back());
r = m_host_parameters->cache(static_cast<double>(value));
break;
case CUDNN_DATA_INT8:
m_host_parameters_int8_t.push_back(static_cast<int8_t>(value));
r = static_cast<void*>(&m_host_parameters_int8_t.back());
r = m_host_parameters->cache(static_cast<int8_t>(value));
break;
case CUDNN_DATA_INT32:
m_host_parameters_int32_t.push_back(static_cast<int32_t>(value));
r = static_cast<void*>(&m_host_parameters_int32_t.back());
r = m_host_parameters->cache(static_cast<int32_t>(value));
break;
case CUDNN_DATA_HALF:
case CUDNN_DATA_INT8x4:
......@@ -71,10 +71,7 @@ namespace ngraph
}
private:
std::list<int8_t> m_host_parameters_int8_t;
std::list<int32_t> m_host_parameters_int32_t;
std::list<float> m_host_parameters_float;
std::list<double> m_host_parameters_double;
std::shared_ptr<GPUHostParameters> m_host_parameters;
};
}
}
......
......@@ -147,19 +147,8 @@ void runtime::gpu::CudaKernelBuilder::get_ew_collective_op(
}
void runtime::gpu::CudaKernelBuilder::get_broadcast_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types,
const size_t rank)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, "
<< "int* strides, "
<< "int* stride_magic, "
<< "int* stride_shift, "
<< "int* reduced_strides, "
<< "float alpha, float beta, "
<< "size_t nthreads"
<< ")\n";
writer.block_begin();
{
writer << "const int tid = blockDim.x*blockIdx.x + threadIdx.x;\n";
......@@ -174,7 +163,8 @@ void runtime::gpu::CudaKernelBuilder::get_broadcast_op(codegen::CodeWriter& writ
"stride_shift",
"reduced_strides",
"coordinate",
rank);
rank,
true);
writer << "out[tid] = load(in, " << reduced_idx << ");\n";
}
writer.block_end();
......@@ -1178,8 +1168,12 @@ void runtime::gpu::CudaKernelBuilder::coordinate_transform_to_multi_d(codegen::C
std::string i_stride_shift,
std::string i_coord_product,
std::string o_coordinates,
size_t rank)
size_t rank,
bool register_arguments)
{
std::string brace_open = (register_arguments) ? "" : "[";
std::string brace_close = (register_arguments) ? "" : "]";
// Translation from flat index to dense tensor coordinates:
// Given tensor shape [d0 d1 ... dN] with strides [d1*...*dN, d2*...*dN, ... 1],
// calculate coordinates as:
......@@ -1195,11 +1189,11 @@ void runtime::gpu::CudaKernelBuilder::coordinate_transform_to_multi_d(codegen::C
if (i != 0)
{
writer << "coordinate_product -= (" << o_coordinates << i - 1 << " * " << i_strides
<< "[" << i - 1 << "]);\n";
<< brace_open << i - 1 << brace_close << ");\n";
}
writer << "int " << o_coordinates << i << " = division_by_invariant_multiplication("
<< "coordinate_product, " << i_stride_magic << "[" << i << "], " << i_stride_shift
<< "[" << i << "]);\n";
<< "coordinate_product, " << i_stride_magic << brace_open << i << brace_close << ", "
<< i_stride_shift << brace_open << i << brace_close << ");\n";
}
}
std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_helper(
......@@ -1210,22 +1204,33 @@ std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_hel
std::string i_stride_shift,
std::string i_reduced_strides,
std::string o_coordinates,
size_t rank)
size_t rank,
bool register_arguments)
{
coordinate_transform_to_multi_d(
writer, i_strides, i_stride_magic, i_stride_shift, i_thread_index, o_coordinates, rank);
coordinate_transform_to_multi_d(writer,
i_strides,
i_stride_magic,
i_stride_shift,
i_thread_index,
o_coordinates,
rank,
register_arguments);
std::string brace_open = (register_arguments) ? "" : "[";
std::string brace_close = (register_arguments) ? "" : "]";
// index into reduced tensor from coordinates of non-reduced tensor
std::string reduced_idx = "reduced_idx";
writer << "int " << reduced_idx << " = 0;\n";
for (size_t i = 0; i < rank; i++)
{
writer << "reduced_idx += " << o_coordinates << i << " * " << i_reduced_strides << "[" << i
<< "];\n";
writer << "reduced_idx += " << o_coordinates << i << " * " << i_reduced_strides
<< brace_open << i << brace_close << ";\n";
}
return reduced_idx;
}
void runtime::gpu::CudaKernelBuilder::get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& math_kernel,
......
......@@ -34,15 +34,20 @@ namespace ngraph
class CudaKernelBuilder
{
public:
static void get_kernel_signature(codegen::CodeWriter& writer,
const std::string& name,
const std::string& input_signature)
{
writer << "extern \"C\" __global__ void cuda_" << name;
writer << input_signature;
}
static void get_elementwise_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& op,
const std::vector<std::string>& data_types);
static void get_broadcast_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types,
const size_t rank);
static void get_broadcast_op(codegen::CodeWriter& writer, const size_t rank);
static void get_concat_op(codegen::CodeWriter& writer,
const std::string& name,
......@@ -138,14 +143,16 @@ namespace ngraph
std::string i_stride_shift,
std::string i_reduced_strides,
std::string o_coordinates,
size_t rank);
size_t rank,
bool register_arguments = false);
static void coordinate_transform_to_multi_d(codegen::CodeWriter& writer,
std::string i_strides,
std::string i_stride_magic,
std::string i_stride_shift,
std::string i_coord_product,
std::string o_coordinates,
size_t rank);
size_t rank,
bool register_arguments = false);
};
}
}
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <list>
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPUHostParameters
{
public:
GPUHostParameters() = default;
void* cache(const char& value)
{
m_char_params.push_back(value);
return &m_char_params.back();
}
void* cache(const float& value)
{
m_float_params.push_back(value);
return &m_float_params.back();
}
void* cache(const double& value)
{
m_double_params.push_back(value);
return &m_double_params.back();
}
void* cache(const int8_t& value)
{
m_int8_t_params.push_back(value);
return &m_int8_t_params.back();
}
void* cache(const int16_t& value)
{
m_int16_t_params.push_back(value);
return &m_int16_t_params.back();
}
void* cache(const int32_t& value)
{
m_int32_t_params.push_back(value);
return &m_int32_t_params.back();
}
void* cache(const int64_t& value)
{
m_int64_t_params.push_back(value);
return &m_int64_t_params.back();
}
void* cache(const uint8_t& value)
{
m_uint8_t_params.push_back(value);
return &m_uint8_t_params.back();
}
void* cache(const uint16_t& value)
{
m_uint16_t_params.push_back(value);
return &m_uint16_t_params.back();
}
void* cache(const uint32_t& value)
{
m_uint32_t_params.push_back(value);
return &m_uint32_t_params.back();
}
void* cache(const uint64_t& value)
{
m_uint64_t_params.push_back(value);
return &m_uint64_t_params.back();
}
private:
std::list<char> m_char_params;
std::list<float> m_float_params;
std::list<double> m_double_params;
std::list<int8_t> m_int8_t_params;
std::list<int16_t> m_int16_t_params;
std::list<int32_t> m_int32_t_params;
std::list<int64_t> m_int64_t_params;
std::list<uint8_t> m_uint8_t_params;
std::list<uint16_t> m_uint16_t_params;
std::list<uint32_t> m_uint32_t_params;
std::list<uint64_t> m_uint64_t_params;
};
}
}
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <string>
#include "ngraph/runtime/gpu/gpu_kernel_args.hpp"
#define TI(x) std::type_index(typeid(x))
using namespace ngraph;
const std::unordered_map<std::type_index, std::string> runtime::gpu::GPUKernelArgs::type_names = {
{TI(size_t), "size_t"},
{TI(char), "char"},
{TI(float), "float"},
{TI(double), "double"},
{TI(int8_t), "int8_t"},
{TI(int16_t), "int16_t"},
{TI(int32_t), "int32_t"},
{TI(int64_t), "int64_t"},
{TI(uint8_t), "uint8_t"},
{TI(uint16_t), "uint16_t"},
{TI(uint32_t), "uint32_t"},
{TI(uint64_t), "uint64_t"}};
runtime::gpu::GPUKernelArgs::GPUKernelArgs(const std::shared_ptr<GPUHostParameters>& params)
: m_signature_generated(false)
, m_host_parameters(params)
{
m_input_signature << "(";
}
runtime::gpu::GPUKernelArgs::GPUKernelArgs(const GPUKernelArgs& args)
{
m_signature_generated = args.m_signature_generated;
m_argument_list = args.m_argument_list;
m_placeholder_positions = args.m_placeholder_positions;
m_input_signature << args.m_input_signature.str();
m_host_parameters = args.m_host_parameters;
}
void runtime::gpu::GPUKernelArgs::validate()
{
if (m_signature_generated)
{
throw std::runtime_error(
"Kernel input signature already generated. Adding additional kernel arguments has no "
"effect.");
}
}
void runtime::gpu::GPUKernelArgs::add_to_signature(const std::string& type, const std::string& name)
{
if (m_input_signature.str() == "(")
{
m_input_signature << type << " " << name;
}
else
{
m_input_signature << ", " << type << " " << name;
}
}
runtime::gpu::GPUKernelArgs& runtime::gpu::GPUKernelArgs::add_placeholder(const std::string& type,
const std::string& name)
{
validate();
m_argument_list.push_back(nullptr);
m_placeholder_positions.push_back(true);
add_to_signature(type + "*", name);
return *this;
}
runtime::gpu::GPUKernelArgs& runtime::gpu::GPUKernelArgs::resolve_placeholder(size_t arg_num,
void* address)
{
if (m_placeholder_positions.at(arg_num))
{
m_argument_list[arg_num] = address;
}
else
{
throw std::runtime_error("Resolution of specified non-placeholder argument is unallowed.");
}
return *this;
}
std::string runtime::gpu::GPUKernelArgs::get_input_signature()
{
if (m_signature_generated == false)
{
m_signature_generated = true;
m_input_signature << ")";
}
return m_input_signature.str();
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include <sstream>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <vector>
#include "ngraph/runtime/gpu/gpu_host_parameters.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
template <typename T>
struct has_const_iterator;
template <typename T>
struct is_container;
class GPUKernelArgs
{
public:
GPUKernelArgs(const std::shared_ptr<GPUHostParameters>& params);
GPUKernelArgs(const GPUKernelArgs& args);
//
// Add a placeholder parameter for a tensor pointer which will be resolved at runtime.
//
GPUKernelArgs& add_placeholder(const std::string& type, const std::string& name);
//
// Add a POD argument to the kernel signature and argument list.
//
template <typename T>
typename std::enable_if<!is_container<T>::value, GPUKernelArgs&>::type
add(const std::string& name, const T& arg)
{
return add_argument(name, arg);
}
//
// Add POD arguments as above, but by expanding the array arguments and
// and adding each individual arg as kernel register arguments.
//
template <typename T>
typename std::enable_if<is_container<T>::value, GPUKernelArgs&>::type
add(const std::string& name, const T& arg)
{
return add_arguments(name, arg);
}
//
// Retrieve the kernel argument list for use with the launch primitive.
//
void** get_argument_list() { return m_argument_list.data(); }
//
// Replace placeholder argument with specifed address.
//
GPUKernelArgs& resolve_placeholder(size_t arg_num, void* address);
//
// Retrieve the kernel parameter signature given the added kernel arguments.
//
std::string get_input_signature();
private:
//
// Cache the host argument for persistence, add it to the argument list,
// and add its signature to the kernel input signature.
//
template <typename T>
GPUKernelArgs& add_argument(const std::string& name, const T& arg)
{
validate();
void* host_arg = m_host_parameters->cache(arg);
m_argument_list.push_back(host_arg);
m_placeholder_positions.push_back(false);
add_to_signature(type_names.at(std::type_index(typeid(T))), name);
return *this;
}
//
// Same as above for a container type T.
//
template <typename T>
GPUKernelArgs& add_arguments(const std::string& name, const T& args)
{
validate();
size_t i = 0;
for (auto const& arg : args)
{
add_argument(name + std::to_string(i++), arg);
}
return *this;
}
void validate();
//
// Given an input argument type and name, add it to the kernel parameter signature.
//
void add_to_signature(const std::string& type, const std::string& name);
private:
bool m_signature_generated;
std::vector<void*> m_argument_list;
std::vector<bool> m_placeholder_positions;
std::stringstream m_input_signature;
std::shared_ptr<GPUHostParameters> m_host_parameters;
static const std::unordered_map<std::type_index, std::string> type_names;
};
//
// Helper structs to deduce whether a type is iterable.
//
template <typename T>
struct has_const_iterator
{
private:
typedef struct
{
char x;
} true_type;
typedef struct
{
char x[2];
} false_type;
template <typename U>
static true_type check(typename U::const_iterator*);
template <typename U>
static false_type check(...);
public:
static const bool value = sizeof(check<T>(0)) == sizeof(true_type);
typedef T type;
};
template <typename T>
struct is_container : std::integral_constant<bool, has_const_iterator<T>::value>
{
};
}
}
}
......@@ -23,16 +23,19 @@ using namespace ngraph;
using namespace ngraph::runtime::gpu;
GPUPrimitiveEmitter::GPUPrimitiveEmitter()
: m_cuda_emitter(new CUDAEmitter(this, nullptr))
, m_cudnn_emitter(new CUDNNEmitter(this, nullptr))
, m_memory_manager(this)
: m_memory_manager(this)
, m_host_parameters(new GPUHostParameters)
, m_cuda_emitter(new CUDAEmitter(this, nullptr))
, m_cudnn_emitter(new CUDNNEmitter(this, nullptr, nullptr))
{
}
GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx)
: m_cuda_emitter(new CUDAEmitter(this, ctx.get()))
, m_cudnn_emitter(new CUDNNEmitter(this, ctx.get()))
, m_memory_manager(this)
: m_memory_manager(this)
, m_host_parameters(new GPUHostParameters)
, m_cuda_emitter(new CUDAEmitter(this, ctx.get()))
, m_cudnn_emitter(new CUDNNEmitter(this, ctx.get(), this->m_host_parameters))
{
}
......
......@@ -20,6 +20,7 @@
#include "ngraph/runtime/gpu/cuda_emitter.hpp"
#include "ngraph/runtime/gpu/cudnn_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_kernel_args.hpp"
#include "ngraph/runtime/gpu/gpu_memory_manager.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
......@@ -50,14 +51,16 @@ namespace ngraph
GPUAllocator get_memory_allocator() { return m_memory_manager.build_allocator(); }
void allocate_primitive_memory() { m_memory_manager.allocate(); }
size_t sizeof_device_allocation() { return m_memory_manager.get_allocation_size(); }
GPUKernelArgs add_kernel_args() { return GPUKernelArgs(m_host_parameters); }
private:
std::unique_ptr<CUDAEmitter> m_cuda_emitter;
std::unique_ptr<CUDNNEmitter> m_cudnn_emitter;
std::vector<gpu::primitive*> m_gpu_primitives;
std::vector<gpu::memory_primitive> m_gpu_mem_primitives;
std::unordered_map<std::string, size_t> m_primitive_map;
std::vector<std::unique_ptr<gpu::primitive>> m_managed_primitives;
GPUMemoryManager m_memory_manager;
std::shared_ptr<GPUHostParameters> m_host_parameters;
std::unique_ptr<CUDAEmitter> m_cuda_emitter;
std::unique_ptr<CUDNNEmitter> m_cudnn_emitter;
};
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment