[mlpack-git] master: Handle the case where a node has only two points correctly. (5bd1e7c)

gitdub at mlpack.org gitdub at mlpack.org
Sun Aug 7 16:54:44 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/b8de1faab0e58ff734dba4826c3a9607bcb6c84e...5bd1e7cf7e0cdc545fe06aef98ce1a1f051b976c

>---------------------------------------------------------------

commit 5bd1e7cf7e0cdc545fe06aef98ce1a1f051b976c
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Aug 7 16:50:58 2016 -0400

    Handle the case where a node has only two points correctly.


>---------------------------------------------------------------

5bd1e7cf7e0cdc545fe06aef98ce1a1f051b976c
 src/mlpack/core/tree/cosine_tree/cosine_tree.cpp | 36 ++++++++++++++++--------
 src/mlpack/tests/quic_svd_test.cpp               | 11 +++++++-
 2 files changed, 35 insertions(+), 12 deletions(-)

diff --git a/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp b/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp
index 6080d96..01acbb7 100644
--- a/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp
+++ b/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp
@@ -95,7 +95,20 @@ CosineTree::CosineTree(const arma::mat& dataset,
     currentNode = treeQueue.top();
     treeQueue.pop();
 
-    // Split the node into left and right children.
+    // If the priority is 0, we can't improve anything, and we can assume that
+    // we've done the best we can.
+    if (currentNode->L2Error() == 0.0)
+    {
+      Log::Warn << "CosineTree::CosineTree(): could not build tree to "
+          << "desired relative error " << epsilon << "; failing with estimated "
+          << "relative error " << (monteCarloError / root.FrobNormSquared())
+          << "." << std::endl;
+      break;
+    }
+
+    // Split the node into left and right children.  We assume that this cannot
+    // fail; it might fail if L2Error() is 0, but we have already avoided that
+    // case.
     currentNode->CosineNodeSplit();
 
     // Obtain pointers to the left and right children of the current node.
@@ -277,14 +290,16 @@ void CosineTree::ConstructBasis(CosineNodeQueue& treeQueue)
 
 void CosineTree::CosineNodeSplit()
 {
-  //! If less than two nodes, splitting does not make sense.
-  if (numColumns < 3) return;
+  // If less than two points, splitting does not make sense---there is nothing
+  // to split.
+  if (numColumns < 2)
+    return;
 
-  //! Calculate cosines with respect to the splitting point.
+  // Calculate cosines with respect to the splitting point.
   arma::vec cosines;
   CalculateCosines(cosines);
 
-  //! Compute maximum and minimum cosine values.
+  // Compute maximum and minimum cosine values.
   double cosineMax, cosineMin;
   cosineMax = arma::max(cosines % (cosines < 1));
   cosineMin = arma::min(cosines);
@@ -293,17 +308,16 @@ void CosineTree::CosineNodeSplit()
 
   // Split columns into left and right children. The splitting condition for the
   // column to be in the left child is as follows:
-  //       cos_max - cos(i) <= cos(i) - cos_min
+  //       cos_max - cos(i) < cos(i) - cos_min
+  // We deviate from the paper here and use < instead of <= in order to handle
+  // the edge case where cosineMax == cosineMin, and force there to be at least
+  // one point in the right node.
   for (size_t i = 0; i < numColumns; i++)
   {
-    if (cosineMax - cosines(i) <= cosines(i) - cosineMin)
-    {
+    if (cosineMax - cosines(i) < cosines(i) - cosineMin)
       leftIndices.push_back(i);
-    }
     else
-    {
       rightIndices.push_back(i);
-    }
   }
 
   // Split the node into left and right children.
diff --git a/src/mlpack/tests/quic_svd_test.cpp b/src/mlpack/tests/quic_svd_test.cpp
index b444739..7e09feb 100644
--- a/src/mlpack/tests/quic_svd_test.cpp
+++ b/src/mlpack/tests/quic_svd_test.cpp
@@ -41,7 +41,7 @@ BOOST_AUTO_TEST_CASE(QUICSVDReconstructionError)
 /**
  * The singular value error of the obtained SVD should be small.
  */
-BOOST_AUTO_TEST_CASE(QUICSVDSigularValueError)
+BOOST_AUTO_TEST_CASE(QUICSVDSingularValueError)
 {
   arma::mat U = arma::randn<arma::mat>(3, 20);
   arma::mat V = arma::randn<arma::mat>(10, 3);
@@ -69,4 +69,13 @@ BOOST_AUTO_TEST_CASE(QUICSVDSigularValueError)
   BOOST_REQUIRE_SMALL(error, 0.05);
 }
 
+BOOST_AUTO_TEST_CASE(QUICSVDSameDimensionTest)
+{
+  arma::mat dataset = arma::randn<arma::mat>(10, 10);
+
+  // Obtain the SVD using default parameters.
+  arma::mat u, v, sigma;
+  svd::QUIC_SVD quicsvd(dataset, u, v, sigma);
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list