[mlpack-git] master, mlpack-1.0.x: R tree traversal test code. (4fc56e7)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:51:39 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 4fc56e72e4e88200a95eca8ffab9cf9495c6fd44
Author: andrewmw94 <andrewmw94 at gmail.com>
Date:   Mon Jul 7 20:45:00 2014 +0000

    R tree traversal test code.


>---------------------------------------------------------------

4fc56e72e4e88200a95eca8ffab9cf9495c6fd44
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp |  7 ++--
 src/mlpack/tests/rectangle_tree_test.cpp           | 37 +++++++++++++++++++++-
 2 files changed, 41 insertions(+), 3 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index ab9c55a..6e587d0 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -115,6 +115,9 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
   int j = 0;
   GetBoundSeeds(*tree, &i, &j);
   
+
+  if(i == j)
+    std::cout << i << ", " << j << "; " << tree->NumChildren() << std::endl;
   assert(i != j);
   
   RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>* treeOne = new 
@@ -187,7 +190,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::GetPointSeeds(
   // Here we want to find the pair of points that it is worst to place in the same
   // node.  Because we are just using points, we will simply choose the two that would
   // create the most voluminous hyperrectangle.
-  double worstPairScore = 0.0;
+  double worstPairScore = -1.0;
   int worstI = 0;
   int worstJ = 0;
   for(int i = 0; i < tree.Count(); i++) {
@@ -221,7 +224,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::GetBoundSeeds(
   int* iRet,
   int* jRet)
 {
-  double worstPairScore = 0.0;
+  double worstPairScore = -1.0;
   int worstI = 0;
   int worstJ = 0;
   for(int i = 0; i < tree.NumChildren(); i++) {
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 5e00d75..c6bb651 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -54,7 +54,6 @@ BOOST_AUTO_TEST_CASE(RectangleTreeConstructionCountTest)
                       NeighborSearchStat<NearestNeighborSort>,
                       arma::mat> tree(dataset, 20, 6, 5, 2, 0);
   BOOST_REQUIRE_EQUAL(tree.NumDescendants(), 1000);
-  std::cout << tree.ToString() << std::endl;
 }
 
 std::vector<arma::vec*> getAllPointsInTree(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
@@ -139,5 +138,41 @@ BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest)
   assert(checkContainment(tree) == true);
 }
 
+BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
+{
+  arma::mat dataset;
+  dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+  arma::Mat<size_t> neighbors1;
+  arma::mat distances1;
+  arma::Mat<size_t> neighbors2;
+  arma::mat distances2;
+  
+  RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat> RTree(dataset, 20, 6, 5, 2, 0);
+
+  // nearest neighbor search with the R tree.
+  mlpack::neighbor::NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+        RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+	  	      tree::RTreeDescentHeuristic,
+  		      NeighborSearchStat<NearestNeighborSort>,
+  		      arma::mat> > allknn1(&RTree,
+        dataset, true);
+        
+  allknn1.Search(5, neighbors1, distances1);
+
+  // nearest neighbor search the naive way.
+  mlpack::neighbor::AllkNN allknn2(dataset,
+        true, true);
+
+  allknn2.Search(5, neighbors2, distances2);
+  
+  for(size_t i = 0; i < neighbors1.size(); i++) {
+    assert(neighbors1[i] == neighbors2[i]);
+    assert(distances1[i] == distances2[i]);
+  }
+}
+
 
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list