Commit 67fb65b8 authored by pthoreho's avatar pthoreho

Addressed PR review comments

parent 466854c6
...@@ -2724,7 +2724,7 @@ namespace ngraph ...@@ -2724,7 +2724,7 @@ namespace ngraph
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
if (delta_rank == 4 && mpb->get_window_shape().size() == 2 && if (delta_rank == 4 && mpb->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32) args[0].get_element_type() == element::f32 && max_pool_fprop_op != nullptr)
{ {
const string& et = const string& et =
get_mkldnn_data_type(args[1].get_element_type().c_type_string()); get_mkldnn_data_type(args[1].get_element_type().c_type_string());
...@@ -2743,7 +2743,7 @@ namespace ngraph ...@@ -2743,7 +2743,7 @@ namespace ngraph
//---------------------------------------------------------------------------------------------- //----------------------------------------------------------------------------------------------
// create a forward primitive_desc, use this to query the workspace // create a forward primitive_desc, use this to query the workspace
// FIXME: (pruthvi) this is a workaround, till we maintain a global context to refer to the corrosponding // TODO: (pruthvi) this is a workaround, till we maintain a global context to refer to the corrosponding
// MKLDNN fprop kernel. this impacts performance // MKLDNN fprop kernel. this impacts performance
writer << "memory::desc max_pool_input_desc = memory::desc({" writer << "memory::desc max_pool_input_desc = memory::desc({"
<< join(args[0].get_shape()) << "}, " << et << join(args[0].get_shape()) << "}, " << et
......
...@@ -1410,15 +1410,13 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1) ...@@ -1410,15 +1410,13 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1)
{ {
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
Shape shape_a{1, 4, 4, 4}; //in CHWN Shape shape_a{4, 1, 4, 4}; //in NCHW
Shape maxpool_shape{4, 1, 3, 3}; Shape maxpool_shape{4, 1, 3, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto reshape = make_shared<op::Reshape>(
A, AxisVector{0, 3, 1, 2}, Shape{4, 1, 4, 4}); //convert CHWN to NCHW
Shape window_shape{2, 2}; Shape window_shape{2, 2};
auto window_movement_strides = Strides{1, 1}; auto window_movement_strides = Strides{1, 1};
auto maxpool = make_shared<op::MaxPool>(reshape, window_shape, window_movement_strides); auto maxpool = make_shared<op::MaxPool>(A, window_shape, window_movement_strides);
auto f = make_shared<Function>(maxpool, op::Parameters{A}); auto f = make_shared<Function>(maxpool, op::Parameters{A});
shared_ptr<runtime::TensorView> ep = shared_ptr<runtime::TensorView> ep =
backend->make_primary_tensor_view(element::f32, maxpool_shape); backend->make_primary_tensor_view(element::f32, maxpool_shape);
...@@ -1435,9 +1433,9 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1) ...@@ -1435,9 +1433,9 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1)
45, 63, 16, 14, 55, 54, 37, 20, 36, 12, 70, 34, 19, 26, 32, 23}; 45, 63, 16, 14, 55, 54, 37, 20, 36, 12, 70, 34, 19, 26, 32, 23};
vector<float> expected{//delta vector<float> expected{//delta
0, 4, 0, 0, 0, 0, 0, 8, 0, 0, 8, 0, 0, 0, 0, 0, 0, 4, 4, 4, 12, 0, 0, 8, 0, 0, 0, 0, 0, 4, 0, 8, 16, 0, 0, 0, 0, 0, 0, 4, 0, 4, 8, 0,
0, 0, 0, 8, 0, 0, 4, 8, 0, 8, 0, 0, 8, 0, 0, 0, 0, 4, 16, 4, 16, 8, 0, 0, 0, 4, 4, 0, 4, 4, 0, 4, 0, 0, 8, 0, 4, 0, 0, 0, 8, 0, 16, 0,
0, 0, 0, 4, 0, 4, 0, 0, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; 0, 0, 0, 0, 0, 8, 0, 0, 4, 0, 4, 0, 4, 0, 16, 0, 0, 0, 0, 0};
copy_data(ep, dataEp); copy_data(ep, dataEp);
copy_data(input, dataInput); copy_data(input, dataInput);
...@@ -1447,7 +1445,6 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1) ...@@ -1447,7 +1445,6 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1)
auto external = manager->compile(df); auto external = manager->compile(df);
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
cf->tensor_call({input, ep}, {output}); cf->tensor_call({input, ep}, {output});
ASSERT_TRUE(read_vector<float>(output) == expected); ASSERT_TRUE(read_vector<float>(output) == expected);
} }
...@@ -1456,15 +1453,13 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2) ...@@ -1456,15 +1453,13 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
Shape shape_a{1, 5, 5, 2}; //in CHWN Shape shape_a{1, 2, 5, 5}; //in NCHW
Shape maxpool_shape{1, 2, 2, 2}; Shape maxpool_shape{1, 2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto reshape = make_shared<op::Reshape>(
A, AxisVector{0, 3, 1, 2}, Shape{1, 2, 5, 5}); //convert CHWN to NCHW
Shape window_shape{3, 3}; Shape window_shape{3, 3};
auto window_movement_strides = Strides{2, 2}; auto window_movement_strides = Strides{2, 2};
auto maxpool = make_shared<op::MaxPool>(reshape, window_shape, window_movement_strides); auto maxpool = make_shared<op::MaxPool>(A, window_shape, window_movement_strides);
auto f = make_shared<Function>(maxpool, op::Parameters{A}); auto f = make_shared<Function>(maxpool, op::Parameters{A});
shared_ptr<runtime::TensorView> ep = shared_ptr<runtime::TensorView> ep =
...@@ -1481,9 +1476,9 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2) ...@@ -1481,9 +1476,9 @@ TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
19, 40, 10, 46, 34, 53, 26, 55, 50, 13, 24, 14, 49, 56, 59, 11}; 19, 40, 10, 46, 34, 53, 26, 55, 50, 13, 24, 14, 49, 56, 59, 11};
vector<float> expected{//delta vector<float> expected{//delta
4, 0, 0, 0, 0, 4, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 4, 4, 0, 0, 0, 0, 4, 4, 0}; 0, 0, 0, 4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4, 4, 0};
copy_data(ep, dataEp); copy_data(ep, dataEp);
copy_data(input, dataInput); copy_data(input, dataInput);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment