[mlpack-git] master: Refactor tree test. (8998cd7)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 19:00:07 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0
>---------------------------------------------------------------
commit 8998cd7f1fcb055e904961dccaf68e335d70dd7b
Author: ryan <ryan at ratml.org>
Date: Wed Apr 22 16:53:35 2015 -0400
Refactor tree test.
>---------------------------------------------------------------
8998cd7f1fcb055e904961dccaf68e335d70dd7b
src/mlpack/tests/tree_test.cpp | 84 +++++++++++++++++++++---------------------
1 file changed, 41 insertions(+), 43 deletions(-)
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index e842d93..69265c7 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -1060,17 +1060,21 @@ BOOST_AUTO_TEST_CASE(CheckDataset)
// Leaf size of 1.
BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
- BOOST_REQUIRE_EQUAL(&rootNode.Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Left()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Right()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Left()->Right()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Right()->Right()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Left()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Right()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Left()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Right()->Dataset(), &dataset);
+ arma::mat* rootDataset = &rootNode.Dataset();
+ BOOST_REQUIRE_EQUAL(&rootNode.Left()->Dataset(), rootDataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Right()->Dataset(), rootDataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Dataset(), rootDataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Left()->Right()->Dataset(), rootDataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Dataset(), rootDataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Right()->Right()->Dataset(), rootDataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Left()->Dataset(),
+ rootDataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Right()->Dataset(),
+ rootDataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Left()->Dataset(),
+ rootDataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Right()->Dataset(),
+ rootDataset);
}
// Ensure FurthestDescendantDistance() works.
@@ -1220,8 +1224,8 @@ BOOST_AUTO_TEST_CASE(ParentDistanceTestWithMapping)
}
// Forward declaration of methods we need for the next test.
-template<typename TreeType, typename MatType>
-bool CheckPointBounds(TreeType& node, const MatType& data);
+template<typename TreeType>
+bool CheckPointBounds(TreeType& node);
template<typename TreeType>
void GenerateVectorOfTree(TreeType* node,
@@ -1230,9 +1234,7 @@ void GenerateVectorOfTree(TreeType* node,
template<int t_pow>
bool DoBoundsIntersect(HRectBound<t_pow>& a,
- HRectBound<t_pow>& b,
- size_t ia,
- size_t ib);
+ HRectBound<t_pow>& b);
/**
* Exhaustive kd-tree test based on #125.
@@ -1262,7 +1264,6 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
size_t size = maxPoints;
arma::mat dataset = arma::mat(dimensions, size);
- arma::mat datacopy; // Used to test mappings.
// Mappings for post-sort verification of data.
std::vector<size_t> newToOld;
@@ -1270,26 +1271,26 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
// Generate data.
dataset.randu();
- datacopy = dataset; // Save a copy.
// Build the tree itself.
TreeType root(dataset, newToOld, oldToNew);
+ const arma::mat& treeset = root.Dataset();
// Ensure the size of the tree is correct.
BOOST_REQUIRE_EQUAL(root.Count(), size);
// Check the forward and backward mappings for correctness.
- for(size_t i = 0; i < size; i++)
+ for (size_t i = 0; i < size; i++)
{
- for(size_t j = 0; j < dimensions; j++)
+ for (size_t j = 0; j < dimensions; j++)
{
- BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
- BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+ BOOST_REQUIRE_EQUAL(treeset(j, i), dataset(j, newToOld[i]));
+ BOOST_REQUIRE_EQUAL(treeset(j, oldToNew[i]), dataset(j, i));
}
}
// Now check that each point is contained inside of all bounds above it.
- CheckPointBounds(root, dataset);
+ CheckPointBounds(root);
// Now check that no peers overlap.
std::vector<TreeType*> v;
@@ -1303,14 +1304,13 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
if (v[i] != NULL && v[j] != NULL)
- BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
- i, j));
+ BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound()));
depth *= 2;
}
}
- arma::mat dataset = arma::mat(25, 1000);
+ arma::mat dataset(25, 1000);
for (size_t col = 0; col < dataset.n_cols; ++col)
for (size_t row = 0; row < dataset.n_rows; ++row)
dataset(row, col) = row + col;
@@ -1323,17 +1323,17 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
}
// Recursively checks that each node contains all points that it claims to have.
-template<typename TreeType, typename MatType>
-bool CheckPointBounds(TreeType& node, const MatType& data)
+template<typename TreeType>
+bool CheckPointBounds(TreeType& node)
{
// Check that each point which this tree claims is actually inside the tree.
for (size_t index = 0; index < node.NumDescendants(); index++)
- if (!node.Bound().Contains(data.col(node.Descendant(index))))
+ if (!node.Bound().Contains(node.Dataset().col(node.Descendant(index))))
return false;
bool result = true;
for (size_t child = 0; child < node.NumChildren(); ++child)
- result &= CheckPointBounds(node.Child(child), data);
+ result &= CheckPointBounds(node.Child(child));
return result;
}
@@ -1371,10 +1371,10 @@ BOOST_AUTO_TEST_CASE(BallTreeTest)
// Generate data.
dataset.randu();
- datacopy = dataset; // Save a copy.
// Build the tree itself.
TreeType root(dataset, newToOld, oldToNew);
+ const arma::mat& treeset = root.Dataset();
// Ensure the size of the tree is correct.
BOOST_REQUIRE_EQUAL(root.NumDescendants(), size);
@@ -1384,21 +1384,19 @@ BOOST_AUTO_TEST_CASE(BallTreeTest)
{
for(size_t j = 0; j < dimensions; j++)
{
- BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
- BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+ BOOST_REQUIRE_EQUAL(treeset(j, i), dataset(j, newToOld[i]));
+ BOOST_REQUIRE_EQUAL(treeset(j, oldToNew[i]), dataset(j, i));
}
}
// Now check that each point is contained inside of all bounds above it.
- CheckPointBounds(root, dataset);
+ CheckPointBounds(root);
}
}
template<int t_pow>
bool DoBoundsIntersect(HRectBound<t_pow>& a,
- HRectBound<t_pow>& b,
- size_t /* ia */,
- size_t /* ib */)
+ HRectBound<t_pow>& b)
{
size_t dimensionality = a.Dim();
@@ -1464,7 +1462,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
size_t pointIncrements = 200; // Range is from 200 points to 400.
// We use the default leaf size of 20.
- for(size_t run = 0; run < maxRuns; run++)
+ for (size_t run = 0; run < maxRuns; run++)
{
size_t dimensions = run + 2;
size_t maxPoints = (run + 1) * pointIncrements;
@@ -1483,6 +1481,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
// Build the tree itself.
TreeType root(dataset, newToOld, oldToNew);
+ const arma::sp_mat& treeset = root.Dataset();
// Ensure the size of the tree is correct.
BOOST_REQUIRE_EQUAL(root.Count(), size);
@@ -1492,13 +1491,13 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
{
for(size_t j = 0; j < dimensions; j++)
{
- BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
- BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+ BOOST_REQUIRE_EQUAL(treeset(j, i), dataset(j, newToOld[i]));
+ BOOST_REQUIRE_EQUAL(treeset(j, oldToNew[i]), dataset(j, i));
}
}
// Now check that each point is contained inside of all bounds above it.
- CheckPointBounds(root, dataset);
+ CheckPointBounds(root);
// Now check that no peers overlap.
std::vector<TreeType*> v;
@@ -1512,8 +1511,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
if (v[i] != NULL && v[j] != NULL)
- BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
- i, j));
+ BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound()));
depth *= 2;
}
More information about the mlpack-git
mailing list