Commit 1fcfbca7 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Rework utils

parent c36b1b10
...@@ -173,6 +173,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND ...@@ -173,6 +173,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/cpu_tensor_view_wrapper.cpp runtime/cpu/cpu_tensor_view_wrapper.cpp
runtime/cpu/cpu_layout_descriptor.cpp runtime/cpu/cpu_layout_descriptor.cpp
runtime/cpu/cpu_tracing.cpp runtime/cpu/cpu_tracing.cpp
runtime/cpu/mkldnn_emitter.cpp
runtime/cpu/mkldnn_invoke.cpp runtime/cpu/mkldnn_invoke.cpp
runtime/cpu/mkldnn_utils.cpp runtime/cpu/mkldnn_utils.cpp
runtime/cpu/ops/convert_layout.cpp runtime/cpu/ops/convert_layout.cpp
......
...@@ -1780,6 +1780,7 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution) ...@@ -1780,6 +1780,7 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution)
if (!filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 && if (!filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
args[0].get_element_type() == element::f32) args[0].get_element_type() == element::f32)
{ {
#if 0
const string& et = get_mkldnn_data_type(args[0].get_element_type().c_type_string()); const string& et = get_mkldnn_data_type(args[0].get_element_type().c_type_string());
writer << "{\n"; writer << "{\n";
...@@ -1811,6 +1812,10 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution) ...@@ -1811,6 +1812,10 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitConvolution)
<< "s.submit({conv}).wait();\n"; << "s.submit({conv}).wait();\n";
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
#else
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc = mkldnn_emitter->build_memory_descriptor(args[0]);
#endif
} }
else if (filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 && else if (filter_dilated && !data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
args[0].get_element_type() == element::f32) args[0].get_element_type() == element::f32)
......
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "mkldnn_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace ngraph::runtime::cpu;
mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const TensorViewWrapper& tvw)
{
return mkldnn::memory::desc(mkldnn::memory::dims(tvw.get_shape().begin(), tvw.get_shape().end()),
mkldnn_utils::GetDataType(tvw.get_element_type()),
mkldnn::memory::format::nchw);
}
...@@ -26,6 +26,7 @@ namespace ngraph ...@@ -26,6 +26,7 @@ namespace ngraph
namespace cpu namespace cpu
{ {
class CPU_ExternalFunction; class CPU_ExternalFunction;
class TensorViewWrapper;
class MKLDNNEmitter class MKLDNNEmitter
{ {
...@@ -35,7 +36,8 @@ namespace ngraph ...@@ -35,7 +36,8 @@ namespace ngraph
{ {
} }
void build_memory_descriptor(); // TODO(jmenon): Get rid of TensorViewWrappers at some point
mkldnn::memory::desc build_memory_descriptor(const TensorViewWrapper& tvw);
private: private:
std::shared_ptr<CPU_ExternalFunction> external_function; std::shared_ptr<CPU_ExternalFunction> external_function;
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include <string>
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
...@@ -35,7 +37,7 @@ namespace ngraph ...@@ -35,7 +37,7 @@ namespace ngraph
{ {
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
const std::unordered_set<std::type_index> s_op_registry{ static const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::AvgPool), TI(ngraph::op::AvgPool),
TI(ngraph::op::AvgPoolBackprop), TI(ngraph::op::AvgPoolBackprop),
TI(ngraph::op::Convolution), TI(ngraph::op::Convolution),
...@@ -43,6 +45,27 @@ namespace ngraph ...@@ -43,6 +45,27 @@ namespace ngraph
TI(ngraph::op::ConvolutionBackpropFilters), TI(ngraph::op::ConvolutionBackpropFilters),
TI(ngraph::op::MaxPool)}; TI(ngraph::op::MaxPool)};
static const std::unordered_map<std::string, const mkldnn::memory::data_type> s_data_type_map{
{"char", mkldnn::memory::data_type::s8},
{"float", mkldnn::memory::data_type::f32},
{"double", mkldnn::memory::data_type::data_undef},
{"int8_t", mkldnn::memory::data_type::s8},
{"int16_t", mkldnn::memory::data_type::s16},
{"int32_t", mkldnn::memory::data_type::s32},
{"int64_t", mkldnn::memory::data_type::data_undef},
{"uint8_t", mkldnn::memory::data_type::u8},
{"uint16_t", mkldnn::memory::data_type::data_undef},
{"uint32_t", mkldnn::memory::data_type::data_undef},
{"uint64_t", mkldnn::memory::data_type::data_undef}};
mkldnn::memory::data_type GetDataType(const ngraph::element::Type& et)
{
auto it = s_data_type_map.find(et.c_type_string());
if (it == s_data_type_map.end() || it->second == mkldnn::memory::data_type::data_undef)
throw ngraph_error("No MKLDNN data type exists for the given element type");
return it->second;
}
bool IsMKLDNNOp(ngraph::Node& op) bool IsMKLDNNOp(ngraph::Node& op)
{ {
return (s_op_registry.find(TI(op)) != s_op_registry.end()); return (s_op_registry.find(TI(op)) != s_op_registry.end());
......
...@@ -16,14 +16,11 @@ ...@@ -16,14 +16,11 @@
#pragma once #pragma once
#include <typeindex>
#include <typeinfo>
#include <unordered_set>
#include <mkldnn.hpp> #include <mkldnn.hpp>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,6 +32,8 @@ namespace ngraph ...@@ -35,6 +32,8 @@ namespace ngraph
{ {
extern mkldnn::engine global_cpu_engine; extern mkldnn::engine global_cpu_engine;
mkldnn::memory::data_type GetDataType(const ngraph::element::Type& et);
bool IsMKLDNNOp(ngraph::Node& op); bool IsMKLDNNOp(ngraph::Node& op);
mkldnn::memory::format mkldnn::memory::format
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment