11 #ifndef MLPACK_METHODS_ANN_RBM_RBM_HPP 12 #define MLPACK_METHODS_ANN_RBM_RBM_HPP 34 typename InitializationRuleType,
35 typename DataType = arma::mat,
36 typename PolicyType = BinaryRBM
60 RBM(arma::Mat<ElemType> predictors,
61 InitializationRuleType initializeRule,
62 const size_t visibleSize,
63 const size_t hiddenSize,
64 const size_t batchSize = 1,
65 const size_t numSteps = 1,
66 const size_t negSteps = 1,
67 const size_t poolSize = 2,
68 const ElemType slabPenalty = 8,
69 const ElemType radius = 1,
70 const bool persistence =
false);
73 template<
typename Policy = PolicyType>
74 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
78 template<
typename Policy = PolicyType>
79 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
92 template<
typename OptimizerType>
93 double Train(OptimizerType& optimizer);
103 double Evaluate(
const arma::Mat<ElemType>& parameters,
105 const size_t batchSize);
114 template<
typename Policy = PolicyType>
115 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
double>::type
128 template<
typename Policy = PolicyType>
129 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
139 template<
typename Policy = PolicyType>
140 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
141 Phase(DataType&& input, DataType&& gradient);
149 template<
typename Policy = PolicyType>
150 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
151 Phase(DataType&& input, DataType&& gradient);
160 template<
typename Policy = PolicyType>
161 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
162 SampleHidden(arma::Mat<ElemType>&& input, arma::Mat<ElemType>&& output);
174 template<
typename Policy = PolicyType>
175 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
176 SampleHidden(arma::Mat<ElemType>&& input, arma::Mat<ElemType>&& output);
185 template<
typename Policy = PolicyType>
186 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
187 SampleVisible(arma::Mat<ElemType>&& input, arma::Mat<ElemType>&& output);
199 template<
typename Policy = PolicyType>
200 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
201 SampleVisible(arma::Mat<ElemType>&& input, arma::Mat<ElemType>&& output);
209 template<
typename Policy = PolicyType>
210 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
221 template<
typename Policy = PolicyType>
222 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
231 template<
typename Policy = PolicyType>
232 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
233 HiddenMean(DataType&& input, DataType&& output);
245 template<
typename Policy = PolicyType>
246 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
247 HiddenMean(DataType&& input, DataType&& output);
257 template<
typename Policy = PolicyType>
258 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
259 SpikeMean(DataType&& visible, DataType&& spikeMean);
266 template<
typename Policy = PolicyType>
267 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
268 SampleSpike(DataType&& spikeMean, DataType&& spike);
279 template<
typename Policy = PolicyType>
280 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
281 SlabMean(DataType&& visible, DataType&& spike, DataType&& slabMean);
293 template<
typename Policy = PolicyType>
294 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
295 SampleSlab(DataType&& slabMean, DataType&& slab);
304 void Gibbs(arma::Mat<ElemType>&& input,
305 arma::Mat<ElemType>&& output,
306 const size_t steps = SIZE_MAX);
316 void Gradient(
const arma::Mat<ElemType>& parameters,
318 arma::Mat<ElemType>& gradient,
319 const size_t batchSize);
334 const arma::Mat<ElemType>&
Parameters()
const {
return parameter; }
339 arma::Cube<ElemType>
const&
Weight()
const {
return weight; }
341 arma::Cube<ElemType>&
Weight() {
return weight; }
371 size_t const&
PoolSize()
const {
return poolSize; }
374 template<
typename Archive>
375 void serialize(Archive& ar,
const unsigned int );
379 arma::Mat<ElemType> parameter;
381 arma::Mat<ElemType> predictors;
383 InitializationRuleType initializeRule;
385 arma::Mat<ElemType> state;
403 arma::Cube<ElemType> weight;
405 DataType visibleBias;
409 DataType preActivation;
413 DataType visiblePenalty;
415 DataType visibleMean;
419 DataType spikeSamples;
423 ElemType slabPenalty;
427 arma::Mat<ElemType> hiddenReconstruction;
429 arma::Mat<ElemType> visibleReconstruction;
431 arma::Mat<ElemType> negativeSamples;
433 arma::Mat<ElemType> negativeGradient;
435 arma::Mat<ElemType> tempNegativeGradient;
437 arma::Mat<ElemType> positiveGradient;
439 arma::Mat<ElemType> gibbsTemporary;
449 #include "rbm_impl.hpp" 450 #include "spike_slab_rbm_impl.hpp" DataType const & HiddenBias() const
Return the hidden bias of the network.
void Shuffle()
Shuffle the order of function visitation.
DataType & VisibleBias()
Modify the visible bias of the network.
void Gibbs(arma::Mat< ElemType > &&input, arma::Mat< ElemType > &&output, const size_t steps=SIZE_MAX)
This function does the k-step Gibbs Sampling.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSpike(DataType &&spikeMean, DataType &&spike)
The function samples the spike function using Bernoulli distribution.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(DataType &&input, DataType &&output)
The function calculates the mean for the visible layer.
size_t const & VisibleSize() const
Get the visible size.
void Gradient(const arma::Mat< ElemType > ¶meters, const size_t i, arma::Mat< ElemType > &gradient, const size_t batchSize)
Calculates the gradients for the RBM network.
DataType & SpikeBias()
Modify the regularizer associated with spike variables.
DataType::elem_type ElemType
double Evaluate(const arma::Mat< ElemType > ¶meters, const size_t i, const size_t batchSize)
Evaluate the RBM network with the given parameters.
DataType & HiddenBias()
Modify the hidden bias of the network.
arma::Cube< ElemType > & Weight()
Modify the weights of the network.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSlab(DataType &&slabMean, DataType &&slab)
The function samples from the Normal distribution of P(s|v, h), where the mean is given by: $h_i*^{-1...
The implementation of the RBM module.
DataType const & SpikeBias() const
Get the regularizer associated with spike variables.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(DataType &&input, DataType &&gradient)
Calculates the gradient of the RBM network on the provided input.
DataType const & VisibleBias() const
Return the visible bias of the network.
DataType & VisiblePenalty()
Modify the regularizer associated with visible variables.
arma::Mat< ElemType > & Parameters()
Modify the parameters of the network.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(arma::Mat< ElemType > &&input, arma::Mat< ElemType > &&output)
This function samples the hidden layer given the visible layer using Bernoulli function.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
const arma::Mat< ElemType > & Parameters() const
Return the parameters of the network.
ElemType const & SlabPenalty() const
Get the regularizer associated with slab variables.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
arma::Cube< ElemType > const & Weight() const
Get the weights of the network.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(DataType &&input, DataType &&output)
The function calculates the mean for the hidden layer.
DataType const & VisiblePenalty() const
Get the regularizer associated with visible variables.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(arma::Mat< ElemType > &&input)
This function calculates the free energy of the BinaryRBM.
size_t NumSteps() const
Return the number of steps of Gibbs Sampling.
double Train(OptimizerType &optimizer)
Train the RBM on the given input data.
size_t const & HiddenSize() const
Get the hidden size.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SpikeMean(DataType &&visible, DataType &&spikeMean)
The function calculates the mean of the distribution P(h|v), where mean is given by: $sigm(v^T*W_i*^{...
size_t const & PoolSize() const
Get the pool size.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Reset()
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &&input, arma::Mat< ElemType > &&output)
This function samples the visible layer given the hidden layer using Bernoulli function.
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
RBM(arma::Mat< ElemType > predictors, InitializationRuleType initializeRule, const size_t visibleSize, const size_t hiddenSize, const size_t batchSize=1, const size_t numSteps=1, const size_t negSteps=1, const size_t poolSize=2, const ElemType slabPenalty=8, const ElemType radius=1, const bool persistence=false)
Initialize all the parameters of the network using initializeRule.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SlabMean(DataType &&visible, DataType &&spike, DataType &&slabMean)
The function calculates the mean of Normal distribution of P(s|v, h), where the mean is given by: $h_...