[mlpack-svn] r12663 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 9 16:25:47 EDT 2012
Author: rcurtin
Date: 2012-05-09 16:25:47 -0400 (Wed, 09 May 2012)
New Revision: 12663
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
Log:
Add an option to use cover trees.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2012-05-09 20:25:01 UTC (rev 12662)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2012-05-09 20:25:47 UTC (rev 12663)
@@ -6,6 +6,7 @@
* options.
*/
#include <mlpack/core.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
#include <string>
#include <fstream>
@@ -52,6 +53,8 @@
PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
"dual-tree search.", "s");
+PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search.",
+ "c");
int main(int argc, char *argv[])
{
@@ -107,111 +110,170 @@
arma::Mat<size_t> neighbors;
arma::mat distances;
- // Because we may construct it differently, we need a pointer.
- AllkNN* allknn = NULL;
+ if (!CLI::HasParam("cover_tree"))
+ {
+ // Because we may construct it differently, we need a pointer.
+ AllkNN* allknn = NULL;
- // Mappings for when we build the tree.
- std::vector<size_t> oldFromNewRefs;
+ // Mappings for when we build the tree.
+ std::vector<size_t> oldFromNewRefs;
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("tree_building");
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
- BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
refTree(referenceData, oldFromNewRefs, leafSize);
- BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
queryTree = NULL; // Empty for now.
- Timer::Stop("tree_building");
+ Timer::Stop("tree_building");
- std::vector<size_t> oldFromNewQueries;
+ std::vector<size_t> oldFromNewQueries;
- if (CLI::GetParam<string>("query_file") != "")
- {
- string queryFile = CLI::GetParam<string>("query_file");
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
- data::Load(queryFile.c_str(), queryData, true);
+ data::Load(queryFile.c_str(), queryData, true);
- if (naive && leafSize < queryData.n_cols)
- leafSize = queryData.n_cols;
+ if (naive && leafSize < queryData.n_cols)
+ leafSize = queryData.n_cols;
- Log::Info << "Loaded query data from '" << queryFile << "' ("
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
<< queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
- Log::Info << "Building query tree..." << endl;
+ Log::Info << "Building query tree..." << endl;
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Timer::Start("tree_building");
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Timer::Start("tree_building");
- queryTree = new BinarySpaceTree<bound::HRectBound<2>,
- QueryStat<NearestNeighborSort> >(queryData, oldFromNewQueries,
- leafSize);
+ queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+ QueryStat<NearestNeighborSort> >(queryData, oldFromNewQueries,
+ leafSize);
- Timer::Stop("tree_building");
+ Timer::Stop("tree_building");
- allknn = new AllkNN(&refTree, queryTree, referenceData, queryData,
- singleMode);
+ allknn = new AllkNN(&refTree, queryTree, referenceData, queryData,
+ singleMode);
- Log::Info << "Tree built." << endl;
- }
- else
- {
- allknn = new AllkNN(&refTree, referenceData, singleMode);
+ Log::Info << "Tree built." << endl;
+ }
+ else
+ {
+ allknn = new AllkNN(&refTree, referenceData, singleMode);
- Log::Info << "Trees built." << endl;
- }
+ Log::Info << "Trees built." << endl;
+ }
- Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allknn->Search(k, neighbors, distances);
+ arma::mat distancesOut;
+ arma::Mat<size_t> neighborsOut;
- Log::Info << "Neighbors computed." << endl;
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allknn->Search(k, neighborsOut, distancesOut);
- // We have to map back to the original indices from before the tree
- // construction.
- Log::Info << "Re-mapping indices..." << endl;
+ Log::Info << "Neighbors computed." << endl;
- arma::mat distancesOut(distances.n_rows, distances.n_cols);
- arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
- // Do the actual remapping.
- if (CLI::GetParam<string>("query_file") != "")
- {
- for (size_t i = 0; i < distances.n_cols; ++i)
+ neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+ distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+
+ // Do the actual remapping.
+ if (CLI::GetParam<string>("query_file") != "")
{
- // Map distances (copy a column).
- distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
+ for (size_t i = 0; i < distancesOut.n_cols; ++i)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; ++j)
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distancesOut.n_rows; ++j)
+ {
+ neighbors(j, oldFromNewQueries[i]) =
+ oldFromNewRefs[neighborsOut(j, i)];
+ }
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < distancesOut.n_cols; ++i)
{
- neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
+ // Map distances (copy a column).
+ distances.col(oldFromNewRefs[i]) = distancesOut.col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distancesOut.n_rows; ++j)
+ {
+ neighbors(j, oldFromNewRefs[i]) = oldFromNewRefs[neighborsOut(j, i)];
+ }
}
}
+
+ // Clean up.
+ if (queryTree)
+ delete queryTree;
+
+ delete allknn;
}
- else
+ else // Cover trees.
{
- for (size_t i = 0; i < distances.n_cols; ++i)
+ // Build our reference tree.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+ CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > referenceTree(referenceData);
+ CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> >* queryTree = NULL;
+ Timer::Stop("tree_building");
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
+ CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >* allknn = NULL;
+
+ // See if we have query data.
+ if (CLI::HasParam("query_file"))
{
- // Map distances (copy a column).
- distancesOut.col(oldFromNewRefs[i]) = distances.col(i);
+ string queryFile = CLI::GetParam<string>("query_file");
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; ++j)
- {
- neighborsOut(j, oldFromNewRefs[i]) = oldFromNewRefs[neighbors(j, i)];
- }
+ data::Load(queryFile, queryData, true);
+
+ // Build query tree.
+ Log::Info << "Building query tree..." << endl;
+ Timer::Start("tree_building");
+ queryTree = new CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> >(queryData);
+ Timer::Stop("tree_building");
+
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
+ CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >(&referenceTree, queryTree,
+ referenceData, queryData, true);
}
+ else
+ {
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
+ CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >(&referenceTree, referenceData,
+ true);
+ }
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allknn->Search(k, neighbors, distances);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ delete allknn;
+
+ if (queryTree)
+ delete queryTree;
}
- // Clean up.
- if (queryTree)
- delete queryTree;
-
// Save output.
- data::Save(distancesFile, distancesOut);
- data::Save(neighborsFile, neighborsOut);
-
- delete allknn;
+ data::Save(distancesFile, distances);
+ data::Save(neighborsFile, neighbors);
}
More information about the mlpack-svn
mailing list