Commit 3d1335fe authored by Yimei Sun's avatar Yimei Sun Committed by Scott Cyphers

Update the tolerance on auto_broadcast_test (#3959)

parent c2bb6d99
...@@ -31,6 +31,15 @@ ...@@ -31,6 +31,15 @@
#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS #ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS #define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#endif #endif
#ifndef RTOL
#define RTOL 1e-4
#endif
#ifndef ATOL
#define ATOL 1e-4
#endif
// clang-format on // clang-format on
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -50,7 +59,8 @@ template <typename optype, typename itype, typename otype> ...@@ -50,7 +59,8 @@ template <typename optype, typename itype, typename otype>
void check_auto_bcast( void check_auto_bcast(
const std::vector<std::vector<itype>>& inputs, const std::vector<std::vector<itype>>& inputs,
const std::vector<otype> output, const std::vector<otype> output,
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY)) const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY),
bool set_tolerance = false)
{ {
auto iet = element::from<itype>(); auto iet = element::from<itype>();
auto oet = element::from<otype>(); auto oet = element::from<otype>();
...@@ -79,7 +89,17 @@ void check_auto_bcast( ...@@ -79,7 +89,17 @@ void check_auto_bcast(
auto handle = backend->compile(f); auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b}); handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close(read_vector<otype>(result), output)); if (set_tolerance)
{
EXPECT_TRUE(test::all_close(read_vector<otype>(result),
output,
static_cast<otype>(RTOL),
static_cast<otype>(ATOL)));
}
else
{
EXPECT_TRUE(test::all_close(read_vector<otype>(result), output));
}
} }
NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_binary_elementwise) NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_binary_elementwise)
...@@ -96,7 +116,9 @@ NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_binary_elementwise) ...@@ -96,7 +116,9 @@ NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_binary_elementwise)
check_auto_bcast<op::Minimum, float, float>({{1, 2, 3, 4, 5, 6}, {1, 5, 8}}, check_auto_bcast<op::Minimum, float, float>({{1, 2, 3, 4, 5, 6}, {1, 5, 8}},
{1, 2, 3, 1, 5, 6}); {1, 2, 3, 1, 5, 6});
check_auto_bcast<op::Power, float, float>({{1, 2, 3, 4, 5, 6}, {1, 2, 3}}, check_auto_bcast<op::Power, float, float>({{1, 2, 3, 4, 5, 6}, {1, 2, 3}},
{1, 4, 27, 4, 25, 216}); {1, 4, 27, 4, 25, 216},
op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY),
true);
check_auto_bcast<op::And, char, char>({{1, 0, 1, 0, 0, 1}, {1, 0, 1}}, {1, 0, 1, 0, 0, 1}); check_auto_bcast<op::And, char, char>({{1, 0, 1, 0, 0, 1}, {1, 0, 1}}, {1, 0, 1, 0, 0, 1});
check_auto_bcast<op::Or, char, char>({{1, 0, 1, 0, 1, 1}, {1, 0, 0}}, {1, 0, 1, 1, 1, 1}); check_auto_bcast<op::Or, char, char>({{1, 0, 1, 0, 1, 1}, {1, 0, 0}}, {1, 0, 1, 1, 1, 1});
...@@ -128,7 +150,7 @@ NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_binary_elementwise_pdpd) ...@@ -128,7 +150,7 @@ NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_binary_elementwise_pdpd)
check_auto_bcast<op::Minimum, float, float>( check_auto_bcast<op::Minimum, float, float>(
{{1, 2, 3, 4, 5, 6}, {1, 5, 8}}, {1, 2, 3, 1, 5, 6}, autob); {{1, 2, 3, 4, 5, 6}, {1, 5, 8}}, {1, 2, 3, 1, 5, 6}, autob);
check_auto_bcast<op::Power, float, float>( check_auto_bcast<op::Power, float, float>(
{{1, 2, 3, 4, 5, 6}, {1, 2, 3}}, {1, 4, 27, 4, 25, 216}, autob); {{1, 2, 3, 4, 5, 6}, {1, 2, 3}}, {1, 4, 27, 4, 25, 216}, autob, true);
check_auto_bcast<op::And, char, char>( check_auto_bcast<op::And, char, char>(
{{1, 0, 1, 0, 0, 1}, {1, 0, 1}}, {1, 0, 1, 0, 0, 1}, autob); {{1, 0, 1, 0, 0, 1}, {1, 0, 1}}, {1, 0, 1, 0, 0, 1}, autob);
check_auto_bcast<op::Or, char, char>( check_auto_bcast<op::Or, char, char>(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment