[mlpack-git] master: Added the UB tree to RAModel. Fixed UB tree traits (added UniqueNumDescendants). (d3563e9)

gitdub at mlpack.org gitdub at mlpack.org
Fri Aug 26 17:55:09 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e

>---------------------------------------------------------------

commit d3563e9b121054e4e5c1f04849b5edcb7416d1dd
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Sat Aug 27 00:55:09 2016 +0300

    Added the UB tree to RAModel. Fixed UB tree traits (added UniqueNumDescendants).


>---------------------------------------------------------------

d3563e9b121054e4e5c1f04849b5edcb7416d1dd
 src/mlpack/core/tree/binary_space_tree/traits.hpp |  1 +
 src/mlpack/methods/rann/krann_main.cpp            | 12 +++--
 src/mlpack/methods/rann/ra_model.hpp              |  5 +-
 src/mlpack/methods/rann/ra_model_impl.hpp         | 56 ++++++++++++++++++++++-
 src/mlpack/tests/krann_search_test.cpp            |  6 ++-
 5 files changed, 71 insertions(+), 9 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/traits.hpp b/src/mlpack/core/tree/binary_space_tree/traits.hpp
index 69bda7c..ac6dc18 100644
--- a/src/mlpack/core/tree/binary_space_tree/traits.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/traits.hpp
@@ -237,6 +237,7 @@ class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType,
   static const bool HasSelfChildren = false;
   static const bool RearrangesDataset = true;
   static const bool BinaryTree = true;
+  static const bool UniqueNumDescendants = true;
 };
 
 } // namespace tree
diff --git a/src/mlpack/methods/rann/krann_main.cpp b/src/mlpack/methods/rann/krann_main.cpp
index 9950cd2..591d741 100644
--- a/src/mlpack/methods/rann/krann_main.cpp
+++ b/src/mlpack/methods/rann/krann_main.cpp
@@ -64,11 +64,11 @@ PARAM_INT_IN("k", "Number of nearest neighbors to find.", "k", 0);
 
 // The user may specify the type of tree to use, and a few parameters for tree
 // building.
-PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'cover', 'r', or "
+PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'ub', 'cover', 'r', "
     "'x', 'r-star', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd");
-PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, R "
-    "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l",
-    20);
+PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
+    "UB trees, R trees, R* trees, X trees, Hilbert R trees, R+ trees and "
+    "R++ trees).", "l", 20);
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
 PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -180,9 +180,11 @@ int main(int argc, char *argv[])
       tree = RANNModel::R_PLUS_TREE;
     else if (treeType == "r-plus-plus")
       tree = RANNModel::R_PLUS_PLUS_TREE;
+    else if (treeType == "ub")
+      tree = RANNModel::UB_TREE;
     else
       Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
-          << "'kd', 'cover', 'r', 'r-star', 'x', 'hilbert-r', "
+          << "'kd', 'ub', 'cover', 'r', 'r-star', 'x', 'hilbert-r', "
           << "'r-plus' and 'r-plus-plus'." << endl;
 
     rann.TreeType() = tree;
diff --git a/src/mlpack/methods/rann/ra_model.hpp b/src/mlpack/methods/rann/ra_model.hpp
index 1e755d3..48be9e9 100644
--- a/src/mlpack/methods/rann/ra_model.hpp
+++ b/src/mlpack/methods/rann/ra_model.hpp
@@ -43,7 +43,8 @@ class RAModel
     X_TREE,
     HILBERT_R_TREE,
     R_PLUS_TREE,
-    R_PLUS_PLUS_TREE
+    R_PLUS_PLUS_TREE,
+    UB_TREE
   };
 
  private:
@@ -82,6 +83,8 @@ class RAModel
   RAType<tree::RPlusTree>* rPlusTreeRA;
   //! Non-NULL if the R++ tree is used.
   RAType<tree::RPlusPlusTree>* rPlusPlusTreeRA;
+  //! Non-NULL if the UB tree is used.
+  RAType<tree::UBTree>* ubTreeRA;
 
  public:
   /**
diff --git a/src/mlpack/methods/rann/ra_model_impl.hpp b/src/mlpack/methods/rann/ra_model_impl.hpp
index f096540..30452ef 100644
--- a/src/mlpack/methods/rann/ra_model_impl.hpp
+++ b/src/mlpack/methods/rann/ra_model_impl.hpp
@@ -25,7 +25,8 @@ RAModel<SortPolicy>::RAModel(const TreeTypes treeType, const bool randomBasis) :
     xTreeRA(NULL),
     hilbertRTreeRA(NULL),
     rPlusTreeRA(NULL),
-    rPlusPlusTreeRA(NULL)
+    rPlusPlusTreeRA(NULL),
+    ubTreeRA(NULL)
 {
   // Nothing to do.
 }
@@ -49,6 +50,8 @@ RAModel<SortPolicy>::~RAModel()
     delete rPlusTreeRA;
   if (rPlusPlusTreeRA)
     delete rPlusPlusTreeRA;
+  if (ubTreeRA)
+    delete ubTreeRA;
 }
 
 template<typename SortPolicy>
@@ -79,6 +82,8 @@ void RAModel<SortPolicy>::Serialize(Archive& ar,
       delete rPlusTreeRA;
     if (rPlusPlusTreeRA)
       delete rPlusPlusTreeRA;
+    if (ubTreeRA)
+      delete ubTreeRA;
 
     // Set all the pointers to NULL.
     kdTreeRA = NULL;
@@ -89,6 +94,7 @@ void RAModel<SortPolicy>::Serialize(Archive& ar,
     hilbertRTreeRA = NULL;
     rPlusPlusTreeRA = NULL;
     rPlusTreeRA = NULL;
+    ubTreeRA = NULL;
   }
 
   // We only need to serialize one of the kRANN objects.
@@ -118,6 +124,9 @@ void RAModel<SortPolicy>::Serialize(Archive& ar,
     case R_PLUS_PLUS_TREE:
       ar & data::CreateNVP(rPlusPlusTreeRA, "ra_model");
       break;
+    case UB_TREE:
+      ar & data::CreateNVP(ubTreeRA, "ra_model");
+      break;
   }
 }
 
@@ -140,6 +149,8 @@ const arma::mat& RAModel<SortPolicy>::Dataset() const
     return rPlusTreeRA->ReferenceSet();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->ReferenceSet();
+  else if (ubTreeRA)
+    return ubTreeRA->ReferenceSet();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -164,6 +175,8 @@ bool RAModel<SortPolicy>::Naive() const
     return rPlusTreeRA->Naive();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->Naive();
+  else if (ubTreeRA)
+    return ubTreeRA->Naive();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -188,6 +201,8 @@ bool& RAModel<SortPolicy>::Naive()
     return rPlusTreeRA->Naive();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->Naive();
+  else if (ubTreeRA)
+    return ubTreeRA->Naive();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -212,6 +227,8 @@ bool RAModel<SortPolicy>::SingleMode() const
     return rPlusTreeRA->SingleMode();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->SingleMode();
+  else if (ubTreeRA)
+    return ubTreeRA->SingleMode();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -236,6 +253,8 @@ bool& RAModel<SortPolicy>::SingleMode()
     return rPlusTreeRA->SingleMode();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->SingleMode();
+  else if (ubTreeRA)
+    return ubTreeRA->SingleMode();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -260,6 +279,8 @@ double RAModel<SortPolicy>::Tau() const
     return rPlusTreeRA->Tau();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->Tau();
+  else if (ubTreeRA)
+    return ubTreeRA->Tau();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -284,6 +305,8 @@ double& RAModel<SortPolicy>::Tau()
     return rPlusTreeRA->Tau();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->Tau();
+  else if (ubTreeRA)
+    return ubTreeRA->Tau();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -308,6 +331,8 @@ double RAModel<SortPolicy>::Alpha() const
     return rPlusTreeRA->Alpha();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->Alpha();
+  else if (ubTreeRA)
+    return ubTreeRA->Alpha();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -332,6 +357,8 @@ double& RAModel<SortPolicy>::Alpha()
     return rPlusTreeRA->Alpha();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->Alpha();
+  else if (ubTreeRA)
+    return ubTreeRA->Alpha();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -356,6 +383,8 @@ bool RAModel<SortPolicy>::SampleAtLeaves() const
     return rPlusTreeRA->SampleAtLeaves();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->SampleAtLeaves();
+  else if (ubTreeRA)
+    return ubTreeRA->SampleAtLeaves();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -380,6 +409,8 @@ bool& RAModel<SortPolicy>::SampleAtLeaves()
     return rPlusTreeRA->SampleAtLeaves();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->SampleAtLeaves();
+  else if (ubTreeRA)
+    return ubTreeRA->SampleAtLeaves();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -404,6 +435,8 @@ bool RAModel<SortPolicy>::FirstLeafExact() const
     return rPlusTreeRA->FirstLeafExact();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->FirstLeafExact();
+  else if (ubTreeRA)
+    return ubTreeRA->FirstLeafExact();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -428,6 +461,8 @@ bool& RAModel<SortPolicy>::FirstLeafExact()
     return rPlusTreeRA->FirstLeafExact();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->FirstLeafExact();
+  else if (ubTreeRA)
+    return ubTreeRA->FirstLeafExact();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -452,6 +487,8 @@ size_t RAModel<SortPolicy>::SingleSampleLimit() const
     return rPlusTreeRA->SingleSampleLimit();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->SingleSampleLimit();
+  else if (ubTreeRA)
+    return ubTreeRA->SingleSampleLimit();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -476,6 +513,8 @@ size_t& RAModel<SortPolicy>::SingleSampleLimit()
     return rPlusTreeRA->SingleSampleLimit();
   else if (rPlusPlusTreeRA)
     return rPlusPlusTreeRA->SingleSampleLimit();
+  else if (ubTreeRA)
+    return ubTreeRA->SingleSampleLimit();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -547,6 +586,8 @@ void RAModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
     delete rPlusTreeRA;
   if (rPlusPlusTreeRA)
     delete rPlusPlusTreeRA;
+  if (ubTreeRA)
+    delete ubTreeRA;
 
   if (randomBasis)
     referenceSet = q * referenceSet;
@@ -607,6 +648,10 @@ void RAModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
       rPlusPlusTreeRA = new RAType<tree::RPlusPlusTree>(std::move(referenceSet),
           naive, singleMode);
       break;
+    case UB_TREE:
+      ubTreeRA = new RAType<tree::UBTree>(std::move(referenceSet),
+          naive, singleMode);
+      break;
   }
 
   if (!naive)
@@ -696,6 +741,10 @@ void RAModel<SortPolicy>::Search(arma::mat&& querySet,
       // No mapping necessary.
       rPlusPlusTreeRA->Search(querySet, k, neighbors, distances);
       break;
+    case UB_TREE:
+      // No mapping necessary.
+      ubTreeRA->Search(querySet, k, neighbors, distances);
+      break;
   }
 }
 
@@ -739,6 +788,9 @@ void RAModel<SortPolicy>::Search(const size_t k,
     case R_PLUS_PLUS_TREE:
       rPlusPlusTreeRA->Search(k, neighbors, distances);
       break;
+    case UB_TREE:
+      ubTreeRA->Search(k, neighbors, distances);
+      break;
   }
 }
 
@@ -763,6 +815,8 @@ std::string RAModel<SortPolicy>::TreeName() const
       return "R+ tree";
     case R_PLUS_PLUS_TREE:
       return "R++ tree";
+    case UB_TREE:
+      return "UB tree";
     default:
       return "unknown tree";
   }
diff --git a/src/mlpack/tests/krann_search_test.cpp b/src/mlpack/tests/krann_search_test.cpp
index 34f87d7..3d3f918 100644
--- a/src/mlpack/tests/krann_search_test.cpp
+++ b/src/mlpack/tests/krann_search_test.cpp
@@ -625,7 +625,7 @@ BOOST_AUTO_TEST_CASE(RAModelTest)
   data::Load("rann_test_q_3_100.csv", queryData, true);
 
   // Build all the possible models.
-  KNNModel models[16];
+  KNNModel models[18];
   models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
   models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
   models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false);
@@ -642,13 +642,15 @@ BOOST_AUTO_TEST_CASE(RAModelTest)
   models[13] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, true);
   models[14] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false);
   models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true);
+  models[16] = KNNModel(KNNModel::TreeTypes::UB_TREE, false);
+  models[17] = KNNModel(KNNModel::TreeTypes::UB_TREE, true);
 
   arma::Mat<size_t> qrRanks;
   data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
 
   for (size_t j = 0; j < 3; ++j)
   {
-    for (size_t i = 0; i < 16; ++i)
+    for (size_t i = 0; i < 18; ++i)
     {
       // We only have std::move() constructors so make a copy of our data.
       arma::mat referenceCopy(referenceData);




More information about the mlpack-git mailing list