Unverified Commit 7e89f1bb authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Remove an optimization for caching a list of ordered ops (#360)

* remove caching of ordered_ops

* graph_util logging msgs

* small cleanup

* remove files for the TopologicalSort pass

* remove NGRAPH_DEBUG from graph_util.hpp
parent 8627c495
...@@ -83,7 +83,6 @@ set (SRC ...@@ -83,7 +83,6 @@ set (SRC
pass/memory_layout.cpp pass/memory_layout.cpp
pass/memory_visualize.cpp pass/memory_visualize.cpp
pass/pass.cpp pass/pass.cpp
pass/topological_sort.cpp
pass/visualize_tree.cpp pass/visualize_tree.cpp
pattern/matcher.cpp pattern/matcher.cpp
runtime/aligned_buffer.cpp runtime/aligned_buffer.cpp
......
...@@ -31,7 +31,6 @@ Function::Function(const Nodes& results, ...@@ -31,7 +31,6 @@ Function::Function(const Nodes& results,
: m_results(results) : m_results(results)
, m_parameters(parameters) , m_parameters(parameters)
, m_name(name) , m_name(name)
, m_ordered_ops_valid(false)
, m_temporary_pool_size(0) , m_temporary_pool_size(0)
, m_instance_id(m_next_instance_id.fetch_add(1)) , m_instance_id(m_next_instance_id.fetch_add(1))
{ {
...@@ -57,28 +56,9 @@ Function::Function(const std::shared_ptr<Node>& result, ...@@ -57,28 +56,9 @@ Function::Function(const std::shared_ptr<Node>& result,
{ {
} }
void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops) std::list<shared_ptr<Node>> Function::get_ordered_ops()
{ {
m_ordered_ops = ordered_ops; return topological_sort(get_ops());
m_ordered_ops_valid = true;
}
std::list<shared_ptr<Node>>& Function::get_ordered_ops()
{
if (!m_ordered_ops_valid)
{
throw ngraph_error("Access to ordered ops invalid");
}
return m_ordered_ops;
}
const std::list<shared_ptr<Node>>& Function::get_ordered_ops() const
{
if (!m_ordered_ops_valid)
{
throw ngraph_error("Access to ordered ops invalid");
}
return m_ordered_ops;
} }
std::string Function::get_name() const std::string Function::get_name() const
......
...@@ -73,11 +73,7 @@ namespace ngraph ...@@ -73,11 +73,7 @@ namespace ngraph
const std::string& const std::string&
name); //so we can use `dynamic_cast` in FunctionCall to double check if we are dealing with an XLA or regular function name); //so we can use `dynamic_cast` in FunctionCall to double check if we are dealing with an XLA or regular function
std::list<std::shared_ptr<Node>> get_ops() const; std::list<std::shared_ptr<Node>> get_ops() const;
std::list<std::shared_ptr<Node>>& get_ordered_ops(); std::list<std::shared_ptr<Node>> get_ordered_ops();
const std::list<std::shared_ptr<Node>>& get_ordered_ops() const;
void set_ordered_ops(const std::list<std::shared_ptr<Node>>&);
void set_ordered_ops_valid() { m_ordered_ops_valid = true; }
void clear_ordered_ops_valid() { m_ordered_ops_valid = false; }
friend std::ostream& operator<<(std::ostream&, const Function&); friend std::ostream& operator<<(std::ostream&, const Function&);
size_t get_instance_id() { return m_instance_id; } size_t get_instance_id() { return m_instance_id; }
size_t get_temporary_pool_size(); size_t get_temporary_pool_size();
...@@ -87,8 +83,6 @@ namespace ngraph ...@@ -87,8 +83,6 @@ namespace ngraph
Nodes m_results; Nodes m_results;
std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
std::string m_name; std::string m_name;
bool m_ordered_ops_valid;
std::list<std::shared_ptr<Node>> m_ordered_ops;
size_t m_temporary_pool_size; size_t m_temporary_pool_size;
private: private:
......
...@@ -112,9 +112,6 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -112,9 +112,6 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
return; return;
} }
//fix input/output descriptors //fix input/output descriptors
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
assert(target->get_outputs().size() == replacement->get_outputs().size()); assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++) for (size_t i = 0; i < target->get_outputs().size(); i++)
{ {
...@@ -135,16 +132,11 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -135,16 +132,11 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target, void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement) std::shared_ptr<Node> replacement)
{ {
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
NGRAPH_DEBUG << "user = " << replacement << " , " << replacement->get_name();
for (auto user : target->users()) for (auto user : target->users())
{ {
auto& args = const_cast<ngraph::Nodes&>(user->get_arguments_FOR_GRAPH_REWRITE_ONLY()); auto& args = const_cast<ngraph::Nodes&>(user->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = std::find(begin(args), end(args), target); auto it = std::find(begin(args), end(args), target);
assert(it != end(args)); assert(it != end(args));
//NGRAPH_DEBUG << "Replaced " << *it << " w/ " << replacement << " in args of " << user << " , args = " << &args;
it = args.erase(it); it = args.erase(it);
args.insert(it, replacement); args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user); const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
......
...@@ -109,13 +109,13 @@ namespace nervana ...@@ -109,13 +109,13 @@ namespace nervana
__LINE__, \ __LINE__, \
__PRETTY_FUNCTION__) \ __PRETTY_FUNCTION__) \
.stream() .stream()
/*
#define NGRAPH_DEBUG \ //#define NGRAPH_DEBUG \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_DEBUG, \ // nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_DEBUG, \
nervana::get_file_name(__FILE__), \ // nervana::get_file_name(__FILE__), \
__LINE__, \ // __LINE__, \
__PRETTY_FUNCTION__) \ // __PRETTY_FUNCTION__) \
.stream() // .stream()
*/
#define NGRAPH_DEBUG nervana::get_nil_stream() #define NGRAPH_DEBUG nervana::get_nil_stream()
} }
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
class AssignLayout : public CallGraphPass class AssignLayout : public CallGraphPass
{ {
public: public:
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes) override virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override
{ {
for (const std::shared_ptr<Node>& node : nodes) for (const std::shared_ptr<Node>& node : nodes)
{ {
......
...@@ -31,7 +31,7 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list( ...@@ -31,7 +31,7 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
return rewritten; return rewritten;
} }
bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes) bool ngraph::pass::GraphRewrite::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
{ {
return run_matchers_on_nodes_list(nodes, m_matchers); return run_matchers_on_nodes_list(nodes, m_matchers);
} }
...@@ -49,7 +49,7 @@ public: ...@@ -49,7 +49,7 @@ public:
} }
void add_matcher(std::shared_ptr<pattern::Matcher> m) { m_matchers.push_back(m); } void add_matcher(std::shared_ptr<pattern::Matcher> m) { m_matchers.push_back(m); }
virtual bool run_on_call_graph(std::list<std::shared_ptr<ngraph::Node>>&) override; virtual bool run_on_call_graph(const std::list<std::shared_ptr<ngraph::Node>>&) override;
static bool static bool
run_matchers_on_nodes_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes, run_matchers_on_nodes_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers); const std::vector<std::shared_ptr<pattern::Matcher>>& matchers);
......
...@@ -28,7 +28,7 @@ using namespace std; ...@@ -28,7 +28,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::descriptor; using namespace ngraph::descriptor;
bool pass::Liveness::run_on_call_graph(list<shared_ptr<Node>>& ops) bool pass::Liveness::run_on_call_graph(const list<shared_ptr<Node>>& ops)
{ {
unordered_set<Tensor*> currently_live; unordered_set<Tensor*> currently_live;
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
class ngraph::pass::Liveness : public CallGraphPass class ngraph::pass::Liveness : public CallGraphPass
{ {
public: public:
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override; virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>&) override;
private: private:
bool is_temporary(const descriptor::Tensor&); bool is_temporary(const descriptor::Tensor&);
......
...@@ -74,5 +74,5 @@ class ngraph::pass::CallGraphPass : public PassBase ...@@ -74,5 +74,5 @@ class ngraph::pass::CallGraphPass : public PassBase
{ {
public: public:
virtual ~CallGraphPass() {} virtual ~CallGraphPass() {}
virtual bool run_on_call_graph(std::list<std::shared_ptr<ngraph::Node>>&) = 0; virtual bool run_on_call_graph(const std::list<std::shared_ptr<ngraph::Node>>&) = 0;
}; };
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <deque>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
using namespace ngraph;
using namespace std;
bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function> func)
{
func->set_ordered_ops(topological_sort(func->get_ops()));
return false;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include <memory>
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class TopologicalSort;
}
}
class ngraph::pass::TopologicalSort : public FunctionPass
{
public:
TopologicalSort() {}
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
};
...@@ -84,7 +84,6 @@ ...@@ -84,7 +84,6 @@
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp" #include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp" #include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_emitter.hpp" #include "ngraph/runtime/cpu/cpu_emitter.hpp"
...@@ -208,7 +207,6 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -208,7 +207,6 @@ void runtime::cpu::CPU_ExternalFunction::compile()
string dump_filename = file_util::path_join(s_output_dir, function_name + "_ops.txt"); string dump_filename = file_util::path_join(s_output_dir, function_name + "_ops.txt");
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
// For now, just make everyone row-major. // For now, just make everyone row-major.
pass_manager.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>(); pass_manager.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
......
...@@ -70,7 +70,6 @@ ...@@ -70,7 +70,6 @@
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/interpreter/int_backend.hpp" #include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/runtime/interpreter/int_call_frame.hpp" #include "ngraph/runtime/interpreter/int_call_frame.hpp"
#include "ngraph/runtime/interpreter/int_external_function.hpp" #include "ngraph/runtime/interpreter/int_external_function.hpp"
...@@ -108,7 +107,6 @@ void runtime::interpreter::ExternalFunction::compile() ...@@ -108,7 +107,6 @@ void runtime::interpreter::ExternalFunction::compile()
string dump_filename = file_util::path_join(s_output_dir, function_name + "_ops.txt"); string dump_filename = file_util::path_join(s_output_dir, function_name + "_ops.txt");
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
// For now, just make everyone row-major. // For now, just make everyone row-major.
pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>(); pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
......
...@@ -40,7 +40,6 @@ set (SRC ...@@ -40,7 +40,6 @@ set (SRC
pattern.cpp pattern.cpp
shape.cpp shape.cpp
tensor.cpp tensor.cpp
topological_sort.cpp
type_prop.cpp type_prop.cpp
util/autodiff/backprop_function.cpp util/autodiff/backprop_function.cpp
util/test_tools.cpp util/test_tools.cpp
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
...@@ -41,7 +40,6 @@ TEST(liveness, constant) ...@@ -41,7 +40,6 @@ TEST(liveness, constant)
auto f = make_shared<Function>(make_shared<op::Negative>(c), op::Parameters{}); auto f = make_shared<Function>(make_shared<op::Negative>(c), op::Parameters{});
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
...@@ -64,7 +62,6 @@ TEST(liveness, liveness) ...@@ -64,7 +62,6 @@ TEST(liveness, liveness)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>(image); pass_manager.register_pass<pass::VisualizeTree>(image);
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::DumpSorted>(dump_file); pass_manager.register_pass<pass::DumpSorted>(dump_file);
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -32,8 +31,6 @@ TEST(pass_manager, add) ...@@ -32,8 +31,6 @@ TEST(pass_manager, add)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
auto graph = make_test_graph(); auto graph = make_test_graph();
size_t node_count = 0; size_t node_count = 0;
traverse_nodes(graph, [&](shared_ptr<Node> node) { node_count++; }); traverse_nodes(graph, [&](shared_ptr<Node> node) { node_count++; });
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
...@@ -207,7 +206,6 @@ TEST(memory_layout, basic) ...@@ -207,7 +206,6 @@ TEST(memory_layout, basic)
{ {
string dump_file = "memory_layout.txt"; string dump_file = "memory_layout.txt";
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>(); pass_manager.register_pass<pass::MemoryLayout>();
pass_manager.register_pass<pass::DumpSorted>(dump_file); pass_manager.register_pass<pass::DumpSorted>(dump_file);
...@@ -223,7 +221,6 @@ TEST(memory_layout, constant) ...@@ -223,7 +221,6 @@ TEST(memory_layout, constant)
{ {
string dump_file = "constant.txt"; string dump_file = "constant.txt";
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>(); pass_manager.register_pass<pass::MemoryLayout>();
pass_manager.register_pass<pass::DumpSorted>(dump_file); pass_manager.register_pass<pass::DumpSorted>(dump_file);
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
...@@ -196,7 +195,6 @@ TEST(pattern, graph_rewrite) ...@@ -196,7 +195,6 @@ TEST(pattern, graph_rewrite)
auto shape = Shape{1}; auto shape = Shape{1};
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<TestGraphRewrite>(); pass_manager.register_pass<TestGraphRewrite>();
{ {
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
using namespace std; using namespace std;
...@@ -34,7 +33,6 @@ TEST(tensor, size) ...@@ -34,7 +33,6 @@ TEST(tensor, size)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
{ {
...@@ -112,7 +110,6 @@ TEST(tensor, read_write) ...@@ -112,7 +110,6 @@ TEST(tensor, read_write)
TEST(tensor, output_flag) TEST(tensor, output_flag)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
auto arg0 = make_shared<op::Parameter>(element::f32, Shape{1}); auto arg0 = make_shared<op::Parameter>(element::f32, Shape{1});
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
TEST(topological_sort, basic)
{
vector<shared_ptr<op::Parameter>> args;
for (int i = 0; i < 10; i++)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{});
ASSERT_NE(nullptr, arg);
args.push_back(arg);
}
auto t0 = make_shared<op::Add>(args[0], args[1]);
ASSERT_NE(nullptr, t0);
auto t1 = make_shared<op::Dot>(t0, args[2]);
ASSERT_NE(nullptr, t1);
auto t2 = make_shared<op::Multiply>(t0, args[3]);
ASSERT_NE(nullptr, t2);
auto t3 = make_shared<op::Add>(t1, args[4]);
ASSERT_NE(nullptr, t2);
auto t4 = make_shared<op::Add>(t2, args[5]);
ASSERT_NE(nullptr, t3);
auto r0 = make_shared<op::Add>(t3, t4);
ASSERT_NE(nullptr, r0);
auto f0 = make_shared<Function>(r0, args);
ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->get_input_ops().size());
// Visualize vz;
// vz.add(r0);
// vz.save_dot("test.png");
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.run_passes(f0);
auto sorted_list = f0->get_ordered_ops();
size_t node_count = 0;
traverse_nodes(f0, [&](shared_ptr<Node>) { node_count++; });
EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list));
}
// TEST(topological_sort, cycle)
// {
// vector<shared_ptr<op::Parameter>> args;
// for (int i = 0; i < 10; i++)
// {
// auto arg = make_shared<op::Parameter>(element::f32, Shape{1});
// ASSERT_NE(nullptr, arg);
// args.push_back(arg);
// }
// auto add_0 = make_shared<op::Add>(args[0], args[1]);
// auto add_1 = make_shared<op::Add>(args[0], args[1]);
// }
shared_ptr<Node> make_cell(shared_ptr<Node> in_0, shared_ptr<Node> in_1, shared_ptr<Node> in_2)
{
auto t0 = make_shared<op::Dot>(in_0, in_1);
auto t1 = make_shared<op::Add>(t0, in_2);
auto t2 = make_shared<op::Negative>(t1); // no tanh yet, this will do
return static_pointer_cast<Node>(t2);
}
TEST(benchmark, topological_sort)
{
stopwatch timer;
// x[i+1] = tanh(dot(W,x[i])+b)
shared_ptr<Node> result;
vector<shared_ptr<op::Parameter>> args;
result = make_shared<op::Parameter>(element::f32, Shape{});
for (int i = 0; i < 1000000; i++)
{
auto in_1 = make_shared<op::Parameter>(element::f32, Shape{});
auto in_2 = make_shared<op::Parameter>(element::f32, Shape{});
args.push_back(in_1);
args.push_back(in_2);
result = make_cell(result, in_1, in_2);
}
auto f0 = make_shared<Function>(result, args);
timer.start();
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.run_passes(f0);
auto sorted_list = f0->get_ordered_ops();
timer.stop();
NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0;
traverse_nodes(f0, [&](shared_ptr<Node> node) { node_count++; });
NGRAPH_INFO << "node count " << node_count;
timer.start();
ngraph::free_nodes(f0);
timer.stop();
NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
}
TEST(topological_sort, collect_functions)
{
// First create "f(A,B,C) = (A+B)*C".
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, op::Parameters{A, B, C}, "f");
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto X = make_shared<op::Parameter>(element::f32, shape);
auto Y = make_shared<op::Parameter>(element::f32, shape);
auto Z = make_shared<op::Parameter>(element::f32, shape);
auto g = make_shared<Function>(make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}) +
make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}),
op::Parameters{X, Y, Z},
"g");
// Now make "h(X,Y,Z) = g(X,Y,Z) + g(X,Y,Z)"
auto X1 = make_shared<op::Parameter>(element::f32, shape);
auto Y1 = make_shared<op::Parameter>(element::f32, shape);
auto Z1 = make_shared<op::Parameter>(element::f32, shape);
auto h = make_shared<Function>(make_shared<op::FunctionCall>(g, Nodes{X1, Y1, Z1}) +
make_shared<op::FunctionCall>(g, Nodes{X1, Y1, Z1}),
op::Parameters{X1, Y1, Z1},
"h");
pass::Manager pass_manager;
pass_manager.run_passes(h);
set<string> expected = {"f", "g", "h"};
auto functions = pass_manager.get_state().get_functions();
vector<string> fnames;
for (shared_ptr<Function> func : functions)
{
fnames.push_back(func->get_name());
}
EXPECT_EQ(expected.size(), functions.size());
EXPECT_TRUE(contains(fnames, "f"));
EXPECT_TRUE(contains(fnames, "g"));
EXPECT_TRUE(contains(fnames, "h"));
}
TEST(topological_sort, unused_function_arg)
{
// Create a function with an unused argument
// B is unused in the function but must be in the graph
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto result = A + C + C;
auto f = make_shared<Function>(result, op::Parameters{A, B, C}, "f");
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
// pass_manager.register_pass<pass::DumpSorted>("sorted.txt");
pass_manager.run_passes(f);
list<shared_ptr<Node>> ops = f->get_ordered_ops();
EXPECT_EQ(5, ops.size());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment