[mlpack-git] master: Add Train() functions to NeighborSearch for consistency. (60c8170)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Sep 29 09:33:25 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/cbeb3ea17262b7c5115247dc217e316c529249b7...f85a9b22f3ce56143943a2488c05c2810d6b2bf3

>---------------------------------------------------------------

commit 60c81702c432014544351948c091d03deeef6985
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Sep 25 17:18:34 2015 -0400

    Add Train() functions to NeighborSearch for consistency.
    
    This will make serialization much easier, and also adds more flexibility to the
    class.


>---------------------------------------------------------------

60c81702c432014544351948c091d03deeef6985
 .../methods/neighbor_search/neighbor_search.hpp    |  52 ++++-
 .../neighbor_search/neighbor_search_impl.hpp       | 247 +++++++++++++++++++--
 2 files changed, 265 insertions(+), 34 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 0e56d25..ff01c88 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -116,12 +116,42 @@ class NeighborSearch
                  const MetricType metric = MetricType());
 
   /**
+   * Create a NeighborSearch object without any reference data.  If Search() is
+   * called before a reference set is set with Train(), an exception will be
+   * thrown.
+   *
+   * @param naive Whether to use naive search.
+   * @param singleMode Whether single-tree computation should be used (as
+   *      opposed to dual-tree computation).
+   * @param metric Instantiated metric.
+   */
+  NeighborSearch(const bool naive = false,
+                 const bool singleMode = false,
+                 const MetricType metric = MetricType());
+
+
+  /**
    * Delete the NeighborSearch object. The tree is the only member we are
    * responsible for deleting.  The others will take care of themselves.
    */
   ~NeighborSearch();
 
   /**
+   * Set the reference set to a new reference set, and build a tree if
+   * necessary.  This method is called 'Train()' in order to match the rest of
+   * the mlpack abstractions, even though calling this "training" is maybe a bit
+   * of a stretch.
+   *
+   * @param referenceSet New set of reference data.
+   */
+  void Train(const MatType& referenceSet);
+
+  /**
+   * Set the reference tree to a new reference tree.
+   */
+  void Train(Tree* referenceTree);
+
+  /**
    * For each point in the query set, compute the nearest neighbors and store
    * the output in the given matrices.  The matrices will be set to the size of
    * n columns by k rows, where n is the number of points in the query dataset
@@ -182,16 +212,12 @@ class NeighborSearch
   //! Returns a string representation of this object.
   std::string ToString() const;
 
-  //! Return the total number of base case evaluations performed during
-  //! searches.
+  //! Return the total number of base case evaluations performed during the last
+  //! search.
   size_t BaseCases() const { return baseCases; }
-  //! Modify the total number of base case evaluations.
-  size_t& BaseCases() { return baseCases; }
 
-  //! Return the number of node combination scores during the search.
+  //! Return the number of node combination scores during the last search.
   size_t Scores() const { return scores; }
-  //! Modify the number of node combination scores.
-  size_t& Scores() { return scores; }
 
   //! Access whether or not search is done in naive linear scan mode.
   bool Naive() const { return naive; }
@@ -203,20 +229,26 @@ class NeighborSearch
   //! Modify whether or not search is done in single-tree mode.
   bool& SingleMode() { return singleMode; }
 
+  //! Serialize the NeighborSearch model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
  private:
   //! Permutations of reference points during tree building.
   std::vector<size_t> oldFromNewReferences;
   //! Pointer to the root of the reference tree.
   Tree* referenceTree;
-  //! Reference to reference dataset.
-  const MatType& referenceSet;
+  //! Reference dataset.  In some situations we may be the owner of this.
+  const MatType* referenceSet;
 
   //! If true, this object created the trees and is responsible for them.
   bool treeOwner;
+  //! If true, we own the reference set.
+  bool setOwner;
 
   //! Indicates if O(n^2) naive search is being used.
   bool naive;
-  //! Indicates if single-tree search is being used (opposed to dual-tree).
+  //! Indicates if single-tree search is being used (as opposed to dual-tree).
   bool singleMode;
 
   //! Instantiation of metric.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 82bec9c..19eabc8 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -54,8 +54,9 @@ NeighborSearch(const MatType& referenceSetIn,
                const MetricType metric) :
     referenceTree(naive ? NULL :
         BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
-    referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
+    referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
     treeOwner(!naive), // False if a tree was passed.  If naive, then no trees.
+    setOwner(false),
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
     metric(metric),
@@ -78,8 +79,9 @@ NeighborSearch(Tree* referenceTree,
                const bool singleMode,
                const MetricType metric) :
     referenceTree(referenceTree),
-    referenceSet(referenceTree->Dataset()),
+    referenceSet(&referenceTree->Dataset()),
     treeOwner(false),
+    setOwner(false),
     naive(false),
     singleMode(singleMode),
     metric(metric),
@@ -89,6 +91,37 @@ NeighborSearch(Tree* referenceTree,
   // Nothing else to initialize.
 }
 
+// Construct the object without a reference dataset.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class TraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+    NeighborSearch(const bool naive,
+                   const bool singleMode,
+                   const MetricType metric) :
+    referenceTree(NULL),
+    referenceSet(new MatType()), // Empty matrix.
+    treeOwner(false),
+    setOwner(true),
+    naive(naive),
+    singleMode(singleMode),
+    metric(metric),
+    baseCases(0),
+    scores(0)
+{
+  // Build the tree on the empty dataset, if necessary.
+  if (!naive)
+  {
+    referenceTree = BuildTree<MatType, Tree>(*referenceSet,
+        oldFromNewReferences);
+    treeOwner = true;
+  }
+}
+
 // Clean memory.
 template<typename SortPolicy,
          typename MetricType,
@@ -98,10 +131,69 @@ template<typename SortPolicy,
                   typename TreeMatType> class TreeType,
          template<typename> class TraversalType>
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-    ~NeighborSearch()
+~NeighborSearch()
 {
   if (treeOwner && referenceTree)
     delete referenceTree;
+  if (setOwner && referenceSet)
+    delete referenceSet;
+}
+
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Train(const MatType& referenceSet)
+{
+  // We may need to rebuild the tree.
+  if (!naive)
+  {
+    if (treeOwner && referenceTree)
+      delete referenceTree;
+
+    referenceTree = BuildTree<MatType, Tree>(referenceSet,
+        oldFromNewReferences);
+
+    treeOwner = true;
+  }
+
+  if (setOwner && this->referenceSet)
+    delete this->referenceSet;
+
+  if (!naive)
+    this->referenceSet = &referenceTree->Dataset();
+  else
+    this->referenceSet = &referenceSet;
+  setOwner = false; // We don't own the set in either case.
+}
+
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Train(Tree* referenceTree)
+{
+  if (naive)
+    throw std::invalid_argument("cannot train on given reference tree when "
+        "naive search (without trees) is desired");
+
+  if (treeOwner && referenceTree)
+    delete this->referenceTree;
+  if (setOwner && referenceSet)
+    delete this->referenceSet;
+
+  this->referenceTree = referenceTree;
+  this->referenceSet = &referenceTree->Dataset();
+  treeOwner = false;
+  setOwner = false;
 }
 
 /**
@@ -121,8 +213,19 @@ Search(const MatType& querySet,
        arma::Mat<size_t>& neighbors,
        arma::mat& distances)
 {
+  if (k > referenceSet->n_cols)
+  {
+    std::stringstream ss;
+    ss << "requested value of k (" << k << ") is greater than the number of "
+        << "points in the reference set (" << referenceSet->n_cols << ")";
+    throw std::invalid_argument(ss.str());
+  }
+
   Timer::Start("computing_neighbors");
 
+  baseCases = 0;
+  scores = 0;
+
   // This will hold mappings for query points, if necessary.
   std::vector<size_t> oldFromNewQueries;
 
@@ -154,19 +257,19 @@ Search(const MatType& querySet,
   if (naive)
   {
     // Create the helper object for the tree traversal.
-    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric);
 
     // The naive brute-force traversal.
     for (size_t i = 0; i < querySet.n_cols; ++i)
-      for (size_t j = 0; j < referenceSet.n_cols; ++j)
+      for (size_t j = 0; j < referenceSet->n_cols; ++j)
         rules.BaseCase(i, j);
 
-    baseCases += querySet.n_cols * referenceSet.n_cols;
+    baseCases += querySet.n_cols * referenceSet->n_cols;
   }
   else if (singleMode)
   {
     // Create the helper object for the tree traversal.
-    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric);
 
     // Create the traverser.
     typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -191,7 +294,7 @@ Search(const MatType& querySet,
     Timer::Start("computing_neighbors");
 
     // Create the helper object for the tree traversal.
-    RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr,
+    RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr,
         *distancePtr, metric);
 
     // Create the traverser.
@@ -283,16 +386,27 @@ Search(Tree* queryTree,
        arma::Mat<size_t>& neighbors,
        arma::mat& distances)
 {
-  Timer::Start("computing_neighbors");
-
-  // Get a reference to the query set.
-  const MatType& querySet = queryTree->Dataset();
+  if (k > referenceSet->n_cols)
+  {
+    std::stringstream ss;
+    ss << "requested value of k (" << k << ") is greater than the number of "
+        << "points in the reference set (" << referenceSet->n_cols << ")";
+    throw std::invalid_argument(ss.str());
+  }
 
   // Make sure we are in dual-tree mode.
   if (singleMode || naive)
     throw std::invalid_argument("cannot call NeighborSearch::Search() with a "
         "query tree when naive or singleMode are set to true");
 
+  Timer::Start("computing_neighbors");
+
+  baseCases = 0;
+  scores = 0;
+
+  // Get a reference to the query set.
+  const MatType& querySet = queryTree->Dataset();
+
   // We won't need to map query indices, but will we need to map distances?
   arma::Mat<size_t>* neighborPtr = &neighbors;
 
@@ -306,7 +420,7 @@ Search(Tree* queryTree,
 
   // Create the helper object for the traversal.
   typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(referenceSet, querySet, *neighborPtr, distances, metric);
+  RuleType rules(*referenceSet, querySet, *neighborPtr, distances, metric);
 
   // Create the traverser.
   TraversalType<RuleType> traverser(rules);
@@ -345,8 +459,19 @@ Search(const size_t k,
        arma::Mat<size_t>& neighbors,
        arma::mat& distances)
 {
+  if (k > referenceSet->n_cols)
+  {
+    std::stringstream ss;
+    ss << "requested value of k (" << k << ") is greater than the number of "
+        << "points in the reference set (" << referenceSet->n_cols << ")";
+    throw std::invalid_argument(ss.str());
+  }
+
   Timer::Start("computing_neighbors");
 
+  baseCases = 0;
+  scores = 0;
+
   arma::Mat<size_t>* neighborPtr = &neighbors;
   arma::mat* distancePtr = &distances;
 
@@ -358,24 +483,24 @@ Search(const size_t k,
   }
 
   // Initialize results.
-  neighborPtr->set_size(k, referenceSet.n_cols);
+  neighborPtr->set_size(k, referenceSet->n_cols);
   neighborPtr->fill(size_t() - 1);
-  distancePtr->set_size(k, referenceSet.n_cols);
+  distancePtr->set_size(k, referenceSet->n_cols);
   distancePtr->fill(SortPolicy::WorstDistance());
 
   // Create the helper object for the traversal.
   typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(referenceSet, referenceSet, *neighborPtr, *distancePtr,
+  RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr,
       metric, true /* don't return the same point as nearest neighbor */);
 
   if (naive)
   {
     // The naive brute-force solution.
-    for (size_t i = 0; i < referenceSet.n_cols; ++i)
-      for (size_t j = 0; j < referenceSet.n_cols; ++j)
+    for (size_t i = 0; i < referenceSet->n_cols; ++i)
+      for (size_t j = 0; j < referenceSet->n_cols; ++j)
         rules.BaseCase(i, j);
 
-    baseCases += referenceSet.n_cols * referenceSet.n_cols;
+    baseCases += referenceSet->n_cols * referenceSet->n_cols;
   }
   else if (singleMode)
   {
@@ -383,7 +508,7 @@ Search(const size_t k,
     typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
-    for (size_t i = 0; i < referenceSet.n_cols; ++i)
+    for (size_t i = 0; i < referenceSet->n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
 
     scores += rules.Scores();
@@ -411,8 +536,8 @@ Search(const size_t k,
   // Do we need to map the reference indices?
   if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
   {
-    neighbors.set_size(k, referenceSet.n_cols);
-    distances.set_size(k, referenceSet.n_cols);
+    neighbors.set_size(k, referenceSet->n_cols);
+    distances.set_size(k, referenceSet->n_cols);
 
     for (size_t i = 0; i < distances.n_cols; ++i)
     {
@@ -444,8 +569,8 @@ std::string NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 {
   std::ostringstream convert;
   convert << "NeighborSearch [" << this << "]" << std::endl;
-  convert << "  Reference set: " << referenceSet.n_rows << "x" ;
-  convert << referenceSet.n_cols << std::endl;
+  convert << "  Reference set: " << referenceSet->n_rows << "x" ;
+  convert << referenceSet->n_cols << std::endl;
   if (referenceTree)
     convert << "  Reference tree: " << referenceTree << std::endl;
   convert << "  Tree owner: " << treeOwner << std::endl;
@@ -455,6 +580,80 @@ std::string NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
   return convert.str();
 }
 
+//! Serialize the NeighborSearch model.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class TraversalType>
+template<typename Archive>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+    Serialize(Archive& ar, const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  // Serialize preferences for search.
+  ar & CreateNVP(naive, "naive");
+  ar & CreateNVP(singleMode, "singleMode");
+
+  // If we are doing naive search, we serialize the dataset.  Otherwise we
+  // serialize the tree.
+  if (naive)
+  {
+    // Delete the current reference set, if necessary and if we are loading.
+    if (Archive::is_loading::value)
+    {
+      if (setOwner && referenceSet)
+        delete referenceSet;
+
+      setOwner = true; // We will own the reference set when we load it.
+    }
+
+    ar & CreateNVP(referenceSet, "referenceSet");
+    ar & CreateNVP(metric, "metric");
+
+    // If we are loading, set the tree to NULL and clean up memory if necessary.
+    if (Archive::is_loading::value)
+    {
+      if (treeOwner && referenceTree)
+        delete referenceTree;
+
+      referenceTree = NULL;
+      oldFromNewReferences.clear();
+      treeOwner = false;
+    }
+  }
+  else
+  {
+    // Delete the current reference tree, if necessary and if we are loading.
+    if (Archive::is_loading::value)
+    {
+      if (treeOwner && referenceTree)
+        delete referenceTree;
+
+      // After we load the tree, we will own it.
+      treeOwner = true;
+    }
+
+    ar & CreateNVP(referenceTree, "referenceTree");
+    ar & CreateNVP(oldFromNewReferences, "oldFromNewReferences");
+
+    // If we are loading, set the dataset accordingly and clean up memory if
+    // necessary.
+    if (Archive::is_loading::value)
+    {
+      if (setOwner && referenceSet)
+        delete referenceSet;
+
+      referenceSet = &referenceTree->Dataset();
+      metric = referenceTree->Metric(); // Get the metric from the tree.
+      setOwner = false;
+    }
+  }
+}
+
 } // namespace neighbor
 } // namespace mlpack
 



More information about the mlpack-git mailing list