Commit 736666d8 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Fix sum reference to handle corner cases with +-inf (#3412)

* Fix sum reference to handle corner cases with +-inf

* Review comments, and try to make Windows happy

* Update GPU unit_test.manifest

* More template grindery, to make macOS happy
parent 5fb9fd10
...@@ -100,6 +100,9 @@ all_2x2x3_eliminate_dims_1_2 ...@@ -100,6 +100,9 @@ all_2x2x3_eliminate_dims_1_2
all_2x2x3_eliminate_dims_0_1_2 all_2x2x3_eliminate_dims_0_1_2
all_dynamic all_dynamic
# Corner-case tests for sum with infs and -infs.
sum_inf
# GPU backend uses floats to implement these ops for int32 # GPU backend uses floats to implement these ops for int32
floor_int32 floor_int32
divide_int32 divide_int32
......
...@@ -27,6 +27,29 @@ namespace ngraph ...@@ -27,6 +27,29 @@ namespace ngraph
{ {
namespace reference namespace reference
{ {
// Windows doesn't seem to like it if we directly use std::isfinite on integer types,
// so we will roll our own thing here.
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, bool>::type is_finite(T x)
{
return std::isfinite(x);
}
template <typename T>
typename std::enable_if<std::is_same<T, bfloat16>::value ||
std::is_same<T, float16>::value,
bool>::type
is_finite(T x)
{
return std::isfinite(static_cast<float>(x));
}
template <typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type is_finite(T x)
{
return true;
}
template <typename T> template <typename T>
void sum(const T* arg, void sum(const T* arg,
T* out, T* out,
...@@ -35,12 +58,12 @@ namespace ngraph ...@@ -35,12 +58,12 @@ namespace ngraph
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
{ {
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
std::vector<T> c(shape_size(out_shape)); std::vector<T> cs(shape_size(out_shape));
for (const Coordinate& output_coord : output_transform) for (const Coordinate& output_coord : output_transform)
{ {
out[output_transform.index(output_coord)] = 0; out[output_transform.index(output_coord)] = 0;
c[output_transform.index(output_coord)] = 0; cs[output_transform.index(output_coord)] = 0;
} }
CoordinateTransform input_transform(in_shape); CoordinateTransform input_transform(in_shape);
...@@ -48,12 +71,21 @@ namespace ngraph ...@@ -48,12 +71,21 @@ namespace ngraph
for (const Coordinate& input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate output_coord = reduce(input_coord, reduction_axes); Coordinate output_coord = reduce(input_coord, reduction_axes);
T y = arg[input_transform.index(input_coord)] -
c[output_transform.index(output_coord)]; T x = arg[input_transform.index(input_coord)];
T t = out[output_transform.index(output_coord)] + y; T& z = out[output_transform.index(output_coord)];
c[output_transform.index(output_coord)] =
(t - out[output_transform.index(output_coord)]) - y; if (is_finite(x) && is_finite(z))
out[output_transform.index(output_coord)] = t; {
T& c = cs[output_transform.index(output_coord)];
T t = z + (x - c);
c = (t - z) - (x - c);
z = t;
}
else
{
z = z + x;
}
} }
} }
} }
......
...@@ -740,3 +740,39 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_dynamic) ...@@ -740,3 +740,39 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_dynamic)
ASSERT_TRUE(test::all_close_f(results, expected_results[i], MIN_FLOAT_TOLERANCE_BITS)); ASSERT_TRUE(test::all_close_f(results, expected_results[i], MIN_FLOAT_TOLERANCE_BITS));
} }
} }
NGRAPH_TEST(${BACKEND_NAME}, sum_inf)
{
Shape shape{7, 4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Sum>(A, AxisSet{1}), ParameterVector{A});
auto infi = std::numeric_limits<float>::infinity();
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 2>({{-infi, 0, 0, infi},
{infi, 100, -100, -infi},
{infi, 0, 100, infi},
{-infi, -100, 0, -infi},
{infi, infi, infi, infi},
{infi, infi, infi, -infi},
{infi, std::nanf(""), 42, infi}})
.get_vector());
auto result = backend->create_tensor(element::f32, Shape{7});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
auto r = read_vector<float>(result);
ASSERT_EQ(r.size(), 7);
EXPECT_TRUE(isnan(r[0]));
EXPECT_TRUE(isnan(r[1]));
EXPECT_TRUE(r[2] > 0 && isinf(r[2]));
EXPECT_TRUE(r[3] < 0 && isinf(r[3]));
EXPECT_TRUE(r[4] > 0 && isinf(r[4]));
EXPECT_TRUE(isnan(r[5]));
EXPECT_TRUE(isnan(r[6]));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment