12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP 13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP 22 template<
typename FitnessFunction = GiniGain,
23 typename DimensionSelectionType = MultipleRandomDimensionSelect,
24 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
25 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
26 typename ElemType =
double>
31 typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
55 template<
typename MatType>
57 const arma::Row<size_t>& labels,
58 const size_t numClasses,
59 const size_t numTrees = 20,
60 const size_t minimumLeafSize = 1,
61 const double minimumGainSplit = 1e-7,
62 DimensionSelectionType dimensionSelector =
63 DimensionSelectionType());
82 template<
typename MatType>
85 const arma::Row<size_t>& labels,
86 const size_t numClasses,
87 const size_t numTrees = 20,
88 const size_t minimumLeafSize = 1,
89 const double minimumGainSplit = 1e-7,
90 DimensionSelectionType dimensionSelector =
91 DimensionSelectionType());
105 template<
typename MatType>
107 const arma::Row<size_t>& labels,
108 const size_t numClasses,
109 const arma::rowvec& weights,
110 const size_t numTrees = 20,
111 const size_t minimumLeafSize = 1,
112 const double minimumGainSplit = 1e-7,
113 DimensionSelectionType dimensionSelector =
114 DimensionSelectionType());
134 template<
typename MatType>
137 const arma::Row<size_t>& labels,
138 const size_t numClasses,
139 const arma::rowvec& weights,
140 const size_t numTrees = 20,
141 const size_t minimumLeafSize = 1,
142 const double minimumGainSplit = 1e-7,
143 DimensionSelectionType dimensionSelector =
144 DimensionSelectionType());
162 template<
typename MatType>
163 double Train(
const MatType& data,
164 const arma::Row<size_t>& labels,
165 const size_t numClasses,
166 const size_t numTrees = 20,
167 const size_t minimumLeafSize = 1,
168 const double minimumGainSplit = 1e-7,
169 DimensionSelectionType dimensionSelector =
170 DimensionSelectionType());
191 template<
typename MatType>
192 double Train(
const MatType& data,
194 const arma::Row<size_t>& labels,
195 const size_t numClasses,
196 const size_t numTrees = 20,
197 const size_t minimumLeafSize = 1,
198 const double minimumGainSplit = 1e-7,
199 DimensionSelectionType dimensionSelector =
200 DimensionSelectionType());
219 template<
typename MatType>
220 double Train(
const MatType& data,
221 const arma::Row<size_t>& labels,
222 const size_t numClasses,
223 const arma::rowvec& weights,
224 const size_t numTrees = 20,
225 const size_t minimumLeafSize = 1,
226 const double minimumGainSplit = 1e-7,
227 DimensionSelectionType dimensionSelector =
228 DimensionSelectionType());
249 template<
typename MatType>
250 double Train(
const MatType& data,
252 const arma::Row<size_t>& labels,
253 const size_t numClasses,
254 const arma::rowvec& weights,
255 const size_t numTrees = 20,
256 const size_t minimumLeafSize = 1,
257 const double minimumGainSplit = 1e-7,
258 DimensionSelectionType dimensionSelector =
259 DimensionSelectionType());
267 template<
typename VecType>
268 size_t Classify(
const VecType& point)
const;
279 template<
typename VecType>
282 arma::vec& probabilities)
const;
291 template<
typename MatType>
293 arma::Row<size_t>& predictions)
const;
304 template<
typename MatType>
306 arma::Row<size_t>& predictions,
307 arma::mat& probabilities)
const;
310 const DecisionTreeType&
Tree(
const size_t i)
const {
return trees[i]; }
312 DecisionTreeType&
Tree(
const size_t i) {
return trees[i]; }
320 template<
typename Archive>
321 void serialize(Archive& ar,
const unsigned int );
343 template<
bool UseWeights,
bool UseDatasetInfo,
typename MatType>
344 double Train(
const MatType& data,
346 const arma::Row<size_t>& labels,
347 const size_t numClasses,
348 const arma::rowvec& weights,
349 const size_t numTrees,
350 const size_t minimumLeafSize,
351 const double minimumGainSplit,
352 DimensionSelectionType& dimensionSelector);
355 std::vector<DecisionTreeType> trees;
362 #include "random_forest_impl.hpp"
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
size_t NumTrees() const
Get the number of trees in the forest.
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
This class implements a generic decision tree learner.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given number of trees...
RandomForest()
Construct the random forest without any training or specifying the number of trees.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
size_t Classify(const VecType &point) const
Predict the class of the given point.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.