atoms.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_CORE_OPTIMIZERS_FW_ATOMS_HPP
12 #define MLPACK_CORE_OPTIMIZERS_FW_ATOMS_HPP
13 
14 #include <mlpack/prereqs.hpp>
16 #include "func_sq.hpp"
17 
18 namespace mlpack {
19 namespace optimization {
20 
25 class Atoms
26 {
27  public:
28  Atoms(){ /* Nothing to do. */ }
29 
36  void AddAtom(const arma::vec& v, FuncSq& function, const double c = 0)
37  {
38  if (currentAtoms.is_empty())
39  {
40  CurrentAtoms() = v;
41  CurrentCoeffs().set_size(1);
42  CurrentCoeffs().fill(c);
43  atomSqTerm.set_size(1);
44  atomSqTerm(0) = std::pow(norm(function.MatrixA() * v, 2), 2);
45  }
46  else
47  {
48  currentAtoms.insert_cols(0, v);
49  arma::vec cVec(1);
50  cVec(0) = c;
51  currentCoeffs.insert_rows(0, cVec);
52  double tmp = std::pow(norm(function.MatrixA() * v, 2), 2);
53  arma::vec tmpVec(1);
54  tmpVec(0) = tmp;
55  atomSqTerm.insert_rows(0, tmpVec);
56  }
57  }
58 
59 
61  void RecoverVector(arma::mat& x)
62  {
63  x = currentAtoms * currentCoeffs;
64  }
65 
85  void PruneSupport(const double F, FuncSq& function)
86  {
87  arma::vec sqTerm = 0.5 * atomSqTerm % square(currentCoeffs);
88 
89  while (currentAtoms.n_cols > 1)
90  {
91  // Solve for current gradient.
92  arma::mat x;
93  RecoverVector(x);
94  arma::mat gradient(size(x));
95  function.Gradient(x, gradient);
96 
97  // Find possible atom to be deleted.
98  arma::vec gap = sqTerm -
99  currentCoeffs % trans(gradient.t() * currentAtoms);
100  arma::uword ind;
101  gap.min(ind);
102 
103  // Try deleting the atom.
104  arma::mat newAtoms(currentAtoms.n_rows, currentAtoms.n_cols - 1);
105  if (ind > 0)
106  newAtoms.cols(0, ind - 1) = currentAtoms.cols(0, ind - 1);
107  if (ind < (currentAtoms.n_cols - 1))
108  {
109  newAtoms.cols(ind, newAtoms.n_cols - 1) =
110  currentAtoms.cols(ind + 1, currentAtoms.n_cols - 1);
111  }
112 
113  // Reoptimize the coefficients, we brute-forcely reoptimize in the span,
114  // which would be used in UpdateSpan class. Alternatively, if you want to
115  // add an atom norm constraint, you could use projected gradient method,
116  // see the implementaton of ProjectedGradientEnhancement().
117  arma::vec newCoeffs =
118  solve(function.MatrixA() * newAtoms, function.Vectorb());
119 
120  // Evaluate the function again.
121  double Fnew = function.Evaluate(newAtoms * newCoeffs);
122 
123  if (Fnew > F)
124  // Should not delete the atom.
125  break;
126  else
127  {
128  // Delete the atom from current atoms.
129  currentAtoms = newAtoms;
130  currentCoeffs = newCoeffs;
131  atomSqTerm.shed_row(ind);
132  sqTerm.shed_row(ind);
133  } // else
134  } // while
135  }
136 
137 
166  double tau,
167  double stepSize,
168  size_t maxIteration = 100,
169  double tolerance = 1e-3)
170  {
171  arma::mat x;
172  RecoverVector(x);
173  double value = function.Evaluate(x);
174 
175  for (size_t iter = 1; iter<maxIteration; iter++)
176  {
177  // Update currentCoeffs with gradient descent method.
178  arma::mat g;
179  function.Gradient(x, g);
180  g = currentAtoms.t() * g;
181  currentCoeffs = currentCoeffs - stepSize * g;
182 
183  // Projection of currentCoeffs to satisfy the atom norm constraint.
184  Proximal::ProjectToL1Ball(currentCoeffs, tau);
185 
186  RecoverVector(x);
187  double valueNew = function.Evaluate(x);
188 
189  if ((value - valueNew) < tolerance)
190  break;
191 
192  value = valueNew;
193  }
194  }
195 
196 
198  const arma::vec& CurrentCoeffs() const { return currentCoeffs; }
200  arma::vec& CurrentCoeffs() { return currentCoeffs; }
201 
203  const arma::mat& CurrentAtoms() const { return currentAtoms; }
205  arma::mat& CurrentAtoms() { return currentAtoms; }
206 
207  private:
209  arma::vec currentCoeffs;
210 
212  arma::mat currentAtoms;
213 
216  arma::vec atomSqTerm;
217 }; // class Atoms
218 } // namespace optimization
219 } // namespace mlpack
220 
221 #endif
arma::vec & CurrentCoeffs()
Modify the current atom coefficients.
Definition: atoms.hpp:200
Square loss function .
Definition: func_sq.hpp:25
Class to hold the information and operations of current atoms in the soluton space.
Definition: atoms.hpp:25
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
const arma::vec & CurrentCoeffs() const
Get the current atom coefficients.
Definition: atoms.hpp:198
void RecoverVector(arma::mat &x)
Recover the solution coordinate from the coefficients of current atoms.
Definition: atoms.hpp:61
void ProjectedGradientEnhancement(FuncSq &function, double tau, double stepSize, size_t maxIteration=100, double tolerance=1e-3)
Enhance the solution in the convex hull of current atoms with atom norm constraint tau...
Definition: atoms.hpp:165
void PruneSupport(const double F, FuncSq &function)
Prune the support, delete previous atoms if they don&#39;t contribute much.
Definition: atoms.hpp:85
static void ProjectToL1Ball(arma::vec &v, double tau)
Project the vector onto the l1 ball with norm tau.
void AddAtom(const arma::vec &v, FuncSq &function, const double c=0)
Add atom into the solution space.
Definition: atoms.hpp:36
const arma::mat & CurrentAtoms() const
Get the current atoms.
Definition: atoms.hpp:203
arma::mat & CurrentAtoms()
Modify the current atoms.
Definition: atoms.hpp:205