random_forest.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
14 
17 #include "bootstrap.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
22 template<typename FitnessFunction = GiniGain,
23  typename DimensionSelectionType = MultipleRandomDimensionSelect<>,
24  template<typename> class NumericSplitType = BestBinaryNumericSplit,
25  template<typename> class CategoricalSplitType = AllCategoricalSplit,
26  typename ElemType = double>
28 {
29  public:
31  typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
32  DimensionSelectionType, ElemType> DecisionTreeType;
33 
39 
51  template<typename MatType>
52  RandomForest(const MatType& dataset,
53  const arma::Row<size_t>& labels,
54  const size_t numClasses,
55  const size_t numTrees = 50,
56  const size_t minimumLeafSize = 20);
57 
71  template<typename MatType>
72  RandomForest(const MatType& dataset,
73  const data::DatasetInfo& datasetInfo,
74  const arma::Row<size_t>& labels,
75  const size_t numClasses,
76  const size_t numTrees = 50,
77  const size_t minimumLeafSize = 20);
78 
91  template<typename MatType>
92  RandomForest(const MatType& dataset,
93  const arma::Row<size_t>& labels,
94  const size_t numClasses,
95  const arma::rowvec& weights,
96  const size_t numTrees = 50,
97  const size_t minimumLeafSize = 20);
98 
113  template<typename MatType>
114  RandomForest(const MatType& dataset,
115  const data::DatasetInfo& datasetInfo,
116  const arma::Row<size_t>& labels,
117  const size_t numClasses,
118  const arma::rowvec& weights,
119  const size_t numTrees = 50,
120  const size_t minimumLeafSize = 20);
121 
133  template<typename MatType>
134  void Train(const MatType& data,
135  const arma::Row<size_t>& labels,
136  const size_t numClasses,
137  const size_t numTrees = 50,
138  const size_t minimumLeafSize = 20);
139 
153  template<typename MatType>
154  void Train(const MatType& data,
155  const data::DatasetInfo& datasetInfo,
156  const arma::Row<size_t>& labels,
157  const size_t numClasses,
158  const size_t numTrees = 50,
159  const size_t minimumLeafSize = 20);
160 
173  template<typename MatType>
174  void Train(const MatType& data,
175  const arma::Row<size_t>& labels,
176  const size_t numClasses,
177  const arma::rowvec& weights,
178  const size_t numTrees = 50,
179  const size_t minimumLeafSize = 20);
180 
195  template<typename MatType>
196  void Train(const MatType& data,
197  const data::DatasetInfo& datasetInfo,
198  const arma::Row<size_t>& labels,
199  const size_t numClasses,
200  const arma::rowvec& weights,
201  const size_t numTrees = 50,
202  const size_t minimumLeafSize = 20);
203 
210  template<typename VecType>
211  size_t Classify(const VecType& point) const;
212 
222  template<typename VecType>
223  void Classify(const VecType& point,
224  size_t& prediction,
225  arma::vec& probabilities) const;
226 
234  template<typename MatType>
235  void Classify(const MatType& data,
236  arma::Row<size_t>& predictions) const;
237 
247  template<typename MatType>
248  void Classify(const MatType& data,
249  arma::Row<size_t>& predictions,
250  arma::mat& probabilities) const;
251 
253  const DecisionTreeType& Tree(const size_t i) const { return trees[i]; }
255  DecisionTreeType& Tree(const size_t i) { return trees[i]; }
256 
258  size_t NumTrees() const { return trees.size(); }
259 
263  template<typename Archive>
264  void serialize(Archive& ar, const unsigned int /* version */);
265 
266  private:
283  template<bool UseWeights, bool UseDatasetInfo, typename MatType>
284  void Train(const MatType& data,
285  const data::DatasetInfo& datasetInfo,
286  const arma::Row<size_t>& labels,
287  const size_t numClasses,
288  const arma::rowvec& weights,
289  const size_t numTrees,
290  const size_t minimumLeafSize);
291 
293  std::vector<DecisionTreeType> trees;
294 };
295 
296 } // namespace tree
297 } // namespace mlpack
298 
299 // Include implementation.
300 #include "random_forest_impl.hpp"
301 
302 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
size_t NumTrees() const
Get the number of trees in the forest.
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
.hpp
Definition: add_to_po.hpp:21
This class implements a generic decision tree learner.
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=50, const size_t minimumLeafSize=20)
Train the random forest on the given labeled training data with the given number of trees...
RandomForest()
Construct the random forest without any training or specifying the number of trees.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
size_t Classify(const VecType &point) const
Predict the class of the given point.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.