[mlpack-svn] r16224 - in mlpack/trunk/src/mlpack/methods: emst fastmks range_search rann

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Feb 6 15:15:09 EST 2014


Author: rcurtin
Date: Thu Feb  6 15:15:09 2014
New Revision: 16224

Log:
Make rules classes for various dual-tree algorithms support (but not use) the
idea of TraversalInfo classes.


Modified:
   mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
   mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp	Thu Feb  6 15:15:09 2014
@@ -9,6 +9,8 @@
 
 #include <mlpack/core.hpp>
 
+#include "../neighbor_search/ns_traversal_info.hpp"
+
 namespace mlpack {
 namespace emst {
 
@@ -103,6 +105,11 @@
                  TreeType& referenceNode,
                  const double oldScore) const;
 
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
  private:
   //! The data points.
   const arma::mat& dataSet;
@@ -129,6 +136,8 @@
    */
   inline double CalculateBound(TreeType& queryNode) const;
 
+  TraversalInfoType traversalInfo;
+
 }; // class DTBRules
 
 } // emst namespace

Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp	Thu Feb  6 15:15:09 2014
@@ -10,6 +10,8 @@
 #include <mlpack/core.hpp>
 #include <mlpack/core/tree/cover_tree/cover_tree.hpp>
 
+#include "../neighbor_search/ns_traversal_info.hpp"
+
 namespace mlpack {
 namespace fastmks {
 
@@ -89,6 +91,11 @@
   //! Modify the number of times Score() was called.
   size_t& Scores() { return scores; }
 
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
  private:
   //! The reference dataset.
   const arma::mat& referenceSet;
@@ -128,6 +135,8 @@
   size_t baseCases;
   //! For benchmarking.
   size_t scores;
+
+  TraversalInfoType traversalInfo;
 };
 
 }; // namespace fastmks

Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp	Thu Feb  6 15:15:09 2014
@@ -7,6 +7,8 @@
 #ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
 #define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
 
+#include "../neighbor_search/ns_traversal_info.hpp"
+
 namespace mlpack {
 namespace range {
 
@@ -91,6 +93,11 @@
                  TreeType& referenceNode,
                  const double oldScore) const;
 
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
  private:
   //! The reference set.
   const arma::mat& referenceSet;
@@ -120,6 +127,8 @@
   //! add that to the results twice.
   void AddResult(const size_t queryIndex,
                  TreeType& referenceNode);
+
+  TraversalInfoType traversalInfo;
 };
 
 }; // namespace range

Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp	Thu Feb  6 15:15:09 2014
@@ -9,6 +9,8 @@
 #ifndef __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
 #define __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
 
+#include "../neighbor_search/ns_traversal_info.hpp"
+
 namespace mlpack {
 namespace neighbor {
 
@@ -206,6 +208,11 @@
       return arma::sum(numSamplesMade);
   }
 
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
  private:
   //! The reference set.
   const arma::mat& referenceSet;
@@ -243,6 +250,8 @@
   // TO REMOVE: just for testing
   size_t numDistComputations;
 
+  TraversalInfoType traversalInfo;
+
   /**
    * Insert a point into the neighbors and distances matrices; this is a helper
    * function.

Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp	Thu Feb  6 15:15:09 2014
@@ -277,7 +277,9 @@
   // If this distance is better than any of the current candidates, the
   // SortDistance() function will give us the position to insert it into.
   arma::vec queryDist = distances.unsafe_col(queryIndex);
-  size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+  arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
+  size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
+      distance);
 
   // SortDistance() returns (size_t() - 1) if we shouldn't add it.
   if (insertPosition != (size_t() - 1))



More information about the mlpack-svn mailing list