cover_tree.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
13 #define MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
14 
15 #include <mlpack/prereqs.hpp>
17 
18 #include "../statistic.hpp"
19 #include "first_point_is_root.hpp"
20 
21 namespace mlpack {
22 namespace tree {
23 
95 template<typename MetricType = metric::LMetric<2, true>,
96  typename StatisticType = EmptyStatistic,
97  typename MatType = arma::mat,
98  typename RootPointPolicy = FirstPointIsRoot>
99 class CoverTree
100 {
101  public:
103  typedef MatType Mat;
105  typedef typename MatType::elem_type ElemType;
106 
117  CoverTree(const MatType& dataset,
118  const ElemType base = 2.0,
119  MetricType* metric = NULL);
120 
130  CoverTree(const MatType& dataset,
131  MetricType& metric,
132  const ElemType base = 2.0);
133 
141  CoverTree(MatType&& dataset,
142  const ElemType base = 2.0);
143 
152  CoverTree(MatType&& dataset,
153  MetricType& metric,
154  const ElemType base = 2.0);
155 
187  CoverTree(const MatType& dataset,
188  const ElemType base,
189  const size_t pointIndex,
190  const int scale,
191  CoverTree* parent,
192  const ElemType parentDistance,
193  arma::Col<size_t>& indices,
194  arma::vec& distances,
195  size_t nearSetSize,
196  size_t& farSetSize,
197  size_t& usedSetSize,
198  MetricType& metric = NULL);
199 
216  CoverTree(const MatType& dataset,
217  const ElemType base,
218  const size_t pointIndex,
219  const int scale,
220  CoverTree* parent,
221  const ElemType parentDistance,
222  const ElemType furthestDescendantDistance,
223  MetricType* metric = NULL);
224 
231  CoverTree(const CoverTree& other);
232 
239  CoverTree(CoverTree&& other);
240 
246  CoverTree& operator=(const CoverTree& other);
247 
253  CoverTree& operator=(CoverTree&& other);
254 
258  template<typename Archive>
259  CoverTree(
260  Archive& ar,
262 
266  ~CoverTree();
267 
270  template<typename RuleType>
272 
274  template<typename RuleType>
276 
277  template<typename RuleType>
279 
281  const MatType& Dataset() const { return *dataset; }
282 
284  size_t Point() const { return point; }
286  size_t Point(const size_t) const { return point; }
287 
288  bool IsLeaf() const { return (children.size() == 0); }
289  size_t NumPoints() const { return 1; }
290 
292  const CoverTree& Child(const size_t index) const { return *children[index]; }
294  CoverTree& Child(const size_t index) { return *children[index]; }
295 
296  CoverTree*& ChildPtr(const size_t index) { return children[index]; }
297 
299  size_t NumChildren() const { return children.size(); }
300 
302  const std::vector<CoverTree*>& Children() const { return children; }
304  std::vector<CoverTree*>& Children() { return children; }
305 
307  size_t NumDescendants() const;
308 
310  size_t Descendant(const size_t index) const;
311 
313  int Scale() const { return scale; }
315  int& Scale() { return scale; }
316 
318  ElemType Base() const { return base; }
320  ElemType& Base() { return base; }
321 
323  const StatisticType& Stat() const { return stat; }
325  StatisticType& Stat() { return stat; }
326 
331  template<typename VecType>
332  size_t GetNearestChild(
333  const VecType& point,
335 
340  template<typename VecType>
341  size_t GetFurthestChild(
342  const VecType& point,
344 
349  size_t GetNearestChild(const CoverTree& queryNode);
350 
355  size_t GetFurthestChild(const CoverTree& queryNode);
356 
358  ElemType MinDistance(const CoverTree& other) const;
359 
362  ElemType MinDistance(const CoverTree& other, const ElemType distance) const;
363 
365  ElemType MinDistance(const arma::vec& other) const;
366 
369  ElemType MinDistance(const arma::vec& other, const ElemType distance) const;
370 
372  ElemType MaxDistance(const CoverTree& other) const;
373 
376  ElemType MaxDistance(const CoverTree& other, const ElemType distance) const;
377 
379  ElemType MaxDistance(const arma::vec& other) const;
380 
383  ElemType MaxDistance(const arma::vec& other, const ElemType distance) const;
384 
387 
391  const ElemType distance) const;
392 
394  math::RangeType<ElemType> RangeDistance(const arma::vec& other) const;
395 
398  math::RangeType<ElemType> RangeDistance(const arma::vec& other,
399  const ElemType distance) const;
400 
402  CoverTree* Parent() const { return parent; }
404  CoverTree*& Parent() { return parent; }
405 
407  ElemType ParentDistance() const { return parentDistance; }
409  ElemType& ParentDistance() { return parentDistance; }
410 
412  ElemType FurthestPointDistance() const { return 0.0; }
413 
415  ElemType FurthestDescendantDistance() const
416  { return furthestDescendantDistance; }
419  ElemType& FurthestDescendantDistance() { return furthestDescendantDistance; }
420 
423  ElemType MinimumBoundDistance() const { return furthestDescendantDistance; }
424 
426  void Center(arma::vec& center) const
427  {
428  center = arma::vec(dataset->col(point));
429  }
430 
432  MetricType& Metric() const { return *metric; }
433 
434  private:
436  const MatType* dataset;
438  size_t point;
440  std::vector<CoverTree*> children;
442  int scale;
444  ElemType base;
446  StatisticType stat;
448  size_t numDescendants;
450  CoverTree* parent;
452  ElemType parentDistance;
454  ElemType furthestDescendantDistance;
456  bool localMetric;
458  bool localDataset;
460  MetricType* metric;
461 
465  void CreateChildren(arma::Col<size_t>& indices,
466  arma::vec& distances,
467  size_t nearSetSize,
468  size_t& farSetSize,
469  size_t& usedSetSize);
470 
482  void ComputeDistances(const size_t pointIndex,
483  const arma::Col<size_t>& indices,
484  arma::vec& distances,
485  const size_t pointSetSize);
500  size_t SplitNearFar(arma::Col<size_t>& indices,
501  arma::vec& distances,
502  const ElemType bound,
503  const size_t pointSetSize);
504 
524  size_t SortPointSet(arma::Col<size_t>& indices,
525  arma::vec& distances,
526  const size_t childFarSetSize,
527  const size_t childUsedSetSize,
528  const size_t farSetSize);
529 
530  void MoveToUsedSet(arma::Col<size_t>& indices,
531  arma::vec& distances,
532  size_t& nearSetSize,
533  size_t& farSetSize,
534  size_t& usedSetSize,
535  arma::Col<size_t>& childIndices,
536  const size_t childFarSetSize,
537  const size_t childUsedSetSize);
538  size_t PruneFarSet(arma::Col<size_t>& indices,
539  arma::vec& distances,
540  const ElemType bound,
541  const size_t nearSetSize,
542  const size_t pointSetSize);
543 
548  void RemoveNewImplicitNodes();
549 
550  protected:
557  CoverTree();
558 
560  friend class boost::serialization::access;
561 
562  public:
566  template<typename Archive>
567  void serialize(Archive& ar, const unsigned int /* version */);
568 
569  size_t DistanceComps() const { return distanceComps; }
570  size_t& DistanceComps() { return distanceComps; }
571 
572  private:
573  size_t distanceComps;
574 };
575 
576 } // namespace tree
577 } // namespace mlpack
578 
579 // Include implementation.
580 #include "cover_tree_impl.hpp"
581 
582 // Include the rest of the pieces, if necessary.
583 #include "../cover_tree.hpp"
584 
585 #endif
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
size_t DistanceComps() const
Definition: cover_tree.hpp:569
CoverTree & operator=(const CoverTree &other)
Copy the given Cover Tree.
size_t NumPoints() const
Definition: cover_tree.hpp:289
MatType Mat
So that other classes can access the matrix type.
Definition: cover_tree.hpp:103
void Center(arma::vec &center) const
Get the center of the node and store it in the given vector.
Definition: cover_tree.hpp:426
A dual-tree cover tree traverser; see dual_tree_traverser.hpp.
Definition: cover_tree.hpp:275
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
ElemType Base() const
Get the base.
Definition: cover_tree.hpp:318
.hpp
Definition: add_to_po.hpp:21
size_t Point() const
Get the index of the point which this node represents.
Definition: cover_tree.hpp:284
MatType::elem_type ElemType
The type held by the matrix type.
Definition: cover_tree.hpp:105
ElemType MaxDistance(const CoverTree &other) const
Return the maximum distance to another node.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
ElemType & FurthestDescendantDistance()
Modify the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:419
const std::vector< CoverTree * > & Children() const
Get the children.
Definition: cover_tree.hpp:302
int & Scale()
Modify the scale of this node. Be careful...
Definition: cover_tree.hpp:315
StatisticType & Stat()
Modify the statistic for this node.
Definition: cover_tree.hpp:325
CoverTree()
A default constructor.
CoverTree *& Parent()
Modify the parent node.
Definition: cover_tree.hpp:404
CoverTree * Parent() const
Get the parent node.
Definition: cover_tree.hpp:402
std::vector< CoverTree * > & Children()
Modify the children manually (maybe not a great idea).
Definition: cover_tree.hpp:304
int Scale() const
Get the scale of this node.
Definition: cover_tree.hpp:313
~CoverTree()
Delete this cover tree node and its children.
const StatisticType & Stat() const
Get the statistic for this node.
Definition: cover_tree.hpp:323
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
CoverTree *& ChildPtr(const size_t index)
Definition: cover_tree.hpp:296
A single-tree cover tree traverser; see single_tree_traverser.hpp for implementation.
Definition: cover_tree.hpp:271
ElemType ParentDistance() const
Get the distance to the parent.
Definition: cover_tree.hpp:407
ElemType MinDistance(const CoverTree &other) const
Return the minimum distance to another node.
size_t Point(const size_t) const
For compatibility with other trees; the argument is ignored.
Definition: cover_tree.hpp:286
const MatType & Dataset() const
Get a reference to the dataset.
Definition: cover_tree.hpp:281
size_t NumChildren() const
Get the number of children.
Definition: cover_tree.hpp:299
ElemType FurthestPointDistance() const
Get the distance to the furthest point. This is always 0 for cover trees.
Definition: cover_tree.hpp:412
ElemType & Base()
Modify the base; don&#39;t do this, you&#39;ll break everything.
Definition: cover_tree.hpp:320
Definition of the Range class, which represents a simple range with a lower and upper bound...
CoverTree & Child(const size_t index)
Modify a particular child node.
Definition: cover_tree.hpp:294
ElemType MinimumBoundDistance() const
Get the minimum distance from the center to any bound edge (this is the same as furthestDescendantDis...
Definition: cover_tree.hpp:423
MetricType & Metric() const
Get the instantiated metric.
Definition: cover_tree.hpp:432
size_t Descendant(const size_t index) const
Get the index of a particular descendant point.
ElemType FurthestDescendantDistance() const
Get the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:415
math::RangeType< ElemType > RangeDistance(const CoverTree &other) const
Return the minimum and maximum distance to another node.
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99
ElemType & ParentDistance()
Modify the distance to the parent.
Definition: cover_tree.hpp:409
const CoverTree & Child(const size_t index) const
Get a particular child node.
Definition: cover_tree.hpp:292
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
size_t NumDescendants() const
Get the number of descendant points.