constr_lpball.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_OPTIMIZERS_FW_CONSTR_LPBALL_HPP
13 #define MLPACK_CORE_OPTIMIZERS_FW_CONSTR_LPBALL_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace optimization {
19 
56 {
57  public:
64  ConstrLpBallSolver(const double p) : p(p)
65  { /* Do nothing. */ }
66 
74  ConstrLpBallSolver(const double p, const arma::vec lambda) :
75  p(p), regFlag(true), lambda(lambda)
76  { /* Do nothing. */ }
77 
78 
85  void Optimize(const arma::mat& v,
86  arma::mat& s)
87  {
88  if (p == std::numeric_limits<double>::infinity())
89  {
90  // l-inf ball.
91  s = -sign(v);
92  if (regFlag)
93  s = s / lambda; // element-wise division.
94  }
95  else if (p > 1.0)
96  {
97  // lp ball with 1<p<inf.
98  if (regFlag)
99  s = v / lambda;
100  else
101  s = v;
102 
103  double q = 1 / (1.0 - 1.0 / p);
104  s = - sign(v) % pow(abs(s), q - 1); // element-wise multiplication.
105  s = arma::normalise(s, p);
106 
107  if (regFlag)
108  s = s / lambda;
109  }
110  else if (p == 1.0)
111  {
112  // l1 ball, also used in OMP.
113  if (regFlag)
114  s = arma::abs(v / lambda);
115  else
116  s = arma::abs(v);
117 
118  arma::uword k = 0;
119  s.max(k); // k is the linear index of the largest element.
120  s.zeros();
121  s(k) = - mlpack::math::Sign(v(k));
122 
123  if (regFlag)
124  s = s / lambda;
125  }
126  else
127  {
128  Log::Fatal << "Wrong norm p!" << std::endl;
129  }
130 
131  return;
132  }
133 
135  double P() const { return p; }
137  double& P() { return p;}
138 
140  bool RegFlag() const {return regFlag;}
142  bool& RegFlag() {return regFlag;}
143 
145  arma::vec Lambda() const {return lambda;}
147  arma::vec& Lambda() {return lambda;}
148 
149  private:
152  double p;
153 
155  bool regFlag = false;
156 
158  arma::vec lambda;
159 };
160 
161 } // namespace optimization
162 } // namespace mlpack
163 
164 #endif
double & P()
Modify the p-norm.
void Optimize(const arma::mat &v, arma::mat &s)
Optimizer of Linear Constrained Problem for FrankWolfe.
.hpp
Definition: add_to_po.hpp:21
T Sign(const T x)
Signum function.
Definition: lin_alg.hpp:135
arma::vec & Lambda()
Modify the regularization parameter.
double P() const
Get the p-norm.
The core includes that mlpack expects; standard C++ includes and Armadillo.
bool RegFlag() const
Get regularization flag.
ConstrLpBallSolver(const double p, const arma::vec lambda)
Construct the solver of constrained problem, with regularization parameter lambda here...
arma::vec Lambda() const
Get the regularization parameter.
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
bool & RegFlag()
Modify regularization flag.
LinearConstrSolver for FrankWolfe algorithm.
ConstrLpBallSolver(const double p)
Construct the solver of constrained problem.