Commit c4c5c471 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

update visualize tree file extenstions and output formats (#2954)

* update visualize tree file extenstions and output formats

* fix runtime error
parent a65b5155
......@@ -16,6 +16,7 @@
#include <fstream>
#include "ngraph/file_util.hpp"
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
......@@ -347,25 +348,15 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
return ss.str();
}
string pass::VisualizeTree::get_file_ext()
void pass::VisualizeTree::render() const
{
const char* format = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT");
if (!format)
string ext = file_util::get_file_ext(m_name);
string output_format = ext.substr(1);
string dot_file = m_name;
if (to_lower(ext) != ".dot")
{
format = "dot";
dot_file += ".dot";
}
if (format[0] == '.')
{
format += 1;
}
return string(format);
}
void pass::VisualizeTree::render() const
{
auto dot_file = m_name + ".dot";
ofstream out(dot_file);
if (out)
{
......@@ -374,12 +365,11 @@ void pass::VisualizeTree::render() const
out << "}\n";
out.close();
if (!m_dot_only && get_file_ext() != "dot")
if (!m_dot_only && to_lower(ext) != ".dot")
{
#ifndef _WIN32
stringstream ss;
ss << "dot -T" << get_file_ext() << " " << dot_file << " -o" << m_name << "."
<< get_file_ext();
ss << "dot -T" << output_format << " " << dot_file << " -o" << m_name;
auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r");
if (stream)
......
......@@ -46,7 +46,6 @@ public:
bool dot_only = false);
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
static std::string get_file_ext();
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) { m_ops_to_details = ops_map; }
private:
std::string add_attributes(std::shared_ptr<Node> node);
......
......@@ -319,8 +319,8 @@ OPTIONS
if (visualize)
{
shared_ptr<Function> f = deserialize(model);
auto model_file_name = ngraph::file_util::get_file_name(model) + std::string(".") +
(dot_file ? "dot" : pass::VisualizeTree::get_file_ext());
auto model_file_name = ngraph::file_util::get_file_name(model) +
(dot_file ? ".dot" : ngraph::file_util::get_file_ext(model));
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>(model_file_name, nullptr, true);
......
......@@ -656,7 +656,7 @@ TEST(cpu_fusion, conv_bias_bprop)
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<pass::VisualizeTree>("conv_bias_bprop_fusion");
pass_manager.register_pass<pass::VisualizeTree>("conv_bias_bprop_fusion.png");
auto f = make_shared<Function>(conv_bias, ParameterVector{data_batch, filters, bias});
ngraph::autodiff::Adjoints adjoints(NodeVector{conv_bias}, NodeVector{delta});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment