Commit 40036d54 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Map Dot2D to MKL SGEMM

parent a48a75d5
......@@ -98,16 +98,17 @@ include_directories(
"${EIGEN_INCLUDE_DIR}"
)
if(LLVM_INCLUDE_DIR)
include_directories(SYSTEM ${LLVM_INCLUDE_DIR})
link_directories(${LLVM_LIB_DIR})
if(LLVM_INCLUDE_DIR AND MKLDNN_INCLUDE_DIR)
include_directories(SYSTEM ${LLVM_INCLUDE_DIR} ${MKLDNN_INCLUDE_DIR})
link_directories(${LLVM_LIB_DIR} ${MKLDNN_LIB_DIR})
# Add sources for the CPU backend
# and all its dependencies
set(SRC ${SRC}
codegen/compiler.cpp
runtime/cpu/call_frame.cpp
runtime/cpu/cpu_manager.cpp
runtime/cpu/cpu_backend.cpp
runtime/cpu/cpu_manager.cpp
runtime/cpu/cpu_kernels.cpp
runtime/cpu/emitter.cpp
runtime/cpu/external_function.cpp
)
......@@ -132,8 +133,10 @@ if (APPLE)
set_property(TARGET ngraph PROPERTY PREFIX "lib")
set_property(TARGET ngraph PROPERTY OUTPUT_NAME "ngraph.so")
set_property(TARGET ngraph PROPERTY SUFFIX "")
else()
include_directories("${MKLDNN_INCLUDE_DIR}")
endif()
if(MKLDNN_LIB_DIR)
target_link_libraries(ngraph LINK_PRIVATE mkldnn)
endif()
#-----------------------------------------------------------------------------------------------
......@@ -194,3 +197,7 @@ add_dependencies(ngraph eigen)
if(NOT LLVM_PACKAGED AND LLVM_INCLUDE_DIR)
add_dependencies(ngraph ext_llvm)
endif()
if(MKLDNN_INCLUDE_DIR)
add_dependencies(ngraph ext_mkldnn)
endif()
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/types/element_type.hpp"
// CBLAS types and wrappers
namespace cblas
{
enum class Layout
{
RowMajor = 101,
ColMajor = 102
};
enum class Transpose
{
None = 111,
Transpose = 112,
ConjTrans = 113
};
enum class UpperLower
{
Upper = 121,
Lower = 122
};
enum class Diag
{
NonUnit = 131,
Unit = 132
};
enum class Side
{
Left = 141,
Right = 142
};
enum class Storage
{
Packed = 151
};
enum class Ident
{
AMatrix = 161,
BMatrix = 162
};
enum class Offset
{
RowOffset = 171,
ColOffset = 172,
FixOffset = 173
};
extern "C" {
void cblas_sgemm(const Layout layout,
const Transpose TransA,
const Transpose TransB,
const ngraph::element::Int64::type M,
const ngraph::element::Int64::type N,
const ngraph::element::Int64::type K,
const ngraph::element::Float32::type alpha,
const ngraph::element::Float32::type* A,
const ngraph::element::Int64::type lda,
const ngraph::element::Float32::type* B,
const ngraph::element::Int64::type ldb,
const ngraph::element::Float32::type beta,
ngraph::element::Float32::type* C,
const ngraph::element::Int64::type ldc);
}
}
......@@ -17,6 +17,7 @@
#include <typeindex>
#include <unordered_map>
#include <vector>
#include <algorithm>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/node.hpp"
......@@ -177,6 +178,28 @@ void Emitter::EMITTER_DECL(EmitDot)
auto arg1_layout = inputs[1].get_layout<DenseTensorViewLayout>();
auto out_layout = outputs[0].get_layout<DenseTensorViewLayout>();
// Emit an MKL SGEMM call if possible
if (arg0_element_type == ngraph::element::Float32::element_type())
{
TU +=
" {\n"
" auto arg0 = call_frame->get_tensor_view_data<" +
element_type_names[TI(arg0_element_type)] + ">(" + to_string(inputs[0].get_index()) +
");\n"
" auto arg1 = call_frame->get_tensor_view_data<" +
element_type_names[TI(arg0_element_type)] + ">(" + to_string(inputs[1].get_index()) +
");\n"
" auto out = call_frame->get_tensor_view_data<" +
element_type_names[TI(arg0_element_type)] + ">(" + to_string(outputs[0].get_index()) +
");\n"
" cblas::cblas_sgemm(cblas::Layout::RowMajor, cblas::Transpose::None, cblas::Transpose::None, " +
to_string(arg0_shape[0]) + ", " + to_string(arg1_shape[1]) + ", " + to_string(arg0_shape[1]) + ",\n"
" 1.0f, arg0, " + to_string(max(1UL, arg0_shape[1])) + ", arg1, " + to_string(max(1UL, arg1_shape[1])) + ", 0.0f,\n"
" out, " + to_string(max(1UL, arg1_shape[1])) + ");\n"
" }\n";
}
else
{
TU +=
" {\n"
" auto arg0 = call_frame->get_tensor_view_data<" +
......@@ -202,6 +225,7 @@ void Emitter::EMITTER_DECL(EmitDot)
");\n"
" }\n";
}
}
else
{
throw ngraph_error("Dot product not implemented for given inputs");
......
......@@ -195,9 +195,10 @@ void ExternalFunction::compile(FunctionMap& function_map)
#include <Eigen/Dense>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/eigen_utils.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
void *__dso_handle = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment