Commit 8eed208b authored by Adam Procter's avatar Adam Procter

Unify folders for arithmetic reduction ops; add Max and Min

parent d8d940d0
......@@ -40,7 +40,9 @@
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
......@@ -76,7 +78,9 @@
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
#include "ngraph/runtime/reference/less_eq.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/maximum.hpp"
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp"
......@@ -1584,180 +1588,138 @@ void pass::ConstantFolding::construct_constant_reverse()
}
template <typename T>
static shared_ptr<op::Constant> fold_constant_product_helper(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
static shared_ptr<op::Constant>
fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant,
shared_ptr<Node> reduction_node)
{
vector<T> out_vec(shape_size(result_shape));
vector<T> out_vec(shape_size(reduction_node->get_shape()));
runtime::reference::product<T>(constant->get_vector<T>().data(),
if (auto p = dynamic_pointer_cast<op::Max>(reduction_node))
{
runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
result_shape,
reduction_axes);
return make_shared<op::Constant>(constant->get_output_element_type(0), result_shape, out_vec);
}
static shared_ptr<op::Constant> fold_constant_product(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
{
auto& input_element_type = constant->get_output_element_type(0);
switch (input_element_type.get_type_enum())
reduction_node->get_shape(),
p->get_reduction_axes());
}
else if (auto p = dynamic_pointer_cast<op::Min>(reduction_node))
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_product");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_product");
break;
case element::Type_t::boolean:
return fold_constant_product_helper<char>(constant, reduction_axes, result_shape);
case element::Type_t::bf16:
return fold_constant_product_helper<bfloat16>(constant, reduction_axes, result_shape);
case element::Type_t::f16:
return fold_constant_product_helper<float16>(constant, reduction_axes, result_shape);
case element::Type_t::f32:
return fold_constant_product_helper<float>(constant, reduction_axes, result_shape);
case element::Type_t::f64:
return fold_constant_product_helper<double>(constant, reduction_axes, result_shape);
case element::Type_t::i8:
return fold_constant_product_helper<int8_t>(constant, reduction_axes, result_shape);
case element::Type_t::i16:
return fold_constant_product_helper<int16_t>(constant, reduction_axes, result_shape);
case element::Type_t::i32:
return fold_constant_product_helper<int32_t>(constant, reduction_axes, result_shape);
case element::Type_t::i64:
return fold_constant_product_helper<int64_t>(constant, reduction_axes, result_shape);
case element::Type_t::u8:
return fold_constant_product_helper<uint8_t>(constant, reduction_axes, result_shape);
case element::Type_t::u16:
return fold_constant_product_helper<uint16_t>(constant, reduction_axes, result_shape);
case element::Type_t::u32:
return fold_constant_product_helper<uint32_t>(constant, reduction_axes, result_shape);
case element::Type_t::u64:
return fold_constant_product_helper<uint64_t>(constant, reduction_axes, result_shape);
runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
p->get_reduction_axes());
}
else if (auto p = dynamic_pointer_cast<op::Product>(reduction_node))
{
runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
p->get_reduction_axes());
}
else if (auto s = dynamic_pointer_cast<op::Sum>(reduction_node))
{
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
s->get_reduction_axes());
}
else
{
NGRAPH_CHECK(false,
"Internal nGraph error: Ops handled in "
"fold_constant_arithmetic_reduction_helper must be consistent with those "
"matched in construct_constant_arithmetic_reduction");
}
NGRAPH_UNREACHABLE("Unexpected switch case");
}
void pass::ConstantFolding::construct_constant_product()
{
auto constant_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Product>(constant_label, AxisSet{0, 1, 2});
auto constant_product_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_product_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto product_match = static_pointer_cast<op::Product>(m.get_match_root());
replace_node(m.get_match_root(),
fold_constant_product(constant_match,
product_match->get_reduction_axes(),
product_match->get_output_shape(0)));
return true;
};
auto convert_matcher =
make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantProduct");
this->add_matcher(convert_matcher, constant_product_callback, all_pass_property_off);
}
// TODO(amprocte): Find a way to reduce duplication with Product. (The fact
// that we bottom out in a reference call makes it a bit tricky.)
template <typename T>
static shared_ptr<op::Constant> fold_constant_sum_helper(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
{
vector<T> out_vec(shape_size(result_shape));
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
result_shape,
reduction_axes);
return make_shared<op::Constant>(constant->get_output_element_type(0), result_shape, out_vec);
return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec);
}
static shared_ptr<op::Constant> fold_constant_sum(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
static shared_ptr<op::Constant>
fold_constant_arithmetic_reduction(shared_ptr<op::Constant> constant,
shared_ptr<Node> reduction_node)
{
auto& input_element_type = constant->get_output_element_type(0);
switch (input_element_type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_sum");
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in fold_constant_arithmetic_reduction");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_sum");
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in fold_constant_arithmetic_reduction");
break;
case element::Type_t::boolean:
return fold_constant_sum_helper<char>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node);
case element::Type_t::bf16:
return fold_constant_sum_helper<bfloat16>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<bfloat16>(constant, reduction_node);
case element::Type_t::f16:
return fold_constant_sum_helper<float16>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<float16>(constant, reduction_node);
case element::Type_t::f32:
return fold_constant_sum_helper<float>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<float>(constant, reduction_node);
case element::Type_t::f64:
return fold_constant_sum_helper<double>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<double>(constant, reduction_node);
case element::Type_t::i8:
return fold_constant_sum_helper<int8_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<int8_t>(constant, reduction_node);
case element::Type_t::i16:
return fold_constant_sum_helper<int16_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<int16_t>(constant, reduction_node);
case element::Type_t::i32:
return fold_constant_sum_helper<int32_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<int32_t>(constant, reduction_node);
case element::Type_t::i64:
return fold_constant_sum_helper<int64_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<int64_t>(constant, reduction_node);
case element::Type_t::u8:
return fold_constant_sum_helper<uint8_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<uint8_t>(constant, reduction_node);
case element::Type_t::u16:
return fold_constant_sum_helper<uint16_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<uint16_t>(constant, reduction_node);
case element::Type_t::u32:
return fold_constant_sum_helper<uint32_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<uint32_t>(constant, reduction_node);
case element::Type_t::u64:
return fold_constant_sum_helper<uint64_t>(constant, reduction_axes, result_shape);
return fold_constant_arithmetic_reduction_helper<uint64_t>(constant, reduction_node);
}
NGRAPH_UNREACHABLE("Unexpected switch case");
}
void pass::ConstantFolding::construct_constant_sum()
void pass::ConstantFolding::construct_constant_arithmetic_reduction()
{
auto constant_label = make_shared<pattern::op::Label>(
auto constant_data_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Sum>(constant_label, AxisSet{0, 1, 2});
auto constant_sum_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_sum_callback against node = "
auto constant_axes_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto is_supported_reduction = [](std::shared_ptr<Node> n) {
return (pattern::has_class<op::Max>()(n) || pattern::has_class<op::Min>()(n) ||
pattern::has_class<op::Product>()(n) || pattern::has_class<op::Sum>()(n));
};
auto reduction =
std::make_shared<pattern::op::Any>(element::i32,
Shape{2},
is_supported_reduction,
NodeVector{constant_data_label, constant_axes_label});
auto constant_arithmetic_reduction_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_arithmetic_reduction_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto sum_match = static_pointer_cast<op::Sum>(m.get_match_root());
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto reduction_match = m.get_match_root();
replace_node(m.get_match_root(),
fold_constant_sum(constant_match,
sum_match->get_reduction_axes(),
sum_match->get_output_shape(0)));
replace_node(reduction_match,
fold_constant_arithmetic_reduction(constant_match, reduction_match));
return true;
};
auto convert_matcher = make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantSum");
this->add_matcher(convert_matcher, constant_sum_callback, all_pass_property_off);
auto arithmetic_reduction_matcher =
make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantArithmeticReduction");
this->add_matcher(arithmetic_reduction_matcher,
constant_arithmetic_reduction_callback,
all_pass_property_off);
}
template <typename T>
......
......@@ -42,8 +42,7 @@ public:
CONVERT,
SHAPE_OF,
REVERSE,
PRODUCT,
SUM,
ARITHMETIC_REDUCTION,
CONCAT,
GATHER,
SLICE,
......@@ -66,8 +65,7 @@ public:
construct_constant_convert();
construct_constant_shape_of();
construct_constant_reverse();
construct_constant_product();
construct_constant_sum();
construct_constant_arithmetic_reduction();
construct_constant_concat();
construct_constant_gather();
construct_constant_slice();
......@@ -97,8 +95,9 @@ public:
case CFTransformations::CONVERT: construct_constant_convert(); break;
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
case CFTransformations::REVERSE: construct_constant_reverse(); break;
case CFTransformations::PRODUCT: construct_constant_product(); break;
case CFTransformations::SUM: construct_constant_sum(); break;
case CFTransformations::ARITHMETIC_REDUCTION:
construct_constant_arithmetic_reduction();
break;
case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::GATHER: construct_constant_gather(); break;
case CFTransformations::SLICE: construct_constant_slice(); break;
......@@ -120,8 +119,7 @@ private:
void construct_constant_convert();
void construct_constant_shape_of();
void construct_constant_reverse();
void construct_constant_product();
void construct_constant_sum();
void construct_constant_arithmetic_reduction();
void construct_constant_concat();
void construct_constant_gather();
void construct_constant_slice();
......
......@@ -36,7 +36,7 @@ namespace ngraph
const AxisSet& reduction_axes)
{
T minval = std::numeric_limits<T>::has_infinity
? -std::numeric_limits<T>::infinity()
? T(-std::numeric_limits<T>::infinity())
: std::numeric_limits<T>::min();
CoordinateTransform output_transform(out_shape);
......
......@@ -434,6 +434,58 @@ TEST(constant_folding, const_sum)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_max)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Max>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Max>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{3, 6, 9};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_min)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Min>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Min>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 4, 7};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat)
{
auto constant0 =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment