orthogonal_init.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_INIT_RULES_ORTHOGONAL_INIT_HPP
13 #define MLPACK_METHODS_ANN_INIT_RULES_ORTHOGONAL_INIT_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
25 {
26  public:
32  OrthogonalInitialization(const double gain = 1.0) : gain(gain) { }
33 
42  template<typename eT>
43  void Initialize(arma::Mat<eT>& W, const size_t rows, const size_t cols)
44  {
45  arma::Mat<eT> V;
46  arma::Col<eT> s;
47 
48  arma::svd_econ(W, s, V, arma::randu<arma::Mat<eT> >(rows, cols));
49  W *= gain;
50  }
51 
61  template<typename eT>
62  void Initialize(arma::Cube<eT>& W,
63  const size_t rows,
64  const size_t cols,
65  const size_t slices)
66  {
67  W = arma::Cube<eT>(rows, cols, slices);
68 
69  for (size_t i = 0; i < slices; i++)
70  Initialize(W.slice(i), rows, cols);
71  }
72 
73  private:
75  double gain;
76 }; // class OrthogonalInitialization
77 
78 
79 } // namespace ann
80 } // namespace mlpack
81 
82 #endif
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize the elements of the specified weight 3rd order tensor with the orthogonal matrix initializ...
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements of the specified weight matrix with the orthogonal matrix initialization meth...
strip_type.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
This class is used to initialize the weight matrix with the orthogonal matrix initialization.
OrthogonalInitialization(const double gain=1.0)
Initialize the orthogonal matrix initialization rule with the given gain.