[mlpack-svn] r10605 - mlpack/trunk/src/mlpack/methods/emst

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 6 18:51:56 EST 2011


Author: march
Date: 2011-12-06 18:51:55 -0500 (Tue, 06 Dec 2011)
New Revision: 10605

Added:
   mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
   mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp
Removed:
   mlpack/trunk/src/mlpack/methods/emst/emst.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
Log:
split dtb into _impl.hpp, should finish off ticket 117

Modified: mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt	2011-12-06 23:10:05 UTC (rev 10604)
+++ mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt	2011-12-06 23:51:55 UTC (rev 10605)
@@ -7,7 +7,8 @@
    union_find.hpp
    # dtb
    dtb.hpp
-   emst.hpp
+   dtb_impl.hpp
+   edge_pair.hpp
 )
 
 # Add directory name to sources.

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	2011-12-06 23:10:05 UTC (rev 10604)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	2011-12-06 23:51:55 UTC (rev 10605)
@@ -13,7 +13,7 @@
 #ifndef __MLPACK_METHODS_EMST_DTB_HPP
 #define __MLPACK_METHODS_EMST_DTB_HPP
 
-#include "emst.hpp"
+#include "edge_pair.hpp"
 
 #include <mlpack/core.hpp>
 #include <mlpack/core/tree/bounds.hpp>
@@ -41,69 +41,30 @@
   int component_membership_;
 
  public:
-  void set_max_neighbor_distance(double distance)
-  {
-    max_neighbor_distance_ = distance;
-  }
+  void set_max_neighbor_distance(double distance);
 
-  double max_neighbor_distance()
-  {
-    return max_neighbor_distance_;
-  }
+  double max_neighbor_distance();
 
-  void set_component_membership(int membership)
-  {
-    component_membership_ = membership;
-  }
+  void set_component_membership(int membership);
 
-  int component_membership()
-  {
-    return component_membership_;
-  }
+  int component_membership();
 
   /**
    * A generic initializer.
    */
-  DTBStat()
-  {
-    set_max_neighbor_distance(DBL_MAX);
-    set_component_membership(-1);
-  }
+  DTBStat();
 
   /**
    * An initializer for leaves.
    */
-  DTBStat(const arma::mat& dataset, size_t start, size_t count)
-  {
-    if (count == 1)
-    {
-      set_component_membership(start);
-      set_max_neighbor_distance(DBL_MAX);
-    }
-    else
-    {
-      set_max_neighbor_distance(DBL_MAX);
-      set_component_membership(-1);
-    }
-  }
+  DTBStat(const arma::mat& dataset, size_t start, size_t count);
 
   /**
    * An initializer for non-leaves.  Simply calls the leaf initializer.
    */
   DTBStat(const arma::mat& dataset, size_t start, size_t count,
-          const DTBStat& left_stat, const DTBStat& right_stat)
-  {
-    if (count == 1)
-    {
-      set_component_membership(start);
-      set_max_neighbor_distance(DBL_MAX);
-    }
-    else
-    {
-      set_max_neighbor_distance(DBL_MAX);
-      set_component_membership(-1);
-    }
-  }
+          const DTBStat& left_stat, const DTBStat& right_stat);
+  
 }; // class DTBStat
 
 /**
@@ -148,495 +109,98 @@
 
   DTBTree* tree_;
 
+  // for sorting the edge list after the computation
+  struct SortEdgesHelper_
+  {
+    bool operator() (const EdgePair& pairA, const EdgePair& pairB)
+    {
+      return (pairA.distance() < pairB.distance());
+    }
+  } SortFun;
+  
+
 ////////////////// Constructors ////////////////////////
  public:
   DualTreeBoruvka() { }
 
-  ~DualTreeBoruvka()
-  {
-    if (tree_ != NULL)
-      delete tree_;
-  }
+  ~DualTreeBoruvka();
 
   ////////////////////////// Private Functions ////////////////////
  private:
   /**
    * Adds a single edge to the edge list
    */
-  void AddEdge_(size_t e1, size_t e2, double distance)
-  {
-    //EdgePair edge;
-    mlpack::Log::Assert((e1 != e2),
-        "Indices are equal in DualTreeBoruvka.add_edge(...)");
-
-    mlpack::Log::Assert((distance >= 0.0),
-        "Negative distance input in DualTreeBoruvka.add_edge(...)");
-
-    if (e1 < e2)
-      edges_[number_of_edges_].Init(e1, e2, distance);
-    else
-      edges_[number_of_edges_].Init(e2, e1, distance);
-
-    number_of_edges_++;
-
-  } // AddEdge_
-
+  void AddEdge_(size_t e1, size_t e2, double distance);
+  
   /**
    * Adds all the edges found in one iteration to the list of neighbors.
    */
-  void AddAllEdges_()
-  {
-    for (size_t i = 0; i < number_of_points_; i++)
-    {
-      size_t component_i = connections_.Find(i);
-      size_t in_edge_i = neighbors_in_component_[component_i];
-      size_t out_edge_i = neighbors_out_component_[component_i];
-      if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i))
-      {
-        double dist = neighbors_distances_[component_i];
-        //total_dist_ = total_dist_ + dist;
-        // changed to make this agree with the cover tree code
-        total_dist_ = total_dist_ + sqrt(dist);
-        AddEdge_(in_edge_i, out_edge_i, dist);
-        connections_.Union(in_edge_i, out_edge_i);
-      }
-    }
-  } // AddAllEdges_
-
-
+  void AddAllEdges_();
+  
   /**
    * Handles the base case computation.  Also called by naive.
    */
   double ComputeBaseCase_(size_t query_start, size_t query_end,
-                          size_t reference_start, size_t reference_end)
-  {
-    number_leaf_computations_++;
-
-    double new_upper_bound = -1.0;
-
-    for (size_t query_index = query_start; query_index < query_end;
-         query_index++)
-    {
-      // Find the index of the component the query is in
-      size_t query_component_index = connections_.Find(query_index);
-
-      arma::vec query_point = data_points_.col(query_index);
-
-      for (size_t reference_index = reference_start;
-           reference_index < reference_end; reference_index++)
-      {
-        size_t reference_component_index = connections_.Find(reference_index);
-
-        if (query_component_index != reference_component_index)
-        {
-          arma::vec reference_point = data_points_.col(reference_index);
-
-          double distance = mlpack::metric::LMetric<2>::Evaluate(query_point,
-              reference_point);
-
-          if (distance < neighbors_distances_[query_component_index])
-          {
-            mlpack::Log::Assert(query_index != reference_index);
-
-            neighbors_distances_[query_component_index] = distance;
-            neighbors_in_component_[query_component_index] = query_index;
-            neighbors_out_component_[query_component_index] = reference_index;
-          } // if distance
-        } // if indices not equal
-      } // for reference_index
-
-      if (new_upper_bound < neighbors_distances_[query_component_index])
-        new_upper_bound = neighbors_distances_[query_component_index];
-
-    } // for query_index
-
-    mlpack::Log::Assert(new_upper_bound >= 0.0);
-    return new_upper_bound;
-
-  } // ComputeBaseCase_
-
-
+                          size_t reference_start, size_t reference_end);
+  
   /**
    * Handles the recursive calls to find the nearest neighbors in an iteration
    */
   void ComputeNeighborsRecursion_(DTBTree *query_node, DTBTree *reference_node,
-                                  double incoming_distance)
-  {
-    // Check for a distance prune.
-    if (query_node->Stat().max_neighbor_distance() < incoming_distance)
-    {
-      // Pruned by distance.
-      number_distance_prunes_++;
-    }
-    // Check for a component prune.
-    else if ((query_node->Stat().component_membership() >= 0)
-        && (query_node->Stat().component_membership() ==
-            reference_node->Stat().component_membership()))
-    {
-      // Pruned by component membership.
-      mlpack::Log::Assert(reference_node->Stat().component_membership() >= 0);
-      mlpack::Log::Info << query_node->Stat().component_membership()
-          << "q mem\n";
-      mlpack::Log::Info << reference_node->Stat().component_membership()
-          << "r mem\n";
-
-      number_component_prunes_++;
-    }
-    else if (query_node->IsLeaf() && reference_node->IsLeaf()) // Base case.
-    {
-      double new_bound = ComputeBaseCase_(query_node->Begin(),
-          query_node->End(), reference_node->Begin(), reference_node->End());
-
-      query_node->Stat().set_max_neighbor_distance(new_bound);
-    }
-    else if (query_node->IsLeaf()) // Other recursive calls.
-    {
-      // Recurse on reference_node only.
-      number_r_recursions_++;
-
-      double left_dist =
-          query_node->Bound().MinDistance(reference_node->Left()->Bound());
-      double right_dist =
-          query_node->Bound().MinDistance(reference_node->Right()->Bound());
-      mlpack::Log::Assert(left_dist >= 0.0);
-      mlpack::Log::Assert(right_dist >= 0.0);
-
-      if (left_dist < right_dist)
-      {
-        ComputeNeighborsRecursion_(query_node, reference_node->Left(),
-            left_dist);
-        ComputeNeighborsRecursion_(query_node, reference_node->Right(),
-            right_dist);
-      }
-      else
-      {
-        ComputeNeighborsRecursion_(query_node, reference_node->Right(),
-            right_dist);
-        ComputeNeighborsRecursion_(query_node, reference_node->Left(),
-            left_dist);
-      }
-    }
-    else if (reference_node->IsLeaf())
-    {
-      // Recurse on query_node only.
-      number_q_recursions_++;
-
-      double left_dist =
-          query_node->Left()->Bound().MinDistance(reference_node->Bound());
-      double right_dist =
-          query_node->Right()->Bound().MinDistance(reference_node->Bound());
-
-      ComputeNeighborsRecursion_(query_node->Left(), reference_node, left_dist);
-      ComputeNeighborsRecursion_(query_node->Right(), reference_node,
-          right_dist);
-
-      // Update query_node's stat.
-      query_node->Stat().set_max_neighbor_distance(
-          std::max(query_node->Left()->Stat().max_neighbor_distance(),
-          query_node->Right()->Stat().max_neighbor_distance()));
-
-    }
-    else
-    {
-      // Recurse on both.
-      number_both_recursions_++;
-
-      double left_dist = query_node->Left()->Bound().MinDistance(
-          reference_node->Left()->Bound());
-      double right_dist = query_node->Left()->Bound().MinDistance(
-          reference_node->Right()->Bound());
-
-      if (left_dist < right_dist)
-      {
-        ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
-            left_dist);
-        ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
-            right_dist);
-      }
-      else
-      {
-        ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
-            right_dist);
-        ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
-            left_dist);
-      }
-
-      left_dist = query_node->Right()->Bound().MinDistance(
-          reference_node->Left()->Bound());
-      right_dist = query_node->Right()->Bound().MinDistance(
-          reference_node->Right()->Bound());
-
-      if (left_dist < right_dist)
-      {
-        ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
-            left_dist);
-        ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
-            right_dist);
-      }
-      else
-      {
-        ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
-            right_dist);
-        ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
-            left_dist);
-      }
-
-      query_node->Stat().set_max_neighbor_distance(
-          std::max(query_node->Left()->Stat().max_neighbor_distance(),
-          query_node->Right()->Stat().max_neighbor_distance()));
-    }
-  } // ComputeNeighborsRecursion_
-
+                                  double incoming_distance);
+  
   /**
    * Computes the nearest neighbor of each point in each iteration
    * of the algorithm
    */
-  void ComputeNeighbors_()
-  {
-    if (do_naive_)
-    {
-      ComputeBaseCase_(0, number_of_points_, 0, number_of_points_);
-    }
-    else
-    {
-      ComputeNeighborsRecursion_(tree_, tree_, DBL_MAX);
-    }
-  } // ComputeNeighbors_
+  void ComputeNeighbors_();
 
-  struct SortEdgesHelper_
-  {
-    bool operator() (const EdgePair& pairA, const EdgePair& pairB)
-    {
-      return (pairA.distance() < pairB.distance());
-    }
-  } SortFun;
-
-  void SortEdges_()
-  {
-    std::sort(edges_.begin(), edges_.end(), SortFun);
-  } // SortEdges_()
-
+  
+  void SortEdges_();
+  
   /**
    * Unpermute the edge list and output it to results
    *
-   * TODO: Make this sort the edge list by distance as well for hierarchical
-   * clusterings.
    */
-  void EmitResults_(arma::mat& results)
-  {
-    SortEdges_();
+  void EmitResults_(arma::mat& results);
 
-    mlpack::Log::Assert(number_of_edges_ == number_of_points_ - 1);
-    results.set_size(number_of_edges_, 3);
-
-    // Need to unpermute the point labels.
-    if (!do_naive_)
-    {
-      for (size_t i = 0; i < (number_of_points_ - 1); i++)
-      {
-        // Make sure the edge list stores the smaller index first to
-        // make checking correctness easier
-        size_t ind1, ind2;
-        ind1 = old_from_new_permutation_[edges_[i].lesser_index()];
-        ind2 = old_from_new_permutation_[edges_[i].greater_index()];
-
-        edges_[i].set_lesser_index(std::min(ind1, ind2));
-        edges_[i].set_greater_index(std::max(ind1, ind2));
-
-        results(i, 0) = edges_[i].lesser_index();
-        results(i, 1) = edges_[i].greater_index();
-        results(i, 2) = sqrt(edges_[i].distance());
-      }
-    }
-    else
-    {
-      for (size_t i = 0; i < number_of_edges_; i++)
-      {
-        results(i, 0) = edges_[i].lesser_index();
-        results(i, 1) = edges_[i].greater_index();
-        results(i, 2) = sqrt(edges_[i].distance());
-      }
-    }
-  } // EmitResults_
-
   /**
    * This function resets the values in the nodes of the tree nearest neighbor
    * distance, check for fully connected nodes
    */
-  void CleanupHelper_(DTBTree* tree)
-  {
-    tree->Stat().set_max_neighbor_distance(DBL_MAX);
+  void CleanupHelper_(DTBTree* tree);
 
-    if (!tree->IsLeaf())
-    {
-      CleanupHelper_(tree->Left());
-      CleanupHelper_(tree->Right());
-
-      if ((tree->Left()->Stat().component_membership() >= 0)
-          && (tree->Left()->Stat().component_membership() ==
-              tree->Right()->Stat().component_membership()))
-      {
-        tree->Stat().set_component_membership(tree->Left()->Stat().
-            component_membership());
-      }
-    }
-    else
-    {
-      size_t new_membership = connections_.Find(tree->Begin());
-
-      for (size_t i = tree->Begin(); i < tree->End(); i++)
-      {
-        if (new_membership != connections_.Find(i))
-        {
-          new_membership = -1;
-          mlpack::Log::Assert(tree->Stat().component_membership() < 0);
-          return;
-        }
-      }
-      tree->Stat().set_component_membership(new_membership);
-    }
-  } // CleanupHelper_
-
   /**
    * The values stored in the tree must be reset on each iteration.
    */
-  void Cleanup_()
-  {
-    for (size_t i = 0; i < number_of_points_; i++)
-    {
-      neighbors_distances_[i] = DBL_MAX;
-    }
-    number_of_loops_++;
-
-    if (!do_naive_)
-    {
-      CleanupHelper_(tree_);
-    }
-  }
-
+  void Cleanup_();
+  
   /**
    * Format and output the results
    */
-  void OutputResults_()
-  {
-    /* fx_result_double(module_, "total_squared_length", total_dist_);
-    fx_result_int(module_, "number_of_points", number_of_points_);
-    fx_result_int(module_, "dimension", data_points_.n_rows);
-    fx_result_int(module_, "number_of_loops", number_of_loops_);
-    fx_result_int(module_, "number_distance_prunes", number_distance_prunes_);
-    fx_result_int(module_, "number_component_prunes", number_component_prunes_);
-    fx_result_int(module_, "number_leaf_computations", number_leaf_computations_);
-    fx_result_int(module_, "number_q_recursions", number_q_recursions_);
-    fx_result_int(module_, "number_r_recursions", number_r_recursions_);
-    fx_result_int(module_, "number_both_recursions", number_both_recursions_);*/
-    // TODO, not sure how I missed this last time.
-    mlpack::Log::Info << "Total squared length: " << total_dist_ << std::endl;
-    mlpack::Log::Info << "Number of points: " << number_of_points_ << std::endl;
-    mlpack::Log::Info << "Dimension: " << data_points_.n_rows << std::endl;
-    /*
-    mlpack::Log::Info << "number_of_loops" << std::endl;
-    mlpack::Log::Info << "number_distance_prunes" << std::endl;
-    mlpack::Log::Info << "number_component_prunes" << std::endl;
-    mlpack::Log::Info << "number_leaf_computations" << std::endl;
-    mlpack::Log::Info << "number_q_recursions" << std::endl;
-    mlpack::Log::Info << "number_r_recursions" << std::endl;
-    mlpack::Log::Info << "number_both_recursions" << std::endl;
-     */
-
-  } // OutputResults_
-
+  void OutputResults_();
+  
   /////////// Public Functions ///////////////////
  public:
-  size_t number_of_edges()
-  {
-    return number_of_edges_;
-  }
+  size_t number_of_edges();
 
   /**
    * Takes in a reference to the data set.  Copies the data, builds the tree,
    * and initializes all of the member variables.
    */
-  void Init(const arma::mat& data, bool naive = false, size_t leafSize = 1)
-  {
-    number_of_edges_ = 0;
-    data_points_ = data; // copy
-
-    do_naive_ = naive;
-
-    if (!do_naive_)
-    {
-      // Default leaf size is 1
-      // This gives best pruning empirically
-      // Use leaf_size=1 unless space is a big concern
-      Timers::StartTimer("emst/tree_building");
-
-      tree_ = new DTBTree(data_points_, old_from_new_permutation_, leafSize);
-
-      Timers::StopTimer("emst/tree_building");
-    }
-    else
-    {
-      tree_ = NULL;
-      old_from_new_permutation_.resize(0);
-    }
-
-    number_of_points_ = data_points_.n_cols;
-    edges_.resize(number_of_points_ - 1, EdgePair()); // fill with EdgePairs
-    connections_.Init(number_of_points_);
-
-    neighbors_in_component_.set_size(number_of_points_);
-    neighbors_out_component_.set_size(number_of_points_);
-    neighbors_distances_.set_size(number_of_points_);
-    neighbors_distances_.fill(DBL_MAX);
-
-    total_dist_ = 0.0;
-    number_of_loops_ = 0;
-    number_distance_prunes_ = 0;
-    number_component_prunes_ = 0;
-    number_leaf_computations_ = 0;
-    number_q_recursions_ = 0;
-    number_r_recursions_ = 0;
-    number_both_recursions_ = 0;
-  } // Init
-
+  void Init(const arma::mat& data, bool naive, size_t leafSize);
+  
   /**
    * Call this function after Init.  It will iteratively find the nearest
    * neighbor of each component until the MST is complete.
    */
-  void ComputeMST(arma::mat& results)
-  {
-    Timers::StartTimer("emst/MST_computation");
-
-    while (number_of_edges_ < (number_of_points_ - 1))
-    {
-      ComputeNeighbors_();
-
-      AddAllEdges_();
-
-      Cleanup_();
-
-      Log::Info << "Finished loop number: " << number_of_loops_ << std::endl;
-      Log::Info << number_of_edges_ << " edges found so far.\n\n";
-      /*
-      Log::Info << number_leaf_computations_ << " base cases.\n";
-      Log::Info << number_distance_prunes_ << " distance prunes.\n";
-      Log::Info << number_component_prunes_ << " component prunes.\n";
-      Log::Info << number_r_recursions_ << " reference recursions.\n";
-      Log::Info << number_q_recursions_ << " query recursions.\n";
-      Log::Info << number_both_recursions_ << " dual recursions.\n\n";
-      */
-    }
-
-    Timers::StopTimer("emst/MST_computation");
-
-    EmitResults_(results);
-
-    OutputResults_();
-  } // ComputeMST
-
+  void ComputeMST(arma::mat& results);
+  
 }; // class DualTreeBoruvka
 
 }; // namespace emst
 }; // namespace mlpack
 
+#include "dtb_impl.hpp"
+
 #endif // __MLPACK_METHODS_EMST_DTB_HPP

Added: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp	2011-12-06 23:51:55 UTC (rev 10605)
@@ -0,0 +1,558 @@
+/*
+ *  dtb_impl.hpp
+ *  
+ *
+ *  Created by William March on 12/6/11.
+ *  Copyright 2011 __MyCompanyName__. All rights reserved.
+ *
+ */
+
+#ifndef __MLPACK_METHODS_EMST_DTB_IMPL_HPP
+#define __MLPACK_METHODS_EMST_DTB_IMPL_HPP
+
+#include <mlpack/core.hpp>
+
+
+
+using namespace mlpack::emst;
+
+void DTBStat::set_max_neighbor_distance(double distance)
+{
+  max_neighbor_distance_ = distance;
+}
+
+double DTBStat::max_neighbor_distance()
+{
+  return max_neighbor_distance_;
+}
+
+void DTBStat::set_component_membership(int membership)
+{
+  component_membership_ = membership;
+}
+
+int DTBStat::component_membership()
+{
+  return component_membership_;
+}
+
+/**
+ * A generic initializer.
+ */
+DTBStat::DTBStat()
+{
+  set_max_neighbor_distance(DBL_MAX);
+  set_component_membership(-1);
+}
+
+/**
+ * An initializer for leaves.
+ */
+DTBStat::DTBStat(const arma::mat& dataset, size_t start, size_t count)
+{
+  if (count == 1)
+  {
+    set_component_membership(start);
+    set_max_neighbor_distance(DBL_MAX);
+  }
+  else
+  {
+    set_max_neighbor_distance(DBL_MAX);
+    set_component_membership(-1);
+  }
+}
+
+/**
+ * An initializer for non-leaves.  Simply calls the leaf initializer.
+ */
+DTBStat::DTBStat(const arma::mat& dataset, size_t start, size_t count,
+        const DTBStat& left_stat, const DTBStat& right_stat)
+{
+  if (count == 1)
+  {
+    set_component_membership(start);
+    set_max_neighbor_distance(DBL_MAX);
+  }
+  else
+  {
+    set_max_neighbor_distance(DBL_MAX);
+    set_component_membership(-1);
+  }
+}
+
+  DualTreeBoruvka::~DualTreeBoruvka()
+  {
+    if (tree_ != NULL)
+      delete tree_;
+  }
+
+/**
+ * Adds a single edge to the edge list
+ */
+void DualTreeBoruvka::AddEdge_(size_t e1, size_t e2, double distance)
+{
+  //EdgePair edge;
+  mlpack::Log::Assert((e1 != e2),
+                      "Indices are equal in DualTreeBoruvka.add_edge(...)");
+  
+  mlpack::Log::Assert((distance >= 0.0),
+                      "Negative distance input in DualTreeBoruvka.add_edge(...)");
+  
+  if (e1 < e2)
+    edges_[number_of_edges_].Init(e1, e2, distance);
+  else
+    edges_[number_of_edges_].Init(e2, e1, distance);
+  
+  number_of_edges_++;
+  
+} // AddEdge_
+
+/**
+ * Adds all the edges found in one iteration to the list of neighbors.
+ */
+void DualTreeBoruvka::AddAllEdges_()
+{
+  for (size_t i = 0; i < number_of_points_; i++)
+  {
+    size_t component_i = connections_.Find(i);
+    size_t in_edge_i = neighbors_in_component_[component_i];
+    size_t out_edge_i = neighbors_out_component_[component_i];
+    if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i))
+    {
+      double dist = neighbors_distances_[component_i];
+      //total_dist_ = total_dist_ + dist;
+      // changed to make this agree with the cover tree code
+      total_dist_ = total_dist_ + sqrt(dist);
+      AddEdge_(in_edge_i, out_edge_i, dist);
+      connections_.Union(in_edge_i, out_edge_i);
+    }
+  }
+} // AddAllEdges_
+
+
+/**
+ * Handles the base case computation.  Also called by naive.
+ */
+double DualTreeBoruvka::ComputeBaseCase_(size_t query_start, size_t query_end,
+                                         size_t reference_start, 
+                                         size_t reference_end)
+{
+  number_leaf_computations_++;
+  
+  double new_upper_bound = -1.0;
+  
+  for (size_t query_index = query_start; query_index < query_end;
+       query_index++)
+  {
+    // Find the index of the component the query is in
+    size_t query_component_index = connections_.Find(query_index);
+    
+    arma::vec query_point = data_points_.col(query_index);
+    
+    for (size_t reference_index = reference_start;
+         reference_index < reference_end; reference_index++)
+    {
+      size_t reference_component_index = connections_.Find(reference_index);
+      
+      if (query_component_index != reference_component_index)
+      {
+        arma::vec reference_point = data_points_.col(reference_index);
+        
+        double distance = mlpack::metric::LMetric<2>::Evaluate(query_point,
+                                                               reference_point);
+        
+        if (distance < neighbors_distances_[query_component_index])
+        {
+          mlpack::Log::Assert(query_index != reference_index);
+          
+          neighbors_distances_[query_component_index] = distance;
+          neighbors_in_component_[query_component_index] = query_index;
+          neighbors_out_component_[query_component_index] = reference_index;
+        } // if distance
+      } // if indices not equal
+    } // for reference_index
+    
+    if (new_upper_bound < neighbors_distances_[query_component_index])
+      new_upper_bound = neighbors_distances_[query_component_index];
+    
+  } // for query_index
+  
+  mlpack::Log::Assert(new_upper_bound >= 0.0);
+  return new_upper_bound;
+  
+} // ComputeBaseCase_
+
+
+/**
+ * Handles the recursive calls to find the nearest neighbors in an iteration
+ */
+void DualTreeBoruvka::ComputeNeighborsRecursion_(DTBTree *query_node, 
+                                                 DTBTree *reference_node,
+                                                 double incoming_distance)
+{
+  // Check for a distance prune.
+  if (query_node->Stat().max_neighbor_distance() < incoming_distance)
+  {
+    // Pruned by distance.
+    number_distance_prunes_++;
+  }
+  // Check for a component prune.
+  else if ((query_node->Stat().component_membership() >= 0)
+           && (query_node->Stat().component_membership() ==
+               reference_node->Stat().component_membership()))
+  {
+    // Pruned by component membership.
+    mlpack::Log::Assert(reference_node->Stat().component_membership() >= 0);
+    mlpack::Log::Info << query_node->Stat().component_membership()
+    << "q mem\n";
+    mlpack::Log::Info << reference_node->Stat().component_membership()
+    << "r mem\n";
+    
+    number_component_prunes_++;
+  }
+  else if (query_node->IsLeaf() && reference_node->IsLeaf()) // Base case.
+  {
+    double new_bound = ComputeBaseCase_(query_node->Begin(),
+                                        query_node->End(), reference_node->Begin(), reference_node->End());
+    
+    query_node->Stat().set_max_neighbor_distance(new_bound);
+  }
+  else if (query_node->IsLeaf()) // Other recursive calls.
+  {
+    // Recurse on reference_node only.
+    number_r_recursions_++;
+    
+    double left_dist =
+    query_node->Bound().MinDistance(reference_node->Left()->Bound());
+    double right_dist =
+    query_node->Bound().MinDistance(reference_node->Right()->Bound());
+    mlpack::Log::Assert(left_dist >= 0.0);
+    mlpack::Log::Assert(right_dist >= 0.0);
+    
+    if (left_dist < right_dist)
+    {
+      ComputeNeighborsRecursion_(query_node, reference_node->Left(),
+                                 left_dist);
+      ComputeNeighborsRecursion_(query_node, reference_node->Right(),
+                                 right_dist);
+    }
+    else
+    {
+      ComputeNeighborsRecursion_(query_node, reference_node->Right(),
+                                 right_dist);
+      ComputeNeighborsRecursion_(query_node, reference_node->Left(),
+                                 left_dist);
+    }
+  }
+  else if (reference_node->IsLeaf())
+  {
+    // Recurse on query_node only.
+    number_q_recursions_++;
+    
+    double left_dist =
+    query_node->Left()->Bound().MinDistance(reference_node->Bound());
+    double right_dist =
+    query_node->Right()->Bound().MinDistance(reference_node->Bound());
+    
+    ComputeNeighborsRecursion_(query_node->Left(), reference_node, left_dist);
+    ComputeNeighborsRecursion_(query_node->Right(), reference_node,
+                               right_dist);
+    
+    // Update query_node's stat.
+    query_node->Stat().set_max_neighbor_distance(
+                                                 std::max(query_node->Left()->Stat().max_neighbor_distance(),
+                                                          query_node->Right()->Stat().max_neighbor_distance()));
+    
+  }
+  else
+  {
+    // Recurse on both.
+    number_both_recursions_++;
+    
+    double left_dist = query_node->Left()->Bound().MinDistance(
+                                                               reference_node->Left()->Bound());
+    double right_dist = query_node->Left()->Bound().MinDistance(
+                                                                reference_node->Right()->Bound());
+    
+    if (left_dist < right_dist)
+    {
+      ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
+                                 left_dist);
+      ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
+                                 right_dist);
+    }
+    else
+    {
+      ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
+                                 right_dist);
+      ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
+                                 left_dist);
+    }
+    
+    left_dist = query_node->Right()->Bound().MinDistance(
+                                                         reference_node->Left()->Bound());
+    right_dist = query_node->Right()->Bound().MinDistance(
+                                                          reference_node->Right()->Bound());
+    
+    if (left_dist < right_dist)
+    {
+      ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
+                                 left_dist);
+      ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
+                                 right_dist);
+    }
+    else
+    {
+      ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
+                                 right_dist);
+      ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
+                                 left_dist);
+    }
+    
+    query_node->Stat().set_max_neighbor_distance(
+                                                 std::max(query_node->Left()->Stat().max_neighbor_distance(),
+                                                          query_node->Right()->Stat().max_neighbor_distance()));
+  }
+} // ComputeNeighborsRecursion_
+
+/**
+ * Computes the nearest neighbor of each point in each iteration
+ * of the algorithm
+ */
+void DualTreeBoruvka::ComputeNeighbors_()
+{
+  if (do_naive_)
+  {
+    ComputeBaseCase_(0, number_of_points_, 0, number_of_points_);
+  }
+  else
+  {
+    ComputeNeighborsRecursion_(tree_, tree_, DBL_MAX);
+  }
+} // ComputeNeighbors_
+
+void DualTreeBoruvka::SortEdges_()
+{
+  std::sort(edges_.begin(), edges_.end(), SortFun);
+} // SortEdges_()
+
+/**
+ * Unpermute the edge list and output it to results
+ *
+ */
+void DualTreeBoruvka::EmitResults_(arma::mat& results)
+{
+  SortEdges_();
+  
+  mlpack::Log::Assert(number_of_edges_ == number_of_points_ - 1);
+  results.set_size(number_of_edges_, 3);
+  
+  // Need to unpermute the point labels.
+  if (!do_naive_)
+  {
+    for (size_t i = 0; i < (number_of_points_ - 1); i++)
+    {
+      // Make sure the edge list stores the smaller index first to
+      // make checking correctness easier
+      size_t ind1, ind2;
+      ind1 = old_from_new_permutation_[edges_[i].lesser_index()];
+      ind2 = old_from_new_permutation_[edges_[i].greater_index()];
+      
+      edges_[i].set_lesser_index(std::min(ind1, ind2));
+      edges_[i].set_greater_index(std::max(ind1, ind2));
+      
+      results(i, 0) = edges_[i].lesser_index();
+      results(i, 1) = edges_[i].greater_index();
+      results(i, 2) = sqrt(edges_[i].distance());
+    }
+  }
+  else
+  {
+    for (size_t i = 0; i < number_of_edges_; i++)
+    {
+      results(i, 0) = edges_[i].lesser_index();
+      results(i, 1) = edges_[i].greater_index();
+      results(i, 2) = sqrt(edges_[i].distance());
+    }
+  }
+} // EmitResults_
+
+/**
+ * This function resets the values in the nodes of the tree nearest neighbor
+ * distance, check for fully connected nodes
+ */
+void DualTreeBoruvka::CleanupHelper_(DTBTree* tree)
+{
+  tree->Stat().set_max_neighbor_distance(DBL_MAX);
+  
+  if (!tree->IsLeaf())
+  {
+    CleanupHelper_(tree->Left());
+    CleanupHelper_(tree->Right());
+    
+    if ((tree->Left()->Stat().component_membership() >= 0)
+        && (tree->Left()->Stat().component_membership() ==
+            tree->Right()->Stat().component_membership()))
+    {
+      tree->Stat().set_component_membership(tree->Left()->Stat().
+                                            component_membership());
+    }
+  }
+  else
+  {
+    size_t new_membership = connections_.Find(tree->Begin());
+    
+    for (size_t i = tree->Begin(); i < tree->End(); i++)
+    {
+      if (new_membership != connections_.Find(i))
+      {
+        new_membership = -1;
+        mlpack::Log::Assert(tree->Stat().component_membership() < 0);
+        return;
+      }
+    }
+    tree->Stat().set_component_membership(new_membership);
+  }
+} // CleanupHelper_
+
+/**
+ * The values stored in the tree must be reset on each iteration.
+ */
+void DualTreeBoruvka::Cleanup_()
+{
+  for (size_t i = 0; i < number_of_points_; i++)
+  {
+    neighbors_distances_[i] = DBL_MAX;
+  }
+  number_of_loops_++;
+  
+  if (!do_naive_)
+  {
+    CleanupHelper_(tree_);
+  }
+}
+
+/**
+ * Format and output the results
+ */
+void DualTreeBoruvka::OutputResults_()
+{
+  /* fx_result_double(module_, "total_squared_length", total_dist_);
+   fx_result_int(module_, "number_of_points", number_of_points_);
+   fx_result_int(module_, "dimension", data_points_.n_rows);
+   fx_result_int(module_, "number_of_loops", number_of_loops_);
+   fx_result_int(module_, "number_distance_prunes", number_distance_prunes_);
+   fx_result_int(module_, "number_component_prunes", number_component_prunes_);
+   fx_result_int(module_, "number_leaf_computations", number_leaf_computations_);
+   fx_result_int(module_, "number_q_recursions", number_q_recursions_);
+   fx_result_int(module_, "number_r_recursions", number_r_recursions_);
+   fx_result_int(module_, "number_both_recursions", number_both_recursions_);*/
+  // TODO, not sure how I missed this last time.
+  mlpack::Log::Info << "Total squared length: " << total_dist_ << std::endl;
+  mlpack::Log::Info << "Number of points: " << number_of_points_ << std::endl;
+  mlpack::Log::Info << "Dimension: " << data_points_.n_rows << std::endl;
+  /*
+   mlpack::Log::Info << "number_of_loops" << std::endl;
+   mlpack::Log::Info << "number_distance_prunes" << std::endl;
+   mlpack::Log::Info << "number_component_prunes" << std::endl;
+   mlpack::Log::Info << "number_leaf_computations" << std::endl;
+   mlpack::Log::Info << "number_q_recursions" << std::endl;
+   mlpack::Log::Info << "number_r_recursions" << std::endl;
+   mlpack::Log::Info << "number_both_recursions" << std::endl;
+   */
+  
+} // OutputResults_
+
+size_t DualTreeBoruvka::number_of_edges()
+{
+  return number_of_edges_;
+}
+
+/**
+ * Takes in a reference to the data set.  Copies the data, builds the tree,
+ * and initializes all of the member variables.
+ */
+void DualTreeBoruvka::Init(const arma::mat& data, bool naive = false, 
+                           size_t leafSize = 1)
+{
+  number_of_edges_ = 0;
+  data_points_ = data; // copy
+  
+  do_naive_ = naive;
+  
+  if (!do_naive_)
+  {
+    // Default leaf size is 1
+    // This gives best pruning empirically
+    // Use leaf_size=1 unless space is a big concern
+    Timers::StartTimer("emst/tree_building");
+    
+    tree_ = new DTBTree(data_points_, old_from_new_permutation_, leafSize);
+    
+    Timers::StopTimer("emst/tree_building");
+  }
+  else
+  {
+    tree_ = NULL;
+    old_from_new_permutation_.resize(0);
+  }
+  
+  number_of_points_ = data_points_.n_cols;
+  edges_.resize(number_of_points_ - 1, EdgePair()); // fill with EdgePairs
+  connections_.Init(number_of_points_);
+  
+  neighbors_in_component_.set_size(number_of_points_);
+  neighbors_out_component_.set_size(number_of_points_);
+  neighbors_distances_.set_size(number_of_points_);
+  neighbors_distances_.fill(DBL_MAX);
+  
+  total_dist_ = 0.0;
+  number_of_loops_ = 0;
+  number_distance_prunes_ = 0;
+  number_component_prunes_ = 0;
+  number_leaf_computations_ = 0;
+  number_q_recursions_ = 0;
+  number_r_recursions_ = 0;
+  number_both_recursions_ = 0;
+} // Init
+
+/**
+ * Call this function after Init.  It will iteratively find the nearest
+ * neighbor of each component until the MST is complete.
+ */
+void DualTreeBoruvka::ComputeMST(arma::mat& results)
+{
+  Timers::StartTimer("emst/MST_computation");
+  
+  while (number_of_edges_ < (number_of_points_ - 1))
+  {
+    ComputeNeighbors_();
+    
+    AddAllEdges_();
+    
+    Cleanup_();
+    
+    Log::Info << "Finished loop number: " << number_of_loops_ << std::endl;
+    Log::Info << number_of_edges_ << " edges found so far.\n\n";
+    /*
+     Log::Info << number_leaf_computations_ << " base cases.\n";
+     Log::Info << number_distance_prunes_ << " distance prunes.\n";
+     Log::Info << number_component_prunes_ << " component prunes.\n";
+     Log::Info << number_r_recursions_ << " reference recursions.\n";
+     Log::Info << number_q_recursions_ << " query recursions.\n";
+     Log::Info << number_both_recursions_ << " dual recursions.\n\n";
+     */
+  }
+  
+  Timers::StopTimer("emst/MST_computation");
+  
+  EmitResults_(results);
+  
+  OutputResults_();
+} // ComputeMST
+
+
+
+
+
+#endif 
\ No newline at end of file

Copied: mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp (from rev 10604, mlpack/trunk/src/mlpack/methods/emst/emst.hpp)
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp	2011-12-06 23:51:55 UTC (rev 10605)
@@ -0,0 +1,80 @@
+/**
+ * @file emst.h
+ *
+ * @author Bill March (march at gatech.edu)
+ *
+ * This file contains utilities necessary for all of the minimum spanning tree
+ * algorithms.
+ */
+#ifndef __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
+#define __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
+
+#include <mlpack/core.hpp>
+
+#include "union_find.hpp"
+
+namespace mlpack {
+namespace emst {
+
+/**
+ * An edge pair is simply two indices and a distance.  It is used as the
+ * basic element of an edge list when computing a minimum spanning tree.
+ */
+class EdgePair
+{
+ private:
+  size_t lesser_index_;
+  size_t greater_index_;
+  double distance_;
+
+ public:
+  /**
+   * Initialize an EdgePair with two indices and a distance.  The indices are
+   * called lesser and greater, implying that they be sorted before calling
+   * Init.  However, this is not necessary for functionality; it is just a way
+   * to keep the edge list organized in other code.
+   */
+  void Init(size_t lesser, size_t greater, double dist)
+  {
+    mlpack::Log::Assert(lesser != greater,
+        "indices equal when creating EdgePair, lesser == greater");
+    lesser_index_ = lesser;
+    greater_index_ = greater;
+    distance_ = dist;
+  }
+
+  size_t lesser_index()
+  {
+    return lesser_index_;
+  }
+
+  void set_lesser_index(size_t index)
+  {
+    lesser_index_ = index;
+  }
+
+  size_t greater_index()
+  {
+    return greater_index_;
+  }
+
+  void set_greater_index(size_t index)
+  {
+    greater_index_ = index;
+  }
+
+  double distance() const
+  {
+    return distance_;
+  }
+
+  void set_distance(double new_dist)
+  {
+    distance_ = new_dist;
+  }
+}; // class EdgePair
+
+}; // namespace emst
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_EMST_EDGE_PAIR_HPP

Deleted: mlpack/trunk/src/mlpack/methods/emst/emst.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst.hpp	2011-12-06 23:10:05 UTC (rev 10604)
+++ mlpack/trunk/src/mlpack/methods/emst/emst.hpp	2011-12-06 23:51:55 UTC (rev 10605)
@@ -1,80 +0,0 @@
-/**
- * @file emst.h
- *
- * @author Bill March (march at gatech.edu)
- *
- * This file contains utilities necessary for all of the minimum spanning tree
- * algorithms.
- */
-#ifndef __MLPACK_METHODS_EMST_EMST_HPP
-#define __MLPACK_METHODS_EMST_EMST_HPP
-
-#include <mlpack/core.hpp>
-
-#include "union_find.hpp"
-
-namespace mlpack {
-namespace emst {
-
-/**
- * An edge pair is simply two indices and a distance.  It is used as the
- * basic element of an edge list when computing a minimum spanning tree.
- */
-class EdgePair
-{
- private:
-  size_t lesser_index_;
-  size_t greater_index_;
-  double distance_;
-
- public:
-  /**
-   * Initialize an EdgePair with two indices and a distance.  The indices are
-   * called lesser and greater, implying that they be sorted before calling
-   * Init.  However, this is not necessary for functionality; it is just a way
-   * to keep the edge list organized in other code.
-   */
-  void Init(size_t lesser, size_t greater, double dist)
-  {
-    mlpack::Log::Assert(lesser != greater,
-        "indices equal when creating EdgePair, lesser == greater");
-    lesser_index_ = lesser;
-    greater_index_ = greater;
-    distance_ = dist;
-  }
-
-  size_t lesser_index()
-  {
-    return lesser_index_;
-  }
-
-  void set_lesser_index(size_t index)
-  {
-    lesser_index_ = index;
-  }
-
-  size_t greater_index()
-  {
-    return greater_index_;
-  }
-
-  void set_greater_index(size_t index)
-  {
-    greater_index_ = index;
-  }
-
-  double distance() const
-  {
-    return distance_;
-  }
-
-  void set_distance(double new_dist)
-  {
-    distance_ = new_dist;
-  }
-}; // class EdgePair
-
-}; // namespace emst
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_EMST_EMST_HPP




More information about the mlpack-svn mailing list