Unverified Commit 29014bab authored by Adam Procter's avatar Adam Procter Committed by GitHub

Allow use of MKLDNN backprop when pointer to forward-prop op is missing (#523)

parent 84d236ad
......@@ -2682,14 +2682,13 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::MaxPoolBackprop)
{
auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node);
auto max_pool_fprop_op = mpb->get_forward_op();
auto delta_shape = args[1].get_shape();
auto delta_rank = delta_shape.size();
auto out_shape = out[0].get_shape();
if (delta_rank == 4 && mpb->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32 && max_pool_fprop_op != nullptr)
args[0].get_element_type() == element::f32)
{
const string& et =
get_mkldnn_data_type(args[1].get_element_type().c_type_string());
......@@ -2725,10 +2724,10 @@ namespace ngraph
"pooling_forward::primitive_desc("
<< "{prop_kind::forward, algorithm::pooling_max, "
<< "max_pool_input_desc, max_pool_result_desc, {"
<< join(max_pool_fprop_op->get_window_movement_strides()) << "}, {"
<< join(max_pool_fprop_op->get_window_shape()) << "}, "
<< "{" << join(max_pool_fprop_op->get_padding_below()) << "}, "
<< "{" << join(max_pool_fprop_op->get_padding_above()) << "}, "
<< join(mpb->get_window_movement_strides()) << "}, {"
<< join(mpb->get_window_shape()) << "}, "
<< "{" << join(mpb->get_padding_below()) << "}, "
<< "{" << join(mpb->get_padding_above()) << "}, "
<< "padding_kind::zero}, cpu_engine);\n";
// query the workspace from the forward primitive desc and allocates memory
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment