[mlpack-git] master: Add rho and leafSize members on SpillSearch, and a command line paramenter '--rho' for mlpack_knn. (fa81be4)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 11 11:16:49 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit fa81be4b559909c3c066645dc707efc5535e9ca6
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Aug 11 12:16:49 2016 -0300

    Add rho and leafSize members on SpillSearch, and a command line paramenter '--rho' for mlpack_knn.


>---------------------------------------------------------------

fa81be4b559909c3c066645dc707efc5535e9ca6
 src/mlpack/methods/neighbor_search/knn_main.cpp    | 14 ++++++++
 src/mlpack/methods/neighbor_search/ns_model.hpp    | 18 ++++++++--
 .../methods/neighbor_search/ns_model_impl.hpp      | 24 ++++++++-----
 .../methods/neighbor_search/spill_search.hpp       | 28 +++++++++++++++
 .../methods/neighbor_search/spill_search_impl.hpp  | 42 +++++++++++++++++-----
 5 files changed, 106 insertions(+), 20 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index c201db9..97db49d 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -70,6 +70,7 @@ PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, R "
     "trees, R* trees, X trees, Hilbert R trees, R+ trees, R++ trees, and Spill "
     "trees).", "l", 20);
 PARAM_DOUBLE_IN("tau", "Overlapping size (for spill trees).", "u", 0);
+PARAM_DOUBLE_IN("rho", "Balance threshold (for spill trees).", "b", 0.7);
 
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
@@ -117,6 +118,9 @@ int main(int argc, char *argv[])
     if (CLI::HasParam("tau"))
       Log::Warn << "--tau (-u) will be ignored because --input_model_file"
           << " is specified." << endl;
+    if (CLI::HasParam("rho"))
+      Log::Warn << "--rho (-b) will be ignored because --input_model_file"
+          << " is specified." << endl;
     if (CLI::HasParam("random_basis"))
       Log::Warn << "--random_basis (-R) will be ignored because "
           << "--input_model_file is specified." << endl;
@@ -157,6 +161,14 @@ int main(int argc, char *argv[])
   if (CLI::HasParam("tau") && "spill" != CLI::GetParam<string>("tree_type"))
     Log::Fatal << "Tau parameter is only valid for spill trees." << endl;
 
+  // Sanity check on rho.
+  const double rho = CLI::GetParam<double>("rho");
+  if (rho < 0 || rho > 1)
+    Log::Fatal << "Invalid rho: " << rho << ".  Must be in the range [0,1]. "
+        << endl;
+  if (CLI::HasParam("rho") && "spill" != CLI::GetParam<string>("tree_type"))
+    Log::Fatal << "Rho parameter is only valid for spill trees." << endl;
+
   // Sanity check on epsilon.
   const double epsilon = CLI::GetParam<double>("epsilon");
   if (epsilon < 0)
@@ -204,6 +216,7 @@ int main(int argc, char *argv[])
     knn.RandomBasis() = randomBasis;
     knn.LeafSize() = size_t(lsInt);
     knn.Tau() = tau;
+    knn.Rho() = rho;
 
     arma::mat referenceSet;
     data::Load(referenceFile, referenceSet, true);
@@ -231,6 +244,7 @@ int main(int argc, char *argv[])
     knn.LeafSize() = size_t(lsInt);
     knn.Epsilon() = epsilon;
     knn.Tau() = tau;
+    knn.Rho() = rho;
   }
 
   // Perform search, if desired.
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 03855a1..4f34e07 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -110,6 +110,8 @@ class BiSearchVisitor : public boost::static_visitor<void>
   const size_t leafSize;
   //! Overlapping size (for spill trees).
   const double tau;
+  //! Balance threshold (for spill trees).
+  const double rho;
 
   //! Bichromatic neighbor search on the given NSType considering the leafSize.
   template<typename NSType>
@@ -143,7 +145,8 @@ class BiSearchVisitor : public boost::static_visitor<void>
                   arma::Mat<size_t>& neighbors,
                   arma::mat& distances,
                   const size_t leafSize,
-                  const double tau);
+                  const double tau,
+                  const double rho);
 };
 
 /**
@@ -162,6 +165,8 @@ class TrainVisitor : public boost::static_visitor<void>
   size_t leafSize;
   //! Overlapping size (for spill trees).
   const double tau;
+  //! Balance threshold (for spill trees).
+  const double rho;
 
   //! Train on the given NSType considering the leafSize.
   template<typename NSType>
@@ -190,10 +195,11 @@ class TrainVisitor : public boost::static_visitor<void>
   void operator()(NSSpillType* ns) const;
 
   //! Construct the TrainVisitor object with the given reference set, leafSize
-  //! for BinarySpaceTrees, and tau for spill trees.
+  //! for BinarySpaceTrees, and tau and rho for spill trees.
   TrainVisitor(arma::mat&& referenceSet,
                const size_t leafSize,
-               const double tau);
+               const double tau,
+               const double rho);
 };
 
 /**
@@ -289,6 +295,8 @@ class NSModel
 
   //! Overlapping size (for spill trees).
   double tau;
+  //! Balance threshold (for spill trees).
+  double rho;
 
   //! If true, random projections are used.
   bool randomBasis;
@@ -348,6 +356,10 @@ class NSModel
   double Tau() const { return tau; }
   double& Tau() { return tau; }
 
+  //! Expose rho.
+  double Rho() const { return rho; }
+  double& Rho() { return rho; }
+
   //! Expose treeType.
   TreeTypes TreeType() const { return treeType; }
   TreeTypes& TreeType() { return treeType; }
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index d46b7e5..29cc167 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -34,13 +34,15 @@ BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
                                              arma::Mat<size_t>& neighbors,
                                              arma::mat& distances,
                                              const size_t leafSize,
-                                             const double tau) :
+                                             const double tau,
+                                             const double rho) :
     querySet(querySet),
     k(k),
     neighbors(neighbors),
     distances(distances),
     leafSize(leafSize),
-    tau(tau)
+    tau(tau),
+    rho(rho)
 {}
 
 //! Default Bichromatic neighbor search on the given NSType instance.
@@ -84,7 +86,7 @@ void BiSearchVisitor<SortPolicy>::operator()(NSSpillType* ns) const
       // For Dual Tree Search on SpillTrees, the queryTree must be built with
       // non overlapping (tau = 0).
       typename NSSpillType::Tree queryTree(std::move(querySet), 0 /* tau*/,
-          leafSize);
+          leafSize, rho);
       ns->Search(&queryTree, k, neighbors, distances);
     }
     else
@@ -126,10 +128,12 @@ void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
 template<typename SortPolicy>
 TrainVisitor<SortPolicy>::TrainVisitor(arma::mat&& referenceSet,
                                        const size_t leafSize,
-                                       const double tau) :
+                                       const double tau,
+                                       const double rho) :
     referenceSet(std::move(referenceSet)),
     leafSize(leafSize),
-    tau(tau)
+    tau(tau),
+    rho(rho)
 {}
 
 //! Default Train on the given NSType instance.
@@ -173,7 +177,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSSpillType* ns) const
     else
     {
       typename NSSpillType::Tree* tree = new typename NSSpillType::Tree(
-          std::move(referenceSet), tau, leafSize);
+          std::move(referenceSet), tau, leafSize, rho);
       ns->Train(tree);
       // Give the model ownership of the tree.
       ns->neighborSearch.treeOwner = true;
@@ -257,6 +261,7 @@ NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
     treeType(treeType),
     leafSize(20),
     tau(0),
+    rho(0.7),
     randomBasis(randomBasis)
 {
   // Nothing to do.
@@ -307,6 +312,7 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
   ar & data::CreateNVP(treeType, "treeType");
   ar & data::CreateNVP(leafSize, "leafSize");
   ar & data::CreateNVP(tau, "tau");
+  ar & data::CreateNVP(rho, "rho");
   ar & data::CreateNVP(randomBasis, "randomBasis");
   ar & data::CreateNVP(q, "q");
 
@@ -454,11 +460,11 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
           epsilon);
       break;
     case SPILL_TREE:
-      nSearch = new NSSpillType(naive, singleMode, tau, epsilon);
+      nSearch = new NSSpillType(naive, singleMode, tau, leafSize, rho, epsilon);
       break;
   }
 
-  TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau);
+  TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau, rho);
   boost::apply_visitor(tn, nSearch);
 
   if (!naive)
@@ -491,7 +497,7 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
         << std::endl;
 
   BiSearchVisitor<SortPolicy> search(querySet, k, neighbors, distances,
-      leafSize, tau);
+      leafSize, tau, rho);
   boost::apply_visitor(search, nSearch);
 }
 
diff --git a/src/mlpack/methods/neighbor_search/spill_search.hpp b/src/mlpack/methods/neighbor_search/spill_search.hpp
index e1d4008..8437012 100644
--- a/src/mlpack/methods/neighbor_search/spill_search.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search.hpp
@@ -68,6 +68,8 @@ class SpillSearch
    * @param singleMode If true, single-tree search will be used (as opposed to
    *      dual-tree search).
    * @param tau Overlapping size (non-negative).
+   * @param leafSize Max size of each leaf in the tree.
+   * @param rho Balance threshold (non-negative).
    * @param epsilon Relative approximate error (non-negative).
    * @param metric An optional instance of the MetricType class.
    */
@@ -75,6 +77,8 @@ class SpillSearch
               const bool naive = false,
               const bool singleMode = false,
               const double tau = 0,
+              const double leafSize = 20,
+              const double rho = 0.7,
               const double epsilon = 0,
               const MetricType metric = MetricType());
 
@@ -91,6 +95,8 @@ class SpillSearch
    * @param singleMode If true, single-tree search will be used (as opposed to
    *      dual-tree search).
    * @param tau Overlapping size (non-negative).
+   * @param leafSize Max size of each leaf in the tree.
+   * @param rho Balance threshold (non-negative).
    * @param epsilon Relative approximate error (non-negative).
    * @param metric An optional instance of the MetricType class.
    */
@@ -98,6 +104,8 @@ class SpillSearch
               const bool naive = false,
               const bool singleMode = false,
               const double tau = 0,
+              const double leafSize = 20,
+              const double rho = 0.7,
               const double epsilon = 0,
               const MetricType metric = MetricType());
 
@@ -113,12 +121,16 @@ class SpillSearch
    * @param singleMode Whether single-tree computation should be used (as
    *      opposed to dual-tree computation).
    * @param tau Overlapping size (non-negative).
+   * @param leafSize Max size of each leaf in the tree.
+   * @param rho Balance threshold (non-negative).
    * @param epsilon Relative approximate error (non-negative).
    * @param metric Instantiated distance metric.
    */
   SpillSearch(Tree* referenceTree,
               const bool singleMode = false,
               const double tau = 0,
+              const double leafSize = 20,
+              const double rho = 0.7,
               const double epsilon = 0,
               const MetricType metric = MetricType());
 
@@ -131,12 +143,16 @@ class SpillSearch
    * @param singleMode Whether single-tree computation should be used (as
    *      opposed to dual-tree computation).
    * @param tau Overlapping size (non-negative).
+   * @param leafSize Max size of each leaf in the tree.
+   * @param rho Balance threshold (non-negative).
    * @param epsilon Relative approximate error (non-negative).
    * @param metric Instantiated metric.
    */
   SpillSearch(const bool naive = false,
               const bool singleMode = false,
               const double tau = 0,
+              const double leafSize = 20,
+              const double rho = 0.7,
               const double epsilon = 0,
               const MetricType metric = MetricType());
 
@@ -262,6 +278,12 @@ class SpillSearch
   //! Access the overlapping size.
   double Tau() const { return tau; }
 
+  //! Access the balance threshold.
+  double Rho() const { return rho; }
+
+  //! Access the leaf size.
+  double LeafSize() const { return leafSize; }
+
   //! Access the reference dataset.
   const MatType& ReferenceSet() const { return neighborSearch.ReferenceSet(); }
 
@@ -277,6 +299,12 @@ class SpillSearch
   //! Overlapping size.
   double tau;
 
+  //! Balance threshold.
+  double rho;
+
+  //! Max leaf size.
+  double leafSize;
+
   //! The NSModel class should have access to internal members.
   template<typename SortPolicy>
   friend class TrainVisitor;
diff --git a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
index 0a06037..5638959 100644
--- a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
@@ -25,13 +25,19 @@ SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
     const bool naive,
     const bool singleMode,
     const double tau,
+    const double leafSize,
+    const double rho,
     const double epsilon,
     const MetricType metric) :
     neighborSearch(naive, singleMode, epsilon, metric),
-    tau(tau)
+    tau(tau),
+    rho(rho),
+    leafSize(leafSize)
 {
   if (tau < 0)
     throw std::invalid_argument("tau must be non-negative");
+  if (rho < 0 || rho > 1)
+    throw std::invalid_argument("rho must be in the range [0,1]");
   Train(referenceSetIn);
 }
 
@@ -45,13 +51,19 @@ SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
     const bool naive,
     const bool singleMode,
     const double tau,
+    const double leafSize,
+    const double rho,
     const double epsilon,
     const MetricType metric) :
     neighborSearch(naive, singleMode, epsilon, metric),
-    tau(tau)
+    tau(tau),
+    rho(rho),
+    leafSize(leafSize)
 {
   if (tau < 0)
     throw std::invalid_argument("tau must be non-negative");
+  if (rho < 0 || rho > 1)
+    throw std::invalid_argument("rho must be in the range [0,1]");
   Train(std::move(referenceSetIn));
 }
 
@@ -64,13 +76,19 @@ SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
     Tree* referenceTree,
     const bool singleMode,
     const double tau,
+    const double leafSize,
+    const double rho,
     const double epsilon,
     const MetricType metric) :
     neighborSearch(singleMode, epsilon, metric),
-    tau(tau)
+    tau(tau),
+    rho(rho),
+    leafSize(leafSize)
 {
   if (tau < 0)
     throw std::invalid_argument("tau must be non-negative");
+  if (rho < 0 || rho > 1)
+    throw std::invalid_argument("rho must be in the range [0,1]");
   Train(referenceTree);
 }
 
@@ -83,13 +101,19 @@ SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
     const bool naive,
     const bool singleMode,
     const double tau,
+    const double leafSize,
+    const double rho,
     const double epsilon,
     const MetricType metric) :
     neighborSearch(naive, singleMode, epsilon, metric),
-    tau(tau)
+    tau(tau),
+    rho(rho),
+    leafSize(leafSize)
 {
   if (tau < 0)
     throw std::invalid_argument("tau must be non-negative");
+  if (rho < 0 || rho > 1)
+    throw std::invalid_argument("rho must be in the range [0,1]");
 }
 
 // Clean memory.
@@ -115,7 +139,7 @@ Train(const MatType& referenceSet)
   else
   {
     // Build reference tree with proper value for tau.
-    Tree* tree = new Tree(referenceSet, tau);
+    Tree* tree = new Tree(referenceSet, tau, leafSize, rho);
     neighborSearch.Train(tree);
     // Give the model ownership of the tree.
     neighborSearch.treeOwner = true;
@@ -134,7 +158,7 @@ Train(MatType&& referenceSetIn)
   else
   {
     // Build reference tree with proper value for tau.
-    Tree* tree = new Tree(std::move(referenceSetIn), tau);
+    Tree* tree = new Tree(std::move(referenceSetIn), tau, leafSize, rho);
     neighborSearch.Train(tree);
     // Give the model ownership of the tree.
     neighborSearch.treeOwner = true;
@@ -167,7 +191,7 @@ Search(const MatType& querySet,
   {
     // For Dual Tree Search on SpillTrees, the queryTree must be built with non
     // overlapping (tau = 0).
-    Tree queryTree(querySet, 0 /* tau */);
+    Tree queryTree(querySet, 0 /* tau */, leafSize, rho);
     neighborSearch.Search(&queryTree, k, neighbors, distances);
   }
 }
@@ -201,7 +225,7 @@ Search(const size_t k,
     // For Dual Tree Search on SpillTrees, the queryTree must be built with non
     // overlapping (tau = 0). If the referenceTree was built with a non-zero
     // value for tau, we need to build a new queryTree.
-    Tree queryTree(ReferenceSet(), 0 /* tau */);
+    Tree queryTree(ReferenceSet(), 0 /* tau */, leafSize, rho);
     neighborSearch.Search(&queryTree, k, neighbors, distances, true);
   }
 }
@@ -217,6 +241,8 @@ void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
 {
   ar & data::CreateNVP(neighborSearch, "neighborSearch");
   ar & data::CreateNVP(tau, "tau");
+  ar & data::CreateNVP(rho, "rho");
+  ar & data::CreateNVP(leafSize, "leafSize");
 }
 
 } // namespace neighbor




More information about the mlpack-git mailing list