Commit faad7d1b authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Support 3-D convolution with mkldnn (#1061)

parent 1f5b690d
......@@ -98,11 +98,21 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_
{memory::format::chwn, "memory::format::chwn"},
{memory::format::nChw8c, "memory::format::nChw8c"},
{memory::format::nChw16c, "memory::format::nChw16c"},
{memory::format::ncdhw, "memory::format::ndhwc"},
{memory::format::ncdhw, "memory::format::ndhwc"},
{memory::format::nCdhw16c, "memory::format::nCdhw16c"},
{memory::format::oi, "memory::format::oi"},
{memory::format::io, "memory::format::io"},
{memory::format::oihw, "memory::format::oihw"},
{memory::format::ihwo, "memory::format::ihwo"},
{memory::format::hwio, "memory::format::hwio"},
// TODO (nishant): Uncomment after the next release of mkl-dnn"
//{memory::format::dhwio, "memory::format::dhwio"},
{memory::format::oidhw, "memory::format::oidhw"},
{memory::format::OIdhw16i16o, "memory::format::OIdhw16i16o"},
{memory::format::OIdhw16o16i, "memory::format::OIdhw16o16i"},
{memory::format::Oidhw16o, "memory::format::Oidhw16o"},
{memory::format::Odhwi16o, "memory::format::Odhwi16o"},
{memory::format::oIhw8i, "memory::format::oIhw8i"},
{memory::format::oIhw16i, "memory::format::oIhw16i"},
{memory::format::OIhw8i8o, "memory::format::OIhw8i8o"},
......@@ -125,6 +135,13 @@ static const std::set<memory::format> s_filter_formats{
memory::format::oihw,
memory::format::ihwo,
memory::format::hwio,
// TODO (nishant): Uncomment after the next release of mkl-dnn"
//memory::format::dhwio,
memory::format::oidhw,
memory::format::OIdhw16i16o,
memory::format::OIdhw16o16i,
memory::format::Oidhw16o,
memory::format::Odhwi16o,
// memory::format::oIhw8i, // These currently map to nChw8c and nChw16c
// memory::format::oIhw16i,
memory::format::OIhw8i8o,
......@@ -151,6 +168,7 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(
case 1: return mkldnn::memory::format::x;
case 2: return mkldnn::memory::format::nc;
case 4: return mkldnn::memory::format::nchw;
case 5: return mkldnn::memory::format::ncdhw;
default: return mkldnn::memory::format::format_undef;
}
}
......
......@@ -110,7 +110,8 @@ namespace ngraph
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
if (!data_dilated && ((arg0_rank == 4 && arg1_rank == 4) ||
(arg0_rank == 5 && arg1_rank == 5)) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
......@@ -198,7 +199,8 @@ namespace ngraph
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
if (!data_dilated && ((arg0_rank == 4 && arg1_rank == 4) ||
(arg0_rank == 5 && arg1_rank == 5)) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
......@@ -225,7 +227,8 @@ namespace ngraph
data_dilated = data_dilated || (s != 1);
}
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 &&
if (!data_dilated && ((arg0_rank == 4 && arg1_rank == 4) ||
(arg0_rank == 5 && arg1_rank == 5)) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment