Commit b9017681 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Extend concat elimination to fold slice + reshape (#1511)

* extend concat elimination to fold slice + reshape

* relax branch tip to be anything not just goe

* add support for transpose when concat of slice + reshape occurred on internal axis.

* simplify reshape order permutation

* multi-axis slice + concat do not cancel and are now disabled. generalize detection of axis reordering when intermediate reshape is present
to include logical reshape and reshape that results in axis reordering for the cases: parent_shape.size gt, eq, and lt concat_shape.size.

* check that slices are in order

* add one user check on reshape

* add more checks

* fix warnings

* Reshape axis order did not include enough dimensions when the transposed and reshaped result was of lower rank.
parent 0aaae2bb
......@@ -43,6 +43,8 @@ using namespace ngraph;
#define TI(x) std::type_index(typeid(x))
extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input,
ngraph::AxisVector order);
template <typename T>
static std::shared_ptr<pattern::Matcher>
create_binary_matcher(std::shared_ptr<pattern::op::Label> label,
......@@ -77,18 +79,23 @@ static bool simplify_concat(std::shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_concat for " << n->get_name();
std::shared_ptr<Node> goe;
std::shared_ptr<Node> branch_tip;
auto lgoe = std::make_shared<pattern::op::Label>(element::i32, Shape{2, 1});
auto ltip = std::make_shared<pattern::op::Label>(element::i32, Shape{2, 1});
auto slice =
std::make_shared<op::Slice>(lgoe, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
auto pslice =
std::make_shared<op::Slice>(ltip, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
auto lslice = std::make_shared<pattern::op::Label>(pslice, nullptr, NodeVector{pslice});
auto skip_reshape =
std::make_shared<pattern::op::Skip>(slice, pattern::has_class<op::Reshape>());
std::make_shared<pattern::op::Skip>(lslice, pattern::has_class<op::Reshape>());
auto matcher = std::make_shared<pattern::Matcher>(skip_reshape, nullptr);
Coordinate prev_lower_bounds;
Shape prev_slice_shape;
for (auto carg : n->get_arguments())
{
if (!matcher->match(carg))
......@@ -97,53 +104,140 @@ static bool simplify_concat(std::shared_ptr<Node> n)
return false;
}
if (goe)
auto slice = std::dynamic_pointer_cast<op::Slice>(matcher->get_pattern_map()[lslice]);
if (branch_tip)
{
if (goe != matcher->get_pattern_map()[lgoe])
if (branch_tip != matcher->get_pattern_map()[ltip])
{
NGRAPH_DEBUG << branch_tip->get_name() << " doesn't match "
<< matcher->get_pattern_map()[ltip]->get_name();
return false;
}
//slice chunks should be slice in the same order as slice nodes in concat's argument list
auto cur_lower_bounds = slice->get_lower_bounds();
if (cur_lower_bounds < prev_lower_bounds)
{
NGRAPH_DEBUG << goe->get_name() << " doesn't match "
<< matcher->get_pattern_map()[lgoe]->get_name();
NGRAPH_DEBUG << slice->get_name() << " is in the wrong order";
return false;
}
prev_lower_bounds.assign(cur_lower_bounds.begin(), cur_lower_bounds.end());
//slice shapes need to match
if (slice->get_shape() != prev_slice_shape)
{
NGRAPH_DEBUG << slice->get_name()
<< " doesn't match the shape of the previous slice";
return false;
}
}
else
{
goe = matcher->get_pattern_map()[lgoe];
NGRAPH_DEBUG << "setting goe to " << goe->get_name();
branch_tip = matcher->get_pattern_map()[ltip];
prev_lower_bounds.assign(slice->get_lower_bounds().begin(),
slice->get_lower_bounds().end());
prev_slice_shape.assign(slice->get_shape().begin(), slice->get_shape().end());
NGRAPH_DEBUG << "setting branch_tip to " << branch_tip->get_name();
}
if (slice->get_users().size() > 1)
{
NGRAPH_DEBUG << slice->get_name() << " has more than one user";
return false;
}
auto it = carg;
while (it != goe)
if (shape_size(slice->get_strides()) != 1)
{
if (auto rcarg = std::dynamic_pointer_cast<op::Reshape>(carg))
NGRAPH_DEBUG << slice->get_name() << " is strided";
return false;
}
//check that no other node uses slices and reshapes
if (auto rcarg = std::dynamic_pointer_cast<op::Reshape>(carg))
{
auto default_shape = ngraph::get_default_order(rcarg->get_argument(0)->get_shape());
if (default_shape != rcarg->get_input_order())
{
auto default_shape = ngraph::get_default_order(rcarg->get_argument(0)->get_shape());
if (default_shape != rcarg->get_input_order())
{
NGRAPH_DEBUG << carg->get_name() << " reshape also does transposes";
return false;
}
NGRAPH_DEBUG << carg->get_name() << " reshape also does transposes";
return false;
}
if (carg->get_users().size() > 1)
if (rcarg->get_users().size() > 1)
{
NGRAPH_DEBUG << carg->get_name() << " has more than one user";
NGRAPH_DEBUG << rcarg->get_name() << " has more than one user";
return false;
}
it = it->get_argument(0);
}
}
if (!std::dynamic_pointer_cast<op::GetOutputElement>(goe))
auto concat = std::dynamic_pointer_cast<op::Concat>(n);
size_t concat_axis = concat->get_concatenation_axis();
auto slice_shape = branch_tip->get_users().at(0)->get_shape();
size_t slice_axis = std::numeric_limits<size_t>::max();
auto btip_shape = branch_tip->get_shape();
//slices should cover all elements
if (shape_size(btip_shape) != shape_size(n->get_shape()))
{
NGRAPH_DEBUG << goe->get_name() << " isn't GOE ";
NGRAPH_DEBUG << "The number of elements in Concat (" << shape_size(n->get_shape())
<< ") and the total of elements in slices (" << shape_size(btip_shape)
<< ") don't match";
return false;
}
auto replacement = goe;
if (goe->get_shape().size() != n->get_shape().size())
for (size_t i = 0; i < btip_shape.size(); i++)
{
return false;
if (btip_shape[i] != slice_shape[i])
{
if (slice_axis != std::numeric_limits<size_t>::max())
{
// multi-axis slice + concat do not cancel
return false;
}
slice_axis = i;
}
}
auto replacement = branch_tip;
if (btip_shape != n->get_shape())
{
auto default_order = ngraph::get_default_order(btip_shape);
if (concat_axis == slice_axis)
{
// logical reshape only
replacement =
std::make_shared<op::Reshape>(branch_tip, default_order, concat->get_shape());
}
else
{
// axis reordering required
auto transposed_shape = n->get_shape();
if (btip_shape.size() >= transposed_shape.size())
{
AxisVector order = ngraph::get_default_order(btip_shape);
auto ax = order[slice_axis];
order[slice_axis] = order[concat_axis];
order[concat_axis] = ax;
replacement = std::make_shared<op::Reshape>(branch_tip, order, transposed_shape);
}
else if (btip_shape.size() < transposed_shape.size())
{
// intermediate logical reshape
AxisVector order = ngraph::get_default_order(transposed_shape);
auto ax = order[slice_axis];
order[slice_axis] = order[concat_axis];
order[concat_axis] = ax;
auto output_shape = ngraph::apply_permutation(transposed_shape, order);
auto logical_reshape =
std::make_shared<op::Reshape>(branch_tip, default_order, output_shape);
// transpose to final concatenated shape
replacement =
std::make_shared<op::Reshape>(logical_reshape, order, transposed_shape);
}
}
}
ngraph::replace_node(n, replacement);
......
......@@ -384,7 +384,7 @@ TEST(algebraic_simplification, concat_reshape_slice)
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, op::ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
ASSERT_TRUE(std::dynamic_pointer_cast<op::Reshape>(f->get_results().at(0)->get_argument(0)));
}
TEST(algebraic_simplification, concat_slice)
......@@ -425,6 +425,67 @@ TEST(algebraic_simplification, concat_parameter_slice)
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, op::ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), a);
}
TEST(algebraic_simplification, concat_parameter_slices_reversed)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
size_t concat_axis = 0;
auto concat = make_shared<op::Concat>(NodeVector{slice3, slice2, slice1}, concat_axis);
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, op::ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
}
TEST(algebraic_simplification, concat_parameter_slices_element_count)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
//slicing 30 elements out of 96; should trigger a check that some elements are missing
auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{10, 100}, Strides{1, 1});
auto slice2 = make_shared<op::Slice>(a, Coordinate{10, 0}, Coordinate{20, 100}, Strides{1, 1});
auto slice3 = make_shared<op::Slice>(a, Coordinate{20, 0}, Coordinate{30, 100}, Strides{1, 1});
size_t concat_axis = 0;
auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, op::ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
}
TEST(algebraic_simplification, concat_parameter_non_uniform_slices)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{38, 100}, Strides{1, 1});
auto slice2 = make_shared<op::Slice>(a, Coordinate{38, 0}, Coordinate{64, 100}, Strides{1, 1});
auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
size_t concat_axis = 0;
auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, op::ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment