Commit 466854c6 authored by pthoreho's avatar pthoreho

style fix

parent 5ce09de0
...@@ -2726,76 +2726,85 @@ namespace ngraph ...@@ -2726,76 +2726,85 @@ namespace ngraph
if (delta_rank == 4 && mpb->get_window_shape().size() == 2 && if (delta_rank == 4 && mpb->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32) args[0].get_element_type() == element::f32)
{ {
const string& et = get_mkldnn_data_type(args[1].get_element_type().c_type_string()); const string& et =
get_mkldnn_data_type(args[1].get_element_type().c_type_string());
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(delta_shape) << "}, " writer << "memory::desc input_data_desc = memory::desc({" << join(delta_shape)
<< et << ", memory::format::nchw);\n"; << "}, " << et << ", memory::format::nchw);\n";
writer << "memory::desc result_desc = memory::desc({" << join(out_shape) << "}, " << et writer << "memory::desc result_desc = memory::desc({" << join(out_shape)
<< ", memory::format::nchw);\n"; << "}, " << et << ", memory::format::nchw);\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " << args[1].get_name() writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< ");\n"; << args[1].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name() writer << "memory result = memory({result_desc, cpu_engine}, "
<< ");\n"; << out[0].get_name() << ");\n";
//---------------------------------------------------------------------------------------------- //----------------------------------------------------------------------------------------------
// create a forward primitive_desc, use this to query the workspace // create a forward primitive_desc, use this to query the workspace
// FIXME: (pruthvi) this is a workaround, till we maintain a global context to refer to the corrosponding // FIXME: (pruthvi) this is a workaround, till we maintain a global context to refer to the corrosponding
// MKLDNN fprop kernel. this impacts performance // MKLDNN fprop kernel. this impacts performance
writer << "memory::desc max_pool_input_desc = memory::desc({" << join(args[0].get_shape()) writer << "memory::desc max_pool_input_desc = memory::desc({"
<< "}, " << et << ", memory::format::nchw);\n"; << join(args[0].get_shape()) << "}, " << et
writer << "memory::desc max_pool_result_desc = memory::desc({" << join(args[1].get_shape()) << ", memory::format::nchw);\n";
<< "}, " << et << ", memory::format::nchw);\n"; writer << "memory::desc max_pool_result_desc = memory::desc({"
writer << "memory maxpool_input_data = memory({max_pool_input_desc, cpu_engine}, " << join(args[1].get_shape()) << "}, " << et
<< ", memory::format::nchw);\n";
writer
<< "memory maxpool_input_data = memory({max_pool_input_desc, cpu_engine}, "
<< args[0].get_name() << ");\n"; << args[0].get_name() << ");\n";
writer << "memory maxpool_result = memory({max_pool_result_desc, cpu_engine}, " writer << "memory maxpool_result = memory({max_pool_result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n"; << out[0].get_name() << ");\n";
writer << "pooling_forward::primitive_desc pool_fwd_pd = pooling_forward::primitive_desc(" writer << "pooling_forward::primitive_desc pool_fwd_pd = "
<< "{prop_kind::forward, algorithm::pooling_max, " "pooling_forward::primitive_desc("
<< "max_pool_input_desc, max_pool_result_desc, {" << "{prop_kind::forward, algorithm::pooling_max, "
<< join(max_pool_fprop_op->get_window_movement_strides()) << "}, {" << "max_pool_input_desc, max_pool_result_desc, {"
<< join(max_pool_fprop_op->get_window_shape()) << "}, " << join(max_pool_fprop_op->get_window_movement_strides()) << "}, {"
<< "{" << join(max_pool_fprop_op->get_padding_below()) << "}, " << join(max_pool_fprop_op->get_window_shape()) << "}, "
<< "{" << join(max_pool_fprop_op->get_padding_above()) << "}, " << "{" << join(max_pool_fprop_op->get_padding_below()) << "}, "
<< "padding_kind::zero}, cpu_engine);\n"; << "{" << join(max_pool_fprop_op->get_padding_above()) << "}, "
<< "padding_kind::zero}, cpu_engine);\n";
// query the workspace from the forward primitive desc and allocates memory // query the workspace from the forward primitive desc and allocates memory
writer << "auto max_pool_workspace_memory = " writer << "auto max_pool_workspace_memory = "
"memory(pool_fwd_pd.workspace_primitive_desc());\n"; "memory(pool_fwd_pd.workspace_primitive_desc());\n";
//run fprop with this workspace attached //run fprop with this workspace attached
writer << "pooling_forward max_pooling_fwd = pooling_forward(" writer << "pooling_forward max_pooling_fwd = pooling_forward("
<< "pool_fwd_pd, maxpool_input_data, maxpool_result, max_pool_workspace_memory);\n"; << "pool_fwd_pd, maxpool_input_data, maxpool_result, "
"max_pool_workspace_memory);\n";
writer << "stream s_fprop = stream(stream::kind::eager);\n" writer << "stream s_fprop = stream(stream::kind::eager);\n"
<< "s_fprop.submit({max_pooling_fwd}).wait();\n"; << "s_fprop.submit({max_pooling_fwd}).wait();\n";
//--------------------------------------------------------------------------------------------- //---------------------------------------------------------------------------------------------
writer << "auto max_pooling_bwd = pooling_backward(pooling_backward::primitive_desc(" writer << "auto max_pooling_bwd = "
<< "pooling_backward::desc(algorithm::pooling_max, " "pooling_backward(pooling_backward::primitive_desc("
<< "result_desc, input_data_desc, {" << join(mpb->get_window_movement_strides()) << "pooling_backward::desc(algorithm::pooling_max, "
<< "}, {" << join(mpb->get_window_shape()) << "}, " << "result_desc, input_data_desc, {"
<< "{" << join(mpb->get_padding_below()) << "}, " << join(mpb->get_window_movement_strides()) << "}, {"
<< "{" << join(mpb->get_padding_above()) << "}, " << join(mpb->get_window_shape()) << "}, "
<< "padding_kind::zero), cpu_engine, pool_fwd_pd), " << "{" << join(mpb->get_padding_below()) << "}, "
<< "input_data, max_pool_workspace_memory, result);\n"; << "{" << join(mpb->get_padding_above()) << "}, "
<< "padding_kind::zero), cpu_engine, pool_fwd_pd), "
<< "input_data, max_pool_workspace_memory, result);\n";
writer << "auto s_bwd = stream(stream::kind::eager);\n" writer << "auto s_bwd = stream(stream::kind::eager);\n"
<< "s_bwd.submit({max_pooling_bwd}).wait();\n"; << "s_bwd.submit({max_pooling_bwd}).wait();\n";
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
else else
{ {
writer << "kernel::max_pool_backprop<" << out[0].get_type() << ">(" << args[0].get_name() writer << "kernel::max_pool_backprop<" << out[0].get_type() << ">("
<< ",\n"; << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n"; writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n"; writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(delta_shape) << "},\n"; writer << " {" << join(delta_shape) << "},\n";
writer << " {" << join(out_shape) << "},\n"; writer << " {" << join(out_shape) << "},\n";
writer << " {" << join(mpb->get_window_shape()) << "},\n"; writer << " {" << join(mpb->get_window_shape()) << "},\n";
writer << " {" << join(mpb->get_window_movement_strides()) << "},\n"; writer << " {" << join(mpb->get_window_movement_strides())
<< "},\n";
writer << " {" << join(mpb->get_padding_below()) << "},\n"; writer << " {" << join(mpb->get_padding_below()) << "},\n";
writer << " {" << join(mpb->get_padding_above()) << "}\n"; writer << " {" << join(mpb->get_padding_above()) << "}\n";
writer << " );\n"; writer << " );\n";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment