[mlpack-svn] r16320 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Feb 19 19:34:21 EST 2014
Author: rcurtin
Date: Wed Feb 19 19:34:21 2014
New Revision: 16320
Log:
Refactor NeighborSearch so it works with arbitrary TreeType::Mat types. That
abstraction should probably be cleaned up a bit.
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp Wed Feb 19 19:34:21 2014
@@ -199,14 +199,14 @@
private:
//! Copy of reference dataset (if we need it, because tree building modifies
//! it).
- arma::mat referenceCopy;
+ typename TreeType::Mat referenceCopy;
//! Copy of query dataset (if we need it, because tree building modifies it).
- arma::mat queryCopy;
+ typename TreeType::Mat queryCopy;
//! Reference dataset.
- const arma::mat& referenceSet;
+ const typename TreeType::Mat& referenceSet;
//! Query dataset (may not be given).
- const arma::mat& querySet;
+ const typename TreeType::Mat& querySet;
//! Pointer to the root of the reference tree.
TreeType* referenceTree;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp Wed Feb 19 19:34:21 2014
@@ -17,8 +17,8 @@
class NeighborSearchRules
{
public:
- NeighborSearchRules(const arma::mat& referenceSet,
- const arma::mat& querySet,
+ NeighborSearchRules(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
arma::Mat<size_t>& neighbors,
arma::mat& distances,
MetricType& metric);
@@ -95,10 +95,10 @@
private:
//! The reference set.
- const arma::mat& referenceSet;
+ const typename TreeType::Mat& referenceSet;
//! The query set.
- const arma::mat& querySet;
+ const typename TreeType::Mat& querySet;
//! The matrix the resultant neighbor indices should be stored in.
arma::Mat<size_t>& neighbors;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp Wed Feb 19 19:34:21 2014
@@ -15,8 +15,8 @@
template<typename SortPolicy, typename MetricType, typename TreeType>
NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
- const arma::mat& referenceSet,
- const arma::mat& querySet,
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
arma::Mat<size_t>& neighbors,
arma::mat& distances,
MetricType& metric) :
@@ -51,12 +51,12 @@
if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
return lastBaseCase;
- double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceIndex));
+ double distance = metric.Evaluate(querySet.col(queryIndex),
+ referenceSet.col(referenceIndex));
++baseCases;
// If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the position to insert it into.
+ // SortDistance() function will give us the poto insert it into.
arma::vec queryDist = distances.unsafe_col(queryIndex);
arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
const size_t insertPosition = SortPolicy::SortDistance(queryDist,
@@ -105,8 +105,8 @@
}
else
{
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- distance = SortPolicy::BestPointToNodeDistance(queryPoint, &referenceNode);
+ distance = SortPolicy::BestPointToNodeDistance(querySet.col(queryIndex),
+ &referenceNode);
}
// Compare against the best k'th distance for this query point so far.
More information about the mlpack-svn
mailing list