• Pruthvi's avatar
    LSTM fusion + RNN fusion across time slice's for single layer (#826) · 1d08f073
    Pruthvi authored
    * - Added pattren matcher for LSTM cell
    
    * WIP added support to replace lstm cell instead of subgraph
    
    * WIP LSTM pattern matcher, fuses recurrent cells
    
    * WIP added RNN CPU op
    
    * WIP mkldnn emmiter code for fprop RNN
    
    * WIP RNN mkldnn integration
    - Added mkldnn kernel for uni directional LSTM in the CPU emitter
    
    * add a getter for root node
    
    * recurrent graph rewrite
    
    * fix perms, rename match_root -> get_match_root
    
    * fix comp errors
    
    * make match_root return the topmost match; fix tests
    
    * - WIP GetOutputElement for handling multiple LSTM o/ps
    - use RecurrentGraphRewrite for replacing node after matching LSTM cells
    
    * WIP LSTM multi Output + debug prints
    
    * moved LSTM fusion to cpu_fusion
    
    * WIP added RNN superfused OP
    
    * WIP towards RNN layer fusion
    
    * WIP multiple output slicing RNN
    
    * WIP RNN mulitple o/ps fusion across layer
    
    * WIP corrected input params for fused RNN OP
    
    * concat corrosponding param's across differnt LSTM to form inputs to RNN fused op
    
    * i) Added  test case for RNN kernel ii) runs without error's
    
    * refactored and moved LSTM class to standalone file
    
    * Rename RNN -> Rnn , LSTM -> Lstm
    
    * WIP replace lstm slices to the consumer op
    
    * Slicing works on multiple RNN layers
    
    * fixed all bugs
    
    * - Added CPU RNN Recurrent Fusion
    - Added CPU LSTM fusion
    - removed debug code
    - style fix
    
    * - Added support to compute src_iter and dst_iter instead of taking zero_memory_desc
    - Added unit test to compute one LSTM cell
    
    *  changed RNN op signature to accept number of states in basic unit of RNN(GRU/LSTM/ vanilla RNN) cell
    
    * added sanity checks for RNN op
    
    * Fixed issue related to patching the graph while replacing the RNN sliced outputs
    
    * Fixed issue to feed the input symbols in the order X0, X1, ...Xt to the RNN op
    
    * Added unit test for multi layer RNN fusion
    
    * Removed debug statements
    
    * Added mulitlayered serialized graph ii) fixed compilation issue
    
    * Addressed PR comments
    
    * i) WIP MKLDNN layout for RNN Op ii) added test case for INTERPRETER v/s CPU Rnn results
    
    * - Fixed bug w.r.to src_layer feature size in rnn mkldnn emitter code
    - Refactored cpu_fusion rnn test case
    
    * merge origin/master with branch pruthvi/lstm_fusion
    
    * style fix
    
    * Added test case for multiple RNN layers
    
    * i) make rnn as mkldnn op if it meets the constraints ii) assert if rnn is not mkldnn op
    
    * fix unit test failure
    
    * - Added support to reliabily identify the hiddent state and input symbols from the nodes collected by Pattern matcher
    - Fixed failing unit tests
    
    * style fix
    
    * - removed "node type" dependency to replace the intermediate LSTM outputs
    
    * Addressed PR comments
    
    * Fix unit test
    
    * - added MKLDNN emitter for LSTM op
    - graph pass to concat LSTM input recurrent state tensors
    - CPU layout assignment for LSTM Op
    - Fixed bug in rnn/lstm unit test's
    - made changes to use replace_output instead of replace_node for replacing matched graph nodes in LSTM/RNN fusion pass
    
    (cherry picked from commit d16fc709265cc0a73e60c6d5f6d2878e7b908aca)
    
    * style fix
    
    * Renamed passes and style fixes
    1d08f073
1_lstm_cell_forward.json 7.89 KB
[{
  "name" : "Function_0",
  "ops" : [
    {
      "element_type" : "float",
      "inputs" : [],
      "name" : "Parameter_10",
      "op" : "Parameter",
      "outputs" : ["Parameter_10_0"],
      "shape" : [400]
    },
    {
      "element_type" : "float",
      "inputs" : [],
      "name" : "Parameter_9",
      "op" : "Parameter",
      "outputs" : ["Parameter_9_0"],
      "shape" : [ 400, 100 ]
    },
    {
      "element_type" : "float",
      "inputs" : [],
      "name" : "Parameter_2",
      "op" : "Parameter",
      "outputs" : ["Parameter_2_0"],
      "shape" : [400]
    },
    {
      "element_type" : "float",
      "inputs" : [],
      "name" : "Parameter_1",
      "op" : "Parameter",
      "outputs" : ["Parameter_1_0"],
      "shape" : [ 400, 100 ]
    },
    {
      "element_type" : "float",
      "inputs" : [],
      "name" : "Parameter_0",
      "op" : "Parameter",
      "outputs" : ["Parameter_0_0"],
      "shape" : [ 10, 100 ]
    },
    {
      "element_type" : "float",
      "inputs" : [],
      "name" : "Constant_7",
      "op" : "Constant",
      "outputs" : ["Constant_7_0"],
      "shape" : [],
      "value" : ["1"]
    },
    {
      "element_type" : "float",
      "inputs" : [],
      "name" : "Constant_34",
      "op" : "Constant",
      "outputs" : ["Constant_34_0"],
      "shape" : [],
      "value" : ["1"]
    },
    {
      "element_type" : "float",
      "inputs" : [],
      "name" : "Constant_30",
      "op" : "Constant",
      "outputs" : ["Constant_30_0"],
      "shape" : [],
      "value" : ["1"]
    },
    {
      "element_type" : "float",
      "inputs" : [],
      "name" : "Constant_24",
      "op" : "Constant",
      "outputs" : ["Constant_24_0"],
      "shape" : [],
      "value" : ["1"]
    },
    {
      "element_type" : "float",
      "inputs" : [],
      "name" : "Constant_17",
      "op" : "Constant",
      "outputs" : ["Constant_17_0"],
      "shape" : [],
      "value" : ["1"]
    },
    {
      "axes" : [0],
      "inputs" : ["Parameter_10"],
      "name" : "Broadcast_13",
      "op" : "Broadcast",
      "outputs" : ["Broadcast_13_0"],
      "shape" : [ 10, 400 ]
    },
    {
      "input_order" : [ 1, 0 ],
      "inputs" : ["Parameter_9"],
      "name" : "Reshape_11",
      "op" : "Reshape",
      "output_shape" : [ 100, 400 ],
      "outputs" : ["Reshape_11_0"]
    },
    {
      "axes" : [0],
      "inputs" : ["Parameter_2"],
      "name" : "Broadcast_5",
      "op" : "Broadcast",
      "outputs" : ["Broadcast_5_0"],
      "shape" : [ 10, 400 ]
    },
    {
      "input_order" : [ 1, 0 ],
      "inputs" : ["Parameter_1"],
      "name" : "Reshape_3",
      "op" : "Reshape",
      "output_shape" : [ 100, 400 ],
      "outputs" : ["Reshape_3_0"]
    },
    {
      "axes" : [ 0, 1 ],
      "inputs" : ["Constant_7"],
      "name" : "Broadcast_8",
      "op" : "Broadcast",
      "outputs" : ["Broadcast_8_0"],
      "shape" : [ 10, 100 ]
    },
    {
      "axes" : [ 0, 1 ],
      "inputs" : ["Constant_34"],
      "name" : "Broadcast_35",
      "op" : "Broadcast",
      "outputs" : ["Broadcast_35_0"],
      "shape" : [ 10, 100 ]
    },
    {
      "axes" : [ 0, 1 ],
      "inputs" : ["Constant_30"],
      "name" : "Broadcast_31",
      "op" : "Broadcast",
      "outputs" : ["Broadcast_31_0"],
      "shape" : [ 10, 100 ]
    },
    {
      "axes" : [ 0, 1 ],
      "inputs" : ["Constant_24"],
      "name" : "Broadcast_25",
      "op" : "Broadcast",
      "outputs" : ["Broadcast_25_0"],
      "shape" : [ 10, 100 ]
    },
    {
      "axes" : [ 0, 1 ],
      "inputs" : ["Constant_17"],
      "name" : "Broadcast_18",
      "op" : "Broadcast",
      "outputs" : ["Broadcast_18_0"],
      "shape" : [ 10, 100 ]
    },
    {
      "inputs" : [ "Parameter_0", "Reshape_3" ],
      "name" : "Dot_4",
      "op" : "Dot",
      "outputs" : ["Dot_4_0"],
      "reduction_axes_count" : 1
    },
    {
      "inputs" : [ "Broadcast_8", "Reshape_11" ],
      "name" : "Dot_12",
      "op" : "Dot",
      "outputs" : ["Dot_12_0"],
      "reduction_axes_count" : 1
    },
    {
      "inputs" : [ "Dot_4", "Broadcast_5" ],
      "name" : "Add_6",
      "op" : "Add",
      "outputs" : ["Add_6_0"]
    },
    {
      "inputs" : [ "Dot_12", "Broadcast_13" ],
      "name" : "Add_14",
      "op" : "Add",
      "outputs" : ["Add_14_0"]
    },
    {
      "inputs" : [ "Add_6", "Add_14" ],
      "name" : "Add_15",
      "op" : "Add",
      "outputs" : ["Add_15_0"]
    },
    {
      "inputs" : ["Add_15"],
      "lower_bounds" : [ 0, 300 ],
      "name" : "Slice_16",
      "op" : "Slice",
      "outputs" : ["Slice_16_0"],
      "strides" : [ 1, 1 ],
      "upper_bounds" : [ 10, 400 ]
    },
    {
      "inputs" : ["Add_15"],
      "lower_bounds" : [ 0, 100 ],
      "name" : "Slice_23",
      "op" : "Slice",
      "outputs" : ["Slice_23_0"],
      "strides" : [ 1, 1 ],
      "upper_bounds" : [ 10, 200 ]
    },
    {
      "inputs" : ["Add_15"],
      "lower_bounds" : [ 0, 0 ],
      "name" : "Slice_33",
      "op" : "Slice",
      "outputs" : ["Slice_33_0"],
      "strides" : [ 1, 1 ],
      "upper_bounds" : [ 10, 100 ]
    },
    {
      "inputs" : ["Add_15"],
      "lower_bounds" : [ 0, 200 ],
      "name" : "Slice_40",
      "op" : "Slice",
      "outputs" : ["Slice_40_0"],
      "strides" : [ 1, 1 ],
      "upper_bounds" : [ 10, 300 ]
    },
    {
      "inputs" : ["Slice_16"],
      "name" : "Negative_19",
      "op" : "Negative",
      "outputs" : ["Negative_19_0"]
    },
    {
      "inputs" : ["Slice_23"],
      "name" : "Negative_26",
      "op" : "Negative",
      "outputs" : ["Negative_26_0"]
    },
    {
      "inputs" : ["Slice_33"],
      "name" : "Negative_36",
      "op" : "Negative",
      "outputs" : ["Negative_36_0"]
    },
    {
      "inputs" : ["Slice_40"],
      "name" : "Tanh_41",
      "op" : "Tanh",
      "outputs" : ["Tanh_41_0"]
    },
    {
      "inputs" : ["Negative_19"],
      "name" : "Exp_20",
      "op" : "Exp",
      "outputs" : ["Exp_20_0"]
    },
    {
      "inputs" : ["Negative_26"],
      "name" : "Exp_27",
      "op" : "Exp",
      "outputs" : ["Exp_27_0"]
    },
    {
      "inputs" : ["Negative_36"],
      "name" : "Exp_37",
      "op" : "Exp",
      "outputs" : ["Exp_37_0"]
    },
    {
      "inputs" : [ "Broadcast_18", "Exp_20" ],
      "name" : "Add_21",
      "op" : "Add",
      "outputs" : ["Add_21_0"]
    },
    {
      "inputs" : [ "Broadcast_25", "Exp_27" ],
      "name" : "Add_28",
      "op" : "Add",
      "outputs" : ["Add_28_0"]
    },
    {
      "inputs" : [ "Broadcast_35", "Exp_37" ],
      "name" : "Add_38",
      "op" : "Add",
      "outputs" : ["Add_38_0"]
    },
    {
      "inputs" : [ "Broadcast_18", "Add_21" ],
      "name" : "Divide_22",
      "op" : "Divide",
      "outputs" : ["Divide_22_0"]
    },
    {
      "inputs" : [ "Broadcast_25", "Add_28" ],
      "name" : "Divide_29",
      "op" : "Divide",
      "outputs" : ["Divide_29_0"]
    },
    {
      "inputs" : [ "Broadcast_35", "Add_38" ],
      "name" : "Divide_39",
      "op" : "Divide",
      "outputs" : ["Divide_39_0"]
    },
    {
      "inputs" : [ "Divide_29", "Broadcast_31" ],
      "name" : "Multiply_32",
      "op" : "Multiply",
      "outputs" : ["Multiply_32_0"]
    },
    {
      "inputs" : [ "Divide_39", "Tanh_41" ],
      "name" : "Multiply_42",
      "op" : "Multiply",
      "outputs" : ["Multiply_42_0"]
    },
    {
      "inputs" : [ "Multiply_32", "Multiply_42" ],
      "name" : "Add_43",
      "op" : "Add",
      "outputs" : ["Add_43_0"]
    },
    {
      "inputs" : ["Add_43"],
      "name" : "Tanh_44",
      "op" : "Tanh",
      "outputs" : ["Tanh_44_0"]
    },
    {
      "inputs" : [ "Divide_22", "Tanh_44" ],
      "name" : "Multiply_45",
      "op" : "Multiply",
      "outputs" : ["Multiply_45_0"]
    },
    {
      "inputs" : ["Multiply_45"],
      "name" : "Result_46",
      "op" : "Result",
      "outputs" : ["Result_46_0"]
    }
  ],
  "parameters" : [
    "Parameter_0", "Parameter_1", "Parameter_2", "Parameter_9",
    "Parameter_10"
  ],
  "result" : ["Result_46"]
}]