Commit 84ba3a2a authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Enable Gather and ScatterAdd to use Eigen kernel for int8 type. (#3094)

* Enable Gather and ScatterAdd to use Eigen kernel for int8 type.

* Reduce number of supported ranks.

* Fix a bug.
parent d3016b24
...@@ -53,9 +53,11 @@ namespace ngraph ...@@ -53,9 +53,11 @@ namespace ngraph
if (is_int64) if (is_int64)
{ {
if (args[0].get_element_type() == element::f32 || if ((args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::f64 || args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8) args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 3)
{ {
std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)>
kernel; kernel;
...@@ -111,9 +113,11 @@ namespace ngraph ...@@ -111,9 +113,11 @@ namespace ngraph
else else
{ {
if (args[0].get_element_type() == element::f32 || if ((args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::f64 || args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8) args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 3)
{ {
std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)>
kernel; kernel;
......
...@@ -46,7 +46,9 @@ namespace ngraph ...@@ -46,7 +46,9 @@ namespace ngraph
} }
if (args[0].get_element_type() != element::f64 && if (args[0].get_element_type() != element::f64 &&
args[0].get_element_type() != element::f32) args[0].get_element_type() != element::f32 &&
args[0].get_element_type() != element::u8 &&
args[0].get_element_type() != element::i8)
{ {
throw ngraph_error("Unsupported type in CPU Builder for ScatterAdd"); throw ngraph_error("Unsupported type in CPU Builder for ScatterAdd");
} }
...@@ -59,6 +61,8 @@ namespace ngraph ...@@ -59,6 +61,8 @@ namespace ngraph
auto element_type = args[0].get_element_type(); auto element_type = args[0].get_element_type();
if (is_int64) if (is_int64)
{
if (inputs_shape.size() <= 3 && updates_shape.size() <= 3)
{ {
std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)>
kernel; kernel;
...@@ -91,6 +95,13 @@ namespace ngraph ...@@ -91,6 +95,13 @@ namespace ngraph
functors.emplace_back(functor); functors.emplace_back(functor);
} }
else else
{
throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd");
}
}
else
{
if (inputs_shape.size() <= 3 && updates_shape.size() <= 3)
{ {
std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)>
kernel; kernel;
...@@ -122,6 +133,11 @@ namespace ngraph ...@@ -122,6 +133,11 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
else
{
throw ngraph_error("Unsupported ranks in CPU Builder for ScatterAdd");
}
}
} }
REGISTER_OP_BUILDER(ScatterAdd); REGISTER_OP_BUILDER(ScatterAdd);
} }
......
...@@ -211,14 +211,6 @@ ...@@ -211,14 +211,6 @@
KV = K<ET, 2, R2>; \ KV = K<ET, 2, R2>; \
else if (R1 == 3) \ else if (R1 == 3) \
KV = K<ET, 3, R2>; \ KV = K<ET, 3, R2>; \
else if (R1 == 4) \
KV = K<ET, 4, R2>; \
else if (R1 == 5) \
KV = K<ET, 5, R2>; \
else if (R1 == 6) \
KV = K<ET, 6, R2>; \
else if (R1 == 7) \
KV = K<ET, 7, R2>; \
else \ else \
throw ngraph_error("Unsupported first rank " + std::to_string(R1) + " for kernel " #K); throw ngraph_error("Unsupported first rank " + std::to_string(R1) + " for kernel " #K);
...@@ -235,22 +227,6 @@ ...@@ -235,22 +227,6 @@
{ \ { \
SELECT_RANK1(KV, ET, R1, 3, K); \ SELECT_RANK1(KV, ET, R1, 3, K); \
} \ } \
else if (R2 == 4) \
{ \
SELECT_RANK1(KV, ET, R1, 4, K); \
} \
else if (R2 == 5) \
{ \
SELECT_RANK1(KV, ET, R1, 5, K); \
} \
else if (R2 == 6) \
{ \
SELECT_RANK1(KV, ET, R1, 6, K); \
} \
else if (R2 == 7) \
{ \
SELECT_RANK1(KV, ET, R1, 7, K); \
} \
else \ else \
{ \ { \
throw ngraph_error("Unsupported second rank " + std::to_string(R2) + " for kernel " #K); \ throw ngraph_error("Unsupported second rank " + std::to_string(R2) + " for kernel " #K); \
...@@ -270,6 +246,10 @@ ...@@ -270,6 +246,10 @@
{ \ { \
SELECT_2RANKS(KV, uint8_t, R1, R2, K); \ SELECT_2RANKS(KV, uint8_t, R1, R2, K); \
} \ } \
else if (ET == element::i8) \
{ \
SELECT_2RANKS(KV, int8_t, R1, R2, K); \
} \
else \ else \
{ \ { \
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K); \ throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K); \
......
...@@ -1834,8 +1834,9 @@ namespace ngraph ...@@ -1834,8 +1834,9 @@ namespace ngraph
writer.block_begin(); writer.block_begin();
if ((args[0].get_element_type() == element::f64 || if ((args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::f32 || args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::u8) && args[0].get_element_type() == element::u8 ||
gather->get_axis() == 0) args[0].get_element_type() == element::i8) &&
args[0].get_shape().size() <= 3 && out[0].get_shape().size() <= 3)
{ {
writer << "cpu::kernel::gather<" << args[0].get_type() << ", " writer << "cpu::kernel::gather<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ", " << args[1].get_element_type().c_type_string() << ", "
...@@ -1895,8 +1896,11 @@ namespace ngraph ...@@ -1895,8 +1896,11 @@ namespace ngraph
} }
writer.block_begin(); writer.block_begin();
if (args[0].get_element_type() == element::f64 || if ((args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::f32) args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
args[0].get_shape().size() <= 3 && args[2].get_shape().size() <= 3)
{ {
writer << "cpu::kernel::scatter_add<" << args[0].get_type() << ", " writer << "cpu::kernel::scatter_add<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ", " << args[1].get_element_type().c_type_string() << ", "
......
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
{ {
namespace kernel namespace kernel
{ {
// Calculate the indices from position 0 to rank-1. // Calculate the indices for positions 0 to rank-1.
static void static void
get_indices(const Shape& shape, int index, std::vector<int>& indices, int rank) get_indices(const Shape& shape, int index, std::vector<int>& indices, int rank)
{ {
...@@ -93,8 +93,11 @@ namespace ngraph ...@@ -93,8 +93,11 @@ namespace ngraph
if (indices_rank == 0) if (indices_rank == 0)
{ {
//TODO Enable this if compiler issue with CODEGEN is fixed or DEX needs it.
#if 0
#ifdef _OPENMP #ifdef _OPENMP
#pragma omp parallel for #pragma omp parallel for
#endif
#endif #endif
for (int i = 0; i < outer_loop_num; i++) for (int i = 0; i < outer_loop_num; i++)
{ {
...@@ -142,7 +145,11 @@ namespace ngraph ...@@ -142,7 +145,11 @@ namespace ngraph
} }
else else
{ {
auto num_indices = shape_size(indices_shape); size_t num_indices = 1;
for (auto d : indices_shape)
{
num_indices *= d;
}
#ifdef _OPENMP #ifdef _OPENMP
#pragma omp parallel for #pragma omp parallel for
......
...@@ -35,6 +35,7 @@ using namespace ngraph; ...@@ -35,6 +35,7 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}"; static string s_manifest = "${MANIFEST}";
#if 0
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_4d_indices) NGRAPH_TEST(${BACKEND_NAME}, scatter_add_4d_indices)
{ {
Shape ref_shape{3, 3, 3}; Shape ref_shape{3, 3, 3};
...@@ -122,13 +123,14 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices) ...@@ -122,13 +123,14 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices)
read_vector<float>(result), read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS)); MIN_FLOAT_TOLERANCE_BITS));
} }
#endif
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices) NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices)
{ {
Shape ref_shape{2, 3, 3}; Shape ref_shape{3};
Shape indices_shape{2, 2}; Shape indices_shape{2, 2};
Shape updates_shape{2, 2, 3, 3}; Shape updates_shape{2, 2};
Shape out_shape{2, 3, 3}; Shape out_shape{3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape); auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape); auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape); auto U = make_shared<op::Parameter>(element::f32, updates_shape);
...@@ -140,20 +142,17 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices) ...@@ -140,20 +142,17 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices)
// Create some tensors for input/output // Create some tensors for input/output
auto r = backend->create_tensor(element::f32, ref_shape); auto r = backend->create_tensor(element::f32, ref_shape);
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9}); copy_data(r, vector<float>{0, 1, 2});
auto i = backend->create_tensor(element::i32, indices_shape); auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 0}); copy_data(i, vector<int32_t>{0, 1, 1, 0});
auto u = backend->create_tensor(element::f32, updates_shape); auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, copy_data(u, vector<float>{1, 2, 3, 4});
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, out_shape); auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f); auto c = backend->compile(f);
c->call_with_validate({result}, {r, i, u}); c->call_with_validate({result}, {r, i, u});
EXPECT_TRUE(test::all_close_f( EXPECT_TRUE(test::all_close_f(
(vector<float>{0, 3, 6, 9, 12, 15, 18, 21, 24, 3, 6, 9, 12, 15, 18, 21, 24, 27}), (vector<float>{5, 6, 2}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
} }
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_1d_indices) NGRAPH_TEST(${BACKEND_NAME}, scatter_add_1d_indices)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment