13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 22 #include <type_traits> 34 template<
typename FitnessFunction = GiniGain,
35 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
36 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
37 typename DimensionSelectionType = AllDimensionSelect,
38 typename ElemType = double,
39 bool NoRecursion =
false>
41 public NumericSplitType<FitnessFunction>::template
42 AuxiliarySplitInfo<ElemType>,
43 public CategoricalSplitType<FitnessFunction>::template
44 AuxiliarySplitInfo<ElemType>
71 template<
typename MatType,
typename LabelsType>
75 const size_t numClasses,
76 const size_t minimumLeafSize = 10,
77 const double minimumGainSplit = 1e-7,
78 const size_t maximumDepth = 0,
79 DimensionSelectionType dimensionSelector =
80 DimensionSelectionType());
98 template<
typename MatType,
typename LabelsType>
101 const size_t numClasses,
102 const size_t minimumLeafSize = 10,
103 const double minimumGainSplit = 1e-7,
104 const size_t maximumDepth = 0,
105 DimensionSelectionType dimensionSelector =
106 DimensionSelectionType());
127 template<
typename MatType,
typename LabelsType,
typename WeightsType>
132 const size_t numClasses,
134 const size_t minimumLeafSize = 10,
135 const double minimumGainSplit = 1e-7,
136 const size_t maximumDepth = 0,
137 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
139 typename std::remove_reference<WeightsType>::type>::value>* = 0);
159 template<
typename MatType,
typename LabelsType,
typename WeightsType>
165 const size_t numClasses,
167 const size_t minimumLeafSize = 10,
168 const double minimumGainSplit = 1e-7,
170 typename std::remove_reference<WeightsType>::type>::value>* = 0);
189 template<
typename MatType,
typename LabelsType,
typename WeightsType>
193 const size_t numClasses,
195 const size_t minimumLeafSize = 10,
196 const double minimumGainSplit = 1e-7,
197 const size_t maximumDepth = 0,
198 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
200 typename std::remove_reference<WeightsType>::type>::value>* = 0);
220 template<
typename MatType,
typename LabelsType,
typename WeightsType>
225 const size_t numClasses,
227 const size_t minimumLeafSize = 10,
228 const double minimumGainSplit = 1e-7,
229 const size_t maximumDepth = 0,
230 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
232 typename std::remove_reference<WeightsType>::type>::value>* = 0);
296 template<
typename MatType,
typename LabelsType>
297 double Train(MatType data,
300 const size_t numClasses,
301 const size_t minimumLeafSize = 10,
302 const double minimumGainSplit = 1e-7,
303 const size_t maximumDepth = 0,
304 DimensionSelectionType dimensionSelector =
305 DimensionSelectionType());
324 template<
typename MatType,
typename LabelsType>
325 double Train(MatType data,
327 const size_t numClasses,
328 const size_t minimumLeafSize = 10,
329 const double minimumGainSplit = 1e-7,
330 const size_t maximumDepth = 0,
331 DimensionSelectionType dimensionSelector =
332 DimensionSelectionType());
355 template<
typename MatType,
typename LabelsType,
typename WeightsType>
356 double Train(MatType data,
359 const size_t numClasses,
361 const size_t minimumLeafSize = 10,
362 const double minimumGainSplit = 1e-7,
363 const size_t maximumDepth = 0,
364 DimensionSelectionType dimensionSelector =
365 DimensionSelectionType(),
367 std::remove_reference<WeightsType>::type>::value>* = 0);
388 template<
typename MatType,
typename LabelsType,
typename WeightsType>
389 double Train(MatType data,
391 const size_t numClasses,
393 const size_t minimumLeafSize = 10,
394 const double minimumGainSplit = 1e-7,
395 const size_t maximumDepth = 0,
396 DimensionSelectionType dimensionSelector =
397 DimensionSelectionType(),
399 std::remove_reference<WeightsType>::type>::value>* = 0);
407 template<
typename VecType>
408 size_t Classify(
const VecType& point)
const;
419 template<
typename VecType>
422 arma::vec& probabilities)
const;
431 template<
typename MatType>
433 arma::Row<size_t>& predictions)
const;
445 template<
typename MatType>
447 arma::Row<size_t>& predictions,
448 arma::mat& probabilities)
const;
453 template<
typename Archive>
454 void serialize(Archive& ar,
const unsigned int );
475 template<
typename VecType>
485 std::vector<DecisionTree*> children;
487 size_t splitDimension;
490 size_t dimensionTypeOrMajorityClass;
498 arma::vec classProbabilities;
503 typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
504 NumericAuxiliarySplitInfo;
505 typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
506 CategoricalAuxiliarySplitInfo;
511 template<
bool UseWeights,
typename RowType,
typename WeightsRowType>
512 void CalculateClassProbabilities(
const RowType& labels,
513 const size_t numClasses,
514 const WeightsRowType& weights);
533 template<
bool UseWeights,
typename MatType>
534 double Train(MatType& data,
538 arma::Row<size_t>& labels,
539 const size_t numClasses,
540 arma::rowvec& weights,
541 const size_t minimumLeafSize,
542 const double minimumGainSplit,
543 const size_t maximumDepth,
544 DimensionSelectionType& dimensionSelector);
562 template<
bool UseWeights,
typename MatType>
563 double Train(MatType& data,
566 arma::Row<size_t>& labels,
567 const size_t numClasses,
568 arma::rowvec& weights,
569 const size_t minimumLeafSize,
570 const double minimumGainSplit,
571 const size_t maximumDepth,
572 DimensionSelectionType& dimensionSelector);
578 template<
typename FitnessFunction =
GiniGain,
582 typename ElemType =
double>
585 CategoricalSplitType,
604 #include "decision_tree_impl.hpp" size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Linear algebra utility functions, generally performed on matrices or vectors.
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, double, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.