he_init.hpp
Go to the documentation of this file.
1 
16 #ifndef MLPACK_METHODS_ANN_INIT_RULES_HE_INIT_HPP
17 #define MLPACK_METHODS_ANN_INIT_RULES_HE_INIT_HPP
18 
19 #include <mlpack/prereqs.hpp>
21 
22 namespace mlpack {
23 namespace ann {
24 
46 {
47  public:
52  {
53  // Nothing to do here.
54  }
55 
64  void Initialize(arma::mat& W, const size_t rows, const size_t cols)
65  {
66  // He initialization rule says to initialize weights with random
67  // values taken from a gaussian distribution with mean = 0 and
68  // standard deviation = sqrt(2/rows), i.e. variance = (2/rows).
69  const double variance = 2.0 / (double)rows;
70 
71  if (W.is_empty())
72  {
73  W.set_size(rows, cols);
74  }
75 
76  // Multipling a random variable X with variance V(X) by some factor c,
77  // then the variance V(cX) = (c^2) * V(X).
78  W.imbue( [&]() { return sqrt(variance) * arma::randn(); } );
79  }
80 
90  void Initialize(arma::cube & W,
91  const size_t rows,
92  const size_t cols,
93  const size_t slices)
94  {
95  if (W.is_empty())
96  W.set_size(rows, cols, slices);
97 
98  for (size_t i = 0; i < slices; i++)
99  Initialize(W.slice(i), rows, cols);
100  }
101 }; // class HeInitialization
102 
103 } // namespace ann
104 } // namespace mlpack
105 
106 #endif
HeInitialization()
Initialize the HeInitialization object.
Definition: he_init.hpp:51
strip_type.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Initialize(arma::mat &W, const size_t rows, const size_t cols)
Initialize the elements of the weight matrix with the He initialization rule.
Definition: he_init.hpp:64
void Initialize(arma::cube &W, const size_t rows, const size_t cols, const size_t slices)
Initialize the elements of the specified weight 3rd order tensor with He initialization rule...
Definition: he_init.hpp:90
This class is used to initialize weight matrix with the He initialization rule given by He et...
Definition: he_init.hpp:45
Miscellaneous math random-related routines.