Commit 84ba3a2a authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Enable Gather and ScatterAdd to use Eigen kernel for int8 type. (#3094)

* Enable Gather and ScatterAdd to use Eigen kernel for int8 type.

* Reduce number of supported ranks.

* Fix a bug.
parent d3016b24
......@@ -53,9 +53,11 @@ namespace ngraph
if (is_int64)
{
if (args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8)
if ((args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 3)
{
std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)>
kernel;
......@@ -111,9 +113,11 @@ namespace ngraph
else
{
if (args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8)
if ((args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 3)
{
std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)>
kernel;
......
......@@ -46,7 +46,9 @@ namespace ngraph
}
if (args[0].get_element_type() != element::f64 &&
args[0].get_element_type() != element::f32)
args[0].get_element_type() != element::f32 &&
args[0].get_element_type() != element::u8 &&
args[0].get_element_type() != element::i8)
{
throw ngraph_error("Unsupported type in CPU Builder for ScatterAdd");
}
......@@ -60,67 +62,81 @@ namespace ngraph
if (is_int64)
{
std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)>
kernel;
if (inputs_shape.size() <= 3 && updates_shape.size() <= 3)
{
std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)>
kernel;
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i64);
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i64);
auto functor = [&,
kernel,
inputs_shape,
indices_shape,
updates_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[inputs_buffer_index],
ctx->buffer_data[indices_buffer_index],
ctx->buffer_data[updates_buffer_index],
ctx->buffer_data[out_buffer_index],
inputs_shape,
indices_shape,
updates_shape,
ectx->arena);
};
functors.emplace_back(functor);
auto functor = [&,
kernel,
inputs_shape,
indices_shape,
updates_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[inputs_buffer_index],
ctx->buffer_data[indices_buffer_index],
ctx->buffer_data[updates_buffer_index],
ctx->buffer_data[out_buffer_index],
inputs_shape,
indices_shape,
updates_shape,
ectx->arena);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd");
}
}
else
{
std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)>
kernel;
if (inputs_shape.size() <= 3 && updates_shape.size() <= 3)
{
std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)>
kernel;
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i32);
SELECT_KERNEL_BY_2RANKS(kernel,
args[0].get_element_type(),
inputs_shape.size(),
updates_shape.size(),
runtime::cpu::kernel::scatter_add_i32);
auto functor = [&,
kernel,
inputs_shape,
indices_shape,
updates_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[inputs_buffer_index],
ctx->buffer_data[indices_buffer_index],
ctx->buffer_data[updates_buffer_index],
ctx->buffer_data[out_buffer_index],
inputs_shape,
indices_shape,
updates_shape,
ectx->arena);
};
functors.emplace_back(functor);
auto functor = [&,
kernel,
inputs_shape,
indices_shape,
updates_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(ctx->buffer_data[inputs_buffer_index],
ctx->buffer_data[indices_buffer_index],
ctx->buffer_data[updates_buffer_index],
ctx->buffer_data[out_buffer_index],
inputs_shape,
indices_shape,
updates_shape,
ectx->arena);
};
functors.emplace_back(functor);
}
else
{
throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd");
}
}
}
REGISTER_OP_BUILDER(ScatterAdd);
......
......@@ -211,14 +211,6 @@
KV = K<ET, 2, R2>; \
else if (R1 == 3) \
KV = K<ET, 3, R2>; \
else if (R1 == 4) \
KV = K<ET, 4, R2>; \
else if (R1 == 5) \
KV = K<ET, 5, R2>; \
else if (R1 == 6) \
KV = K<ET, 6, R2>; \
else if (R1 == 7) \
KV = K<ET, 7, R2>; \
else \
throw ngraph_error("Unsupported first rank " + std::to_string(R1) + " for kernel " #K);
......@@ -235,22 +227,6 @@
{ \
SELECT_RANK1(KV, ET, R1, 3, K); \
} \
else if (R2 == 4) \
{ \
SELECT_RANK1(KV, ET, R1, 4, K); \
} \
else if (R2 == 5) \
{ \
SELECT_RANK1(KV, ET, R1, 5, K); \
} \
else if (R2 == 6) \
{ \
SELECT_RANK1(KV, ET, R1, 6, K); \
} \
else if (R2 == 7) \
{ \
SELECT_RANK1(KV, ET, R1, 7, K); \
} \
else \
{ \
throw ngraph_error("Unsupported second rank " + std::to_string(R2) + " for kernel " #K); \
......@@ -270,6 +246,10 @@
{ \
SELECT_2RANKS(KV, uint8_t, R1, R2, K); \
} \
else if (ET == element::i8) \
{ \
SELECT_2RANKS(KV, int8_t, R1, R2, K); \
} \
else \
{ \
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K); \
......
......@@ -1834,8 +1834,9 @@ namespace ngraph
writer.block_begin();
if ((args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::u8) &&
gather->get_axis() == 0)
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
args[0].get_shape().size() <= 3 && out[0].get_shape().size() <= 3)
{
writer << "cpu::kernel::gather<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ", "
......@@ -1895,8 +1896,11 @@ namespace ngraph
}
writer.block_begin();
if (args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::f32)
if ((args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
args[0].get_shape().size() <= 3 && args[2].get_shape().size() <= 3)
{
writer << "cpu::kernel::scatter_add<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ", "
......
......@@ -31,7 +31,7 @@ namespace ngraph
{
namespace kernel
{
// Calculate the indices from position 0 to rank-1.
// Calculate the indices for positions 0 to rank-1.
static void
get_indices(const Shape& shape, int index, std::vector<int>& indices, int rank)
{
......@@ -93,8 +93,11 @@ namespace ngraph
if (indices_rank == 0)
{
//TODO Enable this if compiler issue with CODEGEN is fixed or DEX needs it.
#if 0
#ifdef _OPENMP
#pragma omp parallel for
#endif
#endif
for (int i = 0; i < outer_loop_num; i++)
{
......@@ -142,7 +145,11 @@ namespace ngraph
}
else
{
auto num_indices = shape_size(indices_shape);
size_t num_indices = 1;
for (auto d : indices_shape)
{
num_indices *= d;
}
#ifdef _OPENMP
#pragma omp parallel for
......
......@@ -35,6 +35,7 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}";
#if 0
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_4d_indices)
{
Shape ref_shape{3, 3, 3};
......@@ -122,13 +123,14 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices)
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
#endif
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices)
{
Shape ref_shape{2, 3, 3};
Shape ref_shape{3};
Shape indices_shape{2, 2};
Shape updates_shape{2, 2, 3, 3};
Shape out_shape{2, 3, 3};
Shape updates_shape{2, 2};
Shape out_shape{3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
......@@ -140,20 +142,17 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices)
// Create some tensors for input/output
auto r = backend->create_tensor(element::f32, ref_shape);
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
copy_data(r, vector<float>{0, 1, 2});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 0});
auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9,
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
copy_data(u, vector<float>{1, 2, 3, 4});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {r, i, u});
EXPECT_TRUE(test::all_close_f(
(vector<float>{0, 3, 6, 9, 12, 15, 18, 21, 24, 3, 6, 9, 12, 15, 18, 21, 24, 27}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
(vector<float>{5, 6, 2}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_1d_indices)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment