13 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_HPP 14 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_HPP 19 #include <boost/variant.hpp> 37 const arma::mat& delta,
41 template<
typename LayerType>
48 const arma::mat& input;
51 const arma::mat& delta;
62 typename std::enable_if<
63 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
64 !HasRunCheck<T, bool&(T::*)(void)>::value,
void>::type
65 LayerGradients(T* layer, arma::mat& input)
const;
70 typename std::enable_if<
71 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
72 HasRunCheck<T, bool&(T::*)(void)>::value,
void>::type
73 LayerGradients(T* layer, arma::mat& input)
const;
77 template<
typename T,
typename P>
78 typename std::enable_if<
79 !HasGradientCheck<T, P&(T::*)()>::value,
void>::type
80 LayerGradients(T* layer, P& input)
const;
87 #include "gradient_visitor_impl.hpp"
Linear algebra utility functions, generally performed on matrices or vectors.
GradientVisitor(const arma::mat &input, const arma::mat &delta)
Executes the Gradient() method of the given module using the input and delta parameter.
boost::variant< Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *> MoreTypes
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
void operator()(LayerType *layer) const
Executes the Gradient() method.