13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 21 #include <type_traits> 33 template<
typename FitnessFunction = GiniGain,
34 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
35 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
36 typename DimensionSelectionType = AllDimensionSelect,
37 typename ElemType = double,
38 bool NoRecursion =
false>
40 public NumericSplitType<FitnessFunction>::template
41 AuxiliarySplitInfo<ElemType>,
42 public CategoricalSplitType<FitnessFunction>::template
43 AuxiliarySplitInfo<ElemType>
70 template<
typename MatType,
typename LabelsType>
74 const size_t numClasses,
75 const size_t minimumLeafSize = 10,
76 const double minimumGainSplit = 1e-7,
77 const size_t maximumDepth = 0,
78 DimensionSelectionType dimensionSelector =
79 DimensionSelectionType());
97 template<
typename MatType,
typename LabelsType>
100 const size_t numClasses,
101 const size_t minimumLeafSize = 10,
102 const double minimumGainSplit = 1e-7,
103 const size_t maximumDepth = 0,
104 DimensionSelectionType dimensionSelector =
105 DimensionSelectionType());
126 template<
typename MatType,
typename LabelsType,
typename WeightsType>
130 const size_t numClasses,
132 const size_t minimumLeafSize = 10,
133 const double minimumGainSplit = 1e-7,
134 const size_t maximumDepth = 0,
135 DimensionSelectionType dimensionSelector =
136 DimensionSelectionType(),
138 typename std::remove_reference<WeightsType>::type>::value>*
159 template<
typename MatType,
typename LabelsType,
typename WeightsType>
162 const size_t numClasses,
164 const size_t minimumLeafSize = 10,
165 const double minimumGainSplit = 1e-7,
166 const size_t maximumDepth = 0,
167 DimensionSelectionType dimensionSelector =
168 DimensionSelectionType(),
170 typename std::remove_reference<WeightsType>::type>::value>*
237 template<
typename MatType,
typename LabelsType>
238 double Train(MatType data,
241 const size_t numClasses,
242 const size_t minimumLeafSize = 10,
243 const double minimumGainSplit = 1e-7,
244 const size_t maximumDepth = 0,
245 DimensionSelectionType dimensionSelector =
246 DimensionSelectionType());
266 template<
typename MatType,
typename LabelsType>
267 double Train(MatType data,
269 const size_t numClasses,
270 const size_t minimumLeafSize = 10,
271 const double minimumGainSplit = 1e-7,
272 const size_t maximumDepth = 0,
273 DimensionSelectionType dimensionSelector =
274 DimensionSelectionType());
297 template<
typename MatType,
typename LabelsType,
typename WeightsType>
298 double Train(MatType data,
301 const size_t numClasses,
303 const size_t minimumLeafSize = 10,
304 const double minimumGainSplit = 1e-7,
305 const size_t maximumDepth = 0,
306 DimensionSelectionType dimensionSelector =
307 DimensionSelectionType(),
309 std::remove_reference<WeightsType>::type>::value>* = 0);
330 template<
typename MatType,
typename LabelsType,
typename WeightsType>
331 double Train(MatType data,
333 const size_t numClasses,
335 const size_t minimumLeafSize = 10,
336 const double minimumGainSplit = 1e-7,
337 const size_t maximumDepth = 0,
338 DimensionSelectionType dimensionSelector =
339 DimensionSelectionType(),
341 std::remove_reference<WeightsType>::type>::value>* = 0);
349 template<
typename VecType>
350 size_t Classify(
const VecType& point)
const;
361 template<
typename VecType>
364 arma::vec& probabilities)
const;
373 template<
typename MatType>
375 arma::Row<size_t>& predictions)
const;
387 template<
typename MatType>
389 arma::Row<size_t>& predictions,
390 arma::mat& probabilities)
const;
395 template<
typename Archive>
396 void serialize(Archive& ar,
const unsigned int );
417 template<
typename VecType>
427 std::vector<DecisionTree*> children;
429 size_t splitDimension;
432 size_t dimensionTypeOrMajorityClass;
440 arma::vec classProbabilities;
445 typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
446 NumericAuxiliarySplitInfo;
447 typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
448 CategoricalAuxiliarySplitInfo;
453 template<
bool UseWeights,
typename RowType,
typename WeightsRowType>
454 void CalculateClassProbabilities(
const RowType& labels,
455 const size_t numClasses,
456 const WeightsRowType& weights);
475 template<
bool UseWeights,
typename MatType>
476 double Train(MatType& data,
480 arma::Row<size_t>& labels,
481 const size_t numClasses,
482 arma::rowvec& weights,
483 const size_t minimumLeafSize,
484 const double minimumGainSplit,
485 const size_t maximumDepth,
486 DimensionSelectionType& dimensionSelector);
504 template<
bool UseWeights,
typename MatType>
505 double Train(MatType& data,
508 arma::Row<size_t>& labels,
509 const size_t numClasses,
510 arma::rowvec& weights,
511 const size_t minimumLeafSize,
512 const double minimumGainSplit,
513 const size_t maximumDepth,
514 DimensionSelectionType& dimensionSelector);
520 template<
typename FitnessFunction =
GiniGain,
524 typename ElemType =
double>
527 CategoricalSplitType,
536 #include "decision_tree_impl.hpp" size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.