15 #ifndef MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP 16 #define MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP 21 namespace naive_bayes {
57 template<
typename ModelMatType = arma::mat>
62 typedef typename ModelMatType::elem_type
ElemType;
82 template<
typename MatType>
84 const arma::Row<size_t>& labels,
85 const size_t numClasses,
86 const bool incrementalVariance =
false);
95 const size_t numClasses = 0);
114 template<
typename MatType>
115 void Train(
const MatType& data,
116 const arma::Row<size_t>& labels,
117 const size_t numClasses,
118 const bool incremental =
true);
128 template<
typename VecType>
129 void Train(
const VecType& point,
const size_t label);
137 template<
typename VecType>
138 size_t Classify(
const VecType& point)
const;
150 template<
typename VecType,
typename ProbabilitiesVecType>
153 ProbabilitiesVecType& probabilities)
const;
169 template<
typename MatType>
171 arma::Row<size_t>& predictions)
const;
194 template<
typename MatType,
typename ProbabilitiesMatType>
196 arma::Row<size_t>& predictions,
197 ProbabilitiesMatType& probabilities)
const;
200 const ModelMatType&
Means()
const {
return means; }
202 ModelMatType&
Means() {
return means; }
205 const ModelMatType&
Variances()
const {
return variances; }
215 template<
typename Archive>
216 void serialize(Archive& ar,
const unsigned int );
222 ModelMatType variances;
224 ModelMatType probabilities;
226 size_t trainingPoints;
236 template<
typename MatType>
237 void LogLikelihood(
const MatType& data,
238 ModelMatType& logLikelihoods)
const;
245 #include "naive_bayes_classifier_impl.hpp" NaiveBayesClassifier(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const bool incrementalVariance=false)
Initializes the classifier as per the input and then trains it by calculating the sample mean and var...
ModelMatType & Probabilities()
Modify the prior probabilities for each class.
ModelMatType & Variances()
Modify the sample variances for each class.
The core includes that mlpack expects; standard C++ includes and Armadillo.
const ModelMatType & Variances() const
Get the sample variances for each class.
The simple Naive Bayes classifier.
size_t Classify(const VecType &point) const
Classify the given point, using the trained NaiveBayesClassifier model.
const ModelMatType & Probabilities() const
Get the prior probabilities for each class.
void serialize(Archive &ar, const unsigned int)
Serialize the classifier.
ModelMatType::elem_type ElemType
const ModelMatType & Means() const
Get the sample means for each class.
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const bool incremental=true)
Train the Naive Bayes classifier on the given dataset.
ModelMatType & Means()
Modify the sample means for each class.