Commit 19bdb2ff authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

Jmenon/nbench (#1365)

* Add an option to exclude the first iteration

* Switch to warmup iterations

* Cleanup
parent ad7e1ea1
...@@ -83,7 +83,8 @@ multimap<size_t, string> aggregate_timing(const vector<runtime::PerformanceCount ...@@ -83,7 +83,8 @@ multimap<size_t, string> aggregate_timing(const vector<runtime::PerformanceCount
void run_benchmark(const string& json_path, void run_benchmark(const string& json_path,
const string& backend_name, const string& backend_name,
size_t iterations, size_t iterations,
bool timing_detail) bool timing_detail,
int warmup_iterations)
{ {
stopwatch timer; stopwatch timer;
timer.start(); timer.start();
...@@ -92,7 +93,7 @@ void run_benchmark(const string& json_path, ...@@ -92,7 +93,7 @@ void run_benchmark(const string& json_path,
shared_ptr<Function> f = deserialize(ss); shared_ptr<Function> f = deserialize(ss);
timer.stop(); timer.stop();
cout << "deserialize time: " << timer.get_milliseconds() << "ms" << endl; cout << "deserialize time: " << timer.get_milliseconds() << "ms" << endl;
run_benchmark(f, backend_name, iterations, timing_detail); run_benchmark(f, backend_name, iterations, timing_detail, warmup_iterations);
} }
void print_times(const multimap<size_t, string>& timing) void print_times(const multimap<size_t, string>& timing)
...@@ -238,7 +239,8 @@ static void random_init(shared_ptr<runtime::TensorView> tv) ...@@ -238,7 +239,8 @@ static void random_init(shared_ptr<runtime::TensorView> tv)
void run_benchmark(shared_ptr<Function> f, void run_benchmark(shared_ptr<Function> f,
const string& backend_name, const string& backend_name,
size_t iterations, size_t iterations,
bool timing_detail) bool timing_detail,
int warmup_iterations)
{ {
stopwatch timer; stopwatch timer;
timer.start(); timer.start();
...@@ -272,6 +274,15 @@ void run_benchmark(shared_ptr<Function> f, ...@@ -272,6 +274,15 @@ void run_benchmark(shared_ptr<Function> f,
args[i]->set_stale(false); args[i]->set_stale(false);
} }
} }
if (warmup_iterations)
{
for (int i = 0; i < warmup_iterations; i++)
{
backend->call(f, results, args);
}
}
stopwatch t1; stopwatch t1;
t1.start(); t1.start();
for (size_t i = 0; i < static_cast<size_t>(iterations); i++) for (size_t i = 0; i < static_cast<size_t>(iterations); i++)
......
...@@ -31,9 +31,11 @@ std::multimap<size_t, std::string> ...@@ -31,9 +31,11 @@ std::multimap<size_t, std::string>
void run_benchmark(std::shared_ptr<ngraph::Function> f, void run_benchmark(std::shared_ptr<ngraph::Function> f,
const std::string& backend_name, const std::string& backend_name,
size_t iterations, size_t iterations,
bool timing_detail); bool timing_detail,
int warmup_iterations);
void run_benchmark(const std::string& json_path, void run_benchmark(const std::string& json_path,
const std::string& backend_name, const std::string& backend_name,
size_t iterations, size_t iterations,
bool timing_detail = false); bool timing_detail = false,
int warmup_iterations = 1);
...@@ -43,6 +43,8 @@ int main(int argc, char** argv) ...@@ -43,6 +43,8 @@ int main(int argc, char** argv)
bool statistics = false; bool statistics = false;
bool timing_detail = false; bool timing_detail = false;
bool visualize = false; bool visualize = false;
int warmup_iterations = 1;
for (size_t i = 1; i < argc; i++) for (size_t i = 1; i < argc; i++)
{ {
string arg = argv[i]; string arg = argv[i];
...@@ -82,6 +84,18 @@ int main(int argc, char** argv) ...@@ -82,6 +84,18 @@ int main(int argc, char** argv)
{ {
directory = argv[++i]; directory = argv[++i];
} }
else if (arg == "-w" || arg == "--warmup_iterations")
{
try
{
warmup_iterations = stoi(argv[++i]);
}
catch (...)
{
cout << "Invalid Argument\n";
failed = true;
}
}
else else
{ {
cout << "Unknown option: " << arg << endl; cout << "Unknown option: " << arg << endl;
...@@ -121,6 +135,7 @@ OPTIONS ...@@ -121,6 +135,7 @@ OPTIONS
-s|--statistics Display op stastics -s|--statistics Display op stastics
-v|--visualize Visualize a model (WARNING: requires GraphViz installed) -v|--visualize Visualize a model (WARNING: requires GraphViz installed)
--timing-detail Gather detailed timing --timing-detail Gather detailed timing
-w|--warmup_iterations Number of warm-up iterations
)###"; )###";
return 1; return 1;
} }
...@@ -190,7 +205,7 @@ OPTIONS ...@@ -190,7 +205,7 @@ OPTIONS
shared_ptr<Function> f = deserialize(m); shared_ptr<Function> f = deserialize(m);
cout << "Benchmarking " << m << ", " << backend << " backend, " << iterations cout << "Benchmarking " << m << ", " << backend << " backend, " << iterations
<< " iterations.\n"; << " iterations.\n";
run_benchmark(f, backend, iterations, timing_detail); run_benchmark(f, backend, iterations, timing_detail, warmup_iterations);
} }
catch (exception e) catch (exception e)
{ {
...@@ -203,7 +218,7 @@ OPTIONS ...@@ -203,7 +218,7 @@ OPTIONS
shared_ptr<Function> f = deserialize(model); shared_ptr<Function> f = deserialize(model);
cout << "Benchmarking " << model << ", " << backend << " backend, " << iterations cout << "Benchmarking " << model << ", " << backend << " backend, " << iterations
<< " iterations.\n"; << " iterations.\n";
run_benchmark(f, backend, iterations, timing_detail); run_benchmark(f, backend, iterations, timing_detail, warmup_iterations);
} }
return 0; return 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment