Commit ef58667f authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Fixes for cpu_fusion.validate_fuse_gru_inputs (#3511)

* WIP

* Fix incorrect CK output adjustment

* Bug fix and enroce sanity check

* Change cycle search depth, and fix sanity check

* cpu_fusion.validate_fuse_gru_inputs passes.

* Fix as_single_output to be able to always create a GOE

* minor fix. style-apply

* Clean up debug msgs

* Switch to backward cycle check

* Enable failing test

* PR fixes

* Address feedback: Add fwd cycle checks. Make cycle checking depth configurable
parent a1f3202c
......@@ -45,9 +45,6 @@ using namespace ngraph::pass;
#define TI(x) std::type_index(typeid(x))
// Maximum depth to check for cycles. If exceeded, we conservatively assume a cycle.
#define MAX_CYCLE_DEPTH 100
int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0;
template <typename T>
......@@ -99,6 +96,15 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2)
m_pass.m_id_to_graph.erase(sg2.get_id());
}
MLIRSubgraphExtractionPass::MLIRSubgraphExtractionPass()
: m_max_cycle_depth(20)
{
if (char* max_cycle_depth = std::getenv("NGRAPH_MLIR_MAX_CYCLE_DEPTH"))
{
m_max_cycle_depth = std::stoi(max_cycle_depth);
}
}
// The sub-graph construction algorithm is as follows
// For each node, check its predecessors, if
// - all predecessors in sub-graphs belong to the same sub-graph (graph ID), then extend the
......@@ -118,7 +124,19 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2)
// - CK will internally have lists record graph nodes, and graph output nodes.
bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
{
NGRAPH_DEBUG << "[CK Extract] Construct sub-graphs" << std::endl;
build_subgraphs(func);
auto ck_nodes = build_ck_nodes(func);
#ifdef NGRAPH_DEBUG_ENABLE
sanity_check(func, ck_nodes);
#endif
return true;
}
void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
{
NGRAPH_DEBUG << "[CK Extract] Construct sub-graphs";
for (auto op : func->get_ordered_ops())
{
NodeVector inputs;
......@@ -133,7 +151,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
continue;
}
NGRAPH_DEBUG << "[CK Extract] Processing " << *op << std::endl;
NGRAPH_DEBUG << "[CK Extract] Processing " << *op;
// supported op
for (auto pred : op->get_arguments())
{
......@@ -151,30 +169,32 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
}
if (subgraph_ids.size() == 0)
{
NGRAPH_DEBUG << "[CK Extract] Start new sub-graph " << std::endl;
// we couldn't find any predecessor sub-graphs to extend with this node
// create a new sub-graph
MLIRSubgraph sg = MLIRSubgraph::create(this);
sg.add_inputs(inputs);
sg.add_node(op);
add_subgraph(sg);
NGRAPH_DEBUG << " [CK Extract] Start new sub-graph " << sg.get_id();
}
else
{
// we have sub-graphs.
// check if adding this node to the sub-graph will create a cycle in the DAG
NGRAPH_DEBUG << "[CK Extract] Extending sub-graph. Check for cycles " << std::endl;
NGRAPH_DEBUG << " [CK Extract] Extending sub-graph. Check for cycles";
if (!check_cycles(op, subgraph_ids))
{
NGRAPH_DEBUG << "[CK Extract] Merging subgraphs";
NGRAPH_DEBUG << " [CK Extract] Merging subgraphs ";
// merge sub-graphs if needed
std::unordered_set<int>::iterator it = subgraph_ids.begin();
int sg_id = *it;
MLIRSubgraph& first_subgraph = get_subgraph(sg_id);
NGRAPH_CHECK(first_subgraph.get_id() == sg_id);
NGRAPH_DEBUG << " Graph ID: " << sg_id;
while (++it != subgraph_ids.end())
{
sg_id = *it;
NGRAPH_DEBUG << " Graph ID: " << sg_id;
MLIRSubgraph& subgraph = get_subgraph(sg_id);
NGRAPH_CHECK(subgraph.get_id() == sg_id);
first_subgraph.merge(subgraph);
......@@ -187,7 +207,8 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
{
// we have a cycle, start a new sub-graph
MLIRSubgraph sg = MLIRSubgraph::create(this);
NGRAPH_DEBUG << "[CK Extract] Cycle found. Start a new subgraph";
NGRAPH_DEBUG << " [CK Extract] Cycle found. Start a new subgraph "
<< sg.get_id();
// use all predecessors as graph inputs
NodeVector inputs = op->get_arguments();
sg.add_inputs(inputs);
......@@ -195,10 +216,10 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
add_subgraph(sg);
}
}
NGRAPH_DEBUG << "[CK Extract] Node Processed " << *op << std::endl;
NGRAPH_DEBUG << "[CK Extract] Node Processed " << *op;
}
NGRAPH_DEBUG << "[CK Extract] Get subgraphs output nodes" << std::endl;
NGRAPH_DEBUG << "[CK Extract] Get subgraphs output nodes";
// get output nodes for each sub-graph. Do this before attaching CK nodes since we will
// remove output edges from the sub-graphs.
for (IDGraphMap::iterator it = m_id_to_graph.begin(); it != m_id_to_graph.end(); it++)
......@@ -211,8 +232,12 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
false /* ignore output duplicates */));
sg.add_outputs(outputs);
}
}
NGRAPH_DEBUG << "[CK Extract] Construct CK nodes" << std::endl;
ngraph::NodeVector MLIRSubgraphExtractionPass::build_ck_nodes(std::shared_ptr<Function> func)
{
NodeVector ck_nodes;
NGRAPH_DEBUG << "[CK Extract] Construct CK nodes";
// attach CK node to each sub-graph.
for (auto it : m_id_to_graph)
{
......@@ -228,27 +253,40 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
NodeVector nodes_vector(nodes_list.begin(), nodes_list.end());
auto ck = std::make_shared<CompiledKernel>(nodes_vector, outputs_vector, inputs_vector);
NGRAPH_DEBUG << "[CK Extract] Graph ID = " << sg.get_id() << std::endl;
NGRAPH_DEBUG << "[CK Extract] Graph Nodes: " << std::endl;
ck_nodes.push_back(ck);
NGRAPH_DEBUG << "[CK Extract] Graph ID = " << sg.get_id();
NGRAPH_DEBUG << " [CK Extract] Graph Nodes: ";
for (auto node : nodes)
{
NGRAPH_DEBUG << "[CK Extract] " << *node << std::endl;
NGRAPH_DEBUG << " [CK Extract] " << *node;
}
NGRAPH_DEBUG << "[CK Extract] Input Nodes: " << std::endl;
NGRAPH_DEBUG << " [CK Extract] Input Nodes: ";
for (auto node : inputs)
{
NGRAPH_DEBUG << "[CK Extract] " << *node << std::endl;
NGRAPH_DEBUG << " [CK Extract] " << *node;
}
NGRAPH_DEBUG << "[CK Extract] Output Nodes: " << std::endl;
NGRAPH_DEBUG << " [CK Extract] Output Nodes: ";
for (auto node : outputs)
{
NGRAPH_DEBUG << "[CK Extract] " << *node << std::endl;
NGRAPH_DEBUG << " [CK Extract] " << *node;
;
}
NGRAPH_DEBUG << " [CK Extract] CK Node = " << *ck;
}
// Connect CompiledKernel to output nodes by replacing the output descriptors of the output
// Do this after all CK nodes are constructed since they add new edges in the graph (CK inputs)
for (auto& node : ck_nodes)
{
auto ck = std::static_pointer_cast<CompiledKernel>(node);
auto& outputs_vector = ck->get_kernel_outputs();
auto& node_list = ck->get_node_list();
std::unordered_set<std::shared_ptr<Node>> node_set(node_list.begin(), node_list.end());
// Connect CompiledKernel to output nodes by replacing the output descriptors of the output
// nodes.
for (size_t i = 0, end = outputs_vector.size(); i < end; ++i)
{
auto& output_descs = outputs_vector[i]->get_outputs();
......@@ -260,15 +298,88 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
for (descriptor::Input* in_desc : input_descs)
{
if (nodes.find(in_desc->get_node()) == nodes.end())
if (node_set.find(in_desc->get_node()) == node_set.end())
{
in_desc->replace_output(ck, i);
}
}
}
}
for (auto& node : ck_nodes)
{
auto ck = std::static_pointer_cast<CompiledKernel>(node);
if (ck->get_output_size() > 1)
{
for (auto& old_output : ck->outputs())
{
auto inputs = old_output.get_target_inputs();
auto goe_node = old_output.as_single_output_node(false);
auto new_output = goe_node->output(0);
for (auto& input : inputs)
{
input.replace_source_output(new_output);
}
}
}
}
return true;
return ck_nodes;
}
// Do a sanity check on graph invariants
// - no cycles
// - inputs to sub-graph are inputs to CK
// - no outputs out of subgraph for output nodes
void MLIRSubgraphExtractionPass::sanity_check(std::shared_ptr<Function> func, NodeVector& ck_nodes)
{
NodeVector cycles;
bool is_bkwd_cycle;
if (check_for_cycles(func.get(), cycles, is_bkwd_cycle))
{
NGRAPH_CHECK(cycles.size() != 0, "Empty cycle ?");
if (is_bkwd_cycle)
{
NGRAPH_DEBUG << "Backward cycle:";
}
for (auto& node : cycles)
{
NGRAPH_DEBUG << node;
}
NGRAPH_UNREACHABLE("Function contains cycle after subgraph constructions");
}
for (auto& node : ck_nodes)
{
auto ck_node = std::static_pointer_cast<CompiledKernel>(node);
auto& node_list = ck_node->get_node_list();
std::unordered_set<std::shared_ptr<Node>> node_set(node_list.begin(), node_list.end());
// CK output nodes shouldn't have any users outside the sub-graph,
// they are all moved to the CK node instead
for (auto& ck_output : ck_node->get_kernel_outputs())
{
for (auto& user : ck_output->get_users())
{
NGRAPH_CHECK(node_set.find(user) != node_set.end(),
"CK output nodes users should be in the sub-graph");
}
}
// Any input to CK must also have at least one user in the sub-graph body
for (auto& arg : ck_node->get_arguments())
{
bool found = false;
for (auto& user : arg->get_users())
{
found = (node_set.find(user) != node_set.end());
if (found)
{
break;
}
}
NGRAPH_CHECK(found, "CK input is not input to sub-graph");
}
}
}
#define TI(x) std::type_index(typeid(x))
......@@ -378,7 +489,7 @@ bool MLIRSubgraphExtractionPass::check_cycles(std::shared_ptr<Node> node,
unsigned depth)
{
// Going too deep, bail out.
if (depth >= MAX_CYCLE_DEPTH)
if (depth >= m_max_cycle_depth)
return true;
// root node is always inside merged sub-graphs.
......@@ -398,7 +509,21 @@ bool MLIRSubgraphExtractionPass::check_cycles(std::shared_ptr<Node> node,
inside_subgraphs = false;
}
}
for (auto& input : node->get_arguments())
NodeVector node_inputs;
auto subgraph_id = get_subgraph_id(node);
// If the node is part of a sub-graph, capture all of sub-graph inputs, else only node's input
if (subgraph_id >= 0)
{
auto subgraph = get_subgraph(subgraph_id);
auto subgraph_inputs = subgraph.get_inputs();
node_inputs.insert(node_inputs.end(), subgraph_inputs.begin(), subgraph_inputs.end());
}
else
{
node_inputs = node->get_arguments();
}
for (auto& input : node_inputs)
{
if (check_cycles(input, subgraph_ids, inside_subgraphs, ++depth))
return true;
......
......@@ -83,7 +83,7 @@ namespace ngraph
friend class MLIRSubgraph;
public:
MLIRSubgraphExtractionPass() {}
MLIRSubgraphExtractionPass();
bool run_on_function(std::shared_ptr<Function> func) override;
/// Checks if an ngraph node is supported by MLIR backend
bool is_supported_mlir_op(std::shared_ptr<Node> node);
......@@ -123,7 +123,10 @@ namespace ngraph
unsigned depth = 0);
private:
static const std::set<std::type_index> m_supported_ops;
void build_subgraphs(std::shared_ptr<Function> func);
NodeVector build_ck_nodes(std::shared_ptr<Function> func);
void sanity_check(std::shared_ptr<Function> func, NodeVector& ck_nodes);
private:
using IDGraphMap = std::unordered_map<int, MLIRSubgraph>;
......@@ -132,6 +135,10 @@ namespace ngraph
NodeGraphMap m_node_to_graph;
// Mutex over sub-graph IDs
std::mutex m_subgraph_mutex;
static const std::set<std::type_index> m_supported_ops;
// Maximum depth to check for cycles during merging of sub-graphs.
// If exceeded, we conservatively assume a cycle.
int m_max_cycle_depth;
};
}
}
......@@ -653,3 +653,93 @@ std::vector<Output<Node>> ngraph::get_outputs_to(Node& src, Node& dst)
return result;
}
static bool check_for_cycles_bkwd(std::shared_ptr<ngraph::Node> node,
std::deque<std::shared_ptr<ngraph::Node>>& path,
std::unordered_set<std::shared_ptr<ngraph::Node>>& path_set,
ngraph::NodeVector& cycle_nodes)
{
path.push_back(node);
path_set.insert(node);
for (auto& input : node->inputs())
{
auto arg = input.get_source_output().get_node_shared_ptr();
if (path_set.find(arg) != path_set.end())
{
for (auto it : path)
{
cycle_nodes.push_back(it);
}
// last node
cycle_nodes.push_back(arg);
return true;
}
if (check_for_cycles_bkwd(arg, path, path_set, cycle_nodes))
{
return true;
}
}
path_set.erase(path.back());
path.pop_back();
return false;
}
static bool check_for_cycles_fwd(std::shared_ptr<ngraph::Node> node,
std::deque<std::shared_ptr<ngraph::Node>>& path,
std::unordered_set<std::shared_ptr<ngraph::Node>>& path_set,
ngraph::NodeVector& cycle_nodes)
{
path.push_back(node);
path_set.insert(node);
for (auto& arg : node->get_users())
{
if (path_set.find(arg) != path_set.end())
{
for (auto it : path)
{
cycle_nodes.push_back(it);
}
// last node
cycle_nodes.push_back(arg);
return true;
}
if (check_for_cycles_fwd(arg, path, path_set, cycle_nodes))
{
return true;
}
}
path_set.erase(path.back());
path.pop_back();
return false;
}
bool ngraph::check_for_cycles(const ngraph::Function* func,
ngraph::NodeVector& cycle_nodes,
bool& is_bkwd_cycle)
{
for (auto res : func->get_results())
{
std::deque<std::shared_ptr<Node>> path;
// mirror of path stack for faster cycle check
std::unordered_set<std::shared_ptr<Node>> path_set;
if (check_for_cycles_bkwd(res, path, path_set, cycle_nodes))
{
is_bkwd_cycle = true;
return true;
};
}
for (auto param : func->get_parameters())
{
std::deque<std::shared_ptr<Node>> path;
// mirror of path stack for faster cycle check
std::unordered_set<std::shared_ptr<Node>> path_set;
if (check_for_cycles_fwd(param, path, path_set, cycle_nodes))
{
is_bkwd_cycle = false;
return true;
};
}
// no cycles
return false;
}
......@@ -299,7 +299,7 @@ namespace ngraph
for (size_t i = 0; i < arg_count; ++i)
{
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
if (nodes_done.count(dep) == 0)
if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0)
{
can_add = false;
nodes_to_do.push(dep);
......@@ -425,4 +425,11 @@ namespace ngraph
/// \return A vector containing a handle for each output of src that is connected to an input
/// of `dst`.
std::vector<Output<Node>> get_outputs_to(Node& src, Node& dst);
/// Checks the func for graph cycles starting from results going backwards, then from parameters
/// going forward.
/// It returns true if a cycle is found and the first cycle encountered.
bool check_for_cycles(const ngraph::Function* func,
ngraph::NodeVector& cycle_nodes,
bool& is_bkwd_cycle);
}
......@@ -87,7 +87,7 @@ std::shared_ptr<Node> Node::copy_with_new_inputs(const OutputVector& inputs) con
return copy_with_new_inputs(inputs, get_control_dependencies());
}
std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i)
std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i, bool for_get_output_element)
{
for (auto in : output(i).get_target_inputs())
{
......@@ -96,7 +96,7 @@ std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i)
return in.get_node()->shared_from_this();
}
}
return get_output_element(output(i), true);
return get_output_element(output(i), for_get_output_element);
}
std::shared_ptr<Node>
......
......@@ -281,7 +281,8 @@ namespace ngraph
/// Returns the partial shape for output i
const PartialShape& get_output_partial_shape(size_t i) const;
std::shared_ptr<Node> get_output_as_single_output_node(size_t i);
std::shared_ptr<Node> get_output_as_single_output_node(size_t i,
bool for_get_output_element = true);
/// Checks that there is exactly one output and returns its shape
// TODO: deprecate in favor of node->output(0).get_shape() with a suitable check in the
......@@ -565,9 +566,10 @@ namespace ngraph
std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node; }
/// \return A useable shared pointer to this output. If index 0, the node,
/// otherwise find or create a GOE.
std::shared_ptr<Node> as_single_output_node() const NGRAPH_DEPRECATED("Transitional.")
std::shared_ptr<Node> as_single_output_node(bool for_get_output_element = true) const
NGRAPH_DEPRECATED("Transitional.")
{
return m_node->get_output_as_single_output_node(m_index);
return m_node->get_output_as_single_output_node(m_index, for_get_output_element);
}
/// \return The index of the output referred to by this output handle.
......
......@@ -3888,7 +3888,7 @@ TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell)
}
}
TEST(cpu_fusion, MLIR_DISABLE_TEST(validate_fuse_gru_inputs))
TEST(cpu_fusion, validate_fuse_gru_inputs)
{
const std::string file_name("mxnet/gru_debug.json");
auto cpu_func = make_function_from_file(file_name);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment