[mlpack-svn] r16747 - in mlpack/trunk/src/mlpack/methods/amf: . termination_policies update_rules

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 2 17:07:43 EDT 2014


Author: sumedhghaisas
Date: Wed Jul  2 17:07:42 2014
New Revision: 16747

Log:
* faster implementation of SVDBatchWithMomentum
* tolerance termination policy is modified according to new policy
* test point selection in validation RMSE termination is shifted to constructor


Added:
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/amf/amf.hpp
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
   mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp

Modified: mlpack/trunk/src/mlpack/methods/amf/amf.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/amf.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/amf.hpp	Wed Jul  2 17:07:42 2014
@@ -82,9 +82,9 @@
    */
   template<typename MatType>
   double Apply(const MatType& V,
-             const size_t r,
-             arma::mat& W,
-             arma::mat& H);
+               const size_t r,
+               arma::mat& W,
+               arma::mat& H);
 
  private:
   //! termination policy

Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt	Wed Jul  2 17:07:42 2014
@@ -3,6 +3,7 @@
 set(SOURCES
   simple_residue_termination.hpp
   simple_tolerance_termination.hpp
+  validation_rmse_termination.hpp
 )
 
 # Add directory name to sources.

Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp	Wed Jul  2 17:07:42 2014
@@ -18,7 +18,7 @@
         : minResidue(minResidue), maxIterations(maxIterations) { }
 
   template<typename MatType>
-  void Initialize(MatType& V)
+  void Initialize(const MatType& V)
   {
     residue = minResidue;
     iteration = 1;
@@ -36,8 +36,7 @@
     else return false;
   }
 
-  template<typename MatType>
-  void Step(const MatType& W, const MatType& H)
+  void Step(const arma::mat& W, const arma::mat& H)
   {
     // Calculate norm of WH after each iteration.
     arma::mat WH;

Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp	Wed Jul  2 17:07:42 2014
@@ -10,6 +10,7 @@
 namespace mlpack {
 namespace amf {
 
+template <class MatType>
 class SimpleToleranceTermination
 {
  public:
@@ -17,44 +18,50 @@
                              const size_t maxIterations = 10000)
             : tolerance(tolerance), maxIterations(maxIterations) {}
 
-  template<typename MatType>
-  void Initialize(MatType& V)
+  void Initialize(const MatType& V)
   {
     residueOld = DBL_MAX;
     iteration = 1;
-    normOld = 0;
     residue = DBL_MIN;
 
-    const size_t n = V.n_rows;
-    const size_t m = V.n_cols;
-
-    nm = n * m;
+    this->V = &V;
   }
 
   bool IsConverged()
   {
-    if(((residueOld - residue) / residueOld < tolerance && iteration > 4) 
+    if(((residueOld - residue) / residueOld < tolerance && iteration > 4)
         || iteration > maxIterations) return true;
     else return false;
   }
 
-  template<typename MatType>
-  void Step(const MatType& W, const MatType& H)
+  void Step(const arma::mat& W, const arma::mat& H)
   {
     // Calculate norm of WH after each iteration.
     arma::mat WH;
 
     WH = W * H;
-    double norm = sqrt(accu(WH % WH) / nm);
 
-    if (iteration != 0)
+    residueOld = residue;
+    size_t n = V->n_rows;
+    size_t m = V->n_cols;
+    double sum = 0;
+    size_t count = 0;
+    for(size_t i = 0;i < n;i++)
     {
-      residueOld = residue;
-      residue = fabs(normOld - norm);
-      residue /= normOld;
+      for(size_t j = 0;j < m;j++)
+      {
+        double temp = 0;
+        if((temp = (*V)(i,j)) != 0)
+        {
+          temp = (temp - WH(i, j));
+          temp = temp * temp;
+          sum += temp;
+          count++;
+        }
+      }
     }
-
-    normOld = norm;
+    residue = sum / count;
+    residue = sqrt(residue);
 
     iteration++;
   }
@@ -66,15 +73,16 @@
   double tolerance;
   size_t maxIterations;
 
+  const MatType* V;
+
   size_t iteration;
   double residueOld;
   double residue;
   double normOld;
-
-  size_t nm;
 }; // class SimpleToleranceTermination
 
 }; // namespace amf
 }; // namespace mlpack
 
 #endif // _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
+

Added: mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp	Wed Jul  2 17:07:42 2014
@@ -0,0 +1,119 @@
+#ifndef VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
+#define VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
+
+#include <mlpack/core.hpp>
+
+namespace mlpack
+{
+namespace amf
+{
+template <class MatType>
+class ValidationRMSETermination
+{
+ public:
+  ValidationRMSETermination(MatType& V,
+                            size_t num_test_points,
+                            double tolerance = 1e-5,
+                            size_t maxIterations = 10000)
+        : tolerance(tolerance),
+          maxIterations(maxIterations),
+          num_test_points(num_test_points)
+  {
+    size_t n = V.n_rows;
+    size_t m = V.n_cols;
+
+    test_points.zeros(num_test_points, 3);
+
+    for(size_t i = 0; i < num_test_points; i++)
+    {
+      double t_val;
+      size_t t_row;
+      size_t t_col;
+      do
+      {
+        t_row = rand() % n;
+        t_col = rand() % m;
+      } while((t_val = V(t_row, t_col)) == 0);
+
+      test_points(i, 0) = t_row;
+      test_points(i, 1) = t_col;
+      test_points(i, 2) = t_val;
+      V(t_row, t_col) = 0;
+    }
+  }
+
+  void Initialize(const MatType& V)
+  {
+    iteration = 1;
+
+    rmse = DBL_MAX;
+    rmseOld = DBL_MAX;
+    t_count = 0;
+  }
+
+  bool IsConverged()
+  {
+    if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4) t_count++;
+    else t_count = 0;
+
+    if(t_count == 3 || iteration > maxIterations) return true;
+    else return false;
+  }
+
+  void Step(const arma::mat& W, const arma::mat& H)
+  {
+    // Calculate norm of WH after each iteration.
+    arma::mat WH;
+
+    WH = W * H;
+
+    if (iteration != 0)
+    {
+      rmseOld = rmse;
+      rmse = 0;
+      for(size_t i = 0; i < num_test_points; i++)
+      {
+        size_t t_row = test_points(i, 0);
+        size_t t_col = test_points(i, 1);
+        double t_val = test_points(i, 2);
+        double temp = (t_val - WH(t_row, t_col));
+        temp *= temp;
+        rmse += temp;
+      }
+      rmse /= num_test_points;
+      rmse = sqrt(rmse);
+    }
+
+    iteration++;
+  }
+
+  const double& Index()
+  {
+    return rmse;
+  }
+
+  const size_t& Iteration()
+  {
+    return iteration;
+  }
+
+ private:
+  double tolerance;
+  size_t maxIterations;
+  size_t num_test_points;
+  size_t iteration;
+
+  arma::Mat<double> test_points;
+
+  double rmseOld;
+  double rmse;
+
+  size_t t_count;
+};
+
+} // namespace amf
+} // namespace mlpack
+
+
+#endif // VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
+

Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp	Wed Jul  2 17:07:42 2014
@@ -1,3 +1,7 @@
+/**
+ * @file simple_residue_termination.hpp
+ * @author Sumedh Ghaisas
+ */
 #ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
 #define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
 
@@ -9,113 +13,123 @@
 {
 class SVDBatchLearning
 {
-public:
-    SVDBatchLearning(double u = 0.000001,
-                     double kw = 0,
-                     double kh = 0,
-                     double momentum = 0.2,
-                     double min = -DBL_MIN,
-                     double max = DBL_MAX)
+ public:
+  SVDBatchLearning(double u = 0.0002,
+                   double kw = 0,
+                   double kh = 0,
+                   double momentum = 0.5,
+                   double min = -DBL_MIN,
+                   double max = DBL_MAX)
         : u(u), kw(kw), kh(kh), min(min), max(max), momentum(momentum)
     {}
 
-    template<typename MatType>
-    void Initialize(const MatType& dataset, const size_t rank)
-    {
-        const size_t n = dataset.n_rows;
-        const size_t m = dataset.n_cols;
-
-        mW.zeros(n, rank);
-        mH.zeros(rank, m);
-    }
-
-    /**
-    * The update rule for the basis matrix W.
-    * The function takes in all the matrices and only changes the
-    * value of the W matrix.
-    *
-    * @param V Input matrix to be factorized.
-    * @param W Basis matrix to be updated.
-    * @param H Encoding matrix.
-    */
-    template<typename MatType>
-    inline void WUpdate(const MatType& V,
-                               arma::mat& W,
-                               const arma::mat& H)
-    {
-        size_t n = V.n_rows;
-        size_t m = V.n_cols;
+  template<typename MatType>
+  void Initialize(const MatType& dataset, const size_t rank)
+  {
+    const size_t n = dataset.n_rows;
+    const size_t m = dataset.n_cols;
+
+    mW.zeros(n, rank);
+    mH.zeros(rank, m);
+  }
+
+  /**
+   * The update rule for the basis matrix W.
+   * The function takes in all the matrices and only changes the
+   * value of the W matrix.
+   *
+   * @param V Input matrix to be factorized.
+   * @param W Basis matrix to be updated.
+   * @param H Encoding matrix.
+   */
+  template<typename MatType>
+  inline void WUpdate(const MatType& V,
+                      arma::mat& W,
+                      const arma::mat& H)
+  {
+    size_t n = V.n_rows;
 
-        size_t r = W.n_cols;
+    size_t r = W.n_cols;
 
-        mW = momentum * mW;
+    mW = momentum * mW;
 
-        arma::mat deltaW(n, r);
-        deltaW.zeros();
+    arma::mat deltaW(n, r);
+    deltaW.zeros();
 
-        for(size_t i = 0; i < n; i++)
-        {
-            for(size_t j = 0; j < m; j++)
-                if(V(i,j) != 0) deltaW.row(i) += (V(i,j) - Predict(W.row(i), H.col(j))) * arma::trans(H.col(j));
-            deltaW.row(i) -= kw * W.row(i);
-        }
-
-        mW += u * deltaW;
-        W += mW;
+    for(typename MatType::const_iterator it = V.begin();it != V.end();it++)
+    {
+      size_t row = it.row();
+      size_t col = it.col();
+      deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) * 
+                                                    arma::trans(H.col(col));
     }
 
-    /**
-    * The update rule for the encoding matrix H.
-    * The function takes in all the matrices and only changes the
-    * value of the H matrix.
-    *
-    * @param V Input matrix to be factorized.
-    * @param W Basis matrix.
-    * @param H Encoding matrix to be updated.
-    */
-    template<typename MatType>
-    inline void HUpdate(const MatType& V,
-                               const arma::mat& W,
-                               arma::mat& H)
+    if(kw != 0) for(size_t i = 0; i < n; i++)
     {
-        size_t n = V.n_rows;
-        size_t m = V.n_cols;
+      deltaW.row(i) -= kw * W.row(i);
+    }
 
-        size_t r = W.n_cols;
+    mW += u * deltaW;
+    W += mW;
+  }
+
+  /**
+   * The update rule for the encoding matrix H.
+   * The function takes in all the matrices and only changes the
+   * value of the H matrix.
+   *
+   * @param V Input matrix to be factorized.
+   * @param W Basis matrix.
+   * @param H Encoding matrix to be updated.
+   */
+  template<typename MatType>
+  inline void HUpdate(const MatType& V,
+                      const arma::mat& W,
+                      arma::mat& H)
+  {
+    size_t m = V.n_cols;
 
-        mH = momentum * mH;
+    size_t r = W.n_cols;
 
-        arma::mat deltaH(r, m);
-        deltaH.zeros();
+    mH = momentum * mH;
 
-        for(size_t j = 0; j < m; j++)
-        {
-            for(size_t i = 0; i < n; i++)
-                if(V(i,j) != 0) deltaH.col(j) += (V(i,j) - Predict(W.row(i), H.col(j))) * arma::trans(W.row(i));
-            deltaH.col(j) -= kh * H.col(j);
-        }
+    arma::mat deltaH(r, m);
+    deltaH.zeros();
 
-        mH += u*deltaH;
-        H += mH;
+    for(typename MatType::const_iterator it = V.begin();it != V.end();it++)
+    {
+      size_t row = it.row();
+      size_t col = it.col();
+      deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) * 
+                                                    arma::trans(W.row(row));
     }
-private:
 
-    double Predict(const arma::mat& wi, const arma::mat& hj) const
+    if(kh != 0) for(size_t j = 0; j < m; j++)
     {
-        arma::mat temp = (wi * hj);
-        double out = temp(0,0);
-        return out;
+      deltaH.col(j) -= kh * H.col(j);
     }
 
-    double u;
-    double kw;
-    double kh;
-    double min;
-    double max;
-    double momentum;
+    mH += u*deltaH;
+    H += mH;
+  }
+  
+ private:
+  double Predict(const arma::mat& wi, const arma::mat& hj) const
+  {
+    arma::mat temp = (wi * hj);
+    double out = temp(0,0);
+    return out;
+  }
+
+  double u;
+  double kw;
+  double kh;
+  double min;
+  double max;
+  double momentum;
 
-    arma::mat mW;
-    arma::mat mH;
+  arma::mat mW;
+  arma::mat mH;
 };
 } // namespace amf
 } // namespace mlpack
@@ -123,3 +137,4 @@
 
 #endif
 
+



More information about the mlpack-svn mailing list