[mlpack-git] master, mlpack-1.0.x: * modified termination policies * fast SVDBatch implementation (a128b7f)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:51:19 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit a128b7fd12036ae9b0de4b634571aaedd72a423c
Author: sumedhghaisas <sumedhghaisas at gmail.com>
Date: Sun Jul 6 13:40:30 2014 +0000
* modified termination policies
* fast SVDBatch implementation
>---------------------------------------------------------------
a128b7fd12036ae9b0de4b634571aaedd72a423c
.../simple_tolerance_termination.hpp | 36 +++--
.../validation_RMSE_termination.hpp | 20 ++-
.../methods/amf/update_rules/svd_batchlearning.hpp | 110 +++++++++++----
src/mlpack/tests/CMakeLists.txt | 1 +
src/mlpack/tests/svd_test.cpp | 156 +++++++++++++++++++++
5 files changed, 275 insertions(+), 48 deletions(-)
diff --git a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
index 17bf24e..777e38a 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
@@ -15,8 +15,11 @@ class SimpleToleranceTermination
{
public:
SimpleToleranceTermination(const double tolerance = 1e-5,
- const size_t maxIterations = 10000)
- : tolerance(tolerance), maxIterations(maxIterations) {}
+ const size_t maxIterations = 10000,
+ const size_t reverseStepTolerance = 3)
+ : tolerance(tolerance),
+ maxIterations(maxIterations),
+ reverseStepTolerance(reverseStepTolerance) {}
void Initialize(const MatType& V)
{
@@ -29,8 +32,12 @@ class SimpleToleranceTermination
bool IsConverged()
{
- if(((residueOld - residue) / residueOld < tolerance && iteration > 4)
- || iteration > maxIterations) return true;
+ if((residueOld - residue) / residueOld < tolerance && iteration > 4)
+ reverseStepCount++;
+ else reverseStepCount = 0;
+
+ if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+ return true;
else return false;
}
@@ -48,17 +55,17 @@ class SimpleToleranceTermination
size_t count = 0;
for(size_t i = 0;i < n;i++)
{
- for(size_t j = 0;j < m;j++)
- {
- double temp = 0;
- if((temp = (*V)(i,j)) != 0)
+ for(size_t j = 0;j < m;j++)
{
- temp = (temp - WH(i, j));
- temp = temp * temp;
- sum += temp;
- count++;
+ double temp = 0;
+ if((temp = (*V)(i,j)) != 0)
+ {
+ temp = (temp - WH(i, j));
+ temp = temp * temp;
+ sum += temp;
+ count++;
+ }
}
- }
}
residue = sum / count;
residue = sqrt(residue);
@@ -79,6 +86,9 @@ class SimpleToleranceTermination
double residueOld;
double residue;
double normOld;
+
+ size_t reverseStepTolerance;
+ size_t reverseStepCount;
}; // class SimpleToleranceTermination
}; // namespace amf
diff --git a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
index 297a35b..49b0509 100644
--- a/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
@@ -14,10 +14,12 @@ class ValidationRMSETermination
ValidationRMSETermination(MatType& V,
size_t num_test_points,
double tolerance = 1e-5,
- size_t maxIterations = 10000)
+ size_t maxIterations = 10000,
+ size_t reverseStepTolerance = 3)
: tolerance(tolerance),
maxIterations(maxIterations),
- num_test_points(num_test_points)
+ num_test_points(num_test_points),
+ reverseStepTolerance(reverseStepTolerance)
{
size_t n = V.n_rows;
size_t m = V.n_cols;
@@ -44,19 +46,22 @@ class ValidationRMSETermination
void Initialize(const MatType& V)
{
+ (void)V;
iteration = 1;
rmse = DBL_MAX;
rmseOld = DBL_MAX;
- t_count = 0;
+ reverseStepCount = 0;
}
bool IsConverged()
{
- if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4) t_count++;
- else t_count = 0;
+ if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
+ reverseStepCount++;
+ else reverseStepCount = 0;
- if(t_count == 3 || iteration > maxIterations) return true;
+ if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
+ return true;
else return false;
}
@@ -108,7 +113,8 @@ class ValidationRMSETermination
double rmseOld;
double rmse;
- size_t t_count;
+ size_t reverseStepTolerance;
+ size_t reverseStepCount;
};
} // namespace amf
diff --git a/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp b/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
index 9658d83..267d651 100644
--- a/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
@@ -17,7 +17,7 @@ class SVDBatchLearning
SVDBatchLearning(double u = 0.0002,
double kw = 0,
double kh = 0,
- double momentum = 0.5,
+ double momentum = 0.9,
double min = -DBL_MIN,
double max = DBL_MAX)
: u(u), kw(kw), kh(kh), min(min), max(max), momentum(momentum)
@@ -48,6 +48,7 @@ class SVDBatchLearning
const arma::mat& H)
{
size_t n = V.n_rows;
+ size_t m = V.n_cols;
size_t r = W.n_cols;
@@ -56,17 +57,16 @@ class SVDBatchLearning
arma::mat deltaW(n, r);
deltaW.zeros();
- for(typename MatType::const_iterator it = V.begin();it != V.end();it++)
+ for(size_t i = 0;i < n;i++)
{
- size_t row = it.row();
- size_t col = it.col();
- deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
- arma::trans(H.col(col));
- }
-
- if(kw != 0) for(size_t i = 0; i < n; i++)
- {
- deltaW.row(i) -= kw * W.row(i);
+ for(size_t j = 0;j < m;j++)
+ {
+ double val;
+ if((val = V(i, j)) != 0)
+ deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
+ arma::trans(H.col(j));
+ }
+ if(kw != 0) deltaW.row(i) -= kw * W.row(i);
}
mW += u * deltaW;
@@ -87,6 +87,7 @@ class SVDBatchLearning
const arma::mat& W,
arma::mat& H)
{
+ size_t n = V.n_rows;
size_t m = V.n_cols;
size_t r = W.n_cols;
@@ -96,17 +97,16 @@ class SVDBatchLearning
arma::mat deltaH(r, m);
deltaH.zeros();
- for(typename MatType::const_iterator it = V.begin();it != V.end();it++)
+ for(size_t j = 0;j < m;j++)
{
- size_t row = it.row();
- size_t col = it.col();
- deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
- arma::trans(W.row(row));
- }
-
- if(kh != 0) for(size_t j = 0; j < m; j++)
- {
- deltaH.col(j) -= kh * H.col(j);
+ for(size_t i = 0;i < n;i++)
+ {
+ double val;
+ if((val = V(i, j)) != 0)
+ deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) *
+ arma::trans(W.row(i));
+ }
+ if(kh != 0) deltaH.col(j) -= kh * H.col(j);
}
mH += u*deltaH;
@@ -114,13 +114,6 @@ class SVDBatchLearning
}
private:
- double Predict(const arma::mat& wi, const arma::mat& hj) const
- {
- arma::mat temp = (wi * hj);
- double out = temp(0,0);
- return out;
- }
-
double u;
double kw;
double kh;
@@ -131,6 +124,67 @@ class SVDBatchLearning
arma::mat mW;
arma::mat mH;
};
+
+template<>
+inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
+ arma::mat& W,
+ const arma::mat& H)
+{
+ size_t n = V.n_rows;
+
+ size_t r = W.n_cols;
+
+ mW = momentum * mW;
+
+ arma::mat deltaW(n, r);
+ deltaW.zeros();
+
+ for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
+ {
+ size_t row = it.row();
+ size_t col = it.col();
+ deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) * arma::trans(H.col(col));
+ }
+
+ if(kw != 0) for(size_t i = 0; i < n; i++)
+ {
+ deltaW.row(i) -= kw * W.row(i);
+ }
+
+ mW += u * deltaW;
+ W += mW;
+}
+
+template<>
+inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
+ const arma::mat& W,
+ arma::mat& H)
+{
+ size_t m = V.n_cols;
+
+ size_t r = W.n_cols;
+
+ mH = momentum * mH;
+
+ arma::mat deltaH(r, m);
+ deltaH.zeros();
+
+ for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
+ {
+ size_t row = it.row();
+ size_t col = it.col();
+ deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) * arma::trans(W.row(row));
+ }
+
+ if(kh != 0) for(size_t j = 0; j < m; j++)
+ {
+ deltaH.col(j) -= kh * H.col(j);
+ }
+
+ mH += u*deltaH;
+ H += mH;
+}
+
} // namespace amf
} // namespace mlpack
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 41c7867..3be664e 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -50,6 +50,7 @@ add_executable(mlpack_test
tree_test.cpp
tree_traits_test.cpp
union_find_test.cpp
+ svd_test.cpp
)
# Link dependencies of test executable.
target_link_libraries(mlpack_test
diff --git a/src/mlpack/tests/svd_test.cpp b/src/mlpack/tests/svd_test.cpp
new file mode 100644
index 0000000..ab2e9d3
--- /dev/null
+++ b/src/mlpack/tests/svd_test.cpp
@@ -0,0 +1,156 @@
+#include <mlpack/core.hpp>
+#include <mlpack/methods/amf/amf.hpp>
+#include <mlpack/methods/amf/update_rules/svd_batchlearning.hpp>
+#include <mlpack/methods/amf/init_rules/random_init.hpp>
+#include <mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp>
+#include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+BOOST_AUTO_TEST_SUITE(SVDBatchTest);
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::amf;
+using namespace arma;
+
+/**
+ * Make sure the momentum is working okay.
+ */
+BOOST_AUTO_TEST_CASE(SVDMomentumTest)
+{
+ mat dataset;
+ data::Load("GroupLens100k.csv", dataset);
+
+ // Generate list of locations for batch insert constructor for sparse
+ // matrices.
+ arma::umat locations(2, dataset.n_cols);
+ arma::vec values(dataset.n_cols);
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ {
+ // We have to transpose it because items are rows, and users are columns.
+ locations(0, i) = ((arma::uword) dataset(0, i));
+ locations(1, i) = ((arma::uword) dataset(1, i));
+ values(i) = dataset(2, i);
+ }
+
+ // Find maximum user and item IDs.
+ const size_t maxUserID = (size_t) max(locations.row(0)) + 1;
+ const size_t maxItemID = (size_t) max(locations.row(1)) + 1;
+
+ // Fill sparse matrix.
+ sp_mat cleanedData = arma::sp_mat(locations, values, maxUserID, maxItemID);
+
+ math::RandomSeed(10);
+ ValidationRMSETermination<sp_mat> vrt(cleanedData, 2000);
+ AMF<ValidationRMSETermination<sp_mat>,
+ RandomInitialization,
+ SVDBatchLearning> amf_1(vrt,
+ RandomInitialization(),
+ SVDBatchLearning(0.0009, 0, 0, 0));
+
+ mat m1,m2;
+ size_t RMSE_1 = amf_1.Apply(cleanedData, 2, m1, m2);
+ size_t iter_1 = amf_1.TPolicy().Iteration();
+
+ math::RandomSeed(10);
+ AMF<ValidationRMSETermination<sp_mat>,
+ RandomInitialization,
+ SVDBatchLearning> amf_2(vrt,
+ RandomInitialization(),
+ SVDBatchLearning(0.0009, 0, 0, 0.8));
+
+ size_t RMSE_2 = amf_2.Apply(cleanedData, 2, m1, m2);
+ size_t iter_2 = amf_2.TPolicy().Iteration();
+
+ BOOST_REQUIRE_LE(RMSE_2, RMSE_1);
+ BOOST_REQUIRE_LE(iter_2, iter_1);
+}
+
+/**
+ * Make sure the regularization is working okay.
+ */
+BOOST_AUTO_TEST_CASE(SVDRegularizationTest)
+{
+ mat dataset;
+ data::Load("GroupLens100k.csv", dataset);
+
+ // Generate list of locations for batch insert constructor for sparse
+ // matrices.
+ arma::umat locations(2, dataset.n_cols);
+ arma::vec values(dataset.n_cols);
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ {
+ // We have to transpose it because items are rows, and users are columns.
+ locations(0, i) = ((arma::uword) dataset(0, i));
+ locations(1, i) = ((arma::uword) dataset(1, i));
+ values(i) = dataset(2, i);
+ }
+
+ // Find maximum user and item IDs.
+ const size_t maxUserID = (size_t) max(locations.row(0)) + 1;
+ const size_t maxItemID = (size_t) max(locations.row(1)) + 1;
+
+ // Fill sparse matrix.
+ sp_mat cleanedData = arma::sp_mat(locations, values, maxUserID, maxItemID);
+
+ math::RandomSeed(10);
+ ValidationRMSETermination<sp_mat> vrt(cleanedData, 2000);
+ AMF<ValidationRMSETermination<sp_mat>,
+ RandomInitialization,
+ SVDBatchLearning> amf_1(vrt,
+ RandomInitialization(),
+ SVDBatchLearning(0.0009, 0, 0, 0));
+
+ mat m1,m2;
+ size_t RMSE_1 = amf_1.Apply(cleanedData, 2, m1, m2);
+
+ math::RandomSeed(10);
+ AMF<ValidationRMSETermination<sp_mat>,
+ RandomInitialization,
+ SVDBatchLearning> amf_2(vrt,
+ RandomInitialization(),
+ SVDBatchLearning(0.0009, 0.5, 0.5, 0.8));
+
+ size_t RMSE_2 = amf_2.Apply(cleanedData, 2, m1, m2);
+
+ BOOST_REQUIRE_LE(RMSE_2, RMSE_1);
+}
+
+/**
+ * Make sure the SVD can factorize matrices with negative entries.
+ */
+BOOST_AUTO_TEST_CASE(SVDNegativeElementTest)
+{
+ mat test;
+ test.zeros(3,3);
+ test(0, 0) = 1;
+ test(0, 1) = -2;
+ test(0, 2) = 3;
+ test(1, 0) = 2;
+ test(1, 1) = -1;
+ test(1, 2) = 1;
+ test(2, 0) = 2;
+ test(2, 1) = 2;
+ test(2, 2) = 2;
+
+ AMF<SimpleToleranceTermination<mat>,
+ RandomInitialization,
+ SVDBatchLearning> amf(SimpleToleranceTermination<mat>(),
+ RandomInitialization(),
+ SVDBatchLearning(0.3, 0.001, 0.001, 0));
+ mat m1, m2;
+ amf.Apply(test, 2, m1, m2);
+
+ arma::mat result = m1 * m2;
+ for(size_t i = 0;i < 3;i++)
+ {
+ for(size_t j = 0;j < 3;j++)
+ {
+ BOOST_REQUIRE_LE(abs(test(i,j) - result(i,j)), 0.5);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list