[mlpack-git] master: Merge branch 'spill-trees' of https://github.com/MarcosPividori/mlpack into MarcosPividori-spill-trees (6618cf3)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 18 13:34:30 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit 6618cf3ed088b082b1c5ac1111edad60098fe81e
Merge: 0f4b25a 815391b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Aug 18 13:34:30 2016 -0400

    Merge branch 'spill-trees' of https://github.com/MarcosPividori/mlpack into MarcosPividori-spill-trees


>---------------------------------------------------------------

6618cf3ed088b082b1c5ac1111edad60098fe81e
 CMakeLists.txt                                     |   4 +-
 src/mlpack/core/data/load_impl.hpp                 |   4 +-
 src/mlpack/core/tree/CMakeLists.txt                |  18 +
 .../core/tree/binary_space_tree/mean_split.hpp     |   1 +
 src/mlpack/core/tree/binary_space_tree/traits.hpp  |  20 +
 .../core/tree/cover_tree/cover_tree_impl.hpp       |  25 +-
 src/mlpack/core/tree/cover_tree/traits.hpp         |   5 +
 .../tree/rectangle_tree/r_star_tree_split_impl.hpp |   8 +-
 src/mlpack/core/tree/rectangle_tree/traits.hpp     |  12 +
 .../core/tree/rectangle_tree/x_tree_split_impl.hpp |   2 +-
 src/mlpack/core/tree/space_split/hyperplane.hpp    | 151 +++++
 .../core/tree/space_split/mean_space_split.hpp     |  45 ++
 .../tree/space_split/mean_space_split_impl.hpp     |  45 ++
 .../core/tree/space_split/midpoint_space_split.hpp |  45 ++
 .../tree/space_split/midpoint_space_split_impl.hpp |  40 ++
 .../core/tree/space_split/projection_vector.hpp    | 149 ++++
 src/mlpack/core/tree/space_split/space_split.hpp   |  67 ++
 .../core/tree/space_split/space_split_impl.hpp     | 103 +++
 src/mlpack/core/tree/spill_tree.hpp                |  21 +
 src/mlpack/core/tree/spill_tree/is_spill_tree.hpp  |  38 ++
 .../spill_dual_tree_traverser.hpp}                 |  35 +-
 .../spill_tree/spill_dual_tree_traverser_impl.hpp  | 416 ++++++++++++
 .../spill_single_tree_traverser.hpp}               |  29 +-
 .../spill_single_tree_traverser_impl.hpp           | 125 ++++
 src/mlpack/core/tree/spill_tree/spill_tree.hpp     | 450 ++++++++++++
 .../core/tree/spill_tree/spill_tree_impl.hpp       | 755 +++++++++++++++++++++
 src/mlpack/core/tree/spill_tree/traits.hpp         |  68 ++
 src/mlpack/core/tree/spill_tree/typedef.hpp        | 119 ++++
 src/mlpack/core/tree/tree_traits.hpp               |   6 +
 src/mlpack/core/util/backtrace.cpp                 |   4 +-
 .../simple_tolerance_termination.hpp               |   8 +-
 .../validation_RMSE_termination.hpp                |  10 +-
 .../svd_complete_incremental_learning.hpp          |   8 +-
 .../svd_incomplete_incremental_learning.hpp        |   6 +-
 .../ann/activation_functions/logistic_function.hpp |   2 +-
 src/mlpack/methods/ann/layer/dropconnect_layer.hpp |  28 +-
 src/mlpack/methods/cf/svd_wrapper_impl.hpp         |   4 +-
 src/mlpack/methods/hmm/hmm_regression_impl.hpp     |   2 +-
 src/mlpack/methods/neighbor_search/kfn_main.cpp    |  73 +-
 src/mlpack/methods/neighbor_search/knn_main.cpp    | 116 +++-
 .../methods/neighbor_search/neighbor_search.hpp    |  54 +-
 .../neighbor_search/neighbor_search_impl.hpp       | 237 +++++--
 .../neighbor_search/neighbor_search_rules.hpp      |   3 +
 .../neighbor_search/neighbor_search_rules_impl.hpp |   7 +-
 src/mlpack/methods/neighbor_search/ns_model.hpp    |  51 +-
 .../methods/neighbor_search/ns_model_impl.hpp      |  99 ++-
 src/mlpack/methods/neighbor_search/typedef.hpp     |  29 +
 src/mlpack/methods/pca/pca_main.cpp                |   4 +-
 src/mlpack/methods/rann/ra_search_rules.hpp        |   3 +
 .../regularized_svd/regularized_svd_function.cpp   |   2 +-
 src/mlpack/tests/CMakeLists.txt                    |  16 +-
 src/mlpack/tests/adaboost_test.cpp                 |   2 +-
 src/mlpack/tests/aknn_test.cpp                     |  41 ++
 src/mlpack/tests/hyperplane_test.cpp               | 135 ++++
 src/mlpack/tests/knn_test.cpp                      |  90 +++
 src/mlpack/tests/rectangle_tree_test.cpp           |   4 +-
 src/mlpack/tests/spill_tree_test.cpp               | 306 +++++++++
 57 files changed, 3935 insertions(+), 215 deletions(-)

diff --cc src/mlpack/core/tree/binary_space_tree/mean_split.hpp
index 960b5f3,8450f9d..bbc25d7
--- a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
@@@ -23,17 -23,29 +23,18 @@@ template<typename BoundType, typename M
  class MeanSplit
  {
   public:
 -  /**
 -   * Split the node according to the mean value in the dimension with maximum
 -   * width.
 -   *
 -   * @param bound The bound used for this node.
 -   * @param data The dataset used by the binary space tree.
 -   * @param begin Index of the starting point in the dataset that belongs to
 -   *    this node.
 -   * @param count Number of points in this node.
 -   * @param splitDimension This will be filled with the dimension the node is to
 -   *    be split on.
 -   * @param splitCol The index at which the dataset is divided into two parts
 -   *    after the rearrangement.
 -   */
 -  static bool SplitNode(const BoundType& bound,
 -                        MatType& data,
 -                        const size_t begin,
 -                        const size_t count,
 -                        size_t& splitCol);
 +  struct SplitInfo
 +  {
 +    //! The dimension to split the node on.
 +    size_t splitDimension;
 +    //! The split in dimension splitDimension is based on this value.
 +    double splitVal;
 +  };
+ 
    /**
 -   * Split the node according to the mean value in the dimension with maximum
 -   * width and return a list of changed indices.
 +   * Find the partition of the node. This method fills up the dimension
 +   * that will be used to split the node and the value according which the split
 +   * will be performed.
     *
     * @param bound The bound used for this node.
     * @param data The dataset used by the binary space tree.
diff --cc src/mlpack/core/tree/binary_space_tree/traits.hpp
index ce02ee6,799168c..f5dc4ac
--- a/src/mlpack/core/tree/binary_space_tree/traits.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/traits.hpp
@@@ -42,99 -42,6 +42,111 @@@ class TreeTraits<BinarySpaceTree<Metric
    static const bool FirstPointIsCentroid = false;
  
    /**
 +   * The tree has not got duplicated points.
 +   */
 +  static const bool HasDuplicatedPoints = false;
 +
 +  /**
 +   * Points are not contained at multiple levels of the binary space tree.
 +   */
 +  static const bool HasSelfChildren = false;
 +
 +  /**
 +   * Points are rearranged during building of the tree.
 +   */
 +  static const bool RearrangesDataset = true;
 +
 +  /**
 +   * This is always a binary tree.
 +   */
 +  static const bool BinaryTree = true;
++
++  /**
++   * Binary space trees don't have duplicated points, so NumDescendants()
++   * represents the number of unique descendant points.
++   */
++  static const bool UniqueNumDescendants = true;
 +};
 +
 +/**
 + * This is a specialization of the TreeType class to the max-split random
 + * projection tree. The only difference with general BinarySpaceTree is that the
 + * tree can have overlapping children.
 + */
 +template<typename MetricType,
 +         typename StatisticType,
 +         typename MatType,
 +         template<typename BoundMetricType, typename...> class BoundType>
 +class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
 +                                 RPTreeMaxSplit>>
 +{
 + public:
 +  /**
 +   * Children of a random projection tree node may overlap.
 +   */
 +  static const bool HasOverlappingChildren = true;
 +
 +  /**
 +   * The tree has not got duplicated points.
 +   */
 +  static const bool HasDuplicatedPoints = false;
 +
 +  /**
 +   * There is no guarantee that the first point in a node is its centroid.
 +   */
 +  static const bool FirstPointIsCentroid = false;
 +
 +  /**
 +   * Points are not contained at multiple levels of the binary space tree.
 +   */
 +  static const bool HasSelfChildren = false;
 +
 +  /**
 +   * Points are rearranged during building of the tree.
 +   */
 +  static const bool RearrangesDataset = true;
 +
 +  /**
 +   * This is always a binary tree.
 +   */
 +  static const bool BinaryTree = true;
++
++  /**
++   * Binary space trees don't have duplicated points, so NumDescendants()
++   * represents the number of unique descendant points.
++   */
++  static const bool UniqueNumDescendants = true;
 +};
 +
 +/**
 + * This is a specialization of the TreeType class to the mean-split random
 + * projection tree. The only difference with general BinarySpaceTree is that the
 + * tree can have overlapping children.
 + */
 +template<typename MetricType,
 +         typename StatisticType,
 +         typename MatType,
 +         template<typename BoundMetricType, typename...> class BoundType>
 +class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
 +                                 RPTreeMeanSplit>>
 +{
 + public:
 +  /**
 +   * Children of a random projection tree node may overlap.
 +   */
 +  static const bool HasOverlappingChildren = true;
 +
 +  /**
 +   * The tree has not got duplicated points.
 +   */
 +  static const bool HasDuplicatedPoints = false;
 +
 +  /**
 +   * There is no guarantee that the first point in a node is its centroid.
 +   */
 +  static const bool FirstPointIsCentroid = false;
 +
 +  /**
     * Points are not contained at multiple levels of the binary space tree.
     */
    static const bool HasSelfChildren = false;
@@@ -171,14 -84,8 +189,15 @@@ class TreeTraits<BinarySpaceTree<Metric
    static const bool HasSelfChildren = false;
    static const bool RearrangesDataset = true;
    static const bool BinaryTree = true;
++  static const bool UniqueNumDescendants = true;
  };
  
 +/**
 + * This is a specialization of the TreeType class to an arbitrary tree with
 + * HollowBallBound (currently only the vantage point tree is supported).
 + * The only difference with general BinarySpaceTree is that the tree can have
 + * overlapping children.
 + */
  template<typename MetricType,
           typename StatisticType,
           typename MatType,
diff --cc src/mlpack/methods/neighbor_search/kfn_main.cpp
index 0a67344,c2f1ea6..64790da
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@@ -224,15 -223,19 +228,19 @@@ int main(int argc, char *argv[]
      const string inputModelFile = CLI::GetParam<string>("input_model_file");
      data::Load(inputModelFile, "kfn_model", kfn, true); // Fatal on failure.
  
-     Log::Info << "Loaded kFN model from '" << inputModelFile << "' (trained on "
-         << kfn.Dataset().n_rows << "x" << kfn.Dataset().n_cols << " dataset)."
-         << endl;
- 
-     // Adjust singleMode and naive if necessary.
 -    knn.SingleMode() = CLI::HasParam("single_mode");
 -    knn.Naive() = CLI::HasParam("naive");
 -    knn.Epsilon() = epsilon;
 +    kfn.SingleMode() = CLI::HasParam("single_mode");
 +    kfn.Naive() = CLI::HasParam("naive");
-     kfn.LeafSize() = size_t(lsInt);
 +    kfn.Epsilon() = epsilon;
+ 
+     // If leaf_size wasn't provided, let's consider the current value in the
+     // loaded model.  Else, update it (only considered when building the query
+     // tree).
+     if (CLI::HasParam("leaf_size"))
 -      knn.LeafSize() = size_t(lsInt);
++      kfn.LeafSize() = size_t(lsInt);
+ 
+     Log::Info << "Loaded kFN model from '" << inputModelFile << "' (trained on "
+         << kfn.Dataset().n_rows << "x" << kfn.Dataset().n_cols << " dataset)."
+         << endl;
    }
  
    // Perform search, if desired.
@@@ -278,6 -281,47 +286,47 @@@
        data::Save(CLI::GetParam<string>("neighbors_file"), neighbors);
      if (CLI::HasParam("distances_file"))
        data::Save(CLI::GetParam<string>("distances_file"), distances);
+ 
+     // Calculate the effective error, if desired.
+     if (CLI::HasParam("true_distances_file"))
+     {
 -      if (knn.Epsilon() == 0)
++      if (kfn.Epsilon() == 0)
+         Log::Warn << "--true_distances_file (-D) specified on exact neighbor "
+             << "search." << endl;
+ 
+       const string trueDistancesFile = CLI::GetParam<string>(
+           "true_distances_file");
+       arma::mat trueDistances;
+       data::Load(trueDistancesFile, trueDistances, true);
+ 
+       if (trueDistances.n_rows != distances.n_rows ||
+           trueDistances.n_cols != distances.n_cols)
+         Log::Fatal << "The true distances file must have the same number of "
+             << "values than the set of distances being queried!" << endl;
+ 
+       Log::Info << "Effective error: " << KFN::EffectiveError(distances,
+           trueDistances) << endl;
+     }
+ 
+     // Calculate the recall, if desired.
+     if (CLI::HasParam("true_neighbors_file"))
+     {
 -      if (knn.Epsilon() == 0)
++      if (kfn.Epsilon() == 0)
+         Log::Warn << "--true_neighbors_file (-T) specified on exact neighbor "
+             << "search." << endl;
+ 
+       const string trueNeighborsFile = CLI::GetParam<string>(
+           "true_neighbors_file");
+       arma::Mat<size_t> trueNeighbors;
+       data::Load(trueNeighborsFile, trueNeighbors, true);
+ 
+       if (trueNeighbors.n_rows != neighbors.n_rows ||
+           trueNeighbors.n_cols != neighbors.n_cols)
+         Log::Fatal << "The true neighbors file must have the same number of "
+             << "values than the set of neighbors being queried!" << endl;
+ 
+       Log::Info << "Recall: " << KFN::Recall(neighbors, trueNeighbors) << endl;
+     }
    }
  
    if (CLI::HasParam("output_model_file"))
diff --cc src/mlpack/methods/neighbor_search/knn_main.cpp
index 486a46f,9ff6bd7..9b2921f
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@@ -63,12 -68,17 +68,17 @@@ PARAM_INT_IN("k", "Number of nearest ne
  
  // The user may specify the type of tree to use, and a few parameters for tree
  // building.
 -PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'vp', 'cover', 'r', "
 -    "'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', 'spill'.",
 -    "t", "kd");
 +PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'vp', 'rp', 'max-rp', "
-     "'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', "
-     "'r-plus-plus'.", "t", "kd");
- PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
-     "vp trees, random projection trees, R trees, R* trees, X trees, "
-     "Hilbert R trees, R+ trees and R++ trees).", "l", 20);
++    "'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus',"
++    " 'spill'.", "t", "kd");
+ PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, vp "
 -    "trees, R trees, R* trees, X trees, Hilbert R trees, R+ trees, R++ trees "
 -    "and spill trees).", "l", 20);
++    "trees, random projection trees, R trees, R* trees, X trees, "
++    "Hilbert R trees, R+ trees, R++ trees and spill trees).", "l", 20);
+ PARAM_DOUBLE_IN("tau", "Overlapping size (only valid for spill trees).", "u",
+     0);
+ PARAM_DOUBLE_IN("rho", "Balance threshold (only valid for spill trees).", "b",
+     0.7);
+ 
  PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
      "random orthogonal basis.", "R");
  PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@@ -181,16 -211,14 +211,18 @@@ int main(int argc, char *argv[]
        tree = KNNModel::R_PLUS_TREE;
      else if (treeType == "r-plus-plus")
        tree = KNNModel::R_PLUS_PLUS_TREE;
+     else if (treeType == "spill")
+       tree = KNNModel::SPILL_TREE;
      else if (treeType == "vp")
        tree = KNNModel::VP_TREE;
 +    else if (treeType == "rp")
 +      tree = KNNModel::RP_TREE;
 +    else if (treeType == "max-rp")
 +      tree = KNNModel::MAX_RP_TREE;
      else
        Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
 -          << "'kd', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', "
 -          << "'r-plus', 'r-plus-plus', 'vp' and 'spill'." << endl;
 +          << "'kd', 'vp', 'rp', 'max-rp', 'cover', 'r', 'r-star', 'x', 'ball', "
-           << "'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
++          << "'hilbert-r', 'r-plus' and 'r-plus-plus', and 'spill'." << endl;
  
      knn.TreeType() = tree;
      knn.RandomBasis() = randomBasis;
diff --cc src/mlpack/methods/neighbor_search/ns_model.hpp
index 5aece6b,bafccc1..48f6b26
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@@ -257,9 -277,8 +277,10 @@@ class NSMode
      HILBERT_R_TREE,
      R_PLUS_TREE,
      R_PLUS_PLUS_TREE,
 -    SPILL_TREE,
 -    VP_TREE
 +    VP_TREE,
 +    RP_TREE,
-     MAX_RP_TREE
++    MAX_RP_TREE,
++    SPILL_TREE
    };
  
   private:
@@@ -289,8 -313,7 +315,9 @@@
                   NSType<SortPolicy, tree::RPlusTree>*,
                   NSType<SortPolicy, tree::RPlusPlusTree>*,
                   NSType<SortPolicy, tree::VPTree>*,
 +                 NSType<SortPolicy, tree::RPTree>*,
-                  NSType<SortPolicy, tree::MaxRPTree>*> nSearch;
++                 NSType<SortPolicy, tree::MaxRPTree>*,
+                  SpillKNN*> nSearch;
  
   public:
    /**
diff --cc src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 46275f0,3fb2b2e..bdc0366
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@@ -398,13 -463,8 +463,15 @@@ void NSModel<SortPolicy>::BuildModel(ar
        nSearch = new NSType<SortPolicy, tree::VPTree>(naive, singleMode,
            epsilon);
        break;
 +    case RP_TREE:
 +      nSearch = new NSType<SortPolicy, tree::RPTree>(naive, singleMode,
 +          epsilon);
 +      break;
 +    case MAX_RP_TREE:
 +      nSearch = new NSType<SortPolicy, tree::MaxRPTree>(naive, singleMode,
 +          epsilon);
+     case SPILL_TREE:
+       nSearch = new SpillKNN(naive, singleMode, epsilon);
        break;
    }
  
@@@ -490,12 -550,10 +557,14 @@@ std::string NSModel<SortPolicy>::TreeNa
        return "R+ tree";
      case R_PLUS_PLUS_TREE:
        return "R++ tree";
+     case SPILL_TREE:
+       return "Spill tree";
      case VP_TREE:
 -      return "Vantage point tree";
 +      return "vantage point tree";
 +    case RP_TREE:
 +      return "random projection tree (mean split)";
 +    case MAX_RP_TREE:
 +      return "random projection tree (max split)";
      default:
        return "unknown tree";
    }




More information about the mlpack-git mailing list