backward_visitor.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_VISITOR_BACKWARD_VISITOR_HPP
14 #define MLPACK_METHODS_ANN_VISITOR_BACKWARD_VISITOR_HPP
15 
18 
19 #include <boost/variant.hpp>
20 
21 namespace mlpack {
22 namespace ann {
23 
28 class BackwardVisitor : public boost::static_visitor<void>
29 {
30  public:
33  BackwardVisitor(arma::mat&& input, arma::mat&& error, arma::mat&& delta);
34 
36  BackwardVisitor(arma::mat&& input, arma::mat&& error, arma::mat&& delta,
37  const size_t index);
38 
40  template<typename LayerType>
41  void operator()(LayerType* layer) const;
42 
43  void operator()(MoreTypes layer) const;
44 
45  private:
47  arma::mat&& input;
48 
50  arma::mat&& error;
51 
53  arma::mat&& delta;
54 
56  size_t index;
57 
59  bool hasIndex;
60 
63  template<typename T>
64  typename std::enable_if<
65  !HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
66  LayerBackward(T* layer, arma::mat& input) const;
67 
69  template<typename T>
70  typename std::enable_if<
71  HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
72  LayerBackward(T* layer, arma::mat& input) const;
73 };
74 
75 } // namespace ann
76 } // namespace mlpack
77 
78 // Include implementation.
79 #include "backward_visitor_impl.hpp"
80 
81 #endif
boost::variant< Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *> MoreTypes
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
strip_type.hpp
Definition: add_to_po.hpp:21
BackwardVisitor(arma::mat &&input, arma::mat &&error, arma::mat &&delta)
Execute the Backward() function given the input, error and delta parameter.
void operator()(LayerType *layer) const
Execute the Backward() function.