Commit 530e8023 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Fix cycle detection during sub-graph construction (#3435)

* Fix cycle detection during sub-graph construction

* small refactor

* style-apply
parent e5d606b8
...@@ -45,6 +45,9 @@ using namespace ngraph::pass; ...@@ -45,6 +45,9 @@ using namespace ngraph::pass;
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
// Maximum depth to check for cycles. If exceeded, we conservatively assume a cycle.
#define MAX_CYCLE_DEPTH 100
int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0; int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0;
template <typename T> template <typename T>
...@@ -158,7 +161,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -158,7 +161,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
// we have sub-graphs. // we have sub-graphs.
// check if adding this node to the sub-graph will create a cycle in the DAG // check if adding this node to the sub-graph will create a cycle in the DAG
NGRAPH_DEBUG << "[CK Extract] Extending sub-graph. Check for cycles " << std::endl; NGRAPH_DEBUG << "[CK Extract] Extending sub-graph. Check for cycles " << std::endl;
if (!check_cycles(inputs, subgraph_ids)) if (!check_cycles(op, subgraph_ids))
{ {
NGRAPH_DEBUG << "[CK Extract] Merging subgraphs"; NGRAPH_DEBUG << "[CK Extract] Merging subgraphs";
// merge sub-graphs if needed // merge sub-graphs if needed
...@@ -351,27 +354,36 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -351,27 +354,36 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
return true; return true;
} }
bool MLIRSubgraphExtractionPass::check_cycles(NodeVector& inputs, bool MLIRSubgraphExtractionPass::check_cycles(std::shared_ptr<Node> node,
std::unordered_set<int>& subgraph_ids) std::unordered_set<int>& subgraph_ids,
bool inside_subgraphs,
unsigned depth)
{ {
NodeVector work_list; // Going too deep, bail out.
NGRAPH_DEBUG << "[CK Extract] Inputs size: " << inputs.size() << std::endl; if (depth >= MAX_CYCLE_DEPTH)
work_list.insert(work_list.end(), inputs.begin(), inputs.end()); return true;
while (!work_list.empty())
// root node is always inside merged sub-graphs.
if (depth != 0)
{ {
auto node = work_list.back();
work_list.pop_back();
if (subgraph_ids.find(get_subgraph_id(node)) != subgraph_ids.end()) if (subgraph_ids.find(get_subgraph_id(node)) != subgraph_ids.end())
{ {
// we hit one of the sub-graphs we want to extend. we have a cycle. // This node is inside a sub-graph. If we are coming from outside the sub-graphs, then we formed a cycle.
NGRAPH_DEBUG << "[CK Extract] Cycle found when trying to add node" << std::endl; if (!inside_subgraphs)
return true; {
return true;
}
} }
for (auto pred : node->get_arguments()) else
{ {
work_list.push_back(pred); inside_subgraphs = false;
} }
} }
for (auto& input : node->get_arguments())
{
if (check_cycles(input, subgraph_ids, inside_subgraphs, ++depth))
return true;
}
return false; return false;
} }
......
...@@ -104,8 +104,8 @@ namespace ngraph ...@@ -104,8 +104,8 @@ namespace ngraph
/// Checks if adding a node to an extracted sub-graph will cause a DAG cycle /// Checks if adding a node to an extracted sub-graph will cause a DAG cycle
/// inputs: the list of input nodes outside sub-graphs to the node we want to add. /// inputs: the list of input nodes outside sub-graphs to the node we want to add.
/// subgraph_ids: the sub-graphs the predecessor nodes belong to. /// subgraph_ids: the sub-graphs the predecessor nodes belong to.
/// It traverses backwards from all input nodes and checks if we reach any node that already /// It traverses backwards from all input nodes and checks if we left the to-be-merged sub-graphs
/// belongs to one of the sub-graph ids. If so, we have a cycle. /// and entered again. If so, we have a cycle.
/// ///
/// Example: /// Example:
/// A(1) /// A(1)
...@@ -115,7 +115,10 @@ namespace ngraph ...@@ -115,7 +115,10 @@ namespace ngraph
/// D /// D
/// we want to add D to sub-graph 1. C is an input to D. sugraph_ids are 1 /// we want to add D to sub-graph 1. C is an input to D. sugraph_ids are 1
/// we traverse backwards C->A(1) and find 1, then we cannot add D since we will form a cycle /// we traverse backwards C->A(1) and find 1, then we cannot add D since we will form a cycle
bool check_cycles(NodeVector& inputs, std::unordered_set<int>& subgraph_ids); bool check_cycles(std::shared_ptr<Node> node,
std::unordered_set<int>& subgraph_ids,
bool inside_subgraphs = true,
unsigned depth = 0);
private: private:
static const std::set<std::type_index> m_supported_ops; static const std::set<std::type_index> m_supported_ops;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment