Commit ecf7a396 authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Fix non-determ. order of nodes in MLIR sub-graphs (#3614)

Replace set with vector container in MLIR sub-graphs for
nodes, inputs and outputs. Using a set caused different order of
operations in MLIR code across different executions and this led to
differences in optimizations.
parent fd67a574
......@@ -48,29 +48,28 @@ using namespace ngraph::pass;
int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0;
template <typename T>
void MLIRSubgraphExtractionPass::MLIRSubgraph::add_inputs(T& inputs)
void MLIRSubgraphExtractionPass::MLIRSubgraph::add_inputs(NodeVector& inputs)
{
// inputs list are not exclusive, avoid duplication
for (auto node : inputs)
{
if (m_input_nodes.find(node) == m_input_nodes.end())
if (m_input_node_set.insert(node).second)
{
m_input_nodes.insert(node);
m_input_node_vector.push_back(node);
}
}
}
template <typename T>
void MLIRSubgraphExtractionPass::MLIRSubgraph::add_outputs(T& outputs)
void MLIRSubgraphExtractionPass::MLIRSubgraph::add_outputs(NodeVector& outputs)
{
m_output_nodes.insert(outputs.begin(), outputs.end());
m_output_nodes.insert(m_output_nodes.end(), outputs.begin(), outputs.end());
}
void MLIRSubgraphExtractionPass::MLIRSubgraph::add_node(std::shared_ptr<Node> node)
{
NGRAPH_CHECK(m_nodes.find(node) == m_nodes.end(), "node added to graph before");
m_nodes.insert(node);
NGRAPH_CHECK(m_pass.m_node_to_graph.find(node) == m_pass.m_node_to_graph.end(),
"node added to graph before");
m_nodes.emplace_back(node);
m_pass.m_node_to_graph[node] = get_id();
}
......@@ -89,7 +88,7 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2)
}
// nodes of sub-graphs are exclusive
m_nodes.insert(sg2.get_nodes().begin(), sg2.get_nodes().end());
m_nodes.insert(m_nodes.end(), sg2.get_nodes().begin(), sg2.get_nodes().end());
// merge inputs
add_inputs(sg2.get_inputs());
......
......@@ -53,17 +53,15 @@ namespace ngraph
/// Get sub-graph id
int get_id() const { return m_graph_id; }
/// Get all nodes in the sub-graph.
NodeSet& get_nodes() { return m_nodes; }
NodeVector& get_nodes() { return m_nodes; }
/// Get input nodes. Predecessors to head nodes.
NodeSet& get_inputs() { return m_input_nodes; }
NodeVector& get_inputs() { return m_input_node_vector; }
/// Get output nodes. Nodes in the sub-graph with edges to external nodes.
NodeSet& get_outputs() { return m_output_nodes; }
NodeVector& get_outputs() { return m_output_nodes; }
/// Add a list of input nodes to the sub-graph.
template <typename T>
void add_inputs(T& inputs);
void add_inputs(NodeVector& inputs);
/// Add a list of output nodes to the sub-graph.
template <typename T>
void add_outputs(T& outputs);
void add_outputs(NodeVector& outputs);
/// Merges sub-graph (other) into this sub-graph. other will be destroyed.
void merge(MLIRSubgraph& other);
/// Add one node to the sub-graph.
......@@ -73,10 +71,13 @@ namespace ngraph
// Unique ID for this sub-graph.
int m_graph_id;
// Actual nodes of the sub-graph
NodeSet m_nodes;
// Predecessor to head nodes in the sub-graph.
NodeSet m_input_nodes;
NodeSet m_output_nodes;
NodeVector m_nodes;
// Predecessor to head nodes in the sub-graph. Both containers have the same
// elements. Set is only used for efficient look-up operations.
NodeVector m_input_node_vector;
NodeSet m_input_node_set;
NodeVector m_output_nodes;
MLIRSubgraphExtractionPass& m_pass;
static int m_curr_graph_id;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment