[mlpack-svn] r10764 - mlpack/trunk/src/mlpack/methods/emst

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 14 05:53:00 EST 2011


Author: rcurtin
Date: 2011-12-14 05:53:00 -0500 (Wed, 14 Dec 2011)
New Revision: 10764

Modified:
   mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
   mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
   mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp
   mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
   mlpack/trunk/src/mlpack/methods/emst/union_find.hpp
Log:
Refactor and clean up EMST code.


Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	2011-12-14 08:19:14 UTC (rev 10763)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	2011-12-14 10:53:00 UTC (rev 10764)
@@ -1,6 +1,5 @@
 /**
  * @file dtb.hpp
- *
  * @author Bill March (march at gatech.edu)
  *
  * Contains an implementation of the DualTreeBoruvka algorithm for finding a
@@ -37,18 +36,12 @@
 class DTBStat
 {
  private:
-  double max_neighbor_distance_;
-  int component_membership_;
+  //! Maximum neighbor distance.
+  double maxNeighborDistance;
+  //! Component membership of this node.
+  int componentMembership;
 
  public:
-  void set_max_neighbor_distance(double distance);
-
-  double max_neighbor_distance();
-
-  void set_component_membership(int membership);
-
-  int component_membership();
-
   /**
    * A generic initializer.
    */
@@ -67,137 +60,152 @@
   DTBStat(const MatType& dataset, const size_t start, const size_t count,
           const DTBStat& leftStat, const DTBStat& rightStat);
 
+  //! Get the maximum neighbor distance.
+  double MaxNeighborDistance() const { return maxNeighborDistance; }
+  //! Modify the maximum neighbor distance.
+  double& MaxNeighborDistance() { return maxNeighborDistance; }
+
+  //! Get the component membership of this node.
+  int ComponentMembership() const { return componentMembership; }
+  //! Modify the component membership of this node.
+  int& ComponentMembership() { return componentMembership; }
+
 }; // class DTBStat
 
 /**
  * Performs the MST calculation using the Dual-Tree Boruvka algorithm.
  */
+template<
+  typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
+>
 class DualTreeBoruvka
 {
- public:
-  // For now, everything is in Euclidean space
-  static const size_t metric = 2;
+ private:
+  //! Copy of the data (if necessary).
+  arma::mat dataCopy;
+  //! Reference to the data (this is what should be used for accessing data).
+  arma::mat& data;
 
-  typedef tree::BinarySpaceTree<bound::HRectBound<metric>, DTBStat> DTBTree;
+  //! Pointer to the root of the tree.
+  TreeType* tree;
+  //! Indicates whether or not we "own" the tree.
+  bool ownTree;
 
-  //////// Member Variables /////////////////////
+  //! Indicates whether or not O(n^2) naive mode will be used.
+  bool naive;
 
- private:
-  size_t number_of_edges_;
-  std::vector<EdgePair> edges_; // must use vector with non-numerical types
-  size_t number_of_points_;
-  UnionFind connections_;
-  struct datanode* module_;
-  arma::mat data_points_;
-  size_t leaf_size_;
+  //! Edges.
+  std::vector<EdgePair> edges; // must use vector with non-numerical types
 
-  // lists
-  std::vector<size_t> old_from_new_permutation_;
-  arma::Col<size_t> neighbors_in_component_;
-  arma::Col<size_t> neighbors_out_component_;
-  arma::vec neighbors_distances_;
+  //! Connections.
+  UnionFind connections;
 
+  //! Permutations of points during tree building.
+  std::vector<size_t> oldFromNew;
+  //! List of edge nodes.
+  arma::Col<size_t> neighborsInComponent;
+  //! List of edge nodes.
+  arma::Col<size_t> neighborsOutComponent;
+  //! List of edge distances.
+  arma::vec neighborsDistances;
+
   // output info
-  double total_dist_;
-  size_t number_of_loops_;
-  size_t number_distance_prunes_;
-  size_t number_component_prunes_;
-  size_t number_leaf_computations_;
-  size_t number_q_recursions_;
-  size_t number_r_recursions_;
-  size_t number_both_recursions_;
+  double totalDist;
 
-  bool do_naive_;
-
-  DTBTree* tree_;
-
-  // for sorting the edge list after the computation
-  struct SortEdgesHelper_
+  // For sorting the edge list after the computation.
+  struct SortEdgesHelper
   {
-    bool operator() (const EdgePair& pairA, const EdgePair& pairB)
+    bool operator()(const EdgePair& pairA, const EdgePair& pairB)
     {
-      return (pairA.distance() < pairB.distance());
+      return (pairA.Distance() < pairB.Distance());
     }
   } SortFun;
-  
 
+
 ////////////////// Constructors ////////////////////////
  public:
-  DualTreeBoruvka() { }
+  /**
+   * Create the tree from the given dataset.  This copies the dataset to an
+   * internal copy, because tree-building modifies the dataset.
+   *
+   * @param data Dataset to build a tree for.
+   * @param naive Whether the computation should be done in O(n^2) naive mode.
+   * @param leafSize The leaf size to be used during tree construction.
+   */
+  DualTreeBoruvka(const typename TreeType::Mat& dataset,
+                  const bool naive = false,
+                  const size_t leafSize = 1);
 
+  /**
+   * Create the DualTreeBoruvka object with an already initialized tree.  This
+   * will not copy the dataset, and can save a little processing power.  Naive
+   * mode is not available as an option for this constructor; instead, to run
+   * naive computation, construct a tree with all the points in one leaf (i.e.
+   * leafSize = number of points).
+   *
+   * @note
+   * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+   * of a matrix, be sure you pass the modified matrix to this object!  In
+   * addition, mapping the points of the matrix back to their original indices
+   * is not done when this constructor is used.
+   * @endnote
+   *
+   * @param tree Pre-built tree.
+   * @param dataset Dataset corresponding to the pre-built tree.
+   */
+  DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset);
+
+  /**
+   * Delete the tree, if it was created inside the object.
+   */
   ~DualTreeBoruvka();
 
+  /**
+   * Call this function after Init.  It will iteratively find the nearest
+   * neighbor of each component until the MST is complete.
+   */
+  void ComputeMST(arma::mat& results);
+
   ////////////////////////// Private Functions ////////////////////
  private:
   /**
    * Adds a single edge to the edge list
    */
-  void AddEdge_(size_t e1, size_t e2, double distance);
-  
+  void AddEdge(const size_t e1, const size_t e2, const double distance);
+
   /**
    * Adds all the edges found in one iteration to the list of neighbors.
    */
-  void AddAllEdges_();
-  
+  void AddAllEdges();
+
   /**
    * Handles the base case computation.  Also called by naive.
    */
-  double ComputeBaseCase_(size_t query_start, size_t query_end,
-                          size_t reference_start, size_t reference_end);
-  
+  double BaseCase(const TreeType* queryNode, const TreeType* referenceNode);
+
   /**
    * Handles the recursive calls to find the nearest neighbors in an iteration
    */
-  void ComputeNeighborsRecursion_(DTBTree *query_node, DTBTree *reference_node,
-                                  double incoming_distance);
-  
-  /**
-   * Computes the nearest neighbor of each point in each iteration
-   * of the algorithm
-   */
-  void ComputeNeighbors_();
+  void DualTreeRecursion(TreeType *queryNode,
+                         TreeType *referenceNode,
+                         double incomingDistance);
 
-  
-  void SortEdges_();
-  
   /**
-   * Unpermute the edge list and output it to results
-   *
+   * Unpermute the edge list and output it to results.
    */
-  void EmitResults_(arma::mat& results);
+  void EmitResults(arma::mat& results);
 
   /**
    * This function resets the values in the nodes of the tree nearest neighbor
-   * distance, check for fully connected nodes
+   * distance, and checks for fully connected nodes.
    */
-  void CleanupHelper_(DTBTree* tree);
+  void CleanupHelper(TreeType* tree);
 
   /**
    * The values stored in the tree must be reset on each iteration.
    */
-  void Cleanup_();
-  
-  /**
-   * Format and output the results
-   */
-  void OutputResults_();
-  
-  /////////// Public Functions ///////////////////
- public:
-  size_t number_of_edges();
+  void Cleanup();
 
-  /**
-   * Takes in a reference to the data set.  Copies the data, builds the tree,
-   * and initializes all of the member variables.
-   */
-  void Init(const arma::mat& data, bool naive, size_t leafSize);
-  
-  /**
-   * Call this function after Init.  It will iteratively find the nearest
-   * neighbor of each component until the MST is complete.
-   */
-  void ComputeMST(arma::mat& results);
-  
 }; // class DualTreeBoruvka
 
 }; // namespace emst

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp	2011-12-14 08:19:14 UTC (rev 10763)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp	2011-12-14 10:53:00 UTC (rev 10764)
@@ -1,10 +1,8 @@
-/*
- *  dtb_impl.hpp
- *  
+/**
+ * @file dtb_impl.hpp
+ * @author Bill March (march at gatech.edu)
  *
- *  Created by William March on 12/6/11.
- *  Copyright 2011 __MyCompanyName__. All rights reserved.
- *
+ * Implementation of DTB.
  */
 
 #ifndef __MLPACK_METHODS_EMST_DTB_IMPL_HPP
@@ -12,37 +10,17 @@
 
 #include <mlpack/core.hpp>
 
+namespace mlpack {
+namespace emst {
 
+// DTBStat
 
-using namespace mlpack::emst;
-
-void DTBStat::set_max_neighbor_distance(double distance)
-{
-  max_neighbor_distance_ = distance;
-}
-
-double DTBStat::max_neighbor_distance()
-{
-  return max_neighbor_distance_;
-}
-
-void DTBStat::set_component_membership(int membership)
-{
-  component_membership_ = membership;
-}
-
-int DTBStat::component_membership()
-{
-  return component_membership_;
-}
-
 /**
  * A generic initializer.
  */
-DTBStat::DTBStat()
+DTBStat::DTBStat() : maxNeighborDistance(DBL_MAX), componentMembership(-1)
 {
-  set_max_neighbor_distance(DBL_MAX);
-  set_component_membership(-1);
+  // Nothing to do.
 }
 
 /**
@@ -51,516 +29,430 @@
 template<typename MatType>
 DTBStat::DTBStat(const MatType& dataset,
                  const size_t start,
-                 const size_t count)
+                 const size_t count) :
+    maxNeighborDistance(DBL_MAX),
+    componentMembership((count == 1) ? start : -1)
 {
-  if (count == 1)
-  {
-    set_component_membership(start);
-    set_max_neighbor_distance(DBL_MAX);
-  }
-  else
-  {
-    set_max_neighbor_distance(DBL_MAX);
-    set_component_membership(-1);
-  }
+  // Nothing to do.
 }
 
 /**
- * An initializer for non-leaves.  Simply calls the leaf initializer.
+ * An initializer for non-leaves.
  */
 template<typename MatType>
 DTBStat::DTBStat(const MatType& dataset,
                  const size_t start,
                  const size_t count,
                  const DTBStat& leftStat,
-                 const DTBStat& right_stat)
+                 const DTBStat& right_stat) :
+    maxNeighborDistance(DBL_MAX),
+    componentMembership((count == 1) ? start : -1)
 {
-  if (count == 1)
+  // Nothing to do.
+}
+
+// DualTreeBoruvka
+
+/**
+ * Takes in a reference to the data set.  Copies the data, builds the tree,
+ * and initializes all of the member variables.
+ */
+template<typename TreeType>
+DualTreeBoruvka<TreeType>::DualTreeBoruvka(
+    const typename TreeType::Mat& dataset,
+    const bool naive,
+    const size_t leafSize) :
+    dataCopy(dataset),
+    data(dataCopy), // The reference points to our copy of the data.
+    ownTree(true),
+    naive(naive),
+    connections(data.n_cols),
+    totalDist(0.0)
+{
+  Timers::StartTimer("emst/treebuilding");
+
+  if (!naive)
   {
-    set_component_membership(start);
-    set_max_neighbor_distance(DBL_MAX);
+    // Default leaf size is 1; this gives the best pruning, empirically.  Use
+    // leaf_size = 1 unless space is a big concern.
+    tree = new TreeType(data, oldFromNew, leafSize);
   }
   else
   {
-    set_max_neighbor_distance(DBL_MAX);
-    set_component_membership(-1);
+    // Naive tree holds all data in one leaf.
+    tree = new TreeType(data, oldFromNew, data.n_cols);
   }
-}
 
+  Timers::StopTimer("emst/treebuilding");
 
-DualTreeBoruvka::~DualTreeBoruvka()
+  edges.reserve(data.n_cols - 1); // Set size.
+
+  neighborsInComponent.set_size(data.n_cols);
+  neighborsOutComponent.set_size(data.n_cols);
+  neighborsDistances.set_size(data.n_cols);
+  neighborsDistances.fill(DBL_MAX);
+} // Constructor
+
+template<typename TreeType>
+DualTreeBoruvka<TreeType>::DualTreeBoruvka(
+    TreeType* tree,
+    const typename TreeType::Mat& dataset) :
+    data(dataset),
+    tree(tree),
+    ownTree(true),
+    naive(false),
+    connections(data.n_cols),
+    totalDist(0.0)
 {
-  if (tree_ != NULL)
-    delete tree_;
+  edges.reserve(data.n_cols - 1); // fill with EdgePairs
+
+  neighborsInComponent.set_size(data.n_cols);
+  neighborsOutComponent.set_size(data.n_cols);
+  neighborsDistances.set_size(data.n_cols);
+  neighborsDistances.fill(DBL_MAX);
 }
 
+template<typename TreeType>
+DualTreeBoruvka<TreeType>::~DualTreeBoruvka()
+{
+  if (ownTree)
+    delete tree;
+}
+
 /**
+ * Call this function after Init.  It will iteratively find the nearest
+ * neighbor of each component until the MST is complete.
+ */
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::ComputeMST(arma::mat& results)
+{
+  Timers::StartTimer("emst/mst_computation");
+
+  while (edges.size() < (data.n_cols - 1))
+  {
+    // Compute neighbors.
+    if (naive)
+    {
+      BaseCase(tree, tree);
+    }
+    else
+    {
+      DualTreeRecursion(tree, tree, DBL_MAX);
+    }
+
+    AddAllEdges();
+
+    Cleanup();
+
+    Log::Info << edges.size() << " edges found so far.\n";
+  }
+
+  Timers::StopTimer("emst/mst_computation");
+
+  EmitResults(results);
+
+  Log::Info << "Total squared length: " << totalDist << std::endl;
+} // ComputeMST
+
+/**
  * Adds a single edge to the edge list
  */
-void DualTreeBoruvka::AddEdge_(size_t e1, size_t e2, double distance)
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::AddEdge(const size_t e1,
+                                        const size_t e2,
+                                        const double distance)
 {
-  //EdgePair edge;
-  mlpack::Log::Assert((e1 != e2),
-                      "Indices are equal in DualTreeBoruvka.add_edge(...)");
-  
-  mlpack::Log::Assert((distance >= 0.0),
-                      "Negative distance input in DualTreeBoruvka.add_edge(...)");
-  
+  Log::Assert((distance >= 0.0),
+      "DualTreeBoruvka::AddEdge(): distance cannot be negative.");
+
   if (e1 < e2)
-    edges_[number_of_edges_].Init(e1, e2, distance);
+    edges.push_back(EdgePair(e1, e2, distance));
   else
-    edges_[number_of_edges_].Init(e2, e1, distance);
-  
-  number_of_edges_++;
-  
-} // AddEdge_
+    edges.push_back(EdgePair(e2, e1, distance));
+} // AddEdge
 
 /**
  * Adds all the edges found in one iteration to the list of neighbors.
  */
-void DualTreeBoruvka::AddAllEdges_()
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::AddAllEdges()
 {
-  for (size_t i = 0; i < number_of_points_; i++)
+  for (size_t i = 0; i < data.n_cols; i++)
   {
-    size_t component_i = connections_.Find(i);
-    size_t in_edge_i = neighbors_in_component_[component_i];
-    size_t out_edge_i = neighbors_out_component_[component_i];
-    if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i))
+    size_t component = connections.Find(i);
+    size_t inEdge = neighborsInComponent[component];
+    size_t outEdge = neighborsOutComponent[component];
+    if (connections.Find(inEdge) != connections.Find(outEdge))
     {
-      double dist = neighbors_distances_[component_i];
-      //total_dist_ = total_dist_ + dist;
+      //totalDist = totalDist + dist;
       // changed to make this agree with the cover tree code
-      total_dist_ = total_dist_ + sqrt(dist);
-      AddEdge_(in_edge_i, out_edge_i, dist);
-      connections_.Union(in_edge_i, out_edge_i);
+      totalDist += sqrt(neighborsDistances[component]);
+      AddEdge(inEdge, outEdge, neighborsDistances[component]);
+      connections.Union(inEdge, outEdge);
     }
   }
-} // AddAllEdges_
+} // AddAllEdges
 
 
 /**
  * Handles the base case computation.  Also called by naive.
  */
-double DualTreeBoruvka::ComputeBaseCase_(size_t query_start, size_t query_end,
-                                         size_t reference_start, 
-                                         size_t reference_end)
+template<typename TreeType>
+double DualTreeBoruvka<TreeType>::BaseCase(const TreeType* queryNode,
+                                           const TreeType* referenceNode)
 {
-  number_leaf_computations_++;
-  
-  double new_upper_bound = -1.0;
-  
-  for (size_t query_index = query_start; query_index < query_end;
-       query_index++)
+  double newUpperBound = -1.0;
+
+  for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
+       ++queryIndex)
   {
-    // Find the index of the component the query is in
-    size_t query_component_index = connections_.Find(query_index);
-    
-    arma::vec query_point = data_points_.col(query_index);
-    
-    for (size_t reference_index = reference_start;
-         reference_index < reference_end; reference_index++)
+    // Find the index of the component the query is in.
+    size_t queryComponentIndex = connections.Find(queryIndex);
+
+    for (size_t referenceIndex = referenceNode->Begin();
+         referenceIndex < referenceNode->End(); ++referenceIndex)
     {
-      size_t reference_component_index = connections_.Find(reference_index);
-      
-      if (query_component_index != reference_component_index)
+      size_t referenceComponentIndex = connections.Find(referenceIndex);
+
+      if (queryComponentIndex != referenceComponentIndex)
       {
-        arma::vec reference_point = data_points_.col(reference_index);
-        
-        double distance = mlpack::metric::LMetric<2>::Evaluate(query_point,
-                                                               reference_point);
-        
-        if (distance < neighbors_distances_[query_component_index])
+        double distance = metric::LMetric<2>::Evaluate(data.col(queryIndex),
+            data.col(referenceIndex));
+
+        if (distance < neighborsDistances[queryComponentIndex])
         {
-          mlpack::Log::Assert(query_index != reference_index);
-          
-          neighbors_distances_[query_component_index] = distance;
-          neighbors_in_component_[query_component_index] = query_index;
-          neighbors_out_component_[query_component_index] = reference_index;
+          Log::Assert(queryIndex != referenceIndex);
+
+          neighborsDistances[queryComponentIndex] = distance;
+          neighborsInComponent[queryComponentIndex] = queryIndex;
+          neighborsOutComponent[queryComponentIndex] = referenceIndex;
         } // if distance
       } // if indices not equal
-    } // for reference_index
-    
-    if (new_upper_bound < neighbors_distances_[query_component_index])
-      new_upper_bound = neighbors_distances_[query_component_index];
-    
-  } // for query_index
-  
-  mlpack::Log::Assert(new_upper_bound >= 0.0);
-  return new_upper_bound;
-  
-} // ComputeBaseCase_
+    } // for referenceIndex
 
+    if (newUpperBound < neighborsDistances[queryComponentIndex])
+      newUpperBound = neighborsDistances[queryComponentIndex];
 
+  } // for queryIndex
+
+  Log::Assert(newUpperBound >= 0.0);
+
+  return newUpperBound;
+
+} // BaseCase
+
+
 /**
  * Handles the recursive calls to find the nearest neighbors in an iteration
  */
-void DualTreeBoruvka::ComputeNeighborsRecursion_(DTBTree *query_node, 
-                                                 DTBTree *reference_node,
-                                                 double incoming_distance)
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::DualTreeRecursion(TreeType *queryNode,
+                                                  TreeType *referenceNode,
+                                                  double incomingDistance)
 {
   // Check for a distance prune.
-  if (query_node->Stat().max_neighbor_distance() < incoming_distance)
+  if (queryNode->Stat().MaxNeighborDistance() < incomingDistance)
   {
     // Pruned by distance.
-    number_distance_prunes_++;
+    return;
   }
   // Check for a component prune.
-  else if ((query_node->Stat().component_membership() >= 0)
-           && (query_node->Stat().component_membership() ==
-               reference_node->Stat().component_membership()))
+  else if ((queryNode->Stat().ComponentMembership() >= 0)
+        && (queryNode->Stat().ComponentMembership() ==
+               referenceNode->Stat().ComponentMembership()))
   {
     // Pruned by component membership.
-    mlpack::Log::Assert(reference_node->Stat().component_membership() >= 0);
-    mlpack::Log::Info << query_node->Stat().component_membership()
-    << "q mem\n";
-    mlpack::Log::Info << reference_node->Stat().component_membership()
-    << "r mem\n";
-    
-    number_component_prunes_++;
+    Log::Assert(referenceNode->Stat().ComponentMembership() >= 0);
+    return;
   }
-  else if (query_node->IsLeaf() && reference_node->IsLeaf()) // Base case.
+  else if (queryNode->IsLeaf() && referenceNode->IsLeaf()) // Base case.
   {
-    double new_bound = ComputeBaseCase_(query_node->Begin(),
-                                        query_node->End(), reference_node->Begin(), reference_node->End());
-    
-    query_node->Stat().set_max_neighbor_distance(new_bound);
+    double new_bound = BaseCase(queryNode, referenceNode);
+    queryNode->Stat().MaxNeighborDistance() = new_bound;
   }
-  else if (query_node->IsLeaf()) // Other recursive calls.
+  else if (queryNode->IsLeaf()) // Other recursive calls.
   {
-    // Recurse on reference_node only.
-    number_r_recursions_++;
-    
-    double left_dist =
-    query_node->Bound().MinDistance(reference_node->Left()->Bound());
-    double right_dist =
-    query_node->Bound().MinDistance(reference_node->Right()->Bound());
-    mlpack::Log::Assert(left_dist >= 0.0);
-    mlpack::Log::Assert(right_dist >= 0.0);
-    
-    if (left_dist < right_dist)
+    // Recurse on referenceNode only.
+    double leftDist =
+        queryNode->Bound().MinDistance(referenceNode->Left()->Bound());
+    double rightDist =
+        queryNode->Bound().MinDistance(referenceNode->Right()->Bound());
+
+    if (leftDist < rightDist)
     {
-      ComputeNeighborsRecursion_(query_node, reference_node->Left(),
-                                 left_dist);
-      ComputeNeighborsRecursion_(query_node, reference_node->Right(),
-                                 right_dist);
+      DualTreeRecursion(queryNode, referenceNode->Left(), leftDist);
+      DualTreeRecursion(queryNode, referenceNode->Right(), rightDist);
     }
     else
     {
-      ComputeNeighborsRecursion_(query_node, reference_node->Right(),
-                                 right_dist);
-      ComputeNeighborsRecursion_(query_node, reference_node->Left(),
-                                 left_dist);
+      DualTreeRecursion(queryNode, referenceNode->Right(), rightDist);
+      DualTreeRecursion(queryNode, referenceNode->Left(), leftDist);
     }
   }
-  else if (reference_node->IsLeaf())
+  else if (referenceNode->IsLeaf())
   {
-    // Recurse on query_node only.
-    number_q_recursions_++;
-    
-    double left_dist =
-    query_node->Left()->Bound().MinDistance(reference_node->Bound());
-    double right_dist =
-    query_node->Right()->Bound().MinDistance(reference_node->Bound());
-    
-    ComputeNeighborsRecursion_(query_node->Left(), reference_node, left_dist);
-    ComputeNeighborsRecursion_(query_node->Right(), reference_node,
-                               right_dist);
-    
-    // Update query_node's stat.
-    query_node->Stat().set_max_neighbor_distance(
-                                                 std::max(query_node->Left()->Stat().max_neighbor_distance(),
-                                                          query_node->Right()->Stat().max_neighbor_distance()));
-    
+    // Recurse on queryNode only.
+    double leftDist =
+        queryNode->Left()->Bound().MinDistance(referenceNode->Bound());
+    double rightDist =
+        queryNode->Right()->Bound().MinDistance(referenceNode->Bound());
+
+    DualTreeRecursion(queryNode->Left(), referenceNode, leftDist);
+    DualTreeRecursion(queryNode->Right(), referenceNode, rightDist);
+
+    // Update queryNode's stat.
+    queryNode->Stat().MaxNeighborDistance() =
+        std::max(queryNode->Left()->Stat().MaxNeighborDistance(),
+                 queryNode->Right()->Stat().MaxNeighborDistance());
   }
   else
   {
     // Recurse on both.
-    number_both_recursions_++;
-    
-    double left_dist = query_node->Left()->Bound().MinDistance(
-                                                               reference_node->Left()->Bound());
-    double right_dist = query_node->Left()->Bound().MinDistance(
-                                                                reference_node->Right()->Bound());
-    
-    if (left_dist < right_dist)
+    double leftDist = queryNode->Left()->Bound().MinDistance(
+        referenceNode->Left()->Bound());
+    double rightDist = queryNode->Left()->Bound().MinDistance(
+        referenceNode->Right()->Bound());
+
+    if (leftDist < rightDist)
     {
-      ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
-                                 left_dist);
-      ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
-                                 right_dist);
+      DualTreeRecursion(queryNode->Left(), referenceNode->Left(), leftDist);
+      DualTreeRecursion(queryNode->Left(), referenceNode->Right(),
+          rightDist);
     }
     else
     {
-      ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
-                                 right_dist);
-      ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
-                                 left_dist);
+      DualTreeRecursion(queryNode->Left(), referenceNode->Right(), rightDist);
+      DualTreeRecursion(queryNode->Left(), referenceNode->Left(), leftDist);
     }
-    
-    left_dist = query_node->Right()->Bound().MinDistance(
-                                                         reference_node->Left()->Bound());
-    right_dist = query_node->Right()->Bound().MinDistance(
-                                                          reference_node->Right()->Bound());
-    
-    if (left_dist < right_dist)
+
+    leftDist = queryNode->Right()->Bound().MinDistance(
+        referenceNode->Left()->Bound());
+    rightDist = queryNode->Right()->Bound().MinDistance(
+        referenceNode->Right()->Bound());
+
+    if (leftDist < rightDist)
     {
-      ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
-                                 left_dist);
-      ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
-                                 right_dist);
+      DualTreeRecursion(queryNode->Right(), referenceNode->Left(), leftDist);
+      DualTreeRecursion(queryNode->Right(), referenceNode->Right(), rightDist);
     }
     else
     {
-      ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
-                                 right_dist);
-      ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
-                                 left_dist);
+      DualTreeRecursion(queryNode->Right(), referenceNode->Right(), rightDist);
+      DualTreeRecursion(queryNode->Right(), referenceNode->Left(), leftDist);
     }
-    
-    query_node->Stat().set_max_neighbor_distance(
-                                                 std::max(query_node->Left()->Stat().max_neighbor_distance(),
-                                                          query_node->Right()->Stat().max_neighbor_distance()));
+
+    queryNode->Stat().MaxNeighborDistance() =
+        std::max(queryNode->Left()->Stat().MaxNeighborDistance(),
+                 queryNode->Right()->Stat().MaxNeighborDistance());
   }
-} // ComputeNeighborsRecursion_
+} // DualTreeRecursion
 
 /**
- * Computes the nearest neighbor of each point in each iteration
- * of the algorithm
+ * Unpermute the edge list (if necessary) and output it to results.
  */
-void DualTreeBoruvka::ComputeNeighbors_()
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::EmitResults(arma::mat& results)
 {
-  if (do_naive_)
-  {
-    ComputeBaseCase_(0, number_of_points_, 0, number_of_points_);
-  }
-  else
-  {
-    ComputeNeighborsRecursion_(tree_, tree_, DBL_MAX);
-  }
-} // ComputeNeighbors_
+  // Sort the edges.
+  std::sort(edges.begin(), edges.end(), SortFun);
 
-void DualTreeBoruvka::SortEdges_()
-{
-  std::sort(edges_.begin(), edges_.end(), SortFun);
-} // SortEdges_()
+  Log::Assert(edges.size() == data.n_cols - 1);
+  results.set_size(3, edges.size());
 
-/**
- * Unpermute the edge list and output it to results
- *
- */
-void DualTreeBoruvka::EmitResults_(arma::mat& results)
-{
-  SortEdges_();
-  
-  mlpack::Log::Assert(number_of_edges_ == number_of_points_ - 1);
-  results.set_size(number_of_edges_, 3);
-  
   // Need to unpermute the point labels.
-  if (!do_naive_)
+  if (!naive && ownTree)
   {
-    for (size_t i = 0; i < (number_of_points_ - 1); i++)
+    for (size_t i = 0; i < (data.n_cols - 1); i++)
     {
       // Make sure the edge list stores the smaller index first to
       // make checking correctness easier
-      size_t ind1, ind2;
-      ind1 = old_from_new_permutation_[edges_[i].lesser_index()];
-      ind2 = old_from_new_permutation_[edges_[i].greater_index()];
-      
-      edges_[i].set_lesser_index(std::min(ind1, ind2));
-      edges_[i].set_greater_index(std::max(ind1, ind2));
-      
-      results(i, 0) = edges_[i].lesser_index();
-      results(i, 1) = edges_[i].greater_index();
-      results(i, 2) = sqrt(edges_[i].distance());
+      size_t ind1 = oldFromNew[edges[i].Lesser()];
+      size_t ind2 = oldFromNew[edges[i].Greater()];
+
+      if (ind1 < ind2)
+      {
+        edges[i].Lesser() = ind1;
+        edges[i].Greater() = ind2;
+      }
+      else
+      {
+        edges[i].Lesser() = ind2;
+        edges[i].Greater() = ind1;
+      }
+
+      results(0, i) = edges[i].Lesser();
+      results(1, i) = edges[i].Greater();
+      results(2, i) = sqrt(edges[i].Distance());
     }
   }
   else
   {
-    for (size_t i = 0; i < number_of_edges_; i++)
+    for (size_t i = 0; i < edges.size(); i++)
     {
-      results(i, 0) = edges_[i].lesser_index();
-      results(i, 1) = edges_[i].greater_index();
-      results(i, 2) = sqrt(edges_[i].distance());
+      results(0, i) = edges[i].Lesser();
+      results(1, i) = edges[i].Greater();
+      results(2, i) = sqrt(edges[i].Distance());
     }
   }
-} // EmitResults_
+} // EmitResults
 
 /**
  * This function resets the values in the nodes of the tree nearest neighbor
  * distance, check for fully connected nodes
  */
-void DualTreeBoruvka::CleanupHelper_(DTBTree* tree)
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::CleanupHelper(TreeType* tree)
 {
-  tree->Stat().set_max_neighbor_distance(DBL_MAX);
-  
+  tree->Stat().MaxNeighborDistance() = DBL_MAX;
+
   if (!tree->IsLeaf())
   {
-    CleanupHelper_(tree->Left());
-    CleanupHelper_(tree->Right());
-    
-    if ((tree->Left()->Stat().component_membership() >= 0)
-        && (tree->Left()->Stat().component_membership() ==
-            tree->Right()->Stat().component_membership()))
+    CleanupHelper(tree->Left());
+    CleanupHelper(tree->Right());
+
+    if ((tree->Left()->Stat().ComponentMembership() >= 0)
+        && (tree->Left()->Stat().ComponentMembership() ==
+            tree->Right()->Stat().ComponentMembership()))
     {
-      tree->Stat().set_component_membership(tree->Left()->Stat().
-                                            component_membership());
+      tree->Stat().ComponentMembership() =
+          tree->Left()->Stat().ComponentMembership();
     }
   }
   else
   {
-    size_t new_membership = connections_.Find(tree->Begin());
-    
-    for (size_t i = tree->Begin(); i < tree->End(); i++)
+    size_t newMembership = connections.Find(tree->Begin());
+
+    for (size_t i = tree->Begin(); i < tree->End(); ++i)
     {
-      if (new_membership != connections_.Find(i))
+      if (newMembership != connections.Find(i))
       {
-        new_membership = -1;
-        mlpack::Log::Assert(tree->Stat().component_membership() < 0);
+        newMembership = -1;
+        Log::Assert(tree->Stat().ComponentMembership() < 0);
         return;
       }
     }
-    tree->Stat().set_component_membership(new_membership);
+    tree->Stat().ComponentMembership() = newMembership;
   }
-} // CleanupHelper_
+} // CleanupHelper
 
 /**
  * The values stored in the tree must be reset on each iteration.
  */
-void DualTreeBoruvka::Cleanup_()
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::Cleanup()
 {
-  for (size_t i = 0; i < number_of_points_; i++)
+  for (size_t i = 0; i < data.n_cols; i++)
   {
-    neighbors_distances_[i] = DBL_MAX;
+    neighborsDistances[i] = DBL_MAX;
   }
-  number_of_loops_++;
-  
-  if (!do_naive_)
+
+  if (!naive)
   {
-    CleanupHelper_(tree_);
+    CleanupHelper(tree);
   }
 }
 
-/**
- * Format and output the results
- */
-void DualTreeBoruvka::OutputResults_()
-{
-  /* fx_result_double(module_, "total_squared_length", total_dist_);
-   fx_result_int(module_, "number_of_points", number_of_points_);
-   fx_result_int(module_, "dimension", data_points_.n_rows);
-   fx_result_int(module_, "number_of_loops", number_of_loops_);
-   fx_result_int(module_, "number_distance_prunes", number_distance_prunes_);
-   fx_result_int(module_, "number_component_prunes", number_component_prunes_);
-   fx_result_int(module_, "number_leaf_computations", number_leaf_computations_);
-   fx_result_int(module_, "number_q_recursions", number_q_recursions_);
-   fx_result_int(module_, "number_r_recursions", number_r_recursions_);
-   fx_result_int(module_, "number_both_recursions", number_both_recursions_);*/
-  // TODO, not sure how I missed this last time.
-  mlpack::Log::Info << "Total squared length: " << total_dist_ << std::endl;
-  mlpack::Log::Info << "Number of points: " << number_of_points_ << std::endl;
-  mlpack::Log::Info << "Dimension: " << data_points_.n_rows << std::endl;
-  /*
-   mlpack::Log::Info << "number_of_loops" << std::endl;
-   mlpack::Log::Info << "number_distance_prunes" << std::endl;
-   mlpack::Log::Info << "number_component_prunes" << std::endl;
-   mlpack::Log::Info << "number_leaf_computations" << std::endl;
-   mlpack::Log::Info << "number_q_recursions" << std::endl;
-   mlpack::Log::Info << "number_r_recursions" << std::endl;
-   mlpack::Log::Info << "number_both_recursions" << std::endl;
-   */
-  
-} // OutputResults_
+}; // namespace emst
+}; // namespace mlpack
 
-size_t DualTreeBoruvka::number_of_edges()
-{
-  return number_of_edges_;
-}
-
-/**
- * Takes in a reference to the data set.  Copies the data, builds the tree,
- * and initializes all of the member variables.
- */
-void DualTreeBoruvka::Init(const arma::mat& data, bool naive = false, 
-                           size_t leafSize = 1)
-{
-  number_of_edges_ = 0;
-  data_points_ = data; // copy
-  
-  do_naive_ = naive;
-  
-  if (!do_naive_)
-  {
-    // Default leaf size is 1
-    // This gives best pruning empirically
-    // Use leaf_size=1 unless space is a big concern
-    Timers::StartTimer("emst/tree_building");
-    
-    tree_ = new DTBTree(data_points_, old_from_new_permutation_, leafSize);
-    
-    Timers::StopTimer("emst/tree_building");
-  }
-  else
-  {
-    tree_ = NULL;
-    old_from_new_permutation_.resize(0);
-  }
-  
-  number_of_points_ = data_points_.n_cols;
-  edges_.resize(number_of_points_ - 1, EdgePair()); // fill with EdgePairs
-  connections_.Init(number_of_points_);
-  
-  neighbors_in_component_.set_size(number_of_points_);
-  neighbors_out_component_.set_size(number_of_points_);
-  neighbors_distances_.set_size(number_of_points_);
-  neighbors_distances_.fill(DBL_MAX);
-  
-  total_dist_ = 0.0;
-  number_of_loops_ = 0;
-  number_distance_prunes_ = 0;
-  number_component_prunes_ = 0;
-  number_leaf_computations_ = 0;
-  number_q_recursions_ = 0;
-  number_r_recursions_ = 0;
-  number_both_recursions_ = 0;
-} // Init
-
-/**
- * Call this function after Init.  It will iteratively find the nearest
- * neighbor of each component until the MST is complete.
- */
-void DualTreeBoruvka::ComputeMST(arma::mat& results)
-{
-  Timers::StartTimer("emst/MST_computation");
-  
-  while (number_of_edges_ < (number_of_points_ - 1))
-  {
-    ComputeNeighbors_();
-    
-    AddAllEdges_();
-    
-    Cleanup_();
-    
-    Log::Info << "Finished loop number: " << number_of_loops_ << std::endl;
-    Log::Info << number_of_edges_ << " edges found so far.\n\n";
-    /*
-     Log::Info << number_leaf_computations_ << " base cases.\n";
-     Log::Info << number_distance_prunes_ << " distance prunes.\n";
-     Log::Info << number_component_prunes_ << " component prunes.\n";
-     Log::Info << number_r_recursions_ << " reference recursions.\n";
-     Log::Info << number_q_recursions_ << " query recursions.\n";
-     Log::Info << number_both_recursions_ << " dual recursions.\n\n";
-     */
-  }
-  
-  Timers::StopTimer("emst/MST_computation");
-  
-  EmitResults_(results);
-  
-  OutputResults_();
-} // ComputeMST
-
-
-
-
-
-#endif 
+#endif

Modified: mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp	2011-12-14 08:19:14 UTC (rev 10763)
+++ mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp	2011-12-14 10:53:00 UTC (rev 10764)
@@ -23,9 +23,12 @@
 class EdgePair
 {
  private:
-  size_t lesser_index_;
-  size_t greater_index_;
-  double distance_;
+  //! Lesser index.
+  size_t lesser;
+  //! Greater index.
+  size_t greater;
+  //! Distance between two indices.
+  double distance;
 
  public:
   /**
@@ -34,44 +37,28 @@
    * Init.  However, this is not necessary for functionality; it is just a way
    * to keep the edge list organized in other code.
    */
-  void Init(size_t lesser, size_t greater, double dist)
+  EdgePair(const size_t lesser, const size_t greater, const double dist) :
+      lesser(lesser), greater(greater), distance(dist)
   {
-    mlpack::Log::Assert(lesser != greater,
-        "indices equal when creating EdgePair, lesser == greater");
-    lesser_index_ = lesser;
-    greater_index_ = greater;
-    distance_ = dist;
+    Log::Assert(lesser != greater,
+        "EdgePair::EdgePair(): indices cannot be equal.");
   }
 
-  size_t lesser_index()
-  {
-    return lesser_index_;
-  }
+  //! Get the lesser index.
+  size_t Lesser() const { return lesser; }
+  //! Modify the lesser index.
+  size_t& Lesser() { return lesser; }
 
-  void set_lesser_index(size_t index)
-  {
-    lesser_index_ = index;
-  }
+  //! Get the greater index.
+  size_t Greater() const { return greater; }
+  //! Modify the greater index.
+  size_t& Greater() { return greater; }
 
-  size_t greater_index()
-  {
-    return greater_index_;
-  }
+  //! Get the distance.
+  double Distance() const { return distance; }
+  //! Modify the distance.
+  double& Distance() { return distance; }
 
-  void set_greater_index(size_t index)
-  {
-    greater_index_ = index;
-  }
-
-  double distance() const
-  {
-    return distance_;
-  }
-
-  void set_distance(double new_dist)
-  {
-    distance_ = new_dist;
-  }
 }; // class EdgePair
 
 }; // namespace emst

Modified: mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp	2011-12-14 08:19:14 UTC (rev 10763)
+++ mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp	2011-12-14 10:53:00 UTC (rev 10764)
@@ -27,7 +27,11 @@
     "Conference\n        on Knowledge Discovery and Data Mining},\n"
     "    series = {KDD '10},\n"
     "    year = {2010}\n"
-    "  }\n");
+    "  }\n\n"
+    "The output is saved in a three-column matrix, where each row indicates an "
+    "edge.  The first column corresponds to the lesser index of the edge; the "
+    "second column corresponds to the greater index of the edge; and the third "
+    "column corresponds to the distance between the two points.");
 
 PARAM_STRING_REQ("input_file", "Data input file.", "i");
 PARAM_STRING("output_file", "Data output file.  Stored as an edge list.", "o",
@@ -41,6 +45,7 @@
 
 using namespace mlpack;
 using namespace mlpack::emst;
+using namespace mlpack::tree;
 
 int main(int argc, char* argv[])
 {
@@ -54,21 +59,19 @@
   arma::mat dataPoints;
   data::Load(dataFilename.c_str(), dataPoints, true);
 
-  // Do naive
+  // Do naive.
   if (CLI::GetParam<bool>("naive"))
   {
     Log::Info << "Running naive algorithm.\n";
 
-    DualTreeBoruvka naive;
+    DualTreeBoruvka<> naive(dataPoints, true);
 
-    naive.Init(dataPoints, true);
+    arma::mat naiveResults;
+    naive.ComputeMST(naiveResults);
 
-    arma::mat naive_results;
-    naive.ComputeMST(naive_results);
-
     std::string outputFilename = CLI::GetParam<std::string>("output_file");
 
-    data::Save(outputFilename.c_str(), naive_results, true);
+    data::Save(outputFilename.c_str(), naiveResults, true);
   }
   else
   {
@@ -83,10 +86,9 @@
 
     size_t leafSize = CLI::GetParam<int>("leaf_size");
 
-    DualTreeBoruvka dtb;
-    dtb.Init(dataPoints, false, leafSize);
+    DualTreeBoruvka<> dtb(dataPoints, false, leafSize);
 
-    Log::Info << "Tree built, running algorithm.\n\n";
+    Log::Info << "Tree built, running algorithm.\n";
 
     ////////////// Run DTB /////////////////////
     arma::mat results;

Modified: mlpack/trunk/src/mlpack/methods/emst/union_find.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/union_find.hpp	2011-12-14 08:19:14 UTC (rev 10763)
+++ mlpack/trunk/src/mlpack/methods/emst/union_find.hpp	2011-12-14 10:53:00 UTC (rev 10764)
@@ -1,6 +1,5 @@
 /**
- * @file union_find.h
- *
+ * @file union_find.hpp
  * @author Bill March (march at gatech.edu)
  *
  * Implements a union-find data structure.  This structure tracks the components
@@ -22,82 +21,69 @@
 class UnionFind
 {
  private:
-  arma::Col<size_t> parent_;
-  arma::ivec rank_;
-  size_t number_of_elements_;
+  size_t size;
+  arma::Col<size_t> parent;
+  arma::ivec rank;
 
  public:
-  UnionFind() {}
-
-  ~UnionFind() {}
-
-  /**
-   * Initializes the structure.  This implementation assumes
-   * that the size is known advance and fixed
-   *
-   * @param size The number of elements to be tracked.
-   */
-  void Init(size_t size)
+  UnionFind(const size_t size) : size(size), parent(size), rank(size)
   {
-    number_of_elements_ = size;
-    parent_.set_size(number_of_elements_);
-    rank_.set_size(number_of_elements_);
-    for (size_t i = 0; i < number_of_elements_; i++)
+    for (size_t i = 0; i < size; ++i)
     {
-      parent_[i] = i;
-      rank_[i] = 0;
+      parent[i] = i;
+      rank[i] = 0;
     }
   }
 
+  ~UnionFind() {}
+
   /**
-   * Returns the component containing an element
+   * Returns the component containing an element.
    *
    * @param x the component to be found
    * @return The index of the component containing x
    */
-  size_t Find(size_t x)
+  size_t Find(const size_t x)
   {
-    if (parent_[x] == x)
+    if (parent[x] == x)
     {
       return x;
     }
     else
     {
       // This ensures that the tree has a small depth
-      parent_[x] = Find(parent_[x]);
-      return parent_[x];
+      parent[x] = Find(parent[x]);
+      return parent[x];
     }
   }
 
   /**
-   * @function Union
+   * Union the components containing x and y.
    *
-   * Union the components containing x and y
-   *
    * @param x one component
    * @param y the other component
    */
-  void Union(size_t x, size_t y)
+  void Union(const size_t x, const size_t y)
   {
-    size_t x_root = Find(x);
-    size_t y_root = Find(y);
+    const size_t xRoot = Find(x);
+    const size_t yRoot = Find(y);
 
-    if (x_root == y_root)
+    if (xRoot == yRoot)
     {
       return;
     }
-    else if (rank_[x_root] == rank_[y_root])
+    else if (rank[xRoot] == rank[yRoot])
     {
-      parent_[y_root] = parent_[x_root];
-      rank_[x_root] = rank_[x_root] + 1;
+      parent[yRoot] = parent[xRoot];
+      rank[xRoot] = rank[xRoot] + 1;
     }
-    else if (rank_[x_root] > rank_[y_root])
+    else if (rank[xRoot] > rank[yRoot])
     {
-      parent_[y_root] = x_root;
+      parent[yRoot] = xRoot;
     }
     else
     {
-      parent_[x_root] = y_root;
+      parent[xRoot] = yRoot;
     }
   }
 }; // class UnionFind




More information about the mlpack-svn mailing list