Commit 1c74ad24 authored by fenglei.tian's avatar fenglei.tian

fix bugs for add and maximum

parent 94429375
./build/test/unit-test --gtest_filter=GPU.ab #./build/test/unit-test --gtest_filter=GPU.ab
./build/test/unit-test --gtest_filter=GPU.maximum
#./build/test/unit-test --gtest_filter=GPU.abs
#./build/test/unit-test --gtest_filter=GPU.dot* #./build/test/unit-test --gtest_filter=GPU.dot*
...@@ -16,12 +16,9 @@ ...@@ -16,12 +16,9 @@
#include <fstream> #include <fstream>
#include <stdio.h> #include <stdio.h>
#include <cuda_runtime.h>
#include "cublas.h"
#include "ngraph/runtime/gpu/gpu_call_frame.hpp" #include "ngraph/runtime/gpu/gpu_call_frame.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_view.hpp" #include "ngraph/runtime/gpu/gpu_tensor_view.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -31,19 +28,25 @@ runtime::gpu::GPU_CallFrame::GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction> ...@@ -31,19 +28,25 @@ runtime::gpu::GPU_CallFrame::GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction>
: m_external_function(external_function) : m_external_function(external_function)
, m_compiled_function(compiled_function) , m_compiled_function(compiled_function)
{ {
cublasStatus_t stat = cublasCreate(&m_cublas_handle); cublasStatus_t cublasStatus = cublasCreate(&m_cublas_handle);
if (stat != CUBLAS_STATUS_SUCCESS) if (cublasStatus != CUBLAS_STATUS_SUCCESS)
{ {
throw runtime_error("cuBLAS create failed"); throw runtime_error("cuBLAS create handle failed");
} }
cudnnStatus_t cudnnStatus = cudnnCreate(&m_cudnn_handle);
if (cudnnStatus != CUDNN_STATUS_SUCCESS)
{
throw runtime_error("cuDnn create handle failed");
}
// Pass scalars as reference on the Host
cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_HOST); cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_HOST);
// Pass scalars as reference on the device
cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
} }
runtime::gpu::GPU_CallFrame::~GPU_CallFrame() runtime::gpu::GPU_CallFrame::~GPU_CallFrame()
{ {
cublasDestroy(m_cublas_handle); cublasDestroy(m_cublas_handle);
cudnnDestroy(m_cudnn_handle);
} }
void runtime::gpu::GPU_CallFrame::tensor_call( void runtime::gpu::GPU_CallFrame::tensor_call(
...@@ -67,7 +70,7 @@ void runtime::gpu::GPU_CallFrame::tensor_call( ...@@ -67,7 +70,7 @@ void runtime::gpu::GPU_CallFrame::tensor_call(
outputs.push_back(tv->m_allocated_buffer_pool); outputs.push_back(tv->m_allocated_buffer_pool);
} }
m_compiled_function(inputs.data(), outputs.data(), m_cublas_handle); m_compiled_function(inputs.data(), outputs.data(), m_cublas_handle, m_cudnn_handle);
} }
void runtime::gpu::GPU_CallFrame::call( void runtime::gpu::GPU_CallFrame::call(
......
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <cuda_runtime.h>
#include "cublas_v2.h" #include "cublas_v2.h"
#include <cudnn.h>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/call_frame.hpp"
...@@ -37,7 +39,8 @@ namespace ngraph ...@@ -37,7 +39,8 @@ namespace ngraph
using EntryPoint_t = void(void** inputs, using EntryPoint_t = void(void** inputs,
void** outputs, void** outputs,
cublasHandle_t& cublas_handle); cublasHandle_t& cublas_handle,
cudnnHandle_t& cudnn_handle);
using EntryPoint = std::function<EntryPoint_t>; using EntryPoint = std::function<EntryPoint_t>;
...@@ -66,6 +69,7 @@ namespace ngraph ...@@ -66,6 +69,7 @@ namespace ngraph
std::shared_ptr<GPU_ExternalFunction> m_external_function; std::shared_ptr<GPU_ExternalFunction> m_external_function;
EntryPoint m_compiled_function; EntryPoint m_compiled_function;
cublasHandle_t m_cublas_handle; cublasHandle_t m_cublas_handle;
cudnnHandle_t m_cudnn_handle;
}; };
} }
} }
......
...@@ -260,53 +260,45 @@ void runtime::gpu::GPU_Emitter::EmitMaximum(codegen::CodeWriter& writer, ...@@ -260,53 +260,45 @@ void runtime::gpu::GPU_Emitter::EmitMaximum(codegen::CodeWriter& writer,
const vector<runtime::gpu::GPU_TensorViewWrapper>& args, const vector<runtime::gpu::GPU_TensorViewWrapper>& args,
const vector<runtime::gpu::GPU_TensorViewWrapper>& out) const vector<runtime::gpu::GPU_TensorViewWrapper>& out)
{ {
const Shape& arg0_shape = args[0].get_shape();
const Shape& arg1_shape = args[1].get_shape();
// clang-format off
writer << "{ // " << n->get_name() << "\n"; writer << "{ // " << n->get_name() << "\n";
writer.indent++; writer.indent++;
writer << "int count = " << out[0].get_size() << ";\n"; writer << "int count = " << out[0].get_size() << ";\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n";; writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
cudnnTensorDescriptor_t descriptor;
(cudnnCreateTensorDescriptor(&descriptor));
(cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NHWC,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count));
cudnnOpTensorDescriptor_t opTensorDesc;
(cudnnCreateOpTensorDescriptor(&opTensorDesc));
(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MAX,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
writer << "cudnnOpTensor(cudnn_handle,"
<< "opTensorDesc,"
<< "&alpha1,"
<< "descriptor,"
<< args[0].get_name() << ","
<< "&alpha2,"
<< "descriptor,"
<< args[1].get_name() << ","
<< "&beta,"
<< "descriptor,"
<< out[0].get_name() << ");\n";
writer += R"( writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0; )";
cudnnHandle_t cudnnHandle; writer.indent--;
(cudnnCreate(&cudnnHandle)); writer << "}\n";
cudnnTensorDescriptor_t descriptor;
(cudnnCreateTensorDescriptor(&descriptor));
(cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NHWC,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count));
cudnnOpTensorDescriptor_t opTensorDesc;
(cudnnCreateOpTensorDescriptor(&opTensorDesc));
(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MAX,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
writer << "cudnnOpTensor(cudnnHandle,"
<< "opTensorDesc,"
<< "&alpha1,"
<< "descriptor,"
<< args[0].get_name() << ","
<< "&alpha2,"
<< "descriptor,"
<< args[1].get_name() << ","
<< "&beta,"
<< "descriptor,"
<< out[0].get_name() << ");\n";
writer += R"(
cudnnDestroy(cudnnHandle);
)";
writer.indent--;
writer << "}\n";
// clang-format on
} }
void runtime::gpu::GPU_Emitter::EmitMinimum(codegen::CodeWriter& writer, void runtime::gpu::GPU_Emitter::EmitMinimum(codegen::CodeWriter& writer,
......
...@@ -358,7 +358,7 @@ void runtime::gpu::GPU_ExternalFunction::compile() ...@@ -358,7 +358,7 @@ void runtime::gpu::GPU_ExternalFunction::compile()
for (shared_ptr<Function> f : pass_manager.get_state().get_functions()) for (shared_ptr<Function> f : pass_manager.get_state().get_functions())
{ {
writer << "extern \"C\" void " << f->get_name() writer << "extern \"C\" void " << f->get_name()
<< "(void** inputs, void** outputs, cublasHandle_t& cublas_handle);\n"; << "(void** inputs, void** outputs, cublasHandle_t& cublas_handle, cudnnHandle_t& cudnn_handle);\n";
} }
writer << "\n"; writer << "\n";
...@@ -476,7 +476,7 @@ void runtime::gpu::GPU_ExternalFunction::compile() ...@@ -476,7 +476,7 @@ void runtime::gpu::GPU_ExternalFunction::compile()
} }
writer << "extern \"C\" void " << current_function->get_name(); writer << "extern \"C\" void " << current_function->get_name();
writer << "(void** inputs, void** outputs, cublasHandle_t& cublas_handle)\n"; writer << "(void** inputs, void** outputs, cublasHandle_t& cublas_handle, cudnnHandle_t& cudnn_handle)\n";
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment