[mlpack-git] master, mlpack-1.0.x: * added svd incomplete incremental learning tests * combined functions IsConverged and Step of termination policies into IsConverged (c247ee6)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:53:21 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit c247ee6ac8441317175e78a0b96c74c4dda7ca7f
Author: sumedhghaisas <sumedhghaisas at gmail.com>
Date: Wed Jul 16 11:14:27 2014 +0000
* added svd incomplete incremental learning tests
* combined functions IsConverged and Step of termination policies into IsConverged
>---------------------------------------------------------------
c247ee6ac8441317175e78a0b96c74c4dda7ca7f
src/mlpack/methods/amf/amf_impl.hpp | 4 +-
.../incomplete_incremental_termination.hpp | 9 ++--
.../simple_residue_termination.hpp | 11 ++--
.../simple_tolerance_termination.hpp | 61 ++++++++++------------
.../validation_RMSE_termination.hpp | 51 +++++++++---------
.../amf/update_rules/svd_incremental_learning.hpp | 4 +-
src/mlpack/tests/svd_incremental_test.cpp | 9 ++--
7 files changed, 66 insertions(+), 83 deletions(-)
diff --git a/src/mlpack/methods/amf/amf_impl.hpp b/src/mlpack/methods/amf/amf_impl.hpp
index a887931..d99cf57 100644
--- a/src/mlpack/methods/amf/amf_impl.hpp
+++ b/src/mlpack/methods/amf/amf_impl.hpp
@@ -55,11 +55,9 @@ Apply(const MatType& V,
// Update the values of W and H based on the update rules provided.
update.WUpdate(V, W, H);
update.HUpdate(V, W, H);
-
- terminationPolicy.Step(W, H);
}
- const double residue = sqrt(terminationPolicy.Index());
+ const double residue = terminationPolicy.Index();
const size_t iteration = terminationPolicy.Iteration();
Log::Info << "AMF converged to residue of " << residue << " in "
diff --git a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
index cfa499e..d53b8b7 100644
--- a/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
@@ -28,13 +28,10 @@ class IncompleteIncrementalTermination
bool IsConverged(arma::mat& W, arma::mat& H)
{
- return t_policy.IsConverged(W, H);
- }
-
- void Step(const arma::mat& W, const arma::mat& H)
- {
- if(iteration % incrementalIndex == 0) t_policy.Step(W, H);
iteration++;
+ if(iteration % incrementalIndex == 0)
+ return t_policy.IsConverged(W, H);
+ else return false;
}
const double& Index()
diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index b5c4fb5..3e5f7b8 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -32,14 +32,6 @@ class SimpleResidueTermination
bool IsConverged(arma::mat& W, arma::mat& H)
{
- (void)W;
- (void)H;
- if(residue < minResidue || iteration > maxIterations) return true;
- else return false;
- }
-
- void Step(const arma::mat& W, const arma::mat& H)
- {
// Calculate norm of WH after each iteration.
arma::mat WH;
@@ -55,6 +47,9 @@ class SimpleResidueTermination
normOld = norm;
iteration++;
+
+ if(residue < minResidue || iteration > maxIterations) return true;
+ else return false;
}
const double& Index() { return residue; }
diff --git a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
index dfc78bd..8976c14 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
@@ -32,6 +32,35 @@ class SimpleToleranceTermination
bool IsConverged(arma::mat& W, arma::mat& H)
{
+ // Calculate norm of WH after each iteration.
+ arma::mat WH;
+
+ WH = W * H;
+
+ residueOld = residue;
+ size_t n = V->n_rows;
+ size_t m = V->n_cols;
+ double sum = 0;
+ size_t count = 0;
+ for(size_t i = 0;i < n;i++)
+ {
+ for(size_t j = 0;j < m;j++)
+ {
+ double temp = 0;
+ if((temp = (*V)(i,j)) != 0)
+ {
+ temp = (temp - WH(i, j));
+ temp = temp * temp;
+ sum += temp;
+ count++;
+ }
+ }
+ }
+ residue = sum / count;
+ residue = sqrt(residue);
+
+ iteration++;
+
if((residueOld - residue) / residueOld < tolerance && iteration > 4)
{
if(reverseStepCount == 0 && isCopy == false)
@@ -66,38 +95,6 @@ class SimpleToleranceTermination
else return false;
}
- void Step(const arma::mat& W, const arma::mat& H)
- {
- // Calculate norm of WH after each iteration.
- arma::mat WH;
-
- WH = W * H;
-
- residueOld = residue;
- size_t n = V->n_rows;
- size_t m = V->n_cols;
- double sum = 0;
- size_t count = 0;
- for(size_t i = 0;i < n;i++)
- {
- for(size_t j = 0;j < m;j++)
- {
- double temp = 0;
- if((temp = (*V)(i,j)) != 0)
- {
- temp = (temp - WH(i, j));
- temp = temp * temp;
- sum += temp;
- count++;
- }
- }
- }
- residue = sum / count;
- residue = sqrt(residue);
-
- iteration++;
- }
-
const double& Index() { return residue; }
const size_t& Iteration() { return iteration; }
const size_t& MaxIterations() { return maxIterations; }
diff --git a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
index 7f03954..a437ce5 100644
--- a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
@@ -56,6 +56,30 @@ class ValidationRMSETermination
bool IsConverged(arma::mat& W, arma::mat& H)
{
+ // Calculate norm of WH after each iteration.
+ arma::mat WH;
+
+ WH = W * H;
+
+ if (iteration != 0)
+ {
+ rmseOld = rmse;
+ rmse = 0;
+ for(size_t i = 0; i < num_test_points; i++)
+ {
+ size_t t_row = test_points(i, 0);
+ size_t t_col = test_points(i, 1);
+ double t_val = test_points(i, 2);
+ double temp = (t_val - WH(t_row, t_col));
+ temp *= temp;
+ rmse += temp;
+ }
+ rmse /= num_test_points;
+ rmse = sqrt(rmse);
+ }
+
+ iteration++;
+
if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
{
if(reverseStepCount == 0 && isCopy == false)
@@ -90,33 +114,6 @@ class ValidationRMSETermination
else return false;
}
- void Step(const arma::mat& W, const arma::mat& H)
- {
- // Calculate norm of WH after each iteration.
- arma::mat WH;
-
- WH = W * H;
-
- if (iteration != 0)
- {
- rmseOld = rmse;
- rmse = 0;
- for(size_t i = 0; i < num_test_points; i++)
- {
- size_t t_row = test_points(i, 0);
- size_t t_col = test_points(i, 1);
- double t_val = test_points(i, 2);
- double temp = (t_val - WH(t_row, t_col));
- temp *= temp;
- rmse += temp;
- }
- rmse /= num_test_points;
- rmse = sqrt(rmse);
- }
-
- iteration++;
- }
-
const double& Index() { return rmse; }
const size_t& Iteration() { return iteration; }
diff --git a/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp b/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
index a0388c1..4cc7053 100644
--- a/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
@@ -49,7 +49,7 @@ class SVDIncrementalLearning
if((val = V(i, currentUserIndex)) != 0)
deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
arma::trans(H.col(currentUserIndex));
- if(kw != 0) deltaW -= kw * W.row(i);
+ if(kw != 0) deltaW.row(i) -= kw * W.row(i);
}
W += u*deltaW;
@@ -112,7 +112,7 @@ inline void SVDIncrementalLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
size_t i = it.row();
deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
arma::trans(H.col(currentUserIndex));
- if(kw != 0) deltaW -= kw * W.row(i);
+ if(kw != 0) deltaW.row(i) -= kw * W.row(i);
}
W += u*deltaW;
diff --git a/src/mlpack/tests/svd_incremental_test.cpp b/src/mlpack/tests/svd_incremental_test.cpp
index 00b55d3..682039c 100644
--- a/src/mlpack/tests/svd_incremental_test.cpp
+++ b/src/mlpack/tests/svd_incremental_test.cpp
@@ -4,6 +4,7 @@
#include <mlpack/methods/amf/init_rules/random_init.hpp>
#include <mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp>
#include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
+#include <mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -35,7 +36,7 @@ BOOST_AUTO_TEST_CASE(SVDIncrementalConvergenceTest)
amf.TerminationPolicy().MaxIterations());
}
-/*
+
BOOST_AUTO_TEST_CASE(SVDIncrementalRegularizationTest)
{
mat dataset;
@@ -78,14 +79,12 @@ BOOST_AUTO_TEST_CASE(SVDIncrementalRegularizationTest)
RandomInitialization,
SVDIncrementalLearning> amf_2(vrt2,
RandomInitialization(),
- SVDIncrementalLearning(0.001, 1e-5, 2e-5));
+ SVDIncrementalLearning(0.001, 0.01, 0.01));
mat m3, m4;
double RMSE_2 = amf_2.Apply(cleanedData2, 2, m3, m4);
- // RMSE_2 should be less than RMSE_1
- std::cout << RMSE_1 << " " << RMSE_2 << std::endl;
+ BOOST_REQUIRE_LT(RMSE_2, RMSE_1);
}
-*/
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list