Commit 8af5a886 authored by Tristan Webb's avatar Tristan Webb

Changes from review

parent 17ac2f3e
...@@ -92,6 +92,24 @@ if (NOT DEFINED NGRAPH_CPU_ENABLE) ...@@ -92,6 +92,24 @@ if (NOT DEFINED NGRAPH_CPU_ENABLE)
SET(NGRAPH_CPU_ENABLE TRUE) SET(NGRAPH_CPU_ENABLE TRUE)
endif() endif()
#-----------------------------------------------------------------------------------------------
# GPU support
#-----------------------------------------------------------------------------------------------
# Setup CUDA and cuDNN if NGRAPH_GPU_ENABLE=TRUE
if(NGRAPH_GPU_ENABLE)
find_package(CUDA 8 REQUIRED)
find_package(CUDNN 5 QUIET REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIR} ${LLVM_INCLUDE_DIR})
if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND
NOT CMAKE_C_COMPILER_VERSION VERSION_LESS 6.0 AND
CUDA_HOST_COMPILER STREQUAL CMAKE_C_COMPILER)
message(FATAL_ERROR
"CUDA 8.0 is not compatible with GCC version >= 6.\n"
"Please select a correct compiler version\n"
)
endif()
endif()
#----------------------------------------------------------------------------------------------- #-----------------------------------------------------------------------------------------------
# External projects install directory # External projects install directory
#----------------------------------------------------------------------------------------------- #-----------------------------------------------------------------------------------------------
......
...@@ -79,35 +79,6 @@ set (SRC ...@@ -79,35 +79,6 @@ set (SRC
util.cpp util.cpp
) )
if(USE_CUDA)
find_package(CUDA 8 REQUIRED)
if(CUDA_FOUND)
include_directories(SYSTEM ${CUDA_INCLUDE_DIR})
link_directories(${CUDA_LIB_DIR})
if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND
NOT CMAKE_C_COMPILER_VERSION VERSION_LESS 6.0 AND
CUDA_HOST_COMPILER STREQUAL CMAKE_C_COMPILER)
message(FATAL_ERROR
"CUDA 8.0 is not compatible with GCC version >= 6.\n"
"Please select a correct compiler version\n"
)
endif()
find_package(CUDNN 5 QUIET)
if(NOT CUDNN_FOUND)
message(WARNING
"If cuDNN is installed, try setting -DCUDNN_ROOT_DIR"
"Not compiling with CUDA. Suppress this warning with -DUSE_CUDA=OFF")
set(USE_CUDA OFF)
else()
include_directories(SYSTEM ${CUDNN_INCLUDE_DIR} ${LLVM_INCLUDE_DIR})
link_directories(${CUDNN_LIB_DIR} ${LLVM_LIB_DIR})
endif()
else()
message(WARNING "Not compiling with CUDA. Suppress this warning with -DUSE_CUDA=OFF")
set(USE_CUDA OFF)
endif()
endif()
# find_program (GRAPHVIZ dot) # find_program (GRAPHVIZ dot)
# message (STATUS "graphviz '${GRAPHVIZ}'") # message (STATUS "graphviz '${GRAPHVIZ}'")
find_package(Graphviz) find_package(Graphviz)
...@@ -191,7 +162,7 @@ message(STATUS "LIBRARY_OUTPUT_DIRECTORY set to: ${COMMON_LIBRARY_OUTPUT_DIRECTO ...@@ -191,7 +162,7 @@ message(STATUS "LIBRARY_OUTPUT_DIRECTORY set to: ${COMMON_LIBRARY_OUTPUT_DIRECTO
target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}") target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}")
if((NGRAPH_CPU_ENABLE OR USE_CUDA) AND LLVM_LINK_LIBS) if((NGRAPH_CPU_ENABLE OR NGRAPH_GPU_ENABLE) AND LLVM_LINK_LIBS)
target_link_libraries(ngraph PRIVATE ${LLVM_LINK_LIBS}) target_link_libraries(ngraph PRIVATE ${LLVM_LINK_LIBS})
endif() endif()
...@@ -205,6 +176,10 @@ if(NGRAPH_CPU_ENABLE AND MKLDNN_LIB_DIR) ...@@ -205,6 +176,10 @@ if(NGRAPH_CPU_ENABLE AND MKLDNN_LIB_DIR)
target_link_libraries(ngraph PRIVATE mkldnn) target_link_libraries(ngraph PRIVATE mkldnn)
endif() endif()
if(NGRAPH_GPU_ENABLE AND CUDA_LIBRARIES)
target_link_libraries(ngraph PRIVATE ${CUDA_LIBRARIES} ${CUDNN_LIBRARIES})
endif()
#----------------------------------------------------------------------------------------------- #-----------------------------------------------------------------------------------------------
# Installation logic... # Installation logic...
...@@ -256,7 +231,7 @@ endif() ...@@ -256,7 +231,7 @@ endif()
add_dependencies(ngraph eigen) add_dependencies(ngraph eigen)
if((NGRAPH_CPU_ENABLE OR USE_CUDA) AND LLVM_INCLUDE_DIR) if((NGRAPH_CPU_ENABLE OR NGRAPH_GPU_ENABLE) AND LLVM_INCLUDE_DIR)
add_dependencies(ngraph ext_llvm) add_dependencies(ngraph ext_llvm)
endif() endif()
......
...@@ -64,12 +64,12 @@ if(NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR) ...@@ -64,12 +64,12 @@ if(NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR)
set(BACKEND_NAMES ${BACKEND_NAMES} "CPU") set(BACKEND_NAMES ${BACKEND_NAMES} "CPU")
endif() endif()
if(USE_CUDA) if(NGRAPH_GPU_ENABLE)
find_package(CUDA QUIET) set(SRC
include_directories(SYSTEM ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIR} ${LLVM_INCLUDE_DIR}) main.cpp
link_directories(${LLVM_LIB_DIR}) cudnn.cpp)
set(SRC main.cpp cudnn.cpp) # Disabled for testing
set(BACKEND_NAMES) # set(BACKEND_NAMES ${BACKEND_NAMES} "GPU")
endif() endif()
foreach(BACKEND_NAME ${BACKEND_NAMES}) foreach(BACKEND_NAME ${BACKEND_NAMES})
...@@ -93,9 +93,8 @@ if(LLVM_INCLUDE_DIR) ...@@ -93,9 +93,8 @@ if(LLVM_INCLUDE_DIR)
add_dependencies(unit-test ext_llvm) add_dependencies(unit-test ext_llvm)
endif() endif()
if(CUDNN_INCLUDE_DIR) if(CUDA_INCLUDE_DIRS)
target_link_libraries(unit-test ${LLVM_LINK_LIBS}) target_link_libraries(unit-test ${CUDA_LIBRARIES} ${CUDNN_LIBRARIES})
target_link_libraries(unit-test cudnn)
endif() endif()
target_link_libraries(unit-test ngraph libgtest pthread) target_link_libraries(unit-test ngraph libgtest pthread)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment