13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 21 #include <type_traits> 33 template<
typename FitnessFunction = GiniGain,
34 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
35 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
36 typename DimensionSelectionType = AllDimensionSelect,
37 typename ElemType = double,
38 bool NoRecursion =
false>
40 public NumericSplitType<FitnessFunction>::template
41 AuxiliarySplitInfo<ElemType>,
42 public CategoricalSplitType<FitnessFunction>::template
43 AuxiliarySplitInfo<ElemType>
66 template<
typename MatType,
typename LabelsType>
70 const size_t numClasses,
71 const size_t minimumLeafSize = 10,
72 const double minimumGainSplit = 1e-7);
86 template<
typename MatType,
typename LabelsType>
89 const size_t numClasses,
90 const size_t minimumLeafSize = 10,
91 const double minimumGainSplit = 1e-7);
107 template<
typename MatType,
typename LabelsType,
typename WeightsType>
111 const size_t numClasses,
112 WeightsType&& weights,
113 const size_t minimumLeafSize = 10,
114 const double minimumGainSplit = 1e-7,
116 typename std::remove_reference<WeightsType>::type>::value>*
132 template<
typename MatType,
typename LabelsType,
typename WeightsType>
135 const size_t numClasses,
136 WeightsType&& weights,
137 const size_t minimumLeafSize = 10,
138 const double minimumGainSplit = 1e-7,
140 typename std::remove_reference<WeightsType>::type>::value>*
202 template<
typename MatType,
typename LabelsType>
203 void Train(MatType&& data,
206 const size_t numClasses,
207 const size_t minimumLeafSize = 10,
208 const double minimumGainSplit = 1e-7);
223 template<
typename MatType,
typename LabelsType>
224 void Train(MatType&& data,
226 const size_t numClasses,
227 const size_t minimumLeafSize = 10,
228 const double minimumGainSplit = 1e-7);
245 template<
typename MatType,
typename LabelsType,
typename WeightsType>
246 void Train(MatType&& data,
249 const size_t numClasses,
250 WeightsType&& weights,
251 const size_t minimumLeafSize = 10,
252 const double minimumGainSplit = 1e-7,
254 std::remove_reference<WeightsType>::type>::value>* = 0);
269 template<
typename MatType,
typename LabelsType,
typename WeightsType>
270 void Train(MatType&& data,
272 const size_t numClasses,
273 WeightsType&& weights,
274 const size_t minimumLeafSize = 10,
275 const double minimumGainSplit = 1e-7,
277 std::remove_reference<WeightsType>::type>::value>* = 0);
285 template<
typename VecType>
286 size_t Classify(
const VecType& point)
const;
297 template<
typename VecType>
300 arma::vec& probabilities)
const;
309 template<
typename MatType>
311 arma::Row<size_t>& predictions)
const;
323 template<
typename MatType>
325 arma::Row<size_t>& predictions,
326 arma::mat& probabilities)
const;
331 template<
typename Archive>
332 void serialize(Archive& ar,
const unsigned int );
349 template<
typename VecType>
359 std::vector<DecisionTree*> children;
361 size_t splitDimension;
364 size_t dimensionTypeOrMajorityClass;
372 arma::vec classProbabilities;
377 typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
378 NumericAuxiliarySplitInfo;
379 typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
380 CategoricalAuxiliarySplitInfo;
385 template<
bool UseWeights,
typename RowType,
typename WeightsRowType>
386 void CalculateClassProbabilities(
const RowType& labels,
387 const size_t numClasses,
388 const WeightsRowType& weights);
405 template<
bool UseWeights,
typename MatType>
406 void Train(MatType& data,
410 arma::Row<size_t>& labels,
411 const size_t numClasses,
412 arma::rowvec& weights,
413 const size_t minimumLeafSize = 10,
414 const double minimumGainSplit = 1e-7);
430 template<
bool UseWeights,
typename MatType>
431 void Train(MatType& data,
434 arma::Row<size_t>& labels,
435 const size_t numClasses,
436 arma::rowvec& weights,
437 const size_t minimumLeafSize = 10,
438 const double minimumGainSplit = 1e-7);
444 template<
typename FitnessFunction =
GiniGain,
448 typename ElemType =
double>
451 CategoricalSplitType,
460 #include "decision_tree_impl.hpp" size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
void Train(MatType &&data, const data::DatasetInfo &datasetInfo, LabelsType &&labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7)
Train the decision tree on the given data.
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
DecisionTree(MatType &&data, const data::DatasetInfo &datasetInfo, LabelsType &&labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7)
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
~DecisionTree()
Clean up memory.