[mlpack-svn] r15624 - mlpack/trunk/src/mlpack/methods/range_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Aug 15 15:17:27 EDT 2013
Author: rcurtin
Date: Thu Aug 15 15:17:27 2013
New Revision: 15624
Log:
Add --cover_tree option to range_search.
Modified:
mlpack/trunk/src/mlpack/methods/range_search/range_search_main.cpp
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_main.cpp (original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_main.cpp Thu Aug 15 15:17:27 2013
@@ -8,6 +8,7 @@
*/
#include <mlpack/core.hpp>
#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
#include "range_search.hpp"
@@ -60,10 +61,12 @@
PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
- "dual-tree search.", "s");
+ "dual-tree search).", "s");
+PARAM_FLAG("cover_tree", "If true, use a cover tree for range searching "
+ "(instead of a kd-tree).", "c");
-typedef RangeSearch<metric::SquaredEuclideanDistance,
- BinarySpaceTree<bound::HRectBound<2>, EmptyStatistic> > RSType;
+typedef RangeSearch<> RSType;
+typedef RangeSearch<metric::EuclideanDistance, CoverTree<> > RSCoverType;
int main(int argc, char *argv[])
{
@@ -81,8 +84,9 @@
double max = CLI::GetParam<double>("max");
double min = CLI::GetParam<double>("min");
- bool naive = CLI::HasParam("naive");
- bool singleMode = CLI::HasParam("single_mode");
+ const bool naive = CLI::HasParam("naive");
+ const bool singleMode = CLI::HasParam("single_mode");
+ bool coverTree = CLI::HasParam("cover_tree");
arma::mat referenceData;
arma::mat queryData; // So it doesn't go out of scope.
@@ -115,117 +119,164 @@
if (naive)
leafSize = referenceData.n_cols;
+ if (coverTree && naive)
+ {
+ Log::Warn << "--cover_tree ignored because --naive is present." << endl;
+ coverTree = false;
+ }
+
vector<vector<size_t> > neighbors;
vector<vector<double> > distances;
- // Because we may construct it differently, we need a pointer.
- RSType* rangeSearch = NULL;
+ // The cover tree implies different types, so we must split this section.
+ if (CLI::HasParam("cover_tree"))
+ {
+ Log::Info << "Using cover trees." << endl;
- // Mappings for when we build the tree.
- vector<size_t> oldFromNewRefs;
+ // This is significantly simpler than kd-tree construction because the data
+ // matrix is not modified.
+ RSCoverType* rangeSearch = NULL;
+ CoverTree<> referenceTree(referenceData);
+ CoverTree<>* queryTree = NULL;
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("tree_building");
+ if (CLI::GetParam<string>("query_file") == "")
+ {
+ // Single dataset.
+ rangeSearch = new RSCoverType(&referenceTree, referenceData, singleMode);
+ }
+ else
+ {
+ // Two datasets.
+ const string queryFile = CLI::GetParam<string>("query_file");
+ data::Load(queryFile, queryData, true);
+ queryTree = new CoverTree<>(queryData);
- BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>
- refTree(referenceData, oldFromNewRefs, leafSize);
- BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>*
- queryTree = NULL; // Empty for now.
+ rangeSearch = new RSCoverType(&referenceTree, queryTree, referenceData,
+ queryData, singleMode);
+ }
- Timer::Stop("tree_building");
+ Log::Info << "Trees built." << endl;
- std::vector<size_t> oldFromNewQueries;
+ const math::Range r(min, max);
+ rangeSearch->Search(r, neighbors, distances);
- if (CLI::GetParam<string>("query_file") != "")
+ if (queryTree)
+ delete queryTree;
+ delete rangeSearch;
+ }
+ else
{
- string queryFile = CLI::GetParam<string>("query_file");
+ // Because we may construct it differently, we need a pointer.
+ RSType* rangeSearch = NULL;
- if (!data::Load(queryFile, queryData))
- Log::Fatal << "Query file " << queryFile << " not found" << endl;
-
- if (naive && leafSize < queryData.n_cols)
- leafSize = queryData.n_cols;
-
- Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
-
- Log::Info << "Building query tree..." << endl;
+ // Mappings for when we build the tree.
+ vector<size_t> oldFromNewRefs;
// Build trees by hand, so we can save memory: if we pass a tree to
// NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
Timer::Start("tree_building");
- queryTree = new BinarySpaceTree<bound::HRectBound<2>,
- tree::EmptyStatistic >(queryData, oldFromNewQueries,
- leafSize);
+ BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>
+ refTree(referenceData, oldFromNewRefs, leafSize);
+ BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>*
+ queryTree = NULL; // Empty for now.
Timer::Stop("tree_building");
- rangeSearch = new RSType(&refTree, queryTree, referenceData, queryData,
- singleMode);
+ vector<size_t> oldFromNewQueries;
- Log::Info << "Tree built." << endl;
- }
- else
- {
- rangeSearch = new RSType(&refTree, referenceData, singleMode);
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ const string queryFile = CLI::GetParam<string>("query_file");
+ data::Load(queryFile, queryData, true);
- Log::Info << "Trees built." << endl;
- }
+ if (naive && leafSize < queryData.n_cols)
+ leafSize = queryData.n_cols;
- Log::Info << "Computing neighbors within range [" << min << ", " << max
- << "]." << endl;
+ Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
- math::Range r = math::Range(min, max);
- rangeSearch->Search(r, neighbors, distances);
+ Log::Info << "Building query tree..." << endl;
- Log::Info << "Neighbors computed." << endl;
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Timer::Start("tree_building");
- // We have to map back to the original indices from before the tree
- // construction.
- Log::Info << "Re-mapping indices..." << endl;
+ queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+ tree::EmptyStatistic >(queryData, oldFromNewQueries,
+ leafSize);
- vector<vector<double> > distancesOut;
- distancesOut.resize(distances.size());
- vector<vector<size_t> > neighborsOut;
- neighborsOut.resize(neighbors.size());
+ Timer::Stop("tree_building");
- // Do the actual remapping.
- if (CLI::GetParam<string>("query_file") != "")
- {
- for (size_t i = 0; i < distances.size(); ++i)
+ rangeSearch = new RSType(&refTree, queryTree, referenceData, queryData,
+ singleMode);
+
+ Log::Info << "Tree built." << endl;
+ }
+ else
{
- // Map distances (copy a column).
- distancesOut[oldFromNewQueries[i]] = distances[i];
+ rangeSearch = new RSType(&refTree, referenceData, singleMode);
+
+ Log::Info << "Trees built." << endl;
+ }
+
+ Log::Info << "Computing neighbors within range [" << min << ", " << max
+ << "]." << endl;
+
+ // Collect the results in these vectors before remapping.
+ vector<vector<double> > distancesOut;
+ vector<vector<size_t> > neighborsOut;
- // Map indices of neighbors.
- neighborsOut[oldFromNewQueries[i]].resize(neighbors[i].size());
- for (size_t j = 0; j < distances[i].size(); ++j)
+ const math::Range r(min, max);
+ rangeSearch->Search(r, neighborsOut, distancesOut);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
+
+ distances.resize(distancesOut.size());
+ neighbors.resize(neighborsOut.size());
+
+ // Do the actual remapping.
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ for (size_t i = 0; i < distances.size(); ++i)
{
- neighborsOut[oldFromNewQueries[i]][j] = oldFromNewRefs[neighbors[i][j]];
+ // Map distances (copy a column).
+ distances[oldFromNewQueries[i]] = distancesOut[i];
+
+ // Map indices of neighbors.
+ neighbors[oldFromNewQueries[i]].resize(neighborsOut[i].size());
+ for (size_t j = 0; j < distancesOut[i].size(); ++j)
+ {
+ neighbors[oldFromNewQueries[i]][j] =
+ oldFromNewRefs[neighborsOut[i][j]];
+ }
}
}
- }
- else
- {
- for (size_t i = 0; i < distances.size(); ++i)
+ else
{
- // Map distances (copy a column).
- distancesOut[oldFromNewRefs[i]] = distances[i];
-
- // Map indices of neighbors.
- neighborsOut[oldFromNewRefs[i]].resize(neighbors[i].size());
- for (size_t j = 0; j < distances[i].size(); ++j)
+ for (size_t i = 0; i < distances.size(); ++i)
{
- neighborsOut[oldFromNewRefs[i]][j] = oldFromNewRefs[neighbors[i][j]];
+ // Map distances (copy a column).
+ distances[oldFromNewRefs[i]] = distancesOut[i];
+
+ // Map indices of neighbors.
+ neighbors[oldFromNewRefs[i]].resize(neighborsOut[i].size());
+ for (size_t j = 0; j < distancesOut[i].size(); ++j)
+ {
+ neighbors[oldFromNewRefs[i]][j] = oldFromNewRefs[neighborsOut[i][j]];
+ }
}
}
- }
- // Clean up.
- if (queryTree)
- delete queryTree;
+ // Clean up.
+ if (queryTree)
+ delete queryTree;
+ delete rangeSearch;
+ }
// Save output. We have to do this by hand.
fstream distancesStr(distancesFile.c_str(), fstream::out);
@@ -237,17 +288,17 @@
else
{
// Loop over each point.
- for (size_t i = 0; i < distancesOut.size(); ++i)
+ for (size_t i = 0; i < distances.size(); ++i)
{
// Store the distances of each point. We may have 0 points to store, so
// we must account for that possibility.
- for (size_t j = 0; j + 1 < distancesOut[i].size(); ++j)
+ for (size_t j = 0; j + 1 < distances[i].size(); ++j)
{
- distancesStr << distancesOut[i][j] << ", ";
+ distancesStr << distances[i][j] << ", ";
}
- if (distancesOut[i].size() > 0)
- distancesStr << distancesOut[i][distancesOut[i].size() - 1];
+ if (distances[i].size() > 0)
+ distancesStr << distances[i][distances[i].size() - 1];
distancesStr << endl;
}
@@ -264,23 +315,21 @@
else
{
// Loop over each point.
- for (size_t i = 0; i < neighborsOut.size(); ++i)
+ for (size_t i = 0; i < neighbors.size(); ++i)
{
// Store the neighbors of each point. We may have 0 points to store, so
// we must account for that possibility.
- for (size_t j = 0; j + 1 < neighborsOut[i].size(); ++j)
+ for (size_t j = 0; j + 1 < neighbors[i].size(); ++j)
{
- neighborsStr << neighborsOut[i][j] << ", ";
+ neighborsStr << neighbors[i][j] << ", ";
}
- if (neighborsOut[i].size() > 0)
- neighborsStr << neighborsOut[i][neighborsOut[i].size() - 1];
+ if (neighbors[i].size() > 0)
+ neighborsStr << neighbors[i][neighbors[i].size() - 1];
neighborsStr << endl;
}
neighborsStr.close();
}
-
- delete rangeSearch;
}
More information about the mlpack-svn
mailing list