[mlpack-svn] r12778 - mlpack/trunk/src/mlpack/core/tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu May 24 23:34:28 EDT 2012
Author: rcurtin
Date: 2012-05-24 23:34:27 -0400 (Thu, 24 May 2012)
New Revision: 12778
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
Log:
Halfway transitioned to newer implementation, which does not try so hard to
conserve memory.
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp 2012-05-24 21:43:53 UTC (rev 12777)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp 2012-05-25 03:34:27 UTC (rev 12778)
@@ -258,6 +258,9 @@
//! The instantiated statistic.
StatisticType stat;
+ //! The instantiated metric. Either the user passes it in, or we build it.
+ MetricType* metric;
+
/**
* Fill the vector of distances with the distances between the point specified
* by pointIndex and each point in the indices array. The distances of the
@@ -316,6 +319,15 @@
const size_t childFarSetSize,
const size_t childUsedSetSize,
const size_t farSetSize);
+
+ void MoveToUsedSet(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t& nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ arma::Col<size_t>& childIndices,
+ const size_t childFarSetSize,
+ const size_t childUsedSetSize);
};
}; // namespace tree
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp 2012-05-24 21:43:53 UTC (rev 12777)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp 2012-05-25 03:34:27 UTC (rev 12778)
@@ -263,67 +263,64 @@
nearSetSize -= childUsedSetSize;
usedSetSize += childUsedSetSize;
- // If we used all the points in the near set, we don't need to continue.
- if (nearSetSize == 0)
- return;
-
// Now for each point in the near set, we need to make children. To save
// computation later, we'll create an array holding the points in the near
// set, and then after each run we'll check which of those (if any) were used
// and we will remove them. ...if that's faster. I think it is.
- arma::Col<size_t> nearSet = indices.rows(0, nearSetSize - 1);
- for (size_t i = 0; i < nearSet.n_elem; ++i)
+ while (nearSetSize > 0)
{
// If this point has been used, skip to the next one.
- if (nearSet[i] == dataset.n_cols)
- continue;
+ const size_t newPointIndex = indices[0]; // nearSet holds indices.
- const size_t newPointIndex = nearSet[i]; // nearSet holds indices.
+ // If there's only one point left, we don't need this crap.
+ if (nearSetSize == 1)
+ {
+ size_t childNearSetSize = 0;
+ children.push_back(new CoverTree(dataset, expansionConstant,
+ newPointIndex, nextScale, indices, distances, childNearSetSize,
+ farSetSize, usedSetSize));
- // We need to move this point into the used set. To do this we'll swap it
- // with the last value in the far set and then increment the counters
- // accordingly. We don't have to worry about the fact that the point we
- // swapped is actually in the far set but grouped with the near set, because
- // we're about to rebuild that anyway.
- size_t setIndex;
- for (size_t k = 0; k < nearSetSize + farSetSize; ++k)
- if (indices[k] == newPointIndex)
- setIndex = k;
+ // Move last point to the used set. Because nearSetSize == 1, we can
+ // simplify a little bit...
+ size_t tempIndex = indices[farSetSize];
+ double tempDist = distances[farSetSize];
- // Ensure we need to swap.
- if (setIndex != ((nearSetSize + farSetSize) - 1))
- {
- // Perform the swap.
- const size_t otherLocation = (nearSetSize + farSetSize) - 1;
- const double tmpDist = distances[setIndex];
+ indices[farSetSize] = indices[0];
+ distances[farSetSize] = distances[0];
- indices[setIndex] = indices[otherLocation];
- distances[setIndex] = distances[otherLocation];
+ indices[0] = tempIndex;
+ distances[0] = tempDist;
- indices[otherLocation] = newPointIndex;
- distances[otherLocation] = tmpDist;
+ ++usedSetSize;
+ --nearSetSize;
+
+ // And we're done.
+ break;
}
- // Update the near set size. The used set size is updated by the recursive
- // child constructor (but we have to add one for the point we are using,
- // because the child constructor will not count that).
- nearSetSize--;
- usedSetSize++;
+ // Create the near and far set indices and distance vectors.
+ arma::Col<size_t> childIndices(nearSetSize + farSetSize);
+ childIndices.rows(0, (nearSetSize + farSetSize - 2)) = indices.rows(1,
+ nearSetSize + farSetSize - 1);
+ // Put the current point into the used set, so when we move our indices to
+ // the used set, this will be done for us.
+ childIndices(nearSetSize + farSetSize - 1) = indices[0];
+ arma::vec childDistances(nearSetSize + farSetSize - 1);
- // Rebuild the distances for this point.
- ComputeDistances(newPointIndex, indices, distances,
- nearSetSize + farSetSize);
+ // Build distances for the child.
+ ComputeDistances(newPointIndex, childIndices, childDistances, nearSetSize
+ + farSetSize - 1);
// Split into near and far sets for this point.
- childNearSetSize = SplitNearFar(indices, distances, bound, nearSetSize +
- farSetSize);
+ childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
+ nearSetSize + farSetSize - 1);
// Build this child (recursively).
- childUsedSetSize = 0;
- childFarSetSize = ((nearSetSize + farSetSize) - childNearSetSize);
+ childUsedSetSize = 1; // Mark self point as used.
+ childFarSetSize = ((nearSetSize + farSetSize - 1) - childNearSetSize);
children.push_back(new CoverTree(dataset, expansionConstant, newPointIndex,
- nextScale, indices, distances, childNearSetSize, childFarSetSize,
- childUsedSetSize));
+ nextScale, childIndices, childDistances, childNearSetSize,
+ childFarSetSize, childUsedSetSize));
// If we created an implicit node, take its self-child instead (this could
// happen multiple times).
@@ -343,42 +340,14 @@
delete old;
}
- // Now the arrays, in memory, look like this:
- // [ childFar | childUsed | used ]
- // So we don't really need to do anything to them to get ready for the next
- // round. We do need to look at the points that were used and update the
- // nearSet array. This double for loop is suboptimal, but the best way I
- // can think of to do this.
- size_t usedNearSetPoints = 0;
- for (size_t j = childFarSetSize; j < (childFarSetSize + childUsedSetSize);
- ++j)
- {
- for (size_t k = i + 1; k < nearSet.n_elem; ++k)
- {
- if (indices[j] == nearSet[k])
- {
- nearSet[k] = dataset.n_cols; // Invalid index to indicate it's used.
- usedNearSetPoints++;
- }
- }
- }
-
- // Now we update the count of far set points and near set points.
- farSetSize -= (childUsedSetSize - usedNearSetPoints);
- nearSetSize -= usedNearSetPoints;
-
- // Update number of used points.
- usedSetSize += childUsedSetSize;
+ // Now with the child created, it returns the childIndices and
+ // childDistances vectors in this form:
+ // [ childFar | childUsed ]
+ // For each point in the childUsed set, we must move that point to the used
+ // set in our own vector.
+ MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
+ childIndices, childFarSetSize, childUsedSetSize);
}
-
- // Now that this is all done, our memory looks like this:
- // [ childFar | childUsed | used ]
- // We need to rebuild the distances and the set so it looks like this:
- // [ far | used ]
- // because all the points in the near set should be used up.
- farSetSize = childFarSetSize;
-
- ComputeDistances(pointIndex, indices, distances, farSetSize);
}
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -555,6 +524,119 @@
return (childFarSetSize + farSetSize);
}
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::MoveToUsedSet(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t& nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ arma::Col<size_t>& childIndices,
+ const size_t childFarSetSize, // childNearSetSize is 0 in this case.
+ const size_t childUsedSetSize)
+{
+ // Loop across the set. We will swap points as we need. It should be noted
+ // that farSetSize and nearSetSize may change with each iteration of this loop
+ // (depending on if we make a swap or not).
+ size_t startChildUsedSet = 0; // Where to start in the child set.
+ for (size_t i = 0; i < nearSetSize; ++i)
+ {
+ // Discover if this point was in the child's used set.
+ for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
+ {
+ if (childIndices[childFarSetSize + j] == indices[i])
+ {
+ // We have found a point; a swap is necessary.
+ // Since this point is from the near set, to preserve the near set, we
+ // must do a three-way swap.
+
+ // First take the target point and put the used point there.
+ size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+ double tempDist = distances[nearSetSize + farSetSize - 1];
+
+ indices[nearSetSize + farSetSize - 1] = indices[i];
+ distances[nearSetSize + farSetSize - 1] = distances[i];
+
+ // Now take the last point in the near set and put the temporary point
+ // (which is from the far set) there. Store the near set point as a
+ // temporary. If the near set only has one point, we don't need to do
+ // the second swap.
+ if ((nearSetSize - 1) != i)
+ {
+ size_t tempNearIndex = indices[nearSetSize - 1];
+ double tempNearDist = distances[nearSetSize - 1];
+
+ indices[nearSetSize - 1] = tempIndex;
+ distances[nearSetSize - 1] = tempDist;
+
+ indices[i] = tempNearIndex;
+ distances[i] = tempNearDist;
+ }
+ else
+ {
+ indices[i] = tempIndex;
+ distances[i] = tempDist;
+ }
+
+ // We don't need to do a complete preservation of the child index set,
+ // but we want to make sure we only loop over points we haven't seen.
+ // So increment the child counter by 1 and move a point if we need.
+ if (j != startChildUsedSet)
+ {
+ childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
+ startChildUsedSet];
+ }
+
+ // Update all counters from the swaps we have done.
+ ++startChildUsedSet;
+ --nearSetSize;
+ --i; // Since we moved a point out of the near set we must step back.
+
+ break; // Break out of this for loop; back to the first one.
+ }
+ }
+ }
+
+ // Now loop over the far set. This loop is different because we only require
+ // a normal two-way swap instead of the three-way swap to preserve the near
+ // set / far set ordering.
+ for (size_t i = 0; i < farSetSize; ++i)
+ {
+ // Discover if this point was in the child's used set.
+ for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
+ {
+ if (childIndices[childFarSetSize + j] == indices[i + nearSetSize])
+ {
+ // We have found a point to swap.
+ size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+ double tempDist = distances[nearSetSize + farSetSize - 1];
+
+ indices[nearSetSize + farSetSize - 1] = indices[nearSetSize + i];
+ distances[nearSetSize + farSetSize - 1] = distances[nearSetSize + i];
+
+ indices[nearSetSize + i] = tempIndex;
+ distances[nearSetSize + i] = tempDist;
+
+ if (j != startChildUsedSet)
+ {
+ childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
+ startChildUsedSet];
+ }
+
+ // Update all counters from the swaps we have done.
+ ++startChildUsedSet;
+ --farSetSize;
+ --i;
+
+ break; // Break out of this for loop; back to the first one.
+ }
+ }
+ }
+
+ // Update used set size.
+ usedSetSize += childUsedSetSize;
+}
+
}; // namespace tree
}; // namespace mlpack
More information about the mlpack-svn
mailing list