12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP 13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP 22 template<
typename FitnessFunction = GiniGain,
23 typename DimensionSelectionType = MultipleRandomDimensionSelect<>,
24 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
25 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
26 typename ElemType =
double>
31 typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
51 template<
typename MatType>
53 const arma::Row<size_t>& labels,
54 const size_t numClasses,
55 const size_t numTrees = 50,
56 const size_t minimumLeafSize = 20);
71 template<
typename MatType>
74 const arma::Row<size_t>& labels,
75 const size_t numClasses,
76 const size_t numTrees = 50,
77 const size_t minimumLeafSize = 20);
91 template<
typename MatType>
93 const arma::Row<size_t>& labels,
94 const size_t numClasses,
95 const arma::rowvec& weights,
96 const size_t numTrees = 50,
97 const size_t minimumLeafSize = 20);
113 template<
typename MatType>
116 const arma::Row<size_t>& labels,
117 const size_t numClasses,
118 const arma::rowvec& weights,
119 const size_t numTrees = 50,
120 const size_t minimumLeafSize = 20);
134 template<
typename MatType>
135 double Train(
const MatType& data,
136 const arma::Row<size_t>& labels,
137 const size_t numClasses,
138 const size_t numTrees = 50,
139 const size_t minimumLeafSize = 20);
155 template<
typename MatType>
156 double Train(
const MatType& data,
158 const arma::Row<size_t>& labels,
159 const size_t numClasses,
160 const size_t numTrees = 50,
161 const size_t minimumLeafSize = 20);
176 template<
typename MatType>
177 double Train(
const MatType& data,
178 const arma::Row<size_t>& labels,
179 const size_t numClasses,
180 const arma::rowvec& weights,
181 const size_t numTrees = 50,
182 const size_t minimumLeafSize = 20);
199 template<
typename MatType>
200 double Train(
const MatType& data,
202 const arma::Row<size_t>& labels,
203 const size_t numClasses,
204 const arma::rowvec& weights,
205 const size_t numTrees = 50,
206 const size_t minimumLeafSize = 20);
214 template<
typename VecType>
215 size_t Classify(
const VecType& point)
const;
226 template<
typename VecType>
229 arma::vec& probabilities)
const;
238 template<
typename MatType>
240 arma::Row<size_t>& predictions)
const;
251 template<
typename MatType>
253 arma::Row<size_t>& predictions,
254 arma::mat& probabilities)
const;
257 const DecisionTreeType&
Tree(
const size_t i)
const {
return trees[i]; }
259 DecisionTreeType&
Tree(
const size_t i) {
return trees[i]; }
267 template<
typename Archive>
268 void serialize(Archive& ar,
const unsigned int );
288 template<
bool UseWeights,
bool UseDatasetInfo,
typename MatType>
289 double Train(
const MatType& data,
291 const arma::Row<size_t>& labels,
292 const size_t numClasses,
293 const arma::rowvec& weights,
294 const size_t numTrees,
295 const size_t minimumLeafSize);
298 std::vector<DecisionTreeType> trees;
305 #include "random_forest_impl.hpp"
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
size_t NumTrees() const
Get the number of trees in the forest.
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
This class implements a generic decision tree learner.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=50, const size_t minimumLeafSize=20)
Train the random forest on the given labeled training data with the given number of trees...
RandomForest()
Construct the random forest without any training or specifying the number of trees.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
size_t Classify(const VecType &point) const
Predict the class of the given point.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.