Commit 892c088e authored by peng xiao's avatar peng xiao

Some modifications to sortByKey API.

Add documentation.
parent a6d55804
...@@ -481,4 +481,40 @@ Performs generalized matrix multiplication. ...@@ -481,4 +481,40 @@ Performs generalized matrix multiplication.
* **GEMM_1_T** transpose ``src1`` * **GEMM_1_T** transpose ``src1``
* **GEMM_2_T** transpose ``src2`` * **GEMM_2_T** transpose ``src2``
.. seealso:: :ocv:func:`gemm` .. seealso:: :ocv:func:`gemm`
\ No newline at end of file
ocl::sortByKey
------------------
Returns void
.. ocv:function:: void ocl::transpose(oclMat& keys, oclMat& values, int method, bool isGreaterThan = false)
:param keys: The keys to be used as sorting indices.
:param values: The array of values.
:param isGreaterThan: Determine sorting order.
:param method: supported sorting methods:
* **SORT_BITONIC** bitonic sort, only support power-of-2 buffer size
* **SORT_SELECTION** selection sort, currently cannot sort duplicate keys
* **SORT_MERGE** merge sort
* **SORT_RADIX** radix sort, only support signed int/float keys(``CV_32S``/``CV_32F``)
Returns the sorted result of all the elements in values based on equivalent keys.
The element unit in the values to be sorted is determined from the data type,
i.e., a ``CV_32FC2`` input ``{a1a2, b1b2}`` will be considered as two elements, regardless its matrix dimension.
Both keys and values will be sorted inplace.
Keys needs to be a **single** channel `oclMat`.
Example::
input -
keys = {2, 3, 1} (CV_8UC1)
values = {10,5, 4,3, 6,2} (CV_8UC2)
sortByKey(keys, values, SORT_SELECTION, false);
output -
keys = {1, 2, 3} (CV_8UC1)
values = {6,2, 10,5, 4,3} (CV_8UC2)
...@@ -1679,7 +1679,7 @@ namespace cv ...@@ -1679,7 +1679,7 @@ namespace cv
SORT_BITONIC, // only support power-of-2 buffer size SORT_BITONIC, // only support power-of-2 buffer size
SORT_SELECTION, // cannot sort duplicate keys SORT_SELECTION, // cannot sort duplicate keys
SORT_MERGE, SORT_MERGE,
SORT_RADIX // only support signed int/float keys SORT_RADIX // only support signed int/float keys(CV_32S/CV_32F)
}; };
//! Returns the sorted result of all the elements in input based on equivalent keys. //! Returns the sorted result of all the elements in input based on equivalent keys.
// //
...@@ -1688,18 +1688,16 @@ namespace cv ...@@ -1688,18 +1688,16 @@ namespace cv
// matrix dimension. // matrix dimension.
// both keys and values will be sorted inplace // both keys and values will be sorted inplace
// Key needs to be single channel oclMat. // Key needs to be single channel oclMat.
// TODO(pengx): add supported types for values
// //
// Example: // Example:
// input - // input -
// keys = {2, 3, 1} (CV_8UC1) // keys = {2, 3, 1} (CV_8UC1)
// values = {10,5, 4,3, 6,2} (CV_8UC2) // values = {10,5, 4,3, 6,2} (CV_8UC2)
// sort_by_key(keys, values, SORT_SELECTION, false); // sortByKey(keys, values, SORT_SELECTION, false);
// output - // output -
// keys = {1, 2, 3} (CV_8UC1) // keys = {1, 2, 3} (CV_8UC1)
// values = {6,2, 10,5, 4,3} (CV_8UC2) // values = {6,2, 10,5, 4,3} (CV_8UC2)
void CV_EXPORTS sort_by_key(oclMat& keys, oclMat& values, int method, bool isGreaterThan = false); void CV_EXPORTS sortByKey(oclMat& keys, oclMat& values, int method, bool isGreaterThan = false);
void CV_EXPORTS sort_by_key(oclMat& keys, oclMat& values, size_t vecSize, int method, bool isGreaterThan = false);
} }
} }
#if defined _MSC_VER && _MSC_VER >= 1200 #if defined _MSC_VER && _MSC_VER >= 1200
......
...@@ -55,6 +55,8 @@ extern const char * kernel_sort_by_key; ...@@ -55,6 +55,8 @@ extern const char * kernel_sort_by_key;
extern const char * kernel_stablesort_by_key; extern const char * kernel_stablesort_by_key;
extern const char * kernel_radix_sort_by_key; extern const char * kernel_radix_sort_by_key;
void sortByKey(oclMat& keys, oclMat& vals, size_t vecSize, int method, bool isGreaterThan);
//TODO(pengx17): change this value depending on device other than a constant //TODO(pengx17): change this value depending on device other than a constant
const static unsigned int GROUP_SIZE = 256; const static unsigned int GROUP_SIZE = 256;
...@@ -85,7 +87,7 @@ inline bool isSizePowerOf2(size_t size) ...@@ -85,7 +87,7 @@ inline bool isSizePowerOf2(size_t size)
namespace bitonic_sort namespace bitonic_sort
{ {
static void sort_by_key(oclMat& keys, oclMat& vals, size_t vecSize, bool isGreaterThan) static void sortByKey(oclMat& keys, oclMat& vals, size_t vecSize, bool isGreaterThan)
{ {
CV_Assert(isSizePowerOf2(vecSize)); CV_Assert(isSizePowerOf2(vecSize));
...@@ -125,7 +127,7 @@ namespace selection_sort ...@@ -125,7 +127,7 @@ namespace selection_sort
{ {
// FIXME: // FIXME:
// This function cannot sort arrays with duplicated keys // This function cannot sort arrays with duplicated keys
static void sort_by_key(oclMat& keys, oclMat& vals, size_t vecSize, bool isGreaterThan) static void sortByKey(oclMat& keys, oclMat& vals, size_t vecSize, bool isGreaterThan)
{ {
CV_Error(-1, "This function is incorrect at the moment."); CV_Error(-1, "This function is incorrect at the moment.");
Context * cxt = Context::getContext(); Context * cxt = Context::getContext();
...@@ -193,7 +195,7 @@ void static naive_scan_addition_cpu(oclMat& input, oclMat& output) ...@@ -193,7 +195,7 @@ void static naive_scan_addition_cpu(oclMat& input, oclMat& output)
//radix sort ported from Bolt //radix sort ported from Bolt
static void sort_by_key(oclMat& keys, oclMat& vals, size_t origVecSize, bool isGreaterThan) static void sortByKey(oclMat& keys, oclMat& vals, size_t origVecSize, bool isGreaterThan)
{ {
CV_Assert(keys.depth() == CV_32S || keys.depth() == CV_32F); // we assume keys are 4 bytes CV_Assert(keys.depth() == CV_32S || keys.depth() == CV_32F); // we assume keys are 4 bytes
...@@ -336,7 +338,7 @@ static void sort_by_key(oclMat& keys, oclMat& vals, size_t origVecSize, bool isG ...@@ -336,7 +338,7 @@ static void sort_by_key(oclMat& keys, oclMat& vals, size_t origVecSize, bool isG
namespace merge_sort namespace merge_sort
{ {
static void sort_by_key(oclMat& keys, oclMat& vals, size_t vecSize, bool isGreaterThan) static void sortByKey(oclMat& keys, oclMat& vals, size_t vecSize, bool isGreaterThan)
{ {
Context * cxt = Context::getContext(); Context * cxt = Context::getContext();
...@@ -421,7 +423,7 @@ static void sort_by_key(oclMat& keys, oclMat& vals, size_t vecSize, bool isGreat ...@@ -421,7 +423,7 @@ static void sort_by_key(oclMat& keys, oclMat& vals, size_t vecSize, bool isGreat
} /* namespace cv { namespace ocl */ } /* namespace cv { namespace ocl */
void cv::ocl::sort_by_key(oclMat& keys, oclMat& vals, size_t vecSize, int method, bool isGreaterThan) void cv::ocl::sortByKey(oclMat& keys, oclMat& vals, size_t vecSize, int method, bool isGreaterThan)
{ {
CV_Assert( keys.rows == 1 ); // we only allow one dimensional input CV_Assert( keys.rows == 1 ); // we only allow one dimensional input
CV_Assert( keys.channels() == 1 ); // we only allow one channel keys CV_Assert( keys.channels() == 1 ); // we only allow one channel keys
...@@ -429,25 +431,24 @@ void cv::ocl::sort_by_key(oclMat& keys, oclMat& vals, size_t vecSize, int method ...@@ -429,25 +431,24 @@ void cv::ocl::sort_by_key(oclMat& keys, oclMat& vals, size_t vecSize, int method
switch(method) switch(method)
{ {
case SORT_BITONIC: case SORT_BITONIC:
bitonic_sort::sort_by_key(keys, vals, vecSize, isGreaterThan); bitonic_sort::sortByKey(keys, vals, vecSize, isGreaterThan);
break; break;
case SORT_SELECTION: case SORT_SELECTION:
selection_sort::sort_by_key(keys, vals, vecSize, isGreaterThan); selection_sort::sortByKey(keys, vals, vecSize, isGreaterThan);
break; break;
case SORT_RADIX: case SORT_RADIX:
radix_sort::sort_by_key(keys, vals, vecSize, isGreaterThan); radix_sort::sortByKey(keys, vals, vecSize, isGreaterThan);
break; break;
case SORT_MERGE: case SORT_MERGE:
merge_sort::sort_by_key(keys, vals, vecSize, isGreaterThan); merge_sort::sortByKey(keys, vals, vecSize, isGreaterThan);
break; break;
} }
} }
void cv::ocl::sortByKey(oclMat& keys, oclMat& vals, int method, bool isGreaterThan)
void cv::ocl::sort_by_key(oclMat& keys, oclMat& vals, int method, bool isGreaterThan)
{ {
CV_Assert( keys.size() == vals.size() ); CV_Assert( keys.size() == vals.size() );
CV_Assert( keys.rows == 1 ); // we only allow one dimensional input CV_Assert( keys.rows == 1 ); // we only allow one dimensional input
size_t vecSize = static_cast<size_t>(keys.cols); size_t vecSize = static_cast<size_t>(keys.cols);
sort_by_key(keys, vals, vecSize, method, isGreaterThan); sortByKey(keys, vals, vecSize, method, isGreaterThan);
} }
...@@ -235,7 +235,7 @@ TEST_P(SortByKey, Accuracy) ...@@ -235,7 +235,7 @@ TEST_P(SortByKey, Accuracy)
ocl::oclMat oclmat_key(mat_key); ocl::oclMat oclmat_key(mat_key);
ocl::oclMat oclmat_val(mat_val); ocl::oclMat oclmat_val(mat_val);
ocl::sort_by_key(oclmat_key, oclmat_val, method, is_gt); ocl::sortByKey(oclmat_key, oclmat_val, method, is_gt);
SortByKey_STL::sort(mat_key, mat_val, is_gt); SortByKey_STL::sort(mat_key, mat_val, is_gt);
EXPECT_MAT_NEAR(mat_key, oclmat_key, 0.0); EXPECT_MAT_NEAR(mat_key, oclmat_key, 0.0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment