[mlpack-git] master: add intercept term to softmax regression (12aa888)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 10:50:39 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0e51cd72212d267c0f299ecdb6f2edb127d60280...50fe86904c980df87e1db0067cb86ef73f3ccaa2

>---------------------------------------------------------------

commit 12aa8888f018066baf865dc8820e9b9ecc7b31ed
Author: apir8181 <kazenoyumechen at gmail.com>
Date:   Thu Mar 5 16:01:38 2015 +0800

    add intercept term to softmax regression


>---------------------------------------------------------------

12aa8888f018066baf865dc8820e9b9ecc7b31ed
 .../softmax_regression/softmax_regression.hpp      | 12 +++-
 .../softmax_regression_function.cpp                | 75 +++++++++++++++++-----
 .../softmax_regression_function.hpp                | 25 +++++++-
 .../softmax_regression/softmax_regression_impl.hpp | 25 ++++++--
 src/mlpack/tests/softmax_regression_test.cpp       | 44 +++++++++++++
 5 files changed, 160 insertions(+), 21 deletions(-)

diff --git a/src/mlpack/methods/softmax_regression/softmax_regression.hpp b/src/mlpack/methods/softmax_regression/softmax_regression.hpp
index 88db1c3..bf5e6c7 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression.hpp
@@ -73,12 +73,14 @@ class SoftmaxRegression
    * @param inputSize Size of the input feature vector.
    * @param numClasses Number of classes for classification.
    * @param lambda L2-regularization constant.
+   * @param fitIntercept add intercept term or not.
    */
   SoftmaxRegression(const arma::mat& data,
                     const arma::vec& labels,
                     const size_t inputSize,
                     const size_t numClasses,
-                    const double lambda = 0.0001);
+                    const double lambda = 0.0001,
+                    const bool fitIntercept = false);
                     
   /**
    * Construct the softmax regression model with the given training data. This
@@ -147,6 +149,12 @@ class SoftmaxRegression
     return lambda;
   }
 
+  //! Gets the intercept term flag.
+  bool FitIntercept() const
+  {
+    return fitIntercept;
+  }
+                    
  private:
   //! Parameters after optimization.
   arma::mat parameters;
@@ -156,6 +164,8 @@ class SoftmaxRegression
   size_t numClasses;
   //! L2-regularization constant.
   double lambda;
+  //! Intercept term flag.
+  bool fitIntercept;
 };
 
 }; // namespace regression
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp b/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp
index 97a4a2b..9fb5e64 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_function.cpp
@@ -13,11 +13,13 @@ SoftmaxRegressionFunction::SoftmaxRegressionFunction(const arma::mat& data,
                                                      const arma::vec& labels,
                                                      const size_t inputSize,
                                                      const size_t numClasses,
-                                                     const double lambda) :
+                                                     const double lambda,
+                                                     const bool fitIntercept) :
     data(data),
     inputSize(inputSize),
     numClasses(numClasses),
-    lambda(lambda)
+    lambda(lambda),
+    fitIntercept(fitIntercept)
 {
   // Intialize the parameters to suitable values.
   initialPoint = InitializeWeights();
@@ -35,8 +37,12 @@ const arma::mat SoftmaxRegressionFunction::InitializeWeights()
 {
   // Initialize values to 0.005 * r. 'r' is a matrix of random values taken from
   // a Gaussian distribution with mean zero and variance one.
+  // If the fitIntercept flag is true, parameters.col(0) is the intercept.
   arma::mat parameters;
-  parameters.randn(numClasses, inputSize);
+  if (fitIntercept)
+    parameters.randn(numClasses, inputSize + 1);
+  else
+    parameters.randn(numClasses, inputSize);
   parameters = 0.005 * parameters;
 
   return parameters;
@@ -77,6 +83,35 @@ void SoftmaxRegressionFunction::GetGroundTruthMatrix(const arma::vec& labels,
 }
 
 /**
+ * Evaluate the probabilities matrix. If fitIntercept flag is true,
+ * it should consider the parameters.cols(0) intercept term.
+ */
+void SoftmaxRegressionFunction::GetProbabilitiesMatrix(
+    const arma::mat& parameters, arma::mat& probabilities) const
+{
+  arma::mat hypothesis;
+
+  if (fitIntercept)
+  {
+    // In order to add the intercept term, we should compute following matrix:
+    //     [1; data] = arma::join_cols(ones(1, data.n_cols), data)
+    //     hypothesis = arma::exp(parameters * [1; data]).
+    //
+    // Since the cost of join maybe high due to the copy of original data,
+    // split the hypothesis computation to two components.
+    hypothesis = arma::exp(arma::repmat(parameters.col(0), 1, data.n_cols) +
+        parameters.cols(1, parameters.n_cols - 1) * data);
+  }
+  else
+  {
+    hypothesis = arma::exp(parameters * data);
+  }
+
+  probabilities = hypothesis / arma::repmat(arma::sum(hypothesis, 0),
+                                            numClasses, 1);
+}
+
+/**
  * Evaluates the objective function given the parameters.
  */
 double SoftmaxRegressionFunction::Evaluate(const arma::mat& parameters) const
@@ -97,11 +132,8 @@ double SoftmaxRegressionFunction::Evaluate(const arma::mat& parameters) const
   // The sum is calculated over all the classes.
   // x_i is the input vector for a particular training example.
   // theta_j is the parameter vector associated with a particular class.
-  arma::mat hypothesis, probabilities;
-
-  hypothesis = arma::exp(parameters * data);
-  probabilities = hypothesis / arma::repmat(arma::sum(hypothesis, 0),
-                                            numClasses, 1);
+  arma::mat probabilities;
+  GetProbabilitiesMatrix(parameters, probabilities);
 
   // Calculate the log likelihood and regularization terms.
   double logLikelihood, weightDecay, cost;
@@ -129,13 +161,26 @@ void SoftmaxRegressionFunction::Gradient(const arma::mat& parameters,
   // The sum is calculated over all the classes.
   // x_i is the input vector for a particular training example.
   // theta_j is the parameter vector associated with a particular class.
-  arma::mat hypothesis, probabilities;
-
-  hypothesis = arma::exp(parameters * data);
-  probabilities = hypothesis / arma::repmat(arma::sum(hypothesis, 0),
-                                            numClasses, 1);
+  arma::mat probabilities;
+  GetProbabilitiesMatrix(parameters, probabilities);
 
   // Calculate the parameter gradients.
-  gradient = (probabilities - groundTruth) * data.t() / data.n_cols +
-      lambda * parameters;
+  gradient.set_size(parameters.n_rows, parameters.n_cols);
+  if (fitIntercept)
+  {
+    // Treating the intercept term parameters.col(0) seperately to avoid
+    // the cost of building matrix [1; data].
+    arma::mat inner = probabilities - groundTruth;
+    gradient.col(0) =
+        inner * arma::ones<arma::mat>(data.n_cols, 1) / data.n_cols +
+        lambda * parameters.col(0);
+    gradient.cols(1, parameters.n_cols - 1) =
+        inner * data.t() / data.n_cols +
+        lambda * parameters.cols(1, parameters.n_cols - 1);
+  }
+  else
+  {
+    gradient = (probabilities - groundTruth) * data.t() / data.n_cols +
+        lambda * parameters;
+  }
 }
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp
index fc22384..bd09041 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_function.hpp
@@ -25,12 +25,14 @@ class SoftmaxRegressionFunction
    * @param inputSize Size of the input feature vector.
    * @param numClasses Number of classes for classification.
    * @param lambda L2-regularization constant.
+   * @param fitIntercept Intercept term flag.
    */
   SoftmaxRegressionFunction(const arma::mat& data,
                             const arma::vec& labels,
                             const size_t inputSize,
                             const size_t numClasses,
-                            const double lambda = 0.0001);
+                            const double lambda = 0.0001,
+                            const bool fitIntercept = false);
 
   //! Initializes the parameters of the model to suitable values.
   const arma::mat InitializeWeights();
@@ -44,6 +46,18 @@ class SoftmaxRegressionFunction
   void GetGroundTruthMatrix(const arma::vec& labels, arma::sp_mat& groundTruth);
 
   /**
+   * Evaluate the probabilities matrix with the passed parameters.
+   * probabilities(i, j) =
+   *     exp(\theta_i * data_j) / sum_k(exp(\theta_k * data_j)).
+   * It represents the probability of data_j belongs to class i.
+   *
+   * @param parameters Current values of the model parameters.
+   * @param probabilities Pointer to arma::mat which stores the probabilities.
+   */
+  void GetProbabilitiesMatrix(const arma::mat& parameters,
+                              arma::mat& probabilities) const;
+
+  /**
    * Evaluates the objective function of the softmax regression model using the
    * given parameters. The cost function has terms for the log likelihood error
    * and the regularization cost. The objective function takes a low value when
@@ -104,6 +118,12 @@ class SoftmaxRegressionFunction
     return lambda;
   }
 
+  //! Gets the intercept flag.
+  bool FitIntercept() const
+  {
+    return fitIntercept;
+  }
+
  private:
   //! Training data matrix.
   const arma::mat& data;
@@ -117,6 +137,9 @@ class SoftmaxRegressionFunction
   size_t numClasses;
   //! L2-regularization constant.
   double lambda;
+  //! Intercept term flag.
+  bool fitIntercept;
+
 };
 
 }; // namespace regression
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
index b16ea0f..9a810ac 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
@@ -18,13 +18,15 @@ SoftmaxRegression<OptimizerType>::SoftmaxRegression(const arma::mat& data,
                                                     const arma::vec& labels,
                                                     const size_t inputSize,
                                                     const size_t numClasses,
-                                                    const double lambda) :
+                                                    const double lambda,
+                                                    const bool fitIntercept) :
     inputSize(inputSize),
     numClasses(numClasses),
-    lambda(lambda)
+    lambda(lambda),
+    fitIntercept(fitIntercept)
 {
   SoftmaxRegressionFunction regressor(data, labels, inputSize, numClasses,
-                                      lambda);
+                                      lambda, fitIntercept);
   OptimizerType<SoftmaxRegressionFunction> optimizer(regressor);
   
   parameters = regressor.GetInitialPoint();
@@ -61,8 +63,23 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
 {
   // Calculate the probabilities for each test input.
   arma::mat hypothesis, probabilities;
+  if (fitIntercept)
+  {
+    // In order to add the intercept term, we should compute following matrix:
+    //     [1; data] = arma::join_cols(ones(1, data.n_cols), data)
+    //     hypothesis = arma::exp(parameters * [1; data]).
+    //
+    // Since the cost of join maybe high due to the copy of original data,
+    // split the hypothesis computation to two components.
+    hypothesis = arma::exp(
+        arma::repmat(parameters.col(0), 1, testData.n_cols) +
+        parameters.cols(1, parameters.n_cols - 1) * testData);
+  }
+  else
+  {
+    hypothesis = arma::exp(parameters * testData);
+  }
   
-  hypothesis = arma::exp(parameters * testData);
   probabilities = hypothesis / arma::repmat(arma::sum(hypothesis, 0),
                                             numClasses, 1);
   
diff --git a/src/mlpack/tests/softmax_regression_test.cpp b/src/mlpack/tests/softmax_regression_test.cpp
index da6a522..01b0eae 100644
--- a/src/mlpack/tests/softmax_regression_test.cpp
+++ b/src/mlpack/tests/softmax_regression_test.cpp
@@ -215,6 +215,50 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionTwoClasses)
   BOOST_REQUIRE_CLOSE(testAcc, 100.0, 0.6);
 }
 
+BOOST_AUTO_TEST_CASE(SoftmaxRegressionFitIntercept)
+{
+  // Generate a two-Gaussian dataset,
+  // which can't be seperated without adding the intercept term.
+  GaussianDistribution g1(arma::vec("1.0 1.0 1.0"), arma::eye<arma::mat>(3, 3));
+  GaussianDistribution g2(arma::vec("9.0 9.0 9.0"), arma::eye<arma::mat>(3, 3));
+
+  arma::mat data(3, 1000);
+  arma::vec responses(1000);
+  for (size_t i = 0; i < 500; ++i)
+  {
+    data.col(i) = g1.Random();
+    responses[i] = 0;
+  }
+  for (size_t i = 501; i < 1000; ++i)
+  {
+    data.col(i) = g2.Random();
+    responses[i] = 1;
+  }
+
+  // Now train a logistic regression object on it.
+  SoftmaxRegression<> lr(data, responses, 3, 2, 0.01, true);
+
+  // Ensure that the error is close to zero.
+  const double acc = lr.ComputeAccuracy(data, responses);
+  BOOST_REQUIRE_CLOSE(acc, 100.0, 2.0);
+
+  // Create a test set.
+  for (size_t i = 0; i < 500; ++i)
+  {
+    data.col(i) = g1.Random();
+    responses[i] = 0;
+  }
+  for (size_t i = 501; i < 1000; ++i)
+  {
+    data.col(i) = g2.Random();
+    responses[i] = 1;
+  }
+
+  // Ensure that the error is close to zero.
+  const double testAcc = lr.ComputeAccuracy(data, responses);
+  BOOST_REQUIRE_CLOSE(testAcc, 100.0, 2.0);
+}
+
 BOOST_AUTO_TEST_CASE(SoftmaxRegressionMultipleClasses)
 {
   const size_t points = 5000;



More information about the mlpack-git mailing list