13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 22 #include <type_traits> 34 template<
typename FitnessFunction = GiniGain,
35 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
36 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
37 typename DimensionSelectionType = AllDimensionSelect,
38 typename ElemType = double,
39 bool NoRecursion =
false>
41 public NumericSplitType<FitnessFunction>::template
42 AuxiliarySplitInfo<ElemType>,
43 public CategoricalSplitType<FitnessFunction>::template
44 AuxiliarySplitInfo<ElemType>
71 template<
typename MatType,
typename LabelsType>
75 const size_t numClasses,
76 const size_t minimumLeafSize = 10,
77 const double minimumGainSplit = 1e-7,
78 const size_t maximumDepth = 0,
79 DimensionSelectionType dimensionSelector =
80 DimensionSelectionType());
98 template<
typename MatType,
typename LabelsType>
101 const size_t numClasses,
102 const size_t minimumLeafSize = 10,
103 const double minimumGainSplit = 1e-7,
104 const size_t maximumDepth = 0,
105 DimensionSelectionType dimensionSelector =
106 DimensionSelectionType());
127 template<
typename MatType,
typename LabelsType,
typename WeightsType>
132 const size_t numClasses,
134 const size_t minimumLeafSize = 10,
135 const double minimumGainSplit = 1e-7,
136 const size_t maximumDepth = 0,
137 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
139 typename std::remove_reference<WeightsType>::type>::value>* = 0);
159 template<
typename MatType,
typename LabelsType,
typename WeightsType>
165 const size_t numClasses,
167 const size_t minimumLeafSize = 10,
168 const double minimumGainSplit = 1e-7,
170 typename std::remove_reference<WeightsType>::type>::value>* = 0);
189 template<
typename MatType,
typename LabelsType,
typename WeightsType>
193 const size_t numClasses,
195 const size_t minimumLeafSize = 10,
196 const double minimumGainSplit = 1e-7,
197 const size_t maximumDepth = 0,
198 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
200 typename std::remove_reference<WeightsType>::type>::value>* = 0);
220 template<
typename MatType,
typename LabelsType,
typename WeightsType>
225 const size_t numClasses,
227 const size_t minimumLeafSize = 10,
228 const double minimumGainSplit = 1e-7,
229 const size_t maximumDepth = 0,
230 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
232 typename std::remove_reference<WeightsType>::type>::value>* = 0);
297 template<
typename MatType,
typename LabelsType>
298 double Train(MatType data,
301 const size_t numClasses,
302 const size_t minimumLeafSize = 10,
303 const double minimumGainSplit = 1e-7,
304 const size_t maximumDepth = 0,
305 DimensionSelectionType dimensionSelector =
306 DimensionSelectionType());
326 template<
typename MatType,
typename LabelsType>
327 double Train(MatType data,
329 const size_t numClasses,
330 const size_t minimumLeafSize = 10,
331 const double minimumGainSplit = 1e-7,
332 const size_t maximumDepth = 0,
333 DimensionSelectionType dimensionSelector =
334 DimensionSelectionType());
357 template<
typename MatType,
typename LabelsType,
typename WeightsType>
358 double Train(MatType data,
361 const size_t numClasses,
363 const size_t minimumLeafSize = 10,
364 const double minimumGainSplit = 1e-7,
365 const size_t maximumDepth = 0,
366 DimensionSelectionType dimensionSelector =
367 DimensionSelectionType(),
369 std::remove_reference<WeightsType>::type>::value>* = 0);
390 template<
typename MatType,
typename LabelsType,
typename WeightsType>
391 double Train(MatType data,
393 const size_t numClasses,
395 const size_t minimumLeafSize = 10,
396 const double minimumGainSplit = 1e-7,
397 const size_t maximumDepth = 0,
398 DimensionSelectionType dimensionSelector =
399 DimensionSelectionType(),
401 std::remove_reference<WeightsType>::type>::value>* = 0);
409 template<
typename VecType>
410 size_t Classify(
const VecType& point)
const;
421 template<
typename VecType>
424 arma::vec& probabilities)
const;
433 template<
typename MatType>
435 arma::Row<size_t>& predictions)
const;
447 template<
typename MatType>
449 arma::Row<size_t>& predictions,
450 arma::mat& probabilities)
const;
455 template<
typename Archive>
456 void serialize(Archive& ar,
const unsigned int );
477 template<
typename VecType>
487 std::vector<DecisionTree*> children;
489 size_t splitDimension;
492 size_t dimensionTypeOrMajorityClass;
500 arma::vec classProbabilities;
505 typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
506 NumericAuxiliarySplitInfo;
507 typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
508 CategoricalAuxiliarySplitInfo;
513 template<
bool UseWeights,
typename RowType,
typename WeightsRowType>
514 void CalculateClassProbabilities(
const RowType& labels,
515 const size_t numClasses,
516 const WeightsRowType& weights);
535 template<
bool UseWeights,
typename MatType>
536 double Train(MatType& data,
540 arma::Row<size_t>& labels,
541 const size_t numClasses,
542 arma::rowvec& weights,
543 const size_t minimumLeafSize,
544 const double minimumGainSplit,
545 const size_t maximumDepth,
546 DimensionSelectionType& dimensionSelector);
564 template<
bool UseWeights,
typename MatType>
565 double Train(MatType& data,
568 arma::Row<size_t>& labels,
569 const size_t numClasses,
570 arma::rowvec& weights,
571 const size_t minimumLeafSize,
572 const double minimumGainSplit,
573 const size_t maximumDepth,
574 DimensionSelectionType& dimensionSelector);
580 template<
typename FitnessFunction =
GiniGain,
584 typename ElemType =
double>
587 CategoricalSplitType,
606 #include "decision_tree_impl.hpp" size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, double, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.