[mlpack-svn] r12778 - mlpack/trunk/src/mlpack/core/tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu May 24 23:34:28 EDT 2012


Author: rcurtin
Date: 2012-05-24 23:34:27 -0400 (Thu, 24 May 2012)
New Revision: 12778

Modified:
   mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
Log:
Halfway transitioned to newer implementation, which does not try so hard to
conserve memory.


Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp	2012-05-24 21:43:53 UTC (rev 12777)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp	2012-05-25 03:34:27 UTC (rev 12778)
@@ -258,6 +258,9 @@
   //! The instantiated statistic.
   StatisticType stat;
 
+  //! The instantiated metric.  Either the user passes it in, or we build it.
+  MetricType* metric;
+
   /**
    * Fill the vector of distances with the distances between the point specified
    * by pointIndex and each point in the indices array.  The distances of the
@@ -316,6 +319,15 @@
                       const size_t childFarSetSize,
                       const size_t childUsedSetSize,
                       const size_t farSetSize);
+
+  void MoveToUsedSet(arma::Col<size_t>& indices,
+                     arma::vec& distances,
+                     size_t& nearSetSize,
+                     size_t& farSetSize,
+                     size_t& usedSetSize,
+                     arma::Col<size_t>& childIndices,
+                     const size_t childFarSetSize,
+                     const size_t childUsedSetSize);
 };
 
 }; // namespace tree

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp	2012-05-24 21:43:53 UTC (rev 12777)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp	2012-05-25 03:34:27 UTC (rev 12778)
@@ -263,67 +263,64 @@
   nearSetSize -= childUsedSetSize;
   usedSetSize += childUsedSetSize;
 
-  // If we used all the points in the near set, we don't need to continue.
-  if (nearSetSize == 0)
-    return;
-
   // Now for each point in the near set, we need to make children.  To save
   // computation later, we'll create an array holding the points in the near
   // set, and then after each run we'll check which of those (if any) were used
   // and we will remove them.  ...if that's faster.  I think it is.
-  arma::Col<size_t> nearSet = indices.rows(0, nearSetSize - 1);
-  for (size_t i = 0; i < nearSet.n_elem; ++i)
+  while (nearSetSize > 0)
   {
     // If this point has been used, skip to the next one.
-    if (nearSet[i] == dataset.n_cols)
-      continue;
+    const size_t newPointIndex = indices[0]; // nearSet holds indices.
 
-    const size_t newPointIndex = nearSet[i]; // nearSet holds indices.
+    // If there's only one point left, we don't need this crap.
+    if (nearSetSize == 1)
+    {
+      size_t childNearSetSize = 0;
+      children.push_back(new CoverTree(dataset, expansionConstant,
+          newPointIndex, nextScale, indices, distances, childNearSetSize,
+          farSetSize, usedSetSize));
 
-    // We need to move this point into the used set.  To do this we'll swap it
-    // with the last value in the far set and then increment the counters
-    // accordingly.  We don't have to worry about the fact that the point we
-    // swapped is actually in the far set but grouped with the near set, because
-    // we're about to rebuild that anyway.
-    size_t setIndex;
-    for (size_t k = 0; k < nearSetSize + farSetSize; ++k)
-      if (indices[k] == newPointIndex)
-        setIndex = k;
+      // Move last point to the used set.  Because nearSetSize == 1, we can
+      // simplify a little bit...
+      size_t tempIndex = indices[farSetSize];
+      double tempDist = distances[farSetSize];
 
-    // Ensure we need to swap.
-    if (setIndex != ((nearSetSize + farSetSize) - 1))
-    {
-      // Perform the swap.
-      const size_t otherLocation = (nearSetSize + farSetSize) - 1;
-      const double tmpDist = distances[setIndex];
+      indices[farSetSize] = indices[0];
+      distances[farSetSize] = distances[0];
 
-      indices[setIndex] = indices[otherLocation];
-      distances[setIndex] = distances[otherLocation];
+      indices[0] = tempIndex;
+      distances[0] = tempDist;
 
-      indices[otherLocation] = newPointIndex;
-      distances[otherLocation] = tmpDist;
+      ++usedSetSize;
+      --nearSetSize;
+
+      // And we're done.
+      break;
     }
 
-    // Update the near set size.  The used set size is updated by the recursive
-    // child constructor (but we have to add one for the point we are using,
-    // because the child constructor will not count that).
-    nearSetSize--;
-    usedSetSize++;
+    // Create the near and far set indices and distance vectors.
+    arma::Col<size_t> childIndices(nearSetSize + farSetSize);
+    childIndices.rows(0, (nearSetSize + farSetSize - 2)) = indices.rows(1,
+        nearSetSize + farSetSize - 1);
+    // Put the current point into the used set, so when we move our indices to
+    // the used set, this will be done for us.
+    childIndices(nearSetSize + farSetSize - 1) = indices[0];
+    arma::vec childDistances(nearSetSize + farSetSize - 1);
 
-    // Rebuild the distances for this point.
-    ComputeDistances(newPointIndex, indices, distances,
-        nearSetSize + farSetSize);
+    // Build distances for the child.
+    ComputeDistances(newPointIndex, childIndices, childDistances, nearSetSize
+        + farSetSize - 1);
 
     // Split into near and far sets for this point.
-    childNearSetSize = SplitNearFar(indices, distances, bound, nearSetSize +
-        farSetSize);
+    childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
+        nearSetSize + farSetSize - 1);
 
     // Build this child (recursively).
-    childUsedSetSize = 0;
-    childFarSetSize = ((nearSetSize + farSetSize) - childNearSetSize);
+    childUsedSetSize = 1; // Mark self point as used.
+    childFarSetSize = ((nearSetSize + farSetSize - 1) - childNearSetSize);
     children.push_back(new CoverTree(dataset, expansionConstant, newPointIndex,
-        nextScale, indices, distances, childNearSetSize, childFarSetSize,
-        childUsedSetSize));
+        nextScale, childIndices, childDistances, childNearSetSize,
+        childFarSetSize, childUsedSetSize));
 
     // If we created an implicit node, take its self-child instead (this could
     // happen multiple times).
@@ -343,42 +340,14 @@
       delete old;
     }
 
-    // Now the arrays, in memory, look like this:
-    // [ childFar | childUsed | used ]
-    // So we don't really need to do anything to them to get ready for the next
-    // round.  We do need to look at the points that were used and update the
-    // nearSet array.  This double for loop is suboptimal, but the best way I
-    // can think of to do this.
-    size_t usedNearSetPoints = 0;
-    for (size_t j = childFarSetSize; j < (childFarSetSize + childUsedSetSize);
-         ++j)
-    {
-      for (size_t k = i + 1; k < nearSet.n_elem; ++k)
-      {
-        if (indices[j] == nearSet[k])
-        {
-          nearSet[k] = dataset.n_cols; // Invalid index to indicate it's used.
-          usedNearSetPoints++;
-        }
-      }
-    }
-
-    // Now we update the count of far set points and near set points.
-    farSetSize -= (childUsedSetSize - usedNearSetPoints);
-    nearSetSize -= usedNearSetPoints;
-
-    // Update number of used points.
-    usedSetSize += childUsedSetSize;
+    // Now with the child created, it returns the childIndices and
+    // childDistances vectors in this form:
+    // [ childFar | childUsed ]
+    // For each point in the childUsed set, we must move that point to the used
+    // set in our own vector.
+    MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
+        childIndices, childFarSetSize, childUsedSetSize);
   }
-
-  // Now that this is all done, our memory looks like this:
-  // [ childFar | childUsed | used ]
-  // We need to rebuild the distances and the set so it looks like this:
-  // [ far | used ]
-  // because all the points in the near set should be used up.
-  farSetSize = childFarSetSize;
-
-  ComputeDistances(pointIndex, indices, distances, farSetSize);
 }
 
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -555,6 +524,119 @@
   return (childFarSetSize + farSetSize);
 }
 
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::MoveToUsedSet(
+    arma::Col<size_t>& indices,
+    arma::vec& distances,
+    size_t& nearSetSize,
+    size_t& farSetSize,
+    size_t& usedSetSize,
+    arma::Col<size_t>& childIndices,
+    const size_t childFarSetSize, // childNearSetSize is 0 in this case.
+    const size_t childUsedSetSize)
+{
+  // Loop across the set.  We will swap points as we need.  It should be noted
+  // that farSetSize and nearSetSize may change with each iteration of this loop
+  // (depending on if we make a swap or not).
+  size_t startChildUsedSet = 0; // Where to start in the child set.
+  for (size_t i = 0; i < nearSetSize; ++i)
+  {
+    // Discover if this point was in the child's used set.
+    for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
+    {
+      if (childIndices[childFarSetSize + j] == indices[i])
+      {
+        // We have found a point; a swap is necessary.
+        // Since this point is from the near set, to preserve the near set, we
+        // must do a three-way swap.
+
+        // First take the target point and put the used point there.
+        size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+        double tempDist = distances[nearSetSize + farSetSize - 1];
+
+        indices[nearSetSize + farSetSize - 1] = indices[i];
+        distances[nearSetSize + farSetSize - 1] = distances[i];
+
+        // Now take the last point in the near set and put the temporary point
+        // (which is from the far set) there.  Store the near set point as a
+        // temporary.  If the near set only has one point, we don't need to do
+        // the second swap.
+        if ((nearSetSize - 1) != i)
+        {
+          size_t tempNearIndex = indices[nearSetSize - 1];
+          double tempNearDist = distances[nearSetSize - 1];
+
+          indices[nearSetSize - 1] = tempIndex;
+          distances[nearSetSize - 1] = tempDist;
+
+          indices[i] = tempNearIndex;
+          distances[i] = tempNearDist;
+        }
+        else
+        {
+          indices[i] = tempIndex;
+          distances[i] = tempDist;
+        }
+
+        // We don't need to do a complete preservation of the child index set,
+        // but we want to make sure we only loop over points we haven't seen.
+        // So increment the child counter by 1 and move a point if we need.
+        if (j != startChildUsedSet)
+        {
+          childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
+              startChildUsedSet];
+        }
+
+        // Update all counters from the swaps we have done.
+        ++startChildUsedSet;
+        --nearSetSize;
+        --i; // Since we moved a point out of the near set we must step back.
+
+        break; // Break out of this for loop; back to the first one.
+      }
+    }
+  }
+
+  // Now loop over the far set.  This loop is different because we only require
+  // a normal two-way swap instead of the three-way swap to preserve the near
+  // set / far set ordering.
+  for (size_t i = 0; i < farSetSize; ++i)
+  {
+    // Discover if this point was in the child's used set.
+    for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
+    {
+      if (childIndices[childFarSetSize + j] == indices[i + nearSetSize])
+      {
+        // We have found a point to swap.
+        size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+        double tempDist = distances[nearSetSize + farSetSize - 1];
+
+        indices[nearSetSize + farSetSize - 1] = indices[nearSetSize + i];
+        distances[nearSetSize + farSetSize - 1] = distances[nearSetSize + i];
+
+        indices[nearSetSize + i] = tempIndex;
+        distances[nearSetSize + i] = tempDist;
+
+        if (j != startChildUsedSet)
+        {
+          childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
+              startChildUsedSet];
+        }
+
+        // Update all counters from the swaps we have done.
+        ++startChildUsedSet;
+        --farSetSize;
+        --i;
+
+        break; // Break out of this for loop; back to the first one.
+      }
+    }
+  }
+
+  // Update used set size.
+  usedSetSize += childUsedSetSize;
+}
+
 }; // namespace tree
 }; // namespace mlpack
 




More information about the mlpack-svn mailing list