Unverified Commit 94d80ffa authored by Tristan Webb's avatar Tristan Webb Committed by GitHub

Drwebb/gpu backend dot op (#413)

* Drwebb/gpu backend dot op (#387)

* GPU Dot prod emitter switch statement

* cuBLAS dot kernel call

* Flush out arg substitution into gpu dot kernel call

* Drwebb/gpu backend dot op (#392)

* Take in CodeWriter into gpu op emitters

* Introduce GPU function gen based on pass functions

* Additional gpu emitter stubs

* link cublas in to unit test and ngraph

* Use static code gen methods for GPU, add new GPU op stubs

* use pass manager to declare functions / cublas Updates

* Prune down gpu_external_function wip

* Switch back to GPU tensor views in GPU backend

* Pass in cublas handle to GPU external function

* cuMalloc memory in gpu tensor view

* Use cuda runtime malloc and free for tensor view managment c

* change GPU tensor view init, and use GPU tensor view for GPU call frame

* include headers as system dirs

* GPU tensor printing utility function

* cublasSetPointer to device mode / Fix copyright notification lowercasing

* Passing GPU dot product test using cuBLAS

Clean up

* Changes from review
parent 2b0a5489
...@@ -187,8 +187,7 @@ endif() ...@@ -187,8 +187,7 @@ endif()
# GPU backend current requires CPU because they share compiler.cpp, # GPU backend current requires CPU because they share compiler.cpp,
# and compiler.cpp requires MKLDNN # and compiler.cpp requires MKLDNN
if(NGRAPH_GPU_ENABLE) if(NGRAPH_GPU_ENABLE)
include_directories(${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIR}) include_directories(SYSTEM ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIR})
link_directories(${CUDA_LIBRARIES} ${CUDNN_LIBRARIES})
# Add sources for the GPU backend # Add sources for the GPU backend
# and all its dependencies # and all its dependencies
...@@ -201,6 +200,7 @@ endif() ...@@ -201,6 +200,7 @@ endif()
runtime/gpu/gpu_manager.cpp runtime/gpu/gpu_manager.cpp
runtime/gpu/gpu_tensor_view.cpp runtime/gpu/gpu_tensor_view.cpp
runtime/gpu/gpu_tensor_view_wrapper.cpp runtime/gpu/gpu_tensor_view_wrapper.cpp
runtime/gpu/gpu_util.cpp
) )
set_property(SOURCE codegen/compiler.cpp APPEND_STRING PROPERTY COMPILE_DEFINITIONS set_property(SOURCE codegen/compiler.cpp APPEND_STRING PROPERTY COMPILE_DEFINITIONS
"CUDA_HEADER_PATHS=\"${CUDA_INCLUDE_DIRS}\";") "CUDA_HEADER_PATHS=\"${CUDA_INCLUDE_DIRS}\";")
...@@ -272,8 +272,9 @@ if(NGRAPH_CPU_ENABLE) ...@@ -272,8 +272,9 @@ if(NGRAPH_CPU_ENABLE)
target_link_libraries(ngraph PRIVATE ${TBB_IMPORTED_TARGETS}) target_link_libraries(ngraph PRIVATE ${TBB_IMPORTED_TARGETS})
endif() endif()
# Nvidia
if(NGRAPH_GPU_ENABLE AND CUDA_LIBRARIES) if(NGRAPH_GPU_ENABLE AND CUDA_LIBRARIES)
target_link_libraries(ngraph PRIVATE cuda) target_link_libraries(ngraph PRIVATE ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES})
endif() endif()
# Argon # Argon
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/runtime/gpu/gpu_backend.hpp" #include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/external_function.hpp" #include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_view.hpp" #include "ngraph/runtime/gpu/gpu_tensor_view.hpp"
...@@ -27,9 +26,9 @@ std::shared_ptr<ngraph::runtime::CallFrame> runtime::gpu::GPU_Backend::make_call ...@@ -27,9 +26,9 @@ std::shared_ptr<ngraph::runtime::CallFrame> runtime::gpu::GPU_Backend::make_call
} }
std::shared_ptr<ngraph::runtime::TensorView> std::shared_ptr<ngraph::runtime::TensorView>
runtime::gpu::GPU_Backend::make_device_tensor(const ngraph::element::Type& element_type, runtime::gpu::GPU_Backend::make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape) const Shape& shape)
{ {
auto rc = make_shared<runtime::HostTensorView>(element_type, shape); auto rc = make_shared<runtime::gpu::GPU_TensorView>(element_type, shape);
return dynamic_pointer_cast<runtime::TensorView>(rc); return dynamic_pointer_cast<runtime::TensorView>(rc);
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// copyright 2017 nervana systems inc. // Copyright 2017 Nervana Systems Inc.
// licensed under the apache license, version 2.0 (the "license"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the license. // you may not use this file except in compliance with the License.
// you may obtain a copy of the license at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/license-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the license is distributed on an "as is" basis, // distributed under the License is distributed on an "AS IS" BASIS,
// without warranties or conditions of any kind, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// see the license for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <cstdlib> #include <cstdlib>
#include <fstream> #include <fstream>
#include <stdio.h>
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp" #include <cuda_runtime.h>
#include "cublas_v2.h"
#include "ngraph/runtime/gpu/gpu_call_frame.hpp" #include "ngraph/runtime/gpu/gpu_call_frame.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp" #include "ngraph/runtime/gpu/gpu_external_function.hpp"
...@@ -29,6 +31,13 @@ runtime::gpu::GPU_CallFrame::GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction> ...@@ -29,6 +31,13 @@ runtime::gpu::GPU_CallFrame::GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction>
: m_external_function(external_function) : m_external_function(external_function)
, m_compiled_function(compiled_function) , m_compiled_function(compiled_function)
{ {
cublasStatus_t stat = cublasCreate(&m_cublas_handle);
if (stat != cudaSuccess)
{
throw runtime_error("cuBLAS create failed");
}
// Pass scalars as reference on the device
cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
} }
void runtime::gpu::GPU_CallFrame::tensor_call( void runtime::gpu::GPU_CallFrame::tensor_call(
...@@ -36,24 +45,23 @@ void runtime::gpu::GPU_CallFrame::tensor_call( ...@@ -36,24 +45,23 @@ void runtime::gpu::GPU_CallFrame::tensor_call(
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& output_tvs) const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& output_tvs)
{ {
// Host tensors // Host tensors
vector<void*> inputs; vector<void**> inputs;
vector<void*> outputs; vector<void**> outputs;
for (size_t i = 0; i < input_tvs.size(); i++) for (size_t i = 0; i < input_tvs.size(); i++)
{ {
shared_ptr<runtime::HostTensorView> tv = shared_ptr<runtime::gpu::GPU_TensorView> tv =
static_pointer_cast<runtime::HostTensorView>(input_tvs[i]); static_pointer_cast<runtime::gpu::GPU_TensorView>(input_tvs[i]);
inputs.push_back(tv->get_data_ptr()); inputs.push_back(tv->m_allocated_buffer_pool);
} }
for (size_t i = 0; i < output_tvs.size(); i++) for (size_t i = 0; i < output_tvs.size(); i++)
{ {
shared_ptr<runtime::HostTensorView> tv = shared_ptr<runtime::gpu::GPU_TensorView> tv =
static_pointer_cast<runtime::HostTensorView>(output_tvs[i]); static_pointer_cast<runtime::gpu::GPU_TensorView>(output_tvs[i]);
outputs.push_back(tv->get_data_ptr()); outputs.push_back(tv->m_allocated_buffer_pool);
} }
// Invoke compiled computation m_compiled_function(inputs.data(), outputs.data(), m_cublas_handle);
m_compiled_function(inputs.data(), outputs.data());
} }
void runtime::gpu::GPU_CallFrame::call( void runtime::gpu::GPU_CallFrame::call(
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "cublas_v2.h"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
...@@ -33,7 +35,9 @@ namespace ngraph ...@@ -33,7 +35,9 @@ namespace ngraph
class GPU_CallFrame; class GPU_CallFrame;
class GPU_ExternalFunction; class GPU_ExternalFunction;
using EntryPoint_t = void(void** inputs, void** outputs); using EntryPoint_t = void(void*** inputs,
void*** outputs,
cublasHandle_t& cublas_handle);
using EntryPoint = std::function<EntryPoint_t>; using EntryPoint = std::function<EntryPoint_t>;
...@@ -44,6 +48,8 @@ namespace ngraph ...@@ -44,6 +48,8 @@ namespace ngraph
GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction> external_function, GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction> external_function,
EntryPoint compiled_function); EntryPoint compiled_function);
~GPU_CallFrame() override = default;
/// @brief Invoke the function with values matching the signature of the function. /// @brief Invoke the function with values matching the signature of the function.
/// ///
/// Tuples will be expanded into their tensor views to build the call frame. /// Tuples will be expanded into their tensor views to build the call frame.
...@@ -59,6 +65,7 @@ namespace ngraph ...@@ -59,6 +65,7 @@ namespace ngraph
protected: protected:
std::shared_ptr<GPU_ExternalFunction> m_external_function; std::shared_ptr<GPU_ExternalFunction> m_external_function;
EntryPoint m_compiled_function; EntryPoint m_compiled_function;
cublasHandle_t m_cublas_handle;
}; };
} }
} }
......
This diff is collapsed.
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
#include "ngraph/runtime/gpu/gpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/gpu/gpu_tensor_view_wrapper.hpp"
#define EMITTER_DECL(E) \ #define EMITTER_DECL(E) \
E(const ngraph::Node* n, \ E(codegen::CodeWriter& writer, \
const ngraph::Node* n, \
const std::vector<ngraph::runtime::gpu::GPU_TensorViewWrapper>& args, \ const std::vector<ngraph::runtime::gpu::GPU_TensorViewWrapper>& args, \
const std::vector<ngraph::runtime::gpu::GPU_TensorViewWrapper>& out) const std::vector<ngraph::runtime::gpu::GPU_TensorViewWrapper>& out)
...@@ -35,79 +36,60 @@ namespace ngraph ...@@ -35,79 +36,60 @@ namespace ngraph
{ {
class GPU_Emitter class GPU_Emitter
{ {
protected:
codegen::CodeWriter m_out;
bool m_use_ref_kernels;
public: public:
GPU_Emitter() static void EMITTER_DECL(EmitNop);
: m_out() static void EMITTER_DECL(EmitAdd);
, m_use_ref_kernels(std::getenv("NGRAPH_GPU_USE_REF_KERNELS") != nullptr) static void EMITTER_DECL(EmitDot);
{ static void EMITTER_DECL(EmitMultiply);
} static void EMITTER_DECL(EmitGetOutputElement);
std::string get_code() { return m_out.get_code(); } static void EMITTER_DECL(EmitXLAGetTupleElement);
codegen::CodeWriter& get_code_writer() { return m_out; } static void EMITTER_DECL(EmitTuple);
void EMITTER_DECL(EmitNop); static void EMITTER_DECL(EmitAbs);
void EMITTER_DECL(EmitAdd); static void EMITTER_DECL(EmitConcat);
void EMITTER_DECL(EmitDot); static void EMITTER_DECL(EmitDivide);
void EMITTER_DECL(EmitMultiply); static void EMITTER_DECL(EmitEqual);
void EMITTER_DECL(EmitGetOutputElement); static void EMITTER_DECL(EmitGreater);
void EMITTER_DECL(EmitXLAGetTupleElement); static void EMITTER_DECL(EmitGreaterEq);
void EMITTER_DECL(EmitTuple); static void EMITTER_DECL(EmitLess);
void EMITTER_DECL(EmitAbs); static void EMITTER_DECL(EmitLessEq);
void EMITTER_DECL(EmitConcat); static void EMITTER_DECL(EmitLog);
void EMITTER_DECL(EmitDivide); static void EMITTER_DECL(EmitMaximum);
void EMITTER_DECL(EmitEqual); static void EMITTER_DECL(EmitMinimum);
void EMITTER_DECL(EmitGreater); static void EMITTER_DECL(EmitNegative);
void EMITTER_DECL(EmitGreaterEq); static void EMITTER_DECL(EmitNotEqual);
void EMITTER_DECL(EmitLess); static void EMITTER_DECL(EmitSelect);
void EMITTER_DECL(EmitLessEq); static void EMITTER_DECL(EmitSubtract);
void EMITTER_DECL(EmitLog); static void EMITTER_DECL(EmitBroadcast);
void EMITTER_DECL(EmitMaximum); static void EMITTER_DECL(EmitConvert);
void EMITTER_DECL(EmitMinimum); static void EMITTER_DECL(EmitConstant);
void EMITTER_DECL(EmitNegative); static void EMITTER_DECL(EmitReshape);
void EMITTER_DECL(EmitNotEqual); static void EMITTER_DECL(EmitFunctionCall);
void EMITTER_DECL(EmitSelect); static void EMITTER_DECL(EmitReduce);
void EMITTER_DECL(EmitSubtract); static void EMITTER_DECL(EmitSign);
void EMITTER_DECL(EmitBroadcast); static void EMITTER_DECL(EmitSlice);
void EMITTER_DECL(EmitConvert); static void EMITTER_DECL(EmitSum);
void EMITTER_DECL(EmitConstant); static void EMITTER_DECL(EmitExp);
void EMITTER_DECL(EmitReshape); static void EMITTER_DECL(EmitSin);
void EMITTER_DECL(EmitFunctionCall); static void EMITTER_DECL(EmitSinh);
void EMITTER_DECL(EmitReduce); static void EMITTER_DECL(EmitCos);
void EMITTER_DECL(EmitSign); static void EMITTER_DECL(EmitCosh);
void EMITTER_DECL(EmitSlice); static void EMITTER_DECL(EmitTan);
void EMITTER_DECL(EmitSum); static void EMITTER_DECL(EmitTanh);
void EMITTER_DECL(EmitExp); static void EMITTER_DECL(EmitAsin);
void EMITTER_DECL(EmitSin); static void EMITTER_DECL(EmitAcos);
void EMITTER_DECL(EmitSinh); static void EMITTER_DECL(EmitAtan);
void EMITTER_DECL(EmitCos); static void EMITTER_DECL(EmitPower);
void EMITTER_DECL(EmitCosh); static void EMITTER_DECL(EmitReplaceSlice);
void EMITTER_DECL(EmitTan); static void EMITTER_DECL(EmitOneHot);
void EMITTER_DECL(EmitTanh); static void EMITTER_DECL(EmitFloor);
void EMITTER_DECL(EmitAsin); static void EMITTER_DECL(EmitCeiling);
void EMITTER_DECL(EmitAcos); static void EMITTER_DECL(EmitSqrt);
void EMITTER_DECL(EmitAtan); static void EMITTER_DECL(EmitConvolution);
void EMITTER_DECL(EmitPower); static void EMITTER_DECL(EmitNot);
void EMITTER_DECL(EmitReplaceSlice); static void EMITTER_DECL(EmitMaxPool);
void EMITTER_DECL(EmitOneHot); static void EMITTER_DECL(EmitReverse);
void EMITTER_DECL(EmitFloor); static void EMITTER_DECL(EmitReduceWindow);
void EMITTER_DECL(EmitCeiling); static void EMITTER_DECL(EmitSelectAndScatter);
void EMITTER_DECL(EmitSqrt);
void EMITTER_DECL(EmitConvolution);
void EMITTER_DECL(EmitNot);
void EMITTER_DECL(EmitMaxPool);
void EMITTER_DECL(EmitReverse);
private:
void generate_call(const std::vector<GPU_TensorViewWrapper>& args,
const std::vector<GPU_TensorViewWrapper>& out,
std::shared_ptr<Function> function);
std::string emit_vector(const GPU_TensorViewWrapper&, const std::string& name = "");
std::string emit_array1d(const GPU_TensorViewWrapper&,
const std::string& name = "");
std::string emit_matrix(const GPU_TensorViewWrapper&, const std::string& name = "");
}; };
} }
} }
......
...@@ -39,7 +39,7 @@ namespace ngraph ...@@ -39,7 +39,7 @@ namespace ngraph
class GPU_CallFrame; class GPU_CallFrame;
using OpFunction = using OpFunction =
std::function<void(GPU_Emitter*, std::function<void(codegen::CodeWriter&,
const ngraph::Node*, const ngraph::Node*,
const std::vector<GPU_TensorViewWrapper>& inputs, const std::vector<GPU_TensorViewWrapper>& inputs,
const std::vector<GPU_TensorViewWrapper>& outputs)>; const std::vector<GPU_TensorViewWrapper>& outputs)>;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include <cuda.h> #include <cuda_runtime.h>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/descriptor/primary_tensor_view.hpp" #include "ngraph/descriptor/primary_tensor_view.hpp"
...@@ -33,25 +33,35 @@ runtime::gpu::GPU_TensorView::GPU_TensorView(const ngraph::element::Type& elemen ...@@ -33,25 +33,35 @@ runtime::gpu::GPU_TensorView::GPU_TensorView(const ngraph::element::Type& elemen
true, true,
false)) false))
{ {
// Need to check type and have host/device tensors
m_descriptor->set_tensor_view_layout( m_descriptor->set_tensor_view_layout(
std::make_shared<ngraph::descriptor::layout::DenseTensorViewLayout>(*m_descriptor)); std::make_shared<ngraph::descriptor::layout::DenseTensorViewLayout>(*m_descriptor));
m_buffer_size = m_descriptor->get_tensor_view_layout()->get_size() * element_type.size(); m_buffer_size = m_descriptor->get_tensor_view_layout()->get_size() * element_type.size();
if (m_buffer_size > 0)
// cuMemAlloc(&dev_buffer, m_buffer_size); {
cudaMalloc(&m_allocated_buffer_pool, m_buffer_size);
}
} }
runtime::gpu::GPU_TensorView::~GPU_TensorView() runtime::gpu::GPU_TensorView::~GPU_TensorView()
{ {
// cuMemFree(dev_buffer); cudaFree(m_allocated_buffer_pool);
} }
void runtime::gpu::GPU_TensorView::write(const void* source, size_t tensor_offset, size_t n) void runtime::gpu::GPU_TensorView::write(const void* source, size_t tensor_offset, size_t n)
{ {
// cuMemcpyHtoD(dev_buffer, source, n); if (tensor_offset + n > m_buffer_size)
{
throw out_of_range("write access past end of tensor");
}
cudaMemcpy(m_allocated_buffer_pool, source, n, cudaMemcpyHostToDevice);
} }
void runtime::gpu::GPU_TensorView::read(void* target, size_t tensor_offset, size_t n) const void runtime::gpu::GPU_TensorView::read(void* target, size_t tensor_offset, size_t n) const
{ {
// cuMemcpyDtoH(target, dev_buffer, n); if (tensor_offset + n > m_buffer_size)
{
throw out_of_range("read access past end of tensor");
}
cudaMemcpy(target, m_allocated_buffer_pool, n, cudaMemcpyDeviceToHost);
} }
...@@ -49,12 +49,6 @@ public: ...@@ -49,12 +49,6 @@ public:
/// @param n Number of bytes to read, must be integral number of elements. /// @param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t tensor_offset, size_t n) const override; void read(void* p, size_t tensor_offset, size_t n) const override;
// const char* get_data_ptr(); void** m_allocated_buffer_pool;
// const char* get_data_ptr() const;
private:
CUdeviceptr dev_buffer;
// At some point need to deal with alignment
size_t m_buffer_size; size_t m_buffer_size;
}; };
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <stddef.h>
#include <stdio.h>
#include "cuda.h"
#include "cuda_runtime.h"
#include "ngraph/runtime/gpu/gpu_util.hpp"
using namespace ngraph;
using namespace std;
void runtime::gpu::print_gpu_f32_tensor(void* p, size_t element_count, size_t element_size)
{
float* local;
size_t size_in_bytes = element_size * element_count;
local = static_cast<float*>(malloc(size_in_bytes));
cudaMemcpy(local, p, size_in_bytes, cudaMemcpyDeviceToHost);
for (size_t i = 0; i < element_count; i++)
{
std::cout << local[i] << "\n";
}
}
void runtime::gpu::check_cuda_errors(CUresult err)
{
assert(err == CUDA_SUCCESS);
}
// ----------------------------------------------------------------------------
// copyright 2017 nervana systems inc.
// licensed under the apache license, version 2.0 (the "license");
// you may not use this file except in compliance with the license.
// you may obtain a copy of the license at
//
// http://www.apache.org/licenses/license-2.0
//
// unless required by applicable law or agreed to in writing, software
// distributed under the license is distributed on an "as is" basis,
// without warranties or conditions of any kind, either express or implied.
// see the license for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace gpu
{
void print_gpu_f32_tensor(void* p, size_t element_count, size_t element_size);
void check_cuda_errors(CUresult err);
}
}
}
...@@ -76,6 +76,8 @@ endif() ...@@ -76,6 +76,8 @@ endif()
if(NGRAPH_GPU_ENABLE AND LLVM_INCLUDE_DIR) if(NGRAPH_GPU_ENABLE AND LLVM_INCLUDE_DIR)
include_directories(SYSTEM ${LLVM_INCLUDE_DIR}) include_directories(SYSTEM ${LLVM_INCLUDE_DIR})
link_directories(${LLVM_LIB_DIR}) link_directories(${LLVM_LIB_DIR})
link_directories(${CUDA_LIBRARIES})
link_directories(${CUDA_CUBLAS_LIBRARIES})
set(SRC set(SRC
${SRC} ${SRC}
cudnn.cpp) cudnn.cpp)
...@@ -123,7 +125,7 @@ if(LLVM_INCLUDE_DIR) ...@@ -123,7 +125,7 @@ if(LLVM_INCLUDE_DIR)
endif() endif()
if(CUDA_INCLUDE_DIRS) if(CUDA_INCLUDE_DIRS)
target_link_libraries(unit-test ${CUDA_LIBRARIES} ${CUDNN_LIBRARIES}) target_link_libraries(unit-test ${CUDA_LIBRARIES} ${CUDNN_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES})
endif() endif()
target_link_libraries(unit-test ngraph libgtest pthread) target_link_libraries(unit-test ngraph libgtest pthread)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn.h> #include <cudnn.h>
#include "ngraph/codegen/compiler.hpp" #include "ngraph/codegen/compiler.hpp"
...@@ -45,6 +46,7 @@ TEST(cudnn, compileTest) ...@@ -45,6 +46,7 @@ TEST(cudnn, compileTest)
#include <cassert> #include <cassert>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include "cublas_v2.h"
#include "cuda.h" #include "cuda.h"
void check_cuda_errors(CUresult err) { void check_cuda_errors(CUresult err) {
...@@ -60,6 +62,15 @@ int main(int argc, char **argv) { ...@@ -60,6 +62,15 @@ int main(int argc, char **argv) {
CUlinkState linker; CUlinkState linker;
int dev_count; int dev_count;
// Cublas init
cudaError_t cudaStat;
cublasStatus_t stat;
cublasHandle_t handle;
stat = cublasCreate(&handle);
cublasDestroy(handle);
// CUDA initialization // CUDA initialization
check_cuda_errors(cuInit(0)); check_cuda_errors(cuInit(0));
check_cuda_errors(cuDeviceGetCount(&dev_count)); check_cuda_errors(cuDeviceGetCount(&dev_count));
...@@ -251,48 +262,48 @@ const auto str = R"( ...@@ -251,48 +262,48 @@ const auto str = R"(
auto module = compiler.compile(source); auto module = compiler.compile(source);
} }
TEST(cudnn, abc) // TEST(cudnn, abc)
{ // {
auto shape = Shape{2, 2}; // auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape); // auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape); // auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape); // auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, op::Parameters{A, B, C}); // auto f = make_shared<Function>((A + B) * C, op::Parameters{A, B, C});
auto manager = runtime::Manager::get("GPU"); // auto manager = runtime::Manager::get("GPU");
auto external = manager->compile(f); // auto external = manager->compile(f);
auto backend = manager->allocate_backend(); // auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external); // auto cf = backend->make_call_frame(external);
// Create some tensors for input/output // // Create some tensors for input/output
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shape); // shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shape);
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shape); // shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shape);
shared_ptr<runtime::TensorView> c = backend->make_primary_tensor_view(element::f32, shape); // shared_ptr<runtime::TensorView> c = backend->make_primary_tensor_view(element::f32, shape);
shared_ptr<runtime::TensorView> result = backend->make_primary_tensor_view(element::f32, shape); // shared_ptr<runtime::TensorView> result = backend->make_primary_tensor_view(element::f32, shape);
copy_data(a, test::NDArray<float, 2>({{1, 2}, {3, 4}}).get_vector()); // copy_data(a, test::NDArray<float, 2>({{1, 2}, {3, 4}}).get_vector());
copy_data(b, test::NDArray<float, 2>({{5, 6}, {7, 8}}).get_vector()); // copy_data(b, test::NDArray<float, 2>({{5, 6}, {7, 8}}).get_vector());
copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector()); // copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector());
cf->call({a, b, c}, {result}); // cf->call({a, b, c}, {result});
EXPECT_EQ(result->read_vector<float>(), // EXPECT_EQ(result->read_vector<float>(),
(test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector()); // (test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector());
cf->call({b, a, c}, {result}); // cf->call({b, a, c}, {result});
EXPECT_EQ(result->read_vector<float>(), // EXPECT_EQ(result->read_vector<float>(),
(test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector()); // (test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector());
cf->call({a, c, b}, {result}); // cf->call({a, c, b}, {result});
EXPECT_EQ(result->read_vector<float>(), // EXPECT_EQ(result->read_vector<float>(),
(test::NDArray<float, 2>({{50, 72}, {98, 128}})).get_vector()); // (test::NDArray<float, 2>({{50, 72}, {98, 128}})).get_vector());
} // }
TEST(cudnn, dot1d) TEST(cudnn, dot1d)
{ {
auto shape = Shape{4}; auto shape = Shape{4};
auto A = make_shared<op::Parameter>(element::f32, shape); auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape); auto B = make_shared<op::Parameter>(element::f32, shape);
auto shape_r = Shape{}; auto shape_r = Shape{1};
auto f = make_shared<Function>(make_shared<op::Dot>(A, B), op::Parameters{A, B}); auto f = make_shared<Function>(make_shared<op::Dot>(A, B), op::Parameters{A, B});
auto manager = runtime::Manager::get("GPU"); auto manager = runtime::Manager::get("GPU");
...@@ -308,5 +319,5 @@ TEST(cudnn, dot1d) ...@@ -308,5 +319,5 @@ TEST(cudnn, dot1d)
auto result = backend->make_primary_tensor_view(element::f32, shape_r); auto result = backend->make_primary_tensor_view(element::f32, shape_r);
cf->call({a, b}, {result}); cf->call({a, b}, {result});
EXPECT_EQ((vector<float>{170}), result->read_vector<float>()); EXPECT_EQ((vector<float>{170}), read_vector<float>(result));
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment