[mlpack-git] master: 1 : remove inputSize from the constructors and data member 2 : provide static function initializeWeight to simplify constructor task 3 : change labels from arma::vec to arma::Row<size_t> 4 : update test cases (7a8b85c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Sep 30 09:27:19 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/8a8b708650f72c2aecbd9b4a12c8b16b4e0a3508...dc2c5c68dc4bfcdd2075b1a0fd2d641fce651669

>---------------------------------------------------------------

commit 7a8b85c49d76d546c8823032f59836dc5408d0be
Author: stereomatchingkiss <stereomatchingkiss at gmail.com>
Date:   Tue Sep 29 12:34:38 2015 +0800

    1 : remove inputSize from the constructors and data member
    2 : provide static function initializeWeight to simplify constructor task
    3 : change labels from arma::vec to arma::Row<size_t>
    4 : update test cases


>---------------------------------------------------------------

7a8b85c49d76d546c8823032f59836dc5408d0be
 .../softmax_regression/softmax_regression.hpp      | 15 ++----
 .../softmax_regression_function.cpp                | 36 +++++++++------
 .../softmax_regression_function.hpp                | 32 +++++++------
 .../softmax_regression/softmax_regression_impl.hpp | 22 +++------
 src/mlpack/tests/softmax_regression_test.cpp       | 53 +++++++++++-----------
 5 files changed, 76 insertions(+), 82 deletions(-)

diff --git a/src/mlpack/methods/softmax_regression/softmax_regression.hpp b/src/mlpack/methods/softmax_regression/softmax_regression.hpp
index bc44c36..fa7142a 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression.hpp
@@ -87,8 +87,7 @@ class SoftmaxRegression
    * @param fitIntercept add intercept term or not.
    */
   SoftmaxRegression(const arma::mat& data,
-                    const arma::vec& labels,
-                    const size_t inputSize,
+                    const arma::Row<size_t>& labels,
                     const size_t numClasses,
                     const double lambda = 0.0001,
                     const bool fitIntercept = false);
@@ -122,7 +121,7 @@ class SoftmaxRegression
    * @param testData Matrix of data points using which predictions are made.
    * @param labels Vector of labels associated with the data.
    */
-  double ComputeAccuracy(const arma::mat& testData, const arma::vec& labels);
+  double ComputeAccuracy(const arma::mat& testData, const arma::Row<size_t>& labels);
 
   /**
    * Train the softmax regression model with the given optimizer.
@@ -141,14 +140,9 @@ class SoftmaxRegression
    * @param numClasses Number of classes for classification.
    * @return Objective value of the final point.
    */
-  double Train(const arma::mat &data, const arma::vec& labels,
+  double Train(const arma::mat &data, const arma::Row<size_t>& labels,
                const size_t numClasses); 
 
-  //! Sets the size of the input vector.
-  size_t& InputSize() { return inputSize; }
-  //! Gets the size of the input vector.
-  size_t InputSize() const { return inputSize; }
-
   //! Sets the number of classes.
   size_t& NumClasses() { return numClasses; }
   //! Gets the number of classes.
@@ -176,7 +170,6 @@ class SoftmaxRegression
     using mlpack::data::CreateNVP;
 
     ar & CreateNVP(parameters, "parameters");   
-    ar & CreateNVP(inputSize, "inputSize");
     ar & CreateNVP(numClasses, "numClasses");
     ar & CreateNVP(lambda, "lambda");
     ar & CreateNVP(fitIntercept, "fitIntercept");
@@ -185,8 +178,6 @@ class SoftmaxRegression
  private:
   //! Parameters after optimization.
   arma::mat parameters;  
-  //! Size of input feature vector.
-  size_t inputSize;
   //! Number of classes.
   size_t numClasses;
   //! L2-regularization constant.
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp b/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp
index e317831..9482e86 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp
@@ -10,13 +10,11 @@ using namespace mlpack;
 using namespace mlpack::regression;
 
 SoftmaxRegressionFunction::SoftmaxRegressionFunction(const arma::mat& data,
-                                                     const arma::vec& labels,
-                                                     const size_t inputSize,
+                                                     const arma::Row<size_t>& labels,
                                                      const size_t numClasses,
                                                      const double lambda,
                                                      const bool fitIntercept) :
     data(data),    
-    inputSize(inputSize),
     numClasses(numClasses),
     lambda(lambda),
     fitIntercept(fitIntercept)
@@ -35,17 +33,25 @@ SoftmaxRegressionFunction::SoftmaxRegressionFunction(const arma::mat& data,
  */
 const arma::mat SoftmaxRegressionFunction::InitializeWeights()
 {    
-  // Initialize values to 0.005 * r. 'r' is a matrix of random values taken from
-  // a Gaussian distribution with mean zero and variance one.
-  // If the fitIntercept flag is true, parameters.col(0) is the intercept.
-  arma::mat parameters;
-  if (fitIntercept)
-    parameters.randn(numClasses, inputSize + 1);
-  else
-    parameters.randn(numClasses, inputSize);
-  parameters = 0.005 * parameters;
+  return InitializeWeights(data.n_rows, numClasses, fitIntercept);
+}
 
-  return parameters;
+const arma::mat SoftmaxRegressionFunction::
+InitializeWeights(const size_t featureSize,
+                  const size_t numClasses,
+                  const bool fitIntercept)
+{
+    // Initialize values to 0.005 * r. 'r' is a matrix of random values taken from
+    // a Gaussian distribution with mean zero and variance one.
+    // If the fitIntercept flag is true, parameters.col(0) is the intercept.
+    arma::mat parameters;
+    if (fitIntercept)
+      parameters.randn(numClasses, featureSize + 1);
+    else
+      parameters.randn(numClasses, featureSize);
+    parameters = 0.005 * parameters;
+
+    return parameters;
 }
 
 /**
@@ -53,7 +59,7 @@ const arma::mat SoftmaxRegressionFunction::InitializeWeights()
  * labels. The output is in the form of a matrix, which leads to simpler
  * calculations in the Evaluate() and Gradient() methods.
  */
-void SoftmaxRegressionFunction::GetGroundTruthMatrix(const arma::vec& labels,
+void SoftmaxRegressionFunction::GetGroundTruthMatrix(const arma::Row<size_t>& labels,
                                                      arma::sp_mat& groundTruth)
 {
   // Calculate the ground truth matrix according to the labels passed. The
@@ -69,7 +75,7 @@ void SoftmaxRegressionFunction::GetGroundTruthMatrix(const arma::vec& labels,
   // number of cumulative entries made uptil that column.
   for(size_t i = 0; i < labels.n_elem; i++)
   {
-    rowPointers(i) = labels(i, 0);
+    rowPointers(i) = labels(i);
     colPointers(i+1) = i + 1;
   }
 
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp
index d45d07a..9cf7dbd 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp
@@ -20,7 +20,7 @@ class SoftmaxRegressionFunction
    * Construct the Softmax Regression objective function with the given
    * parameters.
    *
-   * @param data Input training features.
+   * @param data Input training data, each column associate with one sample
    * @param labels Labels associated with the feature data.
    * @param inputSize Size of the input feature vector.
    * @param numClasses Number of classes for classification.
@@ -28,8 +28,7 @@ class SoftmaxRegressionFunction
    * @param fitIntercept Intercept term flag.
    */
   SoftmaxRegressionFunction(const arma::mat& data,
-                            const arma::vec& labels,
-                            const size_t inputSize,
+                            const arma::Row<size_t>& labels,
                             const size_t numClasses,
                             const double lambda = 0.0001,
                             const bool fitIntercept = false);
@@ -38,12 +37,25 @@ class SoftmaxRegressionFunction
   const arma::mat InitializeWeights();
 
   /**
+   * Initialize Softmax Regression weights(trainable parameters) with
+   * the given parameters.
+   * @param featureSize The features size of the training set
+   * @param numClasses Number of classes for classification.
+   * @param fitIntercept Intercept term flag.
+   * @return weights after initialize
+   */
+  static const arma::mat InitializeWeights(const size_t featureSize,
+                                           const size_t numClasses,
+                                           const bool fitIntercept = false);
+
+  /**
    * Constructs the ground truth label matrix with the passed labels.
    *
    * @param labels Labels associated with the training data.
    * @param groundTruth Pointer to arma::mat which stores the computed matrix.
    */
-  void GetGroundTruthMatrix(const arma::vec& labels, arma::sp_mat& groundTruth);
+  void GetGroundTruthMatrix(const arma::Row<size_t>& labels,
+                            arma::sp_mat& groundTruth);
 
   /**
    * Evaluate the probabilities matrix with the passed parameters.
@@ -82,16 +94,12 @@ class SoftmaxRegressionFunction
   //! Return the initial point for the optimization.
   const arma::mat& GetInitialPoint() const { return initialPoint; }  
 
-  //! Sets the size of the input vector.
-  size_t& InputSize() { return inputSize; }
-  //! Gets the size of the input vector.
-  size_t InputSize() const { return inputSize; }
-
-  //! Sets the number of classes.
-  size_t& NumClasses() { return numClasses; }
   //! Gets the number of classes.
   size_t NumClasses() const { return numClasses; }
 
+  //! Gets the features size of the training data
+  size_t FeatureSize() const { return data.n_rows; }
+
   //! Sets the regularization parameter.
   double& Lambda() { return lambda; }
   //! Gets the regularization parameter.
@@ -107,8 +115,6 @@ class SoftmaxRegressionFunction
   arma::sp_mat groundTruth;
   //! Initial parameter point.
   arma::mat initialPoint;
-  //! Size of input feature vector.
-  size_t inputSize;
   //! Number of classes.
   size_t numClasses;
   //! L2-regularization constant.
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
index 46afc6c..122524f 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
@@ -18,32 +18,25 @@ SoftmaxRegression<OptimizerType>::
 SoftmaxRegression(const size_t inputSize,
                   const size_t numClasses,
                   const bool fitIntercept) :    
-    inputSize(inputSize),
     numClasses(numClasses),
     lambda(0.0001),
     fitIntercept(fitIntercept)
 {  
-  arma::mat tmp;
-  arma::vec tmplabels;
-  SoftmaxRegressionFunction regressor(tmp, tmplabels,
-                                      inputSize, numClasses,
-                                      lambda, fitIntercept);
-  parameters = regressor.GetInitialPoint();
+  parameters = SoftmaxRegressionFunction::InitializeWeights(inputSize, numClasses,
+                                                            fitIntercept);
 }
 
 template<template<typename> class OptimizerType>
 SoftmaxRegression<OptimizerType>::SoftmaxRegression(const arma::mat& data,
-                                                    const arma::vec& labels,
-                                                    const size_t inputSize,
+                                                    const arma::Row<size_t>& labels,
                                                     const size_t numClasses,
                                                     const double lambda,
                                                     const bool fitIntercept) :    
-    inputSize(inputSize),
     numClasses(numClasses),
     lambda(lambda),
     fitIntercept(fitIntercept)
 {
-  SoftmaxRegressionFunction regressor(data, labels, inputSize, numClasses,
+  SoftmaxRegressionFunction regressor(data, labels, numClasses,
                                       lambda, fitIntercept);
   OptimizerType<SoftmaxRegressionFunction> optimizer(regressor);
 
@@ -55,7 +48,6 @@ template<template<typename> class OptimizerType>
 SoftmaxRegression<OptimizerType>::SoftmaxRegression(
     OptimizerType<SoftmaxRegressionFunction>& optimizer) :
     parameters(optimizer.Function().GetInitialPoint()),    
-    inputSize(optimizer.Function().InputSize()),
     numClasses(optimizer.Function().NumClasses()),
     lambda(optimizer.Function().Lambda()),
     fitIntercept(optimizer.Function().FitIntercept())
@@ -115,7 +107,7 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
 template<template<typename> class OptimizerType>
 double SoftmaxRegression<OptimizerType>::ComputeAccuracy(
     const arma::mat& testData,
-    const arma::vec& labels)
+    const arma::Row<size_t>& labels)
 {
   arma::vec predictions;
 
@@ -149,10 +141,10 @@ double SoftmaxRegression<OptimizerType>::Train(
 
 template<template<typename> class OptimizerType>
 double SoftmaxRegression<OptimizerType>::Train(const arma::mat& data,
-                                               const arma::vec& labels,
+                                               const arma::Row<size_t>& labels,
                                                const size_t numClasses)
 {
-  SoftmaxRegressionFunction regressor(data, labels, data.n_rows, numClasses,
+  SoftmaxRegressionFunction regressor(data, labels, numClasses,
                                       lambda, fitIntercept);
   OptimizerType<SoftmaxRegressionFunction> optimizer(regressor);
 
diff --git a/src/mlpack/tests/softmax_regression_test.cpp b/src/mlpack/tests/softmax_regression_test.cpp
index e1e57c7..a5e12f8 100644
--- a/src/mlpack/tests/softmax_regression_test.cpp
+++ b/src/mlpack/tests/softmax_regression_test.cpp
@@ -29,12 +29,12 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionFunctionEvaluate)
   data.randu(inputSize, points);
 
   // Create random class labels.
-  arma::vec labels(points);
+  arma::Row<size_t> labels(points);
   for(size_t i = 0; i < points; i++)
     labels(i) = math::RandInt(0, numClasses);
 
   // Create a SoftmaxRegressionFunction. Regularization term ignored.
-  SoftmaxRegressionFunction srf(data, labels, inputSize, numClasses, 0);
+  SoftmaxRegressionFunction srf(data, labels, numClasses, 0);
 
   // Run a number of trials.
   for(size_t i = 0; i < trials; i++)
@@ -74,14 +74,14 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionFunctionRegularizationEvaluate)
   data.randu(inputSize, points);
 
   // Create random class labels.
-  arma::vec labels(points);
+  arma::Row<size_t> labels(points);
   for(size_t i = 0; i < points; i++)
     labels(i) = math::RandInt(0, numClasses);
 
   // 3 objects for comparing regularization costs.
-  SoftmaxRegressionFunction srfNoReg(data, labels, inputSize, numClasses, 0);
-  SoftmaxRegressionFunction srfSmallReg(data, labels, inputSize, numClasses, 1);
-  SoftmaxRegressionFunction srfBigReg(data, labels, inputSize, numClasses, 20);
+  SoftmaxRegressionFunction srfNoReg(data, labels, numClasses, 0);
+  SoftmaxRegressionFunction srfSmallReg(data, labels, numClasses, 1);
+  SoftmaxRegressionFunction srfBigReg(data, labels, numClasses, 20);
 
   // Run a number of trials.
   for (size_t i = 0; i < trials; i++)
@@ -115,14 +115,14 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionFunctionGradient)
   data.randu(inputSize, points);
 
   // Create random class labels.
-  arma::vec labels(points);
+  arma::Row<size_t> labels(points);
   for(size_t i = 0; i < points; i++)
     labels(i) = math::RandInt(0, numClasses);
 
   // 2 objects for 2 terms in the cost function. Each term contributes towards
   // the gradient and thus need to be checked independently.
-  SoftmaxRegressionFunction srf1(data, labels, inputSize, numClasses, 0);
-  SoftmaxRegressionFunction srf2(data, labels, inputSize, numClasses, 20);
+  SoftmaxRegressionFunction srf1(data, labels, numClasses, 0);
+  SoftmaxRegressionFunction srf2(data, labels, numClasses, 20);
 
   // Create a random set of parameters.
   arma::mat parameters;
@@ -179,7 +179,7 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionTwoClasses)
   GaussianDistribution g2(arma::vec("4.0 3.0 4.0"), arma::eye<arma::mat>(3, 3));
 
   arma::mat data(inputSize, points);
-  arma::vec labels(points);
+  arma::Row<size_t> labels(points);
 
   for (size_t i = 0; i < points/2; i++)
   {
@@ -193,7 +193,7 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionTwoClasses)
   }
 
   // Train softmax regression object.
-  SoftmaxRegression<> sr(data, labels, inputSize, numClasses, lambda);
+  SoftmaxRegression<> sr(data, labels, numClasses, lambda);
 
   // Compare training accuracy to 100.
   const double acc = sr.ComputeAccuracy(data, labels);
@@ -224,20 +224,20 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionFitIntercept)
   GaussianDistribution g2(arma::vec("9.0 9.0 9.0"), arma::eye<arma::mat>(3, 3));
 
   arma::mat data(3, 1000);
-  arma::vec responses(1000);
+  arma::Row<size_t> responses(1000);
   for (size_t i = 0; i < 500; ++i)
   {
     data.col(i) = g1.Random();
     responses[i] = 0;
   }
-  for (size_t i = 501; i < 1000; ++i)
+  for (size_t i = 500; i < 1000; ++i)
   {
     data.col(i) = g2.Random();
     responses[i] = 1;
   }
 
   // Now train a logistic regression object on it.
-  SoftmaxRegression<> lr(data, responses, 3, 2, 0.01, true);
+  SoftmaxRegression<> lr(data, responses, 2, 0.01, true);
 
   // Ensure that the error is close to zero.
   const double acc = lr.ComputeAccuracy(data, responses);
@@ -249,7 +249,7 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionFitIntercept)
     data.col(i) = g1.Random();
     responses[i] = 0;
   }
-  for (size_t i = 501; i < 1000; ++i)
+  for (size_t i = 500; i < 1000; ++i)
   {
     data.col(i) = g2.Random();
     responses[i] = 1;
@@ -276,7 +276,7 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionMultipleClasses)
   GaussianDistribution g5(arma::vec("1.0 0.0 1.0 8.0 3.0"), identity);
 
   arma::mat data(inputSize, points);
-  arma::vec labels(points);
+  arma::Row<size_t> labels(points);
 
   for (size_t i = 0; i < points/5; i++)
   {
@@ -305,7 +305,7 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionMultipleClasses)
   }
 
   // Train softmax regression object.
-  SoftmaxRegression<> sr(data, labels, inputSize, numClasses, lambda);
+  SoftmaxRegression<> sr(data, labels, numClasses, lambda);
 
   // Compare training accuracy to 100.
   const double acc = sr.ComputeAccuracy(data, labels);
@@ -348,17 +348,16 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionTrainTest)
   // Make sure a SoftmaxRegression object trained with Train() operates the same
   // as a SoftmaxRegression object trained in the constructor.
   arma::mat dataset = arma::randu<arma::mat>(5, 1000);
-  arma::vec labels(1000);
+  arma::Row<size_t> labels(1000);
   for (size_t i = 0; i < 500; ++i)
-    labels[i] = 0.0;
+    labels[i] = size_t(0.0);
   for (size_t i = 500; i < 1000; ++i)
-    labels[i] = 1.0;
+    labels[i] = size_t(1.0);
 
 
   // This should be the same as the default parameters given by
   // SoftmaxRegression.
-  SoftmaxRegressionFunction srf(dataset, labels, dataset.n_rows, 2, 0.0001,
-      false);
+  SoftmaxRegressionFunction srf(dataset, labels, 2, 0.0001, false);
   L_BFGS<SoftmaxRegressionFunction> lbfgs(srf);
   SoftmaxRegression<> sr(lbfgs);
 
@@ -382,13 +381,13 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionOptimizerTrainTest)
 {
   // The same as the previous test, just passing in an instantiated optimizer.
   arma::mat dataset = arma::randu<arma::mat>(5, 1000);
-  arma::vec labels(1000);
+  arma::Row<size_t> labels(1000);
   for (size_t i = 0; i < 500; ++i)
-    labels[i] = 0.0;
+    labels[i] = size_t(0.0);
   for (size_t i = 500; i < 1000; ++i)
-    labels[i] = 1.0;
+    labels[i] = size_t(1.0);
 
-  SoftmaxRegressionFunction srf(dataset, labels, dataset.n_rows, 2, 0.01, true);
+  SoftmaxRegressionFunction srf(dataset, labels, 2, 0.01, true);
   L_BFGS<SoftmaxRegressionFunction> lbfgs(srf);
   SoftmaxRegression<> sr(lbfgs);
 
@@ -406,6 +405,6 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionOptimizerTrainTest)
     else
       BOOST_REQUIRE_CLOSE(sr.Parameters()[i], sr2.Parameters()[i], 0.01);
   }
-}
+}//*/
 
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list