[mlpack-svn] r10764 - mlpack/trunk/src/mlpack/methods/emst
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 14 05:53:00 EST 2011
Author: rcurtin
Date: 2011-12-14 05:53:00 -0500 (Wed, 14 Dec 2011)
New Revision: 10764
Modified:
mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp
mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
mlpack/trunk/src/mlpack/methods/emst/union_find.hpp
Log:
Refactor and clean up EMST code.
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2011-12-14 08:19:14 UTC (rev 10763)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2011-12-14 10:53:00 UTC (rev 10764)
@@ -1,6 +1,5 @@
/**
* @file dtb.hpp
- *
* @author Bill March (march at gatech.edu)
*
* Contains an implementation of the DualTreeBoruvka algorithm for finding a
@@ -37,18 +36,12 @@
class DTBStat
{
private:
- double max_neighbor_distance_;
- int component_membership_;
+ //! Maximum neighbor distance.
+ double maxNeighborDistance;
+ //! Component membership of this node.
+ int componentMembership;
public:
- void set_max_neighbor_distance(double distance);
-
- double max_neighbor_distance();
-
- void set_component_membership(int membership);
-
- int component_membership();
-
/**
* A generic initializer.
*/
@@ -67,137 +60,152 @@
DTBStat(const MatType& dataset, const size_t start, const size_t count,
const DTBStat& leftStat, const DTBStat& rightStat);
+ //! Get the maximum neighbor distance.
+ double MaxNeighborDistance() const { return maxNeighborDistance; }
+ //! Modify the maximum neighbor distance.
+ double& MaxNeighborDistance() { return maxNeighborDistance; }
+
+ //! Get the component membership of this node.
+ int ComponentMembership() const { return componentMembership; }
+ //! Modify the component membership of this node.
+ int& ComponentMembership() { return componentMembership; }
+
}; // class DTBStat
/**
* Performs the MST calculation using the Dual-Tree Boruvka algorithm.
*/
+template<
+ typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
+>
class DualTreeBoruvka
{
- public:
- // For now, everything is in Euclidean space
- static const size_t metric = 2;
+ private:
+ //! Copy of the data (if necessary).
+ arma::mat dataCopy;
+ //! Reference to the data (this is what should be used for accessing data).
+ arma::mat& data;
- typedef tree::BinarySpaceTree<bound::HRectBound<metric>, DTBStat> DTBTree;
+ //! Pointer to the root of the tree.
+ TreeType* tree;
+ //! Indicates whether or not we "own" the tree.
+ bool ownTree;
- //////// Member Variables /////////////////////
+ //! Indicates whether or not O(n^2) naive mode will be used.
+ bool naive;
- private:
- size_t number_of_edges_;
- std::vector<EdgePair> edges_; // must use vector with non-numerical types
- size_t number_of_points_;
- UnionFind connections_;
- struct datanode* module_;
- arma::mat data_points_;
- size_t leaf_size_;
+ //! Edges.
+ std::vector<EdgePair> edges; // must use vector with non-numerical types
- // lists
- std::vector<size_t> old_from_new_permutation_;
- arma::Col<size_t> neighbors_in_component_;
- arma::Col<size_t> neighbors_out_component_;
- arma::vec neighbors_distances_;
+ //! Connections.
+ UnionFind connections;
+ //! Permutations of points during tree building.
+ std::vector<size_t> oldFromNew;
+ //! List of edge nodes.
+ arma::Col<size_t> neighborsInComponent;
+ //! List of edge nodes.
+ arma::Col<size_t> neighborsOutComponent;
+ //! List of edge distances.
+ arma::vec neighborsDistances;
+
// output info
- double total_dist_;
- size_t number_of_loops_;
- size_t number_distance_prunes_;
- size_t number_component_prunes_;
- size_t number_leaf_computations_;
- size_t number_q_recursions_;
- size_t number_r_recursions_;
- size_t number_both_recursions_;
+ double totalDist;
- bool do_naive_;
-
- DTBTree* tree_;
-
- // for sorting the edge list after the computation
- struct SortEdgesHelper_
+ // For sorting the edge list after the computation.
+ struct SortEdgesHelper
{
- bool operator() (const EdgePair& pairA, const EdgePair& pairB)
+ bool operator()(const EdgePair& pairA, const EdgePair& pairB)
{
- return (pairA.distance() < pairB.distance());
+ return (pairA.Distance() < pairB.Distance());
}
} SortFun;
-
+
////////////////// Constructors ////////////////////////
public:
- DualTreeBoruvka() { }
+ /**
+ * Create the tree from the given dataset. This copies the dataset to an
+ * internal copy, because tree-building modifies the dataset.
+ *
+ * @param data Dataset to build a tree for.
+ * @param naive Whether the computation should be done in O(n^2) naive mode.
+ * @param leafSize The leaf size to be used during tree construction.
+ */
+ DualTreeBoruvka(const typename TreeType::Mat& dataset,
+ const bool naive = false,
+ const size_t leafSize = 1);
+ /**
+ * Create the DualTreeBoruvka object with an already initialized tree. This
+ * will not copy the dataset, and can save a little processing power. Naive
+ * mode is not available as an option for this constructor; instead, to run
+ * naive computation, construct a tree with all the points in one leaf (i.e.
+ * leafSize = number of points).
+ *
+ * @note
+ * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+ * of a matrix, be sure you pass the modified matrix to this object! In
+ * addition, mapping the points of the matrix back to their original indices
+ * is not done when this constructor is used.
+ * @endnote
+ *
+ * @param tree Pre-built tree.
+ * @param dataset Dataset corresponding to the pre-built tree.
+ */
+ DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset);
+
+ /**
+ * Delete the tree, if it was created inside the object.
+ */
~DualTreeBoruvka();
+ /**
+ * Call this function after Init. It will iteratively find the nearest
+ * neighbor of each component until the MST is complete.
+ */
+ void ComputeMST(arma::mat& results);
+
////////////////////////// Private Functions ////////////////////
private:
/**
* Adds a single edge to the edge list
*/
- void AddEdge_(size_t e1, size_t e2, double distance);
-
+ void AddEdge(const size_t e1, const size_t e2, const double distance);
+
/**
* Adds all the edges found in one iteration to the list of neighbors.
*/
- void AddAllEdges_();
-
+ void AddAllEdges();
+
/**
* Handles the base case computation. Also called by naive.
*/
- double ComputeBaseCase_(size_t query_start, size_t query_end,
- size_t reference_start, size_t reference_end);
-
+ double BaseCase(const TreeType* queryNode, const TreeType* referenceNode);
+
/**
* Handles the recursive calls to find the nearest neighbors in an iteration
*/
- void ComputeNeighborsRecursion_(DTBTree *query_node, DTBTree *reference_node,
- double incoming_distance);
-
- /**
- * Computes the nearest neighbor of each point in each iteration
- * of the algorithm
- */
- void ComputeNeighbors_();
+ void DualTreeRecursion(TreeType *queryNode,
+ TreeType *referenceNode,
+ double incomingDistance);
-
- void SortEdges_();
-
/**
- * Unpermute the edge list and output it to results
- *
+ * Unpermute the edge list and output it to results.
*/
- void EmitResults_(arma::mat& results);
+ void EmitResults(arma::mat& results);
/**
* This function resets the values in the nodes of the tree nearest neighbor
- * distance, check for fully connected nodes
+ * distance, and checks for fully connected nodes.
*/
- void CleanupHelper_(DTBTree* tree);
+ void CleanupHelper(TreeType* tree);
/**
* The values stored in the tree must be reset on each iteration.
*/
- void Cleanup_();
-
- /**
- * Format and output the results
- */
- void OutputResults_();
-
- /////////// Public Functions ///////////////////
- public:
- size_t number_of_edges();
+ void Cleanup();
- /**
- * Takes in a reference to the data set. Copies the data, builds the tree,
- * and initializes all of the member variables.
- */
- void Init(const arma::mat& data, bool naive, size_t leafSize);
-
- /**
- * Call this function after Init. It will iteratively find the nearest
- * neighbor of each component until the MST is complete.
- */
- void ComputeMST(arma::mat& results);
-
}; // class DualTreeBoruvka
}; // namespace emst
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp 2011-12-14 08:19:14 UTC (rev 10763)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp 2011-12-14 10:53:00 UTC (rev 10764)
@@ -1,10 +1,8 @@
-/*
- * dtb_impl.hpp
- *
+/**
+ * @file dtb_impl.hpp
+ * @author Bill March (march at gatech.edu)
*
- * Created by William March on 12/6/11.
- * Copyright 2011 __MyCompanyName__. All rights reserved.
- *
+ * Implementation of DTB.
*/
#ifndef __MLPACK_METHODS_EMST_DTB_IMPL_HPP
@@ -12,37 +10,17 @@
#include <mlpack/core.hpp>
+namespace mlpack {
+namespace emst {
+// DTBStat
-using namespace mlpack::emst;
-
-void DTBStat::set_max_neighbor_distance(double distance)
-{
- max_neighbor_distance_ = distance;
-}
-
-double DTBStat::max_neighbor_distance()
-{
- return max_neighbor_distance_;
-}
-
-void DTBStat::set_component_membership(int membership)
-{
- component_membership_ = membership;
-}
-
-int DTBStat::component_membership()
-{
- return component_membership_;
-}
-
/**
* A generic initializer.
*/
-DTBStat::DTBStat()
+DTBStat::DTBStat() : maxNeighborDistance(DBL_MAX), componentMembership(-1)
{
- set_max_neighbor_distance(DBL_MAX);
- set_component_membership(-1);
+ // Nothing to do.
}
/**
@@ -51,516 +29,430 @@
template<typename MatType>
DTBStat::DTBStat(const MatType& dataset,
const size_t start,
- const size_t count)
+ const size_t count) :
+ maxNeighborDistance(DBL_MAX),
+ componentMembership((count == 1) ? start : -1)
{
- if (count == 1)
- {
- set_component_membership(start);
- set_max_neighbor_distance(DBL_MAX);
- }
- else
- {
- set_max_neighbor_distance(DBL_MAX);
- set_component_membership(-1);
- }
+ // Nothing to do.
}
/**
- * An initializer for non-leaves. Simply calls the leaf initializer.
+ * An initializer for non-leaves.
*/
template<typename MatType>
DTBStat::DTBStat(const MatType& dataset,
const size_t start,
const size_t count,
const DTBStat& leftStat,
- const DTBStat& right_stat)
+ const DTBStat& right_stat) :
+ maxNeighborDistance(DBL_MAX),
+ componentMembership((count == 1) ? start : -1)
{
- if (count == 1)
+ // Nothing to do.
+}
+
+// DualTreeBoruvka
+
+/**
+ * Takes in a reference to the data set. Copies the data, builds the tree,
+ * and initializes all of the member variables.
+ */
+template<typename TreeType>
+DualTreeBoruvka<TreeType>::DualTreeBoruvka(
+ const typename TreeType::Mat& dataset,
+ const bool naive,
+ const size_t leafSize) :
+ dataCopy(dataset),
+ data(dataCopy), // The reference points to our copy of the data.
+ ownTree(true),
+ naive(naive),
+ connections(data.n_cols),
+ totalDist(0.0)
+{
+ Timers::StartTimer("emst/treebuilding");
+
+ if (!naive)
{
- set_component_membership(start);
- set_max_neighbor_distance(DBL_MAX);
+ // Default leaf size is 1; this gives the best pruning, empirically. Use
+ // leaf_size = 1 unless space is a big concern.
+ tree = new TreeType(data, oldFromNew, leafSize);
}
else
{
- set_max_neighbor_distance(DBL_MAX);
- set_component_membership(-1);
+ // Naive tree holds all data in one leaf.
+ tree = new TreeType(data, oldFromNew, data.n_cols);
}
-}
+ Timers::StopTimer("emst/treebuilding");
-DualTreeBoruvka::~DualTreeBoruvka()
+ edges.reserve(data.n_cols - 1); // Set size.
+
+ neighborsInComponent.set_size(data.n_cols);
+ neighborsOutComponent.set_size(data.n_cols);
+ neighborsDistances.set_size(data.n_cols);
+ neighborsDistances.fill(DBL_MAX);
+} // Constructor
+
+template<typename TreeType>
+DualTreeBoruvka<TreeType>::DualTreeBoruvka(
+ TreeType* tree,
+ const typename TreeType::Mat& dataset) :
+ data(dataset),
+ tree(tree),
+ ownTree(true),
+ naive(false),
+ connections(data.n_cols),
+ totalDist(0.0)
{
- if (tree_ != NULL)
- delete tree_;
+ edges.reserve(data.n_cols - 1); // fill with EdgePairs
+
+ neighborsInComponent.set_size(data.n_cols);
+ neighborsOutComponent.set_size(data.n_cols);
+ neighborsDistances.set_size(data.n_cols);
+ neighborsDistances.fill(DBL_MAX);
}
+template<typename TreeType>
+DualTreeBoruvka<TreeType>::~DualTreeBoruvka()
+{
+ if (ownTree)
+ delete tree;
+}
+
/**
+ * Call this function after Init. It will iteratively find the nearest
+ * neighbor of each component until the MST is complete.
+ */
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::ComputeMST(arma::mat& results)
+{
+ Timers::StartTimer("emst/mst_computation");
+
+ while (edges.size() < (data.n_cols - 1))
+ {
+ // Compute neighbors.
+ if (naive)
+ {
+ BaseCase(tree, tree);
+ }
+ else
+ {
+ DualTreeRecursion(tree, tree, DBL_MAX);
+ }
+
+ AddAllEdges();
+
+ Cleanup();
+
+ Log::Info << edges.size() << " edges found so far.\n";
+ }
+
+ Timers::StopTimer("emst/mst_computation");
+
+ EmitResults(results);
+
+ Log::Info << "Total squared length: " << totalDist << std::endl;
+} // ComputeMST
+
+/**
* Adds a single edge to the edge list
*/
-void DualTreeBoruvka::AddEdge_(size_t e1, size_t e2, double distance)
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::AddEdge(const size_t e1,
+ const size_t e2,
+ const double distance)
{
- //EdgePair edge;
- mlpack::Log::Assert((e1 != e2),
- "Indices are equal in DualTreeBoruvka.add_edge(...)");
-
- mlpack::Log::Assert((distance >= 0.0),
- "Negative distance input in DualTreeBoruvka.add_edge(...)");
-
+ Log::Assert((distance >= 0.0),
+ "DualTreeBoruvka::AddEdge(): distance cannot be negative.");
+
if (e1 < e2)
- edges_[number_of_edges_].Init(e1, e2, distance);
+ edges.push_back(EdgePair(e1, e2, distance));
else
- edges_[number_of_edges_].Init(e2, e1, distance);
-
- number_of_edges_++;
-
-} // AddEdge_
+ edges.push_back(EdgePair(e2, e1, distance));
+} // AddEdge
/**
* Adds all the edges found in one iteration to the list of neighbors.
*/
-void DualTreeBoruvka::AddAllEdges_()
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::AddAllEdges()
{
- for (size_t i = 0; i < number_of_points_; i++)
+ for (size_t i = 0; i < data.n_cols; i++)
{
- size_t component_i = connections_.Find(i);
- size_t in_edge_i = neighbors_in_component_[component_i];
- size_t out_edge_i = neighbors_out_component_[component_i];
- if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i))
+ size_t component = connections.Find(i);
+ size_t inEdge = neighborsInComponent[component];
+ size_t outEdge = neighborsOutComponent[component];
+ if (connections.Find(inEdge) != connections.Find(outEdge))
{
- double dist = neighbors_distances_[component_i];
- //total_dist_ = total_dist_ + dist;
+ //totalDist = totalDist + dist;
// changed to make this agree with the cover tree code
- total_dist_ = total_dist_ + sqrt(dist);
- AddEdge_(in_edge_i, out_edge_i, dist);
- connections_.Union(in_edge_i, out_edge_i);
+ totalDist += sqrt(neighborsDistances[component]);
+ AddEdge(inEdge, outEdge, neighborsDistances[component]);
+ connections.Union(inEdge, outEdge);
}
}
-} // AddAllEdges_
+} // AddAllEdges
/**
* Handles the base case computation. Also called by naive.
*/
-double DualTreeBoruvka::ComputeBaseCase_(size_t query_start, size_t query_end,
- size_t reference_start,
- size_t reference_end)
+template<typename TreeType>
+double DualTreeBoruvka<TreeType>::BaseCase(const TreeType* queryNode,
+ const TreeType* referenceNode)
{
- number_leaf_computations_++;
-
- double new_upper_bound = -1.0;
-
- for (size_t query_index = query_start; query_index < query_end;
- query_index++)
+ double newUpperBound = -1.0;
+
+ for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
+ ++queryIndex)
{
- // Find the index of the component the query is in
- size_t query_component_index = connections_.Find(query_index);
-
- arma::vec query_point = data_points_.col(query_index);
-
- for (size_t reference_index = reference_start;
- reference_index < reference_end; reference_index++)
+ // Find the index of the component the query is in.
+ size_t queryComponentIndex = connections.Find(queryIndex);
+
+ for (size_t referenceIndex = referenceNode->Begin();
+ referenceIndex < referenceNode->End(); ++referenceIndex)
{
- size_t reference_component_index = connections_.Find(reference_index);
-
- if (query_component_index != reference_component_index)
+ size_t referenceComponentIndex = connections.Find(referenceIndex);
+
+ if (queryComponentIndex != referenceComponentIndex)
{
- arma::vec reference_point = data_points_.col(reference_index);
-
- double distance = mlpack::metric::LMetric<2>::Evaluate(query_point,
- reference_point);
-
- if (distance < neighbors_distances_[query_component_index])
+ double distance = metric::LMetric<2>::Evaluate(data.col(queryIndex),
+ data.col(referenceIndex));
+
+ if (distance < neighborsDistances[queryComponentIndex])
{
- mlpack::Log::Assert(query_index != reference_index);
-
- neighbors_distances_[query_component_index] = distance;
- neighbors_in_component_[query_component_index] = query_index;
- neighbors_out_component_[query_component_index] = reference_index;
+ Log::Assert(queryIndex != referenceIndex);
+
+ neighborsDistances[queryComponentIndex] = distance;
+ neighborsInComponent[queryComponentIndex] = queryIndex;
+ neighborsOutComponent[queryComponentIndex] = referenceIndex;
} // if distance
} // if indices not equal
- } // for reference_index
-
- if (new_upper_bound < neighbors_distances_[query_component_index])
- new_upper_bound = neighbors_distances_[query_component_index];
-
- } // for query_index
-
- mlpack::Log::Assert(new_upper_bound >= 0.0);
- return new_upper_bound;
-
-} // ComputeBaseCase_
+ } // for referenceIndex
+ if (newUpperBound < neighborsDistances[queryComponentIndex])
+ newUpperBound = neighborsDistances[queryComponentIndex];
+ } // for queryIndex
+
+ Log::Assert(newUpperBound >= 0.0);
+
+ return newUpperBound;
+
+} // BaseCase
+
+
/**
* Handles the recursive calls to find the nearest neighbors in an iteration
*/
-void DualTreeBoruvka::ComputeNeighborsRecursion_(DTBTree *query_node,
- DTBTree *reference_node,
- double incoming_distance)
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::DualTreeRecursion(TreeType *queryNode,
+ TreeType *referenceNode,
+ double incomingDistance)
{
// Check for a distance prune.
- if (query_node->Stat().max_neighbor_distance() < incoming_distance)
+ if (queryNode->Stat().MaxNeighborDistance() < incomingDistance)
{
// Pruned by distance.
- number_distance_prunes_++;
+ return;
}
// Check for a component prune.
- else if ((query_node->Stat().component_membership() >= 0)
- && (query_node->Stat().component_membership() ==
- reference_node->Stat().component_membership()))
+ else if ((queryNode->Stat().ComponentMembership() >= 0)
+ && (queryNode->Stat().ComponentMembership() ==
+ referenceNode->Stat().ComponentMembership()))
{
// Pruned by component membership.
- mlpack::Log::Assert(reference_node->Stat().component_membership() >= 0);
- mlpack::Log::Info << query_node->Stat().component_membership()
- << "q mem\n";
- mlpack::Log::Info << reference_node->Stat().component_membership()
- << "r mem\n";
-
- number_component_prunes_++;
+ Log::Assert(referenceNode->Stat().ComponentMembership() >= 0);
+ return;
}
- else if (query_node->IsLeaf() && reference_node->IsLeaf()) // Base case.
+ else if (queryNode->IsLeaf() && referenceNode->IsLeaf()) // Base case.
{
- double new_bound = ComputeBaseCase_(query_node->Begin(),
- query_node->End(), reference_node->Begin(), reference_node->End());
-
- query_node->Stat().set_max_neighbor_distance(new_bound);
+ double new_bound = BaseCase(queryNode, referenceNode);
+ queryNode->Stat().MaxNeighborDistance() = new_bound;
}
- else if (query_node->IsLeaf()) // Other recursive calls.
+ else if (queryNode->IsLeaf()) // Other recursive calls.
{
- // Recurse on reference_node only.
- number_r_recursions_++;
-
- double left_dist =
- query_node->Bound().MinDistance(reference_node->Left()->Bound());
- double right_dist =
- query_node->Bound().MinDistance(reference_node->Right()->Bound());
- mlpack::Log::Assert(left_dist >= 0.0);
- mlpack::Log::Assert(right_dist >= 0.0);
-
- if (left_dist < right_dist)
+ // Recurse on referenceNode only.
+ double leftDist =
+ queryNode->Bound().MinDistance(referenceNode->Left()->Bound());
+ double rightDist =
+ queryNode->Bound().MinDistance(referenceNode->Right()->Bound());
+
+ if (leftDist < rightDist)
{
- ComputeNeighborsRecursion_(query_node, reference_node->Left(),
- left_dist);
- ComputeNeighborsRecursion_(query_node, reference_node->Right(),
- right_dist);
+ DualTreeRecursion(queryNode, referenceNode->Left(), leftDist);
+ DualTreeRecursion(queryNode, referenceNode->Right(), rightDist);
}
else
{
- ComputeNeighborsRecursion_(query_node, reference_node->Right(),
- right_dist);
- ComputeNeighborsRecursion_(query_node, reference_node->Left(),
- left_dist);
+ DualTreeRecursion(queryNode, referenceNode->Right(), rightDist);
+ DualTreeRecursion(queryNode, referenceNode->Left(), leftDist);
}
}
- else if (reference_node->IsLeaf())
+ else if (referenceNode->IsLeaf())
{
- // Recurse on query_node only.
- number_q_recursions_++;
-
- double left_dist =
- query_node->Left()->Bound().MinDistance(reference_node->Bound());
- double right_dist =
- query_node->Right()->Bound().MinDistance(reference_node->Bound());
-
- ComputeNeighborsRecursion_(query_node->Left(), reference_node, left_dist);
- ComputeNeighborsRecursion_(query_node->Right(), reference_node,
- right_dist);
-
- // Update query_node's stat.
- query_node->Stat().set_max_neighbor_distance(
- std::max(query_node->Left()->Stat().max_neighbor_distance(),
- query_node->Right()->Stat().max_neighbor_distance()));
-
+ // Recurse on queryNode only.
+ double leftDist =
+ queryNode->Left()->Bound().MinDistance(referenceNode->Bound());
+ double rightDist =
+ queryNode->Right()->Bound().MinDistance(referenceNode->Bound());
+
+ DualTreeRecursion(queryNode->Left(), referenceNode, leftDist);
+ DualTreeRecursion(queryNode->Right(), referenceNode, rightDist);
+
+ // Update queryNode's stat.
+ queryNode->Stat().MaxNeighborDistance() =
+ std::max(queryNode->Left()->Stat().MaxNeighborDistance(),
+ queryNode->Right()->Stat().MaxNeighborDistance());
}
else
{
// Recurse on both.
- number_both_recursions_++;
-
- double left_dist = query_node->Left()->Bound().MinDistance(
- reference_node->Left()->Bound());
- double right_dist = query_node->Left()->Bound().MinDistance(
- reference_node->Right()->Bound());
-
- if (left_dist < right_dist)
+ double leftDist = queryNode->Left()->Bound().MinDistance(
+ referenceNode->Left()->Bound());
+ double rightDist = queryNode->Left()->Bound().MinDistance(
+ referenceNode->Right()->Bound());
+
+ if (leftDist < rightDist)
{
- ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
- left_dist);
- ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
- right_dist);
+ DualTreeRecursion(queryNode->Left(), referenceNode->Left(), leftDist);
+ DualTreeRecursion(queryNode->Left(), referenceNode->Right(),
+ rightDist);
}
else
{
- ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
- right_dist);
- ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
- left_dist);
+ DualTreeRecursion(queryNode->Left(), referenceNode->Right(), rightDist);
+ DualTreeRecursion(queryNode->Left(), referenceNode->Left(), leftDist);
}
-
- left_dist = query_node->Right()->Bound().MinDistance(
- reference_node->Left()->Bound());
- right_dist = query_node->Right()->Bound().MinDistance(
- reference_node->Right()->Bound());
-
- if (left_dist < right_dist)
+
+ leftDist = queryNode->Right()->Bound().MinDistance(
+ referenceNode->Left()->Bound());
+ rightDist = queryNode->Right()->Bound().MinDistance(
+ referenceNode->Right()->Bound());
+
+ if (leftDist < rightDist)
{
- ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
- left_dist);
- ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
- right_dist);
+ DualTreeRecursion(queryNode->Right(), referenceNode->Left(), leftDist);
+ DualTreeRecursion(queryNode->Right(), referenceNode->Right(), rightDist);
}
else
{
- ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
- right_dist);
- ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
- left_dist);
+ DualTreeRecursion(queryNode->Right(), referenceNode->Right(), rightDist);
+ DualTreeRecursion(queryNode->Right(), referenceNode->Left(), leftDist);
}
-
- query_node->Stat().set_max_neighbor_distance(
- std::max(query_node->Left()->Stat().max_neighbor_distance(),
- query_node->Right()->Stat().max_neighbor_distance()));
+
+ queryNode->Stat().MaxNeighborDistance() =
+ std::max(queryNode->Left()->Stat().MaxNeighborDistance(),
+ queryNode->Right()->Stat().MaxNeighborDistance());
}
-} // ComputeNeighborsRecursion_
+} // DualTreeRecursion
/**
- * Computes the nearest neighbor of each point in each iteration
- * of the algorithm
+ * Unpermute the edge list (if necessary) and output it to results.
*/
-void DualTreeBoruvka::ComputeNeighbors_()
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::EmitResults(arma::mat& results)
{
- if (do_naive_)
- {
- ComputeBaseCase_(0, number_of_points_, 0, number_of_points_);
- }
- else
- {
- ComputeNeighborsRecursion_(tree_, tree_, DBL_MAX);
- }
-} // ComputeNeighbors_
+ // Sort the edges.
+ std::sort(edges.begin(), edges.end(), SortFun);
-void DualTreeBoruvka::SortEdges_()
-{
- std::sort(edges_.begin(), edges_.end(), SortFun);
-} // SortEdges_()
+ Log::Assert(edges.size() == data.n_cols - 1);
+ results.set_size(3, edges.size());
-/**
- * Unpermute the edge list and output it to results
- *
- */
-void DualTreeBoruvka::EmitResults_(arma::mat& results)
-{
- SortEdges_();
-
- mlpack::Log::Assert(number_of_edges_ == number_of_points_ - 1);
- results.set_size(number_of_edges_, 3);
-
// Need to unpermute the point labels.
- if (!do_naive_)
+ if (!naive && ownTree)
{
- for (size_t i = 0; i < (number_of_points_ - 1); i++)
+ for (size_t i = 0; i < (data.n_cols - 1); i++)
{
// Make sure the edge list stores the smaller index first to
// make checking correctness easier
- size_t ind1, ind2;
- ind1 = old_from_new_permutation_[edges_[i].lesser_index()];
- ind2 = old_from_new_permutation_[edges_[i].greater_index()];
-
- edges_[i].set_lesser_index(std::min(ind1, ind2));
- edges_[i].set_greater_index(std::max(ind1, ind2));
-
- results(i, 0) = edges_[i].lesser_index();
- results(i, 1) = edges_[i].greater_index();
- results(i, 2) = sqrt(edges_[i].distance());
+ size_t ind1 = oldFromNew[edges[i].Lesser()];
+ size_t ind2 = oldFromNew[edges[i].Greater()];
+
+ if (ind1 < ind2)
+ {
+ edges[i].Lesser() = ind1;
+ edges[i].Greater() = ind2;
+ }
+ else
+ {
+ edges[i].Lesser() = ind2;
+ edges[i].Greater() = ind1;
+ }
+
+ results(0, i) = edges[i].Lesser();
+ results(1, i) = edges[i].Greater();
+ results(2, i) = sqrt(edges[i].Distance());
}
}
else
{
- for (size_t i = 0; i < number_of_edges_; i++)
+ for (size_t i = 0; i < edges.size(); i++)
{
- results(i, 0) = edges_[i].lesser_index();
- results(i, 1) = edges_[i].greater_index();
- results(i, 2) = sqrt(edges_[i].distance());
+ results(0, i) = edges[i].Lesser();
+ results(1, i) = edges[i].Greater();
+ results(2, i) = sqrt(edges[i].Distance());
}
}
-} // EmitResults_
+} // EmitResults
/**
* This function resets the values in the nodes of the tree nearest neighbor
* distance, check for fully connected nodes
*/
-void DualTreeBoruvka::CleanupHelper_(DTBTree* tree)
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::CleanupHelper(TreeType* tree)
{
- tree->Stat().set_max_neighbor_distance(DBL_MAX);
-
+ tree->Stat().MaxNeighborDistance() = DBL_MAX;
+
if (!tree->IsLeaf())
{
- CleanupHelper_(tree->Left());
- CleanupHelper_(tree->Right());
-
- if ((tree->Left()->Stat().component_membership() >= 0)
- && (tree->Left()->Stat().component_membership() ==
- tree->Right()->Stat().component_membership()))
+ CleanupHelper(tree->Left());
+ CleanupHelper(tree->Right());
+
+ if ((tree->Left()->Stat().ComponentMembership() >= 0)
+ && (tree->Left()->Stat().ComponentMembership() ==
+ tree->Right()->Stat().ComponentMembership()))
{
- tree->Stat().set_component_membership(tree->Left()->Stat().
- component_membership());
+ tree->Stat().ComponentMembership() =
+ tree->Left()->Stat().ComponentMembership();
}
}
else
{
- size_t new_membership = connections_.Find(tree->Begin());
-
- for (size_t i = tree->Begin(); i < tree->End(); i++)
+ size_t newMembership = connections.Find(tree->Begin());
+
+ for (size_t i = tree->Begin(); i < tree->End(); ++i)
{
- if (new_membership != connections_.Find(i))
+ if (newMembership != connections.Find(i))
{
- new_membership = -1;
- mlpack::Log::Assert(tree->Stat().component_membership() < 0);
+ newMembership = -1;
+ Log::Assert(tree->Stat().ComponentMembership() < 0);
return;
}
}
- tree->Stat().set_component_membership(new_membership);
+ tree->Stat().ComponentMembership() = newMembership;
}
-} // CleanupHelper_
+} // CleanupHelper
/**
* The values stored in the tree must be reset on each iteration.
*/
-void DualTreeBoruvka::Cleanup_()
+template<typename TreeType>
+void DualTreeBoruvka<TreeType>::Cleanup()
{
- for (size_t i = 0; i < number_of_points_; i++)
+ for (size_t i = 0; i < data.n_cols; i++)
{
- neighbors_distances_[i] = DBL_MAX;
+ neighborsDistances[i] = DBL_MAX;
}
- number_of_loops_++;
-
- if (!do_naive_)
+
+ if (!naive)
{
- CleanupHelper_(tree_);
+ CleanupHelper(tree);
}
}
-/**
- * Format and output the results
- */
-void DualTreeBoruvka::OutputResults_()
-{
- /* fx_result_double(module_, "total_squared_length", total_dist_);
- fx_result_int(module_, "number_of_points", number_of_points_);
- fx_result_int(module_, "dimension", data_points_.n_rows);
- fx_result_int(module_, "number_of_loops", number_of_loops_);
- fx_result_int(module_, "number_distance_prunes", number_distance_prunes_);
- fx_result_int(module_, "number_component_prunes", number_component_prunes_);
- fx_result_int(module_, "number_leaf_computations", number_leaf_computations_);
- fx_result_int(module_, "number_q_recursions", number_q_recursions_);
- fx_result_int(module_, "number_r_recursions", number_r_recursions_);
- fx_result_int(module_, "number_both_recursions", number_both_recursions_);*/
- // TODO, not sure how I missed this last time.
- mlpack::Log::Info << "Total squared length: " << total_dist_ << std::endl;
- mlpack::Log::Info << "Number of points: " << number_of_points_ << std::endl;
- mlpack::Log::Info << "Dimension: " << data_points_.n_rows << std::endl;
- /*
- mlpack::Log::Info << "number_of_loops" << std::endl;
- mlpack::Log::Info << "number_distance_prunes" << std::endl;
- mlpack::Log::Info << "number_component_prunes" << std::endl;
- mlpack::Log::Info << "number_leaf_computations" << std::endl;
- mlpack::Log::Info << "number_q_recursions" << std::endl;
- mlpack::Log::Info << "number_r_recursions" << std::endl;
- mlpack::Log::Info << "number_both_recursions" << std::endl;
- */
-
-} // OutputResults_
+}; // namespace emst
+}; // namespace mlpack
-size_t DualTreeBoruvka::number_of_edges()
-{
- return number_of_edges_;
-}
-
-/**
- * Takes in a reference to the data set. Copies the data, builds the tree,
- * and initializes all of the member variables.
- */
-void DualTreeBoruvka::Init(const arma::mat& data, bool naive = false,
- size_t leafSize = 1)
-{
- number_of_edges_ = 0;
- data_points_ = data; // copy
-
- do_naive_ = naive;
-
- if (!do_naive_)
- {
- // Default leaf size is 1
- // This gives best pruning empirically
- // Use leaf_size=1 unless space is a big concern
- Timers::StartTimer("emst/tree_building");
-
- tree_ = new DTBTree(data_points_, old_from_new_permutation_, leafSize);
-
- Timers::StopTimer("emst/tree_building");
- }
- else
- {
- tree_ = NULL;
- old_from_new_permutation_.resize(0);
- }
-
- number_of_points_ = data_points_.n_cols;
- edges_.resize(number_of_points_ - 1, EdgePair()); // fill with EdgePairs
- connections_.Init(number_of_points_);
-
- neighbors_in_component_.set_size(number_of_points_);
- neighbors_out_component_.set_size(number_of_points_);
- neighbors_distances_.set_size(number_of_points_);
- neighbors_distances_.fill(DBL_MAX);
-
- total_dist_ = 0.0;
- number_of_loops_ = 0;
- number_distance_prunes_ = 0;
- number_component_prunes_ = 0;
- number_leaf_computations_ = 0;
- number_q_recursions_ = 0;
- number_r_recursions_ = 0;
- number_both_recursions_ = 0;
-} // Init
-
-/**
- * Call this function after Init. It will iteratively find the nearest
- * neighbor of each component until the MST is complete.
- */
-void DualTreeBoruvka::ComputeMST(arma::mat& results)
-{
- Timers::StartTimer("emst/MST_computation");
-
- while (number_of_edges_ < (number_of_points_ - 1))
- {
- ComputeNeighbors_();
-
- AddAllEdges_();
-
- Cleanup_();
-
- Log::Info << "Finished loop number: " << number_of_loops_ << std::endl;
- Log::Info << number_of_edges_ << " edges found so far.\n\n";
- /*
- Log::Info << number_leaf_computations_ << " base cases.\n";
- Log::Info << number_distance_prunes_ << " distance prunes.\n";
- Log::Info << number_component_prunes_ << " component prunes.\n";
- Log::Info << number_r_recursions_ << " reference recursions.\n";
- Log::Info << number_q_recursions_ << " query recursions.\n";
- Log::Info << number_both_recursions_ << " dual recursions.\n\n";
- */
- }
-
- Timers::StopTimer("emst/MST_computation");
-
- EmitResults_(results);
-
- OutputResults_();
-} // ComputeMST
-
-
-
-
-
-#endif
+#endif
Modified: mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp 2011-12-14 08:19:14 UTC (rev 10763)
+++ mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp 2011-12-14 10:53:00 UTC (rev 10764)
@@ -23,9 +23,12 @@
class EdgePair
{
private:
- size_t lesser_index_;
- size_t greater_index_;
- double distance_;
+ //! Lesser index.
+ size_t lesser;
+ //! Greater index.
+ size_t greater;
+ //! Distance between two indices.
+ double distance;
public:
/**
@@ -34,44 +37,28 @@
* Init. However, this is not necessary for functionality; it is just a way
* to keep the edge list organized in other code.
*/
- void Init(size_t lesser, size_t greater, double dist)
+ EdgePair(const size_t lesser, const size_t greater, const double dist) :
+ lesser(lesser), greater(greater), distance(dist)
{
- mlpack::Log::Assert(lesser != greater,
- "indices equal when creating EdgePair, lesser == greater");
- lesser_index_ = lesser;
- greater_index_ = greater;
- distance_ = dist;
+ Log::Assert(lesser != greater,
+ "EdgePair::EdgePair(): indices cannot be equal.");
}
- size_t lesser_index()
- {
- return lesser_index_;
- }
+ //! Get the lesser index.
+ size_t Lesser() const { return lesser; }
+ //! Modify the lesser index.
+ size_t& Lesser() { return lesser; }
- void set_lesser_index(size_t index)
- {
- lesser_index_ = index;
- }
+ //! Get the greater index.
+ size_t Greater() const { return greater; }
+ //! Modify the greater index.
+ size_t& Greater() { return greater; }
- size_t greater_index()
- {
- return greater_index_;
- }
+ //! Get the distance.
+ double Distance() const { return distance; }
+ //! Modify the distance.
+ double& Distance() { return distance; }
- void set_greater_index(size_t index)
- {
- greater_index_ = index;
- }
-
- double distance() const
- {
- return distance_;
- }
-
- void set_distance(double new_dist)
- {
- distance_ = new_dist;
- }
}; // class EdgePair
}; // namespace emst
Modified: mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp 2011-12-14 08:19:14 UTC (rev 10763)
+++ mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp 2011-12-14 10:53:00 UTC (rev 10764)
@@ -27,7 +27,11 @@
"Conference\n on Knowledge Discovery and Data Mining},\n"
" series = {KDD '10},\n"
" year = {2010}\n"
- " }\n");
+ " }\n\n"
+ "The output is saved in a three-column matrix, where each row indicates an "
+ "edge. The first column corresponds to the lesser index of the edge; the "
+ "second column corresponds to the greater index of the edge; and the third "
+ "column corresponds to the distance between the two points.");
PARAM_STRING_REQ("input_file", "Data input file.", "i");
PARAM_STRING("output_file", "Data output file. Stored as an edge list.", "o",
@@ -41,6 +45,7 @@
using namespace mlpack;
using namespace mlpack::emst;
+using namespace mlpack::tree;
int main(int argc, char* argv[])
{
@@ -54,21 +59,19 @@
arma::mat dataPoints;
data::Load(dataFilename.c_str(), dataPoints, true);
- // Do naive
+ // Do naive.
if (CLI::GetParam<bool>("naive"))
{
Log::Info << "Running naive algorithm.\n";
- DualTreeBoruvka naive;
+ DualTreeBoruvka<> naive(dataPoints, true);
- naive.Init(dataPoints, true);
+ arma::mat naiveResults;
+ naive.ComputeMST(naiveResults);
- arma::mat naive_results;
- naive.ComputeMST(naive_results);
-
std::string outputFilename = CLI::GetParam<std::string>("output_file");
- data::Save(outputFilename.c_str(), naive_results, true);
+ data::Save(outputFilename.c_str(), naiveResults, true);
}
else
{
@@ -83,10 +86,9 @@
size_t leafSize = CLI::GetParam<int>("leaf_size");
- DualTreeBoruvka dtb;
- dtb.Init(dataPoints, false, leafSize);
+ DualTreeBoruvka<> dtb(dataPoints, false, leafSize);
- Log::Info << "Tree built, running algorithm.\n\n";
+ Log::Info << "Tree built, running algorithm.\n";
////////////// Run DTB /////////////////////
arma::mat results;
Modified: mlpack/trunk/src/mlpack/methods/emst/union_find.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/union_find.hpp 2011-12-14 08:19:14 UTC (rev 10763)
+++ mlpack/trunk/src/mlpack/methods/emst/union_find.hpp 2011-12-14 10:53:00 UTC (rev 10764)
@@ -1,6 +1,5 @@
/**
- * @file union_find.h
- *
+ * @file union_find.hpp
* @author Bill March (march at gatech.edu)
*
* Implements a union-find data structure. This structure tracks the components
@@ -22,82 +21,69 @@
class UnionFind
{
private:
- arma::Col<size_t> parent_;
- arma::ivec rank_;
- size_t number_of_elements_;
+ size_t size;
+ arma::Col<size_t> parent;
+ arma::ivec rank;
public:
- UnionFind() {}
-
- ~UnionFind() {}
-
- /**
- * Initializes the structure. This implementation assumes
- * that the size is known advance and fixed
- *
- * @param size The number of elements to be tracked.
- */
- void Init(size_t size)
+ UnionFind(const size_t size) : size(size), parent(size), rank(size)
{
- number_of_elements_ = size;
- parent_.set_size(number_of_elements_);
- rank_.set_size(number_of_elements_);
- for (size_t i = 0; i < number_of_elements_; i++)
+ for (size_t i = 0; i < size; ++i)
{
- parent_[i] = i;
- rank_[i] = 0;
+ parent[i] = i;
+ rank[i] = 0;
}
}
+ ~UnionFind() {}
+
/**
- * Returns the component containing an element
+ * Returns the component containing an element.
*
* @param x the component to be found
* @return The index of the component containing x
*/
- size_t Find(size_t x)
+ size_t Find(const size_t x)
{
- if (parent_[x] == x)
+ if (parent[x] == x)
{
return x;
}
else
{
// This ensures that the tree has a small depth
- parent_[x] = Find(parent_[x]);
- return parent_[x];
+ parent[x] = Find(parent[x]);
+ return parent[x];
}
}
/**
- * @function Union
+ * Union the components containing x and y.
*
- * Union the components containing x and y
- *
* @param x one component
* @param y the other component
*/
- void Union(size_t x, size_t y)
+ void Union(const size_t x, const size_t y)
{
- size_t x_root = Find(x);
- size_t y_root = Find(y);
+ const size_t xRoot = Find(x);
+ const size_t yRoot = Find(y);
- if (x_root == y_root)
+ if (xRoot == yRoot)
{
return;
}
- else if (rank_[x_root] == rank_[y_root])
+ else if (rank[xRoot] == rank[yRoot])
{
- parent_[y_root] = parent_[x_root];
- rank_[x_root] = rank_[x_root] + 1;
+ parent[yRoot] = parent[xRoot];
+ rank[xRoot] = rank[xRoot] + 1;
}
- else if (rank_[x_root] > rank_[y_root])
+ else if (rank[xRoot] > rank[yRoot])
{
- parent_[y_root] = x_root;
+ parent[yRoot] = xRoot;
}
else
{
- parent_[x_root] = y_root;
+ parent[xRoot] = yRoot;
}
}
}; // class UnionFind
More information about the mlpack-svn
mailing list