decision_stump.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
13 #define MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace decision_stump {
19 
37 template<typename MatType = arma::mat>
39 {
40  public:
50  mlpack_deprecated DecisionStump(const MatType& data,
51  const arma::Row<size_t>& labels,
52  const size_t numClasses,
53  const size_t bucketSize = 10);
54 
67  const MatType& data,
68  const arma::Row<size_t>& labels,
69  const size_t numClasses,
70  const arma::rowvec& weights);
71 
77  DecisionStump();
78 
90  mlpack_deprecated double Train(const MatType& data,
91  const arma::Row<size_t>& labels,
92  const size_t numClasses,
93  const size_t bucketSize);
94 
107  mlpack_deprecated double Train(const MatType& data,
108  const arma::Row<size_t>& labels,
109  const arma::rowvec& weights,
110  const size_t numClasses,
111  const size_t bucketSize);
112 
121  mlpack_deprecated void Classify(const MatType& test,
122  arma::Row<size_t>& predictedLabels);
123 
125  size_t SplitDimension() const { return splitDimension; }
127  size_t& SplitDimension() { return splitDimension; }
128 
130  const arma::vec& Split() const { return split; }
132  arma::vec& Split() { return split; }
133 
135  const arma::Col<size_t> BinLabels() const { return binLabels; }
137  arma::Col<size_t>& BinLabels() { return binLabels; }
138 
140  template<typename Archive>
141  void serialize(Archive& ar, const unsigned int /* version */);
142 
143  private:
145  size_t numClasses;
147  size_t bucketSize;
148 
150  size_t splitDimension;
152  arma::vec split;
154  arma::Col<size_t> binLabels;
155 
164  template<bool UseWeights, typename VecType>
165  double SetupSplitDimension(const VecType& dimension,
166  const arma::Row<size_t>& labels,
167  const arma::rowvec& weightD);
168 
176  template<typename VecType>
177  void TrainOnDim(const VecType& dimension,
178  const arma::Row<size_t>& labels);
179 
184  void MergeRanges();
185 
192  template<typename VecType>
193  double CountMostFreq(const VecType& subCols);
194 
200  template<typename VecType>
201  int IsDistinct(const VecType& featureRow);
202 
212  template<bool UseWeights, typename VecType, typename WeightVecType>
213  double CalculateEntropy(const VecType& labels,
214  const WeightVecType& weights);
215 
226  template<bool UseWeights>
227  double Train(const MatType& data,
228  const arma::Row<size_t>& labels,
229  const arma::rowvec& weights);
230 };
231 
232 } // namespace decision_stump
233 } // namespace mlpack
234 
235 #include "decision_stump_impl.hpp"
236 
237 #endif
void serialize(Archive &ar, const unsigned int)
Serialize the decision stump.
size_t SplitDimension() const
Access the splitting dimension.
strip_type.hpp
Definition: add_to_po.hpp:21
mlpack_deprecated double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t bucketSize)
Train the decision stump on the given data.
The core includes that mlpack expects; standard C++ includes and Armadillo.
DecisionStump()
Create a decision stump without training.
#define mlpack_deprecated
Definition: deprecated.hpp:22
This class implements a decision stump.
arma::Col< size_t > & BinLabels()
Modify the labels for each split bin (be careful!).
const arma::Col< size_t > BinLabels() const
Access the labels for each split bin.
mlpack_deprecated void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
arma::vec & Split()
Modify the splitting values (be careful!).
size_t & SplitDimension()
Modify the splitting dimension (be careful!).
const arma::vec & Split() const
Access the splitting values.