[mlpack-git] master, mlpack-1.0.x: * added local minima storing functionality to termination policies (5607d5b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:52:38 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 5607d5bf030886b4f97824c590012266fce2a7e9
Author: sumedhghaisas <sumedhghaisas at gmail.com>
Date:   Wed Jul 9 22:14:52 2014 +0000

    * added local minima storing functionality to termination policies


>---------------------------------------------------------------

5607d5bf030886b4f97824c590012266fce2a7e9
 src/mlpack/methods/amf/amf_impl.hpp                |  2 +-
 .../incomplete_incremental_termination.hpp         |  4 +--
 .../simple_residue_termination.hpp                 |  4 ++-
 .../simple_tolerance_termination.hpp               | 35 ++++++++++++++++++++--
 .../validation_RMSE_termination.hpp                | 35 ++++++++++++++++++++--
 src/mlpack/methods/amf/update_rules/CMakeLists.txt |  2 +-
 ...vd_batchlearning.hpp => svd_batch_learning.hpp} |  0
 src/mlpack/tests/CMakeLists.txt                    |  2 +-
 .../tests/{svd_test.cpp => svd_batch_test.cpp}     |  4 +--
 9 files changed, 75 insertions(+), 13 deletions(-)

diff --git a/src/mlpack/methods/amf/amf_impl.hpp b/src/mlpack/methods/amf/amf_impl.hpp
index ce9a2aa..a887931 100644
--- a/src/mlpack/methods/amf/amf_impl.hpp
+++ b/src/mlpack/methods/amf/amf_impl.hpp
@@ -50,7 +50,7 @@ Apply(const MatType& V,
   update.Initialize(V, r);
   terminationPolicy.Initialize(V);
 
-  while (!terminationPolicy.IsConverged())
+  while (!terminationPolicy.IsConverged(W, H))
   {
     // Update the values of W and H based on the update rules provided.
     update.WUpdate(V, W, H);
diff --git a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
index d24d571..91424db 100644
--- a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
@@ -26,9 +26,9 @@ class IncompleteIncrementalTermination
     iteration = 0;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    return t_policy.IsConverged();
+    return t_policy.IsConverged(W, H);
   }
 
   void Step(const arma::mat& W, const arma::mat& H)
diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index a47ce23..cbae4ba 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -30,8 +30,10 @@ class SimpleResidueTermination
     nm = n * m;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
+    (void)W;
+    (void)H;
     if(residue < minResidue || iteration > maxIterations) return true;
     else return false;
   }
diff --git a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
index 777e38a..8518b54 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
@@ -30,14 +30,39 @@ class SimpleToleranceTermination
     this->V = &V;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
     if((residueOld - residue) / residueOld < tolerance && iteration > 4)
+    {
+      if(reverseStepCount == 0 && isCopy == false)
+      {
+        isCopy = true;
+        this->W = W;
+        this->H = H;
+        c_index = residue;
+        c_indexOld = residueOld;
+      }
       reverseStepCount++;
-    else reverseStepCount = 0;
+    }
+    else
+    {
+      reverseStepCount = 0;
+      if(residue <= c_indexOld && isCopy == true)
+      {
+        isCopy = false;
+      }
+    }
 
     if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+    {
+      if(isCopy)
+      {
+        W = this->W;
+        H = this->H;
+        residue = c_index;
+      }
       return true;
+    }
     else return false;
   }
 
@@ -89,6 +114,12 @@ class SimpleToleranceTermination
 
   size_t reverseStepTolerance;
   size_t reverseStepCount;
+  
+  bool isCopy;
+  arma::mat W;
+  arma::mat H;
+  double c_indexOld;
+  double c_index;
 }; // class SimpleToleranceTermination
 
 }; // namespace amf
diff --git a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
index 49b0509..ceb7b0c 100644
--- a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
@@ -54,14 +54,39 @@ class ValidationRMSETermination
     reverseStepCount = 0;
   }
 
-  bool IsConverged()
+  bool IsConverged(arma::mat& W, arma::mat& H)
   {
     if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
+    {
+      if(reverseStepCount == 0 && isCopy == false)
+      {
+        isCopy = true;
+        this->W = W;
+        this->H = H;
+        c_indexOld = rmseOld;
+        c_index = rmse;
+      }
       reverseStepCount++;
-    else reverseStepCount = 0;
+    }
+    else
+    {
+      reverseStepCount = 0;
+      if(rmse <= c_indexOld && isCopy == true)
+      {
+        isCopy = false;
+      }
+    }
 
     if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+    {
+      if(isCopy)
+      {
+        W = this->W;
+        H = this->H;
+        rmse = c_index;
+      }
       return true;
+    }
     else return false;
   }
 
@@ -115,6 +140,12 @@ class ValidationRMSETermination
 
   size_t reverseStepTolerance;
   size_t reverseStepCount;
+  
+  bool isCopy;
+  arma::mat W;
+  arma::mat H;
+  double c_indexOld;
+  double c_index;
 };
 
 } // namespace amf
diff --git a/src/mlpack/methods/amf/update_rules/CMakeLists.txt b/src/mlpack/methods/amf/update_rules/CMakeLists.txt
index b7bde1c..baa942f 100644
--- a/src/mlpack/methods/amf/update_rules/CMakeLists.txt
+++ b/src/mlpack/methods/amf/update_rules/CMakeLists.txt
@@ -4,7 +4,7 @@ set(SOURCES
   nmf_als.hpp
   nmf_mult_dist.hpp
   nmf_mult_div.hpp
-  svd_batchlearning.hpp
+  svd_batch_learning.hpp
   svd_incremental_learning.hpp
 )
 
diff --git a/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp b/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
similarity index 100%
rename from src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
rename to src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 3be664e..5a8794c 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -50,7 +50,7 @@ add_executable(mlpack_test
   tree_test.cpp
   tree_traits_test.cpp
   union_find_test.cpp
-  svd_test.cpp
+  svd_batch_test.cpp
 )
 # Link dependencies of test executable.
 target_link_libraries(mlpack_test
diff --git a/src/mlpack/tests/svd_test.cpp b/src/mlpack/tests/svd_batch_test.cpp
similarity index 97%
rename from src/mlpack/tests/svd_test.cpp
rename to src/mlpack/tests/svd_batch_test.cpp
index 7160645..af14247 100644
--- a/src/mlpack/tests/svd_test.cpp
+++ b/src/mlpack/tests/svd_batch_test.cpp
@@ -1,6 +1,6 @@
 #include <mlpack/core.hpp>
 #include <mlpack/methods/amf/amf.hpp>
-#include <mlpack/methods/amf/update_rules/svd_batchlearning.hpp>
+#include <mlpack/methods/amf/update_rules/svd_batch_learning.hpp>
 #include <mlpack/methods/amf/init_rules/random_init.hpp>
 #include <mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp>
 #include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
@@ -145,8 +145,6 @@ BOOST_AUTO_TEST_CASE(SVDNegativeElementTest)
 
   arma::mat result = m1 * m2;
 
-  std::cout << result << std::endl;
-
   for(size_t i = 0;i < 3;i++)
   {
     for(size_t j = 0;j < 3;j++)



More information about the mlpack-git mailing list