mean_shift.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
14 #define MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
15 
16 #include <mlpack/prereqs.hpp>
20 #include <boost/utility.hpp>
21 
22 namespace mlpack {
23 namespace meanshift {
24 
48 template<bool UseKernel = false,
49  typename KernelType = kernel::GaussianKernel,
50  typename MatType = arma::mat>
51 class MeanShift
52 {
53  public:
65  MeanShift(const double radius = 0,
66  const size_t maxIterations = 1000,
67  const KernelType kernel = KernelType());
68 
75  double EstimateRadius(const MatType& data, const double ratio = 0.2);
76 
89  void Cluster(const MatType& data,
90  arma::Row<size_t>& assignments,
91  arma::mat& centroids,
92  bool forceConvergence = true,
93  bool useSeeds = true);
94 
96  size_t MaxIterations() const { return maxIterations; }
98  size_t& MaxIterations() { return maxIterations; }
99 
101  double Radius() const { return radius; }
103  void Radius(double radius);
104 
106  const KernelType& Kernel() const { return kernel; }
108  KernelType& Kernel() { return kernel; }
109 
110  private:
124  void GenSeeds(const MatType& data,
125  const double binSize,
126  const int minFreq,
127  MatType& seeds);
128 
137  template<bool ApplyKernel = UseKernel>
138  typename std::enable_if<ApplyKernel, bool>::type
139  CalculateCentroid(const MatType& data,
140  const std::vector<size_t>& neighbors,
141  const std::vector<double>& distances,
142  arma::colvec& centroid);
143 
152  template<bool ApplyKernel = UseKernel>
153  typename std::enable_if<!ApplyKernel, bool>::type
154  CalculateCentroid(const MatType& data,
155  const std::vector<size_t>& neighbors,
156  const std::vector<double>&, /*unused*/
157  arma::colvec& centroid);
158 
164  double radius;
165 
167  size_t maxIterations;
168 
170  KernelType kernel;
171 };
172 
173 } // namespace meanshift
174 } // namespace mlpack
175 
176 // Include implementation.
177 #include "mean_shift_impl.hpp"
178 
179 #endif // MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t & MaxIterations()
Set the maximum number of iterations.
Definition: mean_shift.hpp:98
This class implements mean shift clustering.
Definition: mean_shift.hpp:51
double EstimateRadius(const MatType &data, const double ratio=0.2)
Give an estimation of radius based on given dataset.
size_t MaxIterations() const
Get the maximum number of iterations.
Definition: mean_shift.hpp:96
MeanShift(const double radius=0, const size_t maxIterations=1000, const KernelType kernel=KernelType())
Create a mean shift object and set the parameters which mean shift will be run with.
void Cluster(const MatType &data, arma::Row< size_t > &assignments, arma::mat &centroids, bool forceConvergence=true, bool useSeeds=true)
Perform mean shift clustering on the data, returning a list of cluster assignments and centroids...
The standard Gaussian kernel.
KernelType & Kernel()
Modify the kernel.
Definition: mean_shift.hpp:108
double Radius() const
Get the radius.
Definition: mean_shift.hpp:101
const KernelType & Kernel() const
Get the kernel.
Definition: mean_shift.hpp:106