[mlpack-git] master: A smarter strategy for checking positive-definiteness. (e08a8ff)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Dec 18 11:43:12 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/5ba11bc90223b55eecd5da4cfbe86c8fc40637a5...df229e45a5bd7842fe019e9d49ed32f13beb6aaa
>---------------------------------------------------------------
commit e08a8ff534e2c8c153007c51bc55a5e0ac69c9a6
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Dec 18 15:47:43 2015 +0000
A smarter strategy for checking positive-definiteness.
>---------------------------------------------------------------
e08a8ff534e2c8c153007c51bc55a5e0ac69c9a6
src/mlpack/core/dists/gaussian_distribution.hpp | 5 +++
.../methods/gmm/positive_definite_constraint.hpp | 37 +++++++++++++++++++---
2 files changed, 38 insertions(+), 4 deletions(-)
diff --git a/src/mlpack/core/dists/gaussian_distribution.hpp b/src/mlpack/core/dists/gaussian_distribution.hpp
index 2b50219..9ba9a5a 100644
--- a/src/mlpack/core/dists/gaussian_distribution.hpp
+++ b/src/mlpack/core/dists/gaussian_distribution.hpp
@@ -154,6 +154,11 @@ class GaussianDistribution
}
private:
+ /**
+ * This factors the covariance using arma::chol(). The function assumes that
+ * the given matrix is factorizable via the Cholesky decomposition. If not, a
+ * std::runtime_error will be thrown.
+ */
void FactorCovariance();
};
diff --git a/src/mlpack/methods/gmm/positive_definite_constraint.hpp b/src/mlpack/methods/gmm/positive_definite_constraint.hpp
index 45882bb..96497ff 100644
--- a/src/mlpack/methods/gmm/positive_definite_constraint.hpp
+++ b/src/mlpack/methods/gmm/positive_definite_constraint.hpp
@@ -25,19 +25,48 @@ class PositiveDefiniteConstraint
*/
static void ApplyConstraint(arma::mat& covariance)
{
- // TODO: make this more efficient.
- if (arma::det(covariance) <= 1e-50)
+ // Realistically, all we care about is that we can perform a Cholesky
+ // decomposition of the matrix, so that FactorCovariance() doesn't fail
+ // later. Therefore, that's what we'll do to check for positive
+ // definiteness...
+ //
+ // Note that other techniques like checking the determinant *could* work,
+ // but floating-point errors mean that various decompositions may start to
+ // fail when the matrix gets close to being indefinite. This is why we test
+ // with chol() and not something else, since that's what will be used later.
+ //
+ // We also need to make sure that the errors go to nowhere, so we have to
+ // call set_stream_err2()...
+ std::ostringstream oss;
+ std::ostream& originalStream = arma::get_stream_err2();
+ arma::set_stream_err2(oss); // Thus, errors won't be displayed.
+
+ arma::mat covLower;
+ #if (ARMA_VERSION_MAJOR < 4) || \
+ ((ARMA_VERSION_MAJOR == 4) && (ARMA_VERSION_MINOR < 500))
+ if (!arma::chol(covLower, covariance))
+ #else
+ if (!arma::chol(covLower, covariance, "lower"))
+ #endif
{
Log::Debug << "Covariance matrix is not positive definite. Adding "
<< "perturbation." << std::endl;
- double perturbation = 1e-30;
- while (arma::det(covariance) <= 1e-50)
+ double perturbation = 1e-15;
+ #if (ARMA_VERSION_MAJOR < 4) || \
+ ((ARMA_VERSION_MAJOR == 4) && (ARMA_VERSION_MAJOR < 500))
+ while (!arma::chol(covLower, covariance))
+ #else
+ while (!arma::chol(covLower, covariance, "lower"))
+ #endif
{
covariance.diag() += perturbation;
perturbation *= 10;
}
}
+
+ // Restore the original stream state.
+ arma::set_stream_err2(originalStream);
}
//! Serialize the constraint (which stores nothing, so, nothing to do).
More information about the mlpack-git
mailing list