[mlpack-git] master: 1 : add function Serialize 2 : add function Train 3 : fix bug--did not initialize fitIntercept (11746e9)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Sep 29 09:33:49 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/cbeb3ea17262b7c5115247dc217e316c529249b7...f85a9b22f3ce56143943a2488c05c2810d6b2bf3

>---------------------------------------------------------------

commit 11746e9a626c76bab9b94bbdf4976260490c2b99
Author: stereomatchingkiss <stereomatchingkiss at gmail.com>
Date:   Mon Sep 28 10:05:56 2015 +0800

    1 : add function Serialize
    2 : add function Train
    3 : fix bug--did not initialize fitIntercept


>---------------------------------------------------------------

11746e9a626c76bab9b94bbdf4976260490c2b99
 .../softmax_regression/softmax_regression.hpp      | 103 +++++++++++++++++++--
 .../softmax_regression/softmax_regression_impl.hpp |  97 +++++++++++++------
 2 files changed, 162 insertions(+), 38 deletions(-)

diff --git a/src/mlpack/methods/softmax_regression/softmax_regression.hpp b/src/mlpack/methods/softmax_regression/softmax_regression.hpp
index 1f95c81..0c0d2b0 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression.hpp
@@ -57,10 +57,34 @@ namespace regression {
 
 template<
   template<typename> class OptimizerType = mlpack::optimization::L_BFGS
->
+  >
 class SoftmaxRegression
 {
  public:
+  /**
+   * Initialize the SoftmaxRegression without performing training.
+   * Default value of lambda is 0.0001.
+   * Be sure to use Train() before calling Predict() or ComputeAccuracy(),
+   * otherwise the results may be meaningless.
+   *
+   * @param inputSize Size of the input feature vector.
+   * @param numClasses Number of classes for classification.
+   * @param fitIntercept add intercept term or not.
+   */
+  SoftmaxRegression(const size_t inputSize,
+                    const size_t numClasses,
+                    const bool fitIntercept = false);
+
+  /**
+   * Construct the SoftmaxRegression class with the provided data and labels.
+   * This will train the model.
+   *
+   * @param fileName name of the files saving the model contents
+   * @param name name of the structure to be save
+   * @exception If the file cannot be load, the exception will thrown
+   */
+  SoftmaxRegression(const std::string &fileName,
+                    const std::string& name);
 
   /**
    * Construct the SoftmaxRegression class with the provided data and labels.
@@ -68,7 +92,7 @@ class SoftmaxRegression
    * passed, which controls the amount of L2-regularization in the objective
    * function. By default, the model takes a small value.
    *
-   * @param data Input training features.
+   * @param data Input training features. Each column associate with one sample
    * @param labels Labels associated with the feature data.
    * @param inputSize Size of the input feature vector.
    * @param numClasses Number of classes for classification.
@@ -113,23 +137,84 @@ class SoftmaxRegression
    */
   double ComputeAccuracy(const arma::mat& testData, const arma::vec& labels);
 
+  /**
+   * Train the softmax regression model with the given optimizer.
+   * The optimizer should hold an instantiated
+   * SoftmaxRegressionFunction object for the function to operate upon. This
+   * option should be preferred when the optimizer options are to be changed.
+   * @param optimizer Instantiated optimizer with instantiated error function.
+   * @return Objective value of the final point.
+   */
+  double Train(OptimizerType<SoftmaxRegressionFunction>& optimizer);
+
+  /**
+   * Train the softmax regression with the given training data.
+   * @param data Input data with each column as one example.
+   * @param labels Labels associated with the feature data.
+   * @param numClasses Number of classes for classification.
+   * @return Objective value of the final point.
+   */
+  double Train(const arma::mat &data, const arma::vec& labels,
+               const size_t numClasses);
+
   //! Sets the size of the input vector.
-  size_t& InputSize() { return inputSize; }
+  size_t& InputSize() {
+    return inputSize;
+  }
   //! Gets the size of the input vector.
-  size_t InputSize() const { return inputSize; }
+  size_t InputSize() const {
+    return inputSize;
+  }
 
   //! Sets the number of classes.
-  size_t& NumClasses() { return numClasses; }
+  size_t& NumClasses() {
+    return numClasses;
+  }
   //! Gets the number of classes.
-  size_t NumClasses() const { return numClasses; }
+  size_t NumClasses() const {
+    return numClasses;
+  }
 
   //! Sets the regularization parameter.
-  double& Lambda() { return lambda; }
+  double& Lambda() {
+    return lambda;
+  }
   //! Gets the regularization parameter.
-  double Lambda() const { return lambda; }
+  double Lambda() const {
+    return lambda;
+  }
 
   //! Gets the intercept term flag.  We can't change this after training.
-  bool FitIntercept() const { return fitIntercept; }
+  bool FitIntercept() const {
+    return fitIntercept;
+  }
+
+  //! get the training parameters
+  arma::mat& Parameters()
+  {
+    return parameters;
+  }
+
+  //! get the training parameters
+  const arma::mat& Parameters() const
+  {
+    return parameters;
+  }
+
+  /**
+   * Serialize the SparseAutoencoder
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
+  {
+    using mlpack::data::CreateNVP;
+
+    ar & CreateNVP(parameters, "parameters");
+    ar & CreateNVP(inputSize, "inputSize");
+    ar & CreateNVP(numClasses, "numClasses");
+    ar & CreateNVP(lambda, "lambda");
+    ar & CreateNVP(fitIntercept, "fitIntercept");
+  }
 
  private:
   //! Parameters after optimization.
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
index 01e042c..2cf369f 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
@@ -14,47 +14,59 @@ namespace mlpack {
 namespace regression {
 
 template<template<typename> class OptimizerType>
+SoftmaxRegression<OptimizerType>::
+SoftmaxRegression(const size_t inputSize,
+                  const size_t numClasses,
+                  const bool fitIntercept) :
+  inputSize{inputSize},
+  numClasses{numClasses},
+  lambda{0.0001},
+  fitIntercept{fitIntercept}
+{
+  SoftmaxRegressionFunction regressor(arma::mat(), 1,
+                                      inputSize, numClasses,
+                                      lambda, fitIntercept);
+  parameters = regressor.GetInitialPoint();
+}
+
+template<template<typename> class OptimizerType>
+SoftmaxRegression<OptimizerType>::
+SoftmaxRegression(const std::string &fileName,
+                  const std::string& name)
+{
+  data::Load(fileName, name, *this, true);
+}
+
+template<template<typename> class OptimizerType>
 SoftmaxRegression<OptimizerType>::SoftmaxRegression(const arma::mat& data,
                                                     const arma::vec& labels,
                                                     const size_t inputSize,
                                                     const size_t numClasses,
                                                     const double lambda,
                                                     const bool fitIntercept) :
-    inputSize(inputSize),
-    numClasses(numClasses),
-    lambda(lambda),
-    fitIntercept(fitIntercept)
+  inputSize{inputSize},
+  numClasses{numClasses},
+  lambda{lambda},
+  fitIntercept{fitIntercept}
 {
   SoftmaxRegressionFunction regressor(data, labels, inputSize, numClasses,
                                       lambda, fitIntercept);
   OptimizerType<SoftmaxRegressionFunction> optimizer(regressor);
 
   parameters = regressor.GetInitialPoint();
-
-  // Train the model.
-  Timer::Start("softmax_regression_optimization");
-  const double out = optimizer.Optimize(parameters);
-  Timer::Stop("softmax_regression_optimization");
-
-  Log::Info << "SoftmaxRegression::SoftmaxRegression(): final objective of "
-      << "trained model is " << out << "." << std::endl;
+  Train(optimizer);
 }
 
 template<template<typename> class OptimizerType>
 SoftmaxRegression<OptimizerType>::SoftmaxRegression(
-    OptimizerType<SoftmaxRegressionFunction>& optimizer) :
-    parameters(optimizer.Function().GetInitialPoint()),
-    inputSize(optimizer.Function().InputSize()),
-    numClasses(optimizer.Function().NumClasses()),
-    lambda(optimizer.Function().Lambda())
+  OptimizerType<SoftmaxRegressionFunction>& optimizer) :
+  parameters(optimizer.Function().GetInitialPoint()),
+  inputSize{optimizer.Function().InputSize()},
+  numClasses{optimizer.Function().NumClasses()},
+  lambda{optimizer.Function().Lambda()},
+  fitIntercept{optimizer.Function().FitIntercept()}
 {
-  // Train the model.
-  Timer::Start("softmax_regression_optimization");
-  const double out = optimizer.Optimize(parameters);
-  Timer::Stop("softmax_regression_optimization");
-
-  Log::Info << "SoftmaxRegression::SoftmaxRegression(): final objective of "
-      << "trained model is " << out << "." << std::endl;
+  Train(optimizer);
 }
 
 template<template<typename> class OptimizerType>
@@ -72,8 +84,8 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
     // Since the cost of join maybe high due to the copy of original data,
     // split the hypothesis computation to two components.
     hypothesis = arma::exp(
-        arma::repmat(parameters.col(0), 1, testData.n_cols) +
-        parameters.cols(1, parameters.n_cols - 1) * testData);
+      arma::repmat(parameters.col(0), 1, testData.n_cols) +
+      parameters.cols(1, parameters.n_cols - 1) * testData);
   }
   else
   {
@@ -97,7 +109,7 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
       if(probabilities(j, i) > maxProbability)
       {
         maxProbability = probabilities(j, i);
-        predictions(i) = j;
+        predictions(i) = static_cast<double>(j);
       }
     }
 
@@ -108,8 +120,8 @@ void SoftmaxRegression<OptimizerType>::Predict(const arma::mat& testData,
 
 template<template<typename> class OptimizerType>
 double SoftmaxRegression<OptimizerType>::ComputeAccuracy(
-    const arma::mat& testData,
-    const arma::vec& labels)
+  const arma::mat& testData,
+  const arma::vec& labels)
 {
   arma::vec predictions;
 
@@ -126,6 +138,33 @@ double SoftmaxRegression<OptimizerType>::ComputeAccuracy(
   return (count * 100.0) / predictions.n_elem;
 }
 
+template<template<typename> class OptimizerType>
+double SoftmaxRegression<OptimizerType>::
+Train(OptimizerType<SoftmaxRegressionFunction>& optimizer)
+{
+  // Train the model.
+  Timer::Start("softmax_regression_optimization");
+  const double out = optimizer.Optimize(parameters);
+  Timer::Stop("softmax_regression_optimization");
+
+  Log::Info << "SoftmaxRegression::SoftmaxRegression(): final objective of "
+            << "trained model is " << out << "." << std::endl;
+
+  return out;
+}
+
+template<template<typename> class OptimizerType>
+double SoftmaxRegression<OptimizerType>::
+Train(const arma::mat &data, const arma::vec& labels,
+      const size_t numClasses)
+{
+  SoftmaxRegressionFunction regressor(data, labels, data.n_rows, numClasses,
+                                      lambda, fitIntercept);
+  OptimizerType<SoftmaxRegressionFunction> optimizer(regressor);
+
+  return Train(optimizer);
+}
+
 }; // namespace regression
 }; // namespace mlpack
 



More information about the mlpack-git mailing list