Commit 530a9e09 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Friendlier graph visualization (#3041)

* Use friendly_name in visualization

* sytle

* wip

* cleanup

* style
parent eb6bca0e
...@@ -228,7 +228,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions) ...@@ -228,7 +228,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
auto color = (arg->description() == "Parameter" ? "blue" : "black"); auto color = (arg->description() == "Parameter" ? "blue" : "black");
m_ss << " " << clone_name m_ss << " " << clone_name
<< "[shape=\"box\" style=\"dashed,filled\" color=\"" << color << "[shape=\"box\" style=\"dashed,filled\" color=\"" << color
<< "\" fillcolor=\"white\" label=\"" << arg->get_name() << "\"]\n"; << "\" fillcolor=\"white\" label=\"" << get_node_name(arg) << "\"]\n";
m_ss << " " << clone_name << " -> " << node->get_name() m_ss << " " << clone_name << " -> " << node->get_name()
<< label_edge(arg, node, arg_index, jump_distance) << "\n"; << label_edge(arg, node, arg_index, jump_distance) << "\n";
fake_node_ctr++; fake_node_ctr++;
...@@ -341,7 +341,7 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node) ...@@ -341,7 +341,7 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
// Construct the label attribute // Construct the label attribute
{ {
stringstream label; stringstream label;
label << "label=\"" << node->get_name(); label << "label=\"" << get_node_name(node);
static const char* nvtos = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES"); static const char* nvtos = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES");
if (nvtos != nullptr) if (nvtos != nullptr)
...@@ -384,6 +384,16 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node) ...@@ -384,6 +384,16 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
return ss.str(); return ss.str();
} }
string pass::VisualizeTree::get_node_name(shared_ptr<Node> node)
{
string rc = node->get_friendly_name();
if (node->get_friendly_name() != node->get_name())
{
rc += "\\n" + node->get_name();
}
return rc;
}
void pass::VisualizeTree::render() const void pass::VisualizeTree::render() const
{ {
string ext = file_util::get_file_ext(m_name); string ext = file_util::get_file_ext(m_name);
......
...@@ -50,6 +50,7 @@ public: ...@@ -50,6 +50,7 @@ public:
private: private:
std::string add_attributes(std::shared_ptr<Node> node); std::string add_attributes(std::shared_ptr<Node> node);
std::string get_attributes(std::shared_ptr<Node> node); std::string get_attributes(std::shared_ptr<Node> node);
std::string get_node_name(std::shared_ptr<Node> node);
void render() const; void render() const;
std::stringstream m_ss; std::stringstream m_ss;
......
...@@ -1600,6 +1600,10 @@ static shared_ptr<ngraph::Function> ...@@ -1600,6 +1600,10 @@ static shared_ptr<ngraph::Function>
{ {
node->set_friendly_name(friendly_name); node->set_friendly_name(friendly_name);
} }
else
{
node->set_friendly_name(node_name);
}
node_map[node_name] = node; node_map[node_name] = node;
} }
catch (...) catch (...)
......
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/passthrough.hpp" #include "ngraph/op/passthrough.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
...@@ -309,4 +311,16 @@ TEST(serialize, constant_infinity_nan) ...@@ -309,4 +311,16 @@ TEST(serialize, constant_infinity_nan)
EXPECT_TRUE(test::all_close_f(b->get_vector<float>(), b_data)); EXPECT_TRUE(test::all_close_f(b->get_vector<float>(), b_data));
EXPECT_TRUE(test::all_close_f(c->get_vector<float>(), c_data)); EXPECT_TRUE(test::all_close_f(c->get_vector<float>(), c_data));
EXPECT_EQ(d->get_vector<int64_t>(), d_data); EXPECT_EQ(d->get_vector<int64_t>(), d_data);
string filename = "constant_infinity_nan_test.dot";
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>(filename);
pass_manager.run_passes(g);
ifstream file(filename);
ASSERT_TRUE(file);
string str((istreambuf_iterator<char>(file)), istreambuf_iterator<char>());
EXPECT_NE(str.find(R"(label="A)"), string::npos);
EXPECT_NE(str.find(R"(label="B)"), string::npos);
EXPECT_NE(str.find(R"(label="C)"), string::npos);
EXPECT_NE(str.find(R"(label="D)"), string::npos);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment