Commit d734d6d6 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Adam Straw

TopK: return the lower index for equivalent values (#2710)

TopK: return the lower index for equivalent values
parent 0890cc86
......@@ -30,6 +30,7 @@ topk_1d_min_one
topk_1d_min_partial
topk_2d_max_all
topk_2d_max_one
topk_2d_max_one_with_equal_values
topk_2d_max_partial
topk_2d_min_all
topk_2d_min_one
......
......@@ -33,6 +33,16 @@ namespace ngraph
template <typename T, typename U>
static bool compare_max(const std::tuple<T, U>& a, const std::tuple<T, U>& b)
{
// this is intentional to be able to compare floats directly
// without using relative or absolute tolerance
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
if (std::get<0>(a) == std::get<0>(b))
{
return std::get<1>(a) < std::get<1>(b);
}
#pragma clang diagnostic pop
return a > b;
}
template <typename T, typename U>
......
......@@ -522,6 +522,32 @@ NGRAPH_TEST(${BACKEND_NAME}, topk_2d_max_one)
(vector<float>{12, 11, 10}), read_vector<float>(result1), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_max_one_with_equal_values)
{
Shape shape{2, 4};
Shape rshape{2, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 1, true);
auto f0 = make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), ParameterVector{A});
auto f1 = make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 3, 2, 4, 1, 3, 3, 2});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
auto h0 = backend->compile(f0);
h0->call_with_validate({result0}, {a});
EXPECT_EQ((vector<int32_t>{3, 1}), read_vector<int32_t>(result0));
auto h1 = backend->compile(f1);
h1->call_with_validate({result1}, {a});
EXPECT_TRUE(test::all_close_f(
(vector<float>{4, 3}), read_vector<float>(result1), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_min_all)
{
Shape shape{4, 3};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment