.. fusion/passes-that-use-matcher.rst: Passes that use Matcher ======================= * CPUFusion (GraphRewrite) * CoreFusion (GraphRewrite) * ReshapeElimination (GraphRewrite) * AlgebraicSimplification * CPUPostLayoutOptimizations (GraphRewrite) * CPURnnMatFusion * and many more... Register ``simplify_neg`` handler ---------------------------------- .. code-block:: cpp static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>> initialize_const_values_to_ops() { return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>({ {TI(op::Add), simplify_add}, {TI(op::Multiply), simplify_multiply}, {TI(op::Sum), simplify_sum}, {TI(op::Negative), simplify_neg} }); } Add a fusion ~~~~~~~~~~~~ ``max(0, A) = Relu(A)`` Pattern for capturing ~~~~~~~~~~~~~~~~~~~~~ |image11| ``max(0, A) = Relu(A)`` .. code-block:: cpp namespace ngraph { namespace pass { class CoreFusion; } } class ngraph::pass::CoreFusion : public ngraph::pass::GraphRewrite { public: CoreFusion() : GraphRewrite() { construct_relu_pattern(); } //this should go in a cpp file. void construct_relu_pattern() { auto iconst0 = ngraph::make_zero(element::i32, Shape{}); auto val = make_shared(iconst0); auto zero = make_shared(iconst0, nullptr, NodeVector{iconst0}); auto broadcast_pred = [](std::shared_ptr n) { return static_cast(std::dynamic_pointer_cast(n)); }; auto skip_broadcast = std::make_shared(zero, broadcast_pred); auto max = make_shared(skip_broadcast, val); pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) { NGRAPH_DEBUG << "In a callback for construct_relu_pattern against " << m.get_match_root()->get_name(); auto pattern_map = m.get_pattern_map(); auto mzero = m.get_pattern_map()[zero]; if (!ngraph::is_zero(mzero)) { NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0n"; return false; } auto mpattern = m.get_match_root(); auto cg = shared_ptr(new op::Relu(pattern_map[val])); ngraph::replace_node(m.get_match_root(), cg); return true; }; auto m = make_shared(max, callback); this->add_matcher(m); } }; Recurrent patterns ------------------ Equivalent to ``"A(BC)+A"`` in regexes ``(((A + 0) + 0) + 0) = A`` |image12| |image13| .. code-block:: cpp Shape shape{}; auto a = make_shared<op::Parameter>(element::i32, shape); auto b = make_shared<op::Parameter>(element::i32, shape); auto rpattern = std::make_shared<pattern::op::Label>(b); auto iconst0 = ngraph::make_zero(element::i32, shape); auto abs = make_shared<op::Abs>(a); auto add1 = iconst0 + b; auto add2 = iconst0 + add1; auto add3 = iconst0 + add2; auto padd = iconst0 + rpattern; std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches; RecurrentMatcher rm(padd, rpattern, empty_correlated_matches, nullptr); ASSERT_TRUE(rm.match(add3)); .. |image11| image:: ../fusion/mg/fusion_pattern.png .. |image12| image:: ../fusion/mg/rp_graph1.png .. |image13| image:: ../fusion/mg/rp_pattern.png