[mlpack-svn] r15787 - in mlpack/trunk/src/mlpack/methods/neighbor_search: . sort_policies

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Sep 16 20:06:00 EDT 2013


Author: rcurtin
Date: Mon Sep 16 20:05:59 2013
New Revision: 15787

Log:
Overhaul NeighborSearch so that it only needs one overload of Score() and does
not pass around the base case.  This fixes the failing tests with cover trees.
Also, QueryStat is split into neighbor_search_stat.hpp and renamed
NeighborSearchStat.


Added:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_stat.hpp
      - copied, changed from r15759, /mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/CMakeLists.txt	Mon Sep 16 20:05:59 2013
@@ -5,6 +5,7 @@
   neighbor_search_impl.hpp
   neighbor_search_rules.hpp
   neighbor_search_rules_impl.hpp
+  neighbor_search_stat.hpp
   sort_policies/nearest_neighbor_sort.hpp
   sort_policies/nearest_neighbor_sort.cpp
   sort_policies/nearest_neighbor_sort_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	Mon Sep 16 20:05:59 2013
@@ -116,9 +116,11 @@
   Log::Info << "Building reference tree..." << endl;
   Timer::Start("reference_tree_building");
 
-  BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >
+  BinarySpaceTree<bound::HRectBound<2>,
+      NeighborSearchStat<FurthestNeighborSort> >
       refTree(referenceData, oldFromNewRefs, leafSize);
-  BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >*
+  BinarySpaceTree<bound::HRectBound<2>,
+      NeighborSearchStat<FurthestNeighborSort> >*
       queryTree = NULL; // Empty for now.
 
   Timer::Stop("reference_tree_building");
@@ -144,7 +146,7 @@
     Timer::Start("query_tree_building");
 
     queryTree = new BinarySpaceTree<bound::HRectBound<2>,
-        QueryStat<FurthestNeighborSort> >(queryData, oldFromNewQueries,
+        NeighborSearchStat<FurthestNeighborSort> >(queryData, oldFromNewQueries,
         leafSize);
 
     Timer::Stop("query_tree_building");

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	Mon Sep 16 20:05:59 2013
@@ -179,9 +179,11 @@
     Log::Info << "Building reference tree..." << endl;
     Timer::Start("tree_building");
 
-    BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
+    BinarySpaceTree<bound::HRectBound<2>,
+        NeighborSearchStat<NearestNeighborSort> >
         refTree(referenceData, oldFromNewRefs, leafSize);
-    BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
+    BinarySpaceTree<bound::HRectBound<2>,
+        NeighborSearchStat<NearestNeighborSort> >*
         queryTree = NULL; // Empty for now.
 
     Timer::Stop("tree_building");
@@ -205,8 +207,8 @@
         Timer::Start("tree_building");
 
         queryTree = new BinarySpaceTree<bound::HRectBound<2>,
-            QueryStat<NearestNeighborSort> >(queryData, oldFromNewQueries,
-            leafSize);
+            NeighborSearchStat<NearestNeighborSort> >(queryData,
+            oldFromNewQueries, leafSize);
 
         Timer::Stop("tree_building");
       }
@@ -260,14 +262,15 @@
     Log::Info << "Building reference tree..." << endl;
     Timer::Start("tree_building");
     CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-        QueryStat<NearestNeighborSort> > referenceTree(referenceData, 1.3);
+        NeighborSearchStat<NearestNeighborSort> > referenceTree(referenceData,
+        1.3);
     CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-        QueryStat<NearestNeighborSort> >* queryTree = NULL;
+        NeighborSearchStat<NearestNeighborSort> >* queryTree = NULL;
     Timer::Stop("tree_building");
 
     NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
         CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-        QueryStat<NearestNeighborSort> > >* allknn = NULL;
+        NeighborSearchStat<NearestNeighborSort> > >* allknn = NULL;
 
     // See if we have query data.
     if (CLI::HasParam("query_file"))
@@ -278,22 +281,22 @@
         Log::Info << "Building query tree..." << endl;
         Timer::Start("tree_building");
         queryTree = new CoverTree<metric::LMetric<2, true>,
-            tree::FirstPointIsRoot, QueryStat<NearestNeighborSort> >(queryData,
-            1.3);
+            tree::FirstPointIsRoot, NeighborSearchStat<NearestNeighborSort> >(
+            queryData, 1.3);
         Timer::Stop("tree_building");
       }
 
       allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
           CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-          QueryStat<NearestNeighborSort> > >(&referenceTree, queryTree,
+          NeighborSearchStat<NearestNeighborSort> > >(&referenceTree, queryTree,
           referenceData, queryData, singleMode);
     }
     else
     {
       allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
           CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-          QueryStat<NearestNeighborSort> > >(&referenceTree, referenceData,
-          singleMode);
+          NeighborSearchStat<NearestNeighborSort> > >(&referenceTree,
+          referenceData, singleMode);
     }
 
     Log::Info << "Computing " << k << " nearest neighbors..." << endl;

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	Mon Sep 16 20:05:59 2013
@@ -15,6 +15,7 @@
 #include <mlpack/core/tree/binary_space_tree.hpp>
 
 #include <mlpack/core/metrics/lmetric.hpp>
+#include "neighbor_search_stat.hpp"
 #include "sort_policies/nearest_neighbor_sort.hpp"
 
 namespace mlpack {
@@ -23,59 +24,6 @@
                     * searches. */ {
 
 /**
- * Extra data for each node in the tree.  For neighbor searches, each node only
- * needs to store a bound on neighbor distances.
- */
-template<typename SortPolicy>
-class QueryStat
-{
- private:
-  //! The first bound on the node's neighbor distances (B_1).  This represents
-  //! the worst candidate distance of any descendants of this node.
-  double firstBound;
-  //! The second bound on the node's neighbor distances (B_2).  This represents
-  //! a bound on the worst distance of any descendants of this node assembled
-  //! using the best descendant candidate distance modified by the furthest
-  //! descendant distance.
-  double secondBound;
-  //! The better of the two bounds.
-  double bound;
-
- public:
-  /**
-   * Initialize the statistic with the worst possible distance according to
-   * our sorting policy.
-   */
-  QueryStat() :
-      firstBound(SortPolicy::WorstDistance()),
-      secondBound(SortPolicy::WorstDistance()),
-      bound(SortPolicy::WorstDistance()) { }
-
-  /**
-   * Initialization for a fully initialized node.  In this case, we don't need
-   * to worry about the node.
-   */
-  template<typename TreeType>
-  QueryStat(TreeType& /* node */) :
-      firstBound(SortPolicy::WorstDistance()),
-      secondBound(SortPolicy::WorstDistance()),
-      bound(SortPolicy::WorstDistance()) { }
-
-  //! Get the first bound.
-  double FirstBound() const { return firstBound; }
-  //! Modify the first bound.
-  double& FirstBound() { return firstBound; }
-  //! Get the second bound.
-  double SecondBound() const { return secondBound; }
-  //! Modify the second bound.
-  double& SecondBound() { return secondBound; }
-  //! Get the overall bound (the better of the two bounds).
-  double Bound() const { return bound; }
-  //! Modify the overall bound (it should be the better of the two bounds).
-  double& Bound() { return bound; }
-};
-
-/**
  * The NeighborSearch class is a template class for performing distance-based
  * neighbor searches.  It takes a query dataset and a reference dataset (or just
  * a reference dataset) and, for each point in the query dataset, finds the k
@@ -96,7 +44,7 @@
 template<typename SortPolicy = NearestNeighborSort,
          typename MetricType = mlpack::metric::SquaredEuclideanDistance,
          typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
-                                                   QueryStat<SortPolicy> > >
+             NeighborSearchStat<SortPolicy> > >
 class NeighborSearch
 {
  public:
@@ -262,10 +210,10 @@
   //! Pointer to the root of the query tree (might not exist).
   TreeType* queryTree;
 
-  //! Indicates if we should free the reference tree at deletion time.
-  bool ownReferenceTree;
-  //! Indicates if we should free the query tree at deletion time.
-  bool ownQueryTree;
+  //! If true, this object created the trees and is responsible for them.
+  bool treeOwner;
+  //! Indicates if a separate query set was passed.
+  bool hasQuerySet;
 
   //! Indicates if O(n^2) naive search is being used.
   bool naive;

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	Mon Sep 16 20:05:59 2013
@@ -29,8 +29,8 @@
     querySet(queryCopy),
     referenceTree(NULL),
     queryTree(NULL),
-    ownReferenceTree(true), // False if a tree was passed.
-    ownQueryTree(true), // False if a tree was passed.
+    treeOwner(true), // False if a tree was passed.
+    hasQuerySet(true),
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
     metric(metric),
@@ -40,19 +40,18 @@
   // copypasta problem.
 
   // We'll time tree building, but only if we are building trees.
-  if (!referenceTree || !queryTree)
-    Timer::Start("tree_building");
+  Timer::Start("tree_building");
 
   // Construct as a naive object if we need to.
   referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
       (naive ? referenceCopy.n_cols : leafSize));
 
-  queryTree = new TreeType(queryCopy, oldFromNewQueries,
-      (naive ? querySet.n_cols : leafSize));
+  if (!singleMode)
+    queryTree = new TreeType(queryCopy, oldFromNewQueries,
+        (naive ? querySet.n_cols : leafSize));
 
   // Stop the timer we started above (if we need to).
-  if (!referenceTree || !queryTree)
-    Timer::Stop("tree_building");
+  Timer::Stop("tree_building");
 }
 
 // Construct the object.
@@ -68,8 +67,8 @@
     querySet(referenceCopy),
     referenceTree(NULL),
     queryTree(NULL),
-    ownReferenceTree(true),
-    ownQueryTree(false), // Since it will be the same as referenceTree.
+    treeOwner(true),
+    hasQuerySet(false),
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
     metric(metric),
@@ -81,6 +80,8 @@
   // Construct as a naive object if we need to.
   referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
       (naive ? referenceSet.n_cols : leafSize));
+  if (!singleMode)
+    queryTree = new TreeType(*referenceTree);
 
   // Stop the timer we started above.
   Timer::Stop("tree_building");
@@ -99,8 +100,8 @@
     querySet(querySet),
     referenceTree(referenceTree),
     queryTree(queryTree),
-    ownReferenceTree(false),
-    ownQueryTree(false),
+    treeOwner(false),
+    hasQuerySet(true),
     naive(false),
     singleMode(singleMode),
     metric(metric),
@@ -120,14 +121,20 @@
     querySet(referenceSet),
     referenceTree(referenceTree),
     queryTree(NULL),
-    ownReferenceTree(false),
-    ownQueryTree(false),
+    treeOwner(false),
+    hasQuerySet(false), // In this case we will own a tree, if singleMode.
     naive(false),
     singleMode(singleMode),
     metric(metric),
     numberOfPrunes(0)
 {
-  // Nothing else to initialize.
+  Timer::Start("tree_building");
+
+  // The query tree cannot be the same as the reference tree.
+  if (referenceTree && !singleMode)
+    queryTree = new TreeType(*referenceTree);
+
+  Timer::Stop("tree_building");
 }
 
 /**
@@ -137,10 +144,18 @@
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
 {
-  if (ownReferenceTree)
-    delete referenceTree;
-  if (ownQueryTree)
+  if (treeOwner)
+  {
+    if (referenceTree)
+      delete referenceTree;
+    if (queryTree)
+      delete queryTree;
+  }
+  else if (!treeOwner && !hasQuerySet && !singleMode)
+  {
+    // We replicated the reference tree to create a query tree.
     delete queryTree;
+  }
 }
 
 /**
@@ -162,9 +177,9 @@
   arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
   arma::mat* distancePtr = &distances;
 
-  if (ownQueryTree || (ownReferenceTree && !queryTree))
+  if (treeOwner && !(singleMode && hasQuerySet))
     distancePtr = new arma::mat; // Query indices need to be mapped.
-  if (ownReferenceTree || ownQueryTree)
+  if (treeOwner)
     neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
 
   // Set the size of the neighbor and distance matrices.
@@ -186,8 +201,6 @@
     // Now have it traverse for each point.
     for (size_t i = 0; i < querySet.n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
-
-    numPrunes = traverser.NumPrunes();
   }
   else // Dual-tree recursion.
   {
@@ -197,25 +210,22 @@
 
     typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
 
-    if (queryTree)
-      traverser.Traverse(*queryTree, *referenceTree);
-    else
-      traverser.Traverse(*referenceTree, *referenceTree);
+    traverser.Traverse(*queryTree, *referenceTree);
 
-    numPrunes = traverser.NumPrunes();
+    Log::Info << traverser.NumVisited() << " node combinations were visited.\n";
+    Log::Info << traverser.NumScores() << " node combinations were scored.\n";
+    Log::Info << traverser.NumBaseCases() << " base cases were calculated.\n";
   }
 
-  Log::Debug << "Pruned " << numPrunes << " nodes." << std::endl;
-
   Timer::Stop("computing_neighbors");
 
   // Now, do we need to do mapping of indices?
-  if (!ownReferenceTree && !ownQueryTree)
+  if (!treeOwner)
   {
     // No mapping needed.  We are done.
     return;
   }
-  else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+  else if (treeOwner && hasQuerySet && !singleMode) // Map both sets.
   {
     // Set size of output matrices correctly.
     resultingNeighbors.set_size(k, querySet.n_cols);
@@ -238,62 +248,40 @@
     delete neighborPtr;
     delete distancePtr;
   }
-  else if (ownReferenceTree)
+  else if (treeOwner && !hasQuerySet)
   {
-    if (!queryTree) // No query tree -- map both references and queries.
-    {
-      resultingNeighbors.set_size(k, querySet.n_cols);
-      distances.set_size(k, querySet.n_cols);
-
-      for (size_t i = 0; i < distances.n_cols; i++)
-      {
-        // Map distances (copy a column).
-        distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
+    resultingNeighbors.set_size(k, querySet.n_cols);
+    distances.set_size(k, querySet.n_cols);
 
-        // Map indices of neighbors.
-        for (size_t j = 0; j < distances.n_rows; j++)
-        {
-          resultingNeighbors(j, oldFromNewReferences[i]) =
-              oldFromNewReferences[(*neighborPtr)(j, i)];
-        }
-      }
-    }
-    else // Map only references.
+    for (size_t i = 0; i < distances.n_cols; i++)
     {
-      // Set size of neighbor indices matrix correctly.
-      resultingNeighbors.set_size(k, querySet.n_cols);
+      // Map distances (copy a column).
+      distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
 
       // Map indices of neighbors.
-      for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
+      for (size_t j = 0; j < distances.n_rows; j++)
       {
-        for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
-        {
-          resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
-        }
+        resultingNeighbors(j, oldFromNewReferences[i]) =
+            oldFromNewReferences[(*neighborPtr)(j, i)];
       }
     }
-
-    // Finished with temporary matrix.
-    delete neighborPtr;
   }
-  else if (ownQueryTree)
+  else if (treeOwner && hasQuerySet && singleMode) // Map only references.
   {
-    // Set size of matrices correctly.
+    // Set size of neighbor indices matrix correctly.
     resultingNeighbors.set_size(k, querySet.n_cols);
-    distances.set_size(k, querySet.n_cols);
 
-    for (size_t i = 0; i < distances.n_cols; i++)
+    // Map indices of neighbors.
+    for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
     {
-      // Map distances (copy a column).
-      distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
-
-      // Map indices of neighbors.
-      resultingNeighbors.col(oldFromNewQueries[i]) = neighborPtr->col(i);
+      for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
+      {
+        resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+      }
     }
 
-    // Finished with temporary matrices.
+    // Finished with temporary matrix.
     delete neighborPtr;
-    delete distancePtr;
   }
 } // Search
 

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	Mon Sep 16 20:05:59 2013
@@ -24,24 +24,6 @@
   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
   /**
-   * Get the score for the recursion order, in general before the base case is
-   * computed.  This is useful for cover trees or other trees that can cache
-   * some statistic that could be used to make a prune of a child before its
-   * base case is computed.
-   *
-   * @param queryNode Query node.
-   * @param referenceNode Reference node.
-   */
-  double Prescore(TreeType& queryNode,
-                  TreeType& referenceNode,
-                  TreeType& referenceChildNode,
-                  const double baseCaseResult) const;
-  double PrescoreQ(TreeType& queryNode,
-                   TreeType& queryChildNode,
-                   TreeType& referenceNode,
-                   const double baseCaseResult) const;
-
-  /**
    * Get the score for recursion order.  A low score indicates priority for
    * recursion, while DBL_MAX indicates that the node should not be recursed
    * into at all (it should be pruned).
@@ -49,21 +31,7 @@
    * @param queryIndex Index of query point.
    * @param referenceNode Candidate node to be recursed into.
    */
-  double Score(const size_t queryIndex, TreeType& referenceNode) const;
-
-  /**
-   * Get the score for recursion order, passing the base case result (in the
-   * situation where it may be needed to calculate the recursion order).  A low
-   * score indicates priority for recursion, while DBL_MAX indicates that the
-   * node should not be recursed into at all (it should be pruned).
-   *
-   * @param queryIndex Index of query point.
-   * @param referenceNode Candidate node to be recursed into.
-   * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
-   */
-  double Score(const size_t queryIndex,
-               TreeType& referenceNode,
-               const double baseCaseResult) const;
+  double Score(const size_t queryIndex, TreeType& referenceNode);
 
   /**
    * Re-evaluate the score for recursion order.  A low score indicates priority
@@ -88,21 +56,7 @@
    * @param queryNode Candidate query node to recurse into.
    * @param referenceNode Candidate reference node to recurse into.
    */
-  double Score(TreeType& queryNode, TreeType& referenceNode) const;
-
-  /**
-   * Get the score for recursion order, passing the base case result (in the
-   * situation where it may be needed to calculate the recursion order).  A low
-   * score indicates priority for recursion, while DBL_MAX indicates that the
-   * node should not be recursed into at all (it should be pruned).
-   *
-   * @param queryNode Candidate query node to recurse into.
-   * @param referenceNode Candidate reference node to recurse into.
-   * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
-   */
-  double Score(TreeType& queryNode,
-               TreeType& referenceNode,
-               const double baseCaseResult) const;
+  double Score(TreeType& queryNode, TreeType& referenceNode);
 
   /**
    * Re-evaluate the score for recursion order.  A low score indicates priority
@@ -135,6 +89,13 @@
   //! The instantiated metric.
   MetricType& metric;
 
+  //! The last query point BaseCase() was called with.
+  size_t lastQueryIndex;
+  //! The last reference point BaseCase() was called with.
+  size_t lastReferenceIndex;
+  //! The last base case result.
+  double lastBaseCase;
+
   /**
    * Recalculate the bound for a given query node.
    */

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	Mon Sep 16 20:05:59 2013
@@ -24,7 +24,9 @@
     querySet(querySet),
     neighbors(neighbors),
     distances(distances),
-    metric(metric)
+    metric(metric),
+    lastQueryIndex(querySet.n_cols),
+    lastReferenceIndex(referenceSet.n_cols)
 { /* Nothing left to do. */ }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
@@ -37,6 +39,10 @@
   if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
     return 0.0;
 
+  // If we have already performed this base case, then do not perform it again.
+  if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
+    return lastBaseCase;
+
   double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
                                     referenceSet.unsafe_col(referenceIndex));
 
@@ -49,63 +55,49 @@
   if (insertPosition != (size_t() - 1))
     InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
 
-  return distance;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Prescore(
-    TreeType& queryNode,
-    TreeType& referenceNode,
-    TreeType& referenceChildNode,
-    const double baseCaseResult) const
-{
-  const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
-      &referenceNode, &referenceChildNode, baseCaseResult);
-
-  // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
-
-  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::PrescoreQ(
-    TreeType& queryNode,
-    TreeType& queryChildNode,
-    TreeType& referenceNode,
-    const double baseCaseResult) const
-{
-  const double distance = SortPolicy::BestNodeToNodeDistance(&referenceNode,
-      &queryNode, &queryChildNode, baseCaseResult);
-
-  // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
+  // Cache this information for the next time BaseCase() is called.
+  lastQueryIndex = queryIndex;
+  lastReferenceIndex = referenceIndex;
+  lastBaseCase = distance;
 
-  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+  return distance;
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
     const size_t queryIndex,
-    TreeType& referenceNode) const
+    TreeType& referenceNode)
 {
-  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
-  const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
-      &referenceNode);
-  const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  double distance;
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+  {
+    // The first point in the tree is the centroid.  So we can then calculate
+    // the base case between that and the query point.
+    double baseCase;
+    if (tree::TreeTraits<TreeType>::HasSelfChildren)
+    {
+      // If the parent node is the same, then we have already calculated the
+      // base case.
+      if ((referenceNode.Parent() != NULL) &&
+          (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
+        baseCase = referenceNode.Parent()->Stat().LastDistance();
+      else
+        baseCase = BaseCase(queryIndex, referenceNode.Point(0));
+
+      // Save this evaluation.
+      referenceNode.Stat().LastDistance() = baseCase;
+    }
 
-  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
+    distance = SortPolicy::CombineBest(baseCase,
+        referenceNode.FurthestDescendantDistance());
+  }
+  else
+  {
+    const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+    distance = SortPolicy::BestPointToNodeDistance(queryPoint, &referenceNode);
+  }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
-    const size_t queryIndex,
-    TreeType& referenceNode,
-    const double baseCaseResult) const
-{
-  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
-  const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
-      &referenceNode, baseCaseResult);
+  // Compare against the best k'th distance for this query point so far.
   const double bestDistance = distances(distances.n_rows - 1, queryIndex);
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -130,25 +122,95 @@
 template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
     TreeType& queryNode,
-    TreeType& referenceNode) const
-{
-  const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
-      &referenceNode);
-
-  // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
-
-  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
-    TreeType& queryNode,
-    TreeType& referenceNode,
-    const double baseCaseResult) const
+    TreeType& referenceNode)
 {
-  const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
-      &referenceNode, baseCaseResult);
+  double distance;
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+  {
+    // The first point in the node is the centroid, so we can calculate the
+    // distance between the two points using BaseCase() and then find the
+    // bounds.  This is potentially loose for non-ball bounds.
+    bool alreadyDone = false;
+    double baseCase;
+    if (tree::TreeTraits<TreeType>::HasSelfChildren)
+    {
+      // In this case, we may have already calculated the base case.
+      TreeType* lastRef = (TreeType*) queryNode.Stat().LastDistanceNode();
+      TreeType* lastQuery = (TreeType*) referenceNode.Stat().LastDistanceNode();
+
+      // Does the query node have the base case cached?
+      if ((lastRef != NULL) && (referenceNode.Point(0) == lastRef->Point(0)))
+      {
+        baseCase = queryNode.Stat().LastDistance();
+        alreadyDone = true;
+      }
+
+      // Does the reference node have the base case cached?
+      if ((lastQuery != NULL) &&
+          (queryNode.Point(0) == lastQuery->Point(0)))
+      {
+        baseCase = queryNode.Stat().LastDistance();
+        alreadyDone = true;
+      }
+
+      // Is the query node a self-child, and if so, does the query node's parent
+      // have the base case cached?
+      if ((queryNode.Parent() != NULL) &&
+          (queryNode.Parent()->Point(0) == queryNode.Point(0)))
+      {
+        TreeType* lastParentRef = (TreeType*)
+            queryNode.Parent()->Stat().LastDistanceNode();
+        if (lastParentRef->Point(0) == referenceNode.Point(0))
+        {
+          baseCase = queryNode.Parent()->Stat().LastDistance();
+          alreadyDone = true;
+        }
+      }
+
+      // Is the reference node a self-child, and if so, does the reference
+      // node's parent have the base case cached?
+      if ((referenceNode.Parent() != NULL) &&
+          (referenceNode.Parent()->Point(0) == referenceNode.Point(0)))
+      {
+        TreeType* lastParentRef = (TreeType*)
+            referenceNode.Parent()->Stat().LastDistanceNode();
+        if (lastParentRef->Point(0) == queryNode.Point(0))
+        {
+          baseCase = referenceNode.Parent()->Stat().LastDistance();
+          alreadyDone = true;
+        }
+      }
+    }
+
+    // If we did not find a cached base case, then recalculate it.
+    if (!alreadyDone)
+    {
+      baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
+    }
+    else
+    {
+      // Set lastQueryIndex and lastReferenceIndex, so that BaseCase() does not
+      // duplicate work.
+      lastQueryIndex = queryNode.Point(0);
+      lastReferenceIndex = referenceNode.Point(0);
+      lastBaseCase = baseCase;
+    }
+
+//    distance = SortPolicy::CombineBest(baseCase,
+//        queryNode.FurthestDescendantDistance() +
+//        referenceNode.FurthestDescendantDistance());
+    distance = 0;
+
+    // Update the last distance calculation for the query and reference nodes.
+    queryNode.Stat().LastDistanceNode() = (void*) &referenceNode;
+    queryNode.Stat().LastDistance() = baseCase;
+    referenceNode.Stat().LastDistanceNode() = (void*) &queryNode;
+    referenceNode.Stat().LastDistance() = baseCase;
+  }
+  else
+  {
+    distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode);
+  }
 
   // Update our bound.
   const double bestDistance = CalculateBound(queryNode);

Copied: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_stat.hpp (from r15759, /mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp)
==============================================================================
--- /mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_stat.hpp	Mon Sep 16 20:05:59 2013
@@ -5,29 +5,20 @@
  * Defines the NeighborSearch class, which performs an abstract
  * nearest-neighbor-like query on two datasets.
  */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_STAT_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_STAT_HPP
 
 #include <mlpack/core.hpp>
-#include <vector>
-#include <string>
-
-#include <mlpack/core/tree/binary_space_tree.hpp>
-
-#include <mlpack/core/metrics/lmetric.hpp>
-#include "sort_policies/nearest_neighbor_sort.hpp"
 
 namespace mlpack {
-namespace neighbor /** Neighbor-search routines.  These include
-                    * all-nearest-neighbors and all-furthest-neighbors
-                    * searches. */ {
+namespace neighbor {
 
 /**
  * Extra data for each node in the tree.  For neighbor searches, each node only
  * needs to store a bound on neighbor distances.
  */
 template<typename SortPolicy>
-class QueryStat
+class NeighborSearchStat
 {
  private:
   //! The first bound on the node's neighbor distances (B_1).  This represents
@@ -41,25 +32,34 @@
   //! The better of the two bounds.
   double bound;
 
+  //! The last distance evaluation node.
+  void* lastDistanceNode;
+  //! The last distance evaluation.
+  double lastDistance;
+
  public:
   /**
    * Initialize the statistic with the worst possible distance according to
    * our sorting policy.
    */
-  QueryStat() :
+  NeighborSearchStat() :
       firstBound(SortPolicy::WorstDistance()),
       secondBound(SortPolicy::WorstDistance()),
-      bound(SortPolicy::WorstDistance()) { }
+      bound(SortPolicy::WorstDistance()),
+      lastDistanceNode(NULL),
+      lastDistance(0.0) { }
 
   /**
    * Initialization for a fully initialized node.  In this case, we don't need
    * to worry about the node.
    */
   template<typename TreeType>
-  QueryStat(TreeType& /* node */) :
+  NeighborSearchStat(TreeType& /* node */) :
       firstBound(SortPolicy::WorstDistance()),
       secondBound(SortPolicy::WorstDistance()),
-      bound(SortPolicy::WorstDistance()) { }
+      bound(SortPolicy::WorstDistance()),
+      lastDistanceNode(NULL),
+      lastDistance(0.0) { }
 
   //! Get the first bound.
   double FirstBound() const { return firstBound; }
@@ -73,224 +73,17 @@
   double Bound() const { return bound; }
   //! Modify the overall bound (it should be the better of the two bounds).
   double& Bound() { return bound; }
+  //! Get the last distance evaluation node.
+  void* LastDistanceNode() const { return lastDistanceNode; }
+  //! Modify the last distance evaluation node.
+  void*& LastDistanceNode() { return lastDistanceNode; }
+  //! Get the last distance calculation.
+  double LastDistance() const { return lastDistance; }
+  //! Modify the last distance calculation.
+  double& LastDistance() { return lastDistance; }
 };
 
-/**
- * The NeighborSearch class is a template class for performing distance-based
- * neighbor searches.  It takes a query dataset and a reference dataset (or just
- * a reference dataset) and, for each point in the query dataset, finds the k
- * neighbors in the reference dataset which have the 'best' distance according
- * to a given sorting policy.  A constructor is given which takes only a
- * reference dataset, and if that constructor is used, the given reference
- * dataset is also used as the query dataset.
- *
- * The template parameters SortPolicy and Metric define the sort function used
- * and the metric (distance function) used.  More information on those classes
- * can be found in the NearestNeighborSort class and the kernel::ExampleKernel
- * class.
- *
- * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
- * @tparam MetricType The metric to use for computation.
- * @tparam TreeType The tree type to use.
- */
-template<typename SortPolicy = NearestNeighborSort,
-         typename MetricType = mlpack::metric::SquaredEuclideanDistance,
-         typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
-                                                   QueryStat<SortPolicy> > >
-class NeighborSearch
-{
- public:
-  /**
-   * Initialize the NeighborSearch object, passing both a query and reference
-   * dataset.  Optionally, perform the computation in naive mode or single-tree
-   * mode, and set the leaf size used for tree-building.  An initialized
-   * distance metric can be given, for cases where the metric has internal data
-   * (i.e. the distance::MahalanobisDistance class).
-   *
-   * This method will copy the matrices to internal copies, which are rearranged
-   * during tree-building.  You can avoid this extra copy by pre-constructing
-   * the trees and passing them using a diferent constructor.
-   *
-   * @param referenceSet Set of reference points.
-   * @param querySet Set of query points.
-   * @param naive If true, O(n^2) naive search will be used (as opposed to
-   *      dual-tree search).  This overrides singleMode (if it is set to true).
-   * @param singleMode If true, single-tree search will be used (as opposed to
-   *      dual-tree search).
-   * @param leafSize Leaf size for tree construction (ignored if tree is given).
-   * @param metric An optional instance of the MetricType class.
-   */
-  NeighborSearch(const typename TreeType::Mat& referenceSet,
-                 const typename TreeType::Mat& querySet,
-                 const bool naive = false,
-                 const bool singleMode = false,
-                 const size_t leafSize = 20,
-                 const MetricType metric = MetricType());
-
-  /**
-   * Initialize the NeighborSearch object, passing only one dataset, which is
-   * used as both the query and the reference dataset.  Optionally, perform the
-   * computation in naive mode or single-tree mode, and set the leaf size used
-   * for tree-building.  An initialized distance metric can be given, for cases
-   * where the metric has internal data (i.e. the distance::MahalanobisDistance
-   * class).
-   *
-   * If naive mode is being used and a pre-built tree is given, it may not work:
-   * naive mode operates by building a one-node tree (the root node holds all
-   * the points).  If that condition is not satisfied with the pre-built tree,
-   * then naive mode will not work.
-   *
-   * @param referenceSet Set of reference points.
-   * @param naive If true, O(n^2) naive search will be used (as opposed to
-   *      dual-tree search).  This overrides singleMode (if it is set to true).
-   * @param singleMode If true, single-tree search will be used (as opposed to
-   *      dual-tree search).
-   * @param leafSize Leaf size for tree construction (ignored if tree is given).
-   * @param metric An optional instance of the MetricType class.
-   */
-  NeighborSearch(const typename TreeType::Mat& referenceSet,
-                 const bool naive = false,
-                 const bool singleMode = false,
-                 const size_t leafSize = 20,
-                 const MetricType metric = MetricType());
-
-  /**
-   * Initialize the NeighborSearch object with the given datasets and
-   * pre-constructed trees.  It is assumed that the points in referenceSet and
-   * querySet correspond to the points in referenceTree and queryTree,
-   * respectively.  Optionally, choose to use single-tree mode.  Naive mode is
-   * not available as an option for this constructor; instead, to run naive
-   * computation, construct a tree with all of the points in one leaf (i.e.
-   * leafSize = number of points).  Additionally, an instantiated distance
-   * metric can be given, for cases where the distance metric holds data.
-   *
-   * There is no copying of the data matrices in this constructor (because
-   * tree-building is not necessary), so this is the constructor to use when
-   * copies absolutely must be avoided.
-   *
-   * @note
-   * Because tree-building (at least with BinarySpaceTree) modifies the ordering
-   * of a matrix, be sure you pass the modified matrix to this object!  In
-   * addition, mapping the points of the matrix back to their original indices
-   * is not done when this constructor is used.
-   * @endnote
-   *
-   * @param referenceTree Pre-built tree for reference points.
-   * @param queryTree Pre-built tree for query points.
-   * @param referenceSet Set of reference points corresponding to referenceTree.
-   * @param querySet Set of query points corresponding to queryTree.
-   * @param singleMode Whether single-tree computation should be used (as
-   *      opposed to dual-tree computation).
-   * @param metric Instantiated distance metric.
-   */
-  NeighborSearch(TreeType* referenceTree,
-                 TreeType* queryTree,
-                 const typename TreeType::Mat& referenceSet,
-                 const typename TreeType::Mat& querySet,
-                 const bool singleMode = false,
-                 const MetricType metric = MetricType());
-
-  /**
-   * Initialize the NeighborSearch object with the given reference dataset and
-   * pre-constructed tree.  It is assumed that the points in referenceSet
-   * correspond to the points in referenceTree.  Optionally, choose to use
-   * single-tree mode.  Naive mode is not available as an option for this
-   * constructor; instead, to run naive computation, construct a tree with all
-   * the points in one leaf (i.e. leafSize = number of points).  Additionally,
-   * an instantiated distance metric can be given, for the case where the
-   * distance metric holds data.
-   *
-   * There is no copying of the data matrices in this constructor (because
-   * tree-building is not necessary), so this is the constructor to use when
-   * copies absolutely must be avoided.
-   *
-   * @note
-   * Because tree-building (at least with BinarySpaceTree) modifies the ordering
-   * of a matrix, be sure you pass the modified matrix to this object!  In
-   * addition, mapping the points of the matrix back to their original indices
-   * is not done when this constructor is used.
-   * @endnote
-   *
-   * @param referenceTree Pre-built tree for reference points.
-   * @param referenceSet Set of reference points corresponding to referenceTree.
-   * @param singleMode Whether single-tree computation should be used (as
-   *      opposed to dual-tree computation).
-   * @param metric Instantiated distance metric.
-   */
-  NeighborSearch(TreeType* referenceTree,
-                 const typename TreeType::Mat& referenceSet,
-                 const bool singleMode = false,
-                 const MetricType metric = MetricType());
-
-
-  /**
-   * Delete the NeighborSearch object. The tree is the only member we are
-   * responsible for deleting.  The others will take care of themselves.
-   */
-  ~NeighborSearch();
-
-  /**
-   * Compute the nearest neighbors and store the output in the given matrices.
-   * The matrices will be set to the size of n columns by k rows, where n is the
-   * number of points in the query dataset and k is the number of neighbors
-   * being searched for.
-   *
-   * @param k Number of neighbors to search for.
-   * @param resultingNeighbors Matrix storing lists of neighbors for each query
-   *     point.
-   * @param distances Matrix storing distances of neighbors for each query
-   *     point.
-   */
-  void Search(const size_t k,
-              arma::Mat<size_t>& resultingNeighbors,
-              arma::mat& distances);
-
- private:
-  //! Copy of reference dataset (if we need it, because tree building modifies
-  //! it).
-  arma::mat referenceCopy;
-  //! Copy of query dataset (if we need it, because tree building modifies it).
-  arma::mat queryCopy;
-
-  //! Reference dataset.
-  const arma::mat& referenceSet;
-  //! Query dataset (may not be given).
-  const arma::mat& querySet;
-
-  //! Pointer to the root of the reference tree.
-  TreeType* referenceTree;
-  //! Pointer to the root of the query tree (might not exist).
-  TreeType* queryTree;
-
-  //! Indicates if we should free the reference tree at deletion time.
-  bool ownReferenceTree;
-  //! Indicates if we should free the query tree at deletion time.
-  bool ownQueryTree;
-
-  //! Indicates if O(n^2) naive search is being used.
-  bool naive;
-  //! Indicates if single-tree search is being used (opposed to dual-tree).
-  bool singleMode;
-
-  //! Instantiation of metric.
-  MetricType metric;
-
-  //! Permutations of reference points during tree building.
-  std::vector<size_t> oldFromNewReferences;
-  //! Permutations of query points during tree building.
-  std::vector<size_t> oldFromNewQueries;
-
-  //! Total number of pruned nodes during the neighbor search.
-  size_t numberOfPrunes;
-}; // class NeighborSearch
-
 }; // namespace neighbor
 }; // namespace mlpack
 
-// Include implementation.
-#include "neighbor_search_impl.hpp"
-
-// Include convenience typedefs.
-#include "typedef.hpp"
-
 #endif

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp	Mon Sep 16 20:05:59 2013
@@ -129,6 +129,16 @@
   static inline double BestDistance() { return DBL_MAX; }
 
   /**
+   * Return the best combination of the two distances.
+   */
+  static inline double CombineBest(const double a, const double b)
+  {
+    if (a == DBL_MAX || b == DBL_MAX)
+      return DBL_MAX;
+    return a + b;
+  }
+
+  /**
    * Return the worst combination of the two distances.
    */
   static inline double CombineWorst(const double a, const double b)

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp	Mon Sep 16 20:05:59 2013
@@ -132,13 +132,22 @@
   static inline double BestDistance() { return 0.0; }
 
   /**
+   * Return the best combination of the two distances.
+   */
+  static inline double CombineBest(const double a, const double b)
+  {
+    return std::max(a - b, 0.0);
+  }
+
+  /**
    * Return the worst combination of the two distances.
    */
   static inline double CombineWorst(const double a, const double b)
   {
     if (a == DBL_MAX || b == DBL_MAX)
       return DBL_MAX;
-    return a + b; }
+    return a + b;
+  }
 };
 
 }; // namespace neighbor



More information about the mlpack-svn mailing list