[mlpack-git] master: Refactor tree test. (8998cd7)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 19:00:07 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0

>---------------------------------------------------------------

commit 8998cd7f1fcb055e904961dccaf68e335d70dd7b
Author: ryan <ryan at ratml.org>
Date:   Wed Apr 22 16:53:35 2015 -0400

    Refactor tree test.


>---------------------------------------------------------------

8998cd7f1fcb055e904961dccaf68e335d70dd7b
 src/mlpack/tests/tree_test.cpp | 84 +++++++++++++++++++++---------------------
 1 file changed, 41 insertions(+), 43 deletions(-)

diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index e842d93..69265c7 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -1060,17 +1060,21 @@ BOOST_AUTO_TEST_CASE(CheckDataset)
   // Leaf size of 1.
   BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
 
-  BOOST_REQUIRE_EQUAL(&rootNode.Dataset(), &dataset);
-  BOOST_REQUIRE_EQUAL(&rootNode.Left()->Dataset(), &dataset);
-  BOOST_REQUIRE_EQUAL(&rootNode.Right()->Dataset(), &dataset);
-  BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Dataset(), &dataset);
-  BOOST_REQUIRE_EQUAL(&rootNode.Left()->Right()->Dataset(), &dataset);
-  BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Dataset(), &dataset);
-  BOOST_REQUIRE_EQUAL(&rootNode.Right()->Right()->Dataset(), &dataset);
-  BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Left()->Dataset(), &dataset);
-  BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Right()->Dataset(), &dataset);
-  BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Left()->Dataset(), &dataset);
-  BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Right()->Dataset(), &dataset);
+  arma::mat* rootDataset = &rootNode.Dataset();
+  BOOST_REQUIRE_EQUAL(&rootNode.Left()->Dataset(), rootDataset);
+  BOOST_REQUIRE_EQUAL(&rootNode.Right()->Dataset(), rootDataset);
+  BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Dataset(), rootDataset);
+  BOOST_REQUIRE_EQUAL(&rootNode.Left()->Right()->Dataset(), rootDataset);
+  BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Dataset(), rootDataset);
+  BOOST_REQUIRE_EQUAL(&rootNode.Right()->Right()->Dataset(), rootDataset);
+  BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Left()->Dataset(),
+      rootDataset);
+  BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Right()->Dataset(),
+      rootDataset);
+  BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Left()->Dataset(),
+      rootDataset);
+  BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Right()->Dataset(),
+      rootDataset);
 }
 
 // Ensure FurthestDescendantDistance() works.
@@ -1220,8 +1224,8 @@ BOOST_AUTO_TEST_CASE(ParentDistanceTestWithMapping)
 }
 
 // Forward declaration of methods we need for the next test.
-template<typename TreeType, typename MatType>
-bool CheckPointBounds(TreeType& node, const MatType& data);
+template<typename TreeType>
+bool CheckPointBounds(TreeType& node);
 
 template<typename TreeType>
 void GenerateVectorOfTree(TreeType* node,
@@ -1230,9 +1234,7 @@ void GenerateVectorOfTree(TreeType* node,
 
 template<int t_pow>
 bool DoBoundsIntersect(HRectBound<t_pow>& a,
-                       HRectBound<t_pow>& b,
-                       size_t ia,
-                       size_t ib);
+                       HRectBound<t_pow>& b);
 
 /**
  * Exhaustive kd-tree test based on #125.
@@ -1262,7 +1264,6 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
 
     size_t size = maxPoints;
     arma::mat dataset = arma::mat(dimensions, size);
-    arma::mat datacopy; // Used to test mappings.
 
     // Mappings for post-sort verification of data.
     std::vector<size_t> newToOld;
@@ -1270,26 +1271,26 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
 
     // Generate data.
     dataset.randu();
-    datacopy = dataset; // Save a copy.
 
     // Build the tree itself.
     TreeType root(dataset, newToOld, oldToNew);
+    const arma::mat& treeset = root.Dataset();
 
     // Ensure the size of the tree is correct.
     BOOST_REQUIRE_EQUAL(root.Count(), size);
 
     // Check the forward and backward mappings for correctness.
-    for(size_t i = 0; i < size; i++)
+    for (size_t i = 0; i < size; i++)
     {
-      for(size_t j = 0; j < dimensions; j++)
+      for (size_t j = 0; j < dimensions; j++)
       {
-        BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
-        BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+        BOOST_REQUIRE_EQUAL(treeset(j, i), dataset(j, newToOld[i]));
+        BOOST_REQUIRE_EQUAL(treeset(j, oldToNew[i]), dataset(j, i));
       }
     }
 
     // Now check that each point is contained inside of all bounds above it.
-    CheckPointBounds(root, dataset);
+    CheckPointBounds(root);
 
     // Now check that no peers overlap.
     std::vector<TreeType*> v;
@@ -1303,14 +1304,13 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
       for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
         for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
           if (v[i] != NULL && v[j] != NULL)
-            BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
-                  i, j));
+            BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound()));
 
       depth *= 2;
     }
   }
 
-  arma::mat dataset = arma::mat(25, 1000);
+  arma::mat dataset(25, 1000);
   for (size_t col = 0; col < dataset.n_cols; ++col)
     for (size_t row = 0; row < dataset.n_rows; ++row)
       dataset(row, col) = row + col;
@@ -1323,17 +1323,17 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
 }
 
 // Recursively checks that each node contains all points that it claims to have.
-template<typename TreeType, typename MatType>
-bool CheckPointBounds(TreeType& node, const MatType& data)
+template<typename TreeType>
+bool CheckPointBounds(TreeType& node)
 {
   // Check that each point which this tree claims is actually inside the tree.
   for (size_t index = 0; index < node.NumDescendants(); index++)
-    if (!node.Bound().Contains(data.col(node.Descendant(index))))
+    if (!node.Bound().Contains(node.Dataset().col(node.Descendant(index))))
       return false;
 
   bool result = true;
   for (size_t child = 0; child < node.NumChildren(); ++child)
-    result &= CheckPointBounds(node.Child(child), data);
+    result &= CheckPointBounds(node.Child(child));
   return result;
 }
 
@@ -1371,10 +1371,10 @@ BOOST_AUTO_TEST_CASE(BallTreeTest)
 
     // Generate data.
     dataset.randu();
-    datacopy = dataset; // Save a copy.
 
     // Build the tree itself.
     TreeType root(dataset, newToOld, oldToNew);
+    const arma::mat& treeset = root.Dataset();
 
     // Ensure the size of the tree is correct.
     BOOST_REQUIRE_EQUAL(root.NumDescendants(), size);
@@ -1384,21 +1384,19 @@ BOOST_AUTO_TEST_CASE(BallTreeTest)
     {
       for(size_t j = 0; j < dimensions; j++)
       {
-        BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
-        BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+        BOOST_REQUIRE_EQUAL(treeset(j, i), dataset(j, newToOld[i]));
+        BOOST_REQUIRE_EQUAL(treeset(j, oldToNew[i]), dataset(j, i));
       }
     }
 
     // Now check that each point is contained inside of all bounds above it.
-    CheckPointBounds(root, dataset);
+    CheckPointBounds(root);
   }
 }
 
 template<int t_pow>
 bool DoBoundsIntersect(HRectBound<t_pow>& a,
-                       HRectBound<t_pow>& b,
-                       size_t /* ia */,
-                       size_t /* ib */)
+                       HRectBound<t_pow>& b)
 {
   size_t dimensionality = a.Dim();
 
@@ -1464,7 +1462,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
   size_t pointIncrements = 200; // Range is from 200 points to 400.
 
   // We use the default leaf size of 20.
-  for(size_t run = 0; run < maxRuns; run++)
+  for (size_t run = 0; run < maxRuns; run++)
   {
     size_t dimensions = run + 2;
     size_t maxPoints = (run + 1) * pointIncrements;
@@ -1483,6 +1481,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
 
     // Build the tree itself.
     TreeType root(dataset, newToOld, oldToNew);
+    const arma::sp_mat& treeset = root.Dataset();
 
     // Ensure the size of the tree is correct.
     BOOST_REQUIRE_EQUAL(root.Count(), size);
@@ -1492,13 +1491,13 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
     {
       for(size_t j = 0; j < dimensions; j++)
       {
-        BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
-        BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+        BOOST_REQUIRE_EQUAL(treeset(j, i), dataset(j, newToOld[i]));
+        BOOST_REQUIRE_EQUAL(treeset(j, oldToNew[i]), dataset(j, i));
       }
     }
 
     // Now check that each point is contained inside of all bounds above it.
-    CheckPointBounds(root, dataset);
+    CheckPointBounds(root);
 
     // Now check that no peers overlap.
     std::vector<TreeType*> v;
@@ -1512,8 +1511,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
       for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
         for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
           if (v[i] != NULL && v[j] != NULL)
-            BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
-                  i, j));
+            BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound()));
 
       depth *= 2;
     }



More information about the mlpack-git mailing list