Commit 098c9118 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

ConstantFolding for Sum (#3318)

* CF for Sum

* style
parent 4e0e0f56
......@@ -39,6 +39,7 @@
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/abs.hpp"
......@@ -60,6 +61,7 @@
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -1084,11 +1086,6 @@ static shared_ptr<op::Constant> fold_constant_product(shared_ptr<op::Constant> c
{
auto& input_element_type = constant->get_output_element_type(0);
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (input_element_type.get_type_enum())
{
case element::Type_t::undefined:
......@@ -1126,10 +1123,6 @@ static shared_ptr<op::Constant> fold_constant_product(shared_ptr<op::Constant> c
}
NGRAPH_UNREACHABLE("Unexpected switch case");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
}
void pass::ConstantFolding::construct_constant_product()
......@@ -1159,6 +1152,95 @@ void pass::ConstantFolding::construct_constant_product()
this->add_matcher(convert_matcher, constant_product_callback, all_pass_property_off);
}
// TODO(amprocte): Find a way to reduce duplication with Product. (The fact
// that we bottom out in a reference call makes it a bit tricky.)
template <typename T>
static shared_ptr<op::Constant> fold_constant_sum_helper(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
{
vector<T> out_vec(shape_size(result_shape));
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
result_shape,
reduction_axes);
return make_shared<op::Constant>(constant->get_output_element_type(0), result_shape, out_vec);
}
static shared_ptr<op::Constant> fold_constant_sum(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
{
auto& input_element_type = constant->get_output_element_type(0);
switch (input_element_type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_sum");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_sum");
break;
case element::Type_t::boolean:
return fold_constant_sum_helper<char>(constant, reduction_axes, result_shape);
case element::Type_t::bf16:
return fold_constant_sum_helper<bfloat16>(constant, reduction_axes, result_shape);
case element::Type_t::f16:
return fold_constant_sum_helper<float16>(constant, reduction_axes, result_shape);
case element::Type_t::f32:
return fold_constant_sum_helper<float>(constant, reduction_axes, result_shape);
case element::Type_t::f64:
return fold_constant_sum_helper<double>(constant, reduction_axes, result_shape);
case element::Type_t::i8:
return fold_constant_sum_helper<int8_t>(constant, reduction_axes, result_shape);
case element::Type_t::i16:
return fold_constant_sum_helper<int16_t>(constant, reduction_axes, result_shape);
case element::Type_t::i32:
return fold_constant_sum_helper<int32_t>(constant, reduction_axes, result_shape);
case element::Type_t::i64:
return fold_constant_sum_helper<int64_t>(constant, reduction_axes, result_shape);
case element::Type_t::u8:
return fold_constant_sum_helper<uint8_t>(constant, reduction_axes, result_shape);
case element::Type_t::u16:
return fold_constant_sum_helper<uint16_t>(constant, reduction_axes, result_shape);
case element::Type_t::u32:
return fold_constant_sum_helper<uint32_t>(constant, reduction_axes, result_shape);
case element::Type_t::u64:
return fold_constant_sum_helper<uint64_t>(constant, reduction_axes, result_shape);
}
NGRAPH_UNREACHABLE("Unexpected switch case");
}
void pass::ConstantFolding::construct_constant_sum()
{
auto constant_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Sum>(constant_label, AxisSet{0, 1, 2});
auto constant_sum_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_sum_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto sum_match = static_pointer_cast<op::Sum>(m.get_match_root());
replace_node(m.get_match_root(),
fold_constant_sum(constant_match,
sum_match->get_reduction_axes(),
sum_match->get_output_shape(0)));
return true;
};
auto convert_matcher = make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantSum");
this->add_matcher(convert_matcher, constant_sum_callback, all_pass_property_off);
}
template <typename T>
static shared_ptr<op::Constant> fold_constant_concat_helper(const shared_ptr<op::Concat>& concat)
{
......
......@@ -43,6 +43,7 @@ public:
SHAPE_OF,
REVERSE,
PRODUCT,
SUM,
CONCAT
};
......@@ -61,6 +62,7 @@ public:
construct_constant_shape_of();
construct_constant_reverse();
construct_constant_product();
construct_constant_sum();
construct_constant_concat();
}
......@@ -86,6 +88,7 @@ public:
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
case CFTransformations::REVERSE: construct_constant_reverse(); break;
case CFTransformations::PRODUCT: construct_constant_product(); break;
case CFTransformations::SUM: construct_constant_sum(); break;
case CFTransformations::CONCAT: construct_constant_concat(); break;
}
}
......@@ -103,6 +106,7 @@ private:
void construct_constant_shape_of();
void construct_constant_reverse();
void construct_constant_product();
void construct_constant_sum();
void construct_constant_concat();
ngraph::BuildNodeExecutorMap m_cfmap;
......
......@@ -408,6 +408,32 @@ TEST(constant_folding, const_product)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_sum)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Sum>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Sum>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{6, 15, 24};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat)
{
auto constant0 =
......@@ -429,6 +455,7 @@ TEST(constant_folding, const_concat)
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 2, 3, 7, 4, 5, 6, 8};
ASSERT_EQ(values_expected, values_out);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment