[mlpack-svn] r14260 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Feb 9 20:35:52 EST 2013
Author: rcurtin
Date: 2013-02-09 20:35:52 -0500 (Sat, 09 Feb 2013)
New Revision: 14260
Added:
mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.hpp
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
Log:
Abstract away the unmapping of points. Also fix a few bugs in AllkFN because
that executable was really not kept up to date with the AllkNN executable.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt 2013-02-09 23:06:02 UTC (rev 14259)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt 2013-02-10 01:35:52 UTC (rev 14260)
@@ -12,6 +12,8 @@
sort_policies/furthest_neighbor_sort.cpp
sort_policies/furthest_neighbor_sort_impl.hpp
typedef.hpp
+ unmap.hpp
+ unmap.cpp
)
# Add directory name to sources.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2013-02-09 23:06:02 UTC (rev 14259)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2013-02-10 01:35:52 UTC (rev 14260)
@@ -172,36 +172,17 @@
arma::mat distancesOut(distances.n_rows, distances.n_cols);
arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
- // Do the actual remapping.
- if (CLI::GetParam<string>("query_file") != "")
- {
- for (size_t i = 0; i < distances.n_cols; ++i)
- {
- // Map distances (copy a column).
- distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; ++j)
- {
- neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
- }
- }
- }
+ // Map the points back to their original locations.
+ if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
+ Unmap(neighbors, distances, oldFromNewReferences, oldFromNewQueries,
+ neighborsOut, distancesOut, true);
+ else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
+ Unmap(neighbors, distances, oldFromNewReferences, neighborsOut,
+ distancesOut, true);
else
- {
- for (size_t i = 0; i < distances.n_cols; ++i)
- {
- // Map distances (copy a column).
- distancesOut.col(oldFromNewRefs[i]) = distances.col(i);
+ Unmap(neighbors, distances, oldFromNewReferences, oldFromNewReferences,
+ neighborsOut, distancesOut, true);
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; ++j)
- {
- neighborsOut(j, oldFromNewRefs[i]) = oldFromNewRefs[neighbors(j, i)];
- }
- }
- }
-
// Clean up.
if (queryTree)
delete queryTree;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2013-02-09 23:06:02 UTC (rev 14259)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2013-02-10 01:35:52 UTC (rev 14260)
@@ -13,6 +13,7 @@
#include <iostream>
#include "neighbor_search.hpp"
+#include "unmap.hpp"
using namespace std;
using namespace mlpack;
@@ -53,20 +54,29 @@
PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
- "dual-tree search.", "s");
+ "dual-tree search).", "s");
PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
"(experimental, may be slow).", "c");
+PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
+ "random orthogonal basis.", "R");
+PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
int main(int argc, char *argv[])
{
// Give CLI the command line parameters the user passed in.
CLI::ParseCommandLine(argc, argv);
+ if (CLI::GetParam<int>("seed") != 0)
+ math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ math::RandomSeed((size_t) std::time(NULL));
+
// Get all the parameters.
- string referenceFile = CLI::GetParam<string>("reference_file");
+ const string referenceFile = CLI::GetParam<string>("reference_file");
+ const string queryFile = CLI::GetParam<string>("query_file");
- string distancesFile = CLI::GetParam<string>("distances_file");
- string neighborsFile = CLI::GetParam<string>("neighbors_file");
+ const string distancesFile = CLI::GetParam<string>("distances_file");
+ const string neighborsFile = CLI::GetParam<string>("neighbors_file");
int lsInt = CLI::GetParam<int>("leaf_size");
@@ -74,14 +84,22 @@
bool naive = CLI::HasParam("naive");
bool singleMode = CLI::HasParam("single_mode");
+ const bool randomBasis = CLI::HasParam("random_basis");
arma::mat referenceData;
arma::mat queryData; // So it doesn't go out of scope.
- data::Load(referenceFile.c_str(), referenceData, true);
+ data::Load(referenceFile, referenceData, true);
Log::Info << "Loaded reference data from '" << referenceFile << "' ("
<< referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+ if (queryFile != "")
+ {
+ data::Load(queryFile, queryData, true);
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+ }
+
// Sanity check on k value: must be greater than 0, must be less than the
// number of reference points.
if (k > referenceData.n_cols)
@@ -108,6 +126,43 @@
if (naive)
leafSize = referenceData.n_cols;
+ // See if we want to project onto a random basis.
+ if (randomBasis)
+ {
+ // Generate the random basis.
+ while (true)
+ {
+ // [Q, R] = qr(randn(d, d));
+ // Q = Q * diag(sign(diag(R)));
+ arma::mat q, r;
+ if (arma::qr(q, r, arma::randn<arma::mat>(referenceData.n_rows,
+ referenceData.n_rows)))
+ {
+ arma::vec rDiag(r.n_rows);
+ for (size_t i = 0; i < rDiag.n_elem; ++i)
+ {
+ if (r(i, i) < 0)
+ rDiag(i) = -1;
+ else if (r(i, i) > 0)
+ rDiag(i) = 1;
+ else
+ rDiag(i) = 0;
+ }
+
+ q *= arma::diagmat(rDiag);
+
+ // Check if the determinant is positive.
+ if (arma::det(q) >= 0)
+ {
+ referenceData = q * referenceData;
+ if (queryFile != "")
+ queryData = q * queryData;
+ break;
+ }
+ }
+ }
+ }
+
arma::Mat<size_t> neighbors;
arma::mat distances;
@@ -125,9 +180,9 @@
Timer::Start("tree_building");
BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
- refTree(referenceData, oldFromNewRefs, leafSize);
+ refTree(referenceData, oldFromNewRefs, leafSize);
BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
- queryTree = NULL; // Empty for now.
+ queryTree = NULL; // Empty for now.
Timer::Stop("tree_building");
@@ -135,15 +190,11 @@
if (CLI::GetParam<string>("query_file") != "")
{
- string queryFile = CLI::GetParam<string>("query_file");
-
- data::Load(queryFile.c_str(), queryData, true);
-
if (naive && leafSize < queryData.n_cols)
leafSize = queryData.n_cols;
Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
Log::Info << "Building query tree..." << endl;
@@ -184,51 +235,17 @@
// construction.
Log::Info << "Re-mapping indices..." << endl;
- neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
- distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
-
- // Do the actual remapping.
+ // Map the results back to the correct places.
if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
- {
- for (size_t i = 0; i < distancesOut.n_cols; ++i)
- {
- // Map distances (copy a column) and square root.
- distances.col(oldFromNewQueries[i]) = sqrt(distancesOut.col(i));
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distancesOut.n_rows; ++j)
- {
- neighbors(j, oldFromNewQueries[i]) =
- oldFromNewRefs[neighborsOut(j, i)];
- }
- }
- }
+ Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries,
+ neighbors, distances, true);
else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
- {
- // No remapping of queries is necessary. So distances are the same.
- distances = sqrt(distancesOut);
-
- // The neighbor indices must be mapped.
- for (size_t j = 0; j < neighborsOut.n_elem; ++j)
- {
- neighbors[j] = oldFromNewRefs[neighborsOut[j]];
- }
- }
+ Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances,
+ true);
else
- {
- for (size_t i = 0; i < distancesOut.n_cols; ++i)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewRefs[i]) = sqrt(distancesOut.col(i));
+ Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs,
+ neighbors, distances, true);
- // Map indices of neighbors.
- for (size_t j = 0; j < distancesOut.n_rows; ++j)
- {
- neighbors(j, oldFromNewRefs[i]) = oldFromNewRefs[neighborsOut(j, i)];
- }
- }
- }
-
// Clean up.
if (queryTree)
delete queryTree;
@@ -256,10 +273,6 @@
// See if we have query data.
if (CLI::HasParam("query_file"))
{
- string queryFile = CLI::GetParam<string>("query_file");
-
- data::Load(queryFile, queryData, true);
-
// Build query tree.
if (!singleMode)
{
Added: mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.cpp 2013-02-10 01:35:52 UTC (rev 14260)
@@ -0,0 +1,64 @@
+/**
+ * @file unmap.cpp
+ * @author Ryan Curtin
+ *
+ * Auxiliary function to unmap neighbor search results.
+ */
+#include "unmap.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+// Useful in the dual-tree setting.
+void Unmap(const arma::Mat<size_t>& neighbors,
+ const arma::mat& distances,
+ const std::vector<size_t>& referenceMap,
+ const std::vector<size_t>& queryMap,
+ arma::Mat<size_t>& neighborsOut,
+ arma::mat& distancesOut,
+ const bool squareRoot)
+{
+ // Set matrices to correct size.
+ neighborsOut.set_size(neighbors.n_rows, neighbors.n_cols);
+ distancesOut.set_size(distances.n_rows, distances.n_cols);
+
+ // Unmap distances.
+ for (size_t i = 0; i < distances.n_cols; ++i)
+ {
+ // Map columns to the correct place. The ternary operator does not work
+ // here...
+ if (squareRoot)
+ distancesOut.col(queryMap[i]) = sqrt(distances.col(i));
+ else
+ distancesOut.col(queryMap[i]) = distances.col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; ++j)
+ neighborsOut(j, queryMap[i]) = referenceMap[neighbors(j, i)];
+ }
+}
+
+// Useful in the single-tree setting.
+void Unmap(const arma::Mat<size_t>& neighbors,
+ const arma::mat& distances,
+ const std::vector<size_t>& referenceMap,
+ arma::Mat<size_t>& neighborsOut,
+ arma::mat& distancesOut,
+ const bool squareRoot)
+{
+ // Set matrices to correct size.
+ neighborsOut.set_size(neighbors.n_rows, neighbors.n_cols);
+
+ // Take square root of distances, if necessary.
+ if (squareRoot)
+ distancesOut = sqrt(distances);
+ else
+ distancesOut = distances;
+
+ // Map neighbors back to original locations.
+ for (size_t j = 0; j < neighbors.n_elem; ++j)
+ neighborsOut[j] = referenceMap[neighbors[j]];
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
Added: mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/unmap.hpp 2013-02-10 01:35:52 UTC (rev 14260)
@@ -0,0 +1,61 @@
+/**
+ * @file unmap.hpp
+ * @author Ryan Curtin
+ *
+ * Convenience methods to unmap results.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_UNMAP_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_UNMAP_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * Assuming that the datasets have been mapped using the referenceMap and the
+ * queryMap (such as during kd-tree construction), unmap the columns of the
+ * distances and neighbors matrices into neighborsOut and distancesOut, and also
+ * unmap the entries in each row of neighbors. This is useful for the dual-tree
+ * case.
+ *
+ * @param neighbors Matrix of neighbors resulting from neighbor search.
+ * @param distances Matrix of distances resulting from neighbor search.
+ * @param referenceMap Mapping of reference set to old points.
+ * @param queryMap Mapping of query set to old points.
+ * @param neighborsOut Matrix to store unmapped neighbors into.
+ * @param distancesOut Matrix to store unmapped distances into.
+ * @param squareRoot If true, take the square root of the distances.
+ */
+void Unmap(const arma::Mat<size_t>& neighbors,
+ const arma::mat& distances,
+ const std::vector<size_t>& referenceMap,
+ const std::vector<size_t>& queryMap,
+ arma::Mat<size_t>& neighborsOut,
+ arma::mat& distancesOut,
+ const bool squareRoot = false);
+
+/**
+ * Assuming that the datasets have been mapped using referenceMap (such as
+ * during kd-tree construction), unmap the columns of the distances and
+ * neighbors matrices into neighborsOut and distancesOut, and also unmap the
+ * entries in each row of neighbors. This is useful for the single-tree case.
+ *
+ * @param neighbors Matrix of neighbors resulting from neighbor search.
+ * @param distances Matrix of distances resulting from neighbor search.
+ * @param referenceMap Mapping of reference set to old points.
+ * @param neighborsOut Matrix to store unmapped neighbors into.
+ * @param distancesOut Matrix to store unmapped distances into.
+ * @param squareRoot If true, take the square root of the distances.
+ */
+void Unmap(const arma::Mat<size_t>& neighbors,
+ const arma::mat& distances,
+ const std::vector<size_t>& referenceMap,
+ arma::Mat<size_t>& neighborsOut,
+ arma::mat& distancesOut,
+ const bool squareRoot = false);
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif
More information about the mlpack-svn
mailing list