[mlpack-svn] r15566 - mlpack/trunk/src/mlpack/methods/emst

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 26 17:04:48 EDT 2013


Author: rcurtin
Date: Fri Jul 26 17:04:47 2013
New Revision: 15566

Log:
Code cleanup, and refactor DTBRules so it does not depend on
UpdateAfterRecursion().  Now, EMST actually provides speedup when run in
dual-tree mode (hooray!).


Modified:
   mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
   mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
   mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
   mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
   mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	Fri Jul 26 17:04:47 2013
@@ -40,6 +40,14 @@
   //! Upper bound on the distance to the nearest neighbor of any point in this
   //! node.
   double maxNeighborDistance;
+
+  //! Lower bound on the distance to the nearest neighbor of any point in this
+  //! node.
+  double minNeighborDistance;
+
+  //! Total bound for pruning.
+  double bound;
+
   //! The index of the component that all points in this node belong to.  This
   //! is the same index returned by UnionFind for all points in this node.  If
   //! points in this node are in different components, this value will be
@@ -68,6 +76,16 @@
   //! Modify the maximum neighbor distance.
   double& MaxNeighborDistance() { return maxNeighborDistance; }
 
+  //! Get the minimum neighbor distance.
+  double MinNeighborDistance() const { return minNeighborDistance; }
+  //! Modify the minimum neighbor distance.
+  double& MinNeighborDistance() { return minNeighborDistance; }
+
+  //! Get the total bound for pruning.
+  double Bound() const { return bound; }
+  //! Modify the total bound for pruning.
+  double& Bound() { return bound; }
+
   //! Get the component membership of this node.
   int ComponentMembership() const { return componentMembership; }
   //! Modify the component membership of this node.
@@ -114,7 +132,7 @@
  * @tparam TreeType Type of tree to use.  Should use DTBStat as a statistic.
  */
 template<
-  typename MetricType = metric::SquaredEuclideanDistance,
+  typename MetricType = metric::EuclideanDistance,
   typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
 >
 class DualTreeBoruvka
@@ -151,7 +169,7 @@
   //! Total distance of the tree.
   double totalDist;
 
-  //! The metric
+  //! The instantiated metric.
   MetricType metric;
 
   // For sorting the edge list after the computation.

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp	Fri Jul 26 17:04:47 2013
@@ -4,7 +4,6 @@
  *
  * Implementation of DTB.
  */
-
 #ifndef __MLPACK_METHODS_EMST_DTB_IMPL_HPP
 #define __MLPACK_METHODS_EMST_DTB_IMPL_HPP
 
@@ -18,7 +17,11 @@
 /**
  * A generic initializer.
  */
-DTBStat::DTBStat() : maxNeighborDistance(DBL_MAX), componentMembership(-1)
+DTBStat::DTBStat() :
+    maxNeighborDistance(DBL_MAX),
+    minNeighborDistance(DBL_MAX),
+    bound(DBL_MAX),
+    componentMembership(-1)
 {
   // Nothing to do.
 }
@@ -29,6 +32,8 @@
 template<typename TreeType>
 DTBStat::DTBStat(const TreeType& node) :
     maxNeighborDistance(DBL_MAX),
+    minNeighborDistance(DBL_MAX),
+    bound(DBL_MAX),
     componentMembership(((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
         node.Point(0) : -1)
 {
@@ -124,7 +129,6 @@
 
   while (edges.size() < (data.n_cols - 1))
   {
-
     typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
 
     traverser.Traverse(*tree, *tree);
@@ -133,7 +137,7 @@
 
     Cleanup();
 
-    Log::Info << edges.size() << " edges found so far.\n";
+    Log::Info << edges.size() << " edges found so far." << std::endl;
   }
 
   Timer::Stop("emst/mst_computation");
@@ -175,7 +179,7 @@
     {
       //totalDist = totalDist + dist;
       // changed to make this agree with the cover tree code
-      totalDist += sqrt(neighborsDistances[component]);
+      totalDist += neighborsDistances[component];
       AddEdge(inEdge, outEdge, neighborsDistances[component]);
       connections.Union(inEdge, outEdge);
     }
@@ -217,7 +221,7 @@
 
       results(0, i) = edges[i].Lesser();
       results(1, i) = edges[i].Greater();
-      results(2, i) = sqrt(edges[i].Distance());
+      results(2, i) = edges[i].Distance();
     }
   }
   else
@@ -226,7 +230,7 @@
     {
       results(0, i) = edges[i].Lesser();
       results(1, i) = edges[i].Greater();
-      results(2, i) = sqrt(edges[i].Distance());
+      results(2, i) = edges[i].Distance();
     }
   }
 } // EmitResults
@@ -239,6 +243,8 @@
 void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
 {
   tree->Stat().MaxNeighborDistance() = DBL_MAX;
+  tree->Stat().MinNeighborDistance() = DBL_MAX;
+  tree->Stat().Bound() = DBL_MAX;
 
   if (!tree->IsLeaf())
   {

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp	Fri Jul 26 17:04:47 2013
@@ -25,9 +25,6 @@
 
   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
-  // Update bounds.  Needs a better name.
-  void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
-
   /**
    * Get the score for recursion order.  A low score indicates priority for
    * recursion, while DBL_MAX indicates that the node should not be recursed
@@ -124,8 +121,14 @@
   //! of the candidate edge.
   arma::Col<size_t>& neighborsOutComponent;
 
-  //! The metric
+  //! The instantiated metric.
   MetricType& metric;
+
+  /**
+   * Update the bound for the given query node.
+   */
+  inline double CalculateBound(TreeType& queryNode) const;
+
 }; // class DTBRules
 
 } // emst namespace

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp	Fri Jul 26 17:04:47 2013
@@ -55,7 +55,6 @@
       neighborsDistances[queryComponentIndex] = distance;
       neighborsInComponent[queryComponentIndex] = queryIndex;
       neighborsOutComponent[queryComponentIndex] = referenceIndex;
-
     }
   }
 
@@ -68,34 +67,6 @@
 }
 
 template<typename MetricType, typename TreeType>
-void DTBRules<MetricType, TreeType>::UpdateAfterRecursion(
-    TreeType& queryNode,
-    TreeType& /* referenceNode */)
-{
-  // Find the worst distance that the children found (including any points), and
-  // update the bound accordingly.
-  double newUpperBound = 0.0;
-
-  // First look through children nodes.
-  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
-  {
-    if (newUpperBound < queryNode.Child(i).Stat().MaxNeighborDistance())
-      newUpperBound = queryNode.Child(i).Stat().MaxNeighborDistance();
-  }
-
-  // Now look through children points.
-  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
-  {
-    size_t pointComponent = connections.Find(queryNode.Point(i));
-    if (newUpperBound < neighborsDistances[pointComponent])
-      newUpperBound = neighborsDistances[pointComponent];
-  }
-
-  // Update the bound in the query's statistic.
-  queryNode.Stat().MaxNeighborDistance() = newUpperBound;
-}
-
-template<typename MetricType, typename TreeType>
 double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
                                              TreeType& referenceNode)
 {
@@ -109,7 +80,6 @@
     return DBL_MAX;
 
   const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
-
   const double distance = referenceNode.MinDistance(queryPoint);
 
   // If all the points in the reference node are farther than the candidate
@@ -166,11 +136,11 @@
     return DBL_MAX;
 
   const double distance = queryNode.MinDistance(&referenceNode);
+  const double bound = CalculateBound(queryNode);
 
   // If all the points in the reference node are farther than the candidate
   // nearest neighbor for all queries in the node, we prune.
-  return (queryNode.Stat().MaxNeighborDistance() < distance) ? DBL_MAX :
-      distance;
+  return (bound < distance) ? DBL_MAX : distance;
 }
 
 template<typename MetricType, typename TreeType>
@@ -185,13 +155,12 @@
            referenceNode.Stat().ComponentMembership()))
     return DBL_MAX;
 
-  const double distance = queryNode.MinDistance(referenceNode,
-                                                baseCaseResult);
+  const double distance = queryNode.MinDistance(referenceNode, baseCaseResult);
+  const double bound = CalculateBound(queryNode);
 
   // If all the points in the reference node are farther than the candidate
   // nearest neighbor for all queries in the node, we prune.
-  return (queryNode.Stat().MaxNeighborDistance() < distance) ? DBL_MAX :
-      distance;
+  return (bound < distance) ? DBL_MAX : distance;
 }
 
 template<typename MetricType, typename TreeType>
@@ -199,12 +168,63 @@
                                                TreeType& /* referenceNode */,
                                                const double oldScore) const
 {
-  return (oldScore > queryNode.Stat().MaxNeighborDistance()) ? DBL_MAX :
-      oldScore;
+  const double bound = CalculateBound(queryNode);
+  return (oldScore > bound) ? DBL_MAX : oldScore;
+}
+
+// Calculate the bound for a given query node in its current state and update
+// it.
+template<typename MetricType, typename TreeType>
+inline double DTBRules<MetricType, TreeType>::CalculateBound(
+    TreeType& queryNode) const
+{
+  double worstPointBound = -DBL_MAX;
+  double bestPointBound = DBL_MAX;
+
+  double worstChildBound = -DBL_MAX;
+  double bestChildBound = DBL_MAX;
+
+  // Now, find the best and worst point bounds.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    const size_t pointComponent = connections.Find(queryNode.Point(i));
+    const double bound = neighborsDistances[pointComponent];
+
+    if (bound > worstPointBound)
+      worstPointBound = bound;
+    if (bound < bestPointBound)
+      bestPointBound = bound;
+  }
+
+  // Find the best and worst child bounds.
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    const double maxBound = queryNode.Child(i).Stat().MaxNeighborDistance();
+    if (maxBound > worstChildBound)
+      worstChildBound = maxBound;
+
+    const double minBound = queryNode.Child(i).Stat().MinNeighborDistance();
+    if (minBound < bestChildBound)
+      bestChildBound = minBound;
+  }
+
+  // Now calculate the actual bounds.
+  const double worstBound = std::max(worstPointBound, worstChildBound);
+  const double bestBound = std::min(bestPointBound, bestChildBound);
+  // We must check that bestBound != DBL_MAX; otherwise, we risk overflow.
+  const double bestAdjustedBound = (bestBound == DBL_MAX) ? DBL_MAX :
+      bestBound + 2 * queryNode.FurthestDescendantDistance();
+
+  // Update the relevant quantities in the node.
+  queryNode.Stat().MaxNeighborDistance() = worstBound;
+  queryNode.Stat().MinNeighborDistance() = bestBound;
+  queryNode.Stat().Bound() = std::min(worstBound, bestAdjustedBound);
+
+  return queryNode.Stat().Bound();
 }
 
-} // namespace emst
-} // namespace mlpack
+}; // namespace emst
+}; // namespace mlpack
 
 
 

Modified: mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp	Fri Jul 26 17:04:47 2013
@@ -44,60 +44,54 @@
 using namespace mlpack;
 using namespace mlpack::emst;
 using namespace mlpack::tree;
+using namespace std;
 
 int main(int argc, char* argv[])
 {
   CLI::ParseCommandLine(argc, argv);
 
-  ///////////////// READ IN DATA //////////////////////////////////
-  std::string dataFilename = CLI::GetParam<std::string>("input_file");
-
-  Log::Info << "Reading in data.\n";
+  const string dataFilename = CLI::GetParam<string>("input_file");
 
   arma::mat dataPoints;
   data::Load(dataFilename, dataPoints, true);
 
-  // Do naive.
+  // Do naive computation if necessary.
   if (CLI::GetParam<bool>("naive"))
   {
-    Log::Info << "Running naive algorithm.\n";
+    Log::Info << "Running naive algorithm." << endl;
 
     DualTreeBoruvka<> naive(dataPoints, true);
 
     arma::mat naiveResults;
     naive.ComputeMST(naiveResults);
 
-    std::string outputFilename = CLI::GetParam<std::string>("output_file");
+    const string outputFilename = CLI::GetParam<string>("output_file");
 
     data::Save(outputFilename, naiveResults, true);
   }
   else
   {
-    Log::Info << "Data read, building tree.\n";
+    Log::Info << "Building tree.\n";
 
-    /////////////// Initialize DTB //////////////////////
+    // Check that the leaf size is reasonable.
     if (CLI::GetParam<int>("leaf_size") <= 0)
     {
       Log::Fatal << "Invalid leaf size (" << CLI::GetParam<int>("leaf_size")
           << ")!  Must be greater than or equal to 1." << std::endl;
     }
 
-    size_t leafSize = CLI::GetParam<int>("leaf_size");
-
+    // Initialize the tree and get ready to compute the MST.
+    const size_t leafSize = (size_t) CLI::GetParam<int>("leaf_size");
     DualTreeBoruvka<> dtb(dataPoints, false, leafSize);
 
-    Log::Info << "Tree built, running algorithm.\n";
-
-    ////////////// Run DTB /////////////////////
+    // Run the DTB algorithm.
+    Log::Info << "Calculating minimum spanning tree." << endl;
     arma::mat results;
-
     dtb.ComputeMST(results);
 
-    //////////////// Output the Results ////////////////
-    std::string outputFilename = CLI::GetParam<std::string>("output_file");
+    // Output the results.
+    const string outputFilename = CLI::GetParam<string>("output_file");
 
     data::Save(outputFilename, results, true);
   }
-
-  return 0;
 }



More information about the mlpack-svn mailing list