Commit 2b674e34 authored by Ayan Moitra's avatar Ayan Moitra Committed by Scott Cyphers

Reduce concat fusion test input sizes (#2675)

* Reduce concat fusion test input sizes

* missed something
parent 3af7837b
...@@ -50,7 +50,7 @@ using namespace std; ...@@ -50,7 +50,7 @@ using namespace std;
TEST(concat_fusion, single_branch) TEST(concat_fusion, single_branch)
{ {
Shape shape_a{128, 2048, 1, 1}; Shape shape_a{12, 8, 1, 1};
auto generate_func = [shape_a]() { auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
...@@ -94,7 +94,7 @@ TEST(concat_fusion, single_branch) ...@@ -94,7 +94,7 @@ TEST(concat_fusion, single_branch)
TEST(concat_fusion, multiple_branches_1) TEST(concat_fusion, multiple_branches_1)
{ {
Shape shape_a{128, 2048, 1, 1}; Shape shape_a{16, 8, 1, 1};
auto generate_func = [shape_a]() { auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
...@@ -142,7 +142,7 @@ TEST(concat_fusion, multiple_branches_1) ...@@ -142,7 +142,7 @@ TEST(concat_fusion, multiple_branches_1)
TEST(concat_fusion, multiple_branches_2) TEST(concat_fusion, multiple_branches_2)
{ {
Shape shape_a{128, 2048, 1, 1}; Shape shape_a{16, 8, 1, 1};
auto generate_func = [shape_a]() { auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto concat_3 = make_shared<op::Concat>(NodeVector{A, A, A, A, A, A, A}, 2); auto concat_3 = make_shared<op::Concat>(NodeVector{A, A, A, A, A, A, A}, 2);
...@@ -185,8 +185,8 @@ TEST(concat_fusion, multiple_branches_2) ...@@ -185,8 +185,8 @@ TEST(concat_fusion, multiple_branches_2)
TEST(concat_fusion, non_fusable_self_concat) TEST(concat_fusion, non_fusable_self_concat)
{ {
Shape shape_a{128, 1, 1, 1}; Shape shape_a{32, 1, 1, 1};
Shape shape_b{128, 1, 1}; Shape shape_b{32, 1, 1};
auto generate_func = [shape_a, shape_b]() { auto generate_func = [shape_a, shape_b]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto B = make_shared<op::Parameter>(element::f32, shape_b); auto B = make_shared<op::Parameter>(element::f32, shape_b);
...@@ -199,7 +199,7 @@ TEST(concat_fusion, non_fusable_self_concat) ...@@ -199,7 +199,7 @@ TEST(concat_fusion, non_fusable_self_concat)
auto concat_5 = make_shared<op::Concat>(NodeVector{B, B, B, B, B, B, B}, 1); auto concat_5 = make_shared<op::Concat>(NodeVector{B, B, B, B, B, B, B}, 1);
auto concat_6 = make_shared<op::Concat>(NodeVector{concat_5, concat_5, concat_5}, 2); auto concat_6 = make_shared<op::Concat>(NodeVector{concat_5, concat_5, concat_5}, 2);
auto broadcast = make_shared<op::Broadcast>(concat_6, Shape{128, 8, 7, 3}, AxisSet{1}); auto broadcast = make_shared<op::Broadcast>(concat_6, Shape{32, 8, 7, 3}, AxisSet{1});
auto add = make_shared<op::Add>(concat_4, broadcast); auto add = make_shared<op::Add>(concat_4, broadcast);
auto f_concat_1 = make_shared<Function>(NodeVector{add}, ParameterVector{A, B}); auto f_concat_1 = make_shared<Function>(NodeVector{add}, ParameterVector{A, B});
return f_concat_1; return f_concat_1;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment