[mlpack-svn] r14377 - mlpack/trunk/src/mlpack/methods/rann

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Feb 22 18:32:37 EST 2013


Author: rcurtin
Date: 2013-02-22 18:32:37 -0500 (Fri, 22 Feb 2013)
New Revision: 14377

Modified:
   mlpack/trunk/src/mlpack/methods/rann/allkrann_main.cpp
Log:
Update to only use L2-squared distances (for now).


Modified: mlpack/trunk/src/mlpack/methods/rann/allkrann_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/rann/allkrann_main.cpp	2013-02-22 23:30:47 UTC (rev 14376)
+++ mlpack/trunk/src/mlpack/methods/rann/allkrann_main.cpp	2013-02-22 23:32:37 UTC (rev 14377)
@@ -75,7 +75,7 @@
 {
   // Give CLI the command line parameters the user passed in.
   CLI::ParseCommandLine(argc, argv);
-  math::RandomSeed(time(NULL)); 
+  math::RandomSeed(time(NULL));
 
   // Get all the parameters.
   string referenceFile = CLI::GetParam<string>("reference_file");
@@ -111,13 +111,13 @@
     Log::Fatal << referenceData.n_cols << ")." << endl;
   }
 
-  // Sanity check on the value of 'tau' with respect to 'k' so that 
-  // 'k' neighbors are not requested from the top-'rank_error' neighbors 
+  // Sanity check on the value of 'tau' with respect to 'k' so that
+  // 'k' neighbors are not requested from the top-'rank_error' neighbors
   // where 'rank_error' <= 'k'.
-  size_t rank_error 
+  size_t rank_error
     = (size_t) ceil(tau * (double) referenceData.n_cols / 100.0);
   if (rank_error <= k)
-    Log::Fatal << "Invalid 'tau' (" << tau << ") - k (" << k << ") " << 
+    Log::Fatal << "Invalid 'tau' (" << tau << ") - k (" << k << ") " <<
       "combination. Increase 'tau' or decrease 'k'." << endl;
 
   // Sanity check on leaf size.
@@ -134,7 +134,7 @@
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-  if (naive) 
+  if (naive)
   {
     AllkRANN* allkrann;
     if (CLI::GetParam<string>("query_file") != "")
@@ -143,7 +143,7 @@
 
       data::Load(queryFile.c_str(), queryData, true);
 
-      Log::Info << "Loaded query data from '" << queryFile << "' (" << 
+      Log::Info << "Loaded query data from '" << queryFile << "' (" <<
         queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
 
       allkrann = new AllkRANN(referenceData, queryData, naive);
@@ -151,19 +151,19 @@
     else
       allkrann = new AllkRANN(referenceData, naive);
 
-    Log::Info << "Computing " << k << " nearest neighbors " << "with " << 
+    Log::Info << "Computing " << k << " nearest neighbors " << "with " <<
       tau << "% rank approximation..." << endl;
 
     allkrann->Search(k, neighbors, distances, tau, alpha);
-    
+
     Log::Info << "Neighbors computed." << endl;
 
     delete allkrann;
   }
   else
   {
-    // The results output by the AllkRANN class 
-    // shuffled because the tree construction shuffles the point sets. 
+    // The results output by the AllkRANN class
+    // shuffled because the tree construction shuffles the point sets.
     arma::Mat<size_t> neighborsOut;
     arma::mat distancesOut;
 
@@ -180,10 +180,12 @@
       Log::Info << "Building reference tree..." << endl;
       Timer::Start("tree_building");
 
-      BinarySpaceTree<bound::HRectBound<2>, RAQueryStat<NearestNeighborSort> >
-        refTree(referenceData, oldFromNewRefs, leafSize);
-      BinarySpaceTree<bound::HRectBound<2>, RAQueryStat<NearestNeighborSort> >*
-        queryTree = NULL; // Empty for now.
+      BinarySpaceTree<bound::HRectBound<2, false>,
+          RAQueryStat<NearestNeighborSort> >
+          refTree(referenceData, oldFromNewRefs, leafSize);
+      BinarySpaceTree<bound::HRectBound<2, false>,
+          RAQueryStat<NearestNeighborSort> >*
+          queryTree = NULL; // Empty for now.
 
       Timer::Stop("tree_building");
 
@@ -198,7 +200,7 @@
         if (naive && leafSize < queryData.n_cols)
           leafSize = queryData.n_cols;
 
-        Log::Info << "Loaded query data from '" << queryFile << "' (" << 
+        Log::Info << "Loaded query data from '" << queryFile << "' (" <<
           queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
 
         Log::Info << "Building query tree..." << endl;
@@ -207,9 +209,9 @@
         // NeighborSearch, it does not copy the matrix.
         Timer::Start("tree_building");
 
-        queryTree = new BinarySpaceTree<bound::HRectBound<2>,
-          RAQueryStat<NearestNeighborSort> >
-          (queryData, oldFromNewQueries, leafSize);
+        queryTree = new BinarySpaceTree<bound::HRectBound<2, false>,
+            RAQueryStat<NearestNeighborSort> >
+            (queryData, oldFromNewQueries, leafSize);
         Timer::Stop("tree_building");
 
         allkrann = new AllkRANN(&refTree, queryTree, referenceData, queryData,
@@ -223,12 +225,12 @@
         Log::Info << "Trees built." << endl;
       }
 
-      Log::Info << "Computing " << k << " nearest neighbors " << "with " << 
+      Log::Info << "Computing " << k << " nearest neighbors " << "with " <<
         tau << "% rank approximation..." << endl;
       allkrann->Search(k, neighborsOut, distancesOut,
-                       tau, alpha, sampleAtLeaves, 
+                       tau, alpha, sampleAtLeaves,
                        firstLeafExact, singleSampleLimit);
-    
+
       Log::Info << "Neighbors computed." << endl;
 
       // We have to map back to the original indices from before the tree
@@ -249,7 +251,7 @@
           // Map indices of neighbors.
           for (size_t j = 0; j < distancesOut.n_rows; ++j)
           {
-            neighbors(j, oldFromNewQueries[i]) 
+            neighbors(j, oldFromNewQueries[i])
               = oldFromNewRefs[neighborsOut(j, i)];
           }
         }
@@ -264,12 +266,12 @@
           // Map indices of neighbors.
           for (size_t j = 0; j < distancesOut.n_rows; ++j)
           {
-            neighbors(j, oldFromNewRefs[i]) 
+            neighbors(j, oldFromNewRefs[i])
               = oldFromNewRefs[neighborsOut(j, i)];
           }
         }
       }
-    
+
       // Clean up.
       if (queryTree)
         delete queryTree;
@@ -283,7 +285,7 @@
   }
 
   // Save output.
-  if (distancesFile != "") 
+  if (distancesFile != "")
     data::Save(distancesFile, distances);
 
   if (neighborsFile != "")




More information about the mlpack-svn mailing list