Commit d4df5695 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

[ONNX] Extended types support for NgraphTestCase (#2801)

* [ONNX] Extended types support for NgraphTestCase

* [ONNX] Move the value comparators to the NgraphTestCase class
parent 490d6698
...@@ -15,16 +15,32 @@ ...@@ -15,16 +15,32 @@
//***************************************************************************** //*****************************************************************************
#include "test_case.hpp" #include "test_case.hpp"
#include "all_close.hpp"
#include "all_close_f.hpp"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "test_tools.hpp"
void ngraph::test::NgraphTestCase::run() namespace ngraph
{ {
namespace test
{
std::map<ngraph::element::Type_t, NgraphTestCase::value_comparator_function>
NgraphTestCase::value_comparators = {
{ngraph::element::Type_t::f32, NgraphTestCase::compare_values<float>},
{ngraph::element::Type_t::f64, NgraphTestCase::compare_values<double>},
{ngraph::element::Type_t::i8, NgraphTestCase::compare_values<int8_t>},
{ngraph::element::Type_t::i16, NgraphTestCase::compare_values<int16_t>},
{ngraph::element::Type_t::i32, NgraphTestCase::compare_values<int32_t>},
{ngraph::element::Type_t::i64, NgraphTestCase::compare_values<int64_t>},
{ngraph::element::Type_t::u8, NgraphTestCase::compare_values<uint8_t>},
{ngraph::element::Type_t::u16, NgraphTestCase::compare_values<uint16_t>},
{ngraph::element::Type_t::u32, NgraphTestCase::compare_values<uint32_t>},
{ngraph::element::Type_t::u64, NgraphTestCase::compare_values<uint64_t>}};
void NgraphTestCase::run()
{
const auto& function_results = m_function->get_results(); const auto& function_results = m_function->get_results();
NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(), NGRAPH_CHECK(
m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number of results."); "Expected number of outputs is different from the function's number of results.");
auto handle = m_backend->compile(m_function); auto handle = m_backend->compile(m_function);
...@@ -36,22 +52,18 @@ void ngraph::test::NgraphTestCase::run() ...@@ -36,22 +52,18 @@ void ngraph::test::NgraphTestCase::run()
const auto& expected_result_constant = m_expected_outputs.at(i); const auto& expected_result_constant = m_expected_outputs.at(i);
const auto& element_type = result_tensor->get_element_type(); const auto& element_type = result_tensor->get_element_type();
if (element_type == ngraph::element::f32) if (value_comparators.count(element_type.get_type_enum()) == 0)
{
const auto result = read_vector<float>(result_tensor);
const auto expected = expected_result_constant->get_vector<float>();
EXPECT_TRUE(test::all_close_f(expected, result));
}
else if (element_type == ngraph::element::u8)
{ {
const auto result = read_vector<uint8_t>(result_tensor); NGRAPH_FAIL() << "Please add support for " << element_type
const auto expected = expected_result_constant->get_vector<uint8_t>(); << " to ngraph::test::NgraphTestCase::run()";
EXPECT_TRUE(test::all_close(expected, result));
} }
else else
{ {
NGRAPH_FAIL() << "Please add support for " << element_type auto values_match = value_comparators.at(element_type.get_type_enum());
<< " to ngraph::test::NgraphTestCase::run().";
EXPECT_TRUE(values_match(expected_result_constant, result_tensor));
}
}
} }
} }
} }
...@@ -18,8 +18,11 @@ ...@@ -18,8 +18,11 @@
#include <utility> #include <utility>
#include "all_close.hpp"
#include "all_close_f.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "test_tools.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -83,6 +86,34 @@ namespace ngraph ...@@ -83,6 +86,34 @@ namespace ngraph
void run(); void run();
protected: protected:
template <typename T>
static typename std::enable_if<std::is_floating_point<T>::value,
::testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results)
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
return ngraph::test::all_close_f(expected, result);
}
template <typename T>
static typename std::enable_if<std::is_integral<T>::value,
::testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results)
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
return ngraph::test::all_close(expected, result);
}
using value_comparator_function = std::function<::testing::AssertionResult(
const std::shared_ptr<ngraph::op::Constant>&,
const std::shared_ptr<ngraph::runtime::Tensor>&)>;
static std::map<ngraph::element::Type_t, value_comparator_function> value_comparators;
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
std::unique_ptr<runtime::Backend> m_backend; std::unique_ptr<runtime::Backend> m_backend;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors; std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment