Commit a174c8c9 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

MaxPoolWithIndices (#900)

* MaxPoolWithIndices CPU Fusion

* fix test to pass checks in cpu_fusion

* pass test

* clean up

* add a new pass, add layouts

* remove the opt from cpu_fusion

* refactor cpu_layout logic for maxpool, clean up comments

* add comment w.r.t. indices tensor

* rename to cpu_workspace_insertion

* add CPUWorkspaceInsertion pass for TF
parent 23913010
...@@ -217,9 +217,11 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND ...@@ -217,9 +217,11 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/op/convert_layout.cpp runtime/cpu/op/convert_layout.cpp
runtime/cpu/op/sigmoid.cpp runtime/cpu/op/sigmoid.cpp
runtime/cpu/op/matmul_bias.cpp runtime/cpu/op/matmul_bias.cpp
runtime/cpu/op/max_pool_with_indices.cpp
runtime/cpu/op/batch_norm_relu.cpp runtime/cpu/op/batch_norm_relu.cpp
runtime/cpu/pass/cpu_assignment.cpp runtime/cpu/pass/cpu_assignment.cpp
runtime/cpu/pass/cpu_fusion.cpp runtime/cpu/pass/cpu_fusion.cpp
runtime/cpu/pass/cpu_workspace_insertion.cpp
runtime/cpu/pass/cpu_layout.cpp runtime/cpu/pass/cpu_layout.cpp
runtime/cpu/pass/cpu_nop_elimination.cpp runtime/cpu/pass/cpu_nop_elimination.cpp
runtime/cpu/pass/cpu_rnn_mat_fusion.cpp runtime/cpu/pass/cpu_rnn_mat_fusion.cpp
......
...@@ -98,6 +98,7 @@ ...@@ -98,6 +98,7 @@
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp" #include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp" #include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -2768,6 +2769,45 @@ namespace ngraph ...@@ -2768,6 +2769,45 @@ namespace ngraph
} }
} }
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::MaxPoolWithIndices)
{
auto max_pool = static_cast<const ngraph::op::MaxPoolWithIndices*>(node);
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pooling_with_indices_forward(
mkldnn::algorithm::pooling_max,
input_desc,
result_desc,
max_pool->get_window_movement_strides(),
max_pool->get_window_shape(),
max_pool->get_padding_below(),
max_pool->get_padding_above());
auto& deps = mkldnn_emitter->get_primitive_deps(max_pool_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(max_pool_index) << ");\n";
}
else
{
throw ngraph_error("MaxPoolWithIndices isn't supported");
}
}
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Reverse) void CPU_Emitter::EMITTER_DECL(ngraph::op::Reverse)
{ {
...@@ -3107,6 +3147,46 @@ namespace ngraph ...@@ -3107,6 +3147,46 @@ namespace ngraph
} }
} }
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::MaxPoolWithIndicesBackprop)
{
auto mpb = static_cast<const ngraph::op::MaxPoolWithIndicesBackprop*>(node);
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto diff_dst_desc = mkldnn_emitter->build_memory_descriptor(
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto diff_src_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t max_pool_index = mkldnn_emitter->build_max_pooling_with_indices_backward(
mkldnn::algorithm::pooling_max,
diff_dst_desc,
diff_src_desc,
mpb->get_window_movement_strides(),
mpb->get_window_shape(),
mpb->get_padding_below(),
mpb->get_padding_above());
auto& bdeps = mkldnn_emitter->get_primitive_deps(max_pool_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(bdeps[0])
<< ", " << args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(bdeps[1])
<< ", " << args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(bdeps[2])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(max_pool_index) << ");\n";
}
else
{
throw ngraph_error("MaxPoolWithIndicesBackprop isn't supported");
}
}
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Product) void CPU_Emitter::EMITTER_DECL(ngraph::op::Product)
{ {
......
...@@ -119,6 +119,7 @@ ...@@ -119,6 +119,7 @@
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp" #include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp" #include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_assignment.hpp" #include "ngraph/runtime/cpu/pass/cpu_assignment.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp" #include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
...@@ -126,6 +127,7 @@ ...@@ -126,6 +127,7 @@
#include "ngraph/runtime/cpu/pass/cpu_nop_elimination.hpp" #include "ngraph/runtime/cpu/pass/cpu_nop_elimination.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp" #include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_shuffle_folding.hpp" #include "ngraph/runtime/cpu/pass/cpu_shuffle_folding.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
...@@ -262,6 +264,7 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -262,6 +264,7 @@ static const runtime::cpu::OpMap dispatcher{
&runtime::cpu::CPU_Emitter::emit<runtime::cpu::op::ConvertLayout>}, &runtime::cpu::CPU_Emitter::emit<runtime::cpu::op::ConvertLayout>},
{TI(ngraph::op::Not), &runtime::cpu::CPU_Emitter::emit<op::Not>}, {TI(ngraph::op::Not), &runtime::cpu::CPU_Emitter::emit<op::Not>},
{TI(ngraph::op::MaxPool), &runtime::cpu::CPU_Emitter::emit<op::MaxPool>}, {TI(ngraph::op::MaxPool), &runtime::cpu::CPU_Emitter::emit<op::MaxPool>},
{TI(ngraph::op::MaxPoolWithIndices), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolWithIndices>},
{TI(ngraph::op::Reverse), &runtime::cpu::CPU_Emitter::emit<op::Reverse>}, {TI(ngraph::op::Reverse), &runtime::cpu::CPU_Emitter::emit<op::Reverse>},
{TI(ngraph::op::ReverseSequence), &runtime::cpu::CPU_Emitter::emit<op::ReverseSequence>}, {TI(ngraph::op::ReverseSequence), &runtime::cpu::CPU_Emitter::emit<op::ReverseSequence>},
{TI(ngraph::op::Result), &runtime::cpu::CPU_Emitter::emit<op::Result>}, {TI(ngraph::op::Result), &runtime::cpu::CPU_Emitter::emit<op::Result>},
...@@ -274,6 +277,8 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -274,6 +277,8 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::BatchNormRelu), &runtime::cpu::CPU_Emitter::emit<op::BatchNormRelu>}, {TI(ngraph::op::BatchNormRelu), &runtime::cpu::CPU_Emitter::emit<op::BatchNormRelu>},
{TI(ngraph::op::BatchNormBackprop), &runtime::cpu::CPU_Emitter::emit<op::BatchNormBackprop>}, {TI(ngraph::op::BatchNormBackprop), &runtime::cpu::CPU_Emitter::emit<op::BatchNormBackprop>},
{TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolBackprop>}, {TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolBackprop>},
{TI(ngraph::op::MaxPoolWithIndicesBackprop),
&runtime::cpu::CPU_Emitter::emit<op::MaxPoolWithIndicesBackprop>},
{TI(ngraph::op::Product), &runtime::cpu::CPU_Emitter::emit<op::Product>}, {TI(ngraph::op::Product), &runtime::cpu::CPU_Emitter::emit<op::Product>},
{TI(ngraph::op::Max), &runtime::cpu::CPU_Emitter::emit<op::Max>}, {TI(ngraph::op::Max), &runtime::cpu::CPU_Emitter::emit<op::Max>},
{TI(ngraph::op::Min), &runtime::cpu::CPU_Emitter::emit<op::Min>}, {TI(ngraph::op::Min), &runtime::cpu::CPU_Emitter::emit<op::Min>},
...@@ -317,6 +322,7 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -317,6 +322,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>(); pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CoreFusion>(); pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>();
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this); pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this); pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>(); pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
......
...@@ -453,6 +453,86 @@ size_t MKLDNNEmitter::build_max_pooling_backward(mkldnn::algorithm pooling_algor ...@@ -453,6 +453,86 @@ size_t MKLDNNEmitter::build_max_pooling_backward(mkldnn::algorithm pooling_algor
return bwd_primitive_index; return bwd_primitive_index;
} }
size_t MKLDNNEmitter::build_max_pooling_with_indices_forward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& src_desc,
const mkldnn::memory::desc& dst_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{
size_t src_index = build_memory_primitive(src_desc);
size_t dst_index = build_memory_primitive(dst_desc);
mkldnn::pooling_forward::primitive_desc fwd_pd{
{mkldnn::prop_kind::forward_training,
pooling_algorithm,
src_desc,
dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine};
auto ws_index = build_memory_primitive(fwd_pd.workspace_primitive_desc().desc());
size_t fwd_primitive_index =
insert_primitive(new mkldnn::pooling_forward(fwd_pd,
*m_mkldnn_primitives[src_index],
*m_mkldnn_primitives[dst_index],
*m_mkldnn_primitives[ws_index]));
m_primitive_deps[fwd_primitive_index] = {src_index, dst_index, ws_index};
return fwd_primitive_index;
}
size_t MKLDNNEmitter::build_max_pooling_with_indices_backward(
mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above)
{
size_t diff_dst_index = build_memory_primitive(diff_dst_desc);
size_t diff_src_index = build_memory_primitive(diff_src_desc);
mkldnn::pooling_forward::primitive_desc fwd_pd{
{mkldnn::prop_kind::forward_training,
pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine};
auto fprop_ws_index = build_memory_primitive(fwd_pd.workspace_primitive_desc().desc());
size_t bwd_primitive_index = insert_primitive(new mkldnn::pooling_backward(
{{pooling_algorithm,
diff_src_desc,
diff_dst_desc,
mkldnn::memory::dims(window_strides.begin(), window_strides.end()),
mkldnn::memory::dims(window_shape.begin(), window_shape.end()),
mkldnn::memory::dims(padding_below.begin(), padding_below.end()),
mkldnn::memory::dims(padding_above.begin(), padding_above.end()),
mkldnn::padding_kind::zero},
mkldnn_utils::global_cpu_engine,
fwd_pd},
*m_mkldnn_primitives[diff_dst_index],
*m_mkldnn_primitives[fprop_ws_index],
*m_mkldnn_primitives[diff_src_index]));
m_primitive_deps[bwd_primitive_index] = {diff_dst_index, fprop_ws_index, diff_src_index};
return bwd_primitive_index;
}
size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc, size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc) const mkldnn::memory::desc& result_desc)
{ {
......
...@@ -133,6 +133,14 @@ namespace ngraph ...@@ -133,6 +133,14 @@ namespace ngraph
const ngraph::Shape& padding_below, const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above); const ngraph::Shape& padding_above);
size_t build_max_pooling_with_indices_forward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& src_desc,
const mkldnn::memory::desc& dst_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_max_pooling_backward(mkldnn::algorithm pooling_algorithm, size_t build_max_pooling_backward(mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& fprop_src_desc, const mkldnn::memory::desc& fprop_src_desc,
const mkldnn::memory::desc& diff_dst_desc, const mkldnn::memory::desc& diff_dst_desc,
...@@ -142,6 +150,15 @@ namespace ngraph ...@@ -142,6 +150,15 @@ namespace ngraph
const ngraph::Shape& padding_below, const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above); const ngraph::Shape& padding_above);
size_t build_max_pooling_with_indices_backward(
mkldnn::algorithm pooling_algorithm,
const mkldnn::memory::desc& diff_dst_desc,
const mkldnn::memory::desc& diff_src_desc,
const ngraph::Strides& window_strides,
const ngraph::Shape& window_shape,
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_reorder(const mkldnn::memory::desc& input_desc, size_t build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc); const mkldnn::memory::desc& result_desc);
......
This diff is collapsed.
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
//MaxPoolWithIndices produces two outputs.
//The first output is equivalent to what MaxPool produces
//The second one contains the indices of the maximum numbers
//for each window in input (arg)
//These indices are used by MKLDNN for a back propagation pass
class MaxPoolWithIndices : public util::RequiresTensorViewArgs
{
public:
MaxPoolWithIndices(const std::shared_ptr<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const Shape& get_window_shape() const { return m_window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Shape& get_padding_below() const { return m_padding_below; }
const Shape& get_padding_above() const { return m_padding_above; }
virtual std::shared_ptr<Node> get_default_value() const override
{
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
Shape m_window_shape;
Strides m_window_movement_strides;
Shape m_padding_below;
Shape m_padding_above;
};
//MaxPoolWithIndicesBackprop takes MaxPoolWithIndices' outputs and
//pass the indices directly to MKLDNN to avoid max indices recomputation
class MaxPoolWithIndicesBackprop : public util::RequiresTensorViewArgs
{
public:
MaxPoolWithIndicesBackprop(const std::shared_ptr<Node>& arg_forward,
const std::shared_ptr<Node>& delta,
const std::shared_ptr<Node>& indices,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const Shape& get_window_shape() const { return m_window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Shape& get_padding_below() const { return m_padding_below; }
const Shape& get_padding_above() const { return m_padding_above; }
protected:
Shape m_window_shape;
Strides m_window_movement_strides;
Shape m_padding_below;
Shape m_padding_above;
};
}
}
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp" #include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp" #include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp" #include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace std; using namespace std;
...@@ -342,6 +343,25 @@ namespace ngraph ...@@ -342,6 +343,25 @@ namespace ngraph
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolWithIndices)
{
auto max_pool = static_cast<op::MaxPoolWithIndices*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && max_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
}
}
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolBackprop) void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolBackprop)
{ {
...@@ -361,6 +381,25 @@ namespace ngraph ...@@ -361,6 +381,25 @@ namespace ngraph
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolWithIndicesBackprop)
{
auto max_pool = static_cast<op::MaxPoolWithIndicesBackprop*>(node);
auto arg1_shape = node->get_input_shape(1);
auto arg1_rank = arg1_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg1_rank == 4 && max_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(1) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
}
}
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu) void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu)
{ {
...@@ -487,8 +526,12 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{ ...@@ -487,8 +526,12 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::ConvolutionBackpropFilters), {TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropFilters>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::MaxPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPool>}, {TI(ngraph::op::MaxPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPool>},
{TI(ngraph::op::MaxPoolWithIndices),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolWithIndices>},
{TI(ngraph::op::MaxPoolBackprop), {TI(ngraph::op::MaxPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolBackprop>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolBackprop>},
{TI(ngraph::op::MaxPoolWithIndicesBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::MaxPoolWithIndicesBackprop>},
{TI(ngraph::op::ConvolutionBias), {TI(ngraph::op::ConvolutionBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias), {TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
......
This diff is collapsed.
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "cpu_workspace_insertion.hpp"
#include <algorithm>
#include <iostream>
#include <numeric>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
void ngraph::runtime::cpu::pass::CPUWorkspaceInsertion::construct_max_pool_with_indices()
{
Shape shape_data{1, 1, 14};
auto data = std::make_shared<pattern::op::Label>(element::f32, shape_data);
Shape window_shape{3};
auto max_pool = std::make_shared<op::MaxPool>(data, window_shape);
auto delta = std::make_shared<pattern::op::Label>(element::f32, max_pool->get_shape());
auto max_pool_bprop =
std::make_shared<op::MaxPoolBackprop>(data,
delta,
max_pool->get_window_shape(),
max_pool->get_window_movement_strides(),
max_pool->get_padding_below(),
max_pool->get_padding_above());
pattern::graph_rewrite_callback callback = [data, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_max_pool_with_indices against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto m_max_pool_bprop = std::dynamic_pointer_cast<op::MaxPoolBackprop>(m.get_match_root());
if (m_max_pool_bprop->get_shape().size() != 4 ||
m_max_pool_bprop->get_window_shape().size() != 2 ||
m_max_pool_bprop->get_input_element_type(0) != element::f32)
{
NGRAPH_DEBUG << "MKLDNN doesn't support inputs of given shape type";
return false;
}
//find the original MaxPool now
std::shared_ptr<op::MaxPool> m_max_pool;
for (auto u : pattern_map[data]->get_users())
{
if (auto mp = std::dynamic_pointer_cast<op::MaxPool>(u))
{
if (mp->get_window_shape() == m_max_pool_bprop->get_window_shape() &&
mp->get_window_movement_strides() ==
m_max_pool_bprop->get_window_movement_strides() &&
mp->get_padding_below() == m_max_pool_bprop->get_padding_below() &&
mp->get_padding_above() == m_max_pool_bprop->get_padding_above())
{
m_max_pool = mp;
break;
}
}
}
if (!m_max_pool)
{
NGRAPH_DEBUG << "MaxPool for " << pattern_map[data]->get_name() << " and "
<< m_max_pool_bprop->get_name() << " not found";
}
auto max_pool_with_indices =
std::make_shared<op::MaxPoolWithIndices>(pattern_map[data],
m_max_pool->get_window_shape(),
m_max_pool->get_window_movement_strides(),
m_max_pool->get_padding_below(),
m_max_pool->get_padding_above());
auto max_pool_with_indices_output =
std::make_shared<op::GetOutputElement>(max_pool_with_indices, 0);
auto max_pool_with_indices_indices =
std::make_shared<op::GetOutputElement>(max_pool_with_indices, 1);
//rewire users to use a new MaxPoolWithIndices (maxpool's output)
for (auto& o : m_max_pool->get_outputs())
{
std::set<ngraph::descriptor::Input*> copy{begin(o.get_inputs()), end(o.get_inputs())};
for (auto i : copy)
{
i->replace_output(max_pool_with_indices_output->get_outputs().at(0));
}
}
//create a new max_pool_with_indices_bprop
auto max_pool_with_indices_bprop = std::make_shared<op::MaxPoolWithIndicesBackprop>(
pattern_map[data],
pattern_map[delta],
max_pool_with_indices_indices,
m_max_pool->get_window_shape(),
m_max_pool->get_window_movement_strides(),
m_max_pool->get_padding_below(),
m_max_pool->get_padding_above());
ngraph::replace_node(m_max_pool_bprop, max_pool_with_indices_bprop);
return true;
};
auto m = std::make_shared<pattern::Matcher>(max_pool_bprop, callback);
this->add_matcher(m);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPUWorkspaceInsertion;
}
}
}
}
class ngraph::runtime::cpu::pass::CPUWorkspaceInsertion : public ngraph::pass::GraphRewrite
{
public:
CPUWorkspaceInsertion()
: GraphRewrite()
{
construct_max_pool_with_indices();
}
private:
void construct_max_pool_with_indices();
};
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
...@@ -48,6 +49,7 @@ ...@@ -48,6 +49,7 @@
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp" #include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp" #include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_mat_fusion.hpp" #include "ngraph/runtime/cpu/pass/cpu_rnn_mat_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
...@@ -1125,3 +1127,99 @@ TEST(cpu_fusion, weight_fusion) ...@@ -1125,3 +1127,99 @@ TEST(cpu_fusion, weight_fusion)
new_convert_layout->get_argument(0)), new_convert_layout->get_argument(0)),
cvt_lt_conv); cvt_lt_conv);
} }
TEST(cpu_fusion, max_pool_with_indices)
{
Shape shape_a{10, 3, 28, 28};
auto input = std::make_shared<op::Parameter>(element::f32, shape_a);
Shape window_shape{2, 2};
auto max_pool = std::make_shared<op::MaxPool>(input, window_shape);
auto C = std::make_shared<op::Parameter>(element::f32, max_pool->get_shape());
ngraph::autodiff::Adjoints adjoints(NodeVector{max_pool}, NodeVector{C});
auto dinput = adjoints.backprop_node(input);
auto df = std::make_shared<Function>(NodeVector{dinput}, op::ParameterVector{input, C});
auto f = std::make_shared<Function>(NodeVector{max_pool}, op::ParameterVector{input});
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("max_pool_fprop_before.pdf");
pass_manager.run_passes(f);
}
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_before.pdf");
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>();
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_after.pdf");
pass_manager.run_passes(df);
}
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("max_pool_fprop_after.pdf");
pass_manager.run_passes(f);
}
auto maxpool_goe_output =
std::dynamic_pointer_cast<op::GetOutputElement>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(maxpool_goe_output);
ASSERT_EQ(maxpool_goe_output->get_n(), 0);
auto maxpool_with_indices = df->get_results().at(0)->get_argument(0);
auto maxpool_goe_indices =
std::dynamic_pointer_cast<op::GetOutputElement>(maxpool_with_indices->get_argument(2));
ASSERT_TRUE(maxpool_goe_indices);
ASSERT_EQ(maxpool_goe_indices->get_n(), 1);
}
TEST(cpu_fusion, backwards_maxpool_with_indices_n4_c1_hw4_2x2_max)
{
Shape shape_a{1, 4, 4, 4};
Shape maxpool_shape{1, 4, 3, 3};
auto A = std::make_shared<op::Parameter>(element::f32, shape_a);
Shape window_shape{2, 2};
auto window_movement_strides = Strides{1, 1};
auto maxpool = std::make_shared<op::MaxPool>(A, window_shape, window_movement_strides);
auto f = std::make_shared<Function>(maxpool, op::ParameterVector{A});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::TensorView> ep = backend->create_tensor(element::f32, maxpool_shape);
vector<float> dataEp(shape_size(maxpool_shape), 4);
shared_ptr<runtime::TensorView> input = backend->create_tensor(element::f32, shape_a);
shared_ptr<runtime::TensorView> output = backend->create_tensor(element::f32, shape_a);
vector<float> dataInput{11.f, 31.f, 40.f, 47.f, 13.f, 61.f, 48.f, 59.f, 17.f, 39.f, 64.f,
62.f, 45.f, 55.f, 36.f, 19.f, 65.f, 33.f, 49.f, 30.f, 56.f, 41.f,
53.f, 58.f, 22.f, 35.f, 52.f, 50.f, 63.f, 54.f, 12.f, 26.f, 44.f,
21.f, 69.f, 24.f, 46.f, 25.f, 51.f, 29.f, 72.f, 15.f, 73.f, 10.f,
16.f, 37.f, 70.f, 32.f, 28.f, 66.f, 57.f, 27.f, 60.f, 42.f, 43.f,
71.f, 18.f, 38.f, 67.f, 68.f, 14.f, 20.f, 34.f, 23.f};
vector<float> expected{0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 12.0f, 0.0f, 4.0f, 0.0f, 0.0f, 16.0f,
0.0f, 0.0f, 4.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f, 0.0f, 4.0f, 0.0f,
8.0f, 8.0f, 0.0f, 0.0f, 4.0f, 0.0f, 4.0f, 4.0f, 0.0f, 0.0f, 0.0f,
0.0f, 8.0f, 0.0f, 4.0f, 0.0f, 0.0f, 0.0f, 8.0f, 0.0f, 16.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 8.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f,
8.0f, 0.0f, 4.0f, 8.0f, 4.0f, 0.0f, 0.0f, 0.0f, 0.0f};
copy_data(ep, dataEp);
copy_data(input, dataInput);
auto C = std::make_shared<op::Parameter>(element::f32, maxpool_shape);
auto df = autodiff::backprop_function(f);
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_before2.pdf");
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>();
pass_manager.register_pass<pass::VisualizeTree>("max_pool_bprop_after2.pdf");
pass_manager.run_passes(df);
}
backend->call(df, {output}, {input, ep});
ASSERT_TRUE(read_vector<float>(output) == expected);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment