Commit a20c710b authored by Fenglei's avatar Fenglei Committed by Scott Cyphers

cast to ArgMin/ArgMax based on reduce_op (#1851)

* add ArgMin, ArgMax

* change to scale

* format

* add exception
parent fb49e0c2
......@@ -176,15 +176,29 @@ void runtime::gpu::GPU_Emitter::emit_ArgMin(EMIT_ARGS)
external_function, writer, node, args, out, reduce_op);
}
void runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t reduce_mode)
void runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t reduce_op)
{
if (out[0].get_size() == 0)
{
return;
}
auto argmax = static_cast<const ngraph::op::ArgMax*>(node);
std::vector<size_t> axes{argmax->get_reduction_axis()};
auto axis_set = AxisSet(axes);
size_t axis;
if (reduce_op == CUDNN_REDUCE_TENSOR_MIN)
{
auto argmin = static_cast<const ngraph::op::ArgMin*>(node);
axis = argmin->get_reduction_axis();
}
else if (reduce_op == CUDNN_REDUCE_TENSOR_MAX)
{
auto argmax = static_cast<const ngraph::op::ArgMax*>(node);
axis = argmax->get_reduction_axis();
}
else
{
throw std::runtime_error("Not supported. Only Min/Max op are supported by ArgReduce.");
}
auto axis_set = AxisSet{axis};
std::vector<element::Type> dtypes{args[0].get_element_type(), out[0].get_element_type()};
......@@ -192,7 +206,7 @@ void runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t
{
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_reduce_forward(reduce_mode,
auto index = cudnn_emitter->build_reduce_forward(reduce_op,
dtypes,
args[0].get_shape(),
axis_set,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment