[mlpack-git] master: Handle the case where a node has only two points correctly. (5bd1e7c)
gitdub at mlpack.org
gitdub at mlpack.org
Sun Aug 7 16:54:44 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/b8de1faab0e58ff734dba4826c3a9607bcb6c84e...5bd1e7cf7e0cdc545fe06aef98ce1a1f051b976c
>---------------------------------------------------------------
commit 5bd1e7cf7e0cdc545fe06aef98ce1a1f051b976c
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Aug 7 16:50:58 2016 -0400
Handle the case where a node has only two points correctly.
>---------------------------------------------------------------
5bd1e7cf7e0cdc545fe06aef98ce1a1f051b976c
src/mlpack/core/tree/cosine_tree/cosine_tree.cpp | 36 ++++++++++++++++--------
src/mlpack/tests/quic_svd_test.cpp | 11 +++++++-
2 files changed, 35 insertions(+), 12 deletions(-)
diff --git a/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp b/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp
index 6080d96..01acbb7 100644
--- a/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp
+++ b/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp
@@ -95,7 +95,20 @@ CosineTree::CosineTree(const arma::mat& dataset,
currentNode = treeQueue.top();
treeQueue.pop();
- // Split the node into left and right children.
+ // If the priority is 0, we can't improve anything, and we can assume that
+ // we've done the best we can.
+ if (currentNode->L2Error() == 0.0)
+ {
+ Log::Warn << "CosineTree::CosineTree(): could not build tree to "
+ << "desired relative error " << epsilon << "; failing with estimated "
+ << "relative error " << (monteCarloError / root.FrobNormSquared())
+ << "." << std::endl;
+ break;
+ }
+
+ // Split the node into left and right children. We assume that this cannot
+ // fail; it might fail if L2Error() is 0, but we have already avoided that
+ // case.
currentNode->CosineNodeSplit();
// Obtain pointers to the left and right children of the current node.
@@ -277,14 +290,16 @@ void CosineTree::ConstructBasis(CosineNodeQueue& treeQueue)
void CosineTree::CosineNodeSplit()
{
- //! If less than two nodes, splitting does not make sense.
- if (numColumns < 3) return;
+ // If less than two points, splitting does not make sense---there is nothing
+ // to split.
+ if (numColumns < 2)
+ return;
- //! Calculate cosines with respect to the splitting point.
+ // Calculate cosines with respect to the splitting point.
arma::vec cosines;
CalculateCosines(cosines);
- //! Compute maximum and minimum cosine values.
+ // Compute maximum and minimum cosine values.
double cosineMax, cosineMin;
cosineMax = arma::max(cosines % (cosines < 1));
cosineMin = arma::min(cosines);
@@ -293,17 +308,16 @@ void CosineTree::CosineNodeSplit()
// Split columns into left and right children. The splitting condition for the
// column to be in the left child is as follows:
- // cos_max - cos(i) <= cos(i) - cos_min
+ // cos_max - cos(i) < cos(i) - cos_min
+ // We deviate from the paper here and use < instead of <= in order to handle
+ // the edge case where cosineMax == cosineMin, and force there to be at least
+ // one point in the right node.
for (size_t i = 0; i < numColumns; i++)
{
- if (cosineMax - cosines(i) <= cosines(i) - cosineMin)
- {
+ if (cosineMax - cosines(i) < cosines(i) - cosineMin)
leftIndices.push_back(i);
- }
else
- {
rightIndices.push_back(i);
- }
}
// Split the node into left and right children.
diff --git a/src/mlpack/tests/quic_svd_test.cpp b/src/mlpack/tests/quic_svd_test.cpp
index b444739..7e09feb 100644
--- a/src/mlpack/tests/quic_svd_test.cpp
+++ b/src/mlpack/tests/quic_svd_test.cpp
@@ -41,7 +41,7 @@ BOOST_AUTO_TEST_CASE(QUICSVDReconstructionError)
/**
* The singular value error of the obtained SVD should be small.
*/
-BOOST_AUTO_TEST_CASE(QUICSVDSigularValueError)
+BOOST_AUTO_TEST_CASE(QUICSVDSingularValueError)
{
arma::mat U = arma::randn<arma::mat>(3, 20);
arma::mat V = arma::randn<arma::mat>(10, 3);
@@ -69,4 +69,13 @@ BOOST_AUTO_TEST_CASE(QUICSVDSigularValueError)
BOOST_REQUIRE_SMALL(error, 0.05);
}
+BOOST_AUTO_TEST_CASE(QUICSVDSameDimensionTest)
+{
+ arma::mat dataset = arma::randn<arma::mat>(10, 10);
+
+ // Obtain the SVD using default parameters.
+ arma::mat u, v, sigma;
+ svd::QUIC_SVD quicsvd(dataset, u, v, sigma);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list