Unverified Commit 2b26df18 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

nbench: add option to run all models in a directory (#1279)

* add option to run all models in a directory

* add print for exception from benchmark
parent 11b992a7
......@@ -37,6 +37,7 @@ int main(int argc, char** argv)
{
string model;
string backend = "CPU";
string directory;
int iterations = 10;
bool failed = false;
bool statistics = false;
......@@ -69,7 +70,7 @@ int main(int argc, char** argv)
{
statistics = true;
}
else if (arg == "--timing_detail")
else if (arg == "--timing_detail" || arg == "--timing-detail")
{
timing_detail = true;
}
......@@ -77,17 +78,31 @@ int main(int argc, char** argv)
{
visualize = true;
}
else if (arg == "-d" || arg == "--directory")
{
directory = argv[++i];
}
else
{
cout << "Unknown option: " << arg << endl;
failed = true;
}
}
if (!static_cast<bool>(ifstream(model)))
if (!model.empty() && !file_util::exists(model))
{
cout << "File " << model << " not found\n";
failed = true;
}
else if (!directory.empty() && !file_util::exists(directory))
{
cout << "Directory " << model << " not found\n";
failed = true;
}
else if (directory.empty() && model.empty())
{
cout << "Either file or directory must be specified\n";
failed = true;
}
if (failed)
{
......@@ -101,20 +116,18 @@ SYNOPSIS
OPTIONS
-f|--file Serialized model file
-b|--backend Backend to use (default: CPU)
-d|--directory Directory to scan for models. All models are benchmarked.
-i|--iterations Iterations (default: 10)
-s|--statistics Display op stastics
-v|--visualize Visualize a model (WARNING: requires GraphViz installed)
--timing_detail Gather detailed timing
--timing-detail Gather detailed timing
)###";
return 1;
}
const string json_string = file_util::read_file_to_string(model);
stringstream ss(json_string);
shared_ptr<Function> f = deserialize(ss);
if (visualize)
{
shared_ptr<Function> f = deserialize(model);
auto model_file_name = ngraph::file_util::get_file_name(model) + std::string(".") +
pass::VisualizeTree::get_file_ext();
......@@ -125,6 +138,8 @@ OPTIONS
if (statistics)
{
shared_ptr<Function> f = deserialize(model);
cout << "statistics:" << endl;
cout << "total nodes: " << f->get_ops().size() << endl;
size_t total_constant_bytes = 0;
......@@ -157,8 +172,35 @@ OPTIONS
cout << op_info.first << ": " << op_info.second << " ops" << endl;
}
}
else if (!directory.empty())
{
vector<string> models;
file_util::iterate_files(directory,
[&](const string& file, bool is_dir) {
if (!is_dir)
{
models.push_back(file);
}
},
true);
for (const string& m : models)
{
try
{
shared_ptr<Function> f = deserialize(m);
cout << "Benchmarking " << m << ", " << backend << " backend, " << iterations
<< " iterations.\n";
run_benchmark(f, backend, iterations, timing_detail);
}
catch (exception e)
{
cout << "Exception caught on '" << m << "'\n" << e.what() << endl;
}
}
}
else if (iterations > 0)
{
shared_ptr<Function> f = deserialize(model);
cout << "Benchmarking " << model << ", " << backend << " backend, " << iterations
<< " iterations.\n";
run_benchmark(f, backend, iterations, timing_detail);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment