[mlpack-git] master: Refactor weight vectors to be column-major. (19a4620)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Sep 4 13:32:58 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/424383cb02dcf2d73728e1c3c4b582bdb7cba627...f3bd5e8853a795f4ff41849dd8ef844d53199412

>---------------------------------------------------------------

commit 19a4620fc4081a63772b95987ad157b9d6dfd50b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Sep 4 16:28:00 2015 +0000

    Refactor weight vectors to be column-major.


>---------------------------------------------------------------

19a4620fc4081a63772b95987ad157b9d6dfd50b
 .../perceptron/initialization_methods/random_init.hpp    | 10 +++++-----
 .../perceptron/initialization_methods/zero_init.hpp      | 10 +++++-----
 .../learning_policies/simple_weight_update.hpp           | 14 +++++++-------
 src/mlpack/methods/perceptron/perceptron.hpp             |  6 +++++-
 src/mlpack/methods/perceptron/perceptron_impl.hpp        | 16 ++++++++--------
 5 files changed, 30 insertions(+), 26 deletions(-)

diff --git a/src/mlpack/methods/perceptron/initialization_methods/random_init.hpp b/src/mlpack/methods/perceptron/initialization_methods/random_init.hpp
index f88e952..7ec04df 100644
--- a/src/mlpack/methods/perceptron/initialization_methods/random_init.hpp
+++ b/src/mlpack/methods/perceptron/initialization_methods/random_init.hpp
@@ -22,14 +22,14 @@ class RandomInitialization
   RandomInitialization() { }
 
   inline static void Initialize(arma::mat& W,
-                                const size_t row,
-                                const size_t col)
+                                const size_t numFeatures,
+                                const size_t numClasses)
   {
-    W = arma::randu<arma::mat>(row, col);
+    W = arma::randu<arma::mat>(numFeatures, numClasses);
   }
 }; // class RandomInitialization
 
-}; // namespace perceptron
-}; // namespace mlpack
+} // namespace perceptron
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/methods/perceptron/initialization_methods/zero_init.hpp b/src/mlpack/methods/perceptron/initialization_methods/zero_init.hpp
index 0a02459..7fdf009 100644
--- a/src/mlpack/methods/perceptron/initialization_methods/zero_init.hpp
+++ b/src/mlpack/methods/perceptron/initialization_methods/zero_init.hpp
@@ -21,17 +21,17 @@ class ZeroInitialization
   ZeroInitialization() { }
 
   inline static void Initialize(arma::mat& W,
-                                const size_t row,
-                                const size_t col)
+                                const size_t numFeatures,
+                                const size_t numClasses)
   {
-    arma::mat tempWeights(row, col);
+    arma::mat tempWeights(numFeatures, numClasses);
     tempWeights.fill(0.0);
 
     W = tempWeights;
   }
 }; // class ZeroInitialization
 
-}; // namespace perceptron
-}; // namespace mlpack
+} // namespace perceptron
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
index b78adc9..262c2cd 100644
--- a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
+++ b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
@@ -32,7 +32,7 @@ class SimpleWeightUpdate
    *
    * @param trainData The training dataset.
    * @param weightVectors Matrix of weight vectors.
-   * @param rowIndex Index of the row which has been incorrectly predicted.
+   * @param colIndex Index of the column which has been incorrectly predicted.
    * @param labelIndex Index of the vector in trainData.
    * @param vectorIndex Index of the class which should have been predicted.
    * @param D Cost of mispredicting the labelIndex instance.
@@ -41,15 +41,15 @@ class SimpleWeightUpdate
                      arma::mat& weightVectors,
                      const size_t labelIndex,
                      const size_t vectorIndex,
-                     const size_t rowIndex,
+                     const size_t colIndex,
                      const arma::rowvec& D)
   {
-    weightVectors.row(rowIndex).subvec(1, weightVectors.n_cols - 1) -=
-        D(labelIndex) * trainData.col(labelIndex).t();
-    weightVectors(rowIndex, 0) -= D(labelIndex);
+    weightVectors.col(colIndex).subvec(1, weightVectors.n_rows - 1) -=
+        D(labelIndex) * trainData.col(labelIndex);
+    weightVectors(colIndex, 0) -= D(labelIndex);
 
-    weightVectors.row(vectorIndex).subvec(1, weightVectors.n_cols - 1) +=
-        D(labelIndex) * trainData.col(labelIndex).t();
+    weightVectors.col(vectorIndex).subvec(1, weightVectors.n_rows - 1) +=
+        D(labelIndex) * trainData.col(labelIndex);
     weightVectors(vectorIndex, 0) += D(labelIndex);
   }
 };
diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
index f2c9275..b556684 100644
--- a/src/mlpack/methods/perceptron/perceptron.hpp
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -80,7 +80,11 @@ private:
   //! To store the number of iterations
   size_t iter;
 
-  //! Stores the weight vectors for each of the input class labels.
+  /**
+   * Stores the weight vectors for each of the input class labels.  Each column
+   * corresponds to the weights for one class label, and each row corresponds to
+   * the weights for one dimension of the input data.
+   */
   arma::mat weightVectors;
 
   /**
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.hpp b/src/mlpack/methods/perceptron/perceptron_impl.hpp
index c1aafc2..1f4bf6c 100644
--- a/src/mlpack/methods/perceptron/perceptron_impl.hpp
+++ b/src/mlpack/methods/perceptron/perceptron_impl.hpp
@@ -34,7 +34,7 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
     const int iterations)
 {
   WeightInitializationPolicy WIP;
-  WIP.Initialize(weightVectors, arma::max(labels) + 1, data.n_rows + 1);
+  WIP.Initialize(weightVectors, data.n_rows + 1, arma::max(labels) + 1);
 
   // Start training.
   iter = iterations;
@@ -68,9 +68,9 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
   // Could probably be faster if done in batch.
   for (size_t i = 0; i < test.n_cols; i++)
   {
-    tempLabelMat = weightVectors.submat(0, 1, weightVectors.n_rows - 1,
-                                        weightVectors.n_cols - 1) *
-                                        test.col(i) + weightVectors.col(0);
+    tempLabelMat = weightVectors.submat(1, 0, weightVectors.n_rows - 1,
+                                        weightVectors.n_cols - 1).t() *
+                                        test.col(i) + weightVectors.row(0).t();
     tempLabelMat.max(maxIndex);
     predictedLabels(0, i) = maxIndex;
   }
@@ -101,7 +101,7 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
 
   // Insert a row of ones at the top of the training data set.
   WeightInitializationPolicy WIP;
-  WIP.Initialize(weightVectors, arma::max(labels) + 1, data.n_rows + 1);
+  WIP.Initialize(weightVectors, data.n_rows + 1, arma::max(labels) + 1);
 
   Train(data, labels, D);
 }
@@ -153,8 +153,8 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
     {
       // Multiply for each variable and check whether the current weight vector
       // correctly classifies this.
-      tempLabelMat = weightVectors.cols(1, weightVectors.n_cols - 1) *
-          data.col(j) + weightVectors.col(0);
+      tempLabelMat = weightVectors.rows(1, weightVectors.n_rows - 1).t() *
+          data.col(j) + weightVectors.row(0).t();
 
       tempLabelMat.max(maxIndexRow, maxIndexCol);
 
@@ -167,7 +167,7 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
         // Send maxIndexRow for knowing which weight to update, send j to know
         // the value of the vector to update it with.  Send tempLabel to know
         // the correct class.
-        LP.UpdateWeights(data, weightVectors, j, tempLabel, maxIndexRow, D);
+        LP.UpdateWeights(data, weightVectors, j, tempLabel, maxIndexCol, D);
       }
     }
   }



More information about the mlpack-git mailing list