12 #ifndef MLPACK_CORE_MATH_SHUFFLE_DATA_HPP 13 #define MLPACK_CORE_MATH_SHUFFLE_DATA_HPP 27 template<
typename MatType,
typename LabelsType>
29 const LabelsType& inputLabels,
30 MatType& outputPoints,
31 LabelsType& outputLabels,
36 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
37 inputPoints.n_cols - 1, inputPoints.n_cols));
39 outputPoints = inputPoints.cols(ordering);
40 outputLabels = inputLabels.cols(ordering);
50 template<
typename MatType,
typename LabelsType>
52 const LabelsType& inputLabels,
53 MatType& outputPoints,
54 LabelsType& outputLabels,
59 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
60 inputPoints.n_cols - 1, inputPoints.n_cols));
63 arma::umat locations(2, inputPoints.n_nonzero);
64 arma::Col<typename MatType::elem_type> values(
65 const_cast<typename MatType::elem_type*>(inputPoints.values),
66 inputPoints.n_nonzero,
false,
true);
67 typename MatType::const_iterator it = inputPoints.begin();
69 while (it != inputPoints.end())
71 locations(0, index) = it.row();
72 locations(1, index) = ordering[it.col()];
77 if (&inputPoints == &outputPoints || &inputLabels == &outputLabels)
79 MatType newOutputPoints(locations, values, inputPoints.n_rows,
80 inputPoints.n_cols,
true);
81 LabelsType newOutputLabels(inputLabels.n_elem);
82 newOutputLabels.cols(ordering) = inputLabels;
84 outputPoints = std::move(newOutputPoints);
85 outputLabels = std::move(newOutputLabels);
89 outputPoints = MatType(locations, values, inputPoints.n_rows,
90 inputPoints.n_cols,
true);
91 outputLabels.set_size(inputLabels.n_elem);
92 outputLabels.cols(ordering) = inputLabels;
103 template<
typename MatType,
typename LabelsType>
105 const LabelsType& inputLabels,
106 MatType& outputPoints,
107 LabelsType& outputLabels,
113 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
114 inputPoints.n_cols - 1, inputPoints.n_cols));
118 MatType* outputPointsPtr = &outputPoints;
119 LabelsType* outputLabelsPtr = &outputLabels;
120 if (&inputPoints == &outputPoints)
121 outputPointsPtr =
new MatType();
122 if (&inputLabels == &outputLabels)
123 outputLabelsPtr =
new LabelsType();
125 outputPointsPtr->set_size(inputPoints.n_rows, inputPoints.n_cols,
126 inputPoints.n_slices);
127 outputLabelsPtr->set_size(inputLabels.n_rows, inputLabels.n_cols,
128 inputLabels.n_slices);
129 for (
size_t i = 0; i < ordering.n_elem; ++i)
131 outputPointsPtr->tube(0, ordering[i], outputPointsPtr->n_rows - 1,
132 ordering[i]) = inputPoints.tube(0, i, inputPoints.n_rows - 1, i);
133 outputLabelsPtr->tube(0, ordering[i], outputLabelsPtr->n_rows - 1,
134 ordering[i]) = inputLabels.tube(0, i, inputLabels.n_rows - 1, i);
138 if (&inputPoints == &outputPoints)
140 outputPoints = std::move(*outputPointsPtr);
141 delete outputPointsPtr;
144 if (&inputLabels == &outputLabels)
146 outputLabels = std::move(*outputLabelsPtr);
147 delete outputLabelsPtr;
typename enable_if< B, T >::type enable_if_t
The core includes that mlpack expects; standard C++ includes and Armadillo.
void ShuffleData(const MatType &inputPoints, const LabelsType &inputLabels, MatType &outputPoints, LabelsType &outputLabels, const std::enable_if_t<!arma::is_SpMat< MatType >::value > *=0, const std::enable_if_t<!arma::is_Cube< MatType >::value > *=0)
Shuffle a dataset and associated labels (or responses).