13 #ifndef MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP 14 #define MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP 43 template<
bool UseWeights,
typename RowType,
typename WeightVecType>
45 const size_t numClasses,
46 const WeightVecType& weights)
49 if (labels.n_elem == 0)
54 arma::vec countSpace(4 * numClasses, arma::fill::zeros);
55 arma::vec counts(countSpace.memptr(), numClasses,
false,
true);
56 arma::vec counts2(countSpace.memptr() + numClasses, numClasses,
false,
58 arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses,
false,
60 arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses,
false,
64 double impurity = 0.0;
69 double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
73 for (
size_t i = 3; i < labels.n_elem; i += 4)
75 const double weight1 = weights[i - 3];
76 const double weight2 = weights[i - 2];
77 const double weight3 = weights[i - 1];
78 const double weight4 = weights[i];
80 counts[labels[i - 3]] += weight1;
81 counts2[labels[i - 2]] += weight2;
82 counts3[labels[i - 1]] += weight3;
83 counts4[labels[i]] += weight4;
85 accWeights[0] += weight1;
86 accWeights[1] += weight2;
87 accWeights[2] += weight3;
88 accWeights[3] += weight4;
92 if (labels.n_elem % 4 == 1)
94 const double weight1 = weights[labels.n_elem - 1];
95 counts[labels[labels.n_elem - 1]] += weight1;
96 accWeights[0] += weight1;
98 else if (labels.n_elem % 4 == 2)
100 const double weight1 = weights[labels.n_elem - 2];
101 const double weight2 = weights[labels.n_elem - 1];
103 counts[labels[labels.n_elem - 2]] += weight1;
104 counts2[labels[labels.n_elem - 1]] += weight2;
106 accWeights[0] += weight1;
107 accWeights[1] += weight2;
109 else if (labels.n_elem % 4 == 3)
111 const double weight1 = weights[labels.n_elem - 3];
112 const double weight2 = weights[labels.n_elem - 2];
113 const double weight3 = weights[labels.n_elem - 1];
115 counts[labels[labels.n_elem - 3]] += weight1;
116 counts2[labels[labels.n_elem - 2]] += weight2;
117 counts3[labels[labels.n_elem - 1]] += weight3;
119 accWeights[0] += weight1;
120 accWeights[1] += weight2;
121 accWeights[2] += weight3;
124 accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
125 counts += counts2 + counts3 + counts4;
128 if (accWeights[0] == 0.0)
131 for (
size_t i = 0; i < numClasses; ++i)
133 const double f = ((double) counts[i] / (
double) accWeights[0]);
134 impurity += f * (1.0 - f);
141 for (
size_t i = 3; i < labels.n_elem; i += 4)
143 counts[labels[i - 3]]++;
144 counts2[labels[i - 2]]++;
145 counts3[labels[i - 1]]++;
146 counts4[labels[i]]++;
150 if (labels.n_elem % 4 == 1)
152 counts[labels[labels.n_elem - 1]]++;
154 else if (labels.n_elem % 4 == 2)
156 counts[labels[labels.n_elem - 2]]++;
157 counts2[labels[labels.n_elem - 1]]++;
159 else if (labels.n_elem % 4 == 3)
161 counts[labels[labels.n_elem - 3]]++;
162 counts2[labels[labels.n_elem - 2]]++;
163 counts3[labels[labels.n_elem - 1]]++;
166 counts += counts2 + counts3 + counts4;
168 for (
size_t i = 0; i < numClasses; ++i)
170 const double f = ((double) counts[i] / (
double) labels.n_elem);
171 impurity += f * (1.0 - f);
185 static double Range(
const size_t numClasses)
190 return 1.0 - (1.0 / double(numClasses));
static double Evaluate(const RowType &labels, const size_t numClasses, const WeightVecType &weights)
Evaluate the Gini impurity on the given set of labels.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
static double Range(const size_t numClasses)
Return the range of the Gini impurity for the given number of classes.