[mlpack-git] master: Adds a Train() function that only needs dataset statistics, not the dataset itself (09a90ee)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Jul 26 08:29:02 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/b12e1517565652f62a551e1df41f2e51569bc9d9...e3ce1600893636f15be55d66c621398673b1faba
>---------------------------------------------------------------
commit 09a90ee81a8ae3b69be94d8f84ff32794af2b7d6
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Tue Jul 26 13:29:02 2016 +0100
Adds a Train() function that only needs dataset statistics, not the dataset itself
>---------------------------------------------------------------
09a90ee81a8ae3b69be94d8f84ff32794af2b7d6
src/mlpack/core/dists/gamma_distribution.cpp | 81 +++++++++++++++++++---------
src/mlpack/core/dists/gamma_distribution.hpp | 15 ++++++
src/mlpack/tests/distribution_test.cpp | 22 ++++++++
3 files changed, 92 insertions(+), 26 deletions(-)
diff --git a/src/mlpack/core/dists/gamma_distribution.cpp b/src/mlpack/core/dists/gamma_distribution.cpp
index 8cfb040..4e90de0 100644
--- a/src/mlpack/core/dists/gamma_distribution.cpp
+++ b/src/mlpack/core/dists/gamma_distribution.cpp
@@ -43,9 +43,11 @@ void GammaDistribution::Train(const arma::mat& rdata, const double tol)
using boost::math::trigamma;
using std::log;
- // Allocate space for alphas and betas (Assume independent rows).
- alpha.set_size(rdata.n_rows);
- beta.set_size(rdata.n_rows);
+ // Allocate temporary space for alphas and betas (Assume independent rows).
+ // We need temporary vectors since we are calling Train(logMeanx, meanLogx,
+ // meanx) which modifies the object's own alphas and betas.
+ arma::vec tempAlpha(rdata.n_rows);
+ arma::vec tempBeta(rdata.n_rows);
// Calculate log(mean(x)) and mean(log(x)) of each dataset row.
const arma::vec meanLogxVec = arma::mean(arma::log(rdata), 1);
@@ -55,36 +57,63 @@ void GammaDistribution::Train(const arma::mat& rdata, const double tol)
// Treat each dimension (i.e. row) independently.
for (size_t row = 0; row < rdata.n_rows; ++row)
{
- // Statistics for this row.
+ // Statistics for this Vrow.
const double meanLogx = meanLogxVec(row);
const double meanx = meanxVec(row);
const double logMeanx = logMeanxVec(row);
- // Starting point for Generalized Newton.
- double aEst = 0.5 / (logMeanx - meanLogx);
- double aOld;
+ // Use the statistics-only function to fit this dimension.
+ Train(logMeanx, meanLogx, meanx, tol);
+
+ // The function above modifies the state of this object. Get the parameters
+ // it fit. (This is not very good design...).
+ tempAlpha(row) = alpha(0);
+ tempBeta(row) = beta(0);
+ }
+ alpha = tempAlpha;
+ beta = tempBeta;
+
+}
+
+// Fits an alpha and beta parameter to each dimension of the data.
+void GammaDistribution::Train(const double logMeanx, const double meanLogx,
+ const double meanx,
+ const double tol)
+{
+ // Use boost's definitions of digamma and tgamma, and std::log.
+ using boost::math::digamma;
+ using boost::math::trigamma;
+ using std::log;
+
+ // Allocate space for alphas and betas (Assume independent rows).
+ alpha.set_size(1);
+ beta.set_size(1);
- // Newton's method: In each step, make an update to aEst. If value didn't
- // change much (abs(aNew - aEst)/aEst < tol), then stop.
- do
- {
- // Needed for convergence test.
- aOld = aEst;
- // Calculate new value for alpha.
- double nominator = meanLogx - logMeanx + log(aEst) - digamma(aEst);
- double denominator = pow(aEst, 2) * (1 / aEst - trigamma(aEst));
- assert (denominator != 0); // Protect against division by 0.
- aEst = 1.0 / ((1.0 / aEst) + nominator / denominator);
+ // Starting point for Generalized Newton.
+ double aEst = 0.5 / (logMeanx - meanLogx);
+ double aOld;
- // Protect against nan values (aEst will be passed to logarithm).
- if (aEst <= 0)
- throw std::logic_error("GammaDistribution::Train(): estimated invalid "
- "negative value for parameter alpha!");
+ // Newton's method: In each step, make an update to aEst. If value didn't
+ // change much (abs(aNew - aEst)/aEst < tol), then stop.
+ do
+ {
+ // Needed for convergence test.
+ aOld = aEst;
- } while (!Converged(aEst, aOld, tol));
+ // Calculate new value for alpha.
+ double nominator = meanLogx - logMeanx + log(aEst) - digamma(aEst);
+ double denominator = pow(aEst, 2) * (1 / aEst - trigamma(aEst));
+ assert (denominator != 0); // Protect against division by 0.
+ aEst = 1.0 / ((1.0 / aEst) + nominator / denominator);
- alpha(row) = aEst;
- beta(row) = meanx / aEst;
- }
+ // Protect against nan values (aEst will be passed to logarithm).
+ if (aEst <= 0)
+ throw std::logic_error("GammaDistribution::Train(): estimated invalid "
+ "negative value for parameter alpha!");
+
+ } while (!Converged(aEst, aOld, tol));
+
+ alpha(0) = aEst;
+ beta(0) = meanx / aEst;
}
diff --git a/src/mlpack/core/dists/gamma_distribution.hpp b/src/mlpack/core/dists/gamma_distribution.hpp
index 698e7b4..e6c3a80 100644
--- a/src/mlpack/core/dists/gamma_distribution.hpp
+++ b/src/mlpack/core/dists/gamma_distribution.hpp
@@ -78,6 +78,21 @@ class GammaDistribution
*/
void Train(const arma::mat& rdata, const double tol = 1e-8);
+ /**
+ * This function trains (fits distribution parameters) to a 1-dimensional
+ * dataset with pre-computed statistics logMeanx, meanLogx, meanx.
+ *
+ * @param logMeanx Is the dimension's logarithm of the mean (log(mean(x))).
+ * @param meanLogx Is the dimension's mean of logarithms (mean(log(x))).
+ * @param meanx Is the dimension's mean (mean(x)).
+ * @param tol Convergence tolerance. This is *not* an absolute measure:
+ * It will stop the approximation once the *change* in the value is
+ * smaller than tol.
+ */
+ void Train(const double logMeanx, const double meanLogx,
+ const double meanx,
+ const double tol = 1e-8);
+
// Access to Gamma distribution parameters.
//! Get the alpha parameter of the given dimension.
diff --git a/src/mlpack/tests/distribution_test.cpp b/src/mlpack/tests/distribution_test.cpp
index 68c5754..1283e38 100644
--- a/src/mlpack/tests/distribution_test.cpp
+++ b/src/mlpack/tests/distribution_test.cpp
@@ -517,4 +517,26 @@ BOOST_AUTO_TEST_CASE(GammaDistributionTrainConstructorTest)
}
}
+/**
+ * Test that Train() with a dataset and Train() with dataset statistics return
+ * the same results.
+ */
+BOOST_AUTO_TEST_CASE(GammaDistributionTrainStatisticsTest)
+{
+ const arma::mat data = arma::randu<arma::mat>(1, 500);
+
+ // Train object d1 with the data.
+ GammaDistribution d1(data);
+
+ // Train object d2 with the data's statistics.
+ GammaDistribution d2;
+ const arma::vec meanLogx = arma::mean(arma::log(data), 1);
+ const arma::vec meanx = arma::mean(data, 1);
+ const arma::vec logMeanx = arma::log(meanx);
+ d2.Train(logMeanx(0), meanLogx(0), meanx(0));
+
+ BOOST_REQUIRE_CLOSE(d1.Alpha(0), d2.Alpha(0), 1e-5);
+ BOOST_REQUIRE_CLOSE(d1.Beta(0), d2.Beta(0), 1e-5);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list