decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
20 #include "all_dimension_select.hpp"
21 #include <type_traits>
22 
23 namespace mlpack {
24 namespace tree {
25 
33 template<typename FitnessFunction = GiniGain,
34  template<typename> class NumericSplitType = BestBinaryNumericSplit,
35  template<typename> class CategoricalSplitType = AllCategoricalSplit,
36  typename DimensionSelectionType = AllDimensionSelect,
37  typename ElemType = double,
38  bool NoRecursion = false>
39 class DecisionTree :
40  public NumericSplitType<FitnessFunction>::template
41  AuxiliarySplitInfo<ElemType>,
42  public CategoricalSplitType<FitnessFunction>::template
43  AuxiliarySplitInfo<ElemType>
44 {
45  public:
47  typedef NumericSplitType<FitnessFunction> NumericSplit;
49  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
51  typedef DimensionSelectionType DimensionSelection;
52 
66  template<typename MatType, typename LabelsType>
67  DecisionTree(MatType&& data,
68  const data::DatasetInfo& datasetInfo,
69  LabelsType&& labels,
70  const size_t numClasses,
71  const size_t minimumLeafSize = 10,
72  const double minimumGainSplit = 1e-7);
73 
86  template<typename MatType, typename LabelsType>
87  DecisionTree(MatType&& data,
88  LabelsType&& labels,
89  const size_t numClasses,
90  const size_t minimumLeafSize = 10,
91  const double minimumGainSplit = 1e-7);
92 
107  template<typename MatType, typename LabelsType, typename WeightsType>
108  DecisionTree(MatType&& data,
109  const data::DatasetInfo& datasetInfo,
110  LabelsType&& labels,
111  const size_t numClasses,
112  WeightsType&& weights,
113  const size_t minimumLeafSize = 10,
114  const double minimumGainSplit = 1e-7,
115  const std::enable_if_t<arma::is_arma_type<
116  typename std::remove_reference<WeightsType>::type>::value>*
117  = 0);
118 
132  template<typename MatType, typename LabelsType, typename WeightsType>
133  DecisionTree(MatType&& data,
134  LabelsType&& labels,
135  const size_t numClasses,
136  WeightsType&& weights,
137  const size_t minimumLeafSize = 10,
138  const double minimumGainSplit = 1e-7,
139  const std::enable_if_t<arma::is_arma_type<
140  typename std::remove_reference<WeightsType>::type>::value>*
141  = 0);
142 
143 
150  DecisionTree(const size_t numClasses = 1);
151 
158  DecisionTree(const DecisionTree& other);
159 
165  DecisionTree(DecisionTree&& other);
166 
173  DecisionTree& operator=(const DecisionTree& other);
174 
181 
185  ~DecisionTree();
186 
202  template<typename MatType, typename LabelsType>
203  void Train(MatType&& data,
204  const data::DatasetInfo& datasetInfo,
205  LabelsType&& labels,
206  const size_t numClasses,
207  const size_t minimumLeafSize = 10,
208  const double minimumGainSplit = 1e-7);
209 
223  template<typename MatType, typename LabelsType>
224  void Train(MatType&& data,
225  LabelsType&& labels,
226  const size_t numClasses,
227  const size_t minimumLeafSize = 10,
228  const double minimumGainSplit = 1e-7);
229 
245  template<typename MatType, typename LabelsType, typename WeightsType>
246  void Train(MatType&& data,
247  const data::DatasetInfo& datasetInfo,
248  LabelsType&& labels,
249  const size_t numClasses,
250  WeightsType&& weights,
251  const size_t minimumLeafSize = 10,
252  const double minimumGainSplit = 1e-7,
253  const std::enable_if_t<arma::is_arma_type<typename
254  std::remove_reference<WeightsType>::type>::value>* = 0);
255 
269  template<typename MatType, typename LabelsType, typename WeightsType>
270  void Train(MatType&& data,
271  LabelsType&& labels,
272  const size_t numClasses,
273  WeightsType&& weights,
274  const size_t minimumLeafSize = 10,
275  const double minimumGainSplit = 1e-7,
276  const std::enable_if_t<arma::is_arma_type<typename
277  std::remove_reference<WeightsType>::type>::value>* = 0);
278 
285  template<typename VecType>
286  size_t Classify(const VecType& point) const;
287 
297  template<typename VecType>
298  void Classify(const VecType& point,
299  size_t& prediction,
300  arma::vec& probabilities) const;
301 
309  template<typename MatType>
310  void Classify(const MatType& data,
311  arma::Row<size_t>& predictions) const;
312 
323  template<typename MatType>
324  void Classify(const MatType& data,
325  arma::Row<size_t>& predictions,
326  arma::mat& probabilities) const;
327 
331  template<typename Archive>
332  void serialize(Archive& ar, const unsigned int /* version */);
333 
335  size_t NumChildren() const { return children.size(); }
336 
338  const DecisionTree& Child(const size_t i) const { return *children[i]; }
340  DecisionTree& Child(const size_t i) { return *children[i]; }
341 
349  template<typename VecType>
350  size_t CalculateDirection(const VecType& point) const;
351 
355  size_t NumClasses() const;
356 
357  private:
359  std::vector<DecisionTree*> children;
361  size_t splitDimension;
364  size_t dimensionTypeOrMajorityClass;
372  arma::vec classProbabilities;
373 
377  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
378  NumericAuxiliarySplitInfo;
379  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
380  CategoricalAuxiliarySplitInfo;
381 
385  template<bool UseWeights, typename RowType, typename WeightsRowType>
386  void CalculateClassProbabilities(const RowType& labels,
387  const size_t numClasses,
388  const WeightsRowType& weights);
389 
405  template<bool UseWeights, typename MatType>
406  void Train(MatType& data,
407  const size_t begin,
408  const size_t count,
409  const data::DatasetInfo& datasetInfo,
410  arma::Row<size_t>& labels,
411  const size_t numClasses,
412  arma::rowvec& weights,
413  const size_t minimumLeafSize = 10,
414  const double minimumGainSplit = 1e-7);
415 
430  template<bool UseWeights, typename MatType>
431  void Train(MatType& data,
432  const size_t begin,
433  const size_t count,
434  arma::Row<size_t>& labels,
435  const size_t numClasses,
436  arma::rowvec& weights,
437  const size_t minimumLeafSize = 10,
438  const double minimumGainSplit = 1e-7);
439 };
440 
444 template<typename FitnessFunction = GiniGain,
445  template<typename> class NumericSplitType = BestBinaryNumericSplit,
446  template<typename> class CategoricalSplitType = AllCategoricalSplit,
447  typename DimensionSelectType = AllDimensionSelect,
448  typename ElemType = double>
449 using DecisionStump = DecisionTree<FitnessFunction,
450  NumericSplitType,
451  CategoricalSplitType,
452  DimensionSelectType,
453  ElemType,
454  false>;
455 
456 } // namespace tree
457 } // namespace mlpack
458 
459 // Include implementation.
460 #include "decision_tree_impl.hpp"
461 
462 #endif
size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
void Train(MatType &&data, const data::DatasetInfo &datasetInfo, LabelsType &&labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7)
Train the decision tree on the given data.
.hpp
Definition: add_to_po.hpp:21
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
DecisionTree(MatType &&data, const data::DatasetInfo &datasetInfo, LabelsType &&labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7)
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
~DecisionTree()
Clean up memory.