[mlpack-svn] r15549 - mlpack/trunk/src/mlpack/methods/emst
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jul 25 18:06:30 EDT 2013
Author: rcurtin
Date: Thu Jul 25 18:06:29 2013
New Revision: 15549
Log:
Clean up trailing whitespace, reformat a few lines.
Modified:
mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp Thu Jul 25 18:06:29 2013
@@ -4,8 +4,6 @@
*
* Tree traverser rules for the DualTreeBoruvka algorithm.
*/
-
-
#ifndef __MLPACK_METHODS_EMST_DTB_RULES_HPP
#define __MLPACK_METHODS_EMST_DTB_RULES_HPP
@@ -18,19 +16,18 @@
class DTBRules
{
public:
-
DTBRules(const arma::mat& dataSet,
UnionFind& connections,
arma::vec& neighborsDistances,
arma::Col<size_t>& neighborsInComponent,
arma::Col<size_t>& neighborsOutComponent,
MetricType& metric);
-
+
double BaseCase(const size_t queryIndex, const size_t referenceIndex);
-
+
// Update bounds. Needs a better name.
void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
-
+
/**
* Get the score for recursion order. A low score indicates priority for
* recursion, while DBL_MAX indicates that the node should not be recursed
@@ -40,7 +37,7 @@
* @param referenceNode Candidate node to be recursed into.
*/
double Score(const size_t queryIndex, TreeType& referenceNode);
-
+
/**
* Get the score for recursion order, passing the base case result (in the
* situation where it may be needed to calculate the recursion order). A low
@@ -54,7 +51,7 @@
double Score(const size_t queryIndex,
TreeType& referenceNode,
const double baseCaseResult);
-
+
/**
* Re-evaluate the score for recursion order. A low score indicates priority
* for recursion, while DBL_MAX indicates that the node should not be recursed
@@ -69,7 +66,7 @@
double Rescore(const size_t queryIndex,
TreeType& referenceNode,
const double oldScore);
-
+
/**
* Get the score for recursion order. A low score indicates priority for
* recursionm while DBL_MAX indicates that the node should not be recursed
@@ -79,7 +76,7 @@
* @param referenceNode Candidate reference node to recurse into.
*/
double Score(TreeType& queryNode, TreeType& referenceNode) const;
-
+
/**
* Get the score for recursion order, passing the base case result (in the
* situation where it may be needed to calculate the recursion order). A low
@@ -93,7 +90,7 @@
double Score(TreeType& queryNode,
TreeType& referenceNode,
const double baseCaseResult) const;
-
+
/**
* Re-evaluate the score for recursion order. A low score indicates priority
* for recursion, while DBL_MAX indicates that the node should not be recursed
@@ -110,36 +107,25 @@
const double oldScore) const;
private:
-
- // This class needs to know what points are connected to one another
-
-
- // Things I need
- // UnionFind storing the tree structure at this iteration
- // neighborDistances
- // neighborInComponent
- // neighborOutComponent
-
//! The data points.
const arma::mat& dataSet;
-
+
//! Stores the tree structure so far
UnionFind& connections;
-
+
//! The distance to the candidate nearest neighbor for each component.
arma::vec& neighborsDistances;
-
+
//! The index of the point in the component that is an endpoint of the
//! candidate edge.
arma::Col<size_t>& neighborsInComponent;
-
+
//! The index of the point outside of the component that is an endpoint
//! of the candidate edge.
arma::Col<size_t>& neighborsOutComponent;
-
+
//! The metric
MetricType& metric;
-
}; // class DTBRules
} // emst namespace
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp Thu Jul 25 18:06:29 2013
@@ -4,8 +4,6 @@
*
* Tree traverser rules for the DualTreeBoruvka algorithm.
*/
-
-
#ifndef __MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
#define __MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
@@ -28,67 +26,63 @@
neighborsOutComponent(neighborsOutComponent),
metric(metric)
{
- // nothing else to do
-} // constructor
+ // Nothing else to do.
+}
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::BaseCase(const size_t queryIndex,
const size_t referenceIndex)
{
-
// Check if the points are in the same component at this iteration.
- // If not, return the distance between them.
- // Also responsible for storing this as the current neighbor
-
+ // If not, return the distance between them. Also, store a better result as
+ // the current neighbor, if necessary.
double newUpperBound = -1.0;
-
+
// Find the index of the component the query is in.
size_t queryComponentIndex = connections.Find(queryIndex);
-
+
size_t referenceComponentIndex = connections.Find(referenceIndex);
-
+
if (queryComponentIndex != referenceComponentIndex)
{
double distance = metric.Evaluate(dataSet.col(queryIndex),
dataSet.col(referenceIndex));
-
+
if (distance < neighborsDistances[queryComponentIndex])
{
Log::Assert(queryIndex != referenceIndex);
-
+
neighborsDistances[queryComponentIndex] = distance;
neighborsInComponent[queryComponentIndex] = queryIndex;
neighborsOutComponent[queryComponentIndex] = referenceIndex;
-
- } // if distance
- } // if indices not equal
+
+ }
+ }
if (newUpperBound < neighborsDistances[queryComponentIndex])
newUpperBound = neighborsDistances[queryComponentIndex];
-
+
Log::Assert(newUpperBound >= 0.0);
-
+
return newUpperBound;
-
-} // BaseCase()
+}
template<typename MetricType, typename TreeType>
void DTBRules<MetricType, TreeType>::UpdateAfterRecursion(
TreeType& queryNode,
TreeType& /* referenceNode */)
{
-
// Find the worst distance that the children found (including any points), and
// update the bound accordingly.
double newUpperBound = 0.0;
-
+
// First look through children nodes.
for (size_t i = 0; i < queryNode.NumChildren(); ++i)
{
if (newUpperBound < queryNode.Child(i).Stat().MaxNeighborDistance())
newUpperBound = queryNode.Child(i).Stat().MaxNeighborDistance();
}
-
+
// Now look through children points.
for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
@@ -96,83 +90,69 @@
if (newUpperBound < neighborsDistances[pointComponent])
newUpperBound = neighborsDistances[pointComponent];
}
-
- // Update the bound in the query's stat
+
+ // Update the bound in the query's statistic.
queryNode.Stat().MaxNeighborDistance() = newUpperBound;
-
-} // UpdateAfterRecursion
+}
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
TreeType& referenceNode)
{
-
size_t queryComponentIndex = connections.Find(queryIndex);
-
+
// If the query belongs to the same component as all of the references,
- // then prune.
- // Casting this to stop a warning about comparing unsigned to signed
- // values.
- if (queryComponentIndex == (size_t)referenceNode.Stat().ComponentMembership())
+ // then prune. The cast is to stop a warning about comparing unsigned to
+ // signed values.
+ if (queryComponentIndex ==
+ (size_t) referenceNode.Stat().ComponentMembership())
return DBL_MAX;
-
+
const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
-
+
const double distance = referenceNode.MinDistance(queryPoint);
-
+
// If all the points in the reference node are farther than the candidate
// nearest neighbor for the query's component, we prune.
return neighborsDistances[queryComponentIndex] < distance
? DBL_MAX : distance;
-
-} // Score()
+}
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
TreeType& referenceNode,
- const double baseCaseResult)
+ const double baseCaseResult)
{
// I don't really understand the last argument here
// It just gets passed in the distance call, otherwise this function
- // is the same as the one above
-
+ // is the same as the one above.
size_t queryComponentIndex = connections.Find(queryIndex);
-
- // if the query belongs to the same component as all of the references,
- // then prune
+
+ // If the query belongs to the same component as all of the references,
+ // then prune.
if (queryComponentIndex == referenceNode.Stat().ComponentMembership())
return DBL_MAX;
-
+
const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
-
const double distance = referenceNode.MinDistance(queryPoint,
baseCaseResult);
-
+
// If all the points in the reference node are farther than the candidate
// nearest neighbor for the query's component, we prune.
- return neighborsDistances[queryComponentIndex] < distance
- ? DBL_MAX : distance;
-
-} // Score()
+ return (neighborsDistances[queryComponentIndex] < distance) ? DBL_MAX :
+ distance;
+}
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
TreeType& referenceNode,
- const double oldScore)
+ const double oldScore)
{
// We don't need to check component membership again, because it can't
// change inside a single iteration.
-
- // If we are already pruning, still prune.
- if (oldScore == DBL_MAX)
- return oldScore;
-
- if (oldScore > neighborsDistances[connections.Find(queryIndex)])
- return DBL_MAX;
- else
- return oldScore;
-
-} // Rescore
+ return (oldScore > neighborsDistances[connections.Find(queryIndex)])
+ ? DBL_MAX : oldScore;
+}
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
@@ -180,66 +160,53 @@
{
// If all the queries belong to the same component as all the references
// then we prune.
- if ((queryNode.Stat().ComponentMembership() >= 0)
- && (queryNode.Stat().ComponentMembership() ==
- referenceNode.Stat().ComponentMembership()))
+ if ((queryNode.Stat().ComponentMembership() >= 0) &&
+ (queryNode.Stat().ComponentMembership() ==
+ referenceNode.Stat().ComponentMembership()))
return DBL_MAX;
- double distance = queryNode.MinDistance(&referenceNode);
-
+ const double distance = queryNode.MinDistance(&referenceNode);
+
// If all the points in the reference node are farther than the candidate
// nearest neighbor for all queries in the node, we prune.
- return queryNode.Stat().MaxNeighborDistance() < distance
- ? DBL_MAX : distance;
-
-} // Score()
+ return (queryNode.Stat().MaxNeighborDistance() < distance) ? DBL_MAX :
+ distance;
+}
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
TreeType& referenceNode,
const double baseCaseResult) const
{
-
// If all the queries belong to the same component as all the references
// then we prune.
- if ((queryNode.Stat().ComponentMembership() >= 0)
- && (queryNode.Stat().ComponentMembership() ==
- referenceNode.Stat().ComponentMembership()))
+ if ((queryNode.Stat().ComponentMembership() >= 0) &&
+ (queryNode.Stat().ComponentMembership() ==
+ referenceNode.Stat().ComponentMembership()))
return DBL_MAX;
-
+
const double distance = queryNode.MinDistance(referenceNode,
baseCaseResult);
-
+
// If all the points in the reference node are farther than the candidate
// nearest neighbor for all queries in the node, we prune.
- return queryNode.Stat().MaxNeighborDistance() < distance
- ? DBL_MAX : distance;
-
-} // Score()
+ return (queryNode.Stat().MaxNeighborDistance() < distance) ? DBL_MAX :
+ distance;
+}
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
TreeType& /* referenceNode */,
const double oldScore) const
{
-
- // Same as above, but for nodes,
-
- // If we are already pruning, still prune.
- if (oldScore == DBL_MAX)
- return oldScore;
-
- if (oldScore > queryNode.Stat().MaxNeighborDistance())
- return DBL_MAX;
- else
- return oldScore;
-
-} // Rescore
+ return (oldScore > queryNode.Stat().MaxNeighborDistance()) ? DBL_MAX :
+ oldScore;
+}
} // namespace emst
} // namespace mlpack
-#endif
+#endif
More information about the mlpack-svn
mailing list