random_forest.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
14 
17 #include "bootstrap.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
22 template<typename FitnessFunction = GiniGain,
23  typename DimensionSelectionType = MultipleRandomDimensionSelect<>,
24  template<typename> class NumericSplitType = BestBinaryNumericSplit,
25  template<typename> class CategoricalSplitType = AllCategoricalSplit,
26  typename ElemType = double>
28 {
29  public:
31  typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
32  DimensionSelectionType, ElemType> DecisionTreeType;
33 
39 
51  template<typename MatType>
52  RandomForest(const MatType& dataset,
53  const arma::Row<size_t>& labels,
54  const size_t numClasses,
55  const size_t numTrees = 50,
56  const size_t minimumLeafSize = 20);
57 
71  template<typename MatType>
72  RandomForest(const MatType& dataset,
73  const data::DatasetInfo& datasetInfo,
74  const arma::Row<size_t>& labels,
75  const size_t numClasses,
76  const size_t numTrees = 50,
77  const size_t minimumLeafSize = 20);
78 
91  template<typename MatType>
92  RandomForest(const MatType& dataset,
93  const arma::Row<size_t>& labels,
94  const size_t numClasses,
95  const arma::rowvec& weights,
96  const size_t numTrees = 50,
97  const size_t minimumLeafSize = 20);
98 
113  template<typename MatType>
114  RandomForest(const MatType& dataset,
115  const data::DatasetInfo& datasetInfo,
116  const arma::Row<size_t>& labels,
117  const size_t numClasses,
118  const arma::rowvec& weights,
119  const size_t numTrees = 50,
120  const size_t minimumLeafSize = 20);
121 
134  template<typename MatType>
135  double Train(const MatType& data,
136  const arma::Row<size_t>& labels,
137  const size_t numClasses,
138  const size_t numTrees = 50,
139  const size_t minimumLeafSize = 20);
140 
155  template<typename MatType>
156  double Train(const MatType& data,
157  const data::DatasetInfo& datasetInfo,
158  const arma::Row<size_t>& labels,
159  const size_t numClasses,
160  const size_t numTrees = 50,
161  const size_t minimumLeafSize = 20);
162 
176  template<typename MatType>
177  double Train(const MatType& data,
178  const arma::Row<size_t>& labels,
179  const size_t numClasses,
180  const arma::rowvec& weights,
181  const size_t numTrees = 50,
182  const size_t minimumLeafSize = 20);
183 
199  template<typename MatType>
200  double Train(const MatType& data,
201  const data::DatasetInfo& datasetInfo,
202  const arma::Row<size_t>& labels,
203  const size_t numClasses,
204  const arma::rowvec& weights,
205  const size_t numTrees = 50,
206  const size_t minimumLeafSize = 20);
207 
214  template<typename VecType>
215  size_t Classify(const VecType& point) const;
216 
226  template<typename VecType>
227  void Classify(const VecType& point,
228  size_t& prediction,
229  arma::vec& probabilities) const;
230 
238  template<typename MatType>
239  void Classify(const MatType& data,
240  arma::Row<size_t>& predictions) const;
241 
251  template<typename MatType>
252  void Classify(const MatType& data,
253  arma::Row<size_t>& predictions,
254  arma::mat& probabilities) const;
255 
257  const DecisionTreeType& Tree(const size_t i) const { return trees[i]; }
259  DecisionTreeType& Tree(const size_t i) { return trees[i]; }
260 
262  size_t NumTrees() const { return trees.size(); }
263 
267  template<typename Archive>
268  void serialize(Archive& ar, const unsigned int /* version */);
269 
270  private:
288  template<bool UseWeights, bool UseDatasetInfo, typename MatType>
289  double Train(const MatType& data,
290  const data::DatasetInfo& datasetInfo,
291  const arma::Row<size_t>& labels,
292  const size_t numClasses,
293  const arma::rowvec& weights,
294  const size_t numTrees,
295  const size_t minimumLeafSize);
296 
298  std::vector<DecisionTreeType> trees;
299 };
300 
301 } // namespace tree
302 } // namespace mlpack
303 
304 // Include implementation.
305 #include "random_forest_impl.hpp"
306 
307 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
size_t NumTrees() const
Get the number of trees in the forest.
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
.hpp
Definition: add_to_po.hpp:21
This class implements a generic decision tree learner.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=50, const size_t minimumLeafSize=20)
Train the random forest on the given labeled training data with the given number of trees...
RandomForest()
Construct the random forest without any training or specifying the number of trees.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
size_t Classify(const VecType &point) const
Predict the class of the given point.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.