Commit 3289bcbb authored by gaurides's avatar gaurides Committed by Robert Kimball

Gauri/static initialization problem r08 cherry-pick from #1691 (#1702)

parent 8caa2717
......@@ -42,126 +42,147 @@ using namespace std;
#define TI(x) std::type_index(typeid(x))
static const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::Add),
TI(ngraph::op::AvgPool),
TI(ngraph::op::AvgPoolBackprop),
TI(ngraph::op::BatchNorm),
TI(ngraph::op::BatchNormBackprop),
TI(ngraph::op::Concat),
TI(ngraph::op::Convolution),
TI(ngraph::op::ConvolutionBackpropData),
TI(ngraph::op::ConvolutionBackpropFilters),
TI(ngraph::op::ConvolutionBias),
TI(ngraph::op::ConvolutionRelu),
TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
TI(ngraph::op::MaxPool),
TI(ngraph::op::MaxPoolBackprop),
TI(ngraph::op::Relu),
TI(ngraph::op::ReluBackprop),
TI(ngraph::op::Reshape)};
// Mapping from POD types to MKLDNN data types
static const std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_data_type_map{
{element::boolean, mkldnn::memory::data_type::s8},
{element::f32, mkldnn::memory::data_type::f32},
{element::f64, mkldnn::memory::data_type::data_undef},
{element::i8, mkldnn::memory::data_type::s8},
{element::i16, mkldnn::memory::data_type::s16},
{element::i32, mkldnn::memory::data_type::s32},
{element::i64, mkldnn::memory::data_type::data_undef},
{element::u8, mkldnn::memory::data_type::u8},
{element::u16, mkldnn::memory::data_type::data_undef},
{element::u32, mkldnn::memory::data_type::data_undef},
{element::u64, mkldnn::memory::data_type::data_undef}};
static const std::map<element::Type, const std::string> s_mkldnn_data_type_string_map{
{element::boolean, "mkldnn::memory::data_type::s8"},
{element::f32, "mkldnn::memory::data_type::f32"},
{element::f64, "mkldnn::memory::data_type::data_undef"},
{element::i8, "mkldnn::memory::data_type::s8"},
{element::i16, "mkldnn::memory::data_type::s16"},
{element::i32, "mkldnn::memory::data_type::s32"},
{element::i64, "mkldnn::memory::data_type::data_undef"},
{element::u8, "mkldnn::memory::data_type::u8"},
{element::u16, "mkldnn::memory::data_type::data_undef"},
{element::u32, "mkldnn::memory::data_type::data_undef"},
{element::u64, "mkldnn::memory::data_type::data_undef"}};
// TODO (jbobba): Add the rest of memory formats to this map as well
static const std::map<memory::format, const std::string> s_mkldnn_format_string_map{
{memory::format::format_undef, "memory::format::format_undef"},
{memory::format::any, "memory::format::any"},
{memory::format::blocked, "memory::format::blocked"},
{memory::format::x, "memory::format::x"},
{memory::format::nc, "memory::format::nc"},
{memory::format::nchw, "memory::format::nchw"},
{memory::format::nhwc, "memory::format::nhwc"},
{memory::format::chwn, "memory::format::chwn"},
{memory::format::nChw8c, "memory::format::nChw8c"},
{memory::format::nChw16c, "memory::format::nChw16c"},
{memory::format::ncdhw, "memory::format::ndhwc"},
{memory::format::ncdhw, "memory::format::ndhwc"},
{memory::format::nCdhw16c, "memory::format::nCdhw16c"},
{memory::format::oi, "memory::format::oi"},
{memory::format::io, "memory::format::io"},
{memory::format::oihw, "memory::format::oihw"},
{memory::format::ihwo, "memory::format::ihwo"},
{memory::format::hwio, "memory::format::hwio"},
// TODO (nishant): Uncomment after the next release of mkl-dnn"
//{memory::format::dhwio, "memory::format::dhwio"},
{memory::format::oidhw, "memory::format::oidhw"},
{memory::format::OIdhw16i16o, "memory::format::OIdhw16i16o"},
{memory::format::OIdhw16o16i, "memory::format::OIdhw16o16i"},
{memory::format::Oidhw16o, "memory::format::Oidhw16o"},
{memory::format::Odhwi16o, "memory::format::Odhwi16o"},
{memory::format::oIhw8i, "memory::format::oIhw8i"},
{memory::format::oIhw16i, "memory::format::oIhw16i"},
{memory::format::OIhw8i8o, "memory::format::OIhw8i8o"},
{memory::format::OIhw16i16o, "memory::format::OIhw16i16o"},
{memory::format::IOhw16o16i, "memory::format::IOhw16o16i"},
{memory::format::OIhw8o8i, "memory::format::OIhw8o8i"},
{memory::format::OIhw16o16i, "memory::format::OIhw16o16i"},
{memory::format::Oihw8o, "memory::format::Oihw8o"},
{memory::format::Oihw16o, "memory::format::Oihw16o"},
{memory::format::Ohwi8o, "memory::format::Ohwi8o"},
{memory::format::Ohwi16o, "memory::format::Ohwi16o"},
{memory::format::OhIw16o4i, "memory::format::OhIw16o4i"},
{memory::format::tnc, "memory::format::tnc"},
{memory::format::ldsnc, "memory::format::ldsnc"},
{memory::format::ldigo, "memory::format::ldigo"},
{memory::format::ldgo, "memory::format::ldgo"},
{memory::format::ldgo, "memory::format::Goihw8g"},
{memory::format::ldgo, "memory::format::Goihw16g"},
};
static const std::set<memory::format> s_filter_formats{
memory::format::oihw,
memory::format::ihwo,
memory::format::hwio,
// TODO (nishant): Uncomment after the next release of mkl-dnn"
// memory::format::dhwio,
memory::format::oidhw,
memory::format::OIdhw16i16o,
memory::format::OIdhw16o16i,
memory::format::Oidhw16o,
memory::format::Odhwi16o,
// memory::format::oIhw8i, // These currently map to nChw8c and nChw16c
// memory::format::oIhw16i,
memory::format::OIhw8i8o,
memory::format::OIhw16i16o,
memory::format::IOhw16o16i,
memory::format::OIhw8o8i,
memory::format::OIhw16o16i,
memory::format::Oihw8o,
memory::format::Oihw16o,
memory::format::Ohwi8o,
memory::format::Ohwi16o,
memory::format::OhIw16o4i};
std::unordered_set<std::type_index>& runtime::cpu::mkldnn_utils::get_op_registry()
{
static std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::Add),
TI(ngraph::op::AvgPool),
TI(ngraph::op::AvgPoolBackprop),
TI(ngraph::op::BatchNorm),
TI(ngraph::op::BatchNormBackprop),
TI(ngraph::op::Concat),
TI(ngraph::op::Convolution),
TI(ngraph::op::ConvolutionBackpropData),
TI(ngraph::op::ConvolutionBackpropFilters),
TI(ngraph::op::ConvolutionBias),
TI(ngraph::op::ConvolutionRelu),
TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
TI(ngraph::op::MaxPool),
TI(ngraph::op::MaxPoolBackprop),
TI(ngraph::op::Relu),
TI(ngraph::op::ReluBackprop),
TI(ngraph::op::Reshape)};
return s_op_registry;
}
std::map<element::Type, const mkldnn::memory::data_type>&
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_map()
{
// Mapping from POD types to MKLDNN data types
static std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_data_type_map = {
{element::boolean, mkldnn::memory::data_type::s8},
{element::f32, mkldnn::memory::data_type::f32},
{element::f64, mkldnn::memory::data_type::data_undef},
{element::i8, mkldnn::memory::data_type::s8},
{element::i16, mkldnn::memory::data_type::s16},
{element::i32, mkldnn::memory::data_type::s32},
{element::i64, mkldnn::memory::data_type::data_undef},
{element::u8, mkldnn::memory::data_type::u8},
{element::u16, mkldnn::memory::data_type::data_undef},
{element::u32, mkldnn::memory::data_type::data_undef},
{element::u64, mkldnn::memory::data_type::data_undef},
};
return s_mkldnn_data_type_map;
}
std::map<element::Type, const std::string>&
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string_map()
{
static std::map<element::Type, const std::string> s_mkldnn_data_type_string_map{
{element::boolean, "mkldnn::memory::data_type::s8"},
{element::f32, "mkldnn::memory::data_type::f32"},
{element::f64, "mkldnn::memory::data_type::data_undef"},
{element::i8, "mkldnn::memory::data_type::s8"},
{element::i16, "mkldnn::memory::data_type::s16"},
{element::i32, "mkldnn::memory::data_type::s32"},
{element::i64, "mkldnn::memory::data_type::data_undef"},
{element::u8, "mkldnn::memory::data_type::u8"},
{element::u16, "mkldnn::memory::data_type::data_undef"},
{element::u32, "mkldnn::memory::data_type::data_undef"},
{element::u64, "mkldnn::memory::data_type::data_undef"}};
return s_mkldnn_data_type_string_map;
}
std::map<memory::format, const std::string>&
runtime::cpu::mkldnn_utils::get_mkldnn_format_string_map()
{
// TODO (jbobba): Add the rest of memory formats to this map as well
static std::map<memory::format, const std::string> s_mkldnn_format_string_map{
{memory::format::format_undef, "memory::format::format_undef"},
{memory::format::any, "memory::format::any"},
{memory::format::blocked, "memory::format::blocked"},
{memory::format::x, "memory::format::x"},
{memory::format::nc, "memory::format::nc"},
{memory::format::nchw, "memory::format::nchw"},
{memory::format::nhwc, "memory::format::nhwc"},
{memory::format::chwn, "memory::format::chwn"},
{memory::format::nChw8c, "memory::format::nChw8c"},
{memory::format::nChw16c, "memory::format::nChw16c"},
{memory::format::ncdhw, "memory::format::ndhwc"},
{memory::format::ncdhw, "memory::format::ndhwc"},
{memory::format::nCdhw16c, "memory::format::nCdhw16c"},
{memory::format::oi, "memory::format::oi"},
{memory::format::io, "memory::format::io"},
{memory::format::oihw, "memory::format::oihw"},
{memory::format::ihwo, "memory::format::ihwo"},
{memory::format::hwio, "memory::format::hwio"},
// TODO (nishant): Uncomment after the next release of mkl-dnn"
//{memory::format::dhwio, "memory::format::dhwio"},
{memory::format::oidhw, "memory::format::oidhw"},
{memory::format::OIdhw16i16o, "memory::format::OIdhw16i16o"},
{memory::format::OIdhw16o16i, "memory::format::OIdhw16o16i"},
{memory::format::Oidhw16o, "memory::format::Oidhw16o"},
{memory::format::Odhwi16o, "memory::format::Odhwi16o"},
{memory::format::oIhw8i, "memory::format::oIhw8i"},
{memory::format::oIhw16i, "memory::format::oIhw16i"},
{memory::format::OIhw8i8o, "memory::format::OIhw8i8o"},
{memory::format::OIhw16i16o, "memory::format::OIhw16i16o"},
{memory::format::IOhw16o16i, "memory::format::IOhw16o16i"},
{memory::format::OIhw8o8i, "memory::format::OIhw8o8i"},
{memory::format::OIhw16o16i, "memory::format::OIhw16o16i"},
{memory::format::Oihw8o, "memory::format::Oihw8o"},
{memory::format::Oihw16o, "memory::format::Oihw16o"},
{memory::format::Ohwi8o, "memory::format::Ohwi8o"},
{memory::format::Ohwi16o, "memory::format::Ohwi16o"},
{memory::format::OhIw16o4i, "memory::format::OhIw16o4i"},
{memory::format::tnc, "memory::format::tnc"},
{memory::format::ldsnc, "memory::format::ldsnc"},
{memory::format::ldigo, "memory::format::ldigo"},
{memory::format::ldgo, "memory::format::ldgo"},
};
return s_mkldnn_format_string_map;
}
std::set<memory::format>& runtime::cpu::mkldnn_utils::get_filter_formats()
{
static std::set<memory::format> s_filter_formats{
memory::format::oihw,
memory::format::ihwo,
memory::format::hwio,
// TODO (nishant): Uncomment after the next release of mkl-dnn"
// memory::format::dhwio,
memory::format::oidhw,
memory::format::OIdhw16i16o,
memory::format::OIdhw16o16i,
memory::format::Oidhw16o,
memory::format::Odhwi16o,
// memory::format::oIhw8i, // These currently map to nChw8c and nChw16c
// memory::format::oIhw16i,
memory::format::OIhw8i8o,
memory::format::OIhw16i16o,
memory::format::IOhw16o16i,
memory::format::OIhw8o8i,
memory::format::OIhw16o16i,
memory::format::Oihw8o,
memory::format::Oihw16o,
memory::format::Ohwi8o,
memory::format::Ohwi16o,
memory::format::OhIw16o4i};
return s_filter_formats;
}
bool runtime::cpu::mkldnn_utils::IsMKLDNNOp(ngraph::Node& op)
{
return (s_op_registry.find(TI(op)) != s_op_registry.end());
return (get_op_registry().find(TI(op)) != get_op_registry().end());
}
mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(
......@@ -185,27 +206,31 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(const
const std::string&
runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(const ngraph::element::Type& type)
{
auto it = s_mkldnn_data_type_string_map.find(type);
if (it == s_mkldnn_data_type_string_map.end() || it->second.empty())
throw ngraph_error("No MKLDNN data type exists for the given element type");
auto it = get_mkldnn_data_type_string_map().find(type);
if (it == get_mkldnn_data_type_string_map().end() || it->second.empty())
{
throw ngraph_error("No MKLDNN data type exists for the given element type" +
type.c_type_string());
}
return it->second;
}
mkldnn::memory::data_type
runtime::cpu::mkldnn_utils::get_mkldnn_data_type(const ngraph::element::Type& type)
{
auto it = s_mkldnn_data_type_map.find(type);
if (it == s_mkldnn_data_type_map.end())
auto it = get_mkldnn_data_type_map().find(type);
if (it == get_mkldnn_data_type_map().end())
{
throw ngraph_error("No MKLDNN data type exists for the given element type");
throw ngraph_error("No MKLDNN data type exists for the given element type" +
type.c_type_string());
}
return it->second;
}
const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory::format fmt)
{
auto it = s_mkldnn_format_string_map.find(fmt);
if (it == s_mkldnn_format_string_map.end())
auto it = get_mkldnn_format_string_map().find(fmt);
if (it == get_mkldnn_format_string_map().end())
throw ngraph_error("No MKLDNN format exists for the given format type " +
std::to_string(fmt));
return it->second;
......@@ -252,12 +277,13 @@ bool runtime::cpu::mkldnn_utils::can_create_mkldnn_md(const Shape& dims,
const Strides& strides,
const ngraph::element::Type type)
{
auto it = s_mkldnn_data_type_map.find(type);
auto it = get_mkldnn_data_type_map().find(type);
if (dims.size() == 0)
{
return false;
}
if (it == s_mkldnn_data_type_map.end() || it->second == mkldnn::memory::data_type::data_undef)
if (it == get_mkldnn_data_type_map().end() ||
it->second == mkldnn::memory::data_type::data_undef)
{
return false;
}
......@@ -452,7 +478,7 @@ bool runtime::cpu::mkldnn_utils::compare_mkldnn_mds(const mkldnn::memory::desc&
bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(mkldnn::memory::format fmt)
{
if (s_filter_formats.find(fmt) != s_filter_formats.end())
if (get_filter_formats().find(fmt) != get_filter_formats().end())
{
return true;
}
......
......@@ -62,6 +62,12 @@ namespace ngraph
const mkldnn::memory::desc& rhs);
bool is_mkldnn_filter_format(mkldnn::memory::format fmt);
bool is_mkldnn_blocked_data_format(mkldnn::memory::format fmt);
std::unordered_set<std::type_index>& get_op_registry();
std::map<element::Type, const mkldnn::memory::data_type>&
get_mkldnn_data_type_map();
std::map<element::Type, const std::string>& get_mkldnn_data_type_string_map();
std::map<mkldnn::memory::format, const std::string>& get_mkldnn_format_string_map();
std::set<mkldnn::memory::format>& get_filter_formats();
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment