13 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_UPDATE_VISITOR_HPP 14 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_UPDATE_VISITOR_HPP 18 #include <boost/variant.hpp> 33 template<
typename LayerType>
45 typename std::enable_if<
46 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
47 !HasModelCheck<T>::value,
size_t>::type
48 LayerGradients(T* layer, arma::mat& input)
const;
52 typename std::enable_if<
53 !HasGradientCheck<T, arma::mat&(T::*)()>::value &&
54 HasModelCheck<T>::value,
size_t>::type
55 LayerGradients(T* layer, arma::mat& input)
const;
60 typename std::enable_if<
61 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
62 HasModelCheck<T>::value,
size_t>::type
63 LayerGradients(T* layer, arma::mat& input)
const;
67 template<
typename T,
typename P>
68 typename std::enable_if<
69 !HasGradientCheck<T, P&(T::*)()>::value &&
70 !HasModelCheck<T>::value,
size_t>::type
71 LayerGradients(T* layer, P& input)
const;
78 #include "gradient_update_visitor_impl.hpp"
GradientUpdateVisitor update the gradient parameter given the gradient set.
size_t operator()(LayerType *layer) const
Update the gradient parameter.
GradientUpdateVisitor(arma::mat &&gradient, size_t offset=0)
Update the gradient parameter given the gradient set.