[mlpack-svn] r11440 - in mlpack/trunk/src/mlpack: methods/range_search tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Feb 8 15:48:31 EST 2012


Author: mamidon
Date: 2012-02-08 15:48:30 -0500 (Wed, 08 Feb 2012)
New Revision: 11440

Added:
   mlpack/trunk/src/mlpack/methods/range_search/range_main.cpp
Modified:
   mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
   mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
   mlpack/trunk/src/mlpack/tests/cli_test.cpp
Log:
Implemented range search, need to implement output functionality.
Also need to update the documentation to reflect realities of range search.



Modified: mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt	2012-02-08 20:14:09 UTC (rev 11439)
+++ mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt	2012-02-08 20:48:30 UTC (rev 11440)
@@ -5,6 +5,7 @@
 set(SOURCES
   range_search.hpp
   range_search_impl.hpp
+	range_main.cpp
 )
 
 # Add directory name to sources.
@@ -15,3 +16,11 @@
 # Append sources (with directory name) to list of all MLPACK sources (used at
 # the parent scope).
 set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+add_executable(rangesearch
+	range_main.cpp
+)
+
+target_link_libraries(rangesearch
+	mlpack
+)

Added: mlpack/trunk/src/mlpack/methods/range_search/range_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_main.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_main.cpp	2012-02-08 20:48:30 UTC (rev 11440)
@@ -0,0 +1,223 @@
+/**
+ * @file allknn_main.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the AllkNN executable.  Allows some number of standard
+ * options.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "range_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::range;
+using namespace mlpack::tree;
+
+// Information about the program itself.
+PROGRAM_INFO("Range Search",
+    "This program will calculate the all nearest-neighbors of a set of "
+    "points constrained by a range. You may specify a separate set of reference points and query "
+    "points, or just a reference set which will be used as both the reference "
+    "and query set."
+    "\n\n"
+    "For example, the following will calculate nearest neighbors within 5 units of each"
+    "point in 'input.csv' and store the distances in 'distances.csv' and the "
+    "neighbors in the file 'neighbors.csv':"
+    "\n\n"
+    "$ allknn --min=0 --max=5 --reference_file=input.csv --distances_file=distances.csv\n"
+    "  --neighbors_file=neighbors.csv"
+    "\n\n"
+    "The output files are organized such that row i and column j in the "
+    "neighbors output file corresponds to the index of the point in the "
+    "reference set which is the i'th nearest neighbor from the point in the "
+    "query set with index j.  Row i and column j in the distances output file "
+    "corresponds to the distance between those two points.");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+    "r");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
+
+PARAM_DOUBLE_REQ("max", "Furthest neighbors to find.", "A");
+PARAM_DOUBLE("min", "Closest neighbors to find.", "I", 0.0);
+
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+
+PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+    "dual-tree search.", "s");
+
+typedef RangeSearch<metric::SquaredEuclideanDistance, 
+	BinarySpaceTree<bound::HRectBound<2>, EmptyStatistic> > AllInRange;
+
+int main(int argc, char *argv[])
+{
+  // Give CLI the command line parameters the user passed in.
+  CLI::ParseCommandLine(argc, argv);
+
+  // Get all the parameters.
+  string referenceFile = CLI::GetParam<string>("reference_file");
+
+  string distancesFile = CLI::GetParam<string>("distances_file");
+  string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+  int lsInt = CLI::GetParam<int>("leaf_size");
+
+  double max = CLI::GetParam<int>("max");
+  double min = CLI::GetParam<int>("min");
+
+  bool naive = CLI::HasParam("naive");
+  bool singleMode = CLI::HasParam("single_mode");
+
+  arma::mat referenceData;
+  arma::mat queryData; // So it doesn't go out of scope.
+  if (!data::Load(referenceFile.c_str(), referenceData))
+    Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
+
+  Log::Info << "Loaded reference data from '" << referenceFile << "'." << endl;
+
+  // Sanity check on range value: max must be greater than min.
+  if (max <= min)
+  {
+    Log::Fatal << "Invalid [min,max]: " << max << "; must be greater than " << min;
+  }
+
+  // Sanity check on leaf size.
+  if (lsInt < 0)
+  {
+    Log::Fatal << "Invalid leaf size: " << lsInt << ".  Must be greater "
+        "than or equal to 0." << endl;
+  }
+  size_t leafSize = lsInt;
+
+  // Naive mode overrides single mode.
+  if (singleMode && naive)
+  {
+    Log::Warn << "--single_mode ignored because --naive is present." << endl;
+  }
+
+  if (naive)
+    leafSize = referenceData.n_cols;
+
+  std::vector<std::vector<size_t> > neighbors;
+  std::vector<std::vector<double> > distances;
+
+  // Because we may construct it differently, we need a pointer.
+  AllInRange* rangeSearch = NULL;
+
+  // Mappings for when we build the tree.
+  std::vector<size_t> oldFromNewRefs;
+
+  // Build trees by hand, so we can save memory: if we pass a tree to
+  // NeighborSearch, it does not copy the matrix.
+  Log::Info << "Building reference tree..." << endl;
+  Timer::Start("tree_building");
+
+  BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic> 
+      refTree(referenceData, oldFromNewRefs, leafSize);
+  BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>*
+      queryTree = NULL; // Empty for now.
+
+  Timer::Stop("tree_building");
+
+  std::vector<size_t> oldFromNewQueries;
+
+  if (CLI::GetParam<string>("query_file") != "")
+  {
+    string queryFile = CLI::GetParam<string>("query_file");
+
+    if (!data::Load(queryFile.c_str(), queryData))
+      Log::Fatal << "Query file " << queryFile << " not found" << endl;
+
+    if (naive && leafSize < queryData.n_cols)
+      leafSize = queryData.n_cols;
+
+    Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
+
+    Log::Info << "Building query tree..." << endl;
+
+    // Build trees by hand, so we can save memory: if we pass a tree to
+    // NeighborSearch, it does not copy the matrix.
+    Timer::Start("tree_building");
+
+    queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+        tree::EmptyStatistic >(queryData, oldFromNewQueries,
+        leafSize);
+
+    Timer::Stop("tree_building");
+
+		rangeSearch = new AllInRange(&refTree, queryTree, referenceData, 
+			queryData, singleMode);    
+
+    Log::Info << "Tree built." << endl;
+  }
+  else
+  {
+    rangeSearch = new AllInRange(&refTree, referenceData, singleMode);
+
+    Log::Info << "Trees built." << endl;
+  }
+
+  Log::Info << "Computing neighbors within [" << min << ", " << max << "]." << endl;
+
+	math::Range r = math::Range(min,max);
+	rangeSearch->Search(r, neighbors, distances);
+
+  Log::Info << "Neighbors computed." << endl;
+
+  // We have to map back to the original indices from before the tree
+  // construction.
+  Log::Info << "Re-mapping indices..." << endl;
+
+//  arma::mat distancesOut(distances.n_rows, distances.n_cols);
+//  arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+
+  /*
+  // Do the actual remapping.
+  if (CLI::GetParam<string>("query_file") != "")
+  {
+    for (size_t i = 0; i < distances.n_cols; ++i)
+    {
+      // Map distances (copy a column).
+      distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
+
+      // Map indices of neighbors.
+      for (size_t j = 0; j < distances.n_rows; ++j)
+      {
+        neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
+      }
+    }
+  }
+  else
+  {
+    for (size_t i = 0; i < distances.n_cols; ++i)
+    {
+      // Map distances (copy a column).
+      distancesOut.col(oldFromNewRefs[i]) = distances.col(i);
+
+      // Map indices of neighbors.
+      for (size_t j = 0; j < distances.n_rows; ++j)
+      {
+        neighborsOut(j, oldFromNewRefs[i]) = oldFromNewRefs[neighbors(j, i)];
+      }
+    }
+  }
+
+  // Clean up.
+  if (queryTree)
+    delete queryTree;
+
+  // Save output.
+  data::Save(distancesFile, distances);
+  data::Save(neighborsFile, neighbors);
+*/
+  delete rangeSearch;
+}

Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp	2012-02-08 20:14:09 UTC (rev 11439)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp	2012-02-08 20:48:30 UTC (rev 11440)
@@ -106,8 +106,8 @@
    *      opposed to dual-tree computation).
    * @param metric Instantiated distance metric.
    */
-  RangeSearch(const TreeType* referenceTree,
-              const TreeType* queryTree,
+  RangeSearch(TreeType* referenceTree,
+              TreeType* queryTree,
               const typename TreeType::Mat& referenceSet,
               const typename TreeType::Mat& querySet,
               const bool singleMode = false,
@@ -140,7 +140,7 @@
    *      opposed to dual-tree computation).
    * @param metric Instantiated distance metric.
    */
-  RangeSearch(const TreeType* referenceTree,
+  RangeSearch(TreeType* referenceTree,
               const typename TreeType::Mat& referenceSet,
               const bool singleMode = false,
               const MetricType metric = MetricType());

Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp	2012-02-08 20:14:09 UTC (rev 11439)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp	2012-02-08 20:48:30 UTC (rev 11440)
@@ -75,8 +75,8 @@
 
 template<typename MetricType, typename TreeType>
 RangeSearch<MetricType, TreeType>::RangeSearch(
-    const TreeType* referenceTree,
-    const TreeType* queryTree,
+    TreeType* referenceTree,
+    TreeType* queryTree,
     const typename TreeType::Mat& referenceSet,
     const typename TreeType::Mat& querySet,
     const bool singleMode,
@@ -97,7 +97,7 @@
 
 template<typename MetricType, typename TreeType>
 RangeSearch<MetricType, TreeType>::RangeSearch(
-    const TreeType* referenceTree,
+    TreeType* referenceTree,
     const typename TreeType::Mat& referenceSet,
     const bool singleMode,
     const MetricType metric) :

Modified: mlpack/trunk/src/mlpack/tests/cli_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/cli_test.cpp	2012-02-08 20:14:09 UTC (rev 11439)
+++ mlpack/trunk/src/mlpack/tests/cli_test.cpp	2012-02-08 20:48:30 UTC (rev 11440)
@@ -192,7 +192,7 @@
   PrefixedOutStream pss(ss, BASH_GREEN "[INFO ] " BASH_CLEAR);
 
   pss << "I have a precise number which is ";
-  pss << std::setw(6) << std::setfill('0') << 156;
+  pss << std::setw(6) << std::setfill('0') << (int)156;
 
   BOOST_REQUIRE_EQUAL(ss.str(),
       BASH_GREEN "[INFO ] " BASH_CLEAR




More information about the mlpack-svn mailing list