|
| | KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys) |
| | This constructor can be used for regression algorithms and for binary classification algorithms. More...
|
| |
| | KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses) |
| | This constructor can be used for multiclass classification algorithms. More...
|
| |
| | KFoldCV (const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses) |
| | This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter. More...
|
| |
| | KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const WeightsType &weights) |
| | This constructor can be used for regression and binary classification algorithms that support weighted learning. More...
|
| |
| | KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights) |
| | This constructor can be used for multiclass classification algorithms that support weighted learning. More...
|
| |
| | KFoldCV (const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights) |
| | This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter and support weighted learning. More...
|
| |
| template<typename... MLAlgorithmArgs> |
| double | Evaluate (const MLAlgorithmArgs &...args) |
| | Run k-fold cross-validation. More...
|
| |
| MLAlgorithm & | Model () |
| | Access and modify a model from the last run of k-fold cross-validation. More...
|
| |
template<typename MLAlgorithm, typename Metric, typename MatType = arma::mat, typename PredictionsType = typename MetaInfoExtractor<MLAlgorithm, MatType>::PredictionsType, typename WeightsType = typename MetaInfoExtractor<MLAlgorithm, MatType, PredictionsType>::WeightsType>
class mlpack::cv::KFoldCV< MLAlgorithm, Metric, MatType, PredictionsType, WeightsType >
The class KFoldCV implements k-fold cross-validation for regression and classification algorithms.
To construct a KFoldCV object you need to pass the k parameter and arguments that specify data. For example, you can run 10-fold cross-validation for SoftmaxRegression in the following way.
arma::mat data = arma::randu<arma::mat>(5, 100);
arma::Row<size_t> labels =
arma::randi<arma::Row<size_t>>(100, arma::distr_param(0, 4));
size_t numClasses = 5;
KFoldCV<SoftmaxRegression<>, Accuracy> cv(10, data, labels, numClasses);
double lambda = 0.1;
double softmaxAccuracy = cv.Evaluate(lambda);
- Template Parameters
-
| MLAlgorithm | A machine learning algorithm. |
| Metric | A metric to assess the quality of a trained model. |
| MatType | The type of data. |
| PredictionsType | The type of predictions (should be passed when the predictions type is a template parameter in Train methods of MLAlgorithm). |
| WeightsType | The type of weights (should be passed when weighted learning is supported, and the weights type is a template parameter in Train methods of MLAlgorithm). |
Definition at line 61 of file k_fold_cv.hpp.