Unverified Commit 91ecac9d authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Slice Concat Elimination (#948)

* slice elimination

* add comment for simplify_concat

* fix concat_slice

* another reshape-related fix

* added a missing header

* disable reshape-concat optimization

* test fix
parent 99ea4a4b
...@@ -23,12 +23,16 @@ ...@@ -23,12 +23,16 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/log.hpp" #include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
...@@ -59,6 +63,98 @@ static std::shared_ptr<pattern::op::Label> ...@@ -59,6 +63,98 @@ static std::shared_ptr<pattern::op::Label>
return std::dynamic_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1)); return std::dynamic_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1));
} }
//`simplify_concat` identifies slices-concat sequences
// that cancel each other. Namely it replaces subgraphs
//similar to the one below with `arg`
//
// +----------+
// +----+slice(n/2..n)---+
// +-------+ | +----------+ | +-----------+
// | arg +--+ +--+ concat |
// +-------+ | +----------+ | +-----------+
// +----+slice(0..n/2)---+
// +----------+
static bool simplify_concat(std::shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_concat for " << n->get_name();
std::shared_ptr<Node> goe;
auto lgoe = std::make_shared<pattern::op::Label>(element::i32, Shape{2, 1});
auto slice =
std::make_shared<op::Slice>(lgoe, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
auto reshape_pred = [](std::shared_ptr<Node> r) {
return std::dynamic_pointer_cast<op::Reshape>(r) != nullptr;
};
auto skip_reshape = std::make_shared<pattern::op::Skip>(slice, reshape_pred);
auto matcher = std::make_shared<pattern::Matcher>(skip_reshape, nullptr);
for (auto carg : n->get_arguments())
{
if (!matcher->match(carg))
{
NGRAPH_DEBUG << carg->get_name() << " doesn't match";
return false;
}
if (goe)
{
if (goe != matcher->get_pattern_map()[lgoe])
{
NGRAPH_DEBUG << goe->get_name() << " doesn't match "
<< matcher->get_pattern_map()[lgoe]->get_name();
return false;
}
}
else
{
goe = matcher->get_pattern_map()[lgoe];
NGRAPH_DEBUG << "setting goe to " << goe->get_name();
}
auto it = carg;
while (it != goe)
{
if (auto rcarg = std::dynamic_pointer_cast<op::Reshape>(carg))
{
Shape default_shape(rcarg->get_argument(0)->get_shape().size());
std::iota(begin(default_shape), end(default_shape), 0);
if (default_shape != rcarg->get_input_order())
{
NGRAPH_DEBUG << carg->get_name() << " reshape also does transposes";
return false;
}
}
if (carg->get_users().size() > 1)
{
NGRAPH_DEBUG << carg->get_name() << " has more than one user";
return false;
}
it = it->get_argument(0);
}
}
if (!std::dynamic_pointer_cast<op::GetOutputElement>(goe))
{
NGRAPH_DEBUG << goe->get_name() << " isn't GOE ";
return false;
}
auto replacement = goe;
if (goe->get_shape().size() != n->get_shape().size())
{
return false;
}
ngraph::replace_node(n, replacement);
return true;
}
//`simplify_multiply` optimizes the following 4 *base* cases //`simplify_multiply` optimizes the following 4 *base* cases
//(8 cases in total including variants due to commutativity) //(8 cases in total including variants due to commutativity)
// //
...@@ -286,6 +382,7 @@ static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<No ...@@ -286,6 +382,7 @@ static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<No
return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>( return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>(
{{TI(op::Add), simplify_add}, {{TI(op::Add), simplify_add},
{TI(op::Multiply), simplify_multiply}, {TI(op::Multiply), simplify_multiply},
{TI(op::Concat), simplify_concat},
{TI(op::Sum), {TI(op::Sum),
std::function<bool(std::shared_ptr<Node>)>{ std::function<bool(std::shared_ptr<Node>)>{
simplify_reduction<op::Sum, get_sum_constant>}}, simplify_reduction<op::Sum, get_sum_constant>}},
......
...@@ -27,22 +27,23 @@ ...@@ -27,22 +27,23 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/log.hpp" #include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/pass/algebraic_simplification.hpp" #include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp" #include "ngraph/pattern/op/skip.hpp"
...@@ -359,6 +360,101 @@ TEST(algebraic_simplification, multiply_sum_negative) ...@@ -359,6 +360,101 @@ TEST(algebraic_simplification, multiply_sum_negative)
ASSERT_EQ(f_sum, sum_fconst1); ASSERT_EQ(f_sum, sum_fconst1);
} }
TEST(algebraic_simplification, concat_reshape_slice)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto goe = make_shared<op::GetOutputElement>(a, 0);
auto slice1 = make_shared<op::Slice>(goe, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
auto slice2 =
make_shared<op::Slice>(goe, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
auto slice3 =
make_shared<op::Slice>(goe, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
auto reshape1 = make_shared<op::Reshape>(slice1, AxisVector{0, 1}, Shape{32, 1, 100});
auto reshape2 = make_shared<op::Reshape>(slice2, AxisVector{0, 1}, Shape{32, 1, 100});
auto reshape3 = make_shared<op::Reshape>(slice3, AxisVector{0, 1}, Shape{32, 1, 100});
size_t concat_axis = 1;
auto concat = make_shared<op::Concat>(NodeVector{reshape1, reshape2, reshape3}, concat_axis);
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, op::ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
}
TEST(algebraic_simplification, concat_slice)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto goe = make_shared<op::GetOutputElement>(a, 0);
auto slice1 = make_shared<op::Slice>(goe, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
auto slice2 =
make_shared<op::Slice>(goe, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
auto slice3 =
make_shared<op::Slice>(goe, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
size_t concat_axis = 0;
auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, op::ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), goe);
}
TEST(algebraic_simplification, concat_parameter_slice)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
size_t concat_axis = 0;
auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, op::ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
}
TEST(algebraic_simplification, concat_different_goes)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto goe1 = make_shared<op::GetOutputElement>(a, 0);
auto goe2 = make_shared<op::GetOutputElement>(a, 0);
auto slice1 =
make_shared<op::Slice>(goe1, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
auto slice2 =
make_shared<op::Slice>(goe2, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
auto slice3 =
make_shared<op::Slice>(goe1, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
size_t concat_axis = 0;
auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, op::ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), concat);
}
TEST(algebraic_simplification, log_neg_neg) TEST(algebraic_simplification, log_neg_neg)
{ {
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100}); auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment