[mlpack-svn] r10363 - mlpack/trunk/src/mlpack/methods/emst

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 23 17:22:16 EST 2011


Author: rcurtin
Date: 2011-11-23 17:22:16 -0500 (Wed, 23 Nov 2011)
New Revision: 10363

Modified:
   mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
   mlpack/trunk/src/mlpack/methods/emst/emst.hpp
   mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
   mlpack/trunk/src/mlpack/methods/emst/union_find.hpp
Log:
Format EMST as per #153.


Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	2011-11-23 22:06:03 UTC (rev 10362)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	2011-11-23 22:22:16 UTC (rev 10363)
@@ -10,7 +10,6 @@
  * Spanning Tree: Algorithm, Analysis, Applications.  In KDD, 2010.
  *
  */
-
 #ifndef __MLPACK_METHODS_EMST_DTB_HPP
 #define __MLPACK_METHODS_EMST_DTB_HPP
 
@@ -25,7 +24,7 @@
 namespace emst {
 
 /**
-* A Stat class for use with fastlib's trees.  This one only stores two values.
+ * A Stat class for use with fastlib's trees.  This one only stores two values.
  *
  * @param max_neighbor_distance The upper bound on the distance to the nearest
  * neighbor of any point in this node.
@@ -35,72 +34,83 @@
  * points in this node.  If points in this node are in different components,
  * this value will be negative.
  */
-
-class DTBStat {
+class DTBStat
+{
  private:
   double max_neighbor_distance_;
   int component_membership_;
 
  public:
-  void set_max_neighbor_distance(double distance) {
+  void set_max_neighbor_distance(double distance)
+  {
     max_neighbor_distance_ = distance;
   }
 
-  double max_neighbor_distance() {
+  double max_neighbor_distance()
+  {
     return max_neighbor_distance_;
   }
 
-  void set_component_membership(int membership) {
+  void set_component_membership(int membership)
+  {
     component_membership_ = membership;
   }
 
-  int component_membership() {
+  int component_membership()
+  {
     return component_membership_;
   }
 
   /**
-    * A generic initializer.
-    */
-  DTBStat() {
+   * A generic initializer.
+   */
+  DTBStat()
+  {
     set_max_neighbor_distance(DBL_MAX);
     set_component_membership(-1);
   }
 
   /**
-    * An initializer for leaves.
+   * An initializer for leaves.
    */
-  DTBStat(const arma::mat& dataset, size_t start, size_t count) {
-    if (count == 1) {
+  DTBStat(const arma::mat& dataset, size_t start, size_t count)
+  {
+    if (count == 1)
+    {
       set_component_membership(start);
       set_max_neighbor_distance(DBL_MAX);
-    } else {
+    }
+    else
+    {
       set_max_neighbor_distance(DBL_MAX);
       set_component_membership(-1);
     }
   }
 
   /**
-    * An initializer for non-leaves.  Simply calls the leaf initializer.
+   * An initializer for non-leaves.  Simply calls the leaf initializer.
    */
   DTBStat(const arma::mat& dataset, size_t start, size_t count,
-          const DTBStat& left_stat, const DTBStat& right_stat) {
-    if (count == 1) {
+          const DTBStat& left_stat, const DTBStat& right_stat)
+  {
+    if (count == 1)
+    {
       set_component_membership(start);
       set_max_neighbor_distance(DBL_MAX);
-    } else {
+    }
+    else
+    {
       set_max_neighbor_distance(DBL_MAX);
       set_component_membership(-1);
     }
   }
-
 }; // class DTBStat
 
-
 /**
  * Performs the MST calculation using the Dual-Tree Boruvka algorithm.
  */
-class DualTreeBoruvka {
-
+class DualTreeBoruvka
+{
  public:
   // For now, everything is in Euclidean space
   static const size_t metric = 2;
@@ -110,7 +120,6 @@
   //////// Member Variables /////////////////////
 
  private:
-
   size_t number_of_edges_;
   std::vector<EdgePair> edges_; // must use vector with non-numerical types
   size_t number_of_points_;
@@ -139,28 +148,23 @@
 
   DTBTree* tree_;
 
-
 ////////////////// Constructors ////////////////////////
-
  public:
+  DualTreeBoruvka() { }
 
-  DualTreeBoruvka() {}
-
-  ~DualTreeBoruvka() {
-    if (tree_ != NULL) {
+  ~DualTreeBoruvka()
+  {
+    if (tree_ != NULL)
       delete tree_;
-    }
   }
 
-
   ////////////////////////// Private Functions ////////////////////
  private:
-
   /**
-  * Adds a single edge to the edge list
+   * Adds a single edge to the edge list
    */
-  void AddEdge_(size_t e1, size_t e2, double distance) {
-
+  void AddEdge_(size_t e1, size_t e2, double distance)
+  {
     //EdgePair edge;
     mlpack::Log::Assert((e1 != e2),
         "Indices are equal in DualTreeBoruvka.add_edge(...)");
@@ -168,28 +172,27 @@
     mlpack::Log::Assert((distance >= 0.0),
         "Negative distance input in DualTreeBoruvka.add_edge(...)");
 
-    if (e1 < e2) {
+    if (e1 < e2)
       edges_[number_of_edges_].Init(e1, e2, distance);
-    }
-    else {
+    else
       edges_[number_of_edges_].Init(e2, e1, distance);
-    }
 
     number_of_edges_++;
 
   } // AddEdge_
 
-
   /**
    * Adds all the edges found in one iteration to the list of neighbors.
    */
-  void AddAllEdges_() {
-
-    for (size_t i = 0; i < number_of_points_; i++) {
+  void AddAllEdges_()
+  {
+    for (size_t i = 0; i < number_of_points_; i++)
+    {
       size_t component_i = connections_.Find(i);
       size_t in_edge_i = neighbors_in_component_[component_i];
       size_t out_edge_i = neighbors_out_component_[component_i];
-      if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i)) {
+      if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i))
+      {
         double dist = neighbors_distances_[component_i];
         //total_dist_ = total_dist_ + dist;
         // changed to make this agree with the cover tree code
@@ -198,58 +201,52 @@
         connections_.Union(in_edge_i, out_edge_i);
       }
     }
-
   } // AddAllEdges_
 
 
   /**
-    * Handles the base case computation.  Also called by naive.
-  */
+   * Handles the base case computation.  Also called by naive.
+   */
   double ComputeBaseCase_(size_t query_start, size_t query_end,
-                          size_t reference_start, size_t reference_end) {
-
+                          size_t reference_start, size_t reference_end)
+  {
     number_leaf_computations_++;
 
     double new_upper_bound = -1.0;
 
     for (size_t query_index = query_start; query_index < query_end;
-         query_index++) {
-
+         query_index++)
+    {
       // Find the index of the component the query is in
       size_t query_component_index = connections_.Find(query_index);
 
       arma::vec query_point = data_points_.col(query_index);
 
       for (size_t reference_index = reference_start;
-           reference_index < reference_end; reference_index++) {
-
+           reference_index < reference_end; reference_index++)
+      {
         size_t reference_component_index = connections_.Find(reference_index);
 
-        if (query_component_index != reference_component_index) {
-
+        if (query_component_index != reference_component_index)
+        {
           arma::vec reference_point = data_points_.col(reference_index);
 
           double distance = mlpack::metric::LMetric<2>::Evaluate(query_point,
               reference_point);
 
-          if (distance < neighbors_distances_[query_component_index]) {
-
+          if (distance < neighbors_distances_[query_component_index])
+          {
             mlpack::Log::Assert(query_index != reference_index);
 
             neighbors_distances_[query_component_index] = distance;
             neighbors_in_component_[query_component_index] = query_index;
             neighbors_out_component_[query_component_index] = reference_index;
-
           } // if distance
-
         } // if indices not equal
-
-
       } // for reference_index
 
-      if (new_upper_bound < neighbors_distances_[query_component_index]) {
+      if (new_upper_bound < neighbors_distances_[query_component_index])
         new_upper_bound = neighbors_distances_[query_component_index];
-      }
 
     } // for query_index
 
@@ -260,178 +257,183 @@
 
 
   /**
-    * Handles the recursive calls to find the nearest neighbors in an iteration
+   * Handles the recursive calls to find the nearest neighbors in an iteration
    */
   void ComputeNeighborsRecursion_(DTBTree *query_node, DTBTree *reference_node,
-                                  double incoming_distance) {
-
-    // Check for a distance prune
-    if (query_node->stat().max_neighbor_distance() < incoming_distance) {
-      //pruned by distance
+                                  double incoming_distance)
+  {
+    // Check for a distance prune.
+    if (query_node->stat().max_neighbor_distance() < incoming_distance)
+    {
+      // Pruned by distance.
       number_distance_prunes_++;
     }
-    // Check for a component prune
+    // Check for a component prune.
     else if ((query_node->stat().component_membership() >= 0)
-             && (query_node->stat().component_membership() ==
-                 reference_node->stat().component_membership())) {
-      //pruned by component membership
-
+        && (query_node->stat().component_membership() ==
+            reference_node->stat().component_membership()))
+    {
+      // Pruned by component membership.
       mlpack::Log::Assert(reference_node->stat().component_membership() >= 0);
+      mlpack::Log::Info << query_node->stat().component_membership()
+          << "q mem\n";
+      mlpack::Log::Info << reference_node->stat().component_membership()
+          << "r mem\n";
 
-      mlpack::Log::Info << query_node->stat().component_membership() << "q mem\n";
-      mlpack::Log::Info << reference_node->stat().component_membership() << "r mem\n";
-
       number_component_prunes_++;
     }
-    // The base case
-    else if (query_node->is_leaf() && reference_node->is_leaf()) {
+    else if (query_node->is_leaf() && reference_node->is_leaf()) // Base case.
+    {
+      double new_bound = ComputeBaseCase_(query_node->begin(),
+          query_node->end(), reference_node->begin(), reference_node->end());
 
-      double new_bound =
-      ComputeBaseCase_(query_node->begin(), query_node->end(),
-                       reference_node->begin(), reference_node->end());
-
       query_node->stat().set_max_neighbor_distance(new_bound);
-
     }
-    // Other recursive calls
-    else if (query_node->is_leaf()) {
-      //recurse on reference_node only
+    else if (query_node->is_leaf()) // Other recursive calls.
+    {
+      // Recurse on reference_node only.
       number_r_recursions_++;
 
       double left_dist =
-        query_node->bound().MinDistance(reference_node->left()->bound());
+          query_node->bound().MinDistance(reference_node->left()->bound());
       double right_dist =
-        query_node->bound().MinDistance(reference_node->right()->bound());
+          query_node->bound().MinDistance(reference_node->right()->bound());
       mlpack::Log::Assert(left_dist >= 0.0);
       mlpack::Log::Assert(right_dist >= 0.0);
 
-      if (left_dist < right_dist) {
-        ComputeNeighborsRecursion_(query_node,
-                                   reference_node->left(), left_dist);
-        ComputeNeighborsRecursion_(query_node,
-                                   reference_node->right(), right_dist);
+      if (left_dist < right_dist)
+      {
+        ComputeNeighborsRecursion_(query_node, reference_node->left(),
+            left_dist);
+        ComputeNeighborsRecursion_(query_node, reference_node->right(),
+            right_dist);
       }
-      else {
-        ComputeNeighborsRecursion_(query_node,
-                                   reference_node->right(), right_dist);
-        ComputeNeighborsRecursion_(query_node,
-                                   reference_node->left(), left_dist);
+      else
+      {
+        ComputeNeighborsRecursion_(query_node, reference_node->right(),
+            right_dist);
+        ComputeNeighborsRecursion_(query_node, reference_node->left(),
+            left_dist);
       }
-
     }
-    else if (reference_node->is_leaf()) {
-      //recurse on query_node only
-
+    else if (reference_node->is_leaf())
+    {
+      // Recurse on query_node only.
       number_q_recursions_++;
 
       double left_dist =
-        query_node->left()->bound().MinDistance(reference_node->bound());
+          query_node->left()->bound().MinDistance(reference_node->bound());
       double right_dist =
-        query_node->right()->bound().MinDistance(reference_node->bound());
+          query_node->right()->bound().MinDistance(reference_node->bound());
 
-      ComputeNeighborsRecursion_(query_node->left(),
-                                 reference_node, left_dist);
-      ComputeNeighborsRecursion_(query_node->right(),
-                                 reference_node, right_dist);
+      ComputeNeighborsRecursion_(query_node->left(), reference_node, left_dist);
+      ComputeNeighborsRecursion_(query_node->right(), reference_node,
+          right_dist);
 
-      // Update query_node's stat
+      // Update query_node's stat.
       query_node->stat().set_max_neighbor_distance(
           std::max(query_node->left()->stat().max_neighbor_distance(),
-              query_node->right()->stat().max_neighbor_distance()));
+          query_node->right()->stat().max_neighbor_distance()));
 
     }
-    else {
-      //recurse on both
-
+    else
+    {
+      // Recurse on both.
       number_both_recursions_++;
 
-      double left_dist = query_node->left()->bound().
-        MinDistance(reference_node->left()->bound());
-      double right_dist = query_node->left()->bound().
-        MinDistance(reference_node->right()->bound());
+      double left_dist = query_node->left()->bound().MinDistance(
+          reference_node->left()->bound());
+      double right_dist = query_node->left()->bound().MinDistance(
+          reference_node->right()->bound());
 
-      if (left_dist < right_dist) {
-        ComputeNeighborsRecursion_(query_node->left(),
-                                   reference_node->left(), left_dist);
-        ComputeNeighborsRecursion_(query_node->left(),
-                                   reference_node->right(), right_dist);
+      if (left_dist < right_dist)
+      {
+        ComputeNeighborsRecursion_(query_node->left(), reference_node->left(),
+            left_dist);
+        ComputeNeighborsRecursion_(query_node->left(), reference_node->right(),
+            right_dist);
       }
-      else {
-        ComputeNeighborsRecursion_(query_node->left(),
-                                   reference_node->right(), right_dist);
-        ComputeNeighborsRecursion_(query_node->left(),
-                                   reference_node->left(), left_dist);
+      else
+      {
+        ComputeNeighborsRecursion_(query_node->left(), reference_node->right(),
+            right_dist);
+        ComputeNeighborsRecursion_(query_node->left(), reference_node->left(),
+            left_dist);
       }
 
-      left_dist = query_node->right()->bound().
-        MinDistance(reference_node->left()->bound());
-      right_dist = query_node->right()->bound().
-        MinDistance(reference_node->right()->bound());
+      left_dist = query_node->right()->bound().MinDistance(
+          reference_node->left()->bound());
+      right_dist = query_node->right()->bound().MinDistance(
+          reference_node->right()->bound());
 
-      if (left_dist < right_dist) {
-        ComputeNeighborsRecursion_(query_node->right(),
-                                   reference_node->left(), left_dist);
-        ComputeNeighborsRecursion_(query_node->right(),
-                                   reference_node->right(), right_dist);
+      if (left_dist < right_dist)
+      {
+        ComputeNeighborsRecursion_(query_node->right(), reference_node->left(),
+            left_dist);
+        ComputeNeighborsRecursion_(query_node->right(), reference_node->right(),
+            right_dist);
       }
-      else {
-        ComputeNeighborsRecursion_(query_node->right(),
-                                   reference_node->right(), right_dist);
-        ComputeNeighborsRecursion_(query_node->right(),
-                                   reference_node->left(), left_dist);
+      else
+      {
+        ComputeNeighborsRecursion_(query_node->right(), reference_node->right(),
+            right_dist);
+        ComputeNeighborsRecursion_(query_node->right(), reference_node->left(),
+            left_dist);
       }
 
       query_node->stat().set_max_neighbor_distance(
           std::max(query_node->left()->stat().max_neighbor_distance(),
-              query_node->right()->stat().max_neighbor_distance()));
-
-    }// end else
-
+          query_node->right()->stat().max_neighbor_distance()));
+    }
   } // ComputeNeighborsRecursion_
 
   /**
-    * Computes the nearest neighbor of each point in each iteration
+   * Computes the nearest neighbor of each point in each iteration
    * of the algorithm
    */
-  void ComputeNeighbors_() {
-    if (do_naive_) {
+  void ComputeNeighbors_()
+  {
+    if (do_naive_)
+    {
       ComputeBaseCase_(0, number_of_points_, 0, number_of_points_);
     }
-    else {
+    else
+    {
       ComputeNeighborsRecursion_(tree_, tree_, DBL_MAX);
     }
   } // ComputeNeighbors_
 
-
-  struct SortEdgesHelper_ {
-    bool operator() (const EdgePair& pairA, const EdgePair& pairB) {
+  struct SortEdgesHelper_
+  {
+    bool operator() (const EdgePair& pairA, const EdgePair& pairB)
+    {
       return (pairA.distance() < pairB.distance());
     }
   } SortFun;
 
-  void SortEdges_() {
-
+  void SortEdges_()
+  {
     std::sort(edges_.begin(), edges_.end(), SortFun);
-
   } // SortEdges_()
 
   /**
-    * Unpermute the edge list and output it to results
+   * Unpermute the edge list and output it to results
    *
    * TODO: Make this sort the edge list by distance as well for hierarchical
    * clusterings.
    */
-  void EmitResults_(arma::mat& results) {
-
+  void EmitResults_(arma::mat& results)
+  {
     SortEdges_();
 
     mlpack::Log::Assert(number_of_edges_ == number_of_points_ - 1);
     results.set_size(number_of_edges_, 3);
 
-    // need to unpermute the point labels
-    if (!do_naive_) {
-      for (size_t i = 0; i < (number_of_points_ - 1); i++) {
-
+    // Need to unpermute the point labels.
+    if (!do_naive_)
+    {
+      for (size_t i = 0; i < (number_of_points_ - 1); i++)
+      {
         // Make sure the edge list stores the smaller index first to
         // make checking correctness easier
         size_t ind1, ind2;
@@ -444,81 +446,79 @@
         results(i, 0) = edges_[i].lesser_index();
         results(i, 1) = edges_[i].greater_index();
         results(i, 2) = sqrt(edges_[i].distance());
-
       }
     }
-    else {
-
-      for (size_t i = 0; i < number_of_edges_; i++) {
+    else
+    {
+      for (size_t i = 0; i < number_of_edges_; i++)
+      {
         results(i, 0) = edges_[i].lesser_index();
         results(i, 1) = edges_[i].greater_index();
         results(i, 2) = sqrt(edges_[i].distance());
       }
-
     }
-
   } // EmitResults_
 
-
-
   /**
-    * This function resets the values in the nodes of the tree
-   * nearest neighbor distance, check for fully connected nodes
+   * This function resets the values in the nodes of the tree nearest neighbor
+   * distance, check for fully connected nodes
    */
-  void CleanupHelper_(DTBTree* tree) {
-
+  void CleanupHelper_(DTBTree* tree)
+  {
     tree->stat().set_max_neighbor_distance(DBL_MAX);
 
-    if (!tree->is_leaf()) {
+    if (!tree->is_leaf())
+    {
       CleanupHelper_(tree->left());
       CleanupHelper_(tree->right());
 
       if ((tree->left()->stat().component_membership() >= 0)
           && (tree->left()->stat().component_membership() ==
-              tree->right()->stat().component_membership())) {
+              tree->right()->stat().component_membership()))
+      {
         tree->stat().set_component_membership(tree->left()->stat().
-                                              component_membership());
+            component_membership());
       }
     }
-    else {
-
+    else
+    {
       size_t new_membership = connections_.Find(tree->begin());
 
-      for (size_t i = tree->begin(); i < tree->end(); i++) {
-        if (new_membership != connections_.Find(i)) {
+      for (size_t i = tree->begin(); i < tree->end(); i++)
+      {
+        if (new_membership != connections_.Find(i))
+        {
           new_membership = -1;
           mlpack::Log::Assert(tree->stat().component_membership() < 0);
           return;
         }
       }
       tree->stat().set_component_membership(new_membership);
-
     }
-
   } // CleanupHelper_
 
   /**
-    * The values stored in the tree must be reset on each iteration.
+   * The values stored in the tree must be reset on each iteration.
    */
-  void Cleanup_() {
-
-    for (size_t i = 0; i < number_of_points_; i++) {
+  void Cleanup_()
+  {
+    for (size_t i = 0; i < number_of_points_; i++)
+    {
       neighbors_distances_[i] = DBL_MAX;
-      //DEBUG_ONLY(neighbors_in_component_[i] = BIG_BAD_NUMBER);
-      //DEBUG_ONLY(neighbors_out_component_[i] = BIG_BAD_NUMBER);
     }
     number_of_loops_++;
 
-    if (!do_naive_) {
+    if (!do_naive_)
+    {
       CleanupHelper_(tree_);
     }
   }
 
   /**
-    * Format and output the results
+   * Format and output the results
    */
-  void OutputResults_() {
-
+  void OutputResults_()
+  {
     /* fx_result_double(module_, "total_squared_length", total_dist_);
     fx_result_int(module_, "number_of_points", number_of_points_);
     fx_result_int(module_, "dimension", data_points_.n_rows);
@@ -544,31 +544,28 @@
      */
 
     mlpack::CLI::GetParam<double>("dtb/total_squared_length") = total_dist_;
-
   } // OutputResults_
 
   /////////// Public Functions ///////////////////
-
  public:
-
-  size_t number_of_edges() {
+  size_t number_of_edges()
+  {
     return number_of_edges_;
   }
 
-
   /**
-   * Takes in a reference to the data set.  Copies the data,
-   * builds the tree, and initializes all of the member variables.
-   *
+   * Takes in a reference to the data set.  Copies the data, builds the tree,
+   * and initializes all of the member variables.
    */
-  void Init(const arma::mat& data) {
-
+  void Init(const arma::mat& data)
+  {
     number_of_edges_ = 0;
     data_points_ = data; // copy
 
     do_naive_ = CLI::GetParam<bool>("naive/do_naive");
 
-    if (!do_naive_) {
+    if (!do_naive_)
+    {
       // Default leaf size is 1
       // This gives best pruning empirically
       // Use leaf_size=1 unless space is a big concern
@@ -580,13 +577,11 @@
       tree_ = new DTBTree(data_points_, old_from_new_permutation_);
 
       Timers::StopTimer("emst/tree_building");
-
     }
-    else {
-
+    else
+    {
       tree_ = NULL;
       old_from_new_permutation_.resize(0);
-
     }
 
     number_of_points_ = data_points_.n_cols;
@@ -608,24 +603,22 @@
     number_both_recursions_ = 0;
   } // Init
 
-
   /**
    * Call this function after Init.  It will iteratively find the nearest
    * neighbor of each component until the MST is complete.
    */
-  void ComputeMST(arma::mat& results) {
-
+  void ComputeMST(arma::mat& results)
+  {
     Timers::StartTimer("emst/MST_computation");
 
-    while (number_of_edges_ < (number_of_points_ - 1)) {
-
+    while (number_of_edges_ < (number_of_points_ - 1))
+    {
       ComputeNeighbors_();
 
       AddAllEdges_();
 
       Cleanup_();
 
-
       Log::Info << "Finished loop number: " << number_of_loops_ << std::endl;
       Log::Info << number_of_edges_ << " edges found so far.\n\n";
       /*
@@ -636,7 +629,6 @@
       Log::Info << number_q_recursions_ << " query recursions.\n";
       Log::Info << number_both_recursions_ << " dual recursions.\n\n";
       */
-
     }
 
     Timers::StopTimer("emst/MST_computation");
@@ -644,10 +636,9 @@
     EmitResults_(results);
 
     OutputResults_();
-
   } // ComputeMST
 
-}; //class DualTreeBoruvka
+}; // class DualTreeBoruvka
 
 }; // namespace emst
 }; // namespace mlpack

Modified: mlpack/trunk/src/mlpack/methods/emst/emst.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst.hpp	2011-11-23 22:06:03 UTC (rev 10362)
+++ mlpack/trunk/src/mlpack/methods/emst/emst.hpp	2011-11-23 22:22:16 UTC (rev 10363)
@@ -1,12 +1,11 @@
 /**
-* @file emst.h
-*
-* @author Bill March (march at gatech.edu)
-*
-* This file contains utilities necessary for all of the minimum spanning tree
-* algorithms.
-*/
-
+ * @file emst.h
+ *
+ * @author Bill March (march at gatech.edu)
+ *
+ * This file contains utilities necessary for all of the minimum spanning tree
+ * algorithms.
+ */
 #ifndef __MLPACK_METHODS_EMST_EMST_HPP
 #define __MLPACK_METHODS_EMST_EMST_HPP
 
@@ -20,59 +19,61 @@
 /**
  * An edge pair is simply two indices and a distance.  It is used as the
  * basic element of an edge list when computing a minimum spanning tree.
-*/
-class EdgePair {
-
-private:
+ */
+class EdgePair
+{
+ private:
   size_t lesser_index_;
   size_t greater_index_;
   double distance_;
 
-public:
+ public:
+  /**
+   * Initialize an EdgePair with two indices and a distance.  The indices are
+   * called lesser and greater, implying that they be sorted before calling
+   * Init.  However, this is not necessary for functionality; it is just a way
+   * to keep the edge list organized in other code.
+   */
+  void Init(size_t lesser, size_t greater, double dist)
+  {
+    mlpack::Log::Assert(lesser != greater,
+        "indices equal when creating EdgePair, lesser == greater");
+    lesser_index_ = lesser;
+    greater_index_ = greater;
+    distance_ = dist;
+  }
 
-
-    /**
-     * Initialize an EdgePair with two indices and a distance.  The indices are
-     * called lesser and greater, implying that they be sorted before calling
-     * Init.  However, this is not necessary for functionality; it is just a way
-     * to keep the edge list organized in other code.
-     */
-    void Init(size_t lesser, size_t greater, double dist) {
-
-      mlpack::Log::Assert(lesser != greater,
-          "indices equal when creating EdgePair, lesser == greater");
-      lesser_index_ = lesser;
-      greater_index_ = greater;
-      distance_ = dist;
-
-    }
-
-  size_t lesser_index() {
+  size_t lesser_index()
+  {
     return lesser_index_;
   }
 
-  void set_lesser_index(size_t index) {
+  void set_lesser_index(size_t index)
+  {
     lesser_index_ = index;
   }
 
-  size_t greater_index() {
+  size_t greater_index()
+  {
     return greater_index_;
   }
 
-  void set_greater_index(size_t index) {
+  void set_greater_index(size_t index)
+  {
     greater_index_ = index;
   }
 
-  double distance() const {
+  double distance() const
+  {
     return distance_;
   }
 
-  void set_distance(double new_dist) {
+  void set_distance(double new_dist)
+  {
     distance_ = new_dist;
   }
+}; // class EdgePair
 
-};// class EdgePair
-
 }; // namespace emst
 }; // namespace mlpack
 

Modified: mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp	2011-11-23 22:06:03 UTC (rev 10362)
+++ mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp	2011-11-23 22:22:16 UTC (rev 10363)
@@ -1,5 +1,5 @@
 /**
-* @file emst.cc
+ * @file emst.cc
  *
  * Calls the DualTreeBoruvka algorithm from dtb.h
  * Can optionally call Naive Boruvka's method
@@ -10,7 +10,7 @@
  * In KDD, 2010.
  *
  * @author Bill March (march at gatech.edu)
-*/
+ */
 
 #include "dtb.hpp"
 
@@ -27,11 +27,10 @@
 using namespace mlpack;
 using namespace mlpack::emst;
 
-int main(int argc, char* argv[]) {
-
+int main(int argc, char* argv[])
+{
   CLI::ParseCommandLine(argc, argv);
 
-
   ///////////////// READ IN DATA //////////////////////////////////
   std::string data_file_name = CLI::GetParam<std::string>("emst/input_file");
 
@@ -41,8 +40,8 @@
   data::Load(data_file_name.c_str(), data_points, true);
 
   // Do naive
-  if (CLI::GetParam<bool>("naive/do_naive")) {
-
+  if (CLI::GetParam<bool>("naive/do_naive"))
+  {
     Log::Info << "Running naive algorithm.\n";
 
     DualTreeBoruvka naive;
@@ -58,8 +57,8 @@
 
     data::Save(naive_output_filename.c_str(), naive_results, true);
   }
-  else {
-
+  else
+  {
     Log::Info << "Data read, building tree.\n";
 
     /////////////// Initialize DTB //////////////////////
@@ -83,5 +82,4 @@
   }
 
   return 0;
-
 }

Modified: mlpack/trunk/src/mlpack/methods/emst/union_find.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/union_find.hpp	2011-11-23 22:06:03 UTC (rev 10362)
+++ mlpack/trunk/src/mlpack/methods/emst/union_find.hpp	2011-11-23 22:22:16 UTC (rev 10363)
@@ -8,7 +8,6 @@
  * Calling unionfind.Union(x, y) unites the components indexed by x and y.
  * unionfind.Find(x) returns the index of the component containing point x.
  */
-
 #ifndef __MLPACK_METHODS_EMST_UNION_FIND_HPP
 #define __MLPACK_METHODS_EMST_UNION_FIND_HPP
 
@@ -18,20 +17,16 @@
 namespace emst {
 
 /**
- * @class UnionFind
- *
- *A Union-Find data structure.  See Cormen, Rivest, & Stein for details.
+ * A Union-Find data structure.  See Cormen, Rivest, & Stein for details.
  */
-class UnionFind {
-  friend class TestUnionFind;
-private:
-
+class UnionFind
+{
+ private:
   arma::Col<size_t> parent_;
   arma::ivec rank_;
   size_t number_of_elements_;
 
-public:
-
+ public:
   UnionFind() {}
 
   ~UnionFind() {}
@@ -42,17 +37,16 @@
    *
    * @param size The number of elements to be tracked.
    */
-
-  void Init(size_t size) {
-
+  void Init(size_t size)
+  {
     number_of_elements_ = size;
     parent_.set_size(number_of_elements_);
     rank_.set_size(number_of_elements_);
-    for (size_t i = 0; i < number_of_elements_; i++) {
+    for (size_t i = 0; i < number_of_elements_; i++)
+    {
       parent_[i] = i;
       rank_[i] = 0;
     }
-
   }
 
   /**
@@ -61,17 +55,18 @@
    * @param x the component to be found
    * @return The index of the component containing x
    */
-  size_t Find(size_t x) {
-
-    if (parent_[x] == x) {
+  size_t Find(size_t x)
+  {
+    if (parent_[x] == x)
+    {
       return x;
     }
-    else {
+    else
+    {
       // This ensures that the tree has a small depth
       parent_[x] = Find(parent_[x]);
       return parent_[x];
     }
-
   }
 
   /**
@@ -82,29 +77,31 @@
    * @param x one component
    * @param y the other component
    */
-  void Union(size_t x, size_t y) {
-
+  void Union(size_t x, size_t y)
+  {
     size_t x_root = Find(x);
     size_t y_root = Find(y);
 
-    if (x_root == y_root) {
+    if (x_root == y_root)
+    {
       return;
     }
-    else if (rank_[x_root] == rank_[y_root]) {
+    else if (rank_[x_root] == rank_[y_root])
+    {
       parent_[y_root] = parent_[x_root];
       rank_[x_root] = rank_[x_root] + 1;
     }
-    else if (rank_[x_root] > rank_[y_root]) {
+    else if (rank_[x_root] > rank_[y_root])
+    {
       parent_[y_root] = x_root;
     }
-    else {
+    else
+    {
       parent_[x_root] = y_root;
     }
-
   }
+}; // class UnionFind
 
-}; //class UnionFind
-
 }; // namespace emst
 }; // namespace mlpack
 




More information about the mlpack-svn mailing list