12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP 13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP 22 template<
typename FitnessFunction = GiniGain,
23 typename DimensionSelectionType = MultipleRandomDimensionSelect,
24 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
25 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
26 typename ElemType =
double>
31 typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
56 template<
typename MatType>
58 const arma::Row<size_t>& labels,
59 const size_t numClasses,
60 const size_t numTrees = 20,
61 const size_t minimumLeafSize = 1,
62 const double minimumGainSplit = 1e-7,
63 const size_t maximumDepth = 0,
64 DimensionSelectionType dimensionSelector =
65 DimensionSelectionType());
85 template<
typename MatType>
88 const arma::Row<size_t>& labels,
89 const size_t numClasses,
90 const size_t numTrees = 20,
91 const size_t minimumLeafSize = 1,
92 const double minimumGainSplit = 1e-7,
93 const size_t maximumDepth = 0,
94 DimensionSelectionType dimensionSelector =
95 DimensionSelectionType());
112 template<
typename MatType>
114 const arma::Row<size_t>& labels,
115 const size_t numClasses,
116 const arma::rowvec& weights,
117 const size_t numTrees = 20,
118 const size_t minimumLeafSize = 1,
119 const double minimumGainSplit = 1e-7,
120 const size_t maximumDepth = 0,
121 DimensionSelectionType dimensionSelector =
122 DimensionSelectionType());
143 template<
typename MatType>
146 const arma::Row<size_t>& labels,
147 const size_t numClasses,
148 const arma::rowvec& weights,
149 const size_t numTrees = 20,
150 const size_t minimumLeafSize = 1,
151 const double minimumGainSplit = 1e-7,
152 const size_t maximumDepth = 0,
153 DimensionSelectionType dimensionSelector =
154 DimensionSelectionType());
173 template<
typename MatType>
174 double Train(
const MatType& data,
175 const arma::Row<size_t>& labels,
176 const size_t numClasses,
177 const size_t numTrees = 20,
178 const size_t minimumLeafSize = 1,
179 const double minimumGainSplit = 1e-7,
180 const size_t maximumDepth = 0,
181 DimensionSelectionType dimensionSelector =
182 DimensionSelectionType());
204 template<
typename MatType>
205 double Train(
const MatType& data,
207 const arma::Row<size_t>& labels,
208 const size_t numClasses,
209 const size_t numTrees = 20,
210 const size_t minimumLeafSize = 1,
211 const double minimumGainSplit = 1e-7,
212 const size_t maximumDepth = 0,
213 DimensionSelectionType dimensionSelector =
214 DimensionSelectionType());
234 template<
typename MatType>
235 double Train(
const MatType& data,
236 const arma::Row<size_t>& labels,
237 const size_t numClasses,
238 const arma::rowvec& weights,
239 const size_t numTrees = 20,
240 const size_t minimumLeafSize = 1,
241 const double minimumGainSplit = 1e-7,
242 const size_t maximumDepth = 0,
243 DimensionSelectionType dimensionSelector =
244 DimensionSelectionType());
266 template<
typename MatType>
267 double Train(
const MatType& data,
269 const arma::Row<size_t>& labels,
270 const size_t numClasses,
271 const arma::rowvec& weights,
272 const size_t numTrees = 20,
273 const size_t minimumLeafSize = 1,
274 const double minimumGainSplit = 1e-7,
275 const size_t maximumDepth = 0,
276 DimensionSelectionType dimensionSelector =
277 DimensionSelectionType());
285 template<
typename VecType>
286 size_t Classify(
const VecType& point)
const;
297 template<
typename VecType>
300 arma::vec& probabilities)
const;
309 template<
typename MatType>
311 arma::Row<size_t>& predictions)
const;
322 template<
typename MatType>
324 arma::Row<size_t>& predictions,
325 arma::mat& probabilities)
const;
328 const DecisionTreeType&
Tree(
const size_t i)
const {
return trees[i]; }
330 DecisionTreeType&
Tree(
const size_t i) {
return trees[i]; }
338 template<
typename Archive>
339 void serialize(Archive& ar,
const unsigned int );
362 template<
bool UseWeights,
bool UseDatasetInfo,
typename MatType>
363 double Train(
const MatType& data,
365 const arma::Row<size_t>& labels,
366 const size_t numClasses,
367 const arma::rowvec& weights,
368 const size_t numTrees,
369 const size_t minimumLeafSize,
370 const double minimumGainSplit,
371 const size_t maximumDepth,
372 DimensionSelectionType& dimensionSelector);
375 std::vector<DecisionTreeType> trees;
382 #include "random_forest_impl.hpp"
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
size_t NumTrees() const
Get the number of trees in the forest.
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
This class implements a generic decision tree learner.
RandomForest()
Construct the random forest without any training or specifying the number of trees.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
size_t Classify(const VecType &point) const
Predict the class of the given point.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given number of trees...