[mlpack-svn] r13680 - mlpack/trunk/src/mlpack/methods/emst
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Oct 11 09:41:45 EDT 2012
Author: march
Date: 2012-10-11 09:41:44 -0400 (Thu, 11 Oct 2012)
New Revision: 13680
Added:
mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
Modified:
mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
Log:
Updated EMST code to use tree traverser abstractions.
Modified: mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt 2012-10-10 21:10:32 UTC (rev 13679)
+++ mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt 2012-10-11 13:41:44 UTC (rev 13680)
@@ -8,6 +8,8 @@
# dtb
dtb.hpp
dtb_impl.hpp
+ dtb_rules.hpp
+ dtb_rules_impl.hpp
edge_pair.hpp
)
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2012-10-10 21:10:32 UTC (rev 13679)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2012-10-11 13:41:44 UTC (rev 13680)
@@ -79,8 +79,7 @@
/**
* Performs the MST calculation using the Dual-Tree Boruvka algorithm, using any
- * type of tree. At the moment this class does not support arbitrary distance
- * metrics, and uses the squared Euclidean distance.
+ * type of tree.
*
* For more information on the algorithm, see the following citation:
*
@@ -110,9 +109,14 @@
* More advanced usage of the class can use different types of trees, pass in an
* already-built tree, or compute the MST using the O(n^2) naive algorithm.
*
+ * @tparam MetricType The metric to use. IMPORTANT: this hasn't really been
+ * tested with anything other than the L2 metric, so user beware. Note that the
+ * tree type needs to compute bounds using the same metric as the type
+ * specified here.
* @tparam TreeType Type of tree to use. Should use DTBStat as a statistic.
*/
template<
+ typename MetricType = metric::SquaredEuclideanDistance,
typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
>
class DualTreeBoruvka
@@ -148,6 +152,9 @@
//! Total distance of the tree.
double totalDist;
+
+ //! The metric
+ MetricType metric;
// For sorting the edge list after the computation.
struct SortEdgesHelper
@@ -169,7 +176,8 @@
*/
DualTreeBoruvka(const typename TreeType::Mat& dataset,
const bool naive = false,
- const size_t leafSize = 1);
+ const size_t leafSize = 1,
+ const MetricType metric = MetricType());
/**
* Create the DualTreeBoruvka object with an already initialized tree. This
@@ -188,7 +196,8 @@
* @param tree Pre-built tree.
* @param dataset Dataset corresponding to the pre-built tree.
*/
- DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset);
+ DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset,
+ const MetricType metric = MetricType());
/**
* Delete the tree, if it was created inside the object.
@@ -218,18 +227,6 @@
void AddAllEdges();
/**
- * Handles the base case computation. Also called by naive.
- */
- double BaseCase(const TreeType* queryNode, const TreeType* referenceNode);
-
- /**
- * Handles the recursive calls to find the nearest neighbors in an iteration
- */
- void DualTreeRecursion(TreeType *queryNode,
- TreeType *referenceNode,
- double incomingDistance);
-
- /**
* Unpermute the edge list and output it to results.
*/
void EmitResults(arma::mat& results);
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp 2012-10-10 21:10:32 UTC (rev 13679)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp 2012-10-11 13:41:44 UTC (rev 13680)
@@ -8,7 +8,7 @@
#ifndef __MLPACK_METHODS_EMST_DTB_IMPL_HPP
#define __MLPACK_METHODS_EMST_DTB_IMPL_HPP
-#include <mlpack/core.hpp>
+#include "dtb_rules.hpp"
namespace mlpack {
namespace emst {
@@ -57,17 +57,19 @@
* Takes in a reference to the data set. Copies the data, builds the tree,
* and initializes all of the member variables.
*/
-template<typename TreeType>
-DualTreeBoruvka<TreeType>::DualTreeBoruvka(
+template<typename MetricType, typename TreeType>
+DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
const typename TreeType::Mat& dataset,
const bool naive,
- const size_t leafSize) :
+ const size_t leafSize,
+ const MetricType metric) :
dataCopy(dataset),
data(dataCopy), // The reference points to our copy of the data.
ownTree(true),
naive(naive),
connections(data.n_cols),
- totalDist(0.0)
+ totalDist(0.0),
+ metric(metric)
{
Timer::Start("emst/tree_building");
@@ -93,16 +95,18 @@
neighborsDistances.fill(DBL_MAX);
} // Constructor
-template<typename TreeType>
-DualTreeBoruvka<TreeType>::DualTreeBoruvka(
+template<typename MetricType, typename TreeType>
+DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
TreeType* tree,
- const typename TreeType::Mat& dataset) :
+ const typename TreeType::Mat& dataset,
+ const MetricType metric) :
data(dataset),
tree(tree),
ownTree(true),
naive(false),
connections(data.n_cols),
- totalDist(0.0)
+ totalDist(0.0),
+ metric(metric)
{
edges.reserve(data.n_cols - 1); // fill with EdgePairs
@@ -112,8 +116,8 @@
neighborsDistances.fill(DBL_MAX);
}
-template<typename TreeType>
-DualTreeBoruvka<TreeType>::~DualTreeBoruvka()
+template<typename MetricType, typename TreeType>
+DualTreeBoruvka<MetricType, TreeType>::~DualTreeBoruvka()
{
if (ownTree)
delete tree;
@@ -123,25 +127,24 @@
* Iteratively find the nearest neighbor of each component until the MST is
* complete.
*/
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::ComputeMST(arma::mat& results)
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
{
Timer::Start("emst/mst_computation");
totalDist = 0; // Reset distance.
+ typedef DTBRules<MetricType, TreeType> RuleType;
+ RuleType rules(data, connections, neighborsDistances, neighborsInComponent,
+ neighborsOutComponent, metric);
+
while (edges.size() < (data.n_cols - 1))
{
- // Compute neighbors.
- if (naive)
- {
- BaseCase(tree, tree);
- }
- else
- {
- DualTreeRecursion(tree, tree, DBL_MAX);
- }
-
+
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ traverser.Traverse(*tree, *tree);
+
AddAllEdges();
Cleanup();
@@ -159,8 +162,8 @@
/**
* Adds a single edge to the edge list
*/
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::AddEdge(const size_t e1,
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::AddEdge(const size_t e1,
const size_t e2,
const double distance)
{
@@ -176,8 +179,8 @@
/**
* Adds all the edges found in one iteration to the list of neighbors.
*/
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::AddAllEdges()
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::AddAllEdges()
{
for (size_t i = 0; i < data.n_cols; i++)
{
@@ -195,165 +198,11 @@
}
} // AddAllEdges
-
/**
- * Handles the base case computation. Also called by naive.
- */
-template<typename TreeType>
-double DualTreeBoruvka<TreeType>::BaseCase(const TreeType* queryNode,
- const TreeType* referenceNode)
-{
- double newUpperBound = -1.0;
-
- for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
- ++queryIndex)
- {
- // Find the index of the component the query is in.
- size_t queryComponentIndex = connections.Find(queryIndex);
-
- for (size_t referenceIndex = referenceNode->Begin();
- referenceIndex < referenceNode->End(); ++referenceIndex)
- {
- size_t referenceComponentIndex = connections.Find(referenceIndex);
-
- if (queryComponentIndex != referenceComponentIndex)
- {
- double distance = metric::LMetric<2>::Evaluate(data.col(queryIndex),
- data.col(referenceIndex));
-
- if (distance < neighborsDistances[queryComponentIndex])
- {
- Log::Assert(queryIndex != referenceIndex);
-
- neighborsDistances[queryComponentIndex] = distance;
- neighborsInComponent[queryComponentIndex] = queryIndex;
- neighborsOutComponent[queryComponentIndex] = referenceIndex;
- } // if distance
- } // if indices not equal
- } // for referenceIndex
-
- if (newUpperBound < neighborsDistances[queryComponentIndex])
- newUpperBound = neighborsDistances[queryComponentIndex];
-
- } // for queryIndex
-
- Log::Assert(newUpperBound >= 0.0);
-
- return newUpperBound;
-
-} // BaseCase
-
-
-/**
- * Handles the recursive calls to find the nearest neighbors in an iteration
- */
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::DualTreeRecursion(TreeType *queryNode,
- TreeType *referenceNode,
- double incomingDistance)
-{
- // Check for a distance prune.
- if (queryNode->Stat().MaxNeighborDistance() < incomingDistance)
- {
- // Pruned by distance.
- return;
- }
- // Check for a component prune.
- else if ((queryNode->Stat().ComponentMembership() >= 0)
- && (queryNode->Stat().ComponentMembership() ==
- referenceNode->Stat().ComponentMembership()))
- {
- // Pruned by component membership.
- Log::Assert(referenceNode->Stat().ComponentMembership() >= 0);
- return;
- }
- else if (queryNode->IsLeaf() && referenceNode->IsLeaf()) // Base case.
- {
- double new_bound = BaseCase(queryNode, referenceNode);
- queryNode->Stat().MaxNeighborDistance() = new_bound;
- }
- else if (queryNode->IsLeaf()) // Other recursive calls.
- {
- // Recurse on referenceNode only.
- double leftDist =
- queryNode->Bound().MinDistance(referenceNode->Left()->Bound());
- double rightDist =
- queryNode->Bound().MinDistance(referenceNode->Right()->Bound());
-
- if (leftDist < rightDist)
- {
- DualTreeRecursion(queryNode, referenceNode->Left(), leftDist);
- DualTreeRecursion(queryNode, referenceNode->Right(), rightDist);
- }
- else
- {
- DualTreeRecursion(queryNode, referenceNode->Right(), rightDist);
- DualTreeRecursion(queryNode, referenceNode->Left(), leftDist);
- }
- }
- else if (referenceNode->IsLeaf())
- {
- // Recurse on queryNode only.
- double leftDist =
- queryNode->Left()->Bound().MinDistance(referenceNode->Bound());
- double rightDist =
- queryNode->Right()->Bound().MinDistance(referenceNode->Bound());
-
- DualTreeRecursion(queryNode->Left(), referenceNode, leftDist);
- DualTreeRecursion(queryNode->Right(), referenceNode, rightDist);
-
- // Update queryNode's stat.
- queryNode->Stat().MaxNeighborDistance() =
- std::max(queryNode->Left()->Stat().MaxNeighborDistance(),
- queryNode->Right()->Stat().MaxNeighborDistance());
- }
- else
- {
- // Recurse on both.
- double leftDist = queryNode->Left()->Bound().MinDistance(
- referenceNode->Left()->Bound());
- double rightDist = queryNode->Left()->Bound().MinDistance(
- referenceNode->Right()->Bound());
-
- if (leftDist < rightDist)
- {
- DualTreeRecursion(queryNode->Left(), referenceNode->Left(), leftDist);
- DualTreeRecursion(queryNode->Left(), referenceNode->Right(),
- rightDist);
- }
- else
- {
- DualTreeRecursion(queryNode->Left(), referenceNode->Right(), rightDist);
- DualTreeRecursion(queryNode->Left(), referenceNode->Left(), leftDist);
- }
-
- leftDist = queryNode->Right()->Bound().MinDistance(
- referenceNode->Left()->Bound());
- rightDist = queryNode->Right()->Bound().MinDistance(
- referenceNode->Right()->Bound());
-
- if (leftDist < rightDist)
- {
- DualTreeRecursion(queryNode->Right(), referenceNode->Left(), leftDist);
- DualTreeRecursion(queryNode->Right(), referenceNode->Right(), rightDist);
- }
- else
- {
- DualTreeRecursion(queryNode->Right(), referenceNode->Right(), rightDist);
- DualTreeRecursion(queryNode->Right(), referenceNode->Left(), leftDist);
- }
-
- queryNode->Stat().MaxNeighborDistance() =
- std::max(queryNode->Left()->Stat().MaxNeighborDistance(),
- queryNode->Right()->Stat().MaxNeighborDistance());
- }
-} // DualTreeRecursion
-
-/**
* Unpermute the edge list (if necessary) and output it to results.
*/
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::EmitResults(arma::mat& results)
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
{
// Sort the edges.
std::sort(edges.begin(), edges.end(), SortFun);
@@ -400,10 +249,10 @@
/**
* This function resets the values in the nodes of the tree nearest neighbor
- * distance, check for fully connected nodes
+ * distance and checks for fully connected nodes.
*/
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::CleanupHelper(TreeType* tree)
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
{
tree->Stat().MaxNeighborDistance() = DBL_MAX;
@@ -440,8 +289,8 @@
/**
* The values stored in the tree must be reset on each iteration.
*/
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::Cleanup()
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
{
for (size_t i = 0; i < data.n_cols; i++)
{
Added: mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp 2012-10-11 13:41:44 UTC (rev 13680)
@@ -0,0 +1,150 @@
+/**
+ * @file dtb.hpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Tree traverser rules for the DualTreeBoruvka algorithm.
+ */
+
+
+#ifndef __MLPACK_METHODS_EMST_DTB_RULES_HPP
+#define __MLPACK_METHODS_EMST_DTB_RULES_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace emst {
+
+template<typename MetricType, typename TreeType>
+class DTBRules
+{
+ public:
+
+ DTBRules(const arma::mat& dataSet,
+ UnionFind& connections,
+ arma::vec& neighborsDistances,
+ arma::Col<size_t>& neighborsInComponent,
+ arma::Col<size_t>& neighborsOutComponent,
+ MetricType& metric);
+
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+ // Update bounds. Needs a better name.
+ void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ */
+ double Score(const size_t queryIndex, TreeType& referenceNode);
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult);
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore);
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursionm while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ */
+ double Score(TreeType& queryNode, TreeType& referenceNode) const;
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param oldScore Old score produced by Socre() (or Rescore()).
+ */
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ private:
+
+ // This class needs to know what points are connected to one another
+
+
+ // Things I need
+ // UnionFind storing the tree structure at this iteration
+ // neighborDistances
+ // neighborInComponent
+ // neighborOutComponent
+
+ //! The data points.
+ const arma::mat& dataSet;
+
+ //! Stores the tree structure so far
+ UnionFind& connections;
+
+ //! The distance to the candidate nearest neighbor for each component.
+ arma::vec& neighborsDistances;
+
+ //! The index of the point in the component that is an endpoint of the
+ //! candidate edge.
+ arma::Col<size_t>& neighborsInComponent;
+
+ //! The index of the point outside of the component that is an endpoint
+ //! of the candidate edge.
+ arma::Col<size_t>& neighborsOutComponent;
+
+ //! The metric
+ MetricType& metric;
+
+}; // class DTBRules
+
+} // emst namespace
+} // mlpack namespace
+
+#include "dtb_rules_impl.hpp"
+
+#endif
Added: mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp 2012-10-11 13:41:44 UTC (rev 13680)
@@ -0,0 +1,245 @@
+/**
+ * @file dtb_impl.hpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Tree traverser rules for the DualTreeBoruvka algorithm.
+ */
+
+
+#ifndef __MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
+#define __MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
+
+namespace mlpack {
+namespace emst {
+
+template<typename MetricType, typename TreeType>
+DTBRules<MetricType, TreeType>::
+DTBRules(const arma::mat& dataSet,
+ UnionFind& connections,
+ arma::vec& neighborsDistances,
+ arma::Col<size_t>& neighborsInComponent,
+ arma::Col<size_t>& neighborsOutComponent,
+ MetricType& metric)
+:
+ dataSet(dataSet),
+ connections(connections),
+ neighborsDistances(neighborsDistances),
+ neighborsInComponent(neighborsInComponent),
+ neighborsOutComponent(neighborsOutComponent),
+ metric(metric)
+{
+ // nothing else to do
+} // constructor
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::BaseCase(const size_t queryIndex,
+ const size_t referenceIndex)
+{
+
+ // Check if the points are in the same component at this iteration.
+ // If not, return the distance between them.
+ // Also responsible for storing this as the current neighbor
+
+ double newUpperBound = -1.0;
+
+ // Find the index of the component the query is in.
+ size_t queryComponentIndex = connections.Find(queryIndex);
+
+ size_t referenceComponentIndex = connections.Find(referenceIndex);
+
+ if (queryComponentIndex != referenceComponentIndex)
+ {
+ double distance = metric.Evaluate(dataSet.col(queryIndex),
+ dataSet.col(referenceIndex));
+
+ if (distance < neighborsDistances[queryComponentIndex])
+ {
+ Log::Assert(queryIndex != referenceIndex);
+
+ neighborsDistances[queryComponentIndex] = distance;
+ neighborsInComponent[queryComponentIndex] = queryIndex;
+ neighborsOutComponent[queryComponentIndex] = referenceIndex;
+
+ } // if distance
+ } // if indices not equal
+
+ if (newUpperBound < neighborsDistances[queryComponentIndex])
+ newUpperBound = neighborsDistances[queryComponentIndex];
+
+ Log::Assert(newUpperBound >= 0.0);
+
+ return newUpperBound;
+
+} // BaseCase()
+
+template<typename MetricType, typename TreeType>
+void DTBRules<MetricType, TreeType>::UpdateAfterRecursion(
+ TreeType& queryNode,
+ TreeType& /* referenceNode */)
+{
+
+ // Find the worst distance that the children found (including any points), and
+ // update the bound accordingly.
+ double newUpperBound = 0.0;
+
+ // First look through children nodes.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ if (newUpperBound < queryNode.Child(i).Stat().MaxNeighborDistance())
+ newUpperBound = queryNode.Child(i).Stat().MaxNeighborDistance();
+ }
+
+ // Now look through children points.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ size_t pointComponent = connections.Find(queryNode.Point(i));
+ if (newUpperBound < neighborsDistances[pointComponent])
+ newUpperBound = neighborsDistances[pointComponent];
+ }
+
+ // Update the bound in the query's stat
+ queryNode.Stat().MaxNeighborDistance() = newUpperBound;
+
+} // UpdateAfterRecursion
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
+ TreeType& referenceNode)
+{
+
+ size_t queryComponentIndex = connections.Find(queryIndex);
+
+ // If the query belongs to the same component as all of the references,
+ // then prune.
+ // Casting this to stop a warning about comparing unsigned to signed
+ // values.
+ if (queryComponentIndex == (size_t)referenceNode.Stat().ComponentMembership())
+ return DBL_MAX;
+
+ const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
+
+ const double distance = referenceNode.MinDistance(queryPoint);
+
+ // If all the points in the reference node are farther than the candidate
+ // nearest neighbor for the query's component, we prune.
+ return neighborsDistances[queryComponentIndex] < distance
+ ? DBL_MAX : distance;
+
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult)
+{
+ // I don't really understand the last argument here
+ // It just gets passed in the distance call, otherwise this function
+ // is the same as the one above
+
+ size_t queryComponentIndex = connections.Find(queryIndex);
+
+ // if the query belongs to the same component as all of the references,
+ // then prune
+ if (queryComponentIndex == referenceNode.Stat().ComponentMembership())
+ return DBL_MAX;
+
+ const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
+
+ const double distance = referenceNode.MinDistance(queryPoint,
+ baseCaseResult);
+
+ // If all the points in the reference node are farther than the candidate
+ // nearest neighbor for the query's component, we prune.
+ return neighborsDistances[queryComponentIndex] < distance
+ ? DBL_MAX : distance;
+
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore)
+{
+ // We don't need to check component membership again, because it can't
+ // change inside a single iteration.
+
+ // If we are already pruning, still prune.
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ if (oldScore > neighborsDistances[connections.Find(queryIndex)])
+ return DBL_MAX;
+ else
+ return oldScore;
+
+} // Rescore
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
+ TreeType& referenceNode) const
+{
+ // If all the queries belong to the same component as all the references
+ // then we prune.
+ if ((queryNode.Stat().ComponentMembership() >= 0)
+ && (queryNode.Stat().ComponentMembership() ==
+ referenceNode.Stat().ComponentMembership()))
+ return DBL_MAX;
+
+ double distance = queryNode.MinDistance(&referenceNode);
+
+ // If all the points in the reference node are farther than the candidate
+ // nearest neighbor for all queries in the node, we prune.
+ return queryNode.Stat().MaxNeighborDistance() < distance
+ ? DBL_MAX : distance;
+
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+
+ // If all the queries belong to the same component as all the references
+ // then we prune.
+ if ((queryNode.Stat().ComponentMembership() >= 0)
+ && (queryNode.Stat().ComponentMembership() ==
+ referenceNode.Stat().ComponentMembership()))
+ return DBL_MAX;
+
+ const double distance = queryNode.MinDistance(referenceNode,
+ baseCaseResult);
+
+ // If all the points in the reference node are farther than the candidate
+ // nearest neighbor for all queries in the node, we prune.
+ return queryNode.Stat().MaxNeighborDistance() < distance
+ ? DBL_MAX : distance;
+
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
+ TreeType& /* referenceNode */,
+ const double oldScore) const
+{
+
+ // Same as above, but for nodes,
+
+ // If we are already pruning, still prune.
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ if (oldScore > queryNode.Stat().MaxNeighborDistance())
+ return DBL_MAX;
+ else
+ return oldScore;
+
+} // Rescore
+
+} // namespace emst
+} // namespace mlpack
+
+
+
+#endif
+
More information about the mlpack-svn
mailing list