[mlpack-svn] r10719 - mlpack/trunk/src/mlpack/methods/range_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 12 11:12:01 EST 2011
Author: rcurtin
Date: 2011-12-12 11:12:00 -0500 (Mon, 12 Dec 2011)
New Revision: 10719
Modified:
mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
Log:
Stop using unsafe_col().
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp 2011-12-12 10:47:10 UTC (rev 10718)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp 2011-12-12 16:12:00 UTC (rev 10719)
@@ -45,8 +45,8 @@
* @param leafSize The leaf size to be used during tree construction.
* @param metric Instantiated distance metric.
*/
- RangeSearch(const arma::mat& referenceSet,
- const arma::mat& querySet,
+ RangeSearch(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
const bool naive = false,
const bool singleMode = false,
const size_t leafSize = 20,
@@ -71,7 +71,7 @@
* @param leafSize The leaf size to be used during tree construction.
* @param metric Instantiated distance metric.
*/
- RangeSearch(const arma::mat& referenceSet,
+ RangeSearch(const typename TreeType::Mat& referenceSet,
const bool naive = false,
const bool singleMode = false,
const size_t leafSize = 20,
@@ -108,8 +108,8 @@
*/
RangeSearch(const TreeType* referenceTree,
const TreeType* queryTree,
- const arma::mat& referenceSet,
- const arma::mat& querySet,
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
const bool singleMode = false,
const MetricType metric = MetricType());
@@ -141,7 +141,7 @@
* @param metric Instantiated distance metric.
*/
RangeSearch(const TreeType* referenceTree,
- const arma::mat& referenceSet,
+ const typename TreeType::Mat& referenceSet,
const bool singleMode = false,
const MetricType metric = MetricType());
@@ -220,25 +220,28 @@
*
* @param referenceNode Reference node.
* @param queryPoint Point to query for.
+ * @param queryIndex Index of query node.
* @param range Range of distances to search for.
* @param neighbors Object holding list of neighbors.
* @param distances Object holding list of distances.
*/
+ template<typename VecType>
void SingleTreeRecursion(const TreeType* referenceNode,
- const arma::vec& queryPoint,
+ const VecType& queryPoint,
+ const size_t queryIndex,
const math::Range& range,
std::vector<size_t>& neighbors,
std::vector<double>& distances);
//! Copy of reference matrix; used when a tree is built internally.
- arma::mat referenceCopy;
+ typename TreeType::Mat referenceCopy;
//! Copy of query matrix; used when a tree is built internally.
- arma::mat queryCopy;
+ typename TreeType::Mat queryCopy;
//! Reference set (data should be accessed using this).
- const arma::mat& referenceSet;
+ const typename TreeType::Mat& referenceSet;
//! Query set (data should be accessed using this).
- const arma::mat& querySet;
+ const typename TreeType::Mat& querySet;
//! Reference tree.
TreeType* referenceTree;
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp 2011-12-12 10:47:10 UTC (rev 10718)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp 2011-12-12 16:12:00 UTC (rev 10719)
@@ -14,12 +14,13 @@
namespace range {
template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(const arma::mat& referenceSet,
- const arma::mat& querySet,
- const bool naive,
- const bool singleMode,
- const size_t leafSize,
- const MetricType metric) :
+RangeSearch<MetricType, TreeType>::RangeSearch(
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ const MetricType metric) :
referenceCopy(referenceSet),
queryCopy(querySet),
referenceSet(referenceCopy),
@@ -45,11 +46,12 @@
}
template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(const arma::mat& referenceSet,
- const bool naive,
- const bool singleMode,
- const size_t leafSize,
- const MetricType metric) :
+RangeSearch<MetricType, TreeType>::RangeSearch(
+ const typename TreeType::Mat& referenceSet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ const MetricType metric) :
referenceCopy(referenceSet),
referenceSet(referenceCopy),
querySet(referenceCopy),
@@ -72,12 +74,13 @@
}
template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(const TreeType* referenceTree,
- const TreeType* queryTree,
- const arma::mat& referenceSet,
- const arma::mat& querySet,
- const bool singleMode,
- const MetricType metric) :
+RangeSearch<MetricType, TreeType>::RangeSearch(
+ const TreeType* referenceTree,
+ const TreeType* queryTree,
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool singleMode,
+ const MetricType metric) :
referenceSet(referenceSet),
querySet(querySet),
referenceTree(referenceTree),
@@ -93,10 +96,11 @@
}
template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(const TreeType* referenceTree,
- const arma::mat& referenceSet,
- const bool singleMode,
- const MetricType metric) :
+RangeSearch<MetricType, TreeType>::RangeSearch(
+ const TreeType* referenceTree,
+ const typename TreeType::Mat& referenceSet,
+ const bool singleMode,
+ const MetricType metric) :
referenceSet(referenceSet),
querySet(referenceSet),
referenceTree(referenceTree),
@@ -164,7 +168,7 @@
// Loop over each of the query points.
for (size_t i = 0; i < querySet.n_cols; i++)
{
- SingleTreeRecursion(referenceTree, querySet.unsafe_col(i), range,
+ SingleTreeRecursion(referenceTree, querySet.col(i), i, range,
(*neighborPtr)[i], (*distancePtr)[i]);
}
}
@@ -296,12 +300,11 @@
for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
queryIndex++)
{
- // Get the query point from the matrix.
- arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ double minDistance =
+ referenceNode->Bound().MinDistance(querySet.col(queryIndex));
+ double maxDistance =
+ referenceNode->Bound().MaxDistance(querySet.col(queryIndex));
- double minDistance = referenceNode->Bound().MinDistance(queryPoint);
- double maxDistance = referenceNode->Bound().MaxDistance(queryPoint);
-
// Now see if any points could fall into the range.
if (range.Contains(math::Range(minDistance, maxDistance)))
{
@@ -312,10 +315,9 @@
// We can't add points that are ourselves.
if (referenceNode != queryNode || referenceIndex != queryIndex)
{
- arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+ double distance = metric.Evaluate(querySet.col(queryIndex),
+ referenceSet.col(referenceIndex));
- double distance = metric.Evaluate(queryPoint, referencePoint);
-
// If this lies in the range, add it.
if (range.Contains(distance))
{
@@ -384,9 +386,11 @@
}
template<typename MetricType, typename TreeType>
+template<typename VecType>
void RangeSearch<MetricType, TreeType>::SingleTreeRecursion(
const TreeType* referenceNode,
- const arma::vec& queryPoint,
+ const VecType& queryPoint,
+ const size_t queryIndex,
const math::Range& range,
std::vector<size_t>& neighbors,
std::vector<double>& distances)
@@ -399,12 +403,11 @@
referenceNode->End(); referenceIndex++)
{
// Don't add this point if it is the same as the query point.
- if (!(referenceSet.colptr(referenceIndex) == queryPoint.memptr()))
+ if (!queryTree && !(referenceIndex == queryIndex))
{
- arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+ double distance = metric.Evaluate(queryPoint,
+ referenceSet.col(referenceIndex));
- double distance = metric.Evaluate(queryPoint, referencePoint);
-
// See if the point is in the range we are looking for.
if (range.Contains(distance))
{
@@ -425,8 +428,8 @@
if (range.Contains(distanceLeft))
{
// The left may have points we want to recurse to.
- SingleTreeRecursion(referenceNode->Left(), queryPoint, range, neighbors,
- distances);
+ SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
+ range, neighbors, distances);
}
else
{
@@ -436,8 +439,8 @@
if (range.Contains(distanceRight))
{
// The right may have points we want to recurse to.
- SingleTreeRecursion(referenceNode->Right(), queryPoint, range, neighbors,
- distances);
+ SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
+ range, neighbors, distances);
}
else
{
More information about the mlpack-svn
mailing list