12 #ifndef MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP 13 #define MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP 16 #include <ensmallen.hpp> 21 namespace regression {
72 const size_t numClasses = 0,
73 const bool fitIntercept =
false);
89 template<
typename OptimizerType = ens::L_BFGS>
91 const arma::Row<size_t>& labels,
92 const size_t numClasses,
93 const double lambda = 0.0001,
94 const bool fitIntercept =
false,
95 OptimizerType optimizer = OptimizerType());
114 template<
typename OptimizerType,
typename... CallbackTypes>
116 const arma::Row<size_t>& labels,
117 const size_t numClasses,
119 const bool fitIntercept,
120 OptimizerType optimizer,
121 CallbackTypes&&... callbacks);
130 void Classify(
const arma::mat& dataset, arma::Row<size_t>& labels)
const;
138 template<
typename VecType>
139 size_t Classify(
const VecType& point)
const;
152 void Classify(
const arma::mat& dataset,
153 arma::Row<size_t>& labels,
154 arma::mat& probabilites)
const;
162 void Classify(
const arma::mat& dataset,
163 arma::mat& probabilities)
const;
174 const arma::Row<size_t>& labels)
const;
185 template<
typename OptimizerType = ens::L_BFGS>
186 double Train(
const arma::mat& data,
187 const arma::Row<size_t>& labels,
188 const size_t numClasses,
189 OptimizerType optimizer = OptimizerType());
203 template<
typename OptimizerType = ens::L_BFGS,
typename... CallbackTypes>
204 double Train(
const arma::mat& data,
205 const arma::Row<size_t>& labels,
206 const size_t numClasses,
207 OptimizerType optimizer,
208 CallbackTypes&&... callbacks);
230 {
return fitIntercept ? parameters.n_cols - 1:
236 template<
typename Archive>
239 ar & BOOST_SERIALIZATION_NVP(parameters);
240 ar & BOOST_SERIALIZATION_NVP(numClasses);
241 ar & BOOST_SERIALIZATION_NVP(lambda);
242 ar & BOOST_SERIALIZATION_NVP(fitIntercept);
247 arma::mat parameters;
260 #include "softmax_regression_impl.hpp" SoftmaxRegression(const size_t inputSize=0, const size_t numClasses=0, const bool fitIntercept=false)
Initialize the SoftmaxRegression without performing training.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Lambda() const
Gets the regularization parameter.
bool FitIntercept() const
Gets the intercept term flag. We can't change this after training.
size_t NumClasses() const
Gets the number of classes.
double Train(const arma::mat &data, const arma::Row< size_t > &labels, const size_t numClasses, OptimizerType optimizer=OptimizerType())
Train the softmax regression with the given training data.
Softmax Regression is a classifier which can be used for classification when the data available can t...
arma::mat & Parameters()
Get the model parameters.
void serialize(Archive &ar, const unsigned int)
Serialize the SoftmaxRegression model.
double ComputeAccuracy(const arma::mat &testData, const arma::Row< size_t > &labels) const
Computes accuracy of the learned model given the feature data and the labels associated with each dat...
double & Lambda()
Sets the regularization parameter.
size_t FeatureSize() const
Gets the features size of the training data.
size_t & NumClasses()
Sets the number of classes.
void Classify(const arma::mat &dataset, arma::Row< size_t > &labels) const
Classify the given points, returning the predicted labels for each point.
const arma::mat & Parameters() const
Get the model parameters.