[mlpack-svn] r16795 - in mlpack/trunk/src/mlpack: core/tree/rectangle_tree tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 9 14:43:10 EDT 2014
Author: andrewmw94
Date: Wed Jul 9 14:43:09 2014
New Revision: 16795
Log:
R tree now has dataset and indices
Modified:
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp Wed Jul 9 14:43:09 2014
@@ -197,7 +197,7 @@
for(int j = i+1; j < tree.Count(); j++) {
double score = 1.0;
for(int k = 0; k < tree.Bound().Dim(); k++) {
- score *= std::abs(tree.Dataset().at(k, tree.Points()[i]) - tree.Dataset().at(k, tree.Points()[j])); // Points (in the dataset) are stored by column, but this function takes (row, col).
+ score *= std::abs(tree.LocalDataset().at(k, i) - tree.LocalDataset().at(k, j)); // Points (in the dataset) are stored by column, but this function takes (row, col).
}
if(score > worstPairScore) {
worstPairScore = score;
@@ -312,7 +312,7 @@
double newVolOne = 1.0;
double newVolTwo = 1.0;
for(int i = 0; i < oldTree->Bound().Dim(); i++) {
- double c = oldTree->Dataset().col(oldTree->Points()[index])[i];
+ double c = oldTree->LocalDataset().col(index)[i];
newVolOne *= treeOne->Bound()[i].Contains(c) ? treeOne->Bound()[i].Width() :
(c < treeOne->Bound()[i].Lo() ? (treeOne->Bound()[i].Hi() - c) : (c - treeOne->Bound()[i].Lo()));
newVolTwo *= treeTwo->Bound()[i].Contains(c) ? treeTwo->Bound()[i].Width() :
@@ -347,6 +347,7 @@
}
oldTree->Points()[bestIndex] = oldTree->Points()[--end]; // decrement end.
+ oldTree->LocalDataset().col(bestIndex) = oldTree->LocalDataset().col(end);
}
// See if we need to satisfy the minimum fill.
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp Wed Jul 9 14:43:09 2014
@@ -74,6 +74,8 @@
MatType& dataset;
//! The mapping to the dataset
std::vector<size_t> points;
+ //! The local dataset
+ MatType* localDataset;
public:
//! So other classes can use TreeType::Mat.
@@ -226,6 +228,11 @@
const std::vector<size_t>& Points() const { return points; }
//! Modify the points vector for this node. Be careful!
std::vector<size_t>& Points() { return points; }
+
+ //! Get the local dataset of this node.
+ const arma::mat& LocalDataset() const { return *localDataset; }
+ //! Modify the local dataset of this node.
+ arma::mat& LocalDataset() { return *localDataset; }
//! Get the metric which the tree uses.
typename HRectBound<>::MetricType Metric() const { return bound.Metric(); }
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp Wed Jul 9 14:43:09 2014
@@ -40,7 +40,8 @@
bound(data.n_rows),
parentDistance(0),
dataset(data),
- points(maxLeafSize+1) // Add one to make splitting the node simpler.
+ points(maxLeafSize+1), // Add one to make splitting the node simpler.
+ localDataset(new MatType(data.n_rows, static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
{
stat = StatisticType(*this);
@@ -71,7 +72,8 @@
bound(parentNode->Bound().Dim()),
parentDistance(0),
dataset(parentNode->Dataset()),
- points(maxLeafSize+1) // Add one to make splitting the node simpler.
+ points(maxLeafSize+1), // Add one to make splitting the node simpler.
+ localDataset(new MatType(static_cast<int>(parentNode->Bound().Dim()), static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
{
stat = StatisticType(*this);
}
@@ -92,7 +94,7 @@
delete children[i];
}
//if(numChildren == 0)
- //delete points;
+ delete localDataset;
}
@@ -127,7 +129,7 @@
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
NullifyData()
{
- //points = NULL;
+ localDataset = NULL;
}
@@ -148,6 +150,7 @@
// If this is a leaf node, we stop here and add the point.
if(numChildren == 0) {
points[count++] = point;
+ localDataset->col(count) = dataset.col(point);
SplitNode();
return;
}
Modified: mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp (original)
+++ mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp Wed Jul 9 14:43:09 2014
@@ -128,7 +128,7 @@
BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest)
{
- arma::mat dataset;
+ arma::mat dataset;
dataset.randu(8, 1000); // 1000 points in 8 dimensions.
RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
@@ -138,6 +138,37 @@
assert(checkContainment(tree) == true);
}
+bool checkSync(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>& tree) {
+ if(tree.IsLeaf()) {
+ for(size_t i = 0; i < tree.Count(); i++) {
+ for(size_t j = 0; j < tree.LocalDataset().n_rows; j++) {
+ if(tree.LocalDataset().col(i)[j] != tree.Dataset().col(tree.Points()[i])[j])
+ return false;
+ }
+ }
+ } else {
+ for(size_t i = 0; i < tree.NumChildren(); i++) {
+ if(!checkSync(tree.Children()[i]))
+ return false;
+ }
+ }
+ return true;
+}
+
+BOOST_AUTO_TEST_CASE(TreeLocalDatasetInSync) {
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+ assert(checkSync(tree) == true);
+}
+
BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
{
arma::mat dataset;
@@ -174,5 +205,4 @@
}
}
-
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list