random_forest.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
14 
17 #include "bootstrap.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
22 template<typename FitnessFunction = GiniGain,
23  typename DimensionSelectionType = MultipleRandomDimensionSelect,
24  template<typename> class NumericSplitType = BestBinaryNumericSplit,
25  template<typename> class CategoricalSplitType = AllCategoricalSplit,
26  typename ElemType = double>
28 {
29  public:
31  typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
32  DimensionSelectionType, ElemType> DecisionTreeType;
33 
39 
55  template<typename MatType>
56  RandomForest(const MatType& dataset,
57  const arma::Row<size_t>& labels,
58  const size_t numClasses,
59  const size_t numTrees = 20,
60  const size_t minimumLeafSize = 1,
61  const double minimumGainSplit = 1e-7,
62  DimensionSelectionType dimensionSelector =
63  DimensionSelectionType());
64 
82  template<typename MatType>
83  RandomForest(const MatType& dataset,
84  const data::DatasetInfo& datasetInfo,
85  const arma::Row<size_t>& labels,
86  const size_t numClasses,
87  const size_t numTrees = 20,
88  const size_t minimumLeafSize = 1,
89  const double minimumGainSplit = 1e-7,
90  DimensionSelectionType dimensionSelector =
91  DimensionSelectionType());
92 
105  template<typename MatType>
106  RandomForest(const MatType& dataset,
107  const arma::Row<size_t>& labels,
108  const size_t numClasses,
109  const arma::rowvec& weights,
110  const size_t numTrees = 20,
111  const size_t minimumLeafSize = 1,
112  const double minimumGainSplit = 1e-7,
113  DimensionSelectionType dimensionSelector =
114  DimensionSelectionType());
115 
134  template<typename MatType>
135  RandomForest(const MatType& dataset,
136  const data::DatasetInfo& datasetInfo,
137  const arma::Row<size_t>& labels,
138  const size_t numClasses,
139  const arma::rowvec& weights,
140  const size_t numTrees = 20,
141  const size_t minimumLeafSize = 1,
142  const double minimumGainSplit = 1e-7,
143  DimensionSelectionType dimensionSelector =
144  DimensionSelectionType());
145 
162  template<typename MatType>
163  double Train(const MatType& data,
164  const arma::Row<size_t>& labels,
165  const size_t numClasses,
166  const size_t numTrees = 20,
167  const size_t minimumLeafSize = 1,
168  const double minimumGainSplit = 1e-7,
169  DimensionSelectionType dimensionSelector =
170  DimensionSelectionType());
171 
191  template<typename MatType>
192  double Train(const MatType& data,
193  const data::DatasetInfo& datasetInfo,
194  const arma::Row<size_t>& labels,
195  const size_t numClasses,
196  const size_t numTrees = 20,
197  const size_t minimumLeafSize = 1,
198  const double minimumGainSplit = 1e-7,
199  DimensionSelectionType dimensionSelector =
200  DimensionSelectionType());
201 
219  template<typename MatType>
220  double Train(const MatType& data,
221  const arma::Row<size_t>& labels,
222  const size_t numClasses,
223  const arma::rowvec& weights,
224  const size_t numTrees = 20,
225  const size_t minimumLeafSize = 1,
226  const double minimumGainSplit = 1e-7,
227  DimensionSelectionType dimensionSelector =
228  DimensionSelectionType());
229 
249  template<typename MatType>
250  double Train(const MatType& data,
251  const data::DatasetInfo& datasetInfo,
252  const arma::Row<size_t>& labels,
253  const size_t numClasses,
254  const arma::rowvec& weights,
255  const size_t numTrees = 20,
256  const size_t minimumLeafSize = 1,
257  const double minimumGainSplit = 1e-7,
258  DimensionSelectionType dimensionSelector =
259  DimensionSelectionType());
260 
267  template<typename VecType>
268  size_t Classify(const VecType& point) const;
269 
279  template<typename VecType>
280  void Classify(const VecType& point,
281  size_t& prediction,
282  arma::vec& probabilities) const;
283 
291  template<typename MatType>
292  void Classify(const MatType& data,
293  arma::Row<size_t>& predictions) const;
294 
304  template<typename MatType>
305  void Classify(const MatType& data,
306  arma::Row<size_t>& predictions,
307  arma::mat& probabilities) const;
308 
310  const DecisionTreeType& Tree(const size_t i) const { return trees[i]; }
312  DecisionTreeType& Tree(const size_t i) { return trees[i]; }
313 
315  size_t NumTrees() const { return trees.size(); }
316 
320  template<typename Archive>
321  void serialize(Archive& ar, const unsigned int /* version */);
322 
323  private:
343  template<bool UseWeights, bool UseDatasetInfo, typename MatType>
344  double Train(const MatType& data,
345  const data::DatasetInfo& datasetInfo,
346  const arma::Row<size_t>& labels,
347  const size_t numClasses,
348  const arma::rowvec& weights,
349  const size_t numTrees,
350  const size_t minimumLeafSize,
351  const double minimumGainSplit,
352  DimensionSelectionType& dimensionSelector);
353 
355  std::vector<DecisionTreeType> trees;
356 };
357 
358 } // namespace tree
359 } // namespace mlpack
360 
361 // Include implementation.
362 #include "random_forest_impl.hpp"
363 
364 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
size_t NumTrees() const
Get the number of trees in the forest.
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
.hpp
Definition: add_to_po.hpp:21
This class implements a generic decision tree learner.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given number of trees...
RandomForest()
Construct the random forest without any training or specifying the number of trees.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
size_t Classify(const VecType &point) const
Predict the class of the given point.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.