dtree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DET_DTREE_HPP
14 #define MLPACK_METHODS_DET_DTREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace det {
20 
44 template<typename MatType = arma::mat,
45  typename TagType = int>
46 class DTree
47 {
48  public:
50  typedef typename MatType::elem_type ElemType;
52  typedef typename MatType::vec_type VecType;
54  typedef typename arma::Col<ElemType> StatType;
55 
59  DTree();
60 
66  DTree(const DTree& obj);
67 
73  DTree& operator=(const DTree& obj);
74 
80  DTree(DTree&& obj);
81 
87  DTree& operator=(DTree&& obj);
88 
97  DTree(const StatType& maxVals,
98  const StatType& minVals,
99  const size_t totalPoints);
100 
109  DTree(MatType& data);
110 
123  DTree(const StatType& maxVals,
124  const StatType& minVals,
125  const size_t start,
126  const size_t end,
127  const double logNegError);
128 
140  DTree(const StatType& maxVals,
141  const StatType& minVals,
142  const size_t totalPoints,
143  const size_t start,
144  const size_t end);
145 
147  ~DTree();
148 
159  double Grow(MatType& data,
160  arma::Col<size_t>& oldFromNew,
161  const bool useVolReg = false,
162  const size_t maxLeafSize = 10,
163  const size_t minLeafSize = 5);
164 
173  double PruneAndUpdate(const double oldAlpha,
174  const size_t points,
175  const bool useVolReg = false);
176 
182  double ComputeValue(const VecType& query) const;
183 
193  TagType TagTree(const TagType& tag = 0, bool everyNode = false);
194 
195 
202  TagType FindBucket(const VecType& query) const;
203 
204 
210  void ComputeVariableImportance(arma::vec& importances) const;
211 
218  double LogNegativeError(const size_t totalPoints) const;
219 
223  bool WithinRange(const VecType& query) const;
224 
225  private:
226  // The indices in the complete set of points
227  // (after all forms of swapping in the original data
228  // matrix to align all the points in a node
229  // consecutively in the matrix. The 'old_from_new' array
230  // maps the points back to their original indices.
231 
234  size_t start;
237  size_t end;
238 
240  StatType maxVals;
242  StatType minVals;
243 
245  size_t splitDim;
246 
248  ElemType splitValue;
249 
251  double logNegError;
252 
254  double subtreeLeavesLogNegError;
255 
257  size_t subtreeLeaves;
258 
260  bool root;
261 
263  double ratio;
264 
266  double logVolume;
267 
269  TagType bucketTag;
270 
272  double alphaUpper;
273 
275  DTree* left;
277  DTree* right;
278 
279  public:
281  size_t Start() const { return start; }
283  size_t End() const { return end; }
285  size_t SplitDim() const { return splitDim; }
287  ElemType SplitValue() const { return splitValue; }
289  double LogNegError() const { return logNegError; }
291  double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; }
293  size_t SubtreeLeaves() const { return subtreeLeaves; }
296  double Ratio() const { return ratio; }
298  double LogVolume() const { return logVolume; }
300  DTree* Left() const { return left; }
302  DTree* Right() const { return right; }
304  bool Root() const { return root; }
306  double AlphaUpper() const { return alphaUpper; }
308  TagType BucketTag() const { return bucketTag; }
310  size_t NumChildren() const { return !left ? 0 : 2; }
311 
318  DTree& Child(const size_t child) const { return !child ? *left : *right; }
319 
320  DTree*& ChildPtr(const size_t child) { return (!child) ? left : right; }
321 
323  const StatType& MaxVals() const { return maxVals; }
324 
326  const StatType& MinVals() const { return minVals; }
327 
331  template<typename Archive>
332  void serialize(Archive& ar, const unsigned int /* version */);
333 
334  private:
335  // Utility methods.
336 
340  bool FindSplit(const MatType& data,
341  size_t& splitDim,
342  ElemType& splitValue,
343  double& leftError,
344  double& rightError,
345  const size_t minLeafSize = 5) const;
346 
350  size_t SplitData(MatType& data,
351  const size_t splitDim,
352  const ElemType splitValue,
353  arma::Col<size_t>& oldFromNew) const;
354 
355  void FillMinMax(const StatType& mins,
356  const StatType& maxs);
357 };
358 
359 } // namespace det
360 } // namespace mlpack
361 
362 #include "dtree_impl.hpp"
363 
364 #endif // MLPACK_METHODS_DET_DTREE_HPP
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
Definition: dtree.hpp:293
double ComputeValue(const VecType &query) const
Compute the logarithm of the density estimate of a given query point.
DTree & operator=(const DTree &obj)
Copy the given tree.
size_t SplitDim() const
Return the split dimension of this node.
Definition: dtree.hpp:285
double Grow(MatType &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
~DTree()
Clean up memory allocated by the tree.
size_t Start() const
Return the starting index of points contained in this node.
Definition: dtree.hpp:281
.hpp
Definition: add_to_po.hpp:21
MatType::elem_type ElemType
The actual, underlying type we&#39;re working with.
Definition: dtree.hpp:50
void serialize(Archive &ar, const unsigned int)
Serialize the density estimation tree.
bool WithinRange(const VecType &query) const
Return whether a query point is within the range of this node.
arma::Col< ElemType > StatType
The statistic type we are holding.
Definition: dtree.hpp:54
double LogNegError() const
Return the log negative error of this node.
Definition: dtree.hpp:289
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Ratio() const
Return the ratio of points in this node to the points in the whole dataset.
Definition: dtree.hpp:296
DTree * Left() const
Return the left child.
Definition: dtree.hpp:300
DTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
Definition: dtree.hpp:318
size_t NumChildren() const
Return the number of children in this node.
Definition: dtree.hpp:310
size_t End() const
Return the first index of a point not contained in this node.
Definition: dtree.hpp:283
TagType TagTree(const TagType &tag=0, bool everyNode=false)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
ElemType SplitValue() const
Return the split value of this node.
Definition: dtree.hpp:287
const StatType & MaxVals() const
Return the maximum values.
Definition: dtree.hpp:323
DTree * Right() const
Return the right child.
Definition: dtree.hpp:302
double LogVolume() const
Return the inverse of the volume of this node.
Definition: dtree.hpp:298
TagType FindBucket(const VecType &query) const
Return the tag of the leaf containing the query.
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
Definition: dtree.hpp:291
MatType::vec_type VecType
The type of vector we are using.
Definition: dtree.hpp:52
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
Definition: dtree.hpp:46
double LogNegativeError(const size_t totalPoints) const
Compute the log-negative-error for this point, given the total number of points in the dataset...
const StatType & MinVals() const
Return the minimum values.
Definition: dtree.hpp:326
double AlphaUpper() const
Return the upper part of the alpha sum.
Definition: dtree.hpp:306
TagType BucketTag() const
Return the current bucket&#39;s ID, if leaf, or -1 otherwise.
Definition: dtree.hpp:308
bool Root() const
Return whether or not this is the root of the tree.
Definition: dtree.hpp:304
DTree *& ChildPtr(const size_t child)
Definition: dtree.hpp:320
DTree()
Create an empty density estimation tree.