[mlpack-svn] r10499 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Dec 1 21:59:07 EST 2011


Author: rcurtin
Date: 2011-12-01 21:59:07 -0500 (Thu, 01 Dec 2011)
New Revision: 10499

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp
Log:
Overhaul API for NeighborSearch (it probably is not complete) and remove
dependence on CLI to help fix #150.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	2011-12-02 01:00:03 UTC (rev 10498)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	2011-12-02 02:59:07 UTC (rev 10499)
@@ -16,6 +16,7 @@
 using namespace std;
 using namespace mlpack;
 using namespace mlpack::neighbor;
+using namespace mlpack::tree;
 
 // Information about the program itself.
 PROGRAM_INFO("All K-Furthest-Neighbors",
@@ -25,107 +26,158 @@
     "and query set."
     "\n\n"
     "For example, the following will calculate the 5 furthest neighbors of each"
-    "point in 'input.csv' and store the results in 'output.csv':"
+    "point in 'input.csv' and store the distances in 'distances.csv' and the "
+    "neighbors in 'neighbors.csv':"
     "\n\n"
-    "$ allknn --neighbor_search/k=5 --reference_file=input.csv\n"
-    "  --output_file=output.csv", "neighbor_search");
+    "$ allkfn --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
+    "  --neighbors_file=neighbors.csv"
+    "\n\n"
+    "The output files are organized such that row i and column j in the "
+    "neighbors output file corresponds to the index of the point in the "
+    "reference set which is the i'th furthest neighbor from the point in the "
+    "query set with index j.  Row i and column j in the distances output file "
+    "corresponds to the distance between those two points.", "");
 
 // Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "CSV file containing the reference dataset.",
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
     "");
-PARAM_STRING("query_file", "CSV file containing query points (optional).",
-    "", "");
-PARAM_STRING_REQ("output_file", "File to output CSV-formatted results into.",
-    "");
+PARAM_STRING("query_file", "File containing query points (optional).", "", "");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "");
 
+PARAM_INT("leaf_size", "Leaf size for tree building.", "", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+    "dual-tree search.", "");
+PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "");
+
 int main(int argc, char *argv[])
 {
   // Give CLI the command line parameters the user passed in.
   CLI::ParseCommandLine(argc, argv);
 
-  string reference_file = CLI::GetParam<string>("reference_file");
-  string output_file = CLI::GetParam<string>("output_file");
+  // Get all the parameters.
+  string referenceFile = CLI::GetParam<string>("reference_file");
+  string outputFile = CLI::GetParam<string>("output_file");
 
-  arma::mat reference_data;
+  string distancesFile = CLI::GetParam<string>("distances_file");
+  string neighborsFile = CLI::GetParam<string>("neighbors_file");
 
-  arma::Mat<size_t> neighbors;
-  arma::mat distances;
+  int leafSize = CLI::GetParam<int>("leaf_size");
 
-  if (!data::Load(reference_file.c_str(), reference_data))
-    Log::Fatal << "Reference file " << reference_file << "not found." << endl;
+  size_t k = CLI::GetParam<int>("k");
 
-  Log::Info << "Loaded reference data from " << reference_file << endl;
+  bool naive = CLI::HasParam("naive");
+  bool singleMode = CLI::HasParam("single_mode");
 
+  arma::mat referenceData;
+  if (!data::Load(referenceFile.c_str(), referenceData))
+    Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
+
+  Log::Info << "Loaded reference data from " << referenceFile << endl;
+
   // Sanity check on k value: must be greater than 0, must be less than the
   // number of reference points.
-  size_t k = CLI::GetParam<int>("neighbor_search/k");
-  if ((k <= 0) || (k >= reference_data.n_cols))
+  if ((k <= 0) || (k >= referenceData.n_cols))
   {
     Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
     Log::Fatal << "than the number of reference points (";
-    Log::Fatal << reference_data.n_cols << ")." << endl;
+    Log::Fatal << referenceData.n_cols << ")." << endl;
   }
 
   // Sanity check on leaf size.
-  if (CLI::GetParam<int>("tree/leaf_size") <= 0)
+  if (leafSize < 0)
   {
-    Log::Fatal << "Invalid leaf size: "
-        << CLI::GetParam<int>("allknn/leaf_size") << endl;
+    Log::Fatal << "Invalid leaf size: " << leafSize << ".  Must be greater "
+        "than or equal to 0." << endl;
   }
 
+  // Naive mode overrides single mode.
+  if (singleMode && naive)
+  {
+    Log::Warn << "--single_mode ignored because --naive is present." << endl;
+  }
+
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+
   AllkFN* allkfn = NULL;
 
+  std::vector<size_t> oldFromNewRefs;
+
+  // Build trees by hand, so we can save memory: if we pass a tree to
+  // NeighborSearch, it does not copy the matrix.
+  Log::Info << "Building reference tree..." << endl;
+  Timers::StartTimer("neighbor_search/tree_building");
+
+  BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >
+      refTree(referenceData, oldFromNewRefs);
+
+  Timers::StopTimer("neighbor_search/tree_building");
+
+  std::vector<size_t> oldFromNewQueries;
+
   if (CLI::GetParam<string>("query_file") != "")
   {
-    string query_file = CLI::GetParam<string>("query_file");
-    arma::mat query_data;
+    string queryFile = CLI::GetParam<string>("query_file");
+    arma::mat queryData;
 
-    if (!data::Load(query_file.c_str(), query_data))
-      Log::Fatal << "Query file " << query_file << " not found" << endl;
+    if (!data::Load(queryFile.c_str(), queryData))
+      Log::Fatal << "Query file " << queryFile << " not found" << endl;
 
-    Log::Info << "Query data loaded from " << query_file << endl;
+    Log::Info << "Query data loaded from " << queryFile << endl;
 
-    Log::Info << "Building query and reference trees..." << endl;
-    allkfn = new AllkFN(query_data, reference_data);
+    Log::Info << "Building query tree..." << endl;
 
+    // Build trees by hand, so we can save memory: if we pass a tree to
+    // NeighborSearch, it does not copy the matrix.
+    Timers::StartTimer("neighbor_search/tree_building");
+
+    BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >
+        queryTree(queryData, oldFromNewRefs);
+
+    Timers::StopTimer("neighbor_search/tree_building");
+
+    allkfn = new AllkFN(referenceData, queryData, naive, singleMode, 20,
+        &refTree, &queryTree);
+
+    Log::Info << "Tree built." << endl;
   }
   else
   {
-    Log::Info << "Building reference tree..." << endl;
-    allkfn = new AllkFN(reference_data);
+    allkfn = new AllkFN(referenceData, naive, singleMode, 20, &refTree);
+
+    Log::Info << "Trees built." << endl;
   }
 
-  Log::Info << "Tree(s) built." << endl;
-
   Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-  allkfn->ComputeNeighbors(neighbors, distances);
+  allkfn->ComputeNeighbors(k, neighbors, distances);
 
   Log::Info << "Neighbors computed." << endl;
-  Log::Info << "Exporting results..." << endl;
 
-  // Should be using data::Save or a related function instead of being written
-  // by hand.
-  try
+  // We have to map back to the original indices from before the tree
+  // construction.
+  Log::Info << "Re-mapping indices..." << endl;
+
+  arma::mat distancesOut(distances.n_rows, distances.n_cols);
+  arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+
+  // Do the actual remapping.
+  for (size_t i = 0; i < distances.n_cols; i++)
   {
-    ofstream out(output_file.c_str());
+    // Map distances (copy a column).
+    distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
 
-    for (size_t col = 0; col < neighbors.n_cols; col++)
+    // Map indices of neighbors.
+    for (size_t j = 0; j < distances.n_rows; j++)
     {
-      out << col << ", ";
-      for (size_t j = 0; j < (k - 1) /* last is special case */; j++)
-      {
-        out << neighbors(j, col) << ", " << distances(j, col) << ", ";
-      }
-      out << neighbors((k - 1), col) << ", " << distances((k - 1), col) << endl;
+      neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
     }
-
-    out.close();
   }
-  catch (exception& e)
-  {
-    Log::Fatal << "Error while opening " << output_file << ": " << e.what()
-        << endl;
-  }
 
+  // Save output.
+  data::Save(distancesFile, distances);
+  data::Save(neighborsFile, neighbors);
+
   delete allkfn;
 }

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2011-12-02 01:00:03 UTC (rev 10498)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2011-12-02 02:59:07 UTC (rev 10499)
@@ -16,6 +16,7 @@
 using namespace std;
 using namespace mlpack;
 using namespace mlpack::neighbor;
+using namespace mlpack::tree;
 
 // Information about the program itself.
 PROGRAM_INFO("All K-Nearest-Neighbors",
@@ -25,107 +26,160 @@
     "and query set."
     "\n\n"
     "For example, the following will calculate the 5 nearest neighbors of each"
-    "point in 'input.csv' and store the results in 'output.csv':"
+    "point in 'input.csv' and store the distances in 'distances.csv' and the "
+    "neighbors in 'neighbors.csv':"
     "\n\n"
-    "$ allknn --neighbor_search/k=5 --reference_file=input.csv\n"
-    "  --output_file=output.csv", "neighbor_search");
+    "$ allknn --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
+    "  --neighbors_file=neighbors.csv"
+    "\n\n"
+    "The output files are organized such that row i and column j in the "
+    "neighbors output file corresponds to the index of the point in the "
+    "reference set which is the i'th nearest neighbor from the point in the "
+    "query set with index j.  Row i and column j in the distances output file "
+    "corresponds to the distance between those two points.", "");
 
 // Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "CSV file containing the reference dataset.",
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
     "");
-PARAM_STRING("query_file", "CSV file containing query points (optional).",
-    "", "");
-PARAM_STRING_REQ("output_file", "File to output CSV-formatted results into.",
-    "");
+PARAM_STRING("query_file", "File containing query points (optional).", "", "");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "");
 
+PARAM_INT("leaf_size", "Leaf size for tree building.", "", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+    "dual-tree search.", "");
+PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "");
+
 int main(int argc, char *argv[])
 {
   // Give CLI the command line parameters the user passed in.
   CLI::ParseCommandLine(argc, argv);
 
-  string reference_file = CLI::GetParam<string>("reference_file");
-  string output_file = CLI::GetParam<string>("output_file");
+  // Get all the parameters.
+  string referenceFile = CLI::GetParam<string>("reference_file");
+  string outputFile = CLI::GetParam<string>("output_file");
 
-  arma::mat reference_data;
+  string distancesFile = CLI::GetParam<string>("distances_file");
+  string neighborsFile = CLI::GetParam<string>("neighbors_file");
 
-  arma::Mat<size_t> neighbors;
-  arma::mat distances;
+  int leafSize = CLI::GetParam<int>("leaf_size");
 
-  if (!data::Load(reference_file.c_str(), reference_data))
-    Log::Fatal << "Reference file " << reference_file << " not found." << endl;
+  size_t k = CLI::GetParam<int>("k");
 
-  Log::Info << "Loaded reference data from " << reference_file << endl;
+  bool naive = CLI::HasParam("naive");
+  bool singleMode = CLI::HasParam("single_mode");
 
+  arma::mat referenceData;
+  if (!data::Load(referenceFile.c_str(), referenceData))
+    Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
+
+  Log::Info << "Loaded reference data from " << referenceFile << endl;
+
   // Sanity check on k value: must be greater than 0, must be less than the
   // number of reference points.
-  size_t k = CLI::GetParam<int>("neighbor_search/k");
-  if ((k <= 0) || (k >= reference_data.n_cols))
+  if ((k <= 0) || (k >= referenceData.n_cols))
   {
     Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
     Log::Fatal << "than the number of reference points (";
-    Log::Fatal << reference_data.n_cols << ")." << endl;
+    Log::Fatal << referenceData.n_cols << ")." << endl;
   }
 
   // Sanity check on leaf size.
-  if (CLI::GetParam<int>("tree/leaf_size") <= 0)
+  if (leafSize < 0)
   {
-    Log::Fatal << "Invalid leaf size: "
-        << CLI::GetParam<int>("allknn/leaf_size") << endl;
+    Log::Fatal << "Invalid leaf size: " << leafSize << ".  Must be greater "
+        "than or equal to 0." << endl;
   }
 
+  // Naive mode overrides single mode.
+  if (singleMode && naive)
+  {
+    Log::Warn << "--single_mode ignored because --naive is present." << endl;
+  }
+
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+
+  // Because we may construct it differently, we need a pointer.
   AllkNN* allknn = NULL;
 
+  // Mappings for when we build the tree.
+  std::vector<size_t> oldFromNewRefs;
+
+  // Build trees by hand, so we can save memory: if we pass a tree to
+  // NeighborSearch, it does not copy the matrix.
+  Log::Info << "Building reference tree..." << endl;
+  Timers::StartTimer("neighbor_search/tree_building");
+
+  BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
+      refTree(referenceData, oldFromNewRefs, leafSize);
+
+  Timers::StopTimer("neighbor_search/tree_building");
+
+  std::vector<size_t> oldFromNewQueries;
+
   if (CLI::GetParam<string>("query_file") != "")
   {
-    string query_file = CLI::GetParam<string>("query_file");
-    arma::mat query_data;
+    string queryFile = CLI::GetParam<string>("query_file");
+    arma::mat queryData;
 
-    if (!data::Load(query_file.c_str(), query_data))
-      Log::Fatal << "Query file " << query_file << " not found" << endl;
+    if (!data::Load(queryFile.c_str(), queryData))
+      Log::Fatal << "Query file " << queryFile << " not found" << endl;
 
-    Log::Info << "Query data loaded from " << query_file << endl;
+    Log::Info << "Query data loaded from " << queryFile << endl;
 
-    Log::Info << "Building query and reference trees..." << endl;
-    allknn = new AllkNN(query_data, reference_data);
+    Log::Info << "Building query tree..." << endl;
 
+    // Build trees by hand, so we can save memory: if we pass a tree to
+    // NeighborSearch, it does not copy the matrix.
+    Timers::StartTimer("neighbor_search/tree_building");
+
+    BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
+        queryTree(queryData, oldFromNewRefs, leafSize);
+
+    Timers::StopTimer("neighbor_search/tree_building");
+
+    allknn = new AllkNN(referenceData, queryData, naive, singleMode, 20,
+        &refTree, &queryTree);
+
+    Log::Info << "Tree built." << endl;
   }
   else
   {
-    Log::Info << "Building reference tree..." << endl;
-    allknn = new AllkNN(reference_data);
+    allknn = new AllkNN(referenceData, naive, singleMode, 20, &refTree);
+
+    Log::Info << "Trees built." << endl;
   }
 
-  Log::Info << "Tree(s) built." << endl;
-
   Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-  allknn->ComputeNeighbors(neighbors, distances);
+  allknn->ComputeNeighbors(k, neighbors, distances);
 
   Log::Info << "Neighbors computed." << endl;
-  Log::Info << "Exporting results..." << endl;
 
-  // Should be using data::Save or a related function instead of being written
-  // by hand.
-  try
+  // We have to map back to the original indices from before the tree
+  // construction.
+  Log::Info << "Re-mapping indices..." << endl;
+
+  arma::mat distancesOut(distances.n_rows, distances.n_cols);
+  arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+
+  // Do the actual remapping.
+  for (size_t i = 0; i < distances.n_cols; i++)
   {
-    ofstream out(output_file.c_str());
+    // Map distances (copy a column).
+    distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
 
-    for (size_t col = 0; col < neighbors.n_cols; col++)
+    // Map indices of neighbors.
+    for (size_t j = 0; j < distances.n_rows; j++)
     {
-      out << col << ", ";
-      for (size_t j = 0; j < (k - 1) /* last is special case */; j++)
-      {
-        out << neighbors(j, col) << ", " << distances(j, col) << ", ";
-      }
-      out << neighbors((k - 1), col) << ", " << distances((k - 1), col) << endl;
+      neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
     }
-
-    out.close();
   }
-  catch (exception& e)
-  {
-    Log::Fatal << "Error while opening " << output_file << ": " << e.what()
-        << endl;
-  }
 
+  // Save output.
+  data::Save(distancesFile, distances);
+  data::Save(neighborsFile, neighbors);
+
   delete allknn;
 }

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2011-12-02 01:00:03 UTC (rev 10498)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2011-12-02 02:59:07 UTC (rev 10499)
@@ -22,15 +22,30 @@
                     * all-nearest-neighbors and all-furthest-neighbors
                     * searches. */ {
 
-// Define CLI parameters for the NeighborSearch class.
-PARAM_MODULE("neighbor_search",
-    "Parameters for the distance-based neighbor search.");
-PARAM_INT("k", "Number of neighbors to search for.", "neighbor_search", 5);
-PARAM_FLAG("single_mode", "If set, use single-tree mode (instead of "
-    "dual-tree).", "neighbor_search");
-PARAM_FLAG("naive_mode", "If set, use naive computations (no trees).  This "
-    "overrides the single_mode flag.", "neighbor_search");
+/**
+ * Extra data for each node in the tree.  For neighbor searches, each node only
+ * needs to store a bound on neighbor distances.
+ */
+template<typename SortPolicy>
+class QueryStat
+{
+ private:
+  //! The bound on the node's neighbor distances.
+  double bound;
 
+ public:
+  /**
+   * Initialize the statistic with the worst possible distance according to
+   * our sorting policy.
+   */
+  QueryStat() : bound(SortPolicy::WorstDistance()) { }
+
+  //! Get the bound.
+  const double Bound() const { return bound; }
+  //! Modify the bound.
+  double& Bound() { return bound; }
+};
+
 /**
  * The NeighborSearch class is a template class for performing distance-based
  * neighbor searches.  It takes a query dataset and a reference dataset (or just
@@ -40,120 +55,88 @@
  * reference dataset, and if that constructor is used, the given reference
  * dataset is also used as the query dataset.
  *
- * The template parameters Kernel and SortPolicy define the distance function
- * used and the sort function used.  More information on those classes can be
- * found in the kernel::ExampleKernel class and the NearestNeighborSort class.
+ * The template parameters SortPolicy and Metric define the sort function used
+ * and the metric (distance function) used.  More information on those classes
+ * can be found in the NearestNeighborSort class and the kernel::ExampleKernel
+ * class.
  *
- * This class has several parameters configurable by the CLI interface:
- *
- * @param neighbor_search/k Parameters for the distance-based neighbor search.
- * @param neighbor_search/single_mode If set, single-tree mode will be used
- *     (instead of dual-tree mode).
- * @param neighbor_search/naive_mode If set, naive computation will be used (no
- *     trees). This overrides the single_mode flag.
- *
- * @tparam Kernel The kernel function; see kernel::ExampleKernel.
  * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
+ * @tparam MetricType The metric to use for computation.
+ * @tparam TreeType The tree type to use.
  */
-template<typename MetricType = mlpack::metric::SquaredEuclideanDistance,
-         typename SortPolicy = NearestNeighborSort>
+template<typename SortPolicy = NearestNeighborSort,
+         typename MetricType = mlpack::metric::SquaredEuclideanDistance,
+         typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
+                                                   QueryStat<SortPolicy> > >
 class NeighborSearch
 {
-  /**
-   * Extra data for each node in the tree.  For neighbor searches, each node
-   * only needs to store a bound on neighbor distances.
-   */
-  class QueryStat
-  {
-   public:
-    //! The bound on the node's neighbor distances.
-    double bound_;
-
-    /**
-     * Initialize the statistic with the worst possible distance according to
-     * our sorting policy.
-     */
-    QueryStat() : bound_(SortPolicy::WorstDistance()) { }
-  };
-
-  /**
-   * Simple typedef for the trees, which use a bound and a QueryStat (to store
-   * distances for each node).  The bound should be configurable...
-   */
-  typedef tree::BinarySpaceTree<bound::HRectBound<2>, QueryStat> TreeType;
-
- private:
-  //! Reference dataset.
-  arma::mat references_;
-  //! Query dataset (may not be given).
-  arma::mat queries_;
-
-  //! Instantiation of kernel.
-  MetricType kernel_;
-
-  //! Pointer to the root of the reference tree.
-  TreeType* reference_tree_;
-  //! Pointer to the root of the query tree (might not exist).
-  TreeType* query_tree_;
-
-  //! Permutations of query points during tree building.
-  std::vector<size_t> old_from_new_queries_;
-  //! Permutations of reference points during tree building.
-  std::vector<size_t> old_from_new_references_;
-
-  //! Indicates if O(n^2) naive search is being used.
-  bool naive_;
-  //! Indicates if dual-tree search is being used (opposed to single-tree).
-  bool dual_mode_;
-
-  //! Number of points in a leaf.
-  size_t leaf_size_;
-
-  //! Number of neighbors to compute.
-  size_t knns_;
-
-  //! Total number of pruned nodes during the neighbor search.
-  size_t number_of_prunes_;
-
-  //! The distance to the candidate nearest neighbor for each query
-  arma::mat neighbor_distances_;
-
-  //! The indices of the candidate nearest neighbor for each query
-  arma::Mat<size_t> neighbor_indices_;
-
  public:
   /**
    * Initialize the NeighborSearch object, passing both a query and reference
-   * dataset.  An initialized distance metric can be given, for cases where the
-   * metric has internal data (i.e. the distance::MahalanobisDistance class).
+   * dataset.  Optionally, already-built trees can be passed, for the case where
+   * a special tree-building procedure is needed.  If referenceTree is given, it
+   * is assumed that the points in referenceTree correspond to the points in
+   * referenceSet.  The same is true for queryTree and querySet.  An initialized
+   * distance metric can be given, for cases where the metric has internal data
+   * (i.e. the distance::MahalanobisDistance class).
    *
-   * @param queries_in Set of query points.
-   * @param references_in Set of reference points.
-   * @param alias_matrix If true, alias the passed matrices instead of copying
-   *     them.  While this lowers memory footprint and computational load, the
-   *     points in the matrices will be rearranged during the tree-building
-   *     process!  Defaults to false.
-   * @param kernel An optional instance of the Kernel class.
+   * If naive mode is being used and a pre-built tree is given, it may not work:
+   * naive mode operates by building a one-node tree (the root node holds all
+   * the points).  If that condition is not satisfied with the pre-built tree,
+   * then naive mode will not work.
+   *
+   * @param referenceSet Set of reference points.
+   * @param querySet Set of query points.
+   * @param naive If true, O(n^2) naive search will be used (as opposed to
+   *      dual-tree search).  This overrides singleMode (if it is set to true).
+   * @param singleMode If true, single-tree search will be used (as opposed to
+   *      dual-tree search).
+   * @param leafSize Leaf size for tree construction (ignored if tree is given).
+   * @param referenceTree Optionally pass a pre-built tree for the reference
+   *      set.
+   * @param queryTree Optionally pass a pre-built tree for the query set.
+   * @param metric An optional instance of the MetricType class.
    */
-  NeighborSearch(arma::mat& queries_in, arma::mat& references_in,
-                 bool alias_matrix = false, MetricType kernel = MetricType());
+  NeighborSearch(const arma::mat& referenceSet,
+                 const arma::mat& querySet,
+                 const bool naive = false,
+                 const bool singleMode = false,
+                 const size_t leafSize = 20,
+                 TreeType* referenceTree = NULL,
+                 TreeType* queryTree = NULL,
+                 const MetricType metric = MetricType());
 
   /**
-   * Initialize the NeighborSearch object, passing only one dataset.  In this
-   * case, the query dataset is equivalent to the reference dataset, with one
-   * caveat: for any point, the returned list of neighbors will not include
-   * itself.  An initialized distance metric can be given, for cases where the
-   * metric has internal data (i.e. the distance::MahalanobisDistance class).
+   * Initialize the NeighborSearch object, passing only one dataset, which is
+   * used as both the query and the reference dataset.  Optionally, an
+   * already-built tree can be passed, for the case where a special
+   * tree-building procedure is needed.  If referenceTree is given, it is
+   * assumed that the points in referenceTree correspond to the points in
+   * referenceSet.  An initialized distance metric can be given, for cases where
+   * the metric has internal data (i.e. the distance::MahalanobisDistance
+   * class).
    *
-   * @param references_in Set of reference points.
-   * @param alias_matrix If true, alias the passed matrices instead of copying
-   *     them.  While this lowers memory footprint and computational load, the
-   *     points in the matrices will be rearranged during the tree-building
-   *     process!  Defaults to false.
-   * @param kernel An optional instance of the Kernel class.
+   * If naive mode is being used and a pre-built tree is given, it may not work:
+   * naive mode operates by building a one-node tree (the root node holds all
+   * the points).  If that condition is not satisfied with the pre-built tree,
+   * then naive mode will not work.
+   *
+   * @param referenceSet Set of reference points.
+   * @param naive If true, O(n^2) naive search will be used (as opposed to
+   *      dual-tree search).  This overrides singleMode (if it is set to true).
+   * @param singleMode If true, single-tree search will be used (as opposed to
+   *      dual-tree search).
+   * @param leafSize Leaf size for tree construction (ignored if tree is given).
+   * @param referenceTree Optionally pass a pre-built tree for the reference
+   *      set.
+   * @param metric An optional instance of the MetricType class.
    */
-  NeighborSearch(arma::mat& references_in, bool alias_matrix = false,
-                 MetricType kernel = MetricType());
+  NeighborSearch(const arma::mat& referenceSet,
+                 const bool naive = false,
+                 const bool singleMode = false,
+                 const size_t leafSize = 20,
+                 TreeType* referenceTree = NULL,
+                 const MetricType metric = MetricType());
 
   /**
    * Delete the NeighborSearch object. The tree is the only member we are
@@ -167,70 +150,123 @@
    * number of points in the query dataset and k is the number of neighbors
    * being searched for.
    *
-   * The parameter k is set through the CLI interface, not in the arguments to
-   * this method; this allows users to specify k on the command line
-   * ("--neighbor_search/k").  See the CLI documentation for further information
-   * on how to use this functionality.
-   *
-   * @param resulting_neighbors Matrix storing lists of neighbors for each query
+   * @param k Number of neighbors to search for.
+   * @param resultingNeighbors Matrix storing lists of neighbors for each query
    *     point.
    * @param distances Matrix storing distances of neighbors for each query
    *     point.
    */
-  void ComputeNeighbors(arma::Mat<size_t>& resulting_neighbors,
+  void ComputeNeighbors(const size_t k,
+                        arma::Mat<size_t>& resultingNeighbors,
                         arma::mat& distances);
 
  private:
   /**
    * Perform exhaustive computation between two leaves, comparing every node in
    * the leaf to the other leaf to find the furthest neighbor.  The
-   * neighbor_indices_ and neighbor_distances_ matrices will be updated with the
-   * changed information.
+   * neighbors and distances matrices will be updated with the changed
+   * information.
    *
-   * @param query_node Node in query tree.  This should be a leaf
+   * @param queryNode Node in query tree.  This should be a leaf
    *     (bottom-level).
-   * @param reference_node Node in reference tree.  This should be a leaf
+   * @param referenceNode Node in reference tree.  This should be a leaf
    *     (bottom-level).
+   * @param neighbors List of neighbors for each point.
+   * @param distances List of distances for each point.
    */
-  void ComputeBaseCase_(TreeType* query_node, TreeType* reference_node);
+  void ComputeBaseCase(TreeType* queryNode,
+                       TreeType* referenceNode,
+                       arma::Mat<size_t>& neighbors,
+                       arma::mat& distances);
 
   /**
    * Recurse down the trees, computing base case computations when the leaves
    * are reached.
    *
-   * @param query_node Node in query tree.
-   * @param reference_node Node in reference tree.
-   * @param lower_bound The lower bound; if above this, we can prune.
+   * @param queryNode Node in query tree.
+   * @param referenceNode Node in reference tree.
+   * @param lowerBound The lower bound; if above this, we can prune.
+   * @param neighbors List of neighbors for each point.
+   * @param distances List of distances for each point.
    */
-  void ComputeDualNeighborsRecursion_(TreeType* query_node,
-                                      TreeType* reference_node,
-                                      double lower_bound);
+  void ComputeDualNeighborsRecursion(TreeType* queryNode,
+                                     TreeType* referenceNode,
+                                     const double lowerBound,
+                                     arma::Mat<size_t>& neighbors,
+                                     arma::mat& distances);
 
   /**
    * Perform a recursion only on the reference tree; the query point is given.
-   * This method is similar to ComputeBaseCase_().
+   * This method is similar to ComputeBaseCase().
    *
-   * @param point_id Index of query point.
+   * @param pointId Index of query point.
    * @param point The query point.
-   * @param reference_node Reference node.
-   * @param best_dist_so_far Best distance to a node so far -- used for pruning.
+   * @param referenceNode Reference node.
+   * @param bestDistSoFar Best distance to a node so far -- used for pruning.
+   * @param neighbors List of neighbors for each point.
+   * @param distances List of distances for each point.
    */
-  void ComputeSingleNeighborsRecursion_(size_t point_id, arma::vec& point,
-                                        TreeType* reference_node,
-                                        double& best_dist_so_far);
+  void ComputeSingleNeighborsRecursion(const size_t pointId,
+                                       const arma::vec& point,
+                                       TreeType* referenceNode,
+                                       double& bestDistSoFar,
+                                       arma::Mat<size_t>& neighbors,
+                                       arma::mat& distances);
 
   /**
    * Insert a point into the neighbors and distances matrices; this is a helper
    * function.
    *
-   * @param query_index Index of point whose neighbors we are inserting into.
+   * @param queryIndex Index of point whose neighbors we are inserting into.
    * @param pos Position in list to insert into.
    * @param neighbor Index of reference point which is being inserted.
    * @param distance Distance from query point to reference point.
+   * @param neighbors List of neighbors for each point.
+   * @param distances List of distances for each point.
    */
-  void InsertNeighbor(size_t query_index, size_t pos, size_t neighbor,
-                      double distance);
+  void InsertNeighbor(const size_t queryIndex,
+                      const size_t pos,
+                      const size_t neighbor,
+                      const double distance,
+                      arma::Mat<size_t>& neighbors,
+                      arma::mat& distances);
 
+  //! Copy of reference dataset (if we need it, because tree building modifies
+  //! it).
+  arma::mat referenceCopy;
+  //! Copy of query dataset (if we need it, because tree building modifies it).
+  arma::mat queryCopy;
+
+  //! Reference dataset.
+  const arma::mat& referenceSet;
+  //! Query dataset (may not be given).
+  const arma::mat& querySet;
+
+  //! Indicates if O(n^2) naive search is being used.
+  bool naive;
+  //! Indicates if single-tree search is being used (opposed to dual-tree).
+  bool singleMode;
+
+  //! Pointer to the root of the reference tree.
+  TreeType* referenceTree;
+  //! Pointer to the root of the query tree (might not exist).
+  TreeType* queryTree;
+
+  //! Indicates if we should free the reference tree at deletion time.
+  bool ownReferenceTree;
+  //! Indicates if we should free the query tree at deletion time.
+  bool ownQueryTree;
+
+  //! Instantiation of kernel.
+  MetricType metric;
+
+  //! Permutations of reference points during tree building.
+  std::vector<size_t> oldFromNewReferences;
+  //! Permutations of query points during tree building.
+  std::vector<size_t> oldFromNewQueries;
+
+  //! Total number of pruned nodes during the neighbor search.
+  size_t numberOfPrunes;
 }; // class NeighborSearch
 
 }; // namespace neighbor

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2011-12-02 01:00:03 UTC (rev 10498)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2011-12-02 02:59:07 UTC (rev 10499)
@@ -12,499 +12,594 @@
 
 using namespace mlpack::neighbor;
 
-// We call an advanced constructor of arma::mat which allows us to alias a
-// matrix (if the user has asked for that).
-template<typename MetricType, typename SortPolicy>
-NeighborSearch<MetricType, SortPolicy>::NeighborSearch(arma::mat& queries_in,
-                                                   arma::mat& references_in,
-                                                   bool alias_matrix,
-                                                   MetricType kernel) :
-    references_(references_in.memptr(), references_in.n_rows,
-        references_in.n_cols, !alias_matrix),
-    queries_(queries_in.memptr(), queries_in.n_rows, queries_in.n_cols,
-        !alias_matrix),
-    kernel_(kernel),
-    naive_(CLI::GetParam<bool>("neighbor_search/naive_mode")),
-    dual_mode_(!(naive_ || CLI::GetParam<bool>("neighbor_search/single_mode"))),
-    number_of_prunes_(0)
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::
+NeighborSearch(const arma::mat& referenceSet,
+               const arma::mat& querySet,
+               const bool naive,
+               const bool singleMode,
+               const size_t leafSize,
+               TreeType* referenceTree,
+               TreeType* queryTree,
+               const MetricType metric) :
+    referenceCopy(referenceTree ? 0 : referenceSet),
+    queryCopy(queryTree ? 0 : querySet),
+    referenceSet(referenceTree ? referenceSet : referenceCopy),
+    querySet(queryTree ? querySet : queryCopy),
+    naive(naive),
+    singleMode(!naive && singleMode), // No single mode if naive.
+    referenceTree(referenceTree),
+    queryTree(queryTree),
+    ownReferenceTree(!referenceTree), // False if a tree was passed.
+    ownQueryTree(!queryTree), // False if a tree was passed.
+    metric(metric),
+    numberOfPrunes(0)
 {
   // C++11 will allow us to call out to other constructors so we can avoid this
   // copypasta problem.
 
-  // Get the leaf size; naive ensures that the entire tree is one node
-  if (naive_)
-    CLI::GetParam<int>("tree/leaf_size") =
-        std::max(queries_.n_cols, references_.n_cols);
+  // We'll time tree building, but only if we are building trees.
+  if (!referenceTree || !queryTree)
+    Timers::StartTimer("neighbor_search/tree_building");
 
-  // K-nearest neighbors initialization
-  knns_ = CLI::GetParam<int>("neighbor_search/k");
+  if (!referenceTree)
+  {
+    // Construct as a naive object if we need to.
+    if (naive)
+      this->referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+          referenceSet.n_cols /* everything in one leaf */);
+    else
+      this->referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+          leafSize);
+  }
 
-  // Initialize the list of nearest neighbor candidates
-  neighbor_indices_.set_size(knns_, queries_.n_cols);
+  if (!queryTree)
+  {
+    // Construct as a naive object if we need to.
+    if (naive)
+      this->queryTree = new TreeType(queryCopy, oldFromNewQueries,
+          querySet.n_cols);
+    else
+      this->queryTree = new TreeType(queryCopy, oldFromNewQueries, leafSize);
+  }
 
-  // Initialize the vector of upper bounds for each point.
-  neighbor_distances_.set_size(knns_, queries_.n_cols);
-  neighbor_distances_.fill(SortPolicy::WorstDistance());
+  // Stop the timer we started above (if we need to).
+  if (!referenceTree || !queryTree)
+    Timers::StopTimer("neighbor_search/tree_building");
+}
 
-  // We'll time tree building
-  Timers::StartTimer("neighbor_search/tree_building");
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::
+NeighborSearch(const arma::mat& referenceSet,
+               const bool naive,
+               const bool singleMode,
+               const size_t leafSize,
+               TreeType* referenceTree,
+               const MetricType metric) :
+    referenceCopy(referenceTree ? 0 : referenceSet),
+    referenceSet(referenceTree ? referenceSet : referenceCopy),
+    querySet(referenceTree ? referenceSet : referenceCopy),
+    naive(naive),
+    singleMode(!naive && singleMode), // No single mode if naive.
+    referenceTree(referenceTree),
+    queryTree(NULL),
+    ownReferenceTree(!referenceTree),
+    ownQueryTree(false), // Since it will be the same as referenceTree.
+    metric(metric),
+    numberOfPrunes(0)
+{
+  // We'll time tree building, but only if we are building trees.
+  if (!referenceTree)
+  {
+    Timers::StartTimer("neighbor_search/tree_building");
 
-  // This call makes each tree from a matrix, leaf size, and two arrays
-  // that record the permutation of the data points
-  query_tree_ = new TreeType(queries_, old_from_new_queries_);
-  reference_tree_ = new TreeType(references_, old_from_new_references_);
+    // Construct as a naive object if we need to.
+    if (naive)
+      this->referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+          referenceSet.n_cols /* everything in one leaf */);
+    else
+      this->referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+          leafSize);
 
-  // Stop the timer we started above
-  Timers::StopTimer("neighbor_search/tree_building");
+    // Stop the timer we started above.
+    Timers::StopTimer("neighbor_search/tree_building");
+  }
 }
 
-// We call an advanced constructor of arma::mat which allows us to alias a
-// matrix (if the user has asked for that).
-template<typename MetricType, typename SortPolicy>
-NeighborSearch<MetricType, SortPolicy>::NeighborSearch(arma::mat& references_in,
-                                                   bool alias_matrix,
-                                                   MetricType kernel) :
-    references_(references_in.memptr(), references_in.n_rows,
-        references_in.n_cols, !alias_matrix),
-    queries_(references_.memptr(), references_.n_rows, references_.n_cols,
-        false),
-    kernel_(kernel),
-    naive_(CLI::GetParam<bool>("neighbor_search/naive_mode")),
-    dual_mode_(!(naive_ || CLI::GetParam<bool>("neighbor_search/single_mode"))),
-    number_of_prunes_(0)
+/**
+ * The tree is the only member we may be responsible for deleting.  The others
+ * will take care of themselves.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
 {
-  // Get the leaf size from the module
-  if (naive_)
-    CLI::GetParam<int>("tree/leaf_size") =
-        std::max(queries_.n_cols, references_.n_cols);
+  if (ownReferenceTree)
+    delete referenceTree;
+  if (ownQueryTree)
+    delete queryTree;
+}
 
-  // K-nearest neighbors initialization
-  knns_ = CLI::GetParam<int>("neighbor_search/k");
+/**
+ * Computes the best neighbors and stores them in resultingNeighbors and
+ * distances.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeNeighbors(
+    const size_t k,
+    arma::Mat<size_t>& resultingNeighbors,
+    arma::mat& distances)
+{
+  Timers::StartTimer("neighbor_search/computing_neighbors");
 
-  // Initialize the list of nearest neighbor candidates
-  neighbor_indices_.set_size(knns_, references_.n_cols);
+  // If we have built the trees ourselves, then we will have to map all the
+  // indices back to their original indices when this computation is finished.
+  // To avoid an extra copy, we will store the neighbors and distances in a
+  // separate matrix.
+  arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
+  arma::mat* distancePtr = &distances;
 
-  // Initialize the vector of upper bounds for each point.
-  neighbor_distances_.set_size(knns_, references_.n_cols);
-  neighbor_distances_.fill(SortPolicy::WorstDistance());
+  if (ownQueryTree || (ownReferenceTree && !queryTree))
+    distancePtr = new arma::mat; // Query indices need to be mapped.
+  if (ownReferenceTree || ownQueryTree)
+    neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
 
-  // We'll time tree building
-  Timers::StartTimer("neighbor_search/tree_building");
+  // Set the size of the neighbor and distance matrices.
+  neighborPtr->set_size(k, querySet.n_cols);
+  distancePtr->set_size(k, querySet.n_cols);
+  distancePtr->fill(SortPolicy::WorstDistance());
 
-  // This call makes each tree from a matrix, leaf size, and two arrays
-  // that record the permutation of the data points
-  // Instead of NULL, it is possible to specify an array new_from_old_
-  query_tree_ = NULL;
-  reference_tree_ = new TreeType(references_, old_from_new_references_);
+  if (naive)
+  {
+    // Run the base case computation on all nodes
+    if (queryTree)
+      ComputeBaseCase(queryTree, referenceTree, *neighborPtr, *distancePtr);
+    else
+      ComputeBaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
+  }
+  else
+  {
+    if (singleMode)
+    {
+      // Do one tenth of the query set at a time.
+      size_t chunk = querySet.n_cols / 10;
 
-  // Stop the timer we started above
-  Timers::StopTimer("neighbor_search/tree_building");
-}
+      for (size_t i = 0; i < 10; i++)
+      {
+        for (size_t j = 0; j < chunk; j++)
+        {
+          double worstDistance = SortPolicy::WorstDistance();
+          ComputeSingleNeighborsRecursion(i * chunk + j,
+              querySet.unsafe_col(i * chunk + j), referenceTree, worstDistance,
+              *neighborPtr, *distancePtr);
+        }
+      }
 
-/**
- * The tree is the only member we are responsible for deleting.  The others will
- * take care of themselves.
- */
-template<typename MetricType, typename SortPolicy>
-NeighborSearch<MetricType, SortPolicy>::~NeighborSearch()
-{
-  if (reference_tree_ != query_tree_)
-    delete reference_tree_;
-  if (query_tree_ != NULL)
-    delete query_tree_;
-}
+      // The last tenth is differently sized...
+      for (size_t i = 0; i < querySet.n_cols % 10; i++)
+      {
+        size_t ind = (querySet.n_cols / 10) * 10 + i;
+        double worstDistance = SortPolicy::WorstDistance();
+        ComputeSingleNeighborsRecursion(ind, querySet.unsafe_col(ind),
+            referenceTree, worstDistance, *neighborPtr, *distancePtr);
+      }
+    }
+    else // Dual-tree recursion.
+    {
+      // Start on the root of each tree.
+      if (queryTree)
+      {
+        ComputeDualNeighborsRecursion(queryTree, referenceTree,
+            SortPolicy::BestNodeToNodeDistance(queryTree, referenceTree),
+            *neighborPtr, *distancePtr);
+      }
+      else
+      {
+        ComputeDualNeighborsRecursion(referenceTree, referenceTree,
+            SortPolicy::BestNodeToNodeDistance(referenceTree, referenceTree),
+            *neighborPtr, *distancePtr);
+      }
+    }
+  }
 
+  Timers::StopTimer("neighbor_search/computing_neighbors");
+
+  // Now, do we need to do mapping of indices?
+  if (!ownReferenceTree && !ownQueryTree)
+  {
+    // No mapping needed.  We are done.
+    return;
+  }
+  else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+  {
+    // Set size of output matrices correctly.
+    resultingNeighbors.set_size(k, querySet.n_cols);
+    distances.set_size(k, querySet.n_cols);
+
+    for (size_t i = 0; i < distances.n_cols; i++)
+    {
+      // Map distances (copy a column).
+      distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+
+      // Map indices of neighbors.
+      for (size_t j = 0; j < distances.n_rows; j++)
+      {
+        resultingNeighbors(j, oldFromNewQueries[i]) =
+            oldFromNewReferences[(*neighborPtr)(j, i)];
+      }
+    }
+
+    // Finished with temporary matrices.
+    delete neighborPtr;
+    delete distancePtr;
+  }
+  else if (ownReferenceTree)
+  {
+    if (!queryTree) // No query tree -- map both references and queries.
+    {
+      resultingNeighbors.set_size(k, querySet.n_cols);
+      distances.set_size(k, querySet.n_cols);
+
+      for (size_t i = 0; i < distances.n_cols; i++)
+      {
+        // Map distances (copy a column).
+        distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
+
+        // Map indices of neighbors.
+        for (size_t j = 0; j < distances.n_rows; j++)
+        {
+          resultingNeighbors(j, oldFromNewReferences[i]) =
+              oldFromNewReferences[(*neighborPtr)(j, i)];
+        }
+      }
+    }
+    else // Map only references.
+    {
+      // Set size of neighbor indices matrix correctly.
+      resultingNeighbors.set_size(k, querySet.n_cols);
+
+      // Map indices of neighbors.
+      for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
+      {
+        for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
+        {
+          resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+        }
+      }
+    }
+
+    // Finished with temporary matrix.
+    delete neighborPtr;
+  }
+  else if (ownQueryTree)
+  {
+    // Set size of matrices correctly.
+    resultingNeighbors.set_size(k, querySet.n_cols);
+    distances.set_size(k, querySet.n_cols);
+
+    for (size_t i = 0; i < distances.n_cols; i++)
+    {
+      // Map distances (copy a column).
+      distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+
+      // Map indices of neighbors.
+      resultingNeighbors.col(oldFromNewQueries[i]) = neighborPtr->col(i);
+    }
+
+    // Finished with temporary matrices.
+    delete neighborPtr;
+    delete distancePtr;
+  }
+} // ComputeNeighbors
+
 /**
  * Performs exhaustive computation between two leaves.
  */
-template<typename MetricType, typename SortPolicy>
-void NeighborSearch<MetricType, SortPolicy>::ComputeBaseCase_(
-      TreeType* query_node,
-      TreeType* reference_node)
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeBaseCase(
+      TreeType* queryNode,
+      TreeType* referenceNode,
+      arma::Mat<size_t>& neighbors,
+      arma::mat& distances)
 {
   // Used to find the query node's new upper bound.
-  double query_worst_distance = SortPolicy::BestDistance();
+  double queryWorstDistance = SortPolicy::BestDistance();
 
-  // node->begin() is the index of the first point in the node,
-  // node->end is one past the last index.
-  for (size_t query_index = query_node->Begin();
-      query_index < query_node->End(); query_index++)
+  // node->Begin() is the index of the first point in the node,
+  // node->End() is one past the last index.
+  for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
+       queryIndex++)
   {
     // Get the query point from the matrix.
-    arma::vec query_point = queries_.unsafe_col(query_index);
+    arma::vec queryPoint = querySet.unsafe_col(queryIndex);
 
-    double query_to_node_distance =
-      SortPolicy::BestPointToNodeDistance(query_point, reference_node);
+    double queryToNodeDistance =
+      SortPolicy::BestPointToNodeDistance(queryPoint, referenceNode);
 
-    if (SortPolicy::IsBetter(query_to_node_distance,
-        neighbor_distances_(knns_ - 1, query_index)))
+    if (SortPolicy::IsBetter(queryToNodeDistance,
+        distances(distances.n_rows - 1, queryIndex)))
     {
       // We'll do the same for the references.
-      for (size_t reference_index = reference_node->Begin();
-          reference_index < reference_node->End(); reference_index++)
+      for (size_t referenceIndex = referenceNode->Begin();
+          referenceIndex < referenceNode->End(); referenceIndex++)
       {
         // Confirm that points do not identify themselves as neighbors
         // in the monochromatic case.
-        if (reference_node != query_node || reference_index != query_index)
+        if (referenceNode != queryNode || referenceIndex != queryIndex)
         {
-          arma::vec reference_point = references_.unsafe_col(reference_index);
+          arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
 
-          double distance = kernel_.Evaluate(query_point, reference_point);
+          double distance = metric.Evaluate(queryPoint, referencePoint);
 
           // If the reference point is closer than any of the current
           // candidates, add it to the list.
-          arma::vec query_dist = neighbor_distances_.unsafe_col(query_index);
-          size_t insert_position =  SortPolicy::SortDistance(query_dist,
+          arma::vec queryDist = distances.unsafe_col(queryIndex);
+          size_t insertPosition = SortPolicy::SortDistance(queryDist,
               distance);
 
-          if (insert_position != (size_t() - 1))
-            InsertNeighbor(query_index, insert_position, reference_index,
-                distance);
+          if (insertPosition != (size_t() - 1))
+            InsertNeighbor(queryIndex, insertPosition, referenceIndex,
+                distance, neighbors, distances);
         }
       }
     }
 
     // We need to find the upper bound distance for this query node
-    if (SortPolicy::IsBetter(query_worst_distance,
-        neighbor_distances_(knns_ - 1, query_index)))
-      query_worst_distance = neighbor_distances_(knns_ - 1, query_index);
+    if (SortPolicy::IsBetter(queryWorstDistance,
+        distances(distances.n_rows - 1, queryIndex)))
+      queryWorstDistance = distances(distances.n_rows - 1, queryIndex);
   }
 
-  // Update the upper bound for the query_node
-  query_node->Stat().bound_ = query_worst_distance;
+  // Update the upper bound for the queryNode
+  queryNode->Stat().Bound() = queryWorstDistance;
 
-} // ComputeBaseCase_
+} // ComputeBaseCase()
 
 /**
  * The recursive function for dual tree.
  */
-template<typename MetricType, typename SortPolicy>
-void NeighborSearch<MetricType, SortPolicy>::ComputeDualNeighborsRecursion_(
-      TreeType* query_node,
-      TreeType* reference_node,
-      double lower_bound)
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+ComputeDualNeighborsRecursion(
+    TreeType* queryNode,
+    TreeType* referenceNode,
+    const double lowerBound,
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances)
 {
-  if (SortPolicy::IsBetter(query_node->Stat().bound_, lower_bound))
+  if (SortPolicy::IsBetter(queryNode->Stat().Bound(), lowerBound))
   {
-    number_of_prunes_++; // Pruned by distance; the nodes cannot be any closer
-    return;              // than the already established lower bound.
+    numberOfPrunes++; // Pruned by distance; the nodes cannot be any closer
+    return;           // than the already established lower bound.
   }
 
-  if (query_node->IsLeaf() && reference_node->IsLeaf())
+  if (queryNode->IsLeaf() && referenceNode->IsLeaf())
   {
-    ComputeBaseCase_(query_node, reference_node); // Base case: both are leaves.
+    // Base case: both are leaves.
+    ComputeBaseCase(queryNode, referenceNode, neighbors, distances);
     return;
   }
 
-  if (query_node->IsLeaf())
+  if (queryNode->IsLeaf())
   {
     // We must keep descending down the reference node to get to a leaf.
 
     // We'll order the computation by distance; descend in the direction of less
     // distance first.
-    double left_distance = SortPolicy::BestNodeToNodeDistance(query_node,
-        reference_node->Left());
-    double right_distance = SortPolicy::BestNodeToNodeDistance(query_node,
-        reference_node->Right());
+    double leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode,
+        referenceNode->Left());
+    double rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode,
+        referenceNode->Right());
 
-    if (SortPolicy::IsBetter(left_distance, right_distance))
+    if (SortPolicy::IsBetter(leftDistance, rightDistance))
     {
-      ComputeDualNeighborsRecursion_(query_node, reference_node->Left(),
-          left_distance);
-      ComputeDualNeighborsRecursion_(query_node, reference_node->Right(),
-          right_distance);
+      ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
+          leftDistance, neighbors, distances);
+      ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
+          rightDistance, neighbors, distances);
     }
     else
     {
-      ComputeDualNeighborsRecursion_(query_node, reference_node->Right(),
-          right_distance);
-      ComputeDualNeighborsRecursion_(query_node, reference_node->Left(),
-          left_distance);
+      ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
+          rightDistance, neighbors, distances);
+      ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
+          leftDistance, neighbors, distances);
     }
     return;
   }
 
-  if (reference_node->IsLeaf())
+  if (referenceNode->IsLeaf())
   {
     // We must descend down the query node to get to a leaf.
-    double left_distance = SortPolicy::BestNodeToNodeDistance(
-        query_node->Left(), reference_node);
-    double right_distance = SortPolicy::BestNodeToNodeDistance(
-        query_node->Right(), reference_node);
+    double leftDistance = SortPolicy::BestNodeToNodeDistance(
+        queryNode->Left(), referenceNode);
+    double rightDistance = SortPolicy::BestNodeToNodeDistance(
+        queryNode->Right(), referenceNode);
 
-    ComputeDualNeighborsRecursion_(query_node->Left(), reference_node,
-        left_distance);
-    ComputeDualNeighborsRecursion_(query_node->Right(), reference_node,
-        right_distance);
+    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode,
+        leftDistance, neighbors, distances);
+    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode,
+        rightDistance, neighbors, distances);
 
-    // We need to update the upper bound based on the new upper bounds of
-    // the children
-    double left_bound = query_node->Left()->Stat().bound_;
-    double right_bound = query_node->Right()->Stat().bound_;
+    // We need to update the upper bound based on the new upper bounds of the
+    // children.
+    double leftBound = queryNode->Left()->Stat().Bound();
+    double rightBound = queryNode->Right()->Stat().Bound();
 
-    if (SortPolicy::IsBetter(left_bound, right_bound))
-      query_node->Stat().bound_ = right_bound;
+    if (SortPolicy::IsBetter(leftBound, rightBound))
+      queryNode->Stat().Bound() = rightBound;
     else
-      query_node->Stat().bound_ = left_bound;
+      queryNode->Stat().Bound() = leftBound;
 
     return;
   }
 
   // Neither side is a leaf; so we recurse on all combinations of both.  The
   // calculations are ordered by distance.
-  double left_distance = SortPolicy::BestNodeToNodeDistance(query_node->Left(),
-      reference_node->Left());
-  double right_distance = SortPolicy::BestNodeToNodeDistance(query_node->Left(),
-      reference_node->Right());
+  double leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Left(),
+      referenceNode->Left());
+  double rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Left(),
+      referenceNode->Right());
 
-  // Recurse on query_node->left() first.
-  if (SortPolicy::IsBetter(left_distance, right_distance))
+  // Recurse on queryNode->left() first.
+  if (SortPolicy::IsBetter(leftDistance, rightDistance))
   {
-    ComputeDualNeighborsRecursion_(query_node->Left(), reference_node->Left(),
-        left_distance);
-    ComputeDualNeighborsRecursion_(query_node->Left(), reference_node->Right(),
-        right_distance);
+    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
+        leftDistance, neighbors, distances);
+    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
+        rightDistance, neighbors, distances);
   }
   else
   {
-    ComputeDualNeighborsRecursion_(query_node->Left(), reference_node->Right(),
-        right_distance);
-    ComputeDualNeighborsRecursion_(query_node->Left(), reference_node->Left(),
-        left_distance);
+    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
+        rightDistance, neighbors, distances);
+    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
+        leftDistance, neighbors, distances);
   }
 
-  left_distance = SortPolicy::BestNodeToNodeDistance(query_node->Right(),
-      reference_node->Left());
-  right_distance = SortPolicy::BestNodeToNodeDistance(query_node->Right(),
-      reference_node->Right());
+  leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
+      referenceNode->Left());
+  rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
+      referenceNode->Right());
 
-  // Now recurse on query_node->right().
-  if (SortPolicy::IsBetter(left_distance, right_distance))
+  // Now recurse on queryNode->right().
+  if (SortPolicy::IsBetter(leftDistance, rightDistance))
   {
-    ComputeDualNeighborsRecursion_(query_node->Right(), reference_node->Left(),
-        left_distance);
-    ComputeDualNeighborsRecursion_(query_node->Right(), reference_node->Right(),
-        right_distance);
+    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
+        leftDistance, neighbors, distances);
+    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
+        rightDistance, neighbors, distances);
   }
   else
   {
-    ComputeDualNeighborsRecursion_(query_node->Right(), reference_node->Right(),
-        right_distance);
-    ComputeDualNeighborsRecursion_(query_node->Right(), reference_node->Left(),
-        left_distance);
+    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
+        rightDistance, neighbors, distances);
+    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
+        leftDistance, neighbors, distances);
   }
 
   // Update the upper bound as above
-  double left_bound = query_node->Left()->Stat().bound_;
-  double right_bound = query_node->Right()->Stat().bound_;
+  double leftBound = queryNode->Left()->Stat().Bound();
+  double rightBound = queryNode->Right()->Stat().Bound();
 
-  if (SortPolicy::IsBetter(left_bound, right_bound))
-    query_node->Stat().bound_ = right_bound;
+  if (SortPolicy::IsBetter(leftBound, rightBound))
+    queryNode->Stat().Bound() = rightBound;
   else
-    query_node->Stat().bound_ = left_bound;
+    queryNode->Stat().Bound() = leftBound;
 
-} // ComputeDualNeighborsRecursion_
+} // ComputeDualNeighborsRecursion()
 
-template<typename MetricType, typename SortPolicy>
-void NeighborSearch<MetricType, SortPolicy>::ComputeSingleNeighborsRecursion_(
-      size_t point_id,
-      arma::vec& point,
-      TreeType* reference_node,
-      double& best_dist_so_far)
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+ComputeSingleNeighborsRecursion(const size_t pointId,
+                                const arma::vec& point,
+                                TreeType* referenceNode,
+                                double& bestDistSoFar,
+                                arma::Mat<size_t>& neighbors,
+                                arma::mat& distances)
 {
-  if (reference_node->IsLeaf())
+  if (referenceNode->IsLeaf())
   {
     // Base case: reference node is a leaf.
-    for (size_t reference_index = reference_node->Begin();
-        reference_index < reference_node->End(); reference_index++)
+    for (size_t referenceIndex = referenceNode->Begin();
+        referenceIndex < referenceNode->End(); referenceIndex++)
     {
       // Confirm that points do not identify themselves as neighbors
       // in the monochromatic case
-      if (!(references_.memptr() == queries_.memptr() &&
-            reference_index == point_id))
+      if (!(referenceSet.memptr() == querySet.memptr() &&
+            referenceIndex == pointId))
       {
-        arma::vec reference_point = references_.unsafe_col(reference_index);
+        arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
 
-        double distance = kernel_.Evaluate(point, reference_point);
+        double distance = metric.Evaluate(point, referencePoint);
 
         // If the reference point is better than any of the current candidates,
         // insert it into the list correctly.
-        arma::vec query_dist = neighbor_distances_.unsafe_col(point_id);
-        size_t insert_position = SortPolicy::SortDistance(query_dist,
-            distance);
+        arma::vec queryDist = distances.unsafe_col(pointId);
+        size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
 
-        if (insert_position != (size_t() - 1))
-          InsertNeighbor(point_id, insert_position, reference_index, distance);
+        if (insertPosition != (size_t() - 1))
+          InsertNeighbor(pointId, insertPosition, referenceIndex, distance,
+              neighbors, distances);
       }
-    } // for reference_index
+    } // for referenceIndex
 
-    best_dist_so_far = neighbor_distances_(knns_ - 1, point_id);
+    bestDistSoFar = distances(distances.n_rows - 1, pointId);
   }
   else
   {
     // We'll order the computation by distance.
-    double left_distance = SortPolicy::BestPointToNodeDistance(point,
-        reference_node->Left());
-    double right_distance = SortPolicy::BestPointToNodeDistance(point,
-        reference_node->Right());
+    double leftDistance = SortPolicy::BestPointToNodeDistance(point,
+        referenceNode->Left());
+    double rightDistance = SortPolicy::BestPointToNodeDistance(point,
+        referenceNode->Right());
 
     // Recurse in the best direction first.
-    if (SortPolicy::IsBetter(left_distance, right_distance))
+    if (SortPolicy::IsBetter(leftDistance, rightDistance))
     {
-      if (SortPolicy::IsBetter(best_dist_so_far, left_distance))
-        number_of_prunes_++; // Prune; no possibility of finding a better point.
+      if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
+        numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        ComputeSingleNeighborsRecursion_(point_id, point,
-            reference_node->Left(), best_dist_so_far);
+        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
+            bestDistSoFar, neighbors, distances);
 
-      if (SortPolicy::IsBetter(best_dist_so_far, right_distance))
-        number_of_prunes_++; // Prune; no possibility of finding a better point.
+      if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
+        numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        ComputeSingleNeighborsRecursion_(point_id, point,
-            reference_node->Right(), best_dist_so_far);
+        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
+            bestDistSoFar, neighbors, distances);
 
     }
     else
     {
-      if (SortPolicy::IsBetter(best_dist_so_far, right_distance))
-        number_of_prunes_++; // Prune; no possibility of finding a better point.
+      if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
+        numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        ComputeSingleNeighborsRecursion_(point_id, point,
-            reference_node->Right(), best_dist_so_far);
+        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
+            bestDistSoFar, neighbors, distances);
 
-      if (SortPolicy::IsBetter(best_dist_so_far, left_distance))
-        number_of_prunes_++; // Prune; no possibility of finding a better point.
+      if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
+        numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        ComputeSingleNeighborsRecursion_(point_id, point,
-            reference_node->Left(), best_dist_so_far);
+        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
+            bestDistSoFar, neighbors, distances);
     }
   }
 }
 
 /**
- * Computes the best neighbors and stores them in resulting_neighbors and
- * distances.
- */
-template<typename MetricType, typename SortPolicy>
-void NeighborSearch<MetricType, SortPolicy>::ComputeNeighbors(
-      arma::Mat<size_t>& resulting_neighbors,
-      arma::mat& distances)
-{
-  Timers::StartTimer("neighbor_search/computing_neighbors");
-  if (naive_)
-  {
-    // Run the base case computation on all nodes
-    if (query_tree_)
-      ComputeBaseCase_(query_tree_, reference_tree_);
-    else
-      ComputeBaseCase_(reference_tree_, reference_tree_);
-  }
-  else
-  {
-    if (dual_mode_)
-    {
-      // Start on the root of each tree
-      if (query_tree_)
-      {
-        ComputeDualNeighborsRecursion_(query_tree_, reference_tree_,
-            SortPolicy::BestNodeToNodeDistance(query_tree_, reference_tree_));
-      }
-      else
-      {
-        ComputeDualNeighborsRecursion_(reference_tree_, reference_tree_,
-            SortPolicy::BestNodeToNodeDistance(reference_tree_,
-            reference_tree_));
-      }
-    }
-    else
-    {
-      size_t chunk = queries_.n_cols / 10;
-
-      for (size_t i = 0; i < 10; i++)
-      {
-        for (size_t j = 0; j < chunk; j++)
-        {
-          arma::vec point = queries_.unsafe_col(i * chunk + j);
-          double best_dist_so_far = SortPolicy::WorstDistance();
-          ComputeSingleNeighborsRecursion_(i * chunk + j, point,
-              reference_tree_, best_dist_so_far);
-        }
-      }
-
-      for (size_t i = 0; i < queries_.n_cols % 10; i++)
-      {
-        size_t ind = (queries_.n_cols / 10) * 10 + i;
-        arma::vec point = queries_.unsafe_col(ind);
-        double best_dist_so_far = SortPolicy::WorstDistance();
-        ComputeSingleNeighborsRecursion_(ind, point, reference_tree_,
-            best_dist_so_far);
-      }
-    }
-  }
-
-  Timers::StopTimer("neighbor_search/computing_neighbors");
-
-  // We need to initialize the results list before filling it
-  resulting_neighbors.set_size(neighbor_indices_.n_rows,
-      neighbor_indices_.n_cols);
-  distances.set_size(neighbor_distances_.n_rows, neighbor_distances_.n_cols);
-
-  // We need to map the indices back from how they have been permuted
-  if (query_tree_ != NULL)
-  {
-    for (size_t i = 0; i < neighbor_indices_.n_cols; i++)
-    {
-      for (size_t k = 0; k < neighbor_indices_.n_rows; k++)
-      {
-        resulting_neighbors(k, old_from_new_queries_[i]) =
-            old_from_new_references_[neighbor_indices_(k, i)];
-        distances(k, old_from_new_queries_[i]) = neighbor_distances_(k, i);
-      }
-    }
-  }
-  else
-  {
-    for (size_t i = 0; i < neighbor_indices_.n_cols; i++)
-    {
-      for (size_t k = 0; k < neighbor_indices_.n_rows; k++)
-      {
-        resulting_neighbors(k, old_from_new_references_[i]) =
-            old_from_new_references_[neighbor_indices_(k, i)];
-        distances(k, old_from_new_references_[i]) = neighbor_distances_(k, i);
-      }
-    }
-  }
-} // ComputeNeighbors
-
-/***
  * Helper function to insert a point into the neighbors and distances matrices.
  *
- * @param query_index Index of point whose neighbors we are inserting into.
+ * @param queryIndex Index of point whose neighbors we are inserting into.
  * @param pos Position in list to insert into.
  * @param neighbor Index of reference point which is being inserted.
  * @param distance Distance from query point to reference point.
  */
-template<typename MetricType, typename SortPolicy>
-void NeighborSearch<MetricType, SortPolicy>::InsertNeighbor(size_t query_index,
-                                                        size_t pos,
-                                                        size_t neighbor,
-                                                        double distance)
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+InsertNeighbor(const size_t queryIndex,
+               const size_t pos,
+               const size_t neighbor,
+               const double distance,
+               arma::Mat<size_t>& neighbors,
+               arma::mat& distances)
 {
   // We only memmove() if there is actually a need to shift something.
-  if (pos < (knns_ - 1))
+  if (pos < (distances.n_rows - 1))
   {
-    int len = (knns_ - 1) - pos;
-    memmove(neighbor_distances_.colptr(query_index) + (pos + 1),
-        neighbor_distances_.colptr(query_index) + pos,
+    int len = (distances.n_rows - 1) - pos;
+    memmove(distances.colptr(queryIndex) + (pos + 1),
+        distances.colptr(queryIndex) + pos,
         sizeof(double) * len);
-    memmove(neighbor_indices_.colptr(query_index) + (pos + 1),
-        neighbor_indices_.colptr(query_index) + pos,
+    memmove(neighbors.colptr(queryIndex) + (pos + 1),
+        neighbors.colptr(queryIndex) + pos,
         sizeof(size_t) * len);
   }
 
   // Now put the new information in the right index.
-  neighbor_distances_(pos, query_index) = distance;
-  neighbor_indices_(pos, query_index) = neighbor;
+  distances(pos, queryIndex) = distance;
+  neighbors(pos, queryIndex) = neighbor;
 }
 
 #endif

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp	2011-12-02 01:00:03 UTC (rev 10498)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp	2011-12-02 02:59:07 UTC (rev 10499)
@@ -26,7 +26,7 @@
  * neighbors.  Squared distances are used because they are slightly faster than
  * non-squared distances (they have one fewer call to sqrt()).
  */
-typedef NeighborSearch<metric::SquaredEuclideanDistance, NearestNeighborSort>
+typedef NeighborSearch<NearestNeighborSort, metric::SquaredEuclideanDistance>
     AllkNN;
 
 /**
@@ -35,7 +35,7 @@
  * neighbors.  Squared distances are used because they are slightly faster than
  * non-squared distances (they have one fewer call to sqrt()).
  */
-typedef NeighborSearch<metric::SquaredEuclideanDistance, FurthestNeighborSort>
+typedef NeighborSearch<FurthestNeighborSort, metric::SquaredEuclideanDistance>
     AllkFN;
 
 }; // namespace neighbor




More information about the mlpack-svn mailing list