Commit 9d509515 authored by tsocha's avatar tsocha Committed by Robert Kimball

Add support for more types in gather op. (#2926)

* Add test for i32 gather

* Add support for ints to Gather op

* Move helper function to anonymous namespace

* Add more types

* Use static_cast instead of the old one

* Style fix

* Skip tests on GPU

* Add more tests

* Skip tests on gpu

* Change bool to char
parent 3d28d06a
...@@ -29,44 +29,42 @@ namespace ngraph ...@@ -29,44 +29,42 @@ namespace ngraph
{ {
namespace cpu namespace cpu
{ {
template <> namespace
void Builder::BUILDER_DECL(ngraph::op::Gather) {
template <typename T>
CPUKernelFunctor prepare_functor(const Node* node,
const vector<TensorViewWrapper>& args,
const vector<TensorViewWrapper>& out,
CPU_ExternalFunction* external_function)
{ {
auto& functors = external_function->get_functors();
const ngraph::op::Gather* gather = static_cast<const ngraph::op::Gather*>(node); const ngraph::op::Gather* gather = static_cast<const ngraph::op::Gather*>(node);
CPUKernelFunctor functor; auto params_buffer_index =
external_function->get_buffer_index(args[0].get_name());
auto params_buffer_index = external_function->get_buffer_index(args[0].get_name()); auto indices_buffer_index =
auto indices_buffer_index = external_function->get_buffer_index(args[1].get_name()); external_function->get_buffer_index(args[1].get_name());
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = args[1].get_element_type() == element::i64; bool is_int64 = args[1].get_element_type() == element::i64;
auto axis = gather->get_axis(); auto axis = gather->get_axis();
auto params_shape = args[0].get_shape(); auto params_shape = args[0].get_shape();
auto indices_shape = args[1].get_shape(); auto indices_shape = args[1].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64) if (is_int64)
{ {
functor = [&, return
[&,
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
axis, axis,
params_buffer_index, params_buffer_index,
indices_buffer_index, indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
CPUExecutionContext* ectx) { ngraph::runtime::reference::gather<T, int64_t>(
ngraph::runtime::reference::gather<float, int64_t>( static_cast<T*>(ctx->buffer_data[params_buffer_index]),
static_cast<float*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]), static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]), static_cast<T*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
...@@ -75,19 +73,19 @@ namespace ngraph ...@@ -75,19 +73,19 @@ namespace ngraph
} }
else else
{ {
functor = [&, return
[&,
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
axis, axis,
params_buffer_index, params_buffer_index,
indices_buffer_index, indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
CPUExecutionContext* ectx) { ngraph::runtime::reference::gather<T, int32_t>(
ngraph::runtime::reference::gather<float, int32_t>( static_cast<T*>(ctx->buffer_data[params_buffer_index]),
static_cast<float*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]), static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]), static_cast<T*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
...@@ -95,50 +93,62 @@ namespace ngraph ...@@ -95,50 +93,62 @@ namespace ngraph
}; };
} }
} }
} // namespace
template <>
void Builder::BUILDER_DECL(ngraph::op::Gather)
{
auto& functors = external_function->get_functors();
CPUKernelFunctor functor;
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
functor = prepare_functor<float>(node, args, out, external_function);
}
else if (element_type == element::f64) else if (element_type == element::f64)
{ {
if (is_int64) functor = prepare_functor<double>(node, args, out, external_function);
}
else if (element_type == element::i8)
{ {
functor = [&, functor = prepare_functor<int8_t>(node, args, out, external_function);
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<double, int64_t>(
static_cast<double*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
axis);
};
} }
else else if (element_type == element::i16)
{ {
functor = [&, functor = prepare_functor<int16_t>(node, args, out, external_function);
params_shape, }
indices_shape, else if (element_type == element::i32)
out_shape, {
axis, functor = prepare_functor<int32_t>(node, args, out, external_function);
params_buffer_index, }
indices_buffer_index, else if (element_type == element::i64)
out_buffer_index](CPURuntimeContext* ctx, {
CPUExecutionContext* ectx) { functor = prepare_functor<int64_t>(node, args, out, external_function);
ngraph::runtime::reference::gather<double, int32_t>( }
static_cast<double*>(ctx->buffer_data[params_buffer_index]), else if (element_type == element::u8)
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]), {
static_cast<double*>(ctx->buffer_data[out_buffer_index]), functor = prepare_functor<uint8_t>(node, args, out, external_function);
params_shape,
indices_shape,
out_shape,
axis);
};
} }
else if (element_type == element::u16)
{
functor = prepare_functor<uint16_t>(node, args, out, external_function);
}
else if (element_type == element::u32)
{
functor = prepare_functor<uint32_t>(node, args, out, external_function);
}
else if (element_type == element::u64)
{
functor = prepare_functor<uint64_t>(node, args, out, external_function);
}
else if (element_type == element::boolean)
{
functor = prepare_functor<char>(node, args, out, external_function);
} }
else else
{ {
...@@ -149,6 +159,6 @@ namespace ngraph ...@@ -149,6 +159,6 @@ namespace ngraph
} }
REGISTER_OP_BUILDER(Gather); REGISTER_OP_BUILDER(Gather);
} } // namespace cpu
} } // namespace runtime
} } // namespace ngraph
...@@ -165,3 +165,12 @@ scatter_add_1d_indices ...@@ -165,3 +165,12 @@ scatter_add_1d_indices
scatter_add_scalar_indices scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d scatter_nd_add_2d_to_3d
gather_no_axis_int8
gather_no_axis_int16
gather_no_axis_int32
gather_no_axis_int64
gather_no_axis_uint8
gather_no_axis_uint16
gather_no_axis_uint32
gather_no_axis_uint64
gather_no_axis_bool
...@@ -79,3 +79,12 @@ scatter_add_scalar_indices ...@@ -79,3 +79,12 @@ scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d scatter_nd_add_2d_to_3d
zero_sized_erf zero_sized_erf
gather_no_axis_int8
gather_no_axis_int16
gather_no_axis_int32
gather_no_axis_int64
gather_no_axis_uint8
gather_no_axis_uint16
gather_no_axis_uint32
gather_no_axis_uint64
gather_no_axis_bool
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "util/all_close_f.hpp" #include "util/all_close_f.hpp"
#include "util/ndarray.hpp" #include "util/ndarray.hpp"
#include "util/random.hpp" #include "util/random.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp" #include "util/test_control.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
...@@ -413,3 +414,237 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_2d_from_3d) ...@@ -413,3 +414,237 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_2d_from_3d)
read_vector<float>(result), read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS)); MIN_FLOAT_TOLERANCE_BITS));
} }
NGRAPH_TEST(${BACKEND_NAME}, gather_no_axis_int8)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::i8, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::Gather>(P, I);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::i8, params_shape);
copy_data(p, vector<int8_t>{10, 11, 20, 21, 30, 31});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 2});
auto result = backend->create_tensor(element::i8, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close((vector<int8_t>{10, 11, 20, 21, 20, 21, 30, 31}),
read_vector<int8_t>(result),
static_cast<int8_t> MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_no_axis_int16)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::i16, params_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto G = make_shared<op::Gather>(P, I);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::i16, params_shape);
copy_data(p, vector<int16_t>{10, 11, 20, 21, 30, 31});
auto i = backend->create_tensor(element::i64, indices_shape);
copy_data(i, vector<int64_t>{0, 1, 1, 2});
auto result = backend->create_tensor(element::i16, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close((vector<int16_t>{10, 11, 20, 21, 20, 21, 30, 31}),
read_vector<int16_t>(result),
static_cast<int16_t> MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_no_axis_int32)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::i32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::Gather>(P, I);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::i32, params_shape);
copy_data(p, vector<int32_t>{10, 11, 20, 21, 30, 31});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 2});
auto result = backend->create_tensor(element::i32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close((vector<int32_t>{10, 11, 20, 21, 20, 21, 30, 31}),
read_vector<int32_t>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_no_axis_int64)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::i64, params_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto G = make_shared<op::Gather>(P, I);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::i64, params_shape);
copy_data(p, vector<int64_t>{10, 11, 20, 21, 30, 31});
auto i = backend->create_tensor(element::i64, indices_shape);
copy_data(i, vector<int64_t>{0, 1, 1, 2});
auto result = backend->create_tensor(element::i64, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close((vector<int64_t>{10, 11, 20, 21, 20, 21, 30, 31}),
read_vector<int64_t>(result),
static_cast<int64_t> MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_no_axis_uint8)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::u8, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::Gather>(P, I);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::u8, params_shape);
copy_data(p, vector<uint8_t>{10, 11, 20, 21, 30, 31});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 2});
auto result = backend->create_tensor(element::u8, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close((vector<uint8_t>{10, 11, 20, 21, 20, 21, 30, 31}),
read_vector<uint8_t>(result),
static_cast<uint8_t> MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_no_axis_uint16)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::u16, params_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto G = make_shared<op::Gather>(P, I);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::u16, params_shape);
copy_data(p, vector<uint16_t>{10, 11, 20, 21, 30, 31});
auto i = backend->create_tensor(element::i64, indices_shape);
copy_data(i, vector<int64_t>{0, 1, 1, 2});
auto result = backend->create_tensor(element::u16, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close((vector<uint16_t>{10, 11, 20, 21, 20, 21, 30, 31}),
read_vector<uint16_t>(result),
static_cast<uint16_t> MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_no_axis_uint32)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::u32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::Gather>(P, I);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::u32, params_shape);
copy_data(p, vector<uint32_t>{10, 11, 20, 21, 30, 31});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 2});
auto result = backend->create_tensor(element::u32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close((vector<uint32_t>{10, 11, 20, 21, 20, 21, 30, 31}),
read_vector<uint32_t>(result),
static_cast<uint32_t> MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_no_axis_uint64)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::u64, params_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto G = make_shared<op::Gather>(P, I);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::u64, params_shape);
copy_data(p, vector<uint64_t>{10, 11, 20, 21, 30, 31});
auto i = backend->create_tensor(element::i64, indices_shape);
copy_data(i, vector<int64_t>{0, 1, 1, 2});
auto result = backend->create_tensor(element::u64, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close((vector<uint64_t>{10, 11, 20, 21, 20, 21, 30, 31}),
read_vector<uint64_t>(result),
static_cast<uint64_t> MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_no_axis_bool)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::boolean, params_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto G = make_shared<op::Gather>(P, I);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::boolean, params_shape);
copy_data(p, vector<char>{1, 1, 1, 0, 0, 1});
auto i = backend->create_tensor(element::i64, indices_shape);
copy_data(i, vector<int64_t>{0, 1, 1, 2});
auto result = backend->create_tensor(element::boolean, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close((vector<char>{1, 1, 1, 0, 1, 0, 0, 1}),
read_vector<char>(result),
static_cast<char> MIN_FLOAT_TOLERANCE_BITS));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment