[mlpack-git] master: Added the UB tree to RSModel. Update range search tests. Replaced 'maxPRTreeRS' by 'maxRPTreeRS'. (43a065b)

gitdub at mlpack.org gitdub at mlpack.org
Fri Aug 26 16:34:49 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e

>---------------------------------------------------------------

commit 43a065b288999f3b044dd02a2cb6f942376c035c
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Fri Aug 26 23:34:49 2016 +0300

    Added the UB tree to RSModel. Update range search tests. Replaced 'maxPRTreeRS' by 'maxRPTreeRS'.


>---------------------------------------------------------------

43a065b288999f3b044dd02a2cb6f942376c035c
 .../methods/range_search/range_search_main.cpp     | 10 +++---
 src/mlpack/methods/range_search/rs_model.cpp       | 33 +++++++++++++++-----
 src/mlpack/methods/range_search/rs_model.hpp       |  8 +++--
 src/mlpack/methods/range_search/rs_model_impl.hpp  | 36 +++++++++++++++-------
 src/mlpack/tests/range_search_test.cpp             | 12 +++++---
 5 files changed, 71 insertions(+), 28 deletions(-)

diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp
index 2505024..9487832 100644
--- a/src/mlpack/methods/range_search/range_search_main.cpp
+++ b/src/mlpack/methods/range_search/range_search_main.cpp
@@ -71,10 +71,10 @@ PARAM_DOUBLE_IN("min", "Lower bound in range.", "L", 0.0);
 // The user may specify the type of tree to use, and a few parameters for tree
 // building.
 PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'vp', 'rp', 'max-rp', "
-    "'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', "
+    "'ub', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', "
     "'r-plus-plus'.", "t", "kd");
 PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
-    "vp trees, random projection trees, R trees, R* trees, X trees, "
+    "vp trees, random projection trees, UB trees, R trees, R* trees, X trees, "
     "Hilbert R trees, R+ trees and R++ trees).", "l", 20);
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
@@ -189,10 +189,12 @@ int main(int argc, char *argv[])
       tree = RSModel::RP_TREE;
     else if (treeType == "max-rp")
       tree = RSModel::MAX_RP_TREE;
+    else if (treeType == "ub")
+      tree = RSModel::UB_TREE;
     else
       Log::Fatal << "Unknown tree type '" << treeType << "; valid choices are "
-          << "'kd', 'vp', 'rp', 'max-rp', 'cover', 'r', 'r-star', 'x', 'ball', "
-          << "'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
+          << "'kd', 'vp', 'rp', 'max-rp', 'ub', 'cover', 'r', 'r-star', 'x', "
+          << "'ball', 'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
 
     rs.TreeType() = tree;
     rs.RandomBasis() = randomBasis;
diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp
index c926e4b..6025471 100644
--- a/src/mlpack/methods/range_search/rs_model.cpp
+++ b/src/mlpack/methods/range_search/rs_model.cpp
@@ -28,7 +28,8 @@ RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
     rPlusPlusTreeRS(NULL),
     vpTreeRS(NULL),
     rpTreeRS(NULL),
-    maxPRTreeRS(NULL)
+    maxRPTreeRS(NULL),
+    ubTreeRS(NULL)
 {
   // Nothing to do.
 }
@@ -155,7 +156,12 @@ void RSModel::BuildModel(arma::mat&& referenceSet,
       break;
 
     case MAX_RP_TREE:
-      maxPRTreeRS = new RSType<tree::MaxRPTree>(move(referenceSet),
+      maxRPTreeRS = new RSType<tree::MaxRPTree>(move(referenceSet),
+          naive, singleMode);
+      break;
+
+    case UB_TREE:
+      ubTreeRS = new RSType<tree::UBTree>(move(referenceSet),
           naive, singleMode);
       break;
   }
@@ -289,7 +295,11 @@ void RSModel::Search(arma::mat&& querySet,
       break;
 
     case MAX_RP_TREE:
-      maxPRTreeRS->Search(querySet, range, neighbors, distances);
+      maxRPTreeRS->Search(querySet, range, neighbors, distances);
+      break;
+
+    case UB_TREE:
+      ubTreeRS->Search(querySet, range, neighbors, distances);
       break;
   }
 }
@@ -355,7 +365,11 @@ void RSModel::Search(const math::Range& range,
       break;
 
     case MAX_RP_TREE:
-      maxPRTreeRS->Search(range, neighbors, distances);
+      maxRPTreeRS->Search(range, neighbors, distances);
+      break;
+
+    case UB_TREE:
+      ubTreeRS->Search(range, neighbors, distances);
       break;
   }
 }
@@ -389,6 +403,8 @@ std::string RSModel::TreeName() const
       return "random projection tree (mean split)";
     case MAX_RP_TREE:
       return "random projection tree (max split)";
+    case UB_TREE:
+      return "UB tree";
     default:
       return "unknown tree";
   }
@@ -419,8 +435,10 @@ void RSModel::CleanMemory()
     delete vpTreeRS;
   if (rpTreeRS)
     delete rpTreeRS;
-  if (maxPRTreeRS)
-    delete maxPRTreeRS;
+  if (maxRPTreeRS)
+    delete maxRPTreeRS;
+  if (ubTreeRS)
+    delete ubTreeRS;
 
   kdTreeRS = NULL;
   coverTreeRS = NULL;
@@ -433,5 +451,6 @@ void RSModel::CleanMemory()
   rPlusPlusTreeRS = NULL;
   vpTreeRS = NULL;
   rpTreeRS = NULL;
-  maxPRTreeRS = NULL;
+  maxRPTreeRS = NULL;
+  ubTreeRS = NULL;
 }
diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp
index a073968..659319d 100644
--- a/src/mlpack/methods/range_search/rs_model.hpp
+++ b/src/mlpack/methods/range_search/rs_model.hpp
@@ -35,7 +35,8 @@ class RSModel
     R_PLUS_PLUS_TREE,
     VP_TREE,
     RP_TREE,
-    MAX_RP_TREE
+    MAX_RP_TREE,
+    UB_TREE
   };
 
  private:
@@ -79,7 +80,10 @@ class RSModel
   RSType<tree::RPTree>* rpTreeRS;
   //! Random projection tree (max) based range search object
   //! (NULL if not in use).
-  RSType<tree::MaxRPTree>* maxPRTreeRS;
+  RSType<tree::MaxRPTree>* maxRPTreeRS;
+  //! Universal B tree based range search object
+  //! (NULL if not in use).
+  RSType<tree::UBTree>* ubTreeRS;
 
  public:
   /**
diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp
index 183f599..2a8502f 100644
--- a/src/mlpack/methods/range_search/rs_model_impl.hpp
+++ b/src/mlpack/methods/range_search/rs_model_impl.hpp
@@ -75,7 +75,11 @@ void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
       break;
 
     case MAX_RP_TREE:
-      ar & CreateNVP(maxPRTreeRS, "range_search_model");
+      ar & CreateNVP(maxRPTreeRS, "range_search_model");
+      break;
+
+    case UB_TREE:
+      ar & CreateNVP(ubTreeRS, "range_search_model");
       break;
   }
 }
@@ -104,8 +108,10 @@ inline const arma::mat& RSModel::Dataset() const
     return vpTreeRS->ReferenceSet();
   else if (rpTreeRS)
     return rpTreeRS->ReferenceSet();
-  else if (maxPRTreeRS)
-    return maxPRTreeRS->ReferenceSet();
+  else if (maxRPTreeRS)
+    return maxRPTreeRS->ReferenceSet();
+  else if (ubTreeRS)
+    return ubTreeRS->ReferenceSet();
 
   throw std::runtime_error("no range search model initialized");
 }
@@ -134,8 +140,10 @@ inline bool RSModel::SingleMode() const
     return vpTreeRS->SingleMode();
   else if (rpTreeRS)
     return rpTreeRS->SingleMode();
-  else if (maxPRTreeRS)
-    return maxPRTreeRS->SingleMode();
+  else if (maxRPTreeRS)
+    return maxRPTreeRS->SingleMode();
+  else if (ubTreeRS)
+    return ubTreeRS->SingleMode();
 
   throw std::runtime_error("no range search model initialized");
 }
@@ -164,8 +172,10 @@ inline bool& RSModel::SingleMode()
     return vpTreeRS->SingleMode();
   else if (rpTreeRS)
     return rpTreeRS->SingleMode();
-  else if (maxPRTreeRS)
-    return maxPRTreeRS->SingleMode();
+  else if (maxRPTreeRS)
+    return maxRPTreeRS->SingleMode();
+  else if (ubTreeRS)
+    return ubTreeRS->SingleMode();
 
   throw std::runtime_error("no range search model initialized");
 }
@@ -194,8 +204,10 @@ inline bool RSModel::Naive() const
     return vpTreeRS->Naive();
   else if (rpTreeRS)
     return rpTreeRS->Naive();
-  else if (maxPRTreeRS)
-    return maxPRTreeRS->Naive();
+  else if (maxRPTreeRS)
+    return maxRPTreeRS->Naive();
+  else if (ubTreeRS)
+    return ubTreeRS->Naive();
 
   throw std::runtime_error("no range search model initialized");
 }
@@ -224,8 +236,10 @@ inline bool& RSModel::Naive()
     return vpTreeRS->Naive();
   else if (rpTreeRS)
     return rpTreeRS->Naive();
-  else if (maxPRTreeRS)
-    return maxPRTreeRS->Naive();
+  else if (maxRPTreeRS)
+    return maxRPTreeRS->Naive();
+  else if (ubTreeRS)
+    return ubTreeRS->Naive();
 
   throw std::runtime_error("no range search model initialized");
 }
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index fcfa9eb..0968b07 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -1249,7 +1249,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
   arma::mat referenceData = arma::randu<arma::mat>(10, 200);
 
   // Build all the possible models.
-  RSModel models[24];
+  RSModel models[26];
   models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
   models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
   models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
@@ -1274,6 +1274,8 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
   models[21] = RSModel(RSModel::TreeTypes::RP_TREE, false);
   models[22] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, true);
   models[23] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
+  models[24] = RSModel(RSModel::TreeTypes::UB_TREE, true);
+  models[25] = RSModel(RSModel::TreeTypes::UB_TREE, false);
 
   for (size_t j = 0; j < 2; ++j)
   {
@@ -1287,7 +1289,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
     vector<vector<pair<double, size_t>>> baselineSorted;
     SortResults(baselineNeighbors, baselineDistances, baselineSorted);
 
-    for (size_t i = 0; i < 24; ++i)
+    for (size_t i = 0; i < 26; ++i)
     {
       // We only have std::move() constructors, so make a copy of our data.
       arma::mat referenceCopy(referenceData);
@@ -1331,7 +1333,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
   arma::mat referenceData = arma::randu<arma::mat>(10, 200);
 
   // Build all the possible models.
-  RSModel models[24];
+  RSModel models[26];
   models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
   models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
   models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
@@ -1356,6 +1358,8 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
   models[21] = RSModel(RSModel::TreeTypes::RP_TREE, false);
   models[22] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, true);
   models[23] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
+  models[24] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, true);
+  models[25] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
 
   for (size_t j = 0; j < 2; ++j)
   {
@@ -1368,7 +1372,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
     vector<vector<pair<double, size_t>>> baselineSorted;
     SortResults(baselineNeighbors, baselineDistances, baselineSorted);
 
-    for (size_t i = 0; i < 24; ++i)
+    for (size_t i = 0; i < 26; ++i)
     {
       // We only have std::move() cosntructors, so make a copy of our data.
       arma::mat referenceCopy(referenceData);




More information about the mlpack-git mailing list