Commit ef58667f authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Fixes for cpu_fusion.validate_fuse_gru_inputs (#3511)

* WIP

* Fix incorrect CK output adjustment

* Bug fix and enroce sanity check

* Change cycle search depth, and fix sanity check

* cpu_fusion.validate_fuse_gru_inputs passes.

* Fix as_single_output to be able to always create a GOE

* minor fix. style-apply

* Clean up debug msgs

* Switch to backward cycle check

* Enable failing test

* PR fixes

* Address feedback: Add fwd cycle checks. Make cycle checking depth configurable
parent a1f3202c
......@@ -83,7 +83,7 @@ namespace ngraph
friend class MLIRSubgraph;
public:
MLIRSubgraphExtractionPass() {}
MLIRSubgraphExtractionPass();
bool run_on_function(std::shared_ptr<Function> func) override;
/// Checks if an ngraph node is supported by MLIR backend
bool is_supported_mlir_op(std::shared_ptr<Node> node);
......@@ -123,7 +123,10 @@ namespace ngraph
unsigned depth = 0);
private:
static const std::set<std::type_index> m_supported_ops;
void build_subgraphs(std::shared_ptr<Function> func);
NodeVector build_ck_nodes(std::shared_ptr<Function> func);
void sanity_check(std::shared_ptr<Function> func, NodeVector& ck_nodes);
private:
using IDGraphMap = std::unordered_map<int, MLIRSubgraph>;
......@@ -132,6 +135,10 @@ namespace ngraph
NodeGraphMap m_node_to_graph;
// Mutex over sub-graph IDs
std::mutex m_subgraph_mutex;
static const std::set<std::type_index> m_supported_ops;
// Maximum depth to check for cycles during merging of sub-graphs.
// If exceeded, we conservatively assume a cycle.
int m_max_cycle_depth;
};
}
}
......@@ -653,3 +653,93 @@ std::vector<Output<Node>> ngraph::get_outputs_to(Node& src, Node& dst)
return result;
}
static bool check_for_cycles_bkwd(std::shared_ptr<ngraph::Node> node,
std::deque<std::shared_ptr<ngraph::Node>>& path,
std::unordered_set<std::shared_ptr<ngraph::Node>>& path_set,
ngraph::NodeVector& cycle_nodes)
{
path.push_back(node);
path_set.insert(node);
for (auto& input : node->inputs())
{
auto arg = input.get_source_output().get_node_shared_ptr();
if (path_set.find(arg) != path_set.end())
{
for (auto it : path)
{
cycle_nodes.push_back(it);
}
// last node
cycle_nodes.push_back(arg);
return true;
}
if (check_for_cycles_bkwd(arg, path, path_set, cycle_nodes))
{
return true;
}
}
path_set.erase(path.back());
path.pop_back();
return false;
}
static bool check_for_cycles_fwd(std::shared_ptr<ngraph::Node> node,
std::deque<std::shared_ptr<ngraph::Node>>& path,
std::unordered_set<std::shared_ptr<ngraph::Node>>& path_set,
ngraph::NodeVector& cycle_nodes)
{
path.push_back(node);
path_set.insert(node);
for (auto& arg : node->get_users())
{
if (path_set.find(arg) != path_set.end())
{
for (auto it : path)
{
cycle_nodes.push_back(it);
}
// last node
cycle_nodes.push_back(arg);
return true;
}
if (check_for_cycles_fwd(arg, path, path_set, cycle_nodes))
{
return true;
}
}
path_set.erase(path.back());
path.pop_back();
return false;
}
bool ngraph::check_for_cycles(const ngraph::Function* func,
ngraph::NodeVector& cycle_nodes,
bool& is_bkwd_cycle)
{
for (auto res : func->get_results())
{
std::deque<std::shared_ptr<Node>> path;
// mirror of path stack for faster cycle check
std::unordered_set<std::shared_ptr<Node>> path_set;
if (check_for_cycles_bkwd(res, path, path_set, cycle_nodes))
{
is_bkwd_cycle = true;
return true;
};
}
for (auto param : func->get_parameters())
{
std::deque<std::shared_ptr<Node>> path;
// mirror of path stack for faster cycle check
std::unordered_set<std::shared_ptr<Node>> path_set;
if (check_for_cycles_fwd(param, path, path_set, cycle_nodes))
{
is_bkwd_cycle = false;
return true;
};
}
// no cycles
return false;
}
......@@ -299,7 +299,7 @@ namespace ngraph
for (size_t i = 0; i < arg_count; ++i)
{
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
if (nodes_done.count(dep) == 0)
if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0)
{
can_add = false;
nodes_to_do.push(dep);
......@@ -425,4 +425,11 @@ namespace ngraph
/// \return A vector containing a handle for each output of src that is connected to an input
/// of `dst`.
std::vector<Output<Node>> get_outputs_to(Node& src, Node& dst);
/// Checks the func for graph cycles starting from results going backwards, then from parameters
/// going forward.
/// It returns true if a cycle is found and the first cycle encountered.
bool check_for_cycles(const ngraph::Function* func,
ngraph::NodeVector& cycle_nodes,
bool& is_bkwd_cycle);
}
......@@ -87,7 +87,7 @@ std::shared_ptr<Node> Node::copy_with_new_inputs(const OutputVector& inputs) con
return copy_with_new_inputs(inputs, get_control_dependencies());
}
std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i)
std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i, bool for_get_output_element)
{
for (auto in : output(i).get_target_inputs())
{
......@@ -96,7 +96,7 @@ std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i)
return in.get_node()->shared_from_this();
}
}
return get_output_element(output(i), true);
return get_output_element(output(i), for_get_output_element);
}
std::shared_ptr<Node>
......
......@@ -281,7 +281,8 @@ namespace ngraph
/// Returns the partial shape for output i
const PartialShape& get_output_partial_shape(size_t i) const;
std::shared_ptr<Node> get_output_as_single_output_node(size_t i);
std::shared_ptr<Node> get_output_as_single_output_node(size_t i,
bool for_get_output_element = true);
/// Checks that there is exactly one output and returns its shape
// TODO: deprecate in favor of node->output(0).get_shape() with a suitable check in the
......@@ -565,9 +566,10 @@ namespace ngraph
std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node; }
/// \return A useable shared pointer to this output. If index 0, the node,
/// otherwise find or create a GOE.
std::shared_ptr<Node> as_single_output_node() const NGRAPH_DEPRECATED("Transitional.")
std::shared_ptr<Node> as_single_output_node(bool for_get_output_element = true) const
NGRAPH_DEPRECATED("Transitional.")
{
return m_node->get_output_as_single_output_node(m_index);
return m_node->get_output_as_single_output_node(m_index, for_get_output_element);
}
/// \return The index of the output referred to by this output handle.
......
......@@ -3888,7 +3888,7 @@ TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell)
}
}
TEST(cpu_fusion, MLIR_DISABLE_TEST(validate_fuse_gru_inputs))
TEST(cpu_fusion, validate_fuse_gru_inputs)
{
const std::string file_name("mxnet/gru_debug.json");
auto cpu_func = make_function_from_file(file_name);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment