decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
20 #include "all_dimension_select.hpp"
21 #include <type_traits>
22 
23 namespace mlpack {
24 namespace tree {
25 
33 template<typename FitnessFunction = GiniGain,
34  template<typename> class NumericSplitType = BestBinaryNumericSplit,
35  template<typename> class CategoricalSplitType = AllCategoricalSplit,
36  typename DimensionSelectionType = AllDimensionSelect,
37  typename ElemType = double,
38  bool NoRecursion = false>
39 class DecisionTree :
40  public NumericSplitType<FitnessFunction>::template
41  AuxiliarySplitInfo<ElemType>,
42  public CategoricalSplitType<FitnessFunction>::template
43  AuxiliarySplitInfo<ElemType>
44 {
45  public:
47  typedef NumericSplitType<FitnessFunction> NumericSplit;
49  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
51  typedef DimensionSelectionType DimensionSelection;
52 
69  template<typename MatType, typename LabelsType>
70  DecisionTree(MatType data,
71  const data::DatasetInfo& datasetInfo,
72  LabelsType labels,
73  const size_t numClasses,
74  const size_t minimumLeafSize = 10,
75  const double minimumGainSplit = 1e-7,
76  DimensionSelectionType dimensionSelector =
77  DimensionSelectionType());
78 
94  template<typename MatType, typename LabelsType>
95  DecisionTree(MatType data,
96  LabelsType labels,
97  const size_t numClasses,
98  const size_t minimumLeafSize = 10,
99  const double minimumGainSplit = 1e-7,
100  DimensionSelectionType dimensionSelector =
101  DimensionSelectionType());
102 
121  template<typename MatType, typename LabelsType, typename WeightsType>
122  DecisionTree(MatType data,
123  const data::DatasetInfo& datasetInfo,
124  LabelsType labels,
125  const size_t numClasses,
126  WeightsType weights,
127  const size_t minimumLeafSize = 10,
128  const double minimumGainSplit = 1e-7,
129  DimensionSelectionType dimensionSelector =
130  DimensionSelectionType(),
131  const std::enable_if_t<arma::is_arma_type<
132  typename std::remove_reference<WeightsType>::type>::value>*
133  = 0);
134 
152  template<typename MatType, typename LabelsType, typename WeightsType>
153  DecisionTree(MatType data,
154  LabelsType labels,
155  const size_t numClasses,
156  WeightsType weights,
157  const size_t minimumLeafSize = 10,
158  const double minimumGainSplit = 1e-7,
159  DimensionSelectionType dimensionSelector =
160  DimensionSelectionType(),
161  const std::enable_if_t<arma::is_arma_type<
162  typename std::remove_reference<WeightsType>::type>::value>*
163  = 0);
164 
165 
172  DecisionTree(const size_t numClasses = 1);
173 
180  DecisionTree(const DecisionTree& other);
181 
187  DecisionTree(DecisionTree&& other);
188 
195  DecisionTree& operator=(const DecisionTree& other);
196 
203 
207  ~DecisionTree();
208 
228  template<typename MatType, typename LabelsType>
229  double Train(MatType data,
230  const data::DatasetInfo& datasetInfo,
231  LabelsType labels,
232  const size_t numClasses,
233  const size_t minimumLeafSize = 10,
234  const double minimumGainSplit = 1e-7,
235  DimensionSelectionType dimensionSelector =
236  DimensionSelectionType());
237 
255  template<typename MatType, typename LabelsType>
256  double Train(MatType data,
257  LabelsType labels,
258  const size_t numClasses,
259  const size_t minimumLeafSize = 10,
260  const double minimumGainSplit = 1e-7,
261  DimensionSelectionType dimensionSelector =
262  DimensionSelectionType());
263 
284  template<typename MatType, typename LabelsType, typename WeightsType>
285  double Train(MatType data,
286  const data::DatasetInfo& datasetInfo,
287  LabelsType labels,
288  const size_t numClasses,
289  WeightsType weights,
290  const size_t minimumLeafSize = 10,
291  const double minimumGainSplit = 1e-7,
292  DimensionSelectionType dimensionSelector =
293  DimensionSelectionType(),
294  const std::enable_if_t<arma::is_arma_type<typename
295  std::remove_reference<WeightsType>::type>::value>* = 0);
296 
315  template<typename MatType, typename LabelsType, typename WeightsType>
316  double Train(MatType data,
317  LabelsType labels,
318  const size_t numClasses,
319  WeightsType weights,
320  const size_t minimumLeafSize = 10,
321  const double minimumGainSplit = 1e-7,
322  DimensionSelectionType dimensionSelector =
323  DimensionSelectionType(),
324  const std::enable_if_t<arma::is_arma_type<typename
325  std::remove_reference<WeightsType>::type>::value>* = 0);
326 
333  template<typename VecType>
334  size_t Classify(const VecType& point) const;
335 
345  template<typename VecType>
346  void Classify(const VecType& point,
347  size_t& prediction,
348  arma::vec& probabilities) const;
349 
357  template<typename MatType>
358  void Classify(const MatType& data,
359  arma::Row<size_t>& predictions) const;
360 
371  template<typename MatType>
372  void Classify(const MatType& data,
373  arma::Row<size_t>& predictions,
374  arma::mat& probabilities) const;
375 
379  template<typename Archive>
380  void serialize(Archive& ar, const unsigned int /* version */);
381 
383  size_t NumChildren() const { return children.size(); }
384 
386  const DecisionTree& Child(const size_t i) const { return *children[i]; }
388  DecisionTree& Child(const size_t i) { return *children[i]; }
389 
392  size_t SplitDimension() const { return splitDimension; }
393 
401  template<typename VecType>
402  size_t CalculateDirection(const VecType& point) const;
403 
407  size_t NumClasses() const;
408 
409  private:
411  std::vector<DecisionTree*> children;
413  size_t splitDimension;
416  size_t dimensionTypeOrMajorityClass;
424  arma::vec classProbabilities;
425 
429  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
430  NumericAuxiliarySplitInfo;
431  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
432  CategoricalAuxiliarySplitInfo;
433 
437  template<bool UseWeights, typename RowType, typename WeightsRowType>
438  void CalculateClassProbabilities(const RowType& labels,
439  const size_t numClasses,
440  const WeightsRowType& weights);
441 
458  template<bool UseWeights, typename MatType>
459  double Train(MatType& data,
460  const size_t begin,
461  const size_t count,
462  const data::DatasetInfo& datasetInfo,
463  arma::Row<size_t>& labels,
464  const size_t numClasses,
465  arma::rowvec& weights,
466  const size_t minimumLeafSize,
467  const double minimumGainSplit,
468  DimensionSelectionType& dimensionSelector);
469 
485  template<bool UseWeights, typename MatType>
486  double Train(MatType& data,
487  const size_t begin,
488  const size_t count,
489  arma::Row<size_t>& labels,
490  const size_t numClasses,
491  arma::rowvec& weights,
492  const size_t minimumLeafSize,
493  const double minimumGainSplit,
494  DimensionSelectionType& dimensionSelector);
495 };
496 
500 template<typename FitnessFunction = GiniGain,
501  template<typename> class NumericSplitType = BestBinaryNumericSplit,
502  template<typename> class CategoricalSplitType = AllCategoricalSplit,
503  typename DimensionSelectType = AllDimensionSelect,
504  typename ElemType = double>
505 using DecisionStump = DecisionTree<FitnessFunction,
506  NumericSplitType,
507  CategoricalSplitType,
508  DimensionSelectType,
509  ElemType,
510  false>;
511 
512 } // namespace tree
513 } // namespace mlpack
514 
515 // Include implementation.
516 #include "decision_tree_impl.hpp"
517 
518 #endif
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
.hpp
Definition: add_to_po.hpp:21
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
~DecisionTree()
Clean up memory.