Commit feab44b5 authored by Tristan Webb's avatar Tristan Webb Committed by Robert Kimball

Drwebb/gpu runtime boilerplate (#314)

* Simple boilerplate for GPU runtime files

  - GPUBackend
  - GPU ExternalFunction
  - GPUManager
  - GPUCallFrame

* Test for construction all GPU runtime classes

* Comment out calls, constructors haven't been defined

* Clang CUDA source example to later test compiling

Clang cuda example from:
https://gist.github.com/anonymous/855e277884eb6b388cd2f00d956c2fd4

* Initial nvptx compiler copied from CPU compiler sources

* Define FunctionMap and Instruction for gpu external function

* Rename Compiler -> NVPTXCompiler for gpu compile. Add call to compile for test

* Rename StaticCompiler -> NVPTXStaticCompiler for GPU code gen

* CAdd nvptx_compiler and nvptx_execution_engine to gpu sources

* Compiling source unit test using hardcoded PTX

* (a+b)*c test for GPU

* WIP Fix compile

* rmed accidentally included file

* Fix compile, and LLVM link errosr from nvptx_compiler.cpp

* Stub out parts needed for GPU manager

* Test GPU runtime method stubs

* Cleanup

* Add GPU runtime to same cmake block as GPU, include CUDA headers if GPU enabled

* Kill reflexive assertion

* change GPU naming convention to match CPU

* Snake case functions and identifiers in test case

* Change element type to match changes in master

* Make CUDA headers accessible for codegen with GPU transformer

* clang-format

* apply-code-format
parent 2218cf9f
......@@ -13,7 +13,7 @@
include(ExternalProject)
if((NGRAPH_CPU_ENABLE OR USE_CUDA) AND (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") AND
if((NGRAPH_CPU_ENABLE OR NGRAPH_GPU_ENABLE) AND (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") AND
(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows"))
message(STATUS "Fetching LLVM from llvm.org")
set(LLVM_RELEASE_URL http://releases.llvm.org/5.0.0/clang+llvm-5.0.0-linux-x86_64-ubuntu16.04.tar.xz)
......
......@@ -171,13 +171,32 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
"EIGEN_HEADERS_PATH=\"${EIGEN_INCLUDE_DIR}\";CLANG_BUILTIN_HEADERS_PATH=\"${LLVM_LIB_DIR}/clang/5.0.0/include\";TBB_HEADERS_PATH=\"${TBB_ROOT}/include\";NGRAPH_HEADERS_PATH=\"${NGRAPH_INCLUDE_PATH}\";INSTALLED_HEADERS_PATH=\"${CMAKE_INSTALL_PREFIX}/include\";")
set(NGRAPH_CPU_DEBUGINFO_ENABLE 0 CACHE STRING "Enable debuginfo in the CPU backend")
set_property(SOURCE codegen/compiler.cpp APPEND_STRING PROPERTY COMPILE_DEFINITIONS
"NGCPU_DEBUGINFO=${NGRAPH_CPU_DEBUGINFO_ENABLE}")
endif()
"NGCPU_DEBUGINFO=${NGRAPH_CPU_DEBUGINFO_ENABLE};")
if (NGRAPH_ARGON_ENABLE)
link_directories(${ARGON_LIB_DIR})
endif()
# GPU backend current requires CPU because they share compiler.cpp,
# and compiler.cpp requires MKLDNN
if(NGRAPH_GPU_ENABLE)
include_directories(${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIR})
link_directories(${CUDA_LIBRARIES} ${CUDNN_LIBRARIES})
# Add sources for the GPU backend
# and all its dependencies
set(SRC ${SRC}
runtime/gpu/gpu_call_frame.cpp
runtime/gpu/gpu_backend.cpp
runtime/gpu/gpu_manager.cpp
runtime/gpu/gpu_external_function.cpp
runtime/gpu/gpu_tensor_view.cpp
)
set_property(SOURCE codegen/compiler.cpp APPEND_STRING PROPERTY COMPILE_DEFINITIONS
"CUDA_HEADER_PATHS=\"${CUDA_INCLUDE_DIRS}\";")
endif()
endif()
add_library(ngraph SHARED ${SRC})
if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND MKLDNN_INCLUDE_DIR)
# Generate the resource file containing all headers used by the codegen compiler
......
......@@ -354,6 +354,11 @@ void StaticCompiler::configure_search_path()
add_header_search_path(NGRAPH_HEADERS_PATH);
add_header_search_path(INSTALLED_HEADERS_PATH);
#endif
#ifdef CUDA_HEADER_PATHS
// Only needed for GPU backend
add_header_search_path(CUDA_HEADER_PATHS);
#endif
}
void StaticCompiler::load_header_search_path_from_resource()
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_view.hpp"
using namespace ngraph;
using namespace std;
std::shared_ptr<ngraph::runtime::CallFrame> runtime::gpu::GPU_Backend::make_call_frame(
const std::shared_ptr<ExternalFunction>& external_function)
{
return external_function->make_call_frame();
}
std::shared_ptr<ngraph::runtime::TensorView>
runtime::gpu::GPU_Backend::make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape)
{
auto rc = make_shared<runtime::gpu::GPU_TensorView>(element_type, shape);
return dynamic_pointer_cast<runtime::TensorView>(rc);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/backend.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
static size_t alignment = 64;
class GPU_Backend : public Backend
{
public:
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(
const std::shared_ptr<ngraph::runtime::ExternalFunction>& external_function)
override;
std::shared_ptr<ngraph::runtime::TensorView>
make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape) override;
};
}
}
}
// ----------------------------------------------------------------------------
// copyright 2017 nervana systems inc.
// licensed under the apache license, version 2.0 (the "license");
// you may not use this file except in compliance with the license.
// you may obtain a copy of the license at
//
// http://www.apache.org/licenses/license-2.0
//
// unless required by applicable law or agreed to in writing, software
// distributed under the license is distributed on an "as is" basis,
// without warranties or conditions of any kind, either express or implied.
// see the license for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/runtime/gpu/gpu_call_frame.hpp"
using namespace std;
using namespace ngraph::runtime::gpu;
GPU_CallFrame::GPU_CallFrame(shared_ptr<GPU_ExternalFunction> external_function,
shared_ptr<Function> func)
: m_external_function(external_function)
, m_function(func)
{
}
void GPU_CallFrame::call(const vector<shared_ptr<Value>>& input_tvs,
const vector<shared_ptr<Value>>& output_tvs)
{
}
void GPU_CallFrame::tensor_call(const std::vector<std::shared_ptr<TensorView>>& inputs,
const std::vector<std::shared_ptr<TensorView>>& outputs)
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <vector>
#include "ngraph/function.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPU_CallFrame;
class GPU_ExternalFunction;
using EntryPoint_t = void(void** inputs, void** outputs);
using EntryPoint = std::function<EntryPoint_t>;
// Compile and execute graphs
class GPU_CallFrame : public ngraph::runtime::CallFrame
{
public:
GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction> external_function,
std::shared_ptr<Function> func);
/// @brief Invoke the function with values matching the signature of the function.
///
/// Tuples will be expanded into their tensor views to build the call frame.
void call(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs);
/// @brief Invoke the function with tuples pre-expanded to their underlying
/// tensor views.
void tensor_call(const std::vector<std::shared_ptr<TensorView>>& inputs,
const std::vector<std::shared_ptr<TensorView>>& outputs);
protected:
std::shared_ptr<GPU_ExternalFunction> m_external_function;
std::shared_ptr<Function> m_function;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <string>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/runtime/gpu/gpu_call_frame.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
using namespace std;
using namespace ngraph::runtime::gpu;
using namespace ngraph;
ngraph::runtime::gpu::GPU_ExternalFunction::GPU_ExternalFunction(
const std::shared_ptr<ngraph::Function>& function, bool release_function)
: runtime::ExternalFunction(function, release_function)
, m_function(function)
{
}
void runtime::gpu::GPU_ExternalFunction::compile()
{
}
shared_ptr<runtime::CallFrame> runtime::gpu::GPU_ExternalFunction::make_call_frame()
{
if (!m_is_compiled)
{
compile();
}
return make_shared<runtime::gpu::GPU_CallFrame>(shared_from_this(), m_function);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/gpu/gpu_call_frame.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPU_ExternalFunction : public ngraph::runtime::ExternalFunction,
public std::enable_shared_from_this<GPU_ExternalFunction>
{
public:
GPU_ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
protected:
void compile();
std::shared_ptr<ngraph::Function> m_function;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/runtime/gpu/gpu_manager.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
using namespace ngraph::runtime::gpu;
std::shared_ptr<ngraph::runtime::Backend> GPU_Manager::allocate_backend()
{
return std::make_shared<GPU_Backend>();
}
std::shared_ptr<ngraph::runtime::ExternalFunction>
GPU_Manager::compile(const std::shared_ptr<ngraph::Function>& fun)
{
return std::make_shared<GPU_ExternalFunction>(fun);
}
ngraph::runtime::Manager::Factory GPU_Manager::factory = ngraph::runtime::Manager::register_factory(
"GPU", [](const std::string& name) -> std::shared_ptr<ngraph::runtime::Manager> {
return std::make_shared<GPU_Manager>();
});
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/manager.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPU_Manager : public Manager
{
public:
virtual std::shared_ptr<Backend> allocate_backend() override;
virtual std::shared_ptr<ngraph::runtime::ExternalFunction>
compile(const std::shared_ptr<ngraph::Function>& fun) override;
static Factory factory;
};
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/descriptor/primary_tensor_view.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_view.hpp"
using namespace ngraph;
using namespace std;
runtime::gpu::GPU_TensorView::GPU_TensorView(const ngraph::element::Type& element_type,
const Shape& shape)
: runtime::TensorView(std::make_shared<ngraph::descriptor::PrimaryTensorView>(
std::make_shared<ngraph::TensorViewType>(element_type, shape),
"external",
true,
true,
false))
, m_allocated_buffer_pool(nullptr)
, m_aligned_buffer_pool(nullptr)
{
m_descriptor->set_tensor_view_layout(
std::make_shared<ngraph::descriptor::layout::DenseTensorViewLayout>(*m_descriptor));
m_buffer_size = m_descriptor->get_tensor_view_layout()->get_size() * element_type.size();
if (m_buffer_size > 0)
{
size_t allocation_size = m_buffer_size + runtime::gpu::alignment;
m_allocated_buffer_pool = static_cast<char*>(malloc(allocation_size));
m_aligned_buffer_pool = m_allocated_buffer_pool;
size_t mod = size_t(m_aligned_buffer_pool) % alignment;
if (mod != 0)
{
m_aligned_buffer_pool += (alignment - mod);
}
}
}
runtime::gpu::GPU_TensorView::~GPU_TensorView()
{
if (m_allocated_buffer_pool != nullptr)
{
free(m_allocated_buffer_pool);
}
}
char* runtime::gpu::GPU_TensorView::get_data_ptr()
{
return m_aligned_buffer_pool;
}
const char* runtime::gpu::GPU_TensorView::get_data_ptr() const
{
return m_aligned_buffer_pool;
}
void runtime::gpu::GPU_TensorView::write(const void* source, size_t tensor_offset, size_t n)
{
if (tensor_offset + n > m_buffer_size)
{
throw out_of_range("write access past end of tensor");
}
char* target = get_data_ptr();
}
void runtime::gpu::GPU_TensorView::read(void* target, size_t tensor_offset, size_t n) const
{
if (tensor_offset + n > m_buffer_size)
{
throw out_of_range("read access past end of tensor");
}
const char* source = get_data_ptr();
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPU_TensorView;
}
}
}
class ngraph::runtime::gpu::GPU_TensorView : public ngraph::runtime::TensorView
{
public:
GPU_TensorView(const ngraph::element::Type& element_type, const Shape& shape);
virtual ~GPU_TensorView();
char* get_data_ptr();
const char* get_data_ptr() const;
/// @brief Write bytes directly into the tensor
/// @param p Pointer to source of data
/// @param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned.
/// @param n Number of bytes to write, must be integral number of elements.
void write(const void* p, size_t tensor_offset, size_t n) override;
/// @brief Read bytes directly from the tensor
/// @param p Pointer to destination for data
/// @param tensor_offset Offset into tensor storage to begin reading. Must be element-aligned.
/// @param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t tensor_offset, size_t n) const override;
private:
char* m_allocated_buffer_pool;
char* m_aligned_buffer_pool;
size_t m_buffer_size;
};
......@@ -72,9 +72,11 @@ if(NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR)
set(BACKEND_NAMES ${BACKEND_NAMES} "CPU")
endif()
if(NGRAPH_GPU_ENABLE)
if(NGRAPH_GPU_ENABLE AND LLVM_INCLUDE_DIR)
include_directories(SYSTEM ${LLVM_INCLUDE_DIR})
link_directories(${LLVM_LIB_DIR})
set(SRC
main.cpp
${SRC}
cudnn.cpp)
# Disabled for testing
# set(BACKEND_NAMES ${BACKEND_NAMES} "GPU")
......
......@@ -21,8 +21,244 @@
#include <cuda.h>
#include <cudnn.h>
#include "ngraph/codegen/compiler.hpp"
#include "ngraph/ngraph.hpp"
using namespace ngraph;
using namespace std;
TEST(cudnn, loadTest)
{
auto cudnn_version = cudnnGetVersion();
EXPECT_FLOAT_EQ(cudnn_version, CUDNN_VERSION);
}
TEST(cudnn, compileTest)
{
const auto source = R"###(
#include <cassert>
#include <fstream>
#include <iostream>
#include "cuda.h"
void check_cuda_errors(CUresult err) {
assert(err == CUDA_SUCCESS);
}
/// main - Program entry point
int main(int argc, char **argv) {
CUdevice device;
CUmodule cuda_module;
CUcontext context;
CUfunction function;
CUlinkState linker;
int dev_count;
// CUDA initialization
check_cuda_errors(cuInit(0));
check_cuda_errors(cuDeviceGetCount(&dev_count));
check_cuda_errors(cuDeviceGet(&device, 0));
char name[128];
check_cuda_errors(cuDeviceGetName(name, 128, device));
std::cout << "Using CUDA Device [0]: " << name << "\n";
int dev_major, dev_minor;
check_cuda_errors(cuDeviceComputeCapability(&dev_major, &dev_minor, device));
std::cout << "Device Compute Capability: "
<< dev_major << "." << dev_minor << "\n";
if (dev_major < 2) {
std::cerr << "ERROR: Device 0 is not SM 2.0 or greater\n";
return 1;
}
const auto str = R"(
.version 5.0
.target sm_60
.address_size 64
// .globl _Z7ew_multPfS_S_ // -- Begin function _Z7ew_multPfS_S_
.global .align 1 .b8 threadIdx[1];
// @_Z7ew_multPfS_S_
.visible .entry _Z7ew_multPfS_S_(
.param .u64 _Z7ew_multPfS_S__param_0,
.param .u64 _Z7ew_multPfS_S__param_1,
.param .u64 _Z7ew_multPfS_S__param_2
)
{
.local .align 8 .b8 __local_depot0[24];
.reg .b64 %SP;
.reg .b64 %SPL;
.reg .f32 %f<4>;
.reg .b32 %r<2>;
.reg .b64 %rd<17>;
// BB#0:
mov.u64 %SPL, __local_depot0;
cvta.local.u64 %SP, %SPL;
ld.param.u64 %rd3, [_Z7ew_multPfS_S__param_2];
ld.param.u64 %rd2, [_Z7ew_multPfS_S__param_1];
ld.param.u64 %rd1, [_Z7ew_multPfS_S__param_0];
cvta.to.global.u64 %rd4, %rd3;
cvta.global.u64 %rd5, %rd4;
cvta.to.global.u64 %rd6, %rd2;
cvta.global.u64 %rd7, %rd6;
cvta.to.global.u64 %rd8, %rd1;
cvta.global.u64 %rd9, %rd8;
st.u64 [%SP+0], %rd9;
st.u64 [%SP+8], %rd7;
st.u64 [%SP+16], %rd5;
ld.u64 %rd10, [%SP+0];
mov.u32 %r1, %tid.x;
mul.wide.u32 %rd11, %r1, 4;
add.s64 %rd12, %rd10, %rd11;
ld.f32 %f1, [%rd12];
ld.u64 %rd13, [%SP+8];
add.s64 %rd14, %rd13, %rd11;
ld.f32 %f2, [%rd14];
mul.rn.f32 %f3, %f1, %f2;
ld.u64 %rd15, [%SP+16];
add.s64 %rd16, %rd15, %rd11;
st.f32 [%rd16], %f3;
ret;
}
// -- End function
// .globl _Z6ew_addPfS_S_ // -- Begin function _Z6ew_addPfS_S_
.visible .entry _Z6ew_addPfS_S_(
.param .u64 _Z6ew_addPfS_S__param_0,
.param .u64 _Z6ew_addPfS_S__param_1,
.param .u64 _Z6ew_addPfS_S__param_2
) // @_Z6ew_addPfS_S_
{
.local .align 8 .b8 __local_depot1[24];
.reg .b64 %SP;
.reg .b64 %SPL;
.reg .f32 %f<4>;
.reg .b32 %r<2>;
.reg .b64 %rd<17>;
// BB#0:
mov.u64 %SPL, __local_depot1;
cvta.local.u64 %SP, %SPL;
ld.param.u64 %rd3, [_Z6ew_addPfS_S__param_2];
ld.param.u64 %rd2, [_Z6ew_addPfS_S__param_1];
ld.param.u64 %rd1, [_Z6ew_addPfS_S__param_0];
cvta.to.global.u64 %rd4, %rd3;
cvta.global.u64 %rd5, %rd4;
cvta.to.global.u64 %rd6, %rd2;
cvta.global.u64 %rd7, %rd6;
cvta.to.global.u64 %rd8, %rd1;
cvta.global.u64 %rd9, %rd8;
st.u64 [%SP+0], %rd9;
st.u64 [%SP+8], %rd7;
st.u64 [%SP+16], %rd5;
ld.u64 %rd10, [%SP+0];
mov.u32 %r1, %tid.x;
mul.wide.u32 %rd11, %r1, 4;
add.s64 %rd12, %rd10, %rd11;
ld.f32 %f1, [%rd12];
ld.u64 %rd13, [%SP+8];
add.s64 %rd14, %rd13, %rd11;
ld.f32 %f2, [%rd14];
add.rn.f32 %f3, %f1, %f2;
ld.u64 %rd15, [%SP+16];
add.s64 %rd16, %rd15, %rd11;
st.f32 [%rd16], %f3;
ret;
}
// -- End function
)";
// Create driver context
check_cuda_errors(cuCtxCreate(&context, 0, device));
// Create module for object
check_cuda_errors(cuModuleLoadDataEx(&cuda_module, str, 0, 0, 0));
// Get kernel function
check_cuda_errors(cuModuleGetFunction(&function, cuda_module, "_Z7ew_multPfS_S_"));
// Device data
CUdeviceptr dev_bufferA;
CUdeviceptr dev_bufferB;
CUdeviceptr dev_bufferC;
check_cuda_errors(cuMemAlloc(&dev_bufferA, sizeof(float)*16));
check_cuda_errors(cuMemAlloc(&dev_bufferB, sizeof(float)*16));
check_cuda_errors(cuMemAlloc(&dev_bufferC, sizeof(float)*16));
float* host_A = new float[16];
float* host_B = new float[16];
float* host_C = new float[16];
// Populate input
for (unsigned i = 0; i != 16; ++i) {
host_A[i] = (float)i;
host_B[i] = (float)(2*i);
host_C[i] = 0.0f;
}
check_cuda_errors(cuMemcpyHtoD(dev_bufferA, &host_A[0], sizeof(float)*16));
check_cuda_errors(cuMemcpyHtoD(dev_bufferB, &host_B[0], sizeof(float)*16));
unsigned block_size_X = 16;
unsigned block_size_Y = 1;
unsigned block_size_Z = 1;
unsigned grid_size_X = 1;
unsigned grid_size_Y = 1;
unsigned grid_size_Z = 1;
// Kernel parameters
void *kernel_params[] = { &dev_bufferA, &dev_bufferB, &dev_bufferC };
std::cout << "Launching kernel\n";
// Kernel launch
check_cuda_errors(cuLaunchKernel(function, grid_size_X, grid_size_Y, grid_size_Z,
block_size_X, block_size_Y, block_size_Z,
0, NULL, kernel_params, NULL));
// Retrieve device data
check_cuda_errors(cuMemcpyDtoH(&host_C[0], dev_bufferC, sizeof(float)*16));
std::cout << "Results:\n";
for (unsigned i = 0; i != 16; ++i) {
std::cout << host_A[i] << " + " << host_B[i] << " = " << host_C[i] << "\n";
}
// Clean up after ourselves
delete [] host_A;
delete [] host_B;
delete [] host_C;
// Clean-up
check_cuda_errors(cuMemFree(dev_bufferA));
check_cuda_errors(cuMemFree(dev_bufferB));
check_cuda_errors(cuMemFree(dev_bufferC));
check_cuda_errors(cuModuleUnload(cuda_module));
check_cuda_errors(cuCtxDestroy(context));
return 0;
})###";
codegen::Compiler compiler;
auto module = compiler.compile(source);
}
TEST(cudnn, abc)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, op::Parameters{A, B, C});
auto manager = runtime::Manager::get("GPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment