decision_stump.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
13 #define MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace decision_stump {
19 
38 template<typename MatType = arma::mat>
40 {
41  public:
51  mlpack_deprecated DecisionStump(const MatType& data,
52  const arma::Row<size_t>& labels,
53  const size_t numClasses,
54  const size_t bucketSize = 10);
55 
69  const MatType& data,
70  const arma::Row<size_t>& labels,
71  const size_t numClasses,
72  const arma::rowvec& weights);
73 
79  DecisionStump();
80 
92  mlpack_deprecated double Train(const MatType& data,
93  const arma::Row<size_t>& labels,
94  const size_t numClasses,
95  const size_t bucketSize);
96 
109  mlpack_deprecated double Train(const MatType& data,
110  const arma::Row<size_t>& labels,
111  const arma::rowvec& weights,
112  const size_t numClasses,
113  const size_t bucketSize);
114 
123  mlpack_deprecated void Classify(const MatType& test,
124  arma::Row<size_t>& predictedLabels);
125 
127  size_t SplitDimension() const { return splitDimension; }
129  size_t& SplitDimension() { return splitDimension; }
130 
132  const arma::vec& Split() const { return split; }
134  arma::vec& Split() { return split; }
135 
137  const arma::Col<size_t> BinLabels() const { return binLabels; }
139  arma::Col<size_t>& BinLabels() { return binLabels; }
140 
142  template<typename Archive>
143  void serialize(Archive& ar, const unsigned int /* version */);
144 
145  private:
147  size_t numClasses;
149  size_t bucketSize;
150 
152  size_t splitDimension;
154  arma::vec split;
156  arma::Col<size_t> binLabels;
157 
166  template<bool UseWeights, typename VecType>
167  double SetupSplitDimension(const VecType& dimension,
168  const arma::Row<size_t>& labels,
169  const arma::rowvec& weightD);
170 
178  template<typename VecType>
179  void TrainOnDim(const VecType& dimension,
180  const arma::Row<size_t>& labels);
181 
186  void MergeRanges();
187 
194  template<typename VecType>
195  double CountMostFreq(const VecType& subCols);
196 
202  template<typename VecType>
203  int IsDistinct(const VecType& featureRow);
204 
214  template<bool UseWeights, typename VecType, typename WeightVecType>
215  double CalculateEntropy(const VecType& labels,
216  const WeightVecType& weights);
217 
228  template<bool UseWeights>
229  double Train(const MatType& data,
230  const arma::Row<size_t>& labels,
231  const arma::rowvec& weights);
232 };
233 
234 } // namespace decision_stump
235 } // namespace mlpack
236 
237 #include "decision_stump_impl.hpp"
238 
239 #endif
void serialize(Archive &ar, const unsigned int)
Serialize the decision stump.
size_t SplitDimension() const
Access the splitting dimension.
Linear algebra utility functions, generally performed on matrices or vectors.
mlpack_deprecated double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t bucketSize)
Train the decision stump on the given data.
The core includes that mlpack expects; standard C++ includes and Armadillo.
DecisionStump()
Create a decision stump without training.
#define mlpack_deprecated
Definition: deprecated.hpp:22
This class implements a decision stump.
arma::Col< size_t > & BinLabels()
Modify the labels for each split bin (be careful!).
const arma::Col< size_t > BinLabels() const
Access the labels for each split bin.
mlpack_deprecated void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
arma::vec & Split()
Modify the splitting values (be careful!).
size_t & SplitDimension()
Modify the splitting dimension (be careful!).
const arma::vec & Split() const
Access the splitting values.