Commit a20c710b authored by Fenglei's avatar Fenglei Committed by Scott Cyphers

cast to ArgMin/ArgMax based on reduce_op (#1851)

* add ArgMin, ArgMax

* change to scale

* format

* add exception
parent fb49e0c2
...@@ -176,15 +176,29 @@ void runtime::gpu::GPU_Emitter::emit_ArgMin(EMIT_ARGS) ...@@ -176,15 +176,29 @@ void runtime::gpu::GPU_Emitter::emit_ArgMin(EMIT_ARGS)
external_function, writer, node, args, out, reduce_op); external_function, writer, node, args, out, reduce_op);
} }
void runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t reduce_mode) void runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t reduce_op)
{ {
if (out[0].get_size() == 0) if (out[0].get_size() == 0)
{ {
return; return;
} }
auto argmax = static_cast<const ngraph::op::ArgMax*>(node);
std::vector<size_t> axes{argmax->get_reduction_axis()}; size_t axis;
auto axis_set = AxisSet(axes); if (reduce_op == CUDNN_REDUCE_TENSOR_MIN)
{
auto argmin = static_cast<const ngraph::op::ArgMin*>(node);
axis = argmin->get_reduction_axis();
}
else if (reduce_op == CUDNN_REDUCE_TENSOR_MAX)
{
auto argmax = static_cast<const ngraph::op::ArgMax*>(node);
axis = argmax->get_reduction_axis();
}
else
{
throw std::runtime_error("Not supported. Only Min/Max op are supported by ArgReduce.");
}
auto axis_set = AxisSet{axis};
std::vector<element::Type> dtypes{args[0].get_element_type(), out[0].get_element_type()}; std::vector<element::Type> dtypes{args[0].get_element_type(), out[0].get_element_type()};
...@@ -192,7 +206,7 @@ void runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t ...@@ -192,7 +206,7 @@ void runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t
{ {
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter(); auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_reduce_forward(reduce_mode, auto index = cudnn_emitter->build_reduce_forward(reduce_op,
dtypes, dtypes,
args[0].get_shape(), args[0].get_shape(),
axis_set, axis_set,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment