Commit ba486a1b authored by Robert Kimball's avatar Robert Kimball

unreferenced args unit test pass

parent 452a9a1a
......@@ -39,7 +39,7 @@ Function::Function(const std::shared_ptr<Node>& result,
parameter->assign_function(this, i++);
}
traverse_nodes(result, [&](shared_ptr<Node> node) { m_ops.push_back(node); });
traverse_nodes(this, [&](shared_ptr<Node> node) { m_ops.push_back(node); });
}
void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops)
......
......@@ -35,7 +35,7 @@ bool CollectFunctions::run_on_function(shared_ptr<ngraph::Function> func)
shared_ptr<ngraph::Function> f = stack.front();
stack.pop_front();
functions.insert(f);
traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) {
traverse_nodes(f, [&](shared_ptr<Node> node) {
shared_ptr<op::FunctionCall> fc = dynamic_pointer_cast<op::FunctionCall>(node);
if (fc)
{
......
......@@ -32,7 +32,7 @@ bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function>
unordered_map<const Node*, size_t> node_depencency_count;
unordered_map<Node*, shared_ptr<Node>> node_map;
traverse_nodes(func->get_result(), [&](shared_ptr<Node> node) {
traverse_nodes(func, [&](shared_ptr<Node> node) {
node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
......
......@@ -28,7 +28,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu
for (shared_ptr<Function> f : functions)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) {
traverse_nodes(f, [&](shared_ptr<Node> node) {
for (auto arg : node->get_arguments())
{
m_ss << add_attributes(arg);
......
......@@ -18,6 +18,7 @@
#include <map>
#include <unordered_set>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
......@@ -137,7 +138,41 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed;
}
void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p,
void ngraph::traverse_nodes(std::shared_ptr<ngraph::Function> p,
std::function<void(shared_ptr<Node>)> f)
{
traverse_nodes(p.get(), f);
}
void ngraph::traverse_nodes(ngraph::Function* p, std::function<void(shared_ptr<Node>)> f)
{
std::unordered_set<shared_ptr<Node>> instances_seen;
deque<shared_ptr<Node>> stack;
stack.push_front(p->get_result());
for (auto param : p->get_parameters())
{
stack.push_front(param);
}
while (stack.size() > 0)
{
shared_ptr<Node> n = stack.front();
if (instances_seen.find(n) == instances_seen.end())
{
instances_seen.insert(n);
f(n);
}
stack.pop_front();
for (auto arg : n->get_arguments())
{
stack.push_front(arg);
}
}
}
void ngraph::traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::function<void(shared_ptr<Node>)> f)
{
std::unordered_set<shared_ptr<Node>> instances_seen;
......
......@@ -26,6 +26,7 @@
namespace ngraph
{
class Node;
class Function;
class stopwatch;
extern std::map<std::string, stopwatch*> stopwatch_statistics;
......@@ -195,8 +196,11 @@ namespace ngraph
return a * b;
}
void traverse_nodes(const std::shared_ptr<Node>& p,
std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(Function* p, std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(std::shared_ptr<Function> p, std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(std::shared_ptr<Node> p, std::function<void(std::shared_ptr<Node>)> f);
void free_nodes(std::shared_ptr<Node>);
} // end namespace ngraph
......@@ -79,9 +79,30 @@ shared_ptr<Function> make_test_graph()
return f0;
}
size_t get_node_count(std::shared_ptr<Node> n)
size_t get_node_count(std::shared_ptr<Node> node)
{
size_t node_count = 0;
traverse_nodes(n, [&](shared_ptr<Node> node) { node_count++; });
unordered_set<shared_ptr<Node>> instances_seen;
deque<shared_ptr<Node>> stack;
stack.push_back(node);
while (stack.size() > 0)
{
shared_ptr<Node> n = stack.front();
if (instances_seen.find(n) == instances_seen.end())
{
instances_seen.insert(n);
stack.push_back(n);
node_count++;
}
stack.pop_front();
for (auto arg : n->get_arguments())
{
stack.push_front(arg);
}
}
return node_count;
}
......@@ -75,7 +75,8 @@ TEST(topological_sort, basic)
pass_manager.run_passes(f0);
auto sorted_list = f0->get_ordered_ops();
size_t node_count = get_node_count(r0);
size_t node_count = 0;
traverse_nodes(f0, [&](shared_ptr<Node>) { node_count++; });
EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list));
......@@ -130,12 +131,12 @@ TEST(benchmark, topological_sort)
NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0;
traverse_nodes(result, [&](shared_ptr<Node> node) { node_count++; });
traverse_nodes(f0, [&](shared_ptr<Node> node) { node_count++; });
NGRAPH_INFO << "node count " << node_count;
timer.start();
ngraph::free_nodes(result);
ngraph::free_nodes(f0->get_result());
timer.stop();
NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment