[mlpack-git] master: Add support for spill trees in knn search. (2e67697)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 18 13:39:00 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit 2e676974ef119f128147c2f0567706e306c0e7e7
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Tue Jul 12 20:52:09 2016 -0300

    Add support for spill trees in knn search.


>---------------------------------------------------------------

2e676974ef119f128147c2f0567706e306c0e7e7
 src/mlpack/core/tree/spill_tree/spill_tree.hpp     |  6 +-
 src/mlpack/methods/neighbor_search/knn_main.cpp    | 26 +++++++--
 src/mlpack/methods/neighbor_search/ns_model.hpp    | 35 +++++++++--
 .../methods/neighbor_search/ns_model_impl.hpp      | 67 ++++++++++++++++++++--
 4 files changed, 115 insertions(+), 19 deletions(-)

diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index e563b49..d5c13fc 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -127,7 +127,7 @@ class SpillTree
    * @param rho Balance threshold.
    */
   SpillTree(const MatType& data,
-            const double tau,
+            const double tau = 0,
             const size_t maxLeafSize = 20,
             const double rho = 0.7);
 
@@ -143,7 +143,7 @@ class SpillTree
    * @param rho Balance threshold.
    */
   SpillTree(MatType&& data,
-            const double tau,
+            const double tau = 0,
             const size_t maxLeafSize = 20,
             const double rho = 0.7);
 
@@ -163,7 +163,7 @@ class SpillTree
   SpillTree(SpillTree* parent,
             std::vector<size_t>& points,
             const size_t overlapIndex,
-            const double tau,
+            const double tau = 0,
             const size_t maxLeafSize = 20,
             const double rho = 0.7);
 
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index 14e07db..c201db9 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -64,10 +64,13 @@ PARAM_INT_IN("k", "Number of nearest neighbors to find.", "k", 0);
 // The user may specify the type of tree to use, and a few parameters for tree
 // building.
 PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'cover', 'r', "
-    "'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd");
+    "'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', 'spill'.",
+    "t", "kd");
 PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, R "
-    "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l",
-    20);
+    "trees, R* trees, X trees, Hilbert R trees, R+ trees, R++ trees, and Spill "
+    "trees).", "l", 20);
+PARAM_DOUBLE_IN("tau", "Overlapping size (for spill trees).", "u", 0);
+
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
 PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -111,6 +114,9 @@ int main(int argc, char *argv[])
     if (CLI::HasParam("leaf_size"))
       Log::Warn << "--leaf_size (-l) will be ignored because --input_model_file"
           << " is specified." << endl;
+    if (CLI::HasParam("tau"))
+      Log::Warn << "--tau (-u) will be ignored because --input_model_file"
+          << " is specified." << endl;
     if (CLI::HasParam("random_basis"))
       Log::Warn << "--random_basis (-R) will be ignored because "
           << "--input_model_file is specified." << endl;
@@ -144,6 +150,13 @@ int main(int argc, char *argv[])
     Log::Fatal << "Invalid leaf size: " << lsInt << ".  Must be greater "
         "than 0." << endl;
 
+  // Sanity check on tau.
+  const double tau = CLI::GetParam<double>("tau");
+  if (tau < 0)
+    Log::Fatal << "Invalid tau: " << tau << ".  Must be non-negative. " << endl;
+  if (CLI::HasParam("tau") && "spill" != CLI::GetParam<string>("tree_type"))
+    Log::Fatal << "Tau parameter is only valid for spill trees." << endl;
+
   // Sanity check on epsilon.
   const double epsilon = CLI::GetParam<double>("epsilon");
   if (epsilon < 0)
@@ -180,13 +193,17 @@ int main(int argc, char *argv[])
       tree = KNNModel::R_PLUS_TREE;
     else if (treeType == "r-plus-plus")
       tree = KNNModel::R_PLUS_PLUS_TREE;
+    else if (treeType == "spill")
+      tree = KNNModel::SPILL_TREE;
     else
       Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
           << "'kd', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', "
-          << "'r-plus' and 'r-plus-plus'." << endl;
+          << "'r-plus', 'r-plus-plus' and 'spill'." << endl;
 
     knn.TreeType() = tree;
     knn.RandomBasis() = randomBasis;
+    knn.LeafSize() = size_t(lsInt);
+    knn.Tau() = tau;
 
     arma::mat referenceSet;
     data::Load(referenceFile, referenceSet, true);
@@ -213,6 +230,7 @@ int main(int argc, char *argv[])
     knn.Naive() = CLI::HasParam("naive");
     knn.LeafSize() = size_t(lsInt);
     knn.Epsilon() = epsilon;
+    knn.Tau() = tau;
   }
 
   // Perform search, if desired.
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 38e4748..6a31d5b 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -13,6 +13,7 @@
 #include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/core/tree/cover_tree.hpp>
 #include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/core/tree/spill_tree.hpp>
 #include <boost/variant.hpp>
 #include "neighbor_search.hpp"
 
@@ -101,6 +102,8 @@ class BiSearchVisitor : public boost::static_visitor<void>
   arma::mat& distances;
   //! The number of points in a leaf (for BinarySpaceTrees).
   const size_t leafSize;
+  //! Overlapping size (for spill trees).
+  const double tau;
 
   //! Bichromatic neighbor search on the given NSType considering the leafSize.
   template<typename NSType>
@@ -125,12 +128,16 @@ class BiSearchVisitor : public boost::static_visitor<void>
   //! Bichromatic neighbor search on the given NSType specialized for BallTrees.
   void operator()(NSTypeT<tree::BallTree>* ns) const;
 
+  //! Bichromatic neighbor search on the given NSType specialized for SPTrees.
+  void operator()(NSTypeT<tree::SPTree>* ns) const;
+
   //! Construct the BiSearchVisitor.
   BiSearchVisitor(const arma::mat& querySet,
                   const size_t k,
                   arma::Mat<size_t>& neighbors,
                   arma::mat& distances,
-                  const size_t leafSize);
+                  const size_t leafSize,
+                  const double tau);
 };
 
 /**
@@ -147,6 +154,8 @@ class TrainVisitor : public boost::static_visitor<void>
   arma::mat&& referenceSet;
   //! The leaf size, used only by BinarySpaceTree.
   size_t leafSize;
+  //! Overlapping size (for spill trees).
+  const double tau;
 
   //! Train on the given NSType considering the leafSize.
   template<typename NSType>
@@ -171,9 +180,14 @@ class TrainVisitor : public boost::static_visitor<void>
   //! Train on the given NSType specialized for BallTrees.
   void operator()(NSTypeT<tree::BallTree>* ns) const;
 
-  //! Construct the TrainVisitor object with the given reference set and leaf
-  //! size for BinarySpaceTrees.
-  TrainVisitor(arma::mat&& referenceSet, const size_t leafSize);
+  //! Train on the given NSType specialized for SPTrees.
+  void operator()(NSTypeT<tree::SPTree>* ns) const;
+
+  //! Construct the TrainVisitor object with the given reference set, leafSize
+  //! for BinarySpaceTrees, and tau for spill trees.
+  TrainVisitor(arma::mat&& referenceSet,
+               const size_t leafSize,
+               const double tau);
 };
 
 /**
@@ -256,7 +270,8 @@ class NSModel
     X_TREE,
     HILBERT_R_TREE,
     R_PLUS_TREE,
-    R_PLUS_PLUS_TREE
+    R_PLUS_PLUS_TREE,
+    SPILL_TREE
   };
 
  private:
@@ -266,6 +281,9 @@ class NSModel
   //! For tree types that accept the maxLeafSize parameter.
   size_t leafSize;
 
+  //! Overlapping size (for spill trees).
+  double tau;
+
   //! If true, random projections are used.
   bool randomBasis;
   //! This is the random projection matrix; only used if randomBasis is true.
@@ -284,7 +302,8 @@ class NSModel
                  NSType<SortPolicy, tree::XTree>*,
                  NSType<SortPolicy, tree::HilbertRTree>*,
                  NSType<SortPolicy, tree::RPlusTree>*,
-                 NSType<SortPolicy, tree::RPlusPlusTree>*> nSearch;
+                 NSType<SortPolicy, tree::RPlusPlusTree>*,
+                 NSType<SortPolicy, tree::SPTree>*> nSearch;
 
  public:
   /**
@@ -319,6 +338,10 @@ class NSModel
   size_t LeafSize() const { return leafSize; }
   size_t& LeafSize() { return leafSize; }
 
+  //! Expose tau.
+  double Tau() const { return tau; }
+  double& Tau() { return tau; }
+
   //! Expose treeType.
   TreeTypes TreeType() const { return treeType; }
   TreeTypes& TreeType() { return treeType; }
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index acbed6c..71430e6 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -33,12 +33,14 @@ BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
                                              const size_t k,
                                              arma::Mat<size_t>& neighbors,
                                              arma::mat& distances,
-                                             const size_t leafSize) :
+                                             const size_t leafSize,
+                                             const double tau) :
     querySet(querySet),
     k(k),
     neighbors(neighbors),
     distances(distances),
-    leafSize(leafSize)
+    leafSize(leafSize),
+    tau(tau)
 {}
 
 //! Default Bichromatic neighbor search on the given NSType instance.
@@ -71,6 +73,25 @@ void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
   throw std::runtime_error("no neighbor search model initialized");
 }
 
+//! Bichromatic neighbor search on the given NSType specialized for SPTrees.
+template<typename SortPolicy>
+void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::SPTree>* ns) const
+{
+  if (ns)
+  {
+    if (!ns->Naive() && !ns->SingleMode())
+    {
+      typename NSTypeT<tree::SPTree>::Tree queryTree(std::move(querySet), tau,
+          leafSize);
+      ns->Search(&queryTree, k, neighbors, distances);
+    }
+    else
+      ns->Search(querySet, k, neighbors, distances);
+  }
+  else
+    throw std::runtime_error("no neighbor search model initialized");
+}
+
 //! Bichromatic neighbor search on the given NSType considering the leafSize.
 template<typename SortPolicy>
 template<typename NSType>
@@ -102,9 +123,11 @@ void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
 //! Save parameters for Train.
 template<typename SortPolicy>
 TrainVisitor<SortPolicy>::TrainVisitor(arma::mat&& referenceSet,
-                                       const size_t leafSize) :
+                                       const size_t leafSize,
+                                       const double tau) :
     referenceSet(std::move(referenceSet)),
-    leafSize(leafSize)
+    leafSize(leafSize),
+    tau(tau)
 {}
 
 //! Default Train on the given NSType instance.
@@ -137,6 +160,27 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
   throw std::runtime_error("no neighbor search model initialized");
 }
 
+//! Train on the given NSType specialized for SPTrees.
+template<typename SortPolicy>
+void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::SPTree>* ns) const
+{
+  if (ns)
+  {
+    if (ns->Naive())
+      ns->Train(std::move(referenceSet));
+    else
+    {
+      typename NSTypeT<tree::SPTree>::Tree* tree = new typename
+          NSTypeT<tree::SPTree>::Tree(std::move(referenceSet), tau, leafSize);
+      ns->Train(tree);
+      // Give the model ownership of the tree.
+      ns->treeOwner = true;
+    }
+  }
+  else
+    throw std::runtime_error("no neighbor search model initialized");
+}
+
 //! Train on the given NSType considering the leafSize.
 template<typename SortPolicy>
 template<typename NSType>
@@ -209,6 +253,8 @@ void DeleteVisitor::operator()(NSType* ns) const
 template<typename SortPolicy>
 NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
     treeType(treeType),
+    leafSize(20),
+    tau(0),
     randomBasis(randomBasis)
 {
   // Nothing to do.
@@ -249,6 +295,8 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
                                     const unsigned int /* version */)
 {
   ar & data::CreateNVP(treeType, "treeType");
+  ar & data::CreateNVP(leafSize, "leafSize");
+  ar & data::CreateNVP(tau, "tau");
   ar & data::CreateNVP(randomBasis, "randomBasis");
   ar & data::CreateNVP(q, "q");
 
@@ -313,6 +361,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
                                      const bool singleMode,
                                      const double epsilon)
 {
+  this->leafSize = leafSize;
   // Initialize random basis if necessary.
   if (randomBasis)
   {
@@ -394,9 +443,13 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
       nSearch = new NSType<SortPolicy, tree::RPlusPlusTree>(naive, singleMode,
           epsilon);
       break;
+    case SPILL_TREE:
+      nSearch = new NSType<SortPolicy, tree::SPTree>(naive, singleMode,
+          epsilon);
+      break;
   }
 
-  TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize);
+  TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau);
   boost::apply_visitor(tn, nSearch);
 
   if (!naive)
@@ -429,7 +482,7 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
         << std::endl;
 
   BiSearchVisitor<SortPolicy> search(querySet, k, neighbors, distances,
-      leafSize);
+      leafSize, tau);
   boost::apply_visitor(search, nSearch);
 }
 
@@ -478,6 +531,8 @@ std::string NSModel<SortPolicy>::TreeName() const
       return "R+ tree";
     case R_PLUS_PLUS_TREE:
       return "R++ tree";
+    case SPILL_TREE:
+      return "Spill tree";
     default:
       return "unknown tree";
   }




More information about the mlpack-git mailing list