[mlpack-git] master: - DTree class templated. (aa2ad99)

gitdub at mlpack.org gitdub at mlpack.org
Tue Oct 18 05:43:35 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea

>---------------------------------------------------------------

commit aa2ad993b56e379261ff7eecfebb666d80ee8e64
Author: theJonan <ivan at jonan.info>
Date:   Thu Oct 13 16:10:43 2016 +0300

    - DTree class templated.


>---------------------------------------------------------------

aa2ad993b56e379261ff7eecfebb666d80ee8e64
 src/mlpack/methods/det/CMakeLists.txt              |   2 +-
 src/mlpack/methods/det/dtree.hpp                   |  98 +++++-------
 .../methods/det/{dtree.cpp => dtree_impl.hpp}      | 178 ++++++++++++---------
 3 files changed, 141 insertions(+), 137 deletions(-)

diff --git a/src/mlpack/methods/det/CMakeLists.txt b/src/mlpack/methods/det/CMakeLists.txt
index 73b0474..66edbe6 100644
--- a/src/mlpack/methods/det/CMakeLists.txt
+++ b/src/mlpack/methods/det/CMakeLists.txt
@@ -4,7 +4,7 @@
 set(SOURCES
   # the DET class
   dtree.hpp
-  dtree.cpp
+  dtree_impl.hpp
 
   # the util file
   dt_utils.hpp
diff --git a/src/mlpack/methods/det/dtree.hpp b/src/mlpack/methods/det/dtree.hpp
index 8e8bbdf..f39b5bf 100644
--- a/src/mlpack/methods/det/dtree.hpp
+++ b/src/mlpack/methods/det/dtree.hpp
@@ -36,10 +36,18 @@ namespace det /** Density Estimation Trees */ {
  * }
  * @endcode
  */
+template <typename MatType,
+          typename VecType,
+          typename TagType = int>
 class DTree
 {
  public:
   /**
+   * The actual, underlying type we're working with
+   */
+  typedef typename MatType::elem_type ElemType;
+  
+  /**
    * Create an empty density estimation tree.
    */
   DTree();
@@ -52,8 +60,8 @@ class DTree
    * @param minVals Minimum values of the bounding box.
    * @param totalPoints Total number of points in the dataset.
    */
-  DTree(const arma::vec& maxVals,
-        const arma::vec& minVals,
+  DTree(const VecType& maxVals,
+        const VecType& minVals,
         const size_t totalPoints);
 
   /**
@@ -64,7 +72,7 @@ class DTree
    *
    * @param data Dataset to build tree on.
    */
-  DTree(arma::mat& data);
+  DTree(MatType& data);
 
   /**
    * Create a child node of a density estimation tree given the bounding box
@@ -78,8 +86,8 @@ class DTree
    * @param end End of points represented by this node in the data matrix.
    * @param error log-negative error of this node.
    */
-  DTree(const arma::vec& maxVals,
-        const arma::vec& minVals,
+  DTree(const VecType& maxVals,
+        const VecType& minVals,
         const size_t start,
         const size_t end,
         const double logNegError);
@@ -95,8 +103,8 @@ class DTree
    * @param start Start of points represented by this node in the data matrix.
    * @param end End of points represented by this node in the data matrix.
    */
-  DTree(const arma::vec& maxVals,
-        const arma::vec& minVals,
+  DTree(const VecType& maxVals,
+        const VecType& minVals,
         const size_t totalPoints,
         const size_t start,
         const size_t end);
@@ -114,7 +122,7 @@ class DTree
    * @param maxLeafSize Maximum size of a leaf.
    * @param minLeafSize Minimum size of a leaf.
    */
-  double Grow(arma::mat& data,
+  double Grow(MatType& data,
               arma::Col<size_t>& oldFromNew,
               const bool useVolReg = false,
               const size_t maxLeafSize = 10,
@@ -137,16 +145,7 @@ class DTree
    *
    * @param query Point to estimate density of.
    */
-  double ComputeValue(const arma::vec& query) const;
-
-  /**
-   * Print the tree in a depth-first manner (this function is called
-   * recursively).
-   *
-   * @param fp File to write the tree to.
-   * @param level Level of the tree (should start at 0).
-   */
-  void WriteTree(FILE *fp, const size_t level = 0) const;
+  double ComputeValue(const VecType& query) const;
 
   /**
    * Index the buckets for possible usage later; this results in every leaf in
@@ -155,7 +154,7 @@ class DTree
    *
    * @param tag Tag for the next leaf; leave at 0 for the initial call.
    */
-  int TagTree(const int tag = 0);
+  TagType TagTree(const TagType tag = 0);
 
   /**
    * Return the tag of the leaf containing the query.  This is useful for
@@ -163,7 +162,7 @@ class DTree
    *
    * @param query Query to search for.
    */
-  int FindBucket(const arma::vec& query) const;
+  TagType FindBucket(const VecType& query) const;
 
   /**
    * Compute the variable importance of each dimension in the learned tree.
@@ -183,7 +182,7 @@ class DTree
   /**
    * Return whether a query point is within the range of this node.
    */
-  bool WithinRange(const arma::vec& query) const;
+  bool WithinRange(const VecType& query) const;
 
  private:
   // The indices in the complete set of points
@@ -208,7 +207,7 @@ class DTree
   size_t splitDim;
 
   //! The split value on the splitting dimension for this node.
-  double splitValue;
+  ElemType splitValue;
 
   //! log-negative-L2-error of the node.
   double logNegError;
@@ -229,7 +228,7 @@ class DTree
   double logVolume;
 
   //! The tag for the leaf, used for hashing points.
-  int bucketTag;
+  TagType bucketTag;
 
   //! Upper part of alpha sum; used for pruning.
   double alphaUpper;
@@ -247,7 +246,7 @@ class DTree
   //! Return the split dimension of this node.
   size_t SplitDim() const { return splitDim; }
   //! Return the split value of this node.
-  double SplitValue() const { return splitValue; }
+  ElemType SplitValue() const { return splitValue; }
   //! Return the log negative error of this node.
   double LogNegError() const { return logNegError; }
   //! Return the log negative error of all descendants of this node.
@@ -267,51 +266,24 @@ class DTree
   bool Root() const { return root; }
   //! Return the upper part of the alpha sum.
   double AlphaUpper() const { return alphaUpper; }
+  //! Return the current bucket's ID, if leaf, or -1 otherwise
+  TagType BucketTag() const { return subtreeLeaves == 1 ? bucketTag : -1; }
 
   //! Return the maximum values.
-  const arma::vec& MaxVals() const { return maxVals; }
+  const VecType& MaxVals() const { return maxVals; }
   //! Modify the maximum values.
-  arma::vec& MaxVals() { return maxVals; }
+  VecType& MaxVals() { return maxVals; }
 
   //! Return the minimum values.
-  const arma::vec& MinVals() const { return minVals; }
+  const VecType& MinVals() const { return minVals; }
   //! Modify the minimum values.
-  arma::vec& MinVals() { return minVals; }
+  VecType& MinVals() { return minVals; }
 
   /**
    * Serialize the density estimation tree.
    */
   template<typename Archive>
-  void Serialize(Archive& ar, const unsigned int /* version */)
-  {
-    using data::CreateNVP;
-
-    ar & CreateNVP(start, "start");
-    ar & CreateNVP(end, "end");
-    ar & CreateNVP(maxVals, "maxVals");
-    ar & CreateNVP(minVals, "minVals");
-    ar & CreateNVP(splitDim, "splitDim");
-    ar & CreateNVP(splitValue, "splitValue");
-    ar & CreateNVP(logNegError, "logNegError");
-    ar & CreateNVP(subtreeLeavesLogNegError, "subtreeLeavesLogNegError");
-    ar & CreateNVP(subtreeLeaves, "subtreeLeaves");
-    ar & CreateNVP(root, "root");
-    ar & CreateNVP(ratio, "ratio");
-    ar & CreateNVP(logVolume, "logVolume");
-    ar & CreateNVP(bucketTag, "bucketTag");
-    ar & CreateNVP(alphaUpper, "alphaUpper");
-
-    if (Archive::is_loading::value)
-    {
-      if (left)
-        delete left;
-      if (right)
-        delete right;
-    }
-
-    ar & CreateNVP(left, "left");
-    ar & CreateNVP(right, "right");
-  }
+  void Serialize(Archive& ar, const unsigned int /* version */);
 
  private:
 
@@ -320,9 +292,9 @@ class DTree
   /**
    * Find the dimension to split on.
    */
-  bool FindSplit(const arma::mat& data,
+  bool FindSplit(const MatType& data,
                  size_t& splitDim,
-                 double& splitValue,
+                 ElemType& splitValue,
                  double& leftError,
                  double& rightError,
                  const size_t minLeafSize = 5) const;
@@ -330,9 +302,9 @@ class DTree
   /**
    * Split the data, returning the number of points left of the split.
    */
-  size_t SplitData(arma::mat& data,
+  size_t SplitData(MatType& data,
                    const size_t splitDim,
-                   const double splitValue,
+                   const ElemType splitValue,
                    arma::Col<size_t>& oldFromNew) const;
 
 };
@@ -340,4 +312,6 @@ class DTree
 } // namespace det
 } // namespace mlpack
 
+#include "dtree_impl.hpp"
+
 #endif // MLPACK_METHODS_DET_DTREE_HPP
diff --git a/src/mlpack/methods/det/dtree.cpp b/src/mlpack/methods/det/dtree_impl.hpp
similarity index 77%
rename from src/mlpack/methods/det/dtree.cpp
rename to src/mlpack/methods/det/dtree_impl.hpp
index c88d480..5ef3758 100644
--- a/src/mlpack/methods/det/dtree.cpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -12,7 +12,8 @@
 using namespace mlpack;
 using namespace det;
 
-DTree::DTree() :
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::DTree() :
     start(0),
     end(0),
     splitDim(size_t(-1)),
@@ -31,9 +32,11 @@ DTree::DTree() :
 
 
 // Root node initializers
-DTree::DTree(const arma::vec& maxVals,
-             const arma::vec& minVals,
-             const size_t totalPoints) :
+
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
+                                        const VecType& minVals,
+                                        const size_t totalPoints) :
     start(0),
     end(totalPoints),
     maxVals(maxVals),
@@ -52,7 +55,8 @@ DTree::DTree(const arma::vec& maxVals,
     right(NULL)
 { /* Nothing to do. */ }
 
-DTree::DTree(arma::mat& data) :
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::DTree(MatType & data) :
     start(0),
     end(data.n_cols),
     splitDim(size_t(-1)),
@@ -88,11 +92,12 @@ DTree::DTree(arma::mat& data) :
 
 
 // Non-root node initializers
-DTree::DTree(const arma::vec& maxVals,
-             const arma::vec& minVals,
-             const size_t start,
-             const size_t end,
-             const double logNegError) :
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
+                                        const VecType& minVals,
+                                        const size_t start,
+                                        const size_t end,
+                                        const double logNegError) :
     start(start),
     end(end),
     maxVals(maxVals),
@@ -111,11 +116,12 @@ DTree::DTree(const arma::vec& maxVals,
     right(NULL)
 { /* Nothing to do. */ }
 
-DTree::DTree(const arma::vec& maxVals,
-             const arma::vec& minVals,
-             const size_t totalPoints,
-             const size_t start,
-             const size_t end) :
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
+                                        const VecType& minVals,
+                                        const size_t totalPoints,
+                                        const size_t start,
+                                        const size_t end) :
     start(start),
     end(end),
     maxVals(maxVals),
@@ -134,7 +140,8 @@ DTree::DTree(const arma::vec& maxVals,
     right(NULL)
 { /* Nothing to do. */ }
 
-DTree::~DTree()
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::~DTree()
 {
   delete left;
   delete right;
@@ -142,7 +149,8 @@ DTree::~DTree()
 
 // This function computes the log-l2-negative-error of a given node from the
 // formula R(t) = log(|t|^2 / (N^2 V_t)).
-double DTree::LogNegativeError(const size_t totalPoints) const
+template <typename MatType, typename VecType, typename TagType>
+double DTree<MatType, VecType, TagType>::LogNegativeError(const size_t totalPoints) const
 {
   // log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
   double err = 2 * std::log((double) (end - start)) -
@@ -162,12 +170,13 @@ double DTree::LogNegativeError(const size_t totalPoints) const
 // This function finds the best split with respect to the L2-error, by trying
 // all possible splits.  The dataset is the full data set but the start and
 // end are used to obtain the point in this node.
-bool DTree::FindSplit(const arma::mat& data,
-                      size_t& splitDim,
-                      double& splitValue,
-                      double& leftError,
-                      double& rightError,
-                      const size_t minLeafSize) const
+template <typename MatType, typename VecType, typename TagType>
+bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
+                                                 size_t& splitDim,
+                                                 ElemType& splitValue,
+                                                 double& leftError,
+                                                 double& rightError,
+                                                 const size_t minLeafSize) const
 {
   // Ensure the dimensionality of the data is the same as the dimensionality of
   // the bounding rectangle.
@@ -180,12 +189,20 @@ bool DTree::FindSplit(const arma::mat& data,
   bool splitFound = false;
 
   // Loop through each dimension.
-  for (size_t dim = 0; dim < maxVals.n_elem; dim++)
+#ifdef _WIN32
+  #pragma omp parallel for default(none) \
+    shared(testSize, cvData, prunedSequence, regularizationConstants, dataset)
+  for (intmax_t dim = 0; fold < (intmax_t) maxVals.n_elem; ++dim)
+#else
+  #pragma omp parallel for default(none) \
+    shared(testSize, cvData, prunedSequence, regularizationConstants, dataset)
+  for (size_t dim = 0; dim < maxVals.n_elem; ++dim)
+#endif
   {
     // Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
     // think of how to do that...
-    const double min = minVals[dim];
-    const double max = maxVals[dim];
+    const ElemType min = minVals[dim];
+    const ElemType max = maxVals[dim];
 
     // If there is nothing to split in this dimension, move on.
     if (max - min == 0.0)
@@ -197,7 +214,7 @@ bool DTree::FindSplit(const arma::mat& data,
     double minDimError = std::pow(points, 2.0) / (max - min);
     double dimLeftError = 0.0; // For -Wuninitialized.  These variables will
     double dimRightError = 0.0; // always be set to something else before use.
-    double dimSplitValue = 0.0;
+    ElemType dimSplitValue = 0.0;
 
     // Find the log volume of all the other dimensions.
     double volumeWithoutDim = logVolume - std::log(max - min);
@@ -214,7 +231,7 @@ bool DTree::FindSplit(const arma::mat& data,
     {
       // This makes sense for real continuous data.  This kinda corrupts the
       // data and estimation if the data is ordinal.
-      const double split = (dimVec[i] + dimVec[i + 1]) / 2.0;
+      const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
 
       if (split == dimVec[i])
         continue; // We can't split here (two points are the same).
@@ -269,10 +286,11 @@ bool DTree::FindSplit(const arma::mat& data,
   return splitFound;
 }
 
-size_t DTree::SplitData(arma::mat& data,
-                        const size_t splitDim,
-                        const double splitValue,
-                        arma::Col<size_t>& oldFromNew) const
+template <typename MatType, typename VecType, typename TagType>
+size_t DTree<MatType, VecType, TagType>::SplitData(MatType& data,
+                                                   const size_t splitDim,
+                                                   const double splitValue,
+                                                   arma::Col<size_t>& oldFromNew) const
 {
   // Swap all columns such that any columns with value in dimension splitDim
   // less than or equal to splitValue are on the left side, and all others are
@@ -303,11 +321,12 @@ size_t DTree::SplitData(arma::mat& data,
 }
 
 // Greedily expand the tree
-double DTree::Grow(arma::mat& data,
-                   arma::Col<size_t>& oldFromNew,
-                   const bool useVolReg,
-                   const size_t maxLeafSize,
-                   const size_t minLeafSize)
+template <typename MatType, typename VecType, typename TagType>
+double DTree<MatType, VecType, TagType>::Grow(MatType& data,
+                                              arma::Col<size_t>& oldFromNew,
+                                              const bool useVolReg,
+                                              const size_t maxLeafSize,
+                                              const size_t minLeafSize)
 {
   Log::Assert(data.n_rows == maxVals.n_elem);
   Log::Assert(data.n_rows == minVals.n_elem);
@@ -450,10 +469,10 @@ double DTree::Grow(arma::mat& data,
 }
 
 
-double DTree::PruneAndUpdate(const double oldAlpha,
-                             const size_t points,
-                             const bool useVolReg)
-
+template <typename MatType, typename VecType, typename TagType>
+double DTree<MatType, VecType, TagType>::PruneAndUpdate(const double oldAlpha,
+                                                        const size_t points,
+                                                        const bool useVolReg)
 {
   // Compute gT.
   if (subtreeLeaves == 1) // If we are a leaf...
@@ -565,7 +584,8 @@ double DTree::PruneAndUpdate(const double oldAlpha,
 //
 // Future improvement: Open up the range with epsilons on both sides where
 // epsilon depends on the density near the boundary.
-bool DTree::WithinRange(const arma::vec& query) const
+template <typename MatType, typename VecType, typename TagType>
+bool DTree<MatType, VecType, TagType>::WithinRange(const VecType& query) const
 {
   for (size_t i = 0; i < query.n_elem; ++i)
     if ((query[i] < minVals[i]) || (query[i] > maxVals[i]))
@@ -575,7 +595,8 @@ bool DTree::WithinRange(const arma::vec& query) const
 }
 
 
-double DTree::ComputeValue(const arma::vec& query) const
+template <typename MatType, typename VecType, typename TagType>
+double DTree<MatType, VecType, TagType>::ComputeValue(const VecType& query) const
 {
   Log::Assert(query.n_elem == maxVals.n_elem);
 
@@ -607,35 +628,9 @@ double DTree::ComputeValue(const arma::vec& query) const
 }
 
 
-void DTree::WriteTree(FILE *fp, const size_t level) const
-{
-  if (subtreeLeaves > 1)
-  {
-    fprintf(fp, "\n");
-    for (size_t i = 0; i < level; ++i)
-      fprintf(fp, "|\t");
-    fprintf(fp, "Var. %zu > %lg", splitDim, splitValue);
-
-    right->WriteTree(fp, level + 1);
-
-    fprintf(fp, "\n");
-    for (size_t i = 0; i < level; ++i)
-      fprintf(fp, "|\t");
-    fprintf(fp, "Var. %zu <= %lg ", splitDim, splitValue);
-
-    left->WriteTree(fp, level);
-  }
-  else // If we are a leaf...
-  {
-    fprintf(fp, ": f(x)=%lg", std::exp(std::log(ratio) - logVolume));
-    if (bucketTag != -1)
-      fprintf(fp, " BT:%d", bucketTag);
-  }
-}
-
-
 // Index the buckets for possible usage later.
-int DTree::TagTree(const int tag)
+template <typename MatType, typename VecType, typename TagType>
+TagType DTree<MatType, VecType, TagType>::TagTree(const TagType tag)
 {
   if (subtreeLeaves == 1)
   {
@@ -650,7 +645,8 @@ int DTree::TagTree(const int tag)
 }
 
 
-int DTree::FindBucket(const arma::vec& query) const
+template <typename MatType, typename VecType, typename TagType>
+TagType DTree<MatType, VecType, TagType>::FindBucket(const VecType& query) const
 {
   Log::Assert(query.n_elem == maxVals.n_elem);
 
@@ -670,8 +666,8 @@ int DTree::FindBucket(const arma::vec& query) const
   }
 }
 
-
-void DTree::ComputeVariableImportance(arma::vec& importances) const
+template <typename MatType, typename VecType, typename TagType>
+void DTree<MatType, VecType, TagType>::ComputeVariableImportance(arma::vec& importances) const
 {
   // Clear and set to right size.
   importances.zeros(maxVals.n_elem);
@@ -697,3 +693,37 @@ void DTree::ComputeVariableImportance(arma::vec& importances) const
     nodes.push(curNode.Right());
   }
 }
+
+template <typename MatType, typename VecType, typename TagType>
+template <typename Archive>
+void DTree<MatType, VecType, TagType>::Serialize(Archive& ar, const unsigned int /* version */)
+{
+  using data::CreateNVP;
+  
+  ar & CreateNVP(start, "start");
+  ar & CreateNVP(end, "end");
+  ar & CreateNVP(maxVals, "maxVals");
+  ar & CreateNVP(minVals, "minVals");
+  ar & CreateNVP(splitDim, "splitDim");
+  ar & CreateNVP(splitValue, "splitValue");
+  ar & CreateNVP(logNegError, "logNegError");
+  ar & CreateNVP(subtreeLeavesLogNegError, "subtreeLeavesLogNegError");
+  ar & CreateNVP(subtreeLeaves, "subtreeLeaves");
+  ar & CreateNVP(root, "root");
+  ar & CreateNVP(ratio, "ratio");
+  ar & CreateNVP(logVolume, "logVolume");
+  ar & CreateNVP(bucketTag, "bucketTag");
+  ar & CreateNVP(alphaUpper, "alphaUpper");
+  
+  if (Archive::is_loading::value)
+  {
+    if (left)
+      delete left;
+    if (right)
+      delete right;
+  }
+  
+  ar & CreateNVP(left, "left");
+  ar & CreateNVP(right, "right");
+}
+




More information about the mlpack-git mailing list