Commit 28e31004 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Fix a bug in Gather Eigen kernel. (#3743)

parent 5fa63dda
...@@ -86,7 +86,6 @@ namespace ngraph ...@@ -86,7 +86,6 @@ namespace ngraph
static_cast<ElementType*>(inputs), in_dims); static_cast<ElementType*>(inputs), in_dims);
auto indices_ptr = static_cast<IndicesType*>(indices); auto indices_ptr = static_cast<IndicesType*>(indices);
IndicesType index_value;
auto indices_rank = indices_shape.size(); auto indices_rank = indices_shape.size();
auto outer_loop_num = 1; auto outer_loop_num = 1;
for (int i = 0; i < axis; i++) for (int i = 0; i < axis; i++)
...@@ -126,7 +125,7 @@ namespace ngraph ...@@ -126,7 +125,7 @@ namespace ngraph
// at axis // at axis
in_extents[axis] = 1; in_extents[axis] = 1;
// at axis, get the value from indices arg // at axis, get the value from indices arg
index_value = indices_ptr[0]; IndicesType index_value = indices_ptr[0];
// take care of negative indices // take care of negative indices
in_offsets[axis] = in_offsets[axis] =
index_value >= 0 ? index_value : index_value + axis_length; index_value >= 0 ? index_value : index_value + axis_length;
...@@ -199,7 +198,7 @@ namespace ngraph ...@@ -199,7 +198,7 @@ namespace ngraph
} }
// at axis, get the value from indices arg // at axis, get the value from indices arg
int k = i % num_indices; int k = i % num_indices;
index_value = indices_ptr[k]; IndicesType index_value = indices_ptr[k];
// take care of negative indices // take care of negative indices
in_offsets[axis] = in_offsets[axis] =
index_value >= 0 ? index_value : index_value + axis_length; index_value >= 0 ? index_value : index_value + axis_length;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment