Commit 5fbf30e3 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

NgraphTestCase - dump expected and computed data to the console (#2893)

* Dump the expected and actual values for NgraphTestCase

* Adapt to changes in master

* Some docs and API unification
parent fffbaa89
......@@ -51,3 +51,15 @@ void ngraph::test::NgraphTestCase::run()
}
}
}
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::set_tolerance(int tolerance_bits)
{
m_tolerance_bits = tolerance_bits;
return *this;
}
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump)
{
m_dump_results = dump;
return *this;
}
......@@ -38,7 +38,16 @@ namespace ngraph
{
}
void set_tolerance(int tolerance_bits) { m_tolerance_bits = tolerance_bits; }
NgraphTestCase& set_tolerance(int tolerance_bits);
/// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes.
///
/// Just before the assertion is done, the current test case will gather expected and computed values,
/// format them as 2 columns and print out to the console along with a corresponding index in the vector.
///
/// \param dump - Indicates if the test case should perform the console printout
NgraphTestCase& dump_results(bool dump = true);
template <typename T>
void add_input(const std::vector<T>& values)
{
......@@ -121,6 +130,7 @@ namespace ngraph
auto value = read_binary_file<T>(filepath);
add_expected_output(expected_shape, value);
}
void run();
private:
......@@ -132,6 +142,12 @@ namespace ngraph
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
if (m_dump_results)
{
std::cout << get_results_str<T>(expected, result, expected.size());
}
return ngraph::test::all_close_f(expected, result, m_tolerance_bits);
}
......@@ -142,6 +158,12 @@ namespace ngraph
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
if (m_dump_results)
{
std::cout << get_results_str<T>(expected, result, expected.size());
}
return ngraph::test::all_close(expected, result);
}
......@@ -157,21 +179,17 @@ namespace ngraph
std::placeholders::_2) \
}
std::map<ngraph::element::Type_t,
std::function<::testing::AssertionResult(
const std::shared_ptr<ngraph::op::Constant>&,
const std::shared_ptr<ngraph::runtime::Tensor>&)>>
m_value_comparators = {
REGISTER_COMPARATOR(f32, float),
REGISTER_COMPARATOR(f64, double),
REGISTER_COMPARATOR(i8, int8_t),
REGISTER_COMPARATOR(i16, int16_t),
REGISTER_COMPARATOR(i32, int32_t),
REGISTER_COMPARATOR(i64, int64_t),
REGISTER_COMPARATOR(u8, uint8_t),
REGISTER_COMPARATOR(u16, uint16_t),
REGISTER_COMPARATOR(u32, uint32_t),
REGISTER_COMPARATOR(u64, uint64_t),
std::map<ngraph::element::Type_t, value_comparator_function> m_value_comparators = {
REGISTER_COMPARATOR(f32, float),
REGISTER_COMPARATOR(f64, double),
REGISTER_COMPARATOR(i8, int8_t),
REGISTER_COMPARATOR(i16, int16_t),
REGISTER_COMPARATOR(i32, int32_t),
REGISTER_COMPARATOR(i64, int64_t),
REGISTER_COMPARATOR(u8, uint8_t),
REGISTER_COMPARATOR(u16, uint16_t),
REGISTER_COMPARATOR(u32, uint32_t),
REGISTER_COMPARATOR(u64, uint64_t),
};
#undef REGISTER_COMPARATOR
......@@ -183,6 +201,7 @@ namespace ngraph
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
int m_input_index = 0;
int m_output_index = 0;
bool m_dump_results = false;
int m_tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS;
};
}
......
......@@ -284,8 +284,9 @@ void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine
}
template <>
string
get_results_str(std::vector<char>& ref_data, std::vector<char>& actual_data, size_t max_results)
string get_results_str(const std::vector<char>& ref_data,
const std::vector<char>& actual_data,
size_t max_results)
{
stringstream ss;
size_t num_results = std::min(static_cast<size_t>(max_results), ref_data.size());
......
......@@ -222,8 +222,9 @@ std::vector<std::vector<TOUT>> execute(const std::shared_ptr<ngraph::Function>&
}
template <typename T>
std::string
get_results_str(std::vector<T>& ref_data, std::vector<T>& actual_data, size_t max_results = 16)
std::string get_results_str(const std::vector<T>& ref_data,
const std::vector<T>& actual_data,
size_t max_results = 16)
{
std::stringstream ss;
size_t num_results = std::min(static_cast<size_t>(max_results), ref_data.size());
......@@ -240,8 +241,8 @@ std::string
}
template <>
std::string get_results_str(std::vector<char>& ref_data,
std::vector<char>& actual_data,
std::string get_results_str(const std::vector<char>& ref_data,
const std::vector<char>& actual_data,
size_t max_results);
/// \brief Reads a binary file to a vector.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment