Commit 5c5690db authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Fix broken test in shape specialization pass (#2814)

* Fix broken test in shape specialization pass

* Roll back unnecessary changes to Add
parent 118efa26
......@@ -107,7 +107,7 @@ bool pass::ShapeSpecialization::run_on_function(std::shared_ptr<Function> f)
for (size_t i = 0; i < node->get_input_size(); i++)
{
if (node->input(i).get_is_relevant_to_values())
if (!node->input(i).get_is_relevant_to_values())
{
continue;
}
......
......@@ -153,3 +153,35 @@ TEST(shape_specialization, specialization_pass_concat_transpose)
ASSERT_EQ(constant_after->get_element_type(), element::i64);
ASSERT_EQ(constant_after->get_vector<int64_t>(), (vector<int64_t>{1, 0}));
}
// Slight variation on the above test, where the "Concat" does not already have constants going
// into it. (The permutation is Concat(Const<1>,Concat(Const<>,Const<0>)) rather than simply
// Concat(Const<1>,Const<0>).)
TEST(shape_specialization, specialization_pass_add_concat_transpose)
{
auto param0 = make_shared<op::Parameter>(element::boolean, Shape{4, 6});
auto k0 = op::Constant::create(element::i64, Shape{1}, {0});
auto k1 = op::Constant::create(element::i64, Shape{1}, {1});
auto kempty = op::Constant::create(element::i64, Shape{0}, vector<int64_t>{});
auto concat = make_shared<op::Concat>(
NodeVector{k1, make_shared<op::Concat>(NodeVector{kempty, k0}, 0)}, 0);
auto transpose = make_shared<op::Transpose>(param0, concat);
auto f = make_shared<Function>(transpose, ParameterVector{param0});
pass::Manager manager;
manager.register_pass<pass::ShapeSpecialization>();
manager.run_passes(f);
auto transpose_after =
dynamic_pointer_cast<op::Transpose>(f->get_results().at(0)->get_argument(0));
ASSERT_NE(transpose_after, nullptr);
auto constant_after = dynamic_pointer_cast<op::Constant>(transpose_after->get_argument(1));
ASSERT_NE(constant_after, nullptr);
ASSERT_EQ(constant_after->get_shape(), Shape{2});
ASSERT_EQ(constant_after->get_element_type(), element::i64);
ASSERT_EQ(constant_after->get_vector<int64_t>(), (vector<int64_t>{1, 0}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment