Unverified Commit 2b26df18 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

nbench: add option to run all models in a directory (#1279)

* add option to run all models in a directory

* add print for exception from benchmark
parent 11b992a7
...@@ -37,6 +37,7 @@ int main(int argc, char** argv) ...@@ -37,6 +37,7 @@ int main(int argc, char** argv)
{ {
string model; string model;
string backend = "CPU"; string backend = "CPU";
string directory;
int iterations = 10; int iterations = 10;
bool failed = false; bool failed = false;
bool statistics = false; bool statistics = false;
...@@ -69,7 +70,7 @@ int main(int argc, char** argv) ...@@ -69,7 +70,7 @@ int main(int argc, char** argv)
{ {
statistics = true; statistics = true;
} }
else if (arg == "--timing_detail") else if (arg == "--timing_detail" || arg == "--timing-detail")
{ {
timing_detail = true; timing_detail = true;
} }
...@@ -77,17 +78,31 @@ int main(int argc, char** argv) ...@@ -77,17 +78,31 @@ int main(int argc, char** argv)
{ {
visualize = true; visualize = true;
} }
else if (arg == "-d" || arg == "--directory")
{
directory = argv[++i];
}
else else
{ {
cout << "Unknown option: " << arg << endl; cout << "Unknown option: " << arg << endl;
failed = true; failed = true;
} }
} }
if (!static_cast<bool>(ifstream(model))) if (!model.empty() && !file_util::exists(model))
{ {
cout << "File " << model << " not found\n"; cout << "File " << model << " not found\n";
failed = true; failed = true;
} }
else if (!directory.empty() && !file_util::exists(directory))
{
cout << "Directory " << model << " not found\n";
failed = true;
}
else if (directory.empty() && model.empty())
{
cout << "Either file or directory must be specified\n";
failed = true;
}
if (failed) if (failed)
{ {
...@@ -101,20 +116,18 @@ SYNOPSIS ...@@ -101,20 +116,18 @@ SYNOPSIS
OPTIONS OPTIONS
-f|--file Serialized model file -f|--file Serialized model file
-b|--backend Backend to use (default: CPU) -b|--backend Backend to use (default: CPU)
-d|--directory Directory to scan for models. All models are benchmarked.
-i|--iterations Iterations (default: 10) -i|--iterations Iterations (default: 10)
-s|--statistics Display op stastics -s|--statistics Display op stastics
-v|--visualize Visualize a model (WARNING: requires GraphViz installed) -v|--visualize Visualize a model (WARNING: requires GraphViz installed)
--timing_detail Gather detailed timing --timing-detail Gather detailed timing
)###"; )###";
return 1; return 1;
} }
const string json_string = file_util::read_file_to_string(model);
stringstream ss(json_string);
shared_ptr<Function> f = deserialize(ss);
if (visualize) if (visualize)
{ {
shared_ptr<Function> f = deserialize(model);
auto model_file_name = ngraph::file_util::get_file_name(model) + std::string(".") + auto model_file_name = ngraph::file_util::get_file_name(model) + std::string(".") +
pass::VisualizeTree::get_file_ext(); pass::VisualizeTree::get_file_ext();
...@@ -125,6 +138,8 @@ OPTIONS ...@@ -125,6 +138,8 @@ OPTIONS
if (statistics) if (statistics)
{ {
shared_ptr<Function> f = deserialize(model);
cout << "statistics:" << endl; cout << "statistics:" << endl;
cout << "total nodes: " << f->get_ops().size() << endl; cout << "total nodes: " << f->get_ops().size() << endl;
size_t total_constant_bytes = 0; size_t total_constant_bytes = 0;
...@@ -157,8 +172,35 @@ OPTIONS ...@@ -157,8 +172,35 @@ OPTIONS
cout << op_info.first << ": " << op_info.second << " ops" << endl; cout << op_info.first << ": " << op_info.second << " ops" << endl;
} }
} }
else if (!directory.empty())
{
vector<string> models;
file_util::iterate_files(directory,
[&](const string& file, bool is_dir) {
if (!is_dir)
{
models.push_back(file);
}
},
true);
for (const string& m : models)
{
try
{
shared_ptr<Function> f = deserialize(m);
cout << "Benchmarking " << m << ", " << backend << " backend, " << iterations
<< " iterations.\n";
run_benchmark(f, backend, iterations, timing_detail);
}
catch (exception e)
{
cout << "Exception caught on '" << m << "'\n" << e.what() << endl;
}
}
}
else if (iterations > 0) else if (iterations > 0)
{ {
shared_ptr<Function> f = deserialize(model);
cout << "Benchmarking " << model << ", " << backend << " backend, " << iterations cout << "Benchmarking " << model << ", " << backend << " backend, " << iterations
<< " iterations.\n"; << " iterations.\n";
run_benchmark(f, backend, iterations, timing_detail); run_benchmark(f, backend, iterations, timing_detail);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment