decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
20 #include "all_dimension_select.hpp"
21 #include <type_traits>
22 
23 namespace mlpack {
24 namespace tree {
25 
33 template<typename FitnessFunction = GiniGain,
34  template<typename> class NumericSplitType = BestBinaryNumericSplit,
35  template<typename> class CategoricalSplitType = AllCategoricalSplit,
36  typename DimensionSelectionType = AllDimensionSelect,
37  typename ElemType = double,
38  bool NoRecursion = false>
39 class DecisionTree :
40  public NumericSplitType<FitnessFunction>::template
41  AuxiliarySplitInfo<ElemType>,
42  public CategoricalSplitType<FitnessFunction>::template
43  AuxiliarySplitInfo<ElemType>
44 {
45  public:
47  typedef NumericSplitType<FitnessFunction> NumericSplit;
49  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
51  typedef DimensionSelectionType DimensionSelection;
52 
68  template<typename MatType, typename LabelsType>
69  DecisionTree(MatType data,
70  const data::DatasetInfo& datasetInfo,
71  LabelsType labels,
72  const size_t numClasses,
73  const size_t minimumLeafSize = 10,
74  const double minimumGainSplit = 1e-7);
75 
90  template<typename MatType, typename LabelsType>
91  DecisionTree(MatType data,
92  LabelsType labels,
93  const size_t numClasses,
94  const size_t minimumLeafSize = 10,
95  const double minimumGainSplit = 1e-7);
96 
114  template<typename MatType, typename LabelsType, typename WeightsType>
115  DecisionTree(MatType data,
116  const data::DatasetInfo& datasetInfo,
117  LabelsType labels,
118  const size_t numClasses,
119  WeightsType weights,
120  const size_t minimumLeafSize = 10,
121  const double minimumGainSplit = 1e-7,
122  const std::enable_if_t<arma::is_arma_type<
123  typename std::remove_reference<WeightsType>::type>::value>*
124  = 0);
125 
142  template<typename MatType, typename LabelsType, typename WeightsType>
143  DecisionTree(MatType data,
144  LabelsType labels,
145  const size_t numClasses,
146  WeightsType weights,
147  const size_t minimumLeafSize = 10,
148  const double minimumGainSplit = 1e-7,
149  const std::enable_if_t<arma::is_arma_type<
150  typename std::remove_reference<WeightsType>::type>::value>*
151  = 0);
152 
153 
160  DecisionTree(const size_t numClasses = 1);
161 
168  DecisionTree(const DecisionTree& other);
169 
175  DecisionTree(DecisionTree&& other);
176 
183  DecisionTree& operator=(const DecisionTree& other);
184 
191 
195  ~DecisionTree();
196 
215  template<typename MatType, typename LabelsType>
216  double Train(MatType data,
217  const data::DatasetInfo& datasetInfo,
218  LabelsType labels,
219  const size_t numClasses,
220  const size_t minimumLeafSize = 10,
221  const double minimumGainSplit = 1e-7);
222 
239  template<typename MatType, typename LabelsType>
240  double Train(MatType data,
241  LabelsType labels,
242  const size_t numClasses,
243  const size_t minimumLeafSize = 10,
244  const double minimumGainSplit = 1e-7);
245 
265  template<typename MatType, typename LabelsType, typename WeightsType>
266  double Train(MatType data,
267  const data::DatasetInfo& datasetInfo,
268  LabelsType labels,
269  const size_t numClasses,
270  WeightsType weights,
271  const size_t minimumLeafSize = 10,
272  const double minimumGainSplit = 1e-7,
273  const std::enable_if_t<arma::is_arma_type<typename
274  std::remove_reference<WeightsType>::type>::value>* = 0);
275 
293  template<typename MatType, typename LabelsType, typename WeightsType>
294  double Train(MatType data,
295  LabelsType labels,
296  const size_t numClasses,
297  WeightsType weights,
298  const size_t minimumLeafSize = 10,
299  const double minimumGainSplit = 1e-7,
300  const std::enable_if_t<arma::is_arma_type<typename
301  std::remove_reference<WeightsType>::type>::value>* = 0);
302 
309  template<typename VecType>
310  size_t Classify(const VecType& point) const;
311 
321  template<typename VecType>
322  void Classify(const VecType& point,
323  size_t& prediction,
324  arma::vec& probabilities) const;
325 
333  template<typename MatType>
334  void Classify(const MatType& data,
335  arma::Row<size_t>& predictions) const;
336 
347  template<typename MatType>
348  void Classify(const MatType& data,
349  arma::Row<size_t>& predictions,
350  arma::mat& probabilities) const;
351 
355  template<typename Archive>
356  void serialize(Archive& ar, const unsigned int /* version */);
357 
359  size_t NumChildren() const { return children.size(); }
360 
362  const DecisionTree& Child(const size_t i) const { return *children[i]; }
364  DecisionTree& Child(const size_t i) { return *children[i]; }
365 
373  template<typename VecType>
374  size_t CalculateDirection(const VecType& point) const;
375 
379  size_t NumClasses() const;
380 
381  private:
383  std::vector<DecisionTree*> children;
385  size_t splitDimension;
388  size_t dimensionTypeOrMajorityClass;
396  arma::vec classProbabilities;
397 
401  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
402  NumericAuxiliarySplitInfo;
403  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
404  CategoricalAuxiliarySplitInfo;
405 
409  template<bool UseWeights, typename RowType, typename WeightsRowType>
410  void CalculateClassProbabilities(const RowType& labels,
411  const size_t numClasses,
412  const WeightsRowType& weights);
413 
430  template<bool UseWeights, typename MatType>
431  double Train(MatType& data,
432  const size_t begin,
433  const size_t count,
434  const data::DatasetInfo& datasetInfo,
435  arma::Row<size_t>& labels,
436  const size_t numClasses,
437  arma::rowvec& weights,
438  const size_t minimumLeafSize = 10,
439  const double minimumGainSplit = 1e-7);
440 
456  template<bool UseWeights, typename MatType>
457  double Train(MatType& data,
458  const size_t begin,
459  const size_t count,
460  arma::Row<size_t>& labels,
461  const size_t numClasses,
462  arma::rowvec& weights,
463  const size_t minimumLeafSize = 10,
464  const double minimumGainSplit = 1e-7);
465 };
466 
470 template<typename FitnessFunction = GiniGain,
471  template<typename> class NumericSplitType = BestBinaryNumericSplit,
472  template<typename> class CategoricalSplitType = AllCategoricalSplit,
473  typename DimensionSelectType = AllDimensionSelect,
474  typename ElemType = double>
475 using DecisionStump = DecisionTree<FitnessFunction,
476  NumericSplitType,
477  CategoricalSplitType,
478  DimensionSelectType,
479  ElemType,
480  false>;
481 
482 } // namespace tree
483 } // namespace mlpack
484 
485 // Include implementation.
486 #include "decision_tree_impl.hpp"
487 
488 #endif
size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
.hpp
Definition: add_to_po.hpp:21
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7)
Train the decision tree on the given data.
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
~DecisionTree()
Clean up memory.
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7)
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...