[mlpack-git] master: Add RSModel (not yet tested). (156787d)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Nov 2 12:19:31 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f86acf8be2c01568d8b3dcd2e529ee9f20f7585e...156787dd4f372a7fd740f733127ac200ea2564b7
>---------------------------------------------------------------
commit 156787dd4f372a7fd740f733127ac200ea2564b7
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Nov 2 17:17:49 2015 +0000
Add RSModel (not yet tested).
>---------------------------------------------------------------
156787dd4f372a7fd740f733127ac200ea2564b7
src/mlpack/methods/range_search/CMakeLists.txt | 3 +
src/mlpack/methods/range_search/range_search.hpp | 6 +
src/mlpack/methods/range_search/rs_model.cpp | 305 ++++++++++++++++++++++
src/mlpack/methods/range_search/rs_model.hpp | 171 ++++++++++++
src/mlpack/methods/range_search/rs_model_impl.hpp | 138 ++++++++++
5 files changed, 623 insertions(+)
diff --git a/src/mlpack/methods/range_search/CMakeLists.txt b/src/mlpack/methods/range_search/CMakeLists.txt
index 763f0a0..171dd86 100644
--- a/src/mlpack/methods/range_search/CMakeLists.txt
+++ b/src/mlpack/methods/range_search/CMakeLists.txt
@@ -6,6 +6,9 @@ set(SOURCES
range_search_rules.hpp
range_search_rules_impl.hpp
range_search_stat.hpp
+ rs_model.hpp
+ rs_model_impl.hpp
+ rs_model.cpp
)
# Add directory name to sources.
diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index ffb3d86..1ecf0f4 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -16,6 +16,9 @@
namespace mlpack {
namespace range /** Range-search routines. */ {
+//! Forward declaration.
+class RSModel;
+
/**
* The RangeSearch class is a template class for performing range searches. It
* is implemented in the style of a generalized tree-independent dual-tree
@@ -316,6 +319,9 @@ class RangeSearch
size_t baseCases;
//! The total number of scores during the last search.
size_t scores;
+
+ //! For access to mappings when building models.
+ friend RSModel;
};
} // namespace range
diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp
new file mode 100644
index 0000000..3c9a4ef
--- /dev/null
+++ b/src/mlpack/methods/range_search/rs_model.cpp
@@ -0,0 +1,305 @@
+/**
+ * @file rs_model.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the range search model class.
+ */
+#include "rs_model.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::range;
+
+/**
+ * Initialize the RSModel with the given tree type and whether or not a random
+ * basis should be used.
+ */
+RSModel::RSModel(int treeType, bool randomBasis) :
+ treeType(treeType),
+ randomBasis(randomBasis),
+ kdTreeRS(NULL),
+ coverTreeRS(NULL),
+ rTreeRS(NULL),
+ rStarTreeRS(NULL),
+ ballTreeRS(NULL)
+{
+ // Nothing to do.
+}
+
+// Clean memory, if necessary.
+RSModel::~RSModel()
+{
+ CleanMemory();
+}
+
+void RSModel::BuildModel(arma::mat&& referenceSet,
+ const size_t leafSize,
+ const bool naive,
+ const bool singleMode)
+{
+ // Initialize random basis if necessary.
+ if (randomBasis)
+ {
+ Log::Info << "Creating random basis..." << endl;
+ math::RandomBasis(q, referenceSet.n_rows);
+ }
+
+ // Clean memory, if necessary.
+ CleanMemory();
+
+ // Do we need to modify the reference set?
+ if (randomBasis)
+ referenceSet = q * referenceSet;
+
+ if (!naive)
+ {
+ Timer::Start("tree_building");
+ Log::Info << "Building reference tree..." << endl;
+ }
+
+ switch (treeType)
+ {
+ case KD_TREE:
+ // If necessary, build the tree.
+ if (naive)
+ {
+ kdTreeRS = new RSType<tree::KDTree>(move(referenceSet), naive,
+ singleMode);
+ }
+ else
+ {
+ vector<size_t> oldFromNewReferences;
+ typename RSType<tree::KDTree>::Tree* kdTree =
+ new typename RSType<tree::KDTree>::Tree(move(referenceSet),
+ oldFromNewReferences, leafSize);
+ kdTreeRS = new RSType<tree::KDTree>(kdTree, singleMode);
+
+ // Give the model ownership of the tree and the mappings.
+ kdTreeRS->treeOwner = true;
+ kdTreeRS->oldFromNewReferences = move(oldFromNewReferences);
+ }
+
+ break;
+
+ case COVER_TREE:
+ coverTreeRS = new RSType<tree::StandardCoverTree>(move(referenceSet),
+ naive, singleMode);
+ break;
+
+ case R_TREE:
+ rTreeRS = new RSType<tree::RTree>(move(referenceSet), naive,
+ singleMode);
+ break;
+
+ case R_STAR_TREE:
+ rStarTreeRS = new RSType<tree::RStarTree>(move(referenceSet), naive,
+ singleMode);
+ break;
+
+ case BALL_TREE:
+ // If necessary, build the ball tree.
+ if (naive)
+ {
+ ballTreeRS = new RSType<tree::BallTree>(move(referenceSet), naive,
+ singleMode);
+ }
+ else
+ {
+ vector<size_t> oldFromNewReferences;
+ typename RSType<tree::BallTree>::Tree* ballTree =
+ new typename RSType<tree::BallTree>::Tree(move(referenceSet),
+ oldFromNewReferences, leafSize);
+ ballTreeRS = new RSType<tree::BallTree>(ballTree, singleMode);
+
+ // Give the model ownership of the tree and the mappings.
+ ballTreeRS->treeOwner = true;
+ ballTreeRS->oldFromNewReferences = move(oldFromNewReferences);
+ }
+
+ break;
+ }
+
+ if (!naive)
+ {
+ Timer::Stop("tree_building");
+ Log::Info << "Tree built." << endl;
+ }
+}
+
+// Perform range search.
+void RSModel::Search(arma::mat&& querySet,
+ const math::Range& range,
+ vector<vector<size_t>>& neighbors,
+ vector<vector<double>>& distances)
+{
+ // We may need to map the query set randomly.
+ if (randomBasis)
+ querySet = q * querySet;
+
+ Log::Info << "Search for points in the range [" << range.Lo() << ", "
+ << range.Hi() << "] with ";
+ if (!Naive() && !SingleMode())
+ Log::Info << "dual-tree " << TreeName() << " search..." << endl;
+ else if (!Naive())
+ Log::Info << "single-tree " << TreeName() << " search..." << endl;
+ else
+ Log::Info << "brute-force (naive) search..." << endl;
+
+ switch (treeType)
+ {
+ case KD_TREE:
+ if (!kdTreeRS->Naive() && !kdTreeRS->SingleMode())
+ {
+ // Build a second tree and search.
+ Timer::Start("tree_building");
+ Log::Info << "Building query tree..." << endl;
+ vector<size_t> oldFromNewQueries;
+ typename RSType<tree::KDTree>::Tree queryTree(move(querySet),
+ oldFromNewQueries, leafSize);
+ Log::Info << "Tree built." << endl;
+ Timer::Stop("tree_building");
+
+ vector<vector<size_t>> neighborsOut;
+ vector<vector<double>> distancesOut;
+ kdTreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
+
+ // Remap the query points.
+ neighbors.resize(queryTree.Dataset().n_cols);
+ distances.resize(queryTree.Dataset().n_cols);
+ for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
+ {
+ neighbors[i] = neighborsOut[oldFromNewQueries[i]];
+ distances[i] = distancesOut[oldFromNewQueries[i]];
+ }
+ }
+ else
+ {
+ // Search without building a second tree.
+ kdTreeRS->Search(querySet, range, neighbors, distances);
+ }
+ break;
+
+ case COVER_TREE:
+ coverTreeRS->Search(querySet, range, neighbors, distances);
+ break;
+
+ case R_TREE:
+ rTreeRS->Search(querySet, range, neighbors, distances);
+ break;
+
+ case R_STAR_TREE:
+ rStarTreeRS->Search(querySet, range, neighbors, distances);
+ break;
+
+ case BALL_TREE:
+ if (!ballTreeRS->Naive() && !ballTreeRS->SingleMode())
+ {
+ // Build a second tree and search.
+ Timer::Start("tree_building");
+ Log::Info << "Building query tree..." << endl;
+ vector<size_t> oldFromNewQueries;
+ typename RSType<tree::BallTree>::Tree queryTree(move(querySet),
+ oldFromNewQueries, leafSize);
+ Log::Info << "Tree built." << endl;
+ Timer::Stop("tree_building");
+
+ vector<vector<size_t>> neighborsOut;
+ vector<vector<double>> distancesOut;
+ ballTreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
+
+ // Remap the query points.
+ neighbors.resize(queryTree.Dataset().n_cols);
+ distances.resize(queryTree.Dataset().n_cols);
+ for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
+ {
+ neighbors[i] = neighborsOut[oldFromNewQueries[i]];
+ distances[i] = distancesOut[oldFromNewQueries[i]];
+ }
+ }
+ else
+ {
+ // Search without building a second tree.
+ ballTreeRS->Search(querySet, range, neighbors, distances);
+ }
+ break;
+ }
+}
+
+// Perform range search (monochromatic case).
+void RSModel::Search(const math::Range& range,
+ vector<vector<size_t>>& neighbors,
+ vector<vector<double>>& distances)
+{
+ Log::Info << "Search for points in the range [" << range.Lo() << ", "
+ << range.Hi() << "] with ";
+ if (!Naive() && !SingleMode())
+ Log::Info << "dual-tree " << TreeName() << " search..." << endl;
+ else if (!Naive())
+ Log::Info << "single-tree " << TreeName() << " search..." << endl;
+ else
+ Log::Info << "brute-force (naive) search..." << endl;
+
+ switch (treeType)
+ {
+ case KD_TREE:
+ kdTreeRS->Search(range, neighbors, distances);
+ break;
+
+ case COVER_TREE:
+ coverTreeRS->Search(range, neighbors, distances);
+ break;
+
+ case R_TREE:
+ rTreeRS->Search(range, neighbors, distances);
+ break;
+
+ case R_STAR_TREE:
+ rStarTreeRS->Search(range, neighbors, distances);
+ break;
+
+ case BALL_TREE:
+ ballTreeRS->Search(range, neighbors, distances);
+ break;
+ }
+}
+
+// Get the name of the tree type.
+std::string RSModel::TreeName() const
+{
+ switch (treeType)
+ {
+ case KD_TREE:
+ return "kd-tree";
+ case COVER_TREE:
+ return "cover tree";
+ case R_TREE:
+ return "R tree";
+ case R_STAR_TREE:
+ return "R* tree";
+ case BALL_TREE:
+ return "ball tree";
+ default:
+ return "unknown tree";
+ }
+}
+
+// Clean memory.
+void RSModel::CleanMemory()
+{
+ if (kdTreeRS)
+ delete kdTreeRS;
+ if (coverTreeRS)
+ delete coverTreeRS;
+ if (rTreeRS)
+ delete rTreeRS;
+ if (rStarTreeRS)
+ delete rStarTreeRS;
+ if (ballTreeRS)
+ delete ballTreeRS;
+
+ kdTreeRS = NULL;
+ coverTreeRS = NULL;
+ rTreeRS = NULL;
+ rStarTreeRS = NULL;
+ ballTreeRS = NULL;
+}
diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp
new file mode 100644
index 0000000..0a249fd
--- /dev/null
+++ b/src/mlpack/methods/range_search/rs_model.hpp
@@ -0,0 +1,171 @@
+/**
+ * @file rs_model.hpp
+ * @author Ryan Curtin
+ *
+ * This is a model for range search. It is useful in that it provides an easy
+ * way to serialize a model, abstracts away the different types of trees, and
+ * also reflects the RangeSearch API and automatically directs to the right
+ * tree types.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_HPP
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
+
+#include "range_search.hpp"
+
+namespace mlpack {
+namespace range {
+
+class RSModel
+{
+ public:
+ enum TreeTypes
+ {
+ KD_TREE,
+ COVER_TREE,
+ R_TREE,
+ R_STAR_TREE,
+ BALL_TREE
+ };
+
+ private:
+ int treeType;
+ size_t leafSize;
+
+ //! If true, we randomly project the data into a new basis before search.
+ bool randomBasis;
+ //! Random projection matrix.
+ arma::mat q;
+
+ //! The mostly-specified type of the range search model.
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ using RSType = RangeSearch<metric::EuclideanDistance, arma::mat, TreeType>;
+
+ // Only one of these pointers will be non-NULL.
+ //! kd-tree based range search object (NULL if not in use).
+ RSType<tree::KDTree>* kdTreeRS;
+ //! Cover tree based range search object (NULL if not in use).
+ RSType<tree::StandardCoverTree>* coverTreeRS;
+ //! R tree based range search object (NULL if not in use).
+ RSType<tree::RTree>* rTreeRS;
+ //! R* tree based range search object (NULL if not in use).
+ RSType<tree::RStarTree>* rStarTreeRS;
+ //! Ball tree based range search object (NULL if not in use).
+ RSType<tree::BallTree>* ballTreeRS;
+
+ public:
+ /**
+ * Initialize the RSModel with the given type and whether or not a random
+ * basis should be used.
+ *
+ * @param treeType Type of tree to use.
+ * @param randomBasis Whether or not to use a random basis.
+ */
+ RSModel(const int treeType = TreeTypes::KD_TREE,
+ const bool randomBasis = false);
+
+ /**
+ * Clean memory, if necessary.
+ */
+ ~RSModel();
+
+ //! Serialize the range search model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
+ //! Expose the dataset.
+ const arma::mat& Dataset() const;
+
+ //! Get whether the model is in single-tree search mode.
+ bool SingleMode() const;
+ //! Modify whether the model is in single-tree search mode.
+ bool& SingleMode();
+
+ //! Get whether the model is in naive search mode.
+ bool Naive() const;
+ //! Modify whether the model is in naive search mode.
+ bool& Naive();
+
+ //! Get the leaf size (applicable to everything but the cover tree).
+ size_t LeafSize() const { return leafSize; }
+ //! Modify the leaf size (applicable to everything but the cover tree).
+ size_t& LeafSize() { return leafSize; }
+
+ //! Get the type of tree.
+ int TreeType() const { return treeType; }
+ //! Modify the type of tree (don't do this after the model has been built).
+ int& TreeType() { return treeType; }
+
+ //! Get whether a random basis is used.
+ bool RandomBasis() const { return randomBasis; }
+ //! Modify whether a random basis is used (don't do this after the model has
+ //! been built).
+ bool& RandomBasis() { return randomBasis; }
+
+ /**
+ * Build the reference tree on the given dataset with the given parameters.
+ * This takes possession of the reference set to avoid a copy.
+ *
+ * @param referenceSet Set of reference points.
+ * @param leafSize Leaf size of tree (ignored for the cover tree).
+ * @param naive Whether naive search should be used.
+ * @param singleMode Whether single-tree search should be used.
+ */
+ void BuildModel(arma::mat&& referenceSet,
+ const size_t leafSize,
+ const bool naive,
+ const bool singleMode);
+
+ /**
+ * Perform range search. This takes possession of the query set, so the query
+ * set will not be usable after the search. For more information on the
+ * output format, see RangeSearch<>::Search().
+ *
+ * @param querySet Set of query points.
+ * @param range Range to search for.
+ * @param neighbors Output: neighbors falling within the desired range.
+ * @param distances Output: distances of neighbors.
+ */
+ void Search(arma::mat&& querySet,
+ const math::Range& range,
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances);
+
+ /**
+ * Perform monochromatic range search, with the reference set as the query
+ * set. For more information on the output format, see
+ * RangeSearch<>::Search().
+ *
+ * @param range Range to search for.
+ * @param neighbors Output: neighbors falling within the desired range.
+ * @param distances Output: distances of neighbors.
+ */
+ void Search(const math::Range& range,
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances);
+
+ private:
+ /**
+ * Return a string representing the name of the tree. This is used for
+ * logging output.
+ */
+ std::string TreeName() const;
+
+ /**
+ * Clean up memory.
+ */
+ void CleanMemory();
+};
+
+} // namespace range
+} // namespace mlpack
+
+// Include implementation (of Serialize() and inline functions).
+#include "rs_model_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp
new file mode 100644
index 0000000..55f38d0
--- /dev/null
+++ b/src/mlpack/methods/range_search/rs_model_impl.hpp
@@ -0,0 +1,138 @@
+/**
+ * @file rs_model_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of Serialize() and inline functions for RSModel.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_IMPL_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "rs_model.hpp"
+
+namespace mlpack {
+namespace range {
+
+// Serialize the model.
+template<typename Archive>
+void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ ar & CreateNVP(treeType, "treeType");
+ ar & CreateNVP(randomBasis, "randomBasis");
+ ar & CreateNVP(q, "q");
+
+ // This should never happen, but just in case...
+ if (Archive::is_loading::value)
+ CleanMemory();
+
+ // We'll only need to serialize one of the model objects, based on the type.
+ switch (treeType)
+ {
+ case KD_TREE:
+ ar & CreateNVP(kdTreeRS, "range_search_model");
+ break;
+
+ case COVER_TREE:
+ ar & CreateNVP(coverTreeRS, "range_search_model");
+ break;
+
+ case R_TREE:
+ ar & CreateNVP(rTreeRS, "range_search_model");
+ break;
+
+ case R_STAR_TREE:
+ ar & CreateNVP(rStarTreeRS, "range_search_model");
+ break;
+
+ case BALL_TREE:
+ ar & CreateNVP(ballTreeRS, "range_search_model");
+ break;
+ }
+}
+
+inline const arma::mat& RSModel::Dataset() const
+{
+ if (kdTreeRS)
+ return kdTreeRS->ReferenceSet();
+ else if (coverTreeRS)
+ return coverTreeRS->ReferenceSet();
+ else if (rTreeRS)
+ return rTreeRS->ReferenceSet();
+ else if (rStarTreeRS)
+ return rStarTreeRS->ReferenceSet();
+ else if (ballTreeRS)
+ return ballTreeRS->ReferenceSet();
+
+ throw std::runtime_error("no range search model initialized");
+}
+
+inline bool RSModel::SingleMode() const
+{
+ if (kdTreeRS)
+ return kdTreeRS->SingleMode();
+ else if (coverTreeRS)
+ return coverTreeRS->SingleMode();
+ else if (rTreeRS)
+ return rTreeRS->SingleMode();
+ else if (rStarTreeRS)
+ return rStarTreeRS->SingleMode();
+ else if (ballTreeRS)
+ return ballTreeRS->SingleMode();
+
+ throw std::runtime_error("no range search model initialized");
+}
+
+inline bool& RSModel::SingleMode()
+{
+ if (kdTreeRS)
+ return kdTreeRS->SingleMode();
+ else if (coverTreeRS)
+ return coverTreeRS->SingleMode();
+ else if (rTreeRS)
+ return rTreeRS->SingleMode();
+ else if (rStarTreeRS)
+ return rStarTreeRS->SingleMode();
+ else if (ballTreeRS)
+ return ballTreeRS->SingleMode();
+
+ throw std::runtime_error("no range search model initialized");
+}
+
+inline bool RSModel::Naive() const
+{
+ if (kdTreeRS)
+ return kdTreeRS->Naive();
+ else if (coverTreeRS)
+ return coverTreeRS->Naive();
+ else if (rTreeRS)
+ return rTreeRS->Naive();
+ else if (rStarTreeRS)
+ return rStarTreeRS->Naive();
+ else if (ballTreeRS)
+ return ballTreeRS->Naive();
+
+ throw std::runtime_error("no range search model initialized");
+}
+
+inline bool& RSModel::Naive()
+{
+ if (kdTreeRS)
+ return kdTreeRS->Naive();
+ else if (coverTreeRS)
+ return coverTreeRS->Naive();
+ else if (rTreeRS)
+ return rTreeRS->Naive();
+ else if (rStarTreeRS)
+ return rStarTreeRS->Naive();
+ else if (ballTreeRS)
+ return ballTreeRS->Naive();
+
+ throw std::runtime_error("no range search model initialized");
+}
+
+} // namespace range
+} // namespace mlpack
+
+#endif
More information about the mlpack-git
mailing list