Commit 3289bcbb authored by gaurides's avatar gaurides Committed by Robert Kimball

Gauri/static initialization problem r08 cherry-pick from #1691 (#1702)

parent 8caa2717
......@@ -42,7 +42,9 @@ using namespace std;
#define TI(x) std::type_index(typeid(x))
static const std::unordered_set<std::type_index> s_op_registry{
std::unordered_set<std::type_index>& runtime::cpu::mkldnn_utils::get_op_registry()
{
static std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::Add),
TI(ngraph::op::AvgPool),
TI(ngraph::op::AvgPoolBackprop),
......@@ -60,9 +62,14 @@ static const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::Relu),
TI(ngraph::op::ReluBackprop),
TI(ngraph::op::Reshape)};
return s_op_registry;
}
// Mapping from POD types to MKLDNN data types
static const std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_data_type_map{
std::map<element::Type, const mkldnn::memory::data_type>&
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_map()
{
// Mapping from POD types to MKLDNN data types
static std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_data_type_map = {
{element::boolean, mkldnn::memory::data_type::s8},
{element::f32, mkldnn::memory::data_type::f32},
{element::f64, mkldnn::memory::data_type::data_undef},
......@@ -73,9 +80,15 @@ static const std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_d
{element::u8, mkldnn::memory::data_type::u8},
{element::u16, mkldnn::memory::data_type::data_undef},
{element::u32, mkldnn::memory::data_type::data_undef},
{element::u64, mkldnn::memory::data_type::data_undef}};
{element::u64, mkldnn::memory::data_type::data_undef},
};
return s_mkldnn_data_type_map;
}
static const std::map<element::Type, const std::string> s_mkldnn_data_type_string_map{
std::map<element::Type, const std::string>&
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string_map()
{
static std::map<element::Type, const std::string> s_mkldnn_data_type_string_map{
{element::boolean, "mkldnn::memory::data_type::s8"},
{element::f32, "mkldnn::memory::data_type::f32"},
{element::f64, "mkldnn::memory::data_type::data_undef"},
......@@ -87,9 +100,14 @@ static const std::map<element::Type, const std::string> s_mkldnn_data_type_strin
{element::u16, "mkldnn::memory::data_type::data_undef"},
{element::u32, "mkldnn::memory::data_type::data_undef"},
{element::u64, "mkldnn::memory::data_type::data_undef"}};
return s_mkldnn_data_type_string_map;
}
// TODO (jbobba): Add the rest of memory formats to this map as well
static const std::map<memory::format, const std::string> s_mkldnn_format_string_map{
std::map<memory::format, const std::string>&
runtime::cpu::mkldnn_utils::get_mkldnn_format_string_map()
{
// TODO (jbobba): Add the rest of memory formats to this map as well
static std::map<memory::format, const std::string> s_mkldnn_format_string_map{
{memory::format::format_undef, "memory::format::format_undef"},
{memory::format::any, "memory::format::any"},
{memory::format::blocked, "memory::format::blocked"},
......@@ -131,11 +149,13 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_
{memory::format::ldsnc, "memory::format::ldsnc"},
{memory::format::ldigo, "memory::format::ldigo"},
{memory::format::ldgo, "memory::format::ldgo"},
{memory::format::ldgo, "memory::format::Goihw8g"},
{memory::format::ldgo, "memory::format::Goihw16g"},
};
};
return s_mkldnn_format_string_map;
}
static const std::set<memory::format> s_filter_formats{
std::set<memory::format>& runtime::cpu::mkldnn_utils::get_filter_formats()
{
static std::set<memory::format> s_filter_formats{
memory::format::oihw,
memory::format::ihwo,
memory::format::hwio,
......@@ -158,10 +178,11 @@ static const std::set<memory::format> s_filter_formats{
memory::format::Ohwi8o,
memory::format::Ohwi16o,
memory::format::OhIw16o4i};
return s_filter_formats;
}
bool runtime::cpu::mkldnn_utils::IsMKLDNNOp(ngraph::Node& op)
{
return (s_op_registry.find(TI(op)) != s_op_registry.end());
return (get_op_registry().find(TI(op)) != get_op_registry().end());
}
mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(
......@@ -185,27 +206,31 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(const
const std::string&
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(const ngraph::element::Type& type)
{
auto it = s_mkldnn_data_type_string_map.find(type);
if (it == s_mkldnn_data_type_string_map.end() || it->second.empty())
throw ngraph_error("No MKLDNN data type exists for the given element type");
auto it = get_mkldnn_data_type_string_map().find(type);
if (it == get_mkldnn_data_type_string_map().end() || it->second.empty())
{
throw ngraph_error("No MKLDNN data type exists for the given element type" +
type.c_type_string());
}
return it->second;
}
mkldnn::memory::data_type
runtime::cpu::mkldnn_utils::get_mkldnn_data_type(const ngraph::element::Type& type)
{
auto it = s_mkldnn_data_type_map.find(type);
if (it == s_mkldnn_data_type_map.end())
auto it = get_mkldnn_data_type_map().find(type);
if (it == get_mkldnn_data_type_map().end())
{
throw ngraph_error("No MKLDNN data type exists for the given element type");
throw ngraph_error("No MKLDNN data type exists for the given element type" +
type.c_type_string());
}
return it->second;
}
const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory::format fmt)
{
auto it = s_mkldnn_format_string_map.find(fmt);
if (it == s_mkldnn_format_string_map.end())
auto it = get_mkldnn_format_string_map().find(fmt);
if (it == get_mkldnn_format_string_map().end())
throw ngraph_error("No MKLDNN format exists for the given format type " +
std::to_string(fmt));
return it->second;
......@@ -252,12 +277,13 @@ bool runtime::cpu::mkldnn_utils::can_create_mkldnn_md(const Shape& dims,
const Strides& strides,
const ngraph::element::Type type)
{
auto it = s_mkldnn_data_type_map.find(type);
auto it = get_mkldnn_data_type_map().find(type);
if (dims.size() == 0)
{
return false;
}
if (it == s_mkldnn_data_type_map.end() || it->second == mkldnn::memory::data_type::data_undef)
if (it == get_mkldnn_data_type_map().end() ||
it->second == mkldnn::memory::data_type::data_undef)
{
return false;
}
......@@ -452,7 +478,7 @@ bool runtime::cpu::mkldnn_utils::compare_mkldnn_mds(const mkldnn::memory::desc&
bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(mkldnn::memory::format fmt)
{
if (s_filter_formats.find(fmt) != s_filter_formats.end())
if (get_filter_formats().find(fmt) != get_filter_formats().end())
{
return true;
}
......
......@@ -62,6 +62,12 @@ namespace ngraph
const mkldnn::memory::desc& rhs);
bool is_mkldnn_filter_format(mkldnn::memory::format fmt);
bool is_mkldnn_blocked_data_format(mkldnn::memory::format fmt);
std::unordered_set<std::type_index>& get_op_registry();
std::map<element::Type, const mkldnn::memory::data_type>&
get_mkldnn_data_type_map();
std::map<element::Type, const std::string>& get_mkldnn_data_type_string_map();
std::map<mkldnn::memory::format, const std::string>& get_mkldnn_format_string_map();
std::set<mkldnn::memory::format>& get_filter_formats();
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment