Commit a174c8c9 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

MaxPoolWithIndices (#900)

* MaxPoolWithIndices CPU Fusion

* fix test to pass checks in cpu_fusion

* pass test

* clean up

* add a new pass, add layouts

* remove the opt from cpu_fusion

* refactor cpu_layout logic for maxpool, clean up comments

* add comment w.r.t. indices tensor

* rename to cpu_workspace_insertion

* add CPUWorkspaceInsertion pass for TF
parent 23913010
......@@ -217,9 +217,11 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/op/convert_layout.cpp
runtime/cpu/op/sigmoid.cpp
runtime/cpu/op/matmul_bias.cpp
runtime/cpu/op/max_pool_with_indices.cpp
runtime/cpu/op/batch_norm_relu.cpp
runtime/cpu/pass/cpu_assignment.cpp
runtime/cpu/pass/cpu_fusion.cpp
runtime/cpu/pass/cpu_workspace_insertion.cpp
runtime/cpu/pass/cpu_layout.cpp
runtime/cpu/pass/cpu_nop_elimination.cpp
runtime/cpu/pass/cpu_rnn_mat_fusion.cpp
......
......@@ -98,6 +98,7 @@
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
......@@ -2768,6 +2769,45 @@ namespace ngraph
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::MaxPoolWithIndices)
{
auto max_pool = static_cast<const ngraph::op::MaxPoolWithIndices*>(node);
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pooling_with_indices_forward(
mkldnn::algorithm::pooling_max,
input_desc,
result_desc,
max_pool->get_window_movement_strides(),
max_pool->get_window_shape(),
max_pool->get_padding_below(),
max_pool->get_padding_above());
auto& deps = mkldnn_emitter->get_primitive_deps(max_pool_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(max_pool_index) << ");\n";
}
else
{
throw ngraph_error("MaxPoolWithIndices isn't supported");
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Reverse)
{
......@@ -3107,6 +3147,46 @@ namespace ngraph
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::MaxPoolWithIndicesBackprop)
{
auto mpb = static_cast<const ngraph::op::MaxPoolWithIndicesBackprop*>(node);
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto diff_dst_desc = mkldnn_emitter->build_memory_descriptor(
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto diff_src_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pooling_with_indices_backward(
mkldnn::algorithm::pooling_max,
diff_dst_desc,
diff_src_desc,
mpb->get_window_movement_strides(),
mpb->get_window_shape(),
mpb->get_padding_below(),
mpb->get_padding_above());
auto& bdeps = mkldnn_emitter->get_primitive_deps(max_pool_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(bdeps[0])
<< ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(bdeps[1])
<< ", " << args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(bdeps[2])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(max_pool_index) << ");\n";
}
else
{
throw ngraph_error("MaxPoolWithIndicesBackprop isn't supported");
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Product)
{
......
......@@ -119,6 +119,7 @@
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_assignment.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
......@@ -126,6 +127,7 @@
#include "ngraph/runtime/cpu/pass/cpu_nop_elimination.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_shuffle_folding.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/op/allreduce.hpp"
......@@ -262,6 +264,7 @@ static const runtime::cpu::OpMap dispatcher{
&runtime::cpu::CPU_Emitter::emit<runtime::cpu::op::ConvertLayout>},
{TI(ngraph::op::Not), &runtime::cpu::CPU_Emitter::emit<op::Not>},
{TI(ngraph::op::MaxPool), &runtime::cpu::CPU_Emitter::emit<op::MaxPool>},
{TI(ngraph::op::MaxPoolWithIndices), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolWithIndices>},
{TI(ngraph::op::Reverse), &runtime::cpu::CPU_Emitter::emit<op::Reverse>},
{TI(ngraph::op::ReverseSequence), &runtime::cpu::CPU_Emitter::emit<op::ReverseSequence>},
{TI(ngraph::op::Result), &runtime::cpu::CPU_Emitter::emit<op::Result>},
......@@ -274,6 +277,8 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::BatchNormRelu), &runtime::cpu::CPU_Emitter::emit<op::BatchNormRelu>},
{TI(ngraph::op::BatchNormBackprop), &runtime::cpu::CPU_Emitter::emit<op::BatchNormBackprop>},
{TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolBackprop>},
{TI(ngraph::op::MaxPoolWithIndicesBackprop),
&runtime::cpu::CPU_Emitter::emit<op::MaxPoolWithIndicesBackprop>},
{TI(ngraph::op::Product), &runtime::cpu::CPU_Emitter::emit<op::Product>},
{TI(ngraph::op::Max), &runtime::cpu::CPU_Emitter::emit<op::Max>},
{TI(ngraph::op::Min), &runtime::cpu::CPU_Emitter::emit<op::Min>},
......@@ -317,6 +322,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>();
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
......
......@@ -453,6 +453,86 @@ size_t MKLDNNEmitter::build_max_pooling_backward(mkldnn::algorithm pooling_algor
return bwd_primitive_index;
}
size_t MKLDNNEmitter::build_max_pooling_with_indices_forward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& src_desc,
const mkldnn::memory::desc& dst_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{
size_t src_index = build_memory_primitive(src_desc);
size_t dst_index = build_memory_primitive(dst_desc);
mkldnn::pooling_forward::primitive_desc fwd_pd{
{mkldnn::prop_kind::forward_training,
pooling_algorithm,
src_desc,
dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine};
auto ws_index = build_memory_primitive(fwd_pd.workspace_primitive_desc().desc());
size_t fwd_primitive_index =
insert_primitive(new mkldnn::pooling_forward(fwd_pd,
*m_mkldnn_primitives[src_index],
*m_mkldnn_primitives[dst_index],
*m_mkldnn_primitives[ws_index]));
m_primitive_deps[fwd_primitive_index] = {src_index, dst_index, ws_index};
return fwd_primitive_index;
}
size_t MKLDNNEmitter::build_max_pooling_with_indices_backward(
mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{
size_t diff_dst_index = build_memory_primitive(diff_dst_desc);
size_t diff_src_index = build_memory_primitive(diff_src_desc);
mkldnn::pooling_forward::primitive_desc fwd_pd{
{mkldnn::prop_kind::forward_training,
pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine};
auto fprop_ws_index = build_memory_primitive(fwd_pd.workspace_primitive_desc().desc());
size_t bwd_primitive_index = insert_primitive(new mkldnn::pooling_backward(
{{pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine,
fwd_pd},
*m_mkldnn_primitives[diff_dst_index],
*m_mkldnn_primitives[fprop_ws_index],
*m_mkldnn_primitives[diff_src_index]));
m_primitive_deps[bwd_primitive_index] = {diff_dst_index, fprop_ws_index, diff_src_index};
return bwd_primitive_index;
}
size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc)
{
......
......@@ -133,6 +133,14 @@ namespace ngraph
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_max_pooling_with_indices_forward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& src_desc,
const mkldnn::memory::desc& dst_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_max_pooling_backward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& fprop_src_desc,
const mkldnn::memory::desc& diff_dst_desc,
......@@ -142,6 +150,15 @@ namespace ngraph
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_max_pooling_with_indices_backward(
mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/function.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::MaxPoolWithIndices::MaxPoolWithIndices(const shared_ptr<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above)
: RequiresTensorViewArgs("MaxPoolWithIndices", {arg})
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below)
, m_padding_above(padding_above)
{
auto& arg_shape = get_input_shape(0);
//
// Make sure arg: NCDi for some Di of rank>0, N != 0, C != 0.
//
if (arg_shape.size() < 3)
{
throw ngraph_error(
"Max-pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one spatial dimension).");
}
size_t batch_size = arg_shape[0];
if (batch_size == 0)
{
throw ngraph_error("Max-pool data batch size is zero.");
}
size_t channel_count = arg_shape[1];
if (channel_count == 0)
{
throw ngraph_error("Max-pool requires at least one feature channel.");
}
size_t spatial_dimension_count = arg_shape.size() - 2;
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
if (window_shape.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool window shape rank does not match number of spatial dimensions.");
}
if (window_movement_strides.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool window movement stride rank does not match number of spatial "
"dimensions.");
}
if (padding_below.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool below-padding rank does not match number of spatial dimensions.");
}
if (padding_above.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool above-padding rank does not match number of spatial dimensions.");
}
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
//
Shape input_item_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
size_t dim_size = arg_shape[1 + 1 + i];
size_t virtual_dim_size = padding_below[i] + dim_size + padding_above[i];
input_item_virtual_shape.push_back(virtual_dim_size);
if (virtual_dim_size == 0)
{
throw ngraph_error("Max-pool input spatial dimension is zero even after padding.");
}
}
//
// Make sure window shape dimensions are all larger than 0.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
if (window_shape[i] == 0)
{
throw ngraph_error("Max-pool window shape has a zero-length axis.");
}
}
//
// Make the max pooling window fits within the spatial dimensions.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
if (window_shape[i] > input_item_virtual_shape[i])
{
throw ngraph_error(
"Max-pool window shape is larger than the spatial dimensions even after "
"padding.");
}
}
//
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
//
Shape output_item_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
if (window_movement_strides[i] == 0)
{
throw ngraph_error("Max-pool window axis movement stride is zero.");
}
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - window_shape[i] + 1,
window_movement_strides[i]));
}
//
// Construct result shape: NCDo.
//
Shape result_shape(1 + 1 + spatial_dimension_count);
result_shape[0] = batch_size;
result_shape[1] = channel_count;
copy(output_item_shape.begin(), output_item_shape.end(), result_shape.begin() + 2);
add_output(get_input_element_type(0), result_shape);
//MKLDNN can pick one of the two following datatypes
//to store maximum indices: s32 and u8.
//For smaller kernels, where 255 positions is enough
//to span the entire kernel, u8 is picked.
//We conservatively always use s32
//to simplify MaxPoolWithIndices c-tor.
add_output(element::i32, result_shape);
}
shared_ptr<Node> op::MaxPoolWithIndices::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<MaxPoolWithIndices>(new_args.at(0),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above);
}
op::MaxPoolWithIndicesBackprop::MaxPoolWithIndicesBackprop(const shared_ptr<Node>& arg_forward,
const shared_ptr<Node>& delta,
const shared_ptr<Node>& indices,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above)
: RequiresTensorViewArgs("MaxPoolWithIndicesBackprop", {arg_forward, delta, indices})
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below)
, m_padding_above(padding_above)
{
if (delta->get_shape() != indices->get_shape())
{
throw ngraph_error("delta shape doesn't match indices' ");
}
if (get_input_element_type(0) != get_input_element_type(1))
{
throw ngraph_error("Max-pool backprop: data batch and delta element types do not match.");
}
auto& arg_forward_shape = get_input_shape(0);
auto& delta_shape = get_input_shape(1);
//
// Make sure arg: NCDi for some Di of rank>0, N != 0, C != 0.
//
if (arg_forward_shape.size() < 3)
{
throw ngraph_error(
"Max-pool backprop: data batch shape must have rank of at least 3 (one batch axis, "
"one channel axis, at least one spatial dimension).");
}
size_t batch_size = arg_forward_shape[0];
if (batch_size == 0)
{
throw ngraph_error("Max-pool backprop: data batch size is zero.");
}
size_t channel_count = arg_forward_shape[1];
if (channel_count == 0)
{
throw ngraph_error("Max-pool backprop: requires at least one feature channel.");
}
size_t spatial_dimension_count = arg_forward_shape.size() - 2;
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
if (window_shape.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool backprop: window shape rank does not match number of spatial "
"dimensions.");
}
if (window_movement_strides.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool backprop: window movement stride rank does not match number of spatial "
"dimensions.");
}
if (padding_below.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool backprop: below-padding rank does not match number of spatial "
"dimensions.");
}
if (padding_above.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool backprop: above-padding rank does not match number of spatial "
"dimensions.");
}
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
//
Shape input_item_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
size_t dim_size = arg_forward_shape[1 + 1 + i];
size_t virtual_dim_size = padding_below[i] + dim_size + padding_above[i];
input_item_virtual_shape.push_back(virtual_dim_size);
if (virtual_dim_size == 0)
{
throw ngraph_error(
"Max-pool backprop: data batch spatial dimension is zero even after padding.");
}
}
//
// Make sure window shape dimensions are all larger than 0.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
if (window_shape[i] == 0)
{
throw ngraph_error("Max-pool backprop: window shape has a zero-length axis.");
}
}
//
// Make the max pooling window fits within the spatial dimensions.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
if (window_shape[i] > input_item_virtual_shape[i])
{
throw ngraph_error(
"Max-pool backprop: window shape is larger than the spatial dimensions even after "
"padding.");
}
}
//
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
//
Shape output_item_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
if (window_movement_strides[i] == 0)
{
throw ngraph_error("Max-pool backprop: window axis movement stride is zero.");
}
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - window_shape[i] + 1,
window_movement_strides[i]));
}
//
// Construct result shape: NCDo.
//
Shape forward_result_shape(1 + 1 + spatial_dimension_count);
forward_result_shape[0] = batch_size;
forward_result_shape[1] = channel_count;
copy(output_item_shape.begin(), output_item_shape.end(), forward_result_shape.begin() + 2);
if (forward_result_shape != delta_shape)
{
throw ngraph_error("Max-pool backprop: forward result shape does not match delta shape.");
}
set_value_type_checked(get_input_element_type(0), arg_forward_shape);
}
shared_ptr<Node>
op::MaxPoolWithIndicesBackprop::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
MaxPoolWithIndicesBackprop* mpbp = new MaxPoolWithIndicesBackprop(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above);
return shared_ptr<op::MaxPoolWithIndicesBackprop>(mpbp);
}
void op::MaxPoolWithIndices::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas)
{
throw ngraph_error("Differentation of MaxPoolWithIndices isn't supported");
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
//MaxPoolWithIndices produces two outputs.
//The first output is equivalent to what MaxPool produces
//The second one contains the indices of the maximum numbers
//for each window in input (arg)
//These indices are used by MKLDNN for a back propagation pass
class MaxPoolWithIndices : public util::RequiresTensorViewArgs
{
public:
MaxPoolWithIndices(const std::shared_ptr<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const Shape& get_window_shape() const { return m_window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Shape& get_padding_below() const { return m_padding_below; }
const Shape& get_padding_above() const { return m_padding_above; }
virtual std::shared_ptr<Node> get_default_value() const override
{
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
Shape m_window_shape;
Strides m_window_movement_strides;
Shape m_padding_below;
Shape m_padding_above;
};
//MaxPoolWithIndicesBackprop takes MaxPoolWithIndices' outputs and
//pass the indices directly to MKLDNN to avoid max indices recomputation
class MaxPoolWithIndicesBackprop : public util::RequiresTensorViewArgs
{
public:
MaxPoolWithIndicesBackprop(const std::shared_ptr<Node>& arg_forward,
const std::shared_ptr<Node>& delta,
const std::shared_ptr<Node>& indices,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const Shape& get_window_shape() const { return m_window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Shape& get_padding_below() const { return m_padding_below; }
const Shape& get_padding_above() const { return m_padding_above; }
protected:
Shape m_window_shape;
Strides m_window_movement_strides;
Shape m_padding_below;
Shape m_padding_above;
};
}
}
......@@ -37,6 +37,7 @@
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace std;
......@@ -342,6 +343,25 @@ namespace ngraph
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolWithIndices)
{
auto max_pool = static_cast<op::MaxPoolWithIndices*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && max_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolBackprop)
{
......@@ -361,6 +381,25 @@ namespace ngraph
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolWithIndicesBackprop)
{
auto max_pool = static_cast<op::MaxPoolWithIndicesBackprop*>(node);
auto arg1_shape = node->get_input_shape(1);
auto arg1_rank = arg1_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg1_rank == 4 && max_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(1) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu)
{
......@@ -487,8 +526,12 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::MaxPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPool>},
{TI(ngraph::op::MaxPoolWithIndices),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolWithIndices>},
{TI(ngraph::op::MaxPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolBackprop>},
{TI(ngraph::op::MaxPoolWithIndicesBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolWithIndicesBackprop>},
{TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
......
......@@ -43,6 +43,7 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace std;
......@@ -794,12 +795,12 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::MaxPool)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
template <typename T, prop_kind pk>
void MaxPoolLayout(std::shared_ptr<ngraph::Node> node,
vector<memory::format>& prim_input_formats,
vector<memory::format>& prim_output_formats)
{
auto max_pool = static_cast<const ngraph::op::MaxPool*>(node.get());
auto max_pool = static_cast<const T*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto result_shape = node->get_output_shape(0);
......@@ -818,23 +819,18 @@ namespace ngraph
memory::dims mkldnn_filter_shape(filter_shape.begin(), filter_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
auto input_desc = memory::desc(mkldnn_arg0_shape, et, input_layout);
auto result_desc =
memory::desc(mkldnn_result_shape, et, memory::format::any);
auto result_desc = memory::desc(mkldnn_result_shape, et, memory::format::any);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
try
{
auto prim_desc = pooling_forward::primitive_desc(
{prop_kind::forward_inference,
{pk,
algorithm_enumerator,
input_desc,
result_desc,
......@@ -847,13 +843,29 @@ namespace ngraph
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.dst_primitive_desc().desc().data.format));
// TODO (jbobba): Add workspace layouts here
if (pk == prop_kind::forward_training)
{
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.workspace_primitive_desc().desc().data.format));
}
}
catch (const mkldnn::error& e)
{
throw ngraph_error("MKLDNN Unsupported pooling fwd layout" +
to_string(input_layout) + e.message);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::MaxPoolWithIndices)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
MaxPoolLayout<ngraph::op::MaxPoolWithIndices, prop_kind::forward_training>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
......@@ -866,14 +878,35 @@ namespace ngraph
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::MaxPoolBackprop)
void CPULayout::LAYOUT_DECL(ngraph::op::MaxPool)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto max_pool = static_cast<const ngraph::op::MaxPoolBackprop*>(node.get());
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
MaxPoolLayout<ngraph::op::MaxPool, prop_kind::forward_inference>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <typename T, bool with_indices>
void MaxPoolBackpropLayout(std::shared_ptr<ngraph::Node> node,
vector<memory::format>& prim_input_formats,
vector<memory::format>& prim_output_formats)
{
auto max_pool = static_cast<const T*>(node.get());
// arg 0 - fprop input
// arg 0 - work input
// arg 1 - delta
// arg 2 - work space
// Propagate fprop's input layout
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
......@@ -894,21 +927,15 @@ namespace ngraph
memory::dims mkldnn_filter_shape(filter_shape.begin(), filter_shape.end());
memory::dims mkldnn_filter_strides(filter_strides.begin(),
filter_strides.end());
memory::dims mkldnn_padding_below(padding_below.begin(),
padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(),
padding_above.end());
memory::dims mkldnn_padding_below(padding_below.begin(), padding_below.end());
memory::dims mkldnn_padding_above(padding_above.begin(), padding_above.end());
auto fprop_input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
auto diff_dst_desc =
memory::desc(mkldnn_arg1_shape, et, fprop_input_layout);
auto diff_src_desc =
memory::desc(mkldnn_arg0_shape, et, memory::format::any);
auto diff_dst_desc = memory::desc(mkldnn_arg1_shape, et, fprop_input_layout);
auto diff_src_desc = memory::desc(mkldnn_arg0_shape, et, memory::format::any);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
try
{
auto fwd_prim_desc = pooling_forward::primitive_desc(
......@@ -936,6 +963,13 @@ namespace ngraph
fwd_prim_desc);
prim_input_formats.push_back(fprop_input_layout);
prim_input_formats.push_back(fprop_input_layout);
if (with_indices)
{
prim_input_formats.push_back(static_cast<memory::format>(
fwd_prim_desc.workspace_primitive_desc().desc().data.format));
}
prim_output_formats.push_back(static_cast<memory::format>(
prim_desc.diff_src_primitive_desc().desc().data.format));
}
......@@ -944,6 +978,37 @@ namespace ngraph
throw ngraph_error("MKLDNN Unsupported pooling layout" +
to_string(fprop_input_layout) + e.message);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::MaxPoolBackprop)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
MaxPoolBackpropLayout<ngraph::op::MaxPoolBackprop, false>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::MaxPoolWithIndicesBackprop)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
MaxPoolBackpropLayout<ngraph::op::MaxPoolWithIndicesBackprop, true>(
node, prim_input_formats, prim_output_formats);
node =
insert_input_conversions(external_function, node, prim_input_formats);
......@@ -1303,8 +1368,12 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::MaxPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPool>},
{TI(ngraph::op::MaxPoolWithIndices),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPoolWithIndices>},
{TI(ngraph::op::MaxPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPoolBackprop>},
{TI(ngraph::op::MaxPoolWithIndicesBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPoolWithIndicesBackprop>},
{TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionRelu),
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "cpu_workspace_insertion.hpp"
#include <algorithm>
#include <iostream>
#include <numeric>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
void ngraph::runtime::cpu::pass::CPUWorkspaceInsertion::construct_max_pool_with_indices()
{
Shape shape_data{1, 1, 14};
auto data = std::make_shared<pattern::op::Label>(element::f32, shape_data);
Shape window_shape{3};
auto max_pool = std::make_shared<op::MaxPool>(data, window_shape);
auto delta = std::make_shared<pattern::op::Label>(element::f32, max_pool->get_shape());
auto max_pool_bprop =
std::make_shared<op::MaxPoolBackprop>(data,
delta,
max_pool->get_window_shape(),
max_pool->get_window_movement_strides(),
max_pool->get_padding_below(),
max_pool->get_padding_above());
pattern::graph_rewrite_callback callback = [data, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_max_pool_with_indices against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto m_max_pool_bprop = std::dynamic_pointer_cast<op::MaxPoolBackprop>(m.get_match_root());
if (m_max_pool_bprop->get_shape().size() != 4 ||
m_max_pool_bprop->get_window_shape().size() != 2 ||
m_max_pool_bprop->get_input_element_type(0) != element::f32)
{
NGRAPH_DEBUG << "MKLDNN doesn't support inputs of given shape type";
return false;
}
//find the original MaxPool now
std::shared_ptr<op::MaxPool> m_max_pool;
for (auto u : pattern_map[data]->get_users())
{
if (auto mp = std::dynamic_pointer_cast<op::MaxPool>(u))
{
if (mp->get_window_shape() == m_max_pool_bprop->get_window_shape() &&
mp->get_window_movement_strides() ==
m_max_pool_bprop->get_window_movement_strides() &&
mp->get_padding_below() == m_max_pool_bprop->get_padding_below() &&
mp->get_padding_above() == m_max_pool_bprop->get_padding_above())
{
m_max_pool = mp;
break;
}
}
}
if (!m_max_pool)
{
NGRAPH_DEBUG << "MaxPool for " << pattern_map[data]->get_name() << " and "
<< m_max_pool_bprop->get_name() << " not found";
}
auto max_pool_with_indices =
std::make_shared<op::MaxPoolWithIndices>(pattern_map[data],
m_max_pool->get_window_shape(),
m_max_pool->get_window_movement_strides(),
m_max_pool->get_padding_below(),
m_max_pool->get_padding_above());
auto max_pool_with_indices_output =
std::make_shared<op::GetOutputElement>(max_pool_with_indices, 0);
auto max_pool_with_indices_indices =
std::make_shared<op::GetOutputElement>(max_pool_with_indices, 1);
//rewire users to use a new MaxPoolWithIndices (maxpool's output)
for (auto& o : m_max_pool->get_outputs())
{
std::set<ngraph::descriptor::Input*> copy{begin(o.get_inputs()), end(o.get_inputs())};
for (auto i : copy)
{
i->replace_output(max_pool_with_indices_output->get_outputs().at(0));
}
}
//create a new max_pool_with_indices_bprop
auto max_pool_with_indices_bprop = std::make_shared<op::MaxPoolWithIndicesBackprop>(
pattern_map[data],
pattern_map[delta],
max_pool_with_indices_indices,
m_max_pool->get_window_shape(),
m_max_pool->get_window_movement_strides(),
m_max_pool->get_padding_below(),
m_max_pool->get_padding_above());
ngraph::replace_node(m_max_pool_bprop, max_pool_with_indices_bprop);
return true;
};
auto m = std::make_shared<pattern::Matcher>(max_pool_bprop, callback);
this->add_matcher(m);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPUWorkspaceInsertion;
}
}
}
}
class ngraph::runtime::cpu::pass::CPUWorkspaceInsertion : public ngraph::pass::GraphRewrite
{
public:
CPUWorkspaceInsertion()
: GraphRewrite()
{
construct_max_pool_with_indices();
}
private:
void construct_max_pool_with_indices();
};
......@@ -28,6 +28,7 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sum.hpp"
......@@ -48,6 +49,7 @@
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_mat_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
......@@ -1125,3 +1127,99 @@ TEST(cpu_fusion, weight_fusion)
new_convert_layout->get_argument(0)),
cvt_lt_conv);
}
TEST(cpu_fusion, max_pool_with_indices)
{
Shape shape_a{10, 3, 28, 28};
auto input = std::make_shared<op::Parameter>(element::f32, shape_a);
Shape window_shape{2, 2};
auto max_pool = std::make_shared<op::MaxPool>(input, window_shape);
auto C = std::make_shared<op::Parameter>(element::f32, max_pool->get_shape());
ngraph::autodiff::Adjoints adjoints(NodeVector{max_pool}, NodeVector{C});
auto dinput = adjoints.backprop_node(input);
auto df = std::make_shared<Function>(NodeVector{dinput}, op::ParameterVector{input, C});
auto f = std::make_shared<Function>(NodeVector{max_pool}, op::ParameterVector{input});
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("max_pool_fprop_before.pdf");
pass_manager.run_passes(f);
}
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_before.pdf");
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>();
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_after.pdf");
pass_manager.run_passes(df);
}
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("max_pool_fprop_after.pdf");
pass_manager.run_passes(f);
}
auto maxpool_goe_output =
std::dynamic_pointer_cast<op::GetOutputElement>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(maxpool_goe_output);
ASSERT_EQ(maxpool_goe_output->get_n(), 0);
auto maxpool_with_indices = df->get_results().at(0)->get_argument(0);
auto maxpool_goe_indices =
std::dynamic_pointer_cast<op::GetOutputElement>(maxpool_with_indices->get_argument(2));
ASSERT_TRUE(maxpool_goe_indices);
ASSERT_EQ(maxpool_goe_indices->get_n(), 1);
}
TEST(cpu_fusion, backwards_maxpool_with_indices_n4_c1_hw4_2x2_max)
{
Shape shape_a{1, 4, 4, 4};
Shape maxpool_shape{1, 4, 3, 3};
auto A = std::make_shared<op::Parameter>(element::f32, shape_a);
Shape window_shape{2, 2};
auto window_movement_strides = Strides{1, 1};
auto maxpool = std::make_shared<op::MaxPool>(A, window_shape, window_movement_strides);
auto f = std::make_shared<Function>(maxpool, op::ParameterVector{A});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::TensorView> ep = backend->create_tensor(element::f32, maxpool_shape);
vector<float> dataEp(shape_size(maxpool_shape), 4);
shared_ptr<runtime::TensorView> input = backend->create_tensor(element::f32, shape_a);
shared_ptr<runtime::TensorView> output = backend->create_tensor(element::f32, shape_a);
vector<float> dataInput{11.f, 31.f, 40.f, 47.f, 13.f, 61.f, 48.f, 59.f, 17.f, 39.f, 64.f,
62.f, 45.f, 55.f, 36.f, 19.f, 65.f, 33.f, 49.f, 30.f, 56.f, 41.f,
53.f, 58.f, 22.f, 35.f, 52.f, 50.f, 63.f, 54.f, 12.f, 26.f, 44.f,
21.f, 69.f, 24.f, 46.f, 25.f, 51.f, 29.f, 72.f, 15.f, 73.f, 10.f,
16.f, 37.f, 70.f, 32.f, 28.f, 66.f, 57.f, 27.f, 60.f, 42.f, 43.f,
71.f, 18.f, 38.f, 67.f, 68.f, 14.f, 20.f, 34.f, 23.f};
vector<float> expected{0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 12.0f, 0.0f, 4.0f, 0.0f, 0.0f, 16.0f,
0.0f, 0.0f, 4.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f, 0.0f, 4.0f, 0.0f,
8.0f, 8.0f, 0.0f, 0.0f, 4.0f, 0.0f, 4.0f, 4.0f, 0.0f, 0.0f, 0.0f,
0.0f, 8.0f, 0.0f, 4.0f, 0.0f, 0.0f, 0.0f, 8.0f, 0.0f, 16.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 8.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f,
8.0f, 0.0f, 4.0f, 8.0f, 4.0f, 0.0f, 0.0f, 0.0f, 0.0f};
copy_data(ep, dataEp);
copy_data(input, dataInput);
auto C = std::make_shared<op::Parameter>(element::f32, maxpool_shape);
auto df = autodiff::backprop_function(f);
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_before2.pdf");
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>();
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_after2.pdf");
pass_manager.run_passes(df);
}
backend->call(df, {output}, {input, ep});
ASSERT_TRUE(read_vector<float>(output) == expected);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment