Commit 78c145e6 authored by Robert Kimball's avatar Robert Kimball Committed by omarkanawi

Performance optimize reference TopK op (#3552)

* wip

* r50 test

* update test

* R50 topk calls

* Faster topk

* Finish topk implementation

* cleanup

* checkpoint

* new test working

* more unit test

* style

* wip

* fix tests

* add needed header

* Add new TopK test to plaid manifest

* Fix windows build error:
parent 429776d2
......@@ -54,13 +54,13 @@ namespace ngraph
/// supported
/// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE
/// \param sort SortType for sorting results, default - SORT_VALUES
TopK(const Output<Node>& arg,
size_t top_k_axis,
const element::Type& index_element_type,
size_t k = 0,
bool compute_max = true,
SortType sort = SortType::NONE);
SortType sort = SortType::SORT_VALUES);
/// \brief Constructs a TopK operation.
///
/// \param arg The input tensor
......@@ -69,13 +69,13 @@ namespace ngraph
/// \param index_element_type produce indices. Currently, only int64 or int32 are
/// supported
/// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE
/// \param sort SortType for sorting results, default - SORT_VALUES
TopK(const Output<Node>& arg,
const Output<Node>& k,
size_t top_k_axis,
const element::Type& index_element_type,
bool compute_max = true,
SortType sort = SortType::NONE);
SortType sort = SortType::SORT_VALUES);
void validate_and_infer_types() override;
......
......@@ -52,6 +52,7 @@ namespace ngraph
auto out_shape = out[0].get_shape();
auto k = topk->get_k();
auto compute_max = topk->get_compute_max();
auto sort = topk->get_sort();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
......@@ -64,6 +65,7 @@ namespace ngraph
axis,
k,
compute_max,
sort,
arg_buffer_index,
out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx,
......@@ -76,7 +78,8 @@ namespace ngraph
out_shape,
axis,
k,
compute_max);
compute_max,
sort);
};
}
else
......@@ -87,6 +90,7 @@ namespace ngraph
axis,
k,
compute_max,
sort,
arg_buffer_index,
out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx,
......@@ -99,7 +103,8 @@ namespace ngraph
out_shape,
axis,
k,
compute_max);
compute_max,
sort);
};
}
}
......@@ -113,6 +118,7 @@ namespace ngraph
axis,
k,
compute_max,
sort,
arg_buffer_index,
out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx,
......@@ -125,7 +131,8 @@ namespace ngraph
out_shape,
axis,
k,
compute_max);
compute_max,
sort);
};
}
else
......@@ -136,6 +143,7 @@ namespace ngraph
axis,
k,
compute_max,
sort,
arg_buffer_index,
out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx,
......@@ -148,7 +156,8 @@ namespace ngraph
out_shape,
axis,
k,
compute_max);
compute_max,
sort);
};
}
}
......@@ -162,6 +171,7 @@ namespace ngraph
axis,
k,
compute_max,
sort,
arg_buffer_index,
out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx,
......@@ -174,7 +184,8 @@ namespace ngraph
out_shape,
axis,
k,
compute_max);
compute_max,
sort);
};
}
else
......@@ -185,6 +196,7 @@ namespace ngraph
axis,
k,
compute_max,
sort,
arg_buffer_index,
out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx,
......@@ -197,7 +209,8 @@ namespace ngraph
out_shape,
axis,
k,
compute_max);
compute_max,
sort);
};
}
}
......
......@@ -1734,7 +1734,8 @@ private:
node.get_output_shape(0),
topk->get_top_k_axis(),
topk->get_k(),
topk->get_compute_max());
topk->get_compute_max(),
topk->get_sort());
}
else if (node.get_output_element_type(0) == element::i32)
{
......@@ -1745,7 +1746,8 @@ private:
node.get_output_shape(0),
topk->get_top_k_axis(),
topk->get_k(),
topk->get_compute_max());
topk->get_compute_max(),
topk->get_sort());
}
else
{
......
......@@ -39,6 +39,13 @@ topk_2d_min_one # No plans to implement TopK
topk_int64 # No plans to implement TopK
topk_5d_max_partial # No plans to implement TopK
topk_1d_i32_max_all # No plans to implement TopK
topk_resnet50 # No plans to implement TopK
topk_max_sort_none # No plans to implement TopK
topk_min_sort_none # No plans to implement TopK
topk_max_sort_value # No plans to implement TopK
topk_min_sort_value # No plans to implement TopK
topk_max_sort_index # No plans to implement TopK
topk_min_sort_index # No plans to implement TopK
topk_2d_max_one_with_equal_values # No plans to implement TopK
model_top_k # No plans to implement TopK
......
......@@ -21,6 +21,7 @@
#include <numeric>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/op/topk.hpp"
namespace ngraph
{
......@@ -46,15 +47,28 @@ namespace ngraph
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
return a > b;
}
template <typename T, typename U>
inline bool compare_min(const std::tuple<T, U>& a, const std::tuple<T, U>& b)
{
return a < b;
}
template <typename T, typename U>
inline bool sort_indices_descending(const std::tuple<T, U>& a,
const std::tuple<T, U>& b)
{
return std::get<1>(a) < std::get<1>(b);
}
template <typename T, typename U>
inline bool sort_indices_ascending(const std::tuple<T, U>& a, const std::tuple<T, U>& b)
{
return std::get<1>(a) > std::get<1>(b);
}
template <typename T, typename U>
void topk(const T* arg,
U* out_indices,
......@@ -63,7 +77,8 @@ namespace ngraph
const Shape& out_shape,
size_t axis,
size_t k,
bool compute_max)
bool compute_max,
op::TopK::SortType sort = op::TopK::SortType::NONE)
{
using namespace std;
// reorder source axis visit order and make "axis" inner most
......@@ -103,13 +118,49 @@ namespace ngraph
// Sort the temp vector
if (compute_max)
{
sort(workspace.begin(), workspace.end(), compare_max<T, U>);
nth_element(workspace.begin(),
workspace.begin() + k,
workspace.end(),
compare_max<T, U>);
}
else
{
sort(workspace.begin(), workspace.end(), compare_min<T, U>);
nth_element(workspace.begin(),
workspace.begin() + k,
workspace.end(),
compare_min<T, U>);
}
// Write temp vector to output
if (compute_max)
{
switch (sort)
{
case op::TopK::SortType::NONE: break;
case op::TopK::SortType::SORT_INDICES:
std::sort(workspace.begin(),
workspace.begin() + k,
sort_indices_descending<T, U>);
break;
case op::TopK::SortType::SORT_VALUES:
std::sort(workspace.begin(), workspace.begin() + k, compare_max<T, U>);
break;
}
}
else
{
switch (sort)
{
case op::TopK::SortType::NONE: break;
case op::TopK::SortType::SORT_INDICES:
std::sort(workspace.begin(),
workspace.begin() + k,
sort_indices_ascending<T, U>);
break;
case op::TopK::SortType::SORT_VALUES:
std::sort(workspace.begin(), workspace.begin() + k, compare_min<T, U>);
break;
}
}
for (size_t j = 0; j < k; j++)
{
tuple<T, U> entry = workspace[j];
......
This diff is collapsed.
......@@ -117,7 +117,7 @@ TEST(type_prop, topk_rank_dynamic_ok)
ASSERT_TRUE(topk->get_output_element_type(1) == element::f32);
ASSERT_TRUE(topk->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(topk->get_output_partial_shape(1).rank().is_dynamic());
ASSERT_TRUE(topk->get_sort() == op::TopK::SortType::NONE);
ASSERT_TRUE(topk->get_sort() == op::TopK::SortType::SORT_VALUES);
}
TEST(type_prop, topk_rank_dynamic_result_et_dynamic)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment