Commit 7a39c994 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

Merge branch 'jmenon/cpu' into jmenon/cpu_kernels

Conflicts:
	src/ngraph/runtime/cpu/call_frame.cpp
	src/ngraph/runtime/cpu/call_frame.hpp
	src/ngraph/runtime/cpu/external_function.cpp
	test/cpu.cpp
parents 7e001000 1d45c18b
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
cmake_minimum_required (VERSION 2.8) cmake_minimum_required (VERSION 3.1)
set(NGRAPH_INCLUDE_PATH set(NGRAPH_INCLUDE_PATH
${CMAKE_CURRENT_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR}/src
......
...@@ -13,82 +13,49 @@ ...@@ -13,82 +13,49 @@
include(ExternalProject) include(ExternalProject)
find_package(LLVM CONFIG) if((NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") AND
(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows"))
set(LLVM_PACKAGED FALSE) message(STATUS "Fetching LLVM from llvm.org")
set(LLVM_RELEASE_URL http://releases.llvm.org/5.0.0/clang+llvm-5.0.0-linux-x86_64-ubuntu16.04.tar.xz)
if(LLVM_FOUND) set(LLVM_SHA1_HASH 9cb81c92aa4d3f9707a9b8413c4d24b8dee90c59)
if(${LLVM_PACKAGE_VERSION} VERSION_GREATER "4.0")
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") # Override default LLVM binaries
set(LLVM_PACKAGED TRUE) if(PREBUILT_LLVM)
else() if(NOT DEFINED PREBUILT_LLVM_HASH)
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION} but need atleast 4.0") message(FATAL_ERROR "SHA1 hash of prebuilt llvm tarball not provided in PREBUILT_LLVM_HASH.")
set(LLVM_FOUND FALSE)
endif()
endif()
if(NOT LLVM_FOUND)
if((NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") AND
(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows"))
message(STATUS "Fetching LLVM from llvm.org")
set(LLVM_RELEASE_URL http://releases.llvm.org/5.0.0/clang+llvm-5.0.0-linux-x86_64-ubuntu16.04.tar.xz)
# Override default LLVM binaries
if(PREBUILT_LLVM)
set(LLVM_RELEASE_URL ${PREBUILT_LLVM})
endif() endif()
set(LLVM_RELEASE_URL ${PREBUILT_LLVM})
set(LLVM_SHA1_HASH ${PREBUILT_LLVM_HASH})
endif()
# The 'BUILD_BYPRODUCTS' argument was introduced in CMake 3.2. # The 'BUILD_BYPRODUCTS' argument was introduced in CMake 3.2.
if(${CMAKE_VERSION} VERSION_LESS 3.2) if(${CMAKE_VERSION} VERSION_LESS 3.2)
ExternalProject_Add( ExternalProject_Add(
ext_llvm ext_llvm
URL ${LLVM_RELEASE_URL} URL ${LLVM_RELEASE_URL}
CONFIGURE_COMMAND "" URL_HASH SHA1=${LLVM_SHA1_HASH}
BUILD_COMMAND "" CONFIGURE_COMMAND ""
INSTALL_COMMAND "" BUILD_COMMAND ""
UPDATE_COMMAND "" INSTALL_COMMAND ""
) UPDATE_COMMAND ""
else() )
ExternalProject_Add( else()
ext_llvm ExternalProject_Add(
URL ${LLVM_RELEASE_URL} ext_llvm
CONFIGURE_COMMAND "" URL ${LLVM_RELEASE_URL}
BUILD_COMMAND "" URL_HASH SHA1=${LLVM_SHA1_HASH}
INSTALL_COMMAND "" CONFIGURE_COMMAND ""
UPDATE_COMMAND "" BUILD_COMMAND ""
BUILD_BYPRODUCTS "${CMAKE_CURRENT_BINARY_DIR}/ext_llvm-prefix/src/ext_llvm/lib/libLLVMCore.a" INSTALL_COMMAND ""
) UPDATE_COMMAND ""
endif() BUILD_BYPRODUCTS "${CMAKE_CURRENT_BINARY_DIR}/ext_llvm-prefix/src/ext_llvm/lib/libLLVMCore.a"
)
ExternalProject_Get_Property(ext_llvm source_dir)
set(LLVM_INCLUDE_DIR "${source_dir}/include")
set(LLVM_LIB_DIR "${source_dir}/lib")
set(LLVM_LINK_LIBS clangTooling clangFrontendTool clangFrontend clangDriver clangSerialization clangCodeGen clangParse clangSema clangStaticAnalyzerFrontend clangStaticAnalyzerCheckers clangStaticAnalyzerCore clangAnalysis clangARCMigrate clangRewriteFrontend clangEdit clangAST clangLex clangBasic LLVMLTO LLVMPasses LLVMObjCARCOpts LLVMSymbolize LLVMDebugInfoPDB LLVMDebugInfoDWARF LLVMMIRParser LLVMCoverage LLVMTableGen LLVMDlltoolDriver LLVMOrcJIT LLVMXCoreDisassembler LLVMXCoreCodeGen LLVMXCoreDesc LLVMXCoreInfo LLVMXCoreAsmPrinter LLVMSystemZDisassembler LLVMSystemZCodeGen LLVMSystemZAsmParser LLVMSystemZDesc LLVMSystemZInfo LLVMSystemZAsmPrinter LLVMSparcDisassembler LLVMSparcCodeGen LLVMSparcAsmParser LLVMSparcDesc LLVMSparcInfo LLVMSparcAsmPrinter LLVMPowerPCDisassembler LLVMPowerPCCodeGen LLVMPowerPCAsmParser LLVMPowerPCDesc LLVMPowerPCInfo LLVMPowerPCAsmPrinter LLVMNVPTXCodeGen LLVMNVPTXDesc LLVMNVPTXInfo LLVMNVPTXAsmPrinter LLVMMSP430CodeGen LLVMMSP430Desc LLVMMSP430Info LLVMMSP430AsmPrinter LLVMMipsDisassembler LLVMMipsCodeGen LLVMMipsAsmParser LLVMMipsDesc LLVMMipsInfo LLVMMipsAsmPrinter LLVMLanaiDisassembler LLVMLanaiCodeGen LLVMLanaiAsmParser LLVMLanaiDesc LLVMLanaiAsmPrinter LLVMLanaiInfo LLVMHexagonDisassembler LLVMHexagonCodeGen LLVMHexagonAsmParser LLVMHexagonDesc LLVMHexagonInfo LLVMBPFDisassembler LLVMBPFCodeGen LLVMBPFDesc LLVMBPFInfo LLVMBPFAsmPrinter LLVMARMDisassembler LLVMARMCodeGen LLVMARMAsmParser LLVMARMDesc LLVMARMInfo LLVMARMAsmPrinter LLVMAMDGPUDisassembler LLVMAMDGPUCodeGen LLVMAMDGPUAsmParser LLVMAMDGPUDesc LLVMAMDGPUInfo LLVMAMDGPUAsmPrinter LLVMAMDGPUUtils LLVMAArch64Disassembler LLVMAArch64CodeGen LLVMAArch64AsmParser LLVMAArch64Desc LLVMAArch64Info LLVMAArch64AsmPrinter LLVMAArch64Utils LLVMObjectYAML LLVMLibDriver LLVMOption LLVMX86Disassembler LLVMX86AsmParser LLVMX86CodeGen LLVMGlobalISel LLVMSelectionDAG LLVMAsmPrinter LLVMDebugInfoCodeView LLVMDebugInfoMSF LLVMX86Desc LLVMMCDisassembler LLVMX86Info LLVMX86AsmPrinter LLVMX86Utils LLVMMCJIT LLVMLineEditor LLVMInterpreter LLVMExecutionEngine LLVMRuntimeDyld LLVMCodeGen LLVMTarget LLVMCoroutines LLVMipo LLVMInstrumentation LLVMVectorize LLVMScalarOpts LLVMLinker LLVMIRReader LLVMAsmParser LLVMInstCombine LLVMTransformUtils LLVMBitWriter LLVMAnalysis LLVMProfileData LLVMObject LLVMMCParser LLVMMC LLVMBitReader LLVMCore LLVMBinaryFormat LLVMSupport LLVMDemangle tinfo z m)
set(LLVM_FOUND TRUE)
set(Clang_FOUND TRUE)
endif() endif()
endif()
if(LLVM_FOUND AND NOT Clang_FOUND) ExternalProject_Get_Property(ext_llvm source_dir)
find_package(Clang CONFIG) set(LLVM_INCLUDE_DIR "${source_dir}/include" PARENT_SCOPE)
endif() set(LLVM_LIB_DIR "${source_dir}/lib" PARENT_SCOPE)
# TODO: Figure out if this terminates the build or do we allow interpretation-only builds set(LLVM_LINK_LIBS clangTooling clangFrontendTool clangFrontend clangDriver clangSerialization clangCodeGen clangParse clangSema clangStaticAnalyzerFrontend clangStaticAnalyzerCheckers clangStaticAnalyzerCore clangAnalysis clangARCMigrate clangRewriteFrontend clangEdit clangAST clangLex clangBasic LLVMLTO LLVMPasses LLVMObjCARCOpts LLVMSymbolize LLVMDebugInfoPDB LLVMDebugInfoDWARF LLVMMIRParser LLVMCoverage LLVMTableGen LLVMDlltoolDriver LLVMOrcJIT LLVMXCoreDisassembler LLVMXCoreCodeGen LLVMXCoreDesc LLVMXCoreInfo LLVMXCoreAsmPrinter LLVMSystemZDisassembler LLVMSystemZCodeGen LLVMSystemZAsmParser LLVMSystemZDesc LLVMSystemZInfo LLVMSystemZAsmPrinter LLVMSparcDisassembler LLVMSparcCodeGen LLVMSparcAsmParser LLVMSparcDesc LLVMSparcInfo LLVMSparcAsmPrinter LLVMPowerPCDisassembler LLVMPowerPCCodeGen LLVMPowerPCAsmParser LLVMPowerPCDesc LLVMPowerPCInfo LLVMPowerPCAsmPrinter LLVMNVPTXCodeGen LLVMNVPTXDesc LLVMNVPTXInfo LLVMNVPTXAsmPrinter LLVMMSP430CodeGen LLVMMSP430Desc LLVMMSP430Info LLVMMSP430AsmPrinter LLVMMipsDisassembler LLVMMipsCodeGen LLVMMipsAsmParser LLVMMipsDesc LLVMMipsInfo LLVMMipsAsmPrinter LLVMLanaiDisassembler LLVMLanaiCodeGen LLVMLanaiAsmParser LLVMLanaiDesc LLVMLanaiAsmPrinter LLVMLanaiInfo LLVMHexagonDisassembler LLVMHexagonCodeGen LLVMHexagonAsmParser LLVMHexagonDesc LLVMHexagonInfo LLVMBPFDisassembler LLVMBPFCodeGen LLVMBPFDesc LLVMBPFInfo LLVMBPFAsmPrinter LLVMARMDisassembler LLVMARMCodeGen LLVMARMAsmParser LLVMARMDesc LLVMARMInfo LLVMARMAsmPrinter LLVMAMDGPUDisassembler LLVMAMDGPUCodeGen LLVMAMDGPUAsmParser LLVMAMDGPUDesc LLVMAMDGPUInfo LLVMAMDGPUAsmPrinter LLVMAMDGPUUtils LLVMAArch64Disassembler LLVMAArch64CodeGen LLVMAArch64AsmParser LLVMAArch64Desc LLVMAArch64Info LLVMAArch64AsmPrinter LLVMAArch64Utils LLVMObjectYAML LLVMLibDriver LLVMOption LLVMX86Disassembler LLVMX86AsmParser LLVMX86CodeGen LLVMGlobalISel LLVMSelectionDAG LLVMAsmPrinter LLVMDebugInfoCodeView LLVMDebugInfoMSF LLVMX86Desc LLVMMCDisassembler LLVMX86Info LLVMX86AsmPrinter LLVMX86Utils LLVMMCJIT LLVMLineEditor LLVMInterpreter LLVMExecutionEngine LLVMRuntimeDyld LLVMCodeGen LLVMTarget LLVMCoroutines LLVMipo LLVMInstrumentation LLVMVectorize LLVMScalarOpts LLVMLinker LLVMIRReader LLVMAsmParser LLVMInstCombine LLVMTransformUtils LLVMBitWriter LLVMAnalysis LLVMProfileData LLVMObject LLVMMCParser LLVMMC LLVMBitReader LLVMCore LLVMBinaryFormat LLVMSupport LLVMDemangle tinfo z m PARENT_SCOPE)
#if(NOT LLVM_FOUND OR NOT Clang_FOUND)
#endif()
# Populate header and library paths from package-exported info
# if we found system-level LLVM and Clang packages
if(LLVM_FOUND AND NOT LLVM_INCLUDE_DIR)
set(LLVM_INCLUDE_DIR ${LLVM_INCLUDE_DIRS})
set(LLVM_LIB_DIR ${LLVM_LIBRARY_DIRS})
llvm_map_components_to_libnames(llvm_libs support core engine)
set(LLVM_LINK_LIBS ${llvm_libs})
endif() endif()
# Export all necessary info
set(LLVM_INCLUDE_DIR ${LLVM_INCLUDE_DIR} PARENT_SCOPE)
set(LLVM_LIB_DIR ${LLVM_LIB_DIR} PARENT_SCOPE)
set(LLVM_LINK_LIBS ${LLVM_LINK_LIBS} PARENT_SCOPE)
set(LLVM_PACKAGED ${LLVM_PACKAGED} PARENT_SCOPE)
...@@ -33,7 +33,7 @@ void Reshape::propagate_types() ...@@ -33,7 +33,7 @@ void Reshape::propagate_types()
throw ngraph_error("Argument to reshape is missing type."); throw ngraph_error("Argument to reshape is missing type.");
} }
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type); auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_type) if (nullptr == arg_tensor_view_type)
{ {
throw ngraph_error("Argument to reshape is not a tensor view"); throw ngraph_error("Argument to reshape is not a tensor view");
} }
......
...@@ -20,32 +20,31 @@ using namespace std; ...@@ -20,32 +20,31 @@ using namespace std;
using namespace ngraph::runtime::cpu; using namespace ngraph::runtime::cpu;
CallFrame::CallFrame(EntryPoint compiled_function, CallFrame::CallFrame(EntryPoint compiled_function,
size_t n_inputs,
size_t n_outputs, size_t n_outputs,
size_t n_inputs,
const TensorViewPtrs& temps, const TensorViewPtrs& temps,
const std::vector<std::shared_ptr<CallFrame>>& callees) const std::vector<std::shared_ptr<CallFrame>>& callees)
: m_n_outputs(n_outputs)
: m_n_inputs(n_inputs) , m_n_inputs(n_inputs)
, m_n_outputs(n_outputs) , m_tensor_views(n_outputs + n_inputs + temps.size())
, m_tensor_views(n_inputs + n_outputs + temps.size())
, m_compiled_function(compiled_function) , m_compiled_function(compiled_function)
, m_callees(callees) , m_callees(callees)
{ {
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs); copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_outputs + m_n_inputs);
} }
void CallFrame::tensor_call( void CallFrame::tensor_call(
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs, const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs) const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs)
{ {
copy(inputs.begin(), inputs.end(), m_tensor_views.begin()); copy(outputs.begin(), outputs.end(), m_tensor_views.begin());
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs); copy(inputs.begin(), inputs.end(), m_tensor_views.begin() + m_n_outputs);
// Invoke compiled computation // Invoke compiled computation
m_compiled_function(this, m_tensor_views, m_callees); m_compiled_function(this, m_tensor_views, m_callees);
// Don't hold onto inputs/outputs // Don't hold onto inputs/outputs
fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr); fill_n(m_tensor_views.begin(), m_n_outputs + m_n_inputs, nullptr);
} }
void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments, void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments,
......
...@@ -41,8 +41,8 @@ namespace ngraph ...@@ -41,8 +41,8 @@ namespace ngraph
{ {
public: public:
CallFrame(EntryPoint compiled_function, CallFrame(EntryPoint compiled_function,
size_t n_inputs,
size_t n_outputs, size_t n_outputs,
size_t n_inputs,
const TensorViewPtrs& temps, const TensorViewPtrs& temps,
const std::vector<std::shared_ptr<CallFrame>>& callees); const std::vector<std::shared_ptr<CallFrame>>& callees);
...@@ -71,8 +71,8 @@ namespace ngraph ...@@ -71,8 +71,8 @@ namespace ngraph
} }
protected: protected:
size_t m_n_inputs;
size_t m_n_outputs; size_t m_n_outputs;
size_t m_n_inputs;
TensorViewPtrs m_tensor_views; TensorViewPtrs m_tensor_views;
bool m_return; bool m_return;
EntryPoint m_compiled_function; EntryPoint m_compiled_function;
......
...@@ -153,7 +153,17 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -153,7 +153,17 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Determine tensor requirements for the call frame // Determine tensor requirements for the call frame
unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index; unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index;
// First come the function inputs
// First come the function outputs
for (const descriptor::Output& output : m_function->get_result()->get_outputs())
{
auto tv = output.get_tensor_view();
size_t index = tensor_index.size();
tensor_index[tv] = index;
}
m_n_outputs = tensor_index.size();
// Next are the function inputs
for (auto param : m_function->get_parameters()) for (auto param : m_function->get_parameters())
{ {
for (const descriptor::Output& output : param->get_outputs()) for (const descriptor::Output& output : param->get_outputs())
...@@ -163,16 +173,7 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -163,16 +173,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
tensor_index[tv] = index; tensor_index[tv] = index;
} }
} }
m_n_inputs = tensor_index.size(); m_n_inputs = tensor_index.size() - m_n_outputs;
// Next are the function outputs
for (const descriptor::Output& output : m_function->get_result()->get_outputs())
{
auto tv = output.get_tensor_view();
size_t index = tensor_index.size();
tensor_index[tv] = index;
}
m_n_outputs = tensor_index.size() - m_n_inputs;
// All remaining tensor views // All remaining tensor views
for (shared_ptr<Node> node : m_function->get_ordered_ops()) for (shared_ptr<Node> node : m_function->get_ordered_ops())
...@@ -344,5 +345,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame() ...@@ -344,5 +345,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame()
#undef M #undef M
} }
return make_shared<ngraph::runtime::cpu::CallFrame>( return make_shared<ngraph::runtime::cpu::CallFrame>(
compiled_function, m_n_inputs, m_n_outputs, temps, callees); compiled_function, m_n_outputs, m_n_inputs, temps, callees);
} }
...@@ -24,7 +24,6 @@ include_directories( ...@@ -24,7 +24,6 @@ include_directories(
set (SRC set (SRC
build_graph.cpp build_graph.cpp
eigen.cpp eigen.cpp
execute.cpp
input_output_assign.cpp input_output_assign.cpp
main.cpp main.cpp
op.cpp op.cpp
...@@ -42,6 +41,14 @@ set (SRC ...@@ -42,6 +41,14 @@ set (SRC
uuid.cpp uuid.cpp
) )
#================================================================================================
# To auto generate a suite of unit tests for a backend add a line like this
# set(BACKEND_NAMES ${BACKEND_NAMES} "BACKEND_NAME_GOES_HERE")
# and replace BACKEND_NAME_GOES_HERE with your backend name.
# The code for the unit test suite is in test/backend_test.in.cpp
#================================================================================================
set(BACKEND_NAMES ${BACKEND_NAMES} "NGVM")
if(MKLDNN_INCLUDE_DIR) if(MKLDNN_INCLUDE_DIR)
include_directories(SYSTEM ${MKLDNN_INCLUDE_DIR}) include_directories(SYSTEM ${MKLDNN_INCLUDE_DIR})
link_directories(${MKLDNN_LIB_DIR}) link_directories(${MKLDNN_LIB_DIR})
...@@ -50,9 +57,16 @@ endif() ...@@ -50,9 +57,16 @@ endif()
if(LLVM_INCLUDE_DIR) if(LLVM_INCLUDE_DIR)
include_directories(SYSTEM ${LLVM_INCLUDE_DIR}) include_directories(SYSTEM ${LLVM_INCLUDE_DIR})
set(SRC ${SRC} codegen.cpp cpu.cpp) set(SRC ${SRC} codegen.cpp)
set(BACKEND_NAMES ${BACKEND_NAMES} "CPU")
endif() endif()
foreach(BACKEND_NAME ${BACKEND_NAMES})
configure_file(backend_test.in.cpp backend_test_${BACKEND_NAME}.cpp)
set(SRC ${SRC} ${CMAKE_CURRENT_BINARY_DIR}/backend_test_${BACKEND_NAME}.cpp)
message(STATUS "Adding unit test for backend ${BACKEND_NAME}")
endforeach()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCURDIR=\\\"${CMAKE_CURRENT_SOURCE_DIR}\\\"") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCURDIR=\\\"${CMAKE_CURRENT_SOURCE_DIR}\\\"")
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment