[mlpack-git] master, mlpack-1.0.x: R tree now has dataset and indices (3643362)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:52:20 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 36433627b45dca8897fc08a12943dd85d00dc53f
Author: andrewmw94 <andrewmw94 at gmail.com>
Date:   Wed Jul 9 18:43:09 2014 +0000

    R tree now has dataset and indices


>---------------------------------------------------------------

36433627b45dca8897fc08a12943dd85d00dc53f
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp |  5 ++--
 .../core/tree/rectangle_tree/rectangle_tree.hpp    |  7 +++++
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    | 11 ++++---
 src/mlpack/tests/rectangle_tree_test.cpp           | 34 ++++++++++++++++++++--
 4 files changed, 49 insertions(+), 8 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index 6e587d0..ad4c51b 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -197,7 +197,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::GetPointSeeds(
     for(int j = i+1; j < tree.Count(); j++) {
       double score = 1.0;
       for(int k = 0; k < tree.Bound().Dim(); k++) {
-	score *= std::abs(tree.Dataset().at(k, tree.Points()[i]) - tree.Dataset().at(k, tree.Points()[j])); // Points (in the dataset) are stored by column, but this function takes (row, col).
+	score *= std::abs(tree.LocalDataset().at(k, i) - tree.LocalDataset().at(k, j)); // Points (in the dataset) are stored by column, but this function takes (row, col).
       }
       if(score > worstPairScore) {
 	worstPairScore = score;
@@ -312,7 +312,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
       double newVolOne = 1.0;
       double newVolTwo = 1.0;
       for(int i = 0; i < oldTree->Bound().Dim(); i++) {
-	double c = oldTree->Dataset().col(oldTree->Points()[index])[i];      
+	double c = oldTree->LocalDataset().col(index)[i];      
 	newVolOne *= treeOne->Bound()[i].Contains(c) ? treeOne->Bound()[i].Width() :
 	  (c < treeOne->Bound()[i].Lo() ? (treeOne->Bound()[i].Hi() - c) : (c - treeOne->Bound()[i].Lo()));
 	newVolTwo *= treeTwo->Bound()[i].Contains(c) ? treeTwo->Bound()[i].Width() :
@@ -347,6 +347,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
     }
 
     oldTree->Points()[bestIndex] = oldTree->Points()[--end]; // decrement end.
+    oldTree->LocalDataset().col(bestIndex) = oldTree->LocalDataset().col(end);
   }
   
   // See if we need to satisfy the minimum fill.
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 550cf74..76f879e 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -74,6 +74,8 @@ class RectangleTree
   MatType& dataset;
   //! The mapping to the dataset
   std::vector<size_t> points;
+  //! The local dataset
+  MatType* localDataset;
 
  public:
   //! So other classes can use TreeType::Mat.
@@ -227,6 +229,11 @@ class RectangleTree
   //! Modify the points vector for this node.  Be careful!
   std::vector<size_t>& Points() { return points; }
   
+  //! Get the local dataset of this node.
+  const arma::mat& LocalDataset() const { return *localDataset; }
+  //! Modify the local dataset of this node.
+  arma::mat& LocalDataset() { return *localDataset; }
+
   //! Get the metric which the tree uses.
   typename HRectBound<>::MetricType Metric() const { return bound.Metric(); }
 
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index ad6f038..8df89ec 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -40,7 +40,8 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
     bound(data.n_rows),
     parentDistance(0),
     dataset(data),
-    points(maxLeafSize+1) // Add one to make splitting the node simpler.
+    points(maxLeafSize+1), // Add one to make splitting the node simpler.
+    localDataset(new MatType(data.n_rows, static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
 {
   stat = StatisticType(*this);
 
@@ -71,7 +72,8 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
   bound(parentNode->Bound().Dim()),
   parentDistance(0),
   dataset(parentNode->Dataset()),
-  points(maxLeafSize+1) // Add one to make splitting the node simpler.
+  points(maxLeafSize+1), // Add one to make splitting the node simpler.
+  localDataset(new MatType(static_cast<int>(parentNode->Bound().Dim()), static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
 {
   stat = StatisticType(*this);
 }
@@ -92,7 +94,7 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::
     delete children[i];
   }
   //if(numChildren == 0)
-  //delete points;
+  delete localDataset;
 }
 
 
@@ -127,7 +129,7 @@ template<typename SplitType,
 void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
     NullifyData()
 {
-  //points = NULL;
+  localDataset = NULL;
 }
 
 
@@ -148,6 +150,7 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
   // If this is a leaf node, we stop here and add the point.
   if(numChildren == 0) {
     points[count++] = point;
+    localDataset->col(count) = dataset.col(point);
     SplitNode();
     return;
   }
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index c6bb651..28a14fd 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -128,7 +128,7 @@ bool checkContainment(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeu
 
 BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest)
 {
-    arma::mat dataset;
+  arma::mat dataset;
   dataset.randu(8, 1000); // 1000 points in 8 dimensions.
   
   RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
@@ -138,6 +138,37 @@ BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest)
   assert(checkContainment(tree) == true);
 }
 
+bool checkSync(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+               tree::RTreeDescentHeuristic,
+               NeighborSearchStat<NearestNeighborSort>,
+               arma::mat>& tree) {
+  if(tree.IsLeaf()) {
+    for(size_t i = 0; i < tree.Count(); i++) {
+      for(size_t j = 0; j < tree.LocalDataset().n_rows; j++) {
+	if(tree.LocalDataset().col(i)[j] != tree.Dataset().col(tree.Points()[i])[j])
+	  return false;
+      }
+    }
+  } else {
+    for(size_t i = 0; i < tree.NumChildren(); i++) {
+      if(!checkSync(tree.Children()[i]))
+	return false;
+    }
+  }
+  return true;
+}
+
+BOOST_AUTO_TEST_CASE(TreeLocalDatasetInSync) {
+  arma::mat dataset;
+  dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+  
+  RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+  assert(checkSync(tree) == true);
+}
+
 BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
 {
   arma::mat dataset;
@@ -174,5 +205,4 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
   }
 }
 
-
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list