mlir_subgraph_extraction.hpp 5.71 KB
Newer Older
1
//*****************************************************************************
nmostafa's avatar
nmostafa committed
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15 16
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

17 18 19
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.

20 21
#pragma once

22
#include <mutex>
23 24 25 26 27
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
    namespace pass
    {
28
        /// This pass creates CompiledKernel ops enclosing maximal sub-graphs of ops that are supported by MLIR
29 30
        class MLIRSubgraphExtractionPass : public ngraph::pass::FunctionPass
        {
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
            using NodeSet = std::unordered_set<std::shared_ptr<Node>>;

            class MLIRSubgraph
            {
            private:
                static int get_new_graph_id() { return m_curr_graph_id++; }
                /// Create a sub-graph with a new ID.
                MLIRSubgraph(MLIRSubgraphExtractionPass* pass)
                    : m_graph_id(MLIRSubgraph::get_new_graph_id())
                    , m_pass(*pass)
                {
                }

            public:
                /// Factory method to creates a new sub-graph with unique ID
                static MLIRSubgraph create(MLIRSubgraphExtractionPass* pass)
                {
                    // mutex on global graph ID
                    std::lock_guard<std::mutex> lock(pass->m_subgraph_mutex);
                    return MLIRSubgraph(pass);
                }
                /// Get sub-graph id
                int get_id() const { return m_graph_id; }
                /// Get all nodes in the sub-graph.
                NodeSet& get_nodes() { return m_nodes; }
                /// Get input nodes. Predecessors to head nodes.
                NodeSet& get_inputs() { return m_input_nodes; }
                /// Get output nodes. Nodes in the sub-graph with edges to external nodes.
                NodeSet& get_outputs() { return m_output_nodes; }
                /// Add a list of input nodes to the sub-graph.
                template <typename T>
                void add_inputs(T& inputs);
                /// Add a list of output nodes to the sub-graph.
                template <typename T>
                void add_outputs(T& outputs);
                /// Merges sub-graph (other) into this sub-graph. other will be destroyed.
                void merge(MLIRSubgraph& other);
                /// Add one node to the sub-graph.
                void add_node(std::shared_ptr<Node> node);

            private:
                // Unique ID for this sub-graph.
                int m_graph_id;
                // Actual nodes of the sub-graph
                NodeSet m_nodes;
                // Predecessor to head nodes in the sub-graph.
                NodeSet m_input_nodes;
                NodeSet m_output_nodes;
                MLIRSubgraphExtractionPass& m_pass;
                static int m_curr_graph_id;
            };
            friend class MLIRSubgraph;

84 85 86
        public:
            MLIRSubgraphExtractionPass() {}
            bool run_on_function(std::shared_ptr<Function> func) override;
87 88
            /// Checks if an ngraph node is supported by MLIR backend
            bool is_supported_mlir_op(std::shared_ptr<Node> node);
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
            /// Get the sub-graph ID that a node belongs to
            int get_subgraph_id(std::shared_ptr<Node> node)
            {
                auto it = m_node_to_graph.find(node);
                return (it == m_node_to_graph.end()) ? -1 : it->second;
            }
            /// Get sub-graph by ID
            MLIRSubgraph& get_subgraph(int id)
            {
                auto it = m_id_to_graph.find(id);
                NGRAPH_CHECK(it != m_id_to_graph.end(), "Cannot find subgraph with ID: ", id);
                return it->second;
            }
            /// Stores a sub-graph in the map
            void add_subgraph(MLIRSubgraph& sg) { m_id_to_graph.emplace(sg.get_id(), sg); }
            /// Checks if adding a node to an extracted sub-graph will cause a DAG cycle
            /// inputs: the list of input nodes outside sub-graphs to the node we want to add.
            /// subgraph_ids: the sub-graphs the predecessor nodes belong to.
            /// It traverses backwards from all input nodes and checks if we reach any node that already
            /// belongs to one of the sub-graph ids. If so, we have a cycle.
            ///
            /// Example:
            /// A(1)
            /// |   \
            /// B(1) C
            /// |  /
            /// D
            /// we want to add D to sub-graph 1. C is an input to D. sugraph_ids are 1
            /// we traverse backwards C->A(1) and find 1, then we cannot add D since we will form a cycle
            bool check_cycles(NodeVector& inputs, std::unordered_set<int>& subgraph_ids);
119 120 121

        private:
            static const std::set<std::type_index> m_supported_ops;
122 123 124 125 126 127 128 129

        private:
            using IDGraphMap = std::unordered_map<int, MLIRSubgraph>;
            using NodeGraphMap = std::unordered_map<std::shared_ptr<Node>, int>;
            IDGraphMap m_id_to_graph;
            NodeGraphMap m_node_to_graph;
            // Mutex over sub-graph IDs
            std::mutex m_subgraph_mutex;
130 131 132
        };
    }
}