Commit 8da348f5 authored by Michał Karzyński's avatar Michał Karzyński Committed by Sang Ik Lee

[ONNX] Refactor quantized ops tests (#2798)

* Add NgraphTestCase test runner class

* Code improvements

* clang-format

* Add support for multiple output types

* Add reading inputs/outputs from files

* Refactor model_quant_conv_linear_3d

* Add shape checking to NgraphTestCase

* Add expected output shape to model_quant_conv_linear_3d

* Code review

* Remove small data files and move values to test code for legibility

* clang-format

* Refactor model_quant_conv_linear_2d

* Refactor model_dequantize_linear_1d_zero_scale_uint8_negative_axis

* Refactor model_dequantize_linear_1d_zero_scale_int8_4d

* Refactor model_dequantize_linear_1d_zero_scale_int8

* Refactor model_dequantize_linear_1d_zero_scale_uint8

* Refactor model_dequantize_linear_scalar_zero_scale_int8

* Refactor model_dequantize_linear_scalar_zero_scale_uint8

* Refactor model_quantize_linear_axis_negative

* Refactor model_quantize_linear_axis_zero

* Refactor model_quantize_linear_zero_point

* Add shape checking to NgraphTestCase::run

* Refactor NgraphTestCase::run
parent d9dfd6c4
--broken encoding: IBM424_ltr
\ No newline at end of file
This diff is collapsed.
......@@ -19,51 +19,48 @@
#include "gtest/gtest.h"
#include "ngraph/assertion.hpp"
namespace ngraph
std::map<ngraph::element::Type_t, ngraph::test::NgraphTestCase::value_comparator_function>
ngraph::test::NgraphTestCase::m_value_comparators = {
{ngraph::element::Type_t::f32, NgraphTestCase::compare_values<float>},
{ngraph::element::Type_t::f64, NgraphTestCase::compare_values<double>},
{ngraph::element::Type_t::i8, NgraphTestCase::compare_values<int8_t>},
{ngraph::element::Type_t::i16, NgraphTestCase::compare_values<int16_t>},
{ngraph::element::Type_t::i32, NgraphTestCase::compare_values<int32_t>},
{ngraph::element::Type_t::i64, NgraphTestCase::compare_values<int64_t>},
{ngraph::element::Type_t::u8, NgraphTestCase::compare_values<uint8_t>},
{ngraph::element::Type_t::u16, NgraphTestCase::compare_values<uint16_t>},
{ngraph::element::Type_t::u32, NgraphTestCase::compare_values<uint32_t>},
{ngraph::element::Type_t::u64, NgraphTestCase::compare_values<uint64_t>}};
void ngraph::test::NgraphTestCase::run()
{
namespace test
{
std::map<ngraph::element::Type_t, NgraphTestCase::value_comparator_function>
NgraphTestCase::value_comparators = {
{ngraph::element::Type_t::f32, NgraphTestCase::compare_values<float>},
{ngraph::element::Type_t::f64, NgraphTestCase::compare_values<double>},
{ngraph::element::Type_t::i8, NgraphTestCase::compare_values<int8_t>},
{ngraph::element::Type_t::i16, NgraphTestCase::compare_values<int16_t>},
{ngraph::element::Type_t::i32, NgraphTestCase::compare_values<int32_t>},
{ngraph::element::Type_t::i64, NgraphTestCase::compare_values<int64_t>},
{ngraph::element::Type_t::u8, NgraphTestCase::compare_values<uint8_t>},
{ngraph::element::Type_t::u16, NgraphTestCase::compare_values<uint16_t>},
{ngraph::element::Type_t::u32, NgraphTestCase::compare_values<uint32_t>},
{ngraph::element::Type_t::u64, NgraphTestCase::compare_values<uint64_t>}};
const auto& function_results = m_function->get_results();
NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number of results.");
void NgraphTestCase::run()
{
const auto& function_results = m_function->get_results();
NGRAPH_CHECK(
m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number of results.");
auto handle = m_backend->compile(m_function);
handle->call_with_validate(m_result_tensors, m_input_tensors);
auto handle = m_backend->compile(m_function);
handle->call_with_validate(m_result_tensors, m_input_tensors);
for (int i = 0; i < m_expected_outputs.size(); ++i)
{
const auto& result_tensor = m_result_tensors.at(i);
const auto& expected_result_constant = m_expected_outputs.at(i);
const auto& element_type = result_tensor->get_element_type();
for (int i = 0; i < m_expected_outputs.size(); ++i)
{
const auto& result_tensor = m_result_tensors.at(i);
const auto& expected_result_constant = m_expected_outputs.at(i);
const auto& element_type = result_tensor->get_element_type();
auto expected_shape = expected_result_constant->get_shape();
auto result_shape = result_tensor->get_shape();
EXPECT_EQ(expected_shape, result_shape);
if (value_comparators.count(element_type.get_type_enum()) == 0)
{
NGRAPH_FAIL() << "Please add support for " << element_type
<< " to ngraph::test::NgraphTestCase::run()";
}
else
{
auto values_match = value_comparators.at(element_type.get_type_enum());
if (m_value_comparators.count(element_type.get_type_enum()) == 0)
{
NGRAPH_FAIL() << "Please add support for " << element_type
<< " to ngraph::test::NgraphTestCase::run()";
}
else
{
auto values_match = m_value_comparators.at(element_type.get_type_enum());
EXPECT_TRUE(values_match(expected_result_constant, result_tensor));
}
}
EXPECT_TRUE(values_match(expected_result_constant, result_tensor));
}
}
}
......@@ -55,6 +55,20 @@ namespace ngraph
++m_input_index;
}
template <typename T>
void add_input_from_file(const std::string& basepath, const std::string& filename)
{
auto filepath = ngraph::file_util::path_join(basepath, filename);
add_input_from_file<T>(filepath);
}
template <typename T>
void add_input_from_file(const std::string& filepath)
{
auto value = read_binary_file<T>(filepath);
add_input(value);
}
template <typename T>
void add_multiple_inputs(const std::vector<std::vector<T>>& vector_of_values)
{
......@@ -65,7 +79,7 @@ namespace ngraph
}
template <typename T>
void add_expected_output(const std::vector<T>& values)
void add_expected_output(ngraph::Shape expected_shape, const std::vector<T>& values)
{
auto results = m_function->get_results();
......@@ -78,11 +92,34 @@ namespace ngraph
m_backend->create_tensor(function_output_type, function_output_shape));
m_expected_outputs.emplace_back(std::make_shared<ngraph::op::Constant>(
function_output_type, function_output_shape, values));
function_output_type, expected_shape, values));
++m_output_index;
}
template <typename T>
void add_expected_output(const std::vector<T>& values)
{
auto shape = m_function->get_results().at(m_output_index)->get_shape();
add_expected_output(shape, values);
}
template <typename T>
void add_expected_output_from_file(ngraph::Shape expected_shape,
const std::string& basepath,
const std::string& filename)
{
auto filepath = ngraph::file_util::path_join(basepath, filename);
add_expected_output_from_file<T>(expected_shape, filepath);
}
template <typename T>
void add_expected_output_from_file(ngraph::Shape expected_shape,
const std::string& filepath)
{
auto value = read_binary_file<T>(filepath);
add_expected_output(expected_shape, value);
}
void run();
protected:
......@@ -112,8 +149,6 @@ namespace ngraph
const std::shared_ptr<ngraph::op::Constant>&,
const std::shared_ptr<ngraph::runtime::Tensor>&)>;
static std::map<ngraph::element::Type_t, value_comparator_function> value_comparators;
std::shared_ptr<Function> m_function;
std::unique_ptr<runtime::Backend> m_backend;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
......@@ -121,6 +156,7 @@ namespace ngraph
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
int m_input_index = 0;
int m_output_index = 0;
static std::map<ngraph::element::Type_t, value_comparator_function> m_value_comparators;
};
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment