Commit 9f0589a8 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

ArgReduce 64 bit indices (#1862)

* Update ArgReduce to handle i64 indices.

* Formatting.

* Add throw for output types other than int32/64.

* Add output type to hash.

* Add type to throw.

* Interpreter doesn't currently support 64bit output indices for argmin/max and so disabling this test [JIRA:NGRAPH-3183].
parent ccfcf4f9
......@@ -167,9 +167,9 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorO
auto input_type = dtypes[0];
auto output_type = dtypes[1];
std::stringstream ss;
ss << "reduce_" << reduce_op << input_type.c_type_string() << "_reduction_mode_"
<< static_cast<int>(reduction_mode) << "_i" << join(input_shape, "_") << "_ra"
<< join(reduction_axes, "_");
ss << "reduce_" << reduce_op << "_" << input_type.c_type_string() << "_"
<< output_type.c_type_string() << "_reduction_mode_" << static_cast<int>(reduction_mode)
<< "_i" << join(input_shape, "_") << "_ra" << join(reduction_axes, "_");
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
......@@ -236,44 +236,81 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorO
case ReductionMode::ArgReduce:
{
// TODO: Issue #1782
if (output_type != element::i32)
if (output_type == element::i32 || output_type == element::i64)
{
size_t indices_size = shape_size(output_shape) * output_type.size();
size_t reduce_buffer_idx =
allocator.reserve_workspace(shape_size(output_shape) * input_type.size());
CUDNN_SAFE_CALL(cudnnSetReduceTensorDescriptor(desc,
reduce_op,
data_type,
CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_FLATTENED_INDICES,
CUDNN_32BIT_INDICES));
if (output_type == element::i64)
{
size_t workspace_indices_idx =
allocator.reserve_workspace(shape_size(output_shape) * input_type.size());
auto& cuda_emitter = m_primitive_emitter->get_cuda_emitter();
auto convert_idx = cuda_emitter->build_elementwise<op::Convert>(
{element::i32.c_type_string(), element::i64.c_type_string()}, output_shape);
reduce.reset(new gpu::primitive{
[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) {
void* workspace_indices_ptr =
runtime::gpu::invoke_memory_primitive(m_ctx, workspace_indices_idx);
void* workspace_ptr =
runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx);
void* reduce_buffer =
runtime::gpu::invoke_memory_primitive(m_ctx, reduce_buffer_idx);
CUDNN_SAFE_CALL(cudnnReduceTensor(*m_ctx->cudnn_handle,
desc,
workspace_indices_ptr,
indices_size,
workspace_ptr,
workspace_size,
alpha,
input_desc,
inputs[0],
beta,
output_desc,
reduce_buffer));
gpu::invoke_primitive(m_ctx, convert_idx, &workspace_indices_ptr, outputs);
debug_sync();
}});
}
else
{
reduce.reset(new gpu::primitive{
[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) {
void* workspace_ptr =
runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx);
void* reduce_buffer =
runtime::gpu::invoke_memory_primitive(m_ctx, reduce_buffer_idx);
CUDNN_SAFE_CALL(cudnnReduceTensor(*m_ctx->cudnn_handle,
desc,
outputs[0],
indices_size,
workspace_ptr,
workspace_size,
alpha,
input_desc,
inputs[0],
beta,
output_desc,
reduce_buffer));
debug_sync();
}});
}
}
else
{
std::stringstream ss_er;
ss_er
<< "Unsupported Type: Only uint32 currently supported for indices in op ArgReduce ";
ss_er << "Unsupported Type: " << output_type.c_type_string()
<< ". Only uint32 & uint64 currently supported for indices in op "
"ArgReduce";
throw std::invalid_argument(ss_er.str());
}
size_t indices_size = shape_size(output_shape) * output_type.size();
size_t reduce_buffer_idx =
allocator.reserve_workspace(shape_size(output_shape) * input_type.size());
CUDNN_SAFE_CALL(cudnnSetReduceTensorDescriptor(desc,
reduce_op,
data_type,
CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_FLATTENED_INDICES,
CUDNN_32BIT_INDICES));
reduce.reset(new gpu::primitive{[=, &desc, &input_desc, &output_desc](void** inputs,
void** outputs) {
void* workspace_ptr = runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx);
void* reduce_buffer = runtime::gpu::invoke_memory_primitive(m_ctx, reduce_buffer_idx);
CUDNN_SAFE_CALL(cudnnReduceTensor(*m_ctx->cudnn_handle,
desc,
outputs[0],
indices_size,
workspace_ptr,
workspace_size,
alpha,
input_desc,
inputs[0],
beta,
output_desc,
reduce_buffer));
debug_sync();
}});
break;
}
}
......
argmin_4D_axis_3_i64
batchnorm_bprop_n4c3h2w2
batchnorm_fprop_b1c2h2w2
batchnorm_fprop_b2c2h2w1
......
......@@ -4713,6 +4713,47 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3)
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3_i64)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMin>(A, 3, element::i64), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 4>({{{{0.5f, 1.5f, 0.8f, 2.9f, 1.05f}, // img 0 ch 0
{0.5f, 3.5f, 2.0f, 1.0f, 0.2f},
{2.0f, 0.0f, 2.2f, 0.2f, 1.4f},
{2.9f, 0.0f, 1.52f, 1.2f, 2.22f},
{5.0f, 2.0f, 1.0f, 0.5f, 0.85f}},
{{0.25f, 0.02f, 0.02f, 2.2f, 0.001f}, // img 0 ch 1
{1.0f, 0.2f, 3.0f, 0.25f, 1.14f},
{2.25f, 10.1f, 1.0f, 0.02f, 2.22f},
{3.2f, 1.002f, 0.001f, 0.2f, 6.0f},
{2.0f, 0.0f, 0.0f, 0.0f, 0.0f}}},
{{{0.0f, 2.2f, 1.2f, 1.6f, 0.2f}, // img 1 ch 0
{0.01f, 0.0f, 0.22f, 0.02f, 1.1f},
{0.01f, 0.5f, 1.6f, 0.2f, 3.2f},
{2.4f, 0.5f, 0.0f, 3.0f, 0.1f},
{0.0f, 0.5f, 0.4f, 0.8f, 1.0f}},
{{2.0f, 1.0f, 0.0f, 0.0f, 1.0f}, // img 1 ch 1
{0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
{1.0f, 1.0f, 2.0f, 0.0f, 2.0f},
{1.0f, 1.0f, 1.0f, 0.0f, 1.0f},
{1.0f, 0.0f, 0.0f, 0.0f, 2.0f}}}})
.get_vector());
auto result = backend->create_tensor(element::i64, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int64_t, 3>({{{0, 4, 1, 1, 3}, // ch0
{4, 1, 3, 2, 1}}, //
{{0, 1, 0, 2, 0}, // ch1
{2, 0, 3, 3, 1}}}) //
.get_vector()),
read_vector<int64_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment