Commit 38a389d6 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Use Eigen kernel for more cases for Gather and ScatterAdd. (#3268)

parent 0818dabc
......@@ -57,7 +57,7 @@ namespace ngraph
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 3)
params_shape.size() <= 3 && out_shape.size() <= 5)
{
std::function<decltype(runtime::cpu::kernel::gather_i64<float, 2, 2>)>
kernel;
......@@ -117,7 +117,7 @@ namespace ngraph
args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
params_shape.size() <= 3 && out_shape.size() <= 3)
params_shape.size() <= 3 && out_shape.size() <= 5)
{
std::function<decltype(runtime::cpu::kernel::gather_i32<float, 2, 2>)>
kernel;
......
......@@ -62,7 +62,7 @@ namespace ngraph
if (is_int64)
{
if (inputs_shape.size() <= 3 && updates_shape.size() <= 3)
if (inputs_shape.size() <= 3 && updates_shape.size() <= 5)
{
std::function<decltype(runtime::cpu::kernel::scatter_add_i64<float, 2, 2>)>
kernel;
......@@ -101,7 +101,7 @@ namespace ngraph
}
else
{
if (inputs_shape.size() <= 3 && updates_shape.size() <= 3)
if (inputs_shape.size() <= 3 && updates_shape.size() <= 5)
{
std::function<decltype(runtime::cpu::kernel::scatter_add_i32<float, 2, 2>)>
kernel;
......
......@@ -227,6 +227,14 @@
{ \
SELECT_RANK1(KV, ET, R1, 3, K); \
} \
else if (R2 == 4) \
{ \
SELECT_RANK1(KV, ET, R1, 4, K); \
} \
else if (R2 == 5) \
{ \
SELECT_RANK1(KV, ET, R1, 5, K); \
} \
else \
{ \
throw ngraph_error("Unsupported second rank " + std::to_string(R2) + " for kernel " #K); \
......
......@@ -1833,7 +1833,7 @@ namespace ngraph
args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
args[0].get_shape().size() <= 3 && out[0].get_shape().size() <= 3)
args[0].get_shape().size() <= 3 && out[0].get_shape().size() <= 5)
{
writer << "cpu::kernel::gather<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ", "
......@@ -1897,7 +1897,7 @@ namespace ngraph
args[0].get_element_type() == element::f32 ||
args[0].get_element_type() == element::u8 ||
args[0].get_element_type() == element::i8) &&
args[0].get_shape().size() <= 3 && args[2].get_shape().size() <= 3)
args[0].get_shape().size() <= 3 && args[2].get_shape().size() <= 5)
{
writer << "cpu::kernel::scatter_add<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ", "
......
......@@ -88,6 +88,7 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_4d_indices)
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
#endif
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices)
{
......@@ -123,7 +124,6 @@ NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices)
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
#endif
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment