[mlpack-svn] r16795 - in mlpack/trunk/src/mlpack: core/tree/rectangle_tree tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 9 14:43:10 EDT 2014


Author: andrewmw94
Date: Wed Jul  9 14:43:09 2014
New Revision: 16795

Log:
R tree now has dataset and indices

Modified:
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
   mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp

Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp	Wed Jul  9 14:43:09 2014
@@ -197,7 +197,7 @@
     for(int j = i+1; j < tree.Count(); j++) {
       double score = 1.0;
       for(int k = 0; k < tree.Bound().Dim(); k++) {
-	score *= std::abs(tree.Dataset().at(k, tree.Points()[i]) - tree.Dataset().at(k, tree.Points()[j])); // Points (in the dataset) are stored by column, but this function takes (row, col).
+	score *= std::abs(tree.LocalDataset().at(k, i) - tree.LocalDataset().at(k, j)); // Points (in the dataset) are stored by column, but this function takes (row, col).
       }
       if(score > worstPairScore) {
 	worstPairScore = score;
@@ -312,7 +312,7 @@
       double newVolOne = 1.0;
       double newVolTwo = 1.0;
       for(int i = 0; i < oldTree->Bound().Dim(); i++) {
-	double c = oldTree->Dataset().col(oldTree->Points()[index])[i];      
+	double c = oldTree->LocalDataset().col(index)[i];      
 	newVolOne *= treeOne->Bound()[i].Contains(c) ? treeOne->Bound()[i].Width() :
 	  (c < treeOne->Bound()[i].Lo() ? (treeOne->Bound()[i].Hi() - c) : (c - treeOne->Bound()[i].Lo()));
 	newVolTwo *= treeTwo->Bound()[i].Contains(c) ? treeTwo->Bound()[i].Width() :
@@ -347,6 +347,7 @@
     }
 
     oldTree->Points()[bestIndex] = oldTree->Points()[--end]; // decrement end.
+    oldTree->LocalDataset().col(bestIndex) = oldTree->LocalDataset().col(end);
   }
   
   // See if we need to satisfy the minimum fill.

Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp	Wed Jul  9 14:43:09 2014
@@ -74,6 +74,8 @@
   MatType& dataset;
   //! The mapping to the dataset
   std::vector<size_t> points;
+  //! The local dataset
+  MatType* localDataset;
 
  public:
   //! So other classes can use TreeType::Mat.
@@ -226,6 +228,11 @@
   const std::vector<size_t>& Points() const { return points; }
   //! Modify the points vector for this node.  Be careful!
   std::vector<size_t>& Points() { return points; }
+  
+  //! Get the local dataset of this node.
+  const arma::mat& LocalDataset() const { return *localDataset; }
+  //! Modify the local dataset of this node.
+  arma::mat& LocalDataset() { return *localDataset; }
 
   //! Get the metric which the tree uses.
   typename HRectBound<>::MetricType Metric() const { return bound.Metric(); }

Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp	Wed Jul  9 14:43:09 2014
@@ -40,7 +40,8 @@
     bound(data.n_rows),
     parentDistance(0),
     dataset(data),
-    points(maxLeafSize+1) // Add one to make splitting the node simpler.
+    points(maxLeafSize+1), // Add one to make splitting the node simpler.
+    localDataset(new MatType(data.n_rows, static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
 {
   stat = StatisticType(*this);
 
@@ -71,7 +72,8 @@
   bound(parentNode->Bound().Dim()),
   parentDistance(0),
   dataset(parentNode->Dataset()),
-  points(maxLeafSize+1) // Add one to make splitting the node simpler.
+  points(maxLeafSize+1), // Add one to make splitting the node simpler.
+  localDataset(new MatType(static_cast<int>(parentNode->Bound().Dim()), static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
 {
   stat = StatisticType(*this);
 }
@@ -92,7 +94,7 @@
     delete children[i];
   }
   //if(numChildren == 0)
-  //delete points;
+  delete localDataset;
 }
 
 
@@ -127,7 +129,7 @@
 void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
     NullifyData()
 {
-  //points = NULL;
+  localDataset = NULL;
 }
 
 
@@ -148,6 +150,7 @@
   // If this is a leaf node, we stop here and add the point.
   if(numChildren == 0) {
     points[count++] = point;
+    localDataset->col(count) = dataset.col(point);
     SplitNode();
     return;
   }

Modified: mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp	(original)
+++ mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp	Wed Jul  9 14:43:09 2014
@@ -128,7 +128,7 @@
 
 BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest)
 {
-    arma::mat dataset;
+  arma::mat dataset;
   dataset.randu(8, 1000); // 1000 points in 8 dimensions.
   
   RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
@@ -138,6 +138,37 @@
   assert(checkContainment(tree) == true);
 }
 
+bool checkSync(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+               tree::RTreeDescentHeuristic,
+               NeighborSearchStat<NearestNeighborSort>,
+               arma::mat>& tree) {
+  if(tree.IsLeaf()) {
+    for(size_t i = 0; i < tree.Count(); i++) {
+      for(size_t j = 0; j < tree.LocalDataset().n_rows; j++) {
+	if(tree.LocalDataset().col(i)[j] != tree.Dataset().col(tree.Points()[i])[j])
+	  return false;
+      }
+    }
+  } else {
+    for(size_t i = 0; i < tree.NumChildren(); i++) {
+      if(!checkSync(tree.Children()[i]))
+	return false;
+    }
+  }
+  return true;
+}
+
+BOOST_AUTO_TEST_CASE(TreeLocalDatasetInSync) {
+  arma::mat dataset;
+  dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+  
+  RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+  assert(checkSync(tree) == true);
+}
+
 BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
 {
   arma::mat dataset;
@@ -174,5 +205,4 @@
   }
 }
 
-
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-svn mailing list