Commit 4e0e0f56 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

ConstantFolding for Concat (#3317)

* CF for Concat

* Switch from Nodes to Inputs/Outputs
parent 34499001
......@@ -21,6 +21,7 @@
#include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/dequantize.hpp"
......@@ -43,6 +44,7 @@
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
......@@ -1156,3 +1158,108 @@ void pass::ConstantFolding::construct_constant_product()
make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantProduct");
this->add_matcher(convert_matcher, constant_product_callback, all_pass_property_off);
}
template <typename T>
static shared_ptr<op::Constant> fold_constant_concat_helper(const shared_ptr<op::Concat>& concat)
{
auto concat_inputs = concat->inputs();
std::vector<const T*> arg_bufs;
std::vector<Shape> arg_shapes;
for (auto& input : concat_inputs)
{
auto k = static_cast<op::Constant*>(input.get_source_output().get_node());
arg_bufs.push_back(k->get_data_ptr<T>());
arg_shapes.push_back(input.get_shape());
}
std::vector<T> result_vec(shape_size(concat->get_shape()));
runtime::reference::concat<T>(arg_bufs,
result_vec.data(),
arg_shapes,
concat->get_shape(),
concat->get_concatenation_axis());
return make_shared<op::Constant>(
concat->get_output_element_type(0), concat->get_output_shape(0), result_vec);
}
void pass::ConstantFolding::construct_constant_concat()
{
auto concat_op = make_shared<pattern::op::Label>(
element::f32, Shape{2, 3, 4}, pattern::has_class<op::Concat>());
auto constant_concat_callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_concat_callback against node = "
<< m.get_match_root()->get_name();
auto concat_node = static_pointer_cast<op::Concat>(m.get_match_root());
auto concat_inputs = concat_node->inputs();
if (std::any_of(concat_inputs.begin(), concat_inputs.end(), [](const Input<Node>& input) {
return !(input.get_source_output().get_node()->is_constant());
}))
{
return false;
}
std::shared_ptr<op::Constant> replacement;
switch (concat_node->get_output_element_type(0).get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_concat");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_concat");
break;
case element::Type_t::boolean:
replacement = fold_constant_concat_helper<char>(concat_node);
break;
case element::Type_t::bf16:
replacement = fold_constant_concat_helper<bfloat16>(concat_node);
break;
case element::Type_t::f16:
replacement = fold_constant_concat_helper<float16>(concat_node);
break;
case element::Type_t::f32:
replacement = fold_constant_concat_helper<float>(concat_node);
break;
case element::Type_t::f64:
replacement = fold_constant_concat_helper<double>(concat_node);
break;
case element::Type_t::i8:
replacement = fold_constant_concat_helper<int8_t>(concat_node);
break;
case element::Type_t::i16:
replacement = fold_constant_concat_helper<int16_t>(concat_node);
break;
case element::Type_t::i32:
replacement = fold_constant_concat_helper<int32_t>(concat_node);
break;
case element::Type_t::i64:
replacement = fold_constant_concat_helper<int64_t>(concat_node);
break;
case element::Type_t::u8:
replacement = fold_constant_concat_helper<uint8_t>(concat_node);
break;
case element::Type_t::u16:
replacement = fold_constant_concat_helper<uint16_t>(concat_node);
break;
case element::Type_t::u32:
replacement = fold_constant_concat_helper<uint32_t>(concat_node);
break;
case element::Type_t::u64:
replacement = fold_constant_concat_helper<uint64_t>(concat_node);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto concat_matcher =
make_shared<pattern::Matcher>(concat_op, "ConstantFolding.ConstantConcat");
this->add_matcher(concat_matcher, constant_concat_callback, all_pass_property_off);
}
......@@ -42,7 +42,8 @@ public:
CONVERT,
SHAPE_OF,
REVERSE,
PRODUCT
PRODUCT,
CONCAT
};
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
......@@ -60,6 +61,7 @@ public:
construct_constant_shape_of();
construct_constant_reverse();
construct_constant_product();
construct_constant_concat();
}
//this allows to specify the order in which matchers will be run
......@@ -84,6 +86,7 @@ public:
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
case CFTransformations::REVERSE: construct_constant_reverse(); break;
case CFTransformations::PRODUCT: construct_constant_product(); break;
case CFTransformations::CONCAT: construct_constant_concat(); break;
}
}
}
......@@ -100,6 +103,7 @@ private:
void construct_constant_shape_of();
void construct_constant_reverse();
void construct_constant_product();
void construct_constant_concat();
ngraph::BuildNodeExecutorMap m_cfmap;
};
......@@ -408,6 +408,30 @@ TEST(constant_folding, const_product)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat)
{
auto constant0 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
auto constant1 = op::Constant::create(element::i32, Shape{2, 1}, vector<int32_t>{7, 8});
auto concat = make_shared<op::Concat>(NodeVector{constant0, constant1}, 1);
auto f = make_shared<Function>(concat, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 2, 3, 7, 4, 5, 6, 8};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment