Unverified Commit a1ee816e authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

AlgebraicSimplification simplification (#4326)

* Start of rework

* Faster multiply simplification

* Fix uniform constant check

* Support Add op

* Fix build error
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent b2e7114d
...@@ -119,7 +119,7 @@ TEST(algebraic_simplification, add_broadcast) ...@@ -119,7 +119,7 @@ TEST(algebraic_simplification, add_broadcast)
} }
} }
TEST(algebraic_simplification, multiply_broadcast) TEST(algebraic_simplification, multiply_broadcast_0)
{ {
Shape shape{2, 2}; Shape shape{2, 2};
pass::Manager pass_manager; pass::Manager pass_manager;
...@@ -139,7 +139,7 @@ TEST(algebraic_simplification, multiply_broadcast) ...@@ -139,7 +139,7 @@ TEST(algebraic_simplification, multiply_broadcast)
ParameterVector{a, b, c}); ParameterVector{a, b, c});
pass_manager.run_passes(f); pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Add>(f), 0); ASSERT_EQ(count_ops_of_type<op::Multiply>(f), 0);
auto expected = ngraph::NodeVector{a, b, const_broadcast, c, const_broadcast}; auto expected = ngraph::NodeVector{a, b, const_broadcast, c, const_broadcast};
auto results = f->get_results(); auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++) for (size_t i = 0; i < results.size(); i++)
...@@ -148,6 +148,34 @@ TEST(algebraic_simplification, multiply_broadcast) ...@@ -148,6 +148,34 @@ TEST(algebraic_simplification, multiply_broadcast)
} }
} }
TEST(algebraic_simplification, multiply_broadcast_1)
{
Shape shape{2, 2};
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto const_broadcast = ngraph::builder::make_constant<int32_t>(element::i32, shape, 1);
auto mul_a_0 = a * const_broadcast;
auto mul_a_0_0 = mul_a_0 * const_broadcast;
auto mul_b_0 = b * const_broadcast;
auto mul_b_0_0 = mul_b_0 * const_broadcast;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Multiply>(f), 0);
auto expected = ngraph::NodeVector{a, b, a, c, b};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected[i], results[i]->get_argument(0));
}
}
TEST(algebraic_simplification, zero_plus_zero_commutativity) TEST(algebraic_simplification, zero_plus_zero_commutativity)
{ {
Shape shape{}; Shape shape{};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment