Commit 0d0bd8de authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Add constant folding for Split:v1 and VariadicSplit:v1 (#4041)

* POC of Split ConstantFolding

* Removed unused code

* code styles

* Added constant folding variadic split file

* Added variadic split + UT

* Update VariadicSplit tests
parent e3b442aa
...@@ -485,6 +485,8 @@ set (SRC ...@@ -485,6 +485,8 @@ set (SRC
pass/constant_folding_unsqueeze.cpp pass/constant_folding_unsqueeze.cpp
pass/constant_folding_shape_of.cpp pass/constant_folding_shape_of.cpp
pass/constant_folding_slice.cpp pass/constant_folding_slice.cpp
pass/constant_folding_split.cpp
pass/constant_folding_variadic_split.cpp
pass/constant_folding_strided_slice.cpp pass/constant_folding_strided_slice.cpp
pass/constant_folding_transpose.cpp pass/constant_folding_transpose.cpp
pass/constant_folding_unary.cpp pass/constant_folding_unary.cpp
......
...@@ -137,7 +137,8 @@ void op::v1::Split::validate_and_infer_types() ...@@ -137,7 +137,8 @@ void op::v1::Split::validate_and_infer_types()
if (input_value(1).get_node_shared_ptr()->is_constant()) if (input_value(1).get_node_shared_ptr()->is_constant())
{ {
auto axis = axis_value_from_input(); const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis = axis_input->cast_vector<int64_t>()[0];
if (data_ps.is_static()) if (data_ps.is_static())
{ {
...@@ -178,27 +179,3 @@ shared_ptr<Node> op::v1::Split::copy_with_new_args(const NodeVector& new_args) c ...@@ -178,27 +179,3 @@ shared_ptr<Node> op::v1::Split::copy_with_new_args(const NodeVector& new_args) c
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v1::Split>(new_args.at(0), new_args.at(1), m_num_splits); return make_shared<v1::Split>(new_args.at(0), new_args.at(1), m_num_splits);
} }
int64_t op::v1::Split::axis_value_from_input() const
{
int64_t axis_value{0};
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch (static_cast<element::Type_t>(axis_input->get_element_type()))
{
case element::Type_t::i8: axis_value = axis_input->get_vector<int8_t>().at(0); break;
case element::Type_t::i32: axis_value = axis_input->get_vector<int32_t>().at(0); break;
case element::Type_t::i64: axis_value = axis_input->get_vector<int64_t>().at(0); break;
default: break;
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
return axis_value;
}
...@@ -104,9 +104,6 @@ namespace ngraph ...@@ -104,9 +104,6 @@ namespace ngraph
bool supports_decompose() const override { return false; } bool supports_decompose() const override { return false; }
protected: protected:
size_t m_num_splits; size_t m_num_splits;
private:
int64_t axis_value_from_input() const;
}; };
} }
......
...@@ -61,7 +61,8 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types() ...@@ -61,7 +61,8 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
split_lengths_input->is_constant()) split_lengths_input->is_constant())
{ {
auto data_rank = static_cast<size_t>(data_shape.rank()); auto data_rank = static_cast<size_t>(data_shape.rank());
auto axis_val = as_type_ptr<op::Constant>(axis_input)->get_vector<int64_t>()[0]; const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis_val = axis_input->cast_vector<int64_t>()[0];
// Adjust split axis in case of negatives // Adjust split axis in case of negatives
int64_t axis = ngraph::normalize_axis(this, axis_val, data_rank); int64_t axis = ngraph::normalize_axis(this, axis_val, data_rank);
......
...@@ -56,7 +56,9 @@ public: ...@@ -56,7 +56,9 @@ public:
RANGE, RANGE,
SELECT, SELECT,
SQUEEZE, SQUEEZE,
UNSQUEEZE UNSQUEEZE,
SPLIT,
VARIADIC_SPLIT
}; };
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap()) ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
...@@ -65,6 +67,8 @@ public: ...@@ -65,6 +67,8 @@ public:
m_cfmap = cfmap; m_cfmap = cfmap;
m_enable_shape_inference = true; m_enable_shape_inference = true;
construct_constant_split();
construct_constant_variadic_split();
construct_constant_reshape(); construct_constant_reshape();
construct_constant_broadcast(); construct_constant_broadcast();
construct_constant_dyn_broadcast(); construct_constant_dyn_broadcast();
...@@ -130,6 +134,8 @@ public: ...@@ -130,6 +134,8 @@ public:
case CFTransformations::SELECT: construct_constant_select(); break; case CFTransformations::SELECT: construct_constant_select(); break;
case CFTransformations::SQUEEZE: construct_constant_squeeze(); break; case CFTransformations::SQUEEZE: construct_constant_squeeze(); break;
case CFTransformations::UNSQUEEZE: construct_constant_unsqueeze(); break; case CFTransformations::UNSQUEEZE: construct_constant_unsqueeze(); break;
case CFTransformations::SPLIT: construct_constant_split(); break;
case CFTransformations::VARIADIC_SPLIT: construct_constant_variadic_split(); break;
} }
} }
} }
...@@ -159,6 +165,8 @@ private: ...@@ -159,6 +165,8 @@ private:
void construct_constant_select(); void construct_constant_select();
void construct_constant_squeeze(); void construct_constant_squeeze();
void construct_constant_unsqueeze(); void construct_constant_unsqueeze();
void construct_constant_split();
void construct_constant_variadic_split();
ngraph::BuildNodeExecutorMap m_cfmap; ngraph::BuildNodeExecutorMap m_cfmap;
}; };
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "constant_folding.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
void pass::ConstantFolding::construct_constant_split()
{
auto data_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto axis_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
auto split_pattern = make_shared<op::v1::Split>(data_label, axis_label, 0);
auto constant_split_callback = [this, data_label, axis_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_split_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
const auto data_node = static_pointer_cast<op::Constant>(pattern_map[data_label]);
const auto axis_node = static_pointer_cast<op::Constant>(pattern_map[axis_label]);
const auto split = static_pointer_cast<op::v1::Split>(m.get_match_root());
const auto axis_val = axis_node->cast_vector<int64_t>()[0];
const auto norm_axis_val =
ngraph::normalize_axis(split.get(), axis_val, data_node->get_shape().size());
const auto slices = builder::split(data_node, split->get_num_splits(), norm_axis_val);
for (size_t i = 0; i < split->get_output_size(); i++)
{
for (auto& input : split->output(i).get_target_inputs())
{
input.replace_source_output((slices[i]->output(0)));
}
}
split->outputs().clear();
construct_constant_slice();
return true;
};
auto split_matcher =
make_shared<pattern::Matcher>(split_pattern, "ConstantFolding.ConstantSplit");
this->add_matcher(split_matcher, constant_split_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "constant_folding.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
void pass::ConstantFolding::construct_constant_variadic_split()
{
auto data_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto axis_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
auto lengths_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto variadic_split_pattern =
make_shared<op::v1::VariadicSplit>(data_label, axis_label, lengths_label);
auto constant_variadic_split_callback = [this, data_label, axis_label, lengths_label](
pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_variadic_split_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
const auto data_node = static_pointer_cast<op::Constant>(pattern_map[data_label]);
const auto axis_node = static_pointer_cast<op::Constant>(pattern_map[axis_label]);
const auto lengths_node = static_pointer_cast<op::Constant>(pattern_map[lengths_label]);
const auto variadic_split = static_pointer_cast<op::v1::VariadicSplit>(m.get_match_root());
const auto axis_val = axis_node->cast_vector<int64_t>()[0];
const auto norm_axis_val =
ngraph::normalize_axis(variadic_split.get(), axis_val, data_node->get_shape().size());
auto split_lengths = lengths_node->cast_vector<int64_t>();
// Adjust split lengths in case of negatives
size_t sum_of_splits = 0;
int64_t negative_one = -1;
for (size_t i = 0; i < split_lengths.size(); i++)
{
if (split_lengths[i] == -1)
{
negative_one = i;
}
else
{
sum_of_splits += split_lengths[i];
}
}
if (negative_one > 0)
{
split_lengths[negative_one] =
static_cast<size_t>(data_node->get_shape()[norm_axis_val]) - sum_of_splits;
}
const auto slices = builder::split(
data_node, vector<size_t>(split_lengths.begin(), split_lengths.end()), norm_axis_val);
for (size_t i = 0; i < variadic_split->get_output_size(); i++)
{
for (auto& input : variadic_split->output(i).get_target_inputs())
{
input.replace_source_output((slices[i]->output(0)));
}
}
variadic_split->outputs().clear();
construct_constant_slice();
return true;
};
auto variadic_split_matcher = make_shared<pattern::Matcher>(
variadic_split_pattern, "ConstantFolding.ConstantVariadicSplit");
this->add_matcher(variadic_split_matcher,
constant_variadic_split_callback,
PassProperty::CHANGE_DYNAMIC_STATE);
}
...@@ -1866,6 +1866,212 @@ TEST(constant_folding, constant_v1_select) ...@@ -1866,6 +1866,212 @@ TEST(constant_folding, constant_v1_select)
ASSERT_EQ(values_expected, values_out); ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, constant_v1_split)
{
vector<float> data{.1f, .2f, .3f, .4f, .5f, .6f};
const auto const_data = op::Constant::create(element::f32, Shape{data.size()}, data);
const auto const_axis = op::Constant::create(element::i64, Shape{}, {0});
const auto num_splits = 3;
auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
auto res1 = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
auto res2 = as_type_ptr<op::Constant>(f->get_results().at(1)->get_argument(0));
auto res3 = as_type_ptr<op::Constant>(f->get_results().at(2)->get_argument(0));
ASSERT_TRUE(res1);
ASSERT_TRUE(res2);
ASSERT_TRUE(res3);
auto res1_values = res1->get_vector<float>();
ASSERT_TRUE(test::all_close_f(vector<float>(data.begin(), data.begin() + 2), res1_values));
auto res2_values = res2->get_vector<float>();
ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 2, data.begin() + 4), res2_values));
auto res3_values = res3->get_vector<float>();
ASSERT_TRUE(test::all_close_f(vector<float>(data.begin() + 4, data.end()), res3_values));
}
TEST(constant_folding, constant_v1_split_axis_1_4_splits)
{
vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
const auto const_axis = op::Constant::create(element::i64, Shape{}, {1});
const auto num_splits = 4;
auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
auto res1 = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
auto res2 = as_type_ptr<op::Constant>(f->get_results().at(1)->get_argument(0));
auto res3 = as_type_ptr<op::Constant>(f->get_results().at(2)->get_argument(0));
auto res4 = as_type_ptr<op::Constant>(f->get_results().at(3)->get_argument(0));
ASSERT_TRUE(res1);
ASSERT_TRUE(res2);
ASSERT_TRUE(res3);
ASSERT_TRUE(res4);
auto res1_values = res1->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51}),
res1_values);
auto res2_values = res2->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55}),
res2_values);
auto res3_values = res3->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 24, 25, 26, 27, 40, 41, 42, 43, 56, 57, 58, 59}),
res3_values);
auto res4_values = res4->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63}),
res4_values);
}
TEST(constant_folding, constant_v1_split_axis_1_2_splits)
{
vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
const auto const_axis = op::Constant::create(element::i64, Shape{}, {1});
const auto num_splits = 2;
auto split_v1 = make_shared<op::v1::Split>(const_data, const_axis, num_splits);
auto f = make_shared<Function>(split_v1->outputs(), ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Split>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), num_splits);
auto res1 = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
auto res2 = as_type_ptr<op::Constant>(f->get_results().at(1)->get_argument(0));
ASSERT_TRUE(res1);
ASSERT_TRUE(res2);
auto res1_values = res1->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55}),
res1_values);
auto res2_values = res2->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31,
40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63}),
res2_values);
}
TEST(constant_folding, constant_v1_variadic_split_axis_1_2_splits)
{
vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
const auto const_axis = op::Constant::create(element::i16, Shape{}, {1});
vector<int64_t> values_lengths{3, 1};
auto constant_lengths =
make_shared<op::Constant>(element::i64, Shape{values_lengths.size()}, values_lengths);
auto variadic_split_v1 =
make_shared<op::v1::VariadicSplit>(const_data, const_axis, constant_lengths);
auto f = make_shared<Function>(variadic_split_v1->outputs(), ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::VariadicSplit>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), values_lengths.size());
auto res1 = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
auto res2 = as_type_ptr<op::Constant>(f->get_results().at(1)->get_argument(0));
ASSERT_TRUE(res1);
ASSERT_TRUE(res2);
auto res1_values = res1->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 32, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59}),
res1_values);
auto res2_values = res2->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63}),
res2_values);
}
TEST(constant_folding, constant_v1_variadic_split_axis_1_3_splits_neg_length)
{
vector<int64_t> data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
const auto const_data = op::Constant::create(element::i64, Shape{4, 4, 4}, data);
const auto const_axis = op::Constant::create(element::i32, Shape{}, {1});
vector<int64_t> values_lengths{1, 1, -1};
auto constant_lengths =
make_shared<op::Constant>(element::i64, Shape{values_lengths.size()}, values_lengths);
auto variadic_split_v1 =
make_shared<op::v1::VariadicSplit>(const_data, const_axis, constant_lengths);
auto f = make_shared<Function>(variadic_split_v1->outputs(), ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::VariadicSplit>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), values_lengths.size());
auto res1 = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
auto res2 = as_type_ptr<op::Constant>(f->get_results().at(1)->get_argument(0));
auto res3 = as_type_ptr<op::Constant>(f->get_results().at(2)->get_argument(0));
ASSERT_TRUE(res1);
ASSERT_TRUE(res2);
ASSERT_TRUE(res3);
auto res1_values = res1->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51}),
res1_values);
auto res2_values = res2->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55}),
res2_values);
auto res3_values = res3->get_vector<int64_t>();
ASSERT_EQ(vector<int64_t>({8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31,
40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63}),
res3_values);
}
TEST(constant_folding, pass_property) TEST(constant_folding, pass_property)
{ {
auto pass = std::make_shared<ngraph::pass::ConstantFolding>(); auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment