13 #ifndef MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP 14 #define MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP 31 template<
bool UseWeights,
typename CountType>
33 const size_t countLength,
34 const CountType totalCount)
38 for (
size_t i = 0; i < countLength; ++i)
40 const double f = ((double) counts[i] / (
double) totalCount);
42 gain += f * std::log2(f);
58 template<
bool UseWeights>
59 static double Evaluate(
const arma::Row<size_t>& labels,
60 const size_t numClasses,
61 const arma::Row<double>& weights)
64 if (labels.n_elem == 0)
72 arma::vec countSpace(4 * numClasses, arma::fill::zeros);
73 arma::vec counts(countSpace.memptr(), numClasses,
false,
true);
74 arma::vec counts2(countSpace.memptr() + numClasses, numClasses,
false,
76 arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses,
false,
78 arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses,
false,
84 double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
88 for (
size_t i = 3; i < labels.n_elem; i += 4)
90 const double weight1 = weights[i - 3];
91 const double weight2 = weights[i - 2];
92 const double weight3 = weights[i - 1];
93 const double weight4 = weights[i];
95 counts[labels[i - 3]] += weight1;
96 counts2[labels[i - 2]] += weight2;
97 counts3[labels[i - 1]] += weight3;
98 counts4[labels[i]] += weight4;
100 accWeights[0] += weight1;
101 accWeights[1] += weight2;
102 accWeights[2] += weight3;
103 accWeights[3] += weight4;
107 if (labels.n_elem % 4 == 1)
109 const double weight1 = weights[labels.n_elem - 1];
110 counts[labels[labels.n_elem - 1]] += weight1;
111 accWeights[0] += weight1;
113 else if (labels.n_elem % 4 == 2)
115 const double weight1 = weights[labels.n_elem - 2];
116 const double weight2 = weights[labels.n_elem - 1];
118 counts[labels[labels.n_elem - 2]] += weight1;
119 counts2[labels[labels.n_elem - 1]] += weight2;
121 accWeights[0] += weight1;
122 accWeights[1] += weight2;
124 else if (labels.n_elem % 4 == 3)
126 const double weight1 = weights[labels.n_elem - 3];
127 const double weight2 = weights[labels.n_elem - 2];
128 const double weight3 = weights[labels.n_elem - 1];
130 counts[labels[labels.n_elem - 3]] += weight1;
131 counts2[labels[labels.n_elem - 2]] += weight2;
132 counts3[labels[labels.n_elem - 1]] += weight3;
134 accWeights[0] += weight1;
135 accWeights[1] += weight2;
136 accWeights[2] += weight3;
139 accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
140 counts += counts2 + counts3 + counts4;
143 if (accWeights[0] == 0.0)
146 for (
size_t i = 0; i < numClasses; ++i)
148 const double f = ((double) counts[i] / (
double) accWeights[0]);
150 gain += f * std::log2(f);
157 for (
size_t i = 3; i < labels.n_elem; i += 4)
159 counts[labels[i - 3]]++;
160 counts2[labels[i - 2]]++;
161 counts3[labels[i - 1]]++;
162 counts4[labels[i]]++;
166 if (labels.n_elem % 4 == 1)
168 counts[labels[labels.n_elem - 1]]++;
170 else if (labels.n_elem % 4 == 2)
172 counts[labels[labels.n_elem - 2]]++;
173 counts2[labels[labels.n_elem - 1]]++;
175 else if (labels.n_elem % 4 == 3)
177 counts[labels[labels.n_elem - 3]]++;
178 counts2[labels[labels.n_elem - 2]]++;
179 counts3[labels[labels.n_elem - 1]]++;
182 counts += counts2 + counts3 + counts4;
184 for (
size_t i = 0; i < numClasses; ++i)
186 const double f = ((double) counts[i] / (
double) labels.n_elem);
188 gain += f * std::log2(f);
202 static double Range(
const size_t numClasses)
207 return std::log2(numClasses);
The core includes that mlpack expects; standard C++ includes and Armadillo.