[mlpack-git] master: Add RAModel. (26e8560)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 9 14:37:20 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/cec4ac427536cbd9738a33e0c6facabeeadd31b0...4a39d474593067343b4972d4a5217bcfae84ca5d

>---------------------------------------------------------------

commit 26e85605312127537eb47e5484981776ade10b9b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Dec 9 09:39:02 2015 -0500

    Add RAModel.


>---------------------------------------------------------------

26e85605312127537eb47e5484981776ade10b9b
 src/mlpack/methods/rann/CMakeLists.txt    |   4 +
 src/mlpack/methods/rann/ra_model.hpp      | 172 +++++++++
 src/mlpack/methods/rann/ra_model_impl.hpp | 555 ++++++++++++++++++++++++++++++
 src/mlpack/methods/rann/ra_search.hpp     |   7 +
 4 files changed, 738 insertions(+)

diff --git a/src/mlpack/methods/rann/CMakeLists.txt b/src/mlpack/methods/rann/CMakeLists.txt
index 30e1a89..f325954 100644
--- a/src/mlpack/methods/rann/CMakeLists.txt
+++ b/src/mlpack/methods/rann/CMakeLists.txt
@@ -19,6 +19,10 @@ set(SOURCES
   # utilities
   ra_util.hpp
   ra_util.cpp
+
+  # model
+  ra_model.hpp
+  ra_model_impl.hpp
 )
 
 # add directory name to sources
diff --git a/src/mlpack/methods/rann/ra_model.hpp b/src/mlpack/methods/rann/ra_model.hpp
new file mode 100644
index 0000000..26af54b
--- /dev/null
+++ b/src/mlpack/methods/rann/ra_model.hpp
@@ -0,0 +1,172 @@
+/**
+ * @file ra_model.hpp
+ * @author Ryan Curtin
+ *
+ * This is a model for rank-approximate nearest neighbor search.  It provides an
+ * easy way to serialize a rank-approximate neighbor search model by abstracting
+ * the types of trees and reflecting the RASearch API.
+ */
+#ifndef __MLPACK_METHODS_RANN_RA_MODEL_HPP
+#define __MLPACK_METHODS_RANN_RA_MODEL_HPP
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
+
+#include "ra_search.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * The RAModel class provides an abstraction for the RASearch class, abstracting
+ * away the TreeType parameter and allowing it to be specified at runtime in
+ * this class.  This class is written for the sake of the 'allkrann' program,
+ * but is not necessarily restricted to that use.
+ *
+ * @param SortPolicy Sorting policy for neighbor searching (see RASearch).
+ */
+template<typename SortPolicy>
+class RAModel
+{
+ public:
+  /**
+   * The list of tree types we can use with RASearch.  Does not include ball
+   * trees; see #338.
+   */
+  enum TreeTypes
+  {
+    KD_TREE,
+    COVER_TREE,
+    R_TREE,
+    R_STAR_TREE
+  };
+
+ private:
+  //! The type of tree being used.
+  int treeType;
+  //! The leaf size of the tree being used (useful only for the kd-tree).
+  size_t leafSize;
+
+  //! If true, randomly project into a new basis.
+  bool randomBasis;
+  //! The basis to project into.
+  arma::mat q;
+
+  //! Typedef the RASearch class we'll use.
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+  using RAType = RASearch<SortPolicy,
+                          metric::EuclideanDistance,
+                          arma::mat,
+                          TreeType>;
+
+  //! Non-NULL if the kd-tree is used.
+  RAType<tree::KDTree>* kdTreeRA;
+  //! Non-NULL if the cover tree is used.
+  RAType<tree::StandardCoverTree>* coverTreeRA;
+  //! Non-NULL if the R tree is used.
+  RAType<tree::RTree>* rTreeRA;
+  //! Non-NULL if the R* tree is used.
+  RAType<tree::RStarTree>* rStarTreeRA;
+
+ public:
+  /**
+   * Initialize the RAModel with the given type and whether or not a random
+   * basis should be used.
+   */
+  RAModel(int treeType = TreeTypes::KD_TREE, bool randomBasis = false);
+
+  //! Clean memory, if necessary.
+  ~RAModel();
+
+  //! Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
+  //! Expose the dataset.
+  const arma::mat& Dataset() const;
+
+  //! Get whether or not single-tree search is being used.
+  bool SingleMode() const;
+  //! Modify whether or not single-tree search is being used.
+  bool& SingleMode();
+
+  //! Get whether or not naive search is being used.
+  bool Naive() const;
+  //! Modify whether or not naive search is being used.
+  bool& Naive();
+
+  //! Get the rank-approximation in percentile of the data.
+  double Tau() const;
+  //! Modify the rank-approximation in percentile of the data.
+  double& Tau();
+
+  //! Get the desired success probability.
+  double Alpha() const;
+  //! Modify the desired success probability.
+  double& Alpha();
+
+  //! Get whether or not sampling is done at the leaves.
+  bool SampleAtLeaves() const;
+  //! Modify whether or not sampling is done at the leaves.
+  bool& SampleAtLeaves();
+
+  //! Get whether or not we traverse to the first leaf without approximation.
+  bool FirstLeafExact() const;
+  //! Modify whether or not we traverse to the first leaf without approximation.
+  bool& FirstLeafExact();
+
+  //! Get the limit on the size of a node that can be approximated.
+  size_t SingleSampleLimit() const;
+  //! Modify the limit on the size of a node that can be approximation.
+  size_t& SingleSampleLimit();
+
+  //! Get the leaf size (only relevant when the kd-tree is used).
+  size_t LeafSize() const;
+  //! Modify the leaf size (only relevant when the kd-tree is used).
+  size_t& LeafSize();
+
+  //! Get the type of tree being used.
+  int TreeType() const;
+  //! Modify the type of tree being used.
+  int& TreeType();
+
+  //! Get whether or not a random basis is being used.
+  bool RandomBasis() const;
+  //! Modify whether or not a random basis is being used.  Be sure to rebuild
+  //! the model using BuildModel().
+  bool& RandomBasis();
+
+  //! Build the reference tree.
+  void BuildModel(arma::mat&& referenceSet,
+                  const size_t leafSize,
+                  const bool naive,
+                  const bool singleMode);
+
+  //! Perform rank-approximate neighbor search, taking ownership of the query
+  //! set.
+  void Search(arma::mat&& querySet,
+              const size_t k,
+              arma::Mat<size_t>& neighbors,
+              arma::mat& distances);
+
+  /**
+   * Perform rank-approximate neighbor search, using the reference set as the
+   * query set.
+   */
+  void Search(const size_t k,
+              arma::Mat<size_t>& neighbors,
+              arma::mat& distances);
+
+  //! Get the name of the tree type.
+  std::string TreeName() const;
+};
+
+} // namespace neighbor
+} // namespace mlpack
+
+#include "ra_model_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/rann/ra_model_impl.hpp b/src/mlpack/methods/rann/ra_model_impl.hpp
new file mode 100644
index 0000000..56d2156
--- /dev/null
+++ b/src/mlpack/methods/rann/ra_model_impl.hpp
@@ -0,0 +1,555 @@
+/**
+ * @file ra_model_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the RAModel class.
+ */
+#ifndef __MLPACK_METHODS_RANN_RA_MODEL_IMPL_HPP
+#define __MLPACK_METHODS_RANN_RA_MODEL_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "ra_model.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy>
+RAModel<SortPolicy>::RAModel(const int treeType, const bool randomBasis) :
+    treeType(treeType),
+    randomBasis(randomBasis),
+    kdTreeRA(NULL),
+    coverTreeRA(NULL),
+    rTreeRA(NULL),
+    rStarTreeRA(NULL)
+{
+  // Nothing to do.
+}
+
+template<typename SortPolicy>
+RAModel<SortPolicy>::~RAModel()
+{
+  if (kdTreeRA)
+    delete kdTreeRA;
+  if (coverTreeRA)
+    delete coverTreeRA;
+  if (rTreeRA)
+    delete rTreeRA;
+  if (rStarTreeRA)
+    delete rStarTreeRA;
+}
+
+template<typename SortPolicy>
+template<typename Archive>
+void RAModel<SortPolicy>::Serialize(Archive& ar,
+                                    const unsigned int /* version */)
+{
+  ar & data::CreateNVP(treeType, "treeType");
+  ar & data::CreateNVP(randomBasis, "randomBasis");
+  ar & data::CreateNVP(q, "q");
+
+  // This should never happen, but just in case, be clean with memory.
+  if (Archive::is_loading::value)
+  {
+    if (kdTreeRA)
+      delete kdTreeRA;
+    if (coverTreeRA)
+      delete coverTreeRA;
+    if (rTreeRA)
+      delete rTreeRA;
+    if (rStarTreeRA)
+      delete rStarTreeRA;
+
+    // Set all the pointers to NULL.
+    kdTreeRA = NULL;
+    coverTreeRA = NULL;
+    rTreeRA = NULL;
+    rStarTreeRA = NULL;
+  }
+
+  // We only need to serialize one of the kRANN objects.
+  switch (treeType)
+  {
+    case KD_TREE:
+      ar & data::CreateNVP(kdTreeRA, "ra_model");
+      break;
+    case COVER_TREE:
+      ar & data::CreateNVP(coverTreeRA, "ra_model");
+      break;
+    case R_TREE:
+      ar & data::CreateNVP(rTreeRA, "ra_model");
+      break;
+    case R_STAR_TREE:
+      ar & data::CreateNVP(rStarTreeRA, "ra_model");
+      break;
+  }
+}
+
+template<typename SortPolicy>
+const arma::mat& RAModel<SortPolicy>::Dataset() const
+{
+  if (kdTreeRA)
+    return kdTreeRA->ReferenceSet();
+  else if (coverTreeRA)
+    return coverTreeRA->ReferenceSet();
+  else if (rTreeRA)
+    return rTreeRA->ReferenceSet();
+  else if (rStarTreeRA)
+    return rStarTreeRA->ReferenceSet();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+bool RAModel<SortPolicy>::Naive() const
+{
+  if (kdTreeRA)
+    return kdTreeRA->Naive();
+  else if (coverTreeRA)
+    return coverTreeRA->Naive();
+  else if (rTreeRA)
+    return rTreeRA->Naive();
+  else if (rStarTreeRA)
+    return rStarTreeRA->Naive();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+bool& RAModel<SortPolicy>::Naive()
+{
+  if (kdTreeRA)
+    return kdTreeRA->Naive();
+  else if (coverTreeRA)
+    return coverTreeRA->Naive();
+  else if (rTreeRA)
+    return rTreeRA->Naive();
+  else if (rStarTreeRA)
+    return rStarTreeRA->Naive();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+bool RAModel<SortPolicy>::SingleMode() const
+{
+  if (kdTreeRA)
+    return kdTreeRA->SingleMode();
+  else if (coverTreeRA)
+    return coverTreeRA->SingleMode();
+  else if (rTreeRA)
+    return rTreeRA->SingleMode();
+  else if (rStarTreeRA)
+    return rStarTreeRA->SingleMode();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+bool& RAModel<SortPolicy>::SingleMode()
+{
+  if (kdTreeRA)
+    return kdTreeRA->SingleMode();
+  else if (coverTreeRA)
+    return coverTreeRA->SingleMode();
+  else if (rTreeRA)
+    return rTreeRA->SingleMode();
+  else if (rStarTreeRA)
+    return rStarTreeRA->SingleMode();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+double RAModel<SortPolicy>::Tau() const
+{
+  if (kdTreeRA)
+    return kdTreeRA->Tau();
+  else if (coverTreeRA)
+    return coverTreeRA->Tau();
+  else if (rTreeRA)
+    return rTreeRA->Tau();
+  else if (rStarTreeRA)
+    return rStarTreeRA->Tau();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+double& RAModel<SortPolicy>::Tau()
+{
+  if (kdTreeRA)
+    return kdTreeRA->Tau();
+  else if (coverTreeRA)
+    return coverTreeRA->Tau();
+  else if (rTreeRA)
+    return rTreeRA->Tau();
+  else if (rStarTreeRA)
+    return rStarTreeRA->Tau();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+double RAModel<SortPolicy>::Alpha() const
+{
+  if (kdTreeRA)
+    return kdTreeRA->Alpha();
+  else if (coverTreeRA)
+    return coverTreeRA->Alpha();
+  else if (rTreeRA)
+    return rTreeRA->Alpha();
+  else if (rStarTreeRA)
+    return rStarTreeRA->Alpha();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+double& RAModel<SortPolicy>::Alpha()
+{
+  if (kdTreeRA)
+    return kdTreeRA->Alpha();
+  else if (coverTreeRA)
+    return coverTreeRA->Alpha();
+  else if (rTreeRA)
+    return rTreeRA->Alpha();
+  else if (rStarTreeRA)
+    return rStarTreeRA->Alpha();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+bool RAModel<SortPolicy>::SampleAtLeaves() const
+{
+  if (kdTreeRA)
+    return kdTreeRA->SampleAtLeaves();
+  else if (coverTreeRA)
+    return coverTreeRA->SampleAtLeaves();
+  else if (rTreeRA)
+    return rTreeRA->SampleAtLeaves();
+  else if (rStarTreeRA)
+    return rStarTreeRA->SampleAtLeaves();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+bool& RAModel<SortPolicy>::SampleAtLeaves()
+{
+  if (kdTreeRA)
+    return kdTreeRA->SampleAtLeaves();
+  else if (coverTreeRA)
+    return coverTreeRA->SampleAtLeaves();
+  else if (rTreeRA)
+    return rTreeRA->SampleAtLeaves();
+  else if (rStarTreeRA)
+    return rStarTreeRA->SampleAtLeaves();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+bool RAModel<SortPolicy>::FirstLeafExact() const
+{
+  if (kdTreeRA)
+    return kdTreeRA->FirstLeafExact();
+  else if (coverTreeRA)
+    return coverTreeRA->FirstLeafExact();
+  else if (rTreeRA)
+    return rTreeRA->FirstLeafExact();
+  else if (rStarTreeRA)
+    return rStarTreeRA->FirstLeafExact();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+bool& RAModel<SortPolicy>::FirstLeafExact()
+{
+  if (kdTreeRA)
+    return kdTreeRA->FirstLeafExact();
+  else if (coverTreeRA)
+    return coverTreeRA->FirstLeafExact();
+  else if (rTreeRA)
+    return rTreeRA->FirstLeafExact();
+  else if (rStarTreeRA)
+    return rStarTreeRA->FirstLeafExact();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+size_t RAModel<SortPolicy>::SingleSampleLimit() const
+{
+  if (kdTreeRA)
+    return kdTreeRA->SingleSampleLimit();
+  else if (coverTreeRA)
+    return coverTreeRA->SingleSampleLimit();
+  else if (rTreeRA)
+    return rTreeRA->SingleSampleLimit();
+  else if (rStarTreeRA)
+    return rStarTreeRA->SingleSampleLimit();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+size_t& RAModel<SortPolicy>::SingleSampleLimit()
+{
+  if (kdTreeRA)
+    return kdTreeRA->SingleSampleLimit();
+  else if (coverTreeRA)
+    return coverTreeRA->SingleSampleLimit();
+  else if (rTreeRA)
+    return rTreeRA->SingleSampleLimit();
+  else if (rStarTreeRA)
+    return rStarTreeRA->SingleSampleLimit();
+
+  throw std::runtime_error("no rank-approximate nearest neighbor search model "
+      "initialized");
+}
+
+template<typename SortPolicy>
+size_t RAModel<SortPolicy>::LeafSize() const
+{
+  return leafSize;
+}
+
+template<typename SortPolicy>
+size_t& RAModel<SortPolicy>::LeafSize()
+{
+  return leafSize;
+}
+
+template<typename SortPolicy>
+int RAModel<SortPolicy>::TreeType() const
+{
+  return treeType;
+}
+
+template<typename SortPolicy>
+int& RAModel<SortPolicy>::TreeType()
+{
+  return treeType;
+}
+
+template<typename SortPolicy>
+bool RAModel<SortPolicy>::RandomBasis() const
+{
+  return randomBasis;
+}
+
+template<typename SortPolicy>
+bool& RAModel<SortPolicy>::RandomBasis()
+{
+  return randomBasis;
+}
+
+template<typename SortPolicy>
+void RAModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
+                                     const size_t leafSize,
+                                     const bool naive,
+                                     const bool singleMode)
+{
+  // Initialize random basis, if necessary.
+  if (randomBasis)
+  {
+    Log::Info << "Creating random basis..." << std::endl;
+    math::RandomBasis(q, referenceSet.n_rows);
+  }
+
+  // Clean memory, if necessary.
+  if (kdTreeRA)
+    delete kdTreeRA;
+  if (coverTreeRA)
+    delete coverTreeRA;
+  if (rTreeRA)
+    delete rTreeRA;
+  if (rStarTreeRA)
+    delete rStarTreeRA;
+
+  if (randomBasis)
+    referenceSet = q * referenceSet;
+
+  if (!naive)
+  {
+    Timer::Start("tree_building");
+    Log::Info << "Building reference tree..." << std::endl;
+  }
+
+  switch (treeType)
+  {
+    case KD_TREE:
+      // Build tree, if necessary.
+      if (naive)
+      {
+        kdTreeRA = new RAType<tree::KDTree>(std::move(referenceSet), naive,
+            singleMode);
+      }
+      else
+      {
+        std::vector<size_t> oldFromNewReferences;
+        typename RAType<tree::KDTree>::Tree* kdTree =
+            new typename RAType<tree::KDTree>::Tree(std::move(referenceSet),
+            oldFromNewReferences, leafSize);
+        kdTreeRA = new RAType<tree::KDTree>(kdTree, singleMode);
+
+        // Give the model ownership of the tree.
+        kdTreeRA->treeOwner = true;
+        kdTreeRA->oldFromNewReferences = oldFromNewReferences;
+      }
+      break;
+    case COVER_TREE:
+      coverTreeRA = new RAType<tree::StandardCoverTree>(std::move(referenceSet),
+          naive, singleMode);
+      break;
+    case R_TREE:
+      rTreeRA = new RAType<tree::RTree>(std::move(referenceSet), naive,
+          singleMode);
+      break;
+    case R_STAR_TREE:
+      rStarTreeRA = new RAType<tree::RStarTree>(std::move(referenceSet), naive,
+          singleMode);
+      break;
+  }
+
+  if (!naive)
+  {
+    Timer::Stop("tree_building");
+    Log::Info << "Tree built." << std::endl;
+  }
+}
+
+template<typename SortPolicy>
+void RAModel<SortPolicy>::Search(arma::mat&& querySet,
+                                 const size_t k,
+                                 arma::Mat<size_t>& neighbors,
+                                 arma::mat& distances)
+{
+  // Apply the random basis if necessary.
+  if (randomBasis)
+    querySet = q * querySet;
+
+  Log::Info << "Searching for " << k << " approximate nearest neighbors with ";
+  if (!Naive() && !SingleMode())
+    Log::Info << "dual-tree rank-approximate " << TreeName() << " search...";
+  else if (!Naive())
+    Log::Info << "single-tree rank-approximate " << TreeName() << " search...";
+  else
+    Log::Info << "brute-force (naive) rank-approximate search...";
+
+  switch (treeType)
+  {
+    case KD_TREE:
+      if (!kdTreeRA->Naive() && !kdTreeRA->SingleMode())
+      {
+        // Build a second tree and search.
+        Timer::Start("tree_building");
+        Log::Info << "Building query tree..." << std::endl;
+        std::vector<size_t> oldFromNewQueries;
+        typename RAType<tree::KDTree>::Tree queryTree(std::move(querySet),
+            oldFromNewQueries, leafSize);
+        Log::Info << "Tree built." << std::endl;
+        Timer::Stop("tree_building");
+
+        arma::Mat<size_t> neighborsOut;
+        arma::mat distancesOut;
+        kdTreeRA->Search(&queryTree, k, neighborsOut, distancesOut);
+
+        // Unmap the query points.
+        distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+        neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+        for (size_t i = 0; i < neighborsOut.n_cols; ++i)
+        {
+          neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
+          distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+        }
+      }
+      else
+      {
+        // Search without building a second tree.
+        kdTreeRA->Search(querySet, k, neighbors, distances);
+      }
+      break;
+    case COVER_TREE:
+      // No mapping necessary.
+      coverTreeRA->Search(querySet, k, neighbors, distances);
+      break;
+    case R_TREE:
+      // No mapping necessary.
+      rTreeRA->Search(querySet, k, neighbors, distances);
+      break;
+    case R_STAR_TREE:
+      // No mapping necessary.
+      rStarTreeRA->Search(querySet, k, neighbors, distances);
+      break;
+  }
+}
+
+template<typename SortPolicy>
+void RAModel<SortPolicy>::Search(const size_t k,
+                                 arma::Mat<size_t>& neighbors,
+                                 arma::mat& distances)
+{
+  Log::Info << "Searching for " << k << " approximate nearest neighbors with ";
+  if (!Naive() && !SingleMode())
+    Log::Info << "dual-tree rank-approximate " << TreeName() << " search...";
+  else if (!Naive())
+    Log::Info << "single-tree rank-approximate " << TreeName() << " search...";
+  else
+    Log::Info << "brute-force (naive) rank-approximate search...";
+
+  switch (treeType)
+  {
+    case KD_TREE:
+      kdTreeRA->Search(k, neighbors, distances);
+      break;
+    case COVER_TREE:
+      coverTreeRA->Search(k, neighbors, distances);
+      break;
+    case R_TREE:
+      rTreeRA->Search(k, neighbors, distances);
+      break;
+    case R_STAR_TREE:
+      rStarTreeRA->Search(k, neighbors, distances);
+      break;
+  }
+}
+
+template<typename SortPolicy>
+std::string RAModel<SortPolicy>::TreeName() const
+{
+  switch (treeType)
+  {
+    case KD_TREE:
+      return "kd-tree";
+    case COVER_TREE:
+      return "cover tree";
+    case R_TREE:
+      return "R tree";
+    case R_STAR_TREE:
+      return "R* tree";
+    default:
+      return "unknown tree";
+  }
+}
+
+} // namespace neighbor
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index 9a4084b..feaf4b5 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -31,6 +31,10 @@
 namespace mlpack {
 namespace neighbor {
 
+// Forward declaration.
+template<typename SortPolicy>
+class RAModel;
+
 /**
  * The RASearch class: This class provides a generic manner to perform
  * rank-approximate search via random-sampling. If the 'naive' option is chosen,
@@ -444,6 +448,9 @@ class RASearch
 
   //! Instantiation of kernel.
   MetricType metric;
+
+  //! RAModel can modify internal members as necessary.
+  friend class RAModel<SortPolicy>;
 }; // class RASearch
 
 } // namespace neighbor



More information about the mlpack-git mailing list