decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
20 #include "all_dimension_select.hpp"
21 #include <type_traits>
22 
23 namespace mlpack {
24 namespace tree {
25 
33 template<typename FitnessFunction = GiniGain,
34  template<typename> class NumericSplitType = BestBinaryNumericSplit,
35  template<typename> class CategoricalSplitType = AllCategoricalSplit,
36  typename DimensionSelectionType = AllDimensionSelect,
37  typename ElemType = double,
38  bool NoRecursion = false>
39 class DecisionTree :
40  public NumericSplitType<FitnessFunction>::template
41  AuxiliarySplitInfo<ElemType>,
42  public CategoricalSplitType<FitnessFunction>::template
43  AuxiliarySplitInfo<ElemType>
44 {
45  public:
47  typedef NumericSplitType<FitnessFunction> NumericSplit;
49  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
51  typedef DimensionSelectionType DimensionSelection;
52 
68  template<typename MatType, typename LabelsType>
69  DecisionTree(MatType data,
70  const data::DatasetInfo& datasetInfo,
71  LabelsType labels,
72  const size_t numClasses,
73  const size_t minimumLeafSize = 10,
74  const double minimumGainSplit = 1e-7);
75 
90  template<typename MatType, typename LabelsType>
91  DecisionTree(MatType data,
92  LabelsType labels,
93  const size_t numClasses,
94  const size_t minimumLeafSize = 10,
95  const double minimumGainSplit = 1e-7);
96 
114  template<typename MatType, typename LabelsType, typename WeightsType>
115  DecisionTree(MatType data,
116  const data::DatasetInfo& datasetInfo,
117  LabelsType labels,
118  const size_t numClasses,
119  WeightsType weights,
120  const size_t minimumLeafSize = 10,
121  const double minimumGainSplit = 1e-7,
122  const std::enable_if_t<arma::is_arma_type<
123  typename std::remove_reference<WeightsType>::type>::value>*
124  = 0);
125 
142  template<typename MatType, typename LabelsType, typename WeightsType>
143  DecisionTree(MatType data,
144  LabelsType labels,
145  const size_t numClasses,
146  WeightsType weights,
147  const size_t minimumLeafSize = 10,
148  const double minimumGainSplit = 1e-7,
149  const std::enable_if_t<arma::is_arma_type<
150  typename std::remove_reference<WeightsType>::type>::value>*
151  = 0);
152 
153 
160  DecisionTree(const size_t numClasses = 1);
161 
168  DecisionTree(const DecisionTree& other);
169 
175  DecisionTree(DecisionTree&& other);
176 
183  DecisionTree& operator=(const DecisionTree& other);
184 
191 
195  ~DecisionTree();
196 
214  template<typename MatType, typename LabelsType>
215  void Train(MatType data,
216  const data::DatasetInfo& datasetInfo,
217  LabelsType labels,
218  const size_t numClasses,
219  const size_t minimumLeafSize = 10,
220  const double minimumGainSplit = 1e-7);
221 
237  template<typename MatType, typename LabelsType>
238  void Train(MatType data,
239  LabelsType labels,
240  const size_t numClasses,
241  const size_t minimumLeafSize = 10,
242  const double minimumGainSplit = 1e-7);
243 
262  template<typename MatType, typename LabelsType, typename WeightsType>
263  void Train(MatType data,
264  const data::DatasetInfo& datasetInfo,
265  LabelsType labels,
266  const size_t numClasses,
267  WeightsType weights,
268  const size_t minimumLeafSize = 10,
269  const double minimumGainSplit = 1e-7,
270  const std::enable_if_t<arma::is_arma_type<typename
271  std::remove_reference<WeightsType>::type>::value>* = 0);
272 
289  template<typename MatType, typename LabelsType, typename WeightsType>
290  void Train(MatType data,
291  LabelsType labels,
292  const size_t numClasses,
293  WeightsType weights,
294  const size_t minimumLeafSize = 10,
295  const double minimumGainSplit = 1e-7,
296  const std::enable_if_t<arma::is_arma_type<typename
297  std::remove_reference<WeightsType>::type>::value>* = 0);
298 
305  template<typename VecType>
306  size_t Classify(const VecType& point) const;
307 
317  template<typename VecType>
318  void Classify(const VecType& point,
319  size_t& prediction,
320  arma::vec& probabilities) const;
321 
329  template<typename MatType>
330  void Classify(const MatType& data,
331  arma::Row<size_t>& predictions) const;
332 
343  template<typename MatType>
344  void Classify(const MatType& data,
345  arma::Row<size_t>& predictions,
346  arma::mat& probabilities) const;
347 
351  template<typename Archive>
352  void serialize(Archive& ar, const unsigned int /* version */);
353 
355  size_t NumChildren() const { return children.size(); }
356 
358  const DecisionTree& Child(const size_t i) const { return *children[i]; }
360  DecisionTree& Child(const size_t i) { return *children[i]; }
361 
369  template<typename VecType>
370  size_t CalculateDirection(const VecType& point) const;
371 
375  size_t NumClasses() const;
376 
377  private:
379  std::vector<DecisionTree*> children;
381  size_t splitDimension;
384  size_t dimensionTypeOrMajorityClass;
392  arma::vec classProbabilities;
393 
397  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
398  NumericAuxiliarySplitInfo;
399  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
400  CategoricalAuxiliarySplitInfo;
401 
405  template<bool UseWeights, typename RowType, typename WeightsRowType>
406  void CalculateClassProbabilities(const RowType& labels,
407  const size_t numClasses,
408  const WeightsRowType& weights);
409 
425  template<bool UseWeights, typename MatType>
426  void Train(MatType& data,
427  const size_t begin,
428  const size_t count,
429  const data::DatasetInfo& datasetInfo,
430  arma::Row<size_t>& labels,
431  const size_t numClasses,
432  arma::rowvec& weights,
433  const size_t minimumLeafSize = 10,
434  const double minimumGainSplit = 1e-7);
435 
450  template<bool UseWeights, typename MatType>
451  void Train(MatType& data,
452  const size_t begin,
453  const size_t count,
454  arma::Row<size_t>& labels,
455  const size_t numClasses,
456  arma::rowvec& weights,
457  const size_t minimumLeafSize = 10,
458  const double minimumGainSplit = 1e-7);
459 };
460 
464 template<typename FitnessFunction = GiniGain,
465  template<typename> class NumericSplitType = BestBinaryNumericSplit,
466  template<typename> class CategoricalSplitType = AllCategoricalSplit,
467  typename DimensionSelectType = AllDimensionSelect,
468  typename ElemType = double>
469 using DecisionStump = DecisionTree<FitnessFunction,
470  NumericSplitType,
471  CategoricalSplitType,
472  DimensionSelectType,
473  ElemType,
474  false>;
475 
476 } // namespace tree
477 } // namespace mlpack
478 
479 // Include implementation.
480 #include "decision_tree_impl.hpp"
481 
482 #endif
size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
.hpp
Definition: add_to_po.hpp:21
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
void Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7)
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7)
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...