Commit 1e73a52c authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

Merge branch 'master' into jmenon/cpu

parents 504d3585 00e56657
......@@ -39,7 +39,7 @@ Function::Function(const std::shared_ptr<Node>& result,
parameter->assign_function(this, i++);
}
traverse_nodes(result, [&](shared_ptr<Node> node) { m_ops.push_back(node); });
traverse_nodes(this, [&](shared_ptr<Node> node) { m_ops.push_back(node); });
}
void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops)
......
......@@ -35,7 +35,7 @@ bool CollectFunctions::run_on_function(shared_ptr<ngraph::Function> func)
shared_ptr<ngraph::Function> f = stack.front();
stack.pop_front();
functions.insert(f);
traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) {
traverse_nodes(f, [&](shared_ptr<Node> node) {
shared_ptr<op::FunctionCall> fc = dynamic_pointer_cast<op::FunctionCall>(node);
if (fc)
{
......
......@@ -32,7 +32,7 @@ bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function>
unordered_map<const Node*, size_t> node_depencency_count;
unordered_map<Node*, shared_ptr<Node>> node_map;
traverse_nodes(func->get_result(), [&](shared_ptr<Node> node) {
traverse_nodes(func, [&](shared_ptr<Node> node) {
node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
......
......@@ -28,7 +28,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu
for (shared_ptr<Function> f : functions)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) {
traverse_nodes(f, [&](shared_ptr<Node> node) {
for (auto arg : node->get_arguments())
{
m_ss << add_attributes(arg);
......
......@@ -18,6 +18,7 @@
#include <map>
#include <unordered_set>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
......@@ -137,12 +138,23 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed;
}
void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p,
void ngraph::traverse_nodes(std::shared_ptr<ngraph::Function> p,
std::function<void(shared_ptr<Node>)> f)
{
traverse_nodes(p.get(), f);
}
void ngraph::traverse_nodes(ngraph::Function* p, std::function<void(shared_ptr<Node>)> f)
{
std::unordered_set<shared_ptr<Node>> instances_seen;
deque<shared_ptr<Node>> stack;
stack.push_front(p);
stack.push_front(p->get_result());
for (auto param : p->get_parameters())
{
stack.push_front(param);
}
while (stack.size() > 0)
{
......@@ -160,7 +172,7 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p,
}
}
void ngraph::free_nodes(shared_ptr<Node> p)
void ngraph::free_nodes(shared_ptr<Function> p)
{
std::deque<Node*> sorted_list;
......
......@@ -26,6 +26,7 @@
namespace ngraph
{
class Node;
class Function;
class stopwatch;
extern std::map<std::string, stopwatch*> stopwatch_statistics;
......@@ -195,8 +196,9 @@ namespace ngraph
return a * b;
}
void traverse_nodes(const std::shared_ptr<Node>& p,
std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(Function* p, std::function<void(std::shared_ptr<Node>)> f);
void free_nodes(std::shared_ptr<Node>);
void traverse_nodes(std::shared_ptr<Function> p, std::function<void(std::shared_ptr<Node>)> f);
void free_nodes(std::shared_ptr<Function>);
} // end namespace ngraph
......@@ -24,6 +24,7 @@
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp"
#include "test_tools.hpp"
using namespace ngraph;
......@@ -38,7 +39,8 @@ TEST(pass_manager, add)
pass_manager.register_pass<pass::AssignTensors>();
auto graph = make_test_graph();
size_t node_count = get_node_count(graph->get_result());
size_t node_count = 0;
traverse_nodes(graph, [&](shared_ptr<Node> node) { node_count++; });
pass_manager.run_passes(graph);
auto sorted = graph->get_ordered_ops();
EXPECT_EQ(node_count, sorted.size());
......
......@@ -78,10 +78,3 @@ shared_ptr<Function> make_test_graph()
return f0;
}
size_t get_node_count(std::shared_ptr<Node> n)
{
size_t node_count = 0;
traverse_nodes(n, [&](shared_ptr<Node> node) { node_count++; });
return node_count;
}
......@@ -25,4 +25,3 @@ namespace ngraph
bool validate_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes);
std::shared_ptr<ngraph::Function> make_test_graph();
size_t get_node_count(std::shared_ptr<ngraph::Node> n);
......@@ -21,9 +21,13 @@
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
#include "test_tools.hpp"
......@@ -71,7 +75,8 @@ TEST(topological_sort, basic)
pass_manager.run_passes(f0);
auto sorted_list = f0->get_ordered_ops();
size_t node_count = get_node_count(r0);
size_t node_count = 0;
traverse_nodes(f0, [&](shared_ptr<Node>) { node_count++; });
EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list));
......@@ -126,12 +131,12 @@ TEST(benchmark, topological_sort)
NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0;
traverse_nodes(result, [&](shared_ptr<Node> node) { node_count++; });
traverse_nodes(f0, [&](shared_ptr<Node> node) { node_count++; });
NGRAPH_INFO << "node count " << node_count;
timer.start();
ngraph::free_nodes(result);
ngraph::free_nodes(f0);
timer.stop();
NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
}
......@@ -184,3 +189,26 @@ TEST(topological_sort, collect_functions)
EXPECT_TRUE(contains(fnames, "g"));
EXPECT_TRUE(contains(fnames, "h"));
}
TEST(topological_sort, unused_function_arg)
{
// Create a function with an unused argument
// B is unused in the function but must be in the graph
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_f = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto result = A + C + C;
auto f = make_shared<Function>(result, rt_f, op::Parameters{A, B, C}, "f");
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
// pass_manager.register_pass<pass::DumpSorted>("sorted.txt");
pass_manager.run_passes(f);
list<shared_ptr<Node>> ops = f->get_ordered_ops();
EXPECT_EQ(5, ops.size());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment