[mlpack-svn] r16224 - in mlpack/trunk/src/mlpack/methods: emst fastmks range_search rann
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Feb 6 15:15:09 EST 2014
Author: rcurtin
Date: Thu Feb 6 15:15:09 2014
New Revision: 16224
Log:
Make rules classes for various dual-tree algorithms support (but not use) the
idea of TraversalInfo classes.
Modified:
mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp
mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp Thu Feb 6 15:15:09 2014
@@ -9,6 +9,8 @@
#include <mlpack/core.hpp>
+#include "../neighbor_search/ns_traversal_info.hpp"
+
namespace mlpack {
namespace emst {
@@ -103,6 +105,11 @@
TreeType& referenceNode,
const double oldScore) const;
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
private:
//! The data points.
const arma::mat& dataSet;
@@ -129,6 +136,8 @@
*/
inline double CalculateBound(TreeType& queryNode) const;
+ TraversalInfoType traversalInfo;
+
}; // class DTBRules
} // emst namespace
Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp Thu Feb 6 15:15:09 2014
@@ -10,6 +10,8 @@
#include <mlpack/core.hpp>
#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+#include "../neighbor_search/ns_traversal_info.hpp"
+
namespace mlpack {
namespace fastmks {
@@ -89,6 +91,11 @@
//! Modify the number of times Score() was called.
size_t& Scores() { return scores; }
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
private:
//! The reference dataset.
const arma::mat& referenceSet;
@@ -128,6 +135,8 @@
size_t baseCases;
//! For benchmarking.
size_t scores;
+
+ TraversalInfoType traversalInfo;
};
}; // namespace fastmks
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp Thu Feb 6 15:15:09 2014
@@ -7,6 +7,8 @@
#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
+#include "../neighbor_search/ns_traversal_info.hpp"
+
namespace mlpack {
namespace range {
@@ -91,6 +93,11 @@
TreeType& referenceNode,
const double oldScore) const;
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
private:
//! The reference set.
const arma::mat& referenceSet;
@@ -120,6 +127,8 @@
//! add that to the results twice.
void AddResult(const size_t queryIndex,
TreeType& referenceNode);
+
+ TraversalInfoType traversalInfo;
};
}; // namespace range
Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp Thu Feb 6 15:15:09 2014
@@ -9,6 +9,8 @@
#ifndef __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
#define __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
+#include "../neighbor_search/ns_traversal_info.hpp"
+
namespace mlpack {
namespace neighbor {
@@ -206,6 +208,11 @@
return arma::sum(numSamplesMade);
}
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
private:
//! The reference set.
const arma::mat& referenceSet;
@@ -243,6 +250,8 @@
// TO REMOVE: just for testing
size_t numDistComputations;
+ TraversalInfoType traversalInfo;
+
/**
* Insert a point into the neighbors and distances matrices; this is a helper
* function.
Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp Thu Feb 6 15:15:09 2014
@@ -277,7 +277,9 @@
// If this distance is better than any of the current candidates, the
// SortDistance() function will give us the position to insert it into.
arma::vec queryDist = distances.unsafe_col(queryIndex);
- size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+ arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
+ distance);
// SortDistance() returns (size_t() - 1) if we shouldn't add it.
if (insertPosition != (size_t() - 1))
More information about the mlpack-svn
mailing list