Commit 7ac35345 authored by Ayan Moitra's avatar Ayan Moitra Committed by Robert Kimball

cublas emitter for NVGPU backend (#1705)

* cublas emitter

* clang format fixes

* Initial comment incorporation from Chris

* Chris's If-else change comment incorporation

* incorporating Bob's comments phase 1

*  Remove unnecessary headers in cublas emitter hpp & cpp (as per Bob's comments)

* clang format on previous commit

* incorporate fenglei's refactoring comment

* incorporating comments

* Incorporate Chris's final comment

* All comments resolved

* Resolve Geoff's comments

* Change cache_primitive to register_primitive
parent 0e008cc5
......@@ -21,6 +21,7 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIRS})
set(SRC
cuda_emitter.cpp
cudnn_emitter.cpp
cublas_emitter.cpp
gpu_backend.cpp
gpu_cuda_context_manager.cpp
gpu_cuda_function_builder.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/gpu/cublas_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
runtime::gpu::CUBLASEmitter::CUBLASEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx)
: m_primitive_emitter(emitter)
{
m_ctx = ctx;
}
size_t runtime::gpu::CUBLASEmitter::build_dot(const element::Type& dtype,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& out_shape,
size_t reduction_axes,
const Node* node)
{
std::stringstream ss;
ss << "dot_op"
<< "_dtype_" << dtype.c_type_string() << "_reduction_axes_count_" << reduction_axes;
std::string hash = ss.str() + "_i_" + join(arg0_shape, "_") + "_i_" + join(arg1_shape, "_");
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
std::unique_ptr<gpu::primitive> dot;
if (arg0_shape.empty() || arg1_shape.empty())
{
auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape);
size_t count = shape_size(second);
size_t firstIndex = (arg0_shape.empty() ? 0 : 1);
size_t secondIndex = (arg0_shape.empty() ? 1 : 0);
dot.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
CUBLAS_SAFE_CALL(cublasScopy(*m_ctx->cublas_handle,
count,
static_cast<const float*>(inputs[secondIndex]),
1,
static_cast<float*>(outputs[0]),
1));
CUBLAS_SAFE_CALL(cublasSscal(*m_ctx->cublas_handle,
count,
static_cast<const float*>(inputs[firstIndex]),
static_cast<float*>(outputs[0]),
1));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->register_primitive(dot, hash);
}
// case that can be treat as dot1d
else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes))
{
for (int i = 0; i < arg0_shape.size(); i++)
{
if (arg0_shape[i] != arg1_shape[i])
{
std::vector<std::string> arg_vec{"arg0", "arg1"};
std::vector<Shape> shape_vec{arg0_shape, arg1_shape};
throw std::invalid_argument(get_error_string(arg_vec, shape_vec, node));
}
}
size_t count = shape_size(arg0_shape);
dot.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
CUBLAS_SAFE_CALL(cublasSdot(*m_ctx->cublas_handle,
count,
static_cast<const float*>(inputs[0]),
1,
static_cast<const float*>(inputs[1]),
1,
static_cast<float*>(outputs[0])));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->register_primitive(dot, hash);
}
// matrix vector
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1))
{
dot.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
const float alpha = 1.0;
const float beta = 0;
CUBLAS_SAFE_CALL(cublasSetPointerMode(*m_ctx->cublas_handle, CUBLAS_POINTER_MODE_HOST));
CUBLAS_SAFE_CALL(cublasSgemv(*m_ctx->cublas_handle,
CUBLAS_OP_T,
arg0_shape[1],
arg0_shape[0],
&alpha,
static_cast<const float*>(inputs[0]),
arg0_shape[1],
static_cast<const float*>(inputs[1]),
1,
&beta,
static_cast<float*>(outputs[0]),
1));
CUBLAS_SAFE_CALL(
cublasSetPointerMode(*m_ctx->cublas_handle, CUBLAS_POINTER_MODE_DEVICE));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->register_primitive(dot, hash);
}
else
{
size_t axes_for_m_count = arg0_shape.size() - reduction_axes;
size_t axes_for_n_count = arg1_shape.size() - reduction_axes;
size_t axes_for_k_count = reduction_axes;
size_t m = 1;
size_t n = 1;
size_t k = 1;
// check if input and output size correct
// check and calculate k for arg0 and arg1
size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k
size_t arg1_k_idx = 0; // first axe in arg1 for k
for (size_t i = 0; i < axes_for_k_count; i++)
{
k *= arg0_shape[arg0_k_idx];
if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++])
{
std::vector<std::string> arg_vec{"arg0", "arg1"};
std::vector<Shape> shape_vec{arg0_shape, arg1_shape};
throw std::invalid_argument(get_error_string(arg_vec, shape_vec, node));
}
}
// check and calculate m for arg0 and out
size_t arg0_m_idx = 0; // first axe in arg0 for m
size_t out_m_idx = 0; // first axe in out for m
for (size_t i = 0; i < axes_for_m_count; i++)
{
m *= arg0_shape[arg0_m_idx];
if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++])
{
std::vector<std::string> arg_vec{"arg0", "output"};
std::vector<Shape> shape_vec{arg0_shape, out_shape};
throw std::invalid_argument(get_error_string(arg_vec, shape_vec, node));
}
}
// check and calculate n for arg1 and out
size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n
size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n
for (size_t i = 0; i < axes_for_n_count; i++)
{
n *= arg1_shape[arg1_n_idx];
if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++])
{
std::vector<std::string> arg_vec{"arg1", "output"};
std::vector<Shape> shape_vec{arg1_shape, out_shape};
throw std::invalid_argument(get_error_string(arg_vec, shape_vec, node));
}
}
dot.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
const float alpha = 1.0;
const float beta = 0;
CUBLAS_SAFE_CALL(cublasSetPointerMode(*m_ctx->cublas_handle, CUBLAS_POINTER_MODE_HOST));
CUBLAS_SAFE_CALL(cublasSgemm(*m_ctx->cublas_handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
n,
m,
k,
&alpha,
static_cast<const float*>(inputs[1]),
n,
static_cast<const float*>(inputs[0]),
k,
&beta,
static_cast<float*>(outputs[0]),
n));
CUBLAS_SAFE_CALL(
cublasSetPointerMode(*m_ctx->cublas_handle, CUBLAS_POINTER_MODE_DEVICE));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->register_primitive(dot, hash);
}
return primitive_index;
}
void runtime::gpu::CUBLASEmitter::sync()
{
CUDA_RT_SAFE_CALL(cudaDeviceSynchronize());
return;
}
void runtime::gpu::CUBLASEmitter::debug_sync()
{
#ifdef NGRAPH_DEBUG_ENABLE
CUDA_RT_SAFE_CALL(cudaDeviceSynchronize());
#endif
return;
}
std::string runtime::gpu::CUBLASEmitter::get_error_string(std::vector<std::string>& arg_names,
std::vector<Shape>& shapes,
const Node* node)
{
std::stringstream ss_err;
ss_err << ngraph::join(arg_names) << " with " << ngraph::join(shapes)
<< " respectively, at Node " << node->get_name() << ", do not match for dot op";
return ss_err.str();
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cublas_v2.h>
#include "ngraph/op/dot.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPUPrimitiveEmitter;
class CUBLASEmitter
{
friend class GPUPrimitiveEmitter;
public:
size_t build_dot(const element::Type& dtype,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& out_shape,
size_t reduction_axes,
const Node* node);
void debug_sync();
void sync();
private:
CUBLASEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx);
GPUPrimitiveEmitter* m_primitive_emitter;
GPURuntimeContext* m_ctx;
std::string get_error_string(std::vector<std::string>& arg_names,
std::vector<Shape>& shapes,
const Node* node);
};
}
}
}
This diff is collapsed.
......@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iostream>
#include <sstream>
......@@ -198,9 +199,7 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorO
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(reduce));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(reduce, hash);
}
size_t runtime::gpu::CUDNNEmitter::build_tensor_op(const cudnnOpTensorOp_t& tensor_op,
......@@ -250,9 +249,7 @@ size_t runtime::gpu::CUDNNEmitter::build_tensor_op(const cudnnOpTensorOp_t& tens
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(tensor));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(tensor, hash);
}
cudnnFilterDescriptor_t& runtime::gpu::CUDNNEmitter::get_cudnn_filter_descriptor(
......@@ -440,9 +437,8 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Convolution* node)
gpu::invoke_primitive(m_ctx, conv_index, inputs, outputs);
}
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(kernel_launch, hash);
}
size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::ConvolutionBackpropData* node)
......@@ -573,9 +569,7 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::ConvolutionBackprop
}
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(kernel_launch, hash);
}
size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::ConvolutionBackpropFilters* node)
......@@ -686,9 +680,7 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::ConvolutionBackprop
}
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(kernel_launch, hash);
}
size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::MaxPool* node)
......@@ -794,9 +786,7 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::MaxPool* node)
}
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(kernel_launch, hash);
}
size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Max* node)
......@@ -855,9 +845,7 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Max* node)
}});
}
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(kernel_launch, hash);
}
size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Min* node)
......@@ -916,9 +904,7 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Min* node)
}});
}
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(kernel_launch, hash);
}
size_t runtime::gpu::CUDNNEmitter::build_convolution(const std::string& dtype,
......@@ -1270,9 +1256,7 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_
}
}
primitive_index = this->m_primitive_emitter->insert(std::move(pool));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(pool, hash);
}
size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& bn_op,
......@@ -1416,9 +1400,7 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
}
}
primitive_index = this->m_primitive_emitter->insert(std::move(batchnorm));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(batchnorm, hash);
}
size_t runtime::gpu::CUDNNEmitter::build_softmax(const cudnnSoftmaxAlgorithm_t& algorithm,
......@@ -1486,9 +1468,7 @@ size_t runtime::gpu::CUDNNEmitter::build_softmax(const cudnnSoftmaxAlgorithm_t&
}
}
primitive_index = this->m_primitive_emitter->insert(std::move(softmax));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
return this->m_primitive_emitter->register_primitive(softmax, hash);
}
void runtime::gpu::CUDNNEmitter::sync()
......
......@@ -519,154 +519,38 @@ void runtime::gpu::GPU_Emitter::emit_Dot(EMIT_ARGS)
{
return;
}
const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(node);
auto dot = static_cast<const ngraph::op::Dot*>(node);
size_t reduction_axes_count = dot->get_reduction_axes_count();
const Shape& arg0_shape = args[0].get_shape();
const Shape& arg1_shape = args[1].get_shape();
const Shape& out_shape = out[0].get_shape();
if (arg0_shape.empty() || arg1_shape.empty())
{
auto& first = (arg0_shape.empty() ? args[0] : args[1]);
auto& second = (arg0_shape.empty() ? args[1] : args[0]);
writer.block_begin();
writer << "int count = " << second.get_size() << ";\n";
writer << "CUBLAS_SAFE_CALL(cublasScopy("
<< "*ctx->cublas_handle,"
<< "count ," << second.get_name() << ","
<< "1," << out[0].get_name() << ", 1));\n";
writer << "CUBLAS_SAFE_CALL(cublasSscal("
<< "*ctx->cublas_handle,"
<< "count ," << first.get_name() << "," << out[0].get_name() << ", 1));\n";
writer.block_end();
return;
}
// set output to 0 if input size is 0
if (args[0].get_size() == 0 || args[1].get_size() == 0)
{
writer.block_begin();
writer << "runtime::gpu::cuda_memset(" << out[0].get_name() << ", 0, " << out[0].get_size()
<< " * " << out[0].get_element_type().size() << ");\n";
writer.block_end();
return;
}
// case that can be treat as dot1d
if ((arg0_shape.size() == arg1_shape.size()) &&
(arg0_shape.size() == dot->get_reduction_axes_count()))
writer.block_begin();
{
for (int i = 0; i < arg0_shape.size(); i++)
{
if (arg0_shape[i] != arg1_shape[i])
{
throw invalid_argument("arg0 and arg1 shape does not match for dot.");
}
}
writer.block_begin();
writer << "CUBLAS_SAFE_CALL(cublasSdot("
<< "*ctx->cublas_handle," << args[0].get_size() << "," << args[0].get_name() << ","
<< "1," << args[1].get_name() << ","
<< "1," << out[0].get_name() << "));\n";
writer.block_end();
}
// matrix vector
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
(dot->get_reduction_axes_count() == 1))
{
writer.block_begin();
writer << "const float alpha = 1.0;\n";
writer << "const float beta = 0;\n";
writer << "CUBLAS_SAFE_CALL(cublasSetPointerMode(*ctx->cublas_handle, "
"CUBLAS_POINTER_MODE_HOST));\n";
writer << "CUBLAS_SAFE_CALL(cublasSgemv("
<< "*ctx->cublas_handle,"
<< "CUBLAS_OP_T," << arg0_shape[1] << "," << arg0_shape[0] << ","
<< "&alpha," // Alpha
<< args[0].get_name() << "," << arg0_shape[1] << "," << args[1].get_name() << ","
<< "1,"
<< "&beta," // beta
<< out[0].get_name() << ","
<< "1));\n";
writer << "CUBLAS_SAFE_CALL(cublasSetPointerMode(*ctx->cublas_handle, "
"CUBLAS_POINTER_MODE_DEVICE));\n";
writer.block_end();
}
// cases that can be treat as matrix multiply
else
{
// treat as out[m,n] = arg0[m,k] * arg1[k,n]
size_t reduction_axes = dot->get_reduction_axes_count();
size_t num_of_axes_for_m = arg0_shape.size() - reduction_axes;
size_t num_of_axes_for_n = arg1_shape.size() - reduction_axes;
size_t num_of_axes_for_k = reduction_axes;
size_t m = 1;
size_t n = 1;
size_t k = 1;
// check if input and output size correct
// check and calculate k for arg0 and arg1
size_t arg0_k_idx = num_of_axes_for_m; // first axe in arg0 for k
size_t arg1_k_idx = 0; // first axe in arg1 for k
for (size_t i = 0; i < num_of_axes_for_k; i++)
{
k *= arg0_shape[arg0_k_idx];
if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++])
{
throw invalid_argument("arg0 and arg1 shape does not match for dot.");
}
}
// check and calculate m for arg0 and out
size_t arg0_m_idx = 0; // first axe in arg0 for m
size_t out_m_idx = 0; // first axe in out for m
for (size_t i = 0; i < num_of_axes_for_m; i++)
// set output to 0 if input size is 0
if (args[0].get_size() == 0 || args[1].get_size() == 0)
{
m *= arg0_shape[arg0_m_idx];
if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++])
{
throw invalid_argument("arg0 and output shape does not match for dot.");
}
writer << "runtime::gpu::cuda_memset(" << out[0].get_name() << ", 0, "
<< out[0].get_size() << " * " << out[0].get_element_type().size() << ");\n";
}
// check and calculate n for arg1 and out
size_t arg1_n_idx = num_of_axes_for_k; // first axe in arg1 for n
size_t out_n_idx = num_of_axes_for_m; // first axe in arg1 for n
for (size_t i = 0; i < num_of_axes_for_n; i++)
else
{
n *= arg1_shape[arg1_n_idx];
if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++])
{
throw invalid_argument("arg1 and output shape does not match for dot.");
}
}
auto& cublas_emitter = external_function->get_primitive_emitter()->get_cublas_emitter();
// GEMM Call
writer.block_begin();
writer << "const float alpha = 1.0;\n";
writer << "const float beta = 0.0;\n";
writer << "int m = " << m << ";\n";
writer << "int n = " << n << ";\n";
writer << "int k = " << k << ";\n";
writer << "CUBLAS_SAFE_CALL(cublasSetPointerMode(*ctx->cublas_handle, "
"CUBLAS_POINTER_MODE_HOST));\n";
writer << "CUBLAS_SAFE_CALL(cublasSgemm("
<< "*ctx->cublas_handle,"
<< "CUBLAS_OP_N,"
<< "CUBLAS_OP_N,"
<< "n,"
<< "m,"
<< "k,"
<< "&alpha," // Alpha
<< args[1].get_name() << ","
<< "n," << args[0].get_name() << ","
<< "k,"
<< "&beta," // beta
<< out[0].get_name() << ","
<< "n));\n";
writer << "CUBLAS_SAFE_CALL(cublasSetPointerMode(*ctx->cublas_handle, "
"CUBLAS_POINTER_MODE_DEVICE));\n";
writer.block_end();
auto index = cublas_emitter->build_dot(out[0].get_element_type(),
arg0_shape,
arg1_shape,
out_shape,
reduction_axes_count,
node);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
}
writer.block_end();
}
void runtime::gpu::GPU_Emitter::emit_Equal(EMIT_ARGS)
......
......@@ -16,7 +16,6 @@
#include <limits>
#include "ngraph/runtime/gpu/cudnn_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
using namespace ngraph;
......@@ -27,6 +26,7 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter()
, m_host_parameters(new GPUHostParameters)
, m_cuda_emitter(new CUDAEmitter(this, nullptr))
, m_cudnn_emitter(new CUDNNEmitter(this, nullptr, nullptr))
, m_cublas_emitter(new CUBLASEmitter(this, nullptr))
{
}
......@@ -35,6 +35,7 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext
, m_host_parameters(new GPUHostParameters)
, m_cuda_emitter(new CUDAEmitter(this, ctx.get()))
, m_cudnn_emitter(new CUDNNEmitter(this, ctx.get(), this->m_host_parameters))
, m_cublas_emitter(new CUBLASEmitter(this, ctx.get()))
{
}
......@@ -47,6 +48,11 @@ std::unique_ptr<CUDNNEmitter>& GPUPrimitiveEmitter::get_cudnn_emitter()
{
return m_cudnn_emitter;
}
std::unique_ptr<CUBLASEmitter>& GPUPrimitiveEmitter::get_cublas_emitter()
{
return m_cublas_emitter;
}
size_t GPUPrimitiveEmitter::insert(std::unique_ptr<gpu::primitive>&& f)
{
m_managed_primitives.emplace_back(std::move(f));
......@@ -70,3 +76,10 @@ void GPUPrimitiveEmitter::cache(const std::string& hash, const size_t& index)
{
m_primitive_map.insert({hash, index});
}
size_t GPUPrimitiveEmitter::register_primitive(std::unique_ptr<gpu::primitive>& f, std::string hash)
{
size_t primitive_index = this->insert(std::move(f));
this->cache(hash, primitive_index);
return primitive_index;
}
......@@ -15,9 +15,8 @@
//*****************************************************************************
#pragma once
#include <functional>
#include <unordered_map>
#include "ngraph/runtime/gpu/cublas_emitter.hpp"
#include "ngraph/runtime/gpu/cuda_emitter.hpp"
#include "ngraph/runtime/gpu/cudnn_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_kernel_args.hpp"
......@@ -30,8 +29,6 @@ namespace ngraph
{
namespace gpu
{
class CUDAEmitter;
class CUDNNEmitter;
class GPUPrimitiveEmitter
{
public:
......@@ -39,6 +36,7 @@ namespace ngraph
GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx);
std::unique_ptr<CUDAEmitter>& get_cuda_emitter();
std::unique_ptr<CUDNNEmitter>& get_cudnn_emitter();
std::unique_ptr<CUBLASEmitter>& get_cublas_emitter();
std::vector<gpu::primitive*>& get_primitives() { return m_gpu_primitives; }
std::vector<gpu::memory_primitive>& get_memory_primitives()
{
......@@ -52,6 +50,8 @@ namespace ngraph
void allocate_primitive_memory() { m_memory_manager.allocate(); }
size_t sizeof_device_allocation() { return m_memory_manager.get_allocation_size(); }
GPUKernelArgs add_kernel_args() { return GPUKernelArgs(m_host_parameters); }
size_t register_primitive(std::unique_ptr<gpu::primitive>&, std::string);
private:
std::vector<gpu::primitive*> m_gpu_primitives;
std::vector<gpu::memory_primitive> m_gpu_mem_primitives;
......@@ -61,6 +61,7 @@ namespace ngraph
std::shared_ptr<GPUHostParameters> m_host_parameters;
std::unique_ptr<CUDAEmitter> m_cuda_emitter;
std::unique_ptr<CUDNNEmitter> m_cudnn_emitter;
std::unique_ptr<CUBLASEmitter> m_cublas_emitter;
};
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment