[mlpack-git] master: Add RSModel (not yet tested). (156787d)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Nov 2 12:19:31 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f86acf8be2c01568d8b3dcd2e529ee9f20f7585e...156787dd4f372a7fd740f733127ac200ea2564b7

>---------------------------------------------------------------

commit 156787dd4f372a7fd740f733127ac200ea2564b7
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Nov 2 17:17:49 2015 +0000

    Add RSModel (not yet tested).


>---------------------------------------------------------------

156787dd4f372a7fd740f733127ac200ea2564b7
 src/mlpack/methods/range_search/CMakeLists.txt    |   3 +
 src/mlpack/methods/range_search/range_search.hpp  |   6 +
 src/mlpack/methods/range_search/rs_model.cpp      | 305 ++++++++++++++++++++++
 src/mlpack/methods/range_search/rs_model.hpp      | 171 ++++++++++++
 src/mlpack/methods/range_search/rs_model_impl.hpp | 138 ++++++++++
 5 files changed, 623 insertions(+)

diff --git a/src/mlpack/methods/range_search/CMakeLists.txt b/src/mlpack/methods/range_search/CMakeLists.txt
index 763f0a0..171dd86 100644
--- a/src/mlpack/methods/range_search/CMakeLists.txt
+++ b/src/mlpack/methods/range_search/CMakeLists.txt
@@ -6,6 +6,9 @@ set(SOURCES
   range_search_rules.hpp
   range_search_rules_impl.hpp
   range_search_stat.hpp
+  rs_model.hpp
+  rs_model_impl.hpp
+  rs_model.cpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index ffb3d86..1ecf0f4 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -16,6 +16,9 @@
 namespace mlpack {
 namespace range /** Range-search routines. */ {
 
+//! Forward declaration.
+class RSModel;
+
 /**
  * The RangeSearch class is a template class for performing range searches.  It
  * is implemented in the style of a generalized tree-independent dual-tree
@@ -316,6 +319,9 @@ class RangeSearch
   size_t baseCases;
   //! The total number of scores during the last search.
   size_t scores;
+
+  //! For access to mappings when building models.
+  friend RSModel;
 };
 
 } // namespace range
diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp
new file mode 100644
index 0000000..3c9a4ef
--- /dev/null
+++ b/src/mlpack/methods/range_search/rs_model.cpp
@@ -0,0 +1,305 @@
+/**
+ * @file rs_model.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the range search model class.
+ */
+#include "rs_model.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::range;
+
+/**
+ * Initialize the RSModel with the given tree type and whether or not a random
+ * basis should be used.
+ */
+RSModel::RSModel(int treeType, bool randomBasis) :
+    treeType(treeType),
+    randomBasis(randomBasis),
+    kdTreeRS(NULL),
+    coverTreeRS(NULL),
+    rTreeRS(NULL),
+    rStarTreeRS(NULL),
+    ballTreeRS(NULL)
+{
+  // Nothing to do.
+}
+
+// Clean memory, if necessary.
+RSModel::~RSModel()
+{
+  CleanMemory();
+}
+
+void RSModel::BuildModel(arma::mat&& referenceSet,
+                         const size_t leafSize,
+                         const bool naive,
+                         const bool singleMode)
+{
+  // Initialize random basis if necessary.
+  if (randomBasis)
+  {
+    Log::Info << "Creating random basis..." << endl;
+    math::RandomBasis(q, referenceSet.n_rows);
+  }
+
+  // Clean memory, if necessary.
+  CleanMemory();
+
+  // Do we need to modify the reference set?
+  if (randomBasis)
+    referenceSet = q * referenceSet;
+
+  if (!naive)
+  {
+    Timer::Start("tree_building");
+    Log::Info << "Building reference tree..." << endl;
+  }
+
+  switch (treeType)
+  {
+    case KD_TREE:
+      // If necessary, build the tree.
+      if (naive)
+      {
+        kdTreeRS = new RSType<tree::KDTree>(move(referenceSet), naive,
+            singleMode);
+      }
+      else
+      {
+        vector<size_t> oldFromNewReferences;
+        typename RSType<tree::KDTree>::Tree* kdTree =
+            new typename RSType<tree::KDTree>::Tree(move(referenceSet),
+            oldFromNewReferences, leafSize);
+        kdTreeRS = new RSType<tree::KDTree>(kdTree, singleMode);
+
+        // Give the model ownership of the tree and the mappings.
+        kdTreeRS->treeOwner = true;
+        kdTreeRS->oldFromNewReferences = move(oldFromNewReferences);
+      }
+
+      break;
+
+    case COVER_TREE:
+      coverTreeRS = new RSType<tree::StandardCoverTree>(move(referenceSet),
+          naive, singleMode);
+      break;
+
+    case R_TREE:
+      rTreeRS = new RSType<tree::RTree>(move(referenceSet), naive,
+          singleMode);
+      break;
+
+    case R_STAR_TREE:
+      rStarTreeRS = new RSType<tree::RStarTree>(move(referenceSet), naive,
+          singleMode);
+      break;
+
+    case BALL_TREE:
+      // If necessary, build the ball tree.
+      if (naive)
+      {
+        ballTreeRS = new RSType<tree::BallTree>(move(referenceSet), naive,
+            singleMode);
+      }
+      else
+      {
+        vector<size_t> oldFromNewReferences;
+        typename RSType<tree::BallTree>::Tree* ballTree =
+            new typename RSType<tree::BallTree>::Tree(move(referenceSet),
+            oldFromNewReferences, leafSize);
+        ballTreeRS = new RSType<tree::BallTree>(ballTree, singleMode);
+
+        // Give the model ownership of the tree and the mappings.
+        ballTreeRS->treeOwner = true;
+        ballTreeRS->oldFromNewReferences = move(oldFromNewReferences);
+      }
+
+      break;
+  }
+
+  if (!naive)
+  {
+    Timer::Stop("tree_building");
+    Log::Info << "Tree built." << endl;
+  }
+}
+
+// Perform range search.
+void RSModel::Search(arma::mat&& querySet,
+                     const math::Range& range,
+                     vector<vector<size_t>>& neighbors,
+                     vector<vector<double>>& distances)
+{
+  // We may need to map the query set randomly.
+  if (randomBasis)
+    querySet = q * querySet;
+
+  Log::Info << "Search for points in the range [" << range.Lo() << ", "
+      << range.Hi() << "] with ";
+  if (!Naive() && !SingleMode())
+    Log::Info << "dual-tree " << TreeName() << " search..." << endl;
+  else if (!Naive())
+    Log::Info << "single-tree " << TreeName() << " search..." << endl;
+  else
+    Log::Info << "brute-force (naive) search..." << endl;
+
+  switch (treeType)
+  {
+    case KD_TREE:
+      if (!kdTreeRS->Naive() && !kdTreeRS->SingleMode())
+      {
+        // Build a second tree and search.
+        Timer::Start("tree_building");
+        Log::Info << "Building query tree..." << endl;
+        vector<size_t> oldFromNewQueries;
+        typename RSType<tree::KDTree>::Tree queryTree(move(querySet),
+            oldFromNewQueries, leafSize);
+        Log::Info << "Tree built." << endl;
+        Timer::Stop("tree_building");
+
+        vector<vector<size_t>> neighborsOut;
+        vector<vector<double>> distancesOut;
+        kdTreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
+
+        // Remap the query points.
+        neighbors.resize(queryTree.Dataset().n_cols);
+        distances.resize(queryTree.Dataset().n_cols);
+        for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
+        {
+          neighbors[i] = neighborsOut[oldFromNewQueries[i]];
+          distances[i] = distancesOut[oldFromNewQueries[i]];
+        }
+      }
+      else
+      {
+        // Search without building a second tree.
+        kdTreeRS->Search(querySet, range, neighbors, distances);
+      }
+      break;
+
+    case COVER_TREE:
+      coverTreeRS->Search(querySet, range, neighbors, distances);
+      break;
+
+    case R_TREE:
+      rTreeRS->Search(querySet, range, neighbors, distances);
+      break;
+
+    case R_STAR_TREE:
+      rStarTreeRS->Search(querySet, range, neighbors, distances);
+      break;
+
+    case BALL_TREE:
+      if (!ballTreeRS->Naive() && !ballTreeRS->SingleMode())
+      {
+        // Build a second tree and search.
+        Timer::Start("tree_building");
+        Log::Info << "Building query tree..." << endl;
+        vector<size_t> oldFromNewQueries;
+        typename RSType<tree::BallTree>::Tree queryTree(move(querySet),
+            oldFromNewQueries, leafSize);
+        Log::Info << "Tree built." << endl;
+        Timer::Stop("tree_building");
+
+        vector<vector<size_t>> neighborsOut;
+        vector<vector<double>> distancesOut;
+        ballTreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
+
+        // Remap the query points.
+        neighbors.resize(queryTree.Dataset().n_cols);
+        distances.resize(queryTree.Dataset().n_cols);
+        for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
+        {
+          neighbors[i] = neighborsOut[oldFromNewQueries[i]];
+          distances[i] = distancesOut[oldFromNewQueries[i]];
+        }
+      }
+      else
+      {
+        // Search without building a second tree.
+        ballTreeRS->Search(querySet, range, neighbors, distances);
+      }
+      break;
+  }
+}
+
+// Perform range search (monochromatic case).
+void RSModel::Search(const math::Range& range,
+                     vector<vector<size_t>>& neighbors,
+                     vector<vector<double>>& distances)
+{
+  Log::Info << "Search for points in the range [" << range.Lo() << ", "
+      << range.Hi() << "] with ";
+  if (!Naive() && !SingleMode())
+    Log::Info << "dual-tree " << TreeName() << " search..." << endl;
+  else if (!Naive())
+    Log::Info << "single-tree " << TreeName() << " search..." << endl;
+  else
+    Log::Info << "brute-force (naive) search..." << endl;
+
+  switch (treeType)
+  {
+    case KD_TREE:
+      kdTreeRS->Search(range, neighbors, distances);
+      break;
+
+    case COVER_TREE:
+      coverTreeRS->Search(range, neighbors, distances);
+      break;
+
+    case R_TREE:
+      rTreeRS->Search(range, neighbors, distances);
+      break;
+
+    case R_STAR_TREE:
+      rStarTreeRS->Search(range, neighbors, distances);
+      break;
+
+    case BALL_TREE:
+      ballTreeRS->Search(range, neighbors, distances);
+      break;
+  }
+}
+
+// Get the name of the tree type.
+std::string RSModel::TreeName() const
+{
+  switch (treeType)
+  {
+    case KD_TREE:
+      return "kd-tree";
+    case COVER_TREE:
+      return "cover tree";
+    case R_TREE:
+      return "R tree";
+    case R_STAR_TREE:
+      return "R* tree";
+    case BALL_TREE:
+      return "ball tree";
+    default:
+      return "unknown tree";
+  }
+}
+
+// Clean memory.
+void RSModel::CleanMemory()
+{
+  if (kdTreeRS)
+    delete kdTreeRS;
+  if (coverTreeRS)
+    delete coverTreeRS;
+  if (rTreeRS)
+    delete rTreeRS;
+  if (rStarTreeRS)
+    delete rStarTreeRS;
+  if (ballTreeRS)
+    delete ballTreeRS;
+
+  kdTreeRS = NULL;
+  coverTreeRS = NULL;
+  rTreeRS = NULL;
+  rStarTreeRS = NULL;
+  ballTreeRS = NULL;
+}
diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp
new file mode 100644
index 0000000..0a249fd
--- /dev/null
+++ b/src/mlpack/methods/range_search/rs_model.hpp
@@ -0,0 +1,171 @@
+/**
+ * @file rs_model.hpp
+ * @author Ryan Curtin
+ *
+ * This is a model for range search.  It is useful in that it provides an easy
+ * way to serialize a model, abstracts away the different types of trees, and
+ * also reflects the RangeSearch API and automatically directs to the right
+ * tree types.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_HPP
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
+
+#include "range_search.hpp"
+
+namespace mlpack {
+namespace range {
+
+class RSModel
+{
+ public:
+  enum TreeTypes
+  {
+    KD_TREE,
+    COVER_TREE,
+    R_TREE,
+    R_STAR_TREE,
+    BALL_TREE
+  };
+
+ private:
+  int treeType;
+  size_t leafSize;
+
+  //! If true, we randomly project the data into a new basis before search.
+  bool randomBasis;
+  //! Random projection matrix.
+  arma::mat q;
+
+  //! The mostly-specified type of the range search model.
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+  using RSType = RangeSearch<metric::EuclideanDistance, arma::mat, TreeType>;
+
+  // Only one of these pointers will be non-NULL.
+  //! kd-tree based range search object (NULL if not in use).
+  RSType<tree::KDTree>* kdTreeRS;
+  //! Cover tree based range search object (NULL if not in use).
+  RSType<tree::StandardCoverTree>* coverTreeRS;
+  //! R tree based range search object (NULL if not in use).
+  RSType<tree::RTree>* rTreeRS;
+  //! R* tree based range search object (NULL if not in use).
+  RSType<tree::RStarTree>* rStarTreeRS;
+  //! Ball tree based range search object (NULL if not in use).
+  RSType<tree::BallTree>* ballTreeRS;
+
+ public:
+  /**
+   * Initialize the RSModel with the given type and whether or not a random
+   * basis should be used.
+   *
+   * @param treeType Type of tree to use.
+   * @param randomBasis Whether or not to use a random basis.
+   */
+  RSModel(const int treeType = TreeTypes::KD_TREE,
+          const bool randomBasis = false);
+
+  /**
+   * Clean memory, if necessary.
+   */
+  ~RSModel();
+
+  //! Serialize the range search model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
+  //! Expose the dataset.
+  const arma::mat& Dataset() const;
+
+  //! Get whether the model is in single-tree search mode.
+  bool SingleMode() const;
+  //! Modify whether the model is in single-tree search mode.
+  bool& SingleMode();
+
+  //! Get whether the model is in naive search mode.
+  bool Naive() const;
+  //! Modify whether the model is in naive search mode.
+  bool& Naive();
+
+  //! Get the leaf size (applicable to everything but the cover tree).
+  size_t LeafSize() const { return leafSize; }
+  //! Modify the leaf size (applicable to everything but the cover tree).
+  size_t& LeafSize() { return leafSize; }
+
+  //! Get the type of tree.
+  int TreeType() const { return treeType; }
+  //! Modify the type of tree (don't do this after the model has been built).
+  int& TreeType() { return treeType; }
+
+  //! Get whether a random basis is used.
+  bool RandomBasis() const { return randomBasis; }
+  //! Modify whether a random basis is used (don't do this after the model has
+  //! been built).
+  bool& RandomBasis() { return randomBasis; }
+
+  /**
+   * Build the reference tree on the given dataset with the given parameters.
+   * This takes possession of the reference set to avoid a copy.
+   *
+   * @param referenceSet Set of reference points.
+   * @param leafSize Leaf size of tree (ignored for the cover tree).
+   * @param naive Whether naive search should be used.
+   * @param singleMode Whether single-tree search should be used.
+   */
+  void BuildModel(arma::mat&& referenceSet,
+                  const size_t leafSize,
+                  const bool naive,
+                  const bool singleMode);
+
+  /**
+   * Perform range search.  This takes possession of the query set, so the query
+   * set will not be usable after the search.  For more information on the
+   * output format, see RangeSearch<>::Search().
+   *
+   * @param querySet Set of query points.
+   * @param range Range to search for.
+   * @param neighbors Output: neighbors falling within the desired range.
+   * @param distances Output: distances of neighbors.
+   */
+  void Search(arma::mat&& querySet,
+              const math::Range& range,
+              std::vector<std::vector<size_t>>& neighbors,
+              std::vector<std::vector<double>>& distances);
+
+  /**
+   * Perform monochromatic range search, with the reference set as the query
+   * set.  For more information on the output format, see
+   * RangeSearch<>::Search().
+   *
+   * @param range Range to search for.
+   * @param neighbors Output: neighbors falling within the desired range.
+   * @param distances Output: distances of neighbors.
+   */
+  void Search(const math::Range& range,
+              std::vector<std::vector<size_t>>& neighbors,
+              std::vector<std::vector<double>>& distances);
+
+ private:
+  /**
+   * Return a string representing the name of the tree.  This is used for
+   * logging output.
+   */
+  std::string TreeName() const;
+
+  /**
+   * Clean up memory.
+   */
+  void CleanMemory();
+};
+
+} // namespace range
+} // namespace mlpack
+
+// Include implementation (of Serialize() and inline functions).
+#include "rs_model_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp
new file mode 100644
index 0000000..55f38d0
--- /dev/null
+++ b/src/mlpack/methods/range_search/rs_model_impl.hpp
@@ -0,0 +1,138 @@
+/**
+ * @file rs_model_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of Serialize() and inline functions for RSModel.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_IMPL_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "rs_model.hpp"
+
+namespace mlpack {
+namespace range {
+
+// Serialize the model.
+template<typename Archive>
+void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  ar & CreateNVP(treeType, "treeType");
+  ar & CreateNVP(randomBasis, "randomBasis");
+  ar & CreateNVP(q, "q");
+
+  // This should never happen, but just in case...
+  if (Archive::is_loading::value)
+    CleanMemory();
+
+  // We'll only need to serialize one of the model objects, based on the type.
+  switch (treeType)
+  {
+    case KD_TREE:
+      ar & CreateNVP(kdTreeRS, "range_search_model");
+      break;
+
+    case COVER_TREE:
+      ar & CreateNVP(coverTreeRS, "range_search_model");
+      break;
+
+    case R_TREE:
+      ar & CreateNVP(rTreeRS, "range_search_model");
+      break;
+
+    case R_STAR_TREE:
+      ar & CreateNVP(rStarTreeRS, "range_search_model");
+      break;
+
+    case BALL_TREE:
+      ar & CreateNVP(ballTreeRS, "range_search_model");
+      break;
+  }
+}
+
+inline const arma::mat& RSModel::Dataset() const
+{
+  if (kdTreeRS)
+    return kdTreeRS->ReferenceSet();
+  else if (coverTreeRS)
+    return coverTreeRS->ReferenceSet();
+  else if (rTreeRS)
+    return rTreeRS->ReferenceSet();
+  else if (rStarTreeRS)
+    return rStarTreeRS->ReferenceSet();
+  else if (ballTreeRS)
+    return ballTreeRS->ReferenceSet();
+
+  throw std::runtime_error("no range search model initialized");
+}
+
+inline bool RSModel::SingleMode() const
+{
+  if (kdTreeRS)
+    return kdTreeRS->SingleMode();
+  else if (coverTreeRS)
+    return coverTreeRS->SingleMode();
+  else if (rTreeRS)
+    return rTreeRS->SingleMode();
+  else if (rStarTreeRS)
+    return rStarTreeRS->SingleMode();
+  else if (ballTreeRS)
+    return ballTreeRS->SingleMode();
+
+  throw std::runtime_error("no range search model initialized");
+}
+
+inline bool& RSModel::SingleMode()
+{
+  if (kdTreeRS)
+    return kdTreeRS->SingleMode();
+  else if (coverTreeRS)
+    return coverTreeRS->SingleMode();
+  else if (rTreeRS)
+    return rTreeRS->SingleMode();
+  else if (rStarTreeRS)
+    return rStarTreeRS->SingleMode();
+  else if (ballTreeRS)
+    return ballTreeRS->SingleMode();
+
+  throw std::runtime_error("no range search model initialized");
+}
+
+inline bool RSModel::Naive() const
+{
+  if (kdTreeRS)
+    return kdTreeRS->Naive();
+  else if (coverTreeRS)
+    return coverTreeRS->Naive();
+  else if (rTreeRS)
+    return rTreeRS->Naive();
+  else if (rStarTreeRS)
+    return rStarTreeRS->Naive();
+  else if (ballTreeRS)
+    return ballTreeRS->Naive();
+
+  throw std::runtime_error("no range search model initialized");
+}
+
+inline bool& RSModel::Naive()
+{
+  if (kdTreeRS)
+    return kdTreeRS->Naive();
+  else if (coverTreeRS)
+    return coverTreeRS->Naive();
+  else if (rTreeRS)
+    return rTreeRS->Naive();
+  else if (rStarTreeRS)
+    return rStarTreeRS->Naive();
+  else if (ballTreeRS)
+    return ballTreeRS->Naive();
+
+  throw std::runtime_error("no range search model initialized");
+}
+
+} // namespace range
+} // namespace mlpack
+
+#endif



More information about the mlpack-git mailing list