[mlpack-git] master: - First successfull builtd. (37a9e50)

gitdub at mlpack.org gitdub at mlpack.org
Tue Oct 18 05:43:36 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea

>---------------------------------------------------------------

commit 37a9e5047db2a83cacc1ceb68ef0d47cbebbef7b
Author: theJonan <ivan at jonan.info>
Date:   Thu Oct 13 19:53:50 2016 +0300

    - First successfull builtd.


>---------------------------------------------------------------

37a9e5047db2a83cacc1ceb68ef0d47cbebbef7b
 src/mlpack/methods/det/CMakeLists.txt              |   4 -
 src/mlpack/methods/det/det_main.cpp                |   4 +-
 src/mlpack/methods/det/dt_utils.hpp                |  23 +--
 .../det/{dt_utils.cpp => dt_utils_impl.hpp}        |  59 ++++---
 src/mlpack/methods/det/dtree.hpp                   |   6 +-
 src/mlpack/methods/det/dtree_impl.hpp              | 171 +++++++++++++--------
 6 files changed, 152 insertions(+), 115 deletions(-)

diff --git a/src/mlpack/methods/det/CMakeLists.txt b/src/mlpack/methods/det/CMakeLists.txt
index 66edbe6..4dd3bc3 100644
--- a/src/mlpack/methods/det/CMakeLists.txt
+++ b/src/mlpack/methods/det/CMakeLists.txt
@@ -5,10 +5,6 @@ set(SOURCES
   # the DET class
   dtree.hpp
   dtree_impl.hpp
-
-  # the util file
-  dt_utils.hpp
-  dt_utils.cpp
 )
 
 # add directory name to sources
diff --git a/src/mlpack/methods/det/det_main.cpp b/src/mlpack/methods/det/det_main.cpp
index f1b33ba..26b394d 100644
--- a/src/mlpack/methods/det/det_main.cpp
+++ b/src/mlpack/methods/det/det_main.cpp
@@ -101,7 +101,7 @@ int main(int argc, char *argv[])
         << "(-T) is not specified." << endl;
 
   // Are we training a DET or loading from file?
-  DTree* tree;
+  DTree<arma::mat, arma::vec, int>* tree;
   if (CLI::HasParam("training_file"))
   {
     const string trainSetFile = CLI::GetParam<string>("training_file");
@@ -127,7 +127,7 @@ int main(int argc, char *argv[])
 
     // Obtain the optimal tree.
     Timer::Start("det_training");
-    tree = Trainer(trainingData, folds, regularization, maxLeafSize,
+    tree = Trainer<arma::mat, arma::vec, int>(trainingData, folds, regularization, maxLeafSize,
         minLeafSize, "");
     Timer::Stop("det_training");
 
diff --git a/src/mlpack/methods/det/dt_utils.hpp b/src/mlpack/methods/det/dt_utils.hpp
index c20f78a..067e6fe 100644
--- a/src/mlpack/methods/det/dt_utils.hpp
+++ b/src/mlpack/methods/det/dt_utils.hpp
@@ -25,8 +25,9 @@ namespace det {
  * @param numClasses Number of classes in dataset.
  * @param leafClassMembershipFile Name of file to print to (optional).
  */
-void PrintLeafMembership(DTree* dtree,
-                         const arma::mat& data,
+template <typename MatType, typename VecType, typename TagType>
+void PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
+                         const MatType& data,
                          const arma::Mat<size_t>& labels,
                          const size_t numClasses,
                          const std::string leafClassMembershipFile = "");
@@ -39,7 +40,8 @@ void PrintLeafMembership(DTree* dtree,
  * @param dtree Density tree to use.
  * @param viFile Name of file to print to (optional).
  */
-void PrintVariableImportance(const DTree* dtree,
+template <typename MatType, typename VecType, typename TagType>
+void PrintVariableImportance(const DTree<MatType, VecType, TagType>* dtree,
                              const std::string viFile = "");
 
 /**
@@ -54,14 +56,17 @@ void PrintVariableImportance(const DTree* dtree,
  * @param minLeafSize Minimum number of points allowed in a leaf.
  * @param unprunedTreeOutput Filename to print unpruned tree to (optional).
  */
-DTree* Trainer(arma::mat& dataset,
-               const size_t folds,
-               const bool useVolumeReg = false,
-               const size_t maxLeafSize = 10,
-               const size_t minLeafSize = 5,
-               const std::string unprunedTreeOutput = "");
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>* Trainer(MatType& dataset,
+                                          const size_t folds,
+                                          const bool useVolumeReg = false,
+                                          const size_t maxLeafSize = 10,
+                                          const size_t minLeafSize = 5,
+                                          const std::string unprunedTreeOutput = "");
 
 } // namespace det
 } // namespace mlpack
 
+#include "dt_utils_impl.hpp"
+
 #endif // MLPACK_METHODS_DET_DT_UTILS_HPP
diff --git a/src/mlpack/methods/det/dt_utils.cpp b/src/mlpack/methods/det/dt_utils_impl.hpp
similarity index 83%
rename from src/mlpack/methods/det/dt_utils.cpp
rename to src/mlpack/methods/det/dt_utils_impl.hpp
index b8946a9..4c057f7 100644
--- a/src/mlpack/methods/det/dt_utils.cpp
+++ b/src/mlpack/methods/det/dt_utils_impl.hpp
@@ -10,22 +10,23 @@
 using namespace mlpack;
 using namespace det;
 
-void mlpack::det::PrintLeafMembership(DTree* dtree,
-                                      const arma::mat& data,
+template <typename MatType, typename VecType, typename TagType>
+void mlpack::det::PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
+                                      const MatType& data,
                                       const arma::Mat<size_t>& labels,
                                       const size_t numClasses,
                                       const std::string leafClassMembershipFile)
 {
   // Tag the leaves with numbers.
-  int numLeaves = dtree->TagTree();
+  TagType numLeaves = dtree->TagTree();
 
   arma::Mat<size_t> table(numLeaves, (numClasses + 1));
   table.zeros();
 
   for (size_t i = 0; i < data.n_cols; i++)
   {
-    const arma::vec testPoint = data.unsafe_col(i);
-    const int leafTag = dtree->FindBucket(testPoint);
+    const VecType testPoint = data.unsafe_col(i);
+    const TagType leafTag = dtree->FindBucket(testPoint);
     const size_t label = labels[i];
     table(leafTag, label) += 1;
   }
@@ -57,8 +58,8 @@ void mlpack::det::PrintLeafMembership(DTree* dtree,
   return;
 }
 
-
-void mlpack::det::PrintVariableImportance(const DTree* dtree,
+template <typename MatType, typename VecType, typename TagType>
+void mlpack::det::PrintVariableImportance(const DTree<MatType, VecType, TagType>* dtree,
                                           const std::string viFile)
 {
   arma::vec imps;
@@ -96,15 +97,16 @@ void mlpack::det::PrintVariableImportance(const DTree* dtree,
 
 // This function trains the optimal decision tree using the given number of
 // folds.
-DTree* mlpack::det::Trainer(arma::mat& dataset,
-                            const size_t folds,
-                            const bool useVolumeReg,
-                            const size_t maxLeafSize,
-                            const size_t minLeafSize,
-                            const std::string unprunedTreeOutput)
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
+                                                       const size_t folds,
+                                                       const bool useVolumeReg,
+                                                       const size_t maxLeafSize,
+                                                       const size_t minLeafSize,
+                                                       const std::string unprunedTreeOutput)
 {
   // Initialize the tree.
-  DTree dtree(dataset);
+  DTree<MatType, VecType, TagType> dtree(dataset);
 
   // Prepare to grow the tree...
   arma::Col<size_t> oldFromNew(dataset.n_cols);
@@ -112,7 +114,7 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
     oldFromNew[i] = i;
 
   // Save the dataset since it would be modified while growing the tree.
-  arma::mat newDataset(dataset);
+  MatType newDataset(dataset);
 
   // Growing the tree
   double oldAlpha = 0.0;
@@ -148,8 +150,7 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
   std::vector<std::pair<double, double> > prunedSequence;
   while (dtree.SubtreeLeaves() > 1)
   {
-    std::pair<double, double> treeSeq(oldAlpha,
-        dtree.SubtreeLeavesLogNegError());
+    std::pair<double, double> treeSeq(oldAlpha, dtree.SubtreeLeavesLogNegError());
     prunedSequence.push_back(treeSeq);
     oldAlpha = alpha;
     alpha = dtree.PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg);
@@ -157,20 +158,18 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
     // Some sanity checks.  It seems that on some datasets, the error does not
     // increase as the tree is pruned but instead stays the same---hence the
     // "<=" in the final assert.
-    Log::Assert((alpha < std::numeric_limits<double>::max()) ||
-        (dtree.SubtreeLeaves() == 1));
+    Log::Assert((alpha < std::numeric_limits<double>::max()) || (dtree.SubtreeLeaves() == 1));
     Log::Assert(alpha > oldAlpha);
     Log::Assert(dtree.SubtreeLeavesLogNegError() <= treeSeq.second);
   }
 
-  std::pair<double, double> treeSeq(oldAlpha,
-      dtree.SubtreeLeavesLogNegError());
+  std::pair<double, double> treeSeq(oldAlpha, dtree.SubtreeLeavesLogNegError());
   prunedSequence.push_back(treeSeq);
 
   Log::Info << prunedSequence.size() << " trees in the sequence; maximum alpha:"
       << " " << oldAlpha << "." << std::endl;
 
-  arma::mat cvData(dataset);
+  MatType cvData(dataset);
   size_t testSize = dataset.n_cols / folds;
 
   arma::vec regularizationConstants(prunedSequence.size());
@@ -194,8 +193,8 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
     size_t start = fold * testSize;
     size_t end = std::min((size_t) (fold + 1) * testSize, (size_t) cvData.n_cols);
 
-    arma::mat test = cvData.cols(start, end - 1);
-    arma::mat train(cvData.n_rows, cvData.n_cols - test.n_cols);
+    MatType test = cvData.cols(start, end - 1);
+    MatType train(cvData.n_rows, cvData.n_cols - test.n_cols);
 
     if (start == 0 && end < cvData.n_cols)
     {
@@ -212,7 +211,7 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
     }
 
     // Initialize the tree.
-    DTree cvDTree(train);
+    DTree<MatType, VecType, TagType> cvDTree(train);
 
     // Getting ready to grow the tree...
     arma::Col<size_t> cvOldFromNew(train.n_cols);
@@ -252,13 +251,12 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
     double cvVal = 0.0;
     for (size_t i = 0; i < test.n_cols; ++i)
     {
-      arma::vec testPoint = test.unsafe_col(i);
+      VecType testPoint = test.unsafe_col(i);
       cvVal += cvDTree.ComputeValue(testPoint);
     }
 
     if (prunedSequence.size() > 2)
-      cvRegularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal /
-          (double) dataset.n_cols;
+      cvRegularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal / (double) dataset.n_cols;
 
     #pragma omp critical
     regularizationConstants += cvRegularizationConstants;
@@ -285,7 +283,7 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
   Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl;
 
   // Initialize the tree.
-  DTree* dtreeOpt = new DTree(dataset);
+  DTree<MatType, VecType, TagType>* dtreeOpt = new DTree<MatType, VecType, TagType>(dataset);
 
   // Getting ready to grow the tree...
   for (size_t i = 0; i < oldFromNew.n_elem; i++)
@@ -296,8 +294,7 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
 
   // Grow the tree.
   oldAlpha = -DBL_MAX;
-  alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
-      minLeafSize);
+  alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize, minLeafSize);
 
   // Prune with optimal alpha.
   while ((oldAlpha < optimalAlpha) && (dtreeOpt->SubtreeLeaves() > 1))
diff --git a/src/mlpack/methods/det/dtree.hpp b/src/mlpack/methods/det/dtree.hpp
index f39b5bf..f34750f 100644
--- a/src/mlpack/methods/det/dtree.hpp
+++ b/src/mlpack/methods/det/dtree.hpp
@@ -154,7 +154,7 @@ class DTree
    *
    * @param tag Tag for the next leaf; leave at 0 for the initial call.
    */
-  TagType TagTree(const TagType tag = 0);
+  TagType TagTree(const TagType& tag = 0);
 
   /**
    * Return the tag of the leaf containing the query.  This is useful for
@@ -271,13 +271,9 @@ class DTree
 
   //! Return the maximum values.
   const VecType& MaxVals() const { return maxVals; }
-  //! Modify the maximum values.
-  VecType& MaxVals() { return maxVals; }
 
   //! Return the minimum values.
   const VecType& MinVals() const { return minVals; }
-  //! Modify the minimum values.
-  VecType& MinVals() { return minVals; }
 
   /**
    * Serialize the density estimation tree.
diff --git a/src/mlpack/methods/det/dtree_impl.hpp b/src/mlpack/methods/det/dtree_impl.hpp
index 5ef3758..58131fb 100644
--- a/src/mlpack/methods/det/dtree_impl.hpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -1,4 +1,4 @@
- /**
+/**
  * @file dtree.cpp
  * @author Parikshit Ram (pram at cc.gatech.edu)
  *
@@ -8,16 +8,91 @@
  */
 #include "dtree.hpp"
 #include <stack>
+#include <mlpack/core/tree/perform_split.hpp>
 
 using namespace mlpack;
 using namespace det;
 
+namespace detail
+{
+  template <typename ElemType>
+  class DTreeSplit
+  {
+  public:
+    typedef DTreeSplit<ElemType>    SplitInfo;
+    
+    template<typename VecType>
+    static bool AssignToLeftNode(const VecType& point,
+                                 const SplitInfo& splitInfo)
+    {
+      return point[splitInfo.splitDimension] < splitInfo.splitVal;
+    }
+    
+  private:
+    ElemType    splitVal;
+    size_t      splitDimension;
+  };
+  
+  /**
+   * We need that function, to be able to specialize it for sparse matrices
+   * in a way which is much faster then usual iteration.
+   */
+  template <typename MatType, typename VecType>
+  void ExtractMinMax(const MatType& data,
+                     VecType& minVals,
+                     VecType& maxVals)
+  {
+    // Initialize to first column; values will be overwritten if necessary.
+    maxVals = data.col(0);
+    minVals = data.col(0);
+    
+    // Loop over data to extract maximum and minimum values in each dimension.
+    for (size_t i = 1; i < data.n_cols; ++i)
+    {
+      for (size_t j = 0; j < data.n_rows; ++j)
+      {
+        if (data(j, i) > maxVals[j])
+          maxVals[j] = data(j, i);
+        if (data(j, i) < minVals[j])
+          minVals[j] = data(j, i);
+      }
+    }
+  }
+  
+  /**
+   * Here is the optimized specialization
+   */
+  template <typename ElemType>
+  void ExtractMinMax(const arma::SpMat<ElemType>& data,
+                     arma::SpCol<ElemType>& minVals,
+                     arma::SpCol<ElemType>& maxVals)
+  {
+    // Initialize to first column; values will be overwritten if necessary.
+    maxVals = data.col(0);
+    minVals = data.col(0);
+    
+    typename arma::sp_mat::iterator dataEnd = data.end();
+    
+    // Loop over data to extract maximum and minimum values in each dimension.
+    for (typename arma::sp_mat::iterator i = data.begin(); i != dataEnd; ++i)
+    {
+      size_t j = i.row();
+      if (i.col() == 0)
+        continue; // we've already taken these values.
+      else if (*i > maxVals[j])
+        maxVals[j] = *i;
+      else if (*i < minVals[j])
+        minVals[j] = *i;
+    }
+  }
+};
+
 template <typename MatType, typename VecType, typename TagType>
 DTree<MatType, VecType, TagType>::DTree() :
     start(0),
     end(0),
     splitDim(size_t(-1)),
-    splitValue(DBL_MAX),
+    splitValue(std::numeric_limits<ElemType>::max()),
     logNegError(-DBL_MAX),
     subtreeLeavesLogNegError(-DBL_MAX),
     subtreeLeaves(0),
@@ -42,7 +117,7 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
     maxVals(maxVals),
     minVals(minVals),
     splitDim(size_t(-1)),
-    splitValue(DBL_MAX),
+    splitValue(std::numeric_limits<ElemType>::max()),
     logNegError(LogNegativeError(totalPoints)),
     subtreeLeavesLogNegError(-DBL_MAX),
     subtreeLeaves(0),
@@ -60,7 +135,7 @@ DTree<MatType, VecType, TagType>::DTree(MatType & data) :
     start(0),
     end(data.n_cols),
     splitDim(size_t(-1)),
-    splitValue(DBL_MAX),
+    splitValue(std::numeric_limits<ElemType>::max()),
     subtreeLeavesLogNegError(-DBL_MAX),
     subtreeLeaves(0),
     root(true),
@@ -71,26 +146,10 @@ DTree<MatType, VecType, TagType>::DTree(MatType & data) :
     left(NULL),
     right(NULL)
 {
-  // Initialize to first column; values will be overwritten if necessary.
-  maxVals = data.col(0);
-  minVals = data.col(0);
-
-  // Loop over data to extract maximum and minimum values in each dimension.
-  for (size_t i = 1; i < data.n_cols; ++i)
-  {
-    for (size_t j = 0; j < data.n_rows; ++j)
-    {
-      if (data(j, i) > maxVals[j])
-        maxVals[j] = data(j, i);
-      if (data(j, i) < minVals[j])
-        minVals[j] = data(j, i);
-    }
-  }
-
+  detail::ExtractMinMax(data, minVals, maxVals);
   logNegError = LogNegativeError(data.n_cols);
 }
 
-
 // Non-root node initializers
 template <typename MatType, typename VecType, typename TagType>
 DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
@@ -103,7 +162,7 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
     maxVals(maxVals),
     minVals(minVals),
     splitDim(size_t(-1)),
-    splitValue(DBL_MAX),
+    splitValue(std::numeric_limits<ElemType>::max()),
     logNegError(logNegError),
     subtreeLeavesLogNegError(-DBL_MAX),
     subtreeLeaves(0),
@@ -127,7 +186,7 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
     maxVals(maxVals),
     minVals(minVals),
     splitDim(size_t(-1)),
-    splitValue(DBL_MAX),
+    splitValue(std::numeric_limits<ElemType>::max()),
     logNegError(LogNegativeError(totalPoints)),
     subtreeLeavesLogNegError(-DBL_MAX),
     subtreeLeaves(0),
@@ -156,12 +215,13 @@ double DTree<MatType, VecType, TagType>::LogNegativeError(const size_t totalPoin
   double err = 2 * std::log((double) (end - start)) -
                2 * std::log((double) totalPoints);
 
-  arma::vec valDiffs = maxVals - minVals;
-  for (size_t i = 0; i < maxVals.n_elem; ++i)
+  VecType valDiffs = maxVals - minVals;
+  typename VecType::iterator valEnd = valDiffs.end();
+  for (typename VecType::iterator i = valDiffs.begin(); i != valEnd; ++i)
   {
     // Ignore very small dimensions to prevent overflow.
-    if (valDiffs[i] > 1e-50)
-      err -= std::log(valDiffs[i]);
+    if (*i > 1e-50)
+      err -= std::log(*i);
   }
 
   return err;
@@ -191,11 +251,11 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
   // Loop through each dimension.
 #ifdef _WIN32
   #pragma omp parallel for default(none) \
-    shared(testSize, cvData, prunedSequence, regularizationConstants, dataset)
-  for (intmax_t dim = 0; fold < (intmax_t) maxVals.n_elem; ++dim)
+    shared(minError, splitFound, points, data, minVals, maxVals, minLeafSize, maxLeafSize)
+  for (intmax_t dim = 0; dim < (intmax_t) maxVals.n_elem; ++dim)
 #else
   #pragma omp parallel for default(none) \
-    shared(testSize, cvData, prunedSequence, regularizationConstants, dataset)
+    shared(minError, splitFound, points, data, minVals, maxVals, minLeafSize, maxLeafSize)
   for (size_t dim = 0; dim < maxVals.n_elem; ++dim)
 #endif
   {
@@ -220,7 +280,7 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
     double volumeWithoutDim = logVolume - std::log(max - min);
 
     // Get the values for the dimension.
-    arma::rowvec dimVec = data.row(dim).subvec(start, end - 1);
+    VecType dimVec = data.row(dim).subvec(start, end - 1);
 
     // Sort the values in ascending order.
     dimVec = arma::sort(dimVec);
@@ -265,9 +325,9 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
       }
     }
 
-    double actualMinDimError = std::log(minDimError)
-        - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+    double actualMinDimError = std::log(minDimError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
 
+#pragma omp critical
     if ((actualMinDimError > minError) && dimSplitFound)
     {
       // Calculate actual error (in logspace) by adding terms back to our
@@ -275,10 +335,8 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
       minError = actualMinDimError;
       splitDim = dim;
       splitValue = dimSplitValue;
-      leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols)
-          - volumeWithoutDim;
-      rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols)
-          - volumeWithoutDim;
+      leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+      rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
       splitFound = true;
     } // end if better split found in this dimension.
   }
@@ -289,7 +347,7 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
 template <typename MatType, typename VecType, typename TagType>
 size_t DTree<MatType, VecType, TagType>::SplitData(MatType& data,
                                                    const size_t splitDim,
-                                                   const double splitValue,
+                                                   const ElemType splitValue,
                                                    arma::Col<size_t>& oldFromNew) const
 {
   // Swap all columns such that any columns with value in dimension splitDim
@@ -310,7 +368,7 @@ size_t DTree<MatType, VecType, TagType>::SplitData(MatType& data,
 
     data.swap_cols(left, right);
 
-    // Store the mapping from old to new.
+    // Store the mapping from old to new. Do not put std::swap here...
     const size_t tmp = oldFromNew[left];
     oldFromNew[left] = oldFromNew[right];
     oldFromNew[right] = tmp;
@@ -353,6 +411,7 @@ double DTree<MatType, VecType, TagType>::Grow(MatType& data,
     {
       // Move the data around for the children to have points in a node lie
       // contiguously (to increase efficiency during the training).
+//      const size_t splitIndex = splt::PerformSplit(data, start, end - start, )
       const size_t splitIndex = SplitData(data, dim, splitValueTmp, oldFromNew);
 
       // Make max and min vals for the children.
@@ -372,10 +431,8 @@ double DTree<MatType, VecType, TagType>::Grow(MatType& data,
       left = new DTree(maxValsL, minValsL, start, splitIndex, leftError);
       right = new DTree(maxValsR, minValsR, splitIndex, end, rightError);
 
-      leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize,
-          minLeafSize);
-      rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize,
-          minLeafSize);
+      leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize, minLeafSize);
+      rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize, minLeafSize);
 
       // Store values of R(T~) and |T~|.
       subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
@@ -440,14 +497,12 @@ double DTree<MatType, VecType, TagType>::Grow(MatType& data,
 
     if (right->SubtreeLeaves() > 1)
     {
-      const double exponent = 2 * std::log((double) data.n_cols) + logVolume +
-          right->AlphaUpper();
+      const double exponent = 2 * std::log((double) data.n_cols) + logVolume + right->AlphaUpper();
 
       tmpAlphaSum += std::exp(exponent);
     }
 
-    alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols)
-        - logVolume;
+    alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols) - logVolume;
 
     double gT;
     if (useVolReg)
@@ -613,15 +668,8 @@ double DTree<MatType, VecType, TagType>::ComputeValue(const VecType& query) cons
   }
   else
   {
-    if (query[splitDim] <= splitValue)
-    {
-      // If left subtree, go to left child.
-      return left->ComputeValue(query);
-    }
-    else  // If right subtree, go to right child
-    {
-      return right->ComputeValue(query);
-    }
+    // Return either of the two children - left or right, depending on the splitValue
+    return (query[splitDim] <= splitValue) ? left->ComputeValue(query) : right->ComputeValue(query);
   }
 
   return 0.0;
@@ -630,7 +678,7 @@ double DTree<MatType, VecType, TagType>::ComputeValue(const VecType& query) cons
 
 // Index the buckets for possible usage later.
 template <typename MatType, typename VecType, typename TagType>
-TagType DTree<MatType, VecType, TagType>::TagTree(const TagType tag)
+TagType DTree<MatType, VecType, TagType>::TagTree(const TagType& tag)
 {
   if (subtreeLeaves == 1)
   {
@@ -654,15 +702,10 @@ TagType DTree<MatType, VecType, TagType>::FindBucket(const VecType& query) const
   {
     return bucketTag;
   }
-  else if (query[splitDim] <= splitValue)
-  {
-    // If left subtree, go to left child.
-    return left->FindBucket(query);
-  }
   else
   {
-    // If right subtree, go to right child.
-    return right->FindBucket(query);
+    // Return the tag from either of the two children - left or right.
+    return (query[splitDim] <= splitValue) ? left->FindBucket(query) : right->FindBucket(query);
   }
 }
 




More information about the mlpack-git mailing list