Unverified Commit a527d460 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Visualize Graphs After Each Pass (#815)

* visualization tracing

* visualize -> m_visualize. add a programmatic way to enable visualization. tweak pass names
parent 13770af2
......@@ -25,12 +25,18 @@
#include "ngraph/op/reduce.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp"
using namespace std;
using namespace ngraph;
ngraph::pass::Manager::Manager()
{
static const auto nevt = std::getenv("NGRAPH_ENABLE_VISUALIZE_TRACING");
if (nevt)
{
m_visualize = true;
}
}
ngraph::pass::Manager::Manager(bool to_set_is_output)
......@@ -54,6 +60,7 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
set<shared_ptr<Function>> tfs(begin(fs), end(fs));
get_state().set_functions(tfs);
size_t index = 0;
for (shared_ptr<PassBase> pass : m_pass_list)
{
pass->set_state(get_state());
......@@ -89,6 +96,20 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
call_graph_pass->run_on_call_graph(f->get_ordered_ops());
}
}
if (m_visualize)
{
//visualizations will be named after the outermost function
const size_t num_digits_in_pass_index = 3;
std::string index_str = std::to_string(index);
index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
auto fname = fs.at(0)->get_name() + std::string("_") + index_str + std::string("_") +
m_pass_names.at(index) + std::string(".") +
pass::VisualizeTree::get_file_ext();
pass::VisualizeTree vt(fname);
vt.run_on_module(fs);
}
index++;
}
}
......
......@@ -18,6 +18,7 @@
#include <list>
#include <memory>
#include <typeinfo>
#include <vector>
#include "ngraph/pass/manager_state.hpp"
......@@ -48,13 +49,19 @@ public:
auto pass = std::make_shared<T>(args...);
auto pass_base = std::static_pointer_cast<PassBase>(pass);
m_pass_list.push_back(pass_base);
if (m_visualize)
{
m_pass_names.push_back(typeid(T).name());
}
}
void run_passes(std::shared_ptr<Function>);
ManagerState& get_state();
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
private:
std::vector<std::string> m_pass_names;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
ManagerState m_state;
bool m_visualize = false;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment