Commit a7c1841e authored by Robert Kimball's avatar Robert Kimball

topological sort working with unit test

parent a89c33b4
......@@ -25,7 +25,7 @@ ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type)
// Add this node as a user of each argument.
for (auto node : m_arguments)
{
node->m_users.insert(node.get());
node->m_users.insert(this);
}
}
......
......@@ -32,7 +32,7 @@ namespace ngraph
** zero or more nodes as arguments and one value, which is either a tensor
** view or a (possibly empty) tuple of values.
**/
class Node : public TypedValueMixin
class Node : public TypedValueMixin, public std::enable_shared_from_this<Node>
{
public:
using ptr = std::shared_ptr<Node>;
......@@ -74,11 +74,11 @@ namespace ngraph
friend std::ostream& operator<<(std::ostream&, const Node&);
protected:
Nodes m_arguments;
std::multiset<Node*> m_users;
std::string m_name;
size_t m_instance_id;
static size_t m_next_instance_id;
Nodes m_arguments;
std::multiset<Node*> m_users;
std::string m_name;
size_t m_instance_id;
static size_t m_next_instance_id;
};
using node_ptr = std::shared_ptr<Node>;
......
......@@ -12,8 +12,53 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp"
#include "util.hpp"
void ngraph::TopologicalSort::process(node_ptr node)
using namespace ngraph;
using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n)
{
for (auto dn=m_dependent_nodes.begin(); dn!=m_dependent_nodes.end(); dn++)
{
if (dn->first > 0) // Skip zero as they should never be promoted
{
auto it = find(dn->second.begin(), dn->second.end(), n);
if (it != dn->second.end())
{
// found the node
dn->second.erase(it);
m_dependent_nodes[dn->first-1].push_back(n);
}
}
}
}
void ngraph::TopologicalSort::process(node_ptr p)
{
traverse_nodes(p, [&](node_ptr node)
{
list<Node*>& node_list = m_dependent_nodes[node->arguments().size()];
node_list.push_back(node.get());
});
list<Node*>& independent_nodes = m_dependent_nodes[0];
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
m_sorted_list.push_back(independent_node);
independent_nodes.pop_front();
for (auto user : independent_node->users())
{
promote_node(user);
}
}
}
const std::vector<Node*>& ngraph::TopologicalSort::get_sorted_list() const
{
return m_sorted_list;
}
......@@ -15,6 +15,8 @@
#pragma once
#include <memory>
#include <map>
#include <list>
namespace ngraph
{
......@@ -26,9 +28,14 @@ namespace ngraph
class ngraph::TopologicalSort
{
public:
TopologicalSort();
TopologicalSort() {}
static void process(node_ptr);
void process(node_ptr);
const std::vector<Node*>& get_sorted_list() const;
private:
void promote_node(Node* n);
std::map<size_t, std::list<Node*>> m_dependent_nodes;
std::vector<Node*> m_sorted_list;
};
......@@ -40,7 +40,7 @@ namespace ngraph
virtual ~ValueType() {}
virtual bool operator==(const ValueType::ptr& that) const = 0;
bool operator!=(const ValueType::ptr& that) const { return !(*this == that); }
bool operator!=(const ValueType::ptr& that) const { return !(*this == that); }
};
/**
......@@ -140,6 +140,7 @@ namespace ngraph
** The type associated with this value.
**/
const ValueType::ptr type() const { return m_type; }
protected:
ValueType::ptr m_type;
};
......
......@@ -26,28 +26,74 @@
using namespace std;
using namespace ngraph;
TEST(top_sort, basic)
static bool validate_list(const vector<Node*>& nodes)
{
auto arg0 = op::parameter(element::Float::type, {1});
ASSERT_NE(nullptr, arg0);
auto arg1 = op::parameter(element::Float::type, {1});
ASSERT_NE(nullptr, arg1);
auto t0 = op::add(arg0, arg1);
bool rc = true;
for (auto it=nodes.rbegin(); it!=nodes.rend(); it++)
{
Node* node = *it;
auto node_tmp = *it;
auto dependencies_tmp = node_tmp->arguments();
vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp)
{
dependencies.push_back(n.get());
}
auto tmp = it+1;
for (; tmp!=nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
if (found != dependencies.end())
{
dependencies.erase(found);
}
}
if (dependencies.size() > 0)
{
rc = false;
}
}
return rc;
}
TEST(topological_sort, basic)
{
vector<shared_ptr<Parameter>> args;
for (int i=0; i<10; i++)
{
auto arg = op::parameter(element::Float::type, {1});
ASSERT_NE(nullptr, arg);
args.push_back(arg);
}
auto t0 = op::add(args[0], args[1]);
ASSERT_NE(nullptr, t0);
auto t1 = op::add(arg0, arg1);
auto t1 = op::dot(t0, args[2]);
ASSERT_NE(nullptr, t1);
Node::ptr r0 = op::add(t0, t1);
auto t2 = op::multiply(t0, args[3]);
ASSERT_NE(nullptr, t2);
auto t3 = op::add(t1, args[4]);
ASSERT_NE(nullptr, t2);
auto t4 = op::add(t2, args[5]);
ASSERT_NE(nullptr, t3);
Node::ptr r0 = op::add(t3, t4);
ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, {arg0, arg1});
auto f0 = op::function(r0, args);
ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0);
cout << "op_r0 name " << *r0 << endl;
Visualize vz;
vz.add(r0);
vz.save_dot("test.png");
TopologicalSort::process(r0);
TopologicalSort ts;
ts.process(r0);
auto sorted_list = ts.get_sorted_list();
EXPECT_TRUE(validate_list(sorted_list));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment