Unverified Commit 3424d70a authored by Pruthvi's avatar Pruthvi Committed by GitHub

- fix output replacement in split constant folding (#4406)

- unit test for split with specialized function
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent f73cfcf0
......@@ -45,12 +45,10 @@ void pass::ConstantFolding::construct_constant_split()
split.get(), axis_val, data_node->get_output_partial_shape(0).rank());
const auto slices = builder::split(data_node, split->get_num_splits(), norm_axis_val);
for (size_t i = 0; i < split->get_output_size(); i++)
int index = 0;
for (auto& output : split->outputs())
{
for (auto& input : split->output(i).get_target_inputs())
{
input.replace_source_output((slices[i]->output(0)));
}
output.replace(slices[index++]->output(0));
}
split->outputs().clear();
construct_constant_slice();
......
......@@ -1898,6 +1898,41 @@ TEST(constant_folding, constant_v1_split)
ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
}
TEST(constant_folding, constant_v1_split_specialized)
{
vector<float> data{.1f, .2f, .3f, .4f, .5f, .6f};
const auto const_data = op::Constant::create(element::f32, Shape{data.size()}, data);
const auto const_axis = op::Constant::create(element::i64, Shape{}, {0});
const auto num_splits = 3;
auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
auto specialized_function = ::ngraph::specialize_function(
std::const_pointer_cast<ngraph::Function>(f), {}, {}, {}, true, true);
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
auto res1 = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
auto res2 = as_type_ptr<op::Constant>(f->get_results().at(1)->get_argument(0));
auto res3 = as_type_ptr<op::Constant>(f->get_results().at(2)->get_argument(0));
ASSERT_TRUE(res1);
ASSERT_TRUE(res2);
ASSERT_TRUE(res3);
auto res1_values = res1->get_vector<float>();
ASSERT_TRUE(test::all_close_f(vector<float>(data.begin(), data.begin() + 2), res1_values));
auto res2_values = res2->get_vector<float>();
ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 2, data.begin() + 4), res2_values));
auto res3_values = res3->get_vector<float>();
ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
}
TEST(constant_folding, constant_v1_split_axis_1_4_splits)
{
vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment