[mlpack-svn] r13680 - mlpack/trunk/src/mlpack/methods/emst

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Oct 11 09:41:45 EDT 2012


Author: march
Date: 2012-10-11 09:41:44 -0400 (Thu, 11 Oct 2012)
New Revision: 13680

Added:
   mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
   mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
   mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
Log:
Updated EMST code to use tree traverser abstractions.

Modified: mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt	2012-10-10 21:10:32 UTC (rev 13679)
+++ mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt	2012-10-11 13:41:44 UTC (rev 13680)
@@ -8,6 +8,8 @@
    # dtb
    dtb.hpp
    dtb_impl.hpp
+   dtb_rules.hpp
+   dtb_rules_impl.hpp
    edge_pair.hpp
 )
 

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	2012-10-10 21:10:32 UTC (rev 13679)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	2012-10-11 13:41:44 UTC (rev 13680)
@@ -79,8 +79,7 @@
 
 /**
  * Performs the MST calculation using the Dual-Tree Boruvka algorithm, using any
- * type of tree.  At the moment this class does not support arbitrary distance
- * metrics, and uses the squared Euclidean distance.
+ * type of tree.  
  *
  * For more information on the algorithm, see the following citation:
  *
@@ -110,9 +109,14 @@
  * More advanced usage of the class can use different types of trees, pass in an
  * already-built tree, or compute the MST using the O(n^2) naive algorithm.
  *
+ * @tparam MetricType The metric to use.  IMPORTANT: this hasn't really been 
+ * tested with anything other than the L2 metric, so user beware. Note that the 
+ * tree type needs to compute bounds using the same metric as the type 
+ * specified here.
  * @tparam TreeType Type of tree to use.  Should use DTBStat as a statistic.
  */
 template<
+  typename MetricType = metric::SquaredEuclideanDistance,
   typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
 >
 class DualTreeBoruvka
@@ -148,6 +152,9 @@
 
   //! Total distance of the tree.
   double totalDist;
+  
+  //! The metric
+  MetricType metric;
 
   // For sorting the edge list after the computation.
   struct SortEdgesHelper
@@ -169,7 +176,8 @@
    */
   DualTreeBoruvka(const typename TreeType::Mat& dataset,
                   const bool naive = false,
-                  const size_t leafSize = 1);
+                  const size_t leafSize = 1,
+                  const MetricType metric = MetricType());
 
   /**
    * Create the DualTreeBoruvka object with an already initialized tree.  This
@@ -188,7 +196,8 @@
    * @param tree Pre-built tree.
    * @param dataset Dataset corresponding to the pre-built tree.
    */
-  DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset);
+  DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset,
+                  const MetricType metric = MetricType());
 
   /**
    * Delete the tree, if it was created inside the object.
@@ -218,18 +227,6 @@
   void AddAllEdges();
 
   /**
-   * Handles the base case computation.  Also called by naive.
-   */
-  double BaseCase(const TreeType* queryNode, const TreeType* referenceNode);
-
-  /**
-   * Handles the recursive calls to find the nearest neighbors in an iteration
-   */
-  void DualTreeRecursion(TreeType *queryNode,
-                         TreeType *referenceNode,
-                         double incomingDistance);
-
-  /**
    * Unpermute the edge list and output it to results.
    */
   void EmitResults(arma::mat& results);

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp	2012-10-10 21:10:32 UTC (rev 13679)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp	2012-10-11 13:41:44 UTC (rev 13680)
@@ -8,7 +8,7 @@
 #ifndef __MLPACK_METHODS_EMST_DTB_IMPL_HPP
 #define __MLPACK_METHODS_EMST_DTB_IMPL_HPP
 
-#include <mlpack/core.hpp>
+#include "dtb_rules.hpp"
 
 namespace mlpack {
 namespace emst {
@@ -57,17 +57,19 @@
  * Takes in a reference to the data set.  Copies the data, builds the tree,
  * and initializes all of the member variables.
  */
-template<typename TreeType>
-DualTreeBoruvka<TreeType>::DualTreeBoruvka(
+template<typename MetricType, typename TreeType>
+DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
     const typename TreeType::Mat& dataset,
     const bool naive,
-    const size_t leafSize) :
+    const size_t leafSize,
+    const MetricType metric) :
     dataCopy(dataset),
     data(dataCopy), // The reference points to our copy of the data.
     ownTree(true),
     naive(naive),
     connections(data.n_cols),
-    totalDist(0.0)
+    totalDist(0.0),
+    metric(metric)
 {
   Timer::Start("emst/tree_building");
 
@@ -93,16 +95,18 @@
   neighborsDistances.fill(DBL_MAX);
 } // Constructor
 
-template<typename TreeType>
-DualTreeBoruvka<TreeType>::DualTreeBoruvka(
+template<typename MetricType, typename TreeType>
+DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
     TreeType* tree,
-    const typename TreeType::Mat& dataset) :
+    const typename TreeType::Mat& dataset,
+    const MetricType metric) :
     data(dataset),
     tree(tree),
     ownTree(true),
     naive(false),
     connections(data.n_cols),
-    totalDist(0.0)
+    totalDist(0.0),
+    metric(metric)
 {
   edges.reserve(data.n_cols - 1); // fill with EdgePairs
 
@@ -112,8 +116,8 @@
   neighborsDistances.fill(DBL_MAX);
 }
 
-template<typename TreeType>
-DualTreeBoruvka<TreeType>::~DualTreeBoruvka()
+template<typename MetricType, typename TreeType>
+DualTreeBoruvka<MetricType, TreeType>::~DualTreeBoruvka()
 {
   if (ownTree)
     delete tree;
@@ -123,25 +127,24 @@
  * Iteratively find the nearest neighbor of each component until the MST is
  * complete.
  */
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::ComputeMST(arma::mat& results)
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
 {
   Timer::Start("emst/mst_computation");
 
   totalDist = 0; // Reset distance.
 
+  typedef DTBRules<MetricType, TreeType> RuleType;
+  RuleType rules(data, connections, neighborsDistances, neighborsInComponent,
+                 neighborsOutComponent, metric);
+
   while (edges.size() < (data.n_cols - 1))
   {
-    // Compute neighbors.
-    if (naive)
-    {
-      BaseCase(tree, tree);
-    }
-    else
-    {
-      DualTreeRecursion(tree, tree, DBL_MAX);
-    }
-
+    
+    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+    
+    traverser.Traverse(*tree, *tree);
+    
     AddAllEdges();
 
     Cleanup();
@@ -159,8 +162,8 @@
 /**
  * Adds a single edge to the edge list
  */
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::AddEdge(const size_t e1,
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::AddEdge(const size_t e1,
                                         const size_t e2,
                                         const double distance)
 {
@@ -176,8 +179,8 @@
 /**
  * Adds all the edges found in one iteration to the list of neighbors.
  */
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::AddAllEdges()
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::AddAllEdges()
 {
   for (size_t i = 0; i < data.n_cols; i++)
   {
@@ -195,165 +198,11 @@
   }
 } // AddAllEdges
 
-
 /**
- * Handles the base case computation.  Also called by naive.
- */
-template<typename TreeType>
-double DualTreeBoruvka<TreeType>::BaseCase(const TreeType* queryNode,
-                                           const TreeType* referenceNode)
-{
-  double newUpperBound = -1.0;
-
-  for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
-       ++queryIndex)
-  {
-    // Find the index of the component the query is in.
-    size_t queryComponentIndex = connections.Find(queryIndex);
-
-    for (size_t referenceIndex = referenceNode->Begin();
-         referenceIndex < referenceNode->End(); ++referenceIndex)
-    {
-      size_t referenceComponentIndex = connections.Find(referenceIndex);
-
-      if (queryComponentIndex != referenceComponentIndex)
-      {
-        double distance = metric::LMetric<2>::Evaluate(data.col(queryIndex),
-            data.col(referenceIndex));
-
-        if (distance < neighborsDistances[queryComponentIndex])
-        {
-          Log::Assert(queryIndex != referenceIndex);
-
-          neighborsDistances[queryComponentIndex] = distance;
-          neighborsInComponent[queryComponentIndex] = queryIndex;
-          neighborsOutComponent[queryComponentIndex] = referenceIndex;
-        } // if distance
-      } // if indices not equal
-    } // for referenceIndex
-
-    if (newUpperBound < neighborsDistances[queryComponentIndex])
-      newUpperBound = neighborsDistances[queryComponentIndex];
-
-  } // for queryIndex
-
-  Log::Assert(newUpperBound >= 0.0);
-
-  return newUpperBound;
-
-} // BaseCase
-
-
-/**
- * Handles the recursive calls to find the nearest neighbors in an iteration
- */
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::DualTreeRecursion(TreeType *queryNode,
-                                                  TreeType *referenceNode,
-                                                  double incomingDistance)
-{
-  // Check for a distance prune.
-  if (queryNode->Stat().MaxNeighborDistance() < incomingDistance)
-  {
-    // Pruned by distance.
-    return;
-  }
-  // Check for a component prune.
-  else if ((queryNode->Stat().ComponentMembership() >= 0)
-        && (queryNode->Stat().ComponentMembership() ==
-               referenceNode->Stat().ComponentMembership()))
-  {
-    // Pruned by component membership.
-    Log::Assert(referenceNode->Stat().ComponentMembership() >= 0);
-    return;
-  }
-  else if (queryNode->IsLeaf() && referenceNode->IsLeaf()) // Base case.
-  {
-    double new_bound = BaseCase(queryNode, referenceNode);
-    queryNode->Stat().MaxNeighborDistance() = new_bound;
-  }
-  else if (queryNode->IsLeaf()) // Other recursive calls.
-  {
-    // Recurse on referenceNode only.
-    double leftDist =
-        queryNode->Bound().MinDistance(referenceNode->Left()->Bound());
-    double rightDist =
-        queryNode->Bound().MinDistance(referenceNode->Right()->Bound());
-
-    if (leftDist < rightDist)
-    {
-      DualTreeRecursion(queryNode, referenceNode->Left(), leftDist);
-      DualTreeRecursion(queryNode, referenceNode->Right(), rightDist);
-    }
-    else
-    {
-      DualTreeRecursion(queryNode, referenceNode->Right(), rightDist);
-      DualTreeRecursion(queryNode, referenceNode->Left(), leftDist);
-    }
-  }
-  else if (referenceNode->IsLeaf())
-  {
-    // Recurse on queryNode only.
-    double leftDist =
-        queryNode->Left()->Bound().MinDistance(referenceNode->Bound());
-    double rightDist =
-        queryNode->Right()->Bound().MinDistance(referenceNode->Bound());
-
-    DualTreeRecursion(queryNode->Left(), referenceNode, leftDist);
-    DualTreeRecursion(queryNode->Right(), referenceNode, rightDist);
-
-    // Update queryNode's stat.
-    queryNode->Stat().MaxNeighborDistance() =
-        std::max(queryNode->Left()->Stat().MaxNeighborDistance(),
-                 queryNode->Right()->Stat().MaxNeighborDistance());
-  }
-  else
-  {
-    // Recurse on both.
-    double leftDist = queryNode->Left()->Bound().MinDistance(
-        referenceNode->Left()->Bound());
-    double rightDist = queryNode->Left()->Bound().MinDistance(
-        referenceNode->Right()->Bound());
-
-    if (leftDist < rightDist)
-    {
-      DualTreeRecursion(queryNode->Left(), referenceNode->Left(), leftDist);
-      DualTreeRecursion(queryNode->Left(), referenceNode->Right(),
-          rightDist);
-    }
-    else
-    {
-      DualTreeRecursion(queryNode->Left(), referenceNode->Right(), rightDist);
-      DualTreeRecursion(queryNode->Left(), referenceNode->Left(), leftDist);
-    }
-
-    leftDist = queryNode->Right()->Bound().MinDistance(
-        referenceNode->Left()->Bound());
-    rightDist = queryNode->Right()->Bound().MinDistance(
-        referenceNode->Right()->Bound());
-
-    if (leftDist < rightDist)
-    {
-      DualTreeRecursion(queryNode->Right(), referenceNode->Left(), leftDist);
-      DualTreeRecursion(queryNode->Right(), referenceNode->Right(), rightDist);
-    }
-    else
-    {
-      DualTreeRecursion(queryNode->Right(), referenceNode->Right(), rightDist);
-      DualTreeRecursion(queryNode->Right(), referenceNode->Left(), leftDist);
-    }
-
-    queryNode->Stat().MaxNeighborDistance() =
-        std::max(queryNode->Left()->Stat().MaxNeighborDistance(),
-                 queryNode->Right()->Stat().MaxNeighborDistance());
-  }
-} // DualTreeRecursion
-
-/**
  * Unpermute the edge list (if necessary) and output it to results.
  */
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::EmitResults(arma::mat& results)
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
 {
   // Sort the edges.
   std::sort(edges.begin(), edges.end(), SortFun);
@@ -400,10 +249,10 @@
 
 /**
  * This function resets the values in the nodes of the tree nearest neighbor
- * distance, check for fully connected nodes
+ * distance and checks for fully connected nodes.
  */
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::CleanupHelper(TreeType* tree)
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
 {
   tree->Stat().MaxNeighborDistance() = DBL_MAX;
 
@@ -440,8 +289,8 @@
 /**
  * The values stored in the tree must be reset on each iteration.
  */
-template<typename TreeType>
-void DualTreeBoruvka<TreeType>::Cleanup()
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
 {
   for (size_t i = 0; i < data.n_cols; i++)
   {

Added: mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp	2012-10-11 13:41:44 UTC (rev 13680)
@@ -0,0 +1,150 @@
+/**
+ * @file dtb.hpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Tree traverser rules for the DualTreeBoruvka algorithm.
+ */
+
+
+#ifndef __MLPACK_METHODS_EMST_DTB_RULES_HPP
+#define __MLPACK_METHODS_EMST_DTB_RULES_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace emst {
+
+template<typename MetricType, typename TreeType>
+class DTBRules
+{
+ public:
+  
+  DTBRules(const arma::mat& dataSet,
+           UnionFind& connections,
+           arma::vec& neighborsDistances,
+           arma::Col<size_t>& neighborsInComponent,
+           arma::Col<size_t>& neighborsOutComponent,
+           MetricType& metric);
+  
+  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+  
+  // Update bounds.  Needs a better name.
+  void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
+  
+  /**
+   * Get the score for recursion order.  A low score indicates priority for
+   * recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   */
+  double Score(const size_t queryIndex, TreeType& referenceNode);
+  
+  /**
+   * Get the score for recursion order, passing the base case result (in the
+   * situation where it may be needed to calculate the recursion order).  A low
+   * score indicates priority for recursion, while DBL_MAX indicates that the
+   * node should not be recursed into at all (it should be pruned).
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+   */
+  double Score(const size_t queryIndex,
+               TreeType& referenceNode,
+               const double baseCaseResult);
+  
+  /**
+   * Re-evaluate the score for recursion order.  A low score indicates priority
+   * for recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).  This is used when the score has already
+   * been calculated, but another recursion may have modified the bounds for
+   * pruning.  So the old score is checked against the new pruning bound.
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   * @param oldScore Old score produced by Score() (or Rescore()).
+   */
+  double Rescore(const size_t queryIndex,
+                 TreeType& referenceNode,
+                 const double oldScore);
+  
+  /**
+   * Get the score for recursion order.  A low score indicates priority for
+   * recursionm while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).
+   *
+   * @param queryNode Candidate query node to recurse into.
+   * @param referenceNode Candidate reference node to recurse into.
+   */
+  double Score(TreeType& queryNode, TreeType& referenceNode) const;
+  
+  /**
+   * Get the score for recursion order, passing the base case result (in the
+   * situation where it may be needed to calculate the recursion order).  A low
+   * score indicates priority for recursion, while DBL_MAX indicates that the
+   * node should not be recursed into at all (it should be pruned).
+   *
+   * @param queryNode Candidate query node to recurse into.
+   * @param referenceNode Candidate reference node to recurse into.
+   * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+   */
+  double Score(TreeType& queryNode,
+               TreeType& referenceNode,
+               const double baseCaseResult) const;
+  
+  /**
+   * Re-evaluate the score for recursion order.  A low score indicates priority
+   * for recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).  This is used when the score has already
+   * been calculated, but another recursion may have modified the bounds for
+   * pruning.  So the old score is checked against the new pruning bound.
+   *
+   * @param queryNode Candidate query node to recurse into.
+   * @param referenceNode Candidate reference node to recurse into.
+   * @param oldScore Old score produced by Socre() (or Rescore()).
+   */
+  double Rescore(TreeType& queryNode,
+                 TreeType& referenceNode,
+                 const double oldScore) const;
+
+ private:
+  
+  // This class needs to know what points are connected to one another
+  
+  
+  // Things I need
+  // UnionFind storing the tree structure at this iteration
+  // neighborDistances
+  // neighborInComponent
+  // neighborOutComponent
+  
+  //! The data points.
+  const arma::mat& dataSet;
+  
+  //! Stores the tree structure so far
+  UnionFind& connections;
+  
+  //! The distance to the candidate nearest neighbor for each component.
+  arma::vec& neighborsDistances;
+  
+  //! The index of the point in the component that is an endpoint of the
+  //! candidate edge.
+  arma::Col<size_t>& neighborsInComponent;
+  
+  //! The index of the point outside of the component that is an endpoint
+  //! of the candidate edge.
+  arma::Col<size_t>& neighborsOutComponent;
+  
+  //! The metric
+  MetricType& metric;
+  
+}; // class DTBRules
+
+} // emst namespace
+} // mlpack namespace
+
+#include "dtb_rules_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp	2012-10-11 13:41:44 UTC (rev 13680)
@@ -0,0 +1,245 @@
+/**
+ * @file dtb_impl.hpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Tree traverser rules for the DualTreeBoruvka algorithm.
+ */
+
+
+#ifndef __MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
+#define __MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
+
+namespace mlpack {
+namespace emst {
+
+template<typename MetricType, typename TreeType>
+DTBRules<MetricType, TreeType>::
+DTBRules(const arma::mat& dataSet,
+         UnionFind& connections,
+         arma::vec& neighborsDistances,
+         arma::Col<size_t>& neighborsInComponent,
+         arma::Col<size_t>& neighborsOutComponent,
+         MetricType& metric)
+:
+  dataSet(dataSet),
+  connections(connections),
+  neighborsDistances(neighborsDistances),
+  neighborsInComponent(neighborsInComponent),
+  neighborsOutComponent(neighborsOutComponent),
+  metric(metric)
+{
+  // nothing else to do
+} // constructor
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::BaseCase(const size_t queryIndex,
+                                                const size_t referenceIndex)
+{
+  
+  // Check if the points are in the same component at this iteration.
+  // If not, return the distance between them.
+  // Also responsible for storing this as the current neighbor
+  
+  double newUpperBound = -1.0;
+  
+  // Find the index of the component the query is in.
+  size_t queryComponentIndex = connections.Find(queryIndex);
+  
+  size_t referenceComponentIndex = connections.Find(referenceIndex);
+  
+  if (queryComponentIndex != referenceComponentIndex)
+  {
+    double distance = metric.Evaluate(dataSet.col(queryIndex),
+                                      dataSet.col(referenceIndex));
+    
+    if (distance < neighborsDistances[queryComponentIndex])
+    {
+      Log::Assert(queryIndex != referenceIndex);
+      
+      neighborsDistances[queryComponentIndex] = distance;
+      neighborsInComponent[queryComponentIndex] = queryIndex;
+      neighborsOutComponent[queryComponentIndex] = referenceIndex;
+      
+    } // if distance
+  } // if indices not equal
+
+  if (newUpperBound < neighborsDistances[queryComponentIndex])
+    newUpperBound = neighborsDistances[queryComponentIndex];
+  
+  Log::Assert(newUpperBound >= 0.0);
+  
+  return newUpperBound;
+  
+} // BaseCase()
+
+template<typename MetricType, typename TreeType>
+void DTBRules<MetricType, TreeType>::UpdateAfterRecursion(
+    TreeType& queryNode,
+    TreeType& /* referenceNode */)
+{
+
+  // Find the worst distance that the children found (including any points), and
+  // update the bound accordingly.
+  double newUpperBound = 0.0;
+  
+  // First look through children nodes.
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    if (newUpperBound < queryNode.Child(i).Stat().MaxNeighborDistance())
+      newUpperBound = queryNode.Child(i).Stat().MaxNeighborDistance();
+  }
+  
+  // Now look through children points.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    size_t pointComponent = connections.Find(queryNode.Point(i));
+    if (newUpperBound < neighborsDistances[pointComponent])
+      newUpperBound = neighborsDistances[pointComponent];
+  }
+  
+  // Update the bound in the query's stat
+  queryNode.Stat().MaxNeighborDistance() = newUpperBound;
+  
+} // UpdateAfterRecursion
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
+                                             TreeType& referenceNode)
+{
+  
+  size_t queryComponentIndex = connections.Find(queryIndex);
+  
+  // If the query belongs to the same component as all of the references,
+  // then prune.
+  // Casting this to stop a warning about comparing unsigned to signed
+  // values.
+  if (queryComponentIndex == (size_t)referenceNode.Stat().ComponentMembership())
+    return DBL_MAX;
+  
+  const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
+  
+  const double distance = referenceNode.MinDistance(queryPoint);
+  
+  // If all the points in the reference node are farther than the candidate
+  // nearest neighbor for the query's component, we prune.
+  return neighborsDistances[queryComponentIndex] < distance
+      ? DBL_MAX : distance;
+  
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
+                                             TreeType& referenceNode,
+                                             const double baseCaseResult) 
+{
+  // I don't really understand the last argument here
+  // It just gets passed in the distance call, otherwise this function
+  // is the same as the one above
+
+  size_t queryComponentIndex = connections.Find(queryIndex);
+  
+  // if the query belongs to the same component as all of the references,
+  // then prune
+  if (queryComponentIndex == referenceNode.Stat().ComponentMembership())
+    return DBL_MAX;
+  
+  const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
+  
+  const double distance = referenceNode.MinDistance(queryPoint,
+                                                    baseCaseResult);
+  
+  // If all the points in the reference node are farther than the candidate
+  // nearest neighbor for the query's component, we prune.
+  return neighborsDistances[queryComponentIndex] < distance
+      ? DBL_MAX : distance;
+  
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
+                                               TreeType& referenceNode,
+                                               const double oldScore) 
+{
+  // We don't need to check component membership again, because it can't
+  // change inside a single iteration.
+  
+  // If we are already pruning, still prune.
+  if (oldScore == DBL_MAX)
+    return oldScore;
+
+  if (oldScore > neighborsDistances[connections.Find(queryIndex)])
+    return DBL_MAX;
+  else
+    return oldScore;
+  
+} // Rescore
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
+                                             TreeType& referenceNode) const
+{
+  // If all the queries belong to the same component as all the references
+  // then we prune.
+  if ((queryNode.Stat().ComponentMembership() >= 0)
+      && (queryNode.Stat().ComponentMembership() ==
+          referenceNode.Stat().ComponentMembership()))
+    return DBL_MAX;
+
+  double distance = queryNode.MinDistance(&referenceNode);
+  
+  // If all the points in the reference node are farther than the candidate
+  // nearest neighbor for all queries in the node, we prune.
+  return queryNode.Stat().MaxNeighborDistance() < distance
+      ? DBL_MAX : distance;
+  
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
+                                             TreeType& referenceNode,
+                                             const double baseCaseResult) const
+{
+  
+  // If all the queries belong to the same component as all the references
+  // then we prune.
+  if ((queryNode.Stat().ComponentMembership() >= 0)
+      && (queryNode.Stat().ComponentMembership() ==
+          referenceNode.Stat().ComponentMembership()))
+    return DBL_MAX;
+  
+  const double distance = queryNode.MinDistance(referenceNode,
+                                                baseCaseResult);
+  
+  // If all the points in the reference node are farther than the candidate
+  // nearest neighbor for all queries in the node, we prune.
+  return queryNode.Stat().MaxNeighborDistance() < distance
+      ? DBL_MAX : distance;
+  
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
+                                               TreeType& /* referenceNode */,
+                                               const double oldScore) const
+{
+  
+  // Same as above, but for nodes,
+  
+  // If we are already pruning, still prune.
+  if (oldScore == DBL_MAX)
+    return oldScore;
+  
+  if (oldScore > queryNode.Stat().MaxNeighborDistance())
+    return DBL_MAX;
+  else
+    return oldScore;
+  
+} // Rescore
+
+} // namespace emst
+} // namespace mlpack
+
+
+
+#endif 
+




More information about the mlpack-svn mailing list