[mlpack-svn] r17387 - mlpack/trunk/src/mlpack/methods/amf/termination_policies

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 19 12:08:48 EST 2014


Author: rcurtin
Date: Wed Nov 19 12:08:48 2014
New Revision: 17387

Log:
Refactor for cleaner code and avoid storing WH explicitly if possible.


Modified:
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp

Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp	Wed Nov 19 12:08:48 2014
@@ -13,21 +13,30 @@
 namespace amf {
 
 /**
- * This class implements simple residue based termination policy. Termination 
- * decision depends on two factors, value of residue and number of iteration. 
- * If the current value of residue drops below the threshold or the number of 
- * iterations goes above the threshold, positive termination signal is passed 
- * to AMF.
+ * This class implements a simple residue-based termination policy. The
+ * termination decision depends on two factors: the value of the residue (the
+ * difference between the norm of WH this iteration and the previous iteration),
+ * and the number of iterations.  If the current value of residue drops below
+ * the threshold or the number of iterations goes above the iteration limit,
+ * IsConverged() will return true.  This class is meant for use with the AMF
+ * (alternating matrix factorization) class.
  *
  * @see AMF
  */
 class SimpleResidueTermination
 {
  public:
-  //! empty constructor
+  /**
+   * Construct the SimpleResidueTermination object with the given minimum
+   * residue (or the default) and the given maximum number of iterations (or the
+   * default).  0 indicates no iteration limit.
+   *
+   * @param minResidue Minimum residue for termination.
+   * @param maxIterations Maximum number of iterations.
+   */
   SimpleResidueTermination(const double minResidue = 1e-10,
                            const size_t maxIterations = 10000)
-        : minResidue(minResidue), maxIterations(maxIterations) { }
+      : minResidue(minResidue), maxIterations(maxIterations) { }
 
   /**
    * Initializes the termination policy before stating the factorization.
@@ -37,57 +46,46 @@
   template<typename MatType>
   void Initialize(const MatType& V)
   {
-    // set resisue to minimum value
-    residue = minResidue;
-    // set iteration to minimum value
+    // Initialize the things we keep track of.
+    residue = DBL_MAX;
     iteration = 1;
-    // remove history
+    nm = V.n_rows * V.n_cols;
+    // Remove history.
     normOld = 0;
-
-    // initialize required variables
-    const size_t n = V.n_rows;
-    const size_t m = V.n_cols;
-    nm = n * m;
   }
 
   /**
-   * Check if termination criterio is met.
+   * Check if termination criterion is met.
    *
    * @param W Basis matrix of output.
    * @param H Encoding matrix of output.
    */
   bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    // Calculate norm of WH after each iteration.
-    arma::mat WH;
-
-    // calculate the norm and compute the residue 
-    WH = W * H;
-    double norm = sqrt(accu(WH % WH) / nm);
-    residue = fabs(normOld - norm);
-    residue /= normOld;
+    // Calculate the norm and compute the residue
+    const double norm = arma::norm(W * H, "fro");
+    residue = fabs(normOld - norm) / normOld;
 
-    // store the residue into history
+    // Store the norm.
     normOld = norm;
-    
-    // increment iteration count
+
+    // Increment iteration count
     iteration++;
-    
-    // check if termination criterion is met
-    if(residue < minResidue || iteration > maxIterations) return true;
-    else return false;
+
+    // Check if termination criterion is met.
+    return (residue < minResidue || iteration > maxIterations);
   }
 
   //! Get current value of residue
   const double& Index() const { return residue; }
 
-  //! Get current iteration count  
+  //! Get current iteration count
   const size_t& Iteration() const { return iteration; }
-  
+
   //! Access max iteration count
   const size_t& MaxIterations() const { return maxIterations; }
   size_t& MaxIterations() { return maxIterations; }
-  
+
   //! Access minimum residue value
   const double& MinResidue() const { return minResidue; }
   double& MinResidue() { return minResidue; }



More information about the mlpack-svn mailing list