neighbor_search.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
14 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include <vector>
18 #include <string>
19 
23 
24 #include "neighbor_search_stat.hpp"
27 
28 namespace mlpack {
29 // Neighbor-search routines. These include all-nearest-neighbors and
30 // all-furthest-neighbors searches.
31 namespace neighbor {
32 
33 // Forward declaration.
34 template<typename SortPolicy>
36 
39 {
44 };
45 
69 template<typename SortPolicy = NearestNeighborSort,
70  typename MetricType = mlpack::metric::EuclideanDistance,
71  typename MatType = arma::mat,
72  template<typename TreeMetricType,
73  typename TreeStatType,
74  typename TreeMatType> class TreeType = tree::KDTree,
75  template<typename RuleType> class DualTreeTraversalType =
76  TreeType<MetricType,
78  MatType>::template DualTreeTraverser,
79  template<typename RuleType> class SingleTreeTraversalType =
80  TreeType<MetricType,
81  NeighborSearchStat<SortPolicy>,
82  MatType>::template SingleTreeTraverser>
84 {
85  public:
87  typedef TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType> Tree;
88 
106  NeighborSearch(const MatType& referenceSet,
107  const NeighborSearchMode mode = DUAL_TREE_MODE,
108  const double epsilon = 0,
109  const MetricType metric = MetricType());
110 
128  NeighborSearch(MatType&& referenceSet,
129  const NeighborSearchMode mode = DUAL_TREE_MODE,
130  const double epsilon = 0,
131  const MetricType metric = MetricType());
132 
157  const Tree& referenceTree,
158  const NeighborSearchMode mode = DUAL_TREE_MODE,
159  const double epsilon = 0,
160  const MetricType metric = MetricType());
161 
187  Tree&& referenceTree,
188  const NeighborSearchMode mode = DUAL_TREE_MODE,
189  const double epsilon = 0,
190  const MetricType metric = MetricType());
191 
202  const double epsilon = 0,
203  const MetricType metric = MetricType());
204 
211  NeighborSearch(const NeighborSearch& other);
212 
220 
226  NeighborSearch& operator=(const NeighborSearch& other);
227 
234 
239  ~NeighborSearch();
240 
249  void Train(const MatType& referenceSet);
250 
259  void Train(MatType&& referenceSet);
260 
269  void Train(const Tree& referenceTree);
270 
278  void Train(Tree&& referenceTree);
279 
297  void Search(const MatType& querySet,
298  const size_t k,
299  arma::Mat<size_t>& neighbors,
300  arma::mat& distances);
301 
322  void Search(Tree& queryTree,
323  const size_t k,
324  arma::Mat<size_t>& neighbors,
325  arma::mat& distances,
326  bool sameSet = false);
327 
342  void Search(const size_t k,
343  arma::Mat<size_t>& neighbors,
344  arma::mat& distances);
345 
361  static double EffectiveError(arma::mat& foundDistances,
362  arma::mat& realDistances);
363 
375  static double Recall(arma::Mat<size_t>& foundNeighbors,
376  arma::Mat<size_t>& realNeighbors);
377 
380  size_t BaseCases() const { return baseCases; }
381 
383  size_t Scores() const { return scores; }
384 
386  NeighborSearchMode SearchMode() const { return searchMode; }
388  NeighborSearchMode& SearchMode() { return searchMode; }
389 
391  double Epsilon() const { return epsilon; }
393  double& Epsilon() { return epsilon; }
394 
396  const MatType& ReferenceSet() const { return *referenceSet; }
397 
399  const Tree& ReferenceTree() const { return *referenceTree; }
401  Tree& ReferenceTree() { return *referenceTree; }
402 
404  template<typename Archive>
405  void serialize(Archive& ar, const unsigned int /* version */);
406 
407  private:
409  std::vector<size_t> oldFromNewReferences;
411  Tree* referenceTree;
413  const MatType* referenceSet;
414 
416  bool treeOwner;
418  bool setOwner;
419 
421  NeighborSearchMode searchMode;
423  double epsilon;
424 
426  MetricType metric;
427 
429  size_t baseCases;
431  size_t scores;
432 
435  bool treeNeedsReset;
436 
438  template<typename SortPol>
439  friend class TrainVisitor;
440 }; // class NeighborSearch
441 
442 } // namespace neighbor
443 } // namespace mlpack
444 
445 // Include implementation.
446 #include "neighbor_search_impl.hpp"
447 
448 // Include convenience typedefs.
449 #include "typedef.hpp"
450 
451 #endif
const MatType & ReferenceSet() const
Access the reference dataset.
double Epsilon() const
Access the relative error to be considered in approximate search.
size_t Scores() const
Return the number of node combination scores during the last search.
const Tree & ReferenceTree() const
Access the reference tree.
.hpp
Definition: add_to_po.hpp:21
Extra data for each node in the tree.
The core includes that mlpack expects; standard C++ includes and Armadillo.
The NeighborSearch class is a template class for performing distance-based neighbor searches...
void Train(const MatType &referenceSet)
Set the reference set to a new reference set, and build a tree if necessary.
static double EffectiveError(arma::mat &foundDistances, arma::mat &realDistances)
Calculate the average relative error (effective error) between the distances calculated and the true ...
double & Epsilon()
Modify the relative error to be considered in approximate search.
Tree & ReferenceTree()
Modify the reference tree.
static double Recall(arma::Mat< size_t > &foundNeighbors, arma::Mat< size_t > &realNeighbors)
Calculate the recall (% of neighbors found) given the list of found neighbors and the true set of nei...
NeighborSearchMode & SearchMode()
Modify the search mode.
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
NeighborSearchMode SearchMode() const
Access the search mode.
NeighborSearch & operator=(const NeighborSearch &other)
Copy the given NeighborSearch object.
TreeType< MetricType, NeighborSearchStat< SortPolicy >, MatType > Tree
Convenience typedef.
TrainVisitor sets the reference set to a new reference set on the given NSType.
NeighborSearch(const MatType &referenceSet, const NeighborSearchMode mode=DUAL_TREE_MODE, const double epsilon=0, const MetricType metric=MetricType())
Initialize the NeighborSearch object, passing a reference dataset (this is the dataset which is searc...
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
For each point in the query set, compute the nearest neighbors and store the output in the given matr...
size_t BaseCases() const
Return the total number of base case evaluations performed during the last search.
void serialize(Archive &ar, const unsigned int)
Serialize the NeighborSearch model.
BinarySpaceTree< MetricType, StatisticType, MatType, bound::HRectBound, MidpointSplit > KDTree
The standard midpoint-split kd-tree.
Definition: typedef.hpp:63
LMetric< 2, true > EuclideanDistance
The Euclidean (L2) distance.
Definition: lmetric.hpp:112
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
~NeighborSearch()
Delete the NeighborSearch object.