Commit 2ba6aea4 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support constant folding on V1 broadcast ops (#3814)

parent 55d33755
...@@ -24,7 +24,7 @@ using namespace ngraph; ...@@ -24,7 +24,7 @@ using namespace ngraph;
template <class T> template <class T>
shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> constant,
shared_ptr<op::Broadcast> broadcast, shared_ptr<Node> broadcast,
NodeExecutorTy func) NodeExecutorTy func)
{ {
auto out_shape = broadcast->get_shape(); auto out_shape = broadcast->get_shape();
...@@ -39,13 +39,33 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta ...@@ -39,13 +39,33 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
func(inputs, outputs); func(inputs, outputs);
} }
else if (auto broadcast_v1 = as_type_ptr<op::v1::Broadcast>(broadcast))
{
auto static_bcast_axes = broadcast_v1->get_broadcast_axes();
if (static_bcast_axes.first)
{
runtime::reference::broadcast<T>(constant->get_data_ptr<T>(),
out_vec.data(),
constant->get_shape(),
out_shape,
static_bcast_axes.second);
}
else else
{
throw ngraph_error("Unexpected failure due to inability to obtain broadcast axes.");
}
}
else if (auto broadcast_v0 = as_type_ptr<op::v0::Broadcast>(broadcast))
{ {
runtime::reference::broadcast<T>(constant->get_data_ptr<T>(), runtime::reference::broadcast<T>(constant->get_data_ptr<T>(),
out_vec.data(), out_vec.data(),
constant->get_shape(), constant->get_shape(),
out_shape, out_shape,
broadcast->get_broadcast_axes()); broadcast_v0->get_broadcast_axes());
}
else
{
throw ngraph_error("Unsupported op in broadcast constant folding.");
} }
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
...@@ -56,7 +76,14 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -56,7 +76,14 @@ void pass::ConstantFolding::construct_constant_broadcast()
auto constant_label = auto constant_label =
make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>()); make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>());
auto broadcast = make_shared<op::Broadcast>(constant_label, Shape{2, 4}, AxisSet{1}); auto broadcast_v0 = make_shared<op::v0::Broadcast>(constant_label, Shape{2, 4}, AxisSet{1});
auto constant_shape =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto constant_axes =
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
auto broadcast_v1 =
make_shared<op::v1::Broadcast>(constant_label, constant_shape, constant_axes);
auto constant_broadcast_callback = [&, constant_label](pattern::Matcher& m) { auto constant_broadcast_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_broadcast_callback against node = " NGRAPH_DEBUG << "In callback for constant_broadcast_callback against node = "
...@@ -65,7 +92,7 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -65,7 +92,7 @@ void pass::ConstantFolding::construct_constant_broadcast()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto broadcast_match = static_pointer_cast<op::Broadcast>(m.get_match_root()); auto broadcast_match = m.get_match_root();
NGRAPH_CHECK(revalidate_and_ensure_static(broadcast_match)); NGRAPH_CHECK(revalidate_and_ensure_static(broadcast_match));
...@@ -135,8 +162,13 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -135,8 +162,13 @@ void pass::ConstantFolding::construct_constant_broadcast()
return true; return true;
}; };
auto broadcast_matcher =
make_shared<pattern::Matcher>(broadcast, "ConstantFolding.ConstantBroadcast");
this->add_matcher( this->add_matcher(
broadcast_matcher, constant_broadcast_callback, PassProperty::CHANGE_DYNAMIC_STATE); make_shared<pattern::Matcher>(broadcast_v0, "ConstantFolding.ConstantBroadcastV0"),
constant_broadcast_callback,
PassProperty::CHANGE_DYNAMIC_STATE);
this->add_matcher(
make_shared<pattern::Matcher>(broadcast_v1, "ConstantFolding.ConstantBroadcastV1"),
constant_broadcast_callback,
PassProperty::CHANGE_DYNAMIC_STATE);
} }
...@@ -180,6 +180,56 @@ TEST(constant_folding, constant_dyn_broadcast) ...@@ -180,6 +180,56 @@ TEST(constant_folding, constant_dyn_broadcast)
ASSERT_EQ(values_expected, values_out); ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, constant_broadcast_v1)
{
vector<int32_t> values_in{0, 1};
auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
vector<int64_t> shape_in{2, 4};
auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
vector<int64_t> axes_in{0};
auto constant_axes = make_shared<op::Constant>(element::i64, Shape{1}, axes_in);
auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, constant_shape, constant_axes);
auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_broadcast_v1_numpy)
{
vector<int32_t> values_in{0, 1};
auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
vector<int64_t> shape_in{4, 2};
auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, constant_shape);
auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{0, 1, 0, 1, 0, 1, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_pad_exterior) TEST(constant_folding, constant_pad_exterior)
{ {
Shape shape_in{2}; Shape shape_in{2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment