Commit 4528f86d authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #96 from NervanaSystems/bob/tsort_opt

update top sort to be linear rather than exponential complexit to b…
parents 86128587 410610f4
......@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <unordered_map>
#include <deque>
#include "topological_sort.hpp"
#include "node.hpp"
#include "util.hpp"
......@@ -19,31 +22,19 @@
using namespace ngraph;
using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n)
{
for (auto dn = m_dependent_nodes.begin(); dn != m_dependent_nodes.end(); dn++)
{
if (dn->first > 0) // Skip zero as they should never be promoted
{
auto it = find(dn->second.begin(), dn->second.end(), n);
if (it != dn->second.end())
{
// found the node
dn->second.erase(it);
m_dependent_nodes[dn->first - 1].push_back(n);
}
}
}
}
void ngraph::TopologicalSort::process(node_ptr p)
{
deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count;
traverse_nodes(p, [&](node_ptr node) {
list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()];
node_list.push_back(node.get());
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node.get());
}
});
list<Node*>& independent_nodes = m_dependent_nodes[0];
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
......@@ -52,12 +43,22 @@ void ngraph::TopologicalSort::process(node_ptr p)
for (auto user : independent_node->users())
{
promote_node(user);
node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
}
}
const std::vector<Node*>& ngraph::TopologicalSort::get_sorted_list() const
const std::list<Node*>& ngraph::TopologicalSort::get_sorted_list() const
{
return m_sorted_list;
}
std::list<Node*>& ngraph::TopologicalSort::get_sorted_list()
{
return m_sorted_list;
}
......@@ -14,10 +14,8 @@
#pragma once
#include <list>
#include <map>
#include <memory>
#include <vector>
#include <list>
namespace ngraph
{
......@@ -32,11 +30,11 @@ public:
TopologicalSort() {}
void process(node_ptr);
const std::vector<Node*>& get_sorted_list() const;
const std::list<Node*>& get_sorted_list() const;
std::list<Node*>& get_sorted_list();
private:
void promote_node(Node* n);
std::map<size_t, std::list<Node*>> m_dependent_nodes;
std::vector<Node*> m_sorted_list;
std::list<Node*> m_sorted_list;
};
......@@ -22,11 +22,12 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/topological_sort.hpp"
#include "ngraph/visualize.hpp"
#include "util.hpp"
using namespace std;
using namespace ngraph;
static bool validate_list(const vector<Node*>& nodes)
static bool validate_list(const list<Node*>& nodes)
{
bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
......@@ -38,7 +39,7 @@ static bool validate_list(const vector<Node*>& nodes)
{
dependencies.push_back(n.get());
}
auto tmp = it + 1;
auto tmp = it++;
for (; tmp != nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
......@@ -87,12 +88,18 @@ TEST(topological_sort, basic)
ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0);
Visualize vz;
vz.add(r0);
vz.save_dot("test.png");
// Visualize vz;
// vz.add(r0);
// vz.save_dot("test.png");
TopologicalSort ts;
ts.process(r0);
auto sorted_list = ts.get_sorted_list();
size_t node_count = 0;
traverse_nodes(r0, [&](node_ptr node) {
node_count++;
});
EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment