linear_svm.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP
13 #define MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include <ensmallen.hpp>
17 
18 #include "linear_svm_function.hpp"
19 
20 namespace mlpack {
21 namespace svm {
22 
79 template <typename MatType = arma::mat>
80 class LinearSVM
81 {
82  public:
97  template <typename OptimizerType = ens::L_BFGS>
98  LinearSVM(const MatType& data,
99  const arma::Row<size_t>& labels,
100  const size_t numClasses = 2,
101  const double lambda = 0.0001,
102  const double delta = 1.0,
103  const bool fitIntercept = false,
104  OptimizerType optimizer = OptimizerType());
105 
117  LinearSVM(const size_t inputSize,
118  const size_t numClasses = 0,
119  const double lambda = 0.0001,
120  const double delta = 1.0,
121  const bool fitIntercept = false);
122 
132  void Classify(const MatType& data,
133  arma::Row<size_t>& labels) const;
134 
146  void Classify(const MatType& data,
147  arma::Row<size_t>& labels,
148  arma::mat& scores) const;
149 
156  void Classify(const MatType& data,
157  arma::mat& scores) const;
158 
167  template<typename VecType>
168  size_t Classify(const VecType& point) const;
169 
179  double ComputeAccuracy(const MatType& testData,
180  const arma::Row<size_t>& testLabels) const;
181 
193  template <typename OptimizerType = ens::L_BFGS>
194  double Train(const MatType& data,
195  const arma::Row<size_t>& labels,
196  const size_t numClasses = 2,
197  OptimizerType optimizer = OptimizerType());
198 
199 
201  size_t& NumClasses() { return numClasses; }
203  size_t NumClasses() const { return numClasses; }
204 
206  double& Lambda() { return lambda; }
208  double Lambda() const { return lambda; }
209 
211  arma::mat& Parameters() { return parameters; }
213  const arma::mat& Parameters() const { return parameters; }
214 
216  size_t FeatureSize() const
217  { return fitIntercept ? parameters.n_rows - 1 :
218  parameters.n_rows; }
219 
223  template<typename Archive>
224  void serialize(Archive& ar, const unsigned int /* version */)
225  {
226  ar & BOOST_SERIALIZATION_NVP(parameters);
227  ar & BOOST_SERIALIZATION_NVP(numClasses);
228  ar & BOOST_SERIALIZATION_NVP(lambda);
229  ar & BOOST_SERIALIZATION_NVP(fitIntercept);
230  }
231 
232  private:
234  arma::mat parameters;
236  size_t numClasses;
238  double lambda;
240  double delta;
242  bool fitIntercept;
243 };
244 
245 } // namespace svm
246 } // namespace mlpack
247 
248 // Include implementation.
249 #include "linear_svm_impl.hpp"
250 
251 #endif // MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP
.hpp
Definition: add_to_po.hpp:21
arma::mat & Parameters()
Set the model parameters.
Definition: linear_svm.hpp:211
The core includes that mlpack expects; standard C++ includes and Armadillo.
double & Lambda()
Sets the regularization parameter.
Definition: linear_svm.hpp:206
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses=2, OptimizerType optimizer=OptimizerType())
Train the Linear SVM with the given training data.
size_t & NumClasses()
Sets the number of classes.
Definition: linear_svm.hpp:201
size_t FeatureSize() const
Gets the features size of the training data.
Definition: linear_svm.hpp:216
const arma::mat & Parameters() const
Get the model parameters.
Definition: linear_svm.hpp:213
LinearSVM(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses=2, const double lambda=0.0001, const double delta=1.0, const bool fitIntercept=false, OptimizerType optimizer=OptimizerType())
Construct the LinearSVM class with the provided data and labels.
void Classify(const MatType &data, arma::Row< size_t > &labels) const
Classify the given points, returning the predicted labels for each point.
The LinearSVM class implements an L2-regularized support vector machine model, and supports training ...
Definition: linear_svm.hpp:80
double ComputeAccuracy(const MatType &testData, const arma::Row< size_t > &testLabels) const
Computes accuracy of the learned model given the feature data and the labels associated with each dat...
double Lambda() const
Gets the regularization parameter.
Definition: linear_svm.hpp:208
void serialize(Archive &ar, const unsigned int)
Serialize the LinearSVM model.
Definition: linear_svm.hpp:224
size_t NumClasses() const
Gets the number of classes.
Definition: linear_svm.hpp:203