[mlpack-svn] r13210 - mlpack/trunk/src/mlpack/methods/nmf

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 11 18:21:40 EDT 2012


Author: rmohan
Date: 2012-07-11 18:21:40 -0400 (Wed, 11 Jul 2012)
New Revision: 13210

Added:
   mlpack/trunk/src/mlpack/methods/nmf/alsupdate.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/nmf/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/nmf/mdivupdate.hpp
   mlpack/trunk/src/mlpack/methods/nmf/nmf.hpp
   mlpack/trunk/src/mlpack/methods/nmf/nmf_impl.hpp
   mlpack/trunk/src/mlpack/methods/nmf/nmf_main.cpp
   mlpack/trunk/src/mlpack/methods/nmf/randominit.hpp
Log:
Added Aleternating Least Square method to NMF. Require some testing

Modified: mlpack/trunk/src/mlpack/methods/nmf/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/CMakeLists.txt	2012-07-11 20:54:40 UTC (rev 13209)
+++ mlpack/trunk/src/mlpack/methods/nmf/CMakeLists.txt	2012-07-11 22:21:40 UTC (rev 13210)
@@ -5,7 +5,9 @@
 set(SOURCES
   mdistupdate.hpp
   mdivupdate.hpp
+  alsupdate.hpp
   randominit.hpp
+  randomacolinit.hpp
   nmf.hpp
   nmf_impl.hpp
 )

Added: mlpack/trunk/src/mlpack/methods/nmf/alsupdate.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/alsupdate.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/nmf/alsupdate.hpp	2012-07-11 22:21:40 UTC (rev 13210)
@@ -0,0 +1,104 @@
+/**
+ * @file alsupdate.hpp
+ * @author Mohan Rajendran
+ *
+ * Update rules for the Non-negative Matrix Factorization. This follows a method
+ * titled 'Alternating Least Squares' describes in the paper 'Positive Matrix
+ * Factorization: A Non-negative Factor Model with Optimal Utilization of 
+ * Error Estimates of Data Values' by P. Paatero and U. Tapper. It uses least 
+ * squares projection formula to reduce the error value of 
+ * \f$ \sqrt{\sum_i \sum_j(V-WH)^2} \f$ by alternately calculating W and H
+ * respectively while holding the other matrix constant.
+ *
+ */
+
+#ifndef __MLPACK_METHODS_NMF_ALSUPDATE_HPP
+#define __MLPACK_METHODS_NMF_ALSUPDATE_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace nmf {
+
+/**
+ * The update rule for the basis matrix W. The formula used is 
+ * \f[ 
+ * W^T = \frac{HV^T}{HH^T}
+ * \f]
+ */
+class AlternatingLeastSquareW
+{
+ public:
+  // Empty constructor required for the WUpdateRule template
+  AlternatingLeastSquareW() { }
+
+  /**
+   * The update function that actually updates the W matrix. The function takes
+   * in all the salient matrices and only changes the value of the W matrix.
+   *
+   * @param V Input matrix to be factorized
+   * @param W Basis matrix to be output
+   * @param H Encoding matrix to output
+   */
+
+  inline static void Update(const arma::mat& V,
+                     arma::mat& W, 
+                     const arma::mat& H)
+  {
+    // Simple implementation. This can be left here.
+    W = (inv(H*H.t())*H*H.t()).t();
+    
+    // Set all negative numbers to machine epsilon
+    for(size_t i=0;i<W.n_rows*W.n_cols;i++)
+    {
+      if(W(i) < 0.0)
+      {
+        W(i) = eps(W);
+      }
+    }
+  }
+}; // Class AlternatingLeastSquareW
+
+/**
+ * The update rule for the encoding matrix H. The formula used is
+ * \f[
+ * H = \frac{W^TV}{W^TW}
+ * \f]
+ */
+class AlternatingLeastSquareH
+{
+ public:
+  // Empty constructor required for the HUpdateRule template
+  AlternatingLeastSquareH() { }
+
+  /**
+   * The update function that actually updates the H matrix. The function takes
+   * in all the salient matrices and only changes the value of the H matrix.
+   *
+   * @param V Input matrix to be factorized
+   * @param W Basis matrix to be output
+   * @param H Encoding matrix to output
+   */
+
+  inline static void Update(const arma::mat& V,
+                     const arma::mat& W, 
+                     arma::mat& H)
+  {
+    // Simple implementation. This can be left here.
+    H = inv(W.t()*W)*W.t()*V;
+
+    // Set all negative numbers to machine epsilon
+    for(size_t i=0;i<H.n_rows*H.n_cols;i++)
+    {
+      if(H(i) < 0.0)
+      {
+        H(i) = eps(H);
+      }
+    }
+  }
+}; // Class AlternatingLeastSquareH
+
+}; // namespace nmf
+}; // namespace mlpack
+
+#endif

Modified: mlpack/trunk/src/mlpack/methods/nmf/mdivupdate.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/mdivupdate.hpp	2012-07-11 20:54:40 UTC (rev 13209)
+++ mlpack/trunk/src/mlpack/methods/nmf/mdivupdate.hpp	2012-07-11 22:21:40 UTC (rev 13210)
@@ -27,7 +27,7 @@
  * {\sum_{\nu} H_{a\nu}}
  * \f]
  */
-class MultiplicativeDistanceW
+class MultiplicativeDivergenceW
 {
  public:
   // Empty constructor required for the WUpdateRule template
@@ -42,7 +42,7 @@
    * @param H Encoding matrix to output
    */
 
-  inline static void Init(const arma::mat& V,
+  inline static void Update(const arma::mat& V,
                      arma::mat& W, 
                      const arma::mat& H)
   {
@@ -56,7 +56,7 @@
       for(size_t j=0;j<W.n_cols;j++)
       {
         t2 = H.row(j)%V.row(i)/t1.row(i);
-        W(i,j) = W(i,j)*sum(t2)/sum(H.row(i));
+        W(i,j) = W(i,j)*sum(t2)/sum(H.row(j));
       }
     }
 
@@ -74,7 +74,7 @@
 {
  public:
   // Empty constructor required for the HUpdateRule template
-  MultiplicativeDistanceH() { }
+  MultiplicativeDivergenceH() { }
 
   /**
    * The update function that actually updates the H matrix. The function takes
@@ -99,7 +99,7 @@
       for(size_t j=0;j<H.n_cols;j++)
       {
         t2 = W.col(i)%V.col(j)/t1.col(j);
-        H(i,j) = H(i,j)*sum(t2)/sum(H.col(i));
+        H(i,j) = H(i,j)*sum(t2)/sum(W.col(i));
       }
     }
 

Modified: mlpack/trunk/src/mlpack/methods/nmf/nmf.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/nmf.hpp	2012-07-11 20:54:40 UTC (rev 13209)
+++ mlpack/trunk/src/mlpack/methods/nmf/nmf.hpp	2012-07-11 22:21:40 UTC (rev 13210)
@@ -69,7 +69,7 @@
    *    the H vector has states that it needs to store.
    */
   NMF(const size_t maxIterations = 10000,
-      const double maxResidue = 1e-10,
+      const double maxResidue = 1e-5,
       const InitializeRule Initialize = InitializeRule(),
       const WUpdateRule WUpdate = WUpdateRule(),
       const HUpdateRule HUpdate = HUpdateRule());

Modified: mlpack/trunk/src/mlpack/methods/nmf/nmf_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/nmf_impl.hpp	2012-07-11 20:54:40 UTC (rev 13209)
+++ mlpack/trunk/src/mlpack/methods/nmf/nmf_impl.hpp	2012-07-11 22:21:40 UTC (rev 13210)
@@ -5,8 +5,6 @@
  * Implementation of NMF class to perform Non-Negative Matrix Factorization
  * on the given matrix.
  */
-#include "nmf.hpp"
-#include <iostream>
 
 namespace mlpack {
 namespace nmf {
@@ -38,6 +36,8 @@
         << "1e-10.\n";
     this->maxResidue = 1e-10;
   } 
+
+  math::RandomSeed((size_t) std::time(NULL));    
 }
 
 /**
@@ -59,40 +59,54 @@
   size_t n = V.n_rows;
   size_t m = V.n_cols;
 
-  // old and new product WH for residue checking
-  arma::mat WHold,WH,diff;
-  
   // Intialize W and H
   Initialize.Init(V,W,H,r);
 
-  // Store the original calculated value for residue checking
-  WHold = W*H;
+  //Log::Debug << "Initialized W and H." << std::endl;
+
+  size_t iteration = 0;
+  size_t nm = n*m;
+  double residue = maxResidue;
+  double normOld,norm;
+  arma::mat WH;    
   
-  size_t iteration = 0;
-  double residue;
-  double sqrRes = maxResidue*maxResidue;
-
-  do
+  while (residue >= maxResidue  && iteration != maxIterations)
   {
     // Update step.
     // Update the value of W and H based on the Update Rules provided
     WUpdate.Update(V,W,H);
     HUpdate.Update(V,W,H);
 
-    // Calculate square of residue after iteration
+    // Calculate norm of WH after each iteration
     WH = W*H;
-    diff = WHold-WH;
+    norm = sqrt(accu(WH%WH)/nm);
+    
+    if(iteration!=0)
+    {
+      residue = fabs(normOld-norm);
+      if(normOld > 1.0)
+      {
+        residue /= normOld;
+      }
+    }
+
+    normOld = norm;
+
+    /*
+      WH = W*H;
+      diff = WHold-WH;
     diff = diff%diff;
     residue = accu(diff)/(double)(n*m);
-    WHold = WH;
-    Log::Debug << "Iteration: " << iteration << " Residue: " 
-          << residue << std::endl;
+    WHold = WH;*/
 
+    //Log::Debug << "Iteration: " << iteration << " Residue: " 
+    //      << sqrt(residue) << std::endl;
+
     iteration++;
-  
-  } while (residue >= sqrRes  && iteration != maxIterations);
+      
+  }
 
-  Log::Debug << "Iterations: " << iteration << std::endl;
+  //Log::Debug << "Iterations: " << iteration << std::endl;
 }
 
 }; // namespace nmf

Modified: mlpack/trunk/src/mlpack/methods/nmf/nmf_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/nmf_main.cpp	2012-07-11 20:54:40 UTC (rev 13209)
+++ mlpack/trunk/src/mlpack/methods/nmf/nmf_main.cpp	2012-07-11 22:21:40 UTC (rev 13210)
@@ -26,14 +26,21 @@
 PARAM_INT_REQ("rank", "Rank of the factorization.", "r");
 PARAM_INT("max_iterations", "Number of iterations before NMF terminates", 
     "m", 10000);
+PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
 PARAM_DOUBLE("max_residue", "The maximum root mean square allowed below which "
-    "the program termiates", "e", 1e-10);
+    "the program termiates", "e", 1e-5);
 
 int main(int argc, char** argv)
 {
   // Parse commandline.
   CLI::ParseCommandLine(argc, argv);
 
+  // Initialize random seed.
+  if (CLI::GetParam<int>("seed") != 0)
+    math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+  else
+    math::RandomSeed((size_t) std::time(NULL));
+
   // Load input dataset.
   string inputFile = CLI::GetParam<string>("input_file");
   arma::mat V;

Modified: mlpack/trunk/src/mlpack/methods/nmf/randominit.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/randominit.hpp	2012-07-11 20:54:40 UTC (rev 13209)
+++ mlpack/trunk/src/mlpack/methods/nmf/randominit.hpp	2012-07-11 22:21:40 UTC (rev 13210)
@@ -27,6 +27,7 @@
                      const size_t& r)
   {
     // Simple inplementation. This can be left here.
+
     size_t n = V.n_rows;
     size_t m = V.n_cols;
   




More information about the mlpack-svn mailing list