Unverified Commit 1974aed1 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Test tools - small refactoring (#1471)

parent 0804f5e2
...@@ -96,13 +96,13 @@ size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f) ...@@ -96,13 +96,13 @@ size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
} }
template <typename T> template <typename T>
std::vector<std::vector<T>> execute(std::shared_ptr<ngraph::Function> f, std::vector<std::vector<T>> execute(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T>> args, std::vector<std::vector<T>> args,
std::string cbackend) const std::string& backend_id)
{ {
auto backend = ngraph::runtime::Backend::create(cbackend); auto backend = ngraph::runtime::Backend::create(backend_id);
auto parms = f->get_parameters(); auto parms = function->get_parameters();
if (parms.size() != args.size()) if (parms.size() != args.size())
{ {
...@@ -117,7 +117,7 @@ std::vector<std::vector<T>> execute(std::shared_ptr<ngraph::Function> f, ...@@ -117,7 +117,7 @@ std::vector<std::vector<T>> execute(std::shared_ptr<ngraph::Function> f,
arg_tensors.at(i) = t; arg_tensors.at(i) = t;
} }
auto results = f->get_results(); auto results = function->get_results();
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> result_tensors(results.size()); std::vector<std::shared_ptr<ngraph::runtime::TensorView>> result_tensors(results.size());
for (size_t i = 0; i < results.size(); i++) for (size_t i = 0; i < results.size(); i++)
...@@ -126,7 +126,7 @@ std::vector<std::vector<T>> execute(std::shared_ptr<ngraph::Function> f, ...@@ -126,7 +126,7 @@ std::vector<std::vector<T>> execute(std::shared_ptr<ngraph::Function> f,
backend->create_tensor(results.at(i)->get_element_type(), results.at(i)->get_shape()); backend->create_tensor(results.at(i)->get_element_type(), results.at(i)->get_shape());
} }
backend->call_with_validate(f, result_tensors, arg_tensors); backend->call_with_validate(function, result_tensors, arg_tensors);
std::vector<std::vector<T>> result_vectors; std::vector<std::vector<T>> result_vectors;
for (auto rt : result_tensors) for (auto rt : result_tensors)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment