[mlpack-svn] r13244 - mlpack/trunk/src/mlpack/methods/det

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 18 16:41:52 EDT 2012


Author: rcurtin
Date: 2012-07-18 16:41:52 -0400 (Wed, 18 Jul 2012)
New Revision: 13244

Modified:
   mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
   mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Use double for ComputeValue(), and take a reference instead of a pointer; also
fix const-correctness of function.


Modified: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-07-16 19:36:33 UTC (rev 13243)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-07-18 20:41:52 UTC (rev 13244)
@@ -125,7 +125,7 @@
   Timer::Start("DET/EstimationTime");
   for (size_t i = 0; i < training_data.n_cols; i++) {
     arma::Col<double> test_p = training_data.unsafe_col(i);
-    long double f = dtree_opt->ComputeValue(&test_p);
+    long double f = dtree_opt->ComputeValue(test_p);
     if (fp != NULL)
       fprintf(fp, "%Lg\n", f);
   } // end for
@@ -160,7 +160,7 @@
     Timer::Start("DET/TestSetEstimation");
     for (size_t i = 0; i < test_data.n_cols; i++) {
       arma::Col<double> test_p = test_data.unsafe_col(i);
-      long double f = dtree_opt->ComputeValue(&test_p);
+      long double f = dtree_opt->ComputeValue(test_p);
       if (fp != NULL)
 	fprintf(fp, "%Lg\n", f);
     } // end for

Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-07-16 19:36:33 UTC (rev 13243)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-07-18 20:41:52 UTC (rev 13244)
@@ -148,7 +148,7 @@
       for (size_t i = 0; i < dataset->n_cols; ++i)
       {
         arma::Col<eT> test_p = dataset->unsafe_col(i);
-        outfile << dtree->ComputeValue(&test_p) << endl;
+        outfile << dtree->ComputeValue(test_p) << endl;
       }
     }
     else
@@ -243,7 +243,7 @@
       for (size_t i = 0; i < test.n_cols; i++)
       {
         arma::Col<eT> test_point = test.unsafe_col(i);
-        val_cv += dtree_cv->ComputeValue(&test_point);
+        val_cv += dtree_cv->ComputeValue(test_point);
       }
 
       // Update the cv error value.
@@ -259,7 +259,7 @@
     for (size_t i = 0; i < test.n_cols; ++i)
     {
       arma::Col<eT> test_point = test.unsafe_col(i);
-      val_cv += dtree_cv->ComputeValue(&test_point);
+      val_cv += dtree_cv->ComputeValue(test_point);
     }
 
     // Update the cv error value.

Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-07-16 19:36:33 UTC (rev 13243)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-07-18 20:41:52 UTC (rev 13244)
@@ -204,7 +204,7 @@
   double PruneAndUpdate(const double old_alpha, const bool useVolReg = false);
 
   // compute the density at a given point
-  cT ComputeValue(VecType* query);
+  double ComputeValue(const arma::vec& query) const;
 
   // print the tree (in a DFS manner)
   void WriteTree(size_t level, FILE *fp);

Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-07-16 19:36:33 UTC (rev 13243)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-07-18 20:41:52 UTC (rev 13244)
@@ -492,24 +492,32 @@
 
 
 template<typename eT, typename cT>
-cT DTree<eT, cT>::ComputeValue(VecType* query)
+double DTree<eT, cT>::ComputeValue(const arma::vec& query) const
 {
-  assert(query->n_elem == maxVals.n_elem);
+  Log::Assert(query.n_elem == maxVals.n_elem);
 
   if (root_ == 1) // If we are the root...
+  {
     // Check if the query is within range.
-    if (!WithinRange(*query))
+    if (!WithinRange(query))
       return 0.0;
+  }
 
   if (subtree_leaves_ == 1)  // If we are a leaf...
+  {
     return ratio_ * v_t_inv_;
+  }
   else
   {
-    if ((*query)[split_dim_] <= split_value_)
+    if (query[split_dim_] <= split_value_)
+    {
       // If left subtree, go to left child.
       return left_->ComputeValue(query);
+    }
     else  // If right subtree, go to right child
+    {
       return right_->ComputeValue(query);
+    }
   }
 }
 




More information about the mlpack-svn mailing list