[mlpack-svn] r10378 - in mlpack/trunk/src/mlpack/methods/neighbor_search: . sort_policies
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 23 21:42:07 EST 2011
Author: rcurtin
Date: 2011-11-23 21:42:06 -0500 (Wed, 23 Nov 2011)
New Revision: 10378
Added:
mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp
Removed:
mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cc
mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cc
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.h
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.h
mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.h
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp
Log:
Rename files to .cpp and .hpp in accordance with #152, and fix style to be in
line with decisions in #153.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt 2011-11-24 02:42:06 UTC (rev 10378)
@@ -3,15 +3,15 @@
# Define the files we need to compile.
# Anything not in this list will not be compiled into MLPACK.
set(SOURCES
- neighbor_search.h
- neighbor_search_impl.h
+ neighbor_search.hpp
+ neighbor_search_impl.hpp
sort_policies/nearest_neighbor_sort.hpp
sort_policies/nearest_neighbor_sort.cpp
sort_policies/nearest_neighbor_sort_impl.hpp
sort_policies/furthest_neighbor_sort.hpp
sort_policies/furthest_neighbor_sort.cpp
sort_policies/furthest_neighbor_sort_impl.hpp
- typedef.h
+ typedef.hpp
)
# Add directory name to sources.
@@ -24,14 +24,14 @@
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
add_executable(allknn
- allknn_main.cc
+ allknn_main.cpp
)
target_link_libraries(allknn
mlpack
)
add_executable(allkfn
- allkfn_main.cc
+ allkfn_main.cpp
)
target_link_libraries(allkfn
mlpack
Deleted: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cc
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cc 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cc 2011-11-24 02:42:06 UTC (rev 10378)
@@ -1,120 +0,0 @@
-/**
- * @file main.cc
- *
- * Implementation of the AllkNN executable. Allows some number of standard
- * options.
- *
- * @author Ryan Curtin
- */
-#include <mlpack/core.h>
-#include "neighbor_search.h"
-
-#include <string>
-#include <fstream>
-#include <iostream>
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::neighbor;
-
-// Information about the program itself.
-PROGRAM_INFO("All K-Furthest-Neighbors",
- "This program will calculate the all k-furthest-neighbors of a set of "
- "points. You may specify a separate set of reference points and query "
- "points, or just a reference set which will be used as both the reference "
- "and query set."
- "\n\n"
- "For example, the following will calculate the 5 furthest neighbors of each"
- "point in 'input.csv' and store the results in 'output.csv':"
- "\n\n"
- "$ allknn --neighbor_search/k=5 --reference_file=input.csv\n"
- " --output_file=output.csv", "neighbor_search");
-
-// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "CSV file containing the reference dataset.",
- "");
-PARAM_STRING("query_file", "CSV file containing query points (optional).",
- "", "");
-PARAM_STRING_REQ("output_file", "File to output CSV-formatted results into.",
- "");
-
-int main(int argc, char *argv[]) {
- // Give CLI the command line parameters the user passed in.
- CLI::ParseCommandLine(argc, argv);
-
- string reference_file = CLI::GetParam<string>("reference_file");
- string output_file = CLI::GetParam<string>("output_file");
-
- arma::mat reference_data;
-
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- if (!data::Load(reference_file.c_str(), reference_data))
- Log::Fatal << "Reference file " << reference_file << "not found." << endl;
-
- Log::Info << "Loaded reference data from " << reference_file << endl;
-
- // Sanity check on k value: must be greater than 0, must be less than the
- // number of reference points.
- size_t k = CLI::GetParam<int>("neighbor_search/k");
- if ((k <= 0) || (k >= reference_data.n_cols)) {
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than the number of reference points (";
- Log::Fatal << reference_data.n_cols << ")." << endl;
- }
-
- // Sanity check on leaf size.
- if (CLI::GetParam<int>("tree/leaf_size") <= 0) {
- Log::Fatal << "Invalid leaf size: "
- << CLI::GetParam<int>("allknn/leaf_size") << endl;
- }
-
- AllkFN* allkfn = NULL;
-
- if (CLI::GetParam<string>("query_file") != "") {
- string query_file = CLI::GetParam<string>("query_file");
- arma::mat query_data;
-
- if (!data::Load(query_file.c_str(), query_data))
- Log::Fatal << "Query file " << query_file << " not found" << endl;
-
- Log::Info << "Query data loaded from " << query_file << endl;
-
- Log::Info << "Building query and reference trees..." << endl;
- allkfn = new AllkFN(query_data, reference_data);
-
- } else {
- Log::Info << "Building reference tree..." << endl;
- allkfn = new AllkFN(reference_data);
- }
-
- Log::Info << "Tree(s) built." << endl;
-
- Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allkfn->ComputeNeighbors(neighbors, distances);
-
- Log::Info << "Neighbors computed." << endl;
- Log::Info << "Exporting results..." << endl;
-
- // Should be using data::Save or a related function instead of being written
- // by hand.
- try {
- ofstream out(output_file.c_str());
-
- for (size_t col = 0; col < neighbors.n_cols; col++) {
- out << col << ", ";
- for (size_t j = 0; j < (k - 1) /* last is special case */; j++) {
- out << neighbors(j, col) << ", " << distances(j, col) << ", ";
- }
- out << neighbors((k - 1), col) << ", " << distances((k - 1), col) << endl;
- }
-
- out.close();
- } catch(exception& e) {
- Log::Fatal << "Error while opening " << output_file << ": " << e.what()
- << endl;
- }
-
- delete allkfn;
-}
Copied: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp (from rev 10352, mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cc)
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -0,0 +1,131 @@
+/**
+ * @file allkfn_main.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the AllkFN executable. Allows some number of standard
+ * options.
+ */
+#include <mlpack/core.h>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "neighbor_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+// Information about the program itself.
+PROGRAM_INFO("All K-Furthest-Neighbors",
+ "This program will calculate the all k-furthest-neighbors of a set of "
+ "points. You may specify a separate set of reference points and query "
+ "points, or just a reference set which will be used as both the reference "
+ "and query set."
+ "\n\n"
+ "For example, the following will calculate the 5 furthest neighbors of each"
+ "point in 'input.csv' and store the results in 'output.csv':"
+ "\n\n"
+ "$ allknn --neighbor_search/k=5 --reference_file=input.csv\n"
+ " --output_file=output.csv", "neighbor_search");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "CSV file containing the reference dataset.",
+ "");
+PARAM_STRING("query_file", "CSV file containing query points (optional).",
+ "", "");
+PARAM_STRING_REQ("output_file", "File to output CSV-formatted results into.",
+ "");
+
+int main(int argc, char *argv[])
+{
+ // Give CLI the command line parameters the user passed in.
+ CLI::ParseCommandLine(argc, argv);
+
+ string reference_file = CLI::GetParam<string>("reference_file");
+ string output_file = CLI::GetParam<string>("output_file");
+
+ arma::mat reference_data;
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ if (!data::Load(reference_file.c_str(), reference_data))
+ Log::Fatal << "Reference file " << reference_file << "not found." << endl;
+
+ Log::Info << "Loaded reference data from " << reference_file << endl;
+
+ // Sanity check on k value: must be greater than 0, must be less than the
+ // number of reference points.
+ size_t k = CLI::GetParam<int>("neighbor_search/k");
+ if ((k <= 0) || (k >= reference_data.n_cols))
+ {
+ Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+ Log::Fatal << "than the number of reference points (";
+ Log::Fatal << reference_data.n_cols << ")." << endl;
+ }
+
+ // Sanity check on leaf size.
+ if (CLI::GetParam<int>("tree/leaf_size") <= 0)
+ {
+ Log::Fatal << "Invalid leaf size: "
+ << CLI::GetParam<int>("allknn/leaf_size") << endl;
+ }
+
+ AllkFN* allkfn = NULL;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string query_file = CLI::GetParam<string>("query_file");
+ arma::mat query_data;
+
+ if (!data::Load(query_file.c_str(), query_data))
+ Log::Fatal << "Query file " << query_file << " not found" << endl;
+
+ Log::Info << "Query data loaded from " << query_file << endl;
+
+ Log::Info << "Building query and reference trees..." << endl;
+ allkfn = new AllkFN(query_data, reference_data);
+
+ }
+ else
+ {
+ Log::Info << "Building reference tree..." << endl;
+ allkfn = new AllkFN(reference_data);
+ }
+
+ Log::Info << "Tree(s) built." << endl;
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allkfn->ComputeNeighbors(neighbors, distances);
+
+ Log::Info << "Neighbors computed." << endl;
+ Log::Info << "Exporting results..." << endl;
+
+ // Should be using data::Save or a related function instead of being written
+ // by hand.
+ try
+ {
+ ofstream out(output_file.c_str());
+
+ for (size_t col = 0; col < neighbors.n_cols; col++)
+ {
+ out << col << ", ";
+ for (size_t j = 0; j < (k - 1) /* last is special case */; j++)
+ {
+ out << neighbors(j, col) << ", " << distances(j, col) << ", ";
+ }
+ out << neighbors((k - 1), col) << ", " << distances((k - 1), col) << endl;
+ }
+
+ out.close();
+ }
+ catch (exception& e)
+ {
+ Log::Fatal << "Error while opening " << output_file << ": " << e.what()
+ << endl;
+ }
+
+ delete allkfn;
+}
Deleted: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cc
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cc 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cc 2011-11-24 02:42:06 UTC (rev 10378)
@@ -1,119 +0,0 @@
-/**
- * @file main.cc
- *
- * Implementation of the AllkNN executable. Allows some number of standard
- * options.
- *
- * @author Ryan Curtin
- */
-#include <mlpack/core.h>
-#include "neighbor_search.h"
-
-#include <string>
-#include <fstream>
-#include <iostream>
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::neighbor;
-
-// Information about the program itself.
-PROGRAM_INFO("All K-Nearest-Neighbors",
- "This program will calculate the all k-nearest-neighbors of a set of "
- "points. You may specify a separate set of reference points and query "
- "points, or just a reference set which will be used as both the reference "
- "and query set."
- "\n\n"
- "For example, the following will calculate the 5 nearest neighbors of each"
- "point in 'input.csv' and store the results in 'output.csv':"
- "\n\n"
- "$ allknn --neighbor_search/k=5 --reference_file=input.csv\n"
- " --output_file=output.csv", "neighbor_search");
-
-// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "CSV file containing the reference dataset.",
- "");
-PARAM_STRING("query_file", "CSV file containing query points (optional).",
- "", "");
-PARAM_STRING_REQ("output_file", "File to output CSV-formatted results into.",
- "");
-
-int main(int argc, char *argv[]) {
- // Give CLI the command line parameters the user passed in.
- CLI::ParseCommandLine(argc, argv);
-
- string reference_file = CLI::GetParam<string>("reference_file");
- string output_file = CLI::GetParam<string>("output_file");
-
- arma::mat reference_data;
-
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- if (!data::Load(reference_file.c_str(), reference_data))
- Log::Fatal << "Reference file " << reference_file << " not found." << endl;
-
- Log::Info << "Loaded reference data from " << reference_file << endl;
-
- // Sanity check on k value: must be greater than 0, must be less than the
- // number of reference points.
- size_t k = CLI::GetParam<int>("neighbor_search/k");
- if ((k <= 0) || (k >= reference_data.n_cols)) {
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than the number of reference points (";
- Log::Fatal << reference_data.n_cols << ")." << endl;
- }
-
- // Sanity check on leaf size.
- if (CLI::GetParam<int>("tree/leaf_size") <= 0) {
- Log::Fatal << "Invalid leaf size: " << CLI::GetParam<int>("allknn/leaf_size")
- << endl;
- }
-
- AllkNN* allknn = NULL;
-
- if (CLI::GetParam<string>("query_file") != "") {
- string query_file = CLI::GetParam<string>("query_file");
- arma::mat query_data;
-
- if (!data::Load(query_file.c_str(), query_data))
- Log::Fatal << "Query file " << query_file << " not found" << endl;
-
- Log::Info << "Query data loaded from " << query_file << endl;
-
- Log::Info << "Building query and reference trees..." << endl;
- allknn = new AllkNN(query_data, reference_data);
-
- } else {
- Log::Info << "Building reference tree..." << endl;
- allknn = new AllkNN(reference_data);
- }
-
- Log::Info << "Tree(s) built." << endl;
-
- Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allknn->ComputeNeighbors(neighbors, distances);
-
- Log::Info << "Neighbors computed." << endl;
- Log::Info << "Exporting results..." << endl;
-
- // Should be using data::Save or a related function instead of being written
- // by hand.
- try {
- ofstream out(output_file.c_str());
-
- for (size_t col = 0; col < neighbors.n_cols; col++) {
- out << col << ", ";
- for (size_t j = 0; j < (k - 1) /* last is special case */; j++) {
- out << neighbors(j, col) << ", " << distances(j, col) << ", ";
- }
- out << neighbors((k - 1), col) << ", " << distances((k - 1), col) << endl;
- }
-
- out.close();
- } catch(exception& e) {
- Log::Fatal << "Error while opening " << output_file << ": " << e.what()
- << endl;
- }
- delete allknn;
-}
Copied: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp (from rev 10352, mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cc)
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -0,0 +1,131 @@
+/**
+ * @file allknn_main.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the AllkNN executable. Allows some number of standard
+ * options.
+ */
+#include <mlpack/core.h>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "neighbor_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+// Information about the program itself.
+PROGRAM_INFO("All K-Nearest-Neighbors",
+ "This program will calculate the all k-nearest-neighbors of a set of "
+ "points. You may specify a separate set of reference points and query "
+ "points, or just a reference set which will be used as both the reference "
+ "and query set."
+ "\n\n"
+ "For example, the following will calculate the 5 nearest neighbors of each"
+ "point in 'input.csv' and store the results in 'output.csv':"
+ "\n\n"
+ "$ allknn --neighbor_search/k=5 --reference_file=input.csv\n"
+ " --output_file=output.csv", "neighbor_search");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "CSV file containing the reference dataset.",
+ "");
+PARAM_STRING("query_file", "CSV file containing query points (optional).",
+ "", "");
+PARAM_STRING_REQ("output_file", "File to output CSV-formatted results into.",
+ "");
+
+int main(int argc, char *argv[])
+{
+ // Give CLI the command line parameters the user passed in.
+ CLI::ParseCommandLine(argc, argv);
+
+ string reference_file = CLI::GetParam<string>("reference_file");
+ string output_file = CLI::GetParam<string>("output_file");
+
+ arma::mat reference_data;
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ if (!data::Load(reference_file.c_str(), reference_data))
+ Log::Fatal << "Reference file " << reference_file << " not found." << endl;
+
+ Log::Info << "Loaded reference data from " << reference_file << endl;
+
+ // Sanity check on k value: must be greater than 0, must be less than the
+ // number of reference points.
+ size_t k = CLI::GetParam<int>("neighbor_search/k");
+ if ((k <= 0) || (k >= reference_data.n_cols))
+ {
+ Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+ Log::Fatal << "than the number of reference points (";
+ Log::Fatal << reference_data.n_cols << ")." << endl;
+ }
+
+ // Sanity check on leaf size.
+ if (CLI::GetParam<int>("tree/leaf_size") <= 0)
+ {
+ Log::Fatal << "Invalid leaf size: "
+ << CLI::GetParam<int>("allknn/leaf_size") << endl;
+ }
+
+ AllkNN* allknn = NULL;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string query_file = CLI::GetParam<string>("query_file");
+ arma::mat query_data;
+
+ if (!data::Load(query_file.c_str(), query_data))
+ Log::Fatal << "Query file " << query_file << " not found" << endl;
+
+ Log::Info << "Query data loaded from " << query_file << endl;
+
+ Log::Info << "Building query and reference trees..." << endl;
+ allknn = new AllkNN(query_data, reference_data);
+
+ }
+ else
+ {
+ Log::Info << "Building reference tree..." << endl;
+ allknn = new AllkNN(reference_data);
+ }
+
+ Log::Info << "Tree(s) built." << endl;
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allknn->ComputeNeighbors(neighbors, distances);
+
+ Log::Info << "Neighbors computed." << endl;
+ Log::Info << "Exporting results..." << endl;
+
+ // Should be using data::Save or a related function instead of being written
+ // by hand.
+ try
+ {
+ ofstream out(output_file.c_str());
+
+ for (size_t col = 0; col < neighbors.n_cols; col++)
+ {
+ out << col << ", ";
+ for (size_t j = 0; j < (k - 1) /* last is special case */; j++)
+ {
+ out << neighbors(j, col) << ", " << distances(j, col) << ", ";
+ }
+ out << neighbors((k - 1), col) << ", " << distances((k - 1), col) << endl;
+ }
+
+ out.close();
+ }
+ catch (exception& e)
+ {
+ Log::Fatal << "Error while opening " << output_file << ": " << e.what()
+ << endl;
+ }
+
+ delete allknn;
+}
Deleted: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.h 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.h 2011-11-24 02:42:06 UTC (rev 10378)
@@ -1,244 +0,0 @@
-/**
- * @file neighbor_search.h
- *
- * Defines the AllkNN class to perform all-k-nearest-neighbors on two specified
- * data sets.
- */
-#ifndef __MLPACK_NEIGHBOR_SEARCH_H
-#define __MLPACK_NEIGHBOR_SEARCH_H
-
-#include <mlpack/core.h>
-#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/binary_space_tree.hpp>
-#include <vector>
-#include <string>
-
-#include <mlpack/core/metrics/lmetric.hpp>
-#include "sort_policies/nearest_neighbor_sort.hpp"
-
-namespace mlpack {
-namespace neighbor /** Neighbor-search routines. These include
- * all-nearest-neighbors and all-furthest-neighbors
- * searches. */ {
-
-// Define CLI parameters for the NeighborSearch class.
-PARAM_MODULE("neighbor_search",
- "Parameters for the distance-based neighbor search.");
-PARAM_INT("k", "Number of neighbors to search for.", "neighbor_search", 5);
-PARAM_FLAG("single_mode", "If set, use single-tree mode (instead of "
- "dual-tree).", "neighbor_search");
-PARAM_FLAG("naive_mode", "If set, use naive computations (no trees). This "
- "overrides the single_mode flag.", "neighbor_search");
-
-/**
- * The NeighborSearch class is a template class for performing distance-based
- * neighbor searches. It takes a query dataset and a reference dataset (or just
- * a reference dataset) and, for each point in the query dataset, finds the k
- * neighbors in the reference dataset which have the 'best' distance according
- * to a given sorting policy. A constructor is given which takes only a
- * reference dataset, and if that constructor is used, the given reference
- * dataset is also used as the query dataset.
- *
- * The template parameters Kernel and SortPolicy define the distance function
- * used and the sort function used. More information on those classes can be
- * found in the kernel::ExampleKernel class and the NearestNeighborSort class.
- *
- * This class has several parameters configurable by the CLI interface:
- *
- * @param neighbor_search/k Parameters for the distance-based neighbor search.
- * @param neighbor_search/single_mode If set, single-tree mode will be used
- * (instead of dual-tree mode).
- * @param neighbor_search/naive_mode If set, naive computation will be used (no
- * trees). This overrides the single_mode flag.
- *
- * @tparam Kernel The kernel function; see kernel::ExampleKernel.
- * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
- */
-template<typename Kernel = mlpack::metric::SquaredEuclideanDistance,
- typename SortPolicy = NearestNeighborSort>
-class NeighborSearch {
-
- /**
- * Extra data for each node in the tree. For neighbor searches, each node
- * only needs to store a bound on neighbor distances.
- */
- class QueryStat {
-
- public:
- //! The bound on the node's neighbor distances.
- double bound_;
-
- /**
- * Initialize the statistic with the worst possible distance according to
- * our sorting policy.
- */
- QueryStat() : bound_(SortPolicy::WorstDistance()) { }
- };
-
- /**
- * Simple typedef for the trees, which use a bound and a QueryStat (to store
- * distances for each node). The bound should be configurable...
- */
- typedef tree::BinarySpaceTree<bound::HRectBound<2>, QueryStat> TreeType;
-
- private:
- //! Reference dataset.
- arma::mat references_;
- //! Query dataset (may not be given).
- arma::mat queries_;
-
- //! Instantiation of kernel.
- Kernel kernel_;
-
- //! Pointer to the root of the reference tree.
- TreeType* reference_tree_;
- //! Pointer to the root of the query tree (might not exist).
- TreeType* query_tree_;
-
- //! Permutations of query points during tree building.
- std::vector<size_t> old_from_new_queries_;
- //! Permutations of reference points during tree building.
- std::vector<size_t> old_from_new_references_;
-
- //! Indicates if O(n^2) naive search is being used.
- bool naive_;
- //! Indicates if dual-tree search is being used (opposed to single-tree).
- bool dual_mode_;
-
- //! Number of points in a leaf.
- size_t leaf_size_;
-
- //! Number of neighbors to compute.
- size_t knns_;
-
- //! Total number of pruned nodes during the neighbor search.
- size_t number_of_prunes_;
-
- //! The distance to the candidate nearest neighbor for each query
- arma::mat neighbor_distances_;
-
- //! The indices of the candidate nearest neighbor for each query
- arma::Mat<size_t> neighbor_indices_;
-
- public:
- /**
- * Initialize the NeighborSearch object, passing both a query and reference
- * dataset. An initialized distance metric can be given, for cases where the
- * metric has internal data (i.e. the distance::MahalanobisDistance class).
- *
- * @param queries_in Set of query points.
- * @param references_in Set of reference points.
- * @param alias_matrix If true, alias the passed matrices instead of copying
- * them. While this lowers memory footprint and computational load, the
- * points in the matrices will be rearranged during the tree-building
- * process! Defaults to false.
- * @param kernel An optional instance of the Kernel class.
- */
- NeighborSearch(arma::mat& queries_in, arma::mat& references_in,
- bool alias_matrix = false, Kernel kernel = Kernel());
-
- /**
- * Initialize the NeighborSearch object, passing only one dataset. In this
- * case, the query dataset is equivalent to the reference dataset, with one
- * caveat: for any point, the returned list of neighbors will not include
- * itself. An initialized distance metric can be given, for cases where the
- * metric has internal data (i.e. the distance::MahalanobisDistance class).
- *
- * @param references_in Set of reference points.
- * @param alias_matrix If true, alias the passed matrices instead of copying
- * them. While this lowers memory footprint and computational load, the
- * points in the matrices will be rearranged during the tree-building
- * process! Defaults to false.
- * @param kernel An optional instance of the Kernel class.
- */
- NeighborSearch(arma::mat& references_in, bool alias_matrix = false,
- Kernel kernel = Kernel());
-
- /**
- * Delete the NeighborSearch object. The tree is the only member we are
- * responsible for deleting. The others will take care of themselves.
- */
- ~NeighborSearch();
-
- /**
- * Compute the nearest neighbors and store the output in the given matrices.
- * The matrices will be set to the size of n columns by k rows, where n is the
- * number of points in the query dataset and k is the number of neighbors
- * being searched for.
- *
- * The parameter k is set through the CLI interface, not in the arguments to
- * this method; this allows users to specify k on the command line
- * ("--neighbor_search/k"). See the CLI documentation for further information
- * on how to use this functionality.
- *
- * @param resulting_neighbors Matrix storing lists of neighbors for each query
- * point.
- * @param distances Matrix storing distances of neighbors for each query
- * point.
- */
- void ComputeNeighbors(arma::Mat<size_t>& resulting_neighbors,
- arma::mat& distances);
-
- private:
- /**
- * Perform exhaustive computation between two leaves, comparing every node in
- * the leaf to the other leaf to find the furthest neighbor. The
- * neighbor_indices_ and neighbor_distances_ matrices will be updated with the
- * changed information.
- *
- * @param query_node Node in query tree. This should be a leaf
- * (bottom-level).
- * @param reference_node Node in reference tree. This should be a leaf
- * (bottom-level).
- */
- void ComputeBaseCase_(TreeType* query_node, TreeType* reference_node);
-
- /**
- * Recurse down the trees, computing base case computations when the leaves
- * are reached.
- *
- * @param query_node Node in query tree.
- * @param reference_node Node in reference tree.
- * @param lower_bound The lower bound; if above this, we can prune.
- */
- void ComputeDualNeighborsRecursion_(TreeType* query_node,
- TreeType* reference_node,
- double lower_bound);
-
- /**
- * Perform a recursion only on the reference tree; the query point is given.
- * This method is similar to ComputeBaseCase_().
- *
- * @param point_id Index of query point.
- * @param point The query point.
- * @param reference_node Reference node.
- * @param best_dist_so_far Best distance to a node so far -- used for pruning.
- */
- void ComputeSingleNeighborsRecursion_(size_t point_id, arma::vec& point,
- TreeType* reference_node,
- double& best_dist_so_far);
-
- /**
- * Insert a point into the neighbors and distances matrices; this is a helper
- * function.
- *
- * @param query_index Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
- void InsertNeighbor(size_t query_index, size_t pos, size_t neighbor,
- double distance);
-
-}; // class NeighborSearch
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-// Include implementation.
-#include "neighbor_search_impl.h"
-
-// Include convenience typedefs.
-#include "typedef.h"
-
-#endif
Copied: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp (from rev 10353, mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.h)
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -0,0 +1,245 @@
+/**
+ * @file neighbor_search.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the AllkNN class to perform all-k-nearest-neighbors on two specified
+ * data sets.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
+
+#include <mlpack/core.h>
+#include <mlpack/core/tree/bounds.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <vector>
+#include <string>
+
+#include <mlpack/core/metrics/lmetric.hpp>
+#include "sort_policies/nearest_neighbor_sort.hpp"
+
+namespace mlpack {
+namespace neighbor /** Neighbor-search routines. These include
+ * all-nearest-neighbors and all-furthest-neighbors
+ * searches. */ {
+
+// Define CLI parameters for the NeighborSearch class.
+PARAM_MODULE("neighbor_search",
+ "Parameters for the distance-based neighbor search.");
+PARAM_INT("k", "Number of neighbors to search for.", "neighbor_search", 5);
+PARAM_FLAG("single_mode", "If set, use single-tree mode (instead of "
+ "dual-tree).", "neighbor_search");
+PARAM_FLAG("naive_mode", "If set, use naive computations (no trees). This "
+ "overrides the single_mode flag.", "neighbor_search");
+
+/**
+ * The NeighborSearch class is a template class for performing distance-based
+ * neighbor searches. It takes a query dataset and a reference dataset (or just
+ * a reference dataset) and, for each point in the query dataset, finds the k
+ * neighbors in the reference dataset which have the 'best' distance according
+ * to a given sorting policy. A constructor is given which takes only a
+ * reference dataset, and if that constructor is used, the given reference
+ * dataset is also used as the query dataset.
+ *
+ * The template parameters Kernel and SortPolicy define the distance function
+ * used and the sort function used. More information on those classes can be
+ * found in the kernel::ExampleKernel class and the NearestNeighborSort class.
+ *
+ * This class has several parameters configurable by the CLI interface:
+ *
+ * @param neighbor_search/k Parameters for the distance-based neighbor search.
+ * @param neighbor_search/single_mode If set, single-tree mode will be used
+ * (instead of dual-tree mode).
+ * @param neighbor_search/naive_mode If set, naive computation will be used (no
+ * trees). This overrides the single_mode flag.
+ *
+ * @tparam Kernel The kernel function; see kernel::ExampleKernel.
+ * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
+ */
+template<typename Kernel = mlpack::metric::SquaredEuclideanDistance,
+ typename SortPolicy = NearestNeighborSort>
+class NeighborSearch
+{
+ /**
+ * Extra data for each node in the tree. For neighbor searches, each node
+ * only needs to store a bound on neighbor distances.
+ */
+ class QueryStat
+ {
+ public:
+ //! The bound on the node's neighbor distances.
+ double bound_;
+
+ /**
+ * Initialize the statistic with the worst possible distance according to
+ * our sorting policy.
+ */
+ QueryStat() : bound_(SortPolicy::WorstDistance()) { }
+ };
+
+ /**
+ * Simple typedef for the trees, which use a bound and a QueryStat (to store
+ * distances for each node). The bound should be configurable...
+ */
+ typedef tree::BinarySpaceTree<bound::HRectBound<2>, QueryStat> TreeType;
+
+ private:
+ //! Reference dataset.
+ arma::mat references_;
+ //! Query dataset (may not be given).
+ arma::mat queries_;
+
+ //! Instantiation of kernel.
+ Kernel kernel_;
+
+ //! Pointer to the root of the reference tree.
+ TreeType* reference_tree_;
+ //! Pointer to the root of the query tree (might not exist).
+ TreeType* query_tree_;
+
+ //! Permutations of query points during tree building.
+ std::vector<size_t> old_from_new_queries_;
+ //! Permutations of reference points during tree building.
+ std::vector<size_t> old_from_new_references_;
+
+ //! Indicates if O(n^2) naive search is being used.
+ bool naive_;
+ //! Indicates if dual-tree search is being used (opposed to single-tree).
+ bool dual_mode_;
+
+ //! Number of points in a leaf.
+ size_t leaf_size_;
+
+ //! Number of neighbors to compute.
+ size_t knns_;
+
+ //! Total number of pruned nodes during the neighbor search.
+ size_t number_of_prunes_;
+
+ //! The distance to the candidate nearest neighbor for each query
+ arma::mat neighbor_distances_;
+
+ //! The indices of the candidate nearest neighbor for each query
+ arma::Mat<size_t> neighbor_indices_;
+
+ public:
+ /**
+ * Initialize the NeighborSearch object, passing both a query and reference
+ * dataset. An initialized distance metric can be given, for cases where the
+ * metric has internal data (i.e. the distance::MahalanobisDistance class).
+ *
+ * @param queries_in Set of query points.
+ * @param references_in Set of reference points.
+ * @param alias_matrix If true, alias the passed matrices instead of copying
+ * them. While this lowers memory footprint and computational load, the
+ * points in the matrices will be rearranged during the tree-building
+ * process! Defaults to false.
+ * @param kernel An optional instance of the Kernel class.
+ */
+ NeighborSearch(arma::mat& queries_in, arma::mat& references_in,
+ bool alias_matrix = false, Kernel kernel = Kernel());
+
+ /**
+ * Initialize the NeighborSearch object, passing only one dataset. In this
+ * case, the query dataset is equivalent to the reference dataset, with one
+ * caveat: for any point, the returned list of neighbors will not include
+ * itself. An initialized distance metric can be given, for cases where the
+ * metric has internal data (i.e. the distance::MahalanobisDistance class).
+ *
+ * @param references_in Set of reference points.
+ * @param alias_matrix If true, alias the passed matrices instead of copying
+ * them. While this lowers memory footprint and computational load, the
+ * points in the matrices will be rearranged during the tree-building
+ * process! Defaults to false.
+ * @param kernel An optional instance of the Kernel class.
+ */
+ NeighborSearch(arma::mat& references_in, bool alias_matrix = false,
+ Kernel kernel = Kernel());
+
+ /**
+ * Delete the NeighborSearch object. The tree is the only member we are
+ * responsible for deleting. The others will take care of themselves.
+ */
+ ~NeighborSearch();
+
+ /**
+ * Compute the nearest neighbors and store the output in the given matrices.
+ * The matrices will be set to the size of n columns by k rows, where n is the
+ * number of points in the query dataset and k is the number of neighbors
+ * being searched for.
+ *
+ * The parameter k is set through the CLI interface, not in the arguments to
+ * this method; this allows users to specify k on the command line
+ * ("--neighbor_search/k"). See the CLI documentation for further information
+ * on how to use this functionality.
+ *
+ * @param resulting_neighbors Matrix storing lists of neighbors for each query
+ * point.
+ * @param distances Matrix storing distances of neighbors for each query
+ * point.
+ */
+ void ComputeNeighbors(arma::Mat<size_t>& resulting_neighbors,
+ arma::mat& distances);
+
+ private:
+ /**
+ * Perform exhaustive computation between two leaves, comparing every node in
+ * the leaf to the other leaf to find the furthest neighbor. The
+ * neighbor_indices_ and neighbor_distances_ matrices will be updated with the
+ * changed information.
+ *
+ * @param query_node Node in query tree. This should be a leaf
+ * (bottom-level).
+ * @param reference_node Node in reference tree. This should be a leaf
+ * (bottom-level).
+ */
+ void ComputeBaseCase_(TreeType* query_node, TreeType* reference_node);
+
+ /**
+ * Recurse down the trees, computing base case computations when the leaves
+ * are reached.
+ *
+ * @param query_node Node in query tree.
+ * @param reference_node Node in reference tree.
+ * @param lower_bound The lower bound; if above this, we can prune.
+ */
+ void ComputeDualNeighborsRecursion_(TreeType* query_node,
+ TreeType* reference_node,
+ double lower_bound);
+
+ /**
+ * Perform a recursion only on the reference tree; the query point is given.
+ * This method is similar to ComputeBaseCase_().
+ *
+ * @param point_id Index of query point.
+ * @param point The query point.
+ * @param reference_node Reference node.
+ * @param best_dist_so_far Best distance to a node so far -- used for pruning.
+ */
+ void ComputeSingleNeighborsRecursion_(size_t point_id, arma::vec& point,
+ TreeType* reference_node,
+ double& best_dist_so_far);
+
+ /**
+ * Insert a point into the neighbors and distances matrices; this is a helper
+ * function.
+ *
+ * @param query_index Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+ void InsertNeighbor(size_t query_index, size_t pos, size_t neighbor,
+ double distance);
+
+}; // class NeighborSearch
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation.
+#include "neighbor_search_impl.hpp"
+
+// Include convenience typedefs.
+#include "typedef.hpp"
+
+#endif
Deleted: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.h 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.h 2011-11-24 02:42:06 UTC (rev 10378)
@@ -1,464 +0,0 @@
-/**
- * @file neighbor_search.cc
- *
- * Implementation of AllkNN class to perform all-nearest-neighbors on two
- * specified data sets.
- */
-#ifndef __MLPACK_NEIGHBOR_SEARCH_IMPL_H
-#define __MLPACK_NEIGHBOR_SEARCH_IMPL_H
-
-#include <mlpack/core.h>
-
-using namespace mlpack::neighbor;
-
-// We call an advanced constructor of arma::mat which allows us to alias a
-// matrix (if the user has asked for that).
-template<typename Kernel, typename SortPolicy>
-NeighborSearch<Kernel, SortPolicy>::NeighborSearch(arma::mat& queries_in,
- arma::mat& references_in,
- bool alias_matrix,
- Kernel kernel) :
- references_(references_in.memptr(), references_in.n_rows,
- references_in.n_cols, !alias_matrix),
- queries_(queries_in.memptr(), queries_in.n_rows, queries_in.n_cols,
- !alias_matrix),
- kernel_(kernel),
- naive_(CLI::GetParam<bool>("neighbor_search/naive_mode")),
- dual_mode_(!(naive_ || CLI::GetParam<bool>("neighbor_search/single_mode"))),
- number_of_prunes_(0) {
-
- // C++0x will allow us to call out to other constructors so we can avoid this
- // copypasta problem.
-
- // Get the leaf size; naive ensures that the entire tree is one node
- if (naive_)
- CLI::GetParam<int>("tree/leaf_size") =
- std::max(queries_.n_cols, references_.n_cols);
-
- // K-nearest neighbors initialization
- knns_ = CLI::GetParam<int>("neighbor_search/k");
-
- // Initialize the list of nearest neighbor candidates
- neighbor_indices_.set_size(knns_, queries_.n_cols);
-
- // Initialize the vector of upper bounds for each point.
- neighbor_distances_.set_size(knns_, queries_.n_cols);
- neighbor_distances_.fill(SortPolicy::WorstDistance());
-
- // We'll time tree building
- Timers::StartTimer("neighbor_search/tree_building");
-
- // This call makes each tree from a matrix, leaf size, and two arrays
- // that record the permutation of the data points
- query_tree_ = new TreeType(queries_, old_from_new_queries_);
- reference_tree_ = new TreeType(references_, old_from_new_references_);
-
- // Stop the timer we started above
- Timers::StopTimer("neighbor_search/tree_building");
-}
-
-// We call an advanced constructor of arma::mat which allows us to alias a
-// matrix (if the user has asked for that).
-template<typename Kernel, typename SortPolicy>
-NeighborSearch<Kernel, SortPolicy>::NeighborSearch(arma::mat& references_in,
- bool alias_matrix,
- Kernel kernel) :
- references_(references_in.memptr(), references_in.n_rows,
- references_in.n_cols, !alias_matrix),
- queries_(references_.memptr(), references_.n_rows, references_.n_cols,
- false),
- kernel_(kernel),
- naive_(CLI::GetParam<bool>("neighbor_search/naive_mode")),
- dual_mode_(!(naive_ || CLI::GetParam<bool>("neighbor_search/single_mode"))),
- number_of_prunes_(0) {
-
- // Get the leaf size from the module
- if (naive_)
- CLI::GetParam<int>("tree/leaf_size") =
- std::max(queries_.n_cols, references_.n_cols);
-
- // K-nearest neighbors initialization
- knns_ = CLI::GetParam<int>("neighbor_search/k");
-
- // Initialize the list of nearest neighbor candidates
- neighbor_indices_.set_size(knns_, references_.n_cols);
-
- // Initialize the vector of upper bounds for each point.
- neighbor_distances_.set_size(knns_, references_.n_cols);
- neighbor_distances_.fill(SortPolicy::WorstDistance());
-
- // We'll time tree building
- Timers::StartTimer("neighbor_search/tree_building");
-
- // This call makes each tree from a matrix, leaf size, and two arrays
- // that record the permutation of the data points
- // Instead of NULL, it is possible to specify an array new_from_old_
- query_tree_ = NULL;
- reference_tree_ = new TreeType(references_, old_from_new_references_);
-
- // Stop the timer we started above
- Timers::StopTimer("neighbor_search/tree_building");
-}
-
-/**
- * The tree is the only member we are responsible for deleting. The others will
- * take care of themselves.
- */
-template<typename Kernel, typename SortPolicy>
-NeighborSearch<Kernel, SortPolicy>::~NeighborSearch() {
- if (reference_tree_ != query_tree_)
- delete reference_tree_;
- if (query_tree_ != NULL)
- delete query_tree_;
-}
-
-/**
- * Performs exhaustive computation between two leaves.
- */
-template<typename Kernel, typename SortPolicy>
-void NeighborSearch<Kernel, SortPolicy>::ComputeBaseCase_(
- TreeType* query_node,
- TreeType* reference_node) {
- // Used to find the query node's new upper bound
- double query_worst_distance = SortPolicy::BestDistance();
-
- // node->begin() is the index of the first point in the node,
- // node->end is one past the last index
- for (size_t query_index = query_node->begin();
- query_index < query_node->end(); query_index++) {
-
- // Get the query point from the matrix
- arma::vec query_point = queries_.unsafe_col(query_index);
-
- double query_to_node_distance =
- SortPolicy::BestPointToNodeDistance(query_point, reference_node);
-
- if (SortPolicy::IsBetter(query_to_node_distance,
- neighbor_distances_(knns_ - 1, query_index))) {
- // We'll do the same for the references
- for (size_t reference_index = reference_node->begin();
- reference_index < reference_node->end(); reference_index++) {
-
- // Confirm that points do not identify themselves as neighbors
- // in the monochromatic case
- if (reference_node != query_node || reference_index != query_index) {
- arma::vec reference_point = references_.unsafe_col(reference_index);
-
- double distance = kernel_.Evaluate(query_point, reference_point);
-
- // If the reference point is closer than any of the current
- // candidates, add it to the list.
- arma::vec query_dist = neighbor_distances_.unsafe_col(query_index);
- size_t insert_position = SortPolicy::SortDistance(query_dist,
- distance);
-
- if (insert_position != (size_t() - 1)) {
- InsertNeighbor(query_index, insert_position, reference_index,
- distance);
- }
- }
- }
- }
-
- // We need to find the upper bound distance for this query node
- if (SortPolicy::IsBetter(query_worst_distance,
- neighbor_distances_(knns_ - 1, query_index)))
- query_worst_distance = neighbor_distances_(knns_ - 1, query_index);
- }
-
- // Update the upper bound for the query_node
- query_node->stat().bound_ = query_worst_distance;
-
-} // ComputeBaseCase_
-
-/**
- * The recursive function for dual tree
- */
-template<typename Kernel, typename SortPolicy>
-void NeighborSearch<Kernel, SortPolicy>::ComputeDualNeighborsRecursion_(
- TreeType* query_node,
- TreeType* reference_node,
- double lower_bound) {
-
- if (SortPolicy::IsBetter(query_node->stat().bound_, lower_bound)) {
- number_of_prunes_++; // Pruned by distance; the nodes cannot be any closer
- return; // than the already established lower bound.
- }
-
- if (query_node->is_leaf() && reference_node->is_leaf()) {
- ComputeBaseCase_(query_node, reference_node); // Base case: both are leaves.
- return;
- }
-
- if (query_node->is_leaf()) {
- // We must keep descending down the reference node to get to a leaf.
-
- // We'll order the computation by distance; descend in the direction of less
- // distance first.
- double left_distance = SortPolicy::BestNodeToNodeDistance(query_node,
- reference_node->left());
- double right_distance = SortPolicy::BestNodeToNodeDistance(query_node,
- reference_node->right());
-
- if (SortPolicy::IsBetter(left_distance, right_distance)) {
- ComputeDualNeighborsRecursion_(query_node, reference_node->left(),
- left_distance);
- ComputeDualNeighborsRecursion_(query_node, reference_node->right(),
- right_distance);
- } else {
- ComputeDualNeighborsRecursion_(query_node, reference_node->right(),
- right_distance);
- ComputeDualNeighborsRecursion_(query_node, reference_node->left(),
- left_distance);
- }
- return;
- }
-
- if (reference_node->is_leaf()) {
- // We must descend down the query node to get to a leaf.
- double left_distance = SortPolicy::BestNodeToNodeDistance(
- query_node->left(), reference_node);
- double right_distance = SortPolicy::BestNodeToNodeDistance(
- query_node->right(), reference_node);
-
- ComputeDualNeighborsRecursion_(query_node->left(), reference_node,
- left_distance);
- ComputeDualNeighborsRecursion_(query_node->right(), reference_node,
- right_distance);
-
- // We need to update the upper bound based on the new upper bounds of
- // the children
- double left_bound = query_node->left()->stat().bound_;
- double right_bound = query_node->right()->stat().bound_;
-
- if (SortPolicy::IsBetter(left_bound, right_bound))
- query_node->stat().bound_ = right_bound;
- else
- query_node->stat().bound_ = left_bound;
-
- return;
- }
-
- // Neither side is a leaf; so we recurse on all combinations of both. The
- // calculations are ordered by distance.
- double left_distance = SortPolicy::BestNodeToNodeDistance(query_node->left(),
- reference_node->left());
- double right_distance = SortPolicy::BestNodeToNodeDistance(query_node->left(),
- reference_node->right());
-
- // Recurse on query_node->left() first.
- if (SortPolicy::IsBetter(left_distance, right_distance)) {
- ComputeDualNeighborsRecursion_(query_node->left(), reference_node->left(),
- left_distance);
- ComputeDualNeighborsRecursion_(query_node->left(), reference_node->right(),
- right_distance);
- } else {
- ComputeDualNeighborsRecursion_(query_node->left(), reference_node->right(),
- right_distance);
- ComputeDualNeighborsRecursion_(query_node->left(), reference_node->left(),
- left_distance);
- }
-
- left_distance = SortPolicy::BestNodeToNodeDistance(query_node->right(),
- reference_node->left());
- right_distance = SortPolicy::BestNodeToNodeDistance(query_node->right(),
- reference_node->right());
-
- // Now recurse on query_node->right().
- if (SortPolicy::IsBetter(left_distance, right_distance)) {
- ComputeDualNeighborsRecursion_(query_node->right(), reference_node->left(),
- left_distance);
- ComputeDualNeighborsRecursion_(query_node->right(), reference_node->right(),
- right_distance);
- } else {
- ComputeDualNeighborsRecursion_(query_node->right(), reference_node->right(),
- right_distance);
- ComputeDualNeighborsRecursion_(query_node->right(), reference_node->left(),
- left_distance);
- }
-
- // Update the upper bound as above
- double left_bound = query_node->left()->stat().bound_;
- double right_bound = query_node->right()->stat().bound_;
-
- if (SortPolicy::IsBetter(left_bound, right_bound))
- query_node->stat().bound_ = right_bound;
- else
- query_node->stat().bound_ = left_bound;
-
-} // ComputeDualNeighborsRecursion_
-
-template<typename Kernel, typename SortPolicy>
-void NeighborSearch<Kernel, SortPolicy>::ComputeSingleNeighborsRecursion_(
- size_t point_id,
- arma::vec& point,
- TreeType* reference_node,
- double& best_dist_so_far) {
-
- if (reference_node->is_leaf()) {
- // Base case: reference node is a leaf
-
- for (size_t reference_index = reference_node->begin();
- reference_index < reference_node->end(); reference_index++) {
- // Confirm that points do not identify themselves as neighbors
- // in the monochromatic case
- if (!(references_.memptr() == queries_.memptr() &&
- reference_index == point_id)) {
- arma::vec reference_point = references_.unsafe_col(reference_index);
-
- double distance = kernel_.Evaluate(point, reference_point);
-
- // If the reference point is better than any of the current candidates,
- // insert it into the list correctly.
- arma::vec query_dist = neighbor_distances_.unsafe_col(point_id);
- size_t insert_position = SortPolicy::SortDistance(query_dist,
- distance);
-
- if (insert_position != (size_t() - 1))
- InsertNeighbor(point_id, insert_position, reference_index, distance);
- }
- } // for reference_index
-
- best_dist_so_far = neighbor_distances_(knns_ - 1, point_id);
- } else {
- // We'll order the computation by distance.
- double left_distance = SortPolicy::BestPointToNodeDistance(point,
- reference_node->left());
- double right_distance = SortPolicy::BestPointToNodeDistance(point,
- reference_node->right());
-
- // Recurse in the best direction first.
- if (SortPolicy::IsBetter(left_distance, right_distance)) {
- if (SortPolicy::IsBetter(best_dist_so_far, left_distance))
- number_of_prunes_++; // Prune; no possibility of finding a better point.
- else
- ComputeSingleNeighborsRecursion_(point_id, point,
- reference_node->left(), best_dist_so_far);
-
- if (SortPolicy::IsBetter(best_dist_so_far, right_distance))
- number_of_prunes_++; // Prune; no possibility of finding a better point.
- else
- ComputeSingleNeighborsRecursion_(point_id, point,
- reference_node->right(), best_dist_so_far);
-
- } else {
- if (SortPolicy::IsBetter(best_dist_so_far, right_distance))
- number_of_prunes_++; // Prune; no possibility of finding a better point.
- else
- ComputeSingleNeighborsRecursion_(point_id, point,
- reference_node->right(), best_dist_so_far);
-
- if (SortPolicy::IsBetter(best_dist_so_far, left_distance))
- number_of_prunes_++; // Prune; no possibility of finding a better point.
- else
- ComputeSingleNeighborsRecursion_(point_id, point,
- reference_node->left(), best_dist_so_far);
- }
- }
-}
-
-/**
- * Computes the best neighbors and stores them in resulting_neighbors and
- * distances.
- */
-template<typename Kernel, typename SortPolicy>
-void NeighborSearch<Kernel, SortPolicy>::ComputeNeighbors(
- arma::Mat<size_t>& resulting_neighbors,
- arma::mat& distances) {
-
- Timers::StartTimer("neighbor_search/computing_neighbors");
- if (naive_) {
- // Run the base case computation on all nodes
- if (query_tree_)
- ComputeBaseCase_(query_tree_, reference_tree_);
- else
- ComputeBaseCase_(reference_tree_, reference_tree_);
- } else {
- if (dual_mode_) {
- // Start on the root of each tree
- if (query_tree_) {
- ComputeDualNeighborsRecursion_(query_tree_, reference_tree_,
- SortPolicy::BestNodeToNodeDistance(query_tree_, reference_tree_));
- } else {
- ComputeDualNeighborsRecursion_(reference_tree_, reference_tree_,
- SortPolicy::BestNodeToNodeDistance(reference_tree_,
- reference_tree_));
- }
- } else {
- size_t chunk = queries_.n_cols / 10;
-
- for(size_t i = 0; i < 10; i++) {
- for(size_t j = 0; j < chunk; j++) {
- arma::vec point = queries_.unsafe_col(i * chunk + j);
- double best_dist_so_far = SortPolicy::WorstDistance();
- ComputeSingleNeighborsRecursion_(i * chunk + j, point,
- reference_tree_, best_dist_so_far);
- }
- }
- for(size_t i = 0; i < queries_.n_cols % 10; i++) {
- size_t ind = (queries_.n_cols / 10) * 10 + i;
- arma::vec point = queries_.unsafe_col(ind);
- double best_dist_so_far = SortPolicy::WorstDistance();
- ComputeSingleNeighborsRecursion_(ind, point, reference_tree_,
- best_dist_so_far);
- }
- }
- }
-
- Timers::StopTimer("neighbor_search/computing_neighbors");
-
- // We need to initialize the results list before filling it
- resulting_neighbors.set_size(neighbor_indices_.n_rows,
- neighbor_indices_.n_cols);
- distances.set_size(neighbor_distances_.n_rows, neighbor_distances_.n_cols);
-
- // We need to map the indices back from how they have been permuted
- if (query_tree_ != NULL) {
- for (size_t i = 0; i < neighbor_indices_.n_cols; i++) {
- for (size_t k = 0; k < neighbor_indices_.n_rows; k++) {
- resulting_neighbors(k, old_from_new_queries_[i]) =
- old_from_new_references_[neighbor_indices_(k, i)];
- distances(k, old_from_new_queries_[i]) = neighbor_distances_(k, i);
- }
- }
- } else {
- for (size_t i = 0; i < neighbor_indices_.n_cols; i++) {
- for (size_t k = 0; k < neighbor_indices_.n_rows; k++) {
- resulting_neighbors(k, old_from_new_references_[i]) =
- old_from_new_references_[neighbor_indices_(k, i)];
- distances(k, old_from_new_references_[i]) = neighbor_distances_(k, i);
- }
- }
- }
-} // ComputeNeighbors
-
-/***
- * Helper function to insert a point into the neighbors and distances matrices.
- *
- * @param query_index Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename Kernel, typename SortPolicy>
-void NeighborSearch<Kernel, SortPolicy>::InsertNeighbor(size_t query_index,
- size_t pos,
- size_t neighbor,
- double distance) {
- // We only memmove() if there is actually a need to shift something.
- if (pos < (knns_ - 1)) {
- int len = (knns_ - 1) - pos;
- memmove(neighbor_distances_.colptr(query_index) + (pos + 1),
- neighbor_distances_.colptr(query_index) + pos,
- sizeof(double) * len);
- memmove(neighbor_indices_.colptr(query_index) + (pos + 1),
- neighbor_indices_.colptr(query_index) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- neighbor_distances_(pos, query_index) = distance;
- neighbor_indices_(pos, query_index) = neighbor;
-}
-
-#endif
Copied: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp (from rev 10352, mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.h)
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -0,0 +1,510 @@
+/**
+ * @file neighbor_search_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of Neighbor-Search class to perform all-nearest-neighbors on
+ * two specified data sets.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
+
+#include <mlpack/core.h>
+
+using namespace mlpack::neighbor;
+
+// We call an advanced constructor of arma::mat which allows us to alias a
+// matrix (if the user has asked for that).
+template<typename Kernel, typename SortPolicy>
+NeighborSearch<Kernel, SortPolicy>::NeighborSearch(arma::mat& queries_in,
+ arma::mat& references_in,
+ bool alias_matrix,
+ Kernel kernel) :
+ references_(references_in.memptr(), references_in.n_rows,
+ references_in.n_cols, !alias_matrix),
+ queries_(queries_in.memptr(), queries_in.n_rows, queries_in.n_cols,
+ !alias_matrix),
+ kernel_(kernel),
+ naive_(CLI::GetParam<bool>("neighbor_search/naive_mode")),
+ dual_mode_(!(naive_ || CLI::GetParam<bool>("neighbor_search/single_mode"))),
+ number_of_prunes_(0)
+{
+ // C++11 will allow us to call out to other constructors so we can avoid this
+ // copypasta problem.
+
+ // Get the leaf size; naive ensures that the entire tree is one node
+ if (naive_)
+ CLI::GetParam<int>("tree/leaf_size") =
+ std::max(queries_.n_cols, references_.n_cols);
+
+ // K-nearest neighbors initialization
+ knns_ = CLI::GetParam<int>("neighbor_search/k");
+
+ // Initialize the list of nearest neighbor candidates
+ neighbor_indices_.set_size(knns_, queries_.n_cols);
+
+ // Initialize the vector of upper bounds for each point.
+ neighbor_distances_.set_size(knns_, queries_.n_cols);
+ neighbor_distances_.fill(SortPolicy::WorstDistance());
+
+ // We'll time tree building
+ Timers::StartTimer("neighbor_search/tree_building");
+
+ // This call makes each tree from a matrix, leaf size, and two arrays
+ // that record the permutation of the data points
+ query_tree_ = new TreeType(queries_, old_from_new_queries_);
+ reference_tree_ = new TreeType(references_, old_from_new_references_);
+
+ // Stop the timer we started above
+ Timers::StopTimer("neighbor_search/tree_building");
+}
+
+// We call an advanced constructor of arma::mat which allows us to alias a
+// matrix (if the user has asked for that).
+template<typename Kernel, typename SortPolicy>
+NeighborSearch<Kernel, SortPolicy>::NeighborSearch(arma::mat& references_in,
+ bool alias_matrix,
+ Kernel kernel) :
+ references_(references_in.memptr(), references_in.n_rows,
+ references_in.n_cols, !alias_matrix),
+ queries_(references_.memptr(), references_.n_rows, references_.n_cols,
+ false),
+ kernel_(kernel),
+ naive_(CLI::GetParam<bool>("neighbor_search/naive_mode")),
+ dual_mode_(!(naive_ || CLI::GetParam<bool>("neighbor_search/single_mode"))),
+ number_of_prunes_(0)
+{
+ // Get the leaf size from the module
+ if (naive_)
+ CLI::GetParam<int>("tree/leaf_size") =
+ std::max(queries_.n_cols, references_.n_cols);
+
+ // K-nearest neighbors initialization
+ knns_ = CLI::GetParam<int>("neighbor_search/k");
+
+ // Initialize the list of nearest neighbor candidates
+ neighbor_indices_.set_size(knns_, references_.n_cols);
+
+ // Initialize the vector of upper bounds for each point.
+ neighbor_distances_.set_size(knns_, references_.n_cols);
+ neighbor_distances_.fill(SortPolicy::WorstDistance());
+
+ // We'll time tree building
+ Timers::StartTimer("neighbor_search/tree_building");
+
+ // This call makes each tree from a matrix, leaf size, and two arrays
+ // that record the permutation of the data points
+ // Instead of NULL, it is possible to specify an array new_from_old_
+ query_tree_ = NULL;
+ reference_tree_ = new TreeType(references_, old_from_new_references_);
+
+ // Stop the timer we started above
+ Timers::StopTimer("neighbor_search/tree_building");
+}
+
+/**
+ * The tree is the only member we are responsible for deleting. The others will
+ * take care of themselves.
+ */
+template<typename Kernel, typename SortPolicy>
+NeighborSearch<Kernel, SortPolicy>::~NeighborSearch()
+{
+ if (reference_tree_ != query_tree_)
+ delete reference_tree_;
+ if (query_tree_ != NULL)
+ delete query_tree_;
+}
+
+/**
+ * Performs exhaustive computation between two leaves.
+ */
+template<typename Kernel, typename SortPolicy>
+void NeighborSearch<Kernel, SortPolicy>::ComputeBaseCase_(
+ TreeType* query_node,
+ TreeType* reference_node)
+{
+ // Used to find the query node's new upper bound.
+ double query_worst_distance = SortPolicy::BestDistance();
+
+ // node->begin() is the index of the first point in the node,
+ // node->end is one past the last index.
+ for (size_t query_index = query_node->begin();
+ query_index < query_node->end(); query_index++)
+ {
+ // Get the query point from the matrix.
+ arma::vec query_point = queries_.unsafe_col(query_index);
+
+ double query_to_node_distance =
+ SortPolicy::BestPointToNodeDistance(query_point, reference_node);
+
+ if (SortPolicy::IsBetter(query_to_node_distance,
+ neighbor_distances_(knns_ - 1, query_index)))
+ {
+ // We'll do the same for the references.
+ for (size_t reference_index = reference_node->begin();
+ reference_index < reference_node->end(); reference_index++)
+ {
+ // Confirm that points do not identify themselves as neighbors
+ // in the monochromatic case.
+ if (reference_node != query_node || reference_index != query_index)
+ {
+ arma::vec reference_point = references_.unsafe_col(reference_index);
+
+ double distance = kernel_.Evaluate(query_point, reference_point);
+
+ // If the reference point is closer than any of the current
+ // candidates, add it to the list.
+ arma::vec query_dist = neighbor_distances_.unsafe_col(query_index);
+ size_t insert_position = SortPolicy::SortDistance(query_dist,
+ distance);
+
+ if (insert_position != (size_t() - 1))
+ InsertNeighbor(query_index, insert_position, reference_index,
+ distance);
+ }
+ }
+ }
+
+ // We need to find the upper bound distance for this query node
+ if (SortPolicy::IsBetter(query_worst_distance,
+ neighbor_distances_(knns_ - 1, query_index)))
+ query_worst_distance = neighbor_distances_(knns_ - 1, query_index);
+ }
+
+ // Update the upper bound for the query_node
+ query_node->stat().bound_ = query_worst_distance;
+
+} // ComputeBaseCase_
+
+/**
+ * The recursive function for dual tree.
+ */
+template<typename Kernel, typename SortPolicy>
+void NeighborSearch<Kernel, SortPolicy>::ComputeDualNeighborsRecursion_(
+ TreeType* query_node,
+ TreeType* reference_node,
+ double lower_bound)
+{
+ if (SortPolicy::IsBetter(query_node->stat().bound_, lower_bound))
+ {
+ number_of_prunes_++; // Pruned by distance; the nodes cannot be any closer
+ return; // than the already established lower bound.
+ }
+
+ if (query_node->is_leaf() && reference_node->is_leaf())
+ {
+ ComputeBaseCase_(query_node, reference_node); // Base case: both are leaves.
+ return;
+ }
+
+ if (query_node->is_leaf())
+ {
+ // We must keep descending down the reference node to get to a leaf.
+
+ // We'll order the computation by distance; descend in the direction of less
+ // distance first.
+ double left_distance = SortPolicy::BestNodeToNodeDistance(query_node,
+ reference_node->left());
+ double right_distance = SortPolicy::BestNodeToNodeDistance(query_node,
+ reference_node->right());
+
+ if (SortPolicy::IsBetter(left_distance, right_distance))
+ {
+ ComputeDualNeighborsRecursion_(query_node, reference_node->left(),
+ left_distance);
+ ComputeDualNeighborsRecursion_(query_node, reference_node->right(),
+ right_distance);
+ }
+ else
+ {
+ ComputeDualNeighborsRecursion_(query_node, reference_node->right(),
+ right_distance);
+ ComputeDualNeighborsRecursion_(query_node, reference_node->left(),
+ left_distance);
+ }
+ return;
+ }
+
+ if (reference_node->is_leaf())
+ {
+ // We must descend down the query node to get to a leaf.
+ double left_distance = SortPolicy::BestNodeToNodeDistance(
+ query_node->left(), reference_node);
+ double right_distance = SortPolicy::BestNodeToNodeDistance(
+ query_node->right(), reference_node);
+
+ ComputeDualNeighborsRecursion_(query_node->left(), reference_node,
+ left_distance);
+ ComputeDualNeighborsRecursion_(query_node->right(), reference_node,
+ right_distance);
+
+ // We need to update the upper bound based on the new upper bounds of
+ // the children
+ double left_bound = query_node->left()->stat().bound_;
+ double right_bound = query_node->right()->stat().bound_;
+
+ if (SortPolicy::IsBetter(left_bound, right_bound))
+ query_node->stat().bound_ = right_bound;
+ else
+ query_node->stat().bound_ = left_bound;
+
+ return;
+ }
+
+ // Neither side is a leaf; so we recurse on all combinations of both. The
+ // calculations are ordered by distance.
+ double left_distance = SortPolicy::BestNodeToNodeDistance(query_node->left(),
+ reference_node->left());
+ double right_distance = SortPolicy::BestNodeToNodeDistance(query_node->left(),
+ reference_node->right());
+
+ // Recurse on query_node->left() first.
+ if (SortPolicy::IsBetter(left_distance, right_distance))
+ {
+ ComputeDualNeighborsRecursion_(query_node->left(), reference_node->left(),
+ left_distance);
+ ComputeDualNeighborsRecursion_(query_node->left(), reference_node->right(),
+ right_distance);
+ }
+ else
+ {
+ ComputeDualNeighborsRecursion_(query_node->left(), reference_node->right(),
+ right_distance);
+ ComputeDualNeighborsRecursion_(query_node->left(), reference_node->left(),
+ left_distance);
+ }
+
+ left_distance = SortPolicy::BestNodeToNodeDistance(query_node->right(),
+ reference_node->left());
+ right_distance = SortPolicy::BestNodeToNodeDistance(query_node->right(),
+ reference_node->right());
+
+ // Now recurse on query_node->right().
+ if (SortPolicy::IsBetter(left_distance, right_distance))
+ {
+ ComputeDualNeighborsRecursion_(query_node->right(), reference_node->left(),
+ left_distance);
+ ComputeDualNeighborsRecursion_(query_node->right(), reference_node->right(),
+ right_distance);
+ }
+ else
+ {
+ ComputeDualNeighborsRecursion_(query_node->right(), reference_node->right(),
+ right_distance);
+ ComputeDualNeighborsRecursion_(query_node->right(), reference_node->left(),
+ left_distance);
+ }
+
+ // Update the upper bound as above
+ double left_bound = query_node->left()->stat().bound_;
+ double right_bound = query_node->right()->stat().bound_;
+
+ if (SortPolicy::IsBetter(left_bound, right_bound))
+ query_node->stat().bound_ = right_bound;
+ else
+ query_node->stat().bound_ = left_bound;
+
+} // ComputeDualNeighborsRecursion_
+
+template<typename Kernel, typename SortPolicy>
+void NeighborSearch<Kernel, SortPolicy>::ComputeSingleNeighborsRecursion_(
+ size_t point_id,
+ arma::vec& point,
+ TreeType* reference_node,
+ double& best_dist_so_far)
+{
+ if (reference_node->is_leaf())
+ {
+ // Base case: reference node is a leaf.
+ for (size_t reference_index = reference_node->begin();
+ reference_index < reference_node->end(); reference_index++)
+ {
+ // Confirm that points do not identify themselves as neighbors
+ // in the monochromatic case
+ if (!(references_.memptr() == queries_.memptr() &&
+ reference_index == point_id))
+ {
+ arma::vec reference_point = references_.unsafe_col(reference_index);
+
+ double distance = kernel_.Evaluate(point, reference_point);
+
+ // If the reference point is better than any of the current candidates,
+ // insert it into the list correctly.
+ arma::vec query_dist = neighbor_distances_.unsafe_col(point_id);
+ size_t insert_position = SortPolicy::SortDistance(query_dist,
+ distance);
+
+ if (insert_position != (size_t() - 1))
+ InsertNeighbor(point_id, insert_position, reference_index, distance);
+ }
+ } // for reference_index
+
+ best_dist_so_far = neighbor_distances_(knns_ - 1, point_id);
+ }
+ else
+ {
+ // We'll order the computation by distance.
+ double left_distance = SortPolicy::BestPointToNodeDistance(point,
+ reference_node->left());
+ double right_distance = SortPolicy::BestPointToNodeDistance(point,
+ reference_node->right());
+
+ // Recurse in the best direction first.
+ if (SortPolicy::IsBetter(left_distance, right_distance))
+ {
+ if (SortPolicy::IsBetter(best_dist_so_far, left_distance))
+ number_of_prunes_++; // Prune; no possibility of finding a better point.
+ else
+ ComputeSingleNeighborsRecursion_(point_id, point,
+ reference_node->left(), best_dist_so_far);
+
+ if (SortPolicy::IsBetter(best_dist_so_far, right_distance))
+ number_of_prunes_++; // Prune; no possibility of finding a better point.
+ else
+ ComputeSingleNeighborsRecursion_(point_id, point,
+ reference_node->right(), best_dist_so_far);
+
+ }
+ else
+ {
+ if (SortPolicy::IsBetter(best_dist_so_far, right_distance))
+ number_of_prunes_++; // Prune; no possibility of finding a better point.
+ else
+ ComputeSingleNeighborsRecursion_(point_id, point,
+ reference_node->right(), best_dist_so_far);
+
+ if (SortPolicy::IsBetter(best_dist_so_far, left_distance))
+ number_of_prunes_++; // Prune; no possibility of finding a better point.
+ else
+ ComputeSingleNeighborsRecursion_(point_id, point,
+ reference_node->left(), best_dist_so_far);
+ }
+ }
+}
+
+/**
+ * Computes the best neighbors and stores them in resulting_neighbors and
+ * distances.
+ */
+template<typename Kernel, typename SortPolicy>
+void NeighborSearch<Kernel, SortPolicy>::ComputeNeighbors(
+ arma::Mat<size_t>& resulting_neighbors,
+ arma::mat& distances)
+{
+ Timers::StartTimer("neighbor_search/computing_neighbors");
+ if (naive_)
+ {
+ // Run the base case computation on all nodes
+ if (query_tree_)
+ ComputeBaseCase_(query_tree_, reference_tree_);
+ else
+ ComputeBaseCase_(reference_tree_, reference_tree_);
+ }
+ else
+ {
+ if (dual_mode_)
+ {
+ // Start on the root of each tree
+ if (query_tree_)
+ {
+ ComputeDualNeighborsRecursion_(query_tree_, reference_tree_,
+ SortPolicy::BestNodeToNodeDistance(query_tree_, reference_tree_));
+ }
+ else
+ {
+ ComputeDualNeighborsRecursion_(reference_tree_, reference_tree_,
+ SortPolicy::BestNodeToNodeDistance(reference_tree_,
+ reference_tree_));
+ }
+ }
+ else
+ {
+ size_t chunk = queries_.n_cols / 10;
+
+ for (size_t i = 0; i < 10; i++)
+ {
+ for (size_t j = 0; j < chunk; j++)
+ {
+ arma::vec point = queries_.unsafe_col(i * chunk + j);
+ double best_dist_so_far = SortPolicy::WorstDistance();
+ ComputeSingleNeighborsRecursion_(i * chunk + j, point,
+ reference_tree_, best_dist_so_far);
+ }
+ }
+
+ for (size_t i = 0; i < queries_.n_cols % 10; i++)
+ {
+ size_t ind = (queries_.n_cols / 10) * 10 + i;
+ arma::vec point = queries_.unsafe_col(ind);
+ double best_dist_so_far = SortPolicy::WorstDistance();
+ ComputeSingleNeighborsRecursion_(ind, point, reference_tree_,
+ best_dist_so_far);
+ }
+ }
+ }
+
+ Timers::StopTimer("neighbor_search/computing_neighbors");
+
+ // We need to initialize the results list before filling it
+ resulting_neighbors.set_size(neighbor_indices_.n_rows,
+ neighbor_indices_.n_cols);
+ distances.set_size(neighbor_distances_.n_rows, neighbor_distances_.n_cols);
+
+ // We need to map the indices back from how they have been permuted
+ if (query_tree_ != NULL)
+ {
+ for (size_t i = 0; i < neighbor_indices_.n_cols; i++)
+ {
+ for (size_t k = 0; k < neighbor_indices_.n_rows; k++)
+ {
+ resulting_neighbors(k, old_from_new_queries_[i]) =
+ old_from_new_references_[neighbor_indices_(k, i)];
+ distances(k, old_from_new_queries_[i]) = neighbor_distances_(k, i);
+ }
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < neighbor_indices_.n_cols; i++)
+ {
+ for (size_t k = 0; k < neighbor_indices_.n_rows; k++)
+ {
+ resulting_neighbors(k, old_from_new_references_[i]) =
+ old_from_new_references_[neighbor_indices_(k, i)];
+ distances(k, old_from_new_references_[i]) = neighbor_distances_(k, i);
+ }
+ }
+ }
+} // ComputeNeighbors
+
+/***
+ * Helper function to insert a point into the neighbors and distances matrices.
+ *
+ * @param query_index Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+template<typename Kernel, typename SortPolicy>
+void NeighborSearch<Kernel, SortPolicy>::InsertNeighbor(size_t query_index,
+ size_t pos,
+ size_t neighbor,
+ double distance)
+{
+ // We only memmove() if there is actually a need to shift something.
+ if (pos < (knns_ - 1))
+ {
+ int len = (knns_ - 1) - pos;
+ memmove(neighbor_distances_.colptr(query_index) + (pos + 1),
+ neighbor_distances_.colptr(query_index) + pos,
+ sizeof(double) * len);
+ memmove(neighbor_indices_.colptr(query_index) + (pos + 1),
+ neighbor_indices_.colptr(query_index) + pos,
+ sizeof(size_t) * len);
+ }
+
+ // Now put the new information in the right index.
+ neighbor_distances_(pos, query_index) = distance;
+ neighbor_indices_(pos, query_index) = neighbor;
+}
+
+#endif
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -9,14 +9,16 @@
using namespace mlpack::neighbor;
size_t FurthestNeighborSort::SortDistance(const arma::vec& list,
- double new_distance) {
+ double new_distance)
+{
// The first element in the list is the nearest neighbor. We only want to
// insert if the new distance is greater than the last element in the list.
if (new_distance < list[list.n_elem - 1])
return (size_t() - 1); // Do not insert.
// Search from the beginning. This may not be the best way.
- for (size_t i = 0; i < list.n_elem; i++) {
+ for (size_t i = 0; i < list.n_elem; i++)
+ {
if (new_distance >= list[i])
return i;
}
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -5,8 +5,8 @@
* Implementation of the SortPolicy class for NeighborSearch; in this case, the
* furthest neighbors are those that are most important.
*/
-#ifndef __MLPACK_NEIGHBOR_FURTHEST_NEIGHBOR_SORT_HPP
-#define __MLPACK_NEIGHBOR_FURTHEST_NEIGHBOR_SORT_HPP
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_HPP
#include <mlpack/core.h>
@@ -19,7 +19,8 @@
* minimum distance is the best (so, when used with NeighborSearch, the output
* is furthest neighbors).
*/
-class FurthestNeighborSort {
+class FurthestNeighborSort
+{
public:
/**
* Return the index in the vector where the new distance should be inserted,
@@ -46,7 +47,8 @@
*
* @return bool indicating whether or not (value > ref).
*/
- static inline bool IsBetter(const double value, const double ref) {
+ static inline bool IsBetter(const double value, const double ref)
+ {
return (value > ref);
}
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -1,12 +1,12 @@
/***
- * @file nearest_neighbor_sort_impl.hpp
+ * @file furthest_neighbor_sort_impl.hpp
* @author Ryan Curtin
*
* Implementation of templated methods for the FurthestNeighborSort SortPolicy
* class for the NeighborSearch class.
*/
-#ifndef __MLPACK_NEIGHBOR_FURTHEST_NEIGHBOR_SORT_IMPL_HPP
-#define __MLPACK_NEIGHBOR_FURTHEST_NEIGHBOR_SORT_IMPL_HPP
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_IMPL_HPP
namespace mlpack {
namespace neighbor {
@@ -14,7 +14,8 @@
template<typename TreeType>
double FurthestNeighborSort::BestNodeToNodeDistance(
const TreeType* query_node,
- const TreeType* reference_node) {
+ const TreeType* reference_node)
+{
// This is not implemented yet for the general case because the trees do not
// accept arbitrary distance metrics.
return query_node->bound().MaxDistance(reference_node->bound());
@@ -23,7 +24,8 @@
template<typename TreeType>
double FurthestNeighborSort::BestPointToNodeDistance(
const arma::vec& point,
- const TreeType* reference_node) {
+ const TreeType* reference_node)
+{
// This is not implemented yet for the general case because the trees do not
// accept arbitrary distance metrics.
return reference_node->bound().MaxDistance(point);
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -9,14 +9,16 @@
using namespace mlpack::neighbor;
size_t NearestNeighborSort::SortDistance(const arma::vec& list,
- double new_distance) {
+ double new_distance)
+{
// The first element in the list is the nearest neighbor. We only want to
// insert if the new distance is less than the last element in the list.
if (new_distance > list[list.n_elem - 1])
return (size_t() - 1); // Do not insert.
// Search from the beginning. This may not be the best way.
- for (size_t i = 0; i < list.n_elem; i++) {
+ for (size_t i = 0; i < list.n_elem; i++)
+ {
if (new_distance <= list[i])
return i;
}
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -1,12 +1,12 @@
-/***
+/**
* @file nearest_neighbor_sort.hpp
* @author Ryan Curtin
*
* Implementation of the SortPolicy class for NeighborSearch; in this case, the
* nearest neighbors are those that are most important.
*/
-#ifndef __MLPACK_NEIGHBOR_NEAREST_NEIGHBOR_SORT_HPP
-#define __MLPACK_NEIGHBOR_NEAREST_NEIGHBOR_SORT_HPP
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_SORT_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_SORT_HPP
#include <mlpack/core.h>
@@ -23,7 +23,8 @@
* SortPolicy. All of the methods implemented here must be implemented by any
* other SortPolicy classes.
*/
-class NearestNeighborSort {
+class NearestNeighborSort
+{
public:
/**
* Return the index in the vector where the new distance should be inserted,
@@ -50,7 +51,8 @@
*
* @return bool indicating whether or not (value < ref).
*/
- static inline bool IsBetter(const double value, const double ref) {
+ static inline bool IsBetter(const double value, const double ref)
+ {
return (value < ref);
}
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -1,4 +1,4 @@
-/***
+/**
* @file nearest_neighbor_sort_impl.hpp
* @author Ryan Curtin
*
@@ -14,7 +14,8 @@
template<typename TreeType>
double NearestNeighborSort::BestNodeToNodeDistance(
const TreeType* query_node,
- const TreeType* reference_node) {
+ const TreeType* reference_node)
+{
// This is not implemented yet for the general case because the trees do not
// accept arbitrary distance metrics.
return query_node->bound().MinDistance(reference_node->bound());
@@ -23,7 +24,8 @@
template<typename TreeType>
double NearestNeighborSort::BestPointToNodeDistance(
const arma::vec& point,
- const TreeType* reference_node) {
+ const TreeType* reference_node)
+{
// This is not implemented yet for the general case because the trees do not
// accept arbitrary distance metrics.
return reference_node->bound().MinDistance(point);
Deleted: mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.h 2011-11-24 02:26:11 UTC (rev 10377)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.h 2011-11-24 02:42:06 UTC (rev 10378)
@@ -1,44 +0,0 @@
-/**
- * @file typedef.h
- * @author Ryan Curtin
- *
- * Simple typedefs describing template instantiations of the NeighborSearch
- * class which are commonly used. This is meant to be included by
- * neighbor_search.h but is a separate file for simplicity.
- */
-#ifndef __MLPACK_NEIGHBOR_SEARCH_TYPEDEF_H
-#define __MLPACK_NEIGHBOR_SEARCH_TYPEDEF_H
-
-// In case someone included this directly.
-#include "neighbor_search.h"
-
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include "sort_policies/nearest_neighbor_sort.hpp"
-#include "sort_policies/furthest_neighbor_sort.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-/**
- * The AllkNN class is the all-k-nearest-neighbors method. It returns squared
- * L2 distances (squared Euclidean distances) for each of the k nearest
- * neighbors. Squared distances are used because they are slightly faster than
- * non-squared distances (they have one fewer call to sqrt()).
- */
-typedef NeighborSearch<metric::SquaredEuclideanDistance, NearestNeighborSort>
- AllkNN;
-
-/**
- * The AllkFN class is the all-k-furthest-neighbors method. It returns squared
- * L2 distances (squared Euclidean distances) for each of the k furthest
- * neighbors. Squared distances are used because they are slightly faster than
- * non-squared distances (they have one fewer call to sqrt()).
- */
-typedef NeighborSearch<metric::SquaredEuclideanDistance, FurthestNeighborSort>
- AllkFN;
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif
Copied: mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp (from rev 10353, mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.h)
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp 2011-11-24 02:42:06 UTC (rev 10378)
@@ -0,0 +1,44 @@
+/**
+ * @file typedef.hpp
+ * @author Ryan Curtin
+ *
+ * Simple typedefs describing template instantiations of the NeighborSearch
+ * class which are commonly used. This is meant to be included by
+ * neighbor_search.h but is a separate file for simplicity.
+ */
+#ifndef __MLPACK_NEIGHBOR_SEARCH_TYPEDEF_H
+#define __MLPACK_NEIGHBOR_SEARCH_TYPEDEF_H
+
+// In case someone included this directly.
+#include "neighbor_search.hpp"
+
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include "sort_policies/nearest_neighbor_sort.hpp"
+#include "sort_policies/furthest_neighbor_sort.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * The AllkNN class is the all-k-nearest-neighbors method. It returns squared
+ * L2 distances (squared Euclidean distances) for each of the k nearest
+ * neighbors. Squared distances are used because they are slightly faster than
+ * non-squared distances (they have one fewer call to sqrt()).
+ */
+typedef NeighborSearch<metric::SquaredEuclideanDistance, NearestNeighborSort>
+ AllkNN;
+
+/**
+ * The AllkFN class is the all-k-furthest-neighbors method. It returns squared
+ * L2 distances (squared Euclidean distances) for each of the k furthest
+ * neighbors. Squared distances are used because they are slightly faster than
+ * non-squared distances (they have one fewer call to sqrt()).
+ */
+typedef NeighborSearch<metric::SquaredEuclideanDistance, FurthestNeighborSort>
+ AllkFN;
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif
More information about the mlpack-svn
mailing list