Commit f2b73a76 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Added predicate for alpha, in BoundedRelu (#1205)

parent 0768a969
......@@ -1236,7 +1236,11 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_bounded_relu()
auto relu = std::make_shared<op::Relu>(relu_input);
auto iconst1 = op::Constant::create(element::f32, Shape{}, {1});
auto alpha = std::make_shared<pattern::op::Label>(iconst1);
auto min = std::make_shared<op::Minimum>(relu, alpha);
auto broadcast_pred = [](std::shared_ptr<Node> n) {
return (std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr);
};
auto skip_broadcast = std::make_shared<pattern::op::Skip>(alpha, broadcast_pred);
auto min = std::make_shared<op::Minimum>(relu, skip_broadcast);
pattern::graph_rewrite_callback callback = [relu_input, alpha](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_bounded_relu against "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment