[mlpack-svn] r15566 - mlpack/trunk/src/mlpack/methods/emst
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 26 17:04:48 EDT 2013
Author: rcurtin
Date: Fri Jul 26 17:04:47 2013
New Revision: 15566
Log:
Code cleanup, and refactor DTBRules so it does not depend on
UpdateAfterRecursion(). Now, EMST actually provides speedup when run in
dual-tree mode (hooray!).
Modified:
mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp Fri Jul 26 17:04:47 2013
@@ -40,6 +40,14 @@
//! Upper bound on the distance to the nearest neighbor of any point in this
//! node.
double maxNeighborDistance;
+
+ //! Lower bound on the distance to the nearest neighbor of any point in this
+ //! node.
+ double minNeighborDistance;
+
+ //! Total bound for pruning.
+ double bound;
+
//! The index of the component that all points in this node belong to. This
//! is the same index returned by UnionFind for all points in this node. If
//! points in this node are in different components, this value will be
@@ -68,6 +76,16 @@
//! Modify the maximum neighbor distance.
double& MaxNeighborDistance() { return maxNeighborDistance; }
+ //! Get the minimum neighbor distance.
+ double MinNeighborDistance() const { return minNeighborDistance; }
+ //! Modify the minimum neighbor distance.
+ double& MinNeighborDistance() { return minNeighborDistance; }
+
+ //! Get the total bound for pruning.
+ double Bound() const { return bound; }
+ //! Modify the total bound for pruning.
+ double& Bound() { return bound; }
+
//! Get the component membership of this node.
int ComponentMembership() const { return componentMembership; }
//! Modify the component membership of this node.
@@ -114,7 +132,7 @@
* @tparam TreeType Type of tree to use. Should use DTBStat as a statistic.
*/
template<
- typename MetricType = metric::SquaredEuclideanDistance,
+ typename MetricType = metric::EuclideanDistance,
typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
>
class DualTreeBoruvka
@@ -151,7 +169,7 @@
//! Total distance of the tree.
double totalDist;
- //! The metric
+ //! The instantiated metric.
MetricType metric;
// For sorting the edge list after the computation.
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp Fri Jul 26 17:04:47 2013
@@ -4,7 +4,6 @@
*
* Implementation of DTB.
*/
-
#ifndef __MLPACK_METHODS_EMST_DTB_IMPL_HPP
#define __MLPACK_METHODS_EMST_DTB_IMPL_HPP
@@ -18,7 +17,11 @@
/**
* A generic initializer.
*/
-DTBStat::DTBStat() : maxNeighborDistance(DBL_MAX), componentMembership(-1)
+DTBStat::DTBStat() :
+ maxNeighborDistance(DBL_MAX),
+ minNeighborDistance(DBL_MAX),
+ bound(DBL_MAX),
+ componentMembership(-1)
{
// Nothing to do.
}
@@ -29,6 +32,8 @@
template<typename TreeType>
DTBStat::DTBStat(const TreeType& node) :
maxNeighborDistance(DBL_MAX),
+ minNeighborDistance(DBL_MAX),
+ bound(DBL_MAX),
componentMembership(((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
node.Point(0) : -1)
{
@@ -124,7 +129,6 @@
while (edges.size() < (data.n_cols - 1))
{
-
typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*tree, *tree);
@@ -133,7 +137,7 @@
Cleanup();
- Log::Info << edges.size() << " edges found so far.\n";
+ Log::Info << edges.size() << " edges found so far." << std::endl;
}
Timer::Stop("emst/mst_computation");
@@ -175,7 +179,7 @@
{
//totalDist = totalDist + dist;
// changed to make this agree with the cover tree code
- totalDist += sqrt(neighborsDistances[component]);
+ totalDist += neighborsDistances[component];
AddEdge(inEdge, outEdge, neighborsDistances[component]);
connections.Union(inEdge, outEdge);
}
@@ -217,7 +221,7 @@
results(0, i) = edges[i].Lesser();
results(1, i) = edges[i].Greater();
- results(2, i) = sqrt(edges[i].Distance());
+ results(2, i) = edges[i].Distance();
}
}
else
@@ -226,7 +230,7 @@
{
results(0, i) = edges[i].Lesser();
results(1, i) = edges[i].Greater();
- results(2, i) = sqrt(edges[i].Distance());
+ results(2, i) = edges[i].Distance();
}
}
} // EmitResults
@@ -239,6 +243,8 @@
void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
{
tree->Stat().MaxNeighborDistance() = DBL_MAX;
+ tree->Stat().MinNeighborDistance() = DBL_MAX;
+ tree->Stat().Bound() = DBL_MAX;
if (!tree->IsLeaf())
{
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp Fri Jul 26 17:04:47 2013
@@ -25,9 +25,6 @@
double BaseCase(const size_t queryIndex, const size_t referenceIndex);
- // Update bounds. Needs a better name.
- void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
-
/**
* Get the score for recursion order. A low score indicates priority for
* recursion, while DBL_MAX indicates that the node should not be recursed
@@ -124,8 +121,14 @@
//! of the candidate edge.
arma::Col<size_t>& neighborsOutComponent;
- //! The metric
+ //! The instantiated metric.
MetricType& metric;
+
+ /**
+ * Update the bound for the given query node.
+ */
+ inline double CalculateBound(TreeType& queryNode) const;
+
}; // class DTBRules
} // emst namespace
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp Fri Jul 26 17:04:47 2013
@@ -55,7 +55,6 @@
neighborsDistances[queryComponentIndex] = distance;
neighborsInComponent[queryComponentIndex] = queryIndex;
neighborsOutComponent[queryComponentIndex] = referenceIndex;
-
}
}
@@ -68,34 +67,6 @@
}
template<typename MetricType, typename TreeType>
-void DTBRules<MetricType, TreeType>::UpdateAfterRecursion(
- TreeType& queryNode,
- TreeType& /* referenceNode */)
-{
- // Find the worst distance that the children found (including any points), and
- // update the bound accordingly.
- double newUpperBound = 0.0;
-
- // First look through children nodes.
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- if (newUpperBound < queryNode.Child(i).Stat().MaxNeighborDistance())
- newUpperBound = queryNode.Child(i).Stat().MaxNeighborDistance();
- }
-
- // Now look through children points.
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- size_t pointComponent = connections.Find(queryNode.Point(i));
- if (newUpperBound < neighborsDistances[pointComponent])
- newUpperBound = neighborsDistances[pointComponent];
- }
-
- // Update the bound in the query's statistic.
- queryNode.Stat().MaxNeighborDistance() = newUpperBound;
-}
-
-template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
TreeType& referenceNode)
{
@@ -109,7 +80,6 @@
return DBL_MAX;
const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
-
const double distance = referenceNode.MinDistance(queryPoint);
// If all the points in the reference node are farther than the candidate
@@ -166,11 +136,11 @@
return DBL_MAX;
const double distance = queryNode.MinDistance(&referenceNode);
+ const double bound = CalculateBound(queryNode);
// If all the points in the reference node are farther than the candidate
// nearest neighbor for all queries in the node, we prune.
- return (queryNode.Stat().MaxNeighborDistance() < distance) ? DBL_MAX :
- distance;
+ return (bound < distance) ? DBL_MAX : distance;
}
template<typename MetricType, typename TreeType>
@@ -185,13 +155,12 @@
referenceNode.Stat().ComponentMembership()))
return DBL_MAX;
- const double distance = queryNode.MinDistance(referenceNode,
- baseCaseResult);
+ const double distance = queryNode.MinDistance(referenceNode, baseCaseResult);
+ const double bound = CalculateBound(queryNode);
// If all the points in the reference node are farther than the candidate
// nearest neighbor for all queries in the node, we prune.
- return (queryNode.Stat().MaxNeighborDistance() < distance) ? DBL_MAX :
- distance;
+ return (bound < distance) ? DBL_MAX : distance;
}
template<typename MetricType, typename TreeType>
@@ -199,12 +168,63 @@
TreeType& /* referenceNode */,
const double oldScore) const
{
- return (oldScore > queryNode.Stat().MaxNeighborDistance()) ? DBL_MAX :
- oldScore;
+ const double bound = CalculateBound(queryNode);
+ return (oldScore > bound) ? DBL_MAX : oldScore;
+}
+
+// Calculate the bound for a given query node in its current state and update
+// it.
+template<typename MetricType, typename TreeType>
+inline double DTBRules<MetricType, TreeType>::CalculateBound(
+ TreeType& queryNode) const
+{
+ double worstPointBound = -DBL_MAX;
+ double bestPointBound = DBL_MAX;
+
+ double worstChildBound = -DBL_MAX;
+ double bestChildBound = DBL_MAX;
+
+ // Now, find the best and worst point bounds.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ const size_t pointComponent = connections.Find(queryNode.Point(i));
+ const double bound = neighborsDistances[pointComponent];
+
+ if (bound > worstPointBound)
+ worstPointBound = bound;
+ if (bound < bestPointBound)
+ bestPointBound = bound;
+ }
+
+ // Find the best and worst child bounds.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ const double maxBound = queryNode.Child(i).Stat().MaxNeighborDistance();
+ if (maxBound > worstChildBound)
+ worstChildBound = maxBound;
+
+ const double minBound = queryNode.Child(i).Stat().MinNeighborDistance();
+ if (minBound < bestChildBound)
+ bestChildBound = minBound;
+ }
+
+ // Now calculate the actual bounds.
+ const double worstBound = std::max(worstPointBound, worstChildBound);
+ const double bestBound = std::min(bestPointBound, bestChildBound);
+ // We must check that bestBound != DBL_MAX; otherwise, we risk overflow.
+ const double bestAdjustedBound = (bestBound == DBL_MAX) ? DBL_MAX :
+ bestBound + 2 * queryNode.FurthestDescendantDistance();
+
+ // Update the relevant quantities in the node.
+ queryNode.Stat().MaxNeighborDistance() = worstBound;
+ queryNode.Stat().MinNeighborDistance() = bestBound;
+ queryNode.Stat().Bound() = std::min(worstBound, bestAdjustedBound);
+
+ return queryNode.Stat().Bound();
}
-} // namespace emst
-} // namespace mlpack
+}; // namespace emst
+}; // namespace mlpack
Modified: mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp Fri Jul 26 17:04:47 2013
@@ -44,60 +44,54 @@
using namespace mlpack;
using namespace mlpack::emst;
using namespace mlpack::tree;
+using namespace std;
int main(int argc, char* argv[])
{
CLI::ParseCommandLine(argc, argv);
- ///////////////// READ IN DATA //////////////////////////////////
- std::string dataFilename = CLI::GetParam<std::string>("input_file");
-
- Log::Info << "Reading in data.\n";
+ const string dataFilename = CLI::GetParam<string>("input_file");
arma::mat dataPoints;
data::Load(dataFilename, dataPoints, true);
- // Do naive.
+ // Do naive computation if necessary.
if (CLI::GetParam<bool>("naive"))
{
- Log::Info << "Running naive algorithm.\n";
+ Log::Info << "Running naive algorithm." << endl;
DualTreeBoruvka<> naive(dataPoints, true);
arma::mat naiveResults;
naive.ComputeMST(naiveResults);
- std::string outputFilename = CLI::GetParam<std::string>("output_file");
+ const string outputFilename = CLI::GetParam<string>("output_file");
data::Save(outputFilename, naiveResults, true);
}
else
{
- Log::Info << "Data read, building tree.\n";
+ Log::Info << "Building tree.\n";
- /////////////// Initialize DTB //////////////////////
+ // Check that the leaf size is reasonable.
if (CLI::GetParam<int>("leaf_size") <= 0)
{
Log::Fatal << "Invalid leaf size (" << CLI::GetParam<int>("leaf_size")
<< ")! Must be greater than or equal to 1." << std::endl;
}
- size_t leafSize = CLI::GetParam<int>("leaf_size");
-
+ // Initialize the tree and get ready to compute the MST.
+ const size_t leafSize = (size_t) CLI::GetParam<int>("leaf_size");
DualTreeBoruvka<> dtb(dataPoints, false, leafSize);
- Log::Info << "Tree built, running algorithm.\n";
-
- ////////////// Run DTB /////////////////////
+ // Run the DTB algorithm.
+ Log::Info << "Calculating minimum spanning tree." << endl;
arma::mat results;
-
dtb.ComputeMST(results);
- //////////////// Output the Results ////////////////
- std::string outputFilename = CLI::GetParam<std::string>("output_file");
+ // Output the results.
+ const string outputFilename = CLI::GetParam<string>("output_file");
data::Save(outputFilename, results, true);
}
-
- return 0;
}
More information about the mlpack-svn
mailing list