13 #ifndef MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP 14 #define MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP 38 template<
bool UseWeights>
39 static double Evaluate(
const arma::Row<size_t>& labels,
40 const size_t numClasses,
41 const arma::Row<double>& weights)
44 if (labels.n_elem == 0)
52 arma::vec countSpace(4 * numClasses, arma::fill::zeros);
53 arma::vec counts(countSpace.memptr(), numClasses,
false,
true);
54 arma::vec counts2(countSpace.memptr() + numClasses, numClasses,
false,
56 arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses,
false,
58 arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses,
false,
64 double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
68 for (
size_t i = 3; i < labels.n_elem; i += 4)
70 const double weight1 = weights[i - 3];
71 const double weight2 = weights[i - 2];
72 const double weight3 = weights[i - 1];
73 const double weight4 = weights[i];
75 counts[labels[i - 3]] += weight1;
76 counts2[labels[i - 2]] += weight2;
77 counts3[labels[i - 1]] += weight3;
78 counts4[labels[i]] += weight4;
80 accWeights[0] += weight1;
81 accWeights[1] += weight2;
82 accWeights[2] += weight3;
83 accWeights[3] += weight4;
87 if (labels.n_elem % 4 == 1)
89 const double weight1 = weights[labels.n_elem - 1];
90 counts[labels[labels.n_elem - 1]] += weight1;
91 accWeights[0] += weight1;
93 else if (labels.n_elem % 4 == 2)
95 const double weight1 = weights[labels.n_elem - 2];
96 const double weight2 = weights[labels.n_elem - 1];
98 counts[labels[labels.n_elem - 2]] += weight1;
99 counts2[labels[labels.n_elem - 1]] += weight2;
101 accWeights[0] += weight1;
102 accWeights[1] += weight2;
104 else if (labels.n_elem % 4 == 3)
106 const double weight1 = weights[labels.n_elem - 3];
107 const double weight2 = weights[labels.n_elem - 2];
108 const double weight3 = weights[labels.n_elem - 1];
110 counts[labels[labels.n_elem - 3]] += weight1;
111 counts2[labels[labels.n_elem - 2]] += weight2;
112 counts3[labels[labels.n_elem - 1]] += weight3;
114 accWeights[0] += weight1;
115 accWeights[1] += weight2;
116 accWeights[2] += weight3;
119 accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
120 counts += counts2 + counts3 + counts4;
123 if (accWeights[0] == 0.0)
126 for (
size_t i = 0; i < numClasses; ++i)
128 const double f = ((double) counts[i] / (
double) accWeights[0]);
130 gain += f * std::log2(f);
137 for (
size_t i = 3; i < labels.n_elem; i += 4)
139 counts[labels[i - 3]]++;
140 counts2[labels[i - 2]]++;
141 counts3[labels[i - 1]]++;
142 counts4[labels[i]]++;
146 if (labels.n_elem % 4 == 1)
148 counts[labels[labels.n_elem - 1]]++;
150 else if (labels.n_elem % 4 == 2)
152 counts[labels[labels.n_elem - 2]]++;
153 counts2[labels[labels.n_elem - 1]]++;
155 else if (labels.n_elem % 4 == 3)
157 counts[labels[labels.n_elem - 3]]++;
158 counts2[labels[labels.n_elem - 2]]++;
159 counts3[labels[labels.n_elem - 1]]++;
162 counts += counts2 + counts3 + counts4;
164 for (
size_t i = 0; i < numClasses; ++i)
166 const double f = ((double) counts[i] / (
double) labels.n_elem);
168 gain += f * std::log2(f);
182 static double Range(
const size_t numClasses)
187 return std::log2(numClasses);
The core includes that mlpack expects; standard C++ includes and Armadillo.