Unverified Commit c29c9f89 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Generalize a ReLU pattern to fuse it in VGG (#529)

* generalize Relu

* format fixes
parent a6c86263
......@@ -15,13 +15,13 @@
*******************************************************************************/
#include <algorithm>
#include <iostream>
#include <unordered_set>
#include "ngraph/pattern/core_fusion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/parameter.hpp"
......@@ -51,7 +51,12 @@ void pass::CoreFusion::construct_relu_pattern()
auto iconst0 = construct_constant_node(0);
auto val = make_shared<pattern::op::Label>(iconst0);
auto zero = make_shared<pattern::op::Label>(iconst0, nullptr, Nodes{iconst0});
auto max = make_shared<op::Maximum>(zero, val);
auto broadcast_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Broadcast>(n));
};
auto skip_broadcast = std::make_shared<pattern::op::Any>(zero, broadcast_pred);
auto max = make_shared<op::Maximum>(skip_broadcast, val);
pattern::gr_callback_fn callback = [val, zero](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_relu_pattern against "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment