[mlpack-svn] r16747 - in mlpack/trunk/src/mlpack/methods/amf: . termination_policies update_rules
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 2 17:07:43 EDT 2014
Author: sumedhghaisas
Date: Wed Jul 2 17:07:42 2014
New Revision: 16747
Log:
* faster implementation of SVDBatchWithMomentum
* tolerance termination policy is modified according to new policy
* test point selection in validation RMSE termination is shifted to constructor
Added:
mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
Modified:
mlpack/trunk/src/mlpack/methods/amf/amf.hpp
mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
Modified: mlpack/trunk/src/mlpack/methods/amf/amf.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/amf.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/amf.hpp Wed Jul 2 17:07:42 2014
@@ -82,9 +82,9 @@
*/
template<typename MatType>
double Apply(const MatType& V,
- const size_t r,
- arma::mat& W,
- arma::mat& H);
+ const size_t r,
+ arma::mat& W,
+ arma::mat& H);
private:
//! termination policy
Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt (original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt Wed Jul 2 17:07:42 2014
@@ -3,6 +3,7 @@
set(SOURCES
simple_residue_termination.hpp
simple_tolerance_termination.hpp
+ validation_rmse_termination.hpp
)
# Add directory name to sources.
Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp Wed Jul 2 17:07:42 2014
@@ -18,7 +18,7 @@
: minResidue(minResidue), maxIterations(maxIterations) { }
template<typename MatType>
- void Initialize(MatType& V)
+ void Initialize(const MatType& V)
{
residue = minResidue;
iteration = 1;
@@ -36,8 +36,7 @@
else return false;
}
- template<typename MatType>
- void Step(const MatType& W, const MatType& H)
+ void Step(const arma::mat& W, const arma::mat& H)
{
// Calculate norm of WH after each iteration.
arma::mat WH;
Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp Wed Jul 2 17:07:42 2014
@@ -10,6 +10,7 @@
namespace mlpack {
namespace amf {
+template <class MatType>
class SimpleToleranceTermination
{
public:
@@ -17,44 +18,50 @@
const size_t maxIterations = 10000)
: tolerance(tolerance), maxIterations(maxIterations) {}
- template<typename MatType>
- void Initialize(MatType& V)
+ void Initialize(const MatType& V)
{
residueOld = DBL_MAX;
iteration = 1;
- normOld = 0;
residue = DBL_MIN;
- const size_t n = V.n_rows;
- const size_t m = V.n_cols;
-
- nm = n * m;
+ this->V = &V;
}
bool IsConverged()
{
- if(((residueOld - residue) / residueOld < tolerance && iteration > 4)
+ if(((residueOld - residue) / residueOld < tolerance && iteration > 4)
|| iteration > maxIterations) return true;
else return false;
}
- template<typename MatType>
- void Step(const MatType& W, const MatType& H)
+ void Step(const arma::mat& W, const arma::mat& H)
{
// Calculate norm of WH after each iteration.
arma::mat WH;
WH = W * H;
- double norm = sqrt(accu(WH % WH) / nm);
- if (iteration != 0)
+ residueOld = residue;
+ size_t n = V->n_rows;
+ size_t m = V->n_cols;
+ double sum = 0;
+ size_t count = 0;
+ for(size_t i = 0;i < n;i++)
{
- residueOld = residue;
- residue = fabs(normOld - norm);
- residue /= normOld;
+ for(size_t j = 0;j < m;j++)
+ {
+ double temp = 0;
+ if((temp = (*V)(i,j)) != 0)
+ {
+ temp = (temp - WH(i, j));
+ temp = temp * temp;
+ sum += temp;
+ count++;
+ }
+ }
}
-
- normOld = norm;
+ residue = sum / count;
+ residue = sqrt(residue);
iteration++;
}
@@ -66,15 +73,16 @@
double tolerance;
size_t maxIterations;
+ const MatType* V;
+
size_t iteration;
double residueOld;
double residue;
double normOld;
-
- size_t nm;
}; // class SimpleToleranceTermination
}; // namespace amf
}; // namespace mlpack
#endif // _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
+
Added: mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp Wed Jul 2 17:07:42 2014
@@ -0,0 +1,119 @@
+#ifndef VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
+#define VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
+
+#include <mlpack/core.hpp>
+
+namespace mlpack
+{
+namespace amf
+{
+template <class MatType>
+class ValidationRMSETermination
+{
+ public:
+ ValidationRMSETermination(MatType& V,
+ size_t num_test_points,
+ double tolerance = 1e-5,
+ size_t maxIterations = 10000)
+ : tolerance(tolerance),
+ maxIterations(maxIterations),
+ num_test_points(num_test_points)
+ {
+ size_t n = V.n_rows;
+ size_t m = V.n_cols;
+
+ test_points.zeros(num_test_points, 3);
+
+ for(size_t i = 0; i < num_test_points; i++)
+ {
+ double t_val;
+ size_t t_row;
+ size_t t_col;
+ do
+ {
+ t_row = rand() % n;
+ t_col = rand() % m;
+ } while((t_val = V(t_row, t_col)) == 0);
+
+ test_points(i, 0) = t_row;
+ test_points(i, 1) = t_col;
+ test_points(i, 2) = t_val;
+ V(t_row, t_col) = 0;
+ }
+ }
+
+ void Initialize(const MatType& V)
+ {
+ iteration = 1;
+
+ rmse = DBL_MAX;
+ rmseOld = DBL_MAX;
+ t_count = 0;
+ }
+
+ bool IsConverged()
+ {
+ if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4) t_count++;
+ else t_count = 0;
+
+ if(t_count == 3 || iteration > maxIterations) return true;
+ else return false;
+ }
+
+ void Step(const arma::mat& W, const arma::mat& H)
+ {
+ // Calculate norm of WH after each iteration.
+ arma::mat WH;
+
+ WH = W * H;
+
+ if (iteration != 0)
+ {
+ rmseOld = rmse;
+ rmse = 0;
+ for(size_t i = 0; i < num_test_points; i++)
+ {
+ size_t t_row = test_points(i, 0);
+ size_t t_col = test_points(i, 1);
+ double t_val = test_points(i, 2);
+ double temp = (t_val - WH(t_row, t_col));
+ temp *= temp;
+ rmse += temp;
+ }
+ rmse /= num_test_points;
+ rmse = sqrt(rmse);
+ }
+
+ iteration++;
+ }
+
+ const double& Index()
+ {
+ return rmse;
+ }
+
+ const size_t& Iteration()
+ {
+ return iteration;
+ }
+
+ private:
+ double tolerance;
+ size_t maxIterations;
+ size_t num_test_points;
+ size_t iteration;
+
+ arma::Mat<double> test_points;
+
+ double rmseOld;
+ double rmse;
+
+ size_t t_count;
+};
+
+} // namespace amf
+} // namespace mlpack
+
+
+#endif // VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
+
Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp Wed Jul 2 17:07:42 2014
@@ -1,3 +1,7 @@
+/**
+ * @file simple_residue_termination.hpp
+ * @author Sumedh Ghaisas
+ */
#ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
#define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
@@ -9,113 +13,123 @@
{
class SVDBatchLearning
{
-public:
- SVDBatchLearning(double u = 0.000001,
- double kw = 0,
- double kh = 0,
- double momentum = 0.2,
- double min = -DBL_MIN,
- double max = DBL_MAX)
+ public:
+ SVDBatchLearning(double u = 0.0002,
+ double kw = 0,
+ double kh = 0,
+ double momentum = 0.5,
+ double min = -DBL_MIN,
+ double max = DBL_MAX)
: u(u), kw(kw), kh(kh), min(min), max(max), momentum(momentum)
{}
- template<typename MatType>
- void Initialize(const MatType& dataset, const size_t rank)
- {
- const size_t n = dataset.n_rows;
- const size_t m = dataset.n_cols;
-
- mW.zeros(n, rank);
- mH.zeros(rank, m);
- }
-
- /**
- * The update rule for the basis matrix W.
- * The function takes in all the matrices and only changes the
- * value of the W matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix to be updated.
- * @param H Encoding matrix.
- */
- template<typename MatType>
- inline void WUpdate(const MatType& V,
- arma::mat& W,
- const arma::mat& H)
- {
- size_t n = V.n_rows;
- size_t m = V.n_cols;
+ template<typename MatType>
+ void Initialize(const MatType& dataset, const size_t rank)
+ {
+ const size_t n = dataset.n_rows;
+ const size_t m = dataset.n_cols;
+
+ mW.zeros(n, rank);
+ mH.zeros(rank, m);
+ }
+
+ /**
+ * The update rule for the basis matrix W.
+ * The function takes in all the matrices and only changes the
+ * value of the W matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix to be updated.
+ * @param H Encoding matrix.
+ */
+ template<typename MatType>
+ inline void WUpdate(const MatType& V,
+ arma::mat& W,
+ const arma::mat& H)
+ {
+ size_t n = V.n_rows;
- size_t r = W.n_cols;
+ size_t r = W.n_cols;
- mW = momentum * mW;
+ mW = momentum * mW;
- arma::mat deltaW(n, r);
- deltaW.zeros();
+ arma::mat deltaW(n, r);
+ deltaW.zeros();
- for(size_t i = 0; i < n; i++)
- {
- for(size_t j = 0; j < m; j++)
- if(V(i,j) != 0) deltaW.row(i) += (V(i,j) - Predict(W.row(i), H.col(j))) * arma::trans(H.col(j));
- deltaW.row(i) -= kw * W.row(i);
- }
-
- mW += u * deltaW;
- W += mW;
+ for(typename MatType::const_iterator it = V.begin();it != V.end();it++)
+ {
+ size_t row = it.row();
+ size_t col = it.col();
+ deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
+ arma::trans(H.col(col));
}
- /**
- * The update rule for the encoding matrix H.
- * The function takes in all the matrices and only changes the
- * value of the H matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix.
- * @param H Encoding matrix to be updated.
- */
- template<typename MatType>
- inline void HUpdate(const MatType& V,
- const arma::mat& W,
- arma::mat& H)
+ if(kw != 0) for(size_t i = 0; i < n; i++)
{
- size_t n = V.n_rows;
- size_t m = V.n_cols;
+ deltaW.row(i) -= kw * W.row(i);
+ }
- size_t r = W.n_cols;
+ mW += u * deltaW;
+ W += mW;
+ }
+
+ /**
+ * The update rule for the encoding matrix H.
+ * The function takes in all the matrices and only changes the
+ * value of the H matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix.
+ * @param H Encoding matrix to be updated.
+ */
+ template<typename MatType>
+ inline void HUpdate(const MatType& V,
+ const arma::mat& W,
+ arma::mat& H)
+ {
+ size_t m = V.n_cols;
- mH = momentum * mH;
+ size_t r = W.n_cols;
- arma::mat deltaH(r, m);
- deltaH.zeros();
+ mH = momentum * mH;
- for(size_t j = 0; j < m; j++)
- {
- for(size_t i = 0; i < n; i++)
- if(V(i,j) != 0) deltaH.col(j) += (V(i,j) - Predict(W.row(i), H.col(j))) * arma::trans(W.row(i));
- deltaH.col(j) -= kh * H.col(j);
- }
+ arma::mat deltaH(r, m);
+ deltaH.zeros();
- mH += u*deltaH;
- H += mH;
+ for(typename MatType::const_iterator it = V.begin();it != V.end();it++)
+ {
+ size_t row = it.row();
+ size_t col = it.col();
+ deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
+ arma::trans(W.row(row));
}
-private:
- double Predict(const arma::mat& wi, const arma::mat& hj) const
+ if(kh != 0) for(size_t j = 0; j < m; j++)
{
- arma::mat temp = (wi * hj);
- double out = temp(0,0);
- return out;
+ deltaH.col(j) -= kh * H.col(j);
}
- double u;
- double kw;
- double kh;
- double min;
- double max;
- double momentum;
+ mH += u*deltaH;
+ H += mH;
+ }
+
+ private:
+ double Predict(const arma::mat& wi, const arma::mat& hj) const
+ {
+ arma::mat temp = (wi * hj);
+ double out = temp(0,0);
+ return out;
+ }
+
+ double u;
+ double kw;
+ double kh;
+ double min;
+ double max;
+ double momentum;
- arma::mat mW;
- arma::mat mH;
+ arma::mat mW;
+ arma::mat mH;
};
} // namespace amf
} // namespace mlpack
@@ -123,3 +137,4 @@
#endif
+
More information about the mlpack-svn
mailing list