[mlpack-svn] r13365 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 7 18:41:06 EDT 2012


Author: rcurtin
Date: 2012-08-07 18:41:05 -0400 (Tue, 07 Aug 2012)
New Revision: 13365

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
Log:
Use L2 distance not squared-L2 distance because that breaks cover trees.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2012-08-07 22:40:37 UTC (rev 13364)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2012-08-07 22:41:05 UTC (rev 13365)
@@ -187,12 +187,12 @@
     distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
 
     // Do the actual remapping.
-    if (CLI::GetParam<string>("query_file") != "")
+    if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
     {
       for (size_t i = 0; i < distancesOut.n_cols; ++i)
       {
-        // Map distances (copy a column).
-        distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+        // Map distances (copy a column) and square root.
+        distances.col(oldFromNewQueries[i]) = sqrt(distancesOut.col(i));
 
         // Map indices of neighbors.
         for (size_t j = 0; j < distancesOut.n_rows; ++j)
@@ -202,12 +202,23 @@
         }
       }
     }
+    else if (singleMode)
+    {
+      // No remapping of queries is necessary.  So distances are the same.
+      distances = sqrt(distancesOut);
+
+      // The neighbor indices must be mapped.
+      for (size_t j = 0; j < neighborsOut.n_elem; ++j)
+      {
+        neighbors[j] = oldFromNewRefs[neighborsOut[j]];
+      }
+    }
     else
     {
       for (size_t i = 0; i < distancesOut.n_cols; ++i)
       {
         // Map distances (copy a column).
-        distances.col(oldFromNewRefs[i]) = distancesOut.col(i);
+        distances.col(oldFromNewRefs[i]) = sqrt(distancesOut.col(i));
 
         // Map indices of neighbors.
         for (size_t j = 0; j < distancesOut.n_rows; ++j)
@@ -228,14 +239,14 @@
     // Build our reference tree.
     Log::Info << "Building reference tree..." << endl;
     Timer::Start("tree_building");
-    CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+    CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
         QueryStat<NearestNeighborSort> > referenceTree(referenceData);
-    CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+    CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
         QueryStat<NearestNeighborSort> >* queryTree = NULL;
     Timer::Stop("tree_building");
 
-    NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
-        CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+    NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+        CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
         QueryStat<NearestNeighborSort> > >* allknn = NULL;
 
     // See if we have query data.
@@ -250,20 +261,20 @@
       {
         Log::Info << "Building query tree..." << endl;
         Timer::Start("tree_building");
-        queryTree = new CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
-            QueryStat<NearestNeighborSort> >(queryData);
+        queryTree = new CoverTree<metric::LMetric<2, true>,
+            tree::FirstPointIsRoot, QueryStat<NearestNeighborSort> >(queryData);
         Timer::Stop("tree_building");
       }
 
-      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
-          CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+          CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
           QueryStat<NearestNeighborSort> > >(&referenceTree, queryTree,
           referenceData, queryData, singleMode);
     }
     else
     {
-      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
-          CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+          CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
           QueryStat<NearestNeighborSort> > >(&referenceTree, referenceData,
           singleMode);
     }




More information about the mlpack-svn mailing list