[mlpack-svn] r10719 - mlpack/trunk/src/mlpack/methods/range_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 12 11:12:01 EST 2011


Author: rcurtin
Date: 2011-12-12 11:12:00 -0500 (Mon, 12 Dec 2011)
New Revision: 10719

Modified:
   mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
   mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
Log:
Stop using unsafe_col().


Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp	2011-12-12 10:47:10 UTC (rev 10718)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp	2011-12-12 16:12:00 UTC (rev 10719)
@@ -45,8 +45,8 @@
    * @param leafSize The leaf size to be used during tree construction.
    * @param metric Instantiated distance metric.
    */
-  RangeSearch(const arma::mat& referenceSet,
-              const arma::mat& querySet,
+  RangeSearch(const typename TreeType::Mat& referenceSet,
+              const typename TreeType::Mat& querySet,
               const bool naive = false,
               const bool singleMode = false,
               const size_t leafSize = 20,
@@ -71,7 +71,7 @@
    * @param leafSize The leaf size to be used during tree construction.
    * @param metric Instantiated distance metric.
    */
-  RangeSearch(const arma::mat& referenceSet,
+  RangeSearch(const typename TreeType::Mat& referenceSet,
               const bool naive = false,
               const bool singleMode = false,
               const size_t leafSize = 20,
@@ -108,8 +108,8 @@
    */
   RangeSearch(const TreeType* referenceTree,
               const TreeType* queryTree,
-              const arma::mat& referenceSet,
-              const arma::mat& querySet,
+              const typename TreeType::Mat& referenceSet,
+              const typename TreeType::Mat& querySet,
               const bool singleMode = false,
               const MetricType metric = MetricType());
 
@@ -141,7 +141,7 @@
    * @param metric Instantiated distance metric.
    */
   RangeSearch(const TreeType* referenceTree,
-              const arma::mat& referenceSet,
+              const typename TreeType::Mat& referenceSet,
               const bool singleMode = false,
               const MetricType metric = MetricType());
 
@@ -220,25 +220,28 @@
    *
    * @param referenceNode Reference node.
    * @param queryPoint Point to query for.
+   * @param queryIndex Index of query node.
    * @param range Range of distances to search for.
    * @param neighbors Object holding list of neighbors.
    * @param distances Object holding list of distances.
    */
+  template<typename VecType>
   void SingleTreeRecursion(const TreeType* referenceNode,
-                           const arma::vec& queryPoint,
+                           const VecType& queryPoint,
+                           const size_t queryIndex,
                            const math::Range& range,
                            std::vector<size_t>& neighbors,
                            std::vector<double>& distances);
 
   //! Copy of reference matrix; used when a tree is built internally.
-  arma::mat referenceCopy;
+  typename TreeType::Mat referenceCopy;
   //! Copy of query matrix; used when a tree is built internally.
-  arma::mat queryCopy;
+  typename TreeType::Mat queryCopy;
 
   //! Reference set (data should be accessed using this).
-  const arma::mat& referenceSet;
+  const typename TreeType::Mat& referenceSet;
   //! Query set (data should be accessed using this).
-  const arma::mat& querySet;
+  const typename TreeType::Mat& querySet;
 
   //! Reference tree.
   TreeType* referenceTree;

Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp	2011-12-12 10:47:10 UTC (rev 10718)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp	2011-12-12 16:12:00 UTC (rev 10719)
@@ -14,12 +14,13 @@
 namespace range {
 
 template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(const arma::mat& referenceSet,
-                                               const arma::mat& querySet,
-                                               const bool naive,
-                                               const bool singleMode,
-                                               const size_t leafSize,
-                                               const MetricType metric) :
+RangeSearch<MetricType, TreeType>::RangeSearch(
+    const typename TreeType::Mat& referenceSet,
+    const typename TreeType::Mat& querySet,
+    const bool naive,
+    const bool singleMode,
+    const size_t leafSize,
+    const MetricType metric) :
     referenceCopy(referenceSet),
     queryCopy(querySet),
     referenceSet(referenceCopy),
@@ -45,11 +46,12 @@
 }
 
 template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(const arma::mat& referenceSet,
-                                               const bool naive,
-                                               const bool singleMode,
-                                               const size_t leafSize,
-                                               const MetricType metric) :
+RangeSearch<MetricType, TreeType>::RangeSearch(
+    const typename TreeType::Mat& referenceSet,
+    const bool naive,
+    const bool singleMode,
+    const size_t leafSize,
+    const MetricType metric) :
     referenceCopy(referenceSet),
     referenceSet(referenceCopy),
     querySet(referenceCopy),
@@ -72,12 +74,13 @@
 }
 
 template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(const TreeType* referenceTree,
-                                               const TreeType* queryTree,
-                                               const arma::mat& referenceSet,
-                                               const arma::mat& querySet,
-                                               const bool singleMode,
-                                               const MetricType metric) :
+RangeSearch<MetricType, TreeType>::RangeSearch(
+    const TreeType* referenceTree,
+    const TreeType* queryTree,
+    const typename TreeType::Mat& referenceSet,
+    const typename TreeType::Mat& querySet,
+    const bool singleMode,
+    const MetricType metric) :
     referenceSet(referenceSet),
     querySet(querySet),
     referenceTree(referenceTree),
@@ -93,10 +96,11 @@
 }
 
 template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(const TreeType* referenceTree,
-                                               const arma::mat& referenceSet,
-                                               const bool singleMode,
-                                               const MetricType metric) :
+RangeSearch<MetricType, TreeType>::RangeSearch(
+    const TreeType* referenceTree,
+    const typename TreeType::Mat& referenceSet,
+    const bool singleMode,
+    const MetricType metric) :
     referenceSet(referenceSet),
     querySet(referenceSet),
     referenceTree(referenceTree),
@@ -164,7 +168,7 @@
     // Loop over each of the query points.
     for (size_t i = 0; i < querySet.n_cols; i++)
     {
-      SingleTreeRecursion(referenceTree, querySet.unsafe_col(i), range,
+      SingleTreeRecursion(referenceTree, querySet.col(i), i, range,
           (*neighborPtr)[i], (*distancePtr)[i]);
     }
   }
@@ -296,12 +300,11 @@
   for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
        queryIndex++)
   {
-    // Get the query point from the matrix.
-    arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+    double minDistance =
+        referenceNode->Bound().MinDistance(querySet.col(queryIndex));
+    double maxDistance =
+        referenceNode->Bound().MaxDistance(querySet.col(queryIndex));
 
-    double minDistance = referenceNode->Bound().MinDistance(queryPoint);
-    double maxDistance = referenceNode->Bound().MaxDistance(queryPoint);
-
     // Now see if any points could fall into the range.
     if (range.Contains(math::Range(minDistance, maxDistance)))
     {
@@ -312,10 +315,9 @@
         // We can't add points that are ourselves.
         if (referenceNode != queryNode || referenceIndex != queryIndex)
         {
-          arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+          double distance = metric.Evaluate(querySet.col(queryIndex),
+                                            referenceSet.col(referenceIndex));
 
-          double distance = metric.Evaluate(queryPoint, referencePoint);
-
           // If this lies in the range, add it.
           if (range.Contains(distance))
           {
@@ -384,9 +386,11 @@
 }
 
 template<typename MetricType, typename TreeType>
+template<typename VecType>
 void RangeSearch<MetricType, TreeType>::SingleTreeRecursion(
     const TreeType* referenceNode,
-    const arma::vec& queryPoint,
+    const VecType& queryPoint,
+    const size_t queryIndex,
     const math::Range& range,
     std::vector<size_t>& neighbors,
     std::vector<double>& distances)
@@ -399,12 +403,11 @@
          referenceNode->End(); referenceIndex++)
     {
       // Don't add this point if it is the same as the query point.
-      if (!(referenceSet.colptr(referenceIndex) == queryPoint.memptr()))
+      if (!queryTree && !(referenceIndex == queryIndex))
       {
-        arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+        double distance = metric.Evaluate(queryPoint,
+                                          referenceSet.col(referenceIndex));
 
-        double distance = metric.Evaluate(queryPoint, referencePoint);
-
         // See if the point is in the range we are looking for.
         if (range.Contains(distance))
         {
@@ -425,8 +428,8 @@
     if (range.Contains(distanceLeft))
     {
       // The left may have points we want to recurse to.
-      SingleTreeRecursion(referenceNode->Left(), queryPoint, range, neighbors,
-          distances);
+      SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
+          range, neighbors, distances);
     }
     else
     {
@@ -436,8 +439,8 @@
     if (range.Contains(distanceRight))
     {
       // The right may have points we want to recurse to.
-      SingleTreeRecursion(referenceNode->Right(), queryPoint, range, neighbors,
-          distances);
+      SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
+          range, neighbors, distances);
     }
     else
     {




More information about the mlpack-svn mailing list