[mlpack-svn] r14260 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Feb 9 20:35:52 EST 2013


Author: rcurtin
Date: 2013-02-09 20:35:52 -0500 (Sat, 09 Feb 2013)
New Revision: 14260

Added:
   mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
Log:
Abstract away the unmapping of points.  Also fix a few bugs in AllkFN because
that executable was really not kept up to date with the AllkNN executable.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt	2013-02-09 23:06:02 UTC (rev 14259)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt	2013-02-10 01:35:52 UTC (rev 14260)
@@ -12,6 +12,8 @@
   sort_policies/furthest_neighbor_sort.cpp
   sort_policies/furthest_neighbor_sort_impl.hpp
   typedef.hpp
+  unmap.hpp
+  unmap.cpp
 )
 
 # Add directory name to sources.

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	2013-02-09 23:06:02 UTC (rev 14259)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	2013-02-10 01:35:52 UTC (rev 14260)
@@ -172,36 +172,17 @@
   arma::mat distancesOut(distances.n_rows, distances.n_cols);
   arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
 
-  // Do the actual remapping.
-  if (CLI::GetParam<string>("query_file") != "")
-  {
-    for (size_t i = 0; i < distances.n_cols; ++i)
-    {
-      // Map distances (copy a column).
-      distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
-
-      // Map indices of neighbors.
-      for (size_t j = 0; j < distances.n_rows; ++j)
-      {
-        neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
-      }
-    }
-  }
+  // Map the points back to their original locations.
+  if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
+    Unmap(neighbors, distances, oldFromNewReferences, oldFromNewQueries,
+        neighborsOut, distancesOut, true);
+  else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
+    Unmap(neighbors, distances, oldFromNewReferences, neighborsOut,
+        distancesOut, true);
   else
-  {
-    for (size_t i = 0; i < distances.n_cols; ++i)
-    {
-      // Map distances (copy a column).
-      distancesOut.col(oldFromNewRefs[i]) = distances.col(i);
+    Unmap(neighbors, distances, oldFromNewReferences, oldFromNewReferences,
+        neighborsOut, distancesOut, true);
 
-      // Map indices of neighbors.
-      for (size_t j = 0; j < distances.n_rows; ++j)
-      {
-        neighborsOut(j, oldFromNewRefs[i]) = oldFromNewRefs[neighbors(j, i)];
-      }
-    }
-  }
-
   // Clean up.
   if (queryTree)
     delete queryTree;

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2013-02-09 23:06:02 UTC (rev 14259)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2013-02-10 01:35:52 UTC (rev 14260)
@@ -13,6 +13,7 @@
 #include <iostream>
 
 #include "neighbor_search.hpp"
+#include "unmap.hpp"
 
 using namespace std;
 using namespace mlpack;
@@ -53,20 +54,29 @@
 PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
 PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
 PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
-    "dual-tree search.", "s");
+    "dual-tree search).", "s");
 PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
     "(experimental, may be slow).", "c");
+PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
+    "random orthogonal basis.", "R");
+PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
 
 int main(int argc, char *argv[])
 {
   // Give CLI the command line parameters the user passed in.
   CLI::ParseCommandLine(argc, argv);
 
+  if (CLI::GetParam<int>("seed") != 0)
+    math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+  else
+    math::RandomSeed((size_t) std::time(NULL));
+
   // Get all the parameters.
-  string referenceFile = CLI::GetParam<string>("reference_file");
+  const string referenceFile = CLI::GetParam<string>("reference_file");
+  const string queryFile = CLI::GetParam<string>("query_file");
 
-  string distancesFile = CLI::GetParam<string>("distances_file");
-  string neighborsFile = CLI::GetParam<string>("neighbors_file");
+  const string distancesFile = CLI::GetParam<string>("distances_file");
+  const string neighborsFile = CLI::GetParam<string>("neighbors_file");
 
   int lsInt = CLI::GetParam<int>("leaf_size");
 
@@ -74,14 +84,22 @@
 
   bool naive = CLI::HasParam("naive");
   bool singleMode = CLI::HasParam("single_mode");
+  const bool randomBasis = CLI::HasParam("random_basis");
 
   arma::mat referenceData;
   arma::mat queryData; // So it doesn't go out of scope.
-  data::Load(referenceFile.c_str(), referenceData, true);
+  data::Load(referenceFile, referenceData, true);
 
   Log::Info << "Loaded reference data from '" << referenceFile << "' ("
       << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
 
+  if (queryFile != "")
+  {
+    data::Load(queryFile, queryData, true);
+    Log::Info << "Loaded query data from '" << queryFile << "' ("
+      << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+  }
+
   // Sanity check on k value: must be greater than 0, must be less than the
   // number of reference points.
   if (k > referenceData.n_cols)
@@ -108,6 +126,43 @@
   if (naive)
     leafSize = referenceData.n_cols;
 
+  // See if we want to project onto a random basis.
+  if (randomBasis)
+  {
+    // Generate the random basis.
+    while (true)
+    {
+      // [Q, R] = qr(randn(d, d));
+      // Q = Q * diag(sign(diag(R)));
+      arma::mat q, r;
+      if (arma::qr(q, r, arma::randn<arma::mat>(referenceData.n_rows,
+          referenceData.n_rows)))
+      {
+        arma::vec rDiag(r.n_rows);
+        for (size_t i = 0; i < rDiag.n_elem; ++i)
+        {
+          if (r(i, i) < 0)
+            rDiag(i) = -1;
+          else if (r(i, i) > 0)
+            rDiag(i) = 1;
+          else
+            rDiag(i) = 0;
+        }
+
+        q *= arma::diagmat(rDiag);
+
+        // Check if the determinant is positive.
+        if (arma::det(q) >= 0)
+        {
+          referenceData = q * referenceData;
+          if (queryFile != "")
+            queryData = q * queryData;
+          break;
+        }
+      }
+    }
+  }
+
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
@@ -125,9 +180,9 @@
     Timer::Start("tree_building");
 
     BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
-      refTree(referenceData, oldFromNewRefs, leafSize);
+        refTree(referenceData, oldFromNewRefs, leafSize);
     BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
-      queryTree = NULL; // Empty for now.
+        queryTree = NULL; // Empty for now.
 
     Timer::Stop("tree_building");
 
@@ -135,15 +190,11 @@
 
     if (CLI::GetParam<string>("query_file") != "")
     {
-      string queryFile = CLI::GetParam<string>("query_file");
-
-      data::Load(queryFile.c_str(), queryData, true);
-
       if (naive && leafSize < queryData.n_cols)
         leafSize = queryData.n_cols;
 
       Log::Info << "Loaded query data from '" << queryFile << "' ("
-        << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+          << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
 
       Log::Info << "Building query tree..." << endl;
 
@@ -184,51 +235,17 @@
     // construction.
     Log::Info << "Re-mapping indices..." << endl;
 
-    neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
-    distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
-
-    // Do the actual remapping.
+    // Map the results back to the correct places.
     if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
-    {
-      for (size_t i = 0; i < distancesOut.n_cols; ++i)
-      {
-        // Map distances (copy a column) and square root.
-        distances.col(oldFromNewQueries[i]) = sqrt(distancesOut.col(i));
-
-        // Map indices of neighbors.
-        for (size_t j = 0; j < distancesOut.n_rows; ++j)
-        {
-          neighbors(j, oldFromNewQueries[i]) =
-              oldFromNewRefs[neighborsOut(j, i)];
-        }
-      }
-    }
+      Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries,
+          neighbors, distances, true);
     else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
-    {
-      // No remapping of queries is necessary.  So distances are the same.
-      distances = sqrt(distancesOut);
-
-      // The neighbor indices must be mapped.
-      for (size_t j = 0; j < neighborsOut.n_elem; ++j)
-      {
-        neighbors[j] = oldFromNewRefs[neighborsOut[j]];
-      }
-    }
+      Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances,
+          true);
     else
-    {
-      for (size_t i = 0; i < distancesOut.n_cols; ++i)
-      {
-        // Map distances (copy a column).
-        distances.col(oldFromNewRefs[i]) = sqrt(distancesOut.col(i));
+      Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs,
+          neighbors, distances, true);
 
-        // Map indices of neighbors.
-        for (size_t j = 0; j < distancesOut.n_rows; ++j)
-        {
-          neighbors(j, oldFromNewRefs[i]) = oldFromNewRefs[neighborsOut(j, i)];
-        }
-      }
-    }
-
     // Clean up.
     if (queryTree)
       delete queryTree;
@@ -256,10 +273,6 @@
     // See if we have query data.
     if (CLI::HasParam("query_file"))
     {
-      string queryFile = CLI::GetParam<string>("query_file");
-
-      data::Load(queryFile, queryData, true);
-
       // Build query tree.
       if (!singleMode)
       {

Added: mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.cpp	2013-02-10 01:35:52 UTC (rev 14260)
@@ -0,0 +1,64 @@
+/**
+ * @file unmap.cpp
+ * @author Ryan Curtin
+ *
+ * Auxiliary function to unmap neighbor search results.
+ */
+#include "unmap.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+// Useful in the dual-tree setting.
+void Unmap(const arma::Mat<size_t>& neighbors,
+           const arma::mat& distances,
+           const std::vector<size_t>& referenceMap,
+           const std::vector<size_t>& queryMap,
+           arma::Mat<size_t>& neighborsOut,
+           arma::mat& distancesOut,
+           const bool squareRoot)
+{
+  // Set matrices to correct size.
+  neighborsOut.set_size(neighbors.n_rows, neighbors.n_cols);
+  distancesOut.set_size(distances.n_rows, distances.n_cols);
+
+  // Unmap distances.
+  for (size_t i = 0; i < distances.n_cols; ++i)
+  {
+    // Map columns to the correct place.  The ternary operator does not work
+    // here...
+    if (squareRoot)
+      distancesOut.col(queryMap[i]) = sqrt(distances.col(i));
+    else
+      distancesOut.col(queryMap[i]) = distances.col(i);
+
+    // Map indices of neighbors.
+    for (size_t j = 0; j < distances.n_rows; ++j)
+      neighborsOut(j, queryMap[i]) = referenceMap[neighbors(j, i)];
+  }
+}
+
+// Useful in the single-tree setting.
+void Unmap(const arma::Mat<size_t>& neighbors,
+           const arma::mat& distances,
+           const std::vector<size_t>& referenceMap,
+           arma::Mat<size_t>& neighborsOut,
+           arma::mat& distancesOut,
+           const bool squareRoot)
+{
+  // Set matrices to correct size.
+  neighborsOut.set_size(neighbors.n_rows, neighbors.n_cols);
+
+  // Take square root of distances, if necessary.
+  if (squareRoot)
+    distancesOut = sqrt(distances);
+  else
+    distancesOut = distances;
+
+  // Map neighbors back to original locations.
+  for (size_t j = 0; j < neighbors.n_elem; ++j)
+    neighborsOut[j] = referenceMap[neighbors[j]];
+}
+
+}; // namespace neighbor
+}; // namespace mlpack

Added: mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.hpp	2013-02-10 01:35:52 UTC (rev 14260)
@@ -0,0 +1,61 @@
+/**
+ * @file unmap.hpp
+ * @author Ryan Curtin
+ *
+ * Convenience methods to unmap results.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_UNMAP_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_UNMAP_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * Assuming that the datasets have been mapped using the referenceMap and the
+ * queryMap (such as during kd-tree construction), unmap the columns of the
+ * distances and neighbors matrices into neighborsOut and distancesOut, and also
+ * unmap the entries in each row of neighbors.  This is useful for the dual-tree
+ * case.
+ *
+ * @param neighbors Matrix of neighbors resulting from neighbor search.
+ * @param distances Matrix of distances resulting from neighbor search.
+ * @param referenceMap Mapping of reference set to old points.
+ * @param queryMap Mapping of query set to old points.
+ * @param neighborsOut Matrix to store unmapped neighbors into.
+ * @param distancesOut Matrix to store unmapped distances into.
+ * @param squareRoot If true, take the square root of the distances.
+ */
+void Unmap(const arma::Mat<size_t>& neighbors,
+           const arma::mat& distances,
+           const std::vector<size_t>& referenceMap,
+           const std::vector<size_t>& queryMap,
+           arma::Mat<size_t>& neighborsOut,
+           arma::mat& distancesOut,
+           const bool squareRoot = false);
+
+/**
+ * Assuming that the datasets have been mapped using referenceMap (such as
+ * during kd-tree construction), unmap the columns of the distances and
+ * neighbors matrices into neighborsOut and distancesOut, and also unmap the
+ * entries in each row of neighbors.  This is useful for the single-tree case.
+ *
+ * @param neighbors Matrix of neighbors resulting from neighbor search.
+ * @param distances Matrix of distances resulting from neighbor search.
+ * @param referenceMap Mapping of reference set to old points.
+ * @param neighborsOut Matrix to store unmapped neighbors into.
+ * @param distancesOut Matrix to store unmapped distances into.
+ * @param squareRoot If true, take the square root of the distances.
+ */
+void Unmap(const arma::Mat<size_t>& neighbors,
+           const arma::mat& distances,
+           const std::vector<size_t>& referenceMap,
+           arma::Mat<size_t>& neighborsOut,
+           arma::mat& distancesOut,
+           const bool squareRoot = false);
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif




More information about the mlpack-svn mailing list