/******************************************************************************* * Copyright 2018 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #include "ngraph/pass/constant_folding.hpp" #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" #include "ngraph/pass/manager.hpp" #include "util/test_tools.hpp" using namespace ngraph; using namespace std; TEST(constant_folding, constant_reshape) { Shape shape_in{2, 4}; Shape shape_out{2, 4, 1}; vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7}; auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in); auto reshape = make_shared<op::Reshape>(constant, AxisVector{0, 1}, shape_out); auto f = make_shared<Function>(reshape, op::ParameterVector{}); pass::Manager pass_manager; pass_manager.register_pass<pass::ConstantFolding>(); pass_manager.run_passes(f); ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0); ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1); auto new_const = std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0)); ASSERT_TRUE(new_const); auto values_out = new_const->get_vector<float>(); ASSERT_EQ(values_in, values_out); } TEST(constant_folding, constant_reshape_permute) { Shape shape_in{2, 4}; Shape shape_out{4, 2}; vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7}; auto constant = make_shared<op::Constant>(element::f64, shape_in, values_in); auto reshape = make_shared<op::Reshape>(constant, AxisVector{1, 0}, shape_out); auto f = make_shared<Function>(reshape, op::ParameterVector{}); pass::Manager pass_manager; pass_manager.register_pass<pass::ConstantFolding>(); pass_manager.run_passes(f); ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0); ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1); auto new_const = std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0)); ASSERT_TRUE(new_const); auto values_out = new_const->get_vector<double>(); vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7}; ASSERT_EQ(values_permute, values_out); } TEST(constant_folding, constant_broadcast) { Shape shape_in{2}; Shape shape_out{2, 4}; vector<int> values_in{0, 1}; auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in); auto broadcast = make_shared<op::Broadcast>(constant, shape_out, AxisSet{1}); auto f = make_shared<Function>(broadcast, op::ParameterVector{}); pass::Manager pass_manager; pass_manager.register_pass<pass::ConstantFolding>(); pass_manager.run_passes(f); ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 0); ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1); auto new_const = std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0)); ASSERT_TRUE(new_const); auto values_out = new_const->get_vector<int>(); vector<int> values_permute{0, 0, 0, 0, 1, 1, 1, 1}; ASSERT_EQ(values_permute, values_out); }