Unverified Commit 80af2c7f authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

[Tests] Add NgraphTestCase test runner class (#2789)

parent e4d5355b
......@@ -30,6 +30,7 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
......@@ -41,7 +42,31 @@ using Inputs = std::vector<std::vector<float>>;
using Outputs = std::vector<std::vector<float>>;
// ############################################################################ CORE TESTS
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_output_names_check)
NGRAPH_TEST(onnx_${BACKEND_NAME}, test_test_case)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/add_abc.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<float>{1});
test_case.add_input(std::vector<float>{2});
test_case.add_input(std::vector<float>{3});
test_case.add_expected_output(std::vector<float>{6});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, test_test_case_mutliple_inputs)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/add_abc.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_multiple_inputs(Inputs{{1}, {2}, {3}});
test_case.add_expected_output(std::vector<float>{6});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, output_names_check)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_default.prototxt"));
......
......@@ -20,6 +20,7 @@ set (SRC
float_util.cpp
test_tools.cpp
test_control.cpp
test_case.cpp
)
add_library(ngraph_test_util STATIC ${SRC})
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "test_case.hpp"
#include "all_close.hpp"
#include "all_close_f.hpp"
#include "gtest/gtest.h"
#include "ngraph/assertion.hpp"
#include "test_tools.hpp"
void ngraph::test::NgraphTestCase::run()
{
const auto& function_results = m_function->get_results();
NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number of results.");
auto handle = m_backend->compile(m_function);
handle->call_with_validate(m_result_tensors, m_input_tensors);
for (int i = 0; i < m_expected_outputs.size(); ++i)
{
const auto& result_tensor = m_result_tensors.at(i);
const auto& expected_result_constant = m_expected_outputs.at(i);
const auto& element_type = result_tensor->get_element_type();
if (element_type == ngraph::element::f32)
{
const auto result = read_vector<float>(result_tensor);
const auto expected = expected_result_constant->get_vector<float>();
EXPECT_TRUE(test::all_close_f(expected, result));
}
else if (element_type == ngraph::element::u8)
{
const auto result = read_vector<uint8_t>(result_tensor);
const auto expected = expected_result_constant->get_vector<uint8_t>();
EXPECT_TRUE(test::all_close(expected, result));
}
else
{
NGRAPH_FAIL() << "Please add support for " << element_type
<< " to ngraph::test::NgraphTestCase::run().";
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <utility>
#include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp"
namespace ngraph
{
namespace test
{
class NgraphTestCase
{
public:
NgraphTestCase(const std::shared_ptr<Function>& function,
const std::string& backend_name)
: m_function(function)
, m_backend(ngraph::runtime::Backend::create(backend_name))
{
}
template <typename T>
void add_input(const std::vector<T>& values)
{
auto params = m_function->get_parameters();
NGRAPH_CHECK(m_input_index < params.size(),
"All function parameters already have inputs.");
auto tensor = m_backend->create_tensor(params.at(m_input_index)->get_element_type(),
params.at(m_input_index)->get_shape());
copy_data(tensor, values);
m_input_tensors.push_back(tensor);
++m_input_index;
}
template <typename T>
void add_multiple_inputs(const std::vector<std::vector<T>>& vector_of_values)
{
for (const auto& value : vector_of_values)
{
add_input(value);
}
}
template <typename T>
void add_expected_output(const std::vector<T>& values)
{
auto results = m_function->get_results();
NGRAPH_CHECK(m_output_index < results.size(),
"All function results already have expected outputs.");
auto function_output_type = results.at(m_output_index)->get_element_type();
auto function_output_shape = results.at(m_output_index)->get_shape();
m_result_tensors.emplace_back(
m_backend->create_tensor(function_output_type, function_output_shape));
m_expected_outputs.emplace_back(std::make_shared<ngraph::op::Constant>(
function_output_type, function_output_shape, values));
++m_output_index;
}
void run();
protected:
std::shared_ptr<Function> m_function;
std::unique_ptr<runtime::Backend> m_backend;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_result_tensors;
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
int m_input_index = 0;
int m_output_index = 0;
};
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment