[mlpack-git] master: - DET changes propagated to tests. (0b58fd9)

gitdub at mlpack.org gitdub at mlpack.org
Tue Nov 1 15:22:39 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea

>---------------------------------------------------------------

commit 0b58fd9a4542b00e12c4525b86d390154631a9ad
Author: theJonan <ivan at jonan.info>
Date:   Fri Oct 14 17:42:25 2016 +0300

    - DET changes propagated to tests.


>---------------------------------------------------------------

0b58fd9a4542b00e12c4525b86d390154631a9ad
 src/mlpack/methods/det/dtree_impl.hpp   | 29 ++++++++++++++++-------------
 src/mlpack/tests/det_test.cpp           | 27 ++++++++++++---------------
 src/mlpack/tests/serialization_test.cpp | 15 ++++++++-------
 3 files changed, 36 insertions(+), 35 deletions(-)

diff --git a/src/mlpack/methods/det/dtree_impl.hpp b/src/mlpack/methods/det/dtree_impl.hpp
index 58131fb..2261bbf 100644
--- a/src/mlpack/methods/det/dtree_impl.hpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -250,12 +250,12 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
 
   // Loop through each dimension.
 #ifdef _WIN32
-  #pragma omp parallel for default(none) \
-    shared(minError, splitFound, points, data, minVals, maxVals, minLeafSize, maxLeafSize)
+  #pragma omp parallel for default(shared) \
+    shared(splitValue, splitDim, data)
   for (intmax_t dim = 0; dim < (intmax_t) maxVals.n_elem; ++dim)
 #else
-  #pragma omp parallel for default(none) \
-    shared(minError, splitFound, points, data, minVals, maxVals, minLeafSize, maxLeafSize)
+  #pragma omp parallel for default(shared) \
+    shared(splitValue, splitDim, data)
   for (size_t dim = 0; dim < maxVals.n_elem; ++dim)
 #endif
   {
@@ -327,17 +327,20 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
 
     double actualMinDimError = std::log(minDimError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
 
-#pragma omp critical
+#pragma omp atomic
     if ((actualMinDimError > minError) && dimSplitFound)
     {
-      // Calculate actual error (in logspace) by adding terms back to our
-      // estimate.
-      minError = actualMinDimError;
-      splitDim = dim;
-      splitValue = dimSplitValue;
-      leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
-      rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
-      splitFound = true;
+#pragma omp critical DTreeFindUpdate
+      {
+        // Calculate actual error (in logspace) by adding terms back to our
+        // estimate.
+        minError = actualMinDimError;
+        splitDim = dim;
+        splitValue = dimSplitValue;
+        leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+        rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+        splitFound = true;
+      }
     } // end if better split found in this dimension.
   }
 
diff --git a/src/mlpack/tests/det_test.cpp b/src/mlpack/tests/det_test.cpp
index 7e32def..2b0ef37 100644
--- a/src/mlpack/tests/det_test.cpp
+++ b/src/mlpack/tests/det_test.cpp
@@ -42,7 +42,7 @@ BOOST_AUTO_TEST_CASE(TestGetMaxMinVals)
            << 5 << 0 << 1 << 7 << 1 << arma::endr
            << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree tree(testData);
+  DTree<arma::mat, arma::vec> tree(testData);
 
   BOOST_REQUIRE_EQUAL(tree.maxVals[0], 7);
   BOOST_REQUIRE_EQUAL(tree.minVals[0], 3);
@@ -57,7 +57,7 @@ BOOST_AUTO_TEST_CASE(TestComputeNodeError)
   arma::vec maxVals("7 7 8");
   arma::vec minVals("3 0 1");
 
-  DTree testDTree(maxVals, minVals, 5);
+  DTree<arma::mat, arma::vec> testDTree(maxVals, minVals, 5);
   double trueNodeError = -log(4.0) - log(7.0) - log(7.0);
 
   BOOST_REQUIRE_CLOSE((double) testDTree.logNegError, trueNodeError, 1e-10);
@@ -75,7 +75,7 @@ BOOST_AUTO_TEST_CASE(TestWithinRange)
   arma::vec maxVals("7 7 8");
   arma::vec minVals("3 0 1");
 
-  DTree testDTree(maxVals, minVals, 5);
+  DTree<arma::mat, arma::vec> testDTree(maxVals, minVals, 5);
 
   arma::vec testQuery(3);
   testQuery << 4.5 << 2.5 << 2;
@@ -95,11 +95,10 @@ BOOST_AUTO_TEST_CASE(TestFindSplit)
            << 5 << 0 << 1 << 7 << 1 << arma::endr
            << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree testDTree(testData);
+  DTree<arma::mat, arma::vec> testDTree(testData);
 
   size_t obDim, trueDim;
-  double trueLeftError, obLeftError, trueRightError, obRightError,
-      obSplit, trueSplit;
+  double trueLeftError, obLeftError, trueRightError, obRightError, obSplit, trueSplit;
 
   trueDim = 2;
   trueSplit = 5.5;
@@ -107,8 +106,7 @@ BOOST_AUTO_TEST_CASE(TestFindSplit)
   trueRightError = 2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5));
 
   testDTree.logVolume = log(7.0) + log(4.0) + log(7.0);
-  BOOST_REQUIRE(testDTree.FindSplit(testData, obDim, obSplit, obLeftError,
-      obRightError, 1));
+  BOOST_REQUIRE(testDTree.FindSplit(testData, obDim, obSplit, obLeftError, obRightError, 1));
 
   BOOST_REQUIRE(trueDim == obDim);
   BOOST_REQUIRE_CLOSE(trueSplit, obSplit, 1e-10);
@@ -125,7 +123,7 @@ BOOST_AUTO_TEST_CASE(TestSplitData)
            << 5 << 0 << 1 << 7 << 1 << arma::endr
            << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree testDTree(testData);
+  DTree<arma::mat, arma::vec> testDTree(testData);
 
   arma::Col<size_t> oTest(5);
   oTest << 1 << 2 << 3 << 4 << 5;
@@ -133,8 +131,7 @@ BOOST_AUTO_TEST_CASE(TestSplitData)
   size_t splitDim = 2;
   double trueSplitVal = 5.5;
 
-  size_t splitInd = testDTree.SplitData(testData, splitDim, trueSplitVal,
-      oTest);
+  size_t splitInd = testDTree.SplitData(testData, splitDim, trueSplitVal, oTest);
 
   BOOST_REQUIRE_EQUAL(splitInd, 2); // 2 points on left side.
 
@@ -169,7 +166,7 @@ BOOST_AUTO_TEST_CASE(TestGrow)
   rlError = 2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5));
   rrError = 2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5));
 
-  DTree testDTree(testData);
+  DTree<arma::mat, arma::vec> testDTree(testData);
   double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
 
   BOOST_REQUIRE_EQUAL(oTest[0], 0);
@@ -222,7 +219,7 @@ BOOST_AUTO_TEST_CASE(TestPruneAndUpdate)
 
   arma::Col<size_t> oTest(5);
   oTest << 0 << 1 << 2 << 3 << 4;
-  DTree testDTree(testData);
+  DTree<arma::mat, arma::vec> testDTree(testData);
   double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
   alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
 
@@ -255,7 +252,7 @@ BOOST_AUTO_TEST_CASE(TestComputeValue)
   arma::Col<size_t> oTest(5);
   oTest << 0 << 1 << 2 << 3 << 4;
 
-  DTree testDTree(testData);
+  DTree<arma::mat, arma::vec> testDTree(testData);
   double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
 
   double d1 = (2.0 / 5.0) / exp(log(4.0) + log(7.0) + log(4.5));
@@ -298,7 +295,7 @@ BOOST_AUTO_TEST_CASE(TestVariableImportance)
   arma::Col<size_t> oTest(5);
   oTest << 0 << 1 << 2 << 3 << 4;
 
-  DTree testDTree(testData);
+  DTree<arma::mat, arma::vec> testDTree(testData);
   testDTree.Grow(testData, oTest, false, 2, 1);
 
   arma::vec imps;
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index e6aecc7..d7ea76e 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -853,18 +853,19 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionTest)
 BOOST_AUTO_TEST_CASE(DETTest)
 {
   using det::DTree;
+  typedef DTree<arma::mat, arma::vec>   DTreeX;
 
   // Create a density estimation tree on a random dataset.
   arma::mat dataset = arma::randu<arma::mat>(25, 5000);
 
-  DTree tree(dataset);
+  DTreeX tree(dataset);
 
   arma::mat otherDataset = arma::randu<arma::mat>(5, 100);
-  DTree xmlTree, binaryTree, textTree(otherDataset);
+  DTreeX xmlTree, binaryTree, textTree(otherDataset);
 
   SerializeObjectAll(tree, xmlTree, binaryTree, textTree);
 
-  std::stack<DTree*> stack, xmlStack, binaryStack, textStack;
+  std::stack<DTreeX*> stack, xmlStack, binaryStack, textStack;
   stack.push(&tree);
   xmlStack.push(&xmlTree);
   binaryStack.push(&binaryTree);
@@ -873,10 +874,10 @@ BOOST_AUTO_TEST_CASE(DETTest)
   while (!stack.empty())
   {
     // Get the top node from the stack.
-    DTree* node = stack.top();
-    DTree* xmlNode = xmlStack.top();
-    DTree* binaryNode = binaryStack.top();
-    DTree* textNode = textStack.top();
+    DTreeX* node = stack.top();
+    DTreeX* xmlNode = xmlStack.top();
+    DTreeX* binaryNode = binaryStack.top();
+    DTreeX* textNode = textStack.top();
 
     stack.pop();
     xmlStack.pop();




More information about the mlpack-git mailing list