[mlpack-svn] r12601 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 2 12:31:09 EDT 2012


Author: rcurtin
Date: 2012-05-02 12:31:08 -0400 (Wed, 02 May 2012)
New Revision: 12601

Added:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
Log:
Add the NeighborSearchRules class, which defines how the
SingleTreeDepthFirstTraverser can perform a NeighborSearch.  Adapt the
NeighborSearch class to use this.  It is not as fast as it could be.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt	2012-05-02 16:30:21 UTC (rev 12600)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt	2012-05-02 16:31:08 UTC (rev 12601)
@@ -5,6 +5,8 @@
 set(SOURCES
   neighbor_search.hpp
   neighbor_search_impl.hpp
+  neighbor_search_rules.hpp
+  neighbor_search_rules_impl.hpp
   sort_policies/nearest_neighbor_sort.hpp
   sort_policies/nearest_neighbor_sort.cpp
   sort_policies/nearest_neighbor_sort_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2012-05-02 16:30:21 UTC (rev 12600)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2012-05-02 16:31:08 UTC (rev 12601)
@@ -10,6 +10,9 @@
 
 #include <mlpack/core.hpp>
 
+#include <mlpack/core/tree/traversers/single_tree_depth_first_traverser.hpp>
+#include "neighbor_search_rules.hpp"
+
 using namespace mlpack::neighbor;
 
 // Construct the object.
@@ -182,28 +185,20 @@
   {
     if (singleMode)
     {
-      // Do one tenth of the query set at a time.
-      size_t chunk = querySet.n_cols / 10;
+      // Create the helper object for the tree traversal.
+      NeighborSearchRules<SortPolicy, MetricType, TreeType> rules(referenceSet,
+          querySet, *neighborPtr, *distancePtr, metric);
 
-      for (size_t i = 0; i < 10; i++)
-      {
-        for (size_t j = 0; j < chunk; j++)
-        {
-          double worstDistance = SortPolicy::WorstDistance();
-          ComputeSingleNeighborsRecursion(i * chunk + j,
-              querySet.unsafe_col(i * chunk + j), referenceTree, worstDistance,
-              *neighborPtr, *distancePtr);
-        }
-      }
+      // Create the traverser.
+      tree::SingleTreeDepthFirstTraverser<TreeType,
+          NeighborSearchRules<SortPolicy, MetricType, TreeType> >
+          traverser(rules);
 
-      // The last tenth is differently sized...
-      for (size_t i = 0; i < querySet.n_cols % 10; i++)
-      {
-        size_t ind = (querySet.n_cols / 10) * 10 + i;
-        double worstDistance = SortPolicy::WorstDistance();
-        ComputeSingleNeighborsRecursion(ind, querySet.unsafe_col(ind),
-            referenceTree, worstDistance, *neighborPtr, *distancePtr);
-      }
+      // Now have it traverse for each point.
+      for (size_t i = 0; i < querySet.n_cols; ++i)
+        traverser.Traverse(i, *referenceTree);
+
+      Log::Info << "Pruned " << traverser.NumPrunes() << " nodes." << std::endl;
     }
     else // Dual-tree recursion.
     {

Added: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-05-02 16:31:08 UTC (rev 12601)
@@ -0,0 +1,66 @@
+/**
+ * @file neighbor_search_rules.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the pruning rules and base case rules necessary to perform a
+ * tree-based search (with an arbitrary tree) for the NeighborSearch class.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+class NeighborSearchRules
+{
+ public:
+  NeighborSearchRules(const arma::mat& referenceSet,
+                      const arma::mat& querySet,
+                      arma::Mat<size_t>& neighbors,
+                      arma::mat& distances,
+                      MetricType& metric);
+
+  void BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+  // For single-tree traversal.
+  bool CanPrune(const size_t queryIndex, TreeType& referenceNode);
+
+ private:
+  //! The reference set.
+  const arma::mat& referenceSet;
+
+  //! The query set.
+  const arma::mat& querySet;
+
+  //! The matrix the resultant neighbor indices should be stored in.
+  arma::Mat<size_t>& neighbors;
+
+  //! The matrix the resultant neighbor distances should be stored in.
+  arma::mat& distances;
+
+  //! The instantiated metric.
+  MetricType& metric;
+
+  /**
+   * Insert a point into the neighbors and distances matrices; this is a helper
+   * function.
+   *
+   * @param queryIndex Index of point whose neighbors we are inserting into.
+   * @param pos Position in list to insert into.
+   * @param neighbor Index of reference point which is being inserted.
+   * @param distance Distance from query point to reference point.
+   */
+  void InsertNeighbor(const size_t queryIndex,
+                      const size_t pos,
+                      const size_t neighbor,
+                      const double distance);
+};
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation.
+#include "neighbor_search_rules_impl.hpp"
+
+#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP

Added: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-05-02 16:31:08 UTC (rev 12601)
@@ -0,0 +1,107 @@
+/**
+ * @file nearest_neighbor_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of NearestNeighborRules.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "neighbor_search_rules.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
+    const arma::mat& referenceSet,
+    const arma::mat& querySet,
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances,
+    MetricType& metric) :
+    referenceSet(referenceSet),
+    querySet(querySet),
+    neighbors(neighbors),
+    distances(distances),
+    metric(metric)
+{ /* Nothing left to do. */ }
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline void NeighborSearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
+    const size_t queryIndex,
+    const size_t referenceIndex)
+{
+  // If the datasets are the same, then this search is only using one dataset
+  // and we should not return identical points.
+  if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
+    return;
+
+  double distance = metric.Evaluate(querySet.col(queryIndex),
+                                    referenceSet.col(referenceIndex));
+
+  // If this distance is better than any of the current candidates, the
+  // SortDistance() function will give us the position to insert it into.
+  arma::vec queryDist = distances.unsafe_col(queryIndex);
+  size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+
+  // SortDistance() returns (size_t() - 1) if we shouldn't add it.
+  if (insertPosition != (size_t() - 1))
+    InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::CanPrune(
+    const size_t queryIndex,
+    TreeType& referenceNode)
+{
+  // Find the best distance between the query point and the node.
+  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+  const double distance =
+      SortPolicy::BestPointToNodeDistance(queryPoint, &referenceNode);
+  const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+  // If this is better than the best distance we've seen so far, maybe there
+  // will be something down this node.
+  if (SortPolicy::IsBetter(distance, bestDistance))
+    return false; // We cannot prune.
+  else
+    return true; // There cannot be anything better in this node.  So prune it.
+}
+
+/**
+ * Helper function to insert a point into the neighbors and distances matrices.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+    const size_t queryIndex,
+    const size_t pos,
+    const size_t neighbor,
+    const double distance)
+{
+  // We only memmove() if there is actually a need to shift something.
+  if (pos < (distances.n_rows - 1))
+  {
+    int len = (distances.n_rows - 1) - pos;
+    memmove(distances.colptr(queryIndex) + (pos + 1),
+        distances.colptr(queryIndex) + pos,
+        sizeof(double) * len);
+    memmove(neighbors.colptr(queryIndex) + (pos + 1),
+        neighbors.colptr(queryIndex) + pos,
+        sizeof(size_t) * len);
+  }
+
+  // Now put the new information in the right index.
+  distances(pos, queryIndex) = distance;
+  neighbors(pos, queryIndex) = neighbor;
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP




More information about the mlpack-svn mailing list