[mlpack-svn] r11440 - in mlpack/trunk/src/mlpack: methods/range_search tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Feb 8 15:48:31 EST 2012
Author: mamidon
Date: 2012-02-08 15:48:30 -0500 (Wed, 08 Feb 2012)
New Revision: 11440
Added:
mlpack/trunk/src/mlpack/methods/range_search/range_main.cpp
Modified:
mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
mlpack/trunk/src/mlpack/tests/cli_test.cpp
Log:
Implemented range search, need to implement output functionality.
Also need to update the documentation to reflect realities of range search.
Modified: mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt 2012-02-08 20:14:09 UTC (rev 11439)
+++ mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt 2012-02-08 20:48:30 UTC (rev 11440)
@@ -5,6 +5,7 @@
set(SOURCES
range_search.hpp
range_search_impl.hpp
+ range_main.cpp
)
# Add directory name to sources.
@@ -15,3 +16,11 @@
# Append sources (with directory name) to list of all MLPACK sources (used at
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+add_executable(rangesearch
+ range_main.cpp
+)
+
+target_link_libraries(rangesearch
+ mlpack
+)
Added: mlpack/trunk/src/mlpack/methods/range_search/range_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_main.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_main.cpp 2012-02-08 20:48:30 UTC (rev 11440)
@@ -0,0 +1,223 @@
+/**
+ * @file allknn_main.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the AllkNN executable. Allows some number of standard
+ * options.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "range_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::range;
+using namespace mlpack::tree;
+
+// Information about the program itself.
+PROGRAM_INFO("Range Search",
+ "This program will calculate the all nearest-neighbors of a set of "
+ "points constrained by a range. You may specify a separate set of reference points and query "
+ "points, or just a reference set which will be used as both the reference "
+ "and query set."
+ "\n\n"
+ "For example, the following will calculate nearest neighbors within 5 units of each"
+ "point in 'input.csv' and store the distances in 'distances.csv' and the "
+ "neighbors in the file 'neighbors.csv':"
+ "\n\n"
+ "$ allknn --min=0 --max=5 --reference_file=input.csv --distances_file=distances.csv\n"
+ " --neighbors_file=neighbors.csv"
+ "\n\n"
+ "The output files are organized such that row i and column j in the "
+ "neighbors output file corresponds to the index of the point in the "
+ "reference set which is the i'th nearest neighbor from the point in the "
+ "query set with index j. Row i and column j in the distances output file "
+ "corresponds to the distance between those two points.");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
+
+PARAM_DOUBLE_REQ("max", "Furthest neighbors to find.", "A");
+PARAM_DOUBLE("min", "Closest neighbors to find.", "I", 0.0);
+
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+
+PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+ "dual-tree search.", "s");
+
+typedef RangeSearch<metric::SquaredEuclideanDistance,
+ BinarySpaceTree<bound::HRectBound<2>, EmptyStatistic> > AllInRange;
+
+int main(int argc, char *argv[])
+{
+ // Give CLI the command line parameters the user passed in.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Get all the parameters.
+ string referenceFile = CLI::GetParam<string>("reference_file");
+
+ string distancesFile = CLI::GetParam<string>("distances_file");
+ string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+ int lsInt = CLI::GetParam<int>("leaf_size");
+
+ double max = CLI::GetParam<int>("max");
+ double min = CLI::GetParam<int>("min");
+
+ bool naive = CLI::HasParam("naive");
+ bool singleMode = CLI::HasParam("single_mode");
+
+ arma::mat referenceData;
+ arma::mat queryData; // So it doesn't go out of scope.
+ if (!data::Load(referenceFile.c_str(), referenceData))
+ Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
+
+ Log::Info << "Loaded reference data from '" << referenceFile << "'." << endl;
+
+ // Sanity check on range value: max must be greater than min.
+ if (max <= min)
+ {
+ Log::Fatal << "Invalid [min,max]: " << max << "; must be greater than " << min;
+ }
+
+ // Sanity check on leaf size.
+ if (lsInt < 0)
+ {
+ Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
+ "than or equal to 0." << endl;
+ }
+ size_t leafSize = lsInt;
+
+ // Naive mode overrides single mode.
+ if (singleMode && naive)
+ {
+ Log::Warn << "--single_mode ignored because --naive is present." << endl;
+ }
+
+ if (naive)
+ leafSize = referenceData.n_cols;
+
+ std::vector<std::vector<size_t> > neighbors;
+ std::vector<std::vector<double> > distances;
+
+ // Because we may construct it differently, we need a pointer.
+ AllInRange* rangeSearch = NULL;
+
+ // Mappings for when we build the tree.
+ std::vector<size_t> oldFromNewRefs;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>
+ refTree(referenceData, oldFromNewRefs, leafSize);
+ BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>*
+ queryTree = NULL; // Empty for now.
+
+ Timer::Stop("tree_building");
+
+ std::vector<size_t> oldFromNewQueries;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ if (!data::Load(queryFile.c_str(), queryData))
+ Log::Fatal << "Query file " << queryFile << " not found" << endl;
+
+ if (naive && leafSize < queryData.n_cols)
+ leafSize = queryData.n_cols;
+
+ Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
+
+ Log::Info << "Building query tree..." << endl;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Timer::Start("tree_building");
+
+ queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+ tree::EmptyStatistic >(queryData, oldFromNewQueries,
+ leafSize);
+
+ Timer::Stop("tree_building");
+
+ rangeSearch = new AllInRange(&refTree, queryTree, referenceData,
+ queryData, singleMode);
+
+ Log::Info << "Tree built." << endl;
+ }
+ else
+ {
+ rangeSearch = new AllInRange(&refTree, referenceData, singleMode);
+
+ Log::Info << "Trees built." << endl;
+ }
+
+ Log::Info << "Computing neighbors within [" << min << ", " << max << "]." << endl;
+
+ math::Range r = math::Range(min,max);
+ rangeSearch->Search(r, neighbors, distances);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
+
+// arma::mat distancesOut(distances.n_rows, distances.n_cols);
+// arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+
+ /*
+ // Do the actual remapping.
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ for (size_t i = 0; i < distances.n_cols; ++i)
+ {
+ // Map distances (copy a column).
+ distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; ++j)
+ {
+ neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
+ }
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < distances.n_cols; ++i)
+ {
+ // Map distances (copy a column).
+ distancesOut.col(oldFromNewRefs[i]) = distances.col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; ++j)
+ {
+ neighborsOut(j, oldFromNewRefs[i]) = oldFromNewRefs[neighbors(j, i)];
+ }
+ }
+ }
+
+ // Clean up.
+ if (queryTree)
+ delete queryTree;
+
+ // Save output.
+ data::Save(distancesFile, distances);
+ data::Save(neighborsFile, neighbors);
+*/
+ delete rangeSearch;
+}
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp 2012-02-08 20:14:09 UTC (rev 11439)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp 2012-02-08 20:48:30 UTC (rev 11440)
@@ -106,8 +106,8 @@
* opposed to dual-tree computation).
* @param metric Instantiated distance metric.
*/
- RangeSearch(const TreeType* referenceTree,
- const TreeType* queryTree,
+ RangeSearch(TreeType* referenceTree,
+ TreeType* queryTree,
const typename TreeType::Mat& referenceSet,
const typename TreeType::Mat& querySet,
const bool singleMode = false,
@@ -140,7 +140,7 @@
* opposed to dual-tree computation).
* @param metric Instantiated distance metric.
*/
- RangeSearch(const TreeType* referenceTree,
+ RangeSearch(TreeType* referenceTree,
const typename TreeType::Mat& referenceSet,
const bool singleMode = false,
const MetricType metric = MetricType());
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp 2012-02-08 20:14:09 UTC (rev 11439)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp 2012-02-08 20:48:30 UTC (rev 11440)
@@ -75,8 +75,8 @@
template<typename MetricType, typename TreeType>
RangeSearch<MetricType, TreeType>::RangeSearch(
- const TreeType* referenceTree,
- const TreeType* queryTree,
+ TreeType* referenceTree,
+ TreeType* queryTree,
const typename TreeType::Mat& referenceSet,
const typename TreeType::Mat& querySet,
const bool singleMode,
@@ -97,7 +97,7 @@
template<typename MetricType, typename TreeType>
RangeSearch<MetricType, TreeType>::RangeSearch(
- const TreeType* referenceTree,
+ TreeType* referenceTree,
const typename TreeType::Mat& referenceSet,
const bool singleMode,
const MetricType metric) :
Modified: mlpack/trunk/src/mlpack/tests/cli_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/cli_test.cpp 2012-02-08 20:14:09 UTC (rev 11439)
+++ mlpack/trunk/src/mlpack/tests/cli_test.cpp 2012-02-08 20:48:30 UTC (rev 11440)
@@ -192,7 +192,7 @@
PrefixedOutStream pss(ss, BASH_GREEN "[INFO ] " BASH_CLEAR);
pss << "I have a precise number which is ";
- pss << std::setw(6) << std::setfill('0') << 156;
+ pss << std::setw(6) << std::setfill('0') << (int)156;
BOOST_REQUIRE_EQUAL(ss.str(),
BASH_GREEN "[INFO ] " BASH_CLEAR
More information about the mlpack-svn
mailing list