mlir_subgraph_extraction.hpp 5.71 KB
Newer Older
nmostafa's avatar
nmostafa committed
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

17 18 19
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.

20 21
#pragma once

#include <mutex>
23 24 25 26 27
#include "ngraph/pass/pass.hpp"
namespace ngraph
    namespace pass
        /// This pass creates CompiledKernel ops enclosing maximal sub-graphs of ops that are supported by MLIR
29 30
        class MLIRSubgraphExtractionPass : public ngraph::pass::FunctionPass
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
            using NodeSet = std::unordered_set<std::shared_ptr<Node>>;

            class MLIRSubgraph
                static int get_new_graph_id() { return m_curr_graph_id++; }
                /// Create a sub-graph with a new ID.
                MLIRSubgraph(MLIRSubgraphExtractionPass* pass)
                    : m_graph_id(MLIRSubgraph::get_new_graph_id())
                    , m_pass(*pass)

                /// Factory method to creates a new sub-graph with unique ID
                static MLIRSubgraph create(MLIRSubgraphExtractionPass* pass)
                    // mutex on global graph ID
                    std::lock_guard<std::mutex> lock(pass->m_subgraph_mutex);
                    return MLIRSubgraph(pass);
                /// Get sub-graph id
                int get_id() const { return m_graph_id; }
                /// Get all nodes in the sub-graph.
                NodeSet& get_nodes() { return m_nodes; }
                /// Get input nodes. Predecessors to head nodes.
                NodeSet& get_inputs() { return m_input_nodes; }
                /// Get output nodes. Nodes in the sub-graph with edges to external nodes.
                NodeSet& get_outputs() { return m_output_nodes; }
                /// Add a list of input nodes to the sub-graph.
                template <typename T>
                void add_inputs(T& inputs);
                /// Add a list of output nodes to the sub-graph.
                template <typename T>
                void add_outputs(T& outputs);
                /// Merges sub-graph (other) into this sub-graph. other will be destroyed.
                void merge(MLIRSubgraph& other);
                /// Add one node to the sub-graph.
                void add_node(std::shared_ptr<Node> node);

                // Unique ID for this sub-graph.
                int m_graph_id;
                // Actual nodes of the sub-graph
                NodeSet m_nodes;
                // Predecessor to head nodes in the sub-graph.
                NodeSet m_input_nodes;
                NodeSet m_output_nodes;
                MLIRSubgraphExtractionPass& m_pass;
                static int m_curr_graph_id;
            friend class MLIRSubgraph;

84 85 86
            MLIRSubgraphExtractionPass() {}
            bool run_on_function(std::shared_ptr<Function> func) override;
87 88
            /// Checks if an ngraph node is supported by MLIR backend
            bool is_supported_mlir_op(std::shared_ptr<Node> node);
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
            /// Get the sub-graph ID that a node belongs to
            int get_subgraph_id(std::shared_ptr<Node> node)
                auto it = m_node_to_graph.find(node);
                return (it == m_node_to_graph.end()) ? -1 : it->second;
            /// Get sub-graph by ID
            MLIRSubgraph& get_subgraph(int id)
                auto it = m_id_to_graph.find(id);
                NGRAPH_CHECK(it != m_id_to_graph.end(), "Cannot find subgraph with ID: ", id);
                return it->second;
            /// Stores a sub-graph in the map
            void add_subgraph(MLIRSubgraph& sg) { m_id_to_graph.emplace(sg.get_id(), sg); }
            /// Checks if adding a node to an extracted sub-graph will cause a DAG cycle
            /// inputs: the list of input nodes outside sub-graphs to the node we want to add.
            /// subgraph_ids: the sub-graphs the predecessor nodes belong to.
            /// It traverses backwards from all input nodes and checks if we reach any node that already
            /// belongs to one of the sub-graph ids. If so, we have a cycle.
            /// Example:
            /// A(1)
            /// |   \
            /// B(1) C
            /// |  /
            /// D
            /// we want to add D to sub-graph 1. C is an input to D. sugraph_ids are 1
            /// we traverse backwards C->A(1) and find 1, then we cannot add D since we will form a cycle
            bool check_cycles(NodeVector& inputs, std::unordered_set<int>& subgraph_ids);
119 120 121

            static const std::set<std::type_index> m_supported_ops;
122 123 124 125 126 127 128 129

            using IDGraphMap = std::unordered_map<int, MLIRSubgraph>;
            using NodeGraphMap = std::unordered_map<std::shared_ptr<Node>, int>;
            IDGraphMap m_id_to_graph;
            NodeGraphMap m_node_to_graph;
            // Mutex over sub-graph IDs
            std::mutex m_subgraph_mutex;
130 131 132