Commit fc0455ba authored by Scott Cyphers's avatar Scott Cyphers

Merge branch 'master' into cyphers/names

parents 235c8ea0 fac27c37
...@@ -25,7 +25,7 @@ ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, std::sha ...@@ -25,7 +25,7 @@ ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, std::sha
// Add this node as a user of each argument. // Add this node as a user of each argument.
for (auto node : m_arguments) for (auto node : m_arguments)
{ {
node->m_users.insert(node.get()); node->m_users.insert(this);
} }
} }
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor /// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
class Node : public TypedValueMixin class Node : public TypedValueMixin, public std::enable_shared_from_this<Node>
{ {
protected: protected:
...@@ -68,10 +68,10 @@ namespace ngraph ...@@ -68,10 +68,10 @@ namespace ngraph
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
protected: protected:
Nodes m_arguments; Nodes m_arguments;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
std::string m_name; std::string m_name;
size_t m_instance_id; size_t m_instance_id;
static size_t m_next_instance_id; static size_t m_next_instance_id;
}; };
} }
...@@ -12,8 +12,53 @@ ...@@ -12,8 +12,53 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp" #include "topological_sort.hpp"
#include "util.hpp"
void ngraph::TopologicalSort::process(node_ptr node) using namespace ngraph;
using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n)
{
for (auto dn=m_dependent_nodes.begin(); dn!=m_dependent_nodes.end(); dn++)
{
if (dn->first > 0) // Skip zero as they should never be promoted
{
auto it = find(dn->second.begin(), dn->second.end(), n);
if (it != dn->second.end())
{
// found the node
dn->second.erase(it);
m_dependent_nodes[dn->first-1].push_back(n);
}
}
}
}
void ngraph::TopologicalSort::process(node_ptr p)
{
traverse_nodes(p, [&](node_ptr node)
{
list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()];
node_list.push_back(node.get());
});
list<Node*>& independent_nodes = m_dependent_nodes[0];
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
m_sorted_list.push_back(independent_node);
independent_nodes.pop_front();
for (auto user : independent_node->users())
{
promote_node(user);
}
}
}
const std::vector<Node*>& ngraph::TopologicalSort::get_sorted_list() const
{ {
return m_sorted_list;
} }
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <map>
#include <list>
namespace ngraph namespace ngraph
{ {
...@@ -26,9 +28,14 @@ namespace ngraph ...@@ -26,9 +28,14 @@ namespace ngraph
class ngraph::TopologicalSort class ngraph::TopologicalSort
{ {
public: public:
TopologicalSort(); TopologicalSort() {}
static void process(node_ptr); void process(node_ptr);
const std::vector<Node*>& get_sorted_list() const;
private: private:
void promote_node(Node* n);
std::map<size_t, std::list<Node*>> m_dependent_nodes;
std::vector<Node*> m_sorted_list;
}; };
...@@ -26,28 +26,73 @@ ...@@ -26,28 +26,73 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
TEST(top_sort, basic) static bool validate_list(const vector<Node*>& nodes)
{ {
auto arg0 = op::parameter(element::Float::element_type(), {1}); bool rc = true;
ASSERT_NE(nullptr, arg0); for (auto it=nodes.rbegin(); it!=nodes.rend(); it++)
auto arg1 = op::parameter(element::Float::element_type(), {1}); {
ASSERT_NE(nullptr, arg1); auto node_tmp = *it;
auto t0 = op::add(arg0, arg1); auto dependencies_tmp = node_tmp->get_arguments();
vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp)
{
dependencies.push_back(n.get());
}
auto tmp = it+1;
for (; tmp!=nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
if (found != dependencies.end())
{
dependencies.erase(found);
}
}
if (dependencies.size() > 0)
{
rc = false;
}
}
return rc;
}
TEST(topological_sort, basic)
{
vector<shared_ptr<Parameter>> args;
for (int i=0; i<10; i++)
{
auto arg = op::parameter(element::Float::element_type(), {1});
ASSERT_NE(nullptr, arg);
args.push_back(arg);
}
auto t0 = op::add(args[0], args[1]);
ASSERT_NE(nullptr, t0); ASSERT_NE(nullptr, t0);
auto t1 = op::add(arg0, arg1); auto t1 = op::dot(t0, args[2]);
ASSERT_NE(nullptr, t1); ASSERT_NE(nullptr, t1);
auto r0 = op::add(t0, t1); auto t2 = op::multiply(t0, args[3]);
ASSERT_NE(nullptr, t2);
auto t3 = op::add(t1, args[4]);
ASSERT_NE(nullptr, t2);
auto t4 = op::add(t2, args[5]);
ASSERT_NE(nullptr, t3);
auto r0 = op::add(t3, t4);
ASSERT_NE(nullptr, r0); ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, {arg0, arg1}); auto f0 = op::function(r0, args);
ASSERT_NE(nullptr, f0); ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->get_arguments().size()); ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0); auto op_r0 = static_pointer_cast<Op>(r0);
cout << "op_r0 name " << *r0 << endl;
Visualize vz; Visualize vz;
vz.add(r0); vz.add(r0);
vz.save_dot("test.png"); vz.save_dot("test.png");
TopologicalSort::process(r0); TopologicalSort ts;
ts.process(r0);
auto sorted_list = ts.get_sorted_list();
EXPECT_TRUE(validate_list(sorted_list));
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment