Commit 6f30b32b authored by Ayan Moitra's avatar Ayan Moitra Committed by Scott Cyphers

Support ArgMin and ArgMax for NVGPU Backend (#1737)

* Project initialization commit

* Added unit tests for 3D tensors for argmax

* Refactored reduce to be used by argmax argmin. argmax argmin still has some issues. WIP

* [WIP]First working version of ArgMax ArgMin

* added reduce buffer for the cudnn api calls

* added reduce buffer for the cudnn api calls

* Further modifications. Using rvalues to pass enums to build reduce method

* more unit tests added

* Incorporate Fenglei's comments

* Incorporating Chris's first set of comments

* small change to test file

* Resolving clang issue that was causing argmin test to fail

* Incorporate Chris's  comments

* clang format issue
parent 2f49032f
...@@ -1783,8 +1783,9 @@ size_t runtime::gpu::CUDAEmitter::build_primitive(const op::Softmax* node) ...@@ -1783,8 +1783,9 @@ size_t runtime::gpu::CUDAEmitter::build_primitive(const op::Softmax* node)
auto output_type = out[0].get_element_type().c_type_string(); auto output_type = out[0].get_element_type().c_type_string();
auto exp_index = build_elementwise<ngraph::op::Exp>({input_type, output_type}, input_shape); auto exp_index = build_elementwise<ngraph::op::Exp>({input_type, output_type}, input_shape);
std::vector<element::Type> dtypes{args[0].get_element_type(), out[0].get_element_type()};
auto reduce_index = cudnn_emitter->build_reduce_forward( auto reduce_index = cudnn_emitter->build_reduce_forward(
CUDNN_REDUCE_TENSOR_ADD, output_type, input_shape, axes); CUDNN_REDUCE_TENSOR_ADD, dtypes, input_shape, axes, CUDNNEmitter::ReductionMode::Reduce);
size_t divide_index = build_softmax_divide( size_t divide_index = build_softmax_divide(
std::vector<std::string>(3, output_type), input_shape, reduced_shape, axes_flag); std::vector<std::string>(3, output_type), input_shape, reduced_shape, axes_flag);
......
...@@ -153,13 +153,22 @@ cudnnDataType_t runtime::gpu::CUDNNEmitter::get_cudnn_datatype(std::string dtype ...@@ -153,13 +153,22 @@ cudnnDataType_t runtime::gpu::CUDNNEmitter::get_cudnn_datatype(std::string dtype
return p->second; return p->second;
} }
cudnnDataType_t runtime::gpu::CUDNNEmitter::get_cudnn_datatype(const element::Type& dtype)
{
return get_cudnn_datatype(dtype.c_type_string());
}
size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorOp_t& reduce_op, size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorOp_t& reduce_op,
const std::string& dtype, const std::vector<element::Type>& dtypes,
const Shape& input_shape, const Shape& input_shape,
const AxisSet& reduction_axes) const AxisSet& reduction_axes,
const ReductionMode& reduction_mode)
{ {
auto input_type = dtypes[0];
auto output_type = dtypes[1];
std::stringstream ss; std::stringstream ss;
ss << "reduce_op_" << reduce_op << "_dtype_" << dtype << "_i" << join(input_shape, "_") << "_ra" ss << "reduce_" << reduce_op << input_type.c_type_string() << "_reduction_mode_"
<< static_cast<int>(reduction_mode) << "_i" << join(input_shape, "_") << "_ra"
<< join(reduction_axes, "_"); << join(reduction_axes, "_");
std::string hash = ss.str(); std::string hash = ss.str();
...@@ -171,7 +180,7 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorO ...@@ -171,7 +180,7 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorO
} }
auto& desc = m_descriptors.build<cudnnReduceTensorDescriptor_t>(); auto& desc = m_descriptors.build<cudnnReduceTensorDescriptor_t>();
cudnnDataType_t data_type = get_cudnn_datatype(dtype); cudnnDataType_t data_type = get_cudnn_datatype(input_type);
cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW; cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
auto& input_desc = tensor_descriptor_from_shape(input_shape, data_type, tensor_format); auto& input_desc = tensor_descriptor_from_shape(input_shape, data_type, tensor_format);
Shape output_shape = input_shape; Shape output_shape = input_shape;
...@@ -188,19 +197,24 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorO ...@@ -188,19 +197,24 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorO
CUDNN_SAFE_CALL(cudnnGetReductionWorkspaceSize( CUDNN_SAFE_CALL(cudnnGetReductionWorkspaceSize(
*m_ctx->cudnn_handle, desc, input_desc, output_desc, &workspace_size)); *m_ctx->cudnn_handle, desc, input_desc, output_desc, &workspace_size));
size_t workspace_idx = allocator.reserve_workspace(workspace_size); size_t workspace_idx = allocator.reserve_workspace(workspace_size);
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0); void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0); void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
// emit reduce operation std::unique_ptr<gpu::primitive> reduce;
std::unique_ptr<gpu::primitive> reduce( switch (reduction_mode)
new gpu::primitive{[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) { {
case ReductionMode::Reduce:
{
CUDNN_SAFE_CALL(cudnnSetReduceTensorDescriptor(desc, CUDNN_SAFE_CALL(cudnnSetReduceTensorDescriptor(desc,
reduce_op, reduce_op,
data_type, data_type,
CUDNN_NOT_PROPAGATE_NAN, CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_REDUCE_TENSOR_NO_INDICES,
CUDNN_32BIT_INDICES)); CUDNN_32BIT_INDICES));
// emit reduce operation
reduce.reset(new gpu::primitive{
[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) {
void* workspace_ptr = runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx); void* workspace_ptr = runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx);
CUDNN_SAFE_CALL(cudnnReduceTensor(*m_ctx->cudnn_handle, CUDNN_SAFE_CALL(cudnnReduceTensor(*m_ctx->cudnn_handle,
desc, desc,
...@@ -217,6 +231,53 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorO ...@@ -217,6 +231,53 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorO
debug_sync(); debug_sync();
}}); }});
break;
}
case ReductionMode::ArgReduce:
{
// TODO: Issue #1782
if (output_type != element::i32)
{
std::stringstream ss_er;
ss_er
<< "Unsupported Type: Only uint32 currently supported for indices in op ArgReduce ";
throw std::invalid_argument(ss_er.str());
}
size_t indices_size = shape_size(output_shape) * output_type.size();
size_t reduce_buffer_idx =
allocator.reserve_workspace(shape_size(output_shape) * input_type.size());
CUDNN_SAFE_CALL(cudnnSetReduceTensorDescriptor(desc,
reduce_op,
data_type,
CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_FLATTENED_INDICES,
CUDNN_32BIT_INDICES));
reduce.reset(new gpu::primitive{[=, &desc, &input_desc, &output_desc](void** inputs,
void** outputs) {
void* workspace_ptr = runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx);
void* reduce_buffer = runtime::gpu::invoke_memory_primitive(m_ctx, reduce_buffer_idx);
CUDNN_SAFE_CALL(cudnnReduceTensor(*m_ctx->cudnn_handle,
desc,
outputs[0],
indices_size,
workspace_ptr,
workspace_size,
alpha,
input_desc,
inputs[0],
beta,
output_desc,
reduce_buffer));
debug_sync();
}});
break;
}
}
return this->m_primitive_emitter->register_primitive(reduce, hash); return this->m_primitive_emitter->register_primitive(reduce, hash);
} }
...@@ -822,10 +883,10 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Max* node) ...@@ -822,10 +883,10 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Max* node)
auto input_size = shape_size(input_shape); auto input_size = shape_size(input_shape);
auto output_size = shape_size(output_shape); auto output_size = shape_size(output_shape);
auto output_element_size = out[0].get_element_type().size(); auto output_element_size = out[0].get_element_type().size();
auto output_type = out[0].get_element_type().c_type_string(); auto output_type = out[0].get_element_type();
std::stringstream ss; std::stringstream ss;
ss << "max_" << output_type << "_i" << join(input_shape, "_") << "_ra" ss << "max_" << output_type.c_type_string() << "_i" << join(input_shape, "_") << "_ra"
<< join(node->get_reduction_axes(), "_"); << join(node->get_reduction_axes(), "_");
std::string hash = ss.str(); std::string hash = ss.str();
...@@ -861,9 +922,13 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Max* node) ...@@ -861,9 +922,13 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Max* node)
} }
else else
{ {
std::vector<element::Type> dtypes{args[0].get_element_type(), out[0].get_element_type()};
auto& cudnn_emitter = m_primitive_emitter->get_cudnn_emitter(); auto& cudnn_emitter = m_primitive_emitter->get_cudnn_emitter();
auto max_index = cudnn_emitter->build_reduce_forward( auto max_index = cudnn_emitter->build_reduce_forward(CUDNN_REDUCE_TENSOR_MAX,
CUDNN_REDUCE_TENSOR_MAX, output_type, input_shape, node->get_reduction_axes()); dtypes,
input_shape,
node->get_reduction_axes(),
ReductionMode::Reduce);
kernel_launch.reset(new gpu::primitive{[=](void** inputs, void** outputs) mutable { kernel_launch.reset(new gpu::primitive{[=](void** inputs, void** outputs) mutable {
gpu::invoke_primitive(m_ctx, max_index, inputs, outputs); gpu::invoke_primitive(m_ctx, max_index, inputs, outputs);
}}); }});
...@@ -881,10 +946,10 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Min* node) ...@@ -881,10 +946,10 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Min* node)
auto input_size = shape_size(input_shape); auto input_size = shape_size(input_shape);
auto output_size = shape_size(output_shape); auto output_size = shape_size(output_shape);
auto output_element_size = out[0].get_element_type().size(); auto output_element_size = out[0].get_element_type().size();
auto output_type = out[0].get_element_type().c_type_string(); auto output_type = out[0].get_element_type();
std::stringstream ss; std::stringstream ss;
ss << "min_" << output_type << "_i" << join(input_shape, "_") << "_ra" ss << "min_" << output_type.c_type_string() << "_i" << join(input_shape, "_") << "_ra"
<< join(node->get_reduction_axes(), "_"); << join(node->get_reduction_axes(), "_");
std::string hash = ss.str(); std::string hash = ss.str();
...@@ -920,9 +985,13 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Min* node) ...@@ -920,9 +985,13 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Min* node)
} }
else else
{ {
std::vector<element::Type> dtypes{args[0].get_element_type(), out[0].get_element_type()};
auto& cudnn_emitter = m_primitive_emitter->get_cudnn_emitter(); auto& cudnn_emitter = m_primitive_emitter->get_cudnn_emitter();
auto min_index = cudnn_emitter->build_reduce_forward( auto min_index = cudnn_emitter->build_reduce_forward(CUDNN_REDUCE_TENSOR_MIN,
CUDNN_REDUCE_TENSOR_MIN, output_type, input_shape, node->get_reduction_axes()); dtypes,
input_shape,
node->get_reduction_axes(),
ReductionMode::Reduce);
kernel_launch.reset(new gpu::primitive{[=](void** inputs, void** outputs) mutable { kernel_launch.reset(new gpu::primitive{[=](void** inputs, void** outputs) mutable {
gpu::invoke_primitive(m_ctx, min_index, inputs, outputs); gpu::invoke_primitive(m_ctx, min_index, inputs, outputs);
}}); }});
......
...@@ -72,6 +72,12 @@ namespace ngraph ...@@ -72,6 +72,12 @@ namespace ngraph
Backward Backward
}; };
enum class ReductionMode
{
Reduce,
ArgReduce
};
enum class algo_search enum class algo_search
{ {
HEURISTIC, HEURISTIC,
...@@ -109,9 +115,10 @@ namespace ngraph ...@@ -109,9 +115,10 @@ namespace ngraph
const algo_search find_algo = algo_search::NONE); const algo_search find_algo = algo_search::NONE);
size_t build_reduce_forward(const cudnnReduceTensorOp_t& reduce_op, size_t build_reduce_forward(const cudnnReduceTensorOp_t& reduce_op,
const std::string& dtype, const std::vector<element::Type>& dtypes,
const Shape& input_shape, const Shape& input_shape,
const AxisSet& reduction_axes); const AxisSet& reduction_axes,
const ReductionMode& reduction_mode);
size_t build_tensor_op(const cudnnOpTensorOp_t& tensor_op, size_t build_tensor_op(const cudnnOpTensorOp_t& tensor_op,
const std::string& dtype, const std::string& dtype,
...@@ -163,6 +170,7 @@ namespace ngraph ...@@ -163,6 +170,7 @@ namespace ngraph
void* get_data_by_type(cudnnDataType_t data_type, double value); void* get_data_by_type(cudnnDataType_t data_type, double value);
cudnnDataType_t get_cudnn_datatype(std::string dtype); cudnnDataType_t get_cudnn_datatype(std::string dtype);
cudnnDataType_t get_cudnn_datatype(const element::Type& dtype);
cudnnTensorDescriptor_t& cudnnTensorDescriptor_t&
tensor_descriptor_from_shape(const Shape& shape, tensor_descriptor_from_shape(const Shape& shape,
......
...@@ -164,12 +164,45 @@ void runtime::gpu::GPU_Emitter::emit_And(EMIT_ARGS) ...@@ -164,12 +164,45 @@ void runtime::gpu::GPU_Emitter::emit_And(EMIT_ARGS)
void runtime::gpu::GPU_Emitter::emit_ArgMax(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_ArgMax(EMIT_ARGS)
{ {
throw unsupported_op("Unsupported op '" + node->description() + "'"); cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_MAX;
runtime::gpu::GPU_Emitter::emit_ArgReduce(
external_function, writer, node, args, out, reduce_op);
} }
void runtime::gpu::GPU_Emitter::emit_ArgMin(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_ArgMin(EMIT_ARGS)
{ {
throw unsupported_op("Unsupported op '" + node->description() + "'"); cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_MIN;
runtime::gpu::GPU_Emitter::emit_ArgReduce(
external_function, writer, node, args, out, reduce_op);
}
void runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t reduce_mode)
{
if (out[0].get_size() == 0)
{
return;
}
auto argmax = static_cast<const ngraph::op::ArgMax*>(node);
std::vector<size_t> axes{argmax->get_reduction_axis()};
auto axis_set = AxisSet(axes);
std::vector<element::Type> dtypes{args[0].get_element_type(), out[0].get_element_type()};
writer.block_begin();
{
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_reduce_forward(reduce_mode,
dtypes,
args[0].get_shape(),
axis_set,
CUDNNEmitter::ReductionMode::ArgReduce);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
} }
void runtime::gpu::GPU_Emitter::emit_Asin(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_Asin(EMIT_ARGS)
...@@ -856,6 +889,7 @@ void runtime::gpu::GPU_Emitter::emit_Power(EMIT_ARGS) ...@@ -856,6 +889,7 @@ void runtime::gpu::GPU_Emitter::emit_Power(EMIT_ARGS)
void runtime::gpu::GPU_Emitter::emit_Product(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_Product(EMIT_ARGS)
{ {
const ngraph::op::Product* product = static_cast<const ngraph::op::Product*>(node); const ngraph::op::Product* product = static_cast<const ngraph::op::Product*>(node);
writer.block_begin(); writer.block_begin();
{ {
if (out[0].get_size() != 0) if (out[0].get_size() != 0)
...@@ -877,12 +911,16 @@ void runtime::gpu::GPU_Emitter::emit_Product(EMIT_ARGS) ...@@ -877,12 +911,16 @@ void runtime::gpu::GPU_Emitter::emit_Product(EMIT_ARGS)
// descriptors for tensors with <= 4 dimensions // descriptors for tensors with <= 4 dimensions
else else
{ {
std::vector<element::Type> dtypes{args[0].get_element_type(),
out[0].get_element_type()};
auto& cudnn_emitter = auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter(); external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_reduce_forward(CUDNN_REDUCE_TENSOR_MUL, auto index =
out[0].get_type(), cudnn_emitter->build_reduce_forward(CUDNN_REDUCE_TENSOR_MUL,
dtypes,
args[0].get_shape(), args[0].get_shape(),
product->get_reduction_axes()); product->get_reduction_axes(),
CUDNNEmitter::ReductionMode::Reduce);
writer << "void* input[] = {" << node_names(args) << "};\n"; writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n"; writer << "void* output[] = {" << node_names(out) << "};\n";
...@@ -971,14 +1009,16 @@ void runtime::gpu::GPU_Emitter::emit_Reduce(EMIT_ARGS) ...@@ -971,14 +1009,16 @@ void runtime::gpu::GPU_Emitter::emit_Reduce(EMIT_ARGS)
reduce_tensor_op = f_ptr->second; reduce_tensor_op = f_ptr->second;
} }
} }
std::vector<element::Type> dtypes{args[0].get_element_type(),
out[0].get_element_type()};
auto& cudnn_emitter = auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter(); external_function->get_primitive_emitter()->get_cudnn_emitter();
auto reduce_index = auto reduce_index =
cudnn_emitter->build_reduce_forward(reduce_tensor_op, cudnn_emitter->build_reduce_forward(reduce_tensor_op,
out[0].get_type(), dtypes,
args[0].get_shape(), args[0].get_shape(),
reduce_op->get_reduction_axes()); reduce_op->get_reduction_axes(),
CUDNNEmitter::ReductionMode::Reduce);
writer << "void* input[] = {" << node_names(args) << "};\n"; writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n"; writer << "void* output[] = {" << node_names(out) << "};\n";
......
...@@ -75,6 +75,8 @@ namespace ngraph ...@@ -75,6 +75,8 @@ namespace ngraph
writer.block_end(); writer.block_end();
} }
static void emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t);
private: private:
/// \brief Create a list of node names for each arg in args /// \brief Create a list of node names for each arg in args
/// \param args list of tensor arguments /// \param args list of tensor arguments
......
...@@ -31,8 +31,6 @@ backwards_avgpool_n1_c1_hw4x4 ...@@ -31,8 +31,6 @@ backwards_avgpool_n1_c1_hw4x4
backwards_avgpool_n2_c2_hw4x4 backwards_avgpool_n2_c2_hw4x4
max_pool_3d max_pool_3d
avg_pool_3d avg_pool_3d
argmin_trivial
argmax_trivial
topk_1d_max_all topk_1d_max_all
topk_1d_max_partial topk_1d_max_partial
topk_1d_max_one topk_1d_max_one
......
...@@ -9483,9 +9483,51 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial) ...@@ -9483,9 +9483,51 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial)
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result)); EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMin>(A, 3, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 4>({{{{0.5f, 1.5f, 0.8f, 2.9f, 1.05f}, // img 0 ch 0
{0.5f, 3.5f, 2.0f, 1.0f, 0.2f},
{2.0f, 0.0f, 2.2f, 0.2f, 1.4f},
{2.9f, 0.0f, 1.52f, 1.2f, 2.22f},
{5.0f, 2.0f, 1.0f, 0.5f, 0.85f}},
{{0.25f, 0.02f, 0.02f, 2.2f, 0.001f}, // img 0 ch 1
{1.0f, 0.2f, 3.0f, 0.25f, 1.14f},
{2.25f, 10.1f, 1.0f, 0.02f, 2.22f},
{3.2f, 1.002f, 0.001f, 0.2f, 6.0f},
{2.0f, 0.0f, 0.0f, 0.0f, 0.0f}}},
{{{0.0f, 2.2f, 1.2f, 1.6f, 0.2f}, // img 1 ch 0
{0.01f, 0.0f, 0.22f, 0.02f, 1.1f},
{0.01f, 0.5f, 1.6f, 0.2f, 3.2f},
{2.4f, 0.5f, 0.0f, 3.0f, 0.1f},
{0.0f, 0.5f, 0.4f, 0.8f, 1.0f}},
{{2.0f, 1.0f, 0.0f, 0.0f, 1.0f}, // img 1 ch 1
{0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
{1.0f, 1.0f, 2.0f, 0.0f, 2.0f},
{1.0f, 1.0f, 1.0f, 0.0f, 1.0f},
{1.0f, 0.0f, 0.0f, 0.0f, 2.0f}}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 3>({{{0, 4, 1, 1, 3}, // ch0
{4, 1, 3, 2, 1}}, //
{{0, 1, 0, 2, 0}, // ch1
{2, 0, 3, 3, 1}}}) //
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial) NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial)
{ {
Shape shape{4, 3}; Shape shape{4, 3}; // HW -> (0,1)
Shape rshape{3}; Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape); auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = auto f =
...@@ -9502,6 +9544,168 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial) ...@@ -9502,6 +9544,168 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial)
EXPECT_EQ((vector<int>{1, 3, 0}), read_vector<int>(result)); EXPECT_EQ((vector<int>{1, 3, 0}), read_vector<int>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_0) // Along Channels
{
Shape shape{3, 4, 2}; // CHW ->(0,1,2)
Shape rshape{4, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 0, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 3>({{{8, 4}, //ch0
{12, 10},
{2, 9},
{1, 5}},
{{6, 7}, //ch1
{11, 3},
{9, 2},
{10, 12}},
{{8, 4}, //ch2
{6, 1},
{5, 3},
{11, 7}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 2>({{0, 1}, //r0
{0, 0}, //r1
{1, 0}, //r2
{2, 1}}) //r3
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_1) // Along Height
{
Shape shape{3, 4, 2}; // CHW ->(0,1,2)
Shape rshape{3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 1, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 3>({{{8, 4}, //ch0
{12, 10},
{2, 9},
{1, 5}},
{{6, 7}, //ch1
{11, 3},
{9, 2},
{10, 12}},
{{8, 4}, //ch2
{6, 1},
{5, 3},
{11, 7}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 2>({{1, 1}, //
{1, 3}, //
{3, 3}})
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_2) // Along Width
{
Shape shape{3, 4, 2}; // CHW ->(0,1,2)
Shape rshape{3, 4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 2, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 3>({{{8, 4}, //ch0
{12, 10},
{2, 9},
{1, 5}},
{{6, 7}, //ch1
{11, 3},
{9, 2},
{10, 12}},
{{8, 4}, //ch2
{6, 1},
{5, 3},
{11, 7}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 2>({{0, 0, 1, 1}, //
{1, 0, 0, 1}, //
{0, 0, 0, 0}}) //
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_axis_3)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 3, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 4>({{{{0, 1, 0, 2, 1}, // img 0 ch 0
{0, 3, 2, 0, 0},
{2, 0, 0, 0, 1},
{2, 0, 1, 1, 2},
{0, 2, 1, 0, 0}},
{{0, 0, 0, 2, 0}, // img 0 ch 1
{0, 2, 3, 0, 1},
{2, 0, 1, 0, 2},
{3, 1, 0, 0, 0},
{2, 0, 0, 0, 0}}},
{{{0, 2, 1, 1, 0}, // img 1 ch 0
{0, 0, 2, 0, 1},
{0, 0, 1, 2, 3},
{2, 0, 0, 3, 0},
{0, 0, 0, 0, 0}},
{{2, 1, 0, 0, 1}, // img 1 ch 1
{0, 2, 0, 0, 0},
{1, 1, 2, 0, 2},
{1, 1, 1, 0, 1},
{1, 0, 0, 0, 2}}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 3>({{{3, 1, 0, 0, 1}, {3, 2, 0, 0, 0}}, //ch0
{{1, 2, 4, 3, 0}, {0, 1, 2, 0, 4}}}) //ch1
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_all) NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_all)
{ {
Shape shape{6}; Shape shape{6};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment