Commit f076fea9 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Add `--visualize` to nbench (#679)

* add visualize option to nbench

* check for dot, amend help msg
parent 2e1823fe
...@@ -114,6 +114,22 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node) ...@@ -114,6 +114,22 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
return ss.str(); return ss.str();
} }
std::string pass::VisualizeTree::get_file_ext()
{
const char* format = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT");
if (!format)
{
format = "png";
}
if (format[0] == '.')
{
format += 1;
}
return std::string(format);
}
void pass::VisualizeTree::render() const void pass::VisualizeTree::render() const
{ {
#ifdef GRAPHVIZ_FOUND #ifdef GRAPHVIZ_FOUND
...@@ -128,13 +144,7 @@ void pass::VisualizeTree::render() const ...@@ -128,13 +144,7 @@ void pass::VisualizeTree::render() const
stringstream ss; stringstream ss;
const char* format = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT"); ss << "dot -T" << get_file_ext() << " " << tmp_file << " -o " << m_name;
if (!format)
{
format = "png";
}
ss << "dot -T" << format << " " << tmp_file << " -o " << m_name;
auto cmd = ss.str(); auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r"); auto stream = popen(cmd.c_str(), "r");
pclose(stream); pclose(stream);
......
...@@ -36,6 +36,8 @@ public: ...@@ -36,6 +36,8 @@ public:
VisualizeTree(const std::string& file_name); VisualizeTree(const std::string& file_name);
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override; bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
static std::string get_file_ext();
private: private:
std::string add_attributes(std::shared_ptr<Node> node); std::string add_attributes(std::shared_ptr<Node> node);
std::string get_attributes(std::shared_ptr<Node> node); std::string get_attributes(std::shared_ptr<Node> node);
......
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#include <fstream> #include <fstream>
#include <ngraph/file_util.hpp> #include <ngraph/file_util.hpp>
#include <ngraph/file_util.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <ngraph/runtime/backend.hpp> #include <ngraph/runtime/backend.hpp>
#include <ngraph/runtime/call_frame.hpp> #include <ngraph/runtime/call_frame.hpp>
#include <ngraph/runtime/manager.hpp> #include <ngraph/runtime/manager.hpp>
...@@ -41,6 +44,7 @@ int main(int argc, char** argv) ...@@ -41,6 +44,7 @@ int main(int argc, char** argv)
bool failed = false; bool failed = false;
bool statistics = false; bool statistics = false;
bool timing_detail = false; bool timing_detail = false;
bool visualize = false;
for (size_t i = 1; i < argc; i++) for (size_t i = 1; i < argc; i++)
{ {
string arg = argv[i]; string arg = argv[i];
...@@ -72,6 +76,10 @@ int main(int argc, char** argv) ...@@ -72,6 +76,10 @@ int main(int argc, char** argv)
{ {
timing_detail = true; timing_detail = true;
} }
else if (arg == "-v" || arg == "--visualize")
{
visualize = true;
}
else else
{ {
cout << "Unknown option: " << arg << endl; cout << "Unknown option: " << arg << endl;
...@@ -98,6 +106,7 @@ OPTIONS ...@@ -98,6 +106,7 @@ OPTIONS
-b|--backend Backend to use (default: CPU) -b|--backend Backend to use (default: CPU)
-i|--iterations Iterations (default: 10) -i|--iterations Iterations (default: 10)
-s|--statistics Display op stastics -s|--statistics Display op stastics
-v|--visualize Visualize a model (WARNING: requires GraphViz installed)
--timing_detail Gather detailed timing --timing_detail Gather detailed timing
)###"; )###";
return 1; return 1;
...@@ -106,6 +115,17 @@ OPTIONS ...@@ -106,6 +115,17 @@ OPTIONS
const string json_string = file_util::read_file_to_string(model); const string json_string = file_util::read_file_to_string(model);
stringstream ss(json_string); stringstream ss(json_string);
shared_ptr<Function> f = deserialize(ss); shared_ptr<Function> f = deserialize(ss);
if (visualize)
{
auto model_file_name = ngraph::file_util::get_file_name(model) + std::string(".") +
pass::VisualizeTree::get_file_ext();
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>(model_file_name);
pass_manager.run_passes(f);
}
if (statistics) if (statistics)
{ {
cout << "statistics:" << endl; cout << "statistics:" << endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment