gini_gain.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP
14 #define MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP
15 
16 #include <mlpack/core.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
27 class GiniGain
28 {
29  public:
43  template<bool UseWeights, typename RowType, typename WeightVecType>
44  static double Evaluate(const RowType& labels,
45  const size_t numClasses,
46  const WeightVecType& weights)
47  {
48  // Corner case: if there are no elements, the impurity is zero.
49  if (labels.n_elem == 0)
50  return 0.0;
51 
52  // Count the number of elements in each class. Use four auxiliary vectors
53  // to exploit SIMD instructions if possible.
54  arma::vec countSpace(4 * numClasses, arma::fill::zeros);
55  arma::vec counts(countSpace.memptr(), numClasses, false, true);
56  arma::vec counts2(countSpace.memptr() + numClasses, numClasses, false,
57  true);
58  arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses, false,
59  true);
60  arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses, false,
61  true);
62 
63  // Calculate the Gini impurity of the un-split node.
64  double impurity = 0.0;
65 
66  if (UseWeights)
67  {
68  // Sum all the weights up.
69  double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
70 
71  // SIMD loop: add counts for four elements simultaneously (if the compiler
72  // manages to vectorize the loop).
73  for (size_t i = 3; i < labels.n_elem; i += 4)
74  {
75  const double weight1 = weights[i - 3];
76  const double weight2 = weights[i - 2];
77  const double weight3 = weights[i - 1];
78  const double weight4 = weights[i];
79 
80  counts[labels[i - 3]] += weight1;
81  counts2[labels[i - 2]] += weight2;
82  counts3[labels[i - 1]] += weight3;
83  counts4[labels[i]] += weight4;
84 
85  accWeights[0] += weight1;
86  accWeights[1] += weight2;
87  accWeights[2] += weight3;
88  accWeights[3] += weight4;
89  }
90 
91  // Handle leftovers.
92  if (labels.n_elem % 4 == 1)
93  {
94  const double weight1 = weights[labels.n_elem - 1];
95  counts[labels[labels.n_elem - 1]] += weight1;
96  accWeights[0] += weight1;
97  }
98  else if (labels.n_elem % 4 == 2)
99  {
100  const double weight1 = weights[labels.n_elem - 2];
101  const double weight2 = weights[labels.n_elem - 1];
102 
103  counts[labels[labels.n_elem - 2]] += weight1;
104  counts2[labels[labels.n_elem - 1]] += weight2;
105 
106  accWeights[0] += weight1;
107  accWeights[1] += weight2;
108  }
109  else if (labels.n_elem % 4 == 3)
110  {
111  const double weight1 = weights[labels.n_elem - 3];
112  const double weight2 = weights[labels.n_elem - 2];
113  const double weight3 = weights[labels.n_elem - 1];
114 
115  counts[labels[labels.n_elem - 3]] += weight1;
116  counts2[labels[labels.n_elem - 2]] += weight2;
117  counts3[labels[labels.n_elem - 1]] += weight3;
118 
119  accWeights[0] += weight1;
120  accWeights[1] += weight2;
121  accWeights[2] += weight3;
122  }
123 
124  accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
125  counts += counts2 + counts3 + counts4;
126 
127  // Catch edge case: if there are no weights, the impurity is zero.
128  if (accWeights[0] == 0.0)
129  return 0.0;
130 
131  for (size_t i = 0; i < numClasses; ++i)
132  {
133  const double f = ((double) counts[i] / (double) accWeights[0]);
134  impurity += f * (1.0 - f);
135  }
136  }
137  else
138  {
139  // SIMD loop: add counts for four elements simultaneously (if the compiler
140  // manages to vectorize the loop).
141  for (size_t i = 3; i < labels.n_elem; i += 4)
142  {
143  counts[labels[i - 3]]++;
144  counts2[labels[i - 2]]++;
145  counts3[labels[i - 1]]++;
146  counts4[labels[i]]++;
147  }
148 
149  // Handle leftovers.
150  if (labels.n_elem % 4 == 1)
151  {
152  counts[labels[labels.n_elem - 1]]++;
153  }
154  else if (labels.n_elem % 4 == 2)
155  {
156  counts[labels[labels.n_elem - 2]]++;
157  counts2[labels[labels.n_elem - 1]]++;
158  }
159  else if (labels.n_elem % 4 == 3)
160  {
161  counts[labels[labels.n_elem - 3]]++;
162  counts2[labels[labels.n_elem - 2]]++;
163  counts3[labels[labels.n_elem - 1]]++;
164  }
165 
166  counts += counts2 + counts3 + counts4;
167 
168  for (size_t i = 0; i < numClasses; ++i)
169  {
170  const double f = ((double) counts[i] / (double) labels.n_elem);
171  impurity += f * (1.0 - f);
172  }
173  }
174 
175  return -impurity;
176  }
177 
185  static double Range(const size_t numClasses)
186  {
187  // The best possible case is that only one class exists, which gives a Gini
188  // impurity of 0. The worst possible case is that the classes are evenly
189  // distributed, which gives n * (1/n * (1 - 1/n)) = 1 - 1/n.
190  return 1.0 - (1.0 / double(numClasses));
191  }
192 };
193 
194 } // namespace tree
195 } // namespace mlpack
196 
197 #endif
.hpp
Definition: add_to_po.hpp:21
static double Evaluate(const RowType &labels, const size_t numClasses, const WeightVecType &weights)
Evaluate the Gini impurity on the given set of labels.
Definition: gini_gain.hpp:44
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
static double Range(const size_t numClasses)
Return the range of the Gini impurity for the given number of classes.
Definition: gini_gain.hpp:185