[mlpack-svn] r10791 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 14 12:45:46 EST 2011


Author: rcurtin
Date: 2011-12-14 12:45:45 -0500 (Wed, 14 Dec 2011)
New Revision: 10791

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
Log:
Fix #174 and another bug in the main executable.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	2011-12-14 17:41:37 UTC (rev 10790)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	2011-12-14 17:45:45 UTC (rev 10791)
@@ -74,7 +74,7 @@
   if (!data::Load(referenceFile.c_str(), referenceData))
     Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
 
-  Log::Info << "Loaded reference data from " << referenceFile << endl;
+  Log::Info << "Loaded reference data from '" << referenceFile << "'." << endl;
 
   // Sanity check on k value: must be greater than 0, must be less than the
   // number of reference points.
@@ -108,12 +108,12 @@
   // Build trees by hand, so we can save memory: if we pass a tree to
   // NeighborSearch, it does not copy the matrix.
   Log::Info << "Building reference tree..." << endl;
-  Timer::Start("tree_building");
+  Timer::Start("reference_tree_building");
 
   BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >
       refTree(referenceData, oldFromNewRefs);
 
-  Timer::Stop("tree_building");
+  Timer::Stop("reference_tree_building");
 
   std::vector<size_t> oldFromNewQueries;
 
@@ -123,20 +123,20 @@
     arma::mat queryData;
 
     if (!data::Load(queryFile.c_str(), queryData))
-      Log::Fatal << "Query file " << queryFile << " not found" << endl;
+      Log::Fatal << "Query file " << queryFile << " not found." << endl;
 
-    Log::Info << "Query data loaded from " << queryFile << endl;
+    Log::Info << "Query data loaded from '" << queryFile << "'." << endl;
 
     Log::Info << "Building query tree..." << endl;
 
     // Build trees by hand, so we can save memory: if we pass a tree to
     // NeighborSearch, it does not copy the matrix.
-    Timer::Start("tree_building");
+    Timer::Start("query_tree_building");
 
     BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >
-        queryTree(queryData, oldFromNewRefs);
+        queryTree(queryData, oldFromNewQueries);
 
-    Timer::Stop("tree_building");
+    Timer::Stop("query_tree_building");
 
     allkfn = new AllkFN(referenceData, queryData, naive, singleMode, 20,
         &refTree, &queryTree);
@@ -163,15 +163,32 @@
   arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
 
   // Do the actual remapping.
-  for (size_t i = 0; i < distances.n_cols; i++)
+  if (CLI::GetParam<string>("query_file") != "")
   {
-    // Map distances (copy a column).
-    distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
+    for (size_t i = 0; i < distances.n_cols; ++i)
+    {
+      // Map distances (copy a column).
+      distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
 
-    // Map indices of neighbors.
-    for (size_t j = 0; j < distances.n_rows; j++)
+      // Map indices of neighbors.
+      for (size_t j = 0; j < distances.n_rows; ++j)
+      {
+        neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
+      }
+    }
+  }
+  else
+  {
+    for (size_t i = 0; i < distances.n_cols; ++i)
     {
-      neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
+      // Map distances (copy a column).
+      distancesOut.col(oldFromNewRefs[i]) = distances.col(i);
+
+      // Map indices of neighbors.
+      for (size_t j = 0; j < distances.n_rows; ++j)
+      {
+        neighborsOut(j, oldFromNewRefs[i]) = oldFromNewRefs[neighbors(j, i)];
+      }
     }
   }
 

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2011-12-14 17:41:37 UTC (rev 10790)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2011-12-14 17:45:45 UTC (rev 10791)
@@ -74,7 +74,7 @@
   if (!data::Load(referenceFile.c_str(), referenceData))
     Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
 
-  Log::Info << "Loaded reference data from " << referenceFile << endl;
+  Log::Info << "Loaded reference data from '" << referenceFile << "'." << endl;
 
   // Sanity check on k value: must be greater than 0, must be less than the
   // number of reference points.
@@ -136,7 +136,7 @@
     Timer::Start("tree_building");
 
     BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
-        queryTree(queryData, oldFromNewRefs, leafSize);
+        queryTree(queryData, oldFromNewQueries, leafSize);
 
     Timer::Stop("tree_building");
 
@@ -165,15 +165,32 @@
   arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
 
   // Do the actual remapping.
-  for (size_t i = 0; i < distances.n_cols; i++)
+  if (CLI::GetParam<string>("query_file") != "")
   {
-    // Map distances (copy a column).
-    distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
+    for (size_t i = 0; i < distances.n_cols; ++i)
+    {
+      // Map distances (copy a column).
+      distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
 
-    // Map indices of neighbors.
-    for (size_t j = 0; j < distances.n_rows; j++)
+      // Map indices of neighbors.
+      for (size_t j = 0; j < distances.n_rows; ++j)
+      {
+        neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
+      }
+    }
+  }
+  else
+  {
+    for (size_t i = 0; i < distances.n_cols; ++i)
     {
-      neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
+      // Map distances (copy a column).
+      distancesOut.col(oldFromNewRefs[i]) = distances.col(i);
+
+      // Map indices of neighbors.
+      for (size_t j = 0; j < distances.n_rows; ++j)
+      {
+        neighborsOut(j, oldFromNewRefs[i]) = oldFromNewRefs[neighbors(j, i)];
+      }
     }
   }
 

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2011-12-14 17:41:37 UTC (rev 10790)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2011-12-14 17:45:45 UTC (rev 10791)
@@ -23,8 +23,8 @@
                TreeType* referenceTree,
                TreeType* queryTree,
                const MetricType metric) :
-    referenceCopy(referenceTree ? 0 : referenceSet),
-    queryCopy(queryTree ? 0 : querySet),
+    referenceCopy(referenceTree ? arma::mat() : referenceSet),
+    queryCopy(queryTree ? arma::mat() : querySet),
     referenceSet(referenceTree ? referenceSet : referenceCopy),
     querySet(queryTree ? querySet : queryCopy),
     naive(naive),
@@ -78,7 +78,7 @@
                const size_t leafSize,
                TreeType* referenceTree,
                const MetricType metric) :
-    referenceCopy(referenceTree ? 0 : referenceSet),
+    referenceCopy(referenceTree ? arma::mat() : referenceSet),
     referenceSet(referenceTree ? referenceSet : referenceCopy),
     querySet(referenceTree ? referenceSet : referenceCopy),
     naive(naive),
@@ -578,13 +578,13 @@
  * @param distance Distance from query point to reference point.
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-InsertNeighbor(const size_t queryIndex,
-               const size_t pos,
-               const size_t neighbor,
-               const double distance,
-               arma::Mat<size_t>& neighbors,
-               arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+    const size_t queryIndex,
+    const size_t pos,
+    const size_t neighbor,
+    const double distance,
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances)
 {
   // We only memmove() if there is actually a need to shift something.
   if (pos < (distances.n_rows - 1))




More information about the mlpack-svn mailing list