Commit ba486a1b authored by Robert Kimball's avatar Robert Kimball

unreferenced args unit test pass

parent 452a9a1a
...@@ -39,7 +39,7 @@ Function::Function(const std::shared_ptr<Node>& result, ...@@ -39,7 +39,7 @@ Function::Function(const std::shared_ptr<Node>& result,
parameter->assign_function(this, i++); parameter->assign_function(this, i++);
} }
traverse_nodes(result, [&](shared_ptr<Node> node) { m_ops.push_back(node); }); traverse_nodes(this, [&](shared_ptr<Node> node) { m_ops.push_back(node); });
} }
void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops) void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops)
......
...@@ -35,7 +35,7 @@ bool CollectFunctions::run_on_function(shared_ptr<ngraph::Function> func) ...@@ -35,7 +35,7 @@ bool CollectFunctions::run_on_function(shared_ptr<ngraph::Function> func)
shared_ptr<ngraph::Function> f = stack.front(); shared_ptr<ngraph::Function> f = stack.front();
stack.pop_front(); stack.pop_front();
functions.insert(f); functions.insert(f);
traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) { traverse_nodes(f, [&](shared_ptr<Node> node) {
shared_ptr<op::FunctionCall> fc = dynamic_pointer_cast<op::FunctionCall>(node); shared_ptr<op::FunctionCall> fc = dynamic_pointer_cast<op::FunctionCall>(node);
if (fc) if (fc)
{ {
......
...@@ -32,7 +32,7 @@ bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function> ...@@ -32,7 +32,7 @@ bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function>
unordered_map<const Node*, size_t> node_depencency_count; unordered_map<const Node*, size_t> node_depencency_count;
unordered_map<Node*, shared_ptr<Node>> node_map; unordered_map<Node*, shared_ptr<Node>> node_map;
traverse_nodes(func->get_result(), [&](shared_ptr<Node> node) { traverse_nodes(func, [&](shared_ptr<Node> node) {
node_map[node.get()] = node; node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size(); node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0) if (node->get_arguments().size() == 0)
......
...@@ -28,7 +28,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu ...@@ -28,7 +28,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu
for (shared_ptr<Function> f : functions) for (shared_ptr<Function> f : functions)
{ {
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) { traverse_nodes(f, [&](shared_ptr<Node> node) {
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << add_attributes(arg); m_ss << add_attributes(arg);
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <map> #include <map>
#include <unordered_set> #include <unordered_set>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -137,7 +138,41 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list) ...@@ -137,7 +138,41 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed; return seed;
} }
void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, void ngraph::traverse_nodes(std::shared_ptr<ngraph::Function> p,
std::function<void(shared_ptr<Node>)> f)
{
traverse_nodes(p.get(), f);
}
void ngraph::traverse_nodes(ngraph::Function* p, std::function<void(shared_ptr<Node>)> f)
{
std::unordered_set<shared_ptr<Node>> instances_seen;
deque<shared_ptr<Node>> stack;
stack.push_front(p->get_result());
for (auto param : p->get_parameters())
{
stack.push_front(param);
}
while (stack.size() > 0)
{
shared_ptr<Node> n = stack.front();
if (instances_seen.find(n) == instances_seen.end())
{
instances_seen.insert(n);
f(n);
}
stack.pop_front();
for (auto arg : n->get_arguments())
{
stack.push_front(arg);
}
}
}
void ngraph::traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::function<void(shared_ptr<Node>)> f) std::function<void(shared_ptr<Node>)> f)
{ {
std::unordered_set<shared_ptr<Node>> instances_seen; std::unordered_set<shared_ptr<Node>> instances_seen;
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
namespace ngraph namespace ngraph
{ {
class Node; class Node;
class Function;
class stopwatch; class stopwatch;
extern std::map<std::string, stopwatch*> stopwatch_statistics; extern std::map<std::string, stopwatch*> stopwatch_statistics;
...@@ -195,8 +196,11 @@ namespace ngraph ...@@ -195,8 +196,11 @@ namespace ngraph
return a * b; return a * b;
} }
void traverse_nodes(const std::shared_ptr<Node>& p, void traverse_nodes(Function* p, std::function<void(std::shared_ptr<Node>)> f);
std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(std::shared_ptr<Function> p, std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(std::shared_ptr<Node> p, std::function<void(std::shared_ptr<Node>)> f);
void free_nodes(std::shared_ptr<Node>); void free_nodes(std::shared_ptr<Node>);
} // end namespace ngraph } // end namespace ngraph
...@@ -79,9 +79,30 @@ shared_ptr<Function> make_test_graph() ...@@ -79,9 +79,30 @@ shared_ptr<Function> make_test_graph()
return f0; return f0;
} }
size_t get_node_count(std::shared_ptr<Node> n) size_t get_node_count(std::shared_ptr<Node> node)
{ {
size_t node_count = 0; size_t node_count = 0;
traverse_nodes(n, [&](shared_ptr<Node> node) { node_count++; });
unordered_set<shared_ptr<Node>> instances_seen;
deque<shared_ptr<Node>> stack;
stack.push_back(node);
while (stack.size() > 0)
{
shared_ptr<Node> n = stack.front();
if (instances_seen.find(n) == instances_seen.end())
{
instances_seen.insert(n);
stack.push_back(n);
node_count++;
}
stack.pop_front();
for (auto arg : n->get_arguments())
{
stack.push_front(arg);
}
}
return node_count; return node_count;
} }
...@@ -75,7 +75,8 @@ TEST(topological_sort, basic) ...@@ -75,7 +75,8 @@ TEST(topological_sort, basic)
pass_manager.run_passes(f0); pass_manager.run_passes(f0);
auto sorted_list = f0->get_ordered_ops(); auto sorted_list = f0->get_ordered_ops();
size_t node_count = get_node_count(r0); size_t node_count = 0;
traverse_nodes(f0, [&](shared_ptr<Node>) { node_count++; });
EXPECT_EQ(node_count, sorted_list.size()); EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list)); EXPECT_TRUE(validate_list(sorted_list));
...@@ -130,12 +131,12 @@ TEST(benchmark, topological_sort) ...@@ -130,12 +131,12 @@ TEST(benchmark, topological_sort)
NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms"; NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0; size_t node_count = 0;
traverse_nodes(result, [&](shared_ptr<Node> node) { node_count++; }); traverse_nodes(f0, [&](shared_ptr<Node> node) { node_count++; });
NGRAPH_INFO << "node count " << node_count; NGRAPH_INFO << "node count " << node_count;
timer.start(); timer.start();
ngraph::free_nodes(result); ngraph::free_nodes(f0->get_result());
timer.stop(); timer.stop();
NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms"; NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment