[mlpack-git] mlpack-1.0.x: Minor changes to get things to compile. Looks like I didn't do a perfect job of merging... (49d155c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:07:35 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 49d155cb6835d953218127cd94a20a2970ae19e4
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Dec 8 03:47:36 2014 +0000

    Minor changes to get things to compile. Looks like I didn't do a perfect job of merging...


>---------------------------------------------------------------

49d155cb6835d953218127cd94a20a2970ae19e4
 src/mlpack/core.hpp                                |   2 +
 src/mlpack/methods/amf/init_rules/CMakeLists.txt   |   1 +
 .../simple_residue_termination.hpp                 |  26 ++--
 .../simple_tolerance_termination.hpp               |  43 ++++--
 .../methods/decision_stump/decision_stump.hpp      |  39 +++---
 .../methods/decision_stump/decision_stump_impl.hpp | 144 ++++++++++++++-------
 src/mlpack/tests/allknn_test.cpp                   |   1 -
 src/mlpack/tests/gmm_test.cpp                      |   8 +-
 src/mlpack/tests/svd_batch_test.cpp                |  35 ++---
 9 files changed, 195 insertions(+), 104 deletions(-)

diff --git a/src/mlpack/core.hpp b/src/mlpack/core.hpp
index 08468b4..4800f8f 100644
--- a/src/mlpack/core.hpp
+++ b/src/mlpack/core.hpp
@@ -209,3 +209,5 @@
     #undef max
   #endif
 #endif
+
+#endif
diff --git a/src/mlpack/methods/amf/init_rules/CMakeLists.txt b/src/mlpack/methods/amf/init_rules/CMakeLists.txt
index a31d281..20ef23f 100644
--- a/src/mlpack/methods/amf/init_rules/CMakeLists.txt
+++ b/src/mlpack/methods/amf/init_rules/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Define the files we need to compile
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
+  average_init.hpp
   random_init.hpp
   random_acol_init.hpp
 )
diff --git a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
index 0e15154..40c6671 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_residue_termination.hpp
@@ -2,6 +2,8 @@
  * @file simple_residue_termination.hpp
  * @author Sumedh Ghaisas
  *
+ * Termination policy used in AMF (Alternating Matrix Factorization).
+ *
  * This file is part of MLPACK 1.0.11.
  *
  * MLPACK is free software: you can redistribute it and/or modify it under the
@@ -51,18 +53,20 @@ class SimpleResidueTermination
                            const size_t maxIterations = 10000)
       : minResidue(minResidue), maxIterations(maxIterations) { }
 
+  /**
+   * Initializes the termination policy before stating the factorization.
+   *
+   * @param V Input matrix being factorized.
+   */
   template<typename MatType>
   void Initialize(const MatType& V)
   {
     // Initialize the things we keep track of.
     residue = DBL_MAX;
     iteration = 1;
+    nm = V.n_rows * V.n_cols;
+    // Remove history.
     normOld = 0;
-
-    const size_t n = V.n_rows;
-    const size_t m = V.n_cols;
-
-    nm = n * m;
   }
 
   /**
@@ -87,9 +91,8 @@ class SimpleResidueTermination
     return (residue < minResidue || iteration > maxIterations);
   }
 
-  const double& Index() { return residue; }
-  const size_t& Iteration() { return iteration; }
-  const size_t& MaxIterations() { return maxIterations; }
+  //! Get current value of residue
+  const double& Index() const { return residue; }
 
   //! Get current iteration count
   const size_t& Iteration() const { return iteration; }
@@ -102,12 +105,17 @@ class SimpleResidueTermination
   const double& MinResidue() const { return minResidue; }
   double& MinResidue() { return minResidue; }
 
- public:
+public:
+  //! residue threshold
   double minResidue;
+  //! iteration threshold
   size_t maxIterations;
 
+  //! current value of residue
   double residue;
+  //! current iteration count
   size_t iteration;
+  //! norm of previous iteration
   double normOld;
 
   size_t nm;
diff --git a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
index 248ef82..94be641 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
@@ -2,6 +2,8 @@
  * @file simple_tolerance_termination.hpp
  * @author Sumedh Ghaisas
  *
+ * Termination policy used in AMF (Alternating Matrix Factorization).
+ *
  * This file is part of MLPACK 1.0.11.
  *
  * MLPACK is free software: you can redistribute it and/or modify it under the
@@ -39,6 +41,7 @@ template <class MatType>
 class SimpleToleranceTermination
 {
  public:
+  //! empty constructor
   SimpleToleranceTermination(const double tolerance = 1e-5,
                              const size_t maxIterations = 10000,
                              const size_t reverseStepTolerance = 3)
@@ -46,6 +49,11 @@ class SimpleToleranceTermination
               maxIterations(maxIterations),
               reverseStepTolerance(reverseStepTolerance) {}
 
+  /**
+   * Initializes the termination policy before stating the factorization.
+   *
+   * @param V Input matrix to be factorized.
+   */
   void Initialize(const MatType& V)
   {
     residueOld = DBL_MAX;
@@ -62,13 +70,19 @@ class SimpleToleranceTermination
     reverseStepCount = 0;
   }
 
+  /**
+   * Check if termination criterio is met.
+   *
+   * @param W Basis matrix of output.
+   * @param H Encoding matrix of output.
+   */
   bool IsConverged(arma::mat& W, arma::mat& H)
   {
-    // Calculate norm of WH after each iteration.
     arma::mat WH;
 
     WH = W * H;
 
+    // compute residue
     residueOld = residue;
     size_t n = V->n_rows;
     size_t m = V->n_cols;
@@ -91,37 +105,43 @@ class SimpleToleranceTermination
     residue = sum / count;
     residue = sqrt(residue);
 
-<<<<<<< .working
-    iteration++;
-
-=======
+    // increment iteration count
     iteration++;
 
     // if residue tolerance is not satisfied
->>>>>>> .merge-right.r17287
     if ((residueOld - residue) / residueOld < tolerance && iteration > 4)
     {
+      // check if this is a first of successive drops
       if (reverseStepCount == 0 && isCopy == false)
       {
+        // store a copy of W and H matrix
         isCopy = true;
         this->W = W;
         this->H = H;
+        // store residue values
         c_index = residue;
         c_indexOld = residueOld;
       }
+      // increase successive drop count
       reverseStepCount++;
     }
+    // if tolerance is satisfied
     else
     {
+      // initialize successive drop count
       reverseStepCount = 0;
+      // if residue is droped below minimum scrap stored values
       if(residue <= c_indexOld && isCopy == true)
       {
         isCopy = false;
       }
     }
 
+    // check if termination criterion is met
     if(reverseStepCount == reverseStepTolerance || iteration > maxIterations)
     {
+      // if stored values are present replace them with current value as they
+      // represent the minimum residue point
       if(isCopy)
       {
         W = this->W;
@@ -133,9 +153,8 @@ class SimpleToleranceTermination
     else return false;
   }
 
-  const double& Index() { return residue; }
-  const size_t& Iteration() { return iteration; }
-  const size_t& MaxIterations() { return maxIterations; }
+  //! Get current value of residue
+  const double& Index() const { return residue; }
 
   //! Get current iteration count
   const size_t& Iteration() const { return iteration; }
@@ -151,10 +170,13 @@ class SimpleToleranceTermination
  private:
   //! tolerance
   double tolerance;
+  //! iteration threshold
   size_t maxIterations;
 
+  //! pointer to matrix being factorized
   const MatType* V;
 
+  //! current iteration count
   size_t iteration;
 
   //! residue values
@@ -162,10 +184,13 @@ class SimpleToleranceTermination
   double residue;
   double normOld;
 
+  //! tolerance on successive residue drops
   size_t reverseStepTolerance;
   //! successive residue drops
   size_t reverseStepCount;
 
+  //! indicates whether a copy of information is available which corresponds to
+  //! minimum residue point
   bool isCopy;
 
   //! variables to store information of minimum residue poi
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 75f0e9a..bb3e87e 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -69,22 +69,21 @@ class DecisionStump
   void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
 
   /**
-   *
-   *
-   *
-   *
+   * Alternate constructor which copies parameters bucketSize and numClass from
+   * an already initiated decision stump, other. It appropriately sets the
+   * weight vector.
+   *
+   * @param other The other initiated Decision Stump object from
+   *      which we copy the values.
+   * @param data The data on which to train this object on.
+   * @param D Weight vector to use while training. For boosting purposes.
+   * @param labels The labels of data.
+   * @param isWeight Whether we need to run a weighted Decision Stump.
    */
-  DecisionStump(const DecisionStump<>& ds);
-
-  /**
-   *
-   *
-   *
-   *
-   *
-   *
-  ModifyData(MatType& data, const arma::Row<double>& D);
-  */
+  DecisionStump(const DecisionStump<>& other,
+                const MatType& data,
+                const arma::rowvec& weights,
+                const arma::Row<size_t>& labels);
 
   //! Access the splitting attribute.
   int SplitAttribute() const { return splitAttribute; }
@@ -123,9 +122,12 @@ class DecisionStump
    *
    * @param attribute A row from the training data, which might be a
    *     candidate for the splitting attribute.
+   * @param isWeight Whether we need to run a weighted Decision Stump.
    */
+  template <bool isWeight>
   double SetupSplitAttribute(const arma::rowvec& attribute,
-                             const arma::Row<size_t>& labels);
+                             const arma::Row<size_t>& labels,
+                             const arma::rowvec& weightD);
 
   /**
    * After having decided the attribute on which to split, train on that
@@ -149,7 +151,8 @@ class DecisionStump
    * @param subCols The vector in which to find the most frequently
    *     occurring element.
    */
-  template <typename rType> rType CountMostFreq(const arma::Row<rType>& subCols);
+  template <typename rType> rType CountMostFreq(const arma::Row<rType>&
+      subCols);
 
   /**
    * Returns 1 if all the values of featureRow are not same.
@@ -163,6 +166,7 @@ class DecisionStump
    *
    * @param attribute The attribute of which we calculate the entropy.
    * @param labels Corresponding labels of the attribute.
+   * @param isWeight Whether we need to run a weighted Decision Stump.
    */
   template <typename LabelType, bool isWeight>
   double CalculateEntropy(arma::subview_row<LabelType> labels, int begin,
@@ -178,6 +182,7 @@ class DecisionStump
   template <bool isWeight>
   void Train(const MatType& data, const arma::Row<size_t>& labels,
              const arma::rowvec& weightD);
+
 };
 
 }; // namespace decision_stump
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 6071937..f8190b5 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -26,9 +26,6 @@
 // In case it hasn't been included yet.
 #include "decision_stump.hpp"
 
-#include <set>
-#include <algorithm>
-
 namespace mlpack {
 namespace decision_stump {
 
@@ -49,11 +46,28 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
   numClass = classes;
   bucketSize = inpBucketSize;
 
+  arma::rowvec weightD;
+
+  Train<false>(data, labels, weightD);
+}
+
+/**
+ * Train the decision stump on the given data and labels.
+ *
+ * @param data Dataset to train on.
+ * @param labels Labels for dataset.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
+ */
+template<typename MatType>
+template <bool isWeight>
+void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>& labels,
+                                    const arma::rowvec& weightD)
+{
   // If classLabels are not all identical, proceed with training.
-  size_t bestAtt = 0;
+  int bestAtt = 0;
   double entropy;
-  const double rootEntropy = CalculateEntropy<size_t>(
-      labels.subvec(0, labels.n_elem - 1));
+  const double rootEntropy = CalculateEntropy<size_t, isWeight>(
+      labels.subvec(0, labels.n_elem - 1), 0, weightD);
 
   double gain, bestGain = 0.0;
   for (size_t i = 0; i < data.n_rows; i++)
@@ -63,9 +77,8 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
     {
       // For each attribute with non-identical values, treat it as a potential
       // splitting attribute and calculate entropy if split on it.
-      entropy = SetupSplitAttribute(data.row(i), labels);
+      entropy = SetupSplitAttribute<isWeight>(data.row(i), labels, weightD);
 
-      // Log::Debug << "Entropy for attribute " << i << " is " << entropy << ".\n";
       gain = rootEntropy - entropy;
       // Find the attribute with the best entropy so that the gain is
       // maximized.
@@ -119,48 +132,48 @@ void DecisionStump<MatType>::Classify(const MatType& test,
 }
 
 /**
+ * Alternate constructor which copies parameters bucketSize and numClass
+ * from an already initiated decision stump, other. It appropriately
+ * sets the Weight vector.
  *
- *
- *
- *
- *
+ * @param other The other initiated Decision Stump object from
+ *      which we copy the values from.
+ * @param data The data on which to train this object on.
+ * @param D Weight vector to use while training. For boosting purposes.
+ * @param labels The labels of data.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
  */
 template <typename MatType>
-DecisionStump<MatType>::DecisionStump(const DecisionStump<>& ds)
+DecisionStump<MatType>::DecisionStump(
+                        const DecisionStump<>& other,
+                        const MatType& data,
+                        const arma::rowvec& weights,
+                        const arma::Row<size_t>& labels
+                        )
 {
-  numClass = ds.numClass;
-
-  splitAttribute = ds.splitAttribute;
-
-  bucketSize = ds.bucketSize;
+  numClass = other.numClass;
+  bucketSize = other.bucketSize;
 
-  split = ds.split;
+  // weightD = weights;
+  // tempD = weightD;
 
-  binLabels = ds.binLabels;
+  Train<true>(data, labels, weights);
 }
 
 /**
- *
- *
- *
- *
- *
- *
-template <typename MatType>
-DecisionStump<MatType>::ModifyData(MatType& data, const arma::Row<double>& D)
- */
-
-/**
  * Sets up attribute as if it were splitting on it and finds entropy when
  * splitting on attribute.
  *
  * @param attribute A row from the training data, which might be a candidate for
  *      the splitting attribute.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
  */
 template <typename MatType>
+template <bool isWeight>
 double DecisionStump<MatType>::SetupSplitAttribute(
     const arma::rowvec& attribute,
-    const arma::Row<size_t>& labels)
+    const arma::Row<size_t>& labels,
+    const arma::rowvec& weightD)
 {
   size_t i, count, begin, end;
   double entropy = 0.0;
@@ -175,9 +188,16 @@ double DecisionStump<MatType>::SetupSplitAttribute(
   arma::Row<size_t> sortedLabels(attribute.n_elem);
   sortedLabels.fill(0);
 
+  arma::rowvec tempD = arma::rowvec(weightD.n_cols);
+
   for (i = 0; i < attribute.n_elem; i++)
+  {
     sortedLabels(i) = labels(sortedIndexAtt(i));
 
+    if(isWeight)
+      tempD(i) = weightD(sortedIndexAtt(i));
+  }
+
   i = 0;
   count = 0;
 
@@ -196,8 +216,8 @@ double DecisionStump<MatType>::SetupSplitAttribute(
       // Use ratioEl to calculate the ratio of elements in this split.
       const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
 
-      entropy += ratioEl * CalculateEntropy<size_t>(
-          sortedLabels.subvec(begin, end));
+      entropy += ratioEl * CalculateEntropy<size_t, isWeight>(
+          sortedLabels.subvec(begin, end), begin, tempD);
       i++;
     }
     else if (sortedLabels(i) != sortedLabels(i + 1))
@@ -223,8 +243,8 @@ double DecisionStump<MatType>::SetupSplitAttribute(
       }
       const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
 
-      entropy += ratioEl * CalculateEntropy<size_t>(
-          sortedLabels.subvec(begin, end));
+      entropy += ratioEl * CalculateEntropy<size_t, isWeight>(
+          sortedLabels.subvec(begin, end), begin, tempD);
 
       i = end + 1;
       count = 0;
@@ -368,7 +388,6 @@ rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
   rType mostFreq = it->first;
   size_t mostFreqCount = it->second;
   while (it != countMap.end())
->>>>>>> .merge-right.r17318
   {
     if (it->second >= mostFreqCount)
     {
@@ -403,10 +422,13 @@ int DecisionStump<MatType>::IsDistinct(const arma::Row<rType>& featureRow)
  *
  * @param attribute The attribute for which we calculate the entropy.
  * @param labels Corresponding labels of the attribute.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
  */
 template<typename MatType>
-template<typename AttType, typename LabelType>
-double DecisionStump<MatType>::CalculateEntropy(arma::subview_row<LabelType> labels)
+template<typename LabelType, bool isWeight>
+double DecisionStump<MatType>::CalculateEntropy(
+    arma::subview_row<LabelType> labels,
+    int begin, const arma::rowvec& tempD)
 {
   double entropy = 0.0;
   size_t j;
@@ -414,19 +436,43 @@ double DecisionStump<MatType>::CalculateEntropy(arma::subview_row<LabelType> lab
   arma::Row<size_t> numElem(numClass);
   numElem.fill(0);
 
-  // Populate numElem; they are used as helpers to calculate
-  // entropy.
-  for (j = 0; j < labels.n_elem; j++)
-    numElem(labels(j))++;
+  // Variable to accumulate the weight in this subview_row.
+  double accWeight = 0.0;
+  // Populate numElem; they are used as helpers to calculate entropy.
+
+  if (isWeight)
+  {
+    for (j = 0; j < labels.n_elem; j++)
+    {
+      numElem(labels(j)) += tempD(j + begin);
+      accWeight += tempD(j + begin);
+    }
+      // numElem(labels(j))++;
+
+    for (j = 0; j < numClass; j++)
+    {
+      const double p1 = ((double) numElem(j) / accWeight);
 
-  // The equation for entropy uses log2(), but log2() is from C99 and thus
-  // Visual Studio will not have it.  Therefore, we will use std::log(), and
-  // then apply the change-of-base formula at the end of the calculation.
-  for (j = 0; j < numClass; j++)
+      // Instead of using log2(), which is C99 and may not exist on some
+      // compilers, use std::log(), then use the change-of-base formula to make
+      // the result correct.
+      entropy += (p1 == 0) ? 0 : p1 * std::log(p1);
+    }
+  }
+  else
   {
-    const double p1 = ((double) numElem(j) / labels.n_elem);
+    for (j = 0; j < labels.n_elem; j++)
+      numElem(labels(j))++;
 
-    entropy += (p1 == 0) ? 0 : p1 * std::log(p1);
+    for (j = 0; j < numClass; j++)
+    {
+      const double p1 = ((double) numElem(j) / labels.n_elem);
+
+      // Instead of using log2(), which is C99 and may not exist on some
+      // compilers, use std::log(), then use the change-of-base formula to make
+      // the result correct.
+      entropy += (p1 == 0) ? 0 : p1 * std::log(p1);
+    }
   }
 
   return entropy / std::log(2.0);
diff --git a/src/mlpack/tests/allknn_test.cpp b/src/mlpack/tests/allknn_test.cpp
index cfe7988..2a178f3 100644
--- a/src/mlpack/tests/allknn_test.cpp
+++ b/src/mlpack/tests/allknn_test.cpp
@@ -1,4 +1,3 @@
-<<<<<<< .working
 /**
  * @file allknn_test.cpp
  *
diff --git a/src/mlpack/tests/gmm_test.cpp b/src/mlpack/tests/gmm_test.cpp
index 7fd1650..8e22360 100644
--- a/src/mlpack/tests/gmm_test.cpp
+++ b/src/mlpack/tests/gmm_test.cpp
@@ -442,7 +442,7 @@ BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussiansWithProbability)
 
   for (size_t row = 0; row < 3; row++)
     for (size_t col = 0; col < 3; col++)
-      BOOST_REQUIRE_SMALL((g.Component(sortedIndices[0]).Covariance()(row, col)
+      BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[0]](row, col)
           - d4.Covariance()(row, col)), 0.7); // Big tolerance!  Lots of noise.
 
   // Second Gaussian (d1).
@@ -453,7 +453,7 @@ BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussiansWithProbability)
 
   for (size_t row = 0; row < 3; row++)
     for (size_t col = 0; col < 3; col++)
-      BOOST_REQUIRE_SMALL((g.Component(sortedIndices[1]).Covariance()(row, col)
+      BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[1]](row, col)
           - d1.Covariance()(row, col)), 0.7); // Big tolerance!  Lots of noise.
 
   // Third Gaussian (d2).
@@ -464,7 +464,7 @@ BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussiansWithProbability)
 
   for (size_t row = 0; row < 3; row++)
     for (size_t col = 0; col < 3; col++)
-      BOOST_REQUIRE_SMALL((g.Component(sortedIndices[2]).Covariance()(row, col)
+      BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[2]](row, col)
           - d2.Covariance()(row, col)), 0.7); // Big tolerance!  Lots of noise.
 
   // Fourth gaussian (d3).
@@ -475,7 +475,7 @@ BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussiansWithProbability)
 
   for (size_t row = 0; row < 3; ++row)
     for (size_t col = 0; col < 3; ++col)
-      BOOST_REQUIRE_SMALL((g.Component(sortedIndices[3]).Covariance()(row, col)
+      BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[3]](row, col)
           - d3.Covariance()(row, col)), 0.7);
 }
 
diff --git a/src/mlpack/tests/svd_batch_test.cpp b/src/mlpack/tests/svd_batch_test.cpp
index 3716aaf..bf39c84 100644
--- a/src/mlpack/tests/svd_batch_test.cpp
+++ b/src/mlpack/tests/svd_batch_test.cpp
@@ -1,6 +1,26 @@
+/**
+ * @file svd_batch_test.cpp
+ * @author Sumedh Ghaisas
+ *
+ * This file is part of MLPACK 1.0.11.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK.  If not, see <http://www.gnu.org/licenses/>.
+ */
 #include <mlpack/core.hpp>
 #include <mlpack/methods/amf/amf.hpp>
 #include <mlpack/methods/amf/update_rules/svd_batch_learning.hpp>
+#include <mlpack/methods/amf/init_rules/average_init.hpp>
 #include <mlpack/methods/amf/init_rules/random_init.hpp>
 #include <mlpack/methods/amf/termination_policies/validation_RMSE_termination.hpp>
 #include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
@@ -17,21 +37,6 @@ using namespace arma;
 
 /**
  * Make sure the SVD Batch lerning is converging.
- *
- * This file is part of MLPACK 1.0.11.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK.  If not, see <http://www.gnu.org/licenses/>.
  */
 BOOST_AUTO_TEST_CASE(SVDBatchConvergenceElementTest)
 {



More information about the mlpack-git mailing list