[mlpack-svn] r16320 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Feb 19 19:34:21 EST 2014


Author: rcurtin
Date: Wed Feb 19 19:34:21 2014
New Revision: 16320

Log:
Refactor NeighborSearch so it works with arbitrary TreeType::Mat types.  That
abstraction should probably be cleaned up a bit.


Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	Wed Feb 19 19:34:21 2014
@@ -199,14 +199,14 @@
  private:
   //! Copy of reference dataset (if we need it, because tree building modifies
   //! it).
-  arma::mat referenceCopy;
+  typename TreeType::Mat referenceCopy;
   //! Copy of query dataset (if we need it, because tree building modifies it).
-  arma::mat queryCopy;
+  typename TreeType::Mat queryCopy;
 
   //! Reference dataset.
-  const arma::mat& referenceSet;
+  const typename TreeType::Mat& referenceSet;
   //! Query dataset (may not be given).
-  const arma::mat& querySet;
+  const typename TreeType::Mat& querySet;
 
   //! Pointer to the root of the reference tree.
   TreeType* referenceTree;

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	Wed Feb 19 19:34:21 2014
@@ -17,8 +17,8 @@
 class NeighborSearchRules
 {
  public:
-  NeighborSearchRules(const arma::mat& referenceSet,
-                      const arma::mat& querySet,
+  NeighborSearchRules(const typename TreeType::Mat& referenceSet,
+                      const typename TreeType::Mat& querySet,
                       arma::Mat<size_t>& neighbors,
                       arma::mat& distances,
                       MetricType& metric);
@@ -95,10 +95,10 @@
 
  private:
   //! The reference set.
-  const arma::mat& referenceSet;
+  const typename TreeType::Mat& referenceSet;
 
   //! The query set.
-  const arma::mat& querySet;
+  const typename TreeType::Mat& querySet;
 
   //! The matrix the resultant neighbor indices should be stored in.
   arma::Mat<size_t>& neighbors;

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	Wed Feb 19 19:34:21 2014
@@ -15,8 +15,8 @@
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
-    const arma::mat& referenceSet,
-    const arma::mat& querySet,
+    const typename TreeType::Mat& referenceSet,
+    const typename TreeType::Mat& querySet,
     arma::Mat<size_t>& neighbors,
     arma::mat& distances,
     MetricType& metric) :
@@ -51,12 +51,12 @@
   if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
     return lastBaseCase;
 
-  double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
-                                    referenceSet.unsafe_col(referenceIndex));
+  double distance = metric.Evaluate(querySet.col(queryIndex),
+                                    referenceSet.col(referenceIndex));
   ++baseCases;
 
   // If this distance is better than any of the current candidates, the
-  // SortDistance() function will give us the position to insert it into.
+  // SortDistance() function will give us the poto insert it into.
   arma::vec queryDist = distances.unsafe_col(queryIndex);
   arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
   const size_t insertPosition = SortPolicy::SortDistance(queryDist,
@@ -105,8 +105,8 @@
   }
   else
   {
-    const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
-    distance = SortPolicy::BestPointToNodeDistance(queryPoint, &referenceNode);
+    distance = SortPolicy::BestPointToNodeDistance(querySet.col(queryIndex),
+        &referenceNode);
   }
 
   // Compare against the best k'th distance for this query point so far.



More information about the mlpack-svn mailing list