[mlpack-git] master: Rename weightVectors to weights and simplify API. (8ace79e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Sep 4 14:07:03 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f3bd5e8853a795f4ff41849dd8ef844d53199412...2bd5e7889d810e86f7cdd586485f4b4e7a3b9cf0

>---------------------------------------------------------------

commit 8ace79ecba4e374a83792fb02dc8d7a26300b9ce
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Sep 4 17:55:01 2015 +0000

    Rename weightVectors to weights and simplify API.


>---------------------------------------------------------------

8ace79ecba4e374a83792fb02dc8d7a26300b9ce
 .../learning_policies/simple_weight_update.hpp     | 36 ++++++++++++----------
 src/mlpack/methods/perceptron/perceptron.hpp       |  8 ++---
 src/mlpack/methods/perceptron/perceptron_impl.hpp  | 23 +++++++-------
 3 files changed, 35 insertions(+), 32 deletions(-)

diff --git a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
index 603cd11..4fd79e4 100644
--- a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
+++ b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
@@ -30,26 +30,30 @@ class SimpleWeightUpdate
    * the weights of the incorrectly classified class while increasing the weight
    * of the correct class it should have been classified to.
    *
-   * @param trainData The training dataset.
-   * @param weightVectors Matrix of weight vectors.
-   * @param colIndex Index of the column which has been incorrectly predicted.
-   * @param labelIndex Index of the vector in trainData.
-   * @param vectorIndex Index of the class which should have been predicted.
-   * @param D Cost of mispredicting the labelIndex instance.
+   * @tparam Type of vector (should be an Armadillo vector like arma::vec or
+   *      arma::sp_vec or something similar).
+   * @param trainingPoint Point that was misclassified.
+   * @param weights Matrix of weights.
+   * @param biases Vector of biases.
+   * @param incorrectClass Index of class that the point was incorrectly
+   *      classified as.
+   * @param correctClass Index of the true class of the point.
+   * @param instanceWeight Weight to be given to this particular point during
+   *      training (this is useful for boosting).
    */
-  void UpdateWeights(const arma::mat& trainData,
-                     arma::mat& weightVectors,
+  template<typename VecType>
+  void UpdateWeights(const VecType& trainingPoint,
+                     arma::mat& weights,
                      arma::vec& biases,
-                     const size_t labelIndex,
-                     const size_t vectorIndex,
-                     const size_t colIndex,
-                     const arma::rowvec& D)
+                     const size_t incorrectClass,
+                     const size_t correctClass,
+                     const double instanceWeight = 1.0)
   {
-    weightVectors.col(colIndex) -= D(labelIndex) * trainData.col(labelIndex);
-    biases(colIndex) -= D(labelIndex);
+    weights.col(incorrectClass) -= instanceWeight * trainingPoint;
+    biases(incorrectClass) -= instanceWeight;
 
-    weightVectors.col(vectorIndex) += D(labelIndex) * trainData.col(labelIndex);
-    biases(vectorIndex) += D(labelIndex);
+    weights.col(correctClass) += instanceWeight * trainingPoint;
+    biases(correctClass) += instanceWeight;
   }
 };
 
diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
index d81d936..50eaa4d 100644
--- a/src/mlpack/methods/perceptron/perceptron.hpp
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -32,7 +32,7 @@ class Perceptron
 {
  public:
   /**
-   * Constructor - constructs the perceptron by building the weightVectors
+   * Constructor - constructs the perceptron by building the weights
    * matrix, which is later used in Classification.  It adds a bias input vector
    * of 1 to the input data to take care of the bias weights.
    *
@@ -46,7 +46,7 @@ class Perceptron
              const int iterations);
 
   /**
-   * Classification function. After training, use the weightVectors matrix to
+   * Classification function. After training, use the weights matrix to
    * classify test, and put the predicted classes in predictedLabels.
    *
    * @param test Testing data or data to classify.
@@ -81,12 +81,12 @@ private:
   size_t iter;
 
   /**
-   * Stores the weight vectors for each of the input class labels.  Each column
+   * Stores the weights for each of the input class labels.  Each column
    * corresponds to the weights for one class label, and each row corresponds to
    * the weights for one dimension of the input data.  The biases are held in a
    * separate vector.
    */
-  arma::mat weightVectors;
+  arma::mat weights;
 
   //! The biases for each class.
   arma::vec biases;
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.hpp b/src/mlpack/methods/perceptron/perceptron_impl.hpp
index c11f14f..5b5cbc2 100644
--- a/src/mlpack/methods/perceptron/perceptron_impl.hpp
+++ b/src/mlpack/methods/perceptron/perceptron_impl.hpp
@@ -13,10 +13,9 @@ namespace mlpack {
 namespace perceptron {
 
 /**
- * Constructor - constructs the perceptron. Or rather, builds the weightVectors
- * matrix, which is later used in Classification.
- * It adds a bias input vector of 1 to the input data to take care of the bias
- * weights.
+ * Constructor - constructs the perceptron. Or rather, builds the weights
+ * matrix, which is later used in classification.  It adds a bias input vector
+ * of 1 to the input data to take care of the bias weights.
  *
  * @param data Input, training data.
  * @param labels Labels of dataset.
@@ -34,7 +33,7 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
     const int iterations)
 {
   WeightInitializationPolicy WIP;
-  WIP.Initialize(weightVectors, biases, data.n_rows, arma::max(labels) + 1);
+  WIP.Initialize(weights, biases, data.n_rows, arma::max(labels) + 1);
 
   // Start training.
   iter = iterations;
@@ -46,8 +45,8 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
 
 
 /**
- * Classification function. After training, use the weightVectors matrix to
- * classify test, and put the predicted classes in predictedLabels.
+ * Classification function. After training, use the weights matrix to classify
+ * test, and put the predicted classes in predictedLabels.
  *
  * @param test testing data or data to classify.
  * @param predictedLabels vector to store the predicted classes after
@@ -68,7 +67,7 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
   // Could probably be faster if done in batch.
   for (size_t i = 0; i < test.n_cols; i++)
   {
-    tempLabelMat = weightVectors.t() * test.col(i) + biases;
+    tempLabelMat = weights.t() * test.col(i) + biases;
     tempLabelMat.max(maxIndex);
     predictedLabels(0, i) = maxIndex;
   }
@@ -99,7 +98,7 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
 
   // Insert a row of ones at the top of the training data set.
   WeightInitializationPolicy WIP;
-  WIP.Initialize(weightVectors, biases, data.n_rows, arma::max(labels) + 1);
+  WIP.Initialize(weights, biases, data.n_rows, arma::max(labels) + 1);
 
   Train(data, labels, D);
 }
@@ -151,7 +150,7 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
     {
       // Multiply for each variable and check whether the current weight vector
       // correctly classifies this.
-      tempLabelMat = weightVectors.t() * data.col(j) + biases;
+      tempLabelMat = weights.t() * data.col(j) + biases;
 
       tempLabelMat.max(maxIndexRow, maxIndexCol);
 
@@ -164,8 +163,8 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
         // Send maxIndexRow for knowing which weight to update, send j to know
         // the value of the vector to update it with.  Send tempLabel to know
         // the correct class.
-        LP.UpdateWeights(data, weightVectors, biases, j, tempLabel, maxIndexRow,
-            D);
+        LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow, tempLabel,
+            D(j));
       }
     }
   }



More information about the mlpack-git mailing list