[mlpack-svn] r13235 - mlpack/trunk/src/mlpack/methods/det

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Jul 16 11:41:51 EDT 2012


Author: rcurtin
Date: 2012-07-16 11:41:51 -0400 (Mon, 16 Jul 2012)
New Revision: 13235

Modified:
   mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Clean up constructors.  Don't hand-allocate memory for minVals and maxVals;
remove GetMinMaxVals_() because it was only ever used once; inline it directly
into the constructor where it was used.


Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-07-16 02:59:49 UTC (rev 13234)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-07-16 15:41:51 UTC (rev 13235)
@@ -117,7 +117,7 @@
                    string unprunedTreeOutput = "")
 {
   // Initialize the tree.
-  DTree<eT>* dtree = new DTree<eT>(dataset);
+  DTree<eT>* dtree = new DTree<eT>(*dataset);
 
   // Getting ready to grow the tree...
   arma::Col<size_t> old_from_new(dataset->n_cols);
@@ -221,7 +221,7 @@
     assert(train->n_cols + test.n_cols == cvdata->n_cols);
 
     // Initialize the tree.
-    DTree<eT>* dtree_cv = new DTree<eT>(train);
+    DTree<eT>* dtree_cv = new DTree<eT>(*train);
 
     // Getting ready to grow the tree...
     arma::Col<size_t> old_from_new_cv(train->n_cols);
@@ -289,7 +289,7 @@
   Log::Info << "Optimal alpha: " << optimal_alpha << "." << std::endl;
 
   // Initialize the tree.
-  DTree<eT>* dtree_opt = new DTree<eT>(dataset);
+  DTree<eT>* dtree_opt = new DTree<eT>(*dataset);
 
   // Getting ready to grow the tree...
   for (size_t i = 0; i < old_from_new.n_elem; i++)

Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-07-16 02:59:49 UTC (rev 13234)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-07-16 15:41:51 UTC (rev 13235)
@@ -104,8 +104,8 @@
 
   // since we are using uniform density, we need
   // the max and min of every dimension for every node
-  VecType* max_vals_;
-  VecType* min_vals_;
+  arma::vec maxVals;
+  arma::vec minVals;
 
   // the tag for the leaf used for hashing points
   int bucket_tag_;
@@ -160,10 +160,6 @@
                    const double splitValue,
                    arma::Col<size_t>& oldFromNew) const;
 
-  void GetMaxMinVals_(MatType* data,
-                      VecType* max_vals,
-                      VecType* min_vals);
-
   bool WithinRange_(VecType* query);
 
   ///////////////////// Public Functions //////////////////////////////////////
@@ -174,23 +170,23 @@
   // Root node initializer
   // with the bounding box of the data
   // it contains instead of just the data.
-  DTree(VecType* max_vals,
-        VecType* min_vals,
+  DTree(const arma::vec& maxVals,
+        const arma::vec& minVals,
         size_t total_points);
 
   // Root node initializer
   // with the data, no bounding box.
-  DTree(MatType* data);
+  DTree(arma::mat& data);
 
   // Non-root node initializers
-  DTree(VecType* max_vals,
-        VecType* min_vals,
+  DTree(const arma::vec& max_vals,
+        const arma::vec& min_vals,
         size_t start,
         size_t end,
         cT error);
 
-  DTree(VecType* max_vals,
-        VecType* min_vals,
+  DTree(const arma::vec& max_vals,
+        const arma::vec& min_vals,
         size_t total_points,
         size_t start,
         size_t end);

Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-07-16 02:59:49 UTC (rev 13234)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-07-16 15:41:51 UTC (rev 13235)
@@ -22,7 +22,7 @@
   // log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
   return 2 * std::log((double) (end_ - start_)) -
          2 * std::log((double) totalPoints) -
-         arma::accu(arma::log((*max_vals_) - (*min_vals_)));
+         arma::accu(arma::log(maxVals - minVals));
 }
 
 // This function finds the best split with respect to the L2-error, by trying
@@ -39,8 +39,8 @@
 {
   // Ensure the dimensionality of the data is the same as the dimensionality of
   // the bounding rectangle.
-  assert(data.n_rows == max_vals_->n_elem);
-  assert(data.n_rows == min_vals_->n_elem);
+  assert(data.n_rows == maxVals.n_elem);
+  assert(data.n_rows == minVals.n_elem);
 
   const size_t points = end_ - start_;
 
@@ -48,12 +48,12 @@
   bool splitFound = false;
 
   // Loop through each dimension.
-  for (size_t dim = 0; dim < max_vals_->n_elem; dim++)
+  for (size_t dim = 0; dim < maxVals.n_elem; dim++)
   {
     // Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
     // think of how to do that...
-    const double min = (*min_vals_)[dim];
-    const double max = (*max_vals_)[dim];
+    const double min = minVals[dim];
+    const double max = maxVals[dim];
 
     // If there is nothing to split in this dimension, move on.
     if (max - min == 0.0)
@@ -69,11 +69,11 @@
 
     // Find the log volume of all the other dimensions.
     double volumeWithoutDim = 0;
-    for (size_t i = 0; i < max_vals_->n_elem; ++i)
+    for (size_t i = 0; i < maxVals.n_elem; ++i)
     {
-      if (((*max_vals_)[i] - (*min_vals_)[i] > 0.0) && (i != dim))
+      if ((maxVals[i] - minVals[i] > 0.0) && (i != dim))
       {
-        volumeWithoutDim += std::log((*max_vals_)[i] - (*min_vals_)[i]);
+        volumeWithoutDim += std::log(maxVals[i] - minVals[i]);
       }
     }
 
@@ -188,30 +188,9 @@
 
 
 template<typename eT, typename cT>
-void DTree<eT, cT>::GetMaxMinVals_(MatType* data,
-                                   VecType *max_vals,
-                                   VecType *min_vals)
-{
-  max_vals->set_size(data->n_rows);
-  min_vals->set_size(data->n_rows);
-
-  MatType temp_d = arma::trans(*data);
-
-  for (size_t i = 0; i < temp_d.n_cols; ++i)
-  {
-    VecType dim_vals = arma::sort(temp_d.col(i));
-    (*min_vals)[i] = dim_vals[0];
-    (*max_vals)[i] = dim_vals[dim_vals.n_elem - 1];
-  }
-}
-
-
-template<typename eT, typename cT>
 DTree<eT, cT>::DTree() :
     start_(0),
     end_(0),
-    max_vals_(NULL),
-    min_vals_(NULL),
     left_(NULL),
     right_(NULL)
 { /* Nothing to do. */ }
@@ -219,13 +198,13 @@
 
 // Root node initializers
 template<typename eT, typename cT>
-DTree<eT, cT>::DTree(VecType* max_vals,
-                     VecType* min_vals,
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+                     const arma::vec& minVals,
                      size_t total_points) :
     start_(0),
     end_(total_points),
-    max_vals_(max_vals),
-    min_vals_(min_vals),
+    maxVals(maxVals),
+    minVals(minVals),
     left_(NULL),
     right_(NULL)
 {
@@ -237,19 +216,33 @@
 
 
 template<typename eT, typename cT>
-DTree<eT, cT>::DTree(MatType* data) :
+DTree<eT, cT>::DTree(arma::mat& data) :
     start_(0),
-    end_(data->n_cols),
+    end_(data.n_cols),
     left_(NULL),
     right_(NULL)
 {
-  max_vals_ = new VecType();
-  min_vals_ = new VecType();
+  maxVals.set_size(data.n_rows);
+  minVals.set_size(data.n_cols);
 
-  GetMaxMinVals_(data, max_vals_, min_vals_);
+  // Initialize to first column; values will be overwritten if necessary.
+  maxVals = data.col(0);
+  minVals = data.col(0);
 
-  error_ = -std::exp(LogNegativeError(data->n_cols));
+  // Loop over data to extract maximum and minimum values in each dimension.
+  for (size_t i = 1; i < data.n_cols; ++i)
+  {
+    for (size_t j = 0; j < data.n_rows; ++j)
+    {
+      if (data(j, i) > maxVals[j])
+        maxVals[j] = data(j, i);
+      if (data(j, i) < minVals[j])
+        minVals[j] = data(j, i);
+    }
+  }
 
+  error_ = -std::exp(LogNegativeError(data.n_cols));
+
   bucket_tag_ = -1;
   root_ = true;
 }
@@ -257,16 +250,16 @@
 
 // Non-root node initializers
 template<typename eT, typename cT>
-DTree<eT, cT>::DTree(VecType* max_vals,
-                     VecType* min_vals,
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+                     const arma::vec& minVals,
                      size_t start,
                      size_t end,
                      cT error) :
     start_(start),
     end_(end),
     error_(error),
-    max_vals_(max_vals),
-    min_vals_(min_vals),
+    maxVals(maxVals),
+    minVals(minVals),
     left_(NULL),
     right_(NULL)
 {
@@ -276,15 +269,15 @@
 
 
 template<typename eT, typename cT>
-DTree<eT, cT>::DTree(VecType* max_vals,
-                     VecType* min_vals,
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+                     const arma::vec& minVals,
                      size_t total_points,
                      size_t start,
                      size_t end) :
     start_(start),
     end_(end),
-    max_vals_(max_vals),
-    min_vals_(min_vals),
+    maxVals(maxVals),
+    minVals(minVals),
     left_(NULL),
     right_(NULL)
 {
@@ -303,12 +296,6 @@
 
   if (right_ != NULL)
     delete right_;
-
-  if (min_vals_ != NULL)
-    delete min_vals_;
-
-  if (max_vals_ != NULL)
-    delete max_vals_;
 }
 
 
@@ -320,8 +307,8 @@
                        size_t maxLeafSize,
                        size_t minLeafSize)
 {
-  assert(data->n_rows == max_vals_->n_elem);
-  assert(data->n_rows == min_vals_->n_elem);
+  assert(data->n_rows == maxVals.n_elem);
+  assert(data->n_rows == minVals.n_elem);
 
   cT left_g, right_g;
 
@@ -330,10 +317,10 @@
 
   // Compute the v_t_inv: the inverse of the volume of the node.
   cT log_vol_t = 0;
-  for (size_t i = 0; i < max_vals_->n_elem; ++i)
-    if ((*max_vals_)[i] - (*min_vals_)[i] > 0.0)
+  for (size_t i = 0; i < maxVals.n_elem; ++i)
+    if (maxVals[i] - minVals[i] > 0.0)
       // Use log to prevent overflow.
-      log_vol_t += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
+      log_vol_t += (cT) std::log(maxVals[i] - minVals[i]);
 
   // Check for overflow.
   assert(std::exp(log_vol_t) > 0.0);
@@ -355,13 +342,13 @@
           *old_from_new);
 
       // Make max and min vals for the children.
-      VecType* max_vals_l = new VecType(*max_vals_);
-      VecType* max_vals_r = new VecType(*max_vals_);
-      VecType* min_vals_l = new VecType(*min_vals_);
-      VecType* min_vals_r = new VecType(*min_vals_);
+      arma::vec max_vals_l(maxVals);
+      arma::vec max_vals_r(maxVals);
+      arma::vec min_vals_l(minVals);
+      arma::vec min_vals_r(minVals);
 
-      (*max_vals_l)[dim] = splitValue;
-      (*min_vals_r)[dim] = splitValue;
+      max_vals_l[dim] = splitValue;
+      min_vals_r[dim] = splitValue;
 
       // Store split dim and split val in the node.
       split_value_ = splitValue;
@@ -524,7 +511,7 @@
 bool DTree<eT, cT>::WithinRange_(VecType* query)
 {
   for (size_t i = 0; i < query->n_elem; ++i)
-    if (((*query)[i] < (*min_vals_)[i]) || ((*query)[i] > (*max_vals_)[i]))
+    if (((*query)[i] < minVals[i]) || ((*query)[i] > maxVals[i]))
       return false;
 
   return true;
@@ -534,7 +521,7 @@
 template<typename eT, typename cT>
 cT DTree<eT, cT>::ComputeValue(VecType* query)
 {
-  assert(query->n_elem == max_vals_->n_elem);
+  assert(query->n_elem == maxVals.n_elem);
 
   if (root_ == 1) // If we are the root...
     // Check if the query is within range.
@@ -601,7 +588,7 @@
 template<typename eT, typename cT>
 int DTree<eT, cT>::FindBucket(VecType* query)
 {
-  assert(query->n_elem == max_vals_->n_elem);
+  assert(query->n_elem == maxVals.n_elem);
 
   if (subtree_leaves_ == 1) // If we are a leaf...
   {




More information about the mlpack-svn mailing list