Unverified Commit a527d460 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Visualize Graphs After Each Pass (#815)

* visualization tracing

* visualize -> m_visualize. add a programmatic way to enable visualization. tweak pass names
parent 13770af2
...@@ -25,12 +25,18 @@ ...@@ -25,12 +25,18 @@
#include "ngraph/op/reduce.hpp" #include "ngraph/op/reduce.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
ngraph::pass::Manager::Manager() ngraph::pass::Manager::Manager()
{ {
static const auto nevt = std::getenv("NGRAPH_ENABLE_VISUALIZE_TRACING");
if (nevt)
{
m_visualize = true;
}
} }
ngraph::pass::Manager::Manager(bool to_set_is_output) ngraph::pass::Manager::Manager(bool to_set_is_output)
...@@ -54,6 +60,7 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func) ...@@ -54,6 +60,7 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
set<shared_ptr<Function>> tfs(begin(fs), end(fs)); set<shared_ptr<Function>> tfs(begin(fs), end(fs));
get_state().set_functions(tfs); get_state().set_functions(tfs);
size_t index = 0;
for (shared_ptr<PassBase> pass : m_pass_list) for (shared_ptr<PassBase> pass : m_pass_list)
{ {
pass->set_state(get_state()); pass->set_state(get_state());
...@@ -89,6 +96,20 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func) ...@@ -89,6 +96,20 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
call_graph_pass->run_on_call_graph(f->get_ordered_ops()); call_graph_pass->run_on_call_graph(f->get_ordered_ops());
} }
} }
if (m_visualize)
{
//visualizations will be named after the outermost function
const size_t num_digits_in_pass_index = 3;
std::string index_str = std::to_string(index);
index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
auto fname = fs.at(0)->get_name() + std::string("_") + index_str + std::string("_") +
m_pass_names.at(index) + std::string(".") +
pass::VisualizeTree::get_file_ext();
pass::VisualizeTree vt(fname);
vt.run_on_module(fs);
}
index++;
} }
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <list> #include <list>
#include <memory> #include <memory>
#include <typeinfo>
#include <vector> #include <vector>
#include "ngraph/pass/manager_state.hpp" #include "ngraph/pass/manager_state.hpp"
...@@ -48,13 +49,19 @@ public: ...@@ -48,13 +49,19 @@ public:
auto pass = std::make_shared<T>(args...); auto pass = std::make_shared<T>(args...);
auto pass_base = std::static_pointer_cast<PassBase>(pass); auto pass_base = std::static_pointer_cast<PassBase>(pass);
m_pass_list.push_back(pass_base); m_pass_list.push_back(pass_base);
if (m_visualize)
{
m_pass_names.push_back(typeid(T).name());
}
} }
void run_passes(std::shared_ptr<Function>); void run_passes(std::shared_ptr<Function>);
ManagerState& get_state(); ManagerState& get_state();
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
private: private:
std::vector<std::string> m_pass_names;
std::vector<std::shared_ptr<PassBase>> m_pass_list; std::vector<std::shared_ptr<PassBase>> m_pass_list;
ManagerState m_state; ManagerState m_state;
bool m_visualize = false;
}; };
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment