Commit 78c145e6 authored by Robert Kimball's avatar Robert Kimball Committed by omarkanawi

Performance optimize reference TopK op (#3552)

* wip

* r50 test

* update test

* R50 topk calls

* Faster topk

* Finish topk implementation

* cleanup

* checkpoint

* new test working

* more unit test

* style

* wip

* fix tests

* add needed header

* Add new TopK test to plaid manifest

* Fix windows build error:
parent 429776d2
...@@ -54,13 +54,13 @@ namespace ngraph ...@@ -54,13 +54,13 @@ namespace ngraph
/// supported /// supported
/// \param k Number of top indices to compute. Compute all indices if k = 0 /// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param compute_max Compute top k max or top k min? /// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE /// \param sort SortType for sorting results, default - SORT_VALUES
TopK(const Output<Node>& arg, TopK(const Output<Node>& arg,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
size_t k = 0, size_t k = 0,
bool compute_max = true, bool compute_max = true,
SortType sort = SortType::NONE); SortType sort = SortType::SORT_VALUES);
/// \brief Constructs a TopK operation. /// \brief Constructs a TopK operation.
/// ///
/// \param arg The input tensor /// \param arg The input tensor
...@@ -69,13 +69,13 @@ namespace ngraph ...@@ -69,13 +69,13 @@ namespace ngraph
/// \param index_element_type produce indices. Currently, only int64 or int32 are /// \param index_element_type produce indices. Currently, only int64 or int32 are
/// supported /// supported
/// \param compute_max Compute top k max or top k min? /// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE /// \param sort SortType for sorting results, default - SORT_VALUES
TopK(const Output<Node>& arg, TopK(const Output<Node>& arg,
const Output<Node>& k, const Output<Node>& k,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
bool compute_max = true, bool compute_max = true,
SortType sort = SortType::NONE); SortType sort = SortType::SORT_VALUES);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -52,6 +52,7 @@ namespace ngraph ...@@ -52,6 +52,7 @@ namespace ngraph
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto k = topk->get_k(); auto k = topk->get_k();
auto compute_max = topk->get_compute_max(); auto compute_max = topk->get_compute_max();
auto sort = topk->get_sort();
auto element_type = args[0].get_element_type(); auto element_type = args[0].get_element_type();
if (element_type == element::f32) if (element_type == element::f32)
...@@ -64,6 +65,7 @@ namespace ngraph ...@@ -64,6 +65,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -76,7 +78,8 @@ namespace ngraph ...@@ -76,7 +78,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
else else
...@@ -87,6 +90,7 @@ namespace ngraph ...@@ -87,6 +90,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -99,7 +103,8 @@ namespace ngraph ...@@ -99,7 +103,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
} }
...@@ -113,6 +118,7 @@ namespace ngraph ...@@ -113,6 +118,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -125,7 +131,8 @@ namespace ngraph ...@@ -125,7 +131,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
else else
...@@ -136,6 +143,7 @@ namespace ngraph ...@@ -136,6 +143,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -148,7 +156,8 @@ namespace ngraph ...@@ -148,7 +156,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
} }
...@@ -162,6 +171,7 @@ namespace ngraph ...@@ -162,6 +171,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -174,7 +184,8 @@ namespace ngraph ...@@ -174,7 +184,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
else else
...@@ -185,6 +196,7 @@ namespace ngraph ...@@ -185,6 +196,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -197,7 +209,8 @@ namespace ngraph ...@@ -197,7 +209,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
} }
......
...@@ -1734,7 +1734,8 @@ private: ...@@ -1734,7 +1734,8 @@ private:
node.get_output_shape(0), node.get_output_shape(0),
topk->get_top_k_axis(), topk->get_top_k_axis(),
topk->get_k(), topk->get_k(),
topk->get_compute_max()); topk->get_compute_max(),
topk->get_sort());
} }
else if (node.get_output_element_type(0) == element::i32) else if (node.get_output_element_type(0) == element::i32)
{ {
...@@ -1745,7 +1746,8 @@ private: ...@@ -1745,7 +1746,8 @@ private:
node.get_output_shape(0), node.get_output_shape(0),
topk->get_top_k_axis(), topk->get_top_k_axis(),
topk->get_k(), topk->get_k(),
topk->get_compute_max()); topk->get_compute_max(),
topk->get_sort());
} }
else else
{ {
......
...@@ -39,6 +39,13 @@ topk_2d_min_one # No plans to implement TopK ...@@ -39,6 +39,13 @@ topk_2d_min_one # No plans to implement TopK
topk_int64 # No plans to implement TopK topk_int64 # No plans to implement TopK
topk_5d_max_partial # No plans to implement TopK topk_5d_max_partial # No plans to implement TopK
topk_1d_i32_max_all # No plans to implement TopK topk_1d_i32_max_all # No plans to implement TopK
topk_resnet50 # No plans to implement TopK
topk_max_sort_none # No plans to implement TopK
topk_min_sort_none # No plans to implement TopK
topk_max_sort_value # No plans to implement TopK
topk_min_sort_value # No plans to implement TopK
topk_max_sort_index # No plans to implement TopK
topk_min_sort_index # No plans to implement TopK
topk_2d_max_one_with_equal_values # No plans to implement TopK topk_2d_max_one_with_equal_values # No plans to implement TopK
model_top_k # No plans to implement TopK model_top_k # No plans to implement TopK
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <numeric> #include <numeric>
#include "ngraph/coordinate_transform.hpp" #include "ngraph/coordinate_transform.hpp"
#include "ngraph/op/topk.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -46,15 +47,28 @@ namespace ngraph ...@@ -46,15 +47,28 @@ namespace ngraph
#if defined(__GNUC__) #if defined(__GNUC__)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif #endif
return a > b; return a > b;
} }
template <typename T, typename U> template <typename T, typename U>
inline bool compare_min(const std::tuple<T, U>& a, const std::tuple<T, U>& b) inline bool compare_min(const std::tuple<T, U>& a, const std::tuple<T, U>& b)
{ {
return a < b; return a < b;
} }
template <typename T, typename U>
inline bool sort_indices_descending(const std::tuple<T, U>& a,
const std::tuple<T, U>& b)
{
return std::get<1>(a) < std::get<1>(b);
}
template <typename T, typename U>
inline bool sort_indices_ascending(const std::tuple<T, U>& a, const std::tuple<T, U>& b)
{
return std::get<1>(a) > std::get<1>(b);
}
template <typename T, typename U> template <typename T, typename U>
void topk(const T* arg, void topk(const T* arg,
U* out_indices, U* out_indices,
...@@ -63,7 +77,8 @@ namespace ngraph ...@@ -63,7 +77,8 @@ namespace ngraph
const Shape& out_shape, const Shape& out_shape,
size_t axis, size_t axis,
size_t k, size_t k,
bool compute_max) bool compute_max,
op::TopK::SortType sort = op::TopK::SortType::NONE)
{ {
using namespace std; using namespace std;
// reorder source axis visit order and make "axis" inner most // reorder source axis visit order and make "axis" inner most
...@@ -103,13 +118,49 @@ namespace ngraph ...@@ -103,13 +118,49 @@ namespace ngraph
// Sort the temp vector // Sort the temp vector
if (compute_max) if (compute_max)
{ {
sort(workspace.begin(), workspace.end(), compare_max<T, U>); nth_element(workspace.begin(),
workspace.begin() + k,
workspace.end(),
compare_max<T, U>);
} }
else else
{ {
sort(workspace.begin(), workspace.end(), compare_min<T, U>); nth_element(workspace.begin(),
workspace.begin() + k,
workspace.end(),
compare_min<T, U>);
} }
// Write temp vector to output // Write temp vector to output
if (compute_max)
{
switch (sort)
{
case op::TopK::SortType::NONE: break;
case op::TopK::SortType::SORT_INDICES:
std::sort(workspace.begin(),
workspace.begin() + k,
sort_indices_descending<T, U>);
break;
case op::TopK::SortType::SORT_VALUES:
std::sort(workspace.begin(), workspace.begin() + k, compare_max<T, U>);
break;
}
}
else
{
switch (sort)
{
case op::TopK::SortType::NONE: break;
case op::TopK::SortType::SORT_INDICES:
std::sort(workspace.begin(),
workspace.begin() + k,
sort_indices_ascending<T, U>);
break;
case op::TopK::SortType::SORT_VALUES:
std::sort(workspace.begin(), workspace.begin() + k, compare_min<T, U>);
break;
}
}
for (size_t j = 0; j < k; j++) for (size_t j = 0; j < k; j++)
{ {
tuple<T, U> entry = workspace[j]; tuple<T, U> entry = workspace[j];
......
...@@ -18,15 +18,16 @@ ...@@ -18,15 +18,16 @@
#include <cinttypes> #include <cinttypes>
#include <cmath> #include <cmath>
#include <cstdlib> #include <cstdlib>
#include <numeric>
#include <random> #include <random>
#include <string> #include <string>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/op/get_output_element.hpp"
#include "util/all_close.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/topk.hpp"
#include "util/all_close_f.hpp" #include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp" #include "util/test_control.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
...@@ -35,6 +36,396 @@ using namespace ngraph; ...@@ -35,6 +36,396 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}"; static string s_manifest = "${MANIFEST}";
template <typename T>
bool compare_set(const vector<T>& a, vector<T> b)
{
for (auto ita = a.begin(); ita != a.end(); ++ita)
{
auto itb = find(b.begin(), b.end(), *ita);
if (itb == b.end())
{
return false;
}
else
{
b.erase(itb);
}
}
return true;
}
NGRAPH_TEST(${BACKEND_NAME}, topk_resnet50)
{
Shape shape{128, 1000};
Shape rshape5{128, 5};
Shape rshape1{128, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 5, true);
auto C = make_shared<op::TopK>(A, 1, element::i32, 1, true);
auto out5_value = make_shared<op::GetOutputElement>(B, 1);
auto out5_index = make_shared<op::GetOutputElement>(B, 0);
auto out1_value = make_shared<op::GetOutputElement>(C, 1);
auto out1_index = make_shared<op::GetOutputElement>(C, 0);
auto f = make_shared<Function>(NodeVector{out5_value, out5_index, out1_value, out1_index},
ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
vector<float> data;
for (size_t i = 0; i < shape[0]; i++)
{
for (size_t j = 0; j < shape[1]; j++)
{
data.push_back(j);
}
}
copy_data(a, data);
auto result5_value = backend->create_tensor(element::f32, rshape5);
auto result5_index = backend->create_tensor(element::i32, rshape5);
auto result1_value = backend->create_tensor(element::f32, rshape1);
auto result1_index = backend->create_tensor(element::i32, rshape1);
auto exec = backend->compile(f);
exec->call({result5_value, result5_index, result1_value, result1_index}, {a});
auto actual5_value = read_vector<float>(result5_value);
auto actual5_index = read_vector<int32_t>(result5_index);
auto actual1_value = read_vector<float>(result1_value);
auto actual1_index = read_vector<int32_t>(result1_index);
vector<float> expected5_value;
vector<int32_t> expected5_index;
for (size_t i = 0; i < rshape5[0]; i++)
{
for (size_t j = 0; j < rshape5[1]; j++)
{
expected5_value.push_back(shape[1] - j - 1);
expected5_index.push_back(shape[1] - j - 1);
}
}
vector<float> expected1_value;
vector<int32_t> expected1_index;
for (size_t i = 0; i < rshape1[0]; i++)
{
for (size_t j = 0; j < rshape1[1]; j++)
{
expected1_value.push_back(shape[1] - j - 1);
expected1_index.push_back(shape[1] - j - 1);
}
}
EXPECT_TRUE(compare_set<float>(expected5_value, actual5_value));
EXPECT_TRUE(compare_set<int32_t>(expected5_index, actual5_index));
EXPECT_TRUE(compare_set<float>(expected1_value, actual1_value));
EXPECT_TRUE(compare_set<int32_t>(expected1_index, actual1_index));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_max_sort_none)
{
Shape shape{128, 1000};
Shape rshape{128, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 5, true, op::TopK::SortType::NONE);
auto out_value = make_shared<op::GetOutputElement>(B, 1);
auto out_index = make_shared<op::GetOutputElement>(B, 0);
auto f = make_shared<Function>(NodeVector{out_value, out_index}, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
vector<float> data;
for (size_t i = 0; i < shape[0]; i++)
{
for (size_t j = 0; j < shape[1]; j++)
{
data.push_back(j);
}
}
copy_data(a, data);
auto result_value = backend->create_tensor(element::f32, rshape);
auto result_index = backend->create_tensor(element::i32, rshape);
auto exec = backend->compile(f);
exec->call({result_value, result_index}, {a});
auto actual_value = read_vector<float>(result_value);
auto actual_index = read_vector<int32_t>(result_index);
for (size_t i = 0; i < rshape[0]; i++)
{
vector<float> expected_value;
vector<int32_t> expected_index;
vector<float> act_value;
vector<int32_t> act_index;
for (size_t j = 0; j < rshape[1]; j++)
{
expected_value.push_back(shape[1] - j - 1);
expected_index.push_back(shape[1] - j - 1);
act_value.push_back(actual_value[rshape[1] * i + j]);
act_index.push_back(actual_index[rshape[1] * i + j]);
}
EXPECT_TRUE(compare_set<float>(expected_value, act_value));
EXPECT_TRUE(compare_set<int32_t>(expected_index, act_index));
}
}
NGRAPH_TEST(${BACKEND_NAME}, topk_min_sort_none)
{
Shape shape{128, 1000};
Shape rshape{128, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 5, false, op::TopK::SortType::NONE);
auto out_value = make_shared<op::GetOutputElement>(B, 1);
auto out_index = make_shared<op::GetOutputElement>(B, 0);
auto f = make_shared<Function>(NodeVector{out_value, out_index}, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
vector<float> data;
for (size_t i = 0; i < shape[0]; i++)
{
for (size_t j = 0; j < shape[1]; j++)
{
data.push_back(j);
}
}
copy_data(a, data);
auto result_value = backend->create_tensor(element::f32, rshape);
auto result_index = backend->create_tensor(element::i32, rshape);
auto exec = backend->compile(f);
exec->call({result_value, result_index}, {a});
auto actual_value = read_vector<float>(result_value);
auto actual_index = read_vector<int32_t>(result_index);
for (size_t i = 0; i < rshape[0]; i++)
{
vector<float> expected_value;
vector<int32_t> expected_index;
vector<float> act_value;
vector<int32_t> act_index;
for (size_t j = 0; j < rshape[1]; j++)
{
expected_value.push_back(j);
expected_index.push_back(j);
act_value.push_back(actual_value[rshape[1] * i + j]);
act_index.push_back(actual_index[rshape[1] * i + j]);
}
EXPECT_TRUE(compare_set<float>(expected_value, act_value));
EXPECT_TRUE(compare_set<int32_t>(expected_index, act_index));
}
}
NGRAPH_TEST(${BACKEND_NAME}, topk_max_sort_value)
{
Shape shape{128, 1000};
Shape rshape{128, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 5, true, op::TopK::SortType::SORT_VALUES);
auto out_value = make_shared<op::GetOutputElement>(B, 1);
auto out_index = make_shared<op::GetOutputElement>(B, 0);
auto f = make_shared<Function>(NodeVector{out_value, out_index}, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
vector<float> data;
for (size_t i = 0; i < shape[0]; i++)
{
for (size_t j = 0; j < shape[1]; j++)
{
data.push_back(j);
}
}
copy_data(a, data);
auto result_value = backend->create_tensor(element::f32, rshape);
auto result_index = backend->create_tensor(element::i32, rshape);
auto exec = backend->compile(f);
exec->call({result_value, result_index}, {a});
auto actual_value = read_vector<float>(result_value);
auto actual_index = read_vector<int32_t>(result_index);
vector<float> expected_value;
vector<int32_t> expected_index;
for (size_t i = 0; i < rshape[0]; i++)
{
for (size_t j = 0; j < rshape[1]; j++)
{
expected_value.push_back(shape[1] - j - 1);
expected_index.push_back(shape[1] - j - 1);
}
}
EXPECT_TRUE(test::all_close_f(expected_value, actual_value));
EXPECT_EQ(expected_index, actual_index);
}
NGRAPH_TEST(${BACKEND_NAME}, topk_min_sort_value)
{
Shape shape{128, 1000};
Shape rshape{128, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 5, false, op::TopK::SortType::SORT_VALUES);
auto out_value = make_shared<op::GetOutputElement>(B, 1);
auto out_index = make_shared<op::GetOutputElement>(B, 0);
auto f = make_shared<Function>(NodeVector{out_value, out_index}, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
vector<float> data;
for (size_t i = 0; i < shape[0]; i++)
{
for (size_t j = 0; j < shape[1]; j++)
{
data.push_back(j);
}
}
copy_data(a, data);
auto result_value = backend->create_tensor(element::f32, rshape);
auto result_index = backend->create_tensor(element::i32, rshape);
auto exec = backend->compile(f);
exec->call({result_value, result_index}, {a});
auto actual_value = read_vector<float>(result_value);
auto actual_index = read_vector<int32_t>(result_index);
for (size_t i = 0; i < rshape[0]; i++)
{
vector<float> expected_value;
vector<int32_t> expected_index;
vector<float> act_value;
vector<int32_t> act_index;
for (size_t j = 0; j < rshape[1]; j++)
{
expected_value.push_back(j);
expected_index.push_back(j);
act_value.push_back(actual_value[rshape[1] * i + j]);
act_index.push_back(actual_index[rshape[1] * i + j]);
}
EXPECT_TRUE(compare_set<float>(expected_value, act_value));
EXPECT_TRUE(compare_set<int32_t>(expected_index, act_index));
}
}
NGRAPH_TEST(${BACKEND_NAME}, topk_max_sort_index)
{
Shape shape{128, 1000};
Shape rshape{128, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 5, true, op::TopK::SortType::SORT_INDICES);
auto out_value = make_shared<op::GetOutputElement>(B, 1);
auto out_index = make_shared<op::GetOutputElement>(B, 0);
auto f = make_shared<Function>(NodeVector{out_value, out_index}, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
vector<float> data;
for (size_t i = 0; i < shape[0]; i++)
{
for (size_t j = 0; j < shape[1]; j++)
{
data.push_back(j);
}
}
copy_data(a, data);
auto result_value = backend->create_tensor(element::f32, rshape);
auto result_index = backend->create_tensor(element::i32, rshape);
auto exec = backend->compile(f);
exec->call({result_value, result_index}, {a});
auto actual_value = read_vector<float>(result_value);
auto actual_index = read_vector<int32_t>(result_index);
for (size_t i = 0; i < rshape[0]; i++)
{
vector<float> expected_value;
vector<int32_t> expected_index;
vector<float> act_value;
vector<int32_t> act_index;
for (size_t j = 0; j < rshape[1]; j++)
{
expected_value.push_back(shape[1] - j - 1);
expected_index.push_back(shape[1] - j - 1);
act_value.push_back(actual_value[rshape[1] * i + j]);
act_index.push_back(actual_index[rshape[1] * i + j]);
}
EXPECT_TRUE(compare_set<float>(expected_value, act_value));
EXPECT_TRUE(compare_set<int32_t>(expected_index, act_index));
}
}
NGRAPH_TEST(${BACKEND_NAME}, topk_min_sort_index)
{
Shape shape{128, 1000};
Shape rshape{128, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 5, false, op::TopK::SortType::SORT_INDICES);
auto out_value = make_shared<op::GetOutputElement>(B, 1);
auto out_index = make_shared<op::GetOutputElement>(B, 0);
auto f = make_shared<Function>(NodeVector{out_value, out_index}, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
vector<float> data;
for (size_t i = 0; i < shape[0]; i++)
{
for (size_t j = 0; j < shape[1]; j++)
{
data.push_back(j);
}
}
copy_data(a, data);
auto result_value = backend->create_tensor(element::f32, rshape);
auto result_index = backend->create_tensor(element::i32, rshape);
auto exec = backend->compile(f);
exec->call({result_value, result_index}, {a});
auto actual_value = read_vector<float>(result_value);
auto actual_index = read_vector<int32_t>(result_index);
for (size_t i = 0; i < rshape[0]; i++)
{
vector<float> expected_value;
vector<int32_t> expected_index;
vector<float> act_value;
vector<int32_t> act_index;
for (size_t j = 0; j < rshape[1]; j++)
{
expected_value.push_back(j);
expected_index.push_back(j);
act_value.push_back(actual_value[rshape[1] * i + j]);
act_index.push_back(actual_index[rshape[1] * i + j]);
}
EXPECT_TRUE(compare_set<float>(expected_value, act_value));
EXPECT_TRUE(compare_set<int32_t>(expected_index, act_index));
}
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_all) NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_all)
{ {
Shape shape{6}; Shape shape{6};
......
...@@ -117,7 +117,7 @@ TEST(type_prop, topk_rank_dynamic_ok) ...@@ -117,7 +117,7 @@ TEST(type_prop, topk_rank_dynamic_ok)
ASSERT_TRUE(topk->get_output_element_type(1) == element::f32); ASSERT_TRUE(topk->get_output_element_type(1) == element::f32);
ASSERT_TRUE(topk->get_output_partial_shape(0).rank().is_dynamic()); ASSERT_TRUE(topk->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(topk->get_output_partial_shape(1).rank().is_dynamic()); ASSERT_TRUE(topk->get_output_partial_shape(1).rank().is_dynamic());
ASSERT_TRUE(topk->get_sort() == op::TopK::SortType::NONE); ASSERT_TRUE(topk->get_sort() == op::TopK::SortType::SORT_VALUES);
} }
TEST(type_prop, topk_rank_dynamic_result_et_dynamic) TEST(type_prop, topk_rank_dynamic_result_et_dynamic)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment