Commit faad7d1b authored by Nishant Patel's avatar Nishant Patel Committed by Scott Cyphers

Support 3-D convolution with mkldnn (#1061)

parent 1f5b690d
...@@ -98,11 +98,21 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_ ...@@ -98,11 +98,21 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_
{memory::format::chwn, "memory::format::chwn"}, {memory::format::chwn, "memory::format::chwn"},
{memory::format::nChw8c, "memory::format::nChw8c"}, {memory::format::nChw8c, "memory::format::nChw8c"},
{memory::format::nChw16c, "memory::format::nChw16c"}, {memory::format::nChw16c, "memory::format::nChw16c"},
{memory::format::ncdhw, "memory::format::ndhwc"},
{memory::format::ncdhw, "memory::format::ndhwc"},
{memory::format::nCdhw16c, "memory::format::nCdhw16c"},
{memory::format::oi, "memory::format::oi"}, {memory::format::oi, "memory::format::oi"},
{memory::format::io, "memory::format::io"}, {memory::format::io, "memory::format::io"},
{memory::format::oihw, "memory::format::oihw"}, {memory::format::oihw, "memory::format::oihw"},
{memory::format::ihwo, "memory::format::ihwo"}, {memory::format::ihwo, "memory::format::ihwo"},
{memory::format::hwio, "memory::format::hwio"}, {memory::format::hwio, "memory::format::hwio"},
// TODO (nishant): Uncomment after the next release of mkl-dnn"
//{memory::format::dhwio, "memory::format::dhwio"},
{memory::format::oidhw, "memory::format::oidhw"},
{memory::format::OIdhw16i16o, "memory::format::OIdhw16i16o"},
{memory::format::OIdhw16o16i, "memory::format::OIdhw16o16i"},
{memory::format::Oidhw16o, "memory::format::Oidhw16o"},
{memory::format::Odhwi16o, "memory::format::Odhwi16o"},
{memory::format::oIhw8i, "memory::format::oIhw8i"}, {memory::format::oIhw8i, "memory::format::oIhw8i"},
{memory::format::oIhw16i, "memory::format::oIhw16i"}, {memory::format::oIhw16i, "memory::format::oIhw16i"},
{memory::format::OIhw8i8o, "memory::format::OIhw8i8o"}, {memory::format::OIhw8i8o, "memory::format::OIhw8i8o"},
...@@ -125,6 +135,13 @@ static const std::set<memory::format> s_filter_formats{ ...@@ -125,6 +135,13 @@ static const std::set<memory::format> s_filter_formats{
memory::format::oihw, memory::format::oihw,
memory::format::ihwo, memory::format::ihwo,
memory::format::hwio, memory::format::hwio,
// TODO (nishant): Uncomment after the next release of mkl-dnn"
//memory::format::dhwio,
memory::format::oidhw,
memory::format::OIdhw16i16o,
memory::format::OIdhw16o16i,
memory::format::Oidhw16o,
memory::format::Odhwi16o,
// memory::format::oIhw8i, // These currently map to nChw8c and nChw16c // memory::format::oIhw8i, // These currently map to nChw8c and nChw16c
// memory::format::oIhw16i, // memory::format::oIhw16i,
memory::format::OIhw8i8o, memory::format::OIhw8i8o,
...@@ -151,6 +168,7 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat( ...@@ -151,6 +168,7 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(
case 1: return mkldnn::memory::format::x; case 1: return mkldnn::memory::format::x;
case 2: return mkldnn::memory::format::nc; case 2: return mkldnn::memory::format::nc;
case 4: return mkldnn::memory::format::nchw; case 4: return mkldnn::memory::format::nchw;
case 5: return mkldnn::memory::format::ncdhw;
default: return mkldnn::memory::format::format_undef; default: return mkldnn::memory::format::format_undef;
} }
} }
......
...@@ -110,7 +110,8 @@ namespace ngraph ...@@ -110,7 +110,8 @@ namespace ngraph
data_dilated = data_dilated || (s != 1); data_dilated = data_dilated || (s != 1);
} }
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 && if (!data_dilated && ((arg0_rank == 4 && arg1_rank == 4) ||
(arg0_rank == 5 && arg1_rank == 5)) &&
node->get_input_element_type(0) == element::f32) node->get_input_element_type(0) == element::f32)
{ {
auto op_annotations = auto op_annotations =
...@@ -198,7 +199,8 @@ namespace ngraph ...@@ -198,7 +199,8 @@ namespace ngraph
data_dilated = data_dilated || (s != 1); data_dilated = data_dilated || (s != 1);
} }
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 && if (!data_dilated && ((arg0_rank == 4 && arg1_rank == 4) ||
(arg0_rank == 5 && arg1_rank == 5)) &&
node->get_input_element_type(0) == element::f32) node->get_input_element_type(0) == element::f32)
{ {
auto op_annotations = auto op_annotations =
...@@ -225,7 +227,8 @@ namespace ngraph ...@@ -225,7 +227,8 @@ namespace ngraph
data_dilated = data_dilated || (s != 1); data_dilated = data_dilated || (s != 1);
} }
if (!data_dilated && arg0_rank == 4 && arg1_rank == 4 && if (!data_dilated && ((arg0_rank == 4 && arg1_rank == 4) ||
(arg0_rank == 5 && arg1_rank == 5)) &&
node->get_input_element_type(0) == element::f32) node->get_input_element_type(0) == element::f32)
{ {
auto op_annotations = auto op_annotations =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment