[mlpack-git] master: Refactor perceptron to not modify input dataset. Minimize internally-held variables. (424383c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Sep 4 11:50:41 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f5893d5d190d5f5b4b6dc94e2593f50c56d406e4...424383cb02dcf2d73728e1c3c4b582bdb7cba627

>---------------------------------------------------------------

commit 424383cb02dcf2d73728e1c3c4b582bdb7cba627
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Sep 4 15:50:04 2015 +0000

    Refactor perceptron to not modify input dataset.
    Minimize internally-held variables.


>---------------------------------------------------------------

424383cb02dcf2d73728e1c3c4b582bdb7cba627
 .../learning_policies/simple_weight_update.hpp     | 20 ++---
 src/mlpack/methods/perceptron/perceptron.hpp       | 18 ++---
 src/mlpack/methods/perceptron/perceptron_impl.hpp  | 85 +++++++++++-----------
 3 files changed, 60 insertions(+), 63 deletions(-)

diff --git a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
index 12c67ea..b78adc9 100644
--- a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
+++ b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
@@ -26,9 +26,9 @@ class SimpleWeightUpdate
 {
  public:
   /**
-   * This function is called to update the weightVectors matrix.
-   *  It decreases the weights of the incorrectly classified class while
-   * increasing the weight of the correct class it should have been classified to.
+   * This function is called to update the weightVectors matrix.  It decreases
+   * the weights of the incorrectly classified class while increasing the weight
+   * of the correct class it should have been classified to.
    *
    * @param trainData The training dataset.
    * @param weightVectors Matrix of weight vectors.
@@ -44,15 +44,17 @@ class SimpleWeightUpdate
                      const size_t rowIndex,
                      const arma::rowvec& D)
   {
-    weightVectors.row(rowIndex) = weightVectors.row(rowIndex) -
-                                  D(labelIndex) * trainData.col(labelIndex).t();
+    weightVectors.row(rowIndex).subvec(1, weightVectors.n_cols - 1) -=
+        D(labelIndex) * trainData.col(labelIndex).t();
+    weightVectors(rowIndex, 0) -= D(labelIndex);
 
-    weightVectors.row(vectorIndex) = weightVectors.row(vectorIndex) +
-                                     D(labelIndex) * trainData.col(labelIndex).t();
+    weightVectors.row(vectorIndex).subvec(1, weightVectors.n_cols - 1) +=
+        D(labelIndex) * trainData.col(labelIndex).t();
+    weightVectors(vectorIndex, 0) += D(labelIndex);
   }
 };
 
-}; // namespace perceptron
-}; // namespace mlpack
+} // namespace perceptron
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
index b58d538..f2c9275 100644
--- a/src/mlpack/methods/perceptron/perceptron.hpp
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -43,7 +43,7 @@ class Perceptron
    */
   Perceptron(const MatType& data,
              const arma::Row<size_t>& labels,
-             int iterations);
+             const int iterations);
 
   /**
    * Classification function. After training, use the weightVectors matrix to
@@ -66,7 +66,7 @@ class Perceptron
    * @param labels The labels of data.
    */
   Perceptron(const Perceptron<>& other,
-             MatType& data,
+             const MatType& data,
              const arma::rowvec& D,
              const arma::Row<size_t>& labels);
 
@@ -80,21 +80,17 @@ private:
   //! To store the number of iterations
   size_t iter;
 
-  //! Stores the class labels for the input data.
-  arma::Row<size_t> classLabels;
-
   //! Stores the weight vectors for each of the input class labels.
   arma::mat weightVectors;
 
-  //! Stores the training data to be used later on in UpdateWeights.
-  arma::mat trainData;
-
   /**
-   *  Training Function. It trains on trainData using the cost matrix D
+   * Training Function. It trains on trainData using the cost matrix D
    *
-   *  @param D Cost matrix. Stores the cost of mispredicting instances
+   * @param D Cost matrix. Stores the cost of mispredicting instances
    */
-  void Train(const arma::rowvec& D);
+  void Train(const MatType& data,
+             const arma::Row<size_t>& labels,
+             const arma::rowvec& D);
 };
 
 } // namespace perceptron
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.hpp b/src/mlpack/methods/perceptron/perceptron_impl.hpp
index 01a105f..c1aafc2 100644
--- a/src/mlpack/methods/perceptron/perceptron_impl.hpp
+++ b/src/mlpack/methods/perceptron/perceptron_impl.hpp
@@ -31,25 +31,17 @@ template<
 Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
     const MatType& data,
     const arma::Row<size_t>& labels,
-    int iterations)
+    const int iterations)
 {
   WeightInitializationPolicy WIP;
   WIP.Initialize(weightVectors, arma::max(labels) + 1, data.n_rows + 1);
 
   // Start training.
-  classLabels = labels;
-
-  trainData = data;
-  // Insert a row of ones at the top of the training data set.
-  MatType zOnes(1, data.n_cols);
-  zOnes.fill(1);
-  trainData.insert_rows(0, zOnes);
-
   iter = iterations;
   arma::rowvec D(data.n_cols);
   D.fill(1.0);// giving equal weight to all the points.
 
-  Train(D);
+  Train(data, labels, D);
 }
 
 
@@ -61,53 +53,57 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
  * @param predictedLabels vector to store the predicted classes after
  *      classifying test
  */
-template <typename LearnPolicy, typename WeightInitializationPolicy, typename MatType>
+template<
+    typename LearnPolicy,
+    typename WeightInitializationPolicy,
+    typename MatType
+>
 void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
     const MatType& test,
     arma::Row<size_t>& predictedLabels)
 {
-  arma::mat tempLabelMat;
-  arma::uword maxIndexRow, maxIndexCol;
+  arma::vec tempLabelMat;
+  arma::uword maxIndex;
 
+  // Could probably be faster if done in batch.
   for (size_t i = 0; i < test.n_cols; i++)
   {
     tempLabelMat = weightVectors.submat(0, 1, weightVectors.n_rows - 1,
                                         weightVectors.n_cols - 1) *
                                         test.col(i) + weightVectors.col(0);
-    tempLabelMat.max(maxIndexRow, maxIndexCol);
-    predictedLabels(0, i) = maxIndexRow;
+    tempLabelMat.max(maxIndex);
+    predictedLabels(0, i) = maxIndex;
   }
-  // predictedLabels.print("These are the labels predicted by the perceptron");
 }
 
 /**
- *  Alternate constructor which copies parameters from an already initiated
- *  perceptron.
+ * Alternate constructor which copies parameters from an already initiated
+ * perceptron.
  *
- *  @param other The other initiated Perceptron object from which we copy the
- *               values from.
- *  @param data The data on which to train this Perceptron object on.
- *  @param D Weight vector to use while training. For boosting purposes.
- *  @param labels The labels of data.
+ * @param other The other initiated Perceptron object from which we copy the
+ *     values from.
+ * @param data The data on which to train this Perceptron object on.
+ * @param D Weight vector to use while training. For boosting purposes.
+ * @param labels The labels of data.
  */
-template <typename LearnPolicy, typename WeightInitializationPolicy, typename MatType>
+template<
+    typename LearnPolicy,
+    typename WeightInitializationPolicy,
+    typename MatType
+>
 Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
-  const Perceptron<>& other, MatType& data, const arma::rowvec& D, const arma::Row<size_t>& labels)
+    const Perceptron<>& other,
+    const MatType& data,
+    const arma::rowvec& D,
+    const arma::Row<size_t>& labels)
 {
-
-  classLabels = labels;
-  trainData = data;
   iter = other.iter;
 
   // Insert a row of ones at the top of the training data set.
-  MatType zOnes(1, data.n_cols);
-  zOnes.fill(1);
-  trainData.insert_rows(0, zOnes);
-
   WeightInitializationPolicy WIP;
   WIP.Initialize(weightVectors, arma::max(labels) + 1, data.n_rows + 1);
 
-  Train(D);
+  Train(data, labels, D);
 }
 
 //! Serialize the perceptron.
@@ -123,9 +119,9 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Serialize(
 }
 
 /**
- *  Training Function. It trains on trainData using the cost matrix D
+ * Training Function. It trains on trainData using the cost matrix D
  *
- *  @param D Cost matrix. Stores the cost of mispredicting instances
+ * @param D Cost matrix. Stores the cost of mispredicting instances
  */
 template<
     typename LearnPolicy,
@@ -133,7 +129,9 @@ template<
     typename MatType
 >
 void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
-     const arma::rowvec& D)
+    const MatType& data,
+    const arma::Row<size_t>& labels,
+    const arma::rowvec& D)
 {
   size_t j, i = 0;
   bool converged = false;
@@ -151,30 +149,31 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
     converged = true;
 
     // Now this inner loop is for going through the dataset in each iteration.
-    for (j = 0; j < trainData.n_cols; j++)
+    for (j = 0; j < data.n_cols; j++)
     {
       // Multiply for each variable and check whether the current weight vector
       // correctly classifies this.
-      tempLabelMat = weightVectors * trainData.col(j);
+      tempLabelMat = weightVectors.cols(1, weightVectors.n_cols - 1) *
+          data.col(j) + weightVectors.col(0);
 
       tempLabelMat.max(maxIndexRow, maxIndexCol);
 
       // Check whether prediction is correct.
-      if (maxIndexRow != classLabels(0, j))
+      if (maxIndexRow != labels(0, j))
       {
         // Due to incorrect prediction, convergence set to false.
         converged = false;
-        tempLabel = classLabels(0, j);
+        tempLabel = labels(0, j);
         // Send maxIndexRow for knowing which weight to update, send j to know
         // the value of the vector to update it with.  Send tempLabel to know
         // the correct class.
-        LP.UpdateWeights(trainData, weightVectors, j, tempLabel, maxIndexRow, D);
+        LP.UpdateWeights(data, weightVectors, j, tempLabel, maxIndexRow, D);
       }
     }
   }
 }
 
-}; // namespace perceptron
-}; // namespace mlpack
+} // namespace perceptron
+} // namespace mlpack
 
 #endif



More information about the mlpack-git mailing list