Commit 8ee4c69d authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

Enable self-concat to broadcast replacement for a single concat (#4063)

* Enable self-concat to broadcast replacement for a single concat.

* Update tests.
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 9a01762a
...@@ -160,8 +160,6 @@ bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> f ...@@ -160,8 +160,6 @@ bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> f
NGRAPH_DEBUG << print_state_of_bounded_vectors(); NGRAPH_DEBUG << print_state_of_bounded_vectors();
remove_single_concat_op_pattern();
for (auto concat_op_pattern_node_vector : this->m_concat_pattern_vectors) for (auto concat_op_pattern_node_vector : this->m_concat_pattern_vectors)
{ {
modify_graph = replace_patterns(concat_op_pattern_node_vector); modify_graph = replace_patterns(concat_op_pattern_node_vector);
......
...@@ -171,10 +171,10 @@ TEST(concat_fusion, multiple_branches_2) ...@@ -171,10 +171,10 @@ TEST(concat_fusion, multiple_branches_2)
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0))); EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f); size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f); size_t num_broadcast_optimized = count_ops_of_type<op::Broadcast>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 1); ASSERT_EQ(num_reshapes_optimized, 2);
ASSERT_EQ(num_broadcast_optimzed, 1); ASSERT_EQ(num_broadcast_optimized, 2);
} }
TEST(concat_fusion, non_fusable_self_concat) TEST(concat_fusion, non_fusable_self_concat)
...@@ -226,8 +226,8 @@ TEST(concat_fusion, non_fusable_self_concat) ...@@ -226,8 +226,8 @@ TEST(concat_fusion, non_fusable_self_concat)
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f); size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f); size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 2); ASSERT_EQ(num_reshapes_optimized, 3);
ASSERT_EQ(num_broadcast_optimzed, 3); ASSERT_EQ(num_broadcast_optimzed, 4);
} }
TEST(concat_fusion, self_concat_with_fan_out) TEST(concat_fusion, self_concat_with_fan_out)
...@@ -279,8 +279,8 @@ TEST(concat_fusion, self_concat_with_fan_out) ...@@ -279,8 +279,8 @@ TEST(concat_fusion, self_concat_with_fan_out)
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f); size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f); size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 1); ASSERT_EQ(num_reshapes_optimized, 3);
ASSERT_EQ(num_broadcast_optimzed, 1); ASSERT_EQ(num_broadcast_optimzed, 3);
} }
TEST(concat_fusion, pass_property) TEST(concat_fusion, pass_property)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment