Unverified Commit d1d27d9e authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into silee2/pragma

parents 66b6f186 f50e12a1
This diff is collapsed.
...@@ -34,6 +34,7 @@ public: ...@@ -34,6 +34,7 @@ public:
{ {
RESHAPE, RESHAPE,
BROADCAST, BROADCAST,
DYN_BROADCAST,
PAD, PAD,
DEQUANTIZE, DEQUANTIZE,
UNARY, UNARY,
...@@ -60,6 +61,7 @@ public: ...@@ -60,6 +61,7 @@ public:
m_cfmap = cfmap; m_cfmap = cfmap;
construct_constant_reshape(); construct_constant_reshape();
construct_constant_broadcast(); construct_constant_broadcast();
construct_constant_dyn_broadcast();
construct_constant_pad(); construct_constant_pad();
construct_constant_unary(); construct_constant_unary();
construct_constant_binary(); construct_constant_binary();
...@@ -93,6 +95,7 @@ public: ...@@ -93,6 +95,7 @@ public:
{ {
case CFTransformations::RESHAPE: construct_constant_reshape(); break; case CFTransformations::RESHAPE: construct_constant_reshape(); break;
case CFTransformations::BROADCAST: construct_constant_broadcast(); break; case CFTransformations::BROADCAST: construct_constant_broadcast(); break;
case CFTransformations::DYN_BROADCAST: construct_constant_dyn_broadcast(); break;
case CFTransformations::PAD: construct_constant_pad(); break; case CFTransformations::PAD: construct_constant_pad(); break;
case CFTransformations::UNARY: construct_constant_unary(); break; case CFTransformations::UNARY: construct_constant_unary(); break;
case CFTransformations::BINARY: construct_constant_binary(); break; case CFTransformations::BINARY: construct_constant_binary(); break;
...@@ -122,6 +125,7 @@ public: ...@@ -122,6 +125,7 @@ public:
private: private:
void construct_constant_reshape(); void construct_constant_reshape();
void construct_constant_broadcast(); void construct_constant_broadcast();
void construct_constant_dyn_broadcast();
void construct_constant_pad(); void construct_constant_pad();
void construct_constant_unary(); void construct_constant_unary();
void construct_constant_binary(); void construct_constant_binary();
......
...@@ -97,8 +97,35 @@ TEST(constant_folding, constant_broadcast) ...@@ -97,8 +97,35 @@ TEST(constant_folding, constant_broadcast)
ASSERT_TRUE(new_const); ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>(); auto values_out = new_const->get_vector<int>();
vector<int> values_permute{0, 0, 0, 0, 1, 1, 1, 1}; vector<int> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_permute, values_out); ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_dyn_broadcast)
{
vector<int32_t> values_in{0, 1};
auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
vector<int64_t> shape_in{2, 4};
auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
vector<int64_t> axes_in{1};
auto constant_axes = make_shared<op::Constant>(element::i64, Shape{1}, axes_in);
auto dyn_broadcast = make_shared<op::DynBroadcast>(constant_in, constant_shape, constant_axes);
auto f = make_shared<Function>(dyn_broadcast, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynBroadcast>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, constant_pad_exterior) TEST(constant_folding, constant_pad_exterior)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment