13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 21 #include <type_traits> 33 template<
typename FitnessFunction = GiniGain,
34 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
35 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
36 typename DimensionSelectionType = AllDimensionSelect,
37 typename ElemType = double,
38 bool NoRecursion =
false>
40 public NumericSplitType<FitnessFunction>::template
41 AuxiliarySplitInfo<ElemType>,
42 public CategoricalSplitType<FitnessFunction>::template
43 AuxiliarySplitInfo<ElemType>
69 template<
typename MatType,
typename LabelsType>
73 const size_t numClasses,
74 const size_t minimumLeafSize = 10,
75 const double minimumGainSplit = 1e-7,
76 DimensionSelectionType dimensionSelector =
77 DimensionSelectionType());
94 template<
typename MatType,
typename LabelsType>
97 const size_t numClasses,
98 const size_t minimumLeafSize = 10,
99 const double minimumGainSplit = 1e-7,
100 DimensionSelectionType dimensionSelector =
101 DimensionSelectionType());
121 template<
typename MatType,
typename LabelsType,
typename WeightsType>
125 const size_t numClasses,
127 const size_t minimumLeafSize = 10,
128 const double minimumGainSplit = 1e-7,
129 DimensionSelectionType dimensionSelector =
130 DimensionSelectionType(),
132 typename std::remove_reference<WeightsType>::type>::value>*
152 template<
typename MatType,
typename LabelsType,
typename WeightsType>
155 const size_t numClasses,
157 const size_t minimumLeafSize = 10,
158 const double minimumGainSplit = 1e-7,
159 DimensionSelectionType dimensionSelector =
160 DimensionSelectionType(),
162 typename std::remove_reference<WeightsType>::type>::value>*
228 template<
typename MatType,
typename LabelsType>
229 double Train(MatType data,
232 const size_t numClasses,
233 const size_t minimumLeafSize = 10,
234 const double minimumGainSplit = 1e-7,
235 DimensionSelectionType dimensionSelector =
236 DimensionSelectionType());
255 template<
typename MatType,
typename LabelsType>
256 double Train(MatType data,
258 const size_t numClasses,
259 const size_t minimumLeafSize = 10,
260 const double minimumGainSplit = 1e-7,
261 DimensionSelectionType dimensionSelector =
262 DimensionSelectionType());
284 template<
typename MatType,
typename LabelsType,
typename WeightsType>
285 double Train(MatType data,
288 const size_t numClasses,
290 const size_t minimumLeafSize = 10,
291 const double minimumGainSplit = 1e-7,
292 DimensionSelectionType dimensionSelector =
293 DimensionSelectionType(),
295 std::remove_reference<WeightsType>::type>::value>* = 0);
315 template<
typename MatType,
typename LabelsType,
typename WeightsType>
316 double Train(MatType data,
318 const size_t numClasses,
320 const size_t minimumLeafSize = 10,
321 const double minimumGainSplit = 1e-7,
322 DimensionSelectionType dimensionSelector =
323 DimensionSelectionType(),
325 std::remove_reference<WeightsType>::type>::value>* = 0);
333 template<
typename VecType>
334 size_t Classify(
const VecType& point)
const;
345 template<
typename VecType>
348 arma::vec& probabilities)
const;
357 template<
typename MatType>
359 arma::Row<size_t>& predictions)
const;
371 template<
typename MatType>
373 arma::Row<size_t>& predictions,
374 arma::mat& probabilities)
const;
379 template<
typename Archive>
380 void serialize(Archive& ar,
const unsigned int );
401 template<
typename VecType>
411 std::vector<DecisionTree*> children;
413 size_t splitDimension;
416 size_t dimensionTypeOrMajorityClass;
424 arma::vec classProbabilities;
429 typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
430 NumericAuxiliarySplitInfo;
431 typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
432 CategoricalAuxiliarySplitInfo;
437 template<
bool UseWeights,
typename RowType,
typename WeightsRowType>
438 void CalculateClassProbabilities(
const RowType& labels,
439 const size_t numClasses,
440 const WeightsRowType& weights);
458 template<
bool UseWeights,
typename MatType>
459 double Train(MatType& data,
463 arma::Row<size_t>& labels,
464 const size_t numClasses,
465 arma::rowvec& weights,
466 const size_t minimumLeafSize,
467 const double minimumGainSplit,
468 DimensionSelectionType& dimensionSelector);
485 template<
bool UseWeights,
typename MatType>
486 double Train(MatType& data,
489 arma::Row<size_t>& labels,
490 const size_t numClasses,
491 arma::rowvec& weights,
492 const size_t minimumLeafSize,
493 const double minimumGainSplit,
494 DimensionSelectionType& dimensionSelector);
500 template<
typename FitnessFunction =
GiniGain,
504 typename ElemType =
double>
507 CategoricalSplitType,
516 #include "decision_tree_impl.hpp" double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
~DecisionTree()
Clean up memory.