Commit ecbe0042 authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[ONNX] Refactor LSTM tests to use NgraphTestCase (#2827)

* [ONNX] Refactor LSTM tests to use NgraphTestCase

* Enable passing instance values to comparator

* Review comments
parent a6df47e3
......@@ -332,6 +332,7 @@ namespace ngraph
REGISTER_OPERATOR("Xor", 1, logical_xor);
}
#undef REGISTER_OPERATOR
} // namespace onnx_import
} // namespace ngraph
This diff is collapsed.
......@@ -19,19 +19,6 @@
#include "gtest/gtest.h"
#include "ngraph/assertion.hpp"
std::map<ngraph::element::Type_t, ngraph::test::NgraphTestCase::value_comparator_function>
ngraph::test::NgraphTestCase::m_value_comparators = {
{ngraph::element::Type_t::f32, NgraphTestCase::compare_values<float>},
{ngraph::element::Type_t::f64, NgraphTestCase::compare_values<double>},
{ngraph::element::Type_t::i8, NgraphTestCase::compare_values<int8_t>},
{ngraph::element::Type_t::i16, NgraphTestCase::compare_values<int16_t>},
{ngraph::element::Type_t::i32, NgraphTestCase::compare_values<int32_t>},
{ngraph::element::Type_t::i64, NgraphTestCase::compare_values<int64_t>},
{ngraph::element::Type_t::u8, NgraphTestCase::compare_values<uint8_t>},
{ngraph::element::Type_t::u16, NgraphTestCase::compare_values<uint16_t>},
{ngraph::element::Type_t::u32, NgraphTestCase::compare_values<uint32_t>},
{ngraph::element::Type_t::u64, NgraphTestCase::compare_values<uint64_t>}};
void ngraph::test::NgraphTestCase::run()
{
const auto& function_results = m_function->get_results();
......
......@@ -38,6 +38,7 @@ namespace ngraph
{
}
void set_tolerance(int tolerance_bits) { m_tolerance_bits = tolerance_bits; }
template <typename T>
void add_input(const std::vector<T>& values)
{
......@@ -122,21 +123,20 @@ namespace ngraph
}
void run();
protected:
private:
template <typename T>
static typename std::enable_if<std::is_floating_point<T>::value,
::testing::AssertionResult>::type
typename std::enable_if<std::is_floating_point<T>::value,
::testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results)
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
return ngraph::test::all_close_f(expected, result);
return ngraph::test::all_close_f(expected, result, m_tolerance_bits);
}
template <typename T>
static typename std::enable_if<std::is_integral<T>::value,
::testing::AssertionResult>::type
typename std::enable_if<std::is_integral<T>::value, ::testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results)
{
......@@ -149,6 +149,33 @@ namespace ngraph
const std::shared_ptr<ngraph::op::Constant>&,
const std::shared_ptr<ngraph::runtime::Tensor>&)>;
#define REGISTER_COMPARATOR(element_type_, type_) \
{ \
ngraph::element::Type_t::element_type_, std::bind(&NgraphTestCase::compare_values<type_>, \
this, \
std::placeholders::_1, \
std::placeholders::_2) \
}
std::map<ngraph::element::Type_t,
std::function<::testing::AssertionResult(
const std::shared_ptr<ngraph::op::Constant>&,
const std::shared_ptr<ngraph::runtime::Tensor>&)>>
m_value_comparators = {
REGISTER_COMPARATOR(f32, float),
REGISTER_COMPARATOR(f64, double),
REGISTER_COMPARATOR(i8, int8_t),
REGISTER_COMPARATOR(i16, int16_t),
REGISTER_COMPARATOR(i32, int32_t),
REGISTER_COMPARATOR(i64, int64_t),
REGISTER_COMPARATOR(u8, uint8_t),
REGISTER_COMPARATOR(u16, uint16_t),
REGISTER_COMPARATOR(u32, uint32_t),
REGISTER_COMPARATOR(u64, uint64_t),
};
#undef REGISTER_COMPARATOR
protected:
std::shared_ptr<Function> m_function;
std::shared_ptr<runtime::Backend> m_backend;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
......@@ -156,7 +183,7 @@ namespace ngraph
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
int m_input_index = 0;
int m_output_index = 0;
static std::map<ngraph::element::Type_t, value_comparator_function> m_value_comparators;
int m_tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS;
};
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment