shuffle_data.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_MATH_SHUFFLE_DATA_HPP
13 #define MLPACK_CORE_MATH_SHUFFLE_DATA_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace math {
19 
27 template<typename MatType, typename LabelsType>
28 void ShuffleData(const MatType& inputPoints,
29  const LabelsType& inputLabels,
30  MatType& outputPoints,
31  LabelsType& outputLabels,
32  const std::enable_if_t<!arma::is_SpMat<MatType>::value>* = 0,
33  const std::enable_if_t<!arma::is_Cube<MatType>::value>* = 0)
34 {
35  // Generate ordering.
36  arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
37  inputPoints.n_cols - 1, inputPoints.n_cols));
38 
39  outputPoints = inputPoints.cols(ordering);
40  outputLabels = inputLabels.cols(ordering);
41 }
42 
50 template<typename MatType, typename LabelsType>
51 void ShuffleData(const MatType& inputPoints,
52  const LabelsType& inputLabels,
53  MatType& outputPoints,
54  LabelsType& outputLabels,
55  const std::enable_if_t<arma::is_SpMat<MatType>::value>* = 0,
56  const std::enable_if_t<!arma::is_Cube<MatType>::value>* = 0)
57 {
58  // Generate ordering.
59  arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
60  inputPoints.n_cols - 1, inputPoints.n_cols));
61 
62  // Extract coordinate list representation.
63  arma::umat locations(2, inputPoints.n_nonzero);
64  arma::Col<typename MatType::elem_type> values(
65  const_cast<typename MatType::elem_type*>(inputPoints.values),
66  inputPoints.n_nonzero, false, true);
67  typename MatType::const_iterator it = inputPoints.begin();
68  size_t index = 0;
69  while (it != inputPoints.end())
70  {
71  locations(0, index) = it.row();
72  locations(1, index) = ordering[it.col()];
73  ++it;
74  ++index;
75  }
76 
77  if (&inputPoints == &outputPoints || &inputLabels == &outputLabels)
78  {
79  MatType newOutputPoints(locations, values, inputPoints.n_rows,
80  inputPoints.n_cols, true);
81  LabelsType newOutputLabels(inputLabels.n_elem);
82  newOutputLabels.cols(ordering) = inputLabels;
83 
84  outputPoints = std::move(newOutputPoints);
85  outputLabels = std::move(newOutputLabels);
86  }
87  else
88  {
89  outputPoints = MatType(locations, values, inputPoints.n_rows,
90  inputPoints.n_cols, true);
91  outputLabels.set_size(inputLabels.n_elem);
92  outputLabels.cols(ordering) = inputLabels;
93  }
94 }
95 
103 template<typename MatType, typename LabelsType>
104 void ShuffleData(const MatType& inputPoints,
105  const LabelsType& inputLabels,
106  MatType& outputPoints,
107  LabelsType& outputLabels,
108  const std::enable_if_t<!arma::is_SpMat<MatType>::value>* = 0,
109  const std::enable_if_t<arma::is_Cube<MatType>::value>* = 0,
110  const std::enable_if_t<arma::is_Cube<LabelsType>::value>* = 0)
111 {
112  // Generate ordering.
113  arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
114  inputPoints.n_cols - 1, inputPoints.n_cols));
115 
116  // Properly handle the case where the input and output data are the same
117  // object.
118  MatType* outputPointsPtr = &outputPoints;
119  LabelsType* outputLabelsPtr = &outputLabels;
120  if (&inputPoints == &outputPoints)
121  outputPointsPtr = new MatType();
122  if (&inputLabels == &outputLabels)
123  outputLabelsPtr = new LabelsType();
124 
125  outputPointsPtr->set_size(inputPoints.n_rows, inputPoints.n_cols,
126  inputPoints.n_slices);
127  outputLabelsPtr->set_size(inputLabels.n_rows, inputLabels.n_cols,
128  inputLabels.n_slices);
129  for (size_t i = 0; i < ordering.n_elem; ++i)
130  {
131  outputPointsPtr->tube(0, ordering[i], outputPointsPtr->n_rows - 1,
132  ordering[i]) = inputPoints.tube(0, i, inputPoints.n_rows - 1, i);
133  outputLabelsPtr->tube(0, ordering[i], outputLabelsPtr->n_rows - 1,
134  ordering[i]) = inputLabels.tube(0, i, inputLabels.n_rows - 1, i);
135  }
136 
137  // Clean up memory if needed.
138  if (&inputPoints == &outputPoints)
139  {
140  outputPoints = std::move(*outputPointsPtr);
141  delete outputPointsPtr;
142  }
143 
144  if (&inputLabels == &outputLabels)
145  {
146  outputLabels = std::move(*outputLabelsPtr);
147  delete outputLabelsPtr;
148  }
149 }
150 
151 } // namespace math
152 } // namespace mlpack
153 
154 #endif
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
void ShuffleData(const MatType &inputPoints, const LabelsType &inputLabels, MatType &outputPoints, LabelsType &outputLabels, const std::enable_if_t<!arma::is_SpMat< MatType >::value > *=0, const std::enable_if_t<!arma::is_Cube< MatType >::value > *=0)
Shuffle a dataset and associated labels (or responses).