12 #ifndef MLPACK_CORE_CV_K_FOLD_CV_HPP 13 #define MLPACK_CORE_CV_K_FOLD_CV_HPP 53 template<
typename MLAlgorithm,
55 typename MatType = arma::mat,
56 typename PredictionsType =
58 typename WeightsType =
59 typename MetaInfoExtractor<MLAlgorithm, MatType,
60 PredictionsType>::WeightsType>
75 const PredictionsType& ys);
87 const PredictionsType& ys,
88 const size_t numClasses);
103 const PredictionsType& ys,
104 const size_t numClasses);
118 const PredictionsType& ys,
119 const WeightsType& weights);
133 const PredictionsType& ys,
134 const size_t numClasses,
135 const WeightsType& weights);
151 const PredictionsType& ys,
152 const size_t numClasses,
153 const WeightsType& weights);
161 template<
typename... MLAlgorithmArgs>
162 double Evaluate(
const MLAlgorithmArgs& ...args);
165 MLAlgorithm&
Model();
188 size_t trainingSubsetSize;
191 std::unique_ptr<MLAlgorithm> modelPtr;
200 const PredictionsType& ys);
209 const PredictionsType& ys,
210 const WeightsType& weights);
216 template<
typename DataType>
217 void InitKFoldCVMat(
const DataType& source, DataType& destination);
222 template<
typename...MLAlgorithmArgs,
224 typename =
typename std::enable_if<Enabled>::type>
225 double TrainAndEvaluate(
const MLAlgorithmArgs& ...mlAlgorithmArgs);
230 template<
typename...MLAlgorithmArgs,
232 typename =
typename std::enable_if<Enabled>::type,
234 double TrainAndEvaluate(
const MLAlgorithmArgs& ...mlAlgorithmArgs);
242 inline size_t ValidationSubsetFirstCol(
const size_t i);
247 template<
typename ElementType>
248 inline arma::Mat<ElementType> GetTrainingSubset(arma::Mat<ElementType>& m,
254 template<
typename ElementType>
255 inline arma::Row<ElementType> GetTrainingSubset(arma::Row<ElementType>& r,
261 template<
typename ElementType>
262 inline arma::Mat<ElementType> GetValidationSubset(arma::Mat<ElementType>& m,
268 template<
typename ElementType>
269 inline arma::Row<ElementType> GetValidationSubset(arma::Row<ElementType>& r,
277 #include "k_fold_cv_impl.hpp" double Evaluate(const MLAlgorithmArgs &...args)
Run k-fold cross-validation.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
MLAlgorithm & Model()
Access and modify a model from the last run of k-fold cross-validation.
The class KFoldCV implements k-fold cross-validation for regression and classification algorithms...
An auxiliary class for cross-validation.
KFoldCV(const size_t k, const MatType &xs, const PredictionsType &ys)
This constructor can be used for regression algorithms and for binary classification algorithms...