Commit 9d509515 authored by tsocha's avatar tsocha Committed by Robert Kimball

Add support for more types in gather op. (#2926)

* Add test for i32 gather

* Add support for ints to Gather op

* Move helper function to anonymous namespace

* Add more types

* Use static_cast instead of the old one

* Style fix

* Skip tests on GPU

* Add more tests

* Skip tests on gpu

* Change bool to char
parent 3d28d06a
...@@ -29,44 +29,42 @@ namespace ngraph ...@@ -29,44 +29,42 @@ namespace ngraph
{ {
namespace cpu namespace cpu
{ {
template <> namespace
void Builder::BUILDER_DECL(ngraph::op::Gather) {
template <typename T>
CPUKernelFunctor prepare_functor(const Node* node,
const vector<TensorViewWrapper>& args,
const vector<TensorViewWrapper>& out,
CPU_ExternalFunction* external_function)
{ {
auto& functors = external_function->get_functors();
const ngraph::op::Gather* gather = static_cast<const ngraph::op::Gather*>(node); const ngraph::op::Gather* gather = static_cast<const ngraph::op::Gather*>(node);
CPUKernelFunctor functor; auto params_buffer_index =
external_function->get_buffer_index(args[0].get_name());
auto params_buffer_index = external_function->get_buffer_index(args[0].get_name()); auto indices_buffer_index =
auto indices_buffer_index = external_function->get_buffer_index(args[1].get_name()); external_function->get_buffer_index(args[1].get_name());
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = args[1].get_element_type() == element::i64; bool is_int64 = args[1].get_element_type() == element::i64;
auto axis = gather->get_axis(); auto axis = gather->get_axis();
auto params_shape = args[0].get_shape(); auto params_shape = args[0].get_shape();
auto indices_shape = args[1].get_shape(); auto indices_shape = args[1].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64) if (is_int64)
{ {
functor = [&, return
[&,
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
axis, axis,
params_buffer_index, params_buffer_index,
indices_buffer_index, indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
CPUExecutionContext* ectx) { ngraph::runtime::reference::gather<T, int64_t>(
ngraph::runtime::reference::gather<float, int64_t>( static_cast<T*>(ctx->buffer_data[params_buffer_index]),
static_cast<float*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]), static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]), static_cast<T*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
...@@ -75,19 +73,19 @@ namespace ngraph ...@@ -75,19 +73,19 @@ namespace ngraph
} }
else else
{ {
functor = [&, return
[&,
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
axis, axis,
params_buffer_index, params_buffer_index,
indices_buffer_index, indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
CPUExecutionContext* ectx) { ngraph::runtime::reference::gather<T, int32_t>(
ngraph::runtime::reference::gather<float, int32_t>( static_cast<T*>(ctx->buffer_data[params_buffer_index]),
static_cast<float*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]), static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]), static_cast<T*>(ctx->buffer_data[out_buffer_index]),
params_shape, params_shape,
indices_shape, indices_shape,
out_shape, out_shape,
...@@ -95,50 +93,62 @@ namespace ngraph ...@@ -95,50 +93,62 @@ namespace ngraph
}; };
} }
} }
} // namespace
template <>
void Builder::BUILDER_DECL(ngraph::op::Gather)
{
auto& functors = external_function->get_functors();
CPUKernelFunctor functor;
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
functor = prepare_functor<float>(node, args, out, external_function);
}
else if (element_type == element::f64) else if (element_type == element::f64)
{ {
if (is_int64) functor = prepare_functor<double>(node, args, out, external_function);
}
else if (element_type == element::i8)
{ {
functor = [&, functor = prepare_functor<int8_t>(node, args, out, external_function);
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<double, int64_t>(
static_cast<double*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
axis);
};
} }
else else if (element_type == element::i16)
{ {
functor = [&, functor = prepare_functor<int16_t>(node, args, out, external_function);
params_shape, }
indices_shape, else if (element_type == element::i32)
out_shape, {
axis, functor = prepare_functor<int32_t>(node, args, out, external_function);
params_buffer_index, }
indices_buffer_index, else if (element_type == element::i64)
out_buffer_index](CPURuntimeContext* ctx, {
CPUExecutionContext* ectx) { functor = prepare_functor<int64_t>(node, args, out, external_function);
ngraph::runtime::reference::gather<double, int32_t>( }
static_cast<double*>(ctx->buffer_data[params_buffer_index]), else if (element_type == element::u8)
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]), {
static_cast<double*>(ctx->buffer_data[out_buffer_index]), functor = prepare_functor<uint8_t>(node, args, out, external_function);
params_shape,
indices_shape,
out_shape,
axis);
};
} }
else if (element_type == element::u16)
{
functor = prepare_functor<uint16_t>(node, args, out, external_function);
}
else if (element_type == element::u32)
{
functor = prepare_functor<uint32_t>(node, args, out, external_function);
}
else if (element_type == element::u64)
{
functor = prepare_functor<uint64_t>(node, args, out, external_function);
}
else if (element_type == element::boolean)
{
functor = prepare_functor<char>(node, args, out, external_function);
} }
else else
{ {
...@@ -149,6 +159,6 @@ namespace ngraph ...@@ -149,6 +159,6 @@ namespace ngraph
} }
REGISTER_OP_BUILDER(Gather); REGISTER_OP_BUILDER(Gather);
} } // namespace cpu
} } // namespace runtime
} } // namespace ngraph
...@@ -165,3 +165,12 @@ scatter_add_1d_indices ...@@ -165,3 +165,12 @@ scatter_add_1d_indices
scatter_add_scalar_indices scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d scatter_nd_add_2d_to_3d
gather_no_axis_int8
gather_no_axis_int16
gather_no_axis_int32
gather_no_axis_int64
gather_no_axis_uint8
gather_no_axis_uint16
gather_no_axis_uint32
gather_no_axis_uint64
gather_no_axis_bool
...@@ -79,3 +79,12 @@ scatter_add_scalar_indices ...@@ -79,3 +79,12 @@ scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d scatter_nd_add_2d_to_3d
zero_sized_erf zero_sized_erf
gather_no_axis_int8
gather_no_axis_int16
gather_no_axis_int32
gather_no_axis_int64
gather_no_axis_uint8
gather_no_axis_uint16
gather_no_axis_uint32
gather_no_axis_uint64
gather_no_axis_bool
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment