graph-rewrite.rst 2.77 KB
Newer Older
1 2 3
.. fusion/graph-rewrite.rst:

Using ``GraphRewrite`` to fuse ops
4
==================================
5

6 7 8 9
* :ref:`exact`
* :ref:`labels`

.. _exact: 
10 11

Exact pattern matching
12
----------------------
13

14
For the example of ``-(-A) = A``, various graphs of varying complexity can be 
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
created and overwritten with recipes for pattern-matching + graph-rewrite. To 
get started, a simple example for a trivial graph, followed by more complex 
examples: 


|image3|

.. code-block:: cpp 

    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto absn = make_shared<op::Abs>(a);
    auto neg1 = make_shared<op::Negative>(absn);
    auto neg2 = make_shared<op::Negative>(neg1);


|image4|

	
.. code-block:: cpp 

    auto a = make_shared<op::Parameter>(element::i32, shape);
    auto b = make_shared<op::Parameter>(element::i32, shape);
    auto c = a + b;
    auto absn = make_shared<op::Abs>(c);
    auto neg1 = make_shared<op::Negative>(absn);
    auto neg2 = make_shared<op::Negative>(neg1);


Label AKA ``.`` in regexes
44
~~~~~~~~~~~~~~~~~~~~~~~~~~~
45 46 47 48 49 50 51 52 53 54 55 56 57 58


|image5|

For the code below, ``element::f32`` will still match integer Graph1 and 
Graph2 

.. code-block:: cpp

    //note element::f32, will still match integer Graph1 and Graph2 
    auto lbl = std::make_shared<pattern::op::Label>(element::f32, Shape{});  
    auto neg1 = make_shared<op::Negative>(lbl);
    auto neg2 = make_shared<op::Negative>(neg1);

59 60 61 62


.. _labels:

63
Constructing labels from existing nodes
64
---------------------------------------
65

66 67
Double Negative w/ Addition
~~~~~~~~~~~~~~~~~~~~~~~~~~~
68 69 70 71 72 73 74 75 76 77 78 79

|image6|


.. code-block:: cpp

    auto a = make_shared<op::Parameter>(element::i32, shape);
    //`lbl` borrows the type and shape information from `a`
    auto lbl = std::make_shared<pattern::op::Label>(a);  
    auto neg1 = make_shared<op::Negative>(a);
    auto neg2 = make_shared<op::Negative>(neg1);

80 81
Double Negative w/ Subtraction  
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

|image7|


Predicates are of type ``std::function<bool(std::shared_ptr<Node>)>``


.. code-block:: cpp
   
    //predicates are of type std::function<bool(std::shared_ptr<Node>)>
    auto add_or_sub = [](std::shared_ptr<Node> n) {
        return std::dynamic_pointer_cast<op::Add>(n) != nullptr ||
            std::dynamic_pointer_cast<op::Sub>(n) != nullptr
    };

    auto lbl = std::make_shared<pattern::op::Label>(
        element::f32, 
        Shape{}, 
        add_or_sub
    );  
    auto neg1 = make_shared<op::Negative>(a);
    auto neg2 = make_shared<op::Negative>(neg1);






.. |image3| image:: mg/pr1_graph2.png
.. |image4| image:: mg/pr1_graph3.png
.. |image5| image:: mg/pr1_pattern2.png
.. |image6| image:: mg/pr1_graph4.png
.. |image7| image:: mg/pr1_graph5.png
.. |image8| image:: mg/pr2_graph1.png
.. |image9| image:: mg/pr2_graph2.png
.. |image10| image:: mg/pr2_pattern2.png
118