Unverified Commit 1974aed1 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Test tools - small refactoring (#1471)

parent 0804f5e2
......@@ -96,13 +96,13 @@ size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
}
template <typename T>
std::vector<std::vector<T>> execute(std::shared_ptr<ngraph::Function> f,
std::vector<std::vector<T>> execute(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T>> args,
std::string cbackend)
const std::string& backend_id)
{
auto backend = ngraph::runtime::Backend::create(cbackend);
auto backend = ngraph::runtime::Backend::create(backend_id);
auto parms = f->get_parameters();
auto parms = function->get_parameters();
if (parms.size() != args.size())
{
......@@ -117,7 +117,7 @@ std::vector<std::vector<T>> execute(std::shared_ptr<ngraph::Function> f,
arg_tensors.at(i) = t;
}
auto results = f->get_results();
auto results = function->get_results();
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> result_tensors(results.size());
for (size_t i = 0; i < results.size(); i++)
......@@ -126,7 +126,7 @@ std::vector<std::vector<T>> execute(std::shared_ptr<ngraph::Function> f,
backend->create_tensor(results.at(i)->get_element_type(), results.at(i)->get_shape());
}
backend->call_with_validate(f, result_tensors, arg_tensors);
backend->call_with_validate(function, result_tensors, arg_tensors);
std::vector<std::vector<T>> result_vectors;
for (auto rt : result_tensors)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment