12 #ifndef MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP 13 #define MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP 18 namespace decision_stump {
33 template<
typename MatType = arma::mat>
47 const arma::Row<size_t>& labels,
48 const size_t numClasses,
49 const size_t bucketSize = 10);
64 const arma::Row<size_t>& labels,
65 const size_t numClasses,
66 const arma::rowvec& weights);
86 double Train(
const MatType& data,
87 const arma::Row<size_t>& labels,
88 const size_t numClasses,
89 const size_t bucketSize);
103 double Train(
const MatType& data,
104 const arma::Row<size_t>& labels,
105 const arma::rowvec& weights,
106 const size_t numClasses,
107 const size_t bucketSize);
117 void Classify(
const MatType& test, arma::Row<size_t>& predictedLabels);
125 const arma::vec&
Split()
const {
return split; }
127 arma::vec&
Split() {
return split; }
130 const arma::Col<size_t>
BinLabels()
const {
return binLabels; }
135 template<
typename Archive>
136 void serialize(Archive& ar,
const unsigned int );
145 size_t splitDimension;
149 arma::Col<size_t> binLabels;
159 template<
bool UseWeights,
typename VecType>
160 double SetupSplitDimension(
const VecType& dimension,
161 const arma::Row<size_t>& labels,
162 const arma::rowvec& weightD);
171 template<
typename VecType>
172 void TrainOnDim(
const VecType& dimension,
173 const arma::Row<size_t>& labels);
187 template<
typename VecType>
188 double CountMostFreq(
const VecType& subCols);
195 template<
typename VecType>
196 int IsDistinct(
const VecType& featureRow);
207 template<
bool UseWeights,
typename VecType,
typename WeightVecType>
208 double CalculateEntropy(
const VecType& labels,
209 const WeightVecType& weights);
221 template<
bool UseWeights>
222 double Train(
const MatType& data,
223 const arma::Row<size_t>& labels,
224 const arma::rowvec& weights);
230 #include "decision_stump_impl.hpp" void serialize(Archive &ar, const unsigned int)
Serialize the decision stump.
size_t SplitDimension() const
Access the splitting dimension.
The core includes that mlpack expects; standard C++ includes and Armadillo.
DecisionStump()
Create a decision stump without training.
This class implements a decision stump.
arma::Col< size_t > & BinLabels()
Modify the labels for each split bin (be careful!).
const arma::Col< size_t > BinLabels() const
Access the labels for each split bin.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t bucketSize)
Train the decision stump on the given data.
arma::vec & Split()
Modify the splitting values (be careful!).
size_t & SplitDimension()
Modify the splitting dimension (be careful!).
void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
const arma::vec & Split() const
Access the splitting values.