[mlpack-git] master: Add RAModel. (26e8560)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 9 14:37:20 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/cec4ac427536cbd9738a33e0c6facabeeadd31b0...4a39d474593067343b4972d4a5217bcfae84ca5d
>---------------------------------------------------------------
commit 26e85605312127537eb47e5484981776ade10b9b
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Dec 9 09:39:02 2015 -0500
Add RAModel.
>---------------------------------------------------------------
26e85605312127537eb47e5484981776ade10b9b
src/mlpack/methods/rann/CMakeLists.txt | 4 +
src/mlpack/methods/rann/ra_model.hpp | 172 +++++++++
src/mlpack/methods/rann/ra_model_impl.hpp | 555 ++++++++++++++++++++++++++++++
src/mlpack/methods/rann/ra_search.hpp | 7 +
4 files changed, 738 insertions(+)
diff --git a/src/mlpack/methods/rann/CMakeLists.txt b/src/mlpack/methods/rann/CMakeLists.txt
index 30e1a89..f325954 100644
--- a/src/mlpack/methods/rann/CMakeLists.txt
+++ b/src/mlpack/methods/rann/CMakeLists.txt
@@ -19,6 +19,10 @@ set(SOURCES
# utilities
ra_util.hpp
ra_util.cpp
+
+ # model
+ ra_model.hpp
+ ra_model_impl.hpp
)
# add directory name to sources
diff --git a/src/mlpack/methods/rann/ra_model.hpp b/src/mlpack/methods/rann/ra_model.hpp
new file mode 100644
index 0000000..26af54b
--- /dev/null
+++ b/src/mlpack/methods/rann/ra_model.hpp
@@ -0,0 +1,172 @@
+/**
+ * @file ra_model.hpp
+ * @author Ryan Curtin
+ *
+ * This is a model for rank-approximate nearest neighbor search. It provides an
+ * easy way to serialize a rank-approximate neighbor search model by abstracting
+ * the types of trees and reflecting the RASearch API.
+ */
+#ifndef __MLPACK_METHODS_RANN_RA_MODEL_HPP
+#define __MLPACK_METHODS_RANN_RA_MODEL_HPP
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
+
+#include "ra_search.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * The RAModel class provides an abstraction for the RASearch class, abstracting
+ * away the TreeType parameter and allowing it to be specified at runtime in
+ * this class. This class is written for the sake of the 'allkrann' program,
+ * but is not necessarily restricted to that use.
+ *
+ * @param SortPolicy Sorting policy for neighbor searching (see RASearch).
+ */
+template<typename SortPolicy>
+class RAModel
+{
+ public:
+ /**
+ * The list of tree types we can use with RASearch. Does not include ball
+ * trees; see #338.
+ */
+ enum TreeTypes
+ {
+ KD_TREE,
+ COVER_TREE,
+ R_TREE,
+ R_STAR_TREE
+ };
+
+ private:
+ //! The type of tree being used.
+ int treeType;
+ //! The leaf size of the tree being used (useful only for the kd-tree).
+ size_t leafSize;
+
+ //! If true, randomly project into a new basis.
+ bool randomBasis;
+ //! The basis to project into.
+ arma::mat q;
+
+ //! Typedef the RASearch class we'll use.
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ using RAType = RASearch<SortPolicy,
+ metric::EuclideanDistance,
+ arma::mat,
+ TreeType>;
+
+ //! Non-NULL if the kd-tree is used.
+ RAType<tree::KDTree>* kdTreeRA;
+ //! Non-NULL if the cover tree is used.
+ RAType<tree::StandardCoverTree>* coverTreeRA;
+ //! Non-NULL if the R tree is used.
+ RAType<tree::RTree>* rTreeRA;
+ //! Non-NULL if the R* tree is used.
+ RAType<tree::RStarTree>* rStarTreeRA;
+
+ public:
+ /**
+ * Initialize the RAModel with the given type and whether or not a random
+ * basis should be used.
+ */
+ RAModel(int treeType = TreeTypes::KD_TREE, bool randomBasis = false);
+
+ //! Clean memory, if necessary.
+ ~RAModel();
+
+ //! Serialize the model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
+ //! Expose the dataset.
+ const arma::mat& Dataset() const;
+
+ //! Get whether or not single-tree search is being used.
+ bool SingleMode() const;
+ //! Modify whether or not single-tree search is being used.
+ bool& SingleMode();
+
+ //! Get whether or not naive search is being used.
+ bool Naive() const;
+ //! Modify whether or not naive search is being used.
+ bool& Naive();
+
+ //! Get the rank-approximation in percentile of the data.
+ double Tau() const;
+ //! Modify the rank-approximation in percentile of the data.
+ double& Tau();
+
+ //! Get the desired success probability.
+ double Alpha() const;
+ //! Modify the desired success probability.
+ double& Alpha();
+
+ //! Get whether or not sampling is done at the leaves.
+ bool SampleAtLeaves() const;
+ //! Modify whether or not sampling is done at the leaves.
+ bool& SampleAtLeaves();
+
+ //! Get whether or not we traverse to the first leaf without approximation.
+ bool FirstLeafExact() const;
+ //! Modify whether or not we traverse to the first leaf without approximation.
+ bool& FirstLeafExact();
+
+ //! Get the limit on the size of a node that can be approximated.
+ size_t SingleSampleLimit() const;
+ //! Modify the limit on the size of a node that can be approximation.
+ size_t& SingleSampleLimit();
+
+ //! Get the leaf size (only relevant when the kd-tree is used).
+ size_t LeafSize() const;
+ //! Modify the leaf size (only relevant when the kd-tree is used).
+ size_t& LeafSize();
+
+ //! Get the type of tree being used.
+ int TreeType() const;
+ //! Modify the type of tree being used.
+ int& TreeType();
+
+ //! Get whether or not a random basis is being used.
+ bool RandomBasis() const;
+ //! Modify whether or not a random basis is being used. Be sure to rebuild
+ //! the model using BuildModel().
+ bool& RandomBasis();
+
+ //! Build the reference tree.
+ void BuildModel(arma::mat&& referenceSet,
+ const size_t leafSize,
+ const bool naive,
+ const bool singleMode);
+
+ //! Perform rank-approximate neighbor search, taking ownership of the query
+ //! set.
+ void Search(arma::mat&& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
+
+ /**
+ * Perform rank-approximate neighbor search, using the reference set as the
+ * query set.
+ */
+ void Search(const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
+
+ //! Get the name of the tree type.
+ std::string TreeName() const;
+};
+
+} // namespace neighbor
+} // namespace mlpack
+
+#include "ra_model_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/rann/ra_model_impl.hpp b/src/mlpack/methods/rann/ra_model_impl.hpp
new file mode 100644
index 0000000..56d2156
--- /dev/null
+++ b/src/mlpack/methods/rann/ra_model_impl.hpp
@@ -0,0 +1,555 @@
+/**
+ * @file ra_model_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the RAModel class.
+ */
+#ifndef __MLPACK_METHODS_RANN_RA_MODEL_IMPL_HPP
+#define __MLPACK_METHODS_RANN_RA_MODEL_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "ra_model.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy>
+RAModel<SortPolicy>::RAModel(const int treeType, const bool randomBasis) :
+ treeType(treeType),
+ randomBasis(randomBasis),
+ kdTreeRA(NULL),
+ coverTreeRA(NULL),
+ rTreeRA(NULL),
+ rStarTreeRA(NULL)
+{
+ // Nothing to do.
+}
+
+template<typename SortPolicy>
+RAModel<SortPolicy>::~RAModel()
+{
+ if (kdTreeRA)
+ delete kdTreeRA;
+ if (coverTreeRA)
+ delete coverTreeRA;
+ if (rTreeRA)
+ delete rTreeRA;
+ if (rStarTreeRA)
+ delete rStarTreeRA;
+}
+
+template<typename SortPolicy>
+template<typename Archive>
+void RAModel<SortPolicy>::Serialize(Archive& ar,
+ const unsigned int /* version */)
+{
+ ar & data::CreateNVP(treeType, "treeType");
+ ar & data::CreateNVP(randomBasis, "randomBasis");
+ ar & data::CreateNVP(q, "q");
+
+ // This should never happen, but just in case, be clean with memory.
+ if (Archive::is_loading::value)
+ {
+ if (kdTreeRA)
+ delete kdTreeRA;
+ if (coverTreeRA)
+ delete coverTreeRA;
+ if (rTreeRA)
+ delete rTreeRA;
+ if (rStarTreeRA)
+ delete rStarTreeRA;
+
+ // Set all the pointers to NULL.
+ kdTreeRA = NULL;
+ coverTreeRA = NULL;
+ rTreeRA = NULL;
+ rStarTreeRA = NULL;
+ }
+
+ // We only need to serialize one of the kRANN objects.
+ switch (treeType)
+ {
+ case KD_TREE:
+ ar & data::CreateNVP(kdTreeRA, "ra_model");
+ break;
+ case COVER_TREE:
+ ar & data::CreateNVP(coverTreeRA, "ra_model");
+ break;
+ case R_TREE:
+ ar & data::CreateNVP(rTreeRA, "ra_model");
+ break;
+ case R_STAR_TREE:
+ ar & data::CreateNVP(rStarTreeRA, "ra_model");
+ break;
+ }
+}
+
+template<typename SortPolicy>
+const arma::mat& RAModel<SortPolicy>::Dataset() const
+{
+ if (kdTreeRA)
+ return kdTreeRA->ReferenceSet();
+ else if (coverTreeRA)
+ return coverTreeRA->ReferenceSet();
+ else if (rTreeRA)
+ return rTreeRA->ReferenceSet();
+ else if (rStarTreeRA)
+ return rStarTreeRA->ReferenceSet();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+bool RAModel<SortPolicy>::Naive() const
+{
+ if (kdTreeRA)
+ return kdTreeRA->Naive();
+ else if (coverTreeRA)
+ return coverTreeRA->Naive();
+ else if (rTreeRA)
+ return rTreeRA->Naive();
+ else if (rStarTreeRA)
+ return rStarTreeRA->Naive();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+bool& RAModel<SortPolicy>::Naive()
+{
+ if (kdTreeRA)
+ return kdTreeRA->Naive();
+ else if (coverTreeRA)
+ return coverTreeRA->Naive();
+ else if (rTreeRA)
+ return rTreeRA->Naive();
+ else if (rStarTreeRA)
+ return rStarTreeRA->Naive();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+bool RAModel<SortPolicy>::SingleMode() const
+{
+ if (kdTreeRA)
+ return kdTreeRA->SingleMode();
+ else if (coverTreeRA)
+ return coverTreeRA->SingleMode();
+ else if (rTreeRA)
+ return rTreeRA->SingleMode();
+ else if (rStarTreeRA)
+ return rStarTreeRA->SingleMode();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+bool& RAModel<SortPolicy>::SingleMode()
+{
+ if (kdTreeRA)
+ return kdTreeRA->SingleMode();
+ else if (coverTreeRA)
+ return coverTreeRA->SingleMode();
+ else if (rTreeRA)
+ return rTreeRA->SingleMode();
+ else if (rStarTreeRA)
+ return rStarTreeRA->SingleMode();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+double RAModel<SortPolicy>::Tau() const
+{
+ if (kdTreeRA)
+ return kdTreeRA->Tau();
+ else if (coverTreeRA)
+ return coverTreeRA->Tau();
+ else if (rTreeRA)
+ return rTreeRA->Tau();
+ else if (rStarTreeRA)
+ return rStarTreeRA->Tau();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+double& RAModel<SortPolicy>::Tau()
+{
+ if (kdTreeRA)
+ return kdTreeRA->Tau();
+ else if (coverTreeRA)
+ return coverTreeRA->Tau();
+ else if (rTreeRA)
+ return rTreeRA->Tau();
+ else if (rStarTreeRA)
+ return rStarTreeRA->Tau();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+double RAModel<SortPolicy>::Alpha() const
+{
+ if (kdTreeRA)
+ return kdTreeRA->Alpha();
+ else if (coverTreeRA)
+ return coverTreeRA->Alpha();
+ else if (rTreeRA)
+ return rTreeRA->Alpha();
+ else if (rStarTreeRA)
+ return rStarTreeRA->Alpha();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+double& RAModel<SortPolicy>::Alpha()
+{
+ if (kdTreeRA)
+ return kdTreeRA->Alpha();
+ else if (coverTreeRA)
+ return coverTreeRA->Alpha();
+ else if (rTreeRA)
+ return rTreeRA->Alpha();
+ else if (rStarTreeRA)
+ return rStarTreeRA->Alpha();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+bool RAModel<SortPolicy>::SampleAtLeaves() const
+{
+ if (kdTreeRA)
+ return kdTreeRA->SampleAtLeaves();
+ else if (coverTreeRA)
+ return coverTreeRA->SampleAtLeaves();
+ else if (rTreeRA)
+ return rTreeRA->SampleAtLeaves();
+ else if (rStarTreeRA)
+ return rStarTreeRA->SampleAtLeaves();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+bool& RAModel<SortPolicy>::SampleAtLeaves()
+{
+ if (kdTreeRA)
+ return kdTreeRA->SampleAtLeaves();
+ else if (coverTreeRA)
+ return coverTreeRA->SampleAtLeaves();
+ else if (rTreeRA)
+ return rTreeRA->SampleAtLeaves();
+ else if (rStarTreeRA)
+ return rStarTreeRA->SampleAtLeaves();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+bool RAModel<SortPolicy>::FirstLeafExact() const
+{
+ if (kdTreeRA)
+ return kdTreeRA->FirstLeafExact();
+ else if (coverTreeRA)
+ return coverTreeRA->FirstLeafExact();
+ else if (rTreeRA)
+ return rTreeRA->FirstLeafExact();
+ else if (rStarTreeRA)
+ return rStarTreeRA->FirstLeafExact();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+bool& RAModel<SortPolicy>::FirstLeafExact()
+{
+ if (kdTreeRA)
+ return kdTreeRA->FirstLeafExact();
+ else if (coverTreeRA)
+ return coverTreeRA->FirstLeafExact();
+ else if (rTreeRA)
+ return rTreeRA->FirstLeafExact();
+ else if (rStarTreeRA)
+ return rStarTreeRA->FirstLeafExact();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+size_t RAModel<SortPolicy>::SingleSampleLimit() const
+{
+ if (kdTreeRA)
+ return kdTreeRA->SingleSampleLimit();
+ else if (coverTreeRA)
+ return coverTreeRA->SingleSampleLimit();
+ else if (rTreeRA)
+ return rTreeRA->SingleSampleLimit();
+ else if (rStarTreeRA)
+ return rStarTreeRA->SingleSampleLimit();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+size_t& RAModel<SortPolicy>::SingleSampleLimit()
+{
+ if (kdTreeRA)
+ return kdTreeRA->SingleSampleLimit();
+ else if (coverTreeRA)
+ return coverTreeRA->SingleSampleLimit();
+ else if (rTreeRA)
+ return rTreeRA->SingleSampleLimit();
+ else if (rStarTreeRA)
+ return rStarTreeRA->SingleSampleLimit();
+
+ throw std::runtime_error("no rank-approximate nearest neighbor search model "
+ "initialized");
+}
+
+template<typename SortPolicy>
+size_t RAModel<SortPolicy>::LeafSize() const
+{
+ return leafSize;
+}
+
+template<typename SortPolicy>
+size_t& RAModel<SortPolicy>::LeafSize()
+{
+ return leafSize;
+}
+
+template<typename SortPolicy>
+int RAModel<SortPolicy>::TreeType() const
+{
+ return treeType;
+}
+
+template<typename SortPolicy>
+int& RAModel<SortPolicy>::TreeType()
+{
+ return treeType;
+}
+
+template<typename SortPolicy>
+bool RAModel<SortPolicy>::RandomBasis() const
+{
+ return randomBasis;
+}
+
+template<typename SortPolicy>
+bool& RAModel<SortPolicy>::RandomBasis()
+{
+ return randomBasis;
+}
+
+template<typename SortPolicy>
+void RAModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
+ const size_t leafSize,
+ const bool naive,
+ const bool singleMode)
+{
+ // Initialize random basis, if necessary.
+ if (randomBasis)
+ {
+ Log::Info << "Creating random basis..." << std::endl;
+ math::RandomBasis(q, referenceSet.n_rows);
+ }
+
+ // Clean memory, if necessary.
+ if (kdTreeRA)
+ delete kdTreeRA;
+ if (coverTreeRA)
+ delete coverTreeRA;
+ if (rTreeRA)
+ delete rTreeRA;
+ if (rStarTreeRA)
+ delete rStarTreeRA;
+
+ if (randomBasis)
+ referenceSet = q * referenceSet;
+
+ if (!naive)
+ {
+ Timer::Start("tree_building");
+ Log::Info << "Building reference tree..." << std::endl;
+ }
+
+ switch (treeType)
+ {
+ case KD_TREE:
+ // Build tree, if necessary.
+ if (naive)
+ {
+ kdTreeRA = new RAType<tree::KDTree>(std::move(referenceSet), naive,
+ singleMode);
+ }
+ else
+ {
+ std::vector<size_t> oldFromNewReferences;
+ typename RAType<tree::KDTree>::Tree* kdTree =
+ new typename RAType<tree::KDTree>::Tree(std::move(referenceSet),
+ oldFromNewReferences, leafSize);
+ kdTreeRA = new RAType<tree::KDTree>(kdTree, singleMode);
+
+ // Give the model ownership of the tree.
+ kdTreeRA->treeOwner = true;
+ kdTreeRA->oldFromNewReferences = oldFromNewReferences;
+ }
+ break;
+ case COVER_TREE:
+ coverTreeRA = new RAType<tree::StandardCoverTree>(std::move(referenceSet),
+ naive, singleMode);
+ break;
+ case R_TREE:
+ rTreeRA = new RAType<tree::RTree>(std::move(referenceSet), naive,
+ singleMode);
+ break;
+ case R_STAR_TREE:
+ rStarTreeRA = new RAType<tree::RStarTree>(std::move(referenceSet), naive,
+ singleMode);
+ break;
+ }
+
+ if (!naive)
+ {
+ Timer::Stop("tree_building");
+ Log::Info << "Tree built." << std::endl;
+ }
+}
+
+template<typename SortPolicy>
+void RAModel<SortPolicy>::Search(arma::mat&& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ // Apply the random basis if necessary.
+ if (randomBasis)
+ querySet = q * querySet;
+
+ Log::Info << "Searching for " << k << " approximate nearest neighbors with ";
+ if (!Naive() && !SingleMode())
+ Log::Info << "dual-tree rank-approximate " << TreeName() << " search...";
+ else if (!Naive())
+ Log::Info << "single-tree rank-approximate " << TreeName() << " search...";
+ else
+ Log::Info << "brute-force (naive) rank-approximate search...";
+
+ switch (treeType)
+ {
+ case KD_TREE:
+ if (!kdTreeRA->Naive() && !kdTreeRA->SingleMode())
+ {
+ // Build a second tree and search.
+ Timer::Start("tree_building");
+ Log::Info << "Building query tree..." << std::endl;
+ std::vector<size_t> oldFromNewQueries;
+ typename RAType<tree::KDTree>::Tree queryTree(std::move(querySet),
+ oldFromNewQueries, leafSize);
+ Log::Info << "Tree built." << std::endl;
+ Timer::Stop("tree_building");
+
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+ kdTreeRA->Search(&queryTree, k, neighborsOut, distancesOut);
+
+ // Unmap the query points.
+ distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+ neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+ for (size_t i = 0; i < neighborsOut.n_cols; ++i)
+ {
+ neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
+ distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+ }
+ }
+ else
+ {
+ // Search without building a second tree.
+ kdTreeRA->Search(querySet, k, neighbors, distances);
+ }
+ break;
+ case COVER_TREE:
+ // No mapping necessary.
+ coverTreeRA->Search(querySet, k, neighbors, distances);
+ break;
+ case R_TREE:
+ // No mapping necessary.
+ rTreeRA->Search(querySet, k, neighbors, distances);
+ break;
+ case R_STAR_TREE:
+ // No mapping necessary.
+ rStarTreeRA->Search(querySet, k, neighbors, distances);
+ break;
+ }
+}
+
+template<typename SortPolicy>
+void RAModel<SortPolicy>::Search(const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ Log::Info << "Searching for " << k << " approximate nearest neighbors with ";
+ if (!Naive() && !SingleMode())
+ Log::Info << "dual-tree rank-approximate " << TreeName() << " search...";
+ else if (!Naive())
+ Log::Info << "single-tree rank-approximate " << TreeName() << " search...";
+ else
+ Log::Info << "brute-force (naive) rank-approximate search...";
+
+ switch (treeType)
+ {
+ case KD_TREE:
+ kdTreeRA->Search(k, neighbors, distances);
+ break;
+ case COVER_TREE:
+ coverTreeRA->Search(k, neighbors, distances);
+ break;
+ case R_TREE:
+ rTreeRA->Search(k, neighbors, distances);
+ break;
+ case R_STAR_TREE:
+ rStarTreeRA->Search(k, neighbors, distances);
+ break;
+ }
+}
+
+template<typename SortPolicy>
+std::string RAModel<SortPolicy>::TreeName() const
+{
+ switch (treeType)
+ {
+ case KD_TREE:
+ return "kd-tree";
+ case COVER_TREE:
+ return "cover tree";
+ case R_TREE:
+ return "R tree";
+ case R_STAR_TREE:
+ return "R* tree";
+ default:
+ return "unknown tree";
+ }
+}
+
+} // namespace neighbor
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index 9a4084b..feaf4b5 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -31,6 +31,10 @@
namespace mlpack {
namespace neighbor {
+// Forward declaration.
+template<typename SortPolicy>
+class RAModel;
+
/**
* The RASearch class: This class provides a generic manner to perform
* rank-approximate search via random-sampling. If the 'naive' option is chosen,
@@ -444,6 +448,9 @@ class RASearch
//! Instantiation of kernel.
MetricType metric;
+
+ //! RAModel can modify internal members as necessary.
+ friend class RAModel<SortPolicy>;
}; // class RASearch
} // namespace neighbor
More information about the mlpack-git
mailing list