rbm.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_METHODS_ANN_RBM_RBM_HPP
12 #define MLPACK_METHODS_ANN_RBM_RBM_HPP
13 
14 #include <mlpack/core.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
33 template<
34  typename InitializationRuleType,
35  typename DataType = arma::mat,
36  typename PolicyType = BinaryRBM
37 >
38 class RBM
39 {
40  public:
42  typedef typename DataType::elem_type ElemType;
43 
60  RBM(arma::Mat<ElemType> predictors,
61  InitializationRuleType initializeRule,
62  const size_t visibleSize,
63  const size_t hiddenSize,
64  const size_t batchSize = 1,
65  const size_t numSteps = 1,
66  const size_t negSteps = 1,
67  const size_t poolSize = 2,
68  const ElemType slabPenalty = 8,
69  const ElemType radius = 1,
70  const bool persistence = false);
71 
72  // Reset the network.
73  template<typename Policy = PolicyType>
74  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
75  Reset();
76 
77  // Reset the network.
78  template<typename Policy = PolicyType>
79  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
80  Reset();
81 
92  template<typename OptimizerType>
93  double Train(OptimizerType& optimizer);
94 
103  double Evaluate(const arma::Mat<ElemType>& parameters,
104  const size_t i,
105  const size_t batchSize);
106 
114  template<typename Policy = PolicyType>
115  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, double>::type
116  FreeEnergy(arma::Mat<ElemType>&& input);
117 
128  template<typename Policy = PolicyType>
129  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
130  double>::type
131  FreeEnergy(arma::Mat<ElemType>&& input);
132 
139  template<typename Policy = PolicyType>
140  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
141  Phase(DataType&& input, DataType&& gradient);
142 
149  template<typename Policy = PolicyType>
150  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
151  Phase(DataType&& input, DataType&& gradient);
152 
160  template<typename Policy = PolicyType>
161  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
162  SampleHidden(arma::Mat<ElemType>&& input, arma::Mat<ElemType>&& output);
163 
174  template<typename Policy = PolicyType>
175  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
176  SampleHidden(arma::Mat<ElemType>&& input, arma::Mat<ElemType>&& output);
177 
185  template<typename Policy = PolicyType>
186  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
187  SampleVisible(arma::Mat<ElemType>&& input, arma::Mat<ElemType>&& output);
188 
199  template<typename Policy = PolicyType>
200  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
201  SampleVisible(arma::Mat<ElemType>&& input, arma::Mat<ElemType>&& output);
202 
209  template<typename Policy = PolicyType>
210  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
211  VisibleMean(DataType&& input, DataType&& output);
212 
221  template<typename Policy = PolicyType>
222  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
223  VisibleMean(DataType&& input, DataType&& output);
224 
231  template<typename Policy = PolicyType>
232  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
233  HiddenMean(DataType&& input, DataType&& output);
234 
245  template<typename Policy = PolicyType>
246  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
247  HiddenMean(DataType&& input, DataType&& output);
248 
257  template<typename Policy = PolicyType>
258  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
259  SpikeMean(DataType&& visible, DataType&& spikeMean);
260 
266  template<typename Policy = PolicyType>
267  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
268  SampleSpike(DataType&& spikeMean, DataType&& spike);
269 
279  template<typename Policy = PolicyType>
280  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
281  SlabMean(DataType&& visible, DataType&& spike, DataType&& slabMean);
282 
293  template<typename Policy = PolicyType>
294  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
295  SampleSlab(DataType&& slabMean, DataType&& slab);
296 
304  void Gibbs(arma::Mat<ElemType>&& input,
305  arma::Mat<ElemType>&& output,
306  const size_t steps = SIZE_MAX);
307 
316  void Gradient(const arma::Mat<ElemType>& parameters,
317  const size_t i,
318  arma::Mat<ElemType>& gradient,
319  const size_t batchSize);
320 
325  void Shuffle();
326 
328  size_t NumFunctions() const { return numFunctions; }
329 
331  size_t NumSteps() const { return numSteps; }
332 
334  const arma::Mat<ElemType>& Parameters() const { return parameter; }
336  arma::Mat<ElemType>& Parameters() { return parameter; }
337 
339  arma::Cube<ElemType> const& Weight() const { return weight; }
341  arma::Cube<ElemType>& Weight() { return weight; }
342 
344  DataType const& VisibleBias() const { return visibleBias; }
346  DataType& VisibleBias() { return visibleBias; }
347 
349  DataType const& HiddenBias() const { return hiddenBias; }
351  DataType& HiddenBias() { return hiddenBias; }
352 
354  DataType const& SpikeBias() const { return spikeBias; }
356  DataType& SpikeBias() { return spikeBias; }
357 
359  ElemType const& SlabPenalty() const { return 1.0 / slabPenalty; }
360 
362  DataType const& VisiblePenalty() const { return visiblePenalty; }
364  DataType& VisiblePenalty() { return visiblePenalty; }
365 
367  size_t const& VisibleSize() const { return visibleSize; }
369  size_t const& HiddenSize() const { return hiddenSize; }
371  size_t const& PoolSize() const { return poolSize; }
372 
374  template<typename Archive>
375  void serialize(Archive& ar, const unsigned int /* version */);
376 
377  private:
379  arma::Mat<ElemType> parameter;
381  arma::Mat<ElemType> predictors;
382  // Initializer for initializing the weights of the network.
383  InitializationRuleType initializeRule;
385  arma::Mat<ElemType> state;
387  size_t numFunctions;
389  size_t visibleSize;
391  size_t hiddenSize;
393  size_t batchSize;
395  size_t numSteps;
397  size_t negSteps;
399  size_t poolSize;
401  size_t steps;
403  arma::Cube<ElemType> weight;
405  DataType visibleBias;
407  DataType hiddenBias;
409  DataType preActivation;
411  DataType spikeBias;
413  DataType visiblePenalty;
415  DataType visibleMean;
417  DataType spikeMean;
419  DataType spikeSamples;
421  DataType slabMean;
423  ElemType slabPenalty;
425  ElemType radius;
427  arma::Mat<ElemType> hiddenReconstruction;
429  arma::Mat<ElemType> visibleReconstruction;
431  arma::Mat<ElemType> negativeSamples;
433  arma::Mat<ElemType> negativeGradient;
435  arma::Mat<ElemType> tempNegativeGradient;
437  arma::Mat<ElemType> positiveGradient;
439  arma::Mat<ElemType> gibbsTemporary;
441  bool persistence;
443  bool reset;
444 };
445 
446 } // namespace ann
447 } // namespace mlpack
448 
449 #include "rbm_impl.hpp"
450 #include "spike_slab_rbm_impl.hpp"
451 
452 #endif
DataType const & HiddenBias() const
Return the hidden bias of the network.
Definition: rbm.hpp:349
void Shuffle()
Shuffle the order of function visitation.
DataType & VisibleBias()
Modify the visible bias of the network.
Definition: rbm.hpp:346
void Gibbs(arma::Mat< ElemType > &&input, arma::Mat< ElemType > &&output, const size_t steps=SIZE_MAX)
This function does the k-step Gibbs Sampling.
.hpp
Definition: add_to_po.hpp:21
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSpike(DataType &&spikeMean, DataType &&spike)
The function samples the spike function using Bernoulli distribution.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(DataType &&input, DataType &&output)
The function calculates the mean for the visible layer.
size_t const & VisibleSize() const
Get the visible size.
Definition: rbm.hpp:367
void Gradient(const arma::Mat< ElemType > &parameters, const size_t i, arma::Mat< ElemType > &gradient, const size_t batchSize)
Calculates the gradients for the RBM network.
DataType & SpikeBias()
Modify the regularizer associated with spike variables.
Definition: rbm.hpp:356
DataType::elem_type ElemType
Definition: rbm.hpp:42
double Evaluate(const arma::Mat< ElemType > &parameters, const size_t i, const size_t batchSize)
Evaluate the RBM network with the given parameters.
DataType & HiddenBias()
Modify the hidden bias of the network.
Definition: rbm.hpp:351
arma::Cube< ElemType > & Weight()
Modify the weights of the network.
Definition: rbm.hpp:341
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSlab(DataType &&slabMean, DataType &&slab)
The function samples from the Normal distribution of P(s|v, h), where the mean is given by: $h_i*^{-1...
The implementation of the RBM module.
Definition: rbm.hpp:38
DataType const & SpikeBias() const
Get the regularizer associated with spike variables.
Definition: rbm.hpp:354
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(DataType &&input, DataType &&gradient)
Calculates the gradient of the RBM network on the provided input.
DataType const & VisibleBias() const
Return the visible bias of the network.
Definition: rbm.hpp:344
DataType & VisiblePenalty()
Modify the regularizer associated with visible variables.
Definition: rbm.hpp:364
arma::Mat< ElemType > & Parameters()
Modify the parameters of the network.
Definition: rbm.hpp:336
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(arma::Mat< ElemType > &&input, arma::Mat< ElemType > &&output)
This function samples the hidden layer given the visible layer using Bernoulli function.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
const arma::Mat< ElemType > & Parameters() const
Return the parameters of the network.
Definition: rbm.hpp:334
ElemType const & SlabPenalty() const
Get the regularizer associated with slab variables.
Definition: rbm.hpp:359
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
arma::Cube< ElemType > const & Weight() const
Get the weights of the network.
Definition: rbm.hpp:339
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(DataType &&input, DataType &&output)
The function calculates the mean for the hidden layer.
DataType const & VisiblePenalty() const
Get the regularizer associated with visible variables.
Definition: rbm.hpp:362
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(arma::Mat< ElemType > &&input)
This function calculates the free energy of the BinaryRBM.
size_t NumSteps() const
Return the number of steps of Gibbs Sampling.
Definition: rbm.hpp:331
double Train(OptimizerType &optimizer)
Train the RBM on the given input data.
size_t const & HiddenSize() const
Get the hidden size.
Definition: rbm.hpp:369
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SpikeMean(DataType &&visible, DataType &&spikeMean)
The function calculates the mean of the distribution P(h|v), where mean is given by: $sigm(v^T*W_i*^{...
size_t const & PoolSize() const
Get the pool size.
Definition: rbm.hpp:371
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Reset()
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &&input, arma::Mat< ElemType > &&output)
This function samples the visible layer given the hidden layer using Bernoulli function.
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: rbm.hpp:328
RBM(arma::Mat< ElemType > predictors, InitializationRuleType initializeRule, const size_t visibleSize, const size_t hiddenSize, const size_t batchSize=1, const size_t numSteps=1, const size_t negSteps=1, const size_t poolSize=2, const ElemType slabPenalty=8, const ElemType radius=1, const bool persistence=false)
Initialize all the parameters of the network using initializeRule.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SlabMean(DataType &&visible, DataType &&spike, DataType &&slabMean)
The function calculates the mean of Normal distribution of P(s|v, h), where the mean is given by: $h_...