Commit 67fb65b8 authored by pthoreho's avatar pthoreho

Addressed PR review comments

parent 466854c6
......@@ -2724,7 +2724,7 @@ namespace ngraph
auto out_shape = out[0].get_shape();
if (delta_rank == 4 && mpb->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32)
args[0].get_element_type() == element::f32 && max_pool_fprop_op != nullptr)
{
const string& et =
get_mkldnn_data_type(args[1].get_element_type().c_type_string());
......@@ -2743,7 +2743,7 @@ namespace ngraph
//----------------------------------------------------------------------------------------------
// create a forward primitive_desc, use this to query the workspace
// FIXME: (pruthvi) this is a workaround, till we maintain a global context to refer to the corrosponding
// TODO: (pruthvi) this is a workaround, till we maintain a global context to refer to the corrosponding
// MKLDNN fprop kernel. this impacts performance
writer << "memory::desc max_pool_input_desc = memory::desc({"
<< join(args[0].get_shape()) << "}, " << et
......
......@@ -1410,15 +1410,13 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1)
{
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
Shape shape_a{1, 4, 4, 4}; //in CHWN
Shape shape_a{4, 1, 4, 4}; //in NCHW
Shape maxpool_shape{4, 1, 3, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto reshape = make_shared<op::Reshape>(
A, AxisVector{0, 3, 1, 2}, Shape{4, 1, 4, 4}); //convert CHWN to NCHW
Shape window_shape{2, 2};
auto window_movement_strides = Strides{1, 1};
auto maxpool = make_shared<op::MaxPool>(reshape, window_shape, window_movement_strides);
auto maxpool = make_shared<op::MaxPool>(A, window_shape, window_movement_strides);
auto f = make_shared<Function>(maxpool, op::Parameters{A});
shared_ptr<runtime::TensorView> ep =
backend->make_primary_tensor_view(element::f32, maxpool_shape);
......@@ -1435,9 +1433,9 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1)
45, 63, 16, 14, 55, 54, 37, 20, 36, 12, 70, 34, 19, 26, 32, 23};
vector<float> expected{//delta
0, 4, 0, 0, 0, 0, 0, 8, 0, 0, 8, 0, 0, 0, 0, 0, 0, 4, 4, 4, 12, 0,
0, 0, 0, 8, 0, 0, 4, 8, 0, 8, 0, 0, 8, 0, 0, 0, 0, 4, 16, 4, 16, 8,
0, 0, 0, 4, 0, 4, 0, 0, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
0, 8, 0, 0, 0, 0, 0, 4, 0, 8, 16, 0, 0, 0, 0, 0, 0, 4, 0, 4, 8, 0,
0, 0, 0, 4, 4, 0, 4, 4, 0, 4, 0, 0, 8, 0, 4, 0, 0, 0, 8, 0, 16, 0,
0, 0, 0, 0, 0, 8, 0, 0, 4, 0, 4, 0, 4, 0, 16, 0, 0, 0, 0, 0};
copy_data(ep, dataEp);
copy_data(input, dataInput);
......@@ -1447,7 +1445,6 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1)
auto external = manager->compile(df);
auto cf = backend->make_call_frame(external);
cf->tensor_call({input, ep}, {output});
ASSERT_TRUE(read_vector<float>(output) == expected);
}
......@@ -1456,15 +1453,13 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
Shape shape_a{1, 5, 5, 2}; //in CHWN
Shape shape_a{1, 2, 5, 5}; //in NCHW
Shape maxpool_shape{1, 2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto reshape = make_shared<op::Reshape>(
A, AxisVector{0, 3, 1, 2}, Shape{1, 2, 5, 5}); //convert CHWN to NCHW
Shape window_shape{3, 3};
auto window_movement_strides = Strides{2, 2};
auto maxpool = make_shared<op::MaxPool>(reshape, window_shape, window_movement_strides);
auto maxpool = make_shared<op::MaxPool>(A, window_shape, window_movement_strides);
auto f = make_shared<Function>(maxpool, op::Parameters{A});
shared_ptr<runtime::TensorView> ep =
......@@ -1481,9 +1476,9 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
19, 40, 10, 46, 34, 53, 26, 55, 50, 13, 24, 14, 49, 56, 59, 11};
vector<float> expected{//delta
4, 0, 0, 0, 0, 4, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0,
4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 4, 4, 0, 0, 0, 0, 4, 4, 0};
0, 0, 0, 4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4, 4, 0};
copy_data(ep, dataEp);
copy_data(input, dataInput);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment