Commit a51c2e80 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

Graph Visualization of CompilerKernel (#3515)

* Print CK info in visualizer

* style-apply

* Fix parameters printing

*  Fixes

* typo
parent 1683e200
......@@ -20,7 +20,10 @@
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
......@@ -109,6 +112,9 @@ using namespace std;
// be careful to avoid splitting the components. I have some rough ideas on how this could be
// dealt with, but have not had time to implement them yet. --amprocte
//
const int ngraph::pass::VisualizeTree::max_jump_distance = 20;
class HeightMap
{
public:
......@@ -212,56 +218,28 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
}
// TODO(amprocte): Maybe find a way to make this tunable.
const int max_jump_distance = 20;
size_t fake_node_ctr = 0;
traverse_nodes(f, [&](shared_ptr<Node> node) {
size_t arg_index = 0;
for (auto arg : node->get_arguments())
if (auto ck = dynamic_pointer_cast<ngraph::op::CompiledKernel>(node))
{
size_t jump_distance = height_maps[arg.get()].max_jump_to(height_maps[node.get()]);
// print sub-graph
auto nodes_list = ck->get_node_list();
if (arg->description() == "Constant" || arg->description() == "Parameter")
{
auto clone_name = "CLONE_" + to_string(fake_node_ctr);
auto color = (arg->description() == "Parameter" ? "blue" : "black");
m_ss << " " << clone_name
<< "[shape=\"box\" style=\"dashed,filled\" color=\"" << color
<< "\" fillcolor=\"white\" label=\"" << get_node_name(arg) << "\"]\n";
m_ss << " " << clone_name << " -> " << node->get_name()
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
fake_node_ctr++;
}
else if (jump_distance > max_jump_distance)
// all nodes inside the CK sub-graph
for (auto& ck_node : nodes_list)
{
m_ss << add_attributes(arg);
m_ss << add_attributes(node);
auto recv_node_name = "RECV_" + to_string(fake_node_ctr);
auto send_node_name = "SEND_" + to_string(fake_node_ctr);
m_ss << " " << recv_node_name << "[shape=\"box\" style=\"solid,filled\" "
"fillcolor=\"#ffcccc\" label=\"Receive["
<< arg->get_name() << "]\"]\n";
m_ss << " " << send_node_name << "[shape=\"box\" style=\"solid,filled\" "
"fillcolor=\"#ccffcc\" label=\"Send["
<< node->get_name() << "]\"]\n";
m_ss << " " << arg->get_name() << " -> " << send_node_name
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
m_ss << " " << recv_node_name << " -> " << node->get_name()
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
fake_node_ctr++;
m_ss << add_attributes(ck_node);
}
else
// all edges to each node in the sub-graph
for (auto& subgraph_node : nodes_list)
{
m_ss << add_attributes(arg);
m_ss << add_attributes(node);
m_ss << " " << arg->get_name() << " -> " << node->get_name()
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
add_node_arguments(subgraph_node, height_maps, fake_node_ctr);
}
arg_index++;
}
add_node_arguments(node, height_maps, fake_node_ctr);
});
}
......@@ -277,6 +255,54 @@ pass::VisualizeTree::VisualizeTree(const string& file_name, node_modifiers_t nm,
{
}
void pass::VisualizeTree::add_node_arguments(shared_ptr<Node> node,
unordered_map<Node*, HeightMap>& height_maps,
size_t& fake_node_ctr)
{
size_t arg_index = 0;
for (auto arg : node->get_arguments())
{
size_t jump_distance = height_maps[arg.get()].max_jump_to(height_maps[node.get()]);
if (arg->description() == ngraph::op::Constant::type_name ||
arg->description() == ngraph::op::Parameter::type_name)
{
auto clone_name = "CLONE_" + to_string(fake_node_ctr);
auto color = (arg->description() == "Parameter" ? "blue" : "black");
m_ss << " " << clone_name << "[shape=\"box\" style=\"dashed,filled\" color=\""
<< color << "\" fillcolor=\"white\" label=\"" << get_node_name(arg) << "\"]\n";
m_ss << " " << clone_name << " -> " << node->get_name()
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
fake_node_ctr++;
}
else if (jump_distance > max_jump_distance)
{
m_ss << add_attributes(arg);
m_ss << add_attributes(node);
auto recv_node_name = "RECV_" + to_string(fake_node_ctr);
auto send_node_name = "SEND_" + to_string(fake_node_ctr);
m_ss << " " << recv_node_name << "[shape=\"box\" style=\"solid,filled\" "
"fillcolor=\"#ffcccc\" label=\"Receive["
<< arg->get_name() << "]\"]\n";
m_ss << " " << send_node_name << "[shape=\"box\" style=\"solid,filled\" "
"fillcolor=\"#ccffcc\" label=\"Send["
<< node->get_name() << "]\"]\n";
m_ss << " " << arg->get_name() << " -> " << send_node_name
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
m_ss << " " << recv_node_name << " -> " << node->get_name()
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
fake_node_ctr++;
}
else
{
m_ss << add_attributes(arg);
m_ss << add_attributes(node);
m_ss << " " << arg->get_name() << " -> " << node->get_name()
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
}
arg_index++;
}
}
string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
{
string rc;
......@@ -391,6 +417,17 @@ string pass::VisualizeTree::get_node_name(shared_ptr<Node> node)
{
rc += "\\n" + node->get_name();
}
if (auto ck = dynamic_pointer_cast<ngraph::op::CompiledKernel>(node))
{
rc += "\\n{";
// add sub-graph node names
for (auto& ck_node : ck->get_node_list())
{
rc += ck_node->get_name();
rc += ", ";
}
rc += "}\\n";
}
return rc;
}
......
......@@ -36,6 +36,8 @@ namespace ngraph
}
}
class HeightMap;
class ngraph::pass::VisualizeTree : public ModulePass
{
public:
......@@ -48,6 +50,9 @@ public:
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) { m_ops_to_details = ops_map; }
private:
void add_node_arguments(std::shared_ptr<Node> node,
std::unordered_map<Node*, HeightMap>& height_maps,
size_t& fake_node_ctr);
std::string add_attributes(std::shared_ptr<Node> node);
std::string get_attributes(std::shared_ptr<Node> node);
std::string get_node_name(std::shared_ptr<Node> node);
......@@ -60,4 +65,5 @@ private:
m_ops_to_details;
node_modifiers_t m_node_modifiers = nullptr;
bool m_dot_only;
static const int max_jump_distance;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment