[mlpack-git] master, mlpack-1.0.x: Make rules classes for various dual-tree algorithms support (but not use) the idea of TraversalInfo classes. (99140c6)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:42:21 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit 99140c68c2ca29dc43e6a95b5c38e9228125abc6
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Feb 6 20:15:09 2014 +0000
Make rules classes for various dual-tree algorithms support (but not use) the
idea of TraversalInfo classes.
>---------------------------------------------------------------
99140c68c2ca29dc43e6a95b5c38e9228125abc6
src/mlpack/methods/emst/dtb_rules.hpp | 9 +++++++++
src/mlpack/methods/fastmks/fastmks_rules.hpp | 9 +++++++++
src/mlpack/methods/range_search/range_search_rules.hpp | 9 +++++++++
src/mlpack/methods/rann/ra_search_rules.hpp | 9 +++++++++
src/mlpack/methods/rann/ra_search_rules_impl.hpp | 4 +++-
5 files changed, 39 insertions(+), 1 deletion(-)
diff --git a/src/mlpack/methods/emst/dtb_rules.hpp b/src/mlpack/methods/emst/dtb_rules.hpp
index 2566a54..5024ff6 100644
--- a/src/mlpack/methods/emst/dtb_rules.hpp
+++ b/src/mlpack/methods/emst/dtb_rules.hpp
@@ -9,6 +9,8 @@
#include <mlpack/core.hpp>
+#include "../neighbor_search/ns_traversal_info.hpp"
+
namespace mlpack {
namespace emst {
@@ -103,6 +105,11 @@ class DTBRules
TreeType& referenceNode,
const double oldScore) const;
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
private:
//! The data points.
const arma::mat& dataSet;
@@ -129,6 +136,8 @@ class DTBRules
*/
inline double CalculateBound(TreeType& queryNode) const;
+ TraversalInfoType traversalInfo;
+
}; // class DTBRules
} // emst namespace
diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp
index 7413c08..659e448 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp
@@ -10,6 +10,8 @@
#include <mlpack/core.hpp>
#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+#include "../neighbor_search/ns_traversal_info.hpp"
+
namespace mlpack {
namespace fastmks {
@@ -89,6 +91,11 @@ class FastMKSRules
//! Modify the number of times Score() was called.
size_t& Scores() { return scores; }
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
private:
//! The reference dataset.
const arma::mat& referenceSet;
@@ -128,6 +135,8 @@ class FastMKSRules
size_t baseCases;
//! For benchmarking.
size_t scores;
+
+ TraversalInfoType traversalInfo;
};
}; // namespace fastmks
diff --git a/src/mlpack/methods/range_search/range_search_rules.hpp b/src/mlpack/methods/range_search/range_search_rules.hpp
index 5518296..5d925b1 100644
--- a/src/mlpack/methods/range_search/range_search_rules.hpp
+++ b/src/mlpack/methods/range_search/range_search_rules.hpp
@@ -7,6 +7,8 @@
#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
+#include "../neighbor_search/ns_traversal_info.hpp"
+
namespace mlpack {
namespace range {
@@ -91,6 +93,11 @@ class RangeSearchRules
TreeType& referenceNode,
const double oldScore) const;
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
private:
//! The reference set.
const arma::mat& referenceSet;
@@ -120,6 +127,8 @@ class RangeSearchRules
//! add that to the results twice.
void AddResult(const size_t queryIndex,
TreeType& referenceNode);
+
+ TraversalInfoType traversalInfo;
};
}; // namespace range
diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp
index 2e53ceb..453a9bf 100644
--- a/src/mlpack/methods/rann/ra_search_rules.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules.hpp
@@ -9,6 +9,8 @@
#ifndef __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
#define __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
+#include "../neighbor_search/ns_traversal_info.hpp"
+
namespace mlpack {
namespace neighbor {
@@ -206,6 +208,11 @@ class RASearchRules
return arma::sum(numSamplesMade);
}
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
private:
//! The reference set.
const arma::mat& referenceSet;
@@ -243,6 +250,8 @@ class RASearchRules
// TO REMOVE: just for testing
size_t numDistComputations;
+ TraversalInfoType traversalInfo;
+
/**
* Insert a point into the neighbors and distances matrices; this is a helper
* function.
diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
index b8da274..2fbfe58 100644
--- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
@@ -277,7 +277,9 @@ double RASearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
// If this distance is better than any of the current candidates, the
// SortDistance() function will give us the position to insert it into.
arma::vec queryDist = distances.unsafe_col(queryIndex);
- size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+ arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
+ distance);
// SortDistance() returns (size_t() - 1) if we shouldn't add it.
if (insertPosition != (size_t() - 1))
More information about the mlpack-git
mailing list