Commit 8eed208b authored by Adam Procter's avatar Adam Procter

Unify folders for arithmetic reduction ops; add Max and Min

parent d8d940d0
This diff is collapsed.
...@@ -42,8 +42,7 @@ public: ...@@ -42,8 +42,7 @@ public:
CONVERT, CONVERT,
SHAPE_OF, SHAPE_OF,
REVERSE, REVERSE,
PRODUCT, ARITHMETIC_REDUCTION,
SUM,
CONCAT, CONCAT,
GATHER, GATHER,
SLICE, SLICE,
...@@ -66,8 +65,7 @@ public: ...@@ -66,8 +65,7 @@ public:
construct_constant_convert(); construct_constant_convert();
construct_constant_shape_of(); construct_constant_shape_of();
construct_constant_reverse(); construct_constant_reverse();
construct_constant_product(); construct_constant_arithmetic_reduction();
construct_constant_sum();
construct_constant_concat(); construct_constant_concat();
construct_constant_gather(); construct_constant_gather();
construct_constant_slice(); construct_constant_slice();
...@@ -97,8 +95,9 @@ public: ...@@ -97,8 +95,9 @@ public:
case CFTransformations::CONVERT: construct_constant_convert(); break; case CFTransformations::CONVERT: construct_constant_convert(); break;
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break; case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
case CFTransformations::REVERSE: construct_constant_reverse(); break; case CFTransformations::REVERSE: construct_constant_reverse(); break;
case CFTransformations::PRODUCT: construct_constant_product(); break; case CFTransformations::ARITHMETIC_REDUCTION:
case CFTransformations::SUM: construct_constant_sum(); break; construct_constant_arithmetic_reduction();
break;
case CFTransformations::CONCAT: construct_constant_concat(); break; case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::GATHER: construct_constant_gather(); break; case CFTransformations::GATHER: construct_constant_gather(); break;
case CFTransformations::SLICE: construct_constant_slice(); break; case CFTransformations::SLICE: construct_constant_slice(); break;
...@@ -120,8 +119,7 @@ private: ...@@ -120,8 +119,7 @@ private:
void construct_constant_convert(); void construct_constant_convert();
void construct_constant_shape_of(); void construct_constant_shape_of();
void construct_constant_reverse(); void construct_constant_reverse();
void construct_constant_product(); void construct_constant_arithmetic_reduction();
void construct_constant_sum();
void construct_constant_concat(); void construct_constant_concat();
void construct_constant_gather(); void construct_constant_gather();
void construct_constant_slice(); void construct_constant_slice();
......
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
{ {
T minval = std::numeric_limits<T>::has_infinity T minval = std::numeric_limits<T>::has_infinity
? -std::numeric_limits<T>::infinity() ? T(-std::numeric_limits<T>::infinity())
: std::numeric_limits<T>::min(); : std::numeric_limits<T>::min();
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
......
...@@ -434,6 +434,58 @@ TEST(constant_folding, const_sum) ...@@ -434,6 +434,58 @@ TEST(constant_folding, const_sum)
ASSERT_EQ(values_expected, values_out); ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, const_max)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Max>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Max>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{3, 6, 9};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_min)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Min>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Min>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 4, 7};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat) TEST(constant_folding, const_concat)
{ {
auto constant0 = auto constant0 =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment