[mlpack-svn] r13335 - in mlpack/trunk/src/mlpack/methods: emst kmeans maxip neighbor_search range_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Aug 4 23:01:42 EDT 2012
Author: rcurtin
Date: 2012-08-04 23:01:42 -0400 (Sat, 04 Aug 2012)
New Revision: 13335
Modified:
mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
mlpack/trunk/src/mlpack/methods/maxip/max_ip.hpp
mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
Log:
Change include files and APIs to use the new tree traverser setup.
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -23,10 +23,10 @@
#include "edge_pair.hpp"
#include <mlpack/core.hpp>
-#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
namespace mlpack {
namespace emst /** Euclidean Minimum Spanning Trees. */ {
Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -13,6 +13,8 @@
#include "random_partition.hpp"
#include "max_variance_new_cluster.hpp"
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
namespace mlpack {
namespace kmeans /** K-Means clustering. */ {
Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -7,8 +7,6 @@
*/
#include "kmeans.hpp"
-#include <mlpack/core/tree/binary_space_tree.hpp>
-#include <mlpack/core/tree/hrectbound.hpp>
#include <mlpack/core/tree/mrkd_statistic.hpp>
#include <mlpack/core/metrics/lmetric.hpp>
Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -9,7 +9,7 @@
#include <mlpack/core.hpp>
#include "ip_metric.hpp"
-#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
namespace mlpack {
namespace maxip {
Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -13,6 +13,7 @@
#include "max_ip_rules.hpp"
#include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <queue>
namespace mlpack {
namespace maxip {
Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -9,7 +9,7 @@
#define __MLPACK_METHODS_MAXIP_MAX_IP_RULES_HPP
#include <mlpack/core.hpp>
-#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
namespace mlpack {
namespace maxip {
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -9,11 +9,11 @@
#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
#include <mlpack/core.hpp>
-#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/binary_space_tree.hpp>
#include <vector>
#include <string>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
#include <mlpack/core/metrics/lmetric.hpp>
#include "sort_policies/nearest_neighbor_sort.hpp"
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -10,9 +10,6 @@
#include <mlpack/core.hpp>
-#include <mlpack/core/tree/traversers/single_tree_depth_first_traverser.hpp>
-#include <mlpack/core/tree/traversers/single_tree_breadth_first_traverser.hpp>
-#include <mlpack/core/tree/traversers/dual_tree_depth_first_traverser.hpp>
#include "neighbor_search_rules.hpp"
using namespace mlpack::neighbor;
@@ -180,13 +177,11 @@
if (singleMode)
{
// Create the helper object for the tree traversal.
- NeighborSearchRules<SortPolicy, MetricType, TreeType> rules(referenceSet,
- querySet, *neighborPtr, *distancePtr, metric);
+ typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
// Create the traverser.
- typename TreeType::template PreferredTraverser<
- NeighborSearchRules<SortPolicy, MetricType, TreeType> >::Type
- traverser(rules);
+ typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -196,17 +191,12 @@
}
else // Dual-tree recursion.
{
- // Breaking a lot of design rules here...
- typedef typename TreeType::template PreferredRules<SortPolicy, MetricType,
- TreeType>::Type RuleType;
-
+ // Create the helper object for the tree traversal.
+ typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
- typedef typename TreeType::template PreferredDualTraverser<RuleType>::Type
- TraverserType;
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
- TraverserType traverser(rules);
-
if (queryTree)
traverser.Traverse(*queryTree, *referenceTree);
else
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -30,11 +30,8 @@
bool CanPrune(TreeType& queryNode, TreeType& referenceNode);
// Get the order of points to recurse to.
- void RecursionOrder(TreeType& queryNode,
- TreeType& referenceNode,
- arma::Mat<size_t>& recursionOrder,
- bool& queryRecurse,
- bool& referenceRecurse);
+ bool LeftFirst(const size_t queryIndex, TreeType& referenceNode);
+ bool LeftFirst(TreeType& staticNode, TreeType& recurseNode);
// Update bounds. Needs a better name.
void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -84,154 +84,34 @@
return true;
}
-// Return the order in which we should recurse.
template<typename SortPolicy, typename MetricType, typename TreeType>
-inline void NeighborSearchRules<
- SortPolicy,
- MetricType,
- TreeType>::
-RecursionOrder(TreeType& queryNode,
- TreeType& referenceNode,
- arma::Mat<size_t>& recursionOrder,
- bool& queryRecurse,
- bool& referenceRecurse)
+inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::LeftFirst(
+ const size_t queryIndex,
+ TreeType& referenceNode)
{
- queryRecurse = !(queryNode.IsLeaf());
- referenceRecurse = !(referenceNode.IsLeaf());
+ // This ends up with us calculating this distance twice (it will be done again
+ // in CanPrune()), but because single-neighbors recursion is not the most
+ // important in this method, we can let it slide.
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const double leftDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ referenceNode.Left());
+ const double rightDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ referenceNode.Right());
- if (queryRecurse && !referenceRecurse)
- {
- // We only need to recurse into the query children. Therefore, the elements
- // in row 1 can be ignored.
- recursionOrder.set_size(2, queryNode.NumChildren());
- arma::vec recursionDistances(queryNode.NumChildren());
- recursionDistances.fill(SortPolicy::WorstDistance());
- size_t children = 0; // Number of children to recurse to.
+ return SortPolicy::IsBetter(leftDistance, rightDistance);
+}
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- double distance = SortPolicy::BestNodeToNodeDistance(&queryNode.Child(i),
- &referenceNode);
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::LeftFirst(
+ TreeType& staticNode,
+ TreeType& recurseNode)
+{
+ const double leftDistance = SortPolicy::BestNodeToNodeDistance(&staticNode,
+ recurseNode.Left());
+ const double rightDistance = SortPolicy::BestNodeToNodeDistance(&staticNode,
+ recurseNode.Right());
- // Find where to insert.
- size_t insertPosition;
- for (insertPosition = 0; insertPosition < children; ++insertPosition)
- if (SortPolicy::IsBetter(distance, recursionDistances[insertPosition]))
- break;
-
- // Now perform the actual insertion.
- if ((children - insertPosition) > 0)
- {
- memmove(recursionDistances.memptr() + insertPosition + 1,
- recursionDistances.memptr() + insertPosition,
- sizeof(double) * (children - insertPosition));
- memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
- recursionOrder.memptr() + (insertPosition * 2),
- sizeof(size_t) * (children - insertPosition) * 2);
- }
-
- // Insert.
- recursionDistances[insertPosition] = distance;
- recursionOrder(0, insertPosition) = i;
- ++children;
- }
-
- // Strip extra columns.
- if (children < queryNode.NumChildren())
- recursionOrder.shed_cols(children, queryNode.NumChildren() - 1);
- }
- else if (!queryRecurse && referenceRecurse)
- {
- // We only need to recurse into the reference children. Therefore, the
- // elements in row 0 can be ignored.
- recursionOrder.set_size(2, referenceNode.NumChildren());
- arma::vec recursionDistances(referenceNode.NumChildren());
- recursionDistances.fill(SortPolicy::WorstDistance());
- size_t children = 0; // Number of children to recurse into.
-
- for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
- {
- double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
- &referenceNode.Child(i));
-
- // Find where to insert.
- size_t insertPosition;
- for (insertPosition = 0; insertPosition < children; ++insertPosition)
- if (SortPolicy::IsBetter(distance, recursionDistances[insertPosition]))
- break;
-
- // Now perform the actual insertion.
- if ((children - insertPosition) > 0)
- {
- memmove(recursionDistances.memptr() + insertPosition + 1,
- recursionDistances.memptr() + insertPosition,
- sizeof(double) * (children - insertPosition));
- memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
- recursionOrder.memptr() + (insertPosition * 2),
- sizeof(size_t) * (children - insertPosition) * 2);
- }
-
- // Insert.
- recursionDistances[insertPosition] = distance;
- recursionOrder(1, insertPosition) = i;
- ++children;
- }
-
- // Strip extra columns.
- if (children < referenceNode.NumChildren())
- recursionOrder.shed_cols(children, referenceNode.NumChildren() - 1);
- }
- else if (queryRecurse && referenceRecurse)
- {
- // We need to recurse into both children.
- const size_t maxChildren = referenceNode.NumChildren() *
- queryNode.NumChildren();
- recursionOrder.set_size(2, maxChildren);
- arma::vec recursionDistances(maxChildren);
- recursionDistances.fill(SortPolicy::WorstDistance());
- size_t children = 0; // Number of children to recurse into.
-
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- // Check if we should even continue this direction.
- if (CanPrune(queryNode.Child(i), referenceNode))
- continue; // Don't go this way.
-
- for (size_t j = 0; j < referenceNode.NumChildren(); ++j)
- {
- double distance = SortPolicy::BestNodeToNodeDistance(
- &queryNode.Child(i), &referenceNode.Child(j));
-
- // Find where to insert.
- size_t insertPosition;
- for (insertPosition = 0; insertPosition < children; ++insertPosition)
- if (SortPolicy::IsBetter(distance,
- recursionDistances[insertPosition]))
- break;
-
- // Move things to prepare for insertion.
- if ((children - insertPosition) > 0)
- {
- memmove(recursionDistances.memptr() + insertPosition + 1,
- recursionDistances.memptr() + insertPosition,
- sizeof(double) * (children - insertPosition));
- memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
- recursionOrder.memptr() + (insertPosition * 2),
- sizeof(size_t) * (children - insertPosition) * 2);
- }
-
- // Insert.
- recursionDistances[insertPosition] = distance;
- recursionOrder(0, insertPosition) = i;
- recursionOrder(1, insertPosition) = j;
- ++children;
- }
- }
-
- // Strip extra columns.
- if (children < maxChildren)
- recursionOrder.shed_cols(children, maxChildren - 1);
- }
+ return SortPolicy::IsBetter(leftDistance, rightDistance);
}
template<typename SortPolicy, typename MetricType, typename TreeType>
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp 2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp 2012-08-05 03:01:42 UTC (rev 13335)
@@ -9,11 +9,11 @@
#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
#include <mlpack/core.hpp>
-#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
namespace mlpack {
namespace range /** Range-search routines. */ {
More information about the mlpack-svn
mailing list