Commit 530a9e09 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Friendlier graph visualization (#3041)

* Use friendly_name in visualization

* sytle

* wip

* cleanup

* style
parent eb6bca0e
......@@ -228,7 +228,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
auto color = (arg->description() == "Parameter" ? "blue" : "black");
m_ss << " " << clone_name
<< "[shape=\"box\" style=\"dashed,filled\" color=\"" << color
<< "\" fillcolor=\"white\" label=\"" << arg->get_name() << "\"]\n";
<< "\" fillcolor=\"white\" label=\"" << get_node_name(arg) << "\"]\n";
m_ss << " " << clone_name << " -> " << node->get_name()
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
fake_node_ctr++;
......@@ -341,7 +341,7 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
// Construct the label attribute
{
stringstream label;
label << "label=\"" << node->get_name();
label << "label=\"" << get_node_name(node);
static const char* nvtos = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES");
if (nvtos != nullptr)
......@@ -384,6 +384,16 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
return ss.str();
}
string pass::VisualizeTree::get_node_name(shared_ptr<Node> node)
{
string rc = node->get_friendly_name();
if (node->get_friendly_name() != node->get_name())
{
rc += "\\n" + node->get_name();
}
return rc;
}
void pass::VisualizeTree::render() const
{
string ext = file_util::get_file_ext(m_name);
......
......@@ -50,6 +50,7 @@ public:
private:
std::string add_attributes(std::shared_ptr<Node> node);
std::string get_attributes(std::shared_ptr<Node> node);
std::string get_node_name(std::shared_ptr<Node> node);
void render() const;
std::stringstream m_ss;
......
......@@ -1600,6 +1600,10 @@ static shared_ptr<ngraph::Function>
{
node->set_friendly_name(friendly_name);
}
else
{
node->set_friendly_name(node_name);
}
node_map[node_name] = node;
}
catch (...)
......
......@@ -25,6 +25,8 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
......@@ -309,4 +311,16 @@ TEST(serialize, constant_infinity_nan)
EXPECT_TRUE(test::all_close_f(b->get_vector<float>(), b_data));
EXPECT_TRUE(test::all_close_f(c->get_vector<float>(), c_data));
EXPECT_EQ(d->get_vector<int64_t>(), d_data);
string filename = "constant_infinity_nan_test.dot";
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>(filename);
pass_manager.run_passes(g);
ifstream file(filename);
ASSERT_TRUE(file);
string str((istreambuf_iterator<char>(file)), istreambuf_iterator<char>());
EXPECT_NE(str.find(R"(label="A)"), string::npos);
EXPECT_NE(str.find(R"(label="B)"), string::npos);
EXPECT_NE(str.find(R"(label="C)"), string::npos);
EXPECT_NE(str.find(R"(label="D)"), string::npos);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment