[mlpack-git] master: Add ball tree support to NSModel. (abce060)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Oct 20 09:47:56 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/67e0a132c7f62820c734eb508fe1bc83128a3e13...00eccfdb0d315de3d94bfa1da84cc1dc65c8af39

>---------------------------------------------------------------

commit abce06073c625f8fe48d0caad26e4ca658c2a39d
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Oct 19 20:57:18 2015 +0000

    Add ball tree support to NSModel.


>---------------------------------------------------------------

abce06073c625f8fe48d0caad26e4ca658c2a39d
 src/mlpack/methods/neighbor_search/ns_model.hpp    |  4 +-
 .../methods/neighbor_search/ns_model_impl.hpp      | 80 +++++++++++++++++++++-
 src/mlpack/tests/allknn_test.cpp                   | 12 ++--
 3 files changed, 90 insertions(+), 6 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 6412922..6242377 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -46,7 +46,8 @@ class NSModel
     KD_TREE,
     COVER_TREE,
     R_TREE,
-    R_STAR_TREE
+    R_STAR_TREE,
+    BALL_TREE
   };
 
  private:
@@ -70,6 +71,7 @@ class NSModel
   NSType<tree::StandardCoverTree>* coverTreeNS;
   NSType<tree::RTree>* rTreeNS;
   NSType<tree::RStarTree>* rStarTreeNS;
+  NSType<tree::BallTree>* ballTreeNS;
 
  public:
   /**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 1f9ba6f..5eb5511 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -27,7 +27,8 @@ NSModel<SortPolicy>::NSModel(int treeType, bool randomBasis) :
     kdTreeNS(NULL),
     coverTreeNS(NULL),
     rTreeNS(NULL),
-    rStarTreeNS(NULL)
+    rStarTreeNS(NULL),
+    ballTreeNS(NULL)
 {
   // Nothing to do.
 }
@@ -44,6 +45,8 @@ NSModel<SortPolicy>::~NSModel()
     delete rTreeNS;
   if (rStarTreeNS)
     delete rStarTreeNS;
+  if (ballTreeNS)
+    delete ballTreeNS;
 }
 
 //! Serialize the kNN model.
@@ -67,6 +70,8 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
       delete rTreeNS;
     if (rStarTreeNS)
       delete rStarTreeNS;
+    if (ballTreeNS)
+      delete ballTreeNS;
 
     // Set all the pointers to NULL.
     kdTreeNS = NULL;
@@ -91,6 +96,9 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
     case R_STAR_TREE:
       ar & data::CreateNVP(rStarTreeNS, name);
       break;
+    case BALL_TREE:
+      ar & data::CreateNVP(ballTreeNS, name);
+      break;
   }
 }
 
@@ -105,6 +113,8 @@ const arma::mat& NSModel<SortPolicy>::Dataset() const
     return rTreeNS->ReferenceSet();
   else if (rStarTreeNS)
     return rStarTreeNS->ReferenceSet();
+  else if (ballTreeNS)
+    return ballTreeNS->ReferenceSet();
 
   throw std::runtime_error("no neighbor search model initialized");
 }
@@ -121,6 +131,8 @@ bool NSModel<SortPolicy>::SingleMode() const
     return rTreeNS->SingleMode();
   else if (rStarTreeNS)
     return rStarTreeNS->SingleMode();
+  else if (ballTreeNS)
+    return ballTreeNS->SingleMode();
 
   throw std::runtime_error("no neighbor search model initialized");
 }
@@ -136,6 +148,8 @@ bool& NSModel<SortPolicy>::SingleMode()
     return rTreeNS->SingleMode();
   else if (rStarTreeNS)
     return rStarTreeNS->SingleMode();
+  else if (ballTreeNS)
+    return ballTreeNS->SingleMode();
 
   throw std::runtime_error("no neighbor search model initialized");
 }
@@ -151,6 +165,8 @@ bool NSModel<SortPolicy>::Naive() const
     return rTreeNS->Naive();
   else if (rStarTreeNS)
     return rStarTreeNS->Naive();
+  else if (ballTreeNS)
+    return ballTreeNS->Naive();
 
   throw std::runtime_error("no neighbor search model initialized");
 }
@@ -166,6 +182,8 @@ bool& NSModel<SortPolicy>::Naive()
     return rTreeNS->Naive();
   else if (rStarTreeNS)
     return rStarTreeNS->Naive();
+  else if (ballTreeNS)
+    return ballTreeNS->Naive();
 
   throw std::runtime_error("no neighbor search model initialized");
 }
@@ -218,6 +236,8 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
     delete rTreeNS;
   if (rStarTreeNS)
     delete rStarTreeNS;
+  if (ballTreeNS)
+    delete ballTreeNS;
 
   // Do we need to modify the reference set?
   if (randomBasis)
@@ -267,6 +287,27 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
       rStarTreeNS = new NSType<tree::RStarTree>(std::move(referenceSet), naive,
           singleMode);
       break;
+    case BALL_TREE:
+      // If necessary, build the ball tree.
+      if (naive)
+      {
+        ballTreeNS = new NSType<tree::BallTree>(std::move(referenceSet), naive,
+            singleMode);
+      }
+      else
+      {
+        std::vector<size_t> oldFromNewReferences;
+        typename NSType<tree::BallTree>::Tree* ballTree =
+            new typename NSType<tree::KDTree>::Tree(std::move(referenceSet),
+            oldFromNewReferences, leafSize);
+        ballTreeNS = new NSType<tree::BallTree>(ballTree, singleMode);
+
+        // Give the model ownership of the tree and the mappings.
+        ballTreeNS->treeOwner = true;
+        ballTreeNS->oldFromNewReferences = std::move(oldFromNewReferences);
+      }
+
+      break;
   }
 
   if (!naive)
@@ -340,6 +381,38 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
       // No mapping necessary.
       rStarTreeNS->Search(querySet, k, neighbors, distances);
       break;
+    case BALL_TREE:
+      if (!ballTreeNS->Naive() && !ballTreeNS->SingleMode())
+      {
+        // Build a second tree and search.
+        Timer::Start("tree_building");
+        Log::Info << "Building query tree..." << std::endl;
+        std::vector<size_t> oldFromNewQueries;
+        typename NSType<tree::BallTree>::Tree queryTree(std::move(querySet),
+            oldFromNewQueries, leafSize);
+        Log::Info << "Tree built." << std::endl;
+        Timer::Stop("tree_building");
+
+        arma::Mat<size_t> neighborsOut;
+        arma::mat distancesOut;
+        ballTreeNS->Search(&queryTree, k, neighborsOut, distancesOut);
+
+        // Unmap the query points.
+        distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+        neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+        for (size_t i = 0; i < neighborsOut.n_cols; ++i)
+        {
+          neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
+          distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+        }
+      }
+      else
+      {
+        // Search without building a second tree.
+        ballTreeNS->Search(querySet, k, neighbors, distances);
+      }
+
+      break;
   }
 }
 
@@ -371,6 +444,9 @@ void NSModel<SortPolicy>::Search(const size_t k,
     case R_STAR_TREE:
       rStarTreeNS->Search(k, neighbors, distances);
       break;
+    case BALL_TREE:
+      ballTreeNS->Search(k, neighbors, distances);
+      break;
   }
 }
 
@@ -388,6 +464,8 @@ void NSModel<SortPolicy>::TreeName() const
       return "R tree";
     case R_STAR_TREE:
       return "R* tree";
+    case BALL_TREE:
+      return "ball tree";
     default:
       return "unknown tree";
   }
diff --git a/src/mlpack/tests/allknn_test.cpp b/src/mlpack/tests/allknn_test.cpp
index 150a765..82b8564 100644
--- a/src/mlpack/tests/allknn_test.cpp
+++ b/src/mlpack/tests/allknn_test.cpp
@@ -975,7 +975,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
   arma::mat referenceData = arma::randu<arma::mat>(10, 200);
 
   // Build all the possible models.
-  KNNModel models[8];
+  KNNModel models[10];
   models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
   models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
   models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
@@ -984,6 +984,8 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
   models[5] = KNNModel(KNNModel::TreeTypes::R_TREE, false);
   models[6] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, true);
   models[7] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, false);
+  models[8] = KNNModel(KNNModel::TreeTypes::BALL_TREE, true);
+  models[9] = KNNModel(KNNModel::TreeTypes::BALL_TREE, false);
 
   for (size_t j = 0; j < 2; ++j)
   {
@@ -993,7 +995,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
     arma::mat baselineDistances;
     knn.Search(queryData, 3, baselineNeighbors, baselineDistances);
 
-    for (size_t i = 0; i < 8; ++i)
+    for (size_t i = 0; i < 10; ++i)
     {
       // We only have std::move() constructors so make a copy of our data.
       arma::mat referenceCopy(referenceData);
@@ -1037,7 +1039,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
   arma::mat referenceData = arma::randu<arma::mat>(10, 200);
 
   // Build all the possible models.
-  KNNModel models[8];
+  KNNModel models[10];
   models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
   models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
   models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
@@ -1046,6 +1048,8 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
   models[5] = KNNModel(KNNModel::TreeTypes::R_TREE, false);
   models[6] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, true);
   models[7] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, false);
+  models[8] = KNNModel(KNNModel::TreeTypes::BALL_TREE, true);
+  models[0] = KNNModel(KNNModel::TreeTypes::BALL_TREE, false);
 
   for (size_t j = 0; j < 2; ++j)
   {
@@ -1055,7 +1059,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
     arma::mat baselineDistances;
     knn.Search(3, baselineNeighbors, baselineDistances);
 
-    for (size_t i = 0; i < 8; ++i)
+    for (size_t i = 0; i < 10; ++i)
     {
       // We only have a std::move() constructor... so copy the data.
       arma::mat referenceCopy(referenceData);



More information about the mlpack-git mailing list