[mlpack-git] master: Refactor for cleaner code and avoid storing WH explicitly if possible. (8872fb7)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:03:37 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 8872fb790bbf565574f7adb091c03fe59f69cd17
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Nov 19 17:08:48 2014 +0000

    Refactor for cleaner code and avoid storing WH explicitly if possible.


>---------------------------------------------------------------

8872fb790bbf565574f7adb091c03fe59f69cd17
 .../simple_residue_termination.hpp                 | 58 +++++++++++-----------
 1 file changed, 28 insertions(+), 30 deletions(-)

diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index 66d7930..3b7f18e 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -13,21 +13,30 @@ namespace mlpack {
 namespace amf {
 
 /**
- * This class implements simple residue based termination policy. Termination 
- * decision depends on two factors, value of residue and number of iteration. 
- * If the current value of residue drops below the threshold or the number of 
- * iterations goes above the threshold, positive termination signal is passed 
- * to AMF.
+ * This class implements a simple residue-based termination policy. The
+ * termination decision depends on two factors: the value of the residue (the
+ * difference between the norm of WH this iteration and the previous iteration),
+ * and the number of iterations.  If the current value of residue drops below
+ * the threshold or the number of iterations goes above the iteration limit,
+ * IsConverged() will return true.  This class is meant for use with the AMF
+ * (alternating matrix factorization) class.
  *
  * @see AMF
  */
 class SimpleResidueTermination
 {
  public:
-  //! empty constructor
+  /**
+   * Construct the SimpleResidueTermination object with the given minimum
+   * residue (or the default) and the given maximum number of iterations (or the
+   * default).  0 indicates no iteration limit.
+   *
+   * @param minResidue Minimum residue for termination.
+   * @param maxIterations Maximum number of iterations.
+   */
   SimpleResidueTermination(const double minResidue = 1e-10,
                            const size_t maxIterations = 10000)
-        : minResidue(minResidue), maxIterations(maxIterations) { }
+      : minResidue(minResidue), maxIterations(maxIterations) { }
 
   /**
    * Initializes the termination policy before stating the factorization.
@@ -37,45 +46,34 @@ class SimpleResidueTermination
   template<typename MatType>
   void Initialize(const MatType& V)
   {
-    // set resisue to minimum value
-    residue = minResidue;
-    // set iteration to minimum value
+    // Initialize the things we keep track of.
+    residue = DBL_MAX;
     iteration = 1;
-    // remove history
+    nm = V.n_rows * V.n_cols;
+    // Remove history.
     normOld = 0;
-
-    // initialize required variables
-    const size_t n = V.n_rows;
-    const size_t m = V.n_cols;
-    nm = n * m;
   }
 
   /**
-   * Check if termination criterio is met.
+   * Check if termination criterion is met.
    *
    * @param W Basis matrix of output.
    * @param H Encoding matrix of output.
    */
   bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    // Calculate norm of WH after each iteration.
-    arma::mat WH;
-
-    // calculate the norm and compute the residue 
-    WH = W * H;
-    double norm = sqrt(accu(WH % WH) / nm);
-    residue = fabs(normOld - norm);
-    residue /= normOld;
+    // Calculate the norm and compute the residue
+    const double norm = arma::norm(W * H, "fro");
+    residue = fabs(normOld - norm) / normOld;
 
-    // store the residue into history
+    // Store the norm.
     normOld = norm;
 
-    // increment iteration count
+    // Increment iteration count
     iteration++;
 
-    // check if termination criterion is met
-    if(residue < minResidue || iteration > maxIterations) return true;
-    else return false;
+    // Check if termination criterion is met.
+    return (residue < minResidue || iteration > maxIterations);
   }
 
   //! Get current value of residue



More information about the mlpack-git mailing list