Commit f076fea9 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Add `--visualize` to nbench (#679)

* add visualize option to nbench

* check for dot, amend help msg
parent 2e1823fe
......@@ -114,6 +114,22 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
return ss.str();
}
std::string pass::VisualizeTree::get_file_ext()
{
const char* format = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT");
if (!format)
{
format = "png";
}
if (format[0] == '.')
{
format += 1;
}
return std::string(format);
}
void pass::VisualizeTree::render() const
{
#ifdef GRAPHVIZ_FOUND
......@@ -128,13 +144,7 @@ void pass::VisualizeTree::render() const
stringstream ss;
const char* format = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT");
if (!format)
{
format = "png";
}
ss << "dot -T" << format << " " << tmp_file << " -o " << m_name;
ss << "dot -T" << get_file_ext() << " " << tmp_file << " -o " << m_name;
auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r");
pclose(stream);
......
......@@ -36,6 +36,8 @@ public:
VisualizeTree(const std::string& file_name);
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
static std::string get_file_ext();
private:
std::string add_attributes(std::shared_ptr<Node> node);
std::string get_attributes(std::shared_ptr<Node> node);
......
......@@ -22,6 +22,9 @@
#include <fstream>
#include <ngraph/file_util.hpp>
#include <ngraph/file_util.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <ngraph/runtime/backend.hpp>
#include <ngraph/runtime/call_frame.hpp>
#include <ngraph/runtime/manager.hpp>
......@@ -41,6 +44,7 @@ int main(int argc, char** argv)
bool failed = false;
bool statistics = false;
bool timing_detail = false;
bool visualize = false;
for (size_t i = 1; i < argc; i++)
{
string arg = argv[i];
......@@ -72,6 +76,10 @@ int main(int argc, char** argv)
{
timing_detail = true;
}
else if (arg == "-v" || arg == "--visualize")
{
visualize = true;
}
else
{
cout << "Unknown option: " << arg << endl;
......@@ -98,6 +106,7 @@ OPTIONS
-b|--backend Backend to use (default: CPU)
-i|--iterations Iterations (default: 10)
-s|--statistics Display op stastics
-v|--visualize Visualize a model (WARNING: requires GraphViz installed)
--timing_detail Gather detailed timing
)###";
return 1;
......@@ -106,6 +115,17 @@ OPTIONS
const string json_string = file_util::read_file_to_string(model);
stringstream ss(json_string);
shared_ptr<Function> f = deserialize(ss);
if (visualize)
{
auto model_file_name = ngraph::file_util::get_file_name(model) + std::string(".") +
pass::VisualizeTree::get_file_ext();
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>(model_file_name);
pass_manager.run_passes(f);
}
if (statistics)
{
cout << "statistics:" << endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment