[mlpack-svn] r10499 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Dec 1 21:59:07 EST 2011
Author: rcurtin
Date: 2011-12-01 21:59:07 -0500 (Thu, 01 Dec 2011)
New Revision: 10499
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp
Log:
Overhaul API for NeighborSearch (it probably is not complete) and remove
dependence on CLI to help fix #150.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2011-12-02 01:00:03 UTC (rev 10498)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2011-12-02 02:59:07 UTC (rev 10499)
@@ -16,6 +16,7 @@
using namespace std;
using namespace mlpack;
using namespace mlpack::neighbor;
+using namespace mlpack::tree;
// Information about the program itself.
PROGRAM_INFO("All K-Furthest-Neighbors",
@@ -25,107 +26,158 @@
"and query set."
"\n\n"
"For example, the following will calculate the 5 furthest neighbors of each"
- "point in 'input.csv' and store the results in 'output.csv':"
+ "point in 'input.csv' and store the distances in 'distances.csv' and the "
+ "neighbors in 'neighbors.csv':"
"\n\n"
- "$ allknn --neighbor_search/k=5 --reference_file=input.csv\n"
- " --output_file=output.csv", "neighbor_search");
+ "$ allkfn --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
+ " --neighbors_file=neighbors.csv"
+ "\n\n"
+ "The output files are organized such that row i and column j in the "
+ "neighbors output file corresponds to the index of the point in the "
+ "reference set which is the i'th furthest neighbor from the point in the "
+ "query set with index j. Row i and column j in the distances output file "
+ "corresponds to the distance between those two points.", "");
// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "CSV file containing the reference dataset.",
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
"");
-PARAM_STRING("query_file", "CSV file containing query points (optional).",
- "", "");
-PARAM_STRING_REQ("output_file", "File to output CSV-formatted results into.",
- "");
+PARAM_STRING("query_file", "File containing query points (optional).", "", "");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "");
+PARAM_INT("leaf_size", "Leaf size for tree building.", "", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+ "dual-tree search.", "");
+PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "");
+
int main(int argc, char *argv[])
{
// Give CLI the command line parameters the user passed in.
CLI::ParseCommandLine(argc, argv);
- string reference_file = CLI::GetParam<string>("reference_file");
- string output_file = CLI::GetParam<string>("output_file");
+ // Get all the parameters.
+ string referenceFile = CLI::GetParam<string>("reference_file");
+ string outputFile = CLI::GetParam<string>("output_file");
- arma::mat reference_data;
+ string distancesFile = CLI::GetParam<string>("distances_file");
+ string neighborsFile = CLI::GetParam<string>("neighbors_file");
- arma::Mat<size_t> neighbors;
- arma::mat distances;
+ int leafSize = CLI::GetParam<int>("leaf_size");
- if (!data::Load(reference_file.c_str(), reference_data))
- Log::Fatal << "Reference file " << reference_file << "not found." << endl;
+ size_t k = CLI::GetParam<int>("k");
- Log::Info << "Loaded reference data from " << reference_file << endl;
+ bool naive = CLI::HasParam("naive");
+ bool singleMode = CLI::HasParam("single_mode");
+ arma::mat referenceData;
+ if (!data::Load(referenceFile.c_str(), referenceData))
+ Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
+
+ Log::Info << "Loaded reference data from " << referenceFile << endl;
+
// Sanity check on k value: must be greater than 0, must be less than the
// number of reference points.
- size_t k = CLI::GetParam<int>("neighbor_search/k");
- if ((k <= 0) || (k >= reference_data.n_cols))
+ if ((k <= 0) || (k >= referenceData.n_cols))
{
Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
Log::Fatal << "than the number of reference points (";
- Log::Fatal << reference_data.n_cols << ")." << endl;
+ Log::Fatal << referenceData.n_cols << ")." << endl;
}
// Sanity check on leaf size.
- if (CLI::GetParam<int>("tree/leaf_size") <= 0)
+ if (leafSize < 0)
{
- Log::Fatal << "Invalid leaf size: "
- << CLI::GetParam<int>("allknn/leaf_size") << endl;
+ Log::Fatal << "Invalid leaf size: " << leafSize << ". Must be greater "
+ "than or equal to 0." << endl;
}
+ // Naive mode overrides single mode.
+ if (singleMode && naive)
+ {
+ Log::Warn << "--single_mode ignored because --naive is present." << endl;
+ }
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
AllkFN* allkfn = NULL;
+ std::vector<size_t> oldFromNewRefs;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timers::StartTimer("neighbor_search/tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >
+ refTree(referenceData, oldFromNewRefs);
+
+ Timers::StopTimer("neighbor_search/tree_building");
+
+ std::vector<size_t> oldFromNewQueries;
+
if (CLI::GetParam<string>("query_file") != "")
{
- string query_file = CLI::GetParam<string>("query_file");
- arma::mat query_data;
+ string queryFile = CLI::GetParam<string>("query_file");
+ arma::mat queryData;
- if (!data::Load(query_file.c_str(), query_data))
- Log::Fatal << "Query file " << query_file << " not found" << endl;
+ if (!data::Load(queryFile.c_str(), queryData))
+ Log::Fatal << "Query file " << queryFile << " not found" << endl;
- Log::Info << "Query data loaded from " << query_file << endl;
+ Log::Info << "Query data loaded from " << queryFile << endl;
- Log::Info << "Building query and reference trees..." << endl;
- allkfn = new AllkFN(query_data, reference_data);
+ Log::Info << "Building query tree..." << endl;
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Timers::StartTimer("neighbor_search/tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >
+ queryTree(queryData, oldFromNewRefs);
+
+ Timers::StopTimer("neighbor_search/tree_building");
+
+ allkfn = new AllkFN(referenceData, queryData, naive, singleMode, 20,
+ &refTree, &queryTree);
+
+ Log::Info << "Tree built." << endl;
}
else
{
- Log::Info << "Building reference tree..." << endl;
- allkfn = new AllkFN(reference_data);
+ allkfn = new AllkFN(referenceData, naive, singleMode, 20, &refTree);
+
+ Log::Info << "Trees built." << endl;
}
- Log::Info << "Tree(s) built." << endl;
-
Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allkfn->ComputeNeighbors(neighbors, distances);
+ allkfn->ComputeNeighbors(k, neighbors, distances);
Log::Info << "Neighbors computed." << endl;
- Log::Info << "Exporting results..." << endl;
- // Should be using data::Save or a related function instead of being written
- // by hand.
- try
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
+
+ arma::mat distancesOut(distances.n_rows, distances.n_cols);
+ arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+
+ // Do the actual remapping.
+ for (size_t i = 0; i < distances.n_cols; i++)
{
- ofstream out(output_file.c_str());
+ // Map distances (copy a column).
+ distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
- for (size_t col = 0; col < neighbors.n_cols; col++)
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; j++)
{
- out << col << ", ";
- for (size_t j = 0; j < (k - 1) /* last is special case */; j++)
- {
- out << neighbors(j, col) << ", " << distances(j, col) << ", ";
- }
- out << neighbors((k - 1), col) << ", " << distances((k - 1), col) << endl;
+ neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
}
-
- out.close();
}
- catch (exception& e)
- {
- Log::Fatal << "Error while opening " << output_file << ": " << e.what()
- << endl;
- }
+ // Save output.
+ data::Save(distancesFile, distances);
+ data::Save(neighborsFile, neighbors);
+
delete allkfn;
}
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2011-12-02 01:00:03 UTC (rev 10498)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2011-12-02 02:59:07 UTC (rev 10499)
@@ -16,6 +16,7 @@
using namespace std;
using namespace mlpack;
using namespace mlpack::neighbor;
+using namespace mlpack::tree;
// Information about the program itself.
PROGRAM_INFO("All K-Nearest-Neighbors",
@@ -25,107 +26,160 @@
"and query set."
"\n\n"
"For example, the following will calculate the 5 nearest neighbors of each"
- "point in 'input.csv' and store the results in 'output.csv':"
+ "point in 'input.csv' and store the distances in 'distances.csv' and the "
+ "neighbors in 'neighbors.csv':"
"\n\n"
- "$ allknn --neighbor_search/k=5 --reference_file=input.csv\n"
- " --output_file=output.csv", "neighbor_search");
+ "$ allknn --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
+ " --neighbors_file=neighbors.csv"
+ "\n\n"
+ "The output files are organized such that row i and column j in the "
+ "neighbors output file corresponds to the index of the point in the "
+ "reference set which is the i'th nearest neighbor from the point in the "
+ "query set with index j. Row i and column j in the distances output file "
+ "corresponds to the distance between those two points.", "");
// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "CSV file containing the reference dataset.",
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
"");
-PARAM_STRING("query_file", "CSV file containing query points (optional).",
- "", "");
-PARAM_STRING_REQ("output_file", "File to output CSV-formatted results into.",
- "");
+PARAM_STRING("query_file", "File containing query points (optional).", "", "");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "");
+PARAM_INT("leaf_size", "Leaf size for tree building.", "", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+ "dual-tree search.", "");
+PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "");
+
int main(int argc, char *argv[])
{
// Give CLI the command line parameters the user passed in.
CLI::ParseCommandLine(argc, argv);
- string reference_file = CLI::GetParam<string>("reference_file");
- string output_file = CLI::GetParam<string>("output_file");
+ // Get all the parameters.
+ string referenceFile = CLI::GetParam<string>("reference_file");
+ string outputFile = CLI::GetParam<string>("output_file");
- arma::mat reference_data;
+ string distancesFile = CLI::GetParam<string>("distances_file");
+ string neighborsFile = CLI::GetParam<string>("neighbors_file");
- arma::Mat<size_t> neighbors;
- arma::mat distances;
+ int leafSize = CLI::GetParam<int>("leaf_size");
- if (!data::Load(reference_file.c_str(), reference_data))
- Log::Fatal << "Reference file " << reference_file << " not found." << endl;
+ size_t k = CLI::GetParam<int>("k");
- Log::Info << "Loaded reference data from " << reference_file << endl;
+ bool naive = CLI::HasParam("naive");
+ bool singleMode = CLI::HasParam("single_mode");
+ arma::mat referenceData;
+ if (!data::Load(referenceFile.c_str(), referenceData))
+ Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
+
+ Log::Info << "Loaded reference data from " << referenceFile << endl;
+
// Sanity check on k value: must be greater than 0, must be less than the
// number of reference points.
- size_t k = CLI::GetParam<int>("neighbor_search/k");
- if ((k <= 0) || (k >= reference_data.n_cols))
+ if ((k <= 0) || (k >= referenceData.n_cols))
{
Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
Log::Fatal << "than the number of reference points (";
- Log::Fatal << reference_data.n_cols << ")." << endl;
+ Log::Fatal << referenceData.n_cols << ")." << endl;
}
// Sanity check on leaf size.
- if (CLI::GetParam<int>("tree/leaf_size") <= 0)
+ if (leafSize < 0)
{
- Log::Fatal << "Invalid leaf size: "
- << CLI::GetParam<int>("allknn/leaf_size") << endl;
+ Log::Fatal << "Invalid leaf size: " << leafSize << ". Must be greater "
+ "than or equal to 0." << endl;
}
+ // Naive mode overrides single mode.
+ if (singleMode && naive)
+ {
+ Log::Warn << "--single_mode ignored because --naive is present." << endl;
+ }
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ // Because we may construct it differently, we need a pointer.
AllkNN* allknn = NULL;
+ // Mappings for when we build the tree.
+ std::vector<size_t> oldFromNewRefs;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timers::StartTimer("neighbor_search/tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
+ refTree(referenceData, oldFromNewRefs, leafSize);
+
+ Timers::StopTimer("neighbor_search/tree_building");
+
+ std::vector<size_t> oldFromNewQueries;
+
if (CLI::GetParam<string>("query_file") != "")
{
- string query_file = CLI::GetParam<string>("query_file");
- arma::mat query_data;
+ string queryFile = CLI::GetParam<string>("query_file");
+ arma::mat queryData;
- if (!data::Load(query_file.c_str(), query_data))
- Log::Fatal << "Query file " << query_file << " not found" << endl;
+ if (!data::Load(queryFile.c_str(), queryData))
+ Log::Fatal << "Query file " << queryFile << " not found" << endl;
- Log::Info << "Query data loaded from " << query_file << endl;
+ Log::Info << "Query data loaded from " << queryFile << endl;
- Log::Info << "Building query and reference trees..." << endl;
- allknn = new AllkNN(query_data, reference_data);
+ Log::Info << "Building query tree..." << endl;
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Timers::StartTimer("neighbor_search/tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
+ queryTree(queryData, oldFromNewRefs, leafSize);
+
+ Timers::StopTimer("neighbor_search/tree_building");
+
+ allknn = new AllkNN(referenceData, queryData, naive, singleMode, 20,
+ &refTree, &queryTree);
+
+ Log::Info << "Tree built." << endl;
}
else
{
- Log::Info << "Building reference tree..." << endl;
- allknn = new AllkNN(reference_data);
+ allknn = new AllkNN(referenceData, naive, singleMode, 20, &refTree);
+
+ Log::Info << "Trees built." << endl;
}
- Log::Info << "Tree(s) built." << endl;
-
Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allknn->ComputeNeighbors(neighbors, distances);
+ allknn->ComputeNeighbors(k, neighbors, distances);
Log::Info << "Neighbors computed." << endl;
- Log::Info << "Exporting results..." << endl;
- // Should be using data::Save or a related function instead of being written
- // by hand.
- try
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
+
+ arma::mat distancesOut(distances.n_rows, distances.n_cols);
+ arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+
+ // Do the actual remapping.
+ for (size_t i = 0; i < distances.n_cols; i++)
{
- ofstream out(output_file.c_str());
+ // Map distances (copy a column).
+ distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
- for (size_t col = 0; col < neighbors.n_cols; col++)
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; j++)
{
- out << col << ", ";
- for (size_t j = 0; j < (k - 1) /* last is special case */; j++)
- {
- out << neighbors(j, col) << ", " << distances(j, col) << ", ";
- }
- out << neighbors((k - 1), col) << ", " << distances((k - 1), col) << endl;
+ neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
}
-
- out.close();
}
- catch (exception& e)
- {
- Log::Fatal << "Error while opening " << output_file << ": " << e.what()
- << endl;
- }
+ // Save output.
+ data::Save(distancesFile, distances);
+ data::Save(neighborsFile, neighbors);
+
delete allknn;
}
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2011-12-02 01:00:03 UTC (rev 10498)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2011-12-02 02:59:07 UTC (rev 10499)
@@ -22,15 +22,30 @@
* all-nearest-neighbors and all-furthest-neighbors
* searches. */ {
-// Define CLI parameters for the NeighborSearch class.
-PARAM_MODULE("neighbor_search",
- "Parameters for the distance-based neighbor search.");
-PARAM_INT("k", "Number of neighbors to search for.", "neighbor_search", 5);
-PARAM_FLAG("single_mode", "If set, use single-tree mode (instead of "
- "dual-tree).", "neighbor_search");
-PARAM_FLAG("naive_mode", "If set, use naive computations (no trees). This "
- "overrides the single_mode flag.", "neighbor_search");
+/**
+ * Extra data for each node in the tree. For neighbor searches, each node only
+ * needs to store a bound on neighbor distances.
+ */
+template<typename SortPolicy>
+class QueryStat
+{
+ private:
+ //! The bound on the node's neighbor distances.
+ double bound;
+ public:
+ /**
+ * Initialize the statistic with the worst possible distance according to
+ * our sorting policy.
+ */
+ QueryStat() : bound(SortPolicy::WorstDistance()) { }
+
+ //! Get the bound.
+ const double Bound() const { return bound; }
+ //! Modify the bound.
+ double& Bound() { return bound; }
+};
+
/**
* The NeighborSearch class is a template class for performing distance-based
* neighbor searches. It takes a query dataset and a reference dataset (or just
@@ -40,120 +55,88 @@
* reference dataset, and if that constructor is used, the given reference
* dataset is also used as the query dataset.
*
- * The template parameters Kernel and SortPolicy define the distance function
- * used and the sort function used. More information on those classes can be
- * found in the kernel::ExampleKernel class and the NearestNeighborSort class.
+ * The template parameters SortPolicy and Metric define the sort function used
+ * and the metric (distance function) used. More information on those classes
+ * can be found in the NearestNeighborSort class and the kernel::ExampleKernel
+ * class.
*
- * This class has several parameters configurable by the CLI interface:
- *
- * @param neighbor_search/k Parameters for the distance-based neighbor search.
- * @param neighbor_search/single_mode If set, single-tree mode will be used
- * (instead of dual-tree mode).
- * @param neighbor_search/naive_mode If set, naive computation will be used (no
- * trees). This overrides the single_mode flag.
- *
- * @tparam Kernel The kernel function; see kernel::ExampleKernel.
* @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
+ * @tparam MetricType The metric to use for computation.
+ * @tparam TreeType The tree type to use.
*/
-template<typename MetricType = mlpack::metric::SquaredEuclideanDistance,
- typename SortPolicy = NearestNeighborSort>
+template<typename SortPolicy = NearestNeighborSort,
+ typename MetricType = mlpack::metric::SquaredEuclideanDistance,
+ typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
+ QueryStat<SortPolicy> > >
class NeighborSearch
{
- /**
- * Extra data for each node in the tree. For neighbor searches, each node
- * only needs to store a bound on neighbor distances.
- */
- class QueryStat
- {
- public:
- //! The bound on the node's neighbor distances.
- double bound_;
-
- /**
- * Initialize the statistic with the worst possible distance according to
- * our sorting policy.
- */
- QueryStat() : bound_(SortPolicy::WorstDistance()) { }
- };
-
- /**
- * Simple typedef for the trees, which use a bound and a QueryStat (to store
- * distances for each node). The bound should be configurable...
- */
- typedef tree::BinarySpaceTree<bound::HRectBound<2>, QueryStat> TreeType;
-
- private:
- //! Reference dataset.
- arma::mat references_;
- //! Query dataset (may not be given).
- arma::mat queries_;
-
- //! Instantiation of kernel.
- MetricType kernel_;
-
- //! Pointer to the root of the reference tree.
- TreeType* reference_tree_;
- //! Pointer to the root of the query tree (might not exist).
- TreeType* query_tree_;
-
- //! Permutations of query points during tree building.
- std::vector<size_t> old_from_new_queries_;
- //! Permutations of reference points during tree building.
- std::vector<size_t> old_from_new_references_;
-
- //! Indicates if O(n^2) naive search is being used.
- bool naive_;
- //! Indicates if dual-tree search is being used (opposed to single-tree).
- bool dual_mode_;
-
- //! Number of points in a leaf.
- size_t leaf_size_;
-
- //! Number of neighbors to compute.
- size_t knns_;
-
- //! Total number of pruned nodes during the neighbor search.
- size_t number_of_prunes_;
-
- //! The distance to the candidate nearest neighbor for each query
- arma::mat neighbor_distances_;
-
- //! The indices of the candidate nearest neighbor for each query
- arma::Mat<size_t> neighbor_indices_;
-
public:
/**
* Initialize the NeighborSearch object, passing both a query and reference
- * dataset. An initialized distance metric can be given, for cases where the
- * metric has internal data (i.e. the distance::MahalanobisDistance class).
+ * dataset. Optionally, already-built trees can be passed, for the case where
+ * a special tree-building procedure is needed. If referenceTree is given, it
+ * is assumed that the points in referenceTree correspond to the points in
+ * referenceSet. The same is true for queryTree and querySet. An initialized
+ * distance metric can be given, for cases where the metric has internal data
+ * (i.e. the distance::MahalanobisDistance class).
*
- * @param queries_in Set of query points.
- * @param references_in Set of reference points.
- * @param alias_matrix If true, alias the passed matrices instead of copying
- * them. While this lowers memory footprint and computational load, the
- * points in the matrices will be rearranged during the tree-building
- * process! Defaults to false.
- * @param kernel An optional instance of the Kernel class.
+ * If naive mode is being used and a pre-built tree is given, it may not work:
+ * naive mode operates by building a one-node tree (the root node holds all
+ * the points). If that condition is not satisfied with the pre-built tree,
+ * then naive mode will not work.
+ *
+ * @param referenceSet Set of reference points.
+ * @param querySet Set of query points.
+ * @param naive If true, O(n^2) naive search will be used (as opposed to
+ * dual-tree search). This overrides singleMode (if it is set to true).
+ * @param singleMode If true, single-tree search will be used (as opposed to
+ * dual-tree search).
+ * @param leafSize Leaf size for tree construction (ignored if tree is given).
+ * @param referenceTree Optionally pass a pre-built tree for the reference
+ * set.
+ * @param queryTree Optionally pass a pre-built tree for the query set.
+ * @param metric An optional instance of the MetricType class.
*/
- NeighborSearch(arma::mat& queries_in, arma::mat& references_in,
- bool alias_matrix = false, MetricType kernel = MetricType());
+ NeighborSearch(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const size_t leafSize = 20,
+ TreeType* referenceTree = NULL,
+ TreeType* queryTree = NULL,
+ const MetricType metric = MetricType());
/**
- * Initialize the NeighborSearch object, passing only one dataset. In this
- * case, the query dataset is equivalent to the reference dataset, with one
- * caveat: for any point, the returned list of neighbors will not include
- * itself. An initialized distance metric can be given, for cases where the
- * metric has internal data (i.e. the distance::MahalanobisDistance class).
+ * Initialize the NeighborSearch object, passing only one dataset, which is
+ * used as both the query and the reference dataset. Optionally, an
+ * already-built tree can be passed, for the case where a special
+ * tree-building procedure is needed. If referenceTree is given, it is
+ * assumed that the points in referenceTree correspond to the points in
+ * referenceSet. An initialized distance metric can be given, for cases where
+ * the metric has internal data (i.e. the distance::MahalanobisDistance
+ * class).
*
- * @param references_in Set of reference points.
- * @param alias_matrix If true, alias the passed matrices instead of copying
- * them. While this lowers memory footprint and computational load, the
- * points in the matrices will be rearranged during the tree-building
- * process! Defaults to false.
- * @param kernel An optional instance of the Kernel class.
+ * If naive mode is being used and a pre-built tree is given, it may not work:
+ * naive mode operates by building a one-node tree (the root node holds all
+ * the points). If that condition is not satisfied with the pre-built tree,
+ * then naive mode will not work.
+ *
+ * @param referenceSet Set of reference points.
+ * @param naive If true, O(n^2) naive search will be used (as opposed to
+ * dual-tree search). This overrides singleMode (if it is set to true).
+ * @param singleMode If true, single-tree search will be used (as opposed to
+ * dual-tree search).
+ * @param leafSize Leaf size for tree construction (ignored if tree is given).
+ * @param referenceTree Optionally pass a pre-built tree for the reference
+ * set.
+ * @param metric An optional instance of the MetricType class.
*/
- NeighborSearch(arma::mat& references_in, bool alias_matrix = false,
- MetricType kernel = MetricType());
+ NeighborSearch(const arma::mat& referenceSet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const size_t leafSize = 20,
+ TreeType* referenceTree = NULL,
+ const MetricType metric = MetricType());
/**
* Delete the NeighborSearch object. The tree is the only member we are
@@ -167,70 +150,123 @@
* number of points in the query dataset and k is the number of neighbors
* being searched for.
*
- * The parameter k is set through the CLI interface, not in the arguments to
- * this method; this allows users to specify k on the command line
- * ("--neighbor_search/k"). See the CLI documentation for further information
- * on how to use this functionality.
- *
- * @param resulting_neighbors Matrix storing lists of neighbors for each query
+ * @param k Number of neighbors to search for.
+ * @param resultingNeighbors Matrix storing lists of neighbors for each query
* point.
* @param distances Matrix storing distances of neighbors for each query
* point.
*/
- void ComputeNeighbors(arma::Mat<size_t>& resulting_neighbors,
+ void ComputeNeighbors(const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances);
private:
/**
* Perform exhaustive computation between two leaves, comparing every node in
* the leaf to the other leaf to find the furthest neighbor. The
- * neighbor_indices_ and neighbor_distances_ matrices will be updated with the
- * changed information.
+ * neighbors and distances matrices will be updated with the changed
+ * information.
*
- * @param query_node Node in query tree. This should be a leaf
+ * @param queryNode Node in query tree. This should be a leaf
* (bottom-level).
- * @param reference_node Node in reference tree. This should be a leaf
+ * @param referenceNode Node in reference tree. This should be a leaf
* (bottom-level).
+ * @param neighbors List of neighbors for each point.
+ * @param distances List of distances for each point.
*/
- void ComputeBaseCase_(TreeType* query_node, TreeType* reference_node);
+ void ComputeBaseCase(TreeType* queryNode,
+ TreeType* referenceNode,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
/**
* Recurse down the trees, computing base case computations when the leaves
* are reached.
*
- * @param query_node Node in query tree.
- * @param reference_node Node in reference tree.
- * @param lower_bound The lower bound; if above this, we can prune.
+ * @param queryNode Node in query tree.
+ * @param referenceNode Node in reference tree.
+ * @param lowerBound The lower bound; if above this, we can prune.
+ * @param neighbors List of neighbors for each point.
+ * @param distances List of distances for each point.
*/
- void ComputeDualNeighborsRecursion_(TreeType* query_node,
- TreeType* reference_node,
- double lower_bound);
+ void ComputeDualNeighborsRecursion(TreeType* queryNode,
+ TreeType* referenceNode,
+ const double lowerBound,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
/**
* Perform a recursion only on the reference tree; the query point is given.
- * This method is similar to ComputeBaseCase_().
+ * This method is similar to ComputeBaseCase().
*
- * @param point_id Index of query point.
+ * @param pointId Index of query point.
* @param point The query point.
- * @param reference_node Reference node.
- * @param best_dist_so_far Best distance to a node so far -- used for pruning.
+ * @param referenceNode Reference node.
+ * @param bestDistSoFar Best distance to a node so far -- used for pruning.
+ * @param neighbors List of neighbors for each point.
+ * @param distances List of distances for each point.
*/
- void ComputeSingleNeighborsRecursion_(size_t point_id, arma::vec& point,
- TreeType* reference_node,
- double& best_dist_so_far);
+ void ComputeSingleNeighborsRecursion(const size_t pointId,
+ const arma::vec& point,
+ TreeType* referenceNode,
+ double& bestDistSoFar,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
/**
* Insert a point into the neighbors and distances matrices; this is a helper
* function.
*
- * @param query_index Index of point whose neighbors we are inserting into.
+ * @param queryIndex Index of point whose neighbors we are inserting into.
* @param pos Position in list to insert into.
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
+ * @param neighbors List of neighbors for each point.
+ * @param distances List of distances for each point.
*/
- void InsertNeighbor(size_t query_index, size_t pos, size_t neighbor,
- double distance);
+ void InsertNeighbor(const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
+ //! Copy of reference dataset (if we need it, because tree building modifies
+ //! it).
+ arma::mat referenceCopy;
+ //! Copy of query dataset (if we need it, because tree building modifies it).
+ arma::mat queryCopy;
+
+ //! Reference dataset.
+ const arma::mat& referenceSet;
+ //! Query dataset (may not be given).
+ const arma::mat& querySet;
+
+ //! Indicates if O(n^2) naive search is being used.
+ bool naive;
+ //! Indicates if single-tree search is being used (opposed to dual-tree).
+ bool singleMode;
+
+ //! Pointer to the root of the reference tree.
+ TreeType* referenceTree;
+ //! Pointer to the root of the query tree (might not exist).
+ TreeType* queryTree;
+
+ //! Indicates if we should free the reference tree at deletion time.
+ bool ownReferenceTree;
+ //! Indicates if we should free the query tree at deletion time.
+ bool ownQueryTree;
+
+ //! Instantiation of kernel.
+ MetricType metric;
+
+ //! Permutations of reference points during tree building.
+ std::vector<size_t> oldFromNewReferences;
+ //! Permutations of query points during tree building.
+ std::vector<size_t> oldFromNewQueries;
+
+ //! Total number of pruned nodes during the neighbor search.
+ size_t numberOfPrunes;
}; // class NeighborSearch
}; // namespace neighbor
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2011-12-02 01:00:03 UTC (rev 10498)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2011-12-02 02:59:07 UTC (rev 10499)
@@ -12,499 +12,594 @@
using namespace mlpack::neighbor;
-// We call an advanced constructor of arma::mat which allows us to alias a
-// matrix (if the user has asked for that).
-template<typename MetricType, typename SortPolicy>
-NeighborSearch<MetricType, SortPolicy>::NeighborSearch(arma::mat& queries_in,
- arma::mat& references_in,
- bool alias_matrix,
- MetricType kernel) :
- references_(references_in.memptr(), references_in.n_rows,
- references_in.n_cols, !alias_matrix),
- queries_(queries_in.memptr(), queries_in.n_rows, queries_in.n_cols,
- !alias_matrix),
- kernel_(kernel),
- naive_(CLI::GetParam<bool>("neighbor_search/naive_mode")),
- dual_mode_(!(naive_ || CLI::GetParam<bool>("neighbor_search/single_mode"))),
- number_of_prunes_(0)
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::
+NeighborSearch(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ TreeType* referenceTree,
+ TreeType* queryTree,
+ const MetricType metric) :
+ referenceCopy(referenceTree ? 0 : referenceSet),
+ queryCopy(queryTree ? 0 : querySet),
+ referenceSet(referenceTree ? referenceSet : referenceCopy),
+ querySet(queryTree ? querySet : queryCopy),
+ naive(naive),
+ singleMode(!naive && singleMode), // No single mode if naive.
+ referenceTree(referenceTree),
+ queryTree(queryTree),
+ ownReferenceTree(!referenceTree), // False if a tree was passed.
+ ownQueryTree(!queryTree), // False if a tree was passed.
+ metric(metric),
+ numberOfPrunes(0)
{
// C++11 will allow us to call out to other constructors so we can avoid this
// copypasta problem.
- // Get the leaf size; naive ensures that the entire tree is one node
- if (naive_)
- CLI::GetParam<int>("tree/leaf_size") =
- std::max(queries_.n_cols, references_.n_cols);
+ // We'll time tree building, but only if we are building trees.
+ if (!referenceTree || !queryTree)
+ Timers::StartTimer("neighbor_search/tree_building");
- // K-nearest neighbors initialization
- knns_ = CLI::GetParam<int>("neighbor_search/k");
+ if (!referenceTree)
+ {
+ // Construct as a naive object if we need to.
+ if (naive)
+ this->referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ referenceSet.n_cols /* everything in one leaf */);
+ else
+ this->referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ leafSize);
+ }
- // Initialize the list of nearest neighbor candidates
- neighbor_indices_.set_size(knns_, queries_.n_cols);
+ if (!queryTree)
+ {
+ // Construct as a naive object if we need to.
+ if (naive)
+ this->queryTree = new TreeType(queryCopy, oldFromNewQueries,
+ querySet.n_cols);
+ else
+ this->queryTree = new TreeType(queryCopy, oldFromNewQueries, leafSize);
+ }
- // Initialize the vector of upper bounds for each point.
- neighbor_distances_.set_size(knns_, queries_.n_cols);
- neighbor_distances_.fill(SortPolicy::WorstDistance());
+ // Stop the timer we started above (if we need to).
+ if (!referenceTree || !queryTree)
+ Timers::StopTimer("neighbor_search/tree_building");
+}
- // We'll time tree building
- Timers::StartTimer("neighbor_search/tree_building");
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::
+NeighborSearch(const arma::mat& referenceSet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ TreeType* referenceTree,
+ const MetricType metric) :
+ referenceCopy(referenceTree ? 0 : referenceSet),
+ referenceSet(referenceTree ? referenceSet : referenceCopy),
+ querySet(referenceTree ? referenceSet : referenceCopy),
+ naive(naive),
+ singleMode(!naive && singleMode), // No single mode if naive.
+ referenceTree(referenceTree),
+ queryTree(NULL),
+ ownReferenceTree(!referenceTree),
+ ownQueryTree(false), // Since it will be the same as referenceTree.
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // We'll time tree building, but only if we are building trees.
+ if (!referenceTree)
+ {
+ Timers::StartTimer("neighbor_search/tree_building");
- // This call makes each tree from a matrix, leaf size, and two arrays
- // that record the permutation of the data points
- query_tree_ = new TreeType(queries_, old_from_new_queries_);
- reference_tree_ = new TreeType(references_, old_from_new_references_);
+ // Construct as a naive object if we need to.
+ if (naive)
+ this->referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ referenceSet.n_cols /* everything in one leaf */);
+ else
+ this->referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ leafSize);
- // Stop the timer we started above
- Timers::StopTimer("neighbor_search/tree_building");
+ // Stop the timer we started above.
+ Timers::StopTimer("neighbor_search/tree_building");
+ }
}
-// We call an advanced constructor of arma::mat which allows us to alias a
-// matrix (if the user has asked for that).
-template<typename MetricType, typename SortPolicy>
-NeighborSearch<MetricType, SortPolicy>::NeighborSearch(arma::mat& references_in,
- bool alias_matrix,
- MetricType kernel) :
- references_(references_in.memptr(), references_in.n_rows,
- references_in.n_cols, !alias_matrix),
- queries_(references_.memptr(), references_.n_rows, references_.n_cols,
- false),
- kernel_(kernel),
- naive_(CLI::GetParam<bool>("neighbor_search/naive_mode")),
- dual_mode_(!(naive_ || CLI::GetParam<bool>("neighbor_search/single_mode"))),
- number_of_prunes_(0)
+/**
+ * The tree is the only member we may be responsible for deleting. The others
+ * will take care of themselves.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
{
- // Get the leaf size from the module
- if (naive_)
- CLI::GetParam<int>("tree/leaf_size") =
- std::max(queries_.n_cols, references_.n_cols);
+ if (ownReferenceTree)
+ delete referenceTree;
+ if (ownQueryTree)
+ delete queryTree;
+}
- // K-nearest neighbors initialization
- knns_ = CLI::GetParam<int>("neighbor_search/k");
+/**
+ * Computes the best neighbors and stores them in resultingNeighbors and
+ * distances.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeNeighbors(
+ const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
+ arma::mat& distances)
+{
+ Timers::StartTimer("neighbor_search/computing_neighbors");
- // Initialize the list of nearest neighbor candidates
- neighbor_indices_.set_size(knns_, references_.n_cols);
+ // If we have built the trees ourselves, then we will have to map all the
+ // indices back to their original indices when this computation is finished.
+ // To avoid an extra copy, we will store the neighbors and distances in a
+ // separate matrix.
+ arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
+ arma::mat* distancePtr = &distances;
- // Initialize the vector of upper bounds for each point.
- neighbor_distances_.set_size(knns_, references_.n_cols);
- neighbor_distances_.fill(SortPolicy::WorstDistance());
+ if (ownQueryTree || (ownReferenceTree && !queryTree))
+ distancePtr = new arma::mat; // Query indices need to be mapped.
+ if (ownReferenceTree || ownQueryTree)
+ neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
- // We'll time tree building
- Timers::StartTimer("neighbor_search/tree_building");
+ // Set the size of the neighbor and distance matrices.
+ neighborPtr->set_size(k, querySet.n_cols);
+ distancePtr->set_size(k, querySet.n_cols);
+ distancePtr->fill(SortPolicy::WorstDistance());
- // This call makes each tree from a matrix, leaf size, and two arrays
- // that record the permutation of the data points
- // Instead of NULL, it is possible to specify an array new_from_old_
- query_tree_ = NULL;
- reference_tree_ = new TreeType(references_, old_from_new_references_);
+ if (naive)
+ {
+ // Run the base case computation on all nodes
+ if (queryTree)
+ ComputeBaseCase(queryTree, referenceTree, *neighborPtr, *distancePtr);
+ else
+ ComputeBaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
+ }
+ else
+ {
+ if (singleMode)
+ {
+ // Do one tenth of the query set at a time.
+ size_t chunk = querySet.n_cols / 10;
- // Stop the timer we started above
- Timers::StopTimer("neighbor_search/tree_building");
-}
+ for (size_t i = 0; i < 10; i++)
+ {
+ for (size_t j = 0; j < chunk; j++)
+ {
+ double worstDistance = SortPolicy::WorstDistance();
+ ComputeSingleNeighborsRecursion(i * chunk + j,
+ querySet.unsafe_col(i * chunk + j), referenceTree, worstDistance,
+ *neighborPtr, *distancePtr);
+ }
+ }
-/**
- * The tree is the only member we are responsible for deleting. The others will
- * take care of themselves.
- */
-template<typename MetricType, typename SortPolicy>
-NeighborSearch<MetricType, SortPolicy>::~NeighborSearch()
-{
- if (reference_tree_ != query_tree_)
- delete reference_tree_;
- if (query_tree_ != NULL)
- delete query_tree_;
-}
+ // The last tenth is differently sized...
+ for (size_t i = 0; i < querySet.n_cols % 10; i++)
+ {
+ size_t ind = (querySet.n_cols / 10) * 10 + i;
+ double worstDistance = SortPolicy::WorstDistance();
+ ComputeSingleNeighborsRecursion(ind, querySet.unsafe_col(ind),
+ referenceTree, worstDistance, *neighborPtr, *distancePtr);
+ }
+ }
+ else // Dual-tree recursion.
+ {
+ // Start on the root of each tree.
+ if (queryTree)
+ {
+ ComputeDualNeighborsRecursion(queryTree, referenceTree,
+ SortPolicy::BestNodeToNodeDistance(queryTree, referenceTree),
+ *neighborPtr, *distancePtr);
+ }
+ else
+ {
+ ComputeDualNeighborsRecursion(referenceTree, referenceTree,
+ SortPolicy::BestNodeToNodeDistance(referenceTree, referenceTree),
+ *neighborPtr, *distancePtr);
+ }
+ }
+ }
+ Timers::StopTimer("neighbor_search/computing_neighbors");
+
+ // Now, do we need to do mapping of indices?
+ if (!ownReferenceTree && !ownQueryTree)
+ {
+ // No mapping needed. We are done.
+ return;
+ }
+ else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+ {
+ // Set size of output matrices correctly.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < distances.n_cols; i++)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; j++)
+ {
+ resultingNeighbors(j, oldFromNewQueries[i]) =
+ oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
+ }
+
+ // Finished with temporary matrices.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+ else if (ownReferenceTree)
+ {
+ if (!queryTree) // No query tree -- map both references and queries.
+ {
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < distances.n_cols; i++)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; j++)
+ {
+ resultingNeighbors(j, oldFromNewReferences[i]) =
+ oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
+ }
+ }
+ else // Map only references.
+ {
+ // Set size of neighbor indices matrix correctly.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+
+ // Map indices of neighbors.
+ for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
+ {
+ for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
+ {
+ resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
+ }
+ }
+
+ // Finished with temporary matrix.
+ delete neighborPtr;
+ }
+ else if (ownQueryTree)
+ {
+ // Set size of matrices correctly.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < distances.n_cols; i++)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+
+ // Map indices of neighbors.
+ resultingNeighbors.col(oldFromNewQueries[i]) = neighborPtr->col(i);
+ }
+
+ // Finished with temporary matrices.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+} // ComputeNeighbors
+
/**
* Performs exhaustive computation between two leaves.
*/
-template<typename MetricType, typename SortPolicy>
-void NeighborSearch<MetricType, SortPolicy>::ComputeBaseCase_(
- TreeType* query_node,
- TreeType* reference_node)
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeBaseCase(
+ TreeType* queryNode,
+ TreeType* referenceNode,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
// Used to find the query node's new upper bound.
- double query_worst_distance = SortPolicy::BestDistance();
+ double queryWorstDistance = SortPolicy::BestDistance();
- // node->begin() is the index of the first point in the node,
- // node->end is one past the last index.
- for (size_t query_index = query_node->Begin();
- query_index < query_node->End(); query_index++)
+ // node->Begin() is the index of the first point in the node,
+ // node->End() is one past the last index.
+ for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
+ queryIndex++)
{
// Get the query point from the matrix.
- arma::vec query_point = queries_.unsafe_col(query_index);
+ arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- double query_to_node_distance =
- SortPolicy::BestPointToNodeDistance(query_point, reference_node);
+ double queryToNodeDistance =
+ SortPolicy::BestPointToNodeDistance(queryPoint, referenceNode);
- if (SortPolicy::IsBetter(query_to_node_distance,
- neighbor_distances_(knns_ - 1, query_index)))
+ if (SortPolicy::IsBetter(queryToNodeDistance,
+ distances(distances.n_rows - 1, queryIndex)))
{
// We'll do the same for the references.
- for (size_t reference_index = reference_node->Begin();
- reference_index < reference_node->End(); reference_index++)
+ for (size_t referenceIndex = referenceNode->Begin();
+ referenceIndex < referenceNode->End(); referenceIndex++)
{
// Confirm that points do not identify themselves as neighbors
// in the monochromatic case.
- if (reference_node != query_node || reference_index != query_index)
+ if (referenceNode != queryNode || referenceIndex != queryIndex)
{
- arma::vec reference_point = references_.unsafe_col(reference_index);
+ arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
- double distance = kernel_.Evaluate(query_point, reference_point);
+ double distance = metric.Evaluate(queryPoint, referencePoint);
// If the reference point is closer than any of the current
// candidates, add it to the list.
- arma::vec query_dist = neighbor_distances_.unsafe_col(query_index);
- size_t insert_position = SortPolicy::SortDistance(query_dist,
+ arma::vec queryDist = distances.unsafe_col(queryIndex);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist,
distance);
- if (insert_position != (size_t() - 1))
- InsertNeighbor(query_index, insert_position, reference_index,
- distance);
+ if (insertPosition != (size_t() - 1))
+ InsertNeighbor(queryIndex, insertPosition, referenceIndex,
+ distance, neighbors, distances);
}
}
}
// We need to find the upper bound distance for this query node
- if (SortPolicy::IsBetter(query_worst_distance,
- neighbor_distances_(knns_ - 1, query_index)))
- query_worst_distance = neighbor_distances_(knns_ - 1, query_index);
+ if (SortPolicy::IsBetter(queryWorstDistance,
+ distances(distances.n_rows - 1, queryIndex)))
+ queryWorstDistance = distances(distances.n_rows - 1, queryIndex);
}
- // Update the upper bound for the query_node
- query_node->Stat().bound_ = query_worst_distance;
+ // Update the upper bound for the queryNode
+ queryNode->Stat().Bound() = queryWorstDistance;
-} // ComputeBaseCase_
+} // ComputeBaseCase()
/**
* The recursive function for dual tree.
*/
-template<typename MetricType, typename SortPolicy>
-void NeighborSearch<MetricType, SortPolicy>::ComputeDualNeighborsRecursion_(
- TreeType* query_node,
- TreeType* reference_node,
- double lower_bound)
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+ComputeDualNeighborsRecursion(
+ TreeType* queryNode,
+ TreeType* referenceNode,
+ const double lowerBound,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
- if (SortPolicy::IsBetter(query_node->Stat().bound_, lower_bound))
+ if (SortPolicy::IsBetter(queryNode->Stat().Bound(), lowerBound))
{
- number_of_prunes_++; // Pruned by distance; the nodes cannot be any closer
- return; // than the already established lower bound.
+ numberOfPrunes++; // Pruned by distance; the nodes cannot be any closer
+ return; // than the already established lower bound.
}
- if (query_node->IsLeaf() && reference_node->IsLeaf())
+ if (queryNode->IsLeaf() && referenceNode->IsLeaf())
{
- ComputeBaseCase_(query_node, reference_node); // Base case: both are leaves.
+ // Base case: both are leaves.
+ ComputeBaseCase(queryNode, referenceNode, neighbors, distances);
return;
}
- if (query_node->IsLeaf())
+ if (queryNode->IsLeaf())
{
// We must keep descending down the reference node to get to a leaf.
// We'll order the computation by distance; descend in the direction of less
// distance first.
- double left_distance = SortPolicy::BestNodeToNodeDistance(query_node,
- reference_node->Left());
- double right_distance = SortPolicy::BestNodeToNodeDistance(query_node,
- reference_node->Right());
+ double leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode,
+ referenceNode->Left());
+ double rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode,
+ referenceNode->Right());
- if (SortPolicy::IsBetter(left_distance, right_distance))
+ if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- ComputeDualNeighborsRecursion_(query_node, reference_node->Left(),
- left_distance);
- ComputeDualNeighborsRecursion_(query_node, reference_node->Right(),
- right_distance);
+ ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
+ leftDistance, neighbors, distances);
+ ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
+ rightDistance, neighbors, distances);
}
else
{
- ComputeDualNeighborsRecursion_(query_node, reference_node->Right(),
- right_distance);
- ComputeDualNeighborsRecursion_(query_node, reference_node->Left(),
- left_distance);
+ ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
+ rightDistance, neighbors, distances);
+ ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
+ leftDistance, neighbors, distances);
}
return;
}
- if (reference_node->IsLeaf())
+ if (referenceNode->IsLeaf())
{
// We must descend down the query node to get to a leaf.
- double left_distance = SortPolicy::BestNodeToNodeDistance(
- query_node->Left(), reference_node);
- double right_distance = SortPolicy::BestNodeToNodeDistance(
- query_node->Right(), reference_node);
+ double leftDistance = SortPolicy::BestNodeToNodeDistance(
+ queryNode->Left(), referenceNode);
+ double rightDistance = SortPolicy::BestNodeToNodeDistance(
+ queryNode->Right(), referenceNode);
- ComputeDualNeighborsRecursion_(query_node->Left(), reference_node,
- left_distance);
- ComputeDualNeighborsRecursion_(query_node->Right(), reference_node,
- right_distance);
+ ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode,
+ leftDistance, neighbors, distances);
+ ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode,
+ rightDistance, neighbors, distances);
- // We need to update the upper bound based on the new upper bounds of
- // the children
- double left_bound = query_node->Left()->Stat().bound_;
- double right_bound = query_node->Right()->Stat().bound_;
+ // We need to update the upper bound based on the new upper bounds of the
+ // children.
+ double leftBound = queryNode->Left()->Stat().Bound();
+ double rightBound = queryNode->Right()->Stat().Bound();
- if (SortPolicy::IsBetter(left_bound, right_bound))
- query_node->Stat().bound_ = right_bound;
+ if (SortPolicy::IsBetter(leftBound, rightBound))
+ queryNode->Stat().Bound() = rightBound;
else
- query_node->Stat().bound_ = left_bound;
+ queryNode->Stat().Bound() = leftBound;
return;
}
// Neither side is a leaf; so we recurse on all combinations of both. The
// calculations are ordered by distance.
- double left_distance = SortPolicy::BestNodeToNodeDistance(query_node->Left(),
- reference_node->Left());
- double right_distance = SortPolicy::BestNodeToNodeDistance(query_node->Left(),
- reference_node->Right());
+ double leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Left(),
+ referenceNode->Left());
+ double rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Left(),
+ referenceNode->Right());
- // Recurse on query_node->left() first.
- if (SortPolicy::IsBetter(left_distance, right_distance))
+ // Recurse on queryNode->left() first.
+ if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- ComputeDualNeighborsRecursion_(query_node->Left(), reference_node->Left(),
- left_distance);
- ComputeDualNeighborsRecursion_(query_node->Left(), reference_node->Right(),
- right_distance);
+ ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
+ leftDistance, neighbors, distances);
+ ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
+ rightDistance, neighbors, distances);
}
else
{
- ComputeDualNeighborsRecursion_(query_node->Left(), reference_node->Right(),
- right_distance);
- ComputeDualNeighborsRecursion_(query_node->Left(), reference_node->Left(),
- left_distance);
+ ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
+ rightDistance, neighbors, distances);
+ ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
+ leftDistance, neighbors, distances);
}
- left_distance = SortPolicy::BestNodeToNodeDistance(query_node->Right(),
- reference_node->Left());
- right_distance = SortPolicy::BestNodeToNodeDistance(query_node->Right(),
- reference_node->Right());
+ leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
+ referenceNode->Left());
+ rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
+ referenceNode->Right());
- // Now recurse on query_node->right().
- if (SortPolicy::IsBetter(left_distance, right_distance))
+ // Now recurse on queryNode->right().
+ if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- ComputeDualNeighborsRecursion_(query_node->Right(), reference_node->Left(),
- left_distance);
- ComputeDualNeighborsRecursion_(query_node->Right(), reference_node->Right(),
- right_distance);
+ ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
+ leftDistance, neighbors, distances);
+ ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
+ rightDistance, neighbors, distances);
}
else
{
- ComputeDualNeighborsRecursion_(query_node->Right(), reference_node->Right(),
- right_distance);
- ComputeDualNeighborsRecursion_(query_node->Right(), reference_node->Left(),
- left_distance);
+ ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
+ rightDistance, neighbors, distances);
+ ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
+ leftDistance, neighbors, distances);
}
// Update the upper bound as above
- double left_bound = query_node->Left()->Stat().bound_;
- double right_bound = query_node->Right()->Stat().bound_;
+ double leftBound = queryNode->Left()->Stat().Bound();
+ double rightBound = queryNode->Right()->Stat().Bound();
- if (SortPolicy::IsBetter(left_bound, right_bound))
- query_node->Stat().bound_ = right_bound;
+ if (SortPolicy::IsBetter(leftBound, rightBound))
+ queryNode->Stat().Bound() = rightBound;
else
- query_node->Stat().bound_ = left_bound;
+ queryNode->Stat().Bound() = leftBound;
-} // ComputeDualNeighborsRecursion_
+} // ComputeDualNeighborsRecursion()
-template<typename MetricType, typename SortPolicy>
-void NeighborSearch<MetricType, SortPolicy>::ComputeSingleNeighborsRecursion_(
- size_t point_id,
- arma::vec& point,
- TreeType* reference_node,
- double& best_dist_so_far)
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+ComputeSingleNeighborsRecursion(const size_t pointId,
+ const arma::vec& point,
+ TreeType* referenceNode,
+ double& bestDistSoFar,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
- if (reference_node->IsLeaf())
+ if (referenceNode->IsLeaf())
{
// Base case: reference node is a leaf.
- for (size_t reference_index = reference_node->Begin();
- reference_index < reference_node->End(); reference_index++)
+ for (size_t referenceIndex = referenceNode->Begin();
+ referenceIndex < referenceNode->End(); referenceIndex++)
{
// Confirm that points do not identify themselves as neighbors
// in the monochromatic case
- if (!(references_.memptr() == queries_.memptr() &&
- reference_index == point_id))
+ if (!(referenceSet.memptr() == querySet.memptr() &&
+ referenceIndex == pointId))
{
- arma::vec reference_point = references_.unsafe_col(reference_index);
+ arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
- double distance = kernel_.Evaluate(point, reference_point);
+ double distance = metric.Evaluate(point, referencePoint);
// If the reference point is better than any of the current candidates,
// insert it into the list correctly.
- arma::vec query_dist = neighbor_distances_.unsafe_col(point_id);
- size_t insert_position = SortPolicy::SortDistance(query_dist,
- distance);
+ arma::vec queryDist = distances.unsafe_col(pointId);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
- if (insert_position != (size_t() - 1))
- InsertNeighbor(point_id, insert_position, reference_index, distance);
+ if (insertPosition != (size_t() - 1))
+ InsertNeighbor(pointId, insertPosition, referenceIndex, distance,
+ neighbors, distances);
}
- } // for reference_index
+ } // for referenceIndex
- best_dist_so_far = neighbor_distances_(knns_ - 1, point_id);
+ bestDistSoFar = distances(distances.n_rows - 1, pointId);
}
else
{
// We'll order the computation by distance.
- double left_distance = SortPolicy::BestPointToNodeDistance(point,
- reference_node->Left());
- double right_distance = SortPolicy::BestPointToNodeDistance(point,
- reference_node->Right());
+ double leftDistance = SortPolicy::BestPointToNodeDistance(point,
+ referenceNode->Left());
+ double rightDistance = SortPolicy::BestPointToNodeDistance(point,
+ referenceNode->Right());
// Recurse in the best direction first.
- if (SortPolicy::IsBetter(left_distance, right_distance))
+ if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- if (SortPolicy::IsBetter(best_dist_so_far, left_distance))
- number_of_prunes_++; // Prune; no possibility of finding a better point.
+ if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
+ numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- ComputeSingleNeighborsRecursion_(point_id, point,
- reference_node->Left(), best_dist_so_far);
+ ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
+ bestDistSoFar, neighbors, distances);
- if (SortPolicy::IsBetter(best_dist_so_far, right_distance))
- number_of_prunes_++; // Prune; no possibility of finding a better point.
+ if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
+ numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- ComputeSingleNeighborsRecursion_(point_id, point,
- reference_node->Right(), best_dist_so_far);
+ ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
+ bestDistSoFar, neighbors, distances);
}
else
{
- if (SortPolicy::IsBetter(best_dist_so_far, right_distance))
- number_of_prunes_++; // Prune; no possibility of finding a better point.
+ if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
+ numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- ComputeSingleNeighborsRecursion_(point_id, point,
- reference_node->Right(), best_dist_so_far);
+ ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
+ bestDistSoFar, neighbors, distances);
- if (SortPolicy::IsBetter(best_dist_so_far, left_distance))
- number_of_prunes_++; // Prune; no possibility of finding a better point.
+ if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
+ numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- ComputeSingleNeighborsRecursion_(point_id, point,
- reference_node->Left(), best_dist_so_far);
+ ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
+ bestDistSoFar, neighbors, distances);
}
}
}
/**
- * Computes the best neighbors and stores them in resulting_neighbors and
- * distances.
- */
-template<typename MetricType, typename SortPolicy>
-void NeighborSearch<MetricType, SortPolicy>::ComputeNeighbors(
- arma::Mat<size_t>& resulting_neighbors,
- arma::mat& distances)
-{
- Timers::StartTimer("neighbor_search/computing_neighbors");
- if (naive_)
- {
- // Run the base case computation on all nodes
- if (query_tree_)
- ComputeBaseCase_(query_tree_, reference_tree_);
- else
- ComputeBaseCase_(reference_tree_, reference_tree_);
- }
- else
- {
- if (dual_mode_)
- {
- // Start on the root of each tree
- if (query_tree_)
- {
- ComputeDualNeighborsRecursion_(query_tree_, reference_tree_,
- SortPolicy::BestNodeToNodeDistance(query_tree_, reference_tree_));
- }
- else
- {
- ComputeDualNeighborsRecursion_(reference_tree_, reference_tree_,
- SortPolicy::BestNodeToNodeDistance(reference_tree_,
- reference_tree_));
- }
- }
- else
- {
- size_t chunk = queries_.n_cols / 10;
-
- for (size_t i = 0; i < 10; i++)
- {
- for (size_t j = 0; j < chunk; j++)
- {
- arma::vec point = queries_.unsafe_col(i * chunk + j);
- double best_dist_so_far = SortPolicy::WorstDistance();
- ComputeSingleNeighborsRecursion_(i * chunk + j, point,
- reference_tree_, best_dist_so_far);
- }
- }
-
- for (size_t i = 0; i < queries_.n_cols % 10; i++)
- {
- size_t ind = (queries_.n_cols / 10) * 10 + i;
- arma::vec point = queries_.unsafe_col(ind);
- double best_dist_so_far = SortPolicy::WorstDistance();
- ComputeSingleNeighborsRecursion_(ind, point, reference_tree_,
- best_dist_so_far);
- }
- }
- }
-
- Timers::StopTimer("neighbor_search/computing_neighbors");
-
- // We need to initialize the results list before filling it
- resulting_neighbors.set_size(neighbor_indices_.n_rows,
- neighbor_indices_.n_cols);
- distances.set_size(neighbor_distances_.n_rows, neighbor_distances_.n_cols);
-
- // We need to map the indices back from how they have been permuted
- if (query_tree_ != NULL)
- {
- for (size_t i = 0; i < neighbor_indices_.n_cols; i++)
- {
- for (size_t k = 0; k < neighbor_indices_.n_rows; k++)
- {
- resulting_neighbors(k, old_from_new_queries_[i]) =
- old_from_new_references_[neighbor_indices_(k, i)];
- distances(k, old_from_new_queries_[i]) = neighbor_distances_(k, i);
- }
- }
- }
- else
- {
- for (size_t i = 0; i < neighbor_indices_.n_cols; i++)
- {
- for (size_t k = 0; k < neighbor_indices_.n_rows; k++)
- {
- resulting_neighbors(k, old_from_new_references_[i]) =
- old_from_new_references_[neighbor_indices_(k, i)];
- distances(k, old_from_new_references_[i]) = neighbor_distances_(k, i);
- }
- }
- }
-} // ComputeNeighbors
-
-/***
* Helper function to insert a point into the neighbors and distances matrices.
*
- * @param query_index Index of point whose neighbors we are inserting into.
+ * @param queryIndex Index of point whose neighbors we are inserting into.
* @param pos Position in list to insert into.
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
*/
-template<typename MetricType, typename SortPolicy>
-void NeighborSearch<MetricType, SortPolicy>::InsertNeighbor(size_t query_index,
- size_t pos,
- size_t neighbor,
- double distance)
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+InsertNeighbor(const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
// We only memmove() if there is actually a need to shift something.
- if (pos < (knns_ - 1))
+ if (pos < (distances.n_rows - 1))
{
- int len = (knns_ - 1) - pos;
- memmove(neighbor_distances_.colptr(query_index) + (pos + 1),
- neighbor_distances_.colptr(query_index) + pos,
+ int len = (distances.n_rows - 1) - pos;
+ memmove(distances.colptr(queryIndex) + (pos + 1),
+ distances.colptr(queryIndex) + pos,
sizeof(double) * len);
- memmove(neighbor_indices_.colptr(query_index) + (pos + 1),
- neighbor_indices_.colptr(query_index) + pos,
+ memmove(neighbors.colptr(queryIndex) + (pos + 1),
+ neighbors.colptr(queryIndex) + pos,
sizeof(size_t) * len);
}
// Now put the new information in the right index.
- neighbor_distances_(pos, query_index) = distance;
- neighbor_indices_(pos, query_index) = neighbor;
+ distances(pos, queryIndex) = distance;
+ neighbors(pos, queryIndex) = neighbor;
}
#endif
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp 2011-12-02 01:00:03 UTC (rev 10498)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.hpp 2011-12-02 02:59:07 UTC (rev 10499)
@@ -26,7 +26,7 @@
* neighbors. Squared distances are used because they are slightly faster than
* non-squared distances (they have one fewer call to sqrt()).
*/
-typedef NeighborSearch<metric::SquaredEuclideanDistance, NearestNeighborSort>
+typedef NeighborSearch<NearestNeighborSort, metric::SquaredEuclideanDistance>
AllkNN;
/**
@@ -35,7 +35,7 @@
* neighbors. Squared distances are used because they are slightly faster than
* non-squared distances (they have one fewer call to sqrt()).
*/
-typedef NeighborSearch<metric::SquaredEuclideanDistance, FurthestNeighborSort>
+typedef NeighborSearch<FurthestNeighborSort, metric::SquaredEuclideanDistance>
AllkFN;
}; // namespace neighbor
More information about the mlpack-svn
mailing list