Unverified Commit 52dce8bb authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3348 from NervanaSystems/bob/fused_tolerance

Allow override of default tolerance for NgraphTestCase
parents 45ad33ea 11d93be4
...@@ -1540,8 +1540,7 @@ NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose) ...@@ -1540,8 +1540,7 @@ NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose)
-0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f, -0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f,
0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f, 0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f,
-0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f}); -0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f});
test_case.set_tolerance(3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose_output_shape) NGRAPH_TEST(${BACKEND_NAME}, group_conv_transpose_output_shape)
......
...@@ -104,8 +104,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip) ...@@ -104,8 +104,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip)
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value. // We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position. // The discrepancies may occur at most on 7th decimal position.
test_case.set_tolerance(3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
...@@ -144,8 +143,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq) ...@@ -144,8 +143,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_mixed_seq)
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value. // We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position. // The discrepancies may occur at most on 7th decimal position.
test_case.set_tolerance(3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
...@@ -201,8 +199,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation) ...@@ -201,8 +199,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
test_case.add_expected_output<float>(Shape{1, 1, 2}, {0.19017234f, 0.00356848f}); test_case.add_expected_output<float>(Shape{1, 1, 2}, {0.19017234f, 0.00356848f});
// The discrepancies occur at most at 18th mantissa bit - 8th decimal position. // The discrepancies occur at most at 18th mantissa bit - 8th decimal position.
test_case.set_tolerance(6); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_no_clip) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_no_clip)
...@@ -307,8 +304,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_bdir_short_input_seq) ...@@ -307,8 +304,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_bdir_short_input_seq)
test_case.add_expected_output<float>(Shape{2, 1, 2}, test_case.add_expected_output<float>(Shape{2, 1, 2},
{-0.0251062f, 0.0561262f, -0.0318928f, 0.0762679f}); {-0.0251062f, 0.0561262f, -0.0318928f, 0.0762679f});
test_case.set_tolerance(DEFAULT_FLOAT_TOLERANCE_BITS + 3); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse)
...@@ -353,6 +349,5 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse) ...@@ -353,6 +349,5 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_mixed_seq_reverse)
Shape{1, 2, 3}, Shape{1, 2, 3},
{0.52497941f, 0.54983425f, 0.5744428f, 1.34960834f, 1.54772296f, 1.65633056f}); {0.52497941f, 0.54983425f, 0.5744428f, 1.34960834f, 1.54772296f, 1.65633056f});
test_case.set_tolerance(DEFAULT_FLOAT_TOLERANCE_BITS + 1); test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
test_case.run();
} }
...@@ -19,8 +19,9 @@ ...@@ -19,8 +19,9 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
void ngraph::test::NgraphTestCase::run() void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
{ {
m_tolerance_bits = tolerance_bits;
const auto& function_results = m_function->get_results(); const auto& function_results = m_function->get_results();
NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(), NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number of results."); "Expected number of outputs is different from the function's number of results.");
...@@ -52,12 +53,6 @@ void ngraph::test::NgraphTestCase::run() ...@@ -52,12 +53,6 @@ void ngraph::test::NgraphTestCase::run()
} }
} }
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::set_tolerance(int tolerance_bits)
{
m_tolerance_bits = tolerance_bits;
return *this;
}
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump) ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump)
{ {
m_dump_results = dump; m_dump_results = dump;
......
...@@ -38,8 +38,6 @@ namespace ngraph ...@@ -38,8 +38,6 @@ namespace ngraph
{ {
} }
NgraphTestCase& set_tolerance(int tolerance_bits);
/// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes. /// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes.
/// ///
/// Just before the assertion is done, the current test case will gather expected and computed values, /// Just before the assertion is done, the current test case will gather expected and computed values,
...@@ -130,7 +128,7 @@ namespace ngraph ...@@ -130,7 +128,7 @@ namespace ngraph
add_expected_output(expected_shape, value); add_expected_output(expected_shape, value);
} }
void run(); void run(size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
private: private:
template <typename T> template <typename T>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment