Commit d4df5695 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

[ONNX] Extended types support for NgraphTestCase (#2801)

* [ONNX] Extended types support for NgraphTestCase

* [ONNX] Move the value comparators to the NgraphTestCase class
parent 490d6698
...@@ -15,43 +15,55 @@ ...@@ -15,43 +15,55 @@
//***************************************************************************** //*****************************************************************************
#include "test_case.hpp" #include "test_case.hpp"
#include "all_close.hpp"
#include "all_close_f.hpp"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "test_tools.hpp"
void ngraph::test::NgraphTestCase::run() namespace ngraph
{ {
const auto& function_results = m_function->get_results(); namespace test
NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number of results.");
auto handle = m_backend->compile(m_function);
handle->call_with_validate(m_result_tensors, m_input_tensors);
for (int i = 0; i < m_expected_outputs.size(); ++i)
{ {
const auto& result_tensor = m_result_tensors.at(i); std::map<ngraph::element::Type_t, NgraphTestCase::value_comparator_function>
const auto& expected_result_constant = m_expected_outputs.at(i); NgraphTestCase::value_comparators = {
const auto& element_type = result_tensor->get_element_type(); {ngraph::element::Type_t::f32, NgraphTestCase::compare_values<float>},
{ngraph::element::Type_t::f64, NgraphTestCase::compare_values<double>},
{ngraph::element::Type_t::i8, NgraphTestCase::compare_values<int8_t>},
{ngraph::element::Type_t::i16, NgraphTestCase::compare_values<int16_t>},
{ngraph::element::Type_t::i32, NgraphTestCase::compare_values<int32_t>},
{ngraph::element::Type_t::i64, NgraphTestCase::compare_values<int64_t>},
{ngraph::element::Type_t::u8, NgraphTestCase::compare_values<uint8_t>},
{ngraph::element::Type_t::u16, NgraphTestCase::compare_values<uint16_t>},
{ngraph::element::Type_t::u32, NgraphTestCase::compare_values<uint32_t>},
{ngraph::element::Type_t::u64, NgraphTestCase::compare_values<uint64_t>}};
if (element_type == ngraph::element::f32) void NgraphTestCase::run()
{ {
const auto result = read_vector<float>(result_tensor); const auto& function_results = m_function->get_results();
const auto expected = expected_result_constant->get_vector<float>(); NGRAPH_CHECK(
EXPECT_TRUE(test::all_close_f(expected, result)); m_expected_outputs.size() == function_results.size(),
} "Expected number of outputs is different from the function's number of results.");
else if (element_type == ngraph::element::u8)
{ auto handle = m_backend->compile(m_function);
const auto result = read_vector<uint8_t>(result_tensor); handle->call_with_validate(m_result_tensors, m_input_tensors);
const auto expected = expected_result_constant->get_vector<uint8_t>();
EXPECT_TRUE(test::all_close(expected, result)); for (int i = 0; i < m_expected_outputs.size(); ++i)
} {
else const auto& result_tensor = m_result_tensors.at(i);
{ const auto& expected_result_constant = m_expected_outputs.at(i);
NGRAPH_FAIL() << "Please add support for " << element_type const auto& element_type = result_tensor->get_element_type();
<< " to ngraph::test::NgraphTestCase::run().";
if (value_comparators.count(element_type.get_type_enum()) == 0)
{
NGRAPH_FAIL() << "Please add support for " << element_type
<< " to ngraph::test::NgraphTestCase::run()";
}
else
{
auto values_match = value_comparators.at(element_type.get_type_enum());
EXPECT_TRUE(values_match(expected_result_constant, result_tensor));
}
}
} }
} }
} }
...@@ -18,8 +18,11 @@ ...@@ -18,8 +18,11 @@
#include <utility> #include <utility>
#include "all_close.hpp"
#include "all_close_f.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "test_tools.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -83,6 +86,34 @@ namespace ngraph ...@@ -83,6 +86,34 @@ namespace ngraph
void run(); void run();
protected: protected:
template <typename T>
static typename std::enable_if<std::is_floating_point<T>::value,
::testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results)
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
return ngraph::test::all_close_f(expected, result);
}
template <typename T>
static typename std::enable_if<std::is_integral<T>::value,
::testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results)
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
return ngraph::test::all_close(expected, result);
}
using value_comparator_function = std::function<::testing::AssertionResult(
const std::shared_ptr<ngraph::op::Constant>&,
const std::shared_ptr<ngraph::runtime::Tensor>&)>;
static std::map<ngraph::element::Type_t, value_comparator_function> value_comparators;
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
std::unique_ptr<runtime::Backend> m_backend; std::unique_ptr<runtime::Backend> m_backend;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors; std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment