constr_structure_group.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_CORE_OPTIMIZERS_FW_CONSTR_STRUCTURE_GROUP_HPP
15 #define MLPACK_CORE_OPTIMIZERS_FW_CONSTR_STRUCTURE_GROUP_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 #include "constr_lpball.hpp"
19 
20 namespace mlpack {
21 namespace optimization {
22 
67 template<typename GroupType>
69 {
70  public:
77  ConstrStructGroupSolver(GroupType& groupExtractor) :
78  groupExtractor(groupExtractor)
79  { /* Nothing to do */ }
80 
87  void Optimize(const arma::mat& v, arma::mat& s)
88  {
89  size_t nGroups = groupExtractor.NumGroups();
90  double dualNorm = 0;
91  size_t optimalGroup = 1;
92 
93  // Find the optimal group.
94  for (size_t i = 1; i <= nGroups; ++i)
95  {
96  arma::vec y;
97  groupExtractor.ProjectToGroup(v, i, y);
98  double newNorm = groupExtractor.DualNorm(y, i);
99 
100  // Find the group with largest dual norm.
101  if (newNorm > dualNorm)
102  {
103  optimalGroup = i;
104  dualNorm = newNorm;
105  }
106  }
107 
108  groupExtractor.OptimalFromGroup(v, optimalGroup, s);
109  }
110 
111  private:
113  GroupType& groupExtractor;
114 };
115 
122 {
123  public:
131  GroupLpBall(const double p,
132  const size_t dimOrig,
133  std::vector<arma::uvec> groupIndicesList):
134  p(p), numGroups(groupIndicesList.size()),
135  dimOrig(dimOrig),
136  groupIndicesList(groupIndicesList),
137  lpBallSolver(p)
138  {/* Nothing to do. */}
139 
147  void ProjectToGroup(const arma::mat& v, const size_t groupId, arma::vec& y)
148  {
149  arma::uvec& indList = groupIndicesList[groupId - 1];
150  size_t dim = indList.n_elem;
151  y.set_size(dim);
152 
153  for (size_t i = 0; i < dim; ++i)
154  y(i) = v(indList(i));
155  }
156 
165  void OptimalFromGroup(const arma::mat& v, const size_t groupId, arma::mat& s)
166  {
167  // Project v to group.
168  arma::vec yk;
169  ProjectToGroup(v, groupId, yk);
170 
171  // Optimize in this group.
172  arma::vec sProj(yk.n_elem);
173  lpBallSolver.Optimize(yk, sProj);
174 
175  // Recover s to the original dimension.
176  arma::uvec& indList = groupIndicesList[groupId - 1];
177  size_t dim = indList.n_elem; // dimension of the group.
178  s.zeros(dimOrig, 1);
179 
180  for (size_t i = 0; i < dim; ++i)
181  s(indList(i)) = sProj(i);
182  }
183 
185  size_t NumGroups() const {return numGroups;}
187  size_t& NumGroups() {return numGroups;}
188 
195  double DualNorm(const arma::vec& yk, const int groupId)
196  {
197  if (p == std::numeric_limits<double>::infinity())
198  {
199  // inf-norm, return 1-norm
200  return arma::norm(yk, 1);
201  }
202  else if (p > 1.0)
203  {
204  // p norm, return q-norm
205  double q = 1.0 / (1.0 - 1.0/p);
206  return arma::norm(yk, q);
207  }
208  else if (p == 1.0)
209  {
210  // 1-norm, return inf-norm
211  return arma::norm(yk, "inf");
212  }
213  else
214  {
215  Log::Fatal << "Wrong norm p!" << std::endl;
216  return 0.0;
217  }
218  }
219 
220  private:
223  double p;
224 
226  size_t numGroups;
227 
229  size_t dimOrig;
230 
232  std::vector<arma::uvec> groupIndicesList;
233 
235  ConstrLpBallSolver lpBallSolver;
236 };
237 
238 
239 } // namespace optimization
240 } // namespace mlpack
241 
242 
243 
244 #endif
void ProjectToGroup(const arma::mat &v, const size_t groupId, arma::vec &y)
Projection to specific group.
GroupLpBall(const double p, const size_t dimOrig, std::vector< arma::uvec > groupIndicesList)
Construct the lp ball group extractor class.
void OptimalFromGroup(const arma::mat &v, const size_t groupId, arma::mat &s)
Get optimal atom, which belongs to specific group.
.hpp
Definition: add_to_po.hpp:21
double DualNorm(const arma::vec &yk, const int groupId)
Compute the q-norm of yk, 1/p+1/q=1.
void Optimize(const arma::mat &v, arma::mat &s)
Optimizer of structure group ball constrained Problem for FrankWolfe.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Linear Constrained Solver for FrankWolfe.
size_t NumGroups() const
Get the number of groups.
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
size_t & NumGroups()
Modify the number of groups.
LinearConstrSolver for FrankWolfe algorithm.
Implementation of Structured Group.
ConstrStructGroupSolver(GroupType &groupExtractor)
Construct the structure group optimization solver.