Commit a2d8a9fd authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

add int32 support to argmin/argmax (#2288)

parent 1df4cc36
...@@ -108,6 +108,34 @@ namespace ngraph ...@@ -108,6 +108,34 @@ namespace ngraph
}; };
} }
} }
else if (element_type == element::i32)
{
if (is_int64)
{
std::function<decltype(runtime::cpu::kernel::argmax<int, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
else
{
std::function<decltype(runtime::cpu::kernel::argmax<int, int, 1>)> kernel;
SELECT_RANK2(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
}
else else
{ {
throw ngraph_error("Unsupported type in CPU Builder for ArgMax"); throw ngraph_error("Unsupported type in CPU Builder for ArgMax");
......
...@@ -108,6 +108,34 @@ namespace ngraph ...@@ -108,6 +108,34 @@ namespace ngraph
}; };
} }
} }
else if (element_type == element::i32)
{
if (is_int64)
{
std::function<decltype(runtime::cpu::kernel::argmin<int, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, int, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
else
{
std::function<decltype(runtime::cpu::kernel::argmin<int, int, 1>)> kernel;
SELECT_RANK2(
kernel, int, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
}
else else
{ {
throw ngraph_error("Unsupported type in CPU Builder for ArgMin"); throw ngraph_error("Unsupported type in CPU Builder for ArgMin");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment