[mlpack-svn] r16804 - in mlpack/trunk/src/mlpack: methods/amf methods/amf/termination_policies methods/amf/update_rules tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 9 18:14:53 EDT 2014


Author: sumedhghaisas
Date: Wed Jul  9 18:14:52 2014
New Revision: 16804

Log:
* added local minima storing functionality to termination policies


Added:
   mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
      - copied unchanged from r16803, /mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
   mlpack/trunk/src/mlpack/tests/svd_batch_test.cpp
      - copied, changed from r16803, /mlpack/trunk/src/mlpack/tests/svd_test.cpp
Removed:
   mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
   mlpack/trunk/src/mlpack/tests/svd_test.cpp
Modified:
   mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
   mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt
   mlpack/trunk/src/mlpack/tests/CMakeLists.txt

Modified: mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp	Wed Jul  9 18:14:52 2014
@@ -50,7 +50,7 @@
   update.Initialize(V, r);
   terminationPolicy.Initialize(V);
 
-  while (!terminationPolicy.IsConverged())
+  while (!terminationPolicy.IsConverged(W, H))
   {
     // Update the values of W and H based on the update rules provided.
     update.WUpdate(V, W, H);

Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp	Wed Jul  9 18:14:52 2014
@@ -26,9 +26,9 @@
     iteration = 0;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    return t_policy.IsConverged();
+    return t_policy.IsConverged(W, H);
   }
 
   void Step(const arma::mat& W, const arma::mat& H)

Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp	Wed Jul  9 18:14:52 2014
@@ -30,8 +30,10 @@
     nm = n * m;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
+    (void)W;
+    (void)H;
     if(residue < minResidue || iteration > maxIterations) return true;
     else return false;
   }

Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp	Wed Jul  9 18:14:52 2014
@@ -30,14 +30,39 @@
     this->V = &V;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
     if((residueOld - residue) / residueOld < tolerance && iteration > 4)
+    {
+      if(reverseStepCount == 0 && isCopy == false)
+      {
+        isCopy = true;
+        this->W = W;
+        this->H = H;
+        c_index = residue;
+        c_indexOld = residueOld;
+      }
       reverseStepCount++;
-    else reverseStepCount = 0;
+    }
+    else
+    {
+      reverseStepCount = 0;
+      if(residue <= c_indexOld && isCopy == true)
+      {
+        isCopy = false;
+      }
+    }
 
     if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+    {
+      if(isCopy)
+      {
+        W = this->W;
+        H = this->H;
+        residue = c_index;
+      }
       return true;
+    }
     else return false;
   }
 
@@ -89,6 +114,12 @@
 
   size_t reverseStepTolerance;
   size_t reverseStepCount;
+  
+  bool isCopy;
+  arma::mat W;
+  arma::mat H;
+  double c_indexOld;
+  double c_index;
 }; // class SimpleToleranceTermination
 
 }; // namespace amf

Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp	Wed Jul  9 18:14:52 2014
@@ -54,14 +54,39 @@
     reverseStepCount = 0;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4) 
+    if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
+    {
+      if(reverseStepCount == 0 && isCopy == false)
+      {
+        isCopy = true;
+        this->W = W;
+        this->H = H;
+        c_indexOld = rmseOld;
+        c_index = rmse;
+      }
       reverseStepCount++;
-    else reverseStepCount = 0;
+    }
+    else
+    {
+      reverseStepCount = 0;
+      if(rmse <= c_indexOld && isCopy == true)
+      {
+        isCopy = false;
+      }
+    }
 
-    if(reverseStepCount == reverseStepTolerance || iteration > maxIterations) 
+    if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+    {
+      if(isCopy)
+      {
+        W = this->W;
+        H = this->H;
+        rmse = c_index;
+      }
       return true;
+    }
     else return false;
   }
 
@@ -115,6 +140,12 @@
 
   size_t reverseStepTolerance;
   size_t reverseStepCount;
+  
+  bool isCopy;
+  arma::mat W;
+  arma::mat H;
+  double c_indexOld;
+  double c_index;
 };
 
 } // namespace amf

Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt	Wed Jul  9 18:14:52 2014
@@ -4,7 +4,7 @@
   nmf_als.hpp
   nmf_mult_dist.hpp
   nmf_mult_div.hpp
-  svd_batchlearning.hpp
+  svd_batch_learning.hpp
   svd_incremental_learning.hpp
 )
 

Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt	Wed Jul  9 18:14:52 2014
@@ -50,7 +50,7 @@
   tree_test.cpp
   tree_traits_test.cpp
   union_find_test.cpp
-  svd_test.cpp
+  svd_batch_test.cpp
 )
 # Link dependencies of test executable.
 target_link_libraries(mlpack_test

Copied: mlpack/trunk/src/mlpack/tests/svd_batch_test.cpp (from r16803, /mlpack/trunk/src/mlpack/tests/svd_test.cpp)
==============================================================================
--- /mlpack/trunk/src/mlpack/tests/svd_test.cpp	(original)
+++ mlpack/trunk/src/mlpack/tests/svd_batch_test.cpp	Wed Jul  9 18:14:52 2014
@@ -1,6 +1,6 @@
 #include <mlpack/core.hpp>
 #include <mlpack/methods/amf/amf.hpp>
-#include <mlpack/methods/amf/update_rules/svd_batchlearning.hpp>
+#include <mlpack/methods/amf/update_rules/svd_batch_learning.hpp>
 #include <mlpack/methods/amf/init_rules/random_init.hpp>
 #include <mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp>
 #include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
@@ -145,8 +145,6 @@
 
   arma::mat result = m1 * m2;
 
-  std::cout << result << std::endl;
-
   for(size_t i = 0;i < 3;i++)
   {
     for(size_t j = 0;j < 3;j++)



More information about the mlpack-svn mailing list