k_fold_cv.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_CV_K_FOLD_CV_HPP
13 #define MLPACK_CORE_CV_K_FOLD_CV_HPP
14 
17 
18 namespace mlpack {
19 namespace cv {
20 
53 template<typename MLAlgorithm,
54  typename Metric,
55  typename MatType = arma::mat,
56  typename PredictionsType =
58  typename WeightsType =
59  typename MetaInfoExtractor<MLAlgorithm, MatType,
60  PredictionsType>::WeightsType>
61 class KFoldCV
62 {
63  public:
73  KFoldCV(const size_t k,
74  const MatType& xs,
75  const PredictionsType& ys);
76 
85  KFoldCV(const size_t k,
86  const MatType& xs,
87  const PredictionsType& ys,
88  const size_t numClasses);
89 
100  KFoldCV(const size_t k,
101  const MatType& xs,
102  const data::DatasetInfo& datasetInfo,
103  const PredictionsType& ys,
104  const size_t numClasses);
105 
116  KFoldCV(const size_t k,
117  const MatType& xs,
118  const PredictionsType& ys,
119  const WeightsType& weights);
120 
131  KFoldCV(const size_t k,
132  const MatType& xs,
133  const PredictionsType& ys,
134  const size_t numClasses,
135  const WeightsType& weights);
136 
148  KFoldCV(const size_t k,
149  const MatType& xs,
150  const data::DatasetInfo& datasetInfo,
151  const PredictionsType& ys,
152  const size_t numClasses,
153  const WeightsType& weights);
154 
161  template<typename... MLAlgorithmArgs>
162  double Evaluate(const MLAlgorithmArgs& ...args);
163 
165  MLAlgorithm& Model();
166 
167  private:
170 
172  Base base;
173 
175  const size_t k;
176 
178  MatType xs;
180  PredictionsType ys;
182  WeightsType weights;
183 
185  size_t binSize;
186 
188  size_t trainingSubsetSize;
189 
191  std::unique_ptr<MLAlgorithm> modelPtr;
192 
197  KFoldCV(Base&& base,
198  const size_t k,
199  const MatType& xs,
200  const PredictionsType& ys);
201 
206  KFoldCV(Base&& base,
207  const size_t k,
208  const MatType& xs,
209  const PredictionsType& ys,
210  const WeightsType& weights);
211 
216  template<typename DataType>
217  void InitKFoldCVMat(const DataType& source, DataType& destination);
218 
222  template<typename...MLAlgorithmArgs,
223  bool Enabled = !Base::MIE::SupportsWeights,
224  typename = typename std::enable_if<Enabled>::type>
225  double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs);
226 
230  template<typename...MLAlgorithmArgs,
231  bool Enabled = Base::MIE::SupportsWeights,
232  typename = typename std::enable_if<Enabled>::type,
233  typename = void>
234  double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs);
235 
242  inline size_t ValidationSubsetFirstCol(const size_t i);
243 
247  template<typename ElementType>
248  inline arma::Mat<ElementType> GetTrainingSubset(arma::Mat<ElementType>& m,
249  const size_t i);
250 
254  template<typename ElementType>
255  inline arma::Row<ElementType> GetTrainingSubset(arma::Row<ElementType>& r,
256  const size_t i);
257 
261  template<typename ElementType>
262  inline arma::Mat<ElementType> GetValidationSubset(arma::Mat<ElementType>& m,
263  const size_t i);
264 
268  template<typename ElementType>
269  inline arma::Row<ElementType> GetValidationSubset(arma::Row<ElementType>& r,
270  const size_t i);
271 };
272 
273 } // namespace cv
274 } // namespace mlpack
275 
276 // Include implementation
277 #include "k_fold_cv_impl.hpp"
278 
279 #endif
double Evaluate(const MLAlgorithmArgs &...args)
Run k-fold cross-validation.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
.hpp
Definition: add_to_po.hpp:21
MLAlgorithm & Model()
Access and modify a model from the last run of k-fold cross-validation.
static const bool SupportsWeights
An indication whether MLAlgorithm supports weighted learning.
The class KFoldCV implements k-fold cross-validation for regression and classification algorithms...
Definition: k_fold_cv.hpp:61
typename Select< TF1, TF2, TF3, TF4, TF5 >::Type::PredictionsType PredictionsType
The type of predictions used in MLAlgorithm.
An auxiliary class for cross-validation.
Definition: cv_base.hpp:39
KFoldCV(const size_t k, const MatType &xs, const PredictionsType &ys)
This constructor can be used for regression algorithms and for binary classification algorithms...