_kdtree.hpp 14.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                        Intel License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2008, Xavier Delacour, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of Intel Corporation may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

// 2008-05-13, Xavier Delacour <xavier.delacour@gmail.com>

#ifndef __cv_kdtree_h__
#define __cv_kdtree_h__

#include "precomp.hpp"

#include <vector>
#include <algorithm>
#include <limits>
#include <iostream>
#include "assert.h"
#include "math.h"

56 57 58
// J.S. Beis and D.G. Lowe. Shape indexing using approximate nearest-neighbor search
// in highdimensional spaces. In Proc. IEEE Conf. Comp. Vision Patt. Recog.,
// pages 1000--1006, 1997. http://citeseer.ist.psu.edu/beis97shape.html
59 60 61 62 63 64 65 66 67 68 69 70
#undef __deref
#undef __valuetype

template < class __valuetype, class __deref >
class CvKDTree {
public:
  typedef __deref deref_type;
  typedef typename __deref::scalar_type scalar_type;
  typedef typename __deref::accum_type accum_type;

private:
  struct node {
71 72 73 74
    int dim;      // split dimension; >=0 for nodes, -1 for leaves
    __valuetype value;    // if leaf, value of leaf
    int left, right;    // node indices of left and right branches
    scalar_type boundary; // left if deref(value,dim)<=boundary, otherwise right
75 76 77
  };
  typedef std::vector < node > node_array;

78
  __deref deref;    // requires operator() (__valuetype lhs,int dim)
79

80 81 82
  node_array nodes;   // node storage
  int point_dim;    // dimension of points (the k in kd-tree)
  int root_node;    // index of root node, -1 if empty tree
83 84 85 86

  // for given set of point indices, compute dimension of highest variance
  template < class __instype, class __valuector >
  int dimension_of_highest_variance(__instype * first, __instype * last,
87
            __valuector ctor) {
88 89 90 91 92 93 94
    assert(last - first > 0);

    accum_type maxvar = -std::numeric_limits < accum_type >::max();
    int maxj = -1;
    for (int j = 0; j < point_dim; ++j) {
      accum_type mean = 0;
      for (__instype * k = first; k < last; ++k)
95
  mean += deref(ctor(*k), j);
96 97 98
      mean /= last - first;
      accum_type var = 0;
      for (__instype * k = first; k < last; ++k) {
99 100
  accum_type diff = accum_type(deref(ctor(*k), j)) - mean;
  var += diff * diff;
101 102 103 104 105 106
      }
      var /= last - first;

      assert(maxj != -1 || var >= maxvar);

      if (var >= maxvar) {
107 108
  maxvar = var;
  maxj = j;
109 110 111 112 113 114
      }
    }

    return maxj;
  }

115
  // given point indices and dimension, find index of median; (almost) modifies [first,last)
116 117 118 119
  // such that points_in[first,median]<=point[median], points_in(median,last)>point[median].
  // implemented as partial quicksort; expected linear perf.
  template < class __instype, class __valuector >
  __instype * median_partition(__instype * first, __instype * last,
120
             int dim, __valuector ctor) {
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
    assert(last - first > 0);
    __instype *k = first + (last - first) / 2;
    median_partition(first, last, k, dim, ctor);
    return k;
  }

  template < class __instype, class __valuector >
  struct median_pr {
    const __instype & pivot;
    int dim;
    __deref deref;
    __valuector ctor;
    median_pr(const __instype & _pivot, int _dim, __deref _deref, __valuector _ctor)
      : pivot(_pivot), dim(_dim), deref(_deref), ctor(_ctor) {
    }
    bool operator() (const __instype & lhs) const {
      return deref(ctor(lhs), dim) <= deref(ctor(pivot), dim);
    }
139 140
  private:
    median_pr& operator=(const median_pr&);
141 142 143
  };

  template < class __instype, class __valuector >
144 145
  void median_partition(__instype * first, __instype * last,
      __instype * k, int dim, __valuector ctor) {
146 147 148 149
    int pivot = (int)((last - first) / 2);

    std::swap(first[pivot], last[-1]);
    __instype *middle = std::partition(first, last - 1,
150 151
               median_pr < __instype, __valuector >
               (last[-1], dim, deref, ctor));
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    std::swap(*middle, last[-1]);

    if (middle < k)
      median_partition(middle + 1, last, k, dim, ctor);
    else if (middle > k)
      median_partition(first, middle, k, dim, ctor);
  }

  // insert given points into the tree; return created node
  template < class __instype, class __valuector >
  int insert(__instype * first, __instype * last, __valuector ctor) {
    if (first == last)
      return -1;
    else {

      int dim = dimension_of_highest_variance(first, last, ctor);
      __instype *median = median_partition(first, last, dim, ctor);

      __instype *split = median;
171
      for (; split != last && deref(ctor(*split), dim) ==
Andrey Kamaev's avatar
Andrey Kamaev committed
172
       deref(ctor(*median), dim); ++split) {}
173 174

      if (split == last) { // leaf
175 176 177 178 179 180 181 182 183 184 185 186
  int nexti = -1;
  for (--split; split >= first; --split) {
    int i = (int)nodes.size();
    node & n = *nodes.insert(nodes.end(), node());
    n.dim = -1;
    n.value = ctor(*split);
    n.left = -1;
    n.right = nexti;
    nexti = i;
  }

  return nexti;
187
      } else { // node
188 189 190
  int i = (int)nodes.size();
  // note that recursive insert may invalidate this ref
  node & n = *nodes.insert(nodes.end(), node());
191

192 193
  n.dim = dim;
  n.boundary = deref(ctor(*median), dim);
194

195 196 197 198
  int left = insert(first, split, ctor);
  nodes[i].left = left;
  int right = insert(split, last, ctor);
  nodes[i].right = right;
199

200
  return i;
201 202 203 204 205 206 207 208 209 210 211 212 213 214
      }
    }
  }

  // run to leaf; linear search for p;
  // if found, remove paths to empty leaves on unwind
  bool remove(int *i, const __valuetype & p) {
    if (*i == -1)
      return false;
    node & n = nodes[*i];
    bool r;

    if (n.dim >= 0) { // node
      if (deref(p, n.dim) <= n.boundary) // left
215
  r = remove(&n.left, p);
216
      else // right
217
  r = remove(&n.right, p);
218 219 220

      // if terminal, remove this node
      if (n.left == -1 && n.right == -1)
221
  *i = -1;
222 223 224 225

      return r;
    } else { // leaf
      if (n.value == p) {
226 227
  *i = n.right;
  return true;
228
      } else
229
  return remove(&n.right, p);
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
    }
  }

public:
  struct identity_ctor {
    const __valuetype & operator() (const __valuetype & rhs) const {
      return rhs;
    }
  };

  // initialize an empty tree
  CvKDTree(__deref _deref = __deref())
    : deref(_deref), root_node(-1) {
  }
  // given points, initialize a balanced tree
  CvKDTree(__valuetype * first, __valuetype * last, int _point_dim,
246
     __deref _deref = __deref())
247 248 249 250 251 252
    : deref(_deref) {
    set_data(first, last, _point_dim, identity_ctor());
  }
  // given points, initialize a balanced tree
  template < class __instype, class __valuector >
  CvKDTree(__instype * first, __instype * last, int _point_dim,
253
     __valuector ctor, __deref _deref = __deref())
254 255 256 257 258 259 260 261 262 263 264 265 266
    : deref(_deref) {
    set_data(first, last, _point_dim, ctor);
  }

  void set_deref(__deref _deref) {
    deref = _deref;
  }

  void set_data(__valuetype * first, __valuetype * last, int _point_dim) {
    set_data(first, last, _point_dim, identity_ctor());
  }
  template < class __instype, class __valuector >
  void set_data(__instype * first, __instype * last, int _point_dim,
267
    __valuector ctor) {
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
    point_dim = _point_dim;
    nodes.clear();
    nodes.reserve(last - first);
    root_node = insert(first, last, ctor);
  }

  int dims() const {
    return point_dim;
  }

  // remove the given point
  bool remove(const __valuetype & p) {
    return remove(&root_node, p);
  }

  void print() const {
    print(root_node);
  }
  void print(int i, int indent = 0) const {
    if (i == -1)
      return;
    for (int j = 0; j < indent; ++j)
      std::cout << " ";
    const node & n = nodes[i];
    if (n.dim >= 0) {
293 294 295
      std::cout << "node " << i << ", left " << nodes[i].left << ", right " <<
  nodes[i].right << ", dim " << nodes[i].dim << ", boundary " <<
  nodes[i].boundary << std::endl;
296 297 298 299 300 301 302 303 304
      print(n.left, indent + 3);
      print(n.right, indent + 3);
    } else
      std::cout << "leaf " << i << ", value = " << nodes[i].value << std::endl;
  }

  ////////////////////////////////////////////////////////////////////////////////////////
  // bbf search
public:
305 306 307
  struct bbf_nn {   // info on found neighbors (approx k nearest)
    const __valuetype *p; // nearest neighbor
    accum_type dist;    // distance from d to query point
308 309 310 311 312 313 314 315 316
    bbf_nn(const __valuetype & _p, accum_type _dist)
      : p(&_p), dist(_dist) {
    }
    bool operator<(const bbf_nn & rhs) const {
      return dist < rhs.dist;
    }
  };
  typedef std::vector < bbf_nn > bbf_nn_pqueue;
private:
317 318 319
  struct bbf_node {   // info on branches not taken
    int node;     // corresponding node
    accum_type dist;    // minimum distance from bounds to query point
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
    bbf_node(int _node, accum_type _dist)
      : node(_node), dist(_dist) {
    }
    bool operator<(const bbf_node & rhs) const {
      return dist > rhs.dist;
    }
  };
  typedef std::vector < bbf_node > bbf_pqueue;
  mutable bbf_pqueue tmp_pq;

  // called for branches not taken, as bbf walks to leaf;
  // construct bbf_node given minimum distance to bounds of alternate branch
  void pq_alternate(int alt_n, bbf_pqueue & pq, scalar_type dist) const {
    if (alt_n == -1)
      return;

    // add bbf_node for alternate branch in priority queue
    pq.push_back(bbf_node(alt_n, dist));
338
    std::push_heap(pq.begin(), pq.end());
339 340 341 342 343 344 345 346
  }

  // called by bbf to walk to leaf;
  // takes one step down the tree towards query point d
  template < class __desctype >
  int bbf_branch(int i, const __desctype * d, bbf_pqueue & pq) const {
    const node & n = nodes[i];
    // push bbf_node with bounds of alternate branch, then branch
347
    if (d[n.dim] <= n.boundary) { // left
348 349
      pq_alternate(n.right, pq, n.boundary - d[n.dim]);
      return n.left;
350
    } else {      // right
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
      pq_alternate(n.left, pq, d[n.dim] - n.boundary);
      return n.right;
    }
  }

  // compute euclidean distance between two points
  template < class __desctype >
  accum_type distance(const __desctype * d, const __valuetype & p) const {
    accum_type dist = 0;
    for (int j = 0; j < point_dim; ++j) {
      accum_type diff = accum_type(d[j]) - accum_type(deref(p, j));
      dist += diff * diff;
    } return (accum_type) sqrt(dist);
  }

  // called per candidate nearest neighbor; constructs new bbf_nn for
367
  // candidate and adds it to priority queue of all candidates; if
368 369
  // queue len exceeds k, drops the point furthest from query point d.
  template < class __desctype >
370 371
  void bbf_new_nn(bbf_nn_pqueue & nn_pq, int k,
      const __desctype * d, const __valuetype & p) const {
372 373 374
    bbf_nn nn(p, distance(d, p));
    if ((int) nn_pq.size() < k) {
      nn_pq.push_back(nn);
375
      std::push_heap(nn_pq.begin(), nn_pq.end());
376
    } else if (nn_pq[0].dist > nn.dist) {
377
      std::pop_heap(nn_pq.begin(), nn_pq.end());
378
      nn_pq.end()[-1] = nn;
379
      std::push_heap(nn_pq.begin(), nn_pq.end());
380 381 382 383 384
    }
    assert(nn_pq.size() < 2 || nn_pq[0].dist >= nn_pq[1].dist);
  }

public:
385
  // finds (with high probability) the k nearest neighbors of d,
386
  // searching at most emax leaves/bins.
387
  // ret_nn_pq is an array containing the (at most) k nearest neighbors
388 389
  // (see bbf_nn structure def above).
  template < class __desctype >
Andrey Kamaev's avatar
Andrey Kamaev committed
390 391
  int find_nn_bbf(const __desctype * d, int k, int emax, bbf_nn_pqueue & ret_nn_pq) const
  {
392 393 394 395 396 397 398 399 400 401
    assert(k > 0);
    ret_nn_pq.clear();

    if (root_node == -1)
      return 0;

    // add root_node to bbf_node priority queue;
    // iterate while queue non-empty and emax>0
    tmp_pq.clear();
    tmp_pq.push_back(bbf_node(root_node, 0));
Andrey Kamaev's avatar
Andrey Kamaev committed
402 403
    while (tmp_pq.size() && emax > 0)
    {
404 405

      // from node nearest query point d, run to leaf
406
      std::pop_heap(tmp_pq.begin(), tmp_pq.end());
407 408 409 410
      bbf_node bbf(tmp_pq.end()[-1]);
      tmp_pq.erase(tmp_pq.end() - 1);

      int i;
Andrey Kamaev's avatar
Andrey Kamaev committed
411
      for (i = bbf.node; i != -1 && nodes[i].dim >= 0; i = bbf_branch(i, d, tmp_pq)) {}
412

Andrey Kamaev's avatar
Andrey Kamaev committed
413 414
      if (i != -1)
      {
415

Andrey Kamaev's avatar
Andrey Kamaev committed
416 417 418 419
        // add points in leaf/bin to ret_nn_pq
        do {
          bbf_new_nn(ret_nn_pq, k, d, nodes[i].value);
        } while (-1 != (i = nodes[i].right));
420

Andrey Kamaev's avatar
Andrey Kamaev committed
421
        --emax;
422 423 424 425 426 427 428 429 430 431 432
      }
    }

    tmp_pq.clear();
    return (int)ret_nn_pq.size();
  }

  ////////////////////////////////////////////////////////////////////////////////////////
  // orthogonal range search
private:
  void find_ortho_range(int i, scalar_type * bounds_min,
433 434
      scalar_type * bounds_max,
      std::vector < __valuetype > &inbounds) const {
435 436 437 438 439
    if (i == -1)
      return;
    const node & n = nodes[i];
    if (n.dim >= 0) { // node
      if (bounds_min[n.dim] <= n.boundary)
440
  find_ortho_range(n.left, bounds_min, bounds_max, inbounds);
441
      if (bounds_max[n.dim] > n.boundary)
442
  find_ortho_range(n.right, bounds_min, bounds_max, inbounds);
443 444
    } else { // leaf
      do {
445
  inbounds.push_back(nodes[i].value);
446 447 448 449 450 451
      } while (-1 != (i = nodes[i].right));
    }
  }
public:
  // return all points that lie within the given bounds; inbounds is cleared
  int find_ortho_range(scalar_type * bounds_min,
452 453
           scalar_type * bounds_max,
           std::vector < __valuetype > &inbounds) const {
454 455 456 457 458 459 460 461 462 463 464
    inbounds.clear();
    find_ortho_range(root_node, bounds_min, bounds_max, inbounds);
    return (int)inbounds.size();
  }
};

#endif // __cv_kdtree_h__

// Local Variables:
// mode:C++
// End: