decision_stump.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
13 #define MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace decision_stump {
19 
33 template<typename MatType = arma::mat>
35 {
36  public:
46  DecisionStump(const MatType& data,
47  const arma::Row<size_t>& labels,
48  const size_t numClasses,
49  const size_t bucketSize = 10);
50 
62  DecisionStump(const DecisionStump<>& other,
63  const MatType& data,
64  const arma::Row<size_t>& labels,
65  const size_t numClasses,
66  const arma::rowvec& weights);
67 
73  DecisionStump();
74 
86  double Train(const MatType& data,
87  const arma::Row<size_t>& labels,
88  const size_t numClasses,
89  const size_t bucketSize);
90 
103  double Train(const MatType& data,
104  const arma::Row<size_t>& labels,
105  const arma::rowvec& weights,
106  const size_t numClasses,
107  const size_t bucketSize);
108 
117  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
118 
120  size_t SplitDimension() const { return splitDimension; }
122  size_t& SplitDimension() { return splitDimension; }
123 
125  const arma::vec& Split() const { return split; }
127  arma::vec& Split() { return split; }
128 
130  const arma::Col<size_t> BinLabels() const { return binLabels; }
132  arma::Col<size_t>& BinLabels() { return binLabels; }
133 
135  template<typename Archive>
136  void serialize(Archive& ar, const unsigned int /* version */);
137 
138  private:
140  size_t numClasses;
142  size_t bucketSize;
143 
145  size_t splitDimension;
147  arma::vec split;
149  arma::Col<size_t> binLabels;
150 
159  template<bool UseWeights, typename VecType>
160  double SetupSplitDimension(const VecType& dimension,
161  const arma::Row<size_t>& labels,
162  const arma::rowvec& weightD);
163 
171  template<typename VecType>
172  void TrainOnDim(const VecType& dimension,
173  const arma::Row<size_t>& labels);
174 
179  void MergeRanges();
180 
187  template<typename VecType>
188  double CountMostFreq(const VecType& subCols);
189 
195  template<typename VecType>
196  int IsDistinct(const VecType& featureRow);
197 
207  template<bool UseWeights, typename VecType, typename WeightVecType>
208  double CalculateEntropy(const VecType& labels,
209  const WeightVecType& weights);
210 
221  template<bool UseWeights>
222  double Train(const MatType& data,
223  const arma::Row<size_t>& labels,
224  const arma::rowvec& weights);
225 };
226 
227 } // namespace decision_stump
228 } // namespace mlpack
229 
230 #include "decision_stump_impl.hpp"
231 
232 #endif
void serialize(Archive &ar, const unsigned int)
Serialize the decision stump.
size_t SplitDimension() const
Access the splitting dimension.
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
DecisionStump()
Create a decision stump without training.
This class implements a decision stump.
arma::Col< size_t > & BinLabels()
Modify the labels for each split bin (be careful!).
const arma::Col< size_t > BinLabels() const
Access the labels for each split bin.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t bucketSize)
Train the decision stump on the given data.
arma::vec & Split()
Modify the splitting values (be careful!).
size_t & SplitDimension()
Modify the splitting dimension (be careful!).
void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
const arma::vec & Split() const
Access the splitting values.