[mlpack-git] master: Merge branch 'master' of https://github.com/mlpack/mlpack (de88672)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Oct 2 19:20:49 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7a8b0e1292677b71888fad313772c63bcf0e7b80...de88672879a1893ebfc131538c64e7755251337c

>---------------------------------------------------------------

commit de88672879a1893ebfc131538c64e7755251337c
Merge: 0e910b1 7a8b0e1
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Oct 2 23:20:25 2015 +0000

    Merge branch 'master' of https://github.com/mlpack/mlpack


>---------------------------------------------------------------

de88672879a1893ebfc131538c64e7755251337c
 COPYRIGHT.txt                                      |   1 +
 src/mlpack/core.hpp                                |   1 +
 .../softmax_regression/softmax_regression.hpp      |  73 ++++++++++++---
 .../softmax_regression_function.cpp                |  62 ++++++++-----
 .../softmax_regression_function.hpp                |  55 ++++++++----
 .../softmax_regression/softmax_regression_impl.hpp |  84 +++++++++++------
 src/mlpack/tests/serialization_test.cpp            |  25 +++++-
 src/mlpack/tests/softmax_regression_test.cpp       | 100 +++++++++++++++++----
 8 files changed, 303 insertions(+), 98 deletions(-)

diff --cc src/mlpack/tests/serialization_test.cpp
index 424126c,18b8e81..316bdf6
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@@ -25,7 -25,7 +25,8 @@@
  #include <mlpack/methods/perceptron/perceptron.hpp>
  #include <mlpack/methods/logistic_regression/logistic_regression.hpp>
  #include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+ #include <mlpack/methods/softmax_regression/softmax_regression.hpp>
 +#include <mlpack/methods/det/dtree.hpp>
  
  using namespace mlpack;
  using namespace mlpack::distribution;
@@@ -800,217 -800,27 +801,239 @@@ BOOST_AUTO_TEST_CASE(AllkNNTest
    CheckMatrices(neighbors, xmlNeighbors, textNeighbors, binaryNeighbors);
  }
  
+ BOOST_AUTO_TEST_CASE(SoftmaxRegressionTest)
+ {
+   using regression::SoftmaxRegression;
+ 
+   arma::mat dataset = arma::randu<arma::mat>(5, 1000);
+   arma::Row<size_t> labels(1000);
+   for (size_t i = 0; i < 500; ++i)
+     labels[i] = 0;
+   for (size_t i = 500; i < 1000; ++i)
+     labels[i] = 1;
+ 
+   SoftmaxRegression<> sr(dataset, labels, 2);
+ 
+   SoftmaxRegression<> srXml(dataset.n_rows, 2);
+   SoftmaxRegression<> srText(dataset.n_rows, 2);
+   SoftmaxRegression<> srBinary(dataset.n_rows, 2);
+ 
+   SerializeObjectAll(sr, srXml, srText, srBinary);
+ 
+   CheckMatrices(sr.Parameters(), srXml.Parameters(), srText.Parameters(),
+       srBinary.Parameters());
+ }
+ 
 +BOOST_AUTO_TEST_CASE(DETTest)
 +{
 +  using det::DTree;
 +
 +  // Create a density estimation tree on a random dataset.
 +  arma::mat dataset = arma::randu<arma::mat>(25, 5000);
 +
 +  DTree tree(dataset);
 +
 +  arma::mat otherDataset = arma::randu<arma::mat>(5, 100);
 +  DTree xmlTree, binaryTree, textTree(otherDataset);
 +
 +  SerializeObjectAll(tree, xmlTree, binaryTree, textTree);
 +
 +  std::stack<DTree*> stack, xmlStack, binaryStack, textStack;
 +  stack.push(&tree);
 +  xmlStack.push(&xmlTree);
 +  binaryStack.push(&binaryTree);
 +  textStack.push(&textTree);
 +
 +  while (!stack.empty())
 +  {
 +    // Get the top node from the stack.
 +    DTree* node = stack.top();
 +    DTree* xmlNode = xmlStack.top();
 +    DTree* binaryNode = binaryStack.top();
 +    DTree* textNode = textStack.top();
 +
 +    stack.pop();
 +    xmlStack.pop();
 +    binaryStack.pop();
 +    textStack.pop();
 +
 +    // Check that all the members are the same.
 +    BOOST_REQUIRE_EQUAL(node->Start(), xmlNode->Start());
 +    BOOST_REQUIRE_EQUAL(node->Start(), binaryNode->Start());
 +    BOOST_REQUIRE_EQUAL(node->Start(), textNode->Start());
 +
 +    BOOST_REQUIRE_EQUAL(node->End(), xmlNode->End());
 +    BOOST_REQUIRE_EQUAL(node->End(), binaryNode->End());
 +    BOOST_REQUIRE_EQUAL(node->End(), textNode->End());
 +
 +    BOOST_REQUIRE_EQUAL(node->SplitDim(), xmlNode->SplitDim());
 +    BOOST_REQUIRE_EQUAL(node->SplitDim(), binaryNode->SplitDim());
 +    BOOST_REQUIRE_EQUAL(node->SplitDim(), textNode->SplitDim());
 +
 +    if (std::abs(node->SplitValue()) < 1e-5)
 +    {
 +      BOOST_REQUIRE_SMALL(xmlNode->SplitValue(), 1e-5);
 +      BOOST_REQUIRE_SMALL(binaryNode->SplitValue(), 1e-5);
 +      BOOST_REQUIRE_SMALL(textNode->SplitValue(), 1e-5);
 +    }
 +    else
 +    {
 +      BOOST_REQUIRE_CLOSE(node->SplitValue(), xmlNode->SplitValue(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->SplitValue(), binaryNode->SplitValue(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->SplitValue(), textNode->SplitValue(), 1e-5);
 +    }
 +
 +    if (std::abs(node->LogNegError()) < 1e-5)
 +    {
 +      BOOST_REQUIRE_SMALL(xmlNode->LogNegError(), 1e-5);
 +      BOOST_REQUIRE_SMALL(binaryNode->LogNegError(), 1e-5);
 +      BOOST_REQUIRE_SMALL(textNode->LogNegError(), 1e-5);
 +    }
 +    else
 +    {
 +      BOOST_REQUIRE_CLOSE(node->LogNegError(), xmlNode->LogNegError(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->LogNegError(), binaryNode->LogNegError(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->LogNegError(), textNode->LogNegError(), 1e-5);
 +    }
 +
 +    if (std::abs(node->SubtreeLeavesLogNegError()) < 1e-5)
 +    {
 +      BOOST_REQUIRE_SMALL(xmlNode->SubtreeLeavesLogNegError(), 1e-5);
 +      BOOST_REQUIRE_SMALL(binaryNode->SubtreeLeavesLogNegError(), 1e-5);
 +      BOOST_REQUIRE_SMALL(textNode->SubtreeLeavesLogNegError(), 1e-5);
 +    }
 +    else
 +    {
 +      BOOST_REQUIRE_CLOSE(node->SubtreeLeavesLogNegError(),
 +          xmlNode->SubtreeLeavesLogNegError(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->SubtreeLeavesLogNegError(),
 +          binaryNode->SubtreeLeavesLogNegError(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->SubtreeLeavesLogNegError(),
 +          textNode->SubtreeLeavesLogNegError(), 1e-5);
 +    }
 +
 +    BOOST_REQUIRE_EQUAL(node->SubtreeLeaves(), xmlNode->SubtreeLeaves());
 +    BOOST_REQUIRE_EQUAL(node->SubtreeLeaves(), binaryNode->SubtreeLeaves());
 +    BOOST_REQUIRE_EQUAL(node->SubtreeLeaves(), textNode->SubtreeLeaves());
 +
 +    if (std::abs(node->Ratio()) < 1e-5)
 +    {
 +      BOOST_REQUIRE_SMALL(xmlNode->Ratio(), 1e-5);
 +      BOOST_REQUIRE_SMALL(binaryNode->Ratio(), 1e-5);
 +      BOOST_REQUIRE_SMALL(textNode->Ratio(), 1e-5);
 +    }
 +    else
 +    {
 +      BOOST_REQUIRE_CLOSE(node->Ratio(), xmlNode->Ratio(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->Ratio(), binaryNode->Ratio(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->Ratio(), textNode->Ratio(), 1e-5);
 +    }
 +
 +    if (std::abs(node->LogVolume()) < 1e-5)
 +    {
 +      BOOST_REQUIRE_SMALL(xmlNode->LogVolume(), 1e-5);
 +      BOOST_REQUIRE_SMALL(binaryNode->LogVolume(), 1e-5);
 +      BOOST_REQUIRE_SMALL(textNode->LogVolume(), 1e-5);
 +    }
 +    else
 +    {
 +      BOOST_REQUIRE_CLOSE(node->LogVolume(), xmlNode->LogVolume(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->LogVolume(), binaryNode->LogVolume(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->LogVolume(), textNode->LogVolume(), 1e-5);
 +    }
 +
 +    if (node->Left() == NULL)
 +    {
 +      BOOST_REQUIRE(xmlNode->Left() == NULL);
 +      BOOST_REQUIRE(binaryNode->Left() == NULL);
 +      BOOST_REQUIRE(textNode->Left() == NULL);
 +    }
 +    else
 +    {
 +      BOOST_REQUIRE(xmlNode->Left() != NULL);
 +      BOOST_REQUIRE(binaryNode->Left() != NULL);
 +      BOOST_REQUIRE(textNode->Left() != NULL);
 +
 +      // Push children onto stack.
 +      stack.push(node->Left());
 +      xmlStack.push(xmlNode->Left());
 +      binaryStack.push(binaryNode->Left());
 +      textStack.push(textNode->Left());
 +    }
 +
 +    if (node->Right() == NULL)
 +    {
 +      BOOST_REQUIRE(xmlNode->Right() == NULL);
 +      BOOST_REQUIRE(binaryNode->Right() == NULL);
 +      BOOST_REQUIRE(textNode->Right() == NULL);
 +    }
 +    else
 +    {
 +      BOOST_REQUIRE(xmlNode->Right() != NULL);
 +      BOOST_REQUIRE(binaryNode->Right() != NULL);
 +      BOOST_REQUIRE(textNode->Right() != NULL);
 +
 +      // Push children onto stack.
 +      stack.push(node->Right());
 +      xmlStack.push(xmlNode->Right());
 +      binaryStack.push(binaryNode->Right());
 +      textStack.push(textNode->Right());
 +    }
 +
 +    BOOST_REQUIRE_EQUAL(node->Root(), xmlNode->Root());
 +    BOOST_REQUIRE_EQUAL(node->Root(), binaryNode->Root());
 +    BOOST_REQUIRE_EQUAL(node->Root(), textNode->Root());
 +
 +    if (std::abs(node->AlphaUpper()) < 1e-5)
 +    {
 +      BOOST_REQUIRE_SMALL(xmlNode->AlphaUpper(), 1e-5);
 +      BOOST_REQUIRE_SMALL(binaryNode->AlphaUpper(), 1e-5);
 +      BOOST_REQUIRE_SMALL(textNode->AlphaUpper(), 1e-5);
 +    }
 +    else
 +    {
 +      BOOST_REQUIRE_CLOSE(node->AlphaUpper(), xmlNode->AlphaUpper(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->AlphaUpper(), binaryNode->AlphaUpper(), 1e-5);
 +      BOOST_REQUIRE_CLOSE(node->AlphaUpper(), textNode->AlphaUpper(), 1e-5);
 +    }
 +
 +    BOOST_REQUIRE_EQUAL(node->MaxVals().n_elem, xmlNode->MaxVals().n_elem);
 +    BOOST_REQUIRE_EQUAL(node->MaxVals().n_elem, binaryNode->MaxVals().n_elem);
 +    BOOST_REQUIRE_EQUAL(node->MaxVals().n_elem, textNode->MaxVals().n_elem);
 +    for (size_t i = 0; i < node->MaxVals().n_elem; ++i)
 +    {
 +      if (std::abs(node->MaxVals()[i]) < 1e-5)
 +      {
 +        BOOST_REQUIRE_SMALL(xmlNode->MaxVals()[i], 1e-5);
 +        BOOST_REQUIRE_SMALL(binaryNode->MaxVals()[i], 1e-5);
 +        BOOST_REQUIRE_SMALL(textNode->MaxVals()[i], 1e-5);
 +      }
 +      else
 +      {
 +        BOOST_REQUIRE_CLOSE(node->MaxVals()[i], xmlNode->MaxVals()[i], 1e-5);
 +        BOOST_REQUIRE_CLOSE(node->MaxVals()[i], binaryNode->MaxVals()[i], 1e-5);
 +        BOOST_REQUIRE_CLOSE(node->MaxVals()[i], textNode->MaxVals()[i], 1e-5);
 +      }
 +    }
 +
 +    BOOST_REQUIRE_EQUAL(node->MinVals().n_elem, xmlNode->MinVals().n_elem);
 +    BOOST_REQUIRE_EQUAL(node->MinVals().n_elem, binaryNode->MinVals().n_elem);
 +    BOOST_REQUIRE_EQUAL(node->MinVals().n_elem, textNode->MinVals().n_elem);
 +    for (size_t i = 0; i < node->MinVals().n_elem; ++i)
 +    {
 +      if (std::abs(node->MinVals()[i]) < 1e-5)
 +      {
 +        BOOST_REQUIRE_SMALL(xmlNode->MinVals()[i], 1e-5);
 +        BOOST_REQUIRE_SMALL(binaryNode->MinVals()[i], 1e-5);
 +        BOOST_REQUIRE_SMALL(textNode->MinVals()[i], 1e-5);
 +      }
 +      else
 +      {
 +        BOOST_REQUIRE_CLOSE(node->MinVals()[i], xmlNode->MinVals()[i], 1e-5);
 +        BOOST_REQUIRE_CLOSE(node->MinVals()[i], binaryNode->MinVals()[i], 1e-5);
 +        BOOST_REQUIRE_CLOSE(node->MinVals()[i], textNode->MinVals()[i], 1e-5);
 +      }
 +    }
 +  }
- }
 +
  BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list