Commit 02affea5 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

faster topological sort (#1886)

parent 51104813
......@@ -42,6 +42,8 @@ namespace ngraph
/// \return the node that this is an input of
std::shared_ptr<Node> get_node() const;
/// \return the raw pointer to the node that this is an input of
Node* get_raw_pointer_node() const { return m_node; }
/// \return the position within all supplied tensors of this input
size_t get_index() const { return m_index; }
/// \return the connected output
......
......@@ -75,13 +75,13 @@ namespace ngraph
{
for (auto cd : node->get_control_dependencies())
{
control_deps_count++;
control_deps_users[cd.get()].insert(node.get());
}
control_deps_count = node->get_control_dependencies().size();
}
node_map[node.get()] = node;
size_t deps_count = node->get_arguments().size() + control_deps_count;
size_t deps_count = node->get_inputs().size() + control_deps_count;
node_dependency_count[node.get()] = deps_count;
if (deps_count == 0)
{
......@@ -96,9 +96,11 @@ namespace ngraph
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto user_sp : independent_node->get_users())
for (auto& output : independent_node->get_outputs())
{
Node* user = user_sp.get();
for (auto& input : output.get_inputs())
{
auto user = input->get_raw_pointer_node();
node_dependency_count[user] -= 1;
size_t count = node_dependency_count[user];
if (count == 0)
......@@ -106,6 +108,7 @@ namespace ngraph
independent_nodes.push_back(user);
}
}
}
if (include_control_deps)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment