Commit 530e8023 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Fix cycle detection during sub-graph construction (#3435)

* Fix cycle detection during sub-graph construction

* small refactor

* style-apply
parent e5d606b8
......@@ -45,6 +45,9 @@ using namespace ngraph::pass;
#define TI(x) std::type_index(typeid(x))
// Maximum depth to check for cycles. If exceeded, we conservatively assume a cycle.
#define MAX_CYCLE_DEPTH 100
int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0;
template <typename T>
......@@ -158,7 +161,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
// we have sub-graphs.
// check if adding this node to the sub-graph will create a cycle in the DAG
NGRAPH_DEBUG << "[CK Extract] Extending sub-graph. Check for cycles " << std::endl;
if (!check_cycles(inputs, subgraph_ids))
if (!check_cycles(op, subgraph_ids))
{
NGRAPH_DEBUG << "[CK Extract] Merging subgraphs";
// merge sub-graphs if needed
......@@ -351,26 +354,35 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
return true;
}
bool MLIRSubgraphExtractionPass::check_cycles(NodeVector& inputs,
std::unordered_set<int>& subgraph_ids)
bool MLIRSubgraphExtractionPass::check_cycles(std::shared_ptr<Node> node,
std::unordered_set<int>& subgraph_ids,
bool inside_subgraphs,
unsigned depth)
{
NodeVector work_list;
NGRAPH_DEBUG << "[CK Extract] Inputs size: " << inputs.size() << std::endl;
work_list.insert(work_list.end(), inputs.begin(), inputs.end());
while (!work_list.empty())
// Going too deep, bail out.
if (depth >= MAX_CYCLE_DEPTH)
return true;
// root node is always inside merged sub-graphs.
if (depth != 0)
{
auto node = work_list.back();
work_list.pop_back();
if (subgraph_ids.find(get_subgraph_id(node)) != subgraph_ids.end())
{
// we hit one of the sub-graphs we want to extend. we have a cycle.
NGRAPH_DEBUG << "[CK Extract] Cycle found when trying to add node" << std::endl;
// This node is inside a sub-graph. If we are coming from outside the sub-graphs, then we formed a cycle.
if (!inside_subgraphs)
{
return true;
}
for (auto pred : node->get_arguments())
}
else
{
work_list.push_back(pred);
inside_subgraphs = false;
}
}
for (auto& input : node->get_arguments())
{
if (check_cycles(input, subgraph_ids, inside_subgraphs, ++depth))
return true;
}
return false;
}
......
......@@ -104,8 +104,8 @@ namespace ngraph
/// Checks if adding a node to an extracted sub-graph will cause a DAG cycle
/// inputs: the list of input nodes outside sub-graphs to the node we want to add.
/// subgraph_ids: the sub-graphs the predecessor nodes belong to.
/// It traverses backwards from all input nodes and checks if we reach any node that already
/// belongs to one of the sub-graph ids. If so, we have a cycle.
/// It traverses backwards from all input nodes and checks if we left the to-be-merged sub-graphs
/// and entered again. If so, we have a cycle.
///
/// Example:
/// A(1)
......@@ -115,7 +115,10 @@ namespace ngraph
/// D
/// we want to add D to sub-graph 1. C is an input to D. sugraph_ids are 1
/// we traverse backwards C->A(1) and find 1, then we cannot add D since we will form a cycle
bool check_cycles(NodeVector& inputs, std::unordered_set<int>& subgraph_ids);
bool check_cycles(std::shared_ptr<Node> node,
std::unordered_set<int>& subgraph_ids,
bool inside_subgraphs = true,
unsigned depth = 0);
private:
static const std::set<std::type_index> m_supported_ops;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment