Commit 8da348f5 authored by Michał Karzyński's avatar Michał Karzyński Committed by Sang Ik Lee

[ONNX] Refactor quantized ops tests (#2798)

* Add NgraphTestCase test runner class

* Code improvements

* clang-format

* Add support for multiple output types

* Add reading inputs/outputs from files

* Refactor model_quant_conv_linear_3d

* Add shape checking to NgraphTestCase

* Add expected output shape to model_quant_conv_linear_3d

* Code review

* Remove small data files and move values to test code for legibility

* clang-format

* Refactor model_quant_conv_linear_2d

* Refactor model_dequantize_linear_1d_zero_scale_uint8_negative_axis

* Refactor model_dequantize_linear_1d_zero_scale_int8_4d

* Refactor model_dequantize_linear_1d_zero_scale_int8

* Refactor model_dequantize_linear_1d_zero_scale_uint8

* Refactor model_dequantize_linear_scalar_zero_scale_int8

* Refactor model_dequantize_linear_scalar_zero_scale_uint8

* Refactor model_quantize_linear_axis_negative

* Refactor model_quantize_linear_axis_zero

* Refactor model_quantize_linear_zero_point

* Add shape checking to NgraphTestCase::run

* Refactor NgraphTestCase::run
parent d9dfd6c4
--broken encoding: IBM424_ltr
\ No newline at end of file
This diff is collapsed.
...@@ -19,51 +19,48 @@ ...@@ -19,51 +19,48 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
namespace ngraph std::map<ngraph::element::Type_t, ngraph::test::NgraphTestCase::value_comparator_function>
ngraph::test::NgraphTestCase::m_value_comparators = {
{ngraph::element::Type_t::f32, NgraphTestCase::compare_values<float>},
{ngraph::element::Type_t::f64, NgraphTestCase::compare_values<double>},
{ngraph::element::Type_t::i8, NgraphTestCase::compare_values<int8_t>},
{ngraph::element::Type_t::i16, NgraphTestCase::compare_values<int16_t>},
{ngraph::element::Type_t::i32, NgraphTestCase::compare_values<int32_t>},
{ngraph::element::Type_t::i64, NgraphTestCase::compare_values<int64_t>},
{ngraph::element::Type_t::u8, NgraphTestCase::compare_values<uint8_t>},
{ngraph::element::Type_t::u16, NgraphTestCase::compare_values<uint16_t>},
{ngraph::element::Type_t::u32, NgraphTestCase::compare_values<uint32_t>},
{ngraph::element::Type_t::u64, NgraphTestCase::compare_values<uint64_t>}};
void ngraph::test::NgraphTestCase::run()
{ {
namespace test const auto& function_results = m_function->get_results();
{ NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(),
std::map<ngraph::element::Type_t, NgraphTestCase::value_comparator_function> "Expected number of outputs is different from the function's number of results.");
NgraphTestCase::value_comparators = {
{ngraph::element::Type_t::f32, NgraphTestCase::compare_values<float>},
{ngraph::element::Type_t::f64, NgraphTestCase::compare_values<double>},
{ngraph::element::Type_t::i8, NgraphTestCase::compare_values<int8_t>},
{ngraph::element::Type_t::i16, NgraphTestCase::compare_values<int16_t>},
{ngraph::element::Type_t::i32, NgraphTestCase::compare_values<int32_t>},
{ngraph::element::Type_t::i64, NgraphTestCase::compare_values<int64_t>},
{ngraph::element::Type_t::u8, NgraphTestCase::compare_values<uint8_t>},
{ngraph::element::Type_t::u16, NgraphTestCase::compare_values<uint16_t>},
{ngraph::element::Type_t::u32, NgraphTestCase::compare_values<uint32_t>},
{ngraph::element::Type_t::u64, NgraphTestCase::compare_values<uint64_t>}};
void NgraphTestCase::run() auto handle = m_backend->compile(m_function);
{ handle->call_with_validate(m_result_tensors, m_input_tensors);
const auto& function_results = m_function->get_results();
NGRAPH_CHECK(
m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number of results.");
auto handle = m_backend->compile(m_function); for (int i = 0; i < m_expected_outputs.size(); ++i)
handle->call_with_validate(m_result_tensors, m_input_tensors); {
const auto& result_tensor = m_result_tensors.at(i);
const auto& expected_result_constant = m_expected_outputs.at(i);
const auto& element_type = result_tensor->get_element_type();
for (int i = 0; i < m_expected_outputs.size(); ++i) auto expected_shape = expected_result_constant->get_shape();
{ auto result_shape = result_tensor->get_shape();
const auto& result_tensor = m_result_tensors.at(i); EXPECT_EQ(expected_shape, result_shape);
const auto& expected_result_constant = m_expected_outputs.at(i);
const auto& element_type = result_tensor->get_element_type();
if (value_comparators.count(element_type.get_type_enum()) == 0) if (m_value_comparators.count(element_type.get_type_enum()) == 0)
{ {
NGRAPH_FAIL() << "Please add support for " << element_type NGRAPH_FAIL() << "Please add support for " << element_type
<< " to ngraph::test::NgraphTestCase::run()"; << " to ngraph::test::NgraphTestCase::run()";
} }
else else
{ {
auto values_match = value_comparators.at(element_type.get_type_enum()); auto values_match = m_value_comparators.at(element_type.get_type_enum());
EXPECT_TRUE(values_match(expected_result_constant, result_tensor)); EXPECT_TRUE(values_match(expected_result_constant, result_tensor));
}
}
} }
} }
} }
...@@ -55,6 +55,20 @@ namespace ngraph ...@@ -55,6 +55,20 @@ namespace ngraph
++m_input_index; ++m_input_index;
} }
template <typename T>
void add_input_from_file(const std::string& basepath, const std::string& filename)
{
auto filepath = ngraph::file_util::path_join(basepath, filename);
add_input_from_file<T>(filepath);
}
template <typename T>
void add_input_from_file(const std::string& filepath)
{
auto value = read_binary_file<T>(filepath);
add_input(value);
}
template <typename T> template <typename T>
void add_multiple_inputs(const std::vector<std::vector<T>>& vector_of_values) void add_multiple_inputs(const std::vector<std::vector<T>>& vector_of_values)
{ {
...@@ -65,7 +79,7 @@ namespace ngraph ...@@ -65,7 +79,7 @@ namespace ngraph
} }
template <typename T> template <typename T>
void add_expected_output(const std::vector<T>& values) void add_expected_output(ngraph::Shape expected_shape, const std::vector<T>& values)
{ {
auto results = m_function->get_results(); auto results = m_function->get_results();
...@@ -78,11 +92,34 @@ namespace ngraph ...@@ -78,11 +92,34 @@ namespace ngraph
m_backend->create_tensor(function_output_type, function_output_shape)); m_backend->create_tensor(function_output_type, function_output_shape));
m_expected_outputs.emplace_back(std::make_shared<ngraph::op::Constant>( m_expected_outputs.emplace_back(std::make_shared<ngraph::op::Constant>(
function_output_type, function_output_shape, values)); function_output_type, expected_shape, values));
++m_output_index; ++m_output_index;
} }
template <typename T>
void add_expected_output(const std::vector<T>& values)
{
auto shape = m_function->get_results().at(m_output_index)->get_shape();
add_expected_output(shape, values);
}
template <typename T>
void add_expected_output_from_file(ngraph::Shape expected_shape,
const std::string& basepath,
const std::string& filename)
{
auto filepath = ngraph::file_util::path_join(basepath, filename);
add_expected_output_from_file<T>(expected_shape, filepath);
}
template <typename T>
void add_expected_output_from_file(ngraph::Shape expected_shape,
const std::string& filepath)
{
auto value = read_binary_file<T>(filepath);
add_expected_output(expected_shape, value);
}
void run(); void run();
protected: protected:
...@@ -112,8 +149,6 @@ namespace ngraph ...@@ -112,8 +149,6 @@ namespace ngraph
const std::shared_ptr<ngraph::op::Constant>&, const std::shared_ptr<ngraph::op::Constant>&,
const std::shared_ptr<ngraph::runtime::Tensor>&)>; const std::shared_ptr<ngraph::runtime::Tensor>&)>;
static std::map<ngraph::element::Type_t, value_comparator_function> value_comparators;
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
std::unique_ptr<runtime::Backend> m_backend; std::unique_ptr<runtime::Backend> m_backend;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors; std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
...@@ -121,6 +156,7 @@ namespace ngraph ...@@ -121,6 +156,7 @@ namespace ngraph
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs; std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
int m_input_index = 0; int m_input_index = 0;
int m_output_index = 0; int m_output_index = 0;
static std::map<ngraph::element::Type_t, value_comparator_function> m_value_comparators;
}; };
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment