information_gain.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
14 #define MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
26 {
27  public:
38  template<bool UseWeights>
39  static double Evaluate(const arma::Row<size_t>& labels,
40  const size_t numClasses,
41  const arma::Row<double>& weights)
42  {
43  // Edge case: if there are no elements, the gain is zero.
44  if (labels.n_elem == 0)
45  return 0.0;
46 
47  // Calculate the information gain.
48  double gain = 0.0;
49 
50  // Count the number of elements in each class. Use four auxiliary vectors
51  // to exploit SIMD instructions if possible.
52  arma::vec countSpace(4 * numClasses, arma::fill::zeros);
53  arma::vec counts(countSpace.memptr(), numClasses, false, true);
54  arma::vec counts2(countSpace.memptr() + numClasses, numClasses, false,
55  true);
56  arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses, false,
57  true);
58  arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses, false,
59  true);
60 
61  if (UseWeights)
62  {
63  // Sum all the weights up.
64  double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
65 
66  // SIMD loop: add counts for four elements simultaneously (if the compiler
67  // manages to vectorize the loop).
68  for (size_t i = 3; i < labels.n_elem; i += 4)
69  {
70  const double weight1 = weights[i - 3];
71  const double weight2 = weights[i - 2];
72  const double weight3 = weights[i - 1];
73  const double weight4 = weights[i];
74 
75  counts[labels[i - 3]] += weight1;
76  counts2[labels[i - 2]] += weight2;
77  counts3[labels[i - 1]] += weight3;
78  counts4[labels[i]] += weight4;
79 
80  accWeights[0] += weight1;
81  accWeights[1] += weight2;
82  accWeights[2] += weight3;
83  accWeights[3] += weight4;
84  }
85 
86  // Handle leftovers.
87  if (labels.n_elem % 4 == 1)
88  {
89  const double weight1 = weights[labels.n_elem - 1];
90  counts[labels[labels.n_elem - 1]] += weight1;
91  accWeights[0] += weight1;
92  }
93  else if (labels.n_elem % 4 == 2)
94  {
95  const double weight1 = weights[labels.n_elem - 2];
96  const double weight2 = weights[labels.n_elem - 1];
97 
98  counts[labels[labels.n_elem - 2]] += weight1;
99  counts2[labels[labels.n_elem - 1]] += weight2;
100 
101  accWeights[0] += weight1;
102  accWeights[1] += weight2;
103  }
104  else if (labels.n_elem % 4 == 3)
105  {
106  const double weight1 = weights[labels.n_elem - 3];
107  const double weight2 = weights[labels.n_elem - 2];
108  const double weight3 = weights[labels.n_elem - 1];
109 
110  counts[labels[labels.n_elem - 3]] += weight1;
111  counts2[labels[labels.n_elem - 2]] += weight2;
112  counts3[labels[labels.n_elem - 1]] += weight3;
113 
114  accWeights[0] += weight1;
115  accWeights[1] += weight2;
116  accWeights[2] += weight3;
117  }
118 
119  accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
120  counts += counts2 + counts3 + counts4;
121 
122  // Corner case: return 0 if no weight.
123  if (accWeights[0] == 0.0)
124  return 0.0;
125 
126  for (size_t i = 0; i < numClasses; ++i)
127  {
128  const double f = ((double) counts[i] / (double) accWeights[0]);
129  if (f > 0.0)
130  gain += f * std::log2(f);
131  }
132  }
133  else
134  {
135  // SIMD loop: add counts for four elements simultaneously (if the compiler
136  // manages to vectorize the loop).
137  for (size_t i = 3; i < labels.n_elem; i += 4)
138  {
139  counts[labels[i - 3]]++;
140  counts2[labels[i - 2]]++;
141  counts3[labels[i - 1]]++;
142  counts4[labels[i]]++;
143  }
144 
145  // Handle leftovers.
146  if (labels.n_elem % 4 == 1)
147  {
148  counts[labels[labels.n_elem - 1]]++;
149  }
150  else if (labels.n_elem % 4 == 2)
151  {
152  counts[labels[labels.n_elem - 2]]++;
153  counts2[labels[labels.n_elem - 1]]++;
154  }
155  else if (labels.n_elem % 4 == 3)
156  {
157  counts[labels[labels.n_elem - 3]]++;
158  counts2[labels[labels.n_elem - 2]]++;
159  counts3[labels[labels.n_elem - 1]]++;
160  }
161 
162  counts += counts2 + counts3 + counts4;
163 
164  for (size_t i = 0; i < numClasses; ++i)
165  {
166  const double f = ((double) counts[i] / (double) labels.n_elem);
167  if (f > 0.0)
168  gain += f * std::log2(f);
169  }
170  }
171 
172  return gain;
173  }
174 
182  static double Range(const size_t numClasses)
183  {
184  // The best possible case gives an information gain of 0. The worst
185  // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
186  // log2(1/n) = -log2(n). So, the range is log2(n).
187  return std::log2(numClasses);
188  }
189 };
190 
191 } // namespace tree
192 } // namespace mlpack
193 
194 #endif
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
static double Evaluate(const arma::Row< size_t > &labels, const size_t numClasses, const arma::Row< double > &weights)
Given a set of labels, calculate the information gain of those labels.
The standard information gain criterion, used for calculating gain in decision trees.
static double Range(const size_t numClasses)
Return the range of the information gain for the given number of classes.