KFoldCV< MLAlgorithm, Metric, MatType, PredictionsType, WeightsType > Class Template Reference

The class KFoldCV implements k-fold cross-validation for regression and classification algorithms. More...

Public Member Functions

 KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys)
 This constructor can be used for regression algorithms and for binary classification algorithms. More...

 
 KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses)
 This constructor can be used for multiclass classification algorithms. More...

 
 KFoldCV (const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses)
 This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter. More...

 
 KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const WeightsType &weights)
 This constructor can be used for regression and binary classification algorithms that support weighted learning. More...

 
 KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights)
 This constructor can be used for multiclass classification algorithms that support weighted learning. More...

 
 KFoldCV (const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights)
 This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter and support weighted learning. More...

 
template<typename... MLAlgorithmArgs>
double Evaluate (const MLAlgorithmArgs &...args)
 Run k-fold cross-validation. More...

 
MLAlgorithm & Model ()
 Access and modify a model from the last run of k-fold cross-validation. More...

 

Detailed Description


template<typename MLAlgorithm, typename Metric, typename MatType = arma::mat, typename PredictionsType = typename MetaInfoExtractor<MLAlgorithm, MatType>::PredictionsType, typename WeightsType = typename MetaInfoExtractor<MLAlgorithm, MatType, PredictionsType>::WeightsType>
class mlpack::cv::KFoldCV< MLAlgorithm, Metric, MatType, PredictionsType, WeightsType >

The class KFoldCV implements k-fold cross-validation for regression and classification algorithms.

To construct a KFoldCV object you need to pass the k parameter and arguments that specify data. For example, you can run 10-fold cross-validation for SoftmaxRegression in the following way.

// 100-point 5-dimensional random dataset.
arma::mat data = arma::randu<arma::mat>(5, 100);
// Random labels in the [0, 4] interval.
arma::Row<size_t> labels =
arma::randi<arma::Row<size_t>>(100, arma::distr_param(0, 4));
size_t numClasses = 5;
KFoldCV<SoftmaxRegression<>, Accuracy> cv(10, data, labels, numClasses);
double lambda = 0.1;
double softmaxAccuracy = cv.Evaluate(lambda);
Template Parameters
MLAlgorithmA machine learning algorithm.
MetricA metric to assess the quality of a trained model.
MatTypeThe type of data.
PredictionsTypeThe type of predictions (should be passed when the predictions type is a template parameter in Train methods of MLAlgorithm).
WeightsTypeThe type of weights (should be passed when weighted learning is supported, and the weights type is a template parameter in Train methods of MLAlgorithm).

Definition at line 61 of file k_fold_cv.hpp.

Constructor & Destructor Documentation

◆ KFoldCV() [1/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const PredictionsType &  ys 
)

This constructor can be used for regression algorithms and for binary classification algorithms.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
ysPredictions (labels for classification algorithms and responses for regression algorithms) for each data point.

◆ KFoldCV() [2/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const PredictionsType &  ys,
const size_t  numClasses 
)

This constructor can be used for multiclass classification algorithms.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
ysLabels for each data point.
numClassesNumber of classes in the dataset.

◆ KFoldCV() [3/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const data::DatasetInfo datasetInfo,
const PredictionsType &  ys,
const size_t  numClasses 
)

This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
datasetInfoType information for each dimension of the dataset.
ysLabels for each data point.
numClassesNumber of classes in the dataset.

◆ KFoldCV() [4/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const PredictionsType &  ys,
const WeightsType &  weights 
)

This constructor can be used for regression and binary classification algorithms that support weighted learning.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
ysPredictions (labels for classification algorithms and responses for regression algorithms) for each data point.
weightsObservation weights (for boosting).

◆ KFoldCV() [5/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const PredictionsType &  ys,
const size_t  numClasses,
const WeightsType &  weights 
)

This constructor can be used for multiclass classification algorithms that support weighted learning.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
ysLabels for each data point.
numClassesNumber of classes in the dataset.
weightsObservation weights (for boosting).

◆ KFoldCV() [6/6]

KFoldCV ( const size_t  k,
const MatType &  xs,
const data::DatasetInfo datasetInfo,
const PredictionsType &  ys,
const size_t  numClasses,
const WeightsType &  weights 
)

This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter and support weighted learning.

Parameters
kNumber of folds (should be at least 2).
xsData points to cross-validate on.
datasetInfoType information for each dimension of the dataset.
ysLabels for each data point.
numClassesNumber of classes in the dataset.
weightsObservation weights (for boosting).

Member Function Documentation

◆ Evaluate()

double Evaluate ( const MLAlgorithmArgs &...  args)

Run k-fold cross-validation.

Parameters
argsArguments for MLAlgorithm (in addition to the passed ones in the constructor).

◆ Model()

MLAlgorithm& Model ( )

Access and modify a model from the last run of k-fold cross-validation.


The documentation for this class was generated from the following file: