[mlpack-svn] r17387 - mlpack/trunk/src/mlpack/methods/amf/termination_policies
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 19 12:08:48 EST 2014
Author: rcurtin
Date: Wed Nov 19 12:08:48 2014
New Revision: 17387
Log:
Refactor for cleaner code and avoid storing WH explicitly if possible.
Modified:
mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp Wed Nov 19 12:08:48 2014
@@ -13,21 +13,30 @@
namespace amf {
/**
- * This class implements simple residue based termination policy. Termination
- * decision depends on two factors, value of residue and number of iteration.
- * If the current value of residue drops below the threshold or the number of
- * iterations goes above the threshold, positive termination signal is passed
- * to AMF.
+ * This class implements a simple residue-based termination policy. The
+ * termination decision depends on two factors: the value of the residue (the
+ * difference between the norm of WH this iteration and the previous iteration),
+ * and the number of iterations. If the current value of residue drops below
+ * the threshold or the number of iterations goes above the iteration limit,
+ * IsConverged() will return true. This class is meant for use with the AMF
+ * (alternating matrix factorization) class.
*
* @see AMF
*/
class SimpleResidueTermination
{
public:
- //! empty constructor
+ /**
+ * Construct the SimpleResidueTermination object with the given minimum
+ * residue (or the default) and the given maximum number of iterations (or the
+ * default). 0 indicates no iteration limit.
+ *
+ * @param minResidue Minimum residue for termination.
+ * @param maxIterations Maximum number of iterations.
+ */
SimpleResidueTermination(const double minResidue = 1e-10,
const size_t maxIterations = 10000)
- : minResidue(minResidue), maxIterations(maxIterations) { }
+ : minResidue(minResidue), maxIterations(maxIterations) { }
/**
* Initializes the termination policy before stating the factorization.
@@ -37,57 +46,46 @@
template<typename MatType>
void Initialize(const MatType& V)
{
- // set resisue to minimum value
- residue = minResidue;
- // set iteration to minimum value
+ // Initialize the things we keep track of.
+ residue = DBL_MAX;
iteration = 1;
- // remove history
+ nm = V.n_rows * V.n_cols;
+ // Remove history.
normOld = 0;
-
- // initialize required variables
- const size_t n = V.n_rows;
- const size_t m = V.n_cols;
- nm = n * m;
}
/**
- * Check if termination criterio is met.
+ * Check if termination criterion is met.
*
* @param W Basis matrix of output.
* @param H Encoding matrix of output.
*/
bool IsConverged(arma::mat& W, arma::mat& H)
{
- // Calculate norm of WH after each iteration.
- arma::mat WH;
-
- // calculate the norm and compute the residue
- WH = W * H;
- double norm = sqrt(accu(WH % WH) / nm);
- residue = fabs(normOld - norm);
- residue /= normOld;
+ // Calculate the norm and compute the residue
+ const double norm = arma::norm(W * H, "fro");
+ residue = fabs(normOld - norm) / normOld;
- // store the residue into history
+ // Store the norm.
normOld = norm;
-
- // increment iteration count
+
+ // Increment iteration count
iteration++;
-
- // check if termination criterion is met
- if(residue < minResidue || iteration > maxIterations) return true;
- else return false;
+
+ // Check if termination criterion is met.
+ return (residue < minResidue || iteration > maxIterations);
}
//! Get current value of residue
const double& Index() const { return residue; }
- //! Get current iteration count
+ //! Get current iteration count
const size_t& Iteration() const { return iteration; }
-
+
//! Access max iteration count
const size_t& MaxIterations() const { return maxIterations; }
size_t& MaxIterations() { return maxIterations; }
-
+
//! Access minimum residue value
const double& MinResidue() const { return minResidue; }
double& MinResidue() { return minResidue; }
More information about the mlpack-svn
mailing list