Commit 410610f4 authored by Bob Kimball's avatar Bob Kimball

update top sort to be linear rather than exponential computation to better handle large graphs

parent 8bc8579c
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <unordered_map>
#include <deque>
#include "topological_sort.hpp" #include "topological_sort.hpp"
#include "node.hpp" #include "node.hpp"
#include "util.hpp" #include "util.hpp"
...@@ -19,31 +22,19 @@ ...@@ -19,31 +22,19 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n)
{
for (auto dn = m_dependent_nodes.begin(); dn != m_dependent_nodes.end(); dn++)
{
if (dn->first > 0) // Skip zero as they should never be promoted
{
auto it = find(dn->second.begin(), dn->second.end(), n);
if (it != dn->second.end())
{
// found the node
dn->second.erase(it);
m_dependent_nodes[dn->first - 1].push_back(n);
}
}
}
}
void ngraph::TopologicalSort::process(node_ptr p) void ngraph::TopologicalSort::process(node_ptr p)
{ {
deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count;
traverse_nodes(p, [&](node_ptr node) { traverse_nodes(p, [&](node_ptr node) {
list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()]; node_depencency_count[node.get()] = node->get_arguments().size();
node_list.push_back(node.get()); if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node.get());
}
}); });
list<Node*>& independent_nodes = m_dependent_nodes[0];
while (independent_nodes.size() > 0) while (independent_nodes.size() > 0)
{ {
auto independent_node = independent_nodes.front(); auto independent_node = independent_nodes.front();
...@@ -52,12 +43,22 @@ void ngraph::TopologicalSort::process(node_ptr p) ...@@ -52,12 +43,22 @@ void ngraph::TopologicalSort::process(node_ptr p)
for (auto user : independent_node->users()) for (auto user : independent_node->users())
{ {
promote_node(user); node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
} }
} }
} }
const std::vector<Node*>& ngraph::TopologicalSort::get_sorted_list() const const std::list<Node*>& ngraph::TopologicalSort::get_sorted_list() const
{
return m_sorted_list;
}
std::list<Node*>& ngraph::TopologicalSort::get_sorted_list()
{ {
return m_sorted_list; return m_sorted_list;
} }
...@@ -14,10 +14,8 @@ ...@@ -14,10 +14,8 @@
#pragma once #pragma once
#include <list>
#include <map>
#include <memory> #include <memory>
#include <vector> #include <list>
namespace ngraph namespace ngraph
{ {
...@@ -32,11 +30,11 @@ public: ...@@ -32,11 +30,11 @@ public:
TopologicalSort() {} TopologicalSort() {}
void process(node_ptr); void process(node_ptr);
const std::vector<Node*>& get_sorted_list() const; const std::list<Node*>& get_sorted_list() const;
std::list<Node*>& get_sorted_list();
private: private:
void promote_node(Node* n); void promote_node(Node* n);
std::map<size_t, std::list<Node*>> m_dependent_nodes; std::list<Node*> m_sorted_list;
std::vector<Node*> m_sorted_list;
}; };
...@@ -22,11 +22,12 @@ ...@@ -22,11 +22,12 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/topological_sort.hpp" #include "ngraph/topological_sort.hpp"
#include "ngraph/visualize.hpp" #include "ngraph/visualize.hpp"
#include "util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
static bool validate_list(const vector<Node*>& nodes) static bool validate_list(const list<Node*>& nodes)
{ {
bool rc = true; bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++) for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
...@@ -38,7 +39,7 @@ static bool validate_list(const vector<Node*>& nodes) ...@@ -38,7 +39,7 @@ static bool validate_list(const vector<Node*>& nodes)
{ {
dependencies.push_back(n.get()); dependencies.push_back(n.get());
} }
auto tmp = it + 1; auto tmp = it++;
for (; tmp != nodes.rend(); tmp++) for (; tmp != nodes.rend(); tmp++)
{ {
auto dep_tmp = *tmp; auto dep_tmp = *tmp;
...@@ -87,12 +88,18 @@ TEST(topological_sort, basic) ...@@ -87,12 +88,18 @@ TEST(topological_sort, basic)
ASSERT_EQ(2, r0->get_arguments().size()); ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0); auto op_r0 = static_pointer_cast<Op>(r0);
Visualize vz; // Visualize vz;
vz.add(r0); // vz.add(r0);
vz.save_dot("test.png"); // vz.save_dot("test.png");
TopologicalSort ts; TopologicalSort ts;
ts.process(r0); ts.process(r0);
auto sorted_list = ts.get_sorted_list(); auto sorted_list = ts.get_sorted_list();
size_t node_count = 0;
traverse_nodes(r0, [&](node_ptr node) {
node_count++;
});
EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list)); EXPECT_TRUE(validate_list(sorted_list));
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment