Commit 5fbf30e3 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

NgraphTestCase - dump expected and computed data to the console (#2893)

* Dump the expected and actual values for NgraphTestCase

* Adapt to changes in master

* Some docs and API unification
parent fffbaa89
...@@ -51,3 +51,15 @@ void ngraph::test::NgraphTestCase::run() ...@@ -51,3 +51,15 @@ void ngraph::test::NgraphTestCase::run()
} }
} }
} }
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::set_tolerance(int tolerance_bits)
{
m_tolerance_bits = tolerance_bits;
return *this;
}
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump)
{
m_dump_results = dump;
return *this;
}
...@@ -38,7 +38,16 @@ namespace ngraph ...@@ -38,7 +38,16 @@ namespace ngraph
{ {
} }
void set_tolerance(int tolerance_bits) { m_tolerance_bits = tolerance_bits; } NgraphTestCase& set_tolerance(int tolerance_bits);
/// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes.
///
/// Just before the assertion is done, the current test case will gather expected and computed values,
/// format them as 2 columns and print out to the console along with a corresponding index in the vector.
///
/// \param dump - Indicates if the test case should perform the console printout
NgraphTestCase& dump_results(bool dump = true);
template <typename T> template <typename T>
void add_input(const std::vector<T>& values) void add_input(const std::vector<T>& values)
{ {
...@@ -121,6 +130,7 @@ namespace ngraph ...@@ -121,6 +130,7 @@ namespace ngraph
auto value = read_binary_file<T>(filepath); auto value = read_binary_file<T>(filepath);
add_expected_output(expected_shape, value); add_expected_output(expected_shape, value);
} }
void run(); void run();
private: private:
...@@ -132,6 +142,12 @@ namespace ngraph ...@@ -132,6 +142,12 @@ namespace ngraph
{ {
const auto expected = expected_results->get_vector<T>(); const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results); const auto result = read_vector<T>(results);
if (m_dump_results)
{
std::cout << get_results_str<T>(expected, result, expected.size());
}
return ngraph::test::all_close_f(expected, result, m_tolerance_bits); return ngraph::test::all_close_f(expected, result, m_tolerance_bits);
} }
...@@ -142,6 +158,12 @@ namespace ngraph ...@@ -142,6 +158,12 @@ namespace ngraph
{ {
const auto expected = expected_results->get_vector<T>(); const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results); const auto result = read_vector<T>(results);
if (m_dump_results)
{
std::cout << get_results_str<T>(expected, result, expected.size());
}
return ngraph::test::all_close(expected, result); return ngraph::test::all_close(expected, result);
} }
...@@ -157,21 +179,17 @@ namespace ngraph ...@@ -157,21 +179,17 @@ namespace ngraph
std::placeholders::_2) \ std::placeholders::_2) \
} }
std::map<ngraph::element::Type_t, std::map<ngraph::element::Type_t, value_comparator_function> m_value_comparators = {
std::function<::testing::AssertionResult( REGISTER_COMPARATOR(f32, float),
const std::shared_ptr<ngraph::op::Constant>&, REGISTER_COMPARATOR(f64, double),
const std::shared_ptr<ngraph::runtime::Tensor>&)>> REGISTER_COMPARATOR(i8, int8_t),
m_value_comparators = { REGISTER_COMPARATOR(i16, int16_t),
REGISTER_COMPARATOR(f32, float), REGISTER_COMPARATOR(i32, int32_t),
REGISTER_COMPARATOR(f64, double), REGISTER_COMPARATOR(i64, int64_t),
REGISTER_COMPARATOR(i8, int8_t), REGISTER_COMPARATOR(u8, uint8_t),
REGISTER_COMPARATOR(i16, int16_t), REGISTER_COMPARATOR(u16, uint16_t),
REGISTER_COMPARATOR(i32, int32_t), REGISTER_COMPARATOR(u32, uint32_t),
REGISTER_COMPARATOR(i64, int64_t), REGISTER_COMPARATOR(u64, uint64_t),
REGISTER_COMPARATOR(u8, uint8_t),
REGISTER_COMPARATOR(u16, uint16_t),
REGISTER_COMPARATOR(u32, uint32_t),
REGISTER_COMPARATOR(u64, uint64_t),
}; };
#undef REGISTER_COMPARATOR #undef REGISTER_COMPARATOR
...@@ -183,6 +201,7 @@ namespace ngraph ...@@ -183,6 +201,7 @@ namespace ngraph
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs; std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
int m_input_index = 0; int m_input_index = 0;
int m_output_index = 0; int m_output_index = 0;
bool m_dump_results = false;
int m_tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS; int m_tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS;
}; };
} }
......
...@@ -284,8 +284,9 @@ void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine ...@@ -284,8 +284,9 @@ void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine
} }
template <> template <>
string string get_results_str(const std::vector<char>& ref_data,
get_results_str(std::vector<char>& ref_data, std::vector<char>& actual_data, size_t max_results) const std::vector<char>& actual_data,
size_t max_results)
{ {
stringstream ss; stringstream ss;
size_t num_results = std::min(static_cast<size_t>(max_results), ref_data.size()); size_t num_results = std::min(static_cast<size_t>(max_results), ref_data.size());
......
...@@ -222,8 +222,9 @@ std::vector<std::vector<TOUT>> execute(const std::shared_ptr<ngraph::Function>& ...@@ -222,8 +222,9 @@ std::vector<std::vector<TOUT>> execute(const std::shared_ptr<ngraph::Function>&
} }
template <typename T> template <typename T>
std::string std::string get_results_str(const std::vector<T>& ref_data,
get_results_str(std::vector<T>& ref_data, std::vector<T>& actual_data, size_t max_results = 16) const std::vector<T>& actual_data,
size_t max_results = 16)
{ {
std::stringstream ss; std::stringstream ss;
size_t num_results = std::min(static_cast<size_t>(max_results), ref_data.size()); size_t num_results = std::min(static_cast<size_t>(max_results), ref_data.size());
...@@ -240,8 +241,8 @@ std::string ...@@ -240,8 +241,8 @@ std::string
} }
template <> template <>
std::string get_results_str(std::vector<char>& ref_data, std::string get_results_str(const std::vector<char>& ref_data,
std::vector<char>& actual_data, const std::vector<char>& actual_data,
size_t max_results); size_t max_results);
/// \brief Reads a binary file to a vector. /// \brief Reads a binary file to a vector.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment