Commit a7c1841e authored by Robert Kimball's avatar Robert Kimball

topological sort working with unit test

parent a89c33b4
...@@ -25,7 +25,7 @@ ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type) ...@@ -25,7 +25,7 @@ ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type)
// Add this node as a user of each argument. // Add this node as a user of each argument.
for (auto node : m_arguments) for (auto node : m_arguments)
{ {
node->m_users.insert(node.get()); node->m_users.insert(this);
} }
} }
......
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
** zero or more nodes as arguments and one value, which is either a tensor ** zero or more nodes as arguments and one value, which is either a tensor
** view or a (possibly empty) tuple of values. ** view or a (possibly empty) tuple of values.
**/ **/
class Node : public TypedValueMixin class Node : public TypedValueMixin, public std::enable_shared_from_this<Node>
{ {
public: public:
using ptr = std::shared_ptr<Node>; using ptr = std::shared_ptr<Node>;
...@@ -74,11 +74,11 @@ namespace ngraph ...@@ -74,11 +74,11 @@ namespace ngraph
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
protected: protected:
Nodes m_arguments; Nodes m_arguments;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
std::string m_name; std::string m_name;
size_t m_instance_id; size_t m_instance_id;
static size_t m_next_instance_id; static size_t m_next_instance_id;
}; };
using node_ptr = std::shared_ptr<Node>; using node_ptr = std::shared_ptr<Node>;
......
...@@ -12,8 +12,53 @@ ...@@ -12,8 +12,53 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp" #include "topological_sort.hpp"
#include "util.hpp"
void ngraph::TopologicalSort::process(node_ptr node) using namespace ngraph;
using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n)
{
for (auto dn=m_dependent_nodes.begin(); dn!=m_dependent_nodes.end(); dn++)
{
if (dn->first > 0) // Skip zero as they should never be promoted
{
auto it = find(dn->second.begin(), dn->second.end(), n);
if (it != dn->second.end())
{
// found the node
dn->second.erase(it);
m_dependent_nodes[dn->first-1].push_back(n);
}
}
}
}
void ngraph::TopologicalSort::process(node_ptr p)
{
traverse_nodes(p, [&](node_ptr node)
{
list<Node*>& node_list = m_dependent_nodes[node->arguments().size()];
node_list.push_back(node.get());
});
list<Node*>& independent_nodes = m_dependent_nodes[0];
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
m_sorted_list.push_back(independent_node);
independent_nodes.pop_front();
for (auto user : independent_node->users())
{
promote_node(user);
}
}
}
const std::vector<Node*>& ngraph::TopologicalSort::get_sorted_list() const
{ {
return m_sorted_list;
} }
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <map>
#include <list>
namespace ngraph namespace ngraph
{ {
...@@ -26,9 +28,14 @@ namespace ngraph ...@@ -26,9 +28,14 @@ namespace ngraph
class ngraph::TopologicalSort class ngraph::TopologicalSort
{ {
public: public:
TopologicalSort(); TopologicalSort() {}
static void process(node_ptr); void process(node_ptr);
const std::vector<Node*>& get_sorted_list() const;
private: private:
void promote_node(Node* n);
std::map<size_t, std::list<Node*>> m_dependent_nodes;
std::vector<Node*> m_sorted_list;
}; };
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
virtual ~ValueType() {} virtual ~ValueType() {}
virtual bool operator==(const ValueType::ptr& that) const = 0; virtual bool operator==(const ValueType::ptr& that) const = 0;
bool operator!=(const ValueType::ptr& that) const { return !(*this == that); } bool operator!=(const ValueType::ptr& that) const { return !(*this == that); }
}; };
/** /**
...@@ -140,6 +140,7 @@ namespace ngraph ...@@ -140,6 +140,7 @@ namespace ngraph
** The type associated with this value. ** The type associated with this value.
**/ **/
const ValueType::ptr type() const { return m_type; } const ValueType::ptr type() const { return m_type; }
protected: protected:
ValueType::ptr m_type; ValueType::ptr m_type;
}; };
......
...@@ -26,28 +26,74 @@ ...@@ -26,28 +26,74 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
TEST(top_sort, basic) static bool validate_list(const vector<Node*>& nodes)
{ {
auto arg0 = op::parameter(element::Float::type, {1}); bool rc = true;
ASSERT_NE(nullptr, arg0); for (auto it=nodes.rbegin(); it!=nodes.rend(); it++)
auto arg1 = op::parameter(element::Float::type, {1}); {
ASSERT_NE(nullptr, arg1); Node* node = *it;
auto t0 = op::add(arg0, arg1); auto node_tmp = *it;
auto dependencies_tmp = node_tmp->arguments();
vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp)
{
dependencies.push_back(n.get());
}
auto tmp = it+1;
for (; tmp!=nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
if (found != dependencies.end())
{
dependencies.erase(found);
}
}
if (dependencies.size() > 0)
{
rc = false;
}
}
return rc;
}
TEST(topological_sort, basic)
{
vector<shared_ptr<Parameter>> args;
for (int i=0; i<10; i++)
{
auto arg = op::parameter(element::Float::type, {1});
ASSERT_NE(nullptr, arg);
args.push_back(arg);
}
auto t0 = op::add(args[0], args[1]);
ASSERT_NE(nullptr, t0); ASSERT_NE(nullptr, t0);
auto t1 = op::add(arg0, arg1); auto t1 = op::dot(t0, args[2]);
ASSERT_NE(nullptr, t1); ASSERT_NE(nullptr, t1);
Node::ptr r0 = op::add(t0, t1); auto t2 = op::multiply(t0, args[3]);
ASSERT_NE(nullptr, t2);
auto t3 = op::add(t1, args[4]);
ASSERT_NE(nullptr, t2);
auto t4 = op::add(t2, args[5]);
ASSERT_NE(nullptr, t3);
Node::ptr r0 = op::add(t3, t4);
ASSERT_NE(nullptr, r0); ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, {arg0, arg1}); auto f0 = op::function(r0, args);
ASSERT_NE(nullptr, f0); ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->arguments().size()); ASSERT_EQ(2, r0->arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0); auto op_r0 = static_pointer_cast<Op>(r0);
cout << "op_r0 name " << *r0 << endl;
Visualize vz; Visualize vz;
vz.add(r0); vz.add(r0);
vz.save_dot("test.png"); vz.save_dot("test.png");
TopologicalSort::process(r0); TopologicalSort ts;
ts.process(r0);
auto sorted_list = ts.get_sorted_list();
EXPECT_TRUE(validate_list(sorted_list));
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment