Commit a51c2e80 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

Graph Visualization of CompilerKernel (#3515)

* Print CK info in visualizer

* style-apply

* Fix parameters printing

*  Fixes

* typo
parent 1683e200
...@@ -20,7 +20,10 @@ ...@@ -20,7 +20,10 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -109,6 +112,9 @@ using namespace std; ...@@ -109,6 +112,9 @@ using namespace std;
// be careful to avoid splitting the components. I have some rough ideas on how this could be // be careful to avoid splitting the components. I have some rough ideas on how this could be
// dealt with, but have not had time to implement them yet. --amprocte // dealt with, but have not had time to implement them yet. --amprocte
// //
const int ngraph::pass::VisualizeTree::max_jump_distance = 20;
class HeightMap class HeightMap
{ {
public: public:
...@@ -212,23 +218,58 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions) ...@@ -212,23 +218,58 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
} }
// TODO(amprocte): Maybe find a way to make this tunable. // TODO(amprocte): Maybe find a way to make this tunable.
const int max_jump_distance = 20;
size_t fake_node_ctr = 0; size_t fake_node_ctr = 0;
traverse_nodes(f, [&](shared_ptr<Node> node) { traverse_nodes(f, [&](shared_ptr<Node> node) {
if (auto ck = dynamic_pointer_cast<ngraph::op::CompiledKernel>(node))
{
// print sub-graph
auto nodes_list = ck->get_node_list();
// all nodes inside the CK sub-graph
for (auto& ck_node : nodes_list)
{
m_ss << add_attributes(ck_node);
}
// all edges to each node in the sub-graph
for (auto& subgraph_node : nodes_list)
{
add_node_arguments(subgraph_node, height_maps, fake_node_ctr);
}
}
add_node_arguments(node, height_maps, fake_node_ctr);
});
}
render();
return false;
}
pass::VisualizeTree::VisualizeTree(const string& file_name, node_modifiers_t nm, bool dot_only)
: m_name{file_name}
, m_node_modifiers{nm}
, m_dot_only(dot_only)
{
}
void pass::VisualizeTree::add_node_arguments(shared_ptr<Node> node,
unordered_map<Node*, HeightMap>& height_maps,
size_t& fake_node_ctr)
{
size_t arg_index = 0; size_t arg_index = 0;
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
size_t jump_distance = height_maps[arg.get()].max_jump_to(height_maps[node.get()]); size_t jump_distance = height_maps[arg.get()].max_jump_to(height_maps[node.get()]);
if (arg->description() == ngraph::op::Constant::type_name ||
if (arg->description() == "Constant" || arg->description() == "Parameter") arg->description() == ngraph::op::Parameter::type_name)
{ {
auto clone_name = "CLONE_" + to_string(fake_node_ctr); auto clone_name = "CLONE_" + to_string(fake_node_ctr);
auto color = (arg->description() == "Parameter" ? "blue" : "black"); auto color = (arg->description() == "Parameter" ? "blue" : "black");
m_ss << " " << clone_name m_ss << " " << clone_name << "[shape=\"box\" style=\"dashed,filled\" color=\""
<< "[shape=\"box\" style=\"dashed,filled\" color=\"" << color << color << "\" fillcolor=\"white\" label=\"" << get_node_name(arg) << "\"]\n";
<< "\" fillcolor=\"white\" label=\"" << get_node_name(arg) << "\"]\n";
m_ss << " " << clone_name << " -> " << node->get_name() m_ss << " " << clone_name << " -> " << node->get_name()
<< label_edge(arg, node, arg_index, jump_distance) << "\n"; << label_edge(arg, node, arg_index, jump_distance) << "\n";
fake_node_ctr++; fake_node_ctr++;
...@@ -239,14 +280,12 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions) ...@@ -239,14 +280,12 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
m_ss << add_attributes(node); m_ss << add_attributes(node);
auto recv_node_name = "RECV_" + to_string(fake_node_ctr); auto recv_node_name = "RECV_" + to_string(fake_node_ctr);
auto send_node_name = "SEND_" + to_string(fake_node_ctr); auto send_node_name = "SEND_" + to_string(fake_node_ctr);
m_ss << " " << recv_node_name << "[shape=\"box\" style=\"solid,filled\" " m_ss << " " << recv_node_name << "[shape=\"box\" style=\"solid,filled\" "
"fillcolor=\"#ffcccc\" label=\"Receive[" "fillcolor=\"#ffcccc\" label=\"Receive["
<< arg->get_name() << "]\"]\n"; << arg->get_name() << "]\"]\n";
m_ss << " " << send_node_name << "[shape=\"box\" style=\"solid,filled\" " m_ss << " " << send_node_name << "[shape=\"box\" style=\"solid,filled\" "
"fillcolor=\"#ccffcc\" label=\"Send[" "fillcolor=\"#ccffcc\" label=\"Send["
<< node->get_name() << "]\"]\n"; << node->get_name() << "]\"]\n";
m_ss << " " << arg->get_name() << " -> " << send_node_name m_ss << " " << arg->get_name() << " -> " << send_node_name
<< label_edge(arg, node, arg_index, jump_distance) << "\n"; << label_edge(arg, node, arg_index, jump_distance) << "\n";
m_ss << " " << recv_node_name << " -> " << node->get_name() m_ss << " " << recv_node_name << " -> " << node->get_name()
...@@ -262,19 +301,6 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions) ...@@ -262,19 +301,6 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
} }
arg_index++; arg_index++;
} }
});
}
render();
return false;
}
pass::VisualizeTree::VisualizeTree(const string& file_name, node_modifiers_t nm, bool dot_only)
: m_name{file_name}
, m_node_modifiers{nm}
, m_dot_only(dot_only)
{
} }
string pass::VisualizeTree::add_attributes(shared_ptr<Node> node) string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
...@@ -391,6 +417,17 @@ string pass::VisualizeTree::get_node_name(shared_ptr<Node> node) ...@@ -391,6 +417,17 @@ string pass::VisualizeTree::get_node_name(shared_ptr<Node> node)
{ {
rc += "\\n" + node->get_name(); rc += "\\n" + node->get_name();
} }
if (auto ck = dynamic_pointer_cast<ngraph::op::CompiledKernel>(node))
{
rc += "\\n{";
// add sub-graph node names
for (auto& ck_node : ck->get_node_list())
{
rc += ck_node->get_name();
rc += ", ";
}
rc += "}\\n";
}
return rc; return rc;
} }
......
...@@ -36,6 +36,8 @@ namespace ngraph ...@@ -36,6 +36,8 @@ namespace ngraph
} }
} }
class HeightMap;
class ngraph::pass::VisualizeTree : public ModulePass class ngraph::pass::VisualizeTree : public ModulePass
{ {
public: public:
...@@ -48,6 +50,9 @@ public: ...@@ -48,6 +50,9 @@ public:
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) { m_ops_to_details = ops_map; } void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) { m_ops_to_details = ops_map; }
private: private:
void add_node_arguments(std::shared_ptr<Node> node,
std::unordered_map<Node*, HeightMap>& height_maps,
size_t& fake_node_ctr);
std::string add_attributes(std::shared_ptr<Node> node); std::string add_attributes(std::shared_ptr<Node> node);
std::string get_attributes(std::shared_ptr<Node> node); std::string get_attributes(std::shared_ptr<Node> node);
std::string get_node_name(std::shared_ptr<Node> node); std::string get_node_name(std::shared_ptr<Node> node);
...@@ -60,4 +65,5 @@ private: ...@@ -60,4 +65,5 @@ private:
m_ops_to_details; m_ops_to_details;
node_modifiers_t m_node_modifiers = nullptr; node_modifiers_t m_node_modifiers = nullptr;
bool m_dot_only; bool m_dot_only;
static const int max_jump_distance;
}; };
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment