[mlpack-svn] r16804 - in mlpack/trunk/src/mlpack: methods/amf methods/amf/termination_policies methods/amf/update_rules tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 9 18:14:53 EDT 2014
Author: sumedhghaisas
Date: Wed Jul 9 18:14:52 2014
New Revision: 16804
Log:
* added local minima storing functionality to termination policies
Added:
mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
- copied unchanged from r16803, /mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
mlpack/trunk/src/mlpack/tests/svd_batch_test.cpp
- copied, changed from r16803, /mlpack/trunk/src/mlpack/tests/svd_test.cpp
Removed:
mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
mlpack/trunk/src/mlpack/tests/svd_test.cpp
Modified:
mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp
mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt
mlpack/trunk/src/mlpack/tests/CMakeLists.txt
Modified: mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp Wed Jul 9 18:14:52 2014
@@ -50,7 +50,7 @@
update.Initialize(V, r);
terminationPolicy.Initialize(V);
- while (!terminationPolicy.IsConverged())
+ while (!terminationPolicy.IsConverged(W, H))
{
// Update the values of W and H based on the update rules provided.
update.WUpdate(V, W, H);
Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp Wed Jul 9 18:14:52 2014
@@ -26,9 +26,9 @@
iteration = 0;
}
- bool IsConverged()
+ bool IsConverged(arma::mat& W, arma::mat& H)
{
- return t_policy.IsConverged();
+ return t_policy.IsConverged(W, H);
}
void Step(const arma::mat& W, const arma::mat& H)
Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp Wed Jul 9 18:14:52 2014
@@ -30,8 +30,10 @@
nm = n * m;
}
- bool IsConverged()
+ bool IsConverged(arma::mat& W, arma::mat& H)
{
+ (void)W;
+ (void)H;
if(residue < minResidue || iteration > maxIterations) return true;
else return false;
}
Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp Wed Jul 9 18:14:52 2014
@@ -30,14 +30,39 @@
this->V = &V;
}
- bool IsConverged()
+ bool IsConverged(arma::mat& W, arma::mat& H)
{
if((residueOld - residue) / residueOld < tolerance && iteration > 4)
+ {
+ if(reverseStepCount == 0 && isCopy == false)
+ {
+ isCopy = true;
+ this->W = W;
+ this->H = H;
+ c_index = residue;
+ c_indexOld = residueOld;
+ }
reverseStepCount++;
- else reverseStepCount = 0;
+ }
+ else
+ {
+ reverseStepCount = 0;
+ if(residue <= c_indexOld && isCopy == true)
+ {
+ isCopy = false;
+ }
+ }
if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+ {
+ if(isCopy)
+ {
+ W = this->W;
+ H = this->H;
+ residue = c_index;
+ }
return true;
+ }
else return false;
}
@@ -89,6 +114,12 @@
size_t reverseStepTolerance;
size_t reverseStepCount;
+
+ bool isCopy;
+ arma::mat W;
+ arma::mat H;
+ double c_indexOld;
+ double c_index;
}; // class SimpleToleranceTermination
}; // namespace amf
Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp Wed Jul 9 18:14:52 2014
@@ -54,14 +54,39 @@
reverseStepCount = 0;
}
- bool IsConverged()
+ bool IsConverged(arma::mat& W, arma::mat& H)
{
- if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
+ if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
+ {
+ if(reverseStepCount == 0 && isCopy == false)
+ {
+ isCopy = true;
+ this->W = W;
+ this->H = H;
+ c_indexOld = rmseOld;
+ c_index = rmse;
+ }
reverseStepCount++;
- else reverseStepCount = 0;
+ }
+ else
+ {
+ reverseStepCount = 0;
+ if(rmse <= c_indexOld && isCopy == true)
+ {
+ isCopy = false;
+ }
+ }
- if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+ if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+ {
+ if(isCopy)
+ {
+ W = this->W;
+ H = this->H;
+ rmse = c_index;
+ }
return true;
+ }
else return false;
}
@@ -115,6 +140,12 @@
size_t reverseStepTolerance;
size_t reverseStepCount;
+
+ bool isCopy;
+ arma::mat W;
+ arma::mat H;
+ double c_indexOld;
+ double c_index;
};
} // namespace amf
Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt (original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt Wed Jul 9 18:14:52 2014
@@ -4,7 +4,7 @@
nmf_als.hpp
nmf_mult_dist.hpp
nmf_mult_div.hpp
- svd_batchlearning.hpp
+ svd_batch_learning.hpp
svd_incremental_learning.hpp
)
Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt (original)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt Wed Jul 9 18:14:52 2014
@@ -50,7 +50,7 @@
tree_test.cpp
tree_traits_test.cpp
union_find_test.cpp
- svd_test.cpp
+ svd_batch_test.cpp
)
# Link dependencies of test executable.
target_link_libraries(mlpack_test
Copied: mlpack/trunk/src/mlpack/tests/svd_batch_test.cpp (from r16803, /mlpack/trunk/src/mlpack/tests/svd_test.cpp)
==============================================================================
--- /mlpack/trunk/src/mlpack/tests/svd_test.cpp (original)
+++ mlpack/trunk/src/mlpack/tests/svd_batch_test.cpp Wed Jul 9 18:14:52 2014
@@ -1,6 +1,6 @@
#include <mlpack/core.hpp>
#include <mlpack/methods/amf/amf.hpp>
-#include <mlpack/methods/amf/update_rules/svd_batchlearning.hpp>
+#include <mlpack/methods/amf/update_rules/svd_batch_learning.hpp>
#include <mlpack/methods/amf/init_rules/random_init.hpp>
#include <mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp>
#include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
@@ -145,8 +145,6 @@
arma::mat result = m1 * m2;
- std::cout << result << std::endl;
-
for(size_t i = 0;i < 3;i++)
{
for(size_t j = 0;j < 3;j++)
More information about the mlpack-svn
mailing list