bernoulli_distribution.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_HPP
13 #define MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include "../activation_functions/logistic_function.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
33 template <typename DataType = arma::mat>
35 {
36  public:
42 
63  BernoulliDistribution(const DataType& param,
64  const bool applyLogistic = true,
65  const double eps = 1e-10);
66 
72  double Probability(const DataType& observation) const
73  {
74  return std::exp(LogProbability(observation));
75  }
76 
82  double LogProbability(const DataType& observation) const;
83 
91  void LogProbBackward(const DataType& observation, DataType& output) const;
92 
99  DataType Sample() const;
100 
102  const DataType& Probability() const { return probability; }
103 
105  DataType& Probability() { return probability; }
106 
108  const DataType& Logits() const { return logits; }
109 
111  DataType& Logits() { return logits; }
112 
116  template<typename Archive>
117  void serialize(Archive& ar, const unsigned int /* version */)
118  {
119  // We just need to serialize each of the members.
120  ar & BOOST_SERIALIZATION_NVP(probability);
121  ar & BOOST_SERIALIZATION_NVP(logits);
122  ar & BOOST_SERIALIZATION_NVP(applyLogistic);
123  ar & BOOST_SERIALIZATION_NVP(eps);
124  }
125 
126  private:
128  DataType probability;
129 
132  DataType logits;
133 
135  bool applyLogistic;
136 
138  double eps;
139 }; // class BernoulliDistribution
140 
141 } // namespace ann
142 } // namespace mlpack
143 
144 // Include implementation.
145 #include "bernoulli_distribution_impl.hpp"
146 
147 #endif
const DataType & Probability() const
Return the probability matrix.
void LogProbBackward(const DataType &observation, DataType &output) const
Stores the gradient of the log probabilities of the observations in the output matrix.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double LogProbability(const DataType &observation) const
Return the log probabilities of the given matrix of observations.
const DataType & Logits() const
Return the logits matrix.
Multiple independent Bernoulli distributions.
void serialize(Archive &ar, const unsigned int)
Serialize the distribution.
double Probability(const DataType &observation) const
Return the probabilities of the given matrix of observations.
DataType & Logits()
Return a modifiable copy of the pre probability matrix.
DataType & Probability()
Return a modifiable copy of the probability matrix.
BernoulliDistribution()
Default constructor, which creates a Bernoulli distribution with zero dimension.
DataType Sample() const
Return a matrix of randomly generated samples according to the probability distributions defined by t...