Commit 1e73a52c authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

Merge branch 'master' into jmenon/cpu

parents 504d3585 00e56657
...@@ -39,7 +39,7 @@ Function::Function(const std::shared_ptr<Node>& result, ...@@ -39,7 +39,7 @@ Function::Function(const std::shared_ptr<Node>& result,
parameter->assign_function(this, i++); parameter->assign_function(this, i++);
} }
traverse_nodes(result, [&](shared_ptr<Node> node) { m_ops.push_back(node); }); traverse_nodes(this, [&](shared_ptr<Node> node) { m_ops.push_back(node); });
} }
void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops) void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops)
......
...@@ -35,7 +35,7 @@ bool CollectFunctions::run_on_function(shared_ptr<ngraph::Function> func) ...@@ -35,7 +35,7 @@ bool CollectFunctions::run_on_function(shared_ptr<ngraph::Function> func)
shared_ptr<ngraph::Function> f = stack.front(); shared_ptr<ngraph::Function> f = stack.front();
stack.pop_front(); stack.pop_front();
functions.insert(f); functions.insert(f);
traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) { traverse_nodes(f, [&](shared_ptr<Node> node) {
shared_ptr<op::FunctionCall> fc = dynamic_pointer_cast<op::FunctionCall>(node); shared_ptr<op::FunctionCall> fc = dynamic_pointer_cast<op::FunctionCall>(node);
if (fc) if (fc)
{ {
......
...@@ -32,7 +32,7 @@ bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function> ...@@ -32,7 +32,7 @@ bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function>
unordered_map<const Node*, size_t> node_depencency_count; unordered_map<const Node*, size_t> node_depencency_count;
unordered_map<Node*, shared_ptr<Node>> node_map; unordered_map<Node*, shared_ptr<Node>> node_map;
traverse_nodes(func->get_result(), [&](shared_ptr<Node> node) { traverse_nodes(func, [&](shared_ptr<Node> node) {
node_map[node.get()] = node; node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size(); node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0) if (node->get_arguments().size() == 0)
......
...@@ -28,7 +28,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu ...@@ -28,7 +28,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu
for (shared_ptr<Function> f : functions) for (shared_ptr<Function> f : functions)
{ {
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) { traverse_nodes(f, [&](shared_ptr<Node> node) {
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << add_attributes(arg); m_ss << add_attributes(arg);
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <map> #include <map>
#include <unordered_set> #include <unordered_set>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -137,12 +138,23 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list) ...@@ -137,12 +138,23 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed; return seed;
} }
void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, void ngraph::traverse_nodes(std::shared_ptr<ngraph::Function> p,
std::function<void(shared_ptr<Node>)> f) std::function<void(shared_ptr<Node>)> f)
{
traverse_nodes(p.get(), f);
}
void ngraph::traverse_nodes(ngraph::Function* p, std::function<void(shared_ptr<Node>)> f)
{ {
std::unordered_set<shared_ptr<Node>> instances_seen; std::unordered_set<shared_ptr<Node>> instances_seen;
deque<shared_ptr<Node>> stack; deque<shared_ptr<Node>> stack;
stack.push_front(p);
stack.push_front(p->get_result());
for (auto param : p->get_parameters())
{
stack.push_front(param);
}
while (stack.size() > 0) while (stack.size() > 0)
{ {
...@@ -160,7 +172,7 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, ...@@ -160,7 +172,7 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p,
} }
} }
void ngraph::free_nodes(shared_ptr<Node> p) void ngraph::free_nodes(shared_ptr<Function> p)
{ {
std::deque<Node*> sorted_list; std::deque<Node*> sorted_list;
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
namespace ngraph namespace ngraph
{ {
class Node; class Node;
class Function;
class stopwatch; class stopwatch;
extern std::map<std::string, stopwatch*> stopwatch_statistics; extern std::map<std::string, stopwatch*> stopwatch_statistics;
...@@ -195,8 +196,9 @@ namespace ngraph ...@@ -195,8 +196,9 @@ namespace ngraph
return a * b; return a * b;
} }
void traverse_nodes(const std::shared_ptr<Node>& p, void traverse_nodes(Function* p, std::function<void(std::shared_ptr<Node>)> f);
std::function<void(std::shared_ptr<Node>)> f);
void free_nodes(std::shared_ptr<Node>); void traverse_nodes(std::shared_ptr<Function> p, std::function<void(std::shared_ptr<Node>)> f);
void free_nodes(std::shared_ptr<Function>);
} // end namespace ngraph } // end namespace ngraph
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -38,7 +39,8 @@ TEST(pass_manager, add) ...@@ -38,7 +39,8 @@ TEST(pass_manager, add)
pass_manager.register_pass<pass::AssignTensors>(); pass_manager.register_pass<pass::AssignTensors>();
auto graph = make_test_graph(); auto graph = make_test_graph();
size_t node_count = get_node_count(graph->get_result()); size_t node_count = 0;
traverse_nodes(graph, [&](shared_ptr<Node> node) { node_count++; });
pass_manager.run_passes(graph); pass_manager.run_passes(graph);
auto sorted = graph->get_ordered_ops(); auto sorted = graph->get_ordered_ops();
EXPECT_EQ(node_count, sorted.size()); EXPECT_EQ(node_count, sorted.size());
......
...@@ -78,10 +78,3 @@ shared_ptr<Function> make_test_graph() ...@@ -78,10 +78,3 @@ shared_ptr<Function> make_test_graph()
return f0; return f0;
} }
size_t get_node_count(std::shared_ptr<Node> n)
{
size_t node_count = 0;
traverse_nodes(n, [&](shared_ptr<Node> node) { node_count++; });
return node_count;
}
...@@ -25,4 +25,3 @@ namespace ngraph ...@@ -25,4 +25,3 @@ namespace ngraph
bool validate_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes); bool validate_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes);
std::shared_ptr<ngraph::Function> make_test_graph(); std::shared_ptr<ngraph::Function> make_test_graph();
size_t get_node_count(std::shared_ptr<ngraph::Node> n);
...@@ -21,9 +21,13 @@ ...@@ -21,9 +21,13 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/collect_functions.hpp" #include "ngraph/pass/collect_functions.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
...@@ -71,7 +75,8 @@ TEST(topological_sort, basic) ...@@ -71,7 +75,8 @@ TEST(topological_sort, basic)
pass_manager.run_passes(f0); pass_manager.run_passes(f0);
auto sorted_list = f0->get_ordered_ops(); auto sorted_list = f0->get_ordered_ops();
size_t node_count = get_node_count(r0); size_t node_count = 0;
traverse_nodes(f0, [&](shared_ptr<Node>) { node_count++; });
EXPECT_EQ(node_count, sorted_list.size()); EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list)); EXPECT_TRUE(validate_list(sorted_list));
...@@ -126,12 +131,12 @@ TEST(benchmark, topological_sort) ...@@ -126,12 +131,12 @@ TEST(benchmark, topological_sort)
NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms"; NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0; size_t node_count = 0;
traverse_nodes(result, [&](shared_ptr<Node> node) { node_count++; }); traverse_nodes(f0, [&](shared_ptr<Node> node) { node_count++; });
NGRAPH_INFO << "node count " << node_count; NGRAPH_INFO << "node count " << node_count;
timer.start(); timer.start();
ngraph::free_nodes(result); ngraph::free_nodes(f0);
timer.stop(); timer.stop();
NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms"; NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
} }
...@@ -184,3 +189,26 @@ TEST(topological_sort, collect_functions) ...@@ -184,3 +189,26 @@ TEST(topological_sort, collect_functions)
EXPECT_TRUE(contains(fnames, "g")); EXPECT_TRUE(contains(fnames, "g"));
EXPECT_TRUE(contains(fnames, "h")); EXPECT_TRUE(contains(fnames, "h"));
} }
TEST(topological_sort, unused_function_arg)
{
// Create a function with an unused argument
// B is unused in the function but must be in the graph
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_f = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto result = A + C + C;
auto f = make_shared<Function>(result, rt_f, op::Parameters{A, B, C}, "f");
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
// pass_manager.register_pass<pass::DumpSorted>("sorted.txt");
pass_manager.run_passes(f);
list<shared_ptr<Node>> ops = f->get_ordered_ops();
EXPECT_EQ(5, ops.size());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment