Commit 02affea5 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

faster topological sort (#1886)

parent 51104813
...@@ -42,6 +42,8 @@ namespace ngraph ...@@ -42,6 +42,8 @@ namespace ngraph
/// \return the node that this is an input of /// \return the node that this is an input of
std::shared_ptr<Node> get_node() const; std::shared_ptr<Node> get_node() const;
/// \return the raw pointer to the node that this is an input of
Node* get_raw_pointer_node() const { return m_node; }
/// \return the position within all supplied tensors of this input /// \return the position within all supplied tensors of this input
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
/// \return the connected output /// \return the connected output
......
...@@ -75,13 +75,13 @@ namespace ngraph ...@@ -75,13 +75,13 @@ namespace ngraph
{ {
for (auto cd : node->get_control_dependencies()) for (auto cd : node->get_control_dependencies())
{ {
control_deps_count++;
control_deps_users[cd.get()].insert(node.get()); control_deps_users[cd.get()].insert(node.get());
} }
control_deps_count = node->get_control_dependencies().size();
} }
node_map[node.get()] = node; node_map[node.get()] = node;
size_t deps_count = node->get_arguments().size() + control_deps_count; size_t deps_count = node->get_inputs().size() + control_deps_count;
node_dependency_count[node.get()] = deps_count; node_dependency_count[node.get()] = deps_count;
if (deps_count == 0) if (deps_count == 0)
{ {
...@@ -96,14 +96,17 @@ namespace ngraph ...@@ -96,14 +96,17 @@ namespace ngraph
result_list.push_back(node_map[independent_node]); result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front(); independent_nodes.pop_front();
for (auto user_sp : independent_node->get_users()) for (auto& output : independent_node->get_outputs())
{ {
Node* user = user_sp.get(); for (auto& input : output.get_inputs())
node_dependency_count[user] -= 1;
size_t count = node_dependency_count[user];
if (count == 0)
{ {
independent_nodes.push_back(user); auto user = input->get_raw_pointer_node();
node_dependency_count[user] -= 1;
size_t count = node_dependency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment