adaboost_model.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ADABOOST_ADABOOST_MODEL_HPP
13 #define MLPACK_METHODS_ADABOOST_ADABOOST_MODEL_HPP
14 
15 #include <mlpack/core.hpp>
16 
17 // Use forward declaration instead of include to accelerate compilation.
18 class AdaBoost;
19 
20 namespace mlpack {
21 namespace adaboost {
22 
27 {
28  public:
30  {
33  };
34 
35  private:
37  arma::Col<size_t> mappings;
39  size_t weakLearnerType;
45  size_t dimensionality;
46 
47  public:
49  AdaBoostModel();
50 
52  AdaBoostModel(const arma::Col<size_t>& mappings,
53  const size_t weakLearnerType);
54 
56  AdaBoostModel(const AdaBoostModel& other);
57 
60 
62  AdaBoostModel& operator=(const AdaBoostModel& other);
63 
66 
68  const arma::Col<size_t>& Mappings() const { return mappings; }
70  arma::Col<size_t>& Mappings() { return mappings; }
71 
73  size_t WeakLearnerType() const { return weakLearnerType; }
75  size_t& WeakLearnerType() { return weakLearnerType; }
76 
78  size_t Dimensionality() const { return dimensionality; }
80  size_t& Dimensionality() { return dimensionality; }
81 
83  void Train(const arma::mat& data,
84  const arma::Row<size_t>& labels,
85  const size_t numClasses,
86  const size_t iterations,
87  const double tolerance);
88 
90  void Classify(const arma::mat& testData, arma::Row<size_t>& predictions);
91 
93  template<typename Archive>
94  void serialize(Archive& ar, const unsigned int /* version */)
95  {
96  if (Archive::is_loading::value)
97  {
98  if (dsBoost)
99  delete dsBoost;
100  if (pBoost)
101  delete pBoost;
102 
103  dsBoost = NULL;
104  pBoost = NULL;
105  }
106 
107  ar & BOOST_SERIALIZATION_NVP(mappings);
108  ar & BOOST_SERIALIZATION_NVP(weakLearnerType);
109  if (weakLearnerType == WeakLearnerTypes::DECISION_STUMP)
110  ar & BOOST_SERIALIZATION_NVP(dsBoost);
111  else if (weakLearnerType == WeakLearnerTypes::PERCEPTRON)
112  ar & BOOST_SERIALIZATION_NVP(pBoost);
113  ar & BOOST_SERIALIZATION_NVP(dimensionality);
114  }
115 };
116 
117 } // namespace adaboost
118 } // namespace mlpack
119 
120 #endif
~AdaBoostModel()
Clean up memory.
void Classify(const arma::mat &testData, arma::Row< size_t > &predictions)
Classify test points.
.hpp
Definition: add_to_po.hpp:21
The AdaBoost class.
Definition: adaboost.hpp:81
size_t & Dimensionality()
Modify the dimensionality of the model.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
The model to save to disk.
arma::Col< size_t > & Mappings()
Modify the mappings.
void Train(const arma::mat &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t iterations, const double tolerance)
Train the model.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
size_t WeakLearnerType() const
Get the weak learner type.
AdaBoostModel()
Create an empty AdaBoost model.
const arma::Col< size_t > & Mappings() const
Get the mappings.
AdaBoostModel & operator=(const AdaBoostModel &other)
Copy assignment operator.
size_t Dimensionality() const
Get the dimensionality of the model.
size_t & WeakLearnerType()
Modify the weak learner type.