[mlpack-svn] r16733 - in mlpack/trunk/src/mlpack: core/tree/rectangle_tree tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Jul 1 12:11:31 EDT 2014
Author: andrewmw94
Date: Tue Jul 1 12:11:31 2014
New Revision: 16733
Log:
change the tree to store size_t in the nodes and keep the dataset together. Other misc changes.
Modified:
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp Tue Jul 1 12:11:31 2014
@@ -194,7 +194,7 @@
for(int j = i+1; j < tree.Count(); j++) {
double score = 1.0;
for(int k = 0; k < tree.Bound().Dim(); k++) {
- score *= std::abs(tree.Dataset().at(k, i) - tree.Dataset().at(k, j)); // Points are stored by column, but this function takes (row, col).
+ score *= std::abs(tree.Dataset().at(k, tree.Points()[i]) - tree.Dataset().at(k, tree.Points()[j])); // Points (in the dataset) are stored by column, but this function takes (row, col).
}
if(score > worstPairScore) {
worstPairScore = score;
@@ -264,16 +264,16 @@
treeOne->Count() = 0;
treeTwo->Count() = 0;
- treeOne->InsertPoint(oldTree->Dataset().col(intI));
- treeTwo->InsertPoint(oldTree->Dataset().col(intJ));
+ treeOne->InsertPoint(oldTree->Points()[intI]);
+ treeTwo->InsertPoint(oldTree->Points()[intJ]);
// If intJ is the last point in the tree, we need to switch the order so that we remove the correct points.
if(intI > intJ) {
- oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
- oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
+ oldTree->Points()[intI] = oldTree->Points()[--end]; // decrement end
+ oldTree->Points()[intJ] = oldTree->Points()[--end]; // decrement end
} else {
- oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
- oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
+ oldTree->Points()[intJ] = oldTree->Points()[--end]; // decrement end
+ oldTree->Points()[intI] = oldTree->Points()[--end]; // decrement end
}
@@ -309,7 +309,7 @@
double newVolOne = 1.0;
double newVolTwo = 1.0;
for(int i = 0; i < oldTree->Bound().Dim(); i++) {
- double c = oldTree->Dataset().col(index)[i];
+ double c = oldTree->Dataset().col(oldTree->Points()[index])[i];
newVolOne *= treeOne->Bound()[i].Contains(c) ? treeOne->Bound()[i].Width() :
(c < treeOne->Bound()[i].Lo() ? (treeOne->Bound()[i].Hi() - c) : (c - treeOne->Bound()[i].Lo()));
newVolTwo *= treeTwo->Bound()[i].Contains(c) ? treeTwo->Bound()[i].Width() :
@@ -335,26 +335,26 @@
// Assign the point that causes the least increase in volume
// to the appropriate rectangle.
if(bestRect == 1) {
- treeOne->InsertPoint(oldTree->Dataset().col(bestIndex));
+ treeOne->InsertPoint(oldTree->Points()[bestIndex]);
numAssignedOne++;
}
else {
- treeTwo->InsertPoint(oldTree->Dataset().col(bestIndex));
+ treeTwo->InsertPoint(oldTree->Points()[bestIndex]);
numAssignedTwo++;
}
- oldTree->Dataset().col(bestIndex) = oldTree->Dataset().col(--end); // decrement end.
+ oldTree->Points()[bestIndex] = oldTree->Points()[--end]; // decrement end.
}
// See if we need to satisfy the minimum fill.
if(end > 0) {
if(numAssignedOne < numAssignedTwo) {
for(int i = 0; i < end; i++) {
- treeOne->InsertPoint(oldTree->Dataset().col(i));
+ treeOne->InsertPoint(oldTree->Points()[i]);
}
} else {
for(int i = 0; i < end; i++) {
- treeTwo->InsertPoint(oldTree->Dataset().col(i));
+ treeTwo->InsertPoint(oldTree->Points()[i]);
}
}
}
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp Tue Jul 1 12:11:31 2014
@@ -71,7 +71,9 @@
//! The discance to the furthest descendant, cached to speed things up.
double furthestDescendantDistance;
//! The dataset.
- MatType* dataset;
+ MatType& dataset;
+ //! The mapping to the dataset
+ std::vector<size_t> points;
public:
//! So other classes can use TreeType::Mat.
@@ -138,9 +140,17 @@
* it may be passed many times before it actually reaches a leaf.
* @param point The point (arma::vec&) to be inserted.
*/
- void InsertPoint(const arma::vec& point);
+ void InsertPoint(const size_t point);
/**
+ * Deletes a point in the tree. The point will be removed from the data matrix
+ * of the leaf node where it is store and the bounding rectangles will be updated.
+ * Returns true if the point is successfully removed and false if it is not.
+ * (ie. the point is not in the tree)
+ */
+ bool DeletePoint(const size_t point);
+
+ /**
* Find a node in this tree by its begin and count (const).
*
* Every node is uniquely identified by these two numbers.
@@ -205,10 +215,15 @@
RectangleTree*& Parent() { return parent; }
//! Get the dataset which the tree is built on.
- const arma::mat& Dataset() const { return *dataset; }
+ const arma::mat& Dataset() const { return dataset; }
//! Modify the dataset which the tree is built on. Be careful!
- arma::mat& Dataset() { return *dataset; }
-
+ arma::mat& Dataset() { return dataset; }
+
+ //! Get the points vector for this node.
+ const std::vector<size_t>& Points() const { return points; }
+ //! Modify the points vector for this node. Be careful!
+ std::vector<size_t>& Points() { return points; }
+
//! Get the metric which the tree uses.
typename HRectBound<>::MetricType Metric() const { return bound.Metric(); }
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp Tue Jul 1 12:11:31 2014
@@ -39,7 +39,8 @@
minLeafSize(minLeafSize),
bound(data.n_rows),
parentDistance(0),
- dataset(new MatType(data.n_rows, static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
+ dataset(data),
+ points(maxLeafSize+1) // Add one to make splitting the node simpler.
{
stat = StatisticType(*this);
@@ -47,8 +48,8 @@
RectangleTree* root = this;
//for(int i = firstDataIndex; i < 57; i++) { // 56,57 are the bound for where it works/breaks
- for(int i = firstDataIndex; i < data.n_cols; i++) {
- root->InsertPoint(data.col(i));
+ for(size_t i = firstDataIndex; i < data.n_cols; i++) {
+ root->InsertPoint(i);
}
}
@@ -69,7 +70,8 @@
minLeafSize(parentNode->MinLeafSize()),
bound(parentNode->Bound().Dim()),
parentDistance(0),
- dataset(new MatType(static_cast<int>(parentNode->Bound().Dim()), static_cast<int>(maxLeafSize)+1)) // Add one to make splitting the node simpler
+ dataset(parentNode->Dataset()),
+ points(maxLeafSize+1) // Add one to make splitting the node simpler.
{
stat = StatisticType(*this);
}
@@ -90,7 +92,7 @@
delete children[i];
}
//if(numChildren == 0)
- delete dataset;
+ //delete points;
}
@@ -125,7 +127,7 @@
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
NullifyData()
{
- dataset = NULL;
+ //points = NULL;
}
@@ -138,25 +140,25 @@
typename StatisticType,
typename MatType>
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
- InsertPoint(const arma::vec& point)
+ InsertPoint(const size_t point)
{
// Expand the bound regardless of whether it is a leaf node.
- bound |= point;
+ bound |= dataset.col(point);
// If this is a leaf node, we stop here and add the point.
if(numChildren == 0) {
- dataset->col(count++) = point;
+ points[count++] = point;
SplitNode();
return;
}
// If it is not a leaf node, we use the DescentHeuristic to choose a child
// to which we recurse.
- double minScore = DescentType::EvalNode(children[0]->Bound(), point);
+ double minScore = DescentType::EvalNode(children[0]->Bound(), dataset.col(point));
int bestIndex = 0;
for(int i = 1; i < numChildren; i++) {
- double score = DescentType::EvalNode(children[i]->Bound(), point);
+ double score = DescentType::EvalNode(children[i]->Bound(), dataset.col(point));
if(score < minScore) {
minScore = score;
bestIndex = i;
@@ -165,6 +167,60 @@
children[bestIndex]->InsertPoint(point);
}
+/**
+ * Recurse through the tree to remove the point. Once we find the point, we
+ * shrink the rectangles if necessary.
+ */
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
+bool RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+ DeletePoint(const size_t point)
+{
+ if(numChildren == 0) {
+ for(int i = 0; i < count; i++) {
+ if(points[i] == point) {
+ points[i] = points[--count];
+ for(int j = 0; j < bound.Dim(); j++) {
+ if(bound[j].Lo() == dataset.col(point)[j]) {
+ int loIndx = 0;
+ double lo = dataset(points[0])[j];
+ for(int k = 1; k < count; k++) {
+ if(dataset(points[k])[j] < lo) {
+ lo = dataset(points[k])[j];
+ loIndx = k;
+ }
+ }
+ bound[j].Lo() = lo;
+ } else if(bound[j].Hi() == dataset.col(point)[j]) {
+ int hiIndx = 0;
+ double hi = dataset(points[0])[j];
+ for(int k = 1; k < count; k++) {
+ if(dataset(points[k])[j] > hi) {
+ hi = dataset(points[k])[j];
+ hiIndx = k;
+ }
+ }
+ bound[j].Hi() = hi;
+ }
+ }
+ return true;
+ }
+ }
+ } else {
+ for(int i = 0; i < numChildren; i++) {
+ if(children[i].Bound().Contains(dataset.col(point))) {
+ if(children[i].DeletePoint(dataset.col(point))) {
+
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
template<typename SplitType,
typename DescentType,
typename StatisticType,
@@ -308,7 +364,7 @@
inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
Point(const size_t index) const
{
- return (begin + index);
+ return dataset(points[index]);
}
/**
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp Tue Jul 1 12:11:31 2014
@@ -49,8 +49,8 @@
// This is not a leaf node so we:
// Sort the children of this node by their scores.
- std::vector<RectangleTree*> nodes = new std::vector<RectangleTree*>(referenceNode.NumChildren());
- std::vector<double> scores = new std::vector<double>(referenceNode.NumChildren());
+ std::vector<RectangleTree*> nodes(referenceNode.NumChildren());
+ std::vector<double> scores(referenceNode.NumChildren());
for(int i = 0; i < referenceNode.NumChildren(); i++) {
nodes[i] = referenceNode.Children()[i];
scores[i] = rule.Score(nodes[i]);
@@ -60,12 +60,12 @@
// Iterate through them starting with the best and stopping when we reach
// one that isn't good enough.
for(int i = 0; i < referenceNode.NumChildren(); i++) {
- if(rule.Rescore(queryIndex, nodes[i], scores[i]) != DBL_MAX)
- Traverse(queryIndex, nodes[i]);
- else {
- numPrunes += referenceNode.NumChildren - i;
- return;
- }
+ if(rule.Rescore(queryIndex, nodes[i], scores[i]) != DBL_MAX)
+ Traverse(queryIndex, nodes[i]);
+ else {
+ numPrunes += referenceNode.NumChildren - i;
+ return;
+ }
}
// We only get here if we couldn't prune any of them.
return;
Modified: mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp (original)
+++ mlpack/trunk/src/mlpack/tests/rectangle_tree_test.cpp Tue Jul 1 12:11:31 2014
@@ -70,7 +70,7 @@
}
} else {
for(size_t i = 0; i < tree.Count(); i++) {
- arma::vec* c = new arma::vec(tree.Dataset().col(i));
+ arma::vec* c = new arma::vec(tree.Dataset().col(tree.Points()[i]));
vec.push_back(c);
}
}
@@ -112,7 +112,7 @@
bool passed = true;
if(tree.NumChildren() == 0) {
for(size_t i = 0; i < tree.Count(); i++) {
- passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(i));
+ passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(tree.Points()[i]));
}
} else {
for(size_t i = 0; i < tree.NumChildren(); i++) {
More information about the mlpack-svn
mailing list