Commit 38a389d6 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Use Eigen kernel for more cases for Gather and ScatterAdd. (#3268)

parent 0818dabc
...@@ -57,7 +57,7 @@ namespace ngraph ...@@ -57,7 +57,7 @@ namespace ngraph
args[0].get_element_type() == element::f64 || args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 || args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) && args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 3) params_shape.size() <= 3 && out_shape.size() <= 5)
{ {
std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)>
kernel; kernel;
...@@ -117,7 +117,7 @@ namespace ngraph ...@@ -117,7 +117,7 @@ namespace ngraph
args[0].get_element_type() == element::f64 || args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 || args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) && args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 3) params_shape.size() <= 3 && out_shape.size() <= 5)
{ {
std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)>
kernel; kernel;
......
...@@ -62,7 +62,7 @@ namespace ngraph ...@@ -62,7 +62,7 @@ namespace ngraph
if (is_int64) if (is_int64)
{ {
if (inputs_shape.size() <= 3 && updates_shape.size() <= 3) if (inputs_shape.size() <= 3 && updates_shape.size() <= 5)
{ {
std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)>
kernel; kernel;
...@@ -101,7 +101,7 @@ namespace ngraph ...@@ -101,7 +101,7 @@ namespace ngraph
} }
else else
{ {
if (inputs_shape.size() <= 3 && updates_shape.size() <= 3) if (inputs_shape.size() <= 3 && updates_shape.size() <= 5)
{ {
std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)> std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)>
kernel; kernel;
......
...@@ -227,6 +227,14 @@ ...@@ -227,6 +227,14 @@
{ \ { \
SELECT_RANK1(KV, ET, R1, 3, K); \ SELECT_RANK1(KV, ET, R1, 3, K); \
} \ } \
else if (R2 == 4) \
{ \
SELECT_RANK1(KV, ET, R1, 4, K); \
} \
else if (R2 == 5) \
{ \
SELECT_RANK1(KV, ET, R1, 5, K); \
} \
else \ else \
{ \ { \
throw ngraph_error("Unsupported second rank " + std::to_string(R2) + " for kernel " #K); \ throw ngraph_error("Unsupported second rank " + std::to_string(R2) + " for kernel " #K); \
......
...@@ -1833,7 +1833,7 @@ namespace ngraph ...@@ -1833,7 +1833,7 @@ namespace ngraph
args[0].get_element_type() == element::f32 || args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::u8 || args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) && args[0].get_element_type() == element::i8) &&
args[0].get_shape().size() <= 3 && out[0].get_shape().size() <= 3) args[0].get_shape().size() <= 3 && out[0].get_shape().size() <= 5)
{ {
writer << "cpu::kernel::gather<" << args[0].get_type() << ", " writer << "cpu::kernel::gather<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ", " << args[1].get_element_type().c_type_string() << ", "
...@@ -1897,7 +1897,7 @@ namespace ngraph ...@@ -1897,7 +1897,7 @@ namespace ngraph
args[0].get_element_type() == element::f32 || args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::u8 || args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) && args[0].get_element_type() == element::i8) &&
args[0].get_shape().size() <= 3 && args[2].get_shape().size() <= 3) args[0].get_shape().size() <= 3 && args[2].get_shape().size() <= 5)
{ {
writer << "cpu::kernel::scatter_add<" << args[0].get_type() << ", " writer << "cpu::kernel::scatter_add<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ", " << args[1].get_element_type().c_type_string() << ", "
......
...@@ -88,6 +88,7 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_4d_indices) ...@@ -88,6 +88,7 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_4d_indices)
read_vector<float>(result), read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS)); MIN_FLOAT_TOLERANCE_BITS));
} }
#endif
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices) NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices)
{ {
...@@ -123,7 +124,6 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices) ...@@ -123,7 +124,6 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices)
read_vector<float>(result), read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS)); MIN_FLOAT_TOLERANCE_BITS));
} }
#endif
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices) NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment