Commit 8de62221 authored by nikolay.korovaiko's avatar nikolay.korovaiko

sink pad and slice

parent 20bd8bbc
...@@ -29,10 +29,13 @@ ...@@ -29,10 +29,13 @@
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp" #include "ngraph/op/dequantize.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -41,11 +44,30 @@ extern template ngraph::AxisVector ...@@ -41,11 +44,30 @@ extern template ngraph::AxisVector
ngraph::apply_permutation<ngraph::AxisVector>(ngraph::AxisVector input, ngraph::apply_permutation<ngraph::AxisVector>(ngraph::AxisVector input,
ngraph::AxisVector order); ngraph::AxisVector order);
extern template ngraph::Coordinate
ngraph::apply_permutation<ngraph::Coordinate>(ngraph::Coordinate input,
ngraph::AxisVector order);
extern template ngraph::Strides
ngraph::apply_permutation<ngraph::Strides>(ngraph::Strides input, ngraph::AxisVector order);
extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input, extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input,
ngraph::AxisVector order); ngraph::AxisVector order);
using ReshapeMap = std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>>; using ReshapeMap = std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>>;
static std::string describe_reshape(std::shared_ptr<Node> node)
{
std::stringstream ss;
auto reshape = std::dynamic_pointer_cast<op::Reshape>(node);
ss << reshape->get_name()
<< " ( axis order = " << ngraph::vector_to_string(reshape->get_input_order())
<< " , shape = " << vector_to_string(reshape->get_shape()) << " ) "
<< " , child = " << reshape->get_argument(0)->get_name();
return ss.str();
}
static std::shared_ptr<op::Reshape> combine_reshapes(std::shared_ptr<op::Reshape> r1, static std::shared_ptr<op::Reshape> combine_reshapes(std::shared_ptr<op::Reshape> r1,
std::shared_ptr<op::Reshape> r2) std::shared_ptr<op::Reshape> r2)
{ {
...@@ -64,18 +86,6 @@ static void ...@@ -64,18 +86,6 @@ static void
target->get_inputs().at(input_index).replace_output(new_reshape->get_outputs().at(0)); target->get_inputs().at(input_index).replace_output(new_reshape->get_outputs().at(0));
} }
std::string describe_reshape(std::shared_ptr<Node> node)
{
std::stringstream ss;
auto reshape = std::dynamic_pointer_cast<op::Reshape>(node);
ss << reshape->get_name()
<< " ( axis order = " << ngraph::vector_to_string(reshape->get_input_order())
<< " , shape = " << vector_to_string(reshape->get_shape()) << " ) "
<< " , child = " << reshape->get_argument(0)->get_name();
return ss.str();
}
static void delete_reshape(std::shared_ptr<Node> reshape) static void delete_reshape(std::shared_ptr<Node> reshape)
{ {
NGRAPH_DEBUG << "Removing reshape " << reshape->get_name(); NGRAPH_DEBUG << "Removing reshape " << reshape->get_name();
...@@ -256,6 +266,7 @@ static void sink_reshape(std::shared_ptr<op::Reshape> reshape, ...@@ -256,6 +266,7 @@ static void sink_reshape(std::shared_ptr<op::Reshape> reshape,
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete); mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
//replace reshape with combined one //replace reshape with combined one
ngraph::replace_node(reshape, new_reshape); ngraph::replace_node(reshape, new_reshape);
mark_reshape_for_deletion(new_reshape, reshapes_to_delete);
reorders[new_reshape] = new_reshape; reorders[new_reshape] = new_reshape;
NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and" NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and"
<< describe_reshape(reshape) << " into " << describe_reshape(new_reshape); << describe_reshape(reshape) << " into " << describe_reshape(new_reshape);
...@@ -309,6 +320,61 @@ static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> b ...@@ -309,6 +320,61 @@ static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> b
} }
} }
static void sink_slice(std::shared_ptr<op::Slice> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0));
auto order = arg_reshape->get_input_order();
// we need the correct input shape to produce the right output shape
// we are going to create a label of the right input shape,
// so a new slice will have the right shape
auto def_order = ngraph::get_permutation_to_default_order(order);
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
auto dummy_correct_shape =
std::make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
auto new_lower = ngraph::apply_permutation(n->get_lower_bounds(), def_order);
auto new_upper = ngraph::apply_permutation(n->get_upper_bounds(), def_order);
auto new_strides = ngraph::apply_permutation(n->get_strides(), def_order);
auto new_slice =
std::make_shared<op::Slice>(dummy_correct_shape, new_lower, new_upper, new_strides);
ngraph::replace_node(dummy_correct_shape, n->get_argument(0));
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_slice->get_name();
ngraph::replace_node(n, new_slice);
auto new_reshape = std::make_shared<op::Reshape>(new_slice, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_slice] = new_reshape;
}
static void sink_pad(std::shared_ptr<op::Pad> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0));
auto order = arg_reshape->get_input_order();
// we need the correct input shape to produce the right output shape
// we are going to create a label of the right input shape,
// so a new pad will have the right shape
auto def_order = ngraph::get_permutation_to_default_order(order);
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
auto dummy_correct_shape =
std::make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
auto new_lower = ngraph::apply_permutation(n->get_padding_below(), def_order);
auto new_upper = ngraph::apply_permutation(n->get_padding_above(), def_order);
auto new_interior = ngraph::apply_permutation(n->get_padding_interior(), def_order);
auto new_pad = std::make_shared<op::Pad>(
dummy_correct_shape, n->get_argument(1), new_lower, new_upper, new_interior);
ngraph::replace_node(dummy_correct_shape, n->get_argument(0));
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
ngraph::replace_node(n, new_pad);
auto new_reshape = std::make_shared<op::Reshape>(new_pad, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_pad] = new_reshape;
}
static void sink_quantize(std::shared_ptr<op::Quantize> quantize, static void sink_quantize(std::shared_ptr<op::Quantize> quantize,
ReshapeMap& reorders, ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete) std::set<std::shared_ptr<Node>>& reshapes_to_delete)
...@@ -419,6 +485,14 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -419,6 +485,14 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
{ {
sink_dequantize(dequantize, reorders, reshapes_to_delete); sink_dequantize(dequantize, reorders, reshapes_to_delete);
} }
else if (auto slice = std::dynamic_pointer_cast<op::Slice>(n))
{
sink_slice(slice, reorders, reshapes_to_delete);
}
else if (auto pad = std::dynamic_pointer_cast<op::Pad>(n))
{
sink_pad(pad, reorders, reshapes_to_delete);
}
else else
{ {
materialize_shapes(n, reorders, reshapes_to_delete); materialize_shapes(n, reorders, reshapes_to_delete);
......
...@@ -478,6 +478,10 @@ T ngraph::apply_permutation(T input, AxisVector order) ...@@ -478,6 +478,10 @@ T ngraph::apply_permutation(T input, AxisVector order)
template AxisVector ngraph::apply_permutation<AxisVector>(AxisVector input, AxisVector order); template AxisVector ngraph::apply_permutation<AxisVector>(AxisVector input, AxisVector order);
template Shape ngraph::apply_permutation<Shape>(Shape input, AxisVector order); template Shape ngraph::apply_permutation<Shape>(Shape input, AxisVector order);
template ngraph::Coordinate ngraph::apply_permutation<ngraph::Coordinate>(ngraph::Coordinate input,
ngraph::AxisVector order);
template ngraph::Strides ngraph::apply_permutation<ngraph::Strides>(ngraph::Strides input,
ngraph::AxisVector order);
AxisVector ngraph::get_default_order(const Shape& shape) AxisVector ngraph::get_default_order(const Shape& shape)
{ {
......
...@@ -163,3 +163,57 @@ TEST(reshape_sinking, nasnet_pooladd) ...@@ -163,3 +163,57 @@ TEST(reshape_sinking, nasnet_pooladd)
size_t before_after = count_ops_of_type<op::Reshape>(func); size_t before_after = count_ops_of_type<op::Reshape>(func);
ASSERT_LE(before_after, before_count); ASSERT_LE(before_after, before_count);
} }
TEST(reshape_sinking, slice_pad)
{
Shape shape_a{1, 8, 8, 1};
AxisVector to_nhwc{0, 2, 3, 1};
AxisVector to_nchw{0, 3, 1, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto pad_value = op::Constant::create<float>(element::f32, Shape{}, std::vector<float>{0.0f});
Shape padding_below{0, 0, 0, 0};
Shape padding_above{0, 1, 1, 0};
Shape padding_interior{0, 0, 0, 0};
auto reshape1 = make_shared<op::Reshape>(A, to_nchw, Shape{1, 1, 8, 8});
auto maxpool =
make_shared<op::MaxPool>(reshape1, Shape{1, 1}, Strides{2, 2}, Shape{0, 0}, Shape{0, 0});
auto reshape2 = make_shared<op::Reshape>(maxpool, to_nhwc, Shape{1, 4, 4, 1});
auto pad =
make_shared<op::Pad>(reshape2, pad_value, padding_below, padding_above, padding_interior);
auto slice = make_shared<op::Slice>(
pad, Coordinate{0, 1, 1, 0}, Coordinate{1, 5, 5, 1}, Strides{1, 1, 1, 1});
auto reshape3 = make_shared<op::Reshape>(slice, to_nchw, Shape{1, 1, 4, 4});
auto avgpool = make_shared<op::AvgPool>(reshape3, Shape{1, 1}, Strides{2, 2});
auto reshape4 = make_shared<op::Reshape>(avgpool, to_nhwc, Shape{1, 1, 2, 2});
auto f = make_shared<Function>(reshape4, ParameterVector{A});
// auto reshape1 = make_shared<op::Reshape>(A, to_nchw, Shape{1, 1, 4, 4});
// auto pad = make_shared<op::Pad>(reshape1, pad_value, padding_below, padding_above, padding_interior);
// auto reshape2 = make_shared<op::Reshape>(pad, to_nhwc, Shape{1, 5, 5, 1});
// auto absn = make_shared<op::Abs>(reshape2);
// auto reshape3 = make_shared<op::Reshape>(absn, to_nchw, Shape{1, 1, 5, 5});
// auto slice = make_shared<op::Slice>(
// reshape3, Coordinate{0, 0, 1, 1}, Coordinate{1, 1, 5, 5}, Strides{1, 1, 1, 1});
// auto reshape4 = make_shared<op::Reshape>(slice, to_nhwc, Shape{1, 4, 4, 1});
// auto absn2 = make_shared<op::Abs>(reshape4);
// auto reshape5 = make_shared<op::Reshape>(absn2, to_nchw, Shape{1, 1, 4, 4});
// auto avgpool = make_shared<op::AvgPool>(reshape5, Shape{1, 1}, Strides{2, 2});
// auto reshape6 = make_shared<op::Reshape>(avgpool, to_nhwc, Shape{1, 1, 2, 2});
// auto f = make_shared<Function>(reshape6, ParameterVector{A});
pass::Manager pass_manager;
size_t before_count = count_ops_of_type<op::Reshape>(f);
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::ReshapeSinking>();
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.run_passes(f);
size_t before_after = count_ops_of_type<op::Reshape>(f);
ASSERT_LE(before_after, before_count);
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment