[mlpack-svn] r15777 - mlpack/trunk/src/mlpack/methods/fastmks

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Sep 13 15:53:57 EDT 2013


Author: rcurtin
Date: Fri Sep 13 15:53:57 2013
New Revision: 15777

Log:
Print base cases and scores as output even when not debugging.


Modified:
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp	Fri Sep 13 15:53:57 2013
@@ -36,6 +36,9 @@
   if (!naive)
     referenceTree = new TreeType(referenceSet);
 
+  if (!naive && !single)
+    queryTree = new TreeType(referenceSet);
+
   Timer::Stop("tree_building");
 }
 
@@ -86,6 +89,9 @@
   if (!naive)
     referenceTree = new TreeType(referenceSet, metric);
 
+  if (!naive && !single)
+    queryTree = new TreeType(referenceSet, metric);
+
   Timer::Stop("tree_building");
 }
 
@@ -132,7 +138,9 @@
     naive(naive),
     metric(referenceTree->Metric())
 {
-  // Nothing to do.
+  // The query tree cannot be the same as the reference tree.
+  if (referenceTree)
+    queryTree = new TreeType(*referenceTree);
 }
 
 // Two datasets, pre-built trees.
@@ -166,6 +174,12 @@
     if (referenceTree)
       delete referenceTree;
   }
+  else if (&querySet == &referenceSet)
+  {
+    // The user passed in a reference tree which we needed to copy.
+    if (queryTree)
+      delete queryTree;
+  }
 }
 
 template<typename KernelType, typename TreeType>
@@ -228,6 +242,9 @@
 
     Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
 
+    Log::Info << rules.BaseCases() << " base cases." << std::endl;
+    Log::Info << rules.Scores() << " scores." << std::endl;
+
     Timer::Stop("computing_products");
     return;
   }
@@ -238,14 +255,13 @@
 
   typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
 
-  if (queryTree)
-    traverser.Traverse(*queryTree, *referenceTree);
-  else
-    traverser.Traverse(*referenceTree, *referenceTree);
+  traverser.Traverse(*queryTree, *referenceTree);
 
   const size_t numPrunes = traverser.NumPrunes();
 
   Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+  Log::Info << rules.BaseCases() << " base cases." << std::endl;
+  Log::Info << rules.Scores() << " scores." << std::endl;
 
   Timer::Stop("computing_products");
   return;



More information about the mlpack-svn mailing list