Commit 3289bcbb authored by gaurides's avatar gaurides Committed by Robert Kimball

Gauri/static initialization problem r08 cherry-pick from #1691 (#1702)

parent 8caa2717
...@@ -42,7 +42,9 @@ using namespace std; ...@@ -42,7 +42,9 @@ using namespace std;
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
static const std::unordered_set<std::type_index> s_op_registry{ std::unordered_set<std::type_index>& runtime::cpu::mkldnn_utils::get_op_registry()
{
static std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::Add), TI(ngraph::op::Add),
TI(ngraph::op::AvgPool), TI(ngraph::op::AvgPool),
TI(ngraph::op::AvgPoolBackprop), TI(ngraph::op::AvgPoolBackprop),
...@@ -60,9 +62,14 @@ static const std::unordered_set<std::type_index> s_op_registry{ ...@@ -60,9 +62,14 @@ static const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::Relu), TI(ngraph::op::Relu),
TI(ngraph::op::ReluBackprop), TI(ngraph::op::ReluBackprop),
TI(ngraph::op::Reshape)}; TI(ngraph::op::Reshape)};
return s_op_registry;
}
// Mapping from POD types to MKLDNN data types std::map<element::Type, const mkldnn::memory::data_type>&
static const std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_data_type_map{ runtime::cpu::mkldnn_utils::get_mkldnn_data_type_map()
{
// Mapping from POD types to MKLDNN data types
static std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_data_type_map = {
{element::boolean, mkldnn::memory::data_type::s8}, {element::boolean, mkldnn::memory::data_type::s8},
{element::f32, mkldnn::memory::data_type::f32}, {element::f32, mkldnn::memory::data_type::f32},
{element::f64, mkldnn::memory::data_type::data_undef}, {element::f64, mkldnn::memory::data_type::data_undef},
...@@ -73,9 +80,15 @@ static const std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_d ...@@ -73,9 +80,15 @@ static const std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_d
{element::u8, mkldnn::memory::data_type::u8}, {element::u8, mkldnn::memory::data_type::u8},
{element::u16, mkldnn::memory::data_type::data_undef}, {element::u16, mkldnn::memory::data_type::data_undef},
{element::u32, mkldnn::memory::data_type::data_undef}, {element::u32, mkldnn::memory::data_type::data_undef},
{element::u64, mkldnn::memory::data_type::data_undef}}; {element::u64, mkldnn::memory::data_type::data_undef},
};
return s_mkldnn_data_type_map;
}
static const std::map<element::Type, const std::string> s_mkldnn_data_type_string_map{ std::map<element::Type, const std::string>&
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string_map()
{
static std::map<element::Type, const std::string> s_mkldnn_data_type_string_map{
{element::boolean, "mkldnn::memory::data_type::s8"}, {element::boolean, "mkldnn::memory::data_type::s8"},
{element::f32, "mkldnn::memory::data_type::f32"}, {element::f32, "mkldnn::memory::data_type::f32"},
{element::f64, "mkldnn::memory::data_type::data_undef"}, {element::f64, "mkldnn::memory::data_type::data_undef"},
...@@ -87,9 +100,14 @@ static const std::map<element::Type, const std::string> s_mkldnn_data_type_strin ...@@ -87,9 +100,14 @@ static const std::map<element::Type, const std::string> s_mkldnn_data_type_strin
{element::u16, "mkldnn::memory::data_type::data_undef"}, {element::u16, "mkldnn::memory::data_type::data_undef"},
{element::u32, "mkldnn::memory::data_type::data_undef"}, {element::u32, "mkldnn::memory::data_type::data_undef"},
{element::u64, "mkldnn::memory::data_type::data_undef"}}; {element::u64, "mkldnn::memory::data_type::data_undef"}};
return s_mkldnn_data_type_string_map;
}
// TODO (jbobba): Add the rest of memory formats to this map as well std::map<memory::format, const std::string>&
static const std::map<memory::format, const std::string> s_mkldnn_format_string_map{ runtime::cpu::mkldnn_utils::get_mkldnn_format_string_map()
{
// TODO (jbobba): Add the rest of memory formats to this map as well
static std::map<memory::format, const std::string> s_mkldnn_format_string_map{
{memory::format::format_undef, "memory::format::format_undef"}, {memory::format::format_undef, "memory::format::format_undef"},
{memory::format::any, "memory::format::any"}, {memory::format::any, "memory::format::any"},
{memory::format::blocked, "memory::format::blocked"}, {memory::format::blocked, "memory::format::blocked"},
...@@ -131,11 +149,13 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_ ...@@ -131,11 +149,13 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_
{memory::format::ldsnc, "memory::format::ldsnc"}, {memory::format::ldsnc, "memory::format::ldsnc"},
{memory::format::ldigo, "memory::format::ldigo"}, {memory::format::ldigo, "memory::format::ldigo"},
{memory::format::ldgo, "memory::format::ldgo"}, {memory::format::ldgo, "memory::format::ldgo"},
{memory::format::ldgo, "memory::format::Goihw8g"}, };
{memory::format::ldgo, "memory::format::Goihw16g"}, return s_mkldnn_format_string_map;
}; }
static const std::set<memory::format> s_filter_formats{ std::set<memory::format>& runtime::cpu::mkldnn_utils::get_filter_formats()
{
static std::set<memory::format> s_filter_formats{
memory::format::oihw, memory::format::oihw,
memory::format::ihwo, memory::format::ihwo,
memory::format::hwio, memory::format::hwio,
...@@ -158,10 +178,11 @@ static const std::set<memory::format> s_filter_formats{ ...@@ -158,10 +178,11 @@ static const std::set<memory::format> s_filter_formats{
memory::format::Ohwi8o, memory::format::Ohwi8o,
memory::format::Ohwi16o, memory::format::Ohwi16o,
memory::format::OhIw16o4i}; memory::format::OhIw16o4i};
return s_filter_formats;
}
bool runtime::cpu::mkldnn_utils::IsMKLDNNOp(ngraph::Node& op) bool runtime::cpu::mkldnn_utils::IsMKLDNNOp(ngraph::Node& op)
{ {
return (s_op_registry.find(TI(op)) != s_op_registry.end()); return (get_op_registry().find(TI(op)) != get_op_registry().end());
} }
mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat( mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(
...@@ -185,27 +206,31 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(const ...@@ -185,27 +206,31 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(const
const std::string& const std::string&
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(const ngraph::element::Type& type) runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(const ngraph::element::Type& type)
{ {
auto it = s_mkldnn_data_type_string_map.find(type); auto it = get_mkldnn_data_type_string_map().find(type);
if (it == s_mkldnn_data_type_string_map.end() || it->second.empty()) if (it == get_mkldnn_data_type_string_map().end() || it->second.empty())
throw ngraph_error("No MKLDNN data type exists for the given element type"); {
throw ngraph_error("No MKLDNN data type exists for the given element type" +
type.c_type_string());
}
return it->second; return it->second;
} }
mkldnn::memory::data_type mkldnn::memory::data_type
runtime::cpu::mkldnn_utils::get_mkldnn_data_type(const ngraph::element::Type& type) runtime::cpu::mkldnn_utils::get_mkldnn_data_type(const ngraph::element::Type& type)
{ {
auto it = s_mkldnn_data_type_map.find(type); auto it = get_mkldnn_data_type_map().find(type);
if (it == s_mkldnn_data_type_map.end()) if (it == get_mkldnn_data_type_map().end())
{ {
throw ngraph_error("No MKLDNN data type exists for the given element type"); throw ngraph_error("No MKLDNN data type exists for the given element type" +
type.c_type_string());
} }
return it->second; return it->second;
} }
const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory::format fmt) const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory::format fmt)
{ {
auto it = s_mkldnn_format_string_map.find(fmt); auto it = get_mkldnn_format_string_map().find(fmt);
if (it == s_mkldnn_format_string_map.end()) if (it == get_mkldnn_format_string_map().end())
throw ngraph_error("No MKLDNN format exists for the given format type " + throw ngraph_error("No MKLDNN format exists for the given format type " +
std::to_string(fmt)); std::to_string(fmt));
return it->second; return it->second;
...@@ -252,12 +277,13 @@ bool runtime::cpu::mkldnn_utils::can_create_mkldnn_md(const Shape& dims, ...@@ -252,12 +277,13 @@ bool runtime::cpu::mkldnn_utils::can_create_mkldnn_md(const Shape& dims,
const Strides& strides, const Strides& strides,
const ngraph::element::Type type) const ngraph::element::Type type)
{ {
auto it = s_mkldnn_data_type_map.find(type); auto it = get_mkldnn_data_type_map().find(type);
if (dims.size() == 0) if (dims.size() == 0)
{ {
return false; return false;
} }
if (it == s_mkldnn_data_type_map.end() || it->second == mkldnn::memory::data_type::data_undef) if (it == get_mkldnn_data_type_map().end() ||
it->second == mkldnn::memory::data_type::data_undef)
{ {
return false; return false;
} }
...@@ -452,7 +478,7 @@ bool runtime::cpu::mkldnn_utils::compare_mkldnn_mds(const mkldnn::memory::desc& ...@@ -452,7 +478,7 @@ bool runtime::cpu::mkldnn_utils::compare_mkldnn_mds(const mkldnn::memory::desc&
bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(mkldnn::memory::format fmt) bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(mkldnn::memory::format fmt)
{ {
if (s_filter_formats.find(fmt) != s_filter_formats.end()) if (get_filter_formats().find(fmt) != get_filter_formats().end())
{ {
return true; return true;
} }
......
...@@ -62,6 +62,12 @@ namespace ngraph ...@@ -62,6 +62,12 @@ namespace ngraph
const mkldnn::memory::desc& rhs); const mkldnn::memory::desc& rhs);
bool is_mkldnn_filter_format(mkldnn::memory::format fmt); bool is_mkldnn_filter_format(mkldnn::memory::format fmt);
bool is_mkldnn_blocked_data_format(mkldnn::memory::format fmt); bool is_mkldnn_blocked_data_format(mkldnn::memory::format fmt);
std::unordered_set<std::type_index>& get_op_registry();
std::map<element::Type, const mkldnn::memory::data_type>&
get_mkldnn_data_type_map();
std::map<element::Type, const std::string>& get_mkldnn_data_type_string_map();
std::map<mkldnn::memory::format, const std::string>& get_mkldnn_format_string_map();
std::set<mkldnn::memory::format>& get_filter_formats();
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment