[mlpack-git] master: Changes Train overload to accept multi-dimensional statistics (253c493)

gitdub at mlpack.org gitdub at mlpack.org
Wed Aug 3 06:26:02 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/b12e1517565652f62a551e1df41f2e51569bc9d9...e3ce1600893636f15be55d66c621398673b1faba

>---------------------------------------------------------------

commit 253c493c4043feb9565c644c46f25ae6d0b24fb7
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Wed Aug 3 11:26:02 2016 +0100

    Changes Train overload to accept multi-dimensional statistics


>---------------------------------------------------------------

253c493c4043feb9565c644c46f25ae6d0b24fb7
 src/mlpack/core/dists/gamma_distribution.cpp | 100 +++++++++++++--------------
 src/mlpack/core/dists/gamma_distribution.hpp |   5 +-
 src/mlpack/tests/distribution_test.cpp       |   2 +-
 3 files changed, 54 insertions(+), 53 deletions(-)

diff --git a/src/mlpack/core/dists/gamma_distribution.cpp b/src/mlpack/core/dists/gamma_distribution.cpp
index 4e90de0..7b3e000 100644
--- a/src/mlpack/core/dists/gamma_distribution.cpp
+++ b/src/mlpack/core/dists/gamma_distribution.cpp
@@ -43,41 +43,20 @@ void GammaDistribution::Train(const arma::mat& rdata, const double tol)
   using boost::math::trigamma;
   using std::log;
 
-  // Allocate temporary space for alphas and betas (Assume independent rows).
-  // We need temporary vectors since we are calling Train(logMeanx, meanLogx,
-  // meanx) which modifies the object's own alphas and betas.
-  arma::vec tempAlpha(rdata.n_rows);
-  arma::vec tempBeta(rdata.n_rows);
-
   // Calculate log(mean(x)) and mean(log(x)) of each dataset row.
   const arma::vec meanLogxVec = arma::mean(arma::log(rdata), 1);
   const arma::vec meanxVec = arma::mean(rdata, 1);
   const arma::vec logMeanxVec = arma::log(meanxVec);
 
-  // Treat each dimension (i.e. row) independently.
-  for (size_t row = 0; row < rdata.n_rows; ++row)
-  {
-    // Statistics for this Vrow.
-    const double meanLogx = meanLogxVec(row);
-    const double meanx = meanxVec(row);
-    const double logMeanx = logMeanxVec(row);
-
-    // Use the statistics-only function to fit this dimension.
-    Train(logMeanx, meanLogx, meanx, tol);
-
-    // The function above modifies the state of this object. Get the parameters
-    // it fit. (This is not very good design...).
-    tempAlpha(row) = alpha(0);
-    tempBeta(row) = beta(0);
-  }
-  alpha = tempAlpha;
-  beta = tempBeta;
-
+  // Call the statistics-only GammaDistribution::Train() function to fit the
+  // parameters. That function does all the work so we're done.
+  Train(logMeanxVec, meanLogxVec, meanxVec, tol);
 }
 
 // Fits an alpha and beta parameter to each dimension of the data.
-void GammaDistribution::Train(const double logMeanx, const double meanLogx,
-                              const double meanx,
+void GammaDistribution::Train(const arma::vec& logMeanxVec, 
+                              const arma::vec& meanLogxVec,
+                              const arma::vec& meanxVec,
                               const double tol)
 {
   // Use boost's definitions of digamma and tgamma, and std::log.
@@ -85,35 +64,56 @@ void GammaDistribution::Train(const double logMeanx, const double meanLogx,
   using boost::math::trigamma;
   using std::log;
 
-  // Allocate space for alphas and betas (Assume independent rows).
-  alpha.set_size(1);
-  beta.set_size(1);
+  // Number of dimensions of gamma distribution.
+  size_t ndim = logMeanxVec.n_rows;
   
+  // Sanity check - all vectors are same size.
+  if (logMeanxVec.n_rows != meanLogxVec.n_rows ||
+      logMeanxVec.n_rows != meanxVec.n_rows)
+    throw std::runtime_error("Statistic vectors must be of the same size.");
 
-  // Starting point for Generalized Newton.
-  double aEst = 0.5 / (logMeanx - meanLogx);
-  double aOld;
+  // Allocate space for alphas and betas (Assume independent rows).
+  alpha.set_size(ndim);
+  beta.set_size(ndim);
 
-  // Newton's method: In each step, make an update to aEst. If value didn't
-  // change much (abs(aNew - aEst)/aEst < tol), then stop.
-  do
+  // Treat each dimension (i.e. row) independently.
+  for (size_t row = 0; row < ndim; ++row)
   {
-    // Needed for convergence test.
-    aOld = aEst;
+    // Statistics for this row.
+    const double meanLogx = meanLogxVec(row);
+    const double meanx = meanxVec(row);
+    const double logMeanx = logMeanxVec(row);
 
-    // Calculate new value for alpha.
-    double nominator = meanLogx - logMeanx + log(aEst) - digamma(aEst);
-    double denominator = pow(aEst, 2) * (1 / aEst - trigamma(aEst));
-    assert (denominator != 0); // Protect against division by 0.
-    aEst = 1.0 / ((1.0 / aEst) + nominator / denominator);
+    // Starting point for Generalized Newton.
+    double aEst = 0.5 / (logMeanx - meanLogx);
+    double aOld;
 
-    // Protect against nan values (aEst will be passed to logarithm).
-    if (aEst <= 0)
-      throw std::logic_error("GammaDistribution::Train(): estimated invalid "
-          "negative value for parameter alpha!");
+    // Newton's method: In each step, make an update to aEst. If value didn't
+    // change much (abs(aNew - aEst) / aEst < tol), then stop.
+    do
+    {
+      // Needed for convergence test.
+      aOld = aEst;
 
-  } while (!Converged(aEst, aOld, tol));
+      // Calculate new value for alpha.
+      double nominator = meanLogx - logMeanx + log(aEst) - digamma(aEst);
+      double denominator = pow(aEst, 2) * (1 / aEst - trigamma(aEst));
 
-  alpha(0) = aEst;
-  beta(0) = meanx / aEst;
+      // Protect against division by 0.
+      if (denominator == 0)
+        throw std::logic_error("GammaDistribution::Train() attempted division" 
+            " by 0.");
+      
+      aEst = 1.0 / ((1.0 / aEst) + nominator / denominator);
+
+      // Protect against nan values (aEst will be passed to logarithm).
+      if (aEst <= 0)
+        throw std::logic_error("GammaDistribution::Train(): estimated invalid "
+            "negative value for parameter alpha!");
+
+    } while (!Converged(aEst, aOld, tol));
+    
+    alpha(row) = aEst;
+    beta(row) = meanx / aEst;
+  }
 }
diff --git a/src/mlpack/core/dists/gamma_distribution.hpp b/src/mlpack/core/dists/gamma_distribution.hpp
index e6c3a80..b358298 100644
--- a/src/mlpack/core/dists/gamma_distribution.hpp
+++ b/src/mlpack/core/dists/gamma_distribution.hpp
@@ -89,8 +89,9 @@ class GammaDistribution
      *    It will stop the approximation once the *change* in the value is 
      *    smaller than tol.
      */
-    void Train(const double logMeanx, const double meanLogx, 
-               const double meanx,
+    void Train(const arma::vec& logMeanx, 
+               const arma::vec& meanLogx, 
+               const arma::vec& meanx,
                const double tol = 1e-8);
 
     // Access to Gamma distribution parameters.
diff --git a/src/mlpack/tests/distribution_test.cpp b/src/mlpack/tests/distribution_test.cpp
index 1283e38..44c1156 100644
--- a/src/mlpack/tests/distribution_test.cpp
+++ b/src/mlpack/tests/distribution_test.cpp
@@ -533,7 +533,7 @@ BOOST_AUTO_TEST_CASE(GammaDistributionTrainStatisticsTest)
   const arma::vec meanLogx = arma::mean(arma::log(data), 1);
   const arma::vec meanx = arma::mean(data, 1);
   const arma::vec logMeanx = arma::log(meanx);
-  d2.Train(logMeanx(0), meanLogx(0), meanx(0));
+  d2.Train(logMeanx, meanLogx, meanx);
 
   BOOST_REQUIRE_CLOSE(d1.Alpha(0), d2.Alpha(0), 1e-5);
   BOOST_REQUIRE_CLOSE(d1.Beta(0), d2.Beta(0), 1e-5);




More information about the mlpack-git mailing list