Commit 9d509515 authored by tsocha's avatar tsocha Committed by Robert Kimball

Add support for more types in gather op. (#2926)

* Add test for i32 gather

* Add support for ints to Gather op

* Move helper function to anonymous namespace

* Add more types

* Use static_cast instead of the old one

* Style fix

* Skip tests on GPU

* Add more tests

* Skip tests on gpu

* Change bool to char
parent 3d28d06a
......@@ -29,44 +29,42 @@ namespace ngraph
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Gather)
namespace
{
template <typename T>
CPUKernelFunctor prepare_functor(const Node* node,
const vector<TensorViewWrapper>& args,
const vector<TensorViewWrapper>& out,
CPU_ExternalFunction* external_function)
{
auto& functors = external_function->get_functors();
const ngraph::op::Gather* gather = static_cast<const ngraph::op::Gather*>(node);
CPUKernelFunctor functor;
auto params_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto indices_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto params_buffer_index =
external_function->get_buffer_index(args[0].get_name());
auto indices_buffer_index =
external_function->get_buffer_index(args[1].get_name());
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = args[1].get_element_type() == element::i64;
auto axis = gather->get_axis();
auto params_shape = args[0].get_shape();
auto indices_shape = args[1].get_shape();
auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64)
{
functor = [&,
return
[&,
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<float, int64_t>(
static_cast<float*>(ctx->buffer_data[params_buffer_index]),
out_buffer_index](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<T, int64_t>(
static_cast<T*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
static_cast<T*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
......@@ -75,19 +73,19 @@ namespace ngraph
}
else
{
functor = [&,
return
[&,
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<float, int32_t>(
static_cast<float*>(ctx->buffer_data[params_buffer_index]),
out_buffer_index](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<T, int32_t>(
static_cast<T*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
static_cast<T*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
......@@ -95,50 +93,62 @@ namespace ngraph
};
}
}
} // namespace
template <>
void Builder::BUILDER_DECL(ngraph::op::Gather)
{
auto& functors = external_function->get_functors();
CPUKernelFunctor functor;
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
functor = prepare_functor<float>(node, args, out, external_function);
}
else if (element_type == element::f64)
{
if (is_int64)
functor = prepare_functor<double>(node, args, out, external_function);
}
else if (element_type == element::i8)
{
functor = [&,
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<double, int64_t>(
static_cast<double*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
axis);
};
functor = prepare_functor<int8_t>(node, args, out, external_function);
}
else
else if (element_type == element::i16)
{
functor = [&,
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<double, int32_t>(
static_cast<double*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
axis);
};
functor = prepare_functor<int16_t>(node, args, out, external_function);
}
else if (element_type == element::i32)
{
functor = prepare_functor<int32_t>(node, args, out, external_function);
}
else if (element_type == element::i64)
{
functor = prepare_functor<int64_t>(node, args, out, external_function);
}
else if (element_type == element::u8)
{
functor = prepare_functor<uint8_t>(node, args, out, external_function);
}
else if (element_type == element::u16)
{
functor = prepare_functor<uint16_t>(node, args, out, external_function);
}
else if (element_type == element::u32)
{
functor = prepare_functor<uint32_t>(node, args, out, external_function);
}
else if (element_type == element::u64)
{
functor = prepare_functor<uint64_t>(node, args, out, external_function);
}
else if (element_type == element::boolean)
{
functor = prepare_functor<char>(node, args, out, external_function);
}
else
{
......@@ -149,6 +159,6 @@ namespace ngraph
}
REGISTER_OP_BUILDER(Gather);
}
}
}
} // namespace cpu
} // namespace runtime
} // namespace ngraph
......@@ -165,3 +165,12 @@ scatter_add_1d_indices
scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d
gather_no_axis_int8
gather_no_axis_int16
gather_no_axis_int32
gather_no_axis_int64
gather_no_axis_uint8
gather_no_axis_uint16
gather_no_axis_uint32
gather_no_axis_uint64
gather_no_axis_bool
......@@ -79,3 +79,12 @@ scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d
zero_sized_erf
gather_no_axis_int8
gather_no_axis_int16
gather_no_axis_int32
gather_no_axis_int64
gather_no_axis_uint8
gather_no_axis_uint16
gather_no_axis_uint32
gather_no_axis_uint64
gather_no_axis_bool
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment