decision_stump.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
13 #define MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace decision_stump {
19 
33 template<typename MatType = arma::mat>
35 {
36  public:
46  DecisionStump(const MatType& data,
47  const arma::Row<size_t>& labels,
48  const size_t numClasses,
49  const size_t bucketSize = 10);
50 
62  DecisionStump(const DecisionStump<>& other,
63  const MatType& data,
64  const arma::Row<size_t>& labels,
65  const size_t numClasses,
66  const arma::rowvec& weights);
67 
73  DecisionStump();
74 
85  void Train(const MatType& data,
86  const arma::Row<size_t>& labels,
87  const size_t numClasses,
88  const size_t bucketSize);
89 
101  void Train(const MatType& data,
102  const arma::Row<size_t>& labels,
103  const arma::rowvec& weights,
104  const size_t numClasses,
105  const size_t bucketSize);
106 
115  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
116 
118  size_t SplitDimension() const { return splitDimension; }
120  size_t& SplitDimension() { return splitDimension; }
121 
123  const arma::vec& Split() const { return split; }
125  arma::vec& Split() { return split; }
126 
128  const arma::Col<size_t> BinLabels() const { return binLabels; }
130  arma::Col<size_t>& BinLabels() { return binLabels; }
131 
133  template<typename Archive>
134  void serialize(Archive& ar, const unsigned int /* version */);
135 
136  private:
138  size_t numClasses;
140  size_t bucketSize;
141 
143  size_t splitDimension;
145  arma::vec split;
147  arma::Col<size_t> binLabels;
148 
157  template<bool UseWeights, typename VecType>
158  double SetupSplitDimension(const VecType& dimension,
159  const arma::Row<size_t>& labels,
160  const arma::rowvec& weightD);
161 
169  template<typename VecType>
170  void TrainOnDim(const VecType& dimension,
171  const arma::Row<size_t>& labels);
172 
177  void MergeRanges();
178 
185  template<typename VecType>
186  double CountMostFreq(const VecType& subCols);
187 
193  template<typename VecType>
194  int IsDistinct(const VecType& featureRow);
195 
205  template<bool UseWeights, typename VecType, typename WeightVecType>
206  double CalculateEntropy(const VecType& labels,
207  const WeightVecType& weights);
208 
218  template<bool UseWeights>
219  void Train(const MatType& data,
220  const arma::Row<size_t>& labels,
221  const arma::rowvec& weights);
222 };
223 
224 } // namespace decision_stump
225 } // namespace mlpack
226 
227 #include "decision_stump_impl.hpp"
228 
229 #endif
void serialize(Archive &ar, const unsigned int)
Serialize the decision stump.
size_t SplitDimension() const
Access the splitting dimension.
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
DecisionStump()
Create a decision stump without training.
This class implements a decision stump.
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t bucketSize)
Train the decision stump on the given data.
arma::Col< size_t > & BinLabels()
Modify the labels for each split bin (be careful!).
const arma::Col< size_t > BinLabels() const
Access the labels for each split bin.
arma::vec & Split()
Modify the splitting values (be careful!).
size_t & SplitDimension()
Modify the splitting dimension (be careful!).
void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
const arma::vec & Split() const
Access the splitting values.