information_gain.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
14 #define MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
26 {
27  public:
31  template<bool UseWeights, typename CountType>
32  static double EvaluatePtr(const CountType* counts,
33  const size_t countLength,
34  const CountType totalCount)
35  {
36  double gain = 0.0;
37 
38  for (size_t i = 0; i < countLength; ++i)
39  {
40  const double f = ((double) counts[i] / (double) totalCount);
41  if (f > 0.0)
42  gain += f * std::log2(f);
43  }
44 
45  return gain;
46  }
47 
58  template<bool UseWeights>
59  static double Evaluate(const arma::Row<size_t>& labels,
60  const size_t numClasses,
61  const arma::Row<double>& weights)
62  {
63  // Edge case: if there are no elements, the gain is zero.
64  if (labels.n_elem == 0)
65  return 0.0;
66 
67  // Calculate the information gain.
68  double gain = 0.0;
69 
70  // Count the number of elements in each class. Use four auxiliary vectors
71  // to exploit SIMD instructions if possible.
72  arma::vec countSpace(4 * numClasses, arma::fill::zeros);
73  arma::vec counts(countSpace.memptr(), numClasses, false, true);
74  arma::vec counts2(countSpace.memptr() + numClasses, numClasses, false,
75  true);
76  arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses, false,
77  true);
78  arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses, false,
79  true);
80 
81  if (UseWeights)
82  {
83  // Sum all the weights up.
84  double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
85 
86  // SIMD loop: add counts for four elements simultaneously (if the compiler
87  // manages to vectorize the loop).
88  for (size_t i = 3; i < labels.n_elem; i += 4)
89  {
90  const double weight1 = weights[i - 3];
91  const double weight2 = weights[i - 2];
92  const double weight3 = weights[i - 1];
93  const double weight4 = weights[i];
94 
95  counts[labels[i - 3]] += weight1;
96  counts2[labels[i - 2]] += weight2;
97  counts3[labels[i - 1]] += weight3;
98  counts4[labels[i]] += weight4;
99 
100  accWeights[0] += weight1;
101  accWeights[1] += weight2;
102  accWeights[2] += weight3;
103  accWeights[3] += weight4;
104  }
105 
106  // Handle leftovers.
107  if (labels.n_elem % 4 == 1)
108  {
109  const double weight1 = weights[labels.n_elem - 1];
110  counts[labels[labels.n_elem - 1]] += weight1;
111  accWeights[0] += weight1;
112  }
113  else if (labels.n_elem % 4 == 2)
114  {
115  const double weight1 = weights[labels.n_elem - 2];
116  const double weight2 = weights[labels.n_elem - 1];
117 
118  counts[labels[labels.n_elem - 2]] += weight1;
119  counts2[labels[labels.n_elem - 1]] += weight2;
120 
121  accWeights[0] += weight1;
122  accWeights[1] += weight2;
123  }
124  else if (labels.n_elem % 4 == 3)
125  {
126  const double weight1 = weights[labels.n_elem - 3];
127  const double weight2 = weights[labels.n_elem - 2];
128  const double weight3 = weights[labels.n_elem - 1];
129 
130  counts[labels[labels.n_elem - 3]] += weight1;
131  counts2[labels[labels.n_elem - 2]] += weight2;
132  counts3[labels[labels.n_elem - 1]] += weight3;
133 
134  accWeights[0] += weight1;
135  accWeights[1] += weight2;
136  accWeights[2] += weight3;
137  }
138 
139  accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
140  counts += counts2 + counts3 + counts4;
141 
142  // Corner case: return 0 if no weight.
143  if (accWeights[0] == 0.0)
144  return 0.0;
145 
146  for (size_t i = 0; i < numClasses; ++i)
147  {
148  const double f = ((double) counts[i] / (double) accWeights[0]);
149  if (f > 0.0)
150  gain += f * std::log2(f);
151  }
152  }
153  else
154  {
155  // SIMD loop: add counts for four elements simultaneously (if the compiler
156  // manages to vectorize the loop).
157  for (size_t i = 3; i < labels.n_elem; i += 4)
158  {
159  counts[labels[i - 3]]++;
160  counts2[labels[i - 2]]++;
161  counts3[labels[i - 1]]++;
162  counts4[labels[i]]++;
163  }
164 
165  // Handle leftovers.
166  if (labels.n_elem % 4 == 1)
167  {
168  counts[labels[labels.n_elem - 1]]++;
169  }
170  else if (labels.n_elem % 4 == 2)
171  {
172  counts[labels[labels.n_elem - 2]]++;
173  counts2[labels[labels.n_elem - 1]]++;
174  }
175  else if (labels.n_elem % 4 == 3)
176  {
177  counts[labels[labels.n_elem - 3]]++;
178  counts2[labels[labels.n_elem - 2]]++;
179  counts3[labels[labels.n_elem - 1]]++;
180  }
181 
182  counts += counts2 + counts3 + counts4;
183 
184  for (size_t i = 0; i < numClasses; ++i)
185  {
186  const double f = ((double) counts[i] / (double) labels.n_elem);
187  if (f > 0.0)
188  gain += f * std::log2(f);
189  }
190  }
191 
192  return gain;
193  }
194 
202  static double Range(const size_t numClasses)
203  {
204  // The best possible case gives an information gain of 0. The worst
205  // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
206  // log2(1/n) = -log2(n). So, the range is log2(n).
207  return std::log2(numClasses);
208  }
209 };
210 
211 } // namespace tree
212 } // namespace mlpack
213 
214 #endif
strip_type.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
static double Evaluate(const arma::Row< size_t > &labels, const size_t numClasses, const arma::Row< double > &weights)
Given a set of labels, calculate the information gain of those labels.
The standard information gain criterion, used for calculating gain in decision trees.
static double Range(const size_t numClasses)
Return the range of the information gain for the given number of classes.
static double EvaluatePtr(const CountType *counts, const size_t countLength, const CountType totalCount)
Evaluate the Gini impurity given a vector of class weight counts.