[mlpack-git] master: Adds a Train() function that only needs dataset statistics, not the dataset itself (09a90ee)

gitdub at mlpack.org gitdub at mlpack.org
Tue Jul 26 08:29:02 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/b12e1517565652f62a551e1df41f2e51569bc9d9...e3ce1600893636f15be55d66c621398673b1faba

>---------------------------------------------------------------

commit 09a90ee81a8ae3b69be94d8f84ff32794af2b7d6
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Tue Jul 26 13:29:02 2016 +0100

    Adds a Train() function that only needs dataset statistics, not the dataset itself


>---------------------------------------------------------------

09a90ee81a8ae3b69be94d8f84ff32794af2b7d6
 src/mlpack/core/dists/gamma_distribution.cpp | 81 +++++++++++++++++++---------
 src/mlpack/core/dists/gamma_distribution.hpp | 15 ++++++
 src/mlpack/tests/distribution_test.cpp       | 22 ++++++++
 3 files changed, 92 insertions(+), 26 deletions(-)

diff --git a/src/mlpack/core/dists/gamma_distribution.cpp b/src/mlpack/core/dists/gamma_distribution.cpp
index 8cfb040..4e90de0 100644
--- a/src/mlpack/core/dists/gamma_distribution.cpp
+++ b/src/mlpack/core/dists/gamma_distribution.cpp
@@ -43,9 +43,11 @@ void GammaDistribution::Train(const arma::mat& rdata, const double tol)
   using boost::math::trigamma;
   using std::log;
 
-  // Allocate space for alphas and betas (Assume independent rows).
-  alpha.set_size(rdata.n_rows);
-  beta.set_size(rdata.n_rows);
+  // Allocate temporary space for alphas and betas (Assume independent rows).
+  // We need temporary vectors since we are calling Train(logMeanx, meanLogx,
+  // meanx) which modifies the object's own alphas and betas.
+  arma::vec tempAlpha(rdata.n_rows);
+  arma::vec tempBeta(rdata.n_rows);
 
   // Calculate log(mean(x)) and mean(log(x)) of each dataset row.
   const arma::vec meanLogxVec = arma::mean(arma::log(rdata), 1);
@@ -55,36 +57,63 @@ void GammaDistribution::Train(const arma::mat& rdata, const double tol)
   // Treat each dimension (i.e. row) independently.
   for (size_t row = 0; row < rdata.n_rows; ++row)
   {
-    // Statistics for this row.
+    // Statistics for this Vrow.
     const double meanLogx = meanLogxVec(row);
     const double meanx = meanxVec(row);
     const double logMeanx = logMeanxVec(row);
 
-    // Starting point for Generalized Newton.
-    double aEst = 0.5 / (logMeanx - meanLogx);
-    double aOld;
+    // Use the statistics-only function to fit this dimension.
+    Train(logMeanx, meanLogx, meanx, tol);
+
+    // The function above modifies the state of this object. Get the parameters
+    // it fit. (This is not very good design...).
+    tempAlpha(row) = alpha(0);
+    tempBeta(row) = beta(0);
+  }
+  alpha = tempAlpha;
+  beta = tempBeta;
+
+}
+
+// Fits an alpha and beta parameter to each dimension of the data.
+void GammaDistribution::Train(const double logMeanx, const double meanLogx,
+                              const double meanx,
+                              const double tol)
+{
+  // Use boost's definitions of digamma and tgamma, and std::log.
+  using boost::math::digamma;
+  using boost::math::trigamma;
+  using std::log;
+
+  // Allocate space for alphas and betas (Assume independent rows).
+  alpha.set_size(1);
+  beta.set_size(1);
 
-    // Newton's method: In each step, make an update to aEst. If value didn't
-    // change much (abs(aNew - aEst)/aEst < tol), then stop.
-    do
-    {
-      // Needed for convergence test.
-      aOld = aEst;
 
-      // Calculate new value for alpha.
-      double nominator = meanLogx - logMeanx + log(aEst) - digamma(aEst);
-      double denominator = pow(aEst, 2) * (1 / aEst - trigamma(aEst));
-      assert (denominator != 0); // Protect against division by 0.
-      aEst = 1.0 / ((1.0 / aEst) + nominator / denominator);
+  // Starting point for Generalized Newton.
+  double aEst = 0.5 / (logMeanx - meanLogx);
+  double aOld;
 
-      // Protect against nan values (aEst will be passed to logarithm).
-      if (aEst <= 0)
-        throw std::logic_error("GammaDistribution::Train(): estimated invalid "
-            "negative value for parameter alpha!");
+  // Newton's method: In each step, make an update to aEst. If value didn't
+  // change much (abs(aNew - aEst)/aEst < tol), then stop.
+  do
+  {
+    // Needed for convergence test.
+    aOld = aEst;
 
-    } while (!Converged(aEst, aOld, tol));
+    // Calculate new value for alpha.
+    double nominator = meanLogx - logMeanx + log(aEst) - digamma(aEst);
+    double denominator = pow(aEst, 2) * (1 / aEst - trigamma(aEst));
+    assert (denominator != 0); // Protect against division by 0.
+    aEst = 1.0 / ((1.0 / aEst) + nominator / denominator);
 
-    alpha(row) = aEst;
-    beta(row) = meanx / aEst;
-  }
+    // Protect against nan values (aEst will be passed to logarithm).
+    if (aEst <= 0)
+      throw std::logic_error("GammaDistribution::Train(): estimated invalid "
+          "negative value for parameter alpha!");
+
+  } while (!Converged(aEst, aOld, tol));
+
+  alpha(0) = aEst;
+  beta(0) = meanx / aEst;
 }
diff --git a/src/mlpack/core/dists/gamma_distribution.hpp b/src/mlpack/core/dists/gamma_distribution.hpp
index 698e7b4..e6c3a80 100644
--- a/src/mlpack/core/dists/gamma_distribution.hpp
+++ b/src/mlpack/core/dists/gamma_distribution.hpp
@@ -78,6 +78,21 @@ class GammaDistribution
      */
     void Train(const arma::mat& rdata, const double tol = 1e-8);
     
+    /**
+     * This function trains (fits distribution parameters) to a 1-dimensional
+     * dataset with pre-computed statistics logMeanx, meanLogx, meanx.
+     *
+     * @param logMeanx Is the dimension's logarithm of the mean (log(mean(x))).
+     * @param meanLogx Is the dimension's mean of logarithms (mean(log(x))).
+     * @param meanx Is the dimension's mean (mean(x)).
+     * @param tol Convergence tolerance. This is *not* an absolute measure:
+     *    It will stop the approximation once the *change* in the value is 
+     *    smaller than tol.
+     */
+    void Train(const double logMeanx, const double meanLogx, 
+               const double meanx,
+               const double tol = 1e-8);
+
     // Access to Gamma distribution parameters.
 
     //! Get the alpha parameter of the given dimension.
diff --git a/src/mlpack/tests/distribution_test.cpp b/src/mlpack/tests/distribution_test.cpp
index 68c5754..1283e38 100644
--- a/src/mlpack/tests/distribution_test.cpp
+++ b/src/mlpack/tests/distribution_test.cpp
@@ -517,4 +517,26 @@ BOOST_AUTO_TEST_CASE(GammaDistributionTrainConstructorTest)
   }
 }
 
+/**
+ * Test that Train() with a dataset and Train() with dataset statistics return
+ * the same results.
+ */
+BOOST_AUTO_TEST_CASE(GammaDistributionTrainStatisticsTest)
+{
+  const arma::mat data = arma::randu<arma::mat>(1, 500);
+
+  // Train object d1 with the data.
+  GammaDistribution d1(data);
+
+  // Train object d2 with the data's statistics.
+  GammaDistribution d2;
+  const arma::vec meanLogx = arma::mean(arma::log(data), 1);
+  const arma::vec meanx = arma::mean(data, 1);
+  const arma::vec logMeanx = arma::log(meanx);
+  d2.Train(logMeanx(0), meanLogx(0), meanx(0));
+
+  BOOST_REQUIRE_CLOSE(d1.Alpha(0), d2.Alpha(0), 1e-5);
+  BOOST_REQUIRE_CLOSE(d1.Beta(0), d2.Beta(0), 1e-5);
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list