Commit e6ab0ff7 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Robert Kimball

fix to kahan summation in ref kernel (#2140)

parent 6d7e24c8
...@@ -35,21 +35,24 @@ namespace ngraph ...@@ -35,21 +35,24 @@ namespace ngraph
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
{ {
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
std::vector<T> c(shape_size(out_shape));
for (const Coordinate& output_coord : output_transform) for (const Coordinate& output_coord : output_transform)
{ {
out[output_transform.index(output_coord)] = 0; out[output_transform.index(output_coord)] = 0;
c[output_transform.index(output_coord)] = 0;
} }
CoordinateTransform input_transform(in_shape); CoordinateTransform input_transform(in_shape);
T c = 0;
for (const Coordinate& input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate output_coord = reduce(input_coord, reduction_axes); Coordinate output_coord = reduce(input_coord, reduction_axes);
T y = arg[input_transform.index(input_coord)] - c; T y = arg[input_transform.index(input_coord)] -
c[output_transform.index(output_coord)];
T t = out[output_transform.index(output_coord)] + y; T t = out[output_transform.index(output_coord)] + y;
c = (t - out[output_transform.index(output_coord)]) - y; c[output_transform.index(output_coord)] =
(t - out[output_transform.index(output_coord)]) - y;
out[output_transform.index(output_coord)] = t; out[output_transform.index(output_coord)] = t;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment