Commit 943bfe19 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Sang Ik Lee

Fallback to default pooling kernels when mkldnn doesn't support input… (#2526)

* Fallback to default pooling kernels when mkldnn doesn't support input format

* create default output descriptors
parent d85311a8
......@@ -990,8 +990,23 @@ namespace ngraph
}
catch (const mkldnn::error& e)
{
throw ngraph_error("MKLDNN Unsupported pooling layout" +
to_string(input_desc.data.format) + e.message);
if (arg0_shape.size() == 4 || arg0_shape.size() == 5)
{
auto default_format = arg0_shape.size() == 4
? mkldnn::memory::format::nchw
: mkldnn::memory::format::ncdhw;
auto default_desc_i = mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, false, default_format);
auto default_desc_o = mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, true, default_format);
i_mds.push_back(default_desc_i);
o_mds.push_back(default_desc_o);
}
else
{
throw ngraph_error("MKLDNN Unsupported pooling layout" +
to_string(input_desc.data.format) + e.message);
}
}
}
......@@ -1147,8 +1162,39 @@ namespace ngraph
}
catch (const mkldnn::error& e)
{
throw ngraph_error("MKLDNN Unsupported pooling fwd layout" +
to_string(input_desc.data.format) + e.message);
if (arg0_shape.size() == 4 || arg0_shape.size() == 5)
{
auto default_format = arg0_shape.size() == 4
? mkldnn::memory::format::nchw
: mkldnn::memory::format::ncdhw;
auto default_desc_i = mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, false, default_format);
auto default_desc_o = mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, true, default_format);
i_mds.push_back(default_desc_i);
o_mds.push_back(default_desc_o);
if (pk == prop_kind::forward_training)
{
o_mds.push_back(
pooling_forward::primitive_desc({pk,
algorithm_enumerator,
default_desc_i,
result_desc,
mkldnn_filter_strides,
mkldnn_filter_shape,
mkldnn_padding_below,
mkldnn_padding_above,
padding_kind::zero},
executor::global_cpu_engine)
.workspace_primitive_desc()
.desc());
}
}
else
{
throw ngraph_error("MKLDNN Unsupported pooling fwd layout" +
to_string(input_desc.data.format) + e.message);
}
}
}
......
......@@ -56,8 +56,8 @@ public:
}
};
static void compare_backends(std::shared_ptr<Function>& f1,
std::shared_ptr<Function>& f2,
static void compare_backends(const std::shared_ptr<Function>& f1,
const std::shared_ptr<Function>& f2,
const string backend1,
const string backend2,
float rtol = 1e-5,
......@@ -890,3 +890,31 @@ TEST(cpu_test, convert_inplace)
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int8_t>{1, 2, 3, -2}), read_vector<int8_t>(result));
}
TEST(cpu_test, rotated_pooling)
{
auto make_f = [&](bool is_4d, bool avgpool) {
auto input_shape = is_4d ? Shape{2, 4, 4, 1} : Shape{2, 4, 4, 4, 1};
auto rotate_order = is_4d ? AxisVector{3, 0, 1, 2} : AxisVector{4, 0, 1, 2, 3};
auto pool_shape = is_4d ? Shape{1, 2, 4, 4} : Shape{1, 2, 4, 4, 4};
auto window_shape = is_4d ? Shape{2, 2} : Shape{2, 2, 2};
auto input = make_shared<op::Parameter>(element::f32, input_shape); // C, H, W, N
auto transpose = make_shared<op::Reshape>(input, rotate_order, pool_shape);
if (avgpool)
{
return make_shared<Function>(make_shared<op::AvgPool>(transpose, window_shape),
ParameterVector{input});
}
else
{
return make_shared<Function>(make_shared<op::MaxPool>(transpose, window_shape),
ParameterVector{input});
}
};
compare_backends(make_f(true, true), make_f(true, true), "INTERPRETER", "CPU"); // 4D AvgPool
compare_backends(make_f(true, false), make_f(true, false), "INTERPRETER", "CPU"); // 4D MaxPool
compare_backends(make_f(false, true), make_f(false, true), "INTERPRETER", "CPU"); // 5D AvgPool
compare_backends(
make_f(false, false), make_f(false, false), "INTERPRETER", "CPU"); // 5D MaxPool
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment