Commit 66f6331b authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Use Eigen kernel for Gather for any axis value. (#3025)

parent 94df1977
......@@ -53,10 +53,9 @@ namespace ngraph
if (is_int64)
{
if ((args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8) &&
axis == 0)
if (args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8)
{
std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)>
kernel;
......@@ -72,6 +71,7 @@ namespace ngraph
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
......@@ -82,6 +82,7 @@ namespace ngraph
params_shape,
indices_shape,
out_shape,
axis,
ectx->arena);
};
}
......@@ -110,10 +111,9 @@ namespace ngraph
else
{
if ((args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8) &&
axis == 0)
if (args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8)
{
std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)>
kernel;
......@@ -129,6 +129,7 @@ namespace ngraph
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
......@@ -139,6 +140,7 @@ namespace ngraph
params_shape,
indices_shape,
out_shape,
axis,
ectx->arena);
};
}
......
......@@ -1846,6 +1846,7 @@ namespace ngraph
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " " << gather->get_axis() << ",\n";
writer << " 0);\n";
}
else
......
......@@ -247,6 +247,7 @@ namespace ngraph
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& output_shape,
size_t axis,
int arena);
template <typename ElementType,
......
This diff is collapsed.
......@@ -159,7 +159,6 @@ erf
zero_sized_erf
model_erf
model_erf_int32
gather
gather_nd_scalar_from_2d
gather_nd_1d_from_2d
gather_nd_scalar_from_3d
......@@ -171,12 +170,15 @@ gather_nd_batch_scalar_from_3d
gather_nd_batch_1d_from_3d
gather_nd_batch_2d_from_3d
gather_nd_single_indices
gather_scalar_indices
gather_scalar_indices_no_axis
gather_2d_indices_no_axis
gather_3d_indices_no_axis
gather_4d_indices_no_axis
gather_4d_indices_no_axis_uint8
gather_scalar_indices_axis_1_2d_input
gather_1d_indices_axis_2_4d_input
gather_2d_indices_axis_1_2d_input
gather_scalar_indices_no_axis_2d_input
gather_1d_indices_no_axis_1d_input
gather_2d_indices_no_axis_2d_input
gather_3d_indices_no_axis_2d_input
gather_4d_indices_no_axis_2d_input
gemm
gemm_broadcast_input_C
model_hardmax
......
......@@ -44,7 +44,6 @@ pad_reflect_2d_with_neg
batch_mat_mul_forward
backwards_batchmatmul_tensor2_tensor2
erf
gather
gather_nd_scalar_from_2d
gather_nd_1d_from_2d
gather_nd_scalar_from_3d
......@@ -56,12 +55,15 @@ gather_nd_batch_scalar_from_3d
gather_nd_batch_1d_from_3d
gather_nd_batch_2d_from_3d
gather_nd_single_indices
gather_scalar_indices
gather_scalar_indices_no_axis
gather_2d_indices_no_axis
gather_3d_indices_no_axis
gather_4d_indices_no_axis
gather_4d_indices_no_axis_uint8
gather_scalar_indices_axis_1_2d_input
gather_1d_indices_axis_2_4d_input
gather_2d_indices_axis_1_2d_input
gather_scalar_indices_no_axis_2d_input
gather_1d_indices_no_axis_1d_input
gather_2d_indices_no_axis_2d_input
gather_3d_indices_no_axis_2d_input
gather_4d_indices_no_axis_2d_input
gemm
gemm_broadcast_input_C
normalize_across_chw_scalar_scale_4d
......
......@@ -103,7 +103,6 @@ embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_10x1_arbitrary_index_type_int64
floor_int32
gather
gather_nd_scalar_from_2d
gather_nd_1d_from_2d
gather_nd_scalar_from_3d
......@@ -115,12 +114,15 @@ gather_nd_batch_scalar_from_3d
gather_nd_batch_1d_from_3d
gather_nd_batch_2d_from_3d
gather_nd_single_indices
gather_scalar_indices
gather_scalar_indices_no_axis
gather_2d_indices_no_axis
gather_3d_indices_no_axis
gather_4d_indices_no_axis
gather_4d_indices_no_axis_uint8
gather_scalar_indices_axis_1_2d_input
gather_1d_indices_axis_2_4d_input
gather_2d_indices_axis_1_2d_input
gather_scalar_indices_no_axis_2d_input
gather_1d_indices_no_axis_1d_input
gather_2d_indices_no_axis_2d_input
gather_3d_indices_no_axis_2d_input
gather_4d_indices_no_axis_2d_input
scatter_add_4d_indices
scatter_add_3d_indices
scatter_add_2d_indices
......
......@@ -69,7 +69,7 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_4d_indices_no_axis_uint8)
read_vector<uint8_t>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_4d_indices_no_axis)
NGRAPH_TEST(${BACKEND_NAME}, gather_4d_indices_no_axis_2d_input)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2, 3, 4};
......@@ -105,7 +105,7 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_4d_indices_no_axis)
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_3d_indices_no_axis)
NGRAPH_TEST(${BACKEND_NAME}, gather_3d_indices_no_axis_2d_input)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 3, 4};
......@@ -136,7 +136,7 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_3d_indices_no_axis)
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_2d_indices_no_axis)
NGRAPH_TEST(${BACKEND_NAME}, gather_2d_indices_no_axis_2d_input)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
......@@ -162,7 +162,32 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_2d_indices_no_axis)
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_scalar_indices_no_axis)
NGRAPH_TEST(${BACKEND_NAME}, gather_1d_indices_no_axis_1d_input)
{
Shape params_shape{3};
Shape indices_shape{2};
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::Gather>(P, I);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 2.0f, 3.0f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{2.0f, 1.0f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_scalar_indices_no_axis_2d_input)
{
Shape params_shape{3, 2};
Shape indices_shape{};
......@@ -187,7 +212,7 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_scalar_indices_no_axis)
(vector<float>{2.0f, 2.1f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather)
NGRAPH_TEST(${BACKEND_NAME}, gather_2d_indices_axis_1_2d_input)
{
Shape params_shape{3, 3};
Shape indices_shape{1, 2};
......@@ -213,7 +238,38 @@ NGRAPH_TEST(${BACKEND_NAME}, gather)
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_scalar_indices)
NGRAPH_TEST(${BACKEND_NAME}, gather_1d_indices_axis_2_4d_input)
{
Shape params_shape{2, 2, 3, 3};
Shape indices_shape{2};
Shape out_shape{2, 2, 2, 3};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::Gather>(P, I, 2);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f,
1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f,
1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f,
1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 2});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f, 1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f,
1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f, 1.0f, 1.1f, 1.2f, 3.0f, 3.1f, 3.2f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_scalar_indices_axis_1_2d_input)
{
Shape params_shape{3, 3};
Shape indices_shape{};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment