Commit 71f13654 authored by gaurides's avatar gaurides Committed by Scott Cyphers

Skip Broadcast in sigmoid fusion (#2197)

* Skip Broadcast in sigmoid fusion

* added test case; modified file perms

* incorporate review comments

* using is_one() to check the node is constant&1
parent fc216f39
......@@ -92,15 +92,15 @@ void pass::CoreFusion::construct_sigmoid()
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
// broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto skip_broadcast =
std::make_shared<pattern::op::Skip>(constant, pattern::has_class<op::Broadcast>());
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant);
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
auto add_exp = std::make_shared<op::Add>(exp_neg_input, skip_broadcast);
auto divide_1_over_exp = std::make_shared<op::Divide>(skip_broadcast, add_exp);
// Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
ngraph::pattern::graph_rewrite_callback callback = [input, constant](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -119,6 +119,11 @@ void pass::CoreFusion::construct_sigmoid()
return false;
}
if (!is_one(pattern_map[constant]))
{
NGRAPH_DEBUG << "Node not constant or not 1";
return false;
}
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
ngraph::replace_node(m.get_match_root(), sigmoid_node);
return true;
......@@ -151,11 +156,11 @@ void pass::CoreFusion::construct_sigmoid_bprop()
auto divide_2 = std::make_shared<op::Divide>(multiply_sigmoid_delta, add_exp);
auto multiply_2 = std::make_shared<op::Multiply>(divide_2, exp_neg_input);
auto negtive_2 = std::make_shared<op::Negative>(multiply_2);
auto negative_2 = std::make_shared<op::Negative>(multiply_2);
// Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
NGRAPH_DEBUG << "In a callback for construct_bprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
if (m.get_match_root()->get_element_type() != element::f32)
......@@ -178,7 +183,7 @@ void pass::CoreFusion::construct_sigmoid_bprop()
};
auto m =
std::make_shared<ngraph::pattern::Matcher>(negtive_2, callback, "CoreFusion.SigmoidBprop");
std::make_shared<ngraph::pattern::Matcher>(negative_2, callback, "CoreFusion.SigmoidBprop");
this->add_matcher(m);
}
......
......@@ -74,6 +74,53 @@ TEST(core_fusion, sigmoid_fprop_fusion)
ASSERT_EQ(ccg, 1);
}
TEST(core_fusion, sigmoid_fprop_fusion_no_broadcast)
{
auto make_function = []() {
auto input = std::make_shared<op::Parameter>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
auto constant =
op::Constant::create(element::f32, Shape{3, 4}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, constant);
auto divide_1_over_exp = std::make_shared<op::Divide>(constant, add_exp);
return make_shared<Function>(NodeVector{divide_1_over_exp}, ParameterVector{input});
};
auto func = make_function();
// Check fusion happens
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>();
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::Sigmoid>(func);
ASSERT_EQ(ccg, 1);
}
TEST(core_fusion, sigmoid_fprop_fusion_no_broadcast2)
{
auto make_function = []() {
auto input = std::make_shared<op::Parameter>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
auto constant =
op::Constant::create(element::f32, Shape{3, 4}, {1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, constant);
auto divide_1_over_exp = std::make_shared<op::Divide>(constant, add_exp);
return make_shared<Function>(NodeVector{divide_1_over_exp}, ParameterVector{input});
};
auto func = make_function();
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>();
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::Sigmoid>(func);
ASSERT_EQ(ccg, 0);
}
TEST(core_fusion, sigmoid_bprop_fusion)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/Graph_fprop_sigmoid.json");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment