lmnn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_LMNN_LMNN_HPP
13 #define MLPACK_METHODS_LMNN_LMNN_HPP
14 
15 #include <mlpack/prereqs.hpp>
17 #include <ensmallen.hpp>
18 
19 #include "lmnn_function.hpp"
20 
21 namespace mlpack {
22 namespace lmnn {
23 
53 template<typename MetricType = metric::SquaredEuclideanDistance,
54  typename OptimizerType = ens::AMSGrad>
55 class LMNN
56 {
57  public:
68  LMNN(const arma::mat& dataset,
69  const arma::Row<size_t>& labels,
70  const size_t k,
71  const MetricType metric = MetricType());
72 
73 
82  void LearnDistance(arma::mat& outputMatrix);
83 
84 
86  const arma::mat& Dataset() const { return dataset; }
87 
89  const arma::Row<size_t>& Labels() const { return labels; }
90 
92  const double& Regularization() const { return regularization; }
94  double& Regularization() { return regularization; }
95 
97  const size_t& Range() const { return range; }
99  size_t& Range() { return range; }
100 
102  const size_t& K() const { return k; }
104  size_t K() { return k; }
105 
107  const OptimizerType& Optimizer() const { return optimizer; }
108  OptimizerType& Optimizer() { return optimizer; }
109 
110  private:
112  const arma::mat& dataset;
113 
115  const arma::Row<size_t>& labels;
116 
118  size_t k;
119 
121  double regularization;
122 
124  size_t range;
125 
127  MetricType metric;
128 
130  OptimizerType optimizer;
131 }; // class LMNN
132 
133 } // namespace lmnn
134 } // namespace mlpack
135 
136 // Include the implementation.
137 #include "lmnn_impl.hpp"
138 
139 #endif
const size_t & Range() const
Access the range value.
Definition: lmnn.hpp:97
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t K()
Modify the value of k.
Definition: lmnn.hpp:104
void LearnDistance(arma::mat &outputMatrix)
Perform Large Margin Nearest Neighbors metric learning.
size_t & Range()
Modify the range value.
Definition: lmnn.hpp:99
An implementation of Large Margin nearest neighbor metric learning technique.
Definition: lmnn.hpp:55
LMetric< 2, false > SquaredEuclideanDistance
The squared Euclidean (L2) distance.
Definition: lmetric.hpp:107
OptimizerType & Optimizer()
Definition: lmnn.hpp:108
const OptimizerType & Optimizer() const
Get the optimizer.
Definition: lmnn.hpp:107
const double & Regularization() const
Access the regularization value.
Definition: lmnn.hpp:92
const arma::Row< size_t > & Labels() const
Get the labels reference.
Definition: lmnn.hpp:89
const size_t & K() const
Access the value of k.
Definition: lmnn.hpp:102
const arma::mat & Dataset() const
Get the dataset reference.
Definition: lmnn.hpp:86
LMNN(const arma::mat &dataset, const arma::Row< size_t > &labels, const size_t k, const MetricType metric=MetricType())
Initialize the LMNN object, passing a dataset (distance metric is learned using this dataset) and lab...
double & Regularization()
Modify the regularization value.
Definition: lmnn.hpp:94