Commit dd8c9ed7 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Enable Gather with negative indices. (#3701)

* Enable Gather with negative indices.

* Address PR feedback.

* Remove GOE from Gather unit tests.
parent b8266cab
...@@ -69,6 +69,8 @@ namespace ngraph ...@@ -69,6 +69,8 @@ namespace ngraph
Eigen::array<Eigen::Index, Rank1> in_dims; Eigen::array<Eigen::Index, Rank1> in_dims;
Eigen::array<Eigen::Index, Rank2> out_dims; Eigen::array<Eigen::Index, Rank2> out_dims;
auto axis_length = inputs_shape[axis];
for (int i = 0; i < Rank1; i++) for (int i = 0; i < Rank1; i++)
{ {
in_dims[i] = inputs_shape[i]; in_dims[i] = inputs_shape[i];
...@@ -84,6 +86,7 @@ namespace ngraph ...@@ -84,6 +86,7 @@ namespace ngraph
static_cast<ElementType*>(inputs), in_dims); static_cast<ElementType*>(inputs), in_dims);
auto indices_ptr = static_cast<IndicesType*>(indices); auto indices_ptr = static_cast<IndicesType*>(indices);
IndicesType index_value;
auto indices_rank = indices_shape.size(); auto indices_rank = indices_shape.size();
auto outer_loop_num = 1; auto outer_loop_num = 1;
for (int i = 0; i < axis; i++) for (int i = 0; i < axis; i++)
...@@ -123,7 +126,10 @@ namespace ngraph ...@@ -123,7 +126,10 @@ namespace ngraph
// at axis // at axis
in_extents[axis] = 1; in_extents[axis] = 1;
// at axis, get the value from indices arg // at axis, get the value from indices arg
in_offsets[axis] = indices_ptr[0]; index_value = indices_ptr[0];
// take care of negative indices
in_offsets[axis] =
index_value >= 0 ? index_value : index_value + axis_length;
// before axis // before axis
for (int r = 0; r < axis; r++) for (int r = 0; r < axis; r++)
...@@ -193,7 +199,10 @@ namespace ngraph ...@@ -193,7 +199,10 @@ namespace ngraph
} }
// at axis, get the value from indices arg // at axis, get the value from indices arg
int k = i % num_indices; int k = i % num_indices;
in_offsets[axis] = indices_ptr[k]; index_value = indices_ptr[k];
// take care of negative indices
in_offsets[axis] =
index_value >= 0 ? index_value : index_value + axis_length;
// indices_from_indices_arg depends on indices_shape and k. // indices_from_indices_arg depends on indices_shape and k.
// suppose the inputs has shape {3, 3, 3}, indices has shape {2, 2}, and // suppose the inputs has shape {3, 3, 3}, indices has shape {2, 2}, and
......
...@@ -64,6 +64,7 @@ gather_2d_indices_axis_1_2d_input ...@@ -64,6 +64,7 @@ gather_2d_indices_axis_1_2d_input
gather_scalar_indices_no_axis_2d_input gather_scalar_indices_no_axis_2d_input
gather_1d_indices_no_axis_1d_input gather_1d_indices_no_axis_1d_input
gather_2d_indices_no_axis_2d_input gather_2d_indices_no_axis_2d_input
gather_2d_negative_and_positive_indices_no_axis_2d_input
gather_3d_indices_no_axis_2d_input gather_3d_indices_no_axis_2d_input
gather_4d_indices_no_axis_2d_input gather_4d_indices_no_axis_2d_input
gemm gemm
......
...@@ -196,6 +196,7 @@ gather_4d_indices_no_axis_uint8 ...@@ -196,6 +196,7 @@ gather_4d_indices_no_axis_uint8
gather_4d_indices_no_axis_2d_input gather_4d_indices_no_axis_2d_input
gather_3d_indices_no_axis_2d_input gather_3d_indices_no_axis_2d_input
gather_2d_indices_no_axis_2d_input gather_2d_indices_no_axis_2d_input
gather_2d_negative_and_positive_indices_no_axis_2d_input
gather_1d_indices_no_axis_1d_input gather_1d_indices_no_axis_1d_input
gather_scalar_indices_no_axis_2d_input gather_scalar_indices_no_axis_2d_input
gather_2d_indices_axis_1_2d_input gather_2d_indices_axis_1_2d_input
......
...@@ -83,6 +83,8 @@ namespace ngraph ...@@ -83,6 +83,8 @@ namespace ngraph
for (size_t i = 0; i < slice_rank; i++) for (size_t i = 0; i < slice_rank; i++)
{ {
U index = indices[indices_index]; U index = indices[indices_index];
// take care of negative indices
index = index >= 0 ? index : index + params_shape[i];
params_start_corner[i] = index; params_start_corner[i] = index;
params_end_corner[i] = index + 1; params_end_corner[i] = index + 1;
indices_index++; indices_index++;
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment