Commit bc968cd0 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

[MLIR] Modify mlir subgraph extraction pass. (#3675)

* [MLIR] Modify mlir subgraph extraction pass.

* Address PR feedback.

* Modify sub-graph construction algorithm.

* Address PR feedback.

* Address PR feedback.

* Change macro to function.
parent 2856e1c1
...@@ -106,17 +106,19 @@ MLIRSubgraphExtractionPass::MLIRSubgraphExtractionPass() ...@@ -106,17 +106,19 @@ MLIRSubgraphExtractionPass::MLIRSubgraphExtractionPass()
} }
// The sub-graph construction algorithm is as follows // The sub-graph construction algorithm is as follows
// For each node, check its predecessors, if // Construct a map of node to number of its input not being processes
// - all predecessors in sub-graphs belong to the same sub-graph (graph ID), then extend the // Put the node with value 0 into a ready list
// sub-graph to include the current node. // Go through the nodes in the ready list until the list is empty:
// Predecessors outside sub-graphs are marked as input to the sub-graph. // - if the last node processed is supported, try to find a supported node and add that node to the
// - predecessors in sub-graphs belong to different sub-graphs, then merge all the sub-graphs into // current sub-graph.
// one, and add current node to it. Predecessors outside sub-graphs are marked as input to the // - if none is available, process an unsupported node.
// - if the last node processed is unsupported, try to find an unsupported node.
// - if none is available, start a new sub-graph, find a supported node and add that node to the new
// sub-graph. // sub-graph.
// - Erase processed node form the ready list, update the value of its successors in the map, and
// add its successor to ready list if value is 0.
// //
// If the node has any external inputs, then it's possible that the input may come from one of the // Sub-graph may contain multiple disjoint clusters.
// predecessor sub-graphs (cycle).
// If a cycle is found, always start a new sub-graph.
// //
// For each sub-graph found build a CompiledKernel(CK) node around it as follows // For each sub-graph found build a CompiledKernel(CK) node around it as follows
// - all inputs edges to the sub-graph are cloned as inputs to CK node as well. // - all inputs edges to the sub-graph are cloned as inputs to CK node as well.
...@@ -136,89 +138,172 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -136,89 +138,172 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
return true; return true;
} }
void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func) static void
process_successors(std::shared_ptr<ngraph::Node> node,
std::unordered_map<std::shared_ptr<ngraph::Node>, size_t>& node_to_size_map,
std::list<std::shared_ptr<ngraph::Node>>& nodes_ready)
{ {
NGRAPH_DEBUG << "[CK Extract] Construct sub-graphs"; for (auto output : node->outputs())
for (auto op : func->get_ordered_ops())
{ {
NodeVector inputs; for (auto input : output.get_target_inputs())
std::unordered_set<int> subgraph_ids;
// unsupported ops, skip
if (!is_supported_mlir_op(op))
{ {
continue; auto user = input.get_node()->shared_from_this();
node_to_size_map[user]--;
if (node_to_size_map[user] == 0)
{
nodes_ready.push_back(user);
}
} }
if (TI(Parameter) == TI(*op) || TI(Result) == TI(*op)) }
}
void MLIRSubgraphExtractionPass::process_supported_op(std::shared_ptr<ngraph::Node> node,
int current_subgraph_id)
{
NodeVector inputs;
for (auto pred : node->get_arguments())
{
int pred_subgraph_id = get_subgraph_id(pred);
if (pred_subgraph_id != current_subgraph_id)
{ {
continue; // predecessor doesn't belong to current sub-graph, it is an
// input
inputs.push_back(pred);
} }
}
// add inputs and op to current sub-graph
MLIRSubgraph& current_subgraph = get_subgraph(current_subgraph_id);
current_subgraph.add_node(node);
current_subgraph.add_inputs(inputs);
NGRAPH_DEBUG << "[CK Extract] Node Processed " << *node;
}
static void erase_node(std::list<std::shared_ptr<ngraph::Node>>::iterator& it,
std::list<std::shared_ptr<ngraph::Node>>& nodes_ready)
{
auto old_it = it;
it++;
nodes_ready.erase(old_it);
}
NGRAPH_DEBUG << "[CK Extract] Processing " << *op; void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
// supported op {
for (auto pred : op->get_arguments()) NGRAPH_DEBUG << "[CK Extract] Construct sub-graphs";
int current_subgraph_id = 0;
std::unordered_map<std::shared_ptr<Node>, size_t> node_to_size_map;
std::list<std::shared_ptr<Node>> nodes_ready;
bool last_op_is_supported = false;
for (auto op : func->get_ops())
{
size_t arg_count = op->get_input_size();
node_to_size_map[op] = arg_count;
if (arg_count == 0)
{
nodes_ready.push_back(op);
}
}
bool change_mode = false;
while (!nodes_ready.empty())
{
for (auto it = nodes_ready.begin(); it != nodes_ready.end();)
{ {
int pred_subgraph_id = get_subgraph_id(pred); auto node = *it;
if (pred_subgraph_id == -1) if (TI(Result) == TI(*node))
{ {
// predecessor doesn't belong to any sub-graph, it is an input erase_node(it, nodes_ready);
inputs.push_back(pred); }
else if (TI(Parameter) == TI(*node))
{
process_successors(node, node_to_size_map, nodes_ready);
erase_node(it, nodes_ready);
}
else if (is_supported_mlir_op(node))
{
if (last_op_is_supported)
{
process_supported_op(node, current_subgraph_id);
process_successors(node, node_to_size_map, nodes_ready);
erase_node(it, nodes_ready);
change_mode = false;
}
else
{
change_mode = true;
it++;
}
} }
else else
{ {
// record sub-graph id of the predecessor if (last_op_is_supported)
subgraph_ids.insert(pred_subgraph_id); {
change_mode = true;
it++;
}
else
{
process_successors(node, node_to_size_map, nodes_ready);
erase_node(it, nodes_ready);
change_mode = false;
}
} }
} }
if (subgraph_ids.size() == 0)
{ if (change_mode)
// we couldn't find any predecessor sub-graphs to extend with this node
// create a new sub-graph
MLIRSubgraph sg = MLIRSubgraph::create(this);
sg.add_inputs(inputs);
sg.add_node(op);
add_subgraph(sg);
NGRAPH_DEBUG << " [CK Extract] Start new sub-graph " << sg.get_id();
}
else
{ {
// we have sub-graphs. if (last_op_is_supported)
// check if adding this node to the sub-graph will create a cycle in the DAG
NGRAPH_DEBUG << " [CK Extract] Extending sub-graph. Check for cycles";
if (!check_cycles(op, subgraph_ids))
{ {
NGRAPH_DEBUG << " [CK Extract] Merging subgraphs "; for (auto it = nodes_ready.begin(); it != nodes_ready.end();)
// merge sub-graphs if needed
std::unordered_set<int>::iterator it = subgraph_ids.begin();
int sg_id = *it;
MLIRSubgraph& first_subgraph = get_subgraph(sg_id);
NGRAPH_CHECK(first_subgraph.get_id() == sg_id);
NGRAPH_DEBUG << " Graph ID: " << sg_id;
while (++it != subgraph_ids.end())
{ {
sg_id = *it; auto node = *it;
NGRAPH_DEBUG << " Graph ID: " << sg_id; if (TI(Result) == TI(*node))
MLIRSubgraph& subgraph = get_subgraph(sg_id); {
NGRAPH_CHECK(subgraph.get_id() == sg_id); erase_node(it, nodes_ready);
first_subgraph.merge(subgraph); }
else if (!is_supported_mlir_op(node))
{
process_successors(node, node_to_size_map, nodes_ready);
erase_node(it, nodes_ready);
change_mode = false;
last_op_is_supported = false;
}
else
{
it++;
}
} }
first_subgraph.add_node(op);
first_subgraph.add_inputs(inputs);
} }
else else
{ {
// we have a cycle, start a new sub-graph // create a new sub-graph
MLIRSubgraph sg = MLIRSubgraph::create(this); MLIRSubgraph sg = MLIRSubgraph::create(this);
NGRAPH_DEBUG << " [CK Extract] Cycle found. Start a new subgraph " NGRAPH_DEBUG << " [CK Extract] Start new sub-graph " << sg.get_id();
<< sg.get_id();
// use all predecessors as graph inputs
NodeVector inputs = op->get_arguments();
sg.add_inputs(inputs);
sg.add_node(op);
add_subgraph(sg); add_subgraph(sg);
current_subgraph_id = sg.get_id();
for (auto it = nodes_ready.begin(); it != nodes_ready.end();)
{
auto node = *it;
if (TI(Result) == TI(*node))
{
erase_node(it, nodes_ready);
}
else if (is_supported_mlir_op(node))
{
process_supported_op(node, current_subgraph_id);
process_successors(node, node_to_size_map, nodes_ready);
erase_node(it, nodes_ready);
change_mode = false;
last_op_is_supported = true;
}
else
{
it++;
}
}
} }
} }
NGRAPH_DEBUG << "[CK Extract] Node Processed " << *op;
} }
NGRAPH_DEBUG << "[CK Extract] Get subgraphs output nodes"; NGRAPH_DEBUG << "[CK Extract] Get subgraphs output nodes";
...@@ -456,54 +541,6 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -456,54 +541,6 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
return true; return true;
} }
bool MLIRSubgraphExtractionPass::check_cycles(std::shared_ptr<Node> node,
std::unordered_set<int>& subgraph_ids,
bool inside_subgraphs,
unsigned depth)
{
// Going too deep, bail out.
if (depth >= m_max_cycle_depth)
return true;
// root node is always inside merged sub-graphs.
if (depth != 0)
{
if (subgraph_ids.find(get_subgraph_id(node)) != subgraph_ids.end())
{
// This node is inside a sub-graph. If we are coming from outside the sub-graphs, then
// we formed a cycle.
if (!inside_subgraphs)
{
return true;
}
}
else
{
inside_subgraphs = false;
}
}
NodeVector node_inputs;
auto subgraph_id = get_subgraph_id(node);
// If the node is part of a sub-graph, capture all of sub-graph inputs, else only node's input
if (subgraph_id >= 0)
{
auto subgraph = get_subgraph(subgraph_id);
auto subgraph_inputs = subgraph.get_inputs();
node_inputs.insert(node_inputs.end(), subgraph_inputs.begin(), subgraph_inputs.end());
}
else
{
node_inputs = node->get_arguments();
}
for (auto& input : node_inputs)
{
if (check_cycles(input, subgraph_ids, inside_subgraphs, ++depth))
return true;
}
return false;
}
void MLIRSubgraphExtractionPass::clean_up() void MLIRSubgraphExtractionPass::clean_up()
{ {
m_id_to_graph.clear(); m_id_to_graph.clear();
......
...@@ -103,29 +103,10 @@ namespace ngraph ...@@ -103,29 +103,10 @@ namespace ngraph
} }
/// Stores a sub-graph in the map /// Stores a sub-graph in the map
void add_subgraph(MLIRSubgraph& sg) { m_id_to_graph.emplace(sg.get_id(), sg); } void add_subgraph(MLIRSubgraph& sg) { m_id_to_graph.emplace(sg.get_id(), sg); }
/// Checks if adding a node to an extracted sub-graph will cause a DAG cycle
/// inputs: the list of input nodes outside sub-graphs to the node we want to add.
/// subgraph_ids: the sub-graphs the predecessor nodes belong to.
/// It traverses backwards from all input nodes and checks if we left the to-be-merged
/// sub-graphs and entered again. If so, we have a cycle.
///
/// Example:
/// A(1)
/// | \
/// B(1) C
/// | /
/// D
/// we want to add D to sub-graph 1. C is an input to D. sugraph_ids are 1
/// we traverse backwards C->A(1) and find 1, then we cannot add D since we will form a
/// cycle
bool check_cycles(std::shared_ptr<Node> node,
std::unordered_set<int>& subgraph_ids,
bool inside_subgraphs = true,
unsigned depth = 0);
private: private:
void build_subgraphs(std::shared_ptr<Function> func); void build_subgraphs(std::shared_ptr<Function> func);
NodeVector build_ck_nodes(std::shared_ptr<Function> func); NodeVector build_ck_nodes(std::shared_ptr<Function> func);
void process_supported_op(std::shared_ptr<ngraph::Node> node, int current_subgraph_id);
void sanity_check(std::shared_ptr<Function> func, NodeVector& ck_nodes); void sanity_check(std::shared_ptr<Function> func, NodeVector& ck_nodes);
void clean_up(); void clean_up();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment