Commit 78c145e6 authored by Robert Kimball's avatar Robert Kimball Committed by omarkanawi

Performance optimize reference TopK op (#3552)

* wip

* r50 test

* update test

* R50 topk calls

* Faster topk

* Finish topk implementation

* cleanup

* checkpoint

* new test working

* more unit test

* style

* wip

* fix tests

* add needed header

* Add new TopK test to plaid manifest

* Fix windows build error:
parent 429776d2
...@@ -54,13 +54,13 @@ namespace ngraph ...@@ -54,13 +54,13 @@ namespace ngraph
/// supported /// supported
/// \param k Number of top indices to compute. Compute all indices if k = 0 /// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param compute_max Compute top k max or top k min? /// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE /// \param sort SortType for sorting results, default - SORT_VALUES
TopK(const Output<Node>& arg, TopK(const Output<Node>& arg,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
size_t k = 0, size_t k = 0,
bool compute_max = true, bool compute_max = true,
SortType sort = SortType::NONE); SortType sort = SortType::SORT_VALUES);
/// \brief Constructs a TopK operation. /// \brief Constructs a TopK operation.
/// ///
/// \param arg The input tensor /// \param arg The input tensor
...@@ -69,13 +69,13 @@ namespace ngraph ...@@ -69,13 +69,13 @@ namespace ngraph
/// \param index_element_type produce indices. Currently, only int64 or int32 are /// \param index_element_type produce indices. Currently, only int64 or int32 are
/// supported /// supported
/// \param compute_max Compute top k max or top k min? /// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE /// \param sort SortType for sorting results, default - SORT_VALUES
TopK(const Output<Node>& arg, TopK(const Output<Node>& arg,
const Output<Node>& k, const Output<Node>& k,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
bool compute_max = true, bool compute_max = true,
SortType sort = SortType::NONE); SortType sort = SortType::SORT_VALUES);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -52,6 +52,7 @@ namespace ngraph ...@@ -52,6 +52,7 @@ namespace ngraph
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto k = topk->get_k(); auto k = topk->get_k();
auto compute_max = topk->get_compute_max(); auto compute_max = topk->get_compute_max();
auto sort = topk->get_sort();
auto element_type = args[0].get_element_type(); auto element_type = args[0].get_element_type();
if (element_type == element::f32) if (element_type == element::f32)
...@@ -64,6 +65,7 @@ namespace ngraph ...@@ -64,6 +65,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -76,7 +78,8 @@ namespace ngraph ...@@ -76,7 +78,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
else else
...@@ -87,6 +90,7 @@ namespace ngraph ...@@ -87,6 +90,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -99,7 +103,8 @@ namespace ngraph ...@@ -99,7 +103,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
} }
...@@ -113,6 +118,7 @@ namespace ngraph ...@@ -113,6 +118,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -125,7 +131,8 @@ namespace ngraph ...@@ -125,7 +131,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
else else
...@@ -136,6 +143,7 @@ namespace ngraph ...@@ -136,6 +143,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -148,7 +156,8 @@ namespace ngraph ...@@ -148,7 +156,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
} }
...@@ -162,6 +171,7 @@ namespace ngraph ...@@ -162,6 +171,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -174,7 +184,8 @@ namespace ngraph ...@@ -174,7 +184,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
else else
...@@ -185,6 +196,7 @@ namespace ngraph ...@@ -185,6 +196,7 @@ namespace ngraph
axis, axis,
k, k,
compute_max, compute_max,
sort,
arg_buffer_index, arg_buffer_index,
out_indices_buffer_index, out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx, out_values_buffer_index](CPURuntimeContext* ctx,
...@@ -197,7 +209,8 @@ namespace ngraph ...@@ -197,7 +209,8 @@ namespace ngraph
out_shape, out_shape,
axis, axis,
k, k,
compute_max); compute_max,
sort);
}; };
} }
} }
......
...@@ -1734,7 +1734,8 @@ private: ...@@ -1734,7 +1734,8 @@ private:
node.get_output_shape(0), node.get_output_shape(0),
topk->get_top_k_axis(), topk->get_top_k_axis(),
topk->get_k(), topk->get_k(),
topk->get_compute_max()); topk->get_compute_max(),
topk->get_sort());
} }
else if (node.get_output_element_type(0) == element::i32) else if (node.get_output_element_type(0) == element::i32)
{ {
...@@ -1745,7 +1746,8 @@ private: ...@@ -1745,7 +1746,8 @@ private:
node.get_output_shape(0), node.get_output_shape(0),
topk->get_top_k_axis(), topk->get_top_k_axis(),
topk->get_k(), topk->get_k(),
topk->get_compute_max()); topk->get_compute_max(),
topk->get_sort());
} }
else else
{ {
......
...@@ -39,6 +39,13 @@ topk_2d_min_one # No plans to implement TopK ...@@ -39,6 +39,13 @@ topk_2d_min_one # No plans to implement TopK
topk_int64 # No plans to implement TopK topk_int64 # No plans to implement TopK
topk_5d_max_partial # No plans to implement TopK topk_5d_max_partial # No plans to implement TopK
topk_1d_i32_max_all # No plans to implement TopK topk_1d_i32_max_all # No plans to implement TopK
topk_resnet50 # No plans to implement TopK
topk_max_sort_none # No plans to implement TopK
topk_min_sort_none # No plans to implement TopK
topk_max_sort_value # No plans to implement TopK
topk_min_sort_value # No plans to implement TopK
topk_max_sort_index # No plans to implement TopK
topk_min_sort_index # No plans to implement TopK
topk_2d_max_one_with_equal_values # No plans to implement TopK topk_2d_max_one_with_equal_values # No plans to implement TopK
model_top_k # No plans to implement TopK model_top_k # No plans to implement TopK
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <numeric> #include <numeric>
#include "ngraph/coordinate_transform.hpp" #include "ngraph/coordinate_transform.hpp"
#include "ngraph/op/topk.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -46,15 +47,28 @@ namespace ngraph ...@@ -46,15 +47,28 @@ namespace ngraph
#if defined(__GNUC__) #if defined(__GNUC__)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif #endif
return a > b; return a > b;
} }
template <typename T, typename U> template <typename T, typename U>
inline bool compare_min(const std::tuple<T, U>& a, const std::tuple<T, U>& b) inline bool compare_min(const std::tuple<T, U>& a, const std::tuple<T, U>& b)
{ {
return a < b; return a < b;
} }
template <typename T, typename U>
inline bool sort_indices_descending(const std::tuple<T, U>& a,
const std::tuple<T, U>& b)
{
return std::get<1>(a) < std::get<1>(b);
}
template <typename T, typename U>
inline bool sort_indices_ascending(const std::tuple<T, U>& a, const std::tuple<T, U>& b)
{
return std::get<1>(a) > std::get<1>(b);
}
template <typename T, typename U> template <typename T, typename U>
void topk(const T* arg, void topk(const T* arg,
U* out_indices, U* out_indices,
...@@ -63,7 +77,8 @@ namespace ngraph ...@@ -63,7 +77,8 @@ namespace ngraph
const Shape& out_shape, const Shape& out_shape,
size_t axis, size_t axis,
size_t k, size_t k,
bool compute_max) bool compute_max,
op::TopK::SortType sort = op::TopK::SortType::NONE)
{ {
using namespace std; using namespace std;
// reorder source axis visit order and make "axis" inner most // reorder source axis visit order and make "axis" inner most
...@@ -103,13 +118,49 @@ namespace ngraph ...@@ -103,13 +118,49 @@ namespace ngraph
// Sort the temp vector // Sort the temp vector
if (compute_max) if (compute_max)
{ {
sort(workspace.begin(), workspace.end(), compare_max<T, U>); nth_element(workspace.begin(),
workspace.begin() + k,
workspace.end(),
compare_max<T, U>);
} }
else else
{ {
sort(workspace.begin(), workspace.end(), compare_min<T, U>); nth_element(workspace.begin(),
workspace.begin() + k,
workspace.end(),
compare_min<T, U>);
} }
// Write temp vector to output // Write temp vector to output
if (compute_max)
{
switch (sort)
{
case op::TopK::SortType::NONE: break;
case op::TopK::SortType::SORT_INDICES:
std::sort(workspace.begin(),
workspace.begin() + k,
sort_indices_descending<T, U>);
break;
case op::TopK::SortType::SORT_VALUES:
std::sort(workspace.begin(), workspace.begin() + k, compare_max<T, U>);
break;
}
}
else
{
switch (sort)
{
case op::TopK::SortType::NONE: break;
case op::TopK::SortType::SORT_INDICES:
std::sort(workspace.begin(),
workspace.begin() + k,
sort_indices_ascending<T, U>);
break;
case op::TopK::SortType::SORT_VALUES:
std::sort(workspace.begin(), workspace.begin() + k, compare_min<T, U>);
break;
}
}
for (size_t j = 0; j < k; j++) for (size_t j = 0; j < k; j++)
{ {
tuple<T, U> entry = workspace[j]; tuple<T, U> entry = workspace[j];
......
This diff is collapsed.
...@@ -117,7 +117,7 @@ TEST(type_prop, topk_rank_dynamic_ok) ...@@ -117,7 +117,7 @@ TEST(type_prop, topk_rank_dynamic_ok)
ASSERT_TRUE(topk->get_output_element_type(1) == element::f32); ASSERT_TRUE(topk->get_output_element_type(1) == element::f32);
ASSERT_TRUE(topk->get_output_partial_shape(0).rank().is_dynamic()); ASSERT_TRUE(topk->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(topk->get_output_partial_shape(1).rank().is_dynamic()); ASSERT_TRUE(topk->get_output_partial_shape(1).rank().is_dynamic());
ASSERT_TRUE(topk->get_sort() == op::TopK::SortType::NONE); ASSERT_TRUE(topk->get_sort() == op::TopK::SortType::SORT_VALUES);
} }
TEST(type_prop, topk_rank_dynamic_result_et_dynamic) TEST(type_prop, topk_rank_dynamic_result_et_dynamic)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment