random_forest.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
14 
17 #include "bootstrap.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
22 template<typename FitnessFunction = GiniGain,
23  typename DimensionSelectionType = MultipleRandomDimensionSelect,
24  template<typename> class NumericSplitType = BestBinaryNumericSplit,
25  template<typename> class CategoricalSplitType = AllCategoricalSplit,
26  typename ElemType = double>
28 {
29  public:
31  typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
32  DimensionSelectionType, ElemType> DecisionTreeType;
33 
39 
56  template<typename MatType>
57  RandomForest(const MatType& dataset,
58  const arma::Row<size_t>& labels,
59  const size_t numClasses,
60  const size_t numTrees = 20,
61  const size_t minimumLeafSize = 1,
62  const double minimumGainSplit = 1e-7,
63  const size_t maximumDepth = 0,
64  DimensionSelectionType dimensionSelector =
65  DimensionSelectionType());
66 
85  template<typename MatType>
86  RandomForest(const MatType& dataset,
87  const data::DatasetInfo& datasetInfo,
88  const arma::Row<size_t>& labels,
89  const size_t numClasses,
90  const size_t numTrees = 20,
91  const size_t minimumLeafSize = 1,
92  const double minimumGainSplit = 1e-7,
93  const size_t maximumDepth = 0,
94  DimensionSelectionType dimensionSelector =
95  DimensionSelectionType());
96 
112  template<typename MatType>
113  RandomForest(const MatType& dataset,
114  const arma::Row<size_t>& labels,
115  const size_t numClasses,
116  const arma::rowvec& weights,
117  const size_t numTrees = 20,
118  const size_t minimumLeafSize = 1,
119  const double minimumGainSplit = 1e-7,
120  const size_t maximumDepth = 0,
121  DimensionSelectionType dimensionSelector =
122  DimensionSelectionType());
123 
143  template<typename MatType>
144  RandomForest(const MatType& dataset,
145  const data::DatasetInfo& datasetInfo,
146  const arma::Row<size_t>& labels,
147  const size_t numClasses,
148  const arma::rowvec& weights,
149  const size_t numTrees = 20,
150  const size_t minimumLeafSize = 1,
151  const double minimumGainSplit = 1e-7,
152  const size_t maximumDepth = 0,
153  DimensionSelectionType dimensionSelector =
154  DimensionSelectionType());
155 
173  template<typename MatType>
174  double Train(const MatType& data,
175  const arma::Row<size_t>& labels,
176  const size_t numClasses,
177  const size_t numTrees = 20,
178  const size_t minimumLeafSize = 1,
179  const double minimumGainSplit = 1e-7,
180  const size_t maximumDepth = 0,
181  DimensionSelectionType dimensionSelector =
182  DimensionSelectionType());
183 
204  template<typename MatType>
205  double Train(const MatType& data,
206  const data::DatasetInfo& datasetInfo,
207  const arma::Row<size_t>& labels,
208  const size_t numClasses,
209  const size_t numTrees = 20,
210  const size_t minimumLeafSize = 1,
211  const double minimumGainSplit = 1e-7,
212  const size_t maximumDepth = 0,
213  DimensionSelectionType dimensionSelector =
214  DimensionSelectionType());
215 
234  template<typename MatType>
235  double Train(const MatType& data,
236  const arma::Row<size_t>& labels,
237  const size_t numClasses,
238  const arma::rowvec& weights,
239  const size_t numTrees = 20,
240  const size_t minimumLeafSize = 1,
241  const double minimumGainSplit = 1e-7,
242  const size_t maximumDepth = 0,
243  DimensionSelectionType dimensionSelector =
244  DimensionSelectionType());
245 
266  template<typename MatType>
267  double Train(const MatType& data,
268  const data::DatasetInfo& datasetInfo,
269  const arma::Row<size_t>& labels,
270  const size_t numClasses,
271  const arma::rowvec& weights,
272  const size_t numTrees = 20,
273  const size_t minimumLeafSize = 1,
274  const double minimumGainSplit = 1e-7,
275  const size_t maximumDepth = 0,
276  DimensionSelectionType dimensionSelector =
277  DimensionSelectionType());
278 
285  template<typename VecType>
286  size_t Classify(const VecType& point) const;
287 
297  template<typename VecType>
298  void Classify(const VecType& point,
299  size_t& prediction,
300  arma::vec& probabilities) const;
301 
309  template<typename MatType>
310  void Classify(const MatType& data,
311  arma::Row<size_t>& predictions) const;
312 
322  template<typename MatType>
323  void Classify(const MatType& data,
324  arma::Row<size_t>& predictions,
325  arma::mat& probabilities) const;
326 
328  const DecisionTreeType& Tree(const size_t i) const { return trees[i]; }
330  DecisionTreeType& Tree(const size_t i) { return trees[i]; }
331 
333  size_t NumTrees() const { return trees.size(); }
334 
338  template<typename Archive>
339  void serialize(Archive& ar, const unsigned int /* version */);
340 
341  private:
362  template<bool UseWeights, bool UseDatasetInfo, typename MatType>
363  double Train(const MatType& data,
364  const data::DatasetInfo& datasetInfo,
365  const arma::Row<size_t>& labels,
366  const size_t numClasses,
367  const arma::rowvec& weights,
368  const size_t numTrees,
369  const size_t minimumLeafSize,
370  const double minimumGainSplit,
371  const size_t maximumDepth,
372  DimensionSelectionType& dimensionSelector);
373 
375  std::vector<DecisionTreeType> trees;
376 };
377 
378 } // namespace tree
379 } // namespace mlpack
380 
381 // Include implementation.
382 #include "random_forest_impl.hpp"
383 
384 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
size_t NumTrees() const
Get the number of trees in the forest.
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
.hpp
Definition: add_to_po.hpp:21
This class implements a generic decision tree learner.
RandomForest()
Construct the random forest without any training or specifying the number of trees.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
size_t Classify(const VecType &point) const
Predict the class of the given point.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given number of trees...