Unverified Commit 7f57d4e1 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into cyphers/fuseddesc

parents f6528bf5 f50e12a1
This diff is collapsed.
...@@ -34,6 +34,7 @@ public: ...@@ -34,6 +34,7 @@ public:
{ {
RESHAPE, RESHAPE,
BROADCAST, BROADCAST,
DYN_BROADCAST,
PAD, PAD,
DEQUANTIZE, DEQUANTIZE,
UNARY, UNARY,
...@@ -42,8 +43,8 @@ public: ...@@ -42,8 +43,8 @@ public:
CONVERT, CONVERT,
SHAPE_OF, SHAPE_OF,
REVERSE, REVERSE,
PRODUCT, ARITHMETIC_REDUCTION,
SUM, LOGICAL_REDUCTION,
CONCAT, CONCAT,
GATHER, GATHER,
SLICE, SLICE,
...@@ -60,6 +61,7 @@ public: ...@@ -60,6 +61,7 @@ public:
m_cfmap = cfmap; m_cfmap = cfmap;
construct_constant_reshape(); construct_constant_reshape();
construct_constant_broadcast(); construct_constant_broadcast();
construct_constant_dyn_broadcast();
construct_constant_pad(); construct_constant_pad();
construct_constant_unary(); construct_constant_unary();
construct_constant_binary(); construct_constant_binary();
...@@ -68,8 +70,8 @@ public: ...@@ -68,8 +70,8 @@ public:
construct_constant_convert(); construct_constant_convert();
construct_constant_shape_of(); construct_constant_shape_of();
construct_constant_reverse(); construct_constant_reverse();
construct_constant_product(); construct_constant_arithmetic_reduction();
construct_constant_sum(); construct_constant_logical_reduction();
construct_constant_concat(); construct_constant_concat();
construct_constant_gather(); construct_constant_gather();
construct_constant_slice(); construct_constant_slice();
...@@ -93,6 +95,7 @@ public: ...@@ -93,6 +95,7 @@ public:
{ {
case CFTransformations::RESHAPE: construct_constant_reshape(); break; case CFTransformations::RESHAPE: construct_constant_reshape(); break;
case CFTransformations::BROADCAST: construct_constant_broadcast(); break; case CFTransformations::BROADCAST: construct_constant_broadcast(); break;
case CFTransformations::DYN_BROADCAST: construct_constant_dyn_broadcast(); break;
case CFTransformations::PAD: construct_constant_pad(); break; case CFTransformations::PAD: construct_constant_pad(); break;
case CFTransformations::UNARY: construct_constant_unary(); break; case CFTransformations::UNARY: construct_constant_unary(); break;
case CFTransformations::BINARY: construct_constant_binary(); break; case CFTransformations::BINARY: construct_constant_binary(); break;
...@@ -101,8 +104,12 @@ public: ...@@ -101,8 +104,12 @@ public:
case CFTransformations::CONVERT: construct_constant_convert(); break; case CFTransformations::CONVERT: construct_constant_convert(); break;
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break; case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
case CFTransformations::REVERSE: construct_constant_reverse(); break; case CFTransformations::REVERSE: construct_constant_reverse(); break;
case CFTransformations::PRODUCT: construct_constant_product(); break; case CFTransformations::ARITHMETIC_REDUCTION:
case CFTransformations::SUM: construct_constant_sum(); break; construct_constant_arithmetic_reduction();
break;
case CFTransformations::LOGICAL_REDUCTION:
construct_constant_logical_reduction();
break;
case CFTransformations::CONCAT: construct_constant_concat(); break; case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::GATHER: construct_constant_gather(); break; case CFTransformations::GATHER: construct_constant_gather(); break;
case CFTransformations::SLICE: construct_constant_slice(); break; case CFTransformations::SLICE: construct_constant_slice(); break;
...@@ -118,6 +125,7 @@ public: ...@@ -118,6 +125,7 @@ public:
private: private:
void construct_constant_reshape(); void construct_constant_reshape();
void construct_constant_broadcast(); void construct_constant_broadcast();
void construct_constant_dyn_broadcast();
void construct_constant_pad(); void construct_constant_pad();
void construct_constant_unary(); void construct_constant_unary();
void construct_constant_binary(); void construct_constant_binary();
...@@ -126,8 +134,8 @@ private: ...@@ -126,8 +134,8 @@ private:
void construct_constant_convert(); void construct_constant_convert();
void construct_constant_shape_of(); void construct_constant_shape_of();
void construct_constant_reverse(); void construct_constant_reverse();
void construct_constant_product(); void construct_constant_arithmetic_reduction();
void construct_constant_sum(); void construct_constant_logical_reduction();
void construct_constant_concat(); void construct_constant_concat();
void construct_constant_gather(); void construct_constant_gather();
void construct_constant_slice(); void construct_constant_slice();
......
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
{ {
T minval = std::numeric_limits<T>::has_infinity T minval = std::numeric_limits<T>::has_infinity
? -std::numeric_limits<T>::infinity() ? T(-std::numeric_limits<T>::infinity())
: std::numeric_limits<T>::min(); : std::numeric_limits<T>::min();
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
......
...@@ -97,8 +97,35 @@ TEST(constant_folding, constant_broadcast) ...@@ -97,8 +97,35 @@ TEST(constant_folding, constant_broadcast)
ASSERT_TRUE(new_const); ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>(); auto values_out = new_const->get_vector<int>();
vector<int> values_permute{0, 0, 0, 0, 1, 1, 1, 1}; vector<int> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_permute, values_out); ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_dyn_broadcast)
{
vector<int32_t> values_in{0, 1};
auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
vector<int64_t> shape_in{2, 4};
auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
vector<int64_t> axes_in{1};
auto constant_axes = make_shared<op::Constant>(element::i64, Shape{1}, axes_in);
auto dyn_broadcast = make_shared<op::DynBroadcast>(constant_in, constant_shape, constant_axes);
auto f = make_shared<Function>(dyn_broadcast, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynBroadcast>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, constant_pad_exterior) TEST(constant_folding, constant_pad_exterior)
...@@ -434,6 +461,110 @@ TEST(constant_folding, const_sum) ...@@ -434,6 +461,110 @@ TEST(constant_folding, const_sum)
ASSERT_EQ(values_expected, values_out); ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, const_max)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Max>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Max>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{3, 6, 9};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_min)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Min>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Min>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 4, 7};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_all)
{
Shape input_shape{3, 3};
vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::All>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::All>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_any)
{
Shape input_shape{3, 3};
vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::Any>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Any>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 1, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat) TEST(constant_folding, const_concat)
{ {
auto constant0 = auto constant0 =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment