[mlpack-git] master, mlpack-1.0.x: Make rules classes for various dual-tree algorithms support (but not use) the idea of TraversalInfo classes. (99140c6)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:42:21 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 99140c68c2ca29dc43e6a95b5c38e9228125abc6
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Feb 6 20:15:09 2014 +0000

    Make rules classes for various dual-tree algorithms support (but not use) the
    idea of TraversalInfo classes.


>---------------------------------------------------------------

99140c68c2ca29dc43e6a95b5c38e9228125abc6
 src/mlpack/methods/emst/dtb_rules.hpp                  | 9 +++++++++
 src/mlpack/methods/fastmks/fastmks_rules.hpp           | 9 +++++++++
 src/mlpack/methods/range_search/range_search_rules.hpp | 9 +++++++++
 src/mlpack/methods/rann/ra_search_rules.hpp            | 9 +++++++++
 src/mlpack/methods/rann/ra_search_rules_impl.hpp       | 4 +++-
 5 files changed, 39 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/methods/emst/dtb_rules.hpp b/src/mlpack/methods/emst/dtb_rules.hpp
index 2566a54..5024ff6 100644
--- a/src/mlpack/methods/emst/dtb_rules.hpp
+++ b/src/mlpack/methods/emst/dtb_rules.hpp
@@ -9,6 +9,8 @@
 
 #include <mlpack/core.hpp>
 
+#include "../neighbor_search/ns_traversal_info.hpp"
+
 namespace mlpack {
 namespace emst {
 
@@ -103,6 +105,11 @@ class DTBRules
                  TreeType& referenceNode,
                  const double oldScore) const;
 
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
  private:
   //! The data points.
   const arma::mat& dataSet;
@@ -129,6 +136,8 @@ class DTBRules
    */
   inline double CalculateBound(TreeType& queryNode) const;
 
+  TraversalInfoType traversalInfo;
+
 }; // class DTBRules
 
 } // emst namespace
diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp
index 7413c08..659e448 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp
@@ -10,6 +10,8 @@
 #include <mlpack/core.hpp>
 #include <mlpack/core/tree/cover_tree/cover_tree.hpp>
 
+#include "../neighbor_search/ns_traversal_info.hpp"
+
 namespace mlpack {
 namespace fastmks {
 
@@ -89,6 +91,11 @@ class FastMKSRules
   //! Modify the number of times Score() was called.
   size_t& Scores() { return scores; }
 
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
  private:
   //! The reference dataset.
   const arma::mat& referenceSet;
@@ -128,6 +135,8 @@ class FastMKSRules
   size_t baseCases;
   //! For benchmarking.
   size_t scores;
+
+  TraversalInfoType traversalInfo;
 };
 
 }; // namespace fastmks
diff --git a/src/mlpack/methods/range_search/range_search_rules.hpp b/src/mlpack/methods/range_search/range_search_rules.hpp
index 5518296..5d925b1 100644
--- a/src/mlpack/methods/range_search/range_search_rules.hpp
+++ b/src/mlpack/methods/range_search/range_search_rules.hpp
@@ -7,6 +7,8 @@
 #ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
 #define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
 
+#include "../neighbor_search/ns_traversal_info.hpp"
+
 namespace mlpack {
 namespace range {
 
@@ -91,6 +93,11 @@ class RangeSearchRules
                  TreeType& referenceNode,
                  const double oldScore) const;
 
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
  private:
   //! The reference set.
   const arma::mat& referenceSet;
@@ -120,6 +127,8 @@ class RangeSearchRules
   //! add that to the results twice.
   void AddResult(const size_t queryIndex,
                  TreeType& referenceNode);
+
+  TraversalInfoType traversalInfo;
 };
 
 }; // namespace range
diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp
index 2e53ceb..453a9bf 100644
--- a/src/mlpack/methods/rann/ra_search_rules.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules.hpp
@@ -9,6 +9,8 @@
 #ifndef __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
 #define __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
 
+#include "../neighbor_search/ns_traversal_info.hpp"
+
 namespace mlpack {
 namespace neighbor {
 
@@ -206,6 +208,11 @@ class RASearchRules
       return arma::sum(numSamplesMade);
   }
 
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
  private:
   //! The reference set.
   const arma::mat& referenceSet;
@@ -243,6 +250,8 @@ class RASearchRules
   // TO REMOVE: just for testing
   size_t numDistComputations;
 
+  TraversalInfoType traversalInfo;
+
   /**
    * Insert a point into the neighbors and distances matrices; this is a helper
    * function.
diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
index b8da274..2fbfe58 100644
--- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
@@ -277,7 +277,9 @@ double RASearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
   // If this distance is better than any of the current candidates, the
   // SortDistance() function will give us the position to insert it into.
   arma::vec queryDist = distances.unsafe_col(queryIndex);
-  size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+  arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
+  size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
+      distance);
 
   // SortDistance() returns (size_t() - 1) if we shouldn't add it.
   if (insertPosition != (size_t() - 1))



More information about the mlpack-git mailing list